├── .gitignore ├── .gitmodules ├── Dockerfile ├── LICENSE ├── PREPARE_DATASETS.md ├── README.md ├── config.py ├── datasets ├── __init__.py ├── camvid.py ├── cityscapes.py ├── cityscapes_labels.py ├── comma10k.py ├── kitti.py ├── mapillary.py ├── nullloader.py ├── sampler.py └── uniform.py ├── demo.py ├── demo_folder.py ├── eval.py ├── images ├── method.png └── vis.png ├── loss.py ├── network ├── Resnet.py ├── SEresnext.py ├── __init__.py ├── deepv3.py ├── mynn.py └── wider_resnet.py ├── optimizer.py ├── scripts ├── eval_cityscapes_SEResNeXt50.sh ├── eval_cityscapes_WideResNet38.sh ├── submit_cityscapes_WideResNet38.sh ├── test_kitti_WideResNet38.sh ├── train_cityscapes_SEResNeXt50.sh ├── train_cityscapes_WideResNet38.sh ├── train_comma10k_WideResNet38.sh ├── train_kitti_WideResNet38.sh ├── train_mapillary_SEResNeXt50.sh └── train_mapillary_WideResNet38.sh ├── sdcnet ├── _aug.sh ├── _eval.sh ├── datasets │ ├── __init__.py │ ├── dataset_utils.py │ └── frame_loader.py ├── main.py ├── models │ ├── __init__.py │ ├── model_utils.py │ └── sdc_net2d.py ├── sdc_aug.py ├── spatialdisplconv_package │ ├── __init__.py │ ├── setup.py │ ├── spatialdisplconv.py │ ├── spatialdisplconv_cuda.cc │ ├── spatialdisplconv_kernel.cu │ ├── spatialdisplconv_kernel.cuh │ └── test_spatialdisplconv.py └── utility │ ├── Dockerfile │ └── tools.py ├── train.py ├── transforms ├── __init__.py ├── joint_transforms.py └── transforms.py └── utils ├── __init__.py ├── attr_dict.py ├── misc.py └── my_data_parallel.py /.gitignore: -------------------------------------------------------------------------------- 1 | ckpt 2 | setenv 3 | weights 4 | *~ 5 | *.pyc 6 | dump_imgs_train 7 | tb 8 | __pycache__/ 9 | .idea/ 10 | build/ 11 | *.egg-info/ 12 | dist/ 13 | *.py[cod] 14 | *.swp 15 | *.o 16 | *.so 17 | .torch 18 | .DS_Store 19 | pretrained_models 20 | logs 21 | -------------------------------------------------------------------------------- /.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "sdcnet/flownet2_pytorch"] 2 | path = sdcnet/flownet2_pytorch 3 | url = https://github.com/NVIDIA/flownet2-pytorch 4 | -------------------------------------------------------------------------------- /Dockerfile: -------------------------------------------------------------------------------- 1 | # NVIDIA Pytorch Image 2 | FROM nvcr.io/nvidia/pytorch:19.05-py3 3 | 4 | RUN pip install numpy 5 | RUN pip install sklearn 6 | RUN pip install h5py 7 | RUN pip install jupyter 8 | RUN pip install scikit-image 9 | RUN pip install pillow 10 | RUN pip install piexif 11 | RUN pip install cffi 12 | RUN pip install tqdm 13 | RUN pip install dominate 14 | RUN pip install tensorboardX 15 | RUN pip install opencv-python 16 | RUN pip install nose 17 | RUN pip install ninja 18 | 19 | RUN apt-get update 20 | RUN apt-get install libgtk2.0-dev -y && rm -rf /var/lib/apt/lists/* 21 | 22 | 23 | # Install Apex 24 | RUN cd /home/ && git clone https://github.com/NVIDIA/apex.git apex && cd apex && python setup.py install --cuda_ext --cpp_ext 25 | WORKDIR /home/ 26 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Copyright (C) 2019 NVIDIA Corporation. Yi Zhu, Karan Sapra, Fitsum A. Reda, Kevin J. Shih, Shawn Newsam, Andrew Tao and Bryan Catanzaro. 2 | All rights reserved. 3 | Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode). 4 | 5 | Permission to use, copy, modify, and distribute this software and its documentation 6 | for any non-commercial purpose is hereby granted without fee, provided that the above 7 | copyright notice appear in all copies and that both that copyright notice and this 8 | permission notice appear in supporting documentation, and that the name of the author 9 | not be used in advertising or publicity pertaining to distribution of the software 10 | without specific, written prior permission. 11 | 12 | THE AUTHOR DISCLAIMS ALL WARRANTIES WITH REGARD TO THIS SOFTWARE, INCLUDING ALL 13 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR ANY PARTICULAR PURPOSE. 14 | IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR ANY SPECIAL, INDIRECT OR CONSEQUENTIAL 15 | DAMAGES OR ANY DAMAGES WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, 16 | WHETHER IN AN ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING 17 | OUT OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. 18 | -------------------------------------------------------------------------------- /PREPARE_DATASETS.md: -------------------------------------------------------------------------------- 1 | ## Mapillary Vistas Dataset 2 | 3 | First of all, please request the research edition dataset from [here](https://www.mapillary.com/dataset/vistas/). The downloaded file is named as `mapillary-vistas-dataset_public_v1.1.zip`. 4 | 5 | Then simply unzip the file by 6 | ```shell 7 | unzip mapillary-vistas-dataset_public_v1.1.zip 8 | ``` 9 | 10 | The folder structure will look like: 11 | ``` 12 | Mapillary 13 | ├── config.json 14 | ├── demo.py 15 | ├── Mapillary Vistas Research Edition License.pdf 16 | ├── README 17 | ├── requirements.txt 18 | ├── training 19 | │ ├── images 20 | │ ├── instances 21 | │ ├── labels 22 | │ ├── panoptic 23 | ├── validation 24 | │ ├── images 25 | │ ├── instances 26 | │ ├── labels 27 | │ ├── panoptic 28 | ├── testing 29 | │ ├── images 30 | │ ├── instances 31 | │ ├── labels 32 | │ ├── panoptic 33 | ``` 34 | Note that, the `instances`, `labels` and `panoptic` folders inside `testing` are empty. 35 | 36 | Suppose you store your dataset at `~/username/data/Mapillary`, please update the dataset path in `config.py`, 37 | ``` 38 | __C.DATASET.MAPILLARY_DIR = '~/username/data/Mapillary' 39 | ``` 40 | 41 | ## Cityscapes Dataset 42 | 43 | ### Download Dataset 44 | First of all, please request the dataset from [here](https://www.cityscapes-dataset.com/). You need multiple files. 45 | ``` 46 | - leftImg8bit_trainvaltest.zip 47 | - gtFine_trainvaltest.zip 48 | - leftImg8bit_trainextra.zip 49 | - gtCoarse.zip 50 | - leftImg8bit_sequence.zip # This file is very large, 324G. You only need it if you want to run sdc_aug experiments. 51 | ``` 52 | 53 | If you prefer to use command lines (e.g., `wget`) to download the dataset, 54 | ``` 55 | # First step, obtain your login credentials. 56 | Please register an account at https://www.cityscapes-dataset.com/login/. 57 | 58 | # Second step, log into cityscapes system, suppose you already have a USERNAME and a PASSWORD. 59 | wget --keep-session-cookies --save-cookies=cookies.txt --post-data 'username=USERNAME&password=PASSWORD&submit=Login' https://www.cityscapes-dataset.com/login/ 60 | 61 | # Third step, download the zip files you need. 62 | wget -c -t 0 --load-cookies cookies.txt --content-disposition https://www.cityscapes-dataset.com/file-handling/?packageID=3 63 | 64 | # The corresponding packageID is listed below, 65 | 1 -> gtFine_trainvaltest.zip (241MB) md5sum: 4237c19de34c8a376e9ba46b495d6f66 66 | 2 -> gtCoarse.zip (1.3GB) md5sum: 1c7b95c84b1d36cc59a9194d8e5b989f 67 | 3 -> leftImg8bit_trainvaltest.zip (11GB) md5sum: 0a6e97e94b616a514066c9e2adb0c97f 68 | 4 -> leftImg8bit_trainextra.zip (44GB) md5sum: 9167a331a158ce3e8989e166c95d56d4 69 | 14 -> leftImg8bit_sequence.zip (324GB) md5sum: 4348961b135d856c1777f7f1098f7266 70 | ``` 71 | 72 | ### Prepare Folder Structure 73 | 74 | Now unzip those files, the desired folder structure will look like, 75 | ``` 76 | Cityscapes 77 | ├── leftImg8bit_trainvaltest 78 | │ ├── leftImg8bit 79 | │ │ ├── train 80 | │ │ │ ├── aachen 81 | │ │ │ │ ├── aachen_000000_000019_leftImg8bit.png 82 | │ │ │ │ ├── aachen_000001_000019_leftImg8bit.png 83 | │ │ │ │ ├── ... 84 | │ │ │ ├── bochum 85 | │ │ │ ├── ... 86 | │ │ ├── val 87 | │ │ ├── test 88 | ├── gtFine_trainvaltest 89 | │ ├── gtFine 90 | │ │ ├── train 91 | │ │ │ ├── aachen 92 | │ │ │ │ ├── aachen_000000_000019_gtFine_color.png 93 | │ │ │ │ ├── aachen_000000_000019_gtFine_instanceIds.png 94 | │ │ │ │ ├── aachen_000000_000019_gtFine_labelIds.png 95 | │ │ │ │ ├── aachen_000000_000019_gtFine_polygons.json 96 | │ │ │ │ ├── ... 97 | │ │ │ ├── bochum 98 | │ │ │ ├── ... 99 | │ │ ├── val 100 | │ │ ├── test 101 | ├── leftImg8bit_trainextra 102 | │ ├── leftImg8bit 103 | │ │ ├── train_extra 104 | │ │ │ ├── augsburg 105 | │ │ │ ├── bad-honnef 106 | │ │ │ ├── ... 107 | ├── gtCoarse 108 | │ ├── gtCoarse 109 | │ │ ├── train 110 | │ │ ├── train_extra 111 | │ │ ├── val 112 | ├── leftImg8bit_sequence 113 | │ ├── train 114 | │ ├── val 115 | │ ├── test 116 | ``` 117 | 118 | ## CamVid Dataset 119 | 120 | Please download and prepare this dataset according to the [tutorial](https://github.com/alexgkendall/SegNet-Tutorial/tree/master/CamVid). The desired folder structure will look like, 121 | ``` 122 | CamVid 123 | ├── train 124 | ├── trainannot 125 | ├── val 126 | ├── valannot 127 | ├── test 128 | ├── testannot 129 | ``` 130 | 131 | ## KITTI Dataset 132 | 133 | Please download this dataset at the KITTI Semantic Segmentation benchmark [webpage](http://www.cvlibs.net/datasets/kitti/eval_semantics.php). 134 | 135 | Now unzip the file, the desired folder structure will look like, 136 | ``` 137 | KITTI 138 | ├── training 139 | │ ├── image_2 140 | │ ├── instance 141 | │ ├── semantic 142 | ├── test 143 | │ ├── image_2 144 | ``` 145 | There is no official training/validation split as the dataset only has `200` training samples. We randomly create three splits at [here](https://github.com/NVIDIA/semantic-segmentation/blob/master/datasets/kitti.py#L41-L44) in order to perform cross-validation. 146 | 147 | 148 | 149 | 150 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Improving Semantic Segmentation via Video Prediction and Label Relaxation 2 | ### [Project](https://nv-adlr.github.io/publication/2018-Segmentation) | [Paper](https://arxiv.org/pdf/1812.01593.pdf) | [YouTube](https://www.youtube.com/watch?v=aEbXjGZDZSQ) | [Cityscapes Score](https://www.cityscapes-dataset.com/anonymous-results/?id=555fc2b66c6e00b953c72b98b100e396c37274e0788e871a85f1b7b4f4fa130e) | [Kitti Score](http://www.cvlibs.net/datasets/kitti/eval_semseg_detail.php?benchmark=semantics2015&result=83cac7efbd41b1f2fc095f9bc1168bc548b48885)
3 | PyTorch implementation of our CVPR2019 paper (oral) on achieving state-of-the-art semantic segmentation results using Deeplabv3-Plus like architecture with a WideResNet38 trunk. We present a video prediction-based methodology to scale up training sets by synthesizing new training samples and propose a novel label relaxation technique to make training objectives robust to label noise.
4 | 5 | [Improving Semantic Segmentation via Video Propagation and Label Relaxation](https://nv-adlr.github.io/publication/2018-Segmentation)
6 | Yi Zhu1*, Karan Sapra2*, [Fitsum A. Reda](https://scholar.google.com/citations?user=quZ_qLYAAAAJ&hl=en)2, Kevin J. Shih2, Shawn Newsam1, Andrew Tao2, [Bryan Catanzaro](http://catanzaro.name/)2 7 | 1UC Merced, 2NVIDIA Corporation
8 | In CVPR 2019 (* equal contributions). 9 | 10 | [SDCNet: Video Prediction using Spatially Displaced Convolution](https://nv-adlr.github.io/publication/2018-SDCNet) 11 | [Fitsum A. Reda](https://scholar.google.com/citations?user=quZ_qLYAAAAJ&hl=en), Guilin Liu, Kevin J. Shih, Robert Kirby, Jon Barker, David Tarjan, Andrew Tao, [Bryan Catanzaro](http://catanzaro.name/)
12 | NVIDIA Corporation
13 | In ECCV 2018. 14 | 15 | ![alt text](images/method.png) 16 | 17 | ## Installation 18 | 19 | # Get Semantic Segmentation source code 20 | git clone --recursive https://github.com/NVIDIA/semantic-segmentation.git 21 | cd semantic-segmentation 22 | 23 | # Build Docker Image 24 | docker build -t nvidia-segmgentation -f Dockerfile . 25 | 26 | If you prefer not to use docker, you can manually install the following requirements: 27 | 28 | * An NVIDIA GPU and CUDA 9.0 or higher. Some operations only have gpu implementation. 29 | * PyTorch (>= 0.5.1) 30 | * Python 3 31 | * numpy 32 | * sklearn 33 | * h5py 34 | * scikit-image 35 | * pillow 36 | * piexif 37 | * cffi 38 | * tqdm 39 | * dominate 40 | * tensorboardX 41 | * opencv-python 42 | * nose 43 | * ninja 44 | 45 | 46 | Multiple GPU training and mixed precision training are supported, and the code provides examples for training and inference. For more help, type
47 | 48 | python3 train.py --help 49 | 50 | 51 | 52 | ## Network architectures 53 | 54 | Our repo now supports DeepLabV3+ architecture with different backbones, including `WideResNet38`, `SEResNeXt(50, 101)` and `ResNet(50,101)`. 55 | 56 | 57 | ## Pre-trained models 58 | We've included pre-trained models. Download checkpoints to a folder `pretrained_models`. 59 | 60 | * [pretrained_models/cityscapes_best.pth](https://drive.google.com/file/d/1P4kPaMY-SmQ3yPJQTJ7xMGAB_Su-1zTl/view?usp=sharing)[1071MB, WideResNet38 backbone] 61 | * [pretrained_models/camvid_best.pth](https://drive.google.com/file/d/1OzUCbFdXulB2P80Qxm7C3iNTeTP0Mvb_/view?usp=sharing)[1071MB, WideResNet38 backbone] 62 | * [pretrained_models/kitti_best.pth](https://drive.google.com/file/d/1OrTcqH_I3PHFiMlTTZJgBy8l_pladwtg/view?usp=sharing)[1071MB, WideResNet38 backbone] 63 | * [pretrained_models/sdc_cityscapes_vrec.pth.tar](https://drive.google.com/file/d/1OxnJo2tFEQs3vuY01ibPFjn3cRCo2yWt/view?usp=sharing)[38MB] 64 | * [pretrained_models/FlowNet2_checkpoint.pth.tar](https://drive.google.com/file/d/1hF8vS6YeHkx3j2pfCeQqqZGwA_PJq_Da/view?usp=sharing)[620MB] 65 | 66 | ImageNet Weights 67 | * [pretrained_models/wider_resnet38.pth.tar](https://drive.google.com/file/d/1OfKQPQXbXGbWAQJj2R82x6qyz6f-1U6t/view?usp=sharing)[833MB] 68 | 69 | Other Weights 70 | * [pretrained_models/cityscapes_cv0_seresnext50_nosdcaug.pth](https://drive.google.com/file/d/1aGdA1WAKKkU2y-87wSOE1prwrIzs_L-h/view?usp=sharing)[324MB] 71 | * [pretrained_models/cityscapes_cv0_wideresnet38_nosdcaug.pth](https://drive.google.com/file/d/1CKB7gpcPLgDLA7LuFJc46rYcNzF3aWzH/view?usp=sharing)[1.1GB] 72 | 73 | ## Data Loaders 74 | 75 | Dataloaders for Cityscapes, Mapillary, Camvid and Kitti are available in [datasets](./datasets). Details of preparing each dataset can be found at [PREPARE_DATASETS.md](https://github.com/NVIDIA/semantic-segmentation/blob/master/PREPARE_DATASETS.md)
76 | 77 | 78 | ## Semantic segmentation demo for a single image 79 | 80 | If you want to try our trained model on any driving scene images, simply use 81 | 82 | ``` 83 | CUDA_VISIBLE_DEVICES=0 python demo.py --demo-image YOUR_IMG --snapshot ./pretrained_models/cityscapes_best.pth --save-dir YOUR_SAVE_DIR 84 | ``` 85 | This snapshot is trained on Cityscapes dataset, with `DeepLabV3+` architecture and `WideResNet38` backbone. The predicted segmentation masks will be saved to `YOUR_SAVE_DIR`. Check it out. 86 | 87 | ## Semantic segmentation demo for a folder of images 88 | 89 | If you want to try our trained model on a folder of driving scene images, simply use 90 | 91 | ``` 92 | CUDA_VISIBLE_DEVICES=0 python demo_folder.py --demo-folder YOUR_FOLDER --snapshot ./pretrained_models/cityscapes_best.pth --save-dir YOUR_SAVE_DIR 93 | ``` 94 | This snapshot is trained on Cityscapes dataset, with `DeepLabV3+` architecture and `WideResNet38` backbone. The predicted segmentation masks will be saved to `YOUR_SAVE_DIR`. Check it out. 95 | 96 | ## A quick start with light SEResNeXt50 backbone 97 | 98 | Note that, in this section, we use the standard train/val split in Cityscapes to train our model, which is `cv 0`. 99 | 100 | If you have less than 8 GPUs in your machine, please change `--nproc_per_node=8` to the number of GPUs you have in all the .sh files under folder `scripts`. 101 | 102 | ### Pre-Training on Mapillary 103 | First, you can pre-train a DeepLabV3+ model with `SEResNeXt50` trunk on Mapillary dataset. Set `__C.DATASET.MAPILLARY_DIR` in `config.py` to where you store the Mapillary data. 104 | 105 | ``` 106 | ./scripts/train_mapillary_SEResNeXt50.sh 107 | ``` 108 | 109 | When you first run training on a new dataset with flag `--class_uniform_pct` on, it will take some time to preprocess the dataset. Depending on your machine, the preprocessing can take half an hour or more. Once it finishes, you will have a json file in your root folder, e.g., `mapillary_tile1024.json`. You can read more details about `class uniform sampling` in our paper, the idea is to make sure that all classes are approximately uniformly chosen during training. 110 | 111 | ### Fine-tuning on Cityscapes 112 | Once you have the Mapillary pre-trained model (training mIoU should be 50+), you can start fine-tuning the model on Cityscapes dataset. Set `__C.DATASET.CITYSCAPES_DIR` in `config.py` to where you store the Cityscapes data. Your training mIoU in the end should be 80+. 113 | ``` 114 | ./scripts/train_cityscapes_SEResNeXt50.sh 115 | ``` 116 | 117 | ### Inference 118 | 119 | Our inference code supports two ways of evaluation: pooling and sliding based eval. The pooling based eval is faster than sliding based eval but provides slightly lower numbers. We use `sliding` as default. 120 | ``` 121 | ./scripts/eval_cityscapes_SEResNeXt50.sh 122 | ``` 123 | 124 | In the `result_save_location` you set, you will find several folders: `rgb`, `pred`, `compose` and `diff`. `rgb` contains the color-encode predicted segmentation masks. `pred` contains what you need to submit to the evaluation server. `compose` contains the overlapped images of original video frame and the color-encode predicted segmentation masks. `diff` contains the difference between our prediction and the ground truth. 125 | 126 | Right now, our inference code only supports Cityscapes dataset. 127 | 128 | 129 | ## Reproducing our results with heavy WideResNet38 backbone 130 | 131 | Note that, in this section, we use an alternative train/val split in Cityscapes to train our model, which is `cv 2`. You can find the difference between `cv 0` and `cv 2` in the supplementary material section in our arXiv paper. 132 | 133 | ### Pre-Training on Mapillary 134 | ``` 135 | ./scripts/train_mapillary_WideResNet38.sh 136 | ``` 137 | 138 | ### Fine-tuning on Cityscapes 139 | ``` 140 | ./scripts/train_cityscapes_WideResNet38.sh 141 | ``` 142 | 143 | ### Inference 144 | ``` 145 | ./scripts/eval_cityscapes_WideResNet38.sh 146 | ``` 147 | 148 | For submitting to Cityscapes benchmark, we change it to multi-scale setting. 149 | ``` 150 | ./scripts/submit_cityscapes_WideResNet38.sh 151 | ``` 152 | 153 | Now you can zip the `pred` folder and upload to Cityscapes leaderboard. For the test submission, there is nothing in the `diff` folder because we don't have ground truth. 154 | 155 | At this point, you can already achieve top performance on Cityscapes benchmark (83+ mIoU). In order to further boost the segmentation performance, we can use the augmented dataset to help model's generalization capibility. 156 | 157 | ### Label Propagation using Video Prediction 158 | First, you need to donwload the Cityscapes sequence dataset. Note that the sequence dataset is very large (a 325GB .zip file). Then we can use video prediction model to propagate GT segmentation masks to adjacent video frames, so that we can have more annotated image-label pairs during training. 159 | 160 | ``` 161 | cd ./sdcnet 162 | 163 | bash flownet2_pytorch/install.sh 164 | 165 | ./_aug.sh 166 | ``` 167 | 168 | By default, we predict five past frames and five future frames, which effectively enlarge the dataset 10 times. If you prefer to propagate less or more time steps, you can change the `--propagate` accordingly. Enjoy the augmented dataset. 169 | 170 | 171 | ## Results on Cityscapes 172 | 173 | ![alt text](images/vis.png) 174 | 175 | ## Training IOU using fp16 176 | 177 | Training results for WideResNet38 and SEResNeXt50 trained in fp16 on DGX-1 (8-GPU V100). fp16 can significantly speed up experiments without losing much accuracy. 178 | 179 | 180 | 181 | 182 | 183 | 184 | 185 | 186 | 187 | 188 | 189 | 190 | 191 | 192 | 193 | 194 | 195 |
Model NameMean IOUTraining Time
DeepWV3Plus(no sdc-aug)81.4~14 hrs
DeepSRNX50V3PlusD_m1(no sdc-aug)80.0~9 hrs
196 | 197 | 198 | ## Reference 199 | 200 | If you find this implementation useful in your work, please acknowledge it appropriately and cite the paper or code accordingly: 201 | 202 | ``` 203 | @inproceedings{semantic_cvpr19, 204 | author = {Yi Zhu*, Karan Sapra*, Fitsum A. Reda, Kevin J. Shih, Shawn Newsam, Andrew Tao, Bryan Catanzaro}, 205 | title = {Improving Semantic Segmentation via Video Propagation and Label Relaxation}, 206 | booktitle = {IEEE Conference on Computer Vision and Pattern Recognition (CVPR)}, 207 | month = {June}, 208 | year = {2019}, 209 | url = {https://nv-adlr.github.io/publication/2018-Segmentation} 210 | } 211 | * indicates equal contribution 212 | 213 | @inproceedings{reda2018sdc, 214 | title={SDC-Net: Video prediction using spatially-displaced convolution}, 215 | author={Reda, Fitsum A and Liu, Guilin and Shih, Kevin J and Kirby, Robert and Barker, Jon and Tarjan, David and Tao, Andrew and Catanzaro, Bryan}, 216 | booktitle={Proceedings of the European Conference on Computer Vision (ECCV)}, 217 | pages={718--733}, 218 | year={2018} 219 | } 220 | ``` 221 | We encourage people to contribute to our code base and provide suggestions, point any issues, or solution using merge request, and we hope this repo is useful. 222 | 223 | ## Acknowledgments 224 | 225 | Parts of the code were heavily derived from [pytorch-semantic-segmentation](https://github.com/ZijunDeng/pytorch-semantic-segmentation), [inplace-abn](https://github.com/mapillary/inplace_abn), [Pytorch](https://github.com/pytorch/vision/blob/master/torchvision/models/resnet.py), [ClementPinard/FlowNetPytorch](https://github.com/ClementPinard/FlowNetPytorch), [NVIDIA/flownet2-pytorch](https://github.com/NVIDIA/flownet2-pytorch) and [Cadene](https://github.com/Cadene/pretrained-models.pytorch). 226 | 227 | Our initial models used SyncBN from [Synchronized Batch Norm](https://github.com/zhanghang1989/PyTorch-Encoding) but since then have been ported to [Apex SyncBN](https://github.com/NVIDIA/apex) developed by Jie Jiang. 228 | 229 | We would also like to thank Ming-Yu Liu and Peter Kontschieder. 230 | 231 | ## Coding style 232 | * 4 spaces for indentation rather than tabs 233 | * 100 character line length 234 | * PEP8 formatting 235 | -------------------------------------------------------------------------------- /config.py: -------------------------------------------------------------------------------- 1 | """ 2 | # Code adapted from: 3 | # https://github.com/facebookresearch/Detectron/blob/master/detectron/core/config.py 4 | 5 | Source License 6 | # Copyright (c) 2017-present, Facebook, Inc. 7 | # 8 | # Licensed under the Apache License, Version 2.0 (the "License"); 9 | # you may not use this file except in compliance with the License. 10 | # You may obtain a copy of the License at 11 | # 12 | # http://www.apache.org/licenses/LICENSE-2.0 13 | # 14 | # Unless required by applicable law or agreed to in writing, software 15 | # distributed under the License is distributed on an "AS IS" BASIS, 16 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 17 | # See the License for the specific language governing permissions and 18 | # limitations under the License. 19 | ############################################################################## 20 | # 21 | # Based on: 22 | # -------------------------------------------------------- 23 | # Fast R-CNN 24 | # Copyright (c) 2015 Microsoft 25 | # Licensed under The MIT License [see LICENSE for details] 26 | # Written by Ross Girshick 27 | # -------------------------------------------------------- 28 | """ 29 | ############################################################################## 30 | #Config 31 | ############################################################################## 32 | 33 | 34 | from __future__ import absolute_import 35 | from __future__ import division 36 | from __future__ import print_function 37 | from __future__ import unicode_literals 38 | 39 | 40 | import torch 41 | 42 | 43 | from utils.attr_dict import AttrDict 44 | 45 | 46 | __C = AttrDict() 47 | cfg = __C 48 | __C.EPOCH = 0 49 | # Use Class Uniform Sampling to give each class proper sampling 50 | __C.CLASS_UNIFORM_PCT = 0.0 51 | 52 | # Use class weighted loss per batch to increase loss for low pixel count classes per batch 53 | __C.BATCH_WEIGHTING = False 54 | 55 | # Border Relaxation Count 56 | __C.BORDER_WINDOW = 1 57 | # Number of epoch to use before turn off border restriction 58 | __C.REDUCE_BORDER_EPOCH = -1 59 | # Comma Seperated List of class id to relax 60 | __C.STRICTBORDERCLASS = None 61 | 62 | 63 | 64 | #Attribute Dictionary for Dataset 65 | __C.DATASET = AttrDict() 66 | #Cityscapes Dir Location 67 | __C.DATASET.CITYSCAPES_DIR = '' 68 | #SDC Augmented Cityscapes Dir Location 69 | __C.DATASET.CITYSCAPES_AUG_DIR = '' 70 | #Mapillary Dataset Dir Location 71 | __C.DATASET.MAPILLARY_DIR = '' 72 | #Kitti Dataset Dir Location 73 | __C.DATASET.KITTI_DIR = '' 74 | #SDC Augmented Kitti Dataset Dir Location 75 | __C.DATASET.KITTI_AUG_DIR = '' 76 | #Camvid Dataset Dir Location 77 | __C.DATASET.CAMVID_DIR = '' 78 | #Number of splits to support 79 | __C.DATASET.CV_SPLITS = 3 80 | 81 | 82 | __C.MODEL = AttrDict() 83 | __C.MODEL.BN = 'regularnorm' 84 | __C.MODEL.BNFUNC = None 85 | 86 | def assert_and_infer_cfg(args, make_immutable=True, train_mode=True): 87 | """Call this function in your script after you have finished setting all cfg 88 | values that are necessary (e.g., merging a config from a file, merging 89 | command line config options, etc.). By default, this function will also 90 | mark the global cfg as immutable to prevent changing the global cfg settings 91 | during script execution (which can lead to hard to debug errors or code 92 | that's harder to understand than is necessary). 93 | """ 94 | 95 | if hasattr(args, 'syncbn') and args.syncbn: 96 | if args.apex: 97 | import apex 98 | __C.MODEL.BN = 'apex-syncnorm' 99 | __C.MODEL.BNFUNC = apex.parallel.SyncBatchNorm 100 | else: 101 | raise Exception('No Support for SyncBN without Apex') 102 | else: 103 | __C.MODEL.BNFUNC = torch.nn.BatchNorm2d 104 | print('Using regular batch norm') 105 | 106 | if not train_mode: 107 | cfg.immutable(True) 108 | return 109 | if args.class_uniform_pct: 110 | cfg.CLASS_UNIFORM_PCT = args.class_uniform_pct 111 | 112 | if args.batch_weighting: 113 | __C.BATCH_WEIGHTING = True 114 | 115 | if args.jointwtborder: 116 | if args.strict_bdr_cls != '': 117 | __C.STRICTBORDERCLASS = [int(i) for i in args.strict_bdr_cls.split(",")] 118 | if args.rlx_off_epoch > -1: 119 | __C.REDUCE_BORDER_EPOCH = args.rlx_off_epoch 120 | 121 | if make_immutable: 122 | cfg.immutable(True) 123 | -------------------------------------------------------------------------------- /datasets/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | Dataset setup and loaders 3 | """ 4 | from datasets import cityscapes 5 | from datasets import mapillary 6 | from datasets import kitti 7 | from datasets import comma10k 8 | from datasets import camvid 9 | import torchvision.transforms as standard_transforms 10 | 11 | import transforms.joint_transforms as joint_transforms 12 | import transforms.transforms as extended_transforms 13 | from torch.utils.data import DataLoader 14 | 15 | 16 | def setup_loaders(args): 17 | """ 18 | Setup Data Loaders[Currently supports Cityscapes, Mapillary and ADE20kin] 19 | input: argument passed by the user 20 | return: training data loader, validation data loader loader, train_set 21 | """ 22 | 23 | if args.dataset == 'cityscapes': 24 | args.dataset_cls = cityscapes 25 | args.train_batch_size = args.bs_mult * args.ngpu 26 | if args.bs_mult_val > 0: 27 | args.val_batch_size = args.bs_mult_val * args.ngpu 28 | else: 29 | args.val_batch_size = args.bs_mult * args.ngpu 30 | elif args.dataset == 'mapillary': 31 | args.dataset_cls = mapillary 32 | args.train_batch_size = args.bs_mult * args.ngpu 33 | args.val_batch_size = 4 34 | elif args.dataset == 'ade20k': 35 | args.dataset_cls = ade20k 36 | args.train_batch_size = args.bs_mult * args.ngpu 37 | args.val_batch_size = 4 38 | elif args.dataset == 'kitti': 39 | args.dataset_cls = kitti 40 | args.train_batch_size = args.bs_mult * args.ngpu 41 | if args.bs_mult_val > 0: 42 | args.val_batch_size = args.bs_mult_val * args.ngpu 43 | else: 44 | args.val_batch_size = args.bs_mult * args.ngpu 45 | elif args.dataset == 'comma10k': 46 | args.dataset_cls = comma10k 47 | args.train_batch_size = args.bs_mult * args.ngpu 48 | if args.bs_mult_val > 0: 49 | args.val_batch_size = args.bs_mult_val * args.ngpu 50 | else: 51 | args.val_batch_size = args.bs_mult * args.ngpu 52 | elif args.dataset == 'camvid': 53 | args.dataset_cls = camvid 54 | args.train_batch_size = args.bs_mult * args.ngpu 55 | if args.bs_mult_val > 0: 56 | args.val_batch_size = args.bs_mult_val * args.ngpu 57 | else: 58 | args.val_batch_size = args.bs_mult * args.ngpu 59 | elif args.dataset == 'null_loader': 60 | args.dataset_cls = null_loader 61 | args.train_batch_size = args.bs_mult * args.ngpu 62 | if args.bs_mult_val > 0: 63 | args.val_batch_size = args.bs_mult_val * args.ngpu 64 | else: 65 | args.val_batch_size = args.bs_mult * args.ngpu 66 | else: 67 | raise Exception('Dataset {} is not supported'.format(args.dataset)) 68 | 69 | # Readjust batch size to mini-batch size for apex 70 | if args.apex: 71 | args.train_batch_size = args.bs_mult 72 | args.val_batch_size = args.bs_mult_val 73 | 74 | args.num_workers = 4 * args.ngpu 75 | if args.test_mode: 76 | args.num_workers = 1 77 | 78 | 79 | mean_std = ([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) 80 | 81 | # Geometric image transformations 82 | train_joint_transform_list = [ 83 | joint_transforms.RandomSizeAndCrop(args.crop_size, 84 | False, 85 | pre_size=args.pre_size, 86 | scale_min=args.scale_min, 87 | scale_max=args.scale_max, 88 | ignore_index=args.dataset_cls.ignore_label), 89 | joint_transforms.Resize(args.crop_size), 90 | joint_transforms.RandomHorizontallyFlip()] 91 | train_joint_transform = joint_transforms.Compose(train_joint_transform_list) 92 | 93 | # Image appearance transformations 94 | train_input_transform = [] 95 | if args.color_aug: 96 | train_input_transform += [extended_transforms.ColorJitter( 97 | brightness=args.color_aug, 98 | contrast=args.color_aug, 99 | saturation=args.color_aug, 100 | hue=args.color_aug)] 101 | 102 | if args.bblur: 103 | train_input_transform += [extended_transforms.RandomBilateralBlur()] 104 | elif args.gblur: 105 | train_input_transform += [extended_transforms.RandomGaussianBlur()] 106 | else: 107 | pass 108 | 109 | 110 | 111 | train_input_transform += [standard_transforms.ToTensor(), 112 | standard_transforms.Normalize(*mean_std)] 113 | train_input_transform = standard_transforms.Compose(train_input_transform) 114 | 115 | val_input_transform = standard_transforms.Compose([ 116 | standard_transforms.ToTensor(), 117 | standard_transforms.Normalize(*mean_std) 118 | ]) 119 | 120 | target_transform = extended_transforms.MaskToTensor() 121 | 122 | if args.jointwtborder: 123 | target_train_transform = extended_transforms.RelaxedBoundaryLossToTensor(args.dataset_cls.ignore_label, 124 | args.dataset_cls.num_classes) 125 | else: 126 | target_train_transform = extended_transforms.MaskToTensor() 127 | 128 | if args.dataset == 'cityscapes': 129 | city_mode = 'train' ## Can be trainval 130 | city_quality = 'fine' 131 | if args.class_uniform_pct: 132 | if args.coarse_boost_classes: 133 | coarse_boost_classes = \ 134 | [int(c) for c in args.coarse_boost_classes.split(',')] 135 | else: 136 | coarse_boost_classes = None 137 | train_set = args.dataset_cls.CityScapesUniform( 138 | city_quality, city_mode, args.maxSkip, 139 | joint_transform_list=train_joint_transform_list, 140 | transform=train_input_transform, 141 | target_transform=target_train_transform, 142 | dump_images=args.dump_augmentation_images, 143 | cv_split=args.cv, 144 | class_uniform_pct=args.class_uniform_pct, 145 | class_uniform_tile=args.class_uniform_tile, 146 | test=args.test_mode, 147 | coarse_boost_classes=coarse_boost_classes) 148 | else: 149 | train_set = args.dataset_cls.CityScapes( 150 | city_quality, city_mode, 0, 151 | joint_transform=train_joint_transform, 152 | transform=train_input_transform, 153 | target_transform=target_train_transform, 154 | dump_images=args.dump_augmentation_images, 155 | cv_split=args.cv) 156 | 157 | val_set = args.dataset_cls.CityScapes('fine', 'val', 0, 158 | transform=val_input_transform, 159 | target_transform=target_transform, 160 | cv_split=args.cv) 161 | elif args.dataset == 'mapillary': 162 | eval_size = 1536 163 | val_joint_transform_list = [ 164 | joint_transforms.ResizeHeight(eval_size), 165 | joint_transforms.CenterCropPad(eval_size, ignore_index=args.dataset_cls.ignore_label)] 166 | train_set = args.dataset_cls.Mapillary( 167 | 'semantic', 'train', 168 | joint_transform_list=train_joint_transform_list, 169 | transform=train_input_transform, 170 | target_transform=target_train_transform, 171 | dump_images=args.dump_augmentation_images, 172 | class_uniform_pct=args.class_uniform_pct, 173 | class_uniform_tile=args.class_uniform_tile, 174 | test=args.test_mode) 175 | val_set = args.dataset_cls.Mapillary( 176 | 'semantic', 'val', 177 | joint_transform_list=val_joint_transform_list, 178 | transform=val_input_transform, 179 | target_transform=target_transform, 180 | test=False) 181 | elif args.dataset == 'ade20k': 182 | eval_size = 384 183 | val_joint_transform_list = [ 184 | joint_transforms.ResizeHeight(eval_size), 185 | joint_transforms.CenterCropPad(eval_size)] 186 | 187 | train_set = args.dataset_cls.ade20k( 188 | 'semantic', 'train', 189 | joint_transform_list=train_joint_transform_list, 190 | transform=train_input_transform, 191 | target_transform=target_train_transform, 192 | dump_images=args.dump_augmentation_images, 193 | class_uniform_pct=args.class_uniform_pct, 194 | class_uniform_tile=args.class_uniform_tile, 195 | test=args.test_mode) 196 | val_set = args.dataset_cls.ade20k( 197 | 'semantic', 'val', 198 | joint_transform_list=val_joint_transform_list, 199 | transform=val_input_transform, 200 | target_transform=target_transform, 201 | test=False) 202 | elif args.dataset == 'kitti': 203 | # eval_size_h = 384 204 | # eval_size_w = 1280 205 | # val_joint_transform_list = [ 206 | # joint_transforms.ResizeHW(eval_size_h, eval_size_w)] 207 | 208 | train_set = args.dataset_cls.KITTI( 209 | 'semantic', 'train', args.maxSkip, 210 | joint_transform_list=train_joint_transform_list, 211 | transform=train_input_transform, 212 | target_transform=target_train_transform, 213 | dump_images=args.dump_augmentation_images, 214 | class_uniform_pct=args.class_uniform_pct, 215 | class_uniform_tile=args.class_uniform_tile, 216 | test=args.test_mode, 217 | cv_split=args.cv, 218 | scf=args.scf, 219 | hardnm=args.hardnm) 220 | val_set = args.dataset_cls.KITTI( 221 | 'semantic', 'trainval', 0, 222 | joint_transform_list=None, 223 | transform=val_input_transform, 224 | target_transform=target_transform, 225 | test=False, 226 | cv_split=args.cv, 227 | scf=None) 228 | elif args.dataset == 'comma10k': 229 | train_set = args.dataset_cls.COMMA10K( 230 | 'semantic', 'train', args.maxSkip, 231 | joint_transform_list=train_joint_transform_list, 232 | transform=train_input_transform, 233 | target_transform=target_train_transform, 234 | dump_images=args.dump_augmentation_images, 235 | class_uniform_pct=args.class_uniform_pct, 236 | class_uniform_tile=args.class_uniform_tile, 237 | test=args.test_mode, 238 | cv_split=args.cv, 239 | scf=args.scf, 240 | hardnm=args.hardnm) 241 | val_set = args.dataset_cls.COMMA10K( 242 | 'semantic', 'trainval', 0, 243 | joint_transform_list=None, 244 | transform=val_input_transform, 245 | target_transform=target_transform, 246 | test=False, 247 | cv_split=args.cv, 248 | scf=None) 249 | elif args.dataset == 'camvid': 250 | # eval_size_h = 384 251 | # eval_size_w = 1280 252 | # val_joint_transform_list = [ 253 | # joint_transforms.ResizeHW(eval_size_h, eval_size_w)] 254 | 255 | train_set = args.dataset_cls.CAMVID( 256 | 'semantic', 'trainval', args.maxSkip, 257 | joint_transform_list=train_joint_transform_list, 258 | transform=train_input_transform, 259 | target_transform=target_train_transform, 260 | dump_images=args.dump_augmentation_images, 261 | class_uniform_pct=args.class_uniform_pct, 262 | class_uniform_tile=args.class_uniform_tile, 263 | test=args.test_mode, 264 | cv_split=args.cv, 265 | scf=args.scf, 266 | hardnm=args.hardnm) 267 | val_set = args.dataset_cls.CAMVID( 268 | 'semantic', 'test', 0, 269 | joint_transform_list=None, 270 | transform=val_input_transform, 271 | target_transform=target_transform, 272 | test=False, 273 | cv_split=args.cv, 274 | scf=None) 275 | 276 | elif args.dataset == 'null_loader': 277 | train_set = args.dataset_cls.null_loader(args.crop_size) 278 | val_set = args.dataset_cls.null_loader(args.crop_size) 279 | else: 280 | raise Exception('Dataset {} is not supported'.format(args.dataset)) 281 | 282 | if args.apex: 283 | from datasets.sampler import DistributedSampler 284 | train_sampler = DistributedSampler(train_set, pad=True, permutation=True, consecutive_sample=False) 285 | val_sampler = DistributedSampler(val_set, pad=False, permutation=False, consecutive_sample=False) 286 | 287 | else: 288 | train_sampler = None 289 | val_sampler = None 290 | 291 | train_loader = DataLoader(train_set, batch_size=args.train_batch_size, 292 | num_workers=args.num_workers, shuffle=(train_sampler is None), drop_last=True, sampler = train_sampler) 293 | val_loader = DataLoader(val_set, batch_size=args.val_batch_size, 294 | num_workers=args.num_workers // 2 , shuffle=False, drop_last=False, sampler = val_sampler) 295 | 296 | return train_loader, val_loader, train_set 297 | 298 | -------------------------------------------------------------------------------- /datasets/camvid.py: -------------------------------------------------------------------------------- 1 | """ 2 | Camvid Dataset Loader 3 | """ 4 | 5 | import os 6 | import sys 7 | import numpy as np 8 | from PIL import Image 9 | from torch.utils import data 10 | import logging 11 | import datasets.uniform as uniform 12 | import json 13 | from config import cfg 14 | 15 | 16 | # trainid_to_name = cityscapes_labels.trainId2name 17 | # id_to_trainid = cityscapes_labels.label2trainid 18 | num_classes = 11 19 | ignore_label = 11 20 | root = cfg.DATASET.CAMVID_DIR 21 | 22 | palette = [128, 128, 128, 23 | 128, 0, 0, 24 | 192, 192, 128, 25 | 128, 64, 128, 26 | 0, 0, 192, 27 | 128, 128, 0, 28 | 192, 128, 128, 29 | 64, 64, 128, 30 | 64, 0, 128, 31 | 64, 64, 0, 32 | 0, 128, 192] 33 | 34 | 35 | CAMVID_CLASSES = ['Sky', 36 | 'Building', 37 | 'Column-Pole', 38 | 'Road', 39 | 'Sidewalk', 40 | 'Tree', 41 | 'Sign-Symbol', 42 | 'Fence', 43 | 'Car', 44 | 'Pedestrain', 45 | 'Bicyclist', 46 | 'Void'] 47 | 48 | CAMVID_CLASS_COLORS = [ 49 | (128, 128, 128), 50 | (128, 0, 0), 51 | (192, 192, 128), 52 | (128, 64, 128), 53 | (0, 0, 192), 54 | (128, 128, 0), 55 | (192, 128, 128), 56 | (64, 64, 128), 57 | (64, 0, 128), 58 | (64, 64, 0), 59 | (0, 128, 192), 60 | (0, 0, 0), 61 | ] 62 | 63 | zero_pad = 256 * 3 - len(palette) 64 | for i in range(zero_pad): 65 | palette.append(0) 66 | 67 | def colorize_mask(mask): 68 | # mask: numpy array of the mask 69 | new_mask = Image.fromarray(mask.astype(np.uint8)).convert('P') 70 | new_mask.putpalette(palette) 71 | return new_mask 72 | 73 | def add_items(img_path, mask_path, aug_img_path, aug_mask_path, mode, maxSkip): 74 | 75 | c_items = os.listdir(img_path) 76 | c_items.sort() 77 | items = [] 78 | aug_items = [] 79 | 80 | for it in c_items: 81 | item = (os.path.join(img_path, it), os.path.join(mask_path, it)) 82 | items.append(item) 83 | if mode != 'test' and maxSkip > 0: 84 | seq_info = it.split("_") 85 | cur_seq_id = seq_info[-1][:-4] 86 | 87 | if seq_info[0] == "0001TP": 88 | prev_seq_id = "%06d" % (int(cur_seq_id) - maxSkip) 89 | next_seq_id = "%06d" % (int(cur_seq_id) + maxSkip) 90 | elif seq_info[0] == "0006R0": 91 | prev_seq_id = "f%05d" % (int(cur_seq_id[1:]) - maxSkip) 92 | next_seq_id = "f%05d" % (int(cur_seq_id[1:]) + maxSkip) 93 | else: 94 | prev_seq_id = "%05d" % (int(cur_seq_id) - maxSkip) 95 | next_seq_id = "%05d" % (int(cur_seq_id) + maxSkip) 96 | 97 | prev_it = seq_info[0] + "_" + prev_seq_id + '.png' 98 | next_it = seq_info[0] + "_" + next_seq_id + '.png' 99 | 100 | prev_item = (os.path.join(aug_img_path, prev_it), os.path.join(aug_mask_path, prev_it)) 101 | next_item = (os.path.join(aug_img_path, next_it), os.path.join(aug_mask_path, next_it)) 102 | if os.path.isfile(prev_item[0]) and os.path.isfile(prev_item[1]): 103 | aug_items.append(prev_item) 104 | if os.path.isfile(next_item[0]) and os.path.isfile(next_item[1]): 105 | aug_items.append(next_item) 106 | return items, aug_items 107 | 108 | def make_dataset(quality, mode, maxSkip=0, cv_split=0, hardnm=0): 109 | 110 | items = [] 111 | aug_items = [] 112 | assert quality == 'semantic' 113 | assert mode in ['train', 'val', 'trainval', 'test'] 114 | 115 | # img_dir_name = "SegNet/CamVid" 116 | original_img_dir = "LargeScale/CamVid" 117 | augmented_img_dir = "camvid_aug3/CamVid" 118 | 119 | img_path = os.path.join(root, original_img_dir, 'train') 120 | mask_path = os.path.join(root, original_img_dir, 'trainannot') 121 | aug_img_path = os.path.join(root, augmented_img_dir, 'train') 122 | aug_mask_path = os.path.join(root, augmented_img_dir, 'trainannot') 123 | 124 | train_items, train_aug_items = add_items(img_path, mask_path, aug_img_path, aug_mask_path, mode, maxSkip) 125 | logging.info('Camvid has a total of {} train images'.format(len(train_items))) 126 | 127 | img_path = os.path.join(root, original_img_dir, 'val') 128 | mask_path = os.path.join(root, original_img_dir, 'valannot') 129 | aug_img_path = os.path.join(root, augmented_img_dir, 'val') 130 | aug_mask_path = os.path.join(root, augmented_img_dir, 'valannot') 131 | 132 | val_items, val_aug_items = add_items(img_path, mask_path, aug_img_path, aug_mask_path, mode, maxSkip) 133 | logging.info('Camvid has a total of {} validation images'.format(len(val_items))) 134 | 135 | if mode == 'test': 136 | img_path = os.path.join(root, original_img_dir, 'test') 137 | mask_path = os.path.join(root, original_img_dir, 'testannot') 138 | test_items, test_aug_items = add_items(img_path, mask_path, aug_img_path, aug_mask_path, mode, maxSkip) 139 | logging.info('Camvid has a total of {} test images'.format(len(test_items))) 140 | 141 | if mode == 'train': 142 | items = train_items 143 | elif mode == 'val': 144 | items = val_items 145 | elif mode == 'trainval': 146 | items = train_items + val_items 147 | aug_items = train_aug_items + val_aug_items 148 | elif mode == 'test': 149 | items = test_items 150 | aug_items = [] 151 | else: 152 | logging.info('Unknown mode {}'.format(mode)) 153 | sys.exit() 154 | 155 | logging.info('Camvid-{}: {} images'.format(mode, len(items))) 156 | 157 | return items, aug_items 158 | 159 | class CAMVID(data.Dataset): 160 | 161 | def __init__(self, quality, mode, maxSkip=0, joint_transform_list=None, 162 | transform=None, target_transform=None, dump_images=False, 163 | class_uniform_pct=0, class_uniform_tile=0, test=False, 164 | cv_split=None, scf=None, hardnm=0): 165 | 166 | self.quality = quality 167 | self.mode = mode 168 | self.maxSkip = maxSkip 169 | self.joint_transform_list = joint_transform_list 170 | self.transform = transform 171 | self.target_transform = target_transform 172 | self.dump_images = dump_images 173 | self.class_uniform_pct = class_uniform_pct 174 | self.class_uniform_tile = class_uniform_tile 175 | self.scf = scf 176 | self.hardnm = hardnm 177 | self.cv_split = cv_split 178 | self.centroids = [] 179 | 180 | self.imgs, self.aug_imgs = make_dataset(quality, mode, self.maxSkip, cv_split=self.cv_split, hardnm=self.hardnm) 181 | assert len(self.imgs), 'Found 0 images, please check the data set' 182 | 183 | # Centroids for GT data 184 | if self.class_uniform_pct > 0: 185 | json_fn = 'camvid_tile{}_cv{}_{}.json'.format(self.class_uniform_tile, self.cv_split, self.mode) 186 | 187 | if os.path.isfile(json_fn): 188 | with open(json_fn, 'r') as json_data: 189 | centroids = json.load(json_data) 190 | self.centroids = {int(idx): centroids[idx] for idx in centroids} 191 | else: 192 | self.centroids = uniform.class_centroids_all( 193 | self.imgs, 194 | num_classes, 195 | id2trainid=None, 196 | tile_size=class_uniform_tile) 197 | with open(json_fn, 'w') as outfile: 198 | json.dump(self.centroids, outfile, indent=4) 199 | 200 | self.fine_centroids = self.centroids.copy() 201 | 202 | if self.maxSkip > 0: 203 | json_fn = 'camvid_tile{}_cv{}_{}_skip{}.json'.format(self.class_uniform_tile, self.cv_split, self.mode, self.maxSkip) 204 | if os.path.isfile(json_fn): 205 | with open(json_fn, 'r') as json_data: 206 | centroids = json.load(json_data) 207 | self.aug_centroids = {int(idx): centroids[idx] for idx in centroids} 208 | else: 209 | self.aug_centroids = uniform.class_centroids_all( 210 | self.aug_imgs, 211 | num_classes, 212 | id2trainid=None, 213 | tile_size=class_uniform_tile) 214 | with open(json_fn, 'w') as outfile: 215 | json.dump(self.aug_centroids, outfile, indent=4) 216 | 217 | for class_id in range(num_classes): 218 | self.centroids[class_id].extend(self.aug_centroids[class_id]) 219 | 220 | self.build_epoch() 221 | 222 | def build_epoch(self, cut=False): 223 | 224 | if self.class_uniform_pct > 0: 225 | if cut: 226 | self.imgs_uniform = uniform.build_epoch(self.imgs, 227 | self.fine_centroids, 228 | num_classes, 229 | cfg.CLASS_UNIFORM_PCT) 230 | else: 231 | self.imgs_uniform = uniform.build_epoch(self.imgs, 232 | self.centroids, 233 | num_classes, 234 | cfg.CLASS_UNIFORM_PCT) 235 | else: 236 | self.imgs_uniform = self.imgs 237 | 238 | 239 | def __getitem__(self, index): 240 | elem = self.imgs_uniform[index] 241 | centroid = None 242 | if len(elem) == 4: 243 | img_path, mask_path, centroid, class_id = elem 244 | else: 245 | img_path, mask_path = elem 246 | 247 | img, mask = Image.open(img_path).convert('RGB'), Image.open(mask_path) 248 | img_name = os.path.splitext(os.path.basename(img_path))[0] 249 | 250 | # Image Transformations 251 | if self.joint_transform_list is not None: 252 | for idx, xform in enumerate(self.joint_transform_list): 253 | if idx == 0 and centroid is not None: 254 | # HACK 255 | # We assume that the first transform is capable of taking 256 | # in a centroid 257 | img, mask = xform(img, mask, centroid) 258 | else: 259 | img, mask = xform(img, mask) 260 | 261 | # Debug 262 | if self.dump_images and centroid is not None: 263 | outdir = './dump_imgs_{}'.format(self.mode) 264 | os.makedirs(outdir, exist_ok=True) 265 | dump_img_name = trainid_to_name[class_id] + '_' + img_name 266 | out_img_fn = os.path.join(outdir, dump_img_name + '.png') 267 | out_msk_fn = os.path.join(outdir, dump_img_name + '_mask.png') 268 | mask_img = colorize_mask(np.array(mask)) 269 | img.save(out_img_fn) 270 | mask_img.save(out_msk_fn) 271 | 272 | if self.transform is not None: 273 | img = self.transform(img) 274 | if self.target_transform is not None: 275 | mask = self.target_transform(mask) 276 | 277 | return img, mask, img_name 278 | 279 | def __len__(self): 280 | return len(self.imgs_uniform) 281 | 282 | 283 | -------------------------------------------------------------------------------- /datasets/cityscapes_labels.py: -------------------------------------------------------------------------------- 1 | """ 2 | # File taken from https://github.com/mcordts/cityscapesScripts/ 3 | # License File Available at: 4 | # https://github.com/mcordts/cityscapesScripts/blob/master/license.txt 5 | 6 | # ---------------------- 7 | # The Cityscapes Dataset 8 | # ---------------------- 9 | # 10 | # 11 | # License agreement 12 | # ----------------- 13 | # 14 | # This dataset is made freely available to academic and non-academic entities for non-commercial purposes such as academic research, teaching, scientific publications, or personal experimentation. Permission is granted to use the data given that you agree: 15 | # 16 | # 1. That the dataset comes "AS IS", without express or implied warranty. Although every effort has been made to ensure accuracy, we (Daimler AG, MPI Informatics, TU Darmstadt) do not accept any responsibility for errors or omissions. 17 | # 2. That you include a reference to the Cityscapes Dataset in any work that makes use of the dataset. For research papers, cite our preferred publication as listed on our website; for other media cite our preferred publication as listed on our website or link to the Cityscapes website. 18 | # 3. That you do not distribute this dataset or modified versions. It is permissible to distribute derivative works in as far as they are abstract representations of this dataset (such as models trained on it or additional annotations that do not directly include any of our data) and do not allow to recover the dataset or something similar in character. 19 | # 4. That you may not use the dataset or any derivative work for commercial purposes as, for example, licensing or selling the data, or using the data with a purpose to procure a commercial gain. 20 | # 5. That all rights not expressly granted to you are reserved by us (Daimler AG, MPI Informatics, TU Darmstadt). 21 | # 22 | # 23 | # Contact 24 | # ------- 25 | # 26 | # Marius Cordts, Mohamed Omran 27 | # www.cityscapes-dataset.net 28 | 29 | """ 30 | from collections import namedtuple 31 | 32 | 33 | #-------------------------------------------------------------------------------- 34 | # Definitions 35 | #-------------------------------------------------------------------------------- 36 | 37 | # a label and all meta information 38 | Label = namedtuple( 'Label' , [ 39 | 40 | 'name' , # The identifier of this label, e.g. 'car', 'person', ... . 41 | # We use them to uniquely name a class 42 | 43 | 'id' , # An integer ID that is associated with this label. 44 | # The IDs are used to represent the label in ground truth images 45 | # An ID of -1 means that this label does not have an ID and thus 46 | # is ignored when creating ground truth images (e.g. license plate). 47 | # Do not modify these IDs, since exactly these IDs are expected by the 48 | # evaluation server. 49 | 50 | 'trainId' , # Feel free to modify these IDs as suitable for your method. Then create 51 | # ground truth images with train IDs, using the tools provided in the 52 | # 'preparation' folder. However, make sure to validate or submit results 53 | # to our evaluation server using the regular IDs above! 54 | # For trainIds, multiple labels might have the same ID. Then, these labels 55 | # are mapped to the same class in the ground truth images. For the inverse 56 | # mapping, we use the label that is defined first in the list below. 57 | # For example, mapping all void-type classes to the same ID in training, 58 | # might make sense for some approaches. 59 | # Max value is 255! 60 | 61 | 'category' , # The name of the category that this label belongs to 62 | 63 | 'categoryId' , # The ID of this category. Used to create ground truth images 64 | # on category level. 65 | 66 | 'hasInstances', # Whether this label distinguishes between single instances or not 67 | 68 | 'ignoreInEval', # Whether pixels having this class as ground truth label are ignored 69 | # during evaluations or not 70 | 71 | 'color' , # The color of this label 72 | ] ) 73 | 74 | 75 | #-------------------------------------------------------------------------------- 76 | # A list of all labels 77 | #-------------------------------------------------------------------------------- 78 | 79 | # Please adapt the train IDs as appropriate for you approach. 80 | # Note that you might want to ignore labels with ID 255 during training. 81 | # Further note that the current train IDs are only a suggestion. You can use whatever you like. 82 | # Make sure to provide your results using the original IDs and not the training IDs. 83 | # Note that many IDs are ignored in evaluation and thus you never need to predict these! 84 | 85 | labels = [ 86 | # name id trainId category catId hasInstances ignoreInEval color 87 | Label( 'unlabeled' , 0 , 255 , 'void' , 0 , False , True , ( 0, 0, 0) ), 88 | Label( 'ego vehicle' , 1 , 255 , 'void' , 0 , False , True , ( 0, 0, 0) ), 89 | Label( 'rectification border' , 2 , 255 , 'void' , 0 , False , True , ( 0, 0, 0) ), 90 | Label( 'out of roi' , 3 , 255 , 'void' , 0 , False , True , ( 0, 0, 0) ), 91 | Label( 'static' , 4 , 255 , 'void' , 0 , False , True , ( 0, 0, 0) ), 92 | Label( 'dynamic' , 5 , 255 , 'void' , 0 , False , True , (111, 74, 0) ), 93 | Label( 'ground' , 6 , 255 , 'void' , 0 , False , True , ( 81, 0, 81) ), 94 | Label( 'road' , 7 , 0 , 'flat' , 1 , False , False , (128, 64,128) ), 95 | Label( 'sidewalk' , 8 , 1 , 'flat' , 1 , False , False , (244, 35,232) ), 96 | Label( 'parking' , 9 , 255 , 'flat' , 1 , False , True , (250,170,160) ), 97 | Label( 'rail track' , 10 , 255 , 'flat' , 1 , False , True , (230,150,140) ), 98 | Label( 'building' , 11 , 2 , 'construction' , 2 , False , False , ( 70, 70, 70) ), 99 | Label( 'wall' , 12 , 3 , 'construction' , 2 , False , False , (102,102,156) ), 100 | Label( 'fence' , 13 , 4 , 'construction' , 2 , False , False , (190,153,153) ), 101 | Label( 'guard rail' , 14 , 255 , 'construction' , 2 , False , True , (180,165,180) ), 102 | Label( 'bridge' , 15 , 255 , 'construction' , 2 , False , True , (150,100,100) ), 103 | Label( 'tunnel' , 16 , 255 , 'construction' , 2 , False , True , (150,120, 90) ), 104 | Label( 'pole' , 17 , 5 , 'object' , 3 , False , False , (153,153,153) ), 105 | Label( 'polegroup' , 18 , 255 , 'object' , 3 , False , True , (153,153,153) ), 106 | Label( 'traffic light' , 19 , 6 , 'object' , 3 , False , False , (250,170, 30) ), 107 | Label( 'traffic sign' , 20 , 7 , 'object' , 3 , False , False , (220,220, 0) ), 108 | Label( 'vegetation' , 21 , 8 , 'nature' , 4 , False , False , (107,142, 35) ), 109 | Label( 'terrain' , 22 , 9 , 'nature' , 4 , False , False , (152,251,152) ), 110 | Label( 'sky' , 23 , 10 , 'sky' , 5 , False , False , ( 70,130,180) ), 111 | Label( 'person' , 24 , 11 , 'human' , 6 , True , False , (220, 20, 60) ), 112 | Label( 'rider' , 25 , 12 , 'human' , 6 , True , False , (255, 0, 0) ), 113 | Label( 'car' , 26 , 13 , 'vehicle' , 7 , True , False , ( 0, 0,142) ), 114 | Label( 'truck' , 27 , 14 , 'vehicle' , 7 , True , False , ( 0, 0, 70) ), 115 | Label( 'bus' , 28 , 15 , 'vehicle' , 7 , True , False , ( 0, 60,100) ), 116 | Label( 'caravan' , 29 , 255 , 'vehicle' , 7 , True , True , ( 0, 0, 90) ), 117 | Label( 'trailer' , 30 , 255 , 'vehicle' , 7 , True , True , ( 0, 0,110) ), 118 | Label( 'train' , 31 , 16 , 'vehicle' , 7 , True , False , ( 0, 80,100) ), 119 | Label( 'motorcycle' , 32 , 17 , 'vehicle' , 7 , True , False , ( 0, 0,230) ), 120 | Label( 'bicycle' , 33 , 18 , 'vehicle' , 7 , True , False , (119, 11, 32) ), 121 | Label( 'license plate' , -1 , -1 , 'vehicle' , 7 , False , True , ( 0, 0,142) ), 122 | ] 123 | 124 | 125 | #-------------------------------------------------------------------------------- 126 | # Create dictionaries for a fast lookup 127 | #-------------------------------------------------------------------------------- 128 | 129 | # Please refer to the main method below for example usages! 130 | 131 | # name to label object 132 | name2label = { label.name : label for label in labels } 133 | # id to label object 134 | id2label = { label.id : label for label in labels } 135 | # trainId to label object 136 | trainId2label = { label.trainId : label for label in reversed(labels) } 137 | # label2trainid 138 | label2trainid = { label.id : label.trainId for label in labels } 139 | # trainId to label object 140 | trainId2name = { label.trainId : label.name for label in labels } 141 | trainId2color = { label.trainId : label.color for label in labels } 142 | # category to list of label objects 143 | category2labels = {} 144 | for label in labels: 145 | category = label.category 146 | if category in category2labels: 147 | category2labels[category].append(label) 148 | else: 149 | category2labels[category] = [label] 150 | 151 | #-------------------------------------------------------------------------------- 152 | # Assure single instance name 153 | #-------------------------------------------------------------------------------- 154 | 155 | # returns the label name that describes a single instance (if possible) 156 | # e.g. input | output 157 | # ---------------------- 158 | # car | car 159 | # cargroup | car 160 | # foo | None 161 | # foogroup | None 162 | # skygroup | None 163 | def assureSingleInstanceName( name ): 164 | # if the name is known, it is not a group 165 | if name in name2label: 166 | return name 167 | # test if the name actually denotes a group 168 | if not name.endswith("group"): 169 | return None 170 | # remove group 171 | name = name[:-len("group")] 172 | # test if the new name exists 173 | if not name in name2label: 174 | return None 175 | # test if the new name denotes a label that actually has instances 176 | if not name2label[name].hasInstances: 177 | return None 178 | # all good then 179 | return name 180 | 181 | #-------------------------------------------------------------------------------- 182 | # Main for testing 183 | #-------------------------------------------------------------------------------- 184 | 185 | # just a dummy main 186 | if __name__ == "__main__": 187 | # Print all the labels 188 | print("List of cityscapes labels:") 189 | print("") 190 | print((" {:>21} | {:>3} | {:>7} | {:>14} | {:>10} | {:>12} | {:>12}".format( 'name', 'id', 'trainId', 'category', 'categoryId', 'hasInstances', 'ignoreInEval' ))) 191 | print((" " + ('-' * 98))) 192 | for label in labels: 193 | print((" {:>21} | {:>3} | {:>7} | {:>14} | {:>10} | {:>12} | {:>12}".format( label.name, label.id, label.trainId, label.category, label.categoryId, label.hasInstances, label.ignoreInEval ))) 194 | print("") 195 | 196 | print("Example usages:") 197 | 198 | # Map from name to label 199 | name = 'car' 200 | id = name2label[name].id 201 | print(("ID of label '{name}': {id}".format( name=name, id=id ))) 202 | 203 | # Map from ID to label 204 | category = id2label[id].category 205 | print(("Category of label with ID '{id}': {category}".format( id=id, category=category ))) 206 | 207 | # Map from trainID to label 208 | trainId = 0 209 | name = trainId2label[trainId].name 210 | print(("Name of label with trainID '{id}': {name}".format( id=trainId, name=name ))) 211 | -------------------------------------------------------------------------------- /datasets/comma10k.py: -------------------------------------------------------------------------------- 1 | """ 2 | KITTI Dataset Loader 3 | http://www.cvlibs.net/datasets/kitti/eval_semseg.php?benchmark=semantics2015 4 | """ 5 | 6 | import os 7 | import sys 8 | import numpy as np 9 | from PIL import Image 10 | from torch.utils import data 11 | import logging 12 | import datasets.uniform as uniform 13 | import datasets.cityscapes_labels as cityscapes_labels 14 | import json 15 | from config import cfg 16 | 17 | 18 | trainid_to_name = cityscapes_labels.trainId2name 19 | #id_to_trainid = cityscapes_labels.label2trainid 20 | 21 | 22 | id_to_trainid = { 23 | 0x40: 0, # road 24 | 0xff: 1, # lane marking 25 | 0x80: 2, # sky/undrivable 26 | 0x00: 3, # car 27 | 0xcc: 4} # ego 28 | 29 | num_classes = 5 30 | ignore_label = 255 31 | root = cfg.DATASET.KITTI_DIR 32 | aug_root = cfg.DATASET.KITTI_AUG_DIR 33 | 34 | palette = [128, 64, 128, 244, 35, 232, 70, 70, 70, 102, 102, 156, 190, 153, 153, 35 | 153, 153, 153, 250, 170, 30, 36 | 220, 220, 0, 107, 142, 35, 152, 251, 152, 70, 130, 180, 220, 20, 60, 37 | 255, 0, 0, 0, 0, 142, 0, 0, 70, 38 | 0, 60, 100, 0, 80, 100, 0, 0, 230, 119, 11, 32] 39 | zero_pad = 256 * 3 - len(palette) 40 | for i in range(zero_pad): 41 | palette.append(0) 42 | 43 | def colorize_mask(mask): 44 | # mask: numpy array of the mask 45 | new_mask = Image.fromarray(mask.astype(np.uint8)).convert('P') 46 | new_mask.putpalette(palette) 47 | return new_mask 48 | 49 | def get_train_val(cv_split, all_items): 50 | # 90/10 train/val split, three random splits for cross validation 51 | val_0 = [1,5,11,29,35,49,57,68,72,82,93,115,119,130,145,154,156,167,169,189,198] 52 | val_1 = [0,12,24,31,42,50,63,71,84,96,101,112,121,133,141,155,164,171,187,191,197] 53 | val_2 = [3,6,13,21,41,54,61,73,88,91,110,121,126,131,142,149,150,163,173,183,199] 54 | 55 | train_set = [] 56 | val_set = [] 57 | 58 | if cv_split == 0: 59 | for i in range(200): 60 | if i in val_0: 61 | val_set.append(all_items[i]) 62 | else: 63 | train_set.append(all_items[i]) 64 | elif cv_split == 1: 65 | for i in range(200): 66 | if i in val_1: 67 | val_set.append(all_items[i]) 68 | else: 69 | train_set.append(all_items[i]) 70 | elif cv_split == 2: 71 | for i in range(200): 72 | if i in val_2: 73 | val_set.append(all_items[i]) 74 | else: 75 | train_set.append(all_items[i]) 76 | else: 77 | logging.info('Unknown cv_split {}'.format(cv_split)) 78 | sys.exit() 79 | 80 | return train_set, val_set 81 | 82 | def make_dataset(quality, mode, maxSkip=0, cv_split=0, hardnm=0): 83 | items = [] 84 | all_items = [] 85 | aug_items = [] 86 | 87 | assert quality == 'semantic' 88 | assert mode in ['train', 'val', 'trainval'] 89 | # note that train and val are randomly determined, no official split 90 | 91 | """ 92 | img_dir_name = "training" 93 | img_path = os.path.join(root, img_dir_name, 'image_2') 94 | mask_path = os.path.join(root, img_dir_name, 'semantic') 95 | """ 96 | 97 | img_path = "/raid/comma10k/imgs" 98 | mask_path = "/raid/comma10k/masks" 99 | 100 | c_items = os.listdir(img_path) 101 | c_items.sort() 102 | 103 | for it in c_items: 104 | item = (os.path.join(img_path, it), os.path.join(mask_path, it)) 105 | all_items.append(item) 106 | logging.info('KITTI has a total of {} images'.format(len(all_items))) 107 | 108 | # split into train/val 109 | train_set, val_set = get_train_val(cv_split, all_items) 110 | 111 | if mode == 'train': 112 | items = train_set 113 | elif mode == 'val': 114 | items = val_set 115 | elif mode == 'trainval': 116 | items = train_set + val_set 117 | else: 118 | logging.info('Unknown mode {}'.format(mode)) 119 | sys.exit() 120 | 121 | logging.info('KITTI-{}: {} images'.format(mode, len(items))) 122 | 123 | return items, aug_items 124 | 125 | def make_test_dataset(quality, mode, maxSkip=0, cv_split=0): 126 | items = [] 127 | assert quality == 'semantic' 128 | assert mode == 'test' 129 | 130 | img_dir_name = "testing" 131 | img_path = os.path.join(root, img_dir_name, 'image_2') 132 | 133 | c_items = os.listdir(img_path) 134 | c_items.sort() 135 | for it in c_items: 136 | item = (os.path.join(img_path, it), None) 137 | items.append(item) 138 | logging.info('KITTI has a total of {} test images'.format(len(items))) 139 | 140 | return items, [] 141 | 142 | class COMMA10K(data.Dataset): 143 | 144 | def __init__(self, quality, mode, maxSkip=0, joint_transform_list=None, 145 | transform=None, target_transform=None, dump_images=False, 146 | class_uniform_pct=0, class_uniform_tile=0, test=False, 147 | cv_split=None, scf=None, hardnm=0): 148 | 149 | self.quality = quality 150 | self.mode = mode 151 | self.maxSkip = maxSkip 152 | self.joint_transform_list = joint_transform_list 153 | self.transform = transform 154 | self.target_transform = target_transform 155 | self.dump_images = dump_images 156 | self.class_uniform_pct = class_uniform_pct 157 | self.class_uniform_tile = class_uniform_tile 158 | self.scf = scf 159 | self.hardnm = hardnm 160 | 161 | if cv_split: 162 | self.cv_split = cv_split 163 | assert cv_split < cfg.DATASET.CV_SPLITS, \ 164 | 'expected cv_split {} to be < CV_SPLITS {}'.format( 165 | cv_split, cfg.DATASET.CV_SPLITS) 166 | else: 167 | self.cv_split = 0 168 | 169 | if self.mode == 'test': 170 | self.imgs, _ = make_test_dataset(quality, mode, self.maxSkip, cv_split=self.cv_split) 171 | else: 172 | self.imgs, _ = make_dataset(quality, mode, self.maxSkip, cv_split=self.cv_split, hardnm=self.hardnm) 173 | assert len(self.imgs), 'Found 0 images, please check the data set' 174 | 175 | # Centroids for GT data 176 | if self.class_uniform_pct > 0: 177 | if self.scf: 178 | json_fn = 'kitti_tile{}_cv{}_scf.json'.format(self.class_uniform_tile, self.cv_split) 179 | else: 180 | json_fn = 'kitti_tile{}_cv{}_{}_hardnm{}.json'.format(self.class_uniform_tile, self.cv_split, self.mode, self.hardnm) 181 | if os.path.isfile(json_fn): 182 | with open(json_fn, 'r') as json_data: 183 | centroids = json.load(json_data) 184 | self.centroids = {int(idx): centroids[idx] for idx in centroids} 185 | else: 186 | if self.scf: 187 | self.centroids = kitti_uniform.class_centroids_all( 188 | self.imgs, 189 | num_classes, 190 | id2trainid=id_to_trainid, 191 | tile_size=class_uniform_tile) 192 | else: 193 | self.centroids = uniform.class_centroids_all( 194 | self.imgs, 195 | num_classes, 196 | id2trainid=id_to_trainid, 197 | tile_size=class_uniform_tile) 198 | with open(json_fn, 'w') as outfile: 199 | json.dump(self.centroids, outfile, indent=4) 200 | 201 | self.build_epoch() 202 | 203 | def build_epoch(self, cut=False): 204 | if self.class_uniform_pct > 0: 205 | self.imgs_uniform = uniform.build_epoch(self.imgs, 206 | self.centroids, 207 | num_classes, 208 | cfg.CLASS_UNIFORM_PCT) 209 | else: 210 | self.imgs_uniform = self.imgs 211 | 212 | def __getitem__(self, index): 213 | elem = self.imgs_uniform[index] 214 | centroid = None 215 | if len(elem) == 4: 216 | img_path, mask_path, centroid, class_id = elem 217 | else: 218 | img_path, mask_path = elem 219 | 220 | if self.mode == 'test': 221 | img, mask = Image.open(img_path).convert('RGB'), None 222 | else: 223 | img, mask = Image.open(img_path).convert('RGB'), Image.open(mask_path) 224 | img_name = os.path.splitext(os.path.basename(img_path))[0] 225 | 226 | # kitti scale correction factor 227 | if self.mode == 'train' or self.mode == 'trainval': 228 | if self.scf: 229 | width, height = img.size 230 | img = img.resize((width*2, height*2), Image.BICUBIC) 231 | mask = mask.resize((width*2, height*2), Image.NEAREST) 232 | elif self.mode == 'val': 233 | width, height = 1242, 376 234 | img = img.resize((width, height), Image.BICUBIC) 235 | mask = mask.resize((width, height), Image.NEAREST) 236 | elif self.mode == 'test': 237 | img_keepsize = img.copy() 238 | width, height = 1280, 384 239 | img = img.resize((width, height), Image.BICUBIC) 240 | else: 241 | logging.info('Unknown mode {}'.format(mode)) 242 | sys.exit() 243 | 244 | if self.mode != 'test': 245 | mask = np.array(mask)[:, :, 0] 246 | mask_copy = mask.copy() 247 | 248 | for k, v in id_to_trainid.items(): 249 | mask_copy[mask == k] = v 250 | mask = Image.fromarray(mask_copy.astype(np.uint8)) 251 | 252 | # Image Transformations 253 | if self.joint_transform_list is not None: 254 | for idx, xform in enumerate(self.joint_transform_list): 255 | if idx == 0 and centroid is not None: 256 | # HACK 257 | # We assume that the first transform is capable of taking 258 | # in a centroid 259 | img, mask = xform(img, mask, centroid) 260 | else: 261 | img, mask = xform(img, mask) 262 | 263 | # Debug 264 | if self.dump_images and centroid is not None: 265 | outdir = './dump_imgs_{}'.format(self.mode) 266 | os.makedirs(outdir, exist_ok=True) 267 | dump_img_name = trainid_to_name[class_id] + '_' + img_name 268 | out_img_fn = os.path.join(outdir, dump_img_name + '.png') 269 | out_msk_fn = os.path.join(outdir, dump_img_name + '_mask.png') 270 | mask_img = colorize_mask(np.array(mask)) 271 | img.save(out_img_fn) 272 | mask_img.save(out_msk_fn) 273 | 274 | if self.transform is not None: 275 | img = self.transform(img) 276 | if self.mode == 'test': 277 | img_keepsize = self.transform(img_keepsize) 278 | mask = img_keepsize 279 | if self.target_transform is not None: 280 | if self.mode != 'test': 281 | mask = self.target_transform(mask) 282 | 283 | return img, mask, img_name 284 | 285 | def __len__(self): 286 | return len(self.imgs_uniform) 287 | -------------------------------------------------------------------------------- /datasets/kitti.py: -------------------------------------------------------------------------------- 1 | """ 2 | KITTI Dataset Loader 3 | http://www.cvlibs.net/datasets/kitti/eval_semseg.php?benchmark=semantics2015 4 | """ 5 | 6 | import os 7 | import sys 8 | import numpy as np 9 | from PIL import Image 10 | from torch.utils import data 11 | import logging 12 | import datasets.uniform as uniform 13 | import datasets.cityscapes_labels as cityscapes_labels 14 | import json 15 | from config import cfg 16 | 17 | 18 | trainid_to_name = cityscapes_labels.trainId2name 19 | id_to_trainid = cityscapes_labels.label2trainid 20 | num_classes = 19 21 | ignore_label = 255 22 | root = cfg.DATASET.KITTI_DIR 23 | aug_root = cfg.DATASET.KITTI_AUG_DIR 24 | 25 | palette = [128, 64, 128, 244, 35, 232, 70, 70, 70, 102, 102, 156, 190, 153, 153, 26 | 153, 153, 153, 250, 170, 30, 27 | 220, 220, 0, 107, 142, 35, 152, 251, 152, 70, 130, 180, 220, 20, 60, 28 | 255, 0, 0, 0, 0, 142, 0, 0, 70, 29 | 0, 60, 100, 0, 80, 100, 0, 0, 230, 119, 11, 32] 30 | zero_pad = 256 * 3 - len(palette) 31 | for i in range(zero_pad): 32 | palette.append(0) 33 | 34 | def colorize_mask(mask): 35 | # mask: numpy array of the mask 36 | new_mask = Image.fromarray(mask.astype(np.uint8)).convert('P') 37 | new_mask.putpalette(palette) 38 | return new_mask 39 | 40 | def get_train_val(cv_split, all_items): 41 | # 90/10 train/val split, three random splits for cross validation 42 | val_0 = [1,5,11,29,35,49,57,68,72,82,93,115,119,130,145,154,156,167,169,189,198] 43 | val_1 = [0,12,24,31,42,50,63,71,84,96,101,112,121,133,141,155,164,171,187,191,197] 44 | val_2 = [3,6,13,21,41,54,61,73,88,91,110,121,126,131,142,149,150,163,173,183,199] 45 | 46 | train_set = [] 47 | val_set = [] 48 | 49 | if cv_split == 0: 50 | for i in range(200): 51 | if i in val_0: 52 | val_set.append(all_items[i]) 53 | else: 54 | train_set.append(all_items[i]) 55 | elif cv_split == 1: 56 | for i in range(200): 57 | if i in val_1: 58 | val_set.append(all_items[i]) 59 | else: 60 | train_set.append(all_items[i]) 61 | elif cv_split == 2: 62 | for i in range(200): 63 | if i in val_2: 64 | val_set.append(all_items[i]) 65 | else: 66 | train_set.append(all_items[i]) 67 | else: 68 | logging.info('Unknown cv_split {}'.format(cv_split)) 69 | sys.exit() 70 | 71 | return train_set, val_set 72 | 73 | def make_dataset(quality, mode, maxSkip=0, cv_split=0, hardnm=0): 74 | items = [] 75 | all_items = [] 76 | aug_items = [] 77 | 78 | assert quality == 'semantic' 79 | assert mode in ['train', 'val', 'trainval'] 80 | # note that train and val are randomly determined, no official split 81 | 82 | img_dir_name = "training" 83 | img_path = os.path.join(root, img_dir_name, 'image_2') 84 | mask_path = os.path.join(root, img_dir_name, 'semantic') 85 | 86 | c_items = os.listdir(img_path) 87 | c_items.sort() 88 | 89 | for it in c_items: 90 | item = (os.path.join(img_path, it), os.path.join(mask_path, it)) 91 | all_items.append(item) 92 | logging.info('KITTI has a total of {} images'.format(len(all_items))) 93 | 94 | # split into train/val 95 | train_set, val_set = get_train_val(cv_split, all_items) 96 | 97 | if mode == 'train': 98 | items = train_set 99 | elif mode == 'val': 100 | items = val_set 101 | elif mode == 'trainval': 102 | items = train_set + val_set 103 | else: 104 | logging.info('Unknown mode {}'.format(mode)) 105 | sys.exit() 106 | 107 | logging.info('KITTI-{}: {} images'.format(mode, len(items))) 108 | 109 | return items, aug_items 110 | 111 | def make_test_dataset(quality, mode, maxSkip=0, cv_split=0): 112 | items = [] 113 | assert quality == 'semantic' 114 | assert mode == 'test' 115 | 116 | img_dir_name = "testing" 117 | img_path = os.path.join(root, img_dir_name, 'image_2') 118 | 119 | c_items = os.listdir(img_path) 120 | c_items.sort() 121 | for it in c_items: 122 | item = (os.path.join(img_path, it), None) 123 | items.append(item) 124 | logging.info('KITTI has a total of {} test images'.format(len(items))) 125 | 126 | return items, [] 127 | 128 | class KITTI(data.Dataset): 129 | 130 | def __init__(self, quality, mode, maxSkip=0, joint_transform_list=None, 131 | transform=None, target_transform=None, dump_images=False, 132 | class_uniform_pct=0, class_uniform_tile=0, test=False, 133 | cv_split=None, scf=None, hardnm=0): 134 | 135 | self.quality = quality 136 | self.mode = mode 137 | self.maxSkip = maxSkip 138 | self.joint_transform_list = joint_transform_list 139 | self.transform = transform 140 | self.target_transform = target_transform 141 | self.dump_images = dump_images 142 | self.class_uniform_pct = class_uniform_pct 143 | self.class_uniform_tile = class_uniform_tile 144 | self.scf = scf 145 | self.hardnm = hardnm 146 | 147 | if cv_split: 148 | self.cv_split = cv_split 149 | assert cv_split < cfg.DATASET.CV_SPLITS, \ 150 | 'expected cv_split {} to be < CV_SPLITS {}'.format( 151 | cv_split, cfg.DATASET.CV_SPLITS) 152 | else: 153 | self.cv_split = 0 154 | 155 | if self.mode == 'test': 156 | self.imgs, _ = make_test_dataset(quality, mode, self.maxSkip, cv_split=self.cv_split) 157 | else: 158 | self.imgs, _ = make_dataset(quality, mode, self.maxSkip, cv_split=self.cv_split, hardnm=self.hardnm) 159 | assert len(self.imgs), 'Found 0 images, please check the data set' 160 | 161 | # Centroids for GT data 162 | if self.class_uniform_pct > 0: 163 | if self.scf: 164 | json_fn = 'kitti_tile{}_cv{}_scf.json'.format(self.class_uniform_tile, self.cv_split) 165 | else: 166 | json_fn = 'kitti_tile{}_cv{}_{}_hardnm{}.json'.format(self.class_uniform_tile, self.cv_split, self.mode, self.hardnm) 167 | if os.path.isfile(json_fn): 168 | with open(json_fn, 'r') as json_data: 169 | centroids = json.load(json_data) 170 | self.centroids = {int(idx): centroids[idx] for idx in centroids} 171 | else: 172 | if self.scf: 173 | self.centroids = kitti_uniform.class_centroids_all( 174 | self.imgs, 175 | num_classes, 176 | id2trainid=id_to_trainid, 177 | tile_size=class_uniform_tile) 178 | else: 179 | self.centroids = uniform.class_centroids_all( 180 | self.imgs, 181 | num_classes, 182 | id2trainid=id_to_trainid, 183 | tile_size=class_uniform_tile) 184 | with open(json_fn, 'w') as outfile: 185 | json.dump(self.centroids, outfile, indent=4) 186 | 187 | self.build_epoch() 188 | 189 | def build_epoch(self, cut=False): 190 | if self.class_uniform_pct > 0: 191 | self.imgs_uniform = uniform.build_epoch(self.imgs, 192 | self.centroids, 193 | num_classes, 194 | cfg.CLASS_UNIFORM_PCT) 195 | else: 196 | self.imgs_uniform = self.imgs 197 | 198 | def __getitem__(self, index): 199 | elem = self.imgs_uniform[index] 200 | centroid = None 201 | if len(elem) == 4: 202 | img_path, mask_path, centroid, class_id = elem 203 | else: 204 | img_path, mask_path = elem 205 | 206 | if self.mode == 'test': 207 | img, mask = Image.open(img_path).convert('RGB'), None 208 | else: 209 | img, mask = Image.open(img_path).convert('RGB'), Image.open(mask_path) 210 | img_name = os.path.splitext(os.path.basename(img_path))[0] 211 | 212 | # kitti scale correction factor 213 | if self.mode == 'train' or self.mode == 'trainval': 214 | if self.scf: 215 | width, height = img.size 216 | img = img.resize((width*2, height*2), Image.BICUBIC) 217 | mask = mask.resize((width*2, height*2), Image.NEAREST) 218 | elif self.mode == 'val': 219 | width, height = 1242, 376 220 | img = img.resize((width, height), Image.BICUBIC) 221 | mask = mask.resize((width, height), Image.NEAREST) 222 | elif self.mode == 'test': 223 | img_keepsize = img.copy() 224 | width, height = 1280, 384 225 | img = img.resize((width, height), Image.BICUBIC) 226 | else: 227 | logging.info('Unknown mode {}'.format(mode)) 228 | sys.exit() 229 | 230 | if self.mode != 'test': 231 | mask = np.array(mask) 232 | mask_copy = mask.copy() 233 | 234 | for k, v in id_to_trainid.items(): 235 | mask_copy[mask == k] = v 236 | mask = Image.fromarray(mask_copy.astype(np.uint8)) 237 | 238 | # Image Transformations 239 | if self.joint_transform_list is not None: 240 | for idx, xform in enumerate(self.joint_transform_list): 241 | if idx == 0 and centroid is not None: 242 | # HACK 243 | # We assume that the first transform is capable of taking 244 | # in a centroid 245 | img, mask = xform(img, mask, centroid) 246 | else: 247 | img, mask = xform(img, mask) 248 | 249 | # Debug 250 | if self.dump_images and centroid is not None: 251 | outdir = './dump_imgs_{}'.format(self.mode) 252 | os.makedirs(outdir, exist_ok=True) 253 | dump_img_name = trainid_to_name[class_id] + '_' + img_name 254 | out_img_fn = os.path.join(outdir, dump_img_name + '.png') 255 | out_msk_fn = os.path.join(outdir, dump_img_name + '_mask.png') 256 | mask_img = colorize_mask(np.array(mask)) 257 | img.save(out_img_fn) 258 | mask_img.save(out_msk_fn) 259 | 260 | if self.transform is not None: 261 | img = self.transform(img) 262 | if self.mode == 'test': 263 | img_keepsize = self.transform(img_keepsize) 264 | mask = img_keepsize 265 | if self.target_transform is not None: 266 | if self.mode != 'test': 267 | mask = self.target_transform(mask) 268 | 269 | return img, mask, img_name 270 | 271 | def __len__(self): 272 | return len(self.imgs_uniform) 273 | -------------------------------------------------------------------------------- /datasets/mapillary.py: -------------------------------------------------------------------------------- 1 | """ 2 | Mapillary Dataset Loader 3 | """ 4 | from PIL import Image 5 | from torch.utils import data 6 | import os 7 | import numpy as np 8 | import json 9 | import datasets.uniform as uniform 10 | from config import cfg 11 | 12 | num_classes = 65 13 | ignore_label = 65 14 | root = cfg.DATASET.MAPILLARY_DIR 15 | config_fn = os.path.join(root, 'config.json') 16 | id_to_ignore_or_group = {} 17 | color_mapping = [] 18 | id_to_trainid = {} 19 | 20 | 21 | def colorize_mask(image_array): 22 | """ 23 | Colorize a segmentation mask 24 | """ 25 | new_mask = Image.fromarray(image_array.astype(np.uint8)).convert('P') 26 | new_mask.putpalette(color_mapping) 27 | return new_mask 28 | 29 | 30 | def make_dataset(quality, mode): 31 | """ 32 | Create File List 33 | """ 34 | assert (quality == 'semantic' and mode in ['train', 'val']) 35 | img_dir_name = None 36 | if quality == 'semantic': 37 | if mode == 'train': 38 | img_dir_name = 'training' 39 | if mode == 'val': 40 | img_dir_name = 'validation' 41 | mask_path = os.path.join(root, img_dir_name, 'labels') 42 | else: 43 | raise BaseException("Instance Segmentation Not support") 44 | 45 | img_path = os.path.join(root, img_dir_name, 'images') 46 | print(img_path) 47 | if quality != 'video': 48 | imgs = sorted([os.path.splitext(f)[0] for f in os.listdir(img_path)]) 49 | msks = sorted([os.path.splitext(f)[0] for f in os.listdir(mask_path)]) 50 | assert imgs == msks 51 | 52 | items = [] 53 | c_items = os.listdir(img_path) 54 | if '.DS_Store' in c_items: 55 | c_items.remove('.DS_Store') 56 | 57 | for it in c_items: 58 | if quality == 'video': 59 | item = (os.path.join(img_path, it), os.path.join(img_path, it)) 60 | else: 61 | item = (os.path.join(img_path, it), 62 | os.path.join(mask_path, it.replace(".jpg", ".png"))) 63 | items.append(item) 64 | return items 65 | 66 | 67 | def gen_colormap(): 68 | """ 69 | Get Color Map from file 70 | """ 71 | global color_mapping 72 | 73 | # load mapillary config 74 | with open(config_fn) as config_file: 75 | config = json.load(config_file) 76 | config_labels = config['labels'] 77 | 78 | # calculate label color mapping 79 | colormap = [] 80 | id2name = {} 81 | for i in range(0, len(config_labels)): 82 | colormap = colormap + config_labels[i]['color'] 83 | id2name[i] = config_labels[i]['readable'] 84 | color_mapping = colormap 85 | return id2name 86 | 87 | 88 | class Mapillary(data.Dataset): 89 | def __init__(self, quality, mode, joint_transform_list=None, 90 | transform=None, target_transform=None, dump_images=False, 91 | class_uniform_pct=0, class_uniform_tile=768, test=False): 92 | """ 93 | class_uniform_pct = Percent of class uniform samples. 1.0 means fully uniform. 94 | 0.0 means fully random. 95 | class_uniform_tile_size = Class uniform tile size 96 | """ 97 | self.quality = quality 98 | self.mode = mode 99 | self.joint_transform_list = joint_transform_list 100 | self.transform = transform 101 | self.target_transform = target_transform 102 | self.dump_images = dump_images 103 | self.class_uniform_pct = class_uniform_pct 104 | self.class_uniform_tile = class_uniform_tile 105 | self.id2name = gen_colormap() 106 | self.imgs_uniform = None 107 | for i in range(num_classes): 108 | id_to_trainid[i] = i 109 | 110 | # find all images 111 | self.imgs = make_dataset(quality, mode) 112 | if len(self.imgs) == 0: 113 | raise RuntimeError('Found 0 images, please check the data set') 114 | if test: 115 | np.random.shuffle(self.imgs) 116 | self.imgs = self.imgs[:200] 117 | 118 | if self.class_uniform_pct: 119 | json_fn = 'mapillary_tile{}.json'.format(self.class_uniform_tile) 120 | if os.path.isfile(json_fn): 121 | with open(json_fn, 'r') as json_data: 122 | centroids = json.load(json_data) 123 | self.centroids = {int(idx): centroids[idx] for idx in centroids} 124 | else: 125 | # centroids is a dict (indexed by class) of lists of centroids 126 | self.centroids = uniform.class_centroids_all( 127 | self.imgs, 128 | num_classes, 129 | id2trainid=None, 130 | tile_size=self.class_uniform_tile) 131 | with open(json_fn, 'w') as outfile: 132 | json.dump(self.centroids, outfile, indent=4) 133 | else: 134 | self.centroids = [] 135 | self.build_epoch() 136 | 137 | def build_epoch(self): 138 | if self.class_uniform_pct != 0: 139 | self.imgs_uniform = uniform.build_epoch(self.imgs, 140 | self.centroids, 141 | num_classes, 142 | self.class_uniform_pct) 143 | else: 144 | self.imgs_uniform = self.imgs 145 | 146 | def __getitem__(self, index): 147 | if len(self.imgs_uniform[index]) == 2: 148 | img_path, mask_path = self.imgs_uniform[index] 149 | centroid = None 150 | class_id = None 151 | else: 152 | img_path, mask_path, centroid, class_id = self.imgs_uniform[index] 153 | img, mask = Image.open(img_path).convert('RGB'), Image.open(mask_path) 154 | img_name = os.path.splitext(os.path.basename(img_path))[0] 155 | 156 | mask = np.array(mask) 157 | mask_copy = mask.copy() 158 | for k, v in id_to_ignore_or_group.items(): 159 | mask_copy[mask == k] = v 160 | mask = Image.fromarray(mask_copy.astype(np.uint8)) 161 | 162 | # Image Transformations 163 | if self.joint_transform_list is not None: 164 | for idx, xform in enumerate(self.joint_transform_list): 165 | if idx == 0 and centroid is not None: 166 | # HACK! Assume the first transform accepts a centroid 167 | img, mask = xform(img, mask, centroid) 168 | else: 169 | img, mask = xform(img, mask) 170 | 171 | if self.dump_images: 172 | outdir = 'dump_imgs_{}'.format(self.mode) 173 | os.makedirs(outdir, exist_ok=True) 174 | if centroid is not None: 175 | dump_img_name = self.id2name[class_id] + '_' + img_name 176 | else: 177 | dump_img_name = img_name 178 | out_img_fn = os.path.join(outdir, dump_img_name + '.png') 179 | out_msk_fn = os.path.join(outdir, dump_img_name + '_mask.png') 180 | mask_img = colorize_mask(np.array(mask)) 181 | img.save(out_img_fn) 182 | mask_img.save(out_msk_fn) 183 | 184 | if self.transform is not None: 185 | img = self.transform(img) 186 | if self.target_transform is not None: 187 | mask = self.target_transform(mask) 188 | return img, mask, img_name 189 | 190 | def __len__(self): 191 | return len(self.imgs_uniform) 192 | 193 | def calculate_weights(self): 194 | raise BaseException("not supported yet") 195 | -------------------------------------------------------------------------------- /datasets/nullloader.py: -------------------------------------------------------------------------------- 1 | """ 2 | Null Loader 3 | """ 4 | import numpy as np 5 | import torch 6 | from torch.utils import data 7 | 8 | num_classes = 19 9 | ignore_label = 255 10 | 11 | class NullLoader(data.Dataset): 12 | """ 13 | Null Dataset for Performance 14 | """ 15 | def __init__(self,crop_size): 16 | self.imgs = range(200) 17 | self.crop_size = crop_size 18 | 19 | def __getitem__(self, index): 20 | #Return img, mask, name 21 | return torch.FloatTensor(np.zeros((3,self.crop_size,self.crop_size))), torch.LongTensor(np.zeros((self.crop_size,self.crop_size))), 'img' + str(index) 22 | 23 | def __len__(self): 24 | return len(self.imgs) -------------------------------------------------------------------------------- /datasets/sampler.py: -------------------------------------------------------------------------------- 1 | """ 2 | # Code adapted from: 3 | # https://github.com/pytorch/pytorch/blob/master/torch/utils/data/distributed.py 4 | # 5 | # BSD 3-Clause License 6 | # 7 | # Copyright (c) 2017, 8 | # All rights reserved. 9 | # 10 | # Redistribution and use in source and binary forms, with or without 11 | # modification, are permitted provided that the following conditions are met: 12 | # 13 | # * Redistributions of source code must retain the above copyright notice, this 14 | # list of conditions and the following disclaimer. 15 | # 16 | # * Redistributions in binary form must reproduce the above copyright notice, 17 | # this list of conditions and the following disclaimer in the documentation 18 | # and/or other materials provided with the distribution. 19 | # 20 | # * Neither the name of the copyright holder nor the names of its 21 | # contributors may be used to endorse or promote products derived from 22 | # this software without specific prior written permission. 23 | # 24 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 25 | # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 26 | # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 27 | # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 28 | # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 29 | # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 30 | # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 31 | # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 32 | # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 33 | # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 34 | """ 35 | 36 | 37 | 38 | import math 39 | import torch 40 | from torch.distributed import get_world_size, get_rank 41 | from torch.utils.data import Sampler 42 | 43 | class DistributedSampler(Sampler): 44 | """Sampler that restricts data loading to a subset of the dataset. 45 | 46 | It is especially useful in conjunction with 47 | :class:`torch.nn.parallel.DistributedDataParallel`. In such case, each 48 | process can pass a DistributedSampler instance as a DataLoader sampler, 49 | and load a subset of the original dataset that is exclusive to it. 50 | 51 | .. note:: 52 | Dataset is assumed to be of constant size. 53 | 54 | Arguments: 55 | dataset: Dataset used for sampling. 56 | num_replicas (optional): Number of processes participating in 57 | distributed training. 58 | rank (optional): Rank of the current process within num_replicas. 59 | """ 60 | 61 | def __init__(self, dataset, pad=False, consecutive_sample=False, permutation=False, num_replicas=None, rank=None): 62 | if num_replicas is None: 63 | num_replicas = get_world_size() 64 | if rank is None: 65 | rank = get_rank() 66 | self.dataset = dataset 67 | self.num_replicas = num_replicas 68 | self.rank = rank 69 | self.epoch = 0 70 | self.consecutive_sample = consecutive_sample 71 | self.permutation = permutation 72 | if pad: 73 | self.num_samples = int(math.ceil(len(self.dataset) * 1.0 / self.num_replicas)) 74 | else: 75 | self.num_samples = int(math.floor(len(self.dataset) * 1.0 / self.num_replicas)) 76 | self.total_size = self.num_samples * self.num_replicas 77 | 78 | def __iter__(self): 79 | # deterministically shuffle based on epoch 80 | g = torch.Generator() 81 | g.manual_seed(self.epoch) 82 | 83 | if self.permutation: 84 | indices = list(torch.randperm(len(self.dataset), generator=g)) 85 | else: 86 | indices = list([x for x in range(len(self.dataset))]) 87 | 88 | # add extra samples to make it evenly divisible 89 | if self.total_size > len(indices): 90 | indices += indices[:(self.total_size - len(indices))] 91 | 92 | # subsample 93 | if self.consecutive_sample: 94 | offset = self.num_samples * self.rank 95 | indices = indices[offset:offset + self.num_samples] 96 | else: 97 | indices = indices[self.rank:self.total_size:self.num_replicas] 98 | assert len(indices) == self.num_samples 99 | 100 | return iter(indices) 101 | 102 | def __len__(self): 103 | return self.num_samples 104 | 105 | def set_epoch(self, epoch): 106 | self.epoch = epoch 107 | 108 | def set_num_samples(self): 109 | self.num_samples = int(math.ceil(len(self.dataset) * 1.0 / self.num_replicas)) 110 | self.total_size = self.num_samples * self.num_replicas -------------------------------------------------------------------------------- /datasets/uniform.py: -------------------------------------------------------------------------------- 1 | """ 2 | Uniform sampling of classes. 3 | For all images, for all classes, generate centroids around which to sample. 4 | 5 | All images are divided into tiles. 6 | For each tile, a class can be present or not. If it is 7 | present, calculate the centroid of the class and record it. 8 | 9 | We would like to thank Peter Kontschieder for the inspiration of this idea. 10 | """ 11 | 12 | import logging 13 | from collections import defaultdict 14 | from PIL import Image 15 | import numpy as np 16 | from scipy import ndimage 17 | from tqdm import tqdm 18 | 19 | pbar = None 20 | 21 | class Point(): 22 | """ 23 | Point Class For X and Y Location 24 | """ 25 | def __init__(self, x, y): 26 | self.x = x 27 | self.y = y 28 | 29 | 30 | def calc_tile_locations(tile_size, image_size): 31 | """ 32 | Divide an image into tiles to help us cover classes that are spread out. 33 | tile_size: size of tile to distribute 34 | image_size: original image size 35 | return: locations of the tiles 36 | """ 37 | image_size_y, image_size_x = image_size 38 | locations = [] 39 | for y in range(image_size_y // tile_size): 40 | for x in range(image_size_x // tile_size): 41 | x_offs = x * tile_size 42 | y_offs = y * tile_size 43 | locations.append((x_offs, y_offs)) 44 | return locations 45 | 46 | 47 | def class_centroids_image(item, tile_size, num_classes, id2trainid): 48 | """ 49 | For one image, calculate centroids for all classes present in image. 50 | item: image, image_name 51 | tile_size: 52 | num_classes: 53 | id2trainid: mapping from original id to training ids 54 | return: Centroids are calculated for each tile. 55 | """ 56 | image_fn, label_fn = item 57 | centroids = defaultdict(list) 58 | mask = np.array(Image.open(label_fn)) 59 | image_size = mask.shape 60 | tile_locations = calc_tile_locations(tile_size, image_size) 61 | 62 | mask_copy = mask.copy() 63 | if id2trainid: 64 | for k, v in id2trainid.items(): 65 | mask[mask_copy == k] = v 66 | 67 | for x_offs, y_offs in tile_locations: 68 | patch = mask[y_offs:y_offs + tile_size, x_offs:x_offs + tile_size] 69 | for class_id in range(num_classes): 70 | if class_id in patch: 71 | patch_class = (patch == class_id).astype(int) 72 | centroid_y, centroid_x = ndimage.measurements.center_of_mass(patch_class) 73 | centroid_y = int(centroid_y) + y_offs 74 | centroid_x = int(centroid_x) + x_offs 75 | centroid = (centroid_x, centroid_y) 76 | centroids[class_id].append((image_fn, label_fn, centroid, class_id)) 77 | pbar.update(1) 78 | return centroids 79 | 80 | 81 | 82 | def pooled_class_centroids_all(items, num_classes, id2trainid, tile_size=1024): 83 | """ 84 | Calculate class centroids for all classes for all images for all tiles. 85 | items: list of (image_fn, label_fn) 86 | tile size: size of tile 87 | returns: dict that contains a list of centroids for each class 88 | """ 89 | from multiprocessing.dummy import Pool 90 | from functools import partial 91 | pool = Pool(32) 92 | global pbar 93 | pbar = tqdm(total=len(items), desc='pooled centroid extraction') 94 | class_centroids_item = partial(class_centroids_image, 95 | num_classes=num_classes, 96 | id2trainid=id2trainid, 97 | tile_size=tile_size) 98 | 99 | centroids = defaultdict(list) 100 | new_centroids = pool.map(class_centroids_item, items) 101 | pool.close() 102 | pool.join() 103 | 104 | # combine each image's items into a single global dict 105 | for image_items in new_centroids: 106 | for class_id in image_items: 107 | centroids[class_id].extend(image_items[class_id]) 108 | return centroids 109 | 110 | 111 | def unpooled_class_centroids_all(items, num_classes, tile_size=1024): 112 | """ 113 | Calculate class centroids for all classes for all images for all tiles. 114 | items: list of (image_fn, label_fn) 115 | tile size: size of tile 116 | returns: dict that contains a list of centroids for each class 117 | """ 118 | centroids = defaultdict(list) 119 | global pbar 120 | pbar = tqdm(total=len(items), desc='centroid extraction') 121 | for image, label in items: 122 | new_centroids = class_centroids_image((image, label), 123 | tile_size, 124 | num_classes) 125 | for class_id in new_centroids: 126 | centroids[class_id].extend(new_centroids[class_id]) 127 | 128 | return centroids 129 | 130 | 131 | def class_centroids_all(items, num_classes, id2trainid, tile_size=1024): 132 | """ 133 | intermediate function to call pooled_class_centroid 134 | """ 135 | 136 | pooled_centroids = pooled_class_centroids_all(items, num_classes, 137 | id2trainid, tile_size) 138 | return pooled_centroids 139 | 140 | 141 | def random_sampling(alist, num): 142 | """ 143 | Randomly sample num items from the list 144 | alist: list of centroids to sample from 145 | num: can be larger than the list and if so, then wrap around 146 | return: class uniform samples from the list 147 | """ 148 | sampling = [] 149 | len_list = len(alist) 150 | assert len_list, 'len_list is zero!' 151 | indices = np.arange(len_list) 152 | np.random.shuffle(indices) 153 | 154 | for i in range(num): 155 | item = alist[indices[i % len_list]] 156 | sampling.append(item) 157 | return sampling 158 | 159 | 160 | def build_epoch(imgs, centroids, num_classes, class_uniform_pct): 161 | """ 162 | Generate an epochs-worth of crops using uniform sampling. Needs to be called every 163 | imgs: list of imgs 164 | centroids: 165 | num_classes: 166 | class_uniform_pct: class uniform sampling percent ( % of uniform images in one epoch ) 167 | """ 168 | logging.info("Class Uniform Percentage: %s", str(class_uniform_pct)) 169 | num_epoch = int(len(imgs)) 170 | 171 | logging.info('Class Uniform items per Epoch:%s', str(num_epoch)) 172 | num_per_class = int((num_epoch * class_uniform_pct) / num_classes) 173 | num_rand = num_epoch - num_per_class * num_classes 174 | # create random crops 175 | imgs_uniform = random_sampling(imgs, num_rand) 176 | 177 | # now add uniform sampling 178 | for class_id in range(num_classes): 179 | string_format = "cls %d len %d"% (class_id, len(centroids[class_id])) 180 | logging.info(string_format) 181 | for class_id in range(num_classes): 182 | centroid_len = len(centroids[class_id]) 183 | if centroid_len == 0: 184 | pass 185 | else: 186 | class_centroids = random_sampling(centroids[class_id], num_per_class) 187 | imgs_uniform.extend(class_centroids) 188 | 189 | return imgs_uniform 190 | -------------------------------------------------------------------------------- /demo.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import argparse 4 | from PIL import Image 5 | import numpy as np 6 | import cv2 7 | 8 | import torch 9 | from torch.backends import cudnn 10 | import torchvision.transforms as transforms 11 | 12 | import network 13 | from optimizer import restore_snapshot 14 | from datasets import cityscapes 15 | from config import assert_and_infer_cfg 16 | 17 | parser = argparse.ArgumentParser(description='demo') 18 | parser.add_argument('--demo-image', type=str, default='', help='path to demo image', required=True) 19 | parser.add_argument('--snapshot', type=str, default='./pretrained_models/cityscapes_best.pth', help='pre-trained checkpoint', required=True) 20 | parser.add_argument('--arch', type=str, default='network.deepv3.DeepWV3Plus', help='network architecture used for inference') 21 | parser.add_argument('--save-dir', type=str, default='./save', help='path to save your results') 22 | args = parser.parse_args() 23 | assert_and_infer_cfg(args, train_mode=False) 24 | cudnn.benchmark = False 25 | torch.cuda.empty_cache() 26 | 27 | # get net 28 | args.dataset_cls = cityscapes 29 | net = network.get_net(args, criterion=None) 30 | net = torch.nn.DataParallel(net).cuda() 31 | print('Net built.') 32 | net, _ = restore_snapshot(net, optimizer=None, snapshot=args.snapshot, restore_optimizer_bool=False) 33 | net.eval() 34 | print('Net restored.') 35 | 36 | # get data 37 | mean_std = ([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) 38 | img_transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize(*mean_std)]) 39 | img = Image.open(args.demo_image).convert('RGB') 40 | img_tensor = img_transform(img) 41 | 42 | # predict 43 | with torch.no_grad(): 44 | img = img_tensor.unsqueeze(0).cuda() 45 | pred = net(img) 46 | print('Inference done.') 47 | 48 | pred = pred.cpu().numpy().squeeze() 49 | pred = np.argmax(pred, axis=0) 50 | 51 | if not os.path.exists(args.save_dir): 52 | os.makedirs(args.save_dir) 53 | 54 | colorized = args.dataset_cls.colorize_mask(pred) 55 | colorized.save(os.path.join(args.save_dir, 'color_mask.png')) 56 | 57 | label_out = np.zeros_like(pred) 58 | for label_id, train_id in args.dataset_cls.id_to_trainid.items(): 59 | label_out[np.where(pred == train_id)] = label_id 60 | cv2.imwrite(os.path.join(args.save_dir, 'pred_mask.png'), label_out) 61 | print('Results saved.') 62 | -------------------------------------------------------------------------------- /demo_folder.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import time 4 | import argparse 5 | from PIL import Image 6 | import numpy as np 7 | import cv2 8 | 9 | import torch 10 | from torch.backends import cudnn 11 | import torchvision.transforms as transforms 12 | 13 | import network 14 | from optimizer import restore_snapshot 15 | from datasets import cityscapes 16 | from config import assert_and_infer_cfg 17 | 18 | parser = argparse.ArgumentParser(description='demo') 19 | parser.add_argument('--demo-folder', type=str, default='', help='path to the folder containing demo images', required=True) 20 | parser.add_argument('--snapshot', type=str, default='./pretrained_models/cityscapes_best.pth', help='pre-trained checkpoint', required=True) 21 | parser.add_argument('--arch', type=str, default='network.deepv3.DeepWV3Plus', help='network architecture used for inference') 22 | parser.add_argument('--save-dir', type=str, default='./save', help='path to save your results') 23 | args = parser.parse_args() 24 | assert_and_infer_cfg(args, train_mode=False) 25 | cudnn.benchmark = False 26 | torch.cuda.empty_cache() 27 | 28 | # get net 29 | args.dataset_cls = cityscapes 30 | net = network.get_net(args, criterion=None) 31 | net = torch.nn.DataParallel(net).cuda() 32 | print('Net built.') 33 | net, _ = restore_snapshot(net, optimizer=None, snapshot=args.snapshot, restore_optimizer_bool=False) 34 | net.eval() 35 | print('Net restored.') 36 | 37 | # get data 38 | data_dir = args.demo_folder 39 | images = os.listdir(data_dir) 40 | if len(images) == 0: 41 | print('There are no images at directory %s. Check the data path.' % (data_dir)) 42 | else: 43 | print('There are %d images to be processed.' % (len(images))) 44 | images.sort() 45 | 46 | mean_std = ([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) 47 | img_transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize(*mean_std)]) 48 | if not os.path.exists(args.save_dir): 49 | os.makedirs(args.save_dir) 50 | 51 | start_time = time.time() 52 | for img_id, img_name in enumerate(images): 53 | img_dir = os.path.join(data_dir, img_name) 54 | img = Image.open(img_dir).convert('RGB') 55 | img_tensor = img_transform(img) 56 | 57 | # predict 58 | with torch.no_grad(): 59 | pred = net(img_tensor.unsqueeze(0).cuda()) 60 | print('%04d/%04d: Inference done.' % (img_id + 1, len(images))) 61 | 62 | pred = pred.cpu().numpy().squeeze() 63 | pred = np.argmax(pred, axis=0) 64 | 65 | color_name = 'color_mask_' + img_name 66 | overlap_name = 'overlap_' + img_name 67 | pred_name = 'pred_mask_' + img_name 68 | 69 | # save colorized predictions 70 | colorized = args.dataset_cls.colorize_mask(pred) 71 | colorized.save(os.path.join(args.save_dir, color_name)) 72 | 73 | # save colorized predictions overlapped on original images 74 | overlap = cv2.addWeighted(np.array(img), 0.5, np.array(colorized.convert('RGB')), 0.5, 0) 75 | cv2.imwrite(os.path.join(args.save_dir, overlap_name), overlap[:, :, ::-1]) 76 | 77 | # save label-based predictions, e.g. for submission purpose 78 | label_out = np.zeros_like(pred) 79 | for label_id, train_id in args.dataset_cls.id_to_trainid.items(): 80 | label_out[np.where(pred == train_id)] = label_id 81 | cv2.imwrite(os.path.join(args.save_dir, pred_name), label_out) 82 | end_time = time.time() 83 | 84 | print('Results saved.') 85 | print('Inference takes %4.2f seconds, which is %4.2f seconds per image, including saving results.' % (end_time - start_time, (end_time - start_time)/len(images))) 86 | -------------------------------------------------------------------------------- /images/method.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/geohot/semantic-segmentation/d5b6ec7ec9d4e296fc0e1aa85163ed0e98f54f92/images/method.png -------------------------------------------------------------------------------- /images/vis.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/geohot/semantic-segmentation/d5b6ec7ec9d4e296fc0e1aa85163ed0e98f54f92/images/vis.png -------------------------------------------------------------------------------- /loss.py: -------------------------------------------------------------------------------- 1 | """ 2 | Loss.py 3 | """ 4 | 5 | import logging 6 | import numpy as np 7 | import torch 8 | import torch.nn as nn 9 | import torch.nn.functional as F 10 | from config import cfg 11 | 12 | 13 | def get_loss(args): 14 | """ 15 | Get the criterion based on the loss function 16 | args: commandline arguments 17 | return: criterion, criterion_val 18 | """ 19 | 20 | if args.img_wt_loss: 21 | criterion = ImageBasedCrossEntropyLoss2d( 22 | classes=args.dataset_cls.num_classes, size_average=True, 23 | ignore_index=args.dataset_cls.ignore_label, 24 | upper_bound=args.wt_bound).cuda() 25 | elif args.jointwtborder: 26 | criterion = ImgWtLossSoftNLL(classes=args.dataset_cls.num_classes, 27 | ignore_index=args.dataset_cls.ignore_label, 28 | upper_bound=args.wt_bound).cuda() 29 | else: 30 | criterion = CrossEntropyLoss2d(size_average=True, 31 | ignore_index=args.dataset_cls.ignore_label).cuda() 32 | 33 | criterion_val = CrossEntropyLoss2d(size_average=True, 34 | weight=None, 35 | ignore_index=args.dataset_cls.ignore_label).cuda() 36 | return criterion, criterion_val 37 | 38 | 39 | class ImageBasedCrossEntropyLoss2d(nn.Module): 40 | """ 41 | Image Weighted Cross Entropy Loss 42 | """ 43 | 44 | def __init__(self, classes, weight=None, size_average=True, ignore_index=255, 45 | norm=False, upper_bound=1.0): 46 | super(ImageBasedCrossEntropyLoss2d, self).__init__() 47 | logging.info("Using Per Image based weighted loss") 48 | self.num_classes = classes 49 | self.nll_loss = nn.NLLLoss2d(weight, size_average, ignore_index) 50 | self.norm = norm 51 | self.upper_bound = upper_bound 52 | self.batch_weights = cfg.BATCH_WEIGHTING 53 | 54 | def calculate_weights(self, target): 55 | """ 56 | Calculate weights of classes based on the training crop 57 | """ 58 | hist = np.histogram(target.flatten(), range( 59 | self.num_classes + 1), normed=True)[0] 60 | if self.norm: 61 | hist = ((hist != 0) * self.upper_bound * (1 / hist)) + 1 62 | else: 63 | hist = ((hist != 0) * self.upper_bound * (1 - hist)) + 1 64 | return hist 65 | 66 | def forward(self, inputs, targets): 67 | 68 | target_cpu = targets.data.cpu().numpy() 69 | if self.batch_weights: 70 | weights = self.calculate_weights(target_cpu) 71 | self.nll_loss.weight = torch.Tensor(weights).cuda() 72 | 73 | loss = 0.0 74 | for i in range(0, inputs.shape[0]): 75 | if not self.batch_weights: 76 | weights = self.calculate_weights(target_cpu[i]) 77 | self.nll_loss.weight = torch.Tensor(weights).cuda() 78 | 79 | loss += self.nll_loss(F.log_softmax(inputs[i].unsqueeze(0)), 80 | targets[i].unsqueeze(0)) 81 | return loss 82 | 83 | 84 | 85 | class CrossEntropyLoss2d(nn.Module): 86 | """ 87 | Cross Entroply NLL Loss 88 | """ 89 | 90 | def __init__(self, weight=None, size_average=True, ignore_index=255): 91 | super(CrossEntropyLoss2d, self).__init__() 92 | logging.info("Using Cross Entropy Loss") 93 | self.nll_loss = nn.NLLLoss2d(weight, size_average, ignore_index) 94 | # self.weight = weight 95 | 96 | def forward(self, inputs, targets): 97 | return self.nll_loss(F.log_softmax(inputs), targets) 98 | 99 | def customsoftmax(inp, multihotmask): 100 | """ 101 | Custom Softmax 102 | """ 103 | soft = F.softmax(inp) 104 | # This takes the mask * softmax ( sums it up hence summing up the classes in border 105 | # then takes of summed up version vs no summed version 106 | return torch.log( 107 | torch.max(soft, (multihotmask * (soft * multihotmask).sum(1, keepdim=True))) 108 | ) 109 | 110 | class ImgWtLossSoftNLL(nn.Module): 111 | """ 112 | Relax Loss 113 | """ 114 | 115 | def __init__(self, classes, ignore_index=255, weights=None, upper_bound=1.0, 116 | norm=False): 117 | super(ImgWtLossSoftNLL, self).__init__() 118 | self.weights = weights 119 | self.num_classes = classes 120 | self.ignore_index = ignore_index 121 | self.upper_bound = upper_bound 122 | self.norm = norm 123 | self.batch_weights = cfg.BATCH_WEIGHTING 124 | self.fp16 = False 125 | 126 | 127 | def calculate_weights(self, target): 128 | """ 129 | Calculate weights of the classes based on training crop 130 | """ 131 | if len(target.shape) == 3: 132 | hist = np.sum(target, axis=(1, 2)) * 1.0 / target.sum() 133 | else: 134 | hist = np.sum(target, axis=(0, 2, 3)) * 1.0 / target.sum() 135 | if self.norm: 136 | hist = ((hist != 0) * self.upper_bound * (1 / hist)) + 1 137 | else: 138 | hist = ((hist != 0) * self.upper_bound * (1 - hist)) + 1 139 | return hist[:-1] 140 | 141 | def custom_nll(self, inputs, target, class_weights, border_weights, mask): 142 | """ 143 | NLL Relaxed Loss Implementation 144 | """ 145 | if (cfg.REDUCE_BORDER_EPOCH != -1 and cfg.EPOCH > cfg.REDUCE_BORDER_EPOCH): 146 | border_weights = 1 / border_weights 147 | target[target > 1] = 1 148 | if self.fp16: 149 | loss_matrix = (-1 / border_weights * 150 | (target[:, :-1, :, :].half() * 151 | class_weights.unsqueeze(0).unsqueeze(2).unsqueeze(3) * 152 | customsoftmax(inputs, target[:, :-1, :, :].half())).sum(1)) * \ 153 | (1. - mask.half()) 154 | else: 155 | loss_matrix = (-1 / border_weights * 156 | (target[:, :-1, :, :].float() * 157 | class_weights.unsqueeze(0).unsqueeze(2).unsqueeze(3) * 158 | customsoftmax(inputs, target[:, :-1, :, :].float())).sum(1)) * \ 159 | (1. - mask.float()) 160 | 161 | # loss_matrix[border_weights > 1] = 0 162 | loss = loss_matrix.sum() 163 | 164 | # +1 to prevent division by 0 165 | loss = loss / (target.shape[0] * target.shape[2] * target.shape[3] - mask.sum().item() + 1) 166 | return loss 167 | 168 | def forward(self, inputs, target): 169 | if self.fp16: 170 | weights = target[:, :-1, :, :].sum(1).half() 171 | else: 172 | weights = target[:, :-1, :, :].sum(1).float() 173 | ignore_mask = (weights == 0) 174 | weights[ignore_mask] = 1 175 | 176 | loss = 0 177 | target_cpu = target.data.cpu().numpy() 178 | 179 | if self.batch_weights: 180 | class_weights = self.calculate_weights(target_cpu) 181 | 182 | for i in range(0, inputs.shape[0]): 183 | if not self.batch_weights: 184 | class_weights = self.calculate_weights(target_cpu[i]) 185 | loss = loss + self.custom_nll(inputs[i].unsqueeze(0), 186 | target[i].unsqueeze(0), 187 | class_weights=torch.Tensor(class_weights).cuda(), 188 | border_weights=weights, mask=ignore_mask[i]) 189 | 190 | return loss 191 | -------------------------------------------------------------------------------- /network/Resnet.py: -------------------------------------------------------------------------------- 1 | """ 2 | # Code Adapted from: 3 | # https://github.com/pytorch/vision/blob/master/torchvision/models/resnet.py 4 | # 5 | # BSD 3-Clause License 6 | # 7 | # Copyright (c) 2017, 8 | # All rights reserved. 9 | # 10 | # Redistribution and use in source and binary forms, with or without 11 | # modification, are permitted provided that the following conditions are met: 12 | # 13 | # * Redistributions of source code must retain the above copyright notice, this 14 | # list of conditions and the following disclaimer. 15 | # 16 | # * Redistributions in binary form must reproduce the above copyright notice, 17 | # this list of conditions and the following disclaimer in the documentation 18 | # and/or other materials provided with the distribution. 19 | # 20 | # * Neither the name of the copyright holder nor the names of its 21 | # contributors may be used to endorse or promote products derived from 22 | # this software without specific prior written permission. 23 | # 24 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 25 | # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 26 | # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 27 | # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 28 | # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 29 | # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 30 | # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 31 | # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 32 | # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 33 | # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 34 | """ 35 | 36 | import torch.nn as nn 37 | import torch.utils.model_zoo as model_zoo 38 | import network.mynn as mynn 39 | 40 | __all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101', 41 | 'resnet152'] 42 | 43 | 44 | model_urls = { 45 | 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth', 46 | 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth', 47 | 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth', 48 | 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth', 49 | 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth', 50 | } 51 | 52 | 53 | def conv3x3(in_planes, out_planes, stride=1): 54 | """3x3 convolution with padding""" 55 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 56 | padding=1, bias=False) 57 | 58 | 59 | class BasicBlock(nn.Module): 60 | """ 61 | Basic Block for Resnet 62 | """ 63 | expansion = 1 64 | 65 | def __init__(self, inplanes, planes, stride=1, downsample=None): 66 | super(BasicBlock, self).__init__() 67 | self.conv1 = conv3x3(inplanes, planes, stride) 68 | self.bn1 = mynn.Norm2d(planes) 69 | self.relu = nn.ReLU(inplace=True) 70 | self.conv2 = conv3x3(planes, planes) 71 | self.bn2 = mynn.Norm2d(planes) 72 | self.downsample = downsample 73 | self.stride = stride 74 | 75 | def forward(self, x): 76 | residual = x 77 | 78 | out = self.conv1(x) 79 | out = self.bn1(out) 80 | out = self.relu(out) 81 | 82 | out = self.conv2(out) 83 | out = self.bn2(out) 84 | 85 | if self.downsample is not None: 86 | residual = self.downsample(x) 87 | 88 | out += residual 89 | out = self.relu(out) 90 | 91 | return out 92 | 93 | 94 | class Bottleneck(nn.Module): 95 | """ 96 | Bottleneck Layer for Resnet 97 | """ 98 | expansion = 4 99 | 100 | def __init__(self, inplanes, planes, stride=1, downsample=None): 101 | super(Bottleneck, self).__init__() 102 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) 103 | self.bn1 = mynn.Norm2d(planes) 104 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, 105 | padding=1, bias=False) 106 | self.bn2 = mynn.Norm2d(planes) 107 | self.conv3 = nn.Conv2d(planes, planes * self.expansion, kernel_size=1, bias=False) 108 | self.bn3 = mynn.Norm2d(planes * self.expansion) 109 | self.relu = nn.ReLU(inplace=True) 110 | self.downsample = downsample 111 | self.stride = stride 112 | 113 | def forward(self, x): 114 | residual = x 115 | 116 | out = self.conv1(x) 117 | out = self.bn1(out) 118 | out = self.relu(out) 119 | 120 | out = self.conv2(out) 121 | out = self.bn2(out) 122 | out = self.relu(out) 123 | 124 | out = self.conv3(out) 125 | out = self.bn3(out) 126 | 127 | if self.downsample is not None: 128 | residual = self.downsample(x) 129 | 130 | out += residual 131 | out = self.relu(out) 132 | 133 | return out 134 | 135 | 136 | class ResNet(nn.Module): 137 | """ 138 | Resnet Global Module for Initialization 139 | """ 140 | def __init__(self, block, layers, num_classes=1000): 141 | self.inplanes = 64 142 | super(ResNet, self).__init__() 143 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, 144 | bias=False) 145 | self.bn1 = mynn.Norm2d(64) 146 | self.relu = nn.ReLU(inplace=True) 147 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 148 | self.layer1 = self._make_layer(block, 64, layers[0]) 149 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2) 150 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2) 151 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2) 152 | self.avgpool = nn.AvgPool2d(7, stride=1) 153 | self.fc = nn.Linear(512 * block.expansion, num_classes) 154 | 155 | for m in self.modules(): 156 | if isinstance(m, nn.Conv2d): 157 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 158 | elif isinstance(m, nn.BatchNorm2d): 159 | nn.init.constant_(m.weight, 1) 160 | nn.init.constant_(m.bias, 0) 161 | 162 | def _make_layer(self, block, planes, blocks, stride=1): 163 | downsample = None 164 | if stride != 1 or self.inplanes != planes * block.expansion: 165 | downsample = nn.Sequential( 166 | nn.Conv2d(self.inplanes, planes * block.expansion, 167 | kernel_size=1, stride=stride, bias=False), 168 | mynn.Norm2d(planes * block.expansion), 169 | ) 170 | 171 | layers = [] 172 | layers.append(block(self.inplanes, planes, stride, downsample)) 173 | self.inplanes = planes * block.expansion 174 | for index in range(1, blocks): 175 | layers.append(block(self.inplanes, planes)) 176 | 177 | return nn.Sequential(*layers) 178 | 179 | def forward(self, x): 180 | x = self.conv1(x) 181 | x = self.bn1(x) 182 | x = self.relu(x) 183 | x = self.maxpool(x) 184 | 185 | x = self.layer1(x) 186 | x = self.layer2(x) 187 | x = self.layer3(x) 188 | x = self.layer4(x) 189 | 190 | x = self.avgpool(x) 191 | x = x.view(x.size(0), -1) 192 | x = self.fc(x) 193 | 194 | return x 195 | 196 | 197 | def resnet18(pretrained=True, **kwargs): 198 | """Constructs a ResNet-18 model. 199 | 200 | Args: 201 | pretrained (bool): If True, returns a model pre-trained on ImageNet 202 | """ 203 | model = ResNet(BasicBlock, [2, 2, 2, 2], **kwargs) 204 | if pretrained: 205 | model.load_state_dict(model_zoo.load_url(model_urls['resnet18'])) 206 | return model 207 | 208 | 209 | def resnet34(pretrained=True, **kwargs): 210 | """Constructs a ResNet-34 model. 211 | 212 | Args: 213 | pretrained (bool): If True, returns a model pre-trained on ImageNet 214 | """ 215 | model = ResNet(BasicBlock, [3, 4, 6, 3], **kwargs) 216 | if pretrained: 217 | model.load_state_dict(model_zoo.load_url(model_urls['resnet34'])) 218 | return model 219 | 220 | 221 | def resnet50(pretrained=True, **kwargs): 222 | """Constructs a ResNet-50 model. 223 | 224 | Args: 225 | pretrained (bool): If True, returns a model pre-trained on ImageNet 226 | """ 227 | model = ResNet(Bottleneck, [3, 4, 6, 3], **kwargs) 228 | if pretrained: 229 | model.load_state_dict(model_zoo.load_url(model_urls['resnet50'])) 230 | return model 231 | 232 | 233 | def resnet101(pretrained=True, **kwargs): 234 | """Constructs a ResNet-101 model. 235 | 236 | Args: 237 | pretrained (bool): If True, returns a model pre-trained on ImageNet 238 | """ 239 | model = ResNet(Bottleneck, [3, 4, 23, 3], **kwargs) 240 | if pretrained: 241 | model.load_state_dict(model_zoo.load_url(model_urls['resnet101'])) 242 | return model 243 | 244 | 245 | def resnet152(pretrained=True, **kwargs): 246 | """Constructs a ResNet-152 model. 247 | 248 | Args: 249 | pretrained (bool): If True, returns a model pre-trained on ImageNet 250 | """ 251 | model = ResNet(Bottleneck, [3, 8, 36, 3], **kwargs) 252 | if pretrained: 253 | model.load_state_dict(model_zoo.load_url(model_urls['resnet152'])) 254 | return model 255 | -------------------------------------------------------------------------------- /network/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | Network Initializations 3 | """ 4 | 5 | import logging 6 | import importlib 7 | import torch 8 | 9 | 10 | 11 | def get_net(args, criterion): 12 | """ 13 | Get Network Architecture based on arguments provided 14 | """ 15 | net = get_model(network=args.arch, num_classes=args.dataset_cls.num_classes, 16 | criterion=criterion) 17 | num_params = sum([param.nelement() for param in net.parameters()]) 18 | logging.info('Model params = {:2.1f}M'.format(num_params / 1000000)) 19 | 20 | net = net.cuda() 21 | return net 22 | 23 | 24 | def wrap_network_in_dataparallel(net, use_apex_data_parallel=False): 25 | """ 26 | Wrap the network in Dataparallel 27 | """ 28 | if use_apex_data_parallel: 29 | import apex 30 | net = apex.parallel.DistributedDataParallel(net) 31 | else: 32 | net = torch.nn.DataParallel(net) 33 | return net 34 | 35 | 36 | def get_model(network, num_classes, criterion): 37 | """ 38 | Fetch Network Function Pointer 39 | """ 40 | module = network[:network.rfind('.')] 41 | model = network[network.rfind('.') + 1:] 42 | mod = importlib.import_module(module) 43 | net_func = getattr(mod, model) 44 | net = net_func(num_classes=num_classes, criterion=criterion) 45 | return net 46 | -------------------------------------------------------------------------------- /network/deepv3.py: -------------------------------------------------------------------------------- 1 | """ 2 | # Code Adapted from: 3 | # https://github.com/sthalles/deeplab_v3 4 | # 5 | # MIT License 6 | # 7 | # Copyright (c) 2018 Thalles Santos Silva 8 | # 9 | # Permission is hereby granted, free of charge, to any person obtaining a copy 10 | # of this software and associated documentation files (the "Software"), to deal 11 | # in the Software without restriction, including without limitation the rights 12 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 13 | # copies of the Software, and to permit persons to whom the Software is 14 | # furnished to do so, subject to the following conditions: 15 | # 16 | # The above copyright notice and this permission notice shall be included in all 17 | # copies or substantial portions of the Software. 18 | # 19 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 20 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 21 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 22 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 23 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 24 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 25 | """ 26 | import logging 27 | import torch 28 | from torch import nn 29 | from network import SEresnext 30 | from network import Resnet 31 | from network.wider_resnet import wider_resnet38_a2 32 | from network.mynn import initialize_weights, Norm2d, Upsample 33 | 34 | 35 | class _AtrousSpatialPyramidPoolingModule(nn.Module): 36 | """ 37 | operations performed: 38 | 1x1 x depth 39 | 3x3 x depth dilation 6 40 | 3x3 x depth dilation 12 41 | 3x3 x depth dilation 18 42 | image pooling 43 | concatenate all together 44 | Final 1x1 conv 45 | """ 46 | 47 | def __init__(self, in_dim, reduction_dim=256, output_stride=16, rates=(6, 12, 18)): 48 | super(_AtrousSpatialPyramidPoolingModule, self).__init__() 49 | 50 | if output_stride == 8: 51 | rates = [2 * r for r in rates] 52 | elif output_stride == 16: 53 | pass 54 | else: 55 | raise 'output stride of {} not supported'.format(output_stride) 56 | 57 | self.features = [] 58 | # 1x1 59 | self.features.append( 60 | nn.Sequential(nn.Conv2d(in_dim, reduction_dim, kernel_size=1, bias=False), 61 | Norm2d(reduction_dim), nn.ReLU(inplace=True))) 62 | # other rates 63 | for r in rates: 64 | self.features.append(nn.Sequential( 65 | nn.Conv2d(in_dim, reduction_dim, kernel_size=3, 66 | dilation=r, padding=r, bias=False), 67 | Norm2d(reduction_dim), 68 | nn.ReLU(inplace=True) 69 | )) 70 | self.features = torch.nn.ModuleList(self.features) 71 | 72 | # img level features 73 | self.img_pooling = nn.AdaptiveAvgPool2d(1) 74 | self.img_conv = nn.Sequential( 75 | nn.Conv2d(in_dim, reduction_dim, kernel_size=1, bias=False), 76 | Norm2d(reduction_dim), nn.ReLU(inplace=True)) 77 | 78 | def forward(self, x): 79 | x_size = x.size() 80 | 81 | img_features = self.img_pooling(x) 82 | img_features = self.img_conv(img_features) 83 | img_features = Upsample(img_features, x_size[2:]) 84 | out = img_features 85 | 86 | for f in self.features: 87 | y = f(x) 88 | out = torch.cat((out, y), 1) 89 | return out 90 | 91 | 92 | class DeepV3Plus(nn.Module): 93 | """ 94 | Implement DeepLabV3 model 95 | A: stride8 96 | B: stride16 97 | with skip connections 98 | """ 99 | 100 | def __init__(self, num_classes, trunk='seresnext-50', criterion=None, variant='D', 101 | skip='m1', skip_num=48): 102 | super(DeepV3Plus, self).__init__() 103 | self.criterion = criterion 104 | self.variant = variant 105 | self.skip = skip 106 | self.skip_num = skip_num 107 | 108 | if trunk == 'seresnext-50': 109 | resnet = SEresnext.se_resnext50_32x4d() 110 | elif trunk == 'seresnext-101': 111 | resnet = SEresnext.se_resnext101_32x4d() 112 | elif trunk == 'resnet-50': 113 | resnet = Resnet.resnet50() 114 | resnet.layer0 = nn.Sequential(resnet.conv1, resnet.bn1, resnet.relu, resnet.maxpool) 115 | elif trunk == 'resnet-101': 116 | resnet = Resnet.resnet101() 117 | resnet.layer0 = nn.Sequential(resnet.conv1, resnet.bn1, resnet.relu, resnet.maxpool) 118 | else: 119 | raise ValueError("Not a valid network arch") 120 | 121 | self.layer0 = resnet.layer0 122 | self.layer1, self.layer2, self.layer3, self.layer4 = \ 123 | resnet.layer1, resnet.layer2, resnet.layer3, resnet.layer4 124 | 125 | if self.variant == 'D': 126 | for n, m in self.layer3.named_modules(): 127 | if 'conv2' in n: 128 | m.dilation, m.padding, m.stride = (2, 2), (2, 2), (1, 1) 129 | elif 'downsample.0' in n: 130 | m.stride = (1, 1) 131 | for n, m in self.layer4.named_modules(): 132 | if 'conv2' in n: 133 | m.dilation, m.padding, m.stride = (4, 4), (4, 4), (1, 1) 134 | elif 'downsample.0' in n: 135 | m.stride = (1, 1) 136 | elif self.variant == 'D16': 137 | for n, m in self.layer4.named_modules(): 138 | if 'conv2' in n: 139 | m.dilation, m.padding, m.stride = (2, 2), (2, 2), (1, 1) 140 | elif 'downsample.0' in n: 141 | m.stride = (1, 1) 142 | else: 143 | print("Not using Dilation ") 144 | 145 | self.aspp = _AtrousSpatialPyramidPoolingModule(2048, 256, 146 | output_stride=8) 147 | 148 | if self.skip == 'm1': 149 | self.bot_fine = nn.Conv2d(256, self.skip_num, kernel_size=1, bias=False) 150 | elif self.skip == 'm2': 151 | self.bot_fine = nn.Conv2d(512, self.skip_num, kernel_size=1, bias=False) 152 | else: 153 | raise Exception('Not a valid skip') 154 | 155 | self.bot_aspp = nn.Conv2d(1280, 256, kernel_size=1, bias=False) 156 | 157 | self.final = nn.Sequential( 158 | nn.Conv2d(256 + self.skip_num, 256, kernel_size=3, padding=1, bias=False), 159 | Norm2d(256), 160 | nn.ReLU(inplace=True), 161 | nn.Conv2d(256, 256, kernel_size=3, padding=1, bias=False), 162 | Norm2d(256), 163 | nn.ReLU(inplace=True), 164 | nn.Conv2d(256, num_classes, kernel_size=1, bias=False)) 165 | 166 | initialize_weights(self.aspp) 167 | initialize_weights(self.bot_aspp) 168 | initialize_weights(self.bot_fine) 169 | initialize_weights(self.final) 170 | 171 | def forward(self, x, gts=None): 172 | 173 | x_size = x.size() # 800 174 | x0 = self.layer0(x) # 400 175 | x1 = self.layer1(x0) # 400 176 | x2 = self.layer2(x1) # 100 177 | x3 = self.layer3(x2) # 100 178 | x4 = self.layer4(x3) # 100 179 | xp = self.aspp(x4) 180 | 181 | dec0_up = self.bot_aspp(xp) 182 | if self.skip == 'm1': 183 | dec0_fine = self.bot_fine(x1) 184 | dec0_up = Upsample(dec0_up, x1.size()[2:]) 185 | else: 186 | dec0_fine = self.bot_fine(x2) 187 | dec0_up = Upsample(dec0_up, x2.size()[2:]) 188 | 189 | dec0 = [dec0_fine, dec0_up] 190 | dec0 = torch.cat(dec0, 1) 191 | dec1 = self.final(dec0) 192 | main_out = Upsample(dec1, x_size[2:]) 193 | 194 | if self.training: 195 | return self.criterion(main_out, gts) 196 | 197 | return main_out 198 | 199 | 200 | class DeepWV3Plus(nn.Module): 201 | """ 202 | WideResNet38 version of DeepLabV3 203 | mod1 204 | pool2 205 | mod2 bot_fine 206 | pool3 207 | mod3-7 208 | bot_aspp 209 | 210 | structure: [3, 3, 6, 3, 1, 1] 211 | channels = [(128, 128), (256, 256), (512, 512), (512, 1024), (512, 1024, 2048), 212 | (1024, 2048, 4096)] 213 | """ 214 | 215 | def __init__(self, num_classes, trunk='WideResnet38', criterion=None): 216 | 217 | super(DeepWV3Plus, self).__init__() 218 | self.criterion = criterion 219 | logging.info("Trunk: %s", trunk) 220 | 221 | wide_resnet = wider_resnet38_a2(classes=1000, dilation=True) 222 | wide_resnet = torch.nn.DataParallel(wide_resnet) 223 | if criterion is not None: 224 | try: 225 | checkpoint = torch.load('./pretrained_models/wider_resnet38.pth.tar', map_location='cpu') 226 | wide_resnet.load_state_dict(checkpoint['state_dict']) 227 | del checkpoint 228 | except: 229 | print("Please download the ImageNet weights of WideResNet38 in our repo to ./pretrained_models/wider_resnet38.pth.tar.") 230 | raise RuntimeError("=====================Could not load ImageNet weights of WideResNet38 network.=======================") 231 | wide_resnet = wide_resnet.module 232 | 233 | self.mod1 = wide_resnet.mod1 234 | self.mod2 = wide_resnet.mod2 235 | self.mod3 = wide_resnet.mod3 236 | self.mod4 = wide_resnet.mod4 237 | self.mod5 = wide_resnet.mod5 238 | self.mod6 = wide_resnet.mod6 239 | self.mod7 = wide_resnet.mod7 240 | self.pool2 = wide_resnet.pool2 241 | self.pool3 = wide_resnet.pool3 242 | del wide_resnet 243 | 244 | self.aspp = _AtrousSpatialPyramidPoolingModule(4096, 256, 245 | output_stride=8) 246 | 247 | self.bot_fine = nn.Conv2d(128, 48, kernel_size=1, bias=False) 248 | self.bot_aspp = nn.Conv2d(1280, 256, kernel_size=1, bias=False) 249 | 250 | self.final = nn.Sequential( 251 | nn.Conv2d(256 + 48, 256, kernel_size=3, padding=1, bias=False), 252 | Norm2d(256), 253 | nn.ReLU(inplace=True), 254 | nn.Conv2d(256, 256, kernel_size=3, padding=1, bias=False), 255 | Norm2d(256), 256 | nn.ReLU(inplace=True), 257 | nn.Conv2d(256, num_classes, kernel_size=1, bias=False)) 258 | 259 | initialize_weights(self.final) 260 | 261 | def forward(self, inp, gts=None): 262 | 263 | x_size = inp.size() 264 | x = self.mod1(inp) 265 | m2 = self.mod2(self.pool2(x)) 266 | x = self.mod3(self.pool3(m2)) 267 | x = self.mod4(x) 268 | x = self.mod5(x) 269 | x = self.mod6(x) 270 | x = self.mod7(x) 271 | x = self.aspp(x) 272 | dec0_up = self.bot_aspp(x) 273 | 274 | dec0_fine = self.bot_fine(m2) 275 | dec0_up = Upsample(dec0_up, m2.size()[2:]) 276 | dec0 = [dec0_fine, dec0_up] 277 | dec0 = torch.cat(dec0, 1) 278 | 279 | dec1 = self.final(dec0) 280 | out = Upsample(dec1, x_size[2:]) 281 | 282 | if self.training: 283 | return self.criterion(out, gts) 284 | 285 | return out 286 | 287 | 288 | def DeepSRNX50V3PlusD_m1(num_classes, criterion): 289 | """ 290 | SEResNeXt-50 Based Network 291 | """ 292 | return DeepV3Plus(num_classes, trunk='seresnext-50', criterion=criterion, variant='D', 293 | skip='m1') 294 | 295 | def DeepR50V3PlusD_m1(num_classes, criterion): 296 | """ 297 | ResNet-50 Based Network 298 | """ 299 | return DeepV3Plus(num_classes, trunk='resnet-50', criterion=criterion, variant='D', skip='m1') 300 | 301 | 302 | def DeepSRNX101V3PlusD_m1(num_classes, criterion): 303 | """ 304 | SEResNeXt-101 Based Network 305 | """ 306 | return DeepV3Plus(num_classes, trunk='seresnext-101', criterion=criterion, variant='D', 307 | skip='m1') 308 | 309 | -------------------------------------------------------------------------------- /network/mynn.py: -------------------------------------------------------------------------------- 1 | """ 2 | Custom Norm wrappers to enable sync BN, regular BN and for weight initialization 3 | """ 4 | import torch.nn as nn 5 | from config import cfg 6 | 7 | from apex import amp 8 | 9 | def Norm2d(in_channels): 10 | """ 11 | Custom Norm Function to allow flexible switching 12 | """ 13 | layer = getattr(cfg.MODEL, 'BNFUNC') 14 | normalization_layer = layer(in_channels) 15 | return normalization_layer 16 | 17 | 18 | def initialize_weights(*models): 19 | """ 20 | Initialize Model Weights 21 | """ 22 | for model in models: 23 | for module in model.modules(): 24 | if isinstance(module, (nn.Conv2d, nn.Linear)): 25 | nn.init.kaiming_normal_(module.weight) 26 | if module.bias is not None: 27 | module.bias.data.zero_() 28 | elif isinstance(module, nn.BatchNorm2d): 29 | module.weight.data.fill_(1) 30 | module.bias.data.zero_() 31 | 32 | 33 | @amp.float_function 34 | def Upsample(x, size): 35 | """ 36 | Wrapper Around the Upsample Call 37 | """ 38 | return nn.functional.interpolate(x, size=size, mode='bilinear', 39 | align_corners=True) 40 | -------------------------------------------------------------------------------- /optimizer.py: -------------------------------------------------------------------------------- 1 | """ 2 | Pytorch Optimizer and Scheduler Related Task 3 | """ 4 | import math 5 | import logging 6 | import torch 7 | from torch import optim 8 | from config import cfg 9 | 10 | 11 | def get_optimizer(args, net): 12 | """ 13 | Decide Optimizer (Adam or SGD) 14 | """ 15 | param_groups = net.parameters() 16 | 17 | if args.sgd: 18 | optimizer = optim.SGD(param_groups, 19 | lr=args.lr, 20 | weight_decay=args.weight_decay, 21 | momentum=args.momentum, 22 | nesterov=False) 23 | elif args.adam: 24 | amsgrad = False 25 | if args.amsgrad: 26 | amsgrad = True 27 | optimizer = optim.Adam(param_groups, 28 | lr=args.lr, 29 | weight_decay=args.weight_decay, 30 | amsgrad=amsgrad 31 | ) 32 | else: 33 | raise ValueError('Not a valid optimizer') 34 | 35 | if args.lr_schedule == 'scl-poly': 36 | if cfg.REDUCE_BORDER_EPOCH == -1: 37 | raise ValueError('ERROR Cannot Do Scale Poly') 38 | 39 | rescale_thresh = cfg.REDUCE_BORDER_EPOCH 40 | scale_value = args.rescale 41 | lambda1 = lambda epoch: \ 42 | math.pow(1 - epoch / args.max_epoch, 43 | args.poly_exp) if epoch < rescale_thresh else scale_value * math.pow( 44 | 1 - (epoch - rescale_thresh) / (args.max_epoch - rescale_thresh), 45 | args.repoly) 46 | scheduler = optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda1) 47 | elif args.lr_schedule == 'poly': 48 | lambda1 = lambda epoch: math.pow(1 - epoch / args.max_epoch, args.poly_exp) 49 | scheduler = optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda1) 50 | else: 51 | raise ValueError('unknown lr schedule {}'.format(args.lr_schedule)) 52 | 53 | return optimizer, scheduler 54 | 55 | 56 | def load_weights(net, optimizer, snapshot_file, restore_optimizer_bool=False): 57 | """ 58 | Load weights from snapshot file 59 | """ 60 | logging.info("Loading weights from model %s", snapshot_file) 61 | net, optimizer = restore_snapshot(net, optimizer, snapshot_file, restore_optimizer_bool) 62 | return net, optimizer 63 | 64 | 65 | def restore_snapshot(net, optimizer, snapshot, restore_optimizer_bool): 66 | """ 67 | Restore weights and optimizer (if needed ) for resuming job. 68 | """ 69 | checkpoint = torch.load(snapshot, map_location=torch.device('cpu')) 70 | logging.info("Checkpoint Load Compelete") 71 | if optimizer is not None and 'optimizer' in checkpoint and restore_optimizer_bool: 72 | optimizer.load_state_dict(checkpoint['optimizer']) 73 | 74 | if 'state_dict' in checkpoint: 75 | net = forgiving_state_restore(net, checkpoint['state_dict']) 76 | else: 77 | net = forgiving_state_restore(net, checkpoint) 78 | 79 | return net, optimizer 80 | 81 | 82 | def forgiving_state_restore(net, loaded_dict): 83 | """ 84 | Handle partial loading when some tensors don't match up in size. 85 | Because we want to use models that were trained off a different 86 | number of classes. 87 | """ 88 | net_state_dict = net.state_dict() 89 | new_loaded_dict = {} 90 | for k in net_state_dict: 91 | if k in loaded_dict and net_state_dict[k].size() == loaded_dict[k].size(): 92 | new_loaded_dict[k] = loaded_dict[k] 93 | else: 94 | logging.info("Skipped loading parameter %s", k) 95 | net_state_dict.update(new_loaded_dict) 96 | net.load_state_dict(net_state_dict) 97 | return net 98 | -------------------------------------------------------------------------------- /scripts/eval_cityscapes_SEResNeXt50.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | echo "Running inference on" ${1} 3 | echo "Saving Results :" ${2} 4 | PYTHONPATH=$PWD:$PYTHONPATH python3 eval.py \ 5 | --dataset cityscapes \ 6 | --arch network.deepv3.DeepSRNX50V3PlusD_m1 \ 7 | --inference_mode sliding \ 8 | --scales 1.0 \ 9 | --split val \ 10 | --cv_split 0 \ 11 | --dump_images \ 12 | --ckpt_path ${2} \ 13 | --snapshot ${1} 14 | -------------------------------------------------------------------------------- /scripts/eval_cityscapes_WideResNet38.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | echo "Running inference on" ${1} 3 | echo "Saving Results :" ${2} 4 | PYTHONPATH=$PWD:$PYTHONPATH python3 eval.py \ 5 | --dataset cityscapes \ 6 | --arch network.deepv3.DeepWV3Plus \ 7 | --inference_mode sliding \ 8 | --scales 1.0 \ 9 | --split val \ 10 | --cv_split 0 \ 11 | --dump_images \ 12 | --ckpt_path ${2} \ 13 | --snapshot ${1} 14 | -------------------------------------------------------------------------------- /scripts/submit_cityscapes_WideResNet38.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | echo "Running inference on" ${1} 3 | echo "Saving Results :" ${2} 4 | PYTHONPATH=$PWD:$PYTHONPATH python3 eval.py \ 5 | --dataset cityscapes \ 6 | --arch network.deepv3.DeepWV3Plus \ 7 | --inference_mode sliding \ 8 | --scales 0.5,1.0,2.0 \ 9 | --split test \ 10 | --cv_split 0 \ 11 | --dump_images \ 12 | --ckpt_path ${2} \ 13 | --snapshot ${1} 14 | -------------------------------------------------------------------------------- /scripts/test_kitti_WideResNet38.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | echo "Running inference on" ${1} 3 | echo "Saving Results :" ${2} 4 | PYTHONPATH=$PWD:$PYTHONPATH python3 eval.py \ 5 | --dataset kitti \ 6 | --arch network.deepv3.DeepWV3Plus \ 7 | --mode semantic \ 8 | --split test \ 9 | --inference_mode sliding \ 10 | --cv_split 0 \ 11 | --scales 1.5,2.0,2.5 \ 12 | --crop_size 368 \ 13 | --dump_images \ 14 | --snapshot ${1} \ 15 | --ckpt_path ${2} 16 | -------------------------------------------------------------------------------- /scripts/train_cityscapes_SEResNeXt50.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | # Example on Cityscapes 4 | python -m torch.distributed.launch --nproc_per_node=8 train.py \ 5 | --dataset cityscapes \ 6 | --cv 0 \ 7 | --arch network.deepv3.DeepSRNX50V3PlusD_m1 \ 8 | --snapshot ./pretrained_models/YOUR_TRAINED_MAPILLARY_MODEL \ 9 | --class_uniform_pct 0.5 \ 10 | --class_uniform_tile 1024 \ 11 | --max_cu_epoch 150 \ 12 | --lr 0.001 \ 13 | --lr_schedule scl-poly \ 14 | --poly_exp 1.0 \ 15 | --repoly 1.5 \ 16 | --rescale 1.0 \ 17 | --syncbn \ 18 | --sgd \ 19 | --crop_size 896 \ 20 | --scale_min 0.5 \ 21 | --scale_max 2.0 \ 22 | --color_aug 0.25 \ 23 | --gblur \ 24 | --max_epoch 175 \ 25 | --coarse_boost_classes 14,15,16,3,12,17,4 \ 26 | --jointwtborder \ 27 | --strict_bdr_cls 5,6,7,11,12,17,18 \ 28 | --rlx_off_epoch 100 \ 29 | --wt_bound 1.0 \ 30 | --bs_mult 2 \ 31 | --apex \ 32 | --exp cityscapes_ft \ 33 | --ckpt ./logs/ \ 34 | --tb_path ./logs/ 35 | -------------------------------------------------------------------------------- /scripts/train_cityscapes_WideResNet38.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | # Example on Cityscapes 4 | python -m torch.distributed.launch --nproc_per_node=8 train.py \ 5 | --dataset cityscapes \ 6 | --cv 2 \ 7 | --arch network.deepv3.DeepWV3Plus \ 8 | --snapshot ./pretrained_models/YOUR_TRAINED_MAPILLARY_MODEL \ 9 | --class_uniform_pct 0.5 \ 10 | --class_uniform_tile 1024 \ 11 | --max_cu_epoch 150 \ 12 | --lr 0.001 \ 13 | --lr_schedule scl-poly \ 14 | --poly_exp 1.0 \ 15 | --repoly 1.5 \ 16 | --rescale 1.0 \ 17 | --syncbn \ 18 | --sgd \ 19 | --crop_size 896 \ 20 | --scale_min 0.5 \ 21 | --scale_max 2.0 \ 22 | --color_aug 0.25 \ 23 | --gblur \ 24 | --max_epoch 175 \ 25 | --coarse_boost_classes 14,15,16,3,12,17,4 \ 26 | --jointwtborder \ 27 | --strict_bdr_cls 5,6,7,11,12,17,18 \ 28 | --rlx_off_epoch 100 \ 29 | --wt_bound 1.0 \ 30 | --bs_mult 2 \ 31 | --apex \ 32 | --exp cityscapes_ft \ 33 | --ckpt ./logs/ \ 34 | --tb_path ./logs/ 35 | -------------------------------------------------------------------------------- /scripts/train_comma10k_WideResNet38.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | CUDA_VISIBLE_DEVICES=0 python -m torch.distributed.launch --nproc_per_node=1 train.py \ 4 | --dataset comma10k \ 5 | --cv 2 \ 6 | --arch network.deepv3.DeepWV3Plus \ 7 | --class_uniform_pct 0.0 \ 8 | --class_uniform_tile 300 \ 9 | --lr 2e-2 \ 10 | --lr_schedule poly \ 11 | --poly_exp 1.0 \ 12 | --adam \ 13 | --crop_size 360 \ 14 | --scale_min 1.0 \ 15 | --scale_max 2.0 \ 16 | --color_aug 0.25 \ 17 | --max_epoch 90 \ 18 | --img_wt_loss \ 19 | --wt_bound 1.0 \ 20 | --bs_mult 2 \ 21 | --exp comma10k \ 22 | --ckpt ./logs/ \ 23 | --tb_path ./logs/ 24 | 25 | #--lr 0.001 \ 26 | #--snapshot ./pretrained_models/cityscapes_best.pth \ 27 | #--apex \ 28 | #--syncbn \ 29 | -------------------------------------------------------------------------------- /scripts/train_kitti_WideResNet38.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | # Example on KITTI, fine tune 4 | python -m torch.distributed.launch --nproc_per_node=8 train.py \ 5 | --dataset kitti \ 6 | --cv 2 \ 7 | --arch network.deepv3.DeepWV3Plus \ 8 | --snapshot ./pretrained_models/cityscapes_best.pth \ 9 | --class_uniform_pct 0.5 \ 10 | --class_uniform_tile 300 \ 11 | --lr 0.001 \ 12 | --lr_schedule poly \ 13 | --poly_exp 1.0 \ 14 | --syncbn \ 15 | --sgd \ 16 | --crop_size 360 \ 17 | --scale_min 1.0 \ 18 | --scale_max 2.0 \ 19 | --color_aug 0.25 \ 20 | --max_epoch 90 \ 21 | --img_wt_loss \ 22 | --wt_bound 1.0 \ 23 | --bs_mult 2 \ 24 | --apex \ 25 | --exp kitti_ft \ 26 | --ckpt ./logs/ \ 27 | --tb_path ./logs/ 28 | 29 | 30 | -------------------------------------------------------------------------------- /scripts/train_mapillary_SEResNeXt50.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | # Example on Mapillary 4 | python -m torch.distributed.launch --nproc_per_node=8 train.py \ 5 | --dataset mapillary \ 6 | --arch network.deepv3.DeepSRNX50V3PlusD_m1 \ 7 | --class_uniform_pct 0.5 \ 8 | --class_uniform_tile 1024 \ 9 | --syncbn \ 10 | --sgd \ 11 | --lr 2e-2 \ 12 | --lr_schedule poly \ 13 | --poly_exp 1.0 \ 14 | --crop_size 896 \ 15 | --scale_min 0.5 \ 16 | --scale_max 2.0 \ 17 | --color_aug 0.25 \ 18 | --gblur \ 19 | --max_epoch 175 \ 20 | --img_wt_loss \ 21 | --wt_bound 6.0 \ 22 | --bs_mult 2 \ 23 | --apex \ 24 | --exp mapillary_pretrain \ 25 | --ckpt ./logs/ \ 26 | --tb_path ./logs/ 27 | -------------------------------------------------------------------------------- /scripts/train_mapillary_WideResNet38.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | # Example on Mapillary 4 | python -m torch.distributed.launch --nproc_per_node=8 train.py \ 5 | --dataset mapillary \ 6 | --arch network.deepv3.DeepWV3Plus \ 7 | --class_uniform_pct 0.5 \ 8 | --class_uniform_tile 1024 \ 9 | --syncbn \ 10 | --sgd \ 11 | --lr 2e-2 \ 12 | --lr_schedule poly \ 13 | --poly_exp 1.0 \ 14 | --crop_size 896 \ 15 | --scale_min 0.5 \ 16 | --scale_max 2.0 \ 17 | --color_aug 0.25 \ 18 | --gblur \ 19 | --max_epoch 175 \ 20 | --img_wt_loss \ 21 | --wt_bound 6.0 \ 22 | --bs_mult 2 \ 23 | --apex \ 24 | --exp mapillary_pretrain \ 25 | --ckpt ./logs/ \ 26 | --tb_path ./logs/ 27 | -------------------------------------------------------------------------------- /sdcnet/_aug.sh: -------------------------------------------------------------------------------- 1 | python sdc_aug.py --propagate 5 --vis \ 2 | --pretrained ../pretrained_models/sdc_cityscapes_vrec.pth.tar \ 3 | --flownet2_checkpoint ../pretrained_models/FlowNet2_checkpoint.pth.tar \ 4 | --source_dir /home/ubuntu/yizhu/data/Cityscapes \ 5 | --target_dir /home/ubuntu/yizhu/data/Cityscapes/cs_aug 6 | 7 | 8 | 9 | 10 | -------------------------------------------------------------------------------- /sdcnet/_eval.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | # Run SDC2DRecon on Cityscapes dataset 3 | 4 | # Root folder of cityscapes images 5 | VAL_FILE=~/data/tmp/tinycs 6 | SDC2DREC_CHECKPOINT=../pretrained_models/sdc_cityscapes_vrec.pth.tar 7 | FLOWNET2_CHECKPOINT=../pretrained_models/FlowNet2_checkpoint.pth.tar 8 | 9 | python3 main.py \ 10 | --eval \ 11 | --sequence_length 2 \ 12 | --save ./ \ 13 | --name __evalrun \ 14 | --val_n_batches 1 \ 15 | --write_images \ 16 | --dataset FrameLoader \ 17 | --model SDCNet2DRecon \ 18 | --val_file ${VAL_FILE} \ 19 | --resume ${SDC2DREC_CHECKPOINT} \ 20 | --flownet2_checkpoint ${FLOWNET2_CHECKPOINT} 21 | 22 | -------------------------------------------------------------------------------- /sdcnet/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | from .frame_loader import * -------------------------------------------------------------------------------- /sdcnet/datasets/dataset_utils.py: -------------------------------------------------------------------------------- 1 | from __future__ import division 2 | from __future__ import print_function 3 | 4 | import torch 5 | 6 | class StaticRandomCrop(object): 7 | """ 8 | Helper function for random spatial crop 9 | """ 10 | def __init__(self, size, image_shape): 11 | h, w = image_shape 12 | self.th, self.tw = size 13 | self.h1 = torch.randint(0, h - self.th + 1, (1,)).item() 14 | self.w1 = torch.randint(0, w - self.tw + 1, (1,)).item() 15 | 16 | def __call__(self, img): 17 | return img[self.h1:(self.h1 + self.th), self.w1:(self.w1 + self.tw), :] 18 | -------------------------------------------------------------------------------- /sdcnet/datasets/frame_loader.py: -------------------------------------------------------------------------------- 1 | from __future__ import division 2 | from __future__ import print_function 3 | 4 | import os 5 | import natsort 6 | import numpy as np 7 | import cv2 8 | 9 | 10 | import torch 11 | from torch.utils import data 12 | from datasets.dataset_utils import StaticRandomCrop 13 | 14 | class FrameLoader(data.Dataset): 15 | def __init__(self, args, root, is_training = False, transform=None): 16 | 17 | self.is_training = is_training 18 | self.transform = transform 19 | self.chsize = 3 20 | 21 | # carry over command line arguments 22 | assert args.sequence_length > 1, 'sequence length must be > 1' 23 | self.sequence_length = args.sequence_length 24 | 25 | assert args.sample_rate > 0, 'sample rate must be > 0' 26 | self.sample_rate = args.sample_rate 27 | 28 | self.crop_size = args.crop_size 29 | self.start_index = args.start_index 30 | self.stride = args.stride 31 | 32 | assert (os.path.exists(root)) 33 | if self.is_training: 34 | self.start_index = 0 35 | 36 | # collect, colors, motion vectors, and depth 37 | self.ref = self.collect_filelist(root) 38 | 39 | counts = [((len(el) - self.sequence_length) // (self.sample_rate)) for el in self.ref] 40 | self.total = np.sum(counts) 41 | self.cum_sum = list(np.cumsum([0] + [el for el in counts])) 42 | 43 | def collect_filelist(self, root): 44 | include_ext = [".png", ".jpg", "jpeg", ".bmp"] 45 | # collect subfolders, excluding hidden files, but following symlinks 46 | dirs = [x[0] for x in os.walk(root, followlinks=True) if not x[0].startswith('.')] 47 | 48 | # naturally sort, both dirs and individual images, while skipping hidden files 49 | dirs = natsort.natsorted(dirs) 50 | 51 | datasets = [ 52 | [os.path.join(fdir, el) for el in natsort.natsorted(os.listdir(fdir)) 53 | if os.path.isfile(os.path.join(fdir, el)) 54 | and not el.startswith('.') 55 | and any([el.endswith(ext) for ext in include_ext])] 56 | for fdir in dirs 57 | ] 58 | 59 | return [el for el in datasets if el] 60 | 61 | def __len__(self): 62 | return self.total 63 | 64 | def __getitem__(self, index): 65 | # adjust index 66 | index = len(self) + index if index < 0 else index 67 | index = index + self.start_index 68 | 69 | dataset_index = np.searchsorted(self.cum_sum, index + 1) 70 | index = self.sample_rate * (index - self.cum_sum[np.maximum(0, dataset_index - 1)]) 71 | 72 | image_list = self.ref[dataset_index - 1] 73 | input_files = [ image_list[index + offset] for offset in range(self.sequence_length + 1)] 74 | 75 | # reverse image order with p=0.5 76 | if self.is_training and torch.randint(0, 2, (1,)).item(): 77 | input_files = input_files[::-1] 78 | 79 | # images = [imageio.imread(imfile)[..., :self.chsize] for imfile in input_files] 80 | images = [cv2.imread(imfile)[..., :self.chsize] for imfile in input_files] 81 | input_shape = images[0].shape[:2] 82 | if self.is_training: 83 | cropper = StaticRandomCrop(self.crop_size, input_shape) 84 | images = map(cropper, images) 85 | 86 | # Pad images along height and width to fit them evenly into models. 87 | height, width = input_shape 88 | if (height % self.stride) != 0: 89 | padded_height = (height // self.stride + 1) * self.stride 90 | images = [ np.pad(im, ((0, padded_height - height), (0,0), (0,0)), 'reflect') for im in images] 91 | 92 | if (width % self.stride) != 0: 93 | padded_width = (width // self.stride + 1) * self.stride 94 | images = [np.pad(im, ((0, 0), (0, padded_width - width), (0, 0)), 'reflect') for im in images] 95 | 96 | input_images = [torch.from_numpy(im.transpose(2, 0, 1)).float() for im in images] 97 | 98 | output_dict = { 99 | 'image': input_images, 'ishape': input_shape, 'input_files': input_files 100 | } 101 | 102 | return output_dict -------------------------------------------------------------------------------- /sdcnet/models/__init__.py: -------------------------------------------------------------------------------- 1 | from .sdc_net2d import * -------------------------------------------------------------------------------- /sdcnet/models/model_utils.py: -------------------------------------------------------------------------------- 1 | from __future__ import division 2 | from __future__ import print_function 3 | 4 | import torch.nn as nn 5 | 6 | def conv2d(channels_in, channels_out, kernel_size=3, stride=1, bias = True): 7 | return nn.Sequential( 8 | nn.Conv2d(channels_in, channels_out, kernel_size=kernel_size, stride=stride, padding=(kernel_size-1)//2, bias=bias), 9 | nn.LeakyReLU(0.1,inplace=True) 10 | ) 11 | 12 | def deconv2d(channels_in, channels_out, kernel_size=4, stride=2, padding=1, bias=True): 13 | return nn.Sequential( 14 | nn.ConvTranspose2d(channels_in, channels_out, kernel_size=kernel_size, stride=stride, padding=padding, bias=bias), 15 | nn.LeakyReLU(0.1,inplace=True) 16 | ) -------------------------------------------------------------------------------- /sdcnet/models/sdc_net2d.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Portions of this code are adapted from: 3 | https://github.com/NVIDIA/flownet2-pytorch/blob/master/networks/FlowNetS.py 4 | https://github.com/ClementPinard/FlowNetPytorch/blob/master/models/FlowNetS.py 5 | ''' 6 | from __future__ import division 7 | from __future__ import print_function 8 | 9 | import torch 10 | import torch.nn as nn 11 | from torch.nn import init 12 | import os 13 | 14 | from models.model_utils import conv2d, deconv2d 15 | 16 | from flownet2_pytorch.models import FlowNet2 17 | from flownet2_pytorch.networks.resample2d_package.resample2d import Resample2d 18 | 19 | 20 | class SDCNet2D(nn.Module): 21 | def __init__(self, args): 22 | super(SDCNet2D,self).__init__() 23 | 24 | self.rgb_max = args.rgb_max 25 | self.sequence_length = args.sequence_length 26 | 27 | factor = 2 28 | input_channels = self.sequence_length * 3 + (self.sequence_length - 1) * 2 29 | 30 | self.conv1 = conv2d(input_channels, 64 // factor, kernel_size=7, stride=2) 31 | self.conv2 = conv2d(64 // factor, 128 // factor, kernel_size=5, stride=2) 32 | self.conv3 = conv2d(128 // factor, 256 // factor, kernel_size=5, stride=2) 33 | self.conv3_1 = conv2d(256 // factor, 256 // factor) 34 | self.conv4 = conv2d(256 // factor, 512 // factor, stride=2) 35 | self.conv4_1 = conv2d(512 // factor, 512 // factor) 36 | self.conv5 = conv2d(512 // factor, 512 // factor, stride=2) 37 | self.conv5_1 = conv2d(512 // factor, 512 // factor) 38 | self.conv6 = conv2d(512 // factor, 1024 // factor, stride=2) 39 | self.conv6_1 = conv2d(1024 // factor, 1024 // factor) 40 | 41 | self.deconv5 = deconv2d(1024 // factor, 512 // factor) 42 | self.deconv4 = deconv2d(1024 // factor, 256 // factor) 43 | self.deconv3 = deconv2d(768 // factor, 128 // factor) 44 | self.deconv2 = deconv2d(384 // factor, 64 // factor) 45 | self.deconv1 = deconv2d(192 // factor, 32 // factor) 46 | self.deconv0 = deconv2d(96 // factor, 16 // factor) 47 | 48 | self.final_flow = nn.Conv2d(input_channels + 16 // factor, 2, 49 | kernel_size=3, stride=1, padding=1, bias=True) 50 | 51 | 52 | # init parameters, when doing convtranspose3d, do bilinear init 53 | for m in self.modules(): 54 | if isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose3d): 55 | if m.bias is not None: 56 | init.uniform_(m.bias) 57 | init.xavier_uniform_(m.weight) 58 | 59 | self.flownet2 = FlowNet2(args, batchNorm=False) 60 | assert os.path.exists(args.flownet2_checkpoint), "flownet2 checkpoint must be provided." 61 | flownet2_checkpoint = torch.load(args.flownet2_checkpoint) 62 | self.flownet2.load_state_dict(flownet2_checkpoint['state_dict'], strict=False) 63 | 64 | for param in self.flownet2.parameters(): 65 | param.requires_grad = False 66 | 67 | self.warp_nn = Resample2d(bilinear=False) 68 | self.warp_bilinear = Resample2d(bilinear=True) 69 | 70 | self.L1Loss = nn.L1Loss() 71 | 72 | flow_mean = torch.FloatTensor([-0.94427323, -1.23077035]).view(1, 2, 1, 1) 73 | flow_std = torch.FloatTensor([13.77204132, 7.47463894]).view(1, 2, 1, 1) 74 | rgb_mean = torch.FloatTensor([106.7747911, 96.13649598, 76.61428884]).view(1, 3, 1, 1) 75 | 76 | self.register_buffer('flow_mean', flow_mean) 77 | self.register_buffer('flow_std', flow_std) 78 | self.register_buffer('rgb_mean', rgb_mean) 79 | 80 | self.ignore_keys = ['flownet2'] 81 | return 82 | 83 | def interframe_optical_flow(self, input_images): 84 | 85 | #FIXME: flownet2 implementation expects RGB images, 86 | # while image formats for both SDCNet and DeepLabV3 expects BGR. 87 | # input_images = [torch.flip(input_image, dims=[1]) for input_image in input_images] 88 | 89 | # Create image pairs for flownet, then merge batch and frame dimension 90 | # so theres only a single call to flownet2 is done. 91 | flownet2_inputs = torch.stack( 92 | [torch.cat([input_images[i + 1].unsqueeze(2), input_images[i].unsqueeze(2)], dim=2) for i in 93 | range(0, self.sequence_length - 1)], dim=0).contiguous() 94 | 95 | batch_size, channel_count, height, width = input_images[0].shape 96 | flownet2_inputs_flattened = flownet2_inputs.view(-1,channel_count, 2, height, width) 97 | flownet2_outputs = [self.flownet2(flownet2_input) for flownet2_input in 98 | torch.chunk(flownet2_inputs_flattened, self.sequence_length - 1)] 99 | 100 | #FIXME: flipback images to BGR, 101 | # input_images = [torch.flip(input_image, dims=[1]) for input_image in input_images] 102 | 103 | return flownet2_outputs 104 | 105 | def network_output(self, input_images, input_flows): 106 | 107 | # Normalize input flows 108 | input_flows = [(input_flow - self.flow_mean) / (3 * self.flow_std) for 109 | input_flow in input_flows] 110 | 111 | # Normalize input via flownet2-type normalisation 112 | concated_images = torch.cat([image.unsqueeze(2) for image in input_images], dim=2).contiguous() 113 | rgb_mean = concated_images.view(concated_images.size()[:2] + (-1,)).mean(dim=-1).view( 114 | concated_images.size()[:2] + 2 * (1,)) 115 | input_images = [(input_image - rgb_mean) / self.rgb_max for input_image in input_images] 116 | bsize, channels, height, width = input_flows[0].shape 117 | 118 | # Atypical concatenation of input images along channels (done for compatibility with pre-trained models) 119 | # for two rgb images, concated channels would appear as (r1r2g1g2b1b2), 120 | # instaed of typical (r1g1b1r2g2b2) that can be obtained by torch.cat(..,dim=1) 121 | input_images = torch.cat([input_image.unsqueeze(2) for input_image in input_images], dim=2) 122 | input_images = input_images.contiguous().view(bsize, -1, height, width) 123 | 124 | # same atypical concatenation done for input flows. 125 | input_flows = torch.cat([input_flow.unsqueeze(2) for input_flow in input_flows], dim=2) 126 | input_flows = input_flows.contiguous().view(bsize, -1, height, width) 127 | 128 | # Network input 129 | images_and_flows = torch.cat((input_flows, input_images), dim=1) 130 | 131 | out_conv1 = self.conv1(images_and_flows) 132 | 133 | out_conv2 = self.conv2(out_conv1) 134 | out_conv3 = self.conv3_1(self.conv3(out_conv2)) 135 | out_conv4 = self.conv4_1(self.conv4(out_conv3)) 136 | out_conv5 = self.conv5_1(self.conv5(out_conv4)) 137 | out_conv6 = self.conv6_1(self.conv6(out_conv5)) 138 | 139 | out_deconv5 = self.deconv5(out_conv6) 140 | concat5 = torch.cat((out_conv5, out_deconv5), 1) 141 | 142 | out_deconv4 = self.deconv4(concat5) 143 | concat4 = torch.cat((out_conv4, out_deconv4), 1) 144 | 145 | out_deconv3 = self.deconv3(concat4) 146 | concat3 = torch.cat((out_conv3, out_deconv3), 1) 147 | 148 | out_deconv2 = self.deconv2(concat3) 149 | concat2 = torch.cat((out_conv2, out_deconv2), 1) 150 | 151 | out_deconv1 = self.deconv1(concat2) 152 | concat1 = torch.cat((out_conv1, out_deconv1), 1) 153 | 154 | out_deconv0 = self.deconv0(concat1) 155 | 156 | concat0 = torch.cat((images_and_flows, out_deconv0), 1) 157 | output_flow = self.final_flow(concat0) 158 | 159 | flow_prediction = 3 * self.flow_std * output_flow + self.flow_mean 160 | 161 | return flow_prediction 162 | 163 | def prepare_inputs(self, input_dict): 164 | images = input_dict['image'] # expects a list 165 | 166 | input_images = images[:-1] 167 | 168 | target_image = images[-1] 169 | 170 | last_image = (input_images[-1]).clone() 171 | 172 | return input_images, last_image, target_image 173 | 174 | def forward(self, input_dict, label_image=None): 175 | 176 | input_images, last_image, target_image = self.prepare_inputs(input_dict) 177 | 178 | input_flows = self.interframe_optical_flow(input_images) 179 | 180 | flow_prediction = self.network_output(input_images, input_flows) 181 | 182 | image_prediction = self.warp_bilinear(last_image, flow_prediction) 183 | 184 | if label_image is not None: 185 | label_prediction = self.warp_nn(label_image, flow_prediction) 186 | 187 | # calculate losses 188 | losses = {} 189 | 190 | losses['color'] = self.L1Loss(image_prediction/self.rgb_max, target_image/self.rgb_max) 191 | 192 | losses['color_gradient'] = self.L1Loss(torch.abs(image_prediction[...,1:] - image_prediction[...,:-1]), \ 193 | torch.abs(target_image[...,1:] - target_image[...,:-1])) + \ 194 | self.L1Loss(torch.abs(image_prediction[..., 1:,:] - image_prediction[..., :-1,:]), \ 195 | torch.abs(target_image[..., 1:,:] - target_image[..., :-1,:])) 196 | 197 | losses['flow_smoothness'] = self.L1Loss(flow_prediction[...,1:], flow_prediction[...,:-1]) + \ 198 | self.L1Loss(flow_prediction[..., 1:,:], flow_prediction[..., :-1,:]) 199 | 200 | losses['tot'] = 0.7 * losses['color'] + 0.2 * losses['color_gradient'] + 0.1 * losses['flow_smoothness'] 201 | 202 | if label_image is not None: 203 | image_prediction = label_prediction 204 | 205 | return losses, image_prediction, target_image 206 | 207 | 208 | class SDCNet2DRecon(SDCNet2D): 209 | def __init__(self, args): 210 | args.sequence_length += 1 211 | super(SDCNet2DRecon, self).__init__(args) 212 | 213 | def prepare_inputs(self, input_dict): 214 | images = input_dict['image'] # expects a list 215 | 216 | input_images = images 217 | 218 | target_image = images[-1] 219 | 220 | last_image = (input_images[-2]).clone() 221 | 222 | return input_images, last_image, target_image -------------------------------------------------------------------------------- /sdcnet/spatialdisplconv_package/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/geohot/semantic-segmentation/d5b6ec7ec9d4e296fc0e1aa85163ed0e98f54f92/sdcnet/spatialdisplconv_package/__init__.py -------------------------------------------------------------------------------- /sdcnet/spatialdisplconv_package/setup.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | import os 3 | import torch 4 | 5 | from setuptools import setup 6 | from torch.utils.cpp_extension import BuildExtension, CUDAExtension 7 | 8 | cxx_args = ['-std=c++11'] 9 | 10 | nvcc_args = [ 11 | '-gencode', 'arch=compute_50,code=sm_50', 12 | '-gencode', 'arch=compute_52,code=sm_52', 13 | '-gencode', 'arch=compute_60,code=sm_60', 14 | '-gencode', 'arch=compute_61,code=sm_61', 15 | '-gencode', 'arch=compute_70,code=sm_70', 16 | '-gencode', 'arch=compute_70,code=compute_70' 17 | ] 18 | 19 | setup( 20 | name='spatialdisplconv_cuda', 21 | ext_modules=[ 22 | CUDAExtension('spatialdisplconv_cuda', [ 23 | 'spatialdisplconv_cuda.cc', 24 | 'spatialdisplconv_kernel.cu' 25 | ], extra_compile_args={'cxx': cxx_args, 'nvcc': nvcc_args}) 26 | ], 27 | cmdclass={ 28 | 'build_ext': BuildExtension 29 | }) 30 | -------------------------------------------------------------------------------- /sdcnet/spatialdisplconv_package/spatialdisplconv.py: -------------------------------------------------------------------------------- 1 | from torch.nn.modules.module import Module 2 | from torch.autograd import Function, Variable 3 | import spatialdisplconv_cuda 4 | 5 | class SpatialDisplConvFunction(Function): 6 | 7 | @staticmethod 8 | def forward(ctx, input1, input2, input3, input4, kernel_size = 1): 9 | assert input1.is_contiguous(), "spatialdisplconv forward - input1 is not contiguous" 10 | assert input2.is_contiguous(), "spatialdisplconv forward - input2 is not contiguous" 11 | assert input3.is_contiguous(), "spatialdisplconv forward - input3 is not contiguous" 12 | assert input4.is_contiguous(), "spatialdisplconv forward - input4 is not contiguous" 13 | 14 | ctx.save_for_backward(input1, input2, input3, input4) 15 | ctx.kernel_size = kernel_size 16 | 17 | _, image_channels, _, _ = input1.size() 18 | batch_size, _, height, width = input2.size() 19 | output = input1.new(batch_size, image_channels, height, width).zero_() 20 | 21 | spatialdisplconv_cuda.forward( 22 | input1, 23 | input2, 24 | input3, 25 | input4, 26 | output, 27 | kernel_size 28 | ) 29 | 30 | return output 31 | 32 | @staticmethod 33 | def backward(ctx, grad_output): 34 | grad_output = grad_output.contiguous() 35 | assert grad_output.is_contiguous() 36 | 37 | input1, input2, input3, input4 = ctx.saved_tensors 38 | 39 | grad_input1 = Variable(input1.new(input1.size()).zero_()) 40 | grad_input2 = Variable(input2.new(input2.size()).zero_()) 41 | grad_input3 = Variable(input3.new(input3.size()).zero_()) 42 | grad_input4 = Variable(input4.new(input4.size()).zero_()) 43 | 44 | spatialdisplconv_cuda.backward( 45 | input1, 46 | input2, 47 | input3, 48 | input4, 49 | grad_output.data, 50 | grad_input1.data, 51 | grad_input2.data, 52 | grad_input3.data, 53 | grad_input4.data, 54 | ctx.kernel_size 55 | ) 56 | 57 | return grad_input1, grad_input2, grad_input3, grad_input4, None 58 | 59 | class SpatialDisplConv(Module): 60 | def __init__(self, kernel_size = 1): 61 | super(SpatialDisplConv, self).__init__() 62 | self.kernel_size = kernel_size 63 | 64 | 65 | def forward(self, input1, input2, input3, input4): 66 | 67 | return SpatialDisplConvFunction.apply(input1, input2, input3, input4, self.kernel_size) 68 | -------------------------------------------------------------------------------- /sdcnet/spatialdisplconv_package/spatialdisplconv_cuda.cc: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | 4 | #include "spatialdisplconv_kernel.cuh" 5 | 6 | int spatialdisplconv_cuda_forward( 7 | at::Tensor& input1, 8 | at::Tensor& input2, 9 | at::Tensor& input3, 10 | at::Tensor& input4, 11 | at::Tensor& output, 12 | int kernel_size) { 13 | 14 | spatialdisplconv_kernel_forward( 15 | input1, 16 | input2, 17 | input3, 18 | input4, 19 | output, 20 | kernel_size 21 | ); 22 | 23 | return 1; 24 | } 25 | 26 | 27 | int spatialdisplconv_cuda_backward( 28 | at::Tensor& input1, 29 | at::Tensor& input2, 30 | at::Tensor& input3, 31 | at::Tensor& input4, 32 | at::Tensor& gradOutput, 33 | at::Tensor& gradInput1, 34 | at::Tensor& gradInput2, 35 | at::Tensor& gradInput3, 36 | at::Tensor& gradInput4, 37 | int kernel_size 38 | 39 | ) { 40 | spatialdisplconv_kernel_backward( 41 | input1, 42 | input2, 43 | input3, 44 | input4, 45 | gradOutput, 46 | gradInput1, 47 | gradInput2, 48 | gradInput3, 49 | gradInput4, 50 | kernel_size 51 | ); 52 | 53 | return 1; 54 | } 55 | 56 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 57 | m.def("forward", &spatialdisplconv_cuda_forward, "SpatialDisplConv forward (CUDA)"); 58 | m.def("backward", &spatialdisplconv_cuda_backward, "SpatialDisplConv backward (CUDA)"); 59 | } 60 | -------------------------------------------------------------------------------- /sdcnet/spatialdisplconv_package/spatialdisplconv_kernel.cuh: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | 5 | void spatialdisplconv_kernel_forward( 6 | at::Tensor& input1, 7 | at::Tensor& input2, 8 | at::Tensor& input3, 9 | at::Tensor& input4, 10 | at::Tensor& output, 11 | int kernel_size 12 | ); 13 | 14 | void spatialdisplconv_kernel_backward( 15 | at::Tensor& input1, 16 | at::Tensor& input2, 17 | at::Tensor& input3, 18 | at::Tensor& input4, 19 | at::Tensor& gradOutput, 20 | at::Tensor& gradInput1, 21 | at::Tensor& gradInput2, 22 | at::Tensor& gradInput3, 23 | at::Tensor& gradInput4, 24 | int kernel_size 25 | 26 | ); 27 | -------------------------------------------------------------------------------- /sdcnet/spatialdisplconv_package/test_spatialdisplconv.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import time 3 | from spatialdisplconv import SpatialDisplConv 4 | 5 | assert torch.cuda.is_available() 6 | cuda_device = torch.device("cuda") # device object representing GPU 7 | 8 | n = 8 9 | h = 224 10 | w = 224 11 | 12 | offset = 9 # 11 13 | 14 | #input1 = N, 3, H + 11, W + 11 15 | #input2 = N, 11, H, W 16 | #input3 = N, 11, H, W 17 | #input4 = N, 2, H, W 18 | 19 | # Note the device=cuda_device arguments here 20 | a = torch.randn(n, 3, h + offset, w + offset, device=cuda_device, requires_grad=True).contiguous() 21 | 22 | b = torch.randn(n, offset, h, w, device=cuda_device, requires_grad=True).contiguous() 23 | 24 | c = torch.randn(n, offset, h, w, device=cuda_device, requires_grad=True).contiguous() 25 | 26 | d = torch.randn(n, 2, h, w, device=cuda_device, requires_grad=True).contiguous() 27 | 28 | sdc_layer = SpatialDisplConv(kernel_size=1).cuda() 29 | 30 | forward = 0 31 | backward = 0 32 | num_runs = 100 33 | 34 | for _ in range(num_runs): 35 | 36 | start = time.time() 37 | 38 | result = sdc_layer.forward(a, b, c, d) 39 | torch.cuda.synchronize() 40 | forward += time.time() - start 41 | 42 | sdc_layer.zero_grad() 43 | 44 | start = time.time() 45 | 46 | result_sum = result.sum() 47 | 48 | result_sum.backward() 49 | torch.cuda.synchronize() 50 | backward += time.time() - start 51 | 52 | print("Forward time per iteration: %.4f ms" % (forward * 1.0 / num_runs * 1000)) 53 | print("Backward time per iteration: %.4f ms" % (forward * 1.0 / num_runs * 1000)) 54 | -------------------------------------------------------------------------------- /sdcnet/utility/Dockerfile: -------------------------------------------------------------------------------- 1 | # =========== 2 | # base images 3 | # =========== 4 | FROM FROM nvidia/cuda:10.0-cudnn7-devel-ubuntu16.04 5 | 6 | 7 | # =============== 8 | # system packages 9 | # =============== 10 | RUN apt-get update 11 | RUN apt-get install -y bash-completion \ 12 | emacs \ 13 | ffmpeg \ 14 | git \ 15 | graphviz \ 16 | htop \ 17 | libopenexr-dev \ 18 | openssh-server \ 19 | rsync \ 20 | wget 21 | 22 | # =========== 23 | # latest apex 24 | # =========== 25 | RUN pip uninstall -y apex 26 | RUN git clone https://github.com/NVIDIA/apex.git ~/apex && \ 27 | cd ~/apex && \ 28 | pip install -v --no-cache-dir --global-option="--cpp_ext" --global-option="--cuda_ext" . 29 | 30 | # ============ 31 | # pip packages 32 | # ============ 33 | RUN pip install --upgrade pip 34 | RUN pip install --upgrade setuptools 35 | RUN pip install --upgrade boto3 36 | RUN pip install --upgrade cffi 37 | RUN pip install --upgrade colorama==0.3.7 38 | RUN pip install --upgrade Cython 39 | RUN pip install --upgrade dominate 40 | RUN pip install --upgrade ffmpeg 41 | RUN pip install --upgrade graphviz 42 | RUN pip install --upgrade imageio 43 | RUN pip install --upgrade ipython 44 | RUN pip install --upgrade matplotlib 45 | RUN pip install --upgrade natsort 46 | RUN pip install --upgrade nltk 47 | RUN pip install --upgrade numpy 48 | RUN pip install --upgrade openexr 49 | RUN pip install --upgrade packaging 50 | RUN pip install --upgrade pandas 51 | RUN pip install --upgrade pillow 52 | RUN pip install --upgrade pylint 53 | RUN pip install --upgrade pytz 54 | RUN pip install --upgrade pyyaml 55 | RUN pip install --upgrade requests 56 | RUN pip install --upgrade scikit-image 57 | RUN pip install --upgrade scikit-learn 58 | RUN pip install --upgrade scipy 59 | RUN pip install --upgrade sentencepiece 60 | RUN pip install --upgrade setproctitle 61 | RUN pip install --upgrade tensorboard 62 | RUN pip install --upgrade tensorboardX 63 | RUN pip install --upgrade tensorflow 64 | RUN pip install --upgrade torchvision 65 | RUN pip install --upgrade tqdm 66 | RUN pip install --upgrade youtube_dl 67 | RUN pip install --opencv-python 68 | -------------------------------------------------------------------------------- /sdcnet/utility/tools.py: -------------------------------------------------------------------------------- 1 | import os 2 | import subprocess 3 | import time 4 | from inspect import isclass 5 | 6 | class TimerBlock: 7 | def __init__(self, title): 8 | print(("{}".format(title))) 9 | 10 | def __enter__(self): 11 | self.start = time.clock() 12 | return self 13 | 14 | def __exit__(self, exc_type, exc_value, traceback): 15 | self.end = time.clock() 16 | self.interval = self.end - self.start 17 | 18 | if exc_type is not None: 19 | self.log("Operation failed\n") 20 | else: 21 | self.log("Operation finished\n") 22 | 23 | def log(self, string): 24 | duration = time.clock() - self.start 25 | units = 's' 26 | if duration > 60: 27 | duration = duration / 60. 28 | units = 'm' 29 | print(" [{:.3f}{}] {}".format(duration, units, string), flush = True) 30 | 31 | def module_to_dict(module, exclude=[]): 32 | return dict([(x, getattr(module, x)) for x in dir(module) 33 | if isclass(getattr(module, x)) 34 | and x not in exclude 35 | and getattr(module, x) not in exclude]) 36 | 37 | 38 | # creat_pipe: adapted from https://stackoverflow.com/questions/23709893/popen-write-operation-on-closed-file-images-to-video-using-ffmpeg/23709937#23709937 39 | # start an ffmpeg pipe for creating RGB8 for color images or FFV1 for depth 40 | # NOTE: this is REALLY lossy and not optimal for HDR data. when it comes time to train 41 | # on HDR data, you'll need to figure out the way to save to pix_fmt=rgb48 or something 42 | # similar 43 | def create_pipe(pipe_filename, width, height, frame_rate=60, quite=True): 44 | # default extension and tonemapper 45 | pix_fmt = 'rgb24' 46 | out_fmt = 'yuv420p' 47 | codec = 'h264' 48 | 49 | command = ['ffmpeg', 50 | '-threads', '2', # number of threads to start 51 | '-y', # (optional) overwrite output file if it exists 52 | '-f', 'rawvideo', # input format 53 | '-vcodec', 'rawvideo', # input codec 54 | '-s', str(width) + 'x' + str(height), # size of one frame 55 | '-pix_fmt', pix_fmt, # input pixel format 56 | '-r', str(frame_rate), # frames per second 57 | '-i', '-', # The imput comes from a pipe 58 | '-an', # Tells FFMPEG not to expect any audio 59 | '-codec:v', codec, # output codec 60 | '-crf', '18', 61 | # compression quality for h264 (maybe h265 too?) - http://slhck.info/video/2017/02/24/crf-guide.html 62 | # '-compression_level', '10', # compression level for libjpeg if doing lossy depth 63 | '-strict', '-2', # experimental 16 bit support nessesary for gray16le 64 | '-pix_fmt', out_fmt, # output pixel format 65 | '-s', str(width) + 'x' + str(height), # output size 66 | pipe_filename] 67 | cmd = ' '.join(command) 68 | if not quite: 69 | print('openning a pip ....\n' + cmd + '\n') 70 | 71 | # open the pipe, and ignore stdout and stderr output 72 | DEVNULL = open(os.devnull, 'wb') 73 | return subprocess.Popen(command, stdin=subprocess.PIPE, stdout=DEVNULL, stderr=DEVNULL, close_fds=True) 74 | 75 | # AverageMeter: code from https://github.com/pytorch/examples/blob/master/imagenet/main.py 76 | class AverageMeter(object): 77 | """Computes and stores the average and current value""" 78 | def __init__(self, name, fmt=':f'): 79 | self.name = name 80 | self.fmt = fmt 81 | self.reset() 82 | 83 | def reset(self): 84 | self.val = 0 85 | self.avg = 0 86 | self.sum = 0 87 | self.count = 0 88 | 89 | def update(self, val, n=1): 90 | self.val = val 91 | self.sum += val * n 92 | self.count += n 93 | self.avg = self.sum / self.count 94 | 95 | def __str__(self): 96 | fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})' 97 | return fmtstr.format(**self.__dict__) -------------------------------------------------------------------------------- /transforms/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/geohot/semantic-segmentation/d5b6ec7ec9d4e296fc0e1aa85163ed0e98f54f92/transforms/__init__.py -------------------------------------------------------------------------------- /transforms/transforms.py: -------------------------------------------------------------------------------- 1 | """ 2 | # Code borrowded from: 3 | # https://github.com/zijundeng/pytorch-semantic-segmentation/blob/master/utils/transforms.py 4 | # 5 | # 6 | # MIT License 7 | # 8 | # Copyright (c) 2017 ZijunDeng 9 | # 10 | # Permission is hereby granted, free of charge, to any person obtaining a copy 11 | # of this software and associated documentation files (the "Software"), to deal 12 | # in the Software without restriction, including without limitation the rights 13 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 14 | # copies of the Software, and to permit persons to whom the Software is 15 | # furnished to do so, subject to the following conditions: 16 | # 17 | # The above copyright notice and this permission notice shall be included in all 18 | # copies or substantial portions of the Software. 19 | # 20 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 21 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 22 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 23 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 24 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 25 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 26 | # SOFTWARE. 27 | 28 | """ 29 | 30 | """ 31 | Standard Transform 32 | """ 33 | 34 | import random 35 | import numpy as np 36 | from skimage.filters import gaussian 37 | from skimage.restoration import denoise_bilateral 38 | import torch 39 | from PIL import Image, ImageEnhance 40 | import torchvision.transforms as torch_tr 41 | from config import cfg 42 | from scipy.ndimage.interpolation import shift 43 | 44 | from skimage.segmentation import find_boundaries 45 | 46 | try: 47 | import accimage 48 | except ImportError: 49 | accimage = None 50 | 51 | 52 | class RandomVerticalFlip(object): 53 | def __call__(self, img): 54 | if random.random() < 0.5: 55 | return img.transpose(Image.FLIP_TOP_BOTTOM) 56 | return img 57 | 58 | 59 | class DeNormalize(object): 60 | def __init__(self, mean, std): 61 | self.mean = mean 62 | self.std = std 63 | 64 | def __call__(self, tensor): 65 | for t, m, s in zip(tensor, self.mean, self.std): 66 | t.mul_(s).add_(m) 67 | return tensor 68 | 69 | 70 | class MaskToTensor(object): 71 | def __call__(self, img): 72 | return torch.from_numpy(np.array(img, dtype=np.int32)).long() 73 | 74 | class RelaxedBoundaryLossToTensor(object): 75 | """ 76 | Boundary Relaxation 77 | """ 78 | def __init__(self,ignore_id, num_classes): 79 | self.ignore_id=ignore_id 80 | self.num_classes= num_classes 81 | 82 | 83 | def new_one_hot_converter(self,a): 84 | ncols = self.num_classes+1 85 | out = np.zeros( (a.size,ncols), dtype=np.uint8) 86 | out[np.arange(a.size),a.ravel()] = 1 87 | out.shape = a.shape + (ncols,) 88 | return out 89 | 90 | def __call__(self,img): 91 | 92 | img_arr = np.array(img) 93 | img_arr[img_arr==self.ignore_id]=self.num_classes 94 | 95 | if cfg.STRICTBORDERCLASS != None: 96 | one_hot_orig = self.new_one_hot_converter(img_arr) 97 | mask = np.zeros((img_arr.shape[0],img_arr.shape[1])) 98 | for cls in cfg.STRICTBORDERCLASS: 99 | mask = np.logical_or(mask,(img_arr == cls)) 100 | one_hot = 0 101 | 102 | border = cfg.BORDER_WINDOW 103 | if (cfg.REDUCE_BORDER_EPOCH !=-1 and cfg.EPOCH > cfg.REDUCE_BORDER_EPOCH): 104 | border = border // 2 105 | border_prediction = find_boundaries(img_arr, mode='thick').astype(np.uint8) 106 | 107 | for i in range(-border,border+1): 108 | for j in range(-border, border+1): 109 | shifted= shift(img_arr,(i,j), cval=self.num_classes) 110 | one_hot += self.new_one_hot_converter(shifted) 111 | 112 | one_hot[one_hot>1] = 1 113 | 114 | if cfg.STRICTBORDERCLASS != None: 115 | one_hot = np.where(np.expand_dims(mask,2), one_hot_orig, one_hot) 116 | 117 | one_hot = np.moveaxis(one_hot,-1,0) 118 | 119 | 120 | if (cfg.REDUCE_BORDER_EPOCH !=-1 and cfg.EPOCH > cfg.REDUCE_BORDER_EPOCH): 121 | one_hot = np.where(border_prediction,2*one_hot,1*one_hot) 122 | # print(one_hot.shape) 123 | return torch.from_numpy(one_hot).byte() 124 | 125 | class ResizeHeight(object): 126 | def __init__(self, size, interpolation=Image.BILINEAR): 127 | self.target_h = size 128 | self.interpolation = interpolation 129 | 130 | def __call__(self, img): 131 | w, h = img.size 132 | target_w = int(w / h * self.target_h) 133 | return img.resize((target_w, self.target_h), self.interpolation) 134 | 135 | 136 | class FreeScale(object): 137 | def __init__(self, size, interpolation=Image.BILINEAR): 138 | self.size = tuple(reversed(size)) # size: (h, w) 139 | self.interpolation = interpolation 140 | 141 | def __call__(self, img): 142 | return img.resize(self.size, self.interpolation) 143 | 144 | 145 | class FlipChannels(object): 146 | """ 147 | Flip around the x-axis 148 | """ 149 | def __call__(self, img): 150 | img = np.array(img)[:, :, ::-1] 151 | return Image.fromarray(img.astype(np.uint8)) 152 | 153 | 154 | class RandomGaussianBlur(object): 155 | """ 156 | Apply Gaussian Blur 157 | """ 158 | def __call__(self, img): 159 | sigma = 0.15 + random.random() * 1.15 160 | blurred_img = gaussian(np.array(img), sigma=sigma, multichannel=True) 161 | blurred_img *= 255 162 | return Image.fromarray(blurred_img.astype(np.uint8)) 163 | 164 | 165 | class RandomBilateralBlur(object): 166 | """ 167 | Apply Bilateral Filtering 168 | 169 | """ 170 | def __call__(self, img): 171 | sigma = random.uniform(0.05,0.75) 172 | blurred_img = denoise_bilateral(np.array(img), sigma_spatial=sigma, multichannel=True) 173 | blurred_img *= 255 174 | return Image.fromarray(blurred_img.astype(np.uint8)) 175 | 176 | def _is_pil_image(img): 177 | if accimage is not None: 178 | return isinstance(img, (Image.Image, accimage.Image)) 179 | else: 180 | return isinstance(img, Image.Image) 181 | 182 | 183 | def adjust_brightness(img, brightness_factor): 184 | """Adjust brightness of an Image. 185 | 186 | Args: 187 | img (PIL Image): PIL Image to be adjusted. 188 | brightness_factor (float): How much to adjust the brightness. Can be 189 | any non negative number. 0 gives a black image, 1 gives the 190 | original image while 2 increases the brightness by a factor of 2. 191 | 192 | Returns: 193 | PIL Image: Brightness adjusted image. 194 | """ 195 | if not _is_pil_image(img): 196 | raise TypeError('img should be PIL Image. Got {}'.format(type(img))) 197 | 198 | enhancer = ImageEnhance.Brightness(img) 199 | img = enhancer.enhance(brightness_factor) 200 | return img 201 | 202 | 203 | def adjust_contrast(img, contrast_factor): 204 | """Adjust contrast of an Image. 205 | 206 | Args: 207 | img (PIL Image): PIL Image to be adjusted. 208 | contrast_factor (float): How much to adjust the contrast. Can be any 209 | non negative number. 0 gives a solid gray image, 1 gives the 210 | original image while 2 increases the contrast by a factor of 2. 211 | 212 | Returns: 213 | PIL Image: Contrast adjusted image. 214 | """ 215 | if not _is_pil_image(img): 216 | raise TypeError('img should be PIL Image. Got {}'.format(type(img))) 217 | 218 | enhancer = ImageEnhance.Contrast(img) 219 | img = enhancer.enhance(contrast_factor) 220 | return img 221 | 222 | 223 | def adjust_saturation(img, saturation_factor): 224 | """Adjust color saturation of an image. 225 | 226 | Args: 227 | img (PIL Image): PIL Image to be adjusted. 228 | saturation_factor (float): How much to adjust the saturation. 0 will 229 | give a black and white image, 1 will give the original image while 230 | 2 will enhance the saturation by a factor of 2. 231 | 232 | Returns: 233 | PIL Image: Saturation adjusted image. 234 | """ 235 | if not _is_pil_image(img): 236 | raise TypeError('img should be PIL Image. Got {}'.format(type(img))) 237 | 238 | enhancer = ImageEnhance.Color(img) 239 | img = enhancer.enhance(saturation_factor) 240 | return img 241 | 242 | 243 | def adjust_hue(img, hue_factor): 244 | """Adjust hue of an image. 245 | 246 | The image hue is adjusted by converting the image to HSV and 247 | cyclically shifting the intensities in the hue channel (H). 248 | The image is then converted back to original image mode. 249 | 250 | `hue_factor` is the amount of shift in H channel and must be in the 251 | interval `[-0.5, 0.5]`. 252 | 253 | See https://en.wikipedia.org/wiki/Hue for more details on Hue. 254 | 255 | Args: 256 | img (PIL Image): PIL Image to be adjusted. 257 | hue_factor (float): How much to shift the hue channel. Should be in 258 | [-0.5, 0.5]. 0.5 and -0.5 give complete reversal of hue channel in 259 | HSV space in positive and negative direction respectively. 260 | 0 means no shift. Therefore, both -0.5 and 0.5 will give an image 261 | with complementary colors while 0 gives the original image. 262 | 263 | Returns: 264 | PIL Image: Hue adjusted image. 265 | """ 266 | if not(-0.5 <= hue_factor <= 0.5): 267 | raise ValueError('hue_factor is not in [-0.5, 0.5].'.format(hue_factor)) 268 | 269 | if not _is_pil_image(img): 270 | raise TypeError('img should be PIL Image. Got {}'.format(type(img))) 271 | 272 | input_mode = img.mode 273 | if input_mode in {'L', '1', 'I', 'F'}: 274 | return img 275 | 276 | h, s, v = img.convert('HSV').split() 277 | 278 | np_h = np.array(h, dtype=np.uint8) 279 | # uint8 addition take cares of rotation across boundaries 280 | with np.errstate(over='ignore'): 281 | np_h += np.uint8(hue_factor * 255) 282 | h = Image.fromarray(np_h, 'L') 283 | 284 | img = Image.merge('HSV', (h, s, v)).convert(input_mode) 285 | return img 286 | 287 | 288 | class ColorJitter(object): 289 | """Randomly change the brightness, contrast and saturation of an image. 290 | 291 | Args: 292 | brightness (float): How much to jitter brightness. brightness_factor 293 | is chosen uniformly from [max(0, 1 - brightness), 1 + brightness]. 294 | contrast (float): How much to jitter contrast. contrast_factor 295 | is chosen uniformly from [max(0, 1 - contrast), 1 + contrast]. 296 | saturation (float): How much to jitter saturation. saturation_factor 297 | is chosen uniformly from [max(0, 1 - saturation), 1 + saturation]. 298 | hue(float): How much to jitter hue. hue_factor is chosen uniformly from 299 | [-hue, hue]. Should be >=0 and <= 0.5. 300 | """ 301 | def __init__(self, brightness=0, contrast=0, saturation=0, hue=0): 302 | self.brightness = brightness 303 | self.contrast = contrast 304 | self.saturation = saturation 305 | self.hue = hue 306 | 307 | @staticmethod 308 | def get_params(brightness, contrast, saturation, hue): 309 | """Get a randomized transform to be applied on image. 310 | 311 | Arguments are same as that of __init__. 312 | 313 | Returns: 314 | Transform which randomly adjusts brightness, contrast and 315 | saturation in a random order. 316 | """ 317 | transforms = [] 318 | if brightness > 0: 319 | brightness_factor = np.random.uniform(max(0, 1 - brightness), 1 + brightness) 320 | transforms.append( 321 | torch_tr.Lambda(lambda img: adjust_brightness(img, brightness_factor))) 322 | 323 | if contrast > 0: 324 | contrast_factor = np.random.uniform(max(0, 1 - contrast), 1 + contrast) 325 | transforms.append( 326 | torch_tr.Lambda(lambda img: adjust_contrast(img, contrast_factor))) 327 | 328 | if saturation > 0: 329 | saturation_factor = np.random.uniform(max(0, 1 - saturation), 1 + saturation) 330 | transforms.append( 331 | torch_tr.Lambda(lambda img: adjust_saturation(img, saturation_factor))) 332 | 333 | if hue > 0: 334 | hue_factor = np.random.uniform(-hue, hue) 335 | transforms.append( 336 | torch_tr.Lambda(lambda img: adjust_hue(img, hue_factor))) 337 | 338 | np.random.shuffle(transforms) 339 | transform = torch_tr.Compose(transforms) 340 | 341 | return transform 342 | 343 | def __call__(self, img): 344 | """ 345 | Args: 346 | img (PIL Image): Input image. 347 | 348 | Returns: 349 | PIL Image: Color jittered image. 350 | """ 351 | transform = self.get_params(self.brightness, self.contrast, 352 | self.saturation, self.hue) 353 | return transform(img) 354 | -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/geohot/semantic-segmentation/d5b6ec7ec9d4e296fc0e1aa85163ed0e98f54f92/utils/__init__.py -------------------------------------------------------------------------------- /utils/attr_dict.py: -------------------------------------------------------------------------------- 1 | """ 2 | # Code adapted from: 3 | # https://github.com/facebookresearch/Detectron/blob/master/detectron/utils/collections.py 4 | 5 | Source License 6 | # Copyright (c) 2017-present, Facebook, Inc. 7 | # 8 | # Licensed under the Apache License, Version 2.0 (the "License"); 9 | # you may not use this file except in compliance with the License. 10 | # You may obtain a copy of the License at 11 | # 12 | # http://www.apache.org/licenses/LICENSE-2.0 13 | # 14 | # Unless required by applicable law or agreed to in writing, software 15 | # distributed under the License is distributed on an "AS IS" BASIS, 16 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 17 | # See the License for the specific language governing permissions and 18 | # limitations under the License. 19 | ############################################################################## 20 | # 21 | # Based on: 22 | # -------------------------------------------------------- 23 | # Fast R-CNN 24 | # Copyright (c) 2015 Microsoft 25 | # Licensed under The MIT License [see LICENSE for details] 26 | # Written by Ross Girshick 27 | # -------------------------------------------------------- 28 | """ 29 | 30 | class AttrDict(dict): 31 | 32 | IMMUTABLE = '__immutable__' 33 | 34 | def __init__(self, *args, **kwargs): 35 | super(AttrDict, self).__init__(*args, **kwargs) 36 | self.__dict__[AttrDict.IMMUTABLE] = False 37 | 38 | def __getattr__(self, name): 39 | if name in self.__dict__: 40 | return self.__dict__[name] 41 | elif name in self: 42 | return self[name] 43 | else: 44 | raise AttributeError(name) 45 | 46 | def __setattr__(self, name, value): 47 | if not self.__dict__[AttrDict.IMMUTABLE]: 48 | if name in self.__dict__: 49 | self.__dict__[name] = value 50 | else: 51 | self[name] = value 52 | else: 53 | raise AttributeError( 54 | 'Attempted to set "{}" to "{}", but AttrDict is immutable'. 55 | format(name, value) 56 | ) 57 | 58 | def immutable(self, is_immutable): 59 | """Set immutability to is_immutable and recursively apply the setting 60 | to all nested AttrDicts. 61 | """ 62 | self.__dict__[AttrDict.IMMUTABLE] = is_immutable 63 | # Recursively set immutable state 64 | for v in self.__dict__.values(): 65 | if isinstance(v, AttrDict): 66 | v.immutable(is_immutable) 67 | for v in self.values(): 68 | if isinstance(v, AttrDict): 69 | v.immutable(is_immutable) 70 | 71 | def is_immutable(self): 72 | return self.__dict__[AttrDict.IMMUTABLE] 73 | -------------------------------------------------------------------------------- /utils/misc.py: -------------------------------------------------------------------------------- 1 | """ 2 | Miscellanous Functions 3 | """ 4 | 5 | import sys 6 | import re 7 | import os 8 | import shutil 9 | import torch 10 | from datetime import datetime 11 | import logging 12 | from subprocess import call 13 | import shlex 14 | from tensorboardX import SummaryWriter 15 | import numpy as np 16 | import torchvision.transforms as standard_transforms 17 | import torchvision.utils as vutils 18 | 19 | 20 | # Create unique output dir name based on non-default command line args 21 | def make_exp_name(args, parser): 22 | exp_name = '{}-{}'.format(args.dataset[:4], args.arch[:]) 23 | dict_args = vars(args) 24 | 25 | # sort so that we get a consistent directory name 26 | argnames = sorted(dict_args) 27 | ignorelist = ['exp', 'arch','prev_best_filepath', 'lr_schedule', 'max_cu_epoch', 'max_epoch', 28 | 'strict_bdr_cls', 'world_size', 'tb_path','best_record', 'test_mode', 'ckpt'] 29 | # build experiment name with non-default args 30 | for argname in argnames: 31 | if dict_args[argname] != parser.get_default(argname): 32 | if argname in ignorelist: 33 | continue 34 | if argname == 'snapshot': 35 | arg_str = 'PT' 36 | argname = '' 37 | elif argname == 'nosave': 38 | arg_str = '' 39 | argname='' 40 | elif argname == 'freeze_trunk': 41 | argname = '' 42 | arg_str = 'ft' 43 | elif argname == 'syncbn': 44 | argname = '' 45 | arg_str = 'sbn' 46 | elif argname == 'jointwtborder': 47 | argname = '' 48 | arg_str = 'rlx_loss' 49 | elif isinstance(dict_args[argname], bool): 50 | arg_str = 'T' if dict_args[argname] else 'F' 51 | else: 52 | arg_str = str(dict_args[argname])[:7] 53 | if argname is not '': 54 | exp_name += '_{}_{}'.format(str(argname), arg_str) 55 | else: 56 | exp_name += '_{}'.format(arg_str) 57 | # clean special chars out exp_name = re.sub(r'[^A-Za-z0-9_\-]+', '', exp_name) 58 | return exp_name 59 | 60 | def fast_hist(label_pred, label_true, num_classes): 61 | mask = (label_true >= 0) & (label_true < num_classes) 62 | hist = np.bincount( 63 | num_classes * label_true[mask].astype(int) + 64 | label_pred[mask], minlength=num_classes ** 2).reshape(num_classes, num_classes) 65 | return hist 66 | 67 | def per_class_iu(hist): 68 | return np.diag(hist) / (hist.sum(1) + hist.sum(0) - np.diag(hist)) 69 | 70 | def save_log(prefix, output_dir, date_str, rank=0): 71 | fmt = '%(asctime)s.%(msecs)03d %(message)s' 72 | date_fmt = '%m-%d %H:%M:%S' 73 | filename = os.path.join(output_dir, prefix + '_' + date_str +'_rank_' + str(rank) +'.log') 74 | print("Logging :", filename) 75 | logging.basicConfig(level=logging.INFO, format=fmt, datefmt=date_fmt, 76 | filename=filename, filemode='w') 77 | console = logging.StreamHandler() 78 | console.setLevel(logging.INFO) 79 | formatter = logging.Formatter(fmt=fmt, datefmt=date_fmt) 80 | console.setFormatter(formatter) 81 | if rank == 0: 82 | logging.getLogger('').addHandler(console) 83 | else: 84 | fh = logging.FileHandler(filename) 85 | logging.getLogger('').addHandler(fh) 86 | 87 | 88 | 89 | def prep_experiment(args, parser): 90 | """ 91 | Make output directories, setup logging, Tensorboard, snapshot code. 92 | """ 93 | ckpt_path = args.ckpt 94 | tb_path = args.tb_path 95 | exp_name = make_exp_name(args, parser) 96 | args.exp_path = os.path.join(ckpt_path, args.exp, exp_name) 97 | args.tb_exp_path = os.path.join(tb_path, args.exp, exp_name) 98 | args.ngpu = torch.cuda.device_count() 99 | args.date_str = str(datetime.now().strftime('%Y_%m_%d_%H_%M_%S')) 100 | args.best_record = {'epoch': -1, 'iter': 0, 'val_loss': 1e10, 'acc': 0, 101 | 'acc_cls': 0, 'mean_iu': 0, 'fwavacc': 0} 102 | args.last_record = {} 103 | if args.local_rank == 0: 104 | os.makedirs(args.exp_path, exist_ok=True) 105 | os.makedirs(args.tb_exp_path, exist_ok=True) 106 | save_log('log', args.exp_path, args.date_str, rank=args.local_rank) 107 | open(os.path.join(args.exp_path, args.date_str + '.txt'), 'w').write( 108 | str(args) + '\n\n') 109 | writer = SummaryWriter(logdir=args.tb_exp_path, comment=args.tb_tag) 110 | return writer 111 | return None 112 | 113 | def evaluate_eval_for_inference(hist, dataset=None): 114 | """ 115 | Modified IOU mechanism for on-the-fly IOU calculations ( prevents memory overflow for 116 | large dataset) Only applies to eval/eval.py 117 | """ 118 | # axis 0: gt, axis 1: prediction 119 | acc = np.diag(hist).sum() / hist.sum() 120 | acc_cls = np.diag(hist) / hist.sum(axis=1) 121 | acc_cls = np.nanmean(acc_cls) 122 | iu = np.diag(hist) / (hist.sum(axis=1) + hist.sum(axis=0) - np.diag(hist)) 123 | 124 | print_evaluate_results(hist, iu, dataset=dataset) 125 | freq = hist.sum(axis=1) / hist.sum() 126 | mean_iu = np.nanmean(iu) 127 | logging.info('mean {}'.format(mean_iu)) 128 | fwavacc = (freq[freq > 0] * iu[freq > 0]).sum() 129 | return acc, acc_cls, mean_iu, fwavacc 130 | 131 | 132 | 133 | def evaluate_eval(args, net, optimizer, val_loss, hist, dump_images, writer, epoch=0, dataset=None, ): 134 | """ 135 | Modified IOU mechanism for on-the-fly IOU calculations ( prevents memory overflow for 136 | large dataset) Only applies to eval/eval.py 137 | """ 138 | # axis 0: gt, axis 1: prediction 139 | acc = np.diag(hist).sum() / hist.sum() 140 | acc_cls = np.diag(hist) / hist.sum(axis=1) 141 | acc_cls = np.nanmean(acc_cls) 142 | iu = np.diag(hist) / (hist.sum(axis=1) + hist.sum(axis=0) - np.diag(hist)) 143 | 144 | print_evaluate_results(hist, iu, dataset) 145 | freq = hist.sum(axis=1) / hist.sum() 146 | mean_iu = np.nanmean(iu) 147 | logging.info('mean {}'.format(mean_iu)) 148 | fwavacc = (freq[freq > 0] * iu[freq > 0]).sum() 149 | 150 | # update latest snapshot 151 | if 'mean_iu' in args.last_record: 152 | last_snapshot = 'last_epoch_{}_mean-iu_{:.5f}.pth'.format( 153 | args.last_record['epoch'], args.last_record['mean_iu']) 154 | last_snapshot = os.path.join(args.exp_path, last_snapshot) 155 | try: 156 | os.remove(last_snapshot) 157 | except OSError: 158 | pass 159 | last_snapshot = 'last_epoch_{}_mean-iu_{:.5f}.pth'.format(epoch, mean_iu) 160 | last_snapshot = os.path.join(args.exp_path, last_snapshot) 161 | args.last_record['mean_iu'] = mean_iu 162 | args.last_record['epoch'] = epoch 163 | 164 | torch.cuda.synchronize() 165 | 166 | torch.save({ 167 | 'state_dict': net.state_dict(), 168 | 'optimizer': optimizer.state_dict(), 169 | 'epoch': epoch, 170 | 'mean_iu': mean_iu, 171 | 'command': ' '.join(sys.argv[1:]) 172 | }, last_snapshot) 173 | 174 | # update best snapshot 175 | if mean_iu > args.best_record['mean_iu'] : 176 | # remove old best snapshot 177 | if args.best_record['epoch'] != -1: 178 | best_snapshot = 'best_epoch_{}_mean-iu_{:.5f}.pth'.format( 179 | args.best_record['epoch'], args.best_record['mean_iu']) 180 | best_snapshot = os.path.join(args.exp_path, best_snapshot) 181 | assert os.path.exists(best_snapshot), \ 182 | 'cant find old snapshot {}'.format(best_snapshot) 183 | os.remove(best_snapshot) 184 | 185 | 186 | # save new best 187 | args.best_record['val_loss'] = val_loss.avg 188 | args.best_record['epoch'] = epoch 189 | args.best_record['acc'] = acc 190 | args.best_record['acc_cls'] = acc_cls 191 | args.best_record['mean_iu'] = mean_iu 192 | args.best_record['fwavacc'] = fwavacc 193 | 194 | best_snapshot = 'best_epoch_{}_mean-iu_{:.5f}.pth'.format( 195 | args.best_record['epoch'], args.best_record['mean_iu']) 196 | best_snapshot = os.path.join(args.exp_path, best_snapshot) 197 | shutil.copyfile(last_snapshot, best_snapshot) 198 | 199 | 200 | to_save_dir = os.path.join(args.exp_path, 'best_images') 201 | os.makedirs(to_save_dir, exist_ok=True) 202 | val_visual = [] 203 | 204 | idx = 0 205 | 206 | visualize = standard_transforms.Compose([ 207 | standard_transforms.Scale(384), 208 | standard_transforms.ToTensor() 209 | ]) 210 | for bs_idx, bs_data in enumerate(dump_images): 211 | for local_idx, data in enumerate(zip(bs_data[0], bs_data[1],bs_data[2])): 212 | gt_pil = args.dataset_cls.colorize_mask(data[0].cpu().numpy()) 213 | predictions_pil = args.dataset_cls.colorize_mask(data[1].cpu().numpy()) 214 | img_name = data[2] 215 | 216 | prediction_fn = '{}_prediction.png'.format(img_name) 217 | predictions_pil.save(os.path.join(to_save_dir, prediction_fn)) 218 | gt_fn = '{}_gt.png'.format(img_name) 219 | gt_pil.save(os.path.join(to_save_dir, gt_fn)) 220 | val_visual.extend([visualize(gt_pil.convert('RGB')), 221 | visualize(predictions_pil.convert('RGB'))]) 222 | if local_idx >= 9: 223 | break 224 | val_visual = torch.stack(val_visual, 0) 225 | val_visual = vutils.make_grid(val_visual, nrow=10, padding=5) 226 | writer.add_image('imgs', val_visual, epoch ) 227 | 228 | logging.info('-' * 107) 229 | fmt_str = '[epoch %d], [val loss %.5f], [acc %.5f], [acc_cls %.5f], ' +\ 230 | '[mean_iu %.5f], [fwavacc %.5f]' 231 | logging.info(fmt_str % (epoch, val_loss.avg, acc, acc_cls, mean_iu, fwavacc)) 232 | fmt_str = 'best record: [val loss %.5f], [acc %.5f], [acc_cls %.5f], ' +\ 233 | '[mean_iu %.5f], [fwavacc %.5f], [epoch %d], ' 234 | logging.info(fmt_str % (args.best_record['val_loss'], args.best_record['acc'], 235 | args.best_record['acc_cls'], args.best_record['mean_iu'], 236 | args.best_record['fwavacc'], args.best_record['epoch'])) 237 | logging.info('-' * 107) 238 | 239 | # tensorboard logging of validation phase metrics 240 | 241 | writer.add_scalar('training/acc', acc, epoch) 242 | writer.add_scalar('training/acc_cls', acc_cls, epoch) 243 | writer.add_scalar('training/mean_iu', mean_iu, epoch) 244 | writer.add_scalar('training/val_loss', val_loss.avg, epoch) 245 | 246 | 247 | 248 | def fast_hist(label_pred, label_true, num_classes): 249 | mask = (label_true >= 0) & (label_true < num_classes) 250 | hist = np.bincount( 251 | num_classes * label_true[mask].astype(int) + 252 | label_pred[mask], minlength=num_classes ** 2).reshape(num_classes, num_classes) 253 | return hist 254 | 255 | 256 | 257 | def print_evaluate_results(hist, iu, dataset=None): 258 | # fixme: Need to refactor this dict 259 | try: 260 | id2cat = dataset.id2cat 261 | except: 262 | id2cat = {i: i for i in range(dataset.num_classes)} 263 | iu_false_positive = hist.sum(axis=1) - np.diag(hist) 264 | iu_false_negative = hist.sum(axis=0) - np.diag(hist) 265 | iu_true_positive = np.diag(hist) 266 | 267 | logging.info('IoU:') 268 | logging.info('label_id label iU Precision Recall TP FP FN') 269 | for idx, i in enumerate(iu): 270 | # Format all of the strings: 271 | idx_string = "{:2d}".format(idx) 272 | class_name = "{:>13}".format(id2cat[idx]) if idx in id2cat else '' 273 | iu_string = '{:5.2f}'.format(i * 100) 274 | total_pixels = hist.sum() 275 | tp = '{:5.2f}'.format(100 * iu_true_positive[idx] / total_pixels) 276 | fp = '{:5.2f}'.format( 277 | iu_false_positive[idx] / iu_true_positive[idx]) 278 | fn = '{:5.2f}'.format(iu_false_negative[idx] / iu_true_positive[idx]) 279 | precision = '{:5.2f}'.format( 280 | iu_true_positive[idx] / (iu_true_positive[idx] + iu_false_positive[idx])) 281 | recall = '{:5.2f}'.format( 282 | iu_true_positive[idx] / (iu_true_positive[idx] + iu_false_negative[idx])) 283 | logging.info('{} {} {} {} {} {} {} {}'.format( 284 | idx_string, class_name, iu_string, precision, recall, tp, fp, fn)) 285 | 286 | 287 | 288 | 289 | class AverageMeter(object): 290 | 291 | def __init__(self): 292 | self.reset() 293 | 294 | def reset(self): 295 | self.val = 0 296 | self.avg = 0 297 | self.sum = 0 298 | self.count = 0 299 | 300 | def update(self, val, n=1): 301 | self.val = val 302 | self.sum += val * n 303 | self.count += n 304 | self.avg = self.sum / self.count 305 | -------------------------------------------------------------------------------- /utils/my_data_parallel.py: -------------------------------------------------------------------------------- 1 | 2 | """ 3 | # Code adapted from: 4 | # https://github.com/pytorch/pytorch/blob/master/torch/nn/parallel/data_parallel.py 5 | # 6 | # BSD 3-Clause License 7 | # 8 | # Copyright (c) 2017, 9 | # All rights reserved. 10 | # 11 | # Redistribution and use in source and binary forms, with or without 12 | # modification, are permitted provided that the following conditions are met: 13 | # 14 | # * Redistributions of source code must retain the above copyright notice, this 15 | # list of conditions and the following disclaimer. 16 | # 17 | # * Redistributions in binary form must reproduce the above copyright notice, 18 | # this list of conditions and the following disclaimer in the documentation 19 | # and/or other materials provided with the distribution. 20 | # 21 | # * Neither the name of the copyright holder nor the names of its 22 | # contributors may be used to endorse or promote products derived from 23 | # this software without specific prior written permission. 24 | # 25 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 26 | # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 27 | # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 28 | # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 29 | # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 30 | # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 31 | # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 32 | # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 33 | # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 34 | # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.s 35 | """ 36 | 37 | 38 | import operator 39 | import torch 40 | import warnings 41 | from torch.nn.modules import Module 42 | from torch.nn.parallel.scatter_gather import scatter_kwargs, gather 43 | from torch.nn.parallel.replicate import replicate 44 | from torch.nn.parallel.parallel_apply import parallel_apply 45 | 46 | 47 | def _check_balance(device_ids): 48 | imbalance_warn = """ 49 | There is an imbalance between your GPUs. You may want to exclude GPU {} which 50 | has less than 75% of the memory or cores of GPU {}. You can do so by setting 51 | the device_ids argument to DataParallel, or by setting the CUDA_VISIBLE_DEVICES 52 | environment variable.""" 53 | 54 | dev_props = [torch.cuda.get_device_properties(i) for i in device_ids] 55 | 56 | def warn_imbalance(get_prop): 57 | values = [get_prop(props) for props in dev_props] 58 | min_pos, min_val = min(enumerate(values), key=operator.itemgetter(1)) 59 | max_pos, max_val = max(enumerate(values), key=operator.itemgetter(1)) 60 | if min_val / max_val < 0.75: 61 | warnings.warn(imbalance_warn.format(device_ids[min_pos], device_ids[max_pos])) 62 | return True 63 | return False 64 | 65 | if warn_imbalance(lambda props: props.total_memory): 66 | return 67 | if warn_imbalance(lambda props: props.multi_processor_count): 68 | return 69 | 70 | 71 | 72 | def data_parallel(module, inputs, device_ids=None, output_device=None, dim=0, module_kwargs=None, gather=True): 73 | """ 74 | Evaluates module(input) in parallel across the GPUs given in device_ids. 75 | This is the functional version of the DataParallel module. 76 | Args: 77 | module: the module to evaluate in parallel 78 | inputs: inputs to the module 79 | device_ids: GPU ids on which to replicate module 80 | output_device: GPU location of the output Use -1 to indicate the CPU. 81 | (default: device_ids[0]) 82 | Returns: 83 | a Tensor containing the result of module(input) located on 84 | output_device 85 | """ 86 | if not isinstance(inputs, tuple): 87 | inputs = (inputs,) 88 | 89 | if device_ids is None: 90 | device_ids = list(range(torch.cuda.device_count())) 91 | 92 | if output_device is None: 93 | output_device = device_ids[0] 94 | 95 | inputs, module_kwargs = scatter_kwargs(inputs, module_kwargs, device_ids, dim) 96 | if len(device_ids) == 1: 97 | return module(*inputs[0], **module_kwargs[0]) 98 | used_device_ids = device_ids[:len(inputs)] 99 | replicas = replicate(module, used_device_ids) 100 | outputs = parallel_apply(replicas, inputs, module_kwargs, used_device_ids) 101 | if gather: 102 | return gather(outputs, output_device, dim) 103 | else: 104 | return outputs 105 | 106 | 107 | 108 | class MyDataParallel(Module): 109 | """ 110 | Implements data parallelism at the module level. 111 | This container parallelizes the application of the given module by 112 | splitting the input across the specified devices by chunking in the batch 113 | dimension. In the forward pass, the module is replicated on each device, 114 | and each replica handles a portion of the input. During the backwards 115 | pass, gradients from each replica are summed into the original module. 116 | The batch size should be larger than the number of GPUs used. 117 | See also: :ref:`cuda-nn-dataparallel-instead` 118 | Arbitrary positional and keyword inputs are allowed to be passed into 119 | DataParallel EXCEPT Tensors. All tensors will be scattered on dim 120 | specified (default 0). Primitive types will be broadcasted, but all 121 | other types will be a shallow copy and can be corrupted if written to in 122 | the model's forward pass. 123 | .. warning:: 124 | Forward and backward hooks defined on :attr:`module` and its submodules 125 | will be invoked ``len(device_ids)`` times, each with inputs located on 126 | a particular device. Particularly, the hooks are only guaranteed to be 127 | executed in correct order with respect to operations on corresponding 128 | devices. For example, it is not guaranteed that hooks set via 129 | :meth:`~torch.nn.Module.register_forward_pre_hook` be executed before 130 | `all` ``len(device_ids)`` :meth:`~torch.nn.Module.forward` calls, but 131 | that each such hook be executed before the corresponding 132 | :meth:`~torch.nn.Module.forward` call of that device. 133 | .. warning:: 134 | When :attr:`module` returns a scalar (i.e., 0-dimensional tensor) in 135 | :func:`forward`, this wrapper will return a vector of length equal to 136 | number of devices used in data parallelism, containing the result from 137 | each device. 138 | .. note:: 139 | There is a subtlety in using the 140 | ``pack sequence -> recurrent network -> unpack sequence`` pattern in a 141 | :class:`~torch.nn.Module` wrapped in :class:`~torch.nn.DataParallel`. 142 | See :ref:`pack-rnn-unpack-with-data-parallelism` section in FAQ for 143 | details. 144 | Args: 145 | module: module to be parallelized 146 | device_ids: CUDA devices (default: all devices) 147 | output_device: device location of output (default: device_ids[0]) 148 | Attributes: 149 | module (Module): the module to be parallelized 150 | Example:: 151 | >>> net = torch.nn.DataParallel(model, device_ids=[0, 1, 2]) 152 | >>> output = net(input_var) 153 | """ 154 | 155 | # TODO: update notes/cuda.rst when this class handles 8+ GPUs well 156 | 157 | def __init__(self, module, device_ids=None, output_device=None, dim=0, gather=True): 158 | super(MyDataParallel, self).__init__() 159 | 160 | if not torch.cuda.is_available(): 161 | self.module = module 162 | self.device_ids = [] 163 | return 164 | 165 | if device_ids is None: 166 | device_ids = list(range(torch.cuda.device_count())) 167 | if output_device is None: 168 | output_device = device_ids[0] 169 | self.dim = dim 170 | self.module = module 171 | self.device_ids = device_ids 172 | self.output_device = output_device 173 | self.gather_bool = gather 174 | 175 | _check_balance(self.device_ids) 176 | 177 | if len(self.device_ids) == 1: 178 | self.module.cuda(device_ids[0]) 179 | 180 | def forward(self, *inputs, **kwargs): 181 | if not self.device_ids: 182 | return self.module(*inputs, **kwargs) 183 | inputs, kwargs = self.scatter(inputs, kwargs, self.device_ids) 184 | if len(self.device_ids) == 1: 185 | return [self.module(*inputs[0], **kwargs[0])] 186 | replicas = self.replicate(self.module, self.device_ids[:len(inputs)]) 187 | outputs = self.parallel_apply(replicas, inputs, kwargs) 188 | if self.gather_bool: 189 | return self.gather(outputs, self.output_device) 190 | else: 191 | return outputs 192 | 193 | def replicate(self, module, device_ids): 194 | return replicate(module, device_ids) 195 | 196 | def scatter(self, inputs, kwargs, device_ids): 197 | return scatter_kwargs(inputs, kwargs, device_ids, dim=self.dim) 198 | 199 | def parallel_apply(self, replicas, inputs, kwargs): 200 | return parallel_apply(replicas, inputs, kwargs, self.device_ids[:len(replicas)]) 201 | 202 | def gather(self, outputs, output_device): 203 | return gather(outputs, output_device, dim=self.dim) 204 | 205 | --------------------------------------------------------------------------------