├── .gitignore
├── .gitmodules
├── Dockerfile
├── LICENSE
├── PREPARE_DATASETS.md
├── README.md
├── config.py
├── datasets
├── __init__.py
├── camvid.py
├── cityscapes.py
├── cityscapes_labels.py
├── comma10k.py
├── kitti.py
├── mapillary.py
├── nullloader.py
├── sampler.py
└── uniform.py
├── demo.py
├── demo_folder.py
├── eval.py
├── images
├── method.png
└── vis.png
├── loss.py
├── network
├── Resnet.py
├── SEresnext.py
├── __init__.py
├── deepv3.py
├── mynn.py
└── wider_resnet.py
├── optimizer.py
├── scripts
├── eval_cityscapes_SEResNeXt50.sh
├── eval_cityscapes_WideResNet38.sh
├── submit_cityscapes_WideResNet38.sh
├── test_kitti_WideResNet38.sh
├── train_cityscapes_SEResNeXt50.sh
├── train_cityscapes_WideResNet38.sh
├── train_comma10k_WideResNet38.sh
├── train_kitti_WideResNet38.sh
├── train_mapillary_SEResNeXt50.sh
└── train_mapillary_WideResNet38.sh
├── sdcnet
├── _aug.sh
├── _eval.sh
├── datasets
│ ├── __init__.py
│ ├── dataset_utils.py
│ └── frame_loader.py
├── main.py
├── models
│ ├── __init__.py
│ ├── model_utils.py
│ └── sdc_net2d.py
├── sdc_aug.py
├── spatialdisplconv_package
│ ├── __init__.py
│ ├── setup.py
│ ├── spatialdisplconv.py
│ ├── spatialdisplconv_cuda.cc
│ ├── spatialdisplconv_kernel.cu
│ ├── spatialdisplconv_kernel.cuh
│ └── test_spatialdisplconv.py
└── utility
│ ├── Dockerfile
│ └── tools.py
├── train.py
├── transforms
├── __init__.py
├── joint_transforms.py
└── transforms.py
└── utils
├── __init__.py
├── attr_dict.py
├── misc.py
└── my_data_parallel.py
/.gitignore:
--------------------------------------------------------------------------------
1 | ckpt
2 | setenv
3 | weights
4 | *~
5 | *.pyc
6 | dump_imgs_train
7 | tb
8 | __pycache__/
9 | .idea/
10 | build/
11 | *.egg-info/
12 | dist/
13 | *.py[cod]
14 | *.swp
15 | *.o
16 | *.so
17 | .torch
18 | .DS_Store
19 | pretrained_models
20 | logs
21 |
--------------------------------------------------------------------------------
/.gitmodules:
--------------------------------------------------------------------------------
1 | [submodule "sdcnet/flownet2_pytorch"]
2 | path = sdcnet/flownet2_pytorch
3 | url = https://github.com/NVIDIA/flownet2-pytorch
4 |
--------------------------------------------------------------------------------
/Dockerfile:
--------------------------------------------------------------------------------
1 | # NVIDIA Pytorch Image
2 | FROM nvcr.io/nvidia/pytorch:19.05-py3
3 |
4 | RUN pip install numpy
5 | RUN pip install sklearn
6 | RUN pip install h5py
7 | RUN pip install jupyter
8 | RUN pip install scikit-image
9 | RUN pip install pillow
10 | RUN pip install piexif
11 | RUN pip install cffi
12 | RUN pip install tqdm
13 | RUN pip install dominate
14 | RUN pip install tensorboardX
15 | RUN pip install opencv-python
16 | RUN pip install nose
17 | RUN pip install ninja
18 |
19 | RUN apt-get update
20 | RUN apt-get install libgtk2.0-dev -y && rm -rf /var/lib/apt/lists/*
21 |
22 |
23 | # Install Apex
24 | RUN cd /home/ && git clone https://github.com/NVIDIA/apex.git apex && cd apex && python setup.py install --cuda_ext --cpp_ext
25 | WORKDIR /home/
26 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | Copyright (C) 2019 NVIDIA Corporation. Yi Zhu, Karan Sapra, Fitsum A. Reda, Kevin J. Shih, Shawn Newsam, Andrew Tao and Bryan Catanzaro.
2 | All rights reserved.
3 | Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode).
4 |
5 | Permission to use, copy, modify, and distribute this software and its documentation
6 | for any non-commercial purpose is hereby granted without fee, provided that the above
7 | copyright notice appear in all copies and that both that copyright notice and this
8 | permission notice appear in supporting documentation, and that the name of the author
9 | not be used in advertising or publicity pertaining to distribution of the software
10 | without specific, written prior permission.
11 |
12 | THE AUTHOR DISCLAIMS ALL WARRANTIES WITH REGARD TO THIS SOFTWARE, INCLUDING ALL
13 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR ANY PARTICULAR PURPOSE.
14 | IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR ANY SPECIAL, INDIRECT OR CONSEQUENTIAL
15 | DAMAGES OR ANY DAMAGES WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS,
16 | WHETHER IN AN ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING
17 | OUT OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
18 |
--------------------------------------------------------------------------------
/PREPARE_DATASETS.md:
--------------------------------------------------------------------------------
1 | ## Mapillary Vistas Dataset
2 |
3 | First of all, please request the research edition dataset from [here](https://www.mapillary.com/dataset/vistas/). The downloaded file is named as `mapillary-vistas-dataset_public_v1.1.zip`.
4 |
5 | Then simply unzip the file by
6 | ```shell
7 | unzip mapillary-vistas-dataset_public_v1.1.zip
8 | ```
9 |
10 | The folder structure will look like:
11 | ```
12 | Mapillary
13 | ├── config.json
14 | ├── demo.py
15 | ├── Mapillary Vistas Research Edition License.pdf
16 | ├── README
17 | ├── requirements.txt
18 | ├── training
19 | │ ├── images
20 | │ ├── instances
21 | │ ├── labels
22 | │ ├── panoptic
23 | ├── validation
24 | │ ├── images
25 | │ ├── instances
26 | │ ├── labels
27 | │ ├── panoptic
28 | ├── testing
29 | │ ├── images
30 | │ ├── instances
31 | │ ├── labels
32 | │ ├── panoptic
33 | ```
34 | Note that, the `instances`, `labels` and `panoptic` folders inside `testing` are empty.
35 |
36 | Suppose you store your dataset at `~/username/data/Mapillary`, please update the dataset path in `config.py`,
37 | ```
38 | __C.DATASET.MAPILLARY_DIR = '~/username/data/Mapillary'
39 | ```
40 |
41 | ## Cityscapes Dataset
42 |
43 | ### Download Dataset
44 | First of all, please request the dataset from [here](https://www.cityscapes-dataset.com/). You need multiple files.
45 | ```
46 | - leftImg8bit_trainvaltest.zip
47 | - gtFine_trainvaltest.zip
48 | - leftImg8bit_trainextra.zip
49 | - gtCoarse.zip
50 | - leftImg8bit_sequence.zip # This file is very large, 324G. You only need it if you want to run sdc_aug experiments.
51 | ```
52 |
53 | If you prefer to use command lines (e.g., `wget`) to download the dataset,
54 | ```
55 | # First step, obtain your login credentials.
56 | Please register an account at https://www.cityscapes-dataset.com/login/.
57 |
58 | # Second step, log into cityscapes system, suppose you already have a USERNAME and a PASSWORD.
59 | wget --keep-session-cookies --save-cookies=cookies.txt --post-data 'username=USERNAME&password=PASSWORD&submit=Login' https://www.cityscapes-dataset.com/login/
60 |
61 | # Third step, download the zip files you need.
62 | wget -c -t 0 --load-cookies cookies.txt --content-disposition https://www.cityscapes-dataset.com/file-handling/?packageID=3
63 |
64 | # The corresponding packageID is listed below,
65 | 1 -> gtFine_trainvaltest.zip (241MB) md5sum: 4237c19de34c8a376e9ba46b495d6f66
66 | 2 -> gtCoarse.zip (1.3GB) md5sum: 1c7b95c84b1d36cc59a9194d8e5b989f
67 | 3 -> leftImg8bit_trainvaltest.zip (11GB) md5sum: 0a6e97e94b616a514066c9e2adb0c97f
68 | 4 -> leftImg8bit_trainextra.zip (44GB) md5sum: 9167a331a158ce3e8989e166c95d56d4
69 | 14 -> leftImg8bit_sequence.zip (324GB) md5sum: 4348961b135d856c1777f7f1098f7266
70 | ```
71 |
72 | ### Prepare Folder Structure
73 |
74 | Now unzip those files, the desired folder structure will look like,
75 | ```
76 | Cityscapes
77 | ├── leftImg8bit_trainvaltest
78 | │ ├── leftImg8bit
79 | │ │ ├── train
80 | │ │ │ ├── aachen
81 | │ │ │ │ ├── aachen_000000_000019_leftImg8bit.png
82 | │ │ │ │ ├── aachen_000001_000019_leftImg8bit.png
83 | │ │ │ │ ├── ...
84 | │ │ │ ├── bochum
85 | │ │ │ ├── ...
86 | │ │ ├── val
87 | │ │ ├── test
88 | ├── gtFine_trainvaltest
89 | │ ├── gtFine
90 | │ │ ├── train
91 | │ │ │ ├── aachen
92 | │ │ │ │ ├── aachen_000000_000019_gtFine_color.png
93 | │ │ │ │ ├── aachen_000000_000019_gtFine_instanceIds.png
94 | │ │ │ │ ├── aachen_000000_000019_gtFine_labelIds.png
95 | │ │ │ │ ├── aachen_000000_000019_gtFine_polygons.json
96 | │ │ │ │ ├── ...
97 | │ │ │ ├── bochum
98 | │ │ │ ├── ...
99 | │ │ ├── val
100 | │ │ ├── test
101 | ├── leftImg8bit_trainextra
102 | │ ├── leftImg8bit
103 | │ │ ├── train_extra
104 | │ │ │ ├── augsburg
105 | │ │ │ ├── bad-honnef
106 | │ │ │ ├── ...
107 | ├── gtCoarse
108 | │ ├── gtCoarse
109 | │ │ ├── train
110 | │ │ ├── train_extra
111 | │ │ ├── val
112 | ├── leftImg8bit_sequence
113 | │ ├── train
114 | │ ├── val
115 | │ ├── test
116 | ```
117 |
118 | ## CamVid Dataset
119 |
120 | Please download and prepare this dataset according to the [tutorial](https://github.com/alexgkendall/SegNet-Tutorial/tree/master/CamVid). The desired folder structure will look like,
121 | ```
122 | CamVid
123 | ├── train
124 | ├── trainannot
125 | ├── val
126 | ├── valannot
127 | ├── test
128 | ├── testannot
129 | ```
130 |
131 | ## KITTI Dataset
132 |
133 | Please download this dataset at the KITTI Semantic Segmentation benchmark [webpage](http://www.cvlibs.net/datasets/kitti/eval_semantics.php).
134 |
135 | Now unzip the file, the desired folder structure will look like,
136 | ```
137 | KITTI
138 | ├── training
139 | │ ├── image_2
140 | │ ├── instance
141 | │ ├── semantic
142 | ├── test
143 | │ ├── image_2
144 | ```
145 | There is no official training/validation split as the dataset only has `200` training samples. We randomly create three splits at [here](https://github.com/NVIDIA/semantic-segmentation/blob/master/datasets/kitti.py#L41-L44) in order to perform cross-validation.
146 |
147 |
148 |
149 |
150 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Improving Semantic Segmentation via Video Prediction and Label Relaxation
2 | ### [Project](https://nv-adlr.github.io/publication/2018-Segmentation) | [Paper](https://arxiv.org/pdf/1812.01593.pdf) | [YouTube](https://www.youtube.com/watch?v=aEbXjGZDZSQ) | [Cityscapes Score](https://www.cityscapes-dataset.com/anonymous-results/?id=555fc2b66c6e00b953c72b98b100e396c37274e0788e871a85f1b7b4f4fa130e) | [Kitti Score](http://www.cvlibs.net/datasets/kitti/eval_semseg_detail.php?benchmark=semantics2015&result=83cac7efbd41b1f2fc095f9bc1168bc548b48885)
3 | PyTorch implementation of our CVPR2019 paper (oral) on achieving state-of-the-art semantic segmentation results using Deeplabv3-Plus like architecture with a WideResNet38 trunk. We present a video prediction-based methodology to scale up training sets by synthesizing new training samples and propose a novel label relaxation technique to make training objectives robust to label noise.
4 |
5 | [Improving Semantic Segmentation via Video Propagation and Label Relaxation](https://nv-adlr.github.io/publication/2018-Segmentation)
6 | Yi Zhu1*, Karan Sapra2*, [Fitsum A. Reda](https://scholar.google.com/citations?user=quZ_qLYAAAAJ&hl=en)2, Kevin J. Shih2, Shawn Newsam1, Andrew Tao2, [Bryan Catanzaro](http://catanzaro.name/)2
7 | 1UC Merced, 2NVIDIA Corporation
8 | In CVPR 2019 (* equal contributions).
9 |
10 | [SDCNet: Video Prediction using Spatially Displaced Convolution](https://nv-adlr.github.io/publication/2018-SDCNet)
11 | [Fitsum A. Reda](https://scholar.google.com/citations?user=quZ_qLYAAAAJ&hl=en), Guilin Liu, Kevin J. Shih, Robert Kirby, Jon Barker, David Tarjan, Andrew Tao, [Bryan Catanzaro](http://catanzaro.name/)
12 | NVIDIA Corporation
13 | In ECCV 2018.
14 |
15 | 
16 |
17 | ## Installation
18 |
19 | # Get Semantic Segmentation source code
20 | git clone --recursive https://github.com/NVIDIA/semantic-segmentation.git
21 | cd semantic-segmentation
22 |
23 | # Build Docker Image
24 | docker build -t nvidia-segmgentation -f Dockerfile .
25 |
26 | If you prefer not to use docker, you can manually install the following requirements:
27 |
28 | * An NVIDIA GPU and CUDA 9.0 or higher. Some operations only have gpu implementation.
29 | * PyTorch (>= 0.5.1)
30 | * Python 3
31 | * numpy
32 | * sklearn
33 | * h5py
34 | * scikit-image
35 | * pillow
36 | * piexif
37 | * cffi
38 | * tqdm
39 | * dominate
40 | * tensorboardX
41 | * opencv-python
42 | * nose
43 | * ninja
44 |
45 |
46 | Multiple GPU training and mixed precision training are supported, and the code provides examples for training and inference. For more help, type
47 |
48 | python3 train.py --help
49 |
50 |
51 |
52 | ## Network architectures
53 |
54 | Our repo now supports DeepLabV3+ architecture with different backbones, including `WideResNet38`, `SEResNeXt(50, 101)` and `ResNet(50,101)`.
55 |
56 |
57 | ## Pre-trained models
58 | We've included pre-trained models. Download checkpoints to a folder `pretrained_models`.
59 |
60 | * [pretrained_models/cityscapes_best.pth](https://drive.google.com/file/d/1P4kPaMY-SmQ3yPJQTJ7xMGAB_Su-1zTl/view?usp=sharing)[1071MB, WideResNet38 backbone]
61 | * [pretrained_models/camvid_best.pth](https://drive.google.com/file/d/1OzUCbFdXulB2P80Qxm7C3iNTeTP0Mvb_/view?usp=sharing)[1071MB, WideResNet38 backbone]
62 | * [pretrained_models/kitti_best.pth](https://drive.google.com/file/d/1OrTcqH_I3PHFiMlTTZJgBy8l_pladwtg/view?usp=sharing)[1071MB, WideResNet38 backbone]
63 | * [pretrained_models/sdc_cityscapes_vrec.pth.tar](https://drive.google.com/file/d/1OxnJo2tFEQs3vuY01ibPFjn3cRCo2yWt/view?usp=sharing)[38MB]
64 | * [pretrained_models/FlowNet2_checkpoint.pth.tar](https://drive.google.com/file/d/1hF8vS6YeHkx3j2pfCeQqqZGwA_PJq_Da/view?usp=sharing)[620MB]
65 |
66 | ImageNet Weights
67 | * [pretrained_models/wider_resnet38.pth.tar](https://drive.google.com/file/d/1OfKQPQXbXGbWAQJj2R82x6qyz6f-1U6t/view?usp=sharing)[833MB]
68 |
69 | Other Weights
70 | * [pretrained_models/cityscapes_cv0_seresnext50_nosdcaug.pth](https://drive.google.com/file/d/1aGdA1WAKKkU2y-87wSOE1prwrIzs_L-h/view?usp=sharing)[324MB]
71 | * [pretrained_models/cityscapes_cv0_wideresnet38_nosdcaug.pth](https://drive.google.com/file/d/1CKB7gpcPLgDLA7LuFJc46rYcNzF3aWzH/view?usp=sharing)[1.1GB]
72 |
73 | ## Data Loaders
74 |
75 | Dataloaders for Cityscapes, Mapillary, Camvid and Kitti are available in [datasets](./datasets). Details of preparing each dataset can be found at [PREPARE_DATASETS.md](https://github.com/NVIDIA/semantic-segmentation/blob/master/PREPARE_DATASETS.md)
76 |
77 |
78 | ## Semantic segmentation demo for a single image
79 |
80 | If you want to try our trained model on any driving scene images, simply use
81 |
82 | ```
83 | CUDA_VISIBLE_DEVICES=0 python demo.py --demo-image YOUR_IMG --snapshot ./pretrained_models/cityscapes_best.pth --save-dir YOUR_SAVE_DIR
84 | ```
85 | This snapshot is trained on Cityscapes dataset, with `DeepLabV3+` architecture and `WideResNet38` backbone. The predicted segmentation masks will be saved to `YOUR_SAVE_DIR`. Check it out.
86 |
87 | ## Semantic segmentation demo for a folder of images
88 |
89 | If you want to try our trained model on a folder of driving scene images, simply use
90 |
91 | ```
92 | CUDA_VISIBLE_DEVICES=0 python demo_folder.py --demo-folder YOUR_FOLDER --snapshot ./pretrained_models/cityscapes_best.pth --save-dir YOUR_SAVE_DIR
93 | ```
94 | This snapshot is trained on Cityscapes dataset, with `DeepLabV3+` architecture and `WideResNet38` backbone. The predicted segmentation masks will be saved to `YOUR_SAVE_DIR`. Check it out.
95 |
96 | ## A quick start with light SEResNeXt50 backbone
97 |
98 | Note that, in this section, we use the standard train/val split in Cityscapes to train our model, which is `cv 0`.
99 |
100 | If you have less than 8 GPUs in your machine, please change `--nproc_per_node=8` to the number of GPUs you have in all the .sh files under folder `scripts`.
101 |
102 | ### Pre-Training on Mapillary
103 | First, you can pre-train a DeepLabV3+ model with `SEResNeXt50` trunk on Mapillary dataset. Set `__C.DATASET.MAPILLARY_DIR` in `config.py` to where you store the Mapillary data.
104 |
105 | ```
106 | ./scripts/train_mapillary_SEResNeXt50.sh
107 | ```
108 |
109 | When you first run training on a new dataset with flag `--class_uniform_pct` on, it will take some time to preprocess the dataset. Depending on your machine, the preprocessing can take half an hour or more. Once it finishes, you will have a json file in your root folder, e.g., `mapillary_tile1024.json`. You can read more details about `class uniform sampling` in our paper, the idea is to make sure that all classes are approximately uniformly chosen during training.
110 |
111 | ### Fine-tuning on Cityscapes
112 | Once you have the Mapillary pre-trained model (training mIoU should be 50+), you can start fine-tuning the model on Cityscapes dataset. Set `__C.DATASET.CITYSCAPES_DIR` in `config.py` to where you store the Cityscapes data. Your training mIoU in the end should be 80+.
113 | ```
114 | ./scripts/train_cityscapes_SEResNeXt50.sh
115 | ```
116 |
117 | ### Inference
118 |
119 | Our inference code supports two ways of evaluation: pooling and sliding based eval. The pooling based eval is faster than sliding based eval but provides slightly lower numbers. We use `sliding` as default.
120 | ```
121 | ./scripts/eval_cityscapes_SEResNeXt50.sh
122 | ```
123 |
124 | In the `result_save_location` you set, you will find several folders: `rgb`, `pred`, `compose` and `diff`. `rgb` contains the color-encode predicted segmentation masks. `pred` contains what you need to submit to the evaluation server. `compose` contains the overlapped images of original video frame and the color-encode predicted segmentation masks. `diff` contains the difference between our prediction and the ground truth.
125 |
126 | Right now, our inference code only supports Cityscapes dataset.
127 |
128 |
129 | ## Reproducing our results with heavy WideResNet38 backbone
130 |
131 | Note that, in this section, we use an alternative train/val split in Cityscapes to train our model, which is `cv 2`. You can find the difference between `cv 0` and `cv 2` in the supplementary material section in our arXiv paper.
132 |
133 | ### Pre-Training on Mapillary
134 | ```
135 | ./scripts/train_mapillary_WideResNet38.sh
136 | ```
137 |
138 | ### Fine-tuning on Cityscapes
139 | ```
140 | ./scripts/train_cityscapes_WideResNet38.sh
141 | ```
142 |
143 | ### Inference
144 | ```
145 | ./scripts/eval_cityscapes_WideResNet38.sh
146 | ```
147 |
148 | For submitting to Cityscapes benchmark, we change it to multi-scale setting.
149 | ```
150 | ./scripts/submit_cityscapes_WideResNet38.sh
151 | ```
152 |
153 | Now you can zip the `pred` folder and upload to Cityscapes leaderboard. For the test submission, there is nothing in the `diff` folder because we don't have ground truth.
154 |
155 | At this point, you can already achieve top performance on Cityscapes benchmark (83+ mIoU). In order to further boost the segmentation performance, we can use the augmented dataset to help model's generalization capibility.
156 |
157 | ### Label Propagation using Video Prediction
158 | First, you need to donwload the Cityscapes sequence dataset. Note that the sequence dataset is very large (a 325GB .zip file). Then we can use video prediction model to propagate GT segmentation masks to adjacent video frames, so that we can have more annotated image-label pairs during training.
159 |
160 | ```
161 | cd ./sdcnet
162 |
163 | bash flownet2_pytorch/install.sh
164 |
165 | ./_aug.sh
166 | ```
167 |
168 | By default, we predict five past frames and five future frames, which effectively enlarge the dataset 10 times. If you prefer to propagate less or more time steps, you can change the `--propagate` accordingly. Enjoy the augmented dataset.
169 |
170 |
171 | ## Results on Cityscapes
172 |
173 | 
174 |
175 | ## Training IOU using fp16
176 |
177 | Training results for WideResNet38 and SEResNeXt50 trained in fp16 on DGX-1 (8-GPU V100). fp16 can significantly speed up experiments without losing much accuracy.
178 |
179 |
180 |
181 |
Model Name
182 |
Mean IOU
183 |
Training Time
184 |
185 |
186 |
DeepWV3Plus(no sdc-aug)
187 |
81.4
188 |
~14 hrs
189 |
190 |
191 |
DeepSRNX50V3PlusD_m1(no sdc-aug)
192 |
80.0
193 |
~9 hrs
194 |
195 |
196 |
197 |
198 | ## Reference
199 |
200 | If you find this implementation useful in your work, please acknowledge it appropriately and cite the paper or code accordingly:
201 |
202 | ```
203 | @inproceedings{semantic_cvpr19,
204 | author = {Yi Zhu*, Karan Sapra*, Fitsum A. Reda, Kevin J. Shih, Shawn Newsam, Andrew Tao, Bryan Catanzaro},
205 | title = {Improving Semantic Segmentation via Video Propagation and Label Relaxation},
206 | booktitle = {IEEE Conference on Computer Vision and Pattern Recognition (CVPR)},
207 | month = {June},
208 | year = {2019},
209 | url = {https://nv-adlr.github.io/publication/2018-Segmentation}
210 | }
211 | * indicates equal contribution
212 |
213 | @inproceedings{reda2018sdc,
214 | title={SDC-Net: Video prediction using spatially-displaced convolution},
215 | author={Reda, Fitsum A and Liu, Guilin and Shih, Kevin J and Kirby, Robert and Barker, Jon and Tarjan, David and Tao, Andrew and Catanzaro, Bryan},
216 | booktitle={Proceedings of the European Conference on Computer Vision (ECCV)},
217 | pages={718--733},
218 | year={2018}
219 | }
220 | ```
221 | We encourage people to contribute to our code base and provide suggestions, point any issues, or solution using merge request, and we hope this repo is useful.
222 |
223 | ## Acknowledgments
224 |
225 | Parts of the code were heavily derived from [pytorch-semantic-segmentation](https://github.com/ZijunDeng/pytorch-semantic-segmentation), [inplace-abn](https://github.com/mapillary/inplace_abn), [Pytorch](https://github.com/pytorch/vision/blob/master/torchvision/models/resnet.py), [ClementPinard/FlowNetPytorch](https://github.com/ClementPinard/FlowNetPytorch), [NVIDIA/flownet2-pytorch](https://github.com/NVIDIA/flownet2-pytorch) and [Cadene](https://github.com/Cadene/pretrained-models.pytorch).
226 |
227 | Our initial models used SyncBN from [Synchronized Batch Norm](https://github.com/zhanghang1989/PyTorch-Encoding) but since then have been ported to [Apex SyncBN](https://github.com/NVIDIA/apex) developed by Jie Jiang.
228 |
229 | We would also like to thank Ming-Yu Liu and Peter Kontschieder.
230 |
231 | ## Coding style
232 | * 4 spaces for indentation rather than tabs
233 | * 100 character line length
234 | * PEP8 formatting
235 |
--------------------------------------------------------------------------------
/config.py:
--------------------------------------------------------------------------------
1 | """
2 | # Code adapted from:
3 | # https://github.com/facebookresearch/Detectron/blob/master/detectron/core/config.py
4 |
5 | Source License
6 | # Copyright (c) 2017-present, Facebook, Inc.
7 | #
8 | # Licensed under the Apache License, Version 2.0 (the "License");
9 | # you may not use this file except in compliance with the License.
10 | # You may obtain a copy of the License at
11 | #
12 | # http://www.apache.org/licenses/LICENSE-2.0
13 | #
14 | # Unless required by applicable law or agreed to in writing, software
15 | # distributed under the License is distributed on an "AS IS" BASIS,
16 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
17 | # See the License for the specific language governing permissions and
18 | # limitations under the License.
19 | ##############################################################################
20 | #
21 | # Based on:
22 | # --------------------------------------------------------
23 | # Fast R-CNN
24 | # Copyright (c) 2015 Microsoft
25 | # Licensed under The MIT License [see LICENSE for details]
26 | # Written by Ross Girshick
27 | # --------------------------------------------------------
28 | """
29 | ##############################################################################
30 | #Config
31 | ##############################################################################
32 |
33 |
34 | from __future__ import absolute_import
35 | from __future__ import division
36 | from __future__ import print_function
37 | from __future__ import unicode_literals
38 |
39 |
40 | import torch
41 |
42 |
43 | from utils.attr_dict import AttrDict
44 |
45 |
46 | __C = AttrDict()
47 | cfg = __C
48 | __C.EPOCH = 0
49 | # Use Class Uniform Sampling to give each class proper sampling
50 | __C.CLASS_UNIFORM_PCT = 0.0
51 |
52 | # Use class weighted loss per batch to increase loss for low pixel count classes per batch
53 | __C.BATCH_WEIGHTING = False
54 |
55 | # Border Relaxation Count
56 | __C.BORDER_WINDOW = 1
57 | # Number of epoch to use before turn off border restriction
58 | __C.REDUCE_BORDER_EPOCH = -1
59 | # Comma Seperated List of class id to relax
60 | __C.STRICTBORDERCLASS = None
61 |
62 |
63 |
64 | #Attribute Dictionary for Dataset
65 | __C.DATASET = AttrDict()
66 | #Cityscapes Dir Location
67 | __C.DATASET.CITYSCAPES_DIR = ''
68 | #SDC Augmented Cityscapes Dir Location
69 | __C.DATASET.CITYSCAPES_AUG_DIR = ''
70 | #Mapillary Dataset Dir Location
71 | __C.DATASET.MAPILLARY_DIR = ''
72 | #Kitti Dataset Dir Location
73 | __C.DATASET.KITTI_DIR = ''
74 | #SDC Augmented Kitti Dataset Dir Location
75 | __C.DATASET.KITTI_AUG_DIR = ''
76 | #Camvid Dataset Dir Location
77 | __C.DATASET.CAMVID_DIR = ''
78 | #Number of splits to support
79 | __C.DATASET.CV_SPLITS = 3
80 |
81 |
82 | __C.MODEL = AttrDict()
83 | __C.MODEL.BN = 'regularnorm'
84 | __C.MODEL.BNFUNC = None
85 |
86 | def assert_and_infer_cfg(args, make_immutable=True, train_mode=True):
87 | """Call this function in your script after you have finished setting all cfg
88 | values that are necessary (e.g., merging a config from a file, merging
89 | command line config options, etc.). By default, this function will also
90 | mark the global cfg as immutable to prevent changing the global cfg settings
91 | during script execution (which can lead to hard to debug errors or code
92 | that's harder to understand than is necessary).
93 | """
94 |
95 | if hasattr(args, 'syncbn') and args.syncbn:
96 | if args.apex:
97 | import apex
98 | __C.MODEL.BN = 'apex-syncnorm'
99 | __C.MODEL.BNFUNC = apex.parallel.SyncBatchNorm
100 | else:
101 | raise Exception('No Support for SyncBN without Apex')
102 | else:
103 | __C.MODEL.BNFUNC = torch.nn.BatchNorm2d
104 | print('Using regular batch norm')
105 |
106 | if not train_mode:
107 | cfg.immutable(True)
108 | return
109 | if args.class_uniform_pct:
110 | cfg.CLASS_UNIFORM_PCT = args.class_uniform_pct
111 |
112 | if args.batch_weighting:
113 | __C.BATCH_WEIGHTING = True
114 |
115 | if args.jointwtborder:
116 | if args.strict_bdr_cls != '':
117 | __C.STRICTBORDERCLASS = [int(i) for i in args.strict_bdr_cls.split(",")]
118 | if args.rlx_off_epoch > -1:
119 | __C.REDUCE_BORDER_EPOCH = args.rlx_off_epoch
120 |
121 | if make_immutable:
122 | cfg.immutable(True)
123 |
--------------------------------------------------------------------------------
/datasets/__init__.py:
--------------------------------------------------------------------------------
1 | """
2 | Dataset setup and loaders
3 | """
4 | from datasets import cityscapes
5 | from datasets import mapillary
6 | from datasets import kitti
7 | from datasets import comma10k
8 | from datasets import camvid
9 | import torchvision.transforms as standard_transforms
10 |
11 | import transforms.joint_transforms as joint_transforms
12 | import transforms.transforms as extended_transforms
13 | from torch.utils.data import DataLoader
14 |
15 |
16 | def setup_loaders(args):
17 | """
18 | Setup Data Loaders[Currently supports Cityscapes, Mapillary and ADE20kin]
19 | input: argument passed by the user
20 | return: training data loader, validation data loader loader, train_set
21 | """
22 |
23 | if args.dataset == 'cityscapes':
24 | args.dataset_cls = cityscapes
25 | args.train_batch_size = args.bs_mult * args.ngpu
26 | if args.bs_mult_val > 0:
27 | args.val_batch_size = args.bs_mult_val * args.ngpu
28 | else:
29 | args.val_batch_size = args.bs_mult * args.ngpu
30 | elif args.dataset == 'mapillary':
31 | args.dataset_cls = mapillary
32 | args.train_batch_size = args.bs_mult * args.ngpu
33 | args.val_batch_size = 4
34 | elif args.dataset == 'ade20k':
35 | args.dataset_cls = ade20k
36 | args.train_batch_size = args.bs_mult * args.ngpu
37 | args.val_batch_size = 4
38 | elif args.dataset == 'kitti':
39 | args.dataset_cls = kitti
40 | args.train_batch_size = args.bs_mult * args.ngpu
41 | if args.bs_mult_val > 0:
42 | args.val_batch_size = args.bs_mult_val * args.ngpu
43 | else:
44 | args.val_batch_size = args.bs_mult * args.ngpu
45 | elif args.dataset == 'comma10k':
46 | args.dataset_cls = comma10k
47 | args.train_batch_size = args.bs_mult * args.ngpu
48 | if args.bs_mult_val > 0:
49 | args.val_batch_size = args.bs_mult_val * args.ngpu
50 | else:
51 | args.val_batch_size = args.bs_mult * args.ngpu
52 | elif args.dataset == 'camvid':
53 | args.dataset_cls = camvid
54 | args.train_batch_size = args.bs_mult * args.ngpu
55 | if args.bs_mult_val > 0:
56 | args.val_batch_size = args.bs_mult_val * args.ngpu
57 | else:
58 | args.val_batch_size = args.bs_mult * args.ngpu
59 | elif args.dataset == 'null_loader':
60 | args.dataset_cls = null_loader
61 | args.train_batch_size = args.bs_mult * args.ngpu
62 | if args.bs_mult_val > 0:
63 | args.val_batch_size = args.bs_mult_val * args.ngpu
64 | else:
65 | args.val_batch_size = args.bs_mult * args.ngpu
66 | else:
67 | raise Exception('Dataset {} is not supported'.format(args.dataset))
68 |
69 | # Readjust batch size to mini-batch size for apex
70 | if args.apex:
71 | args.train_batch_size = args.bs_mult
72 | args.val_batch_size = args.bs_mult_val
73 |
74 | args.num_workers = 4 * args.ngpu
75 | if args.test_mode:
76 | args.num_workers = 1
77 |
78 |
79 | mean_std = ([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
80 |
81 | # Geometric image transformations
82 | train_joint_transform_list = [
83 | joint_transforms.RandomSizeAndCrop(args.crop_size,
84 | False,
85 | pre_size=args.pre_size,
86 | scale_min=args.scale_min,
87 | scale_max=args.scale_max,
88 | ignore_index=args.dataset_cls.ignore_label),
89 | joint_transforms.Resize(args.crop_size),
90 | joint_transforms.RandomHorizontallyFlip()]
91 | train_joint_transform = joint_transforms.Compose(train_joint_transform_list)
92 |
93 | # Image appearance transformations
94 | train_input_transform = []
95 | if args.color_aug:
96 | train_input_transform += [extended_transforms.ColorJitter(
97 | brightness=args.color_aug,
98 | contrast=args.color_aug,
99 | saturation=args.color_aug,
100 | hue=args.color_aug)]
101 |
102 | if args.bblur:
103 | train_input_transform += [extended_transforms.RandomBilateralBlur()]
104 | elif args.gblur:
105 | train_input_transform += [extended_transforms.RandomGaussianBlur()]
106 | else:
107 | pass
108 |
109 |
110 |
111 | train_input_transform += [standard_transforms.ToTensor(),
112 | standard_transforms.Normalize(*mean_std)]
113 | train_input_transform = standard_transforms.Compose(train_input_transform)
114 |
115 | val_input_transform = standard_transforms.Compose([
116 | standard_transforms.ToTensor(),
117 | standard_transforms.Normalize(*mean_std)
118 | ])
119 |
120 | target_transform = extended_transforms.MaskToTensor()
121 |
122 | if args.jointwtborder:
123 | target_train_transform = extended_transforms.RelaxedBoundaryLossToTensor(args.dataset_cls.ignore_label,
124 | args.dataset_cls.num_classes)
125 | else:
126 | target_train_transform = extended_transforms.MaskToTensor()
127 |
128 | if args.dataset == 'cityscapes':
129 | city_mode = 'train' ## Can be trainval
130 | city_quality = 'fine'
131 | if args.class_uniform_pct:
132 | if args.coarse_boost_classes:
133 | coarse_boost_classes = \
134 | [int(c) for c in args.coarse_boost_classes.split(',')]
135 | else:
136 | coarse_boost_classes = None
137 | train_set = args.dataset_cls.CityScapesUniform(
138 | city_quality, city_mode, args.maxSkip,
139 | joint_transform_list=train_joint_transform_list,
140 | transform=train_input_transform,
141 | target_transform=target_train_transform,
142 | dump_images=args.dump_augmentation_images,
143 | cv_split=args.cv,
144 | class_uniform_pct=args.class_uniform_pct,
145 | class_uniform_tile=args.class_uniform_tile,
146 | test=args.test_mode,
147 | coarse_boost_classes=coarse_boost_classes)
148 | else:
149 | train_set = args.dataset_cls.CityScapes(
150 | city_quality, city_mode, 0,
151 | joint_transform=train_joint_transform,
152 | transform=train_input_transform,
153 | target_transform=target_train_transform,
154 | dump_images=args.dump_augmentation_images,
155 | cv_split=args.cv)
156 |
157 | val_set = args.dataset_cls.CityScapes('fine', 'val', 0,
158 | transform=val_input_transform,
159 | target_transform=target_transform,
160 | cv_split=args.cv)
161 | elif args.dataset == 'mapillary':
162 | eval_size = 1536
163 | val_joint_transform_list = [
164 | joint_transforms.ResizeHeight(eval_size),
165 | joint_transforms.CenterCropPad(eval_size, ignore_index=args.dataset_cls.ignore_label)]
166 | train_set = args.dataset_cls.Mapillary(
167 | 'semantic', 'train',
168 | joint_transform_list=train_joint_transform_list,
169 | transform=train_input_transform,
170 | target_transform=target_train_transform,
171 | dump_images=args.dump_augmentation_images,
172 | class_uniform_pct=args.class_uniform_pct,
173 | class_uniform_tile=args.class_uniform_tile,
174 | test=args.test_mode)
175 | val_set = args.dataset_cls.Mapillary(
176 | 'semantic', 'val',
177 | joint_transform_list=val_joint_transform_list,
178 | transform=val_input_transform,
179 | target_transform=target_transform,
180 | test=False)
181 | elif args.dataset == 'ade20k':
182 | eval_size = 384
183 | val_joint_transform_list = [
184 | joint_transforms.ResizeHeight(eval_size),
185 | joint_transforms.CenterCropPad(eval_size)]
186 |
187 | train_set = args.dataset_cls.ade20k(
188 | 'semantic', 'train',
189 | joint_transform_list=train_joint_transform_list,
190 | transform=train_input_transform,
191 | target_transform=target_train_transform,
192 | dump_images=args.dump_augmentation_images,
193 | class_uniform_pct=args.class_uniform_pct,
194 | class_uniform_tile=args.class_uniform_tile,
195 | test=args.test_mode)
196 | val_set = args.dataset_cls.ade20k(
197 | 'semantic', 'val',
198 | joint_transform_list=val_joint_transform_list,
199 | transform=val_input_transform,
200 | target_transform=target_transform,
201 | test=False)
202 | elif args.dataset == 'kitti':
203 | # eval_size_h = 384
204 | # eval_size_w = 1280
205 | # val_joint_transform_list = [
206 | # joint_transforms.ResizeHW(eval_size_h, eval_size_w)]
207 |
208 | train_set = args.dataset_cls.KITTI(
209 | 'semantic', 'train', args.maxSkip,
210 | joint_transform_list=train_joint_transform_list,
211 | transform=train_input_transform,
212 | target_transform=target_train_transform,
213 | dump_images=args.dump_augmentation_images,
214 | class_uniform_pct=args.class_uniform_pct,
215 | class_uniform_tile=args.class_uniform_tile,
216 | test=args.test_mode,
217 | cv_split=args.cv,
218 | scf=args.scf,
219 | hardnm=args.hardnm)
220 | val_set = args.dataset_cls.KITTI(
221 | 'semantic', 'trainval', 0,
222 | joint_transform_list=None,
223 | transform=val_input_transform,
224 | target_transform=target_transform,
225 | test=False,
226 | cv_split=args.cv,
227 | scf=None)
228 | elif args.dataset == 'comma10k':
229 | train_set = args.dataset_cls.COMMA10K(
230 | 'semantic', 'train', args.maxSkip,
231 | joint_transform_list=train_joint_transform_list,
232 | transform=train_input_transform,
233 | target_transform=target_train_transform,
234 | dump_images=args.dump_augmentation_images,
235 | class_uniform_pct=args.class_uniform_pct,
236 | class_uniform_tile=args.class_uniform_tile,
237 | test=args.test_mode,
238 | cv_split=args.cv,
239 | scf=args.scf,
240 | hardnm=args.hardnm)
241 | val_set = args.dataset_cls.COMMA10K(
242 | 'semantic', 'trainval', 0,
243 | joint_transform_list=None,
244 | transform=val_input_transform,
245 | target_transform=target_transform,
246 | test=False,
247 | cv_split=args.cv,
248 | scf=None)
249 | elif args.dataset == 'camvid':
250 | # eval_size_h = 384
251 | # eval_size_w = 1280
252 | # val_joint_transform_list = [
253 | # joint_transforms.ResizeHW(eval_size_h, eval_size_w)]
254 |
255 | train_set = args.dataset_cls.CAMVID(
256 | 'semantic', 'trainval', args.maxSkip,
257 | joint_transform_list=train_joint_transform_list,
258 | transform=train_input_transform,
259 | target_transform=target_train_transform,
260 | dump_images=args.dump_augmentation_images,
261 | class_uniform_pct=args.class_uniform_pct,
262 | class_uniform_tile=args.class_uniform_tile,
263 | test=args.test_mode,
264 | cv_split=args.cv,
265 | scf=args.scf,
266 | hardnm=args.hardnm)
267 | val_set = args.dataset_cls.CAMVID(
268 | 'semantic', 'test', 0,
269 | joint_transform_list=None,
270 | transform=val_input_transform,
271 | target_transform=target_transform,
272 | test=False,
273 | cv_split=args.cv,
274 | scf=None)
275 |
276 | elif args.dataset == 'null_loader':
277 | train_set = args.dataset_cls.null_loader(args.crop_size)
278 | val_set = args.dataset_cls.null_loader(args.crop_size)
279 | else:
280 | raise Exception('Dataset {} is not supported'.format(args.dataset))
281 |
282 | if args.apex:
283 | from datasets.sampler import DistributedSampler
284 | train_sampler = DistributedSampler(train_set, pad=True, permutation=True, consecutive_sample=False)
285 | val_sampler = DistributedSampler(val_set, pad=False, permutation=False, consecutive_sample=False)
286 |
287 | else:
288 | train_sampler = None
289 | val_sampler = None
290 |
291 | train_loader = DataLoader(train_set, batch_size=args.train_batch_size,
292 | num_workers=args.num_workers, shuffle=(train_sampler is None), drop_last=True, sampler = train_sampler)
293 | val_loader = DataLoader(val_set, batch_size=args.val_batch_size,
294 | num_workers=args.num_workers // 2 , shuffle=False, drop_last=False, sampler = val_sampler)
295 |
296 | return train_loader, val_loader, train_set
297 |
298 |
--------------------------------------------------------------------------------
/datasets/camvid.py:
--------------------------------------------------------------------------------
1 | """
2 | Camvid Dataset Loader
3 | """
4 |
5 | import os
6 | import sys
7 | import numpy as np
8 | from PIL import Image
9 | from torch.utils import data
10 | import logging
11 | import datasets.uniform as uniform
12 | import json
13 | from config import cfg
14 |
15 |
16 | # trainid_to_name = cityscapes_labels.trainId2name
17 | # id_to_trainid = cityscapes_labels.label2trainid
18 | num_classes = 11
19 | ignore_label = 11
20 | root = cfg.DATASET.CAMVID_DIR
21 |
22 | palette = [128, 128, 128,
23 | 128, 0, 0,
24 | 192, 192, 128,
25 | 128, 64, 128,
26 | 0, 0, 192,
27 | 128, 128, 0,
28 | 192, 128, 128,
29 | 64, 64, 128,
30 | 64, 0, 128,
31 | 64, 64, 0,
32 | 0, 128, 192]
33 |
34 |
35 | CAMVID_CLASSES = ['Sky',
36 | 'Building',
37 | 'Column-Pole',
38 | 'Road',
39 | 'Sidewalk',
40 | 'Tree',
41 | 'Sign-Symbol',
42 | 'Fence',
43 | 'Car',
44 | 'Pedestrain',
45 | 'Bicyclist',
46 | 'Void']
47 |
48 | CAMVID_CLASS_COLORS = [
49 | (128, 128, 128),
50 | (128, 0, 0),
51 | (192, 192, 128),
52 | (128, 64, 128),
53 | (0, 0, 192),
54 | (128, 128, 0),
55 | (192, 128, 128),
56 | (64, 64, 128),
57 | (64, 0, 128),
58 | (64, 64, 0),
59 | (0, 128, 192),
60 | (0, 0, 0),
61 | ]
62 |
63 | zero_pad = 256 * 3 - len(palette)
64 | for i in range(zero_pad):
65 | palette.append(0)
66 |
67 | def colorize_mask(mask):
68 | # mask: numpy array of the mask
69 | new_mask = Image.fromarray(mask.astype(np.uint8)).convert('P')
70 | new_mask.putpalette(palette)
71 | return new_mask
72 |
73 | def add_items(img_path, mask_path, aug_img_path, aug_mask_path, mode, maxSkip):
74 |
75 | c_items = os.listdir(img_path)
76 | c_items.sort()
77 | items = []
78 | aug_items = []
79 |
80 | for it in c_items:
81 | item = (os.path.join(img_path, it), os.path.join(mask_path, it))
82 | items.append(item)
83 | if mode != 'test' and maxSkip > 0:
84 | seq_info = it.split("_")
85 | cur_seq_id = seq_info[-1][:-4]
86 |
87 | if seq_info[0] == "0001TP":
88 | prev_seq_id = "%06d" % (int(cur_seq_id) - maxSkip)
89 | next_seq_id = "%06d" % (int(cur_seq_id) + maxSkip)
90 | elif seq_info[0] == "0006R0":
91 | prev_seq_id = "f%05d" % (int(cur_seq_id[1:]) - maxSkip)
92 | next_seq_id = "f%05d" % (int(cur_seq_id[1:]) + maxSkip)
93 | else:
94 | prev_seq_id = "%05d" % (int(cur_seq_id) - maxSkip)
95 | next_seq_id = "%05d" % (int(cur_seq_id) + maxSkip)
96 |
97 | prev_it = seq_info[0] + "_" + prev_seq_id + '.png'
98 | next_it = seq_info[0] + "_" + next_seq_id + '.png'
99 |
100 | prev_item = (os.path.join(aug_img_path, prev_it), os.path.join(aug_mask_path, prev_it))
101 | next_item = (os.path.join(aug_img_path, next_it), os.path.join(aug_mask_path, next_it))
102 | if os.path.isfile(prev_item[0]) and os.path.isfile(prev_item[1]):
103 | aug_items.append(prev_item)
104 | if os.path.isfile(next_item[0]) and os.path.isfile(next_item[1]):
105 | aug_items.append(next_item)
106 | return items, aug_items
107 |
108 | def make_dataset(quality, mode, maxSkip=0, cv_split=0, hardnm=0):
109 |
110 | items = []
111 | aug_items = []
112 | assert quality == 'semantic'
113 | assert mode in ['train', 'val', 'trainval', 'test']
114 |
115 | # img_dir_name = "SegNet/CamVid"
116 | original_img_dir = "LargeScale/CamVid"
117 | augmented_img_dir = "camvid_aug3/CamVid"
118 |
119 | img_path = os.path.join(root, original_img_dir, 'train')
120 | mask_path = os.path.join(root, original_img_dir, 'trainannot')
121 | aug_img_path = os.path.join(root, augmented_img_dir, 'train')
122 | aug_mask_path = os.path.join(root, augmented_img_dir, 'trainannot')
123 |
124 | train_items, train_aug_items = add_items(img_path, mask_path, aug_img_path, aug_mask_path, mode, maxSkip)
125 | logging.info('Camvid has a total of {} train images'.format(len(train_items)))
126 |
127 | img_path = os.path.join(root, original_img_dir, 'val')
128 | mask_path = os.path.join(root, original_img_dir, 'valannot')
129 | aug_img_path = os.path.join(root, augmented_img_dir, 'val')
130 | aug_mask_path = os.path.join(root, augmented_img_dir, 'valannot')
131 |
132 | val_items, val_aug_items = add_items(img_path, mask_path, aug_img_path, aug_mask_path, mode, maxSkip)
133 | logging.info('Camvid has a total of {} validation images'.format(len(val_items)))
134 |
135 | if mode == 'test':
136 | img_path = os.path.join(root, original_img_dir, 'test')
137 | mask_path = os.path.join(root, original_img_dir, 'testannot')
138 | test_items, test_aug_items = add_items(img_path, mask_path, aug_img_path, aug_mask_path, mode, maxSkip)
139 | logging.info('Camvid has a total of {} test images'.format(len(test_items)))
140 |
141 | if mode == 'train':
142 | items = train_items
143 | elif mode == 'val':
144 | items = val_items
145 | elif mode == 'trainval':
146 | items = train_items + val_items
147 | aug_items = train_aug_items + val_aug_items
148 | elif mode == 'test':
149 | items = test_items
150 | aug_items = []
151 | else:
152 | logging.info('Unknown mode {}'.format(mode))
153 | sys.exit()
154 |
155 | logging.info('Camvid-{}: {} images'.format(mode, len(items)))
156 |
157 | return items, aug_items
158 |
159 | class CAMVID(data.Dataset):
160 |
161 | def __init__(self, quality, mode, maxSkip=0, joint_transform_list=None,
162 | transform=None, target_transform=None, dump_images=False,
163 | class_uniform_pct=0, class_uniform_tile=0, test=False,
164 | cv_split=None, scf=None, hardnm=0):
165 |
166 | self.quality = quality
167 | self.mode = mode
168 | self.maxSkip = maxSkip
169 | self.joint_transform_list = joint_transform_list
170 | self.transform = transform
171 | self.target_transform = target_transform
172 | self.dump_images = dump_images
173 | self.class_uniform_pct = class_uniform_pct
174 | self.class_uniform_tile = class_uniform_tile
175 | self.scf = scf
176 | self.hardnm = hardnm
177 | self.cv_split = cv_split
178 | self.centroids = []
179 |
180 | self.imgs, self.aug_imgs = make_dataset(quality, mode, self.maxSkip, cv_split=self.cv_split, hardnm=self.hardnm)
181 | assert len(self.imgs), 'Found 0 images, please check the data set'
182 |
183 | # Centroids for GT data
184 | if self.class_uniform_pct > 0:
185 | json_fn = 'camvid_tile{}_cv{}_{}.json'.format(self.class_uniform_tile, self.cv_split, self.mode)
186 |
187 | if os.path.isfile(json_fn):
188 | with open(json_fn, 'r') as json_data:
189 | centroids = json.load(json_data)
190 | self.centroids = {int(idx): centroids[idx] for idx in centroids}
191 | else:
192 | self.centroids = uniform.class_centroids_all(
193 | self.imgs,
194 | num_classes,
195 | id2trainid=None,
196 | tile_size=class_uniform_tile)
197 | with open(json_fn, 'w') as outfile:
198 | json.dump(self.centroids, outfile, indent=4)
199 |
200 | self.fine_centroids = self.centroids.copy()
201 |
202 | if self.maxSkip > 0:
203 | json_fn = 'camvid_tile{}_cv{}_{}_skip{}.json'.format(self.class_uniform_tile, self.cv_split, self.mode, self.maxSkip)
204 | if os.path.isfile(json_fn):
205 | with open(json_fn, 'r') as json_data:
206 | centroids = json.load(json_data)
207 | self.aug_centroids = {int(idx): centroids[idx] for idx in centroids}
208 | else:
209 | self.aug_centroids = uniform.class_centroids_all(
210 | self.aug_imgs,
211 | num_classes,
212 | id2trainid=None,
213 | tile_size=class_uniform_tile)
214 | with open(json_fn, 'w') as outfile:
215 | json.dump(self.aug_centroids, outfile, indent=4)
216 |
217 | for class_id in range(num_classes):
218 | self.centroids[class_id].extend(self.aug_centroids[class_id])
219 |
220 | self.build_epoch()
221 |
222 | def build_epoch(self, cut=False):
223 |
224 | if self.class_uniform_pct > 0:
225 | if cut:
226 | self.imgs_uniform = uniform.build_epoch(self.imgs,
227 | self.fine_centroids,
228 | num_classes,
229 | cfg.CLASS_UNIFORM_PCT)
230 | else:
231 | self.imgs_uniform = uniform.build_epoch(self.imgs,
232 | self.centroids,
233 | num_classes,
234 | cfg.CLASS_UNIFORM_PCT)
235 | else:
236 | self.imgs_uniform = self.imgs
237 |
238 |
239 | def __getitem__(self, index):
240 | elem = self.imgs_uniform[index]
241 | centroid = None
242 | if len(elem) == 4:
243 | img_path, mask_path, centroid, class_id = elem
244 | else:
245 | img_path, mask_path = elem
246 |
247 | img, mask = Image.open(img_path).convert('RGB'), Image.open(mask_path)
248 | img_name = os.path.splitext(os.path.basename(img_path))[0]
249 |
250 | # Image Transformations
251 | if self.joint_transform_list is not None:
252 | for idx, xform in enumerate(self.joint_transform_list):
253 | if idx == 0 and centroid is not None:
254 | # HACK
255 | # We assume that the first transform is capable of taking
256 | # in a centroid
257 | img, mask = xform(img, mask, centroid)
258 | else:
259 | img, mask = xform(img, mask)
260 |
261 | # Debug
262 | if self.dump_images and centroid is not None:
263 | outdir = './dump_imgs_{}'.format(self.mode)
264 | os.makedirs(outdir, exist_ok=True)
265 | dump_img_name = trainid_to_name[class_id] + '_' + img_name
266 | out_img_fn = os.path.join(outdir, dump_img_name + '.png')
267 | out_msk_fn = os.path.join(outdir, dump_img_name + '_mask.png')
268 | mask_img = colorize_mask(np.array(mask))
269 | img.save(out_img_fn)
270 | mask_img.save(out_msk_fn)
271 |
272 | if self.transform is not None:
273 | img = self.transform(img)
274 | if self.target_transform is not None:
275 | mask = self.target_transform(mask)
276 |
277 | return img, mask, img_name
278 |
279 | def __len__(self):
280 | return len(self.imgs_uniform)
281 |
282 |
283 |
--------------------------------------------------------------------------------
/datasets/cityscapes_labels.py:
--------------------------------------------------------------------------------
1 | """
2 | # File taken from https://github.com/mcordts/cityscapesScripts/
3 | # License File Available at:
4 | # https://github.com/mcordts/cityscapesScripts/blob/master/license.txt
5 |
6 | # ----------------------
7 | # The Cityscapes Dataset
8 | # ----------------------
9 | #
10 | #
11 | # License agreement
12 | # -----------------
13 | #
14 | # This dataset is made freely available to academic and non-academic entities for non-commercial purposes such as academic research, teaching, scientific publications, or personal experimentation. Permission is granted to use the data given that you agree:
15 | #
16 | # 1. That the dataset comes "AS IS", without express or implied warranty. Although every effort has been made to ensure accuracy, we (Daimler AG, MPI Informatics, TU Darmstadt) do not accept any responsibility for errors or omissions.
17 | # 2. That you include a reference to the Cityscapes Dataset in any work that makes use of the dataset. For research papers, cite our preferred publication as listed on our website; for other media cite our preferred publication as listed on our website or link to the Cityscapes website.
18 | # 3. That you do not distribute this dataset or modified versions. It is permissible to distribute derivative works in as far as they are abstract representations of this dataset (such as models trained on it or additional annotations that do not directly include any of our data) and do not allow to recover the dataset or something similar in character.
19 | # 4. That you may not use the dataset or any derivative work for commercial purposes as, for example, licensing or selling the data, or using the data with a purpose to procure a commercial gain.
20 | # 5. That all rights not expressly granted to you are reserved by us (Daimler AG, MPI Informatics, TU Darmstadt).
21 | #
22 | #
23 | # Contact
24 | # -------
25 | #
26 | # Marius Cordts, Mohamed Omran
27 | # www.cityscapes-dataset.net
28 |
29 | """
30 | from collections import namedtuple
31 |
32 |
33 | #--------------------------------------------------------------------------------
34 | # Definitions
35 | #--------------------------------------------------------------------------------
36 |
37 | # a label and all meta information
38 | Label = namedtuple( 'Label' , [
39 |
40 | 'name' , # The identifier of this label, e.g. 'car', 'person', ... .
41 | # We use them to uniquely name a class
42 |
43 | 'id' , # An integer ID that is associated with this label.
44 | # The IDs are used to represent the label in ground truth images
45 | # An ID of -1 means that this label does not have an ID and thus
46 | # is ignored when creating ground truth images (e.g. license plate).
47 | # Do not modify these IDs, since exactly these IDs are expected by the
48 | # evaluation server.
49 |
50 | 'trainId' , # Feel free to modify these IDs as suitable for your method. Then create
51 | # ground truth images with train IDs, using the tools provided in the
52 | # 'preparation' folder. However, make sure to validate or submit results
53 | # to our evaluation server using the regular IDs above!
54 | # For trainIds, multiple labels might have the same ID. Then, these labels
55 | # are mapped to the same class in the ground truth images. For the inverse
56 | # mapping, we use the label that is defined first in the list below.
57 | # For example, mapping all void-type classes to the same ID in training,
58 | # might make sense for some approaches.
59 | # Max value is 255!
60 |
61 | 'category' , # The name of the category that this label belongs to
62 |
63 | 'categoryId' , # The ID of this category. Used to create ground truth images
64 | # on category level.
65 |
66 | 'hasInstances', # Whether this label distinguishes between single instances or not
67 |
68 | 'ignoreInEval', # Whether pixels having this class as ground truth label are ignored
69 | # during evaluations or not
70 |
71 | 'color' , # The color of this label
72 | ] )
73 |
74 |
75 | #--------------------------------------------------------------------------------
76 | # A list of all labels
77 | #--------------------------------------------------------------------------------
78 |
79 | # Please adapt the train IDs as appropriate for you approach.
80 | # Note that you might want to ignore labels with ID 255 during training.
81 | # Further note that the current train IDs are only a suggestion. You can use whatever you like.
82 | # Make sure to provide your results using the original IDs and not the training IDs.
83 | # Note that many IDs are ignored in evaluation and thus you never need to predict these!
84 |
85 | labels = [
86 | # name id trainId category catId hasInstances ignoreInEval color
87 | Label( 'unlabeled' , 0 , 255 , 'void' , 0 , False , True , ( 0, 0, 0) ),
88 | Label( 'ego vehicle' , 1 , 255 , 'void' , 0 , False , True , ( 0, 0, 0) ),
89 | Label( 'rectification border' , 2 , 255 , 'void' , 0 , False , True , ( 0, 0, 0) ),
90 | Label( 'out of roi' , 3 , 255 , 'void' , 0 , False , True , ( 0, 0, 0) ),
91 | Label( 'static' , 4 , 255 , 'void' , 0 , False , True , ( 0, 0, 0) ),
92 | Label( 'dynamic' , 5 , 255 , 'void' , 0 , False , True , (111, 74, 0) ),
93 | Label( 'ground' , 6 , 255 , 'void' , 0 , False , True , ( 81, 0, 81) ),
94 | Label( 'road' , 7 , 0 , 'flat' , 1 , False , False , (128, 64,128) ),
95 | Label( 'sidewalk' , 8 , 1 , 'flat' , 1 , False , False , (244, 35,232) ),
96 | Label( 'parking' , 9 , 255 , 'flat' , 1 , False , True , (250,170,160) ),
97 | Label( 'rail track' , 10 , 255 , 'flat' , 1 , False , True , (230,150,140) ),
98 | Label( 'building' , 11 , 2 , 'construction' , 2 , False , False , ( 70, 70, 70) ),
99 | Label( 'wall' , 12 , 3 , 'construction' , 2 , False , False , (102,102,156) ),
100 | Label( 'fence' , 13 , 4 , 'construction' , 2 , False , False , (190,153,153) ),
101 | Label( 'guard rail' , 14 , 255 , 'construction' , 2 , False , True , (180,165,180) ),
102 | Label( 'bridge' , 15 , 255 , 'construction' , 2 , False , True , (150,100,100) ),
103 | Label( 'tunnel' , 16 , 255 , 'construction' , 2 , False , True , (150,120, 90) ),
104 | Label( 'pole' , 17 , 5 , 'object' , 3 , False , False , (153,153,153) ),
105 | Label( 'polegroup' , 18 , 255 , 'object' , 3 , False , True , (153,153,153) ),
106 | Label( 'traffic light' , 19 , 6 , 'object' , 3 , False , False , (250,170, 30) ),
107 | Label( 'traffic sign' , 20 , 7 , 'object' , 3 , False , False , (220,220, 0) ),
108 | Label( 'vegetation' , 21 , 8 , 'nature' , 4 , False , False , (107,142, 35) ),
109 | Label( 'terrain' , 22 , 9 , 'nature' , 4 , False , False , (152,251,152) ),
110 | Label( 'sky' , 23 , 10 , 'sky' , 5 , False , False , ( 70,130,180) ),
111 | Label( 'person' , 24 , 11 , 'human' , 6 , True , False , (220, 20, 60) ),
112 | Label( 'rider' , 25 , 12 , 'human' , 6 , True , False , (255, 0, 0) ),
113 | Label( 'car' , 26 , 13 , 'vehicle' , 7 , True , False , ( 0, 0,142) ),
114 | Label( 'truck' , 27 , 14 , 'vehicle' , 7 , True , False , ( 0, 0, 70) ),
115 | Label( 'bus' , 28 , 15 , 'vehicle' , 7 , True , False , ( 0, 60,100) ),
116 | Label( 'caravan' , 29 , 255 , 'vehicle' , 7 , True , True , ( 0, 0, 90) ),
117 | Label( 'trailer' , 30 , 255 , 'vehicle' , 7 , True , True , ( 0, 0,110) ),
118 | Label( 'train' , 31 , 16 , 'vehicle' , 7 , True , False , ( 0, 80,100) ),
119 | Label( 'motorcycle' , 32 , 17 , 'vehicle' , 7 , True , False , ( 0, 0,230) ),
120 | Label( 'bicycle' , 33 , 18 , 'vehicle' , 7 , True , False , (119, 11, 32) ),
121 | Label( 'license plate' , -1 , -1 , 'vehicle' , 7 , False , True , ( 0, 0,142) ),
122 | ]
123 |
124 |
125 | #--------------------------------------------------------------------------------
126 | # Create dictionaries for a fast lookup
127 | #--------------------------------------------------------------------------------
128 |
129 | # Please refer to the main method below for example usages!
130 |
131 | # name to label object
132 | name2label = { label.name : label for label in labels }
133 | # id to label object
134 | id2label = { label.id : label for label in labels }
135 | # trainId to label object
136 | trainId2label = { label.trainId : label for label in reversed(labels) }
137 | # label2trainid
138 | label2trainid = { label.id : label.trainId for label in labels }
139 | # trainId to label object
140 | trainId2name = { label.trainId : label.name for label in labels }
141 | trainId2color = { label.trainId : label.color for label in labels }
142 | # category to list of label objects
143 | category2labels = {}
144 | for label in labels:
145 | category = label.category
146 | if category in category2labels:
147 | category2labels[category].append(label)
148 | else:
149 | category2labels[category] = [label]
150 |
151 | #--------------------------------------------------------------------------------
152 | # Assure single instance name
153 | #--------------------------------------------------------------------------------
154 |
155 | # returns the label name that describes a single instance (if possible)
156 | # e.g. input | output
157 | # ----------------------
158 | # car | car
159 | # cargroup | car
160 | # foo | None
161 | # foogroup | None
162 | # skygroup | None
163 | def assureSingleInstanceName( name ):
164 | # if the name is known, it is not a group
165 | if name in name2label:
166 | return name
167 | # test if the name actually denotes a group
168 | if not name.endswith("group"):
169 | return None
170 | # remove group
171 | name = name[:-len("group")]
172 | # test if the new name exists
173 | if not name in name2label:
174 | return None
175 | # test if the new name denotes a label that actually has instances
176 | if not name2label[name].hasInstances:
177 | return None
178 | # all good then
179 | return name
180 |
181 | #--------------------------------------------------------------------------------
182 | # Main for testing
183 | #--------------------------------------------------------------------------------
184 |
185 | # just a dummy main
186 | if __name__ == "__main__":
187 | # Print all the labels
188 | print("List of cityscapes labels:")
189 | print("")
190 | print((" {:>21} | {:>3} | {:>7} | {:>14} | {:>10} | {:>12} | {:>12}".format( 'name', 'id', 'trainId', 'category', 'categoryId', 'hasInstances', 'ignoreInEval' )))
191 | print((" " + ('-' * 98)))
192 | for label in labels:
193 | print((" {:>21} | {:>3} | {:>7} | {:>14} | {:>10} | {:>12} | {:>12}".format( label.name, label.id, label.trainId, label.category, label.categoryId, label.hasInstances, label.ignoreInEval )))
194 | print("")
195 |
196 | print("Example usages:")
197 |
198 | # Map from name to label
199 | name = 'car'
200 | id = name2label[name].id
201 | print(("ID of label '{name}': {id}".format( name=name, id=id )))
202 |
203 | # Map from ID to label
204 | category = id2label[id].category
205 | print(("Category of label with ID '{id}': {category}".format( id=id, category=category )))
206 |
207 | # Map from trainID to label
208 | trainId = 0
209 | name = trainId2label[trainId].name
210 | print(("Name of label with trainID '{id}': {name}".format( id=trainId, name=name )))
211 |
--------------------------------------------------------------------------------
/datasets/comma10k.py:
--------------------------------------------------------------------------------
1 | """
2 | KITTI Dataset Loader
3 | http://www.cvlibs.net/datasets/kitti/eval_semseg.php?benchmark=semantics2015
4 | """
5 |
6 | import os
7 | import sys
8 | import numpy as np
9 | from PIL import Image
10 | from torch.utils import data
11 | import logging
12 | import datasets.uniform as uniform
13 | import datasets.cityscapes_labels as cityscapes_labels
14 | import json
15 | from config import cfg
16 |
17 |
18 | trainid_to_name = cityscapes_labels.trainId2name
19 | #id_to_trainid = cityscapes_labels.label2trainid
20 |
21 |
22 | id_to_trainid = {
23 | 0x40: 0, # road
24 | 0xff: 1, # lane marking
25 | 0x80: 2, # sky/undrivable
26 | 0x00: 3, # car
27 | 0xcc: 4} # ego
28 |
29 | num_classes = 5
30 | ignore_label = 255
31 | root = cfg.DATASET.KITTI_DIR
32 | aug_root = cfg.DATASET.KITTI_AUG_DIR
33 |
34 | palette = [128, 64, 128, 244, 35, 232, 70, 70, 70, 102, 102, 156, 190, 153, 153,
35 | 153, 153, 153, 250, 170, 30,
36 | 220, 220, 0, 107, 142, 35, 152, 251, 152, 70, 130, 180, 220, 20, 60,
37 | 255, 0, 0, 0, 0, 142, 0, 0, 70,
38 | 0, 60, 100, 0, 80, 100, 0, 0, 230, 119, 11, 32]
39 | zero_pad = 256 * 3 - len(palette)
40 | for i in range(zero_pad):
41 | palette.append(0)
42 |
43 | def colorize_mask(mask):
44 | # mask: numpy array of the mask
45 | new_mask = Image.fromarray(mask.astype(np.uint8)).convert('P')
46 | new_mask.putpalette(palette)
47 | return new_mask
48 |
49 | def get_train_val(cv_split, all_items):
50 | # 90/10 train/val split, three random splits for cross validation
51 | val_0 = [1,5,11,29,35,49,57,68,72,82,93,115,119,130,145,154,156,167,169,189,198]
52 | val_1 = [0,12,24,31,42,50,63,71,84,96,101,112,121,133,141,155,164,171,187,191,197]
53 | val_2 = [3,6,13,21,41,54,61,73,88,91,110,121,126,131,142,149,150,163,173,183,199]
54 |
55 | train_set = []
56 | val_set = []
57 |
58 | if cv_split == 0:
59 | for i in range(200):
60 | if i in val_0:
61 | val_set.append(all_items[i])
62 | else:
63 | train_set.append(all_items[i])
64 | elif cv_split == 1:
65 | for i in range(200):
66 | if i in val_1:
67 | val_set.append(all_items[i])
68 | else:
69 | train_set.append(all_items[i])
70 | elif cv_split == 2:
71 | for i in range(200):
72 | if i in val_2:
73 | val_set.append(all_items[i])
74 | else:
75 | train_set.append(all_items[i])
76 | else:
77 | logging.info('Unknown cv_split {}'.format(cv_split))
78 | sys.exit()
79 |
80 | return train_set, val_set
81 |
82 | def make_dataset(quality, mode, maxSkip=0, cv_split=0, hardnm=0):
83 | items = []
84 | all_items = []
85 | aug_items = []
86 |
87 | assert quality == 'semantic'
88 | assert mode in ['train', 'val', 'trainval']
89 | # note that train and val are randomly determined, no official split
90 |
91 | """
92 | img_dir_name = "training"
93 | img_path = os.path.join(root, img_dir_name, 'image_2')
94 | mask_path = os.path.join(root, img_dir_name, 'semantic')
95 | """
96 |
97 | img_path = "/raid/comma10k/imgs"
98 | mask_path = "/raid/comma10k/masks"
99 |
100 | c_items = os.listdir(img_path)
101 | c_items.sort()
102 |
103 | for it in c_items:
104 | item = (os.path.join(img_path, it), os.path.join(mask_path, it))
105 | all_items.append(item)
106 | logging.info('KITTI has a total of {} images'.format(len(all_items)))
107 |
108 | # split into train/val
109 | train_set, val_set = get_train_val(cv_split, all_items)
110 |
111 | if mode == 'train':
112 | items = train_set
113 | elif mode == 'val':
114 | items = val_set
115 | elif mode == 'trainval':
116 | items = train_set + val_set
117 | else:
118 | logging.info('Unknown mode {}'.format(mode))
119 | sys.exit()
120 |
121 | logging.info('KITTI-{}: {} images'.format(mode, len(items)))
122 |
123 | return items, aug_items
124 |
125 | def make_test_dataset(quality, mode, maxSkip=0, cv_split=0):
126 | items = []
127 | assert quality == 'semantic'
128 | assert mode == 'test'
129 |
130 | img_dir_name = "testing"
131 | img_path = os.path.join(root, img_dir_name, 'image_2')
132 |
133 | c_items = os.listdir(img_path)
134 | c_items.sort()
135 | for it in c_items:
136 | item = (os.path.join(img_path, it), None)
137 | items.append(item)
138 | logging.info('KITTI has a total of {} test images'.format(len(items)))
139 |
140 | return items, []
141 |
142 | class COMMA10K(data.Dataset):
143 |
144 | def __init__(self, quality, mode, maxSkip=0, joint_transform_list=None,
145 | transform=None, target_transform=None, dump_images=False,
146 | class_uniform_pct=0, class_uniform_tile=0, test=False,
147 | cv_split=None, scf=None, hardnm=0):
148 |
149 | self.quality = quality
150 | self.mode = mode
151 | self.maxSkip = maxSkip
152 | self.joint_transform_list = joint_transform_list
153 | self.transform = transform
154 | self.target_transform = target_transform
155 | self.dump_images = dump_images
156 | self.class_uniform_pct = class_uniform_pct
157 | self.class_uniform_tile = class_uniform_tile
158 | self.scf = scf
159 | self.hardnm = hardnm
160 |
161 | if cv_split:
162 | self.cv_split = cv_split
163 | assert cv_split < cfg.DATASET.CV_SPLITS, \
164 | 'expected cv_split {} to be < CV_SPLITS {}'.format(
165 | cv_split, cfg.DATASET.CV_SPLITS)
166 | else:
167 | self.cv_split = 0
168 |
169 | if self.mode == 'test':
170 | self.imgs, _ = make_test_dataset(quality, mode, self.maxSkip, cv_split=self.cv_split)
171 | else:
172 | self.imgs, _ = make_dataset(quality, mode, self.maxSkip, cv_split=self.cv_split, hardnm=self.hardnm)
173 | assert len(self.imgs), 'Found 0 images, please check the data set'
174 |
175 | # Centroids for GT data
176 | if self.class_uniform_pct > 0:
177 | if self.scf:
178 | json_fn = 'kitti_tile{}_cv{}_scf.json'.format(self.class_uniform_tile, self.cv_split)
179 | else:
180 | json_fn = 'kitti_tile{}_cv{}_{}_hardnm{}.json'.format(self.class_uniform_tile, self.cv_split, self.mode, self.hardnm)
181 | if os.path.isfile(json_fn):
182 | with open(json_fn, 'r') as json_data:
183 | centroids = json.load(json_data)
184 | self.centroids = {int(idx): centroids[idx] for idx in centroids}
185 | else:
186 | if self.scf:
187 | self.centroids = kitti_uniform.class_centroids_all(
188 | self.imgs,
189 | num_classes,
190 | id2trainid=id_to_trainid,
191 | tile_size=class_uniform_tile)
192 | else:
193 | self.centroids = uniform.class_centroids_all(
194 | self.imgs,
195 | num_classes,
196 | id2trainid=id_to_trainid,
197 | tile_size=class_uniform_tile)
198 | with open(json_fn, 'w') as outfile:
199 | json.dump(self.centroids, outfile, indent=4)
200 |
201 | self.build_epoch()
202 |
203 | def build_epoch(self, cut=False):
204 | if self.class_uniform_pct > 0:
205 | self.imgs_uniform = uniform.build_epoch(self.imgs,
206 | self.centroids,
207 | num_classes,
208 | cfg.CLASS_UNIFORM_PCT)
209 | else:
210 | self.imgs_uniform = self.imgs
211 |
212 | def __getitem__(self, index):
213 | elem = self.imgs_uniform[index]
214 | centroid = None
215 | if len(elem) == 4:
216 | img_path, mask_path, centroid, class_id = elem
217 | else:
218 | img_path, mask_path = elem
219 |
220 | if self.mode == 'test':
221 | img, mask = Image.open(img_path).convert('RGB'), None
222 | else:
223 | img, mask = Image.open(img_path).convert('RGB'), Image.open(mask_path)
224 | img_name = os.path.splitext(os.path.basename(img_path))[0]
225 |
226 | # kitti scale correction factor
227 | if self.mode == 'train' or self.mode == 'trainval':
228 | if self.scf:
229 | width, height = img.size
230 | img = img.resize((width*2, height*2), Image.BICUBIC)
231 | mask = mask.resize((width*2, height*2), Image.NEAREST)
232 | elif self.mode == 'val':
233 | width, height = 1242, 376
234 | img = img.resize((width, height), Image.BICUBIC)
235 | mask = mask.resize((width, height), Image.NEAREST)
236 | elif self.mode == 'test':
237 | img_keepsize = img.copy()
238 | width, height = 1280, 384
239 | img = img.resize((width, height), Image.BICUBIC)
240 | else:
241 | logging.info('Unknown mode {}'.format(mode))
242 | sys.exit()
243 |
244 | if self.mode != 'test':
245 | mask = np.array(mask)[:, :, 0]
246 | mask_copy = mask.copy()
247 |
248 | for k, v in id_to_trainid.items():
249 | mask_copy[mask == k] = v
250 | mask = Image.fromarray(mask_copy.astype(np.uint8))
251 |
252 | # Image Transformations
253 | if self.joint_transform_list is not None:
254 | for idx, xform in enumerate(self.joint_transform_list):
255 | if idx == 0 and centroid is not None:
256 | # HACK
257 | # We assume that the first transform is capable of taking
258 | # in a centroid
259 | img, mask = xform(img, mask, centroid)
260 | else:
261 | img, mask = xform(img, mask)
262 |
263 | # Debug
264 | if self.dump_images and centroid is not None:
265 | outdir = './dump_imgs_{}'.format(self.mode)
266 | os.makedirs(outdir, exist_ok=True)
267 | dump_img_name = trainid_to_name[class_id] + '_' + img_name
268 | out_img_fn = os.path.join(outdir, dump_img_name + '.png')
269 | out_msk_fn = os.path.join(outdir, dump_img_name + '_mask.png')
270 | mask_img = colorize_mask(np.array(mask))
271 | img.save(out_img_fn)
272 | mask_img.save(out_msk_fn)
273 |
274 | if self.transform is not None:
275 | img = self.transform(img)
276 | if self.mode == 'test':
277 | img_keepsize = self.transform(img_keepsize)
278 | mask = img_keepsize
279 | if self.target_transform is not None:
280 | if self.mode != 'test':
281 | mask = self.target_transform(mask)
282 |
283 | return img, mask, img_name
284 |
285 | def __len__(self):
286 | return len(self.imgs_uniform)
287 |
--------------------------------------------------------------------------------
/datasets/kitti.py:
--------------------------------------------------------------------------------
1 | """
2 | KITTI Dataset Loader
3 | http://www.cvlibs.net/datasets/kitti/eval_semseg.php?benchmark=semantics2015
4 | """
5 |
6 | import os
7 | import sys
8 | import numpy as np
9 | from PIL import Image
10 | from torch.utils import data
11 | import logging
12 | import datasets.uniform as uniform
13 | import datasets.cityscapes_labels as cityscapes_labels
14 | import json
15 | from config import cfg
16 |
17 |
18 | trainid_to_name = cityscapes_labels.trainId2name
19 | id_to_trainid = cityscapes_labels.label2trainid
20 | num_classes = 19
21 | ignore_label = 255
22 | root = cfg.DATASET.KITTI_DIR
23 | aug_root = cfg.DATASET.KITTI_AUG_DIR
24 |
25 | palette = [128, 64, 128, 244, 35, 232, 70, 70, 70, 102, 102, 156, 190, 153, 153,
26 | 153, 153, 153, 250, 170, 30,
27 | 220, 220, 0, 107, 142, 35, 152, 251, 152, 70, 130, 180, 220, 20, 60,
28 | 255, 0, 0, 0, 0, 142, 0, 0, 70,
29 | 0, 60, 100, 0, 80, 100, 0, 0, 230, 119, 11, 32]
30 | zero_pad = 256 * 3 - len(palette)
31 | for i in range(zero_pad):
32 | palette.append(0)
33 |
34 | def colorize_mask(mask):
35 | # mask: numpy array of the mask
36 | new_mask = Image.fromarray(mask.astype(np.uint8)).convert('P')
37 | new_mask.putpalette(palette)
38 | return new_mask
39 |
40 | def get_train_val(cv_split, all_items):
41 | # 90/10 train/val split, three random splits for cross validation
42 | val_0 = [1,5,11,29,35,49,57,68,72,82,93,115,119,130,145,154,156,167,169,189,198]
43 | val_1 = [0,12,24,31,42,50,63,71,84,96,101,112,121,133,141,155,164,171,187,191,197]
44 | val_2 = [3,6,13,21,41,54,61,73,88,91,110,121,126,131,142,149,150,163,173,183,199]
45 |
46 | train_set = []
47 | val_set = []
48 |
49 | if cv_split == 0:
50 | for i in range(200):
51 | if i in val_0:
52 | val_set.append(all_items[i])
53 | else:
54 | train_set.append(all_items[i])
55 | elif cv_split == 1:
56 | for i in range(200):
57 | if i in val_1:
58 | val_set.append(all_items[i])
59 | else:
60 | train_set.append(all_items[i])
61 | elif cv_split == 2:
62 | for i in range(200):
63 | if i in val_2:
64 | val_set.append(all_items[i])
65 | else:
66 | train_set.append(all_items[i])
67 | else:
68 | logging.info('Unknown cv_split {}'.format(cv_split))
69 | sys.exit()
70 |
71 | return train_set, val_set
72 |
73 | def make_dataset(quality, mode, maxSkip=0, cv_split=0, hardnm=0):
74 | items = []
75 | all_items = []
76 | aug_items = []
77 |
78 | assert quality == 'semantic'
79 | assert mode in ['train', 'val', 'trainval']
80 | # note that train and val are randomly determined, no official split
81 |
82 | img_dir_name = "training"
83 | img_path = os.path.join(root, img_dir_name, 'image_2')
84 | mask_path = os.path.join(root, img_dir_name, 'semantic')
85 |
86 | c_items = os.listdir(img_path)
87 | c_items.sort()
88 |
89 | for it in c_items:
90 | item = (os.path.join(img_path, it), os.path.join(mask_path, it))
91 | all_items.append(item)
92 | logging.info('KITTI has a total of {} images'.format(len(all_items)))
93 |
94 | # split into train/val
95 | train_set, val_set = get_train_val(cv_split, all_items)
96 |
97 | if mode == 'train':
98 | items = train_set
99 | elif mode == 'val':
100 | items = val_set
101 | elif mode == 'trainval':
102 | items = train_set + val_set
103 | else:
104 | logging.info('Unknown mode {}'.format(mode))
105 | sys.exit()
106 |
107 | logging.info('KITTI-{}: {} images'.format(mode, len(items)))
108 |
109 | return items, aug_items
110 |
111 | def make_test_dataset(quality, mode, maxSkip=0, cv_split=0):
112 | items = []
113 | assert quality == 'semantic'
114 | assert mode == 'test'
115 |
116 | img_dir_name = "testing"
117 | img_path = os.path.join(root, img_dir_name, 'image_2')
118 |
119 | c_items = os.listdir(img_path)
120 | c_items.sort()
121 | for it in c_items:
122 | item = (os.path.join(img_path, it), None)
123 | items.append(item)
124 | logging.info('KITTI has a total of {} test images'.format(len(items)))
125 |
126 | return items, []
127 |
128 | class KITTI(data.Dataset):
129 |
130 | def __init__(self, quality, mode, maxSkip=0, joint_transform_list=None,
131 | transform=None, target_transform=None, dump_images=False,
132 | class_uniform_pct=0, class_uniform_tile=0, test=False,
133 | cv_split=None, scf=None, hardnm=0):
134 |
135 | self.quality = quality
136 | self.mode = mode
137 | self.maxSkip = maxSkip
138 | self.joint_transform_list = joint_transform_list
139 | self.transform = transform
140 | self.target_transform = target_transform
141 | self.dump_images = dump_images
142 | self.class_uniform_pct = class_uniform_pct
143 | self.class_uniform_tile = class_uniform_tile
144 | self.scf = scf
145 | self.hardnm = hardnm
146 |
147 | if cv_split:
148 | self.cv_split = cv_split
149 | assert cv_split < cfg.DATASET.CV_SPLITS, \
150 | 'expected cv_split {} to be < CV_SPLITS {}'.format(
151 | cv_split, cfg.DATASET.CV_SPLITS)
152 | else:
153 | self.cv_split = 0
154 |
155 | if self.mode == 'test':
156 | self.imgs, _ = make_test_dataset(quality, mode, self.maxSkip, cv_split=self.cv_split)
157 | else:
158 | self.imgs, _ = make_dataset(quality, mode, self.maxSkip, cv_split=self.cv_split, hardnm=self.hardnm)
159 | assert len(self.imgs), 'Found 0 images, please check the data set'
160 |
161 | # Centroids for GT data
162 | if self.class_uniform_pct > 0:
163 | if self.scf:
164 | json_fn = 'kitti_tile{}_cv{}_scf.json'.format(self.class_uniform_tile, self.cv_split)
165 | else:
166 | json_fn = 'kitti_tile{}_cv{}_{}_hardnm{}.json'.format(self.class_uniform_tile, self.cv_split, self.mode, self.hardnm)
167 | if os.path.isfile(json_fn):
168 | with open(json_fn, 'r') as json_data:
169 | centroids = json.load(json_data)
170 | self.centroids = {int(idx): centroids[idx] for idx in centroids}
171 | else:
172 | if self.scf:
173 | self.centroids = kitti_uniform.class_centroids_all(
174 | self.imgs,
175 | num_classes,
176 | id2trainid=id_to_trainid,
177 | tile_size=class_uniform_tile)
178 | else:
179 | self.centroids = uniform.class_centroids_all(
180 | self.imgs,
181 | num_classes,
182 | id2trainid=id_to_trainid,
183 | tile_size=class_uniform_tile)
184 | with open(json_fn, 'w') as outfile:
185 | json.dump(self.centroids, outfile, indent=4)
186 |
187 | self.build_epoch()
188 |
189 | def build_epoch(self, cut=False):
190 | if self.class_uniform_pct > 0:
191 | self.imgs_uniform = uniform.build_epoch(self.imgs,
192 | self.centroids,
193 | num_classes,
194 | cfg.CLASS_UNIFORM_PCT)
195 | else:
196 | self.imgs_uniform = self.imgs
197 |
198 | def __getitem__(self, index):
199 | elem = self.imgs_uniform[index]
200 | centroid = None
201 | if len(elem) == 4:
202 | img_path, mask_path, centroid, class_id = elem
203 | else:
204 | img_path, mask_path = elem
205 |
206 | if self.mode == 'test':
207 | img, mask = Image.open(img_path).convert('RGB'), None
208 | else:
209 | img, mask = Image.open(img_path).convert('RGB'), Image.open(mask_path)
210 | img_name = os.path.splitext(os.path.basename(img_path))[0]
211 |
212 | # kitti scale correction factor
213 | if self.mode == 'train' or self.mode == 'trainval':
214 | if self.scf:
215 | width, height = img.size
216 | img = img.resize((width*2, height*2), Image.BICUBIC)
217 | mask = mask.resize((width*2, height*2), Image.NEAREST)
218 | elif self.mode == 'val':
219 | width, height = 1242, 376
220 | img = img.resize((width, height), Image.BICUBIC)
221 | mask = mask.resize((width, height), Image.NEAREST)
222 | elif self.mode == 'test':
223 | img_keepsize = img.copy()
224 | width, height = 1280, 384
225 | img = img.resize((width, height), Image.BICUBIC)
226 | else:
227 | logging.info('Unknown mode {}'.format(mode))
228 | sys.exit()
229 |
230 | if self.mode != 'test':
231 | mask = np.array(mask)
232 | mask_copy = mask.copy()
233 |
234 | for k, v in id_to_trainid.items():
235 | mask_copy[mask == k] = v
236 | mask = Image.fromarray(mask_copy.astype(np.uint8))
237 |
238 | # Image Transformations
239 | if self.joint_transform_list is not None:
240 | for idx, xform in enumerate(self.joint_transform_list):
241 | if idx == 0 and centroid is not None:
242 | # HACK
243 | # We assume that the first transform is capable of taking
244 | # in a centroid
245 | img, mask = xform(img, mask, centroid)
246 | else:
247 | img, mask = xform(img, mask)
248 |
249 | # Debug
250 | if self.dump_images and centroid is not None:
251 | outdir = './dump_imgs_{}'.format(self.mode)
252 | os.makedirs(outdir, exist_ok=True)
253 | dump_img_name = trainid_to_name[class_id] + '_' + img_name
254 | out_img_fn = os.path.join(outdir, dump_img_name + '.png')
255 | out_msk_fn = os.path.join(outdir, dump_img_name + '_mask.png')
256 | mask_img = colorize_mask(np.array(mask))
257 | img.save(out_img_fn)
258 | mask_img.save(out_msk_fn)
259 |
260 | if self.transform is not None:
261 | img = self.transform(img)
262 | if self.mode == 'test':
263 | img_keepsize = self.transform(img_keepsize)
264 | mask = img_keepsize
265 | if self.target_transform is not None:
266 | if self.mode != 'test':
267 | mask = self.target_transform(mask)
268 |
269 | return img, mask, img_name
270 |
271 | def __len__(self):
272 | return len(self.imgs_uniform)
273 |
--------------------------------------------------------------------------------
/datasets/mapillary.py:
--------------------------------------------------------------------------------
1 | """
2 | Mapillary Dataset Loader
3 | """
4 | from PIL import Image
5 | from torch.utils import data
6 | import os
7 | import numpy as np
8 | import json
9 | import datasets.uniform as uniform
10 | from config import cfg
11 |
12 | num_classes = 65
13 | ignore_label = 65
14 | root = cfg.DATASET.MAPILLARY_DIR
15 | config_fn = os.path.join(root, 'config.json')
16 | id_to_ignore_or_group = {}
17 | color_mapping = []
18 | id_to_trainid = {}
19 |
20 |
21 | def colorize_mask(image_array):
22 | """
23 | Colorize a segmentation mask
24 | """
25 | new_mask = Image.fromarray(image_array.astype(np.uint8)).convert('P')
26 | new_mask.putpalette(color_mapping)
27 | return new_mask
28 |
29 |
30 | def make_dataset(quality, mode):
31 | """
32 | Create File List
33 | """
34 | assert (quality == 'semantic' and mode in ['train', 'val'])
35 | img_dir_name = None
36 | if quality == 'semantic':
37 | if mode == 'train':
38 | img_dir_name = 'training'
39 | if mode == 'val':
40 | img_dir_name = 'validation'
41 | mask_path = os.path.join(root, img_dir_name, 'labels')
42 | else:
43 | raise BaseException("Instance Segmentation Not support")
44 |
45 | img_path = os.path.join(root, img_dir_name, 'images')
46 | print(img_path)
47 | if quality != 'video':
48 | imgs = sorted([os.path.splitext(f)[0] for f in os.listdir(img_path)])
49 | msks = sorted([os.path.splitext(f)[0] for f in os.listdir(mask_path)])
50 | assert imgs == msks
51 |
52 | items = []
53 | c_items = os.listdir(img_path)
54 | if '.DS_Store' in c_items:
55 | c_items.remove('.DS_Store')
56 |
57 | for it in c_items:
58 | if quality == 'video':
59 | item = (os.path.join(img_path, it), os.path.join(img_path, it))
60 | else:
61 | item = (os.path.join(img_path, it),
62 | os.path.join(mask_path, it.replace(".jpg", ".png")))
63 | items.append(item)
64 | return items
65 |
66 |
67 | def gen_colormap():
68 | """
69 | Get Color Map from file
70 | """
71 | global color_mapping
72 |
73 | # load mapillary config
74 | with open(config_fn) as config_file:
75 | config = json.load(config_file)
76 | config_labels = config['labels']
77 |
78 | # calculate label color mapping
79 | colormap = []
80 | id2name = {}
81 | for i in range(0, len(config_labels)):
82 | colormap = colormap + config_labels[i]['color']
83 | id2name[i] = config_labels[i]['readable']
84 | color_mapping = colormap
85 | return id2name
86 |
87 |
88 | class Mapillary(data.Dataset):
89 | def __init__(self, quality, mode, joint_transform_list=None,
90 | transform=None, target_transform=None, dump_images=False,
91 | class_uniform_pct=0, class_uniform_tile=768, test=False):
92 | """
93 | class_uniform_pct = Percent of class uniform samples. 1.0 means fully uniform.
94 | 0.0 means fully random.
95 | class_uniform_tile_size = Class uniform tile size
96 | """
97 | self.quality = quality
98 | self.mode = mode
99 | self.joint_transform_list = joint_transform_list
100 | self.transform = transform
101 | self.target_transform = target_transform
102 | self.dump_images = dump_images
103 | self.class_uniform_pct = class_uniform_pct
104 | self.class_uniform_tile = class_uniform_tile
105 | self.id2name = gen_colormap()
106 | self.imgs_uniform = None
107 | for i in range(num_classes):
108 | id_to_trainid[i] = i
109 |
110 | # find all images
111 | self.imgs = make_dataset(quality, mode)
112 | if len(self.imgs) == 0:
113 | raise RuntimeError('Found 0 images, please check the data set')
114 | if test:
115 | np.random.shuffle(self.imgs)
116 | self.imgs = self.imgs[:200]
117 |
118 | if self.class_uniform_pct:
119 | json_fn = 'mapillary_tile{}.json'.format(self.class_uniform_tile)
120 | if os.path.isfile(json_fn):
121 | with open(json_fn, 'r') as json_data:
122 | centroids = json.load(json_data)
123 | self.centroids = {int(idx): centroids[idx] for idx in centroids}
124 | else:
125 | # centroids is a dict (indexed by class) of lists of centroids
126 | self.centroids = uniform.class_centroids_all(
127 | self.imgs,
128 | num_classes,
129 | id2trainid=None,
130 | tile_size=self.class_uniform_tile)
131 | with open(json_fn, 'w') as outfile:
132 | json.dump(self.centroids, outfile, indent=4)
133 | else:
134 | self.centroids = []
135 | self.build_epoch()
136 |
137 | def build_epoch(self):
138 | if self.class_uniform_pct != 0:
139 | self.imgs_uniform = uniform.build_epoch(self.imgs,
140 | self.centroids,
141 | num_classes,
142 | self.class_uniform_pct)
143 | else:
144 | self.imgs_uniform = self.imgs
145 |
146 | def __getitem__(self, index):
147 | if len(self.imgs_uniform[index]) == 2:
148 | img_path, mask_path = self.imgs_uniform[index]
149 | centroid = None
150 | class_id = None
151 | else:
152 | img_path, mask_path, centroid, class_id = self.imgs_uniform[index]
153 | img, mask = Image.open(img_path).convert('RGB'), Image.open(mask_path)
154 | img_name = os.path.splitext(os.path.basename(img_path))[0]
155 |
156 | mask = np.array(mask)
157 | mask_copy = mask.copy()
158 | for k, v in id_to_ignore_or_group.items():
159 | mask_copy[mask == k] = v
160 | mask = Image.fromarray(mask_copy.astype(np.uint8))
161 |
162 | # Image Transformations
163 | if self.joint_transform_list is not None:
164 | for idx, xform in enumerate(self.joint_transform_list):
165 | if idx == 0 and centroid is not None:
166 | # HACK! Assume the first transform accepts a centroid
167 | img, mask = xform(img, mask, centroid)
168 | else:
169 | img, mask = xform(img, mask)
170 |
171 | if self.dump_images:
172 | outdir = 'dump_imgs_{}'.format(self.mode)
173 | os.makedirs(outdir, exist_ok=True)
174 | if centroid is not None:
175 | dump_img_name = self.id2name[class_id] + '_' + img_name
176 | else:
177 | dump_img_name = img_name
178 | out_img_fn = os.path.join(outdir, dump_img_name + '.png')
179 | out_msk_fn = os.path.join(outdir, dump_img_name + '_mask.png')
180 | mask_img = colorize_mask(np.array(mask))
181 | img.save(out_img_fn)
182 | mask_img.save(out_msk_fn)
183 |
184 | if self.transform is not None:
185 | img = self.transform(img)
186 | if self.target_transform is not None:
187 | mask = self.target_transform(mask)
188 | return img, mask, img_name
189 |
190 | def __len__(self):
191 | return len(self.imgs_uniform)
192 |
193 | def calculate_weights(self):
194 | raise BaseException("not supported yet")
195 |
--------------------------------------------------------------------------------
/datasets/nullloader.py:
--------------------------------------------------------------------------------
1 | """
2 | Null Loader
3 | """
4 | import numpy as np
5 | import torch
6 | from torch.utils import data
7 |
8 | num_classes = 19
9 | ignore_label = 255
10 |
11 | class NullLoader(data.Dataset):
12 | """
13 | Null Dataset for Performance
14 | """
15 | def __init__(self,crop_size):
16 | self.imgs = range(200)
17 | self.crop_size = crop_size
18 |
19 | def __getitem__(self, index):
20 | #Return img, mask, name
21 | return torch.FloatTensor(np.zeros((3,self.crop_size,self.crop_size))), torch.LongTensor(np.zeros((self.crop_size,self.crop_size))), 'img' + str(index)
22 |
23 | def __len__(self):
24 | return len(self.imgs)
--------------------------------------------------------------------------------
/datasets/sampler.py:
--------------------------------------------------------------------------------
1 | """
2 | # Code adapted from:
3 | # https://github.com/pytorch/pytorch/blob/master/torch/utils/data/distributed.py
4 | #
5 | # BSD 3-Clause License
6 | #
7 | # Copyright (c) 2017,
8 | # All rights reserved.
9 | #
10 | # Redistribution and use in source and binary forms, with or without
11 | # modification, are permitted provided that the following conditions are met:
12 | #
13 | # * Redistributions of source code must retain the above copyright notice, this
14 | # list of conditions and the following disclaimer.
15 | #
16 | # * Redistributions in binary form must reproduce the above copyright notice,
17 | # this list of conditions and the following disclaimer in the documentation
18 | # and/or other materials provided with the distribution.
19 | #
20 | # * Neither the name of the copyright holder nor the names of its
21 | # contributors may be used to endorse or promote products derived from
22 | # this software without specific prior written permission.
23 | #
24 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
25 | # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
26 | # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
27 | # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
28 | # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
29 | # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
30 | # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
31 | # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
32 | # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
33 | # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
34 | """
35 |
36 |
37 |
38 | import math
39 | import torch
40 | from torch.distributed import get_world_size, get_rank
41 | from torch.utils.data import Sampler
42 |
43 | class DistributedSampler(Sampler):
44 | """Sampler that restricts data loading to a subset of the dataset.
45 |
46 | It is especially useful in conjunction with
47 | :class:`torch.nn.parallel.DistributedDataParallel`. In such case, each
48 | process can pass a DistributedSampler instance as a DataLoader sampler,
49 | and load a subset of the original dataset that is exclusive to it.
50 |
51 | .. note::
52 | Dataset is assumed to be of constant size.
53 |
54 | Arguments:
55 | dataset: Dataset used for sampling.
56 | num_replicas (optional): Number of processes participating in
57 | distributed training.
58 | rank (optional): Rank of the current process within num_replicas.
59 | """
60 |
61 | def __init__(self, dataset, pad=False, consecutive_sample=False, permutation=False, num_replicas=None, rank=None):
62 | if num_replicas is None:
63 | num_replicas = get_world_size()
64 | if rank is None:
65 | rank = get_rank()
66 | self.dataset = dataset
67 | self.num_replicas = num_replicas
68 | self.rank = rank
69 | self.epoch = 0
70 | self.consecutive_sample = consecutive_sample
71 | self.permutation = permutation
72 | if pad:
73 | self.num_samples = int(math.ceil(len(self.dataset) * 1.0 / self.num_replicas))
74 | else:
75 | self.num_samples = int(math.floor(len(self.dataset) * 1.0 / self.num_replicas))
76 | self.total_size = self.num_samples * self.num_replicas
77 |
78 | def __iter__(self):
79 | # deterministically shuffle based on epoch
80 | g = torch.Generator()
81 | g.manual_seed(self.epoch)
82 |
83 | if self.permutation:
84 | indices = list(torch.randperm(len(self.dataset), generator=g))
85 | else:
86 | indices = list([x for x in range(len(self.dataset))])
87 |
88 | # add extra samples to make it evenly divisible
89 | if self.total_size > len(indices):
90 | indices += indices[:(self.total_size - len(indices))]
91 |
92 | # subsample
93 | if self.consecutive_sample:
94 | offset = self.num_samples * self.rank
95 | indices = indices[offset:offset + self.num_samples]
96 | else:
97 | indices = indices[self.rank:self.total_size:self.num_replicas]
98 | assert len(indices) == self.num_samples
99 |
100 | return iter(indices)
101 |
102 | def __len__(self):
103 | return self.num_samples
104 |
105 | def set_epoch(self, epoch):
106 | self.epoch = epoch
107 |
108 | def set_num_samples(self):
109 | self.num_samples = int(math.ceil(len(self.dataset) * 1.0 / self.num_replicas))
110 | self.total_size = self.num_samples * self.num_replicas
--------------------------------------------------------------------------------
/datasets/uniform.py:
--------------------------------------------------------------------------------
1 | """
2 | Uniform sampling of classes.
3 | For all images, for all classes, generate centroids around which to sample.
4 |
5 | All images are divided into tiles.
6 | For each tile, a class can be present or not. If it is
7 | present, calculate the centroid of the class and record it.
8 |
9 | We would like to thank Peter Kontschieder for the inspiration of this idea.
10 | """
11 |
12 | import logging
13 | from collections import defaultdict
14 | from PIL import Image
15 | import numpy as np
16 | from scipy import ndimage
17 | from tqdm import tqdm
18 |
19 | pbar = None
20 |
21 | class Point():
22 | """
23 | Point Class For X and Y Location
24 | """
25 | def __init__(self, x, y):
26 | self.x = x
27 | self.y = y
28 |
29 |
30 | def calc_tile_locations(tile_size, image_size):
31 | """
32 | Divide an image into tiles to help us cover classes that are spread out.
33 | tile_size: size of tile to distribute
34 | image_size: original image size
35 | return: locations of the tiles
36 | """
37 | image_size_y, image_size_x = image_size
38 | locations = []
39 | for y in range(image_size_y // tile_size):
40 | for x in range(image_size_x // tile_size):
41 | x_offs = x * tile_size
42 | y_offs = y * tile_size
43 | locations.append((x_offs, y_offs))
44 | return locations
45 |
46 |
47 | def class_centroids_image(item, tile_size, num_classes, id2trainid):
48 | """
49 | For one image, calculate centroids for all classes present in image.
50 | item: image, image_name
51 | tile_size:
52 | num_classes:
53 | id2trainid: mapping from original id to training ids
54 | return: Centroids are calculated for each tile.
55 | """
56 | image_fn, label_fn = item
57 | centroids = defaultdict(list)
58 | mask = np.array(Image.open(label_fn))
59 | image_size = mask.shape
60 | tile_locations = calc_tile_locations(tile_size, image_size)
61 |
62 | mask_copy = mask.copy()
63 | if id2trainid:
64 | for k, v in id2trainid.items():
65 | mask[mask_copy == k] = v
66 |
67 | for x_offs, y_offs in tile_locations:
68 | patch = mask[y_offs:y_offs + tile_size, x_offs:x_offs + tile_size]
69 | for class_id in range(num_classes):
70 | if class_id in patch:
71 | patch_class = (patch == class_id).astype(int)
72 | centroid_y, centroid_x = ndimage.measurements.center_of_mass(patch_class)
73 | centroid_y = int(centroid_y) + y_offs
74 | centroid_x = int(centroid_x) + x_offs
75 | centroid = (centroid_x, centroid_y)
76 | centroids[class_id].append((image_fn, label_fn, centroid, class_id))
77 | pbar.update(1)
78 | return centroids
79 |
80 |
81 |
82 | def pooled_class_centroids_all(items, num_classes, id2trainid, tile_size=1024):
83 | """
84 | Calculate class centroids for all classes for all images for all tiles.
85 | items: list of (image_fn, label_fn)
86 | tile size: size of tile
87 | returns: dict that contains a list of centroids for each class
88 | """
89 | from multiprocessing.dummy import Pool
90 | from functools import partial
91 | pool = Pool(32)
92 | global pbar
93 | pbar = tqdm(total=len(items), desc='pooled centroid extraction')
94 | class_centroids_item = partial(class_centroids_image,
95 | num_classes=num_classes,
96 | id2trainid=id2trainid,
97 | tile_size=tile_size)
98 |
99 | centroids = defaultdict(list)
100 | new_centroids = pool.map(class_centroids_item, items)
101 | pool.close()
102 | pool.join()
103 |
104 | # combine each image's items into a single global dict
105 | for image_items in new_centroids:
106 | for class_id in image_items:
107 | centroids[class_id].extend(image_items[class_id])
108 | return centroids
109 |
110 |
111 | def unpooled_class_centroids_all(items, num_classes, tile_size=1024):
112 | """
113 | Calculate class centroids for all classes for all images for all tiles.
114 | items: list of (image_fn, label_fn)
115 | tile size: size of tile
116 | returns: dict that contains a list of centroids for each class
117 | """
118 | centroids = defaultdict(list)
119 | global pbar
120 | pbar = tqdm(total=len(items), desc='centroid extraction')
121 | for image, label in items:
122 | new_centroids = class_centroids_image((image, label),
123 | tile_size,
124 | num_classes)
125 | for class_id in new_centroids:
126 | centroids[class_id].extend(new_centroids[class_id])
127 |
128 | return centroids
129 |
130 |
131 | def class_centroids_all(items, num_classes, id2trainid, tile_size=1024):
132 | """
133 | intermediate function to call pooled_class_centroid
134 | """
135 |
136 | pooled_centroids = pooled_class_centroids_all(items, num_classes,
137 | id2trainid, tile_size)
138 | return pooled_centroids
139 |
140 |
141 | def random_sampling(alist, num):
142 | """
143 | Randomly sample num items from the list
144 | alist: list of centroids to sample from
145 | num: can be larger than the list and if so, then wrap around
146 | return: class uniform samples from the list
147 | """
148 | sampling = []
149 | len_list = len(alist)
150 | assert len_list, 'len_list is zero!'
151 | indices = np.arange(len_list)
152 | np.random.shuffle(indices)
153 |
154 | for i in range(num):
155 | item = alist[indices[i % len_list]]
156 | sampling.append(item)
157 | return sampling
158 |
159 |
160 | def build_epoch(imgs, centroids, num_classes, class_uniform_pct):
161 | """
162 | Generate an epochs-worth of crops using uniform sampling. Needs to be called every
163 | imgs: list of imgs
164 | centroids:
165 | num_classes:
166 | class_uniform_pct: class uniform sampling percent ( % of uniform images in one epoch )
167 | """
168 | logging.info("Class Uniform Percentage: %s", str(class_uniform_pct))
169 | num_epoch = int(len(imgs))
170 |
171 | logging.info('Class Uniform items per Epoch:%s', str(num_epoch))
172 | num_per_class = int((num_epoch * class_uniform_pct) / num_classes)
173 | num_rand = num_epoch - num_per_class * num_classes
174 | # create random crops
175 | imgs_uniform = random_sampling(imgs, num_rand)
176 |
177 | # now add uniform sampling
178 | for class_id in range(num_classes):
179 | string_format = "cls %d len %d"% (class_id, len(centroids[class_id]))
180 | logging.info(string_format)
181 | for class_id in range(num_classes):
182 | centroid_len = len(centroids[class_id])
183 | if centroid_len == 0:
184 | pass
185 | else:
186 | class_centroids = random_sampling(centroids[class_id], num_per_class)
187 | imgs_uniform.extend(class_centroids)
188 |
189 | return imgs_uniform
190 |
--------------------------------------------------------------------------------
/demo.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 | import argparse
4 | from PIL import Image
5 | import numpy as np
6 | import cv2
7 |
8 | import torch
9 | from torch.backends import cudnn
10 | import torchvision.transforms as transforms
11 |
12 | import network
13 | from optimizer import restore_snapshot
14 | from datasets import cityscapes
15 | from config import assert_and_infer_cfg
16 |
17 | parser = argparse.ArgumentParser(description='demo')
18 | parser.add_argument('--demo-image', type=str, default='', help='path to demo image', required=True)
19 | parser.add_argument('--snapshot', type=str, default='./pretrained_models/cityscapes_best.pth', help='pre-trained checkpoint', required=True)
20 | parser.add_argument('--arch', type=str, default='network.deepv3.DeepWV3Plus', help='network architecture used for inference')
21 | parser.add_argument('--save-dir', type=str, default='./save', help='path to save your results')
22 | args = parser.parse_args()
23 | assert_and_infer_cfg(args, train_mode=False)
24 | cudnn.benchmark = False
25 | torch.cuda.empty_cache()
26 |
27 | # get net
28 | args.dataset_cls = cityscapes
29 | net = network.get_net(args, criterion=None)
30 | net = torch.nn.DataParallel(net).cuda()
31 | print('Net built.')
32 | net, _ = restore_snapshot(net, optimizer=None, snapshot=args.snapshot, restore_optimizer_bool=False)
33 | net.eval()
34 | print('Net restored.')
35 |
36 | # get data
37 | mean_std = ([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
38 | img_transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize(*mean_std)])
39 | img = Image.open(args.demo_image).convert('RGB')
40 | img_tensor = img_transform(img)
41 |
42 | # predict
43 | with torch.no_grad():
44 | img = img_tensor.unsqueeze(0).cuda()
45 | pred = net(img)
46 | print('Inference done.')
47 |
48 | pred = pred.cpu().numpy().squeeze()
49 | pred = np.argmax(pred, axis=0)
50 |
51 | if not os.path.exists(args.save_dir):
52 | os.makedirs(args.save_dir)
53 |
54 | colorized = args.dataset_cls.colorize_mask(pred)
55 | colorized.save(os.path.join(args.save_dir, 'color_mask.png'))
56 |
57 | label_out = np.zeros_like(pred)
58 | for label_id, train_id in args.dataset_cls.id_to_trainid.items():
59 | label_out[np.where(pred == train_id)] = label_id
60 | cv2.imwrite(os.path.join(args.save_dir, 'pred_mask.png'), label_out)
61 | print('Results saved.')
62 |
--------------------------------------------------------------------------------
/demo_folder.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 | import time
4 | import argparse
5 | from PIL import Image
6 | import numpy as np
7 | import cv2
8 |
9 | import torch
10 | from torch.backends import cudnn
11 | import torchvision.transforms as transforms
12 |
13 | import network
14 | from optimizer import restore_snapshot
15 | from datasets import cityscapes
16 | from config import assert_and_infer_cfg
17 |
18 | parser = argparse.ArgumentParser(description='demo')
19 | parser.add_argument('--demo-folder', type=str, default='', help='path to the folder containing demo images', required=True)
20 | parser.add_argument('--snapshot', type=str, default='./pretrained_models/cityscapes_best.pth', help='pre-trained checkpoint', required=True)
21 | parser.add_argument('--arch', type=str, default='network.deepv3.DeepWV3Plus', help='network architecture used for inference')
22 | parser.add_argument('--save-dir', type=str, default='./save', help='path to save your results')
23 | args = parser.parse_args()
24 | assert_and_infer_cfg(args, train_mode=False)
25 | cudnn.benchmark = False
26 | torch.cuda.empty_cache()
27 |
28 | # get net
29 | args.dataset_cls = cityscapes
30 | net = network.get_net(args, criterion=None)
31 | net = torch.nn.DataParallel(net).cuda()
32 | print('Net built.')
33 | net, _ = restore_snapshot(net, optimizer=None, snapshot=args.snapshot, restore_optimizer_bool=False)
34 | net.eval()
35 | print('Net restored.')
36 |
37 | # get data
38 | data_dir = args.demo_folder
39 | images = os.listdir(data_dir)
40 | if len(images) == 0:
41 | print('There are no images at directory %s. Check the data path.' % (data_dir))
42 | else:
43 | print('There are %d images to be processed.' % (len(images)))
44 | images.sort()
45 |
46 | mean_std = ([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
47 | img_transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize(*mean_std)])
48 | if not os.path.exists(args.save_dir):
49 | os.makedirs(args.save_dir)
50 |
51 | start_time = time.time()
52 | for img_id, img_name in enumerate(images):
53 | img_dir = os.path.join(data_dir, img_name)
54 | img = Image.open(img_dir).convert('RGB')
55 | img_tensor = img_transform(img)
56 |
57 | # predict
58 | with torch.no_grad():
59 | pred = net(img_tensor.unsqueeze(0).cuda())
60 | print('%04d/%04d: Inference done.' % (img_id + 1, len(images)))
61 |
62 | pred = pred.cpu().numpy().squeeze()
63 | pred = np.argmax(pred, axis=0)
64 |
65 | color_name = 'color_mask_' + img_name
66 | overlap_name = 'overlap_' + img_name
67 | pred_name = 'pred_mask_' + img_name
68 |
69 | # save colorized predictions
70 | colorized = args.dataset_cls.colorize_mask(pred)
71 | colorized.save(os.path.join(args.save_dir, color_name))
72 |
73 | # save colorized predictions overlapped on original images
74 | overlap = cv2.addWeighted(np.array(img), 0.5, np.array(colorized.convert('RGB')), 0.5, 0)
75 | cv2.imwrite(os.path.join(args.save_dir, overlap_name), overlap[:, :, ::-1])
76 |
77 | # save label-based predictions, e.g. for submission purpose
78 | label_out = np.zeros_like(pred)
79 | for label_id, train_id in args.dataset_cls.id_to_trainid.items():
80 | label_out[np.where(pred == train_id)] = label_id
81 | cv2.imwrite(os.path.join(args.save_dir, pred_name), label_out)
82 | end_time = time.time()
83 |
84 | print('Results saved.')
85 | print('Inference takes %4.2f seconds, which is %4.2f seconds per image, including saving results.' % (end_time - start_time, (end_time - start_time)/len(images)))
86 |
--------------------------------------------------------------------------------
/images/method.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/geohot/semantic-segmentation/d5b6ec7ec9d4e296fc0e1aa85163ed0e98f54f92/images/method.png
--------------------------------------------------------------------------------
/images/vis.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/geohot/semantic-segmentation/d5b6ec7ec9d4e296fc0e1aa85163ed0e98f54f92/images/vis.png
--------------------------------------------------------------------------------
/loss.py:
--------------------------------------------------------------------------------
1 | """
2 | Loss.py
3 | """
4 |
5 | import logging
6 | import numpy as np
7 | import torch
8 | import torch.nn as nn
9 | import torch.nn.functional as F
10 | from config import cfg
11 |
12 |
13 | def get_loss(args):
14 | """
15 | Get the criterion based on the loss function
16 | args: commandline arguments
17 | return: criterion, criterion_val
18 | """
19 |
20 | if args.img_wt_loss:
21 | criterion = ImageBasedCrossEntropyLoss2d(
22 | classes=args.dataset_cls.num_classes, size_average=True,
23 | ignore_index=args.dataset_cls.ignore_label,
24 | upper_bound=args.wt_bound).cuda()
25 | elif args.jointwtborder:
26 | criterion = ImgWtLossSoftNLL(classes=args.dataset_cls.num_classes,
27 | ignore_index=args.dataset_cls.ignore_label,
28 | upper_bound=args.wt_bound).cuda()
29 | else:
30 | criterion = CrossEntropyLoss2d(size_average=True,
31 | ignore_index=args.dataset_cls.ignore_label).cuda()
32 |
33 | criterion_val = CrossEntropyLoss2d(size_average=True,
34 | weight=None,
35 | ignore_index=args.dataset_cls.ignore_label).cuda()
36 | return criterion, criterion_val
37 |
38 |
39 | class ImageBasedCrossEntropyLoss2d(nn.Module):
40 | """
41 | Image Weighted Cross Entropy Loss
42 | """
43 |
44 | def __init__(self, classes, weight=None, size_average=True, ignore_index=255,
45 | norm=False, upper_bound=1.0):
46 | super(ImageBasedCrossEntropyLoss2d, self).__init__()
47 | logging.info("Using Per Image based weighted loss")
48 | self.num_classes = classes
49 | self.nll_loss = nn.NLLLoss2d(weight, size_average, ignore_index)
50 | self.norm = norm
51 | self.upper_bound = upper_bound
52 | self.batch_weights = cfg.BATCH_WEIGHTING
53 |
54 | def calculate_weights(self, target):
55 | """
56 | Calculate weights of classes based on the training crop
57 | """
58 | hist = np.histogram(target.flatten(), range(
59 | self.num_classes + 1), normed=True)[0]
60 | if self.norm:
61 | hist = ((hist != 0) * self.upper_bound * (1 / hist)) + 1
62 | else:
63 | hist = ((hist != 0) * self.upper_bound * (1 - hist)) + 1
64 | return hist
65 |
66 | def forward(self, inputs, targets):
67 |
68 | target_cpu = targets.data.cpu().numpy()
69 | if self.batch_weights:
70 | weights = self.calculate_weights(target_cpu)
71 | self.nll_loss.weight = torch.Tensor(weights).cuda()
72 |
73 | loss = 0.0
74 | for i in range(0, inputs.shape[0]):
75 | if not self.batch_weights:
76 | weights = self.calculate_weights(target_cpu[i])
77 | self.nll_loss.weight = torch.Tensor(weights).cuda()
78 |
79 | loss += self.nll_loss(F.log_softmax(inputs[i].unsqueeze(0)),
80 | targets[i].unsqueeze(0))
81 | return loss
82 |
83 |
84 |
85 | class CrossEntropyLoss2d(nn.Module):
86 | """
87 | Cross Entroply NLL Loss
88 | """
89 |
90 | def __init__(self, weight=None, size_average=True, ignore_index=255):
91 | super(CrossEntropyLoss2d, self).__init__()
92 | logging.info("Using Cross Entropy Loss")
93 | self.nll_loss = nn.NLLLoss2d(weight, size_average, ignore_index)
94 | # self.weight = weight
95 |
96 | def forward(self, inputs, targets):
97 | return self.nll_loss(F.log_softmax(inputs), targets)
98 |
99 | def customsoftmax(inp, multihotmask):
100 | """
101 | Custom Softmax
102 | """
103 | soft = F.softmax(inp)
104 | # This takes the mask * softmax ( sums it up hence summing up the classes in border
105 | # then takes of summed up version vs no summed version
106 | return torch.log(
107 | torch.max(soft, (multihotmask * (soft * multihotmask).sum(1, keepdim=True)))
108 | )
109 |
110 | class ImgWtLossSoftNLL(nn.Module):
111 | """
112 | Relax Loss
113 | """
114 |
115 | def __init__(self, classes, ignore_index=255, weights=None, upper_bound=1.0,
116 | norm=False):
117 | super(ImgWtLossSoftNLL, self).__init__()
118 | self.weights = weights
119 | self.num_classes = classes
120 | self.ignore_index = ignore_index
121 | self.upper_bound = upper_bound
122 | self.norm = norm
123 | self.batch_weights = cfg.BATCH_WEIGHTING
124 | self.fp16 = False
125 |
126 |
127 | def calculate_weights(self, target):
128 | """
129 | Calculate weights of the classes based on training crop
130 | """
131 | if len(target.shape) == 3:
132 | hist = np.sum(target, axis=(1, 2)) * 1.0 / target.sum()
133 | else:
134 | hist = np.sum(target, axis=(0, 2, 3)) * 1.0 / target.sum()
135 | if self.norm:
136 | hist = ((hist != 0) * self.upper_bound * (1 / hist)) + 1
137 | else:
138 | hist = ((hist != 0) * self.upper_bound * (1 - hist)) + 1
139 | return hist[:-1]
140 |
141 | def custom_nll(self, inputs, target, class_weights, border_weights, mask):
142 | """
143 | NLL Relaxed Loss Implementation
144 | """
145 | if (cfg.REDUCE_BORDER_EPOCH != -1 and cfg.EPOCH > cfg.REDUCE_BORDER_EPOCH):
146 | border_weights = 1 / border_weights
147 | target[target > 1] = 1
148 | if self.fp16:
149 | loss_matrix = (-1 / border_weights *
150 | (target[:, :-1, :, :].half() *
151 | class_weights.unsqueeze(0).unsqueeze(2).unsqueeze(3) *
152 | customsoftmax(inputs, target[:, :-1, :, :].half())).sum(1)) * \
153 | (1. - mask.half())
154 | else:
155 | loss_matrix = (-1 / border_weights *
156 | (target[:, :-1, :, :].float() *
157 | class_weights.unsqueeze(0).unsqueeze(2).unsqueeze(3) *
158 | customsoftmax(inputs, target[:, :-1, :, :].float())).sum(1)) * \
159 | (1. - mask.float())
160 |
161 | # loss_matrix[border_weights > 1] = 0
162 | loss = loss_matrix.sum()
163 |
164 | # +1 to prevent division by 0
165 | loss = loss / (target.shape[0] * target.shape[2] * target.shape[3] - mask.sum().item() + 1)
166 | return loss
167 |
168 | def forward(self, inputs, target):
169 | if self.fp16:
170 | weights = target[:, :-1, :, :].sum(1).half()
171 | else:
172 | weights = target[:, :-1, :, :].sum(1).float()
173 | ignore_mask = (weights == 0)
174 | weights[ignore_mask] = 1
175 |
176 | loss = 0
177 | target_cpu = target.data.cpu().numpy()
178 |
179 | if self.batch_weights:
180 | class_weights = self.calculate_weights(target_cpu)
181 |
182 | for i in range(0, inputs.shape[0]):
183 | if not self.batch_weights:
184 | class_weights = self.calculate_weights(target_cpu[i])
185 | loss = loss + self.custom_nll(inputs[i].unsqueeze(0),
186 | target[i].unsqueeze(0),
187 | class_weights=torch.Tensor(class_weights).cuda(),
188 | border_weights=weights, mask=ignore_mask[i])
189 |
190 | return loss
191 |
--------------------------------------------------------------------------------
/network/Resnet.py:
--------------------------------------------------------------------------------
1 | """
2 | # Code Adapted from:
3 | # https://github.com/pytorch/vision/blob/master/torchvision/models/resnet.py
4 | #
5 | # BSD 3-Clause License
6 | #
7 | # Copyright (c) 2017,
8 | # All rights reserved.
9 | #
10 | # Redistribution and use in source and binary forms, with or without
11 | # modification, are permitted provided that the following conditions are met:
12 | #
13 | # * Redistributions of source code must retain the above copyright notice, this
14 | # list of conditions and the following disclaimer.
15 | #
16 | # * Redistributions in binary form must reproduce the above copyright notice,
17 | # this list of conditions and the following disclaimer in the documentation
18 | # and/or other materials provided with the distribution.
19 | #
20 | # * Neither the name of the copyright holder nor the names of its
21 | # contributors may be used to endorse or promote products derived from
22 | # this software without specific prior written permission.
23 | #
24 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
25 | # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
26 | # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
27 | # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
28 | # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
29 | # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
30 | # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
31 | # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
32 | # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
33 | # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
34 | """
35 |
36 | import torch.nn as nn
37 | import torch.utils.model_zoo as model_zoo
38 | import network.mynn as mynn
39 |
40 | __all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101',
41 | 'resnet152']
42 |
43 |
44 | model_urls = {
45 | 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth',
46 | 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth',
47 | 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth',
48 | 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth',
49 | 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth',
50 | }
51 |
52 |
53 | def conv3x3(in_planes, out_planes, stride=1):
54 | """3x3 convolution with padding"""
55 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
56 | padding=1, bias=False)
57 |
58 |
59 | class BasicBlock(nn.Module):
60 | """
61 | Basic Block for Resnet
62 | """
63 | expansion = 1
64 |
65 | def __init__(self, inplanes, planes, stride=1, downsample=None):
66 | super(BasicBlock, self).__init__()
67 | self.conv1 = conv3x3(inplanes, planes, stride)
68 | self.bn1 = mynn.Norm2d(planes)
69 | self.relu = nn.ReLU(inplace=True)
70 | self.conv2 = conv3x3(planes, planes)
71 | self.bn2 = mynn.Norm2d(planes)
72 | self.downsample = downsample
73 | self.stride = stride
74 |
75 | def forward(self, x):
76 | residual = x
77 |
78 | out = self.conv1(x)
79 | out = self.bn1(out)
80 | out = self.relu(out)
81 |
82 | out = self.conv2(out)
83 | out = self.bn2(out)
84 |
85 | if self.downsample is not None:
86 | residual = self.downsample(x)
87 |
88 | out += residual
89 | out = self.relu(out)
90 |
91 | return out
92 |
93 |
94 | class Bottleneck(nn.Module):
95 | """
96 | Bottleneck Layer for Resnet
97 | """
98 | expansion = 4
99 |
100 | def __init__(self, inplanes, planes, stride=1, downsample=None):
101 | super(Bottleneck, self).__init__()
102 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
103 | self.bn1 = mynn.Norm2d(planes)
104 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride,
105 | padding=1, bias=False)
106 | self.bn2 = mynn.Norm2d(planes)
107 | self.conv3 = nn.Conv2d(planes, planes * self.expansion, kernel_size=1, bias=False)
108 | self.bn3 = mynn.Norm2d(planes * self.expansion)
109 | self.relu = nn.ReLU(inplace=True)
110 | self.downsample = downsample
111 | self.stride = stride
112 |
113 | def forward(self, x):
114 | residual = x
115 |
116 | out = self.conv1(x)
117 | out = self.bn1(out)
118 | out = self.relu(out)
119 |
120 | out = self.conv2(out)
121 | out = self.bn2(out)
122 | out = self.relu(out)
123 |
124 | out = self.conv3(out)
125 | out = self.bn3(out)
126 |
127 | if self.downsample is not None:
128 | residual = self.downsample(x)
129 |
130 | out += residual
131 | out = self.relu(out)
132 |
133 | return out
134 |
135 |
136 | class ResNet(nn.Module):
137 | """
138 | Resnet Global Module for Initialization
139 | """
140 | def __init__(self, block, layers, num_classes=1000):
141 | self.inplanes = 64
142 | super(ResNet, self).__init__()
143 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3,
144 | bias=False)
145 | self.bn1 = mynn.Norm2d(64)
146 | self.relu = nn.ReLU(inplace=True)
147 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
148 | self.layer1 = self._make_layer(block, 64, layers[0])
149 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
150 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
151 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
152 | self.avgpool = nn.AvgPool2d(7, stride=1)
153 | self.fc = nn.Linear(512 * block.expansion, num_classes)
154 |
155 | for m in self.modules():
156 | if isinstance(m, nn.Conv2d):
157 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
158 | elif isinstance(m, nn.BatchNorm2d):
159 | nn.init.constant_(m.weight, 1)
160 | nn.init.constant_(m.bias, 0)
161 |
162 | def _make_layer(self, block, planes, blocks, stride=1):
163 | downsample = None
164 | if stride != 1 or self.inplanes != planes * block.expansion:
165 | downsample = nn.Sequential(
166 | nn.Conv2d(self.inplanes, planes * block.expansion,
167 | kernel_size=1, stride=stride, bias=False),
168 | mynn.Norm2d(planes * block.expansion),
169 | )
170 |
171 | layers = []
172 | layers.append(block(self.inplanes, planes, stride, downsample))
173 | self.inplanes = planes * block.expansion
174 | for index in range(1, blocks):
175 | layers.append(block(self.inplanes, planes))
176 |
177 | return nn.Sequential(*layers)
178 |
179 | def forward(self, x):
180 | x = self.conv1(x)
181 | x = self.bn1(x)
182 | x = self.relu(x)
183 | x = self.maxpool(x)
184 |
185 | x = self.layer1(x)
186 | x = self.layer2(x)
187 | x = self.layer3(x)
188 | x = self.layer4(x)
189 |
190 | x = self.avgpool(x)
191 | x = x.view(x.size(0), -1)
192 | x = self.fc(x)
193 |
194 | return x
195 |
196 |
197 | def resnet18(pretrained=True, **kwargs):
198 | """Constructs a ResNet-18 model.
199 |
200 | Args:
201 | pretrained (bool): If True, returns a model pre-trained on ImageNet
202 | """
203 | model = ResNet(BasicBlock, [2, 2, 2, 2], **kwargs)
204 | if pretrained:
205 | model.load_state_dict(model_zoo.load_url(model_urls['resnet18']))
206 | return model
207 |
208 |
209 | def resnet34(pretrained=True, **kwargs):
210 | """Constructs a ResNet-34 model.
211 |
212 | Args:
213 | pretrained (bool): If True, returns a model pre-trained on ImageNet
214 | """
215 | model = ResNet(BasicBlock, [3, 4, 6, 3], **kwargs)
216 | if pretrained:
217 | model.load_state_dict(model_zoo.load_url(model_urls['resnet34']))
218 | return model
219 |
220 |
221 | def resnet50(pretrained=True, **kwargs):
222 | """Constructs a ResNet-50 model.
223 |
224 | Args:
225 | pretrained (bool): If True, returns a model pre-trained on ImageNet
226 | """
227 | model = ResNet(Bottleneck, [3, 4, 6, 3], **kwargs)
228 | if pretrained:
229 | model.load_state_dict(model_zoo.load_url(model_urls['resnet50']))
230 | return model
231 |
232 |
233 | def resnet101(pretrained=True, **kwargs):
234 | """Constructs a ResNet-101 model.
235 |
236 | Args:
237 | pretrained (bool): If True, returns a model pre-trained on ImageNet
238 | """
239 | model = ResNet(Bottleneck, [3, 4, 23, 3], **kwargs)
240 | if pretrained:
241 | model.load_state_dict(model_zoo.load_url(model_urls['resnet101']))
242 | return model
243 |
244 |
245 | def resnet152(pretrained=True, **kwargs):
246 | """Constructs a ResNet-152 model.
247 |
248 | Args:
249 | pretrained (bool): If True, returns a model pre-trained on ImageNet
250 | """
251 | model = ResNet(Bottleneck, [3, 8, 36, 3], **kwargs)
252 | if pretrained:
253 | model.load_state_dict(model_zoo.load_url(model_urls['resnet152']))
254 | return model
255 |
--------------------------------------------------------------------------------
/network/__init__.py:
--------------------------------------------------------------------------------
1 | """
2 | Network Initializations
3 | """
4 |
5 | import logging
6 | import importlib
7 | import torch
8 |
9 |
10 |
11 | def get_net(args, criterion):
12 | """
13 | Get Network Architecture based on arguments provided
14 | """
15 | net = get_model(network=args.arch, num_classes=args.dataset_cls.num_classes,
16 | criterion=criterion)
17 | num_params = sum([param.nelement() for param in net.parameters()])
18 | logging.info('Model params = {:2.1f}M'.format(num_params / 1000000))
19 |
20 | net = net.cuda()
21 | return net
22 |
23 |
24 | def wrap_network_in_dataparallel(net, use_apex_data_parallel=False):
25 | """
26 | Wrap the network in Dataparallel
27 | """
28 | if use_apex_data_parallel:
29 | import apex
30 | net = apex.parallel.DistributedDataParallel(net)
31 | else:
32 | net = torch.nn.DataParallel(net)
33 | return net
34 |
35 |
36 | def get_model(network, num_classes, criterion):
37 | """
38 | Fetch Network Function Pointer
39 | """
40 | module = network[:network.rfind('.')]
41 | model = network[network.rfind('.') + 1:]
42 | mod = importlib.import_module(module)
43 | net_func = getattr(mod, model)
44 | net = net_func(num_classes=num_classes, criterion=criterion)
45 | return net
46 |
--------------------------------------------------------------------------------
/network/deepv3.py:
--------------------------------------------------------------------------------
1 | """
2 | # Code Adapted from:
3 | # https://github.com/sthalles/deeplab_v3
4 | #
5 | # MIT License
6 | #
7 | # Copyright (c) 2018 Thalles Santos Silva
8 | #
9 | # Permission is hereby granted, free of charge, to any person obtaining a copy
10 | # of this software and associated documentation files (the "Software"), to deal
11 | # in the Software without restriction, including without limitation the rights
12 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
13 | # copies of the Software, and to permit persons to whom the Software is
14 | # furnished to do so, subject to the following conditions:
15 | #
16 | # The above copyright notice and this permission notice shall be included in all
17 | # copies or substantial portions of the Software.
18 | #
19 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
20 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
21 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
22 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
23 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
24 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
25 | """
26 | import logging
27 | import torch
28 | from torch import nn
29 | from network import SEresnext
30 | from network import Resnet
31 | from network.wider_resnet import wider_resnet38_a2
32 | from network.mynn import initialize_weights, Norm2d, Upsample
33 |
34 |
35 | class _AtrousSpatialPyramidPoolingModule(nn.Module):
36 | """
37 | operations performed:
38 | 1x1 x depth
39 | 3x3 x depth dilation 6
40 | 3x3 x depth dilation 12
41 | 3x3 x depth dilation 18
42 | image pooling
43 | concatenate all together
44 | Final 1x1 conv
45 | """
46 |
47 | def __init__(self, in_dim, reduction_dim=256, output_stride=16, rates=(6, 12, 18)):
48 | super(_AtrousSpatialPyramidPoolingModule, self).__init__()
49 |
50 | if output_stride == 8:
51 | rates = [2 * r for r in rates]
52 | elif output_stride == 16:
53 | pass
54 | else:
55 | raise 'output stride of {} not supported'.format(output_stride)
56 |
57 | self.features = []
58 | # 1x1
59 | self.features.append(
60 | nn.Sequential(nn.Conv2d(in_dim, reduction_dim, kernel_size=1, bias=False),
61 | Norm2d(reduction_dim), nn.ReLU(inplace=True)))
62 | # other rates
63 | for r in rates:
64 | self.features.append(nn.Sequential(
65 | nn.Conv2d(in_dim, reduction_dim, kernel_size=3,
66 | dilation=r, padding=r, bias=False),
67 | Norm2d(reduction_dim),
68 | nn.ReLU(inplace=True)
69 | ))
70 | self.features = torch.nn.ModuleList(self.features)
71 |
72 | # img level features
73 | self.img_pooling = nn.AdaptiveAvgPool2d(1)
74 | self.img_conv = nn.Sequential(
75 | nn.Conv2d(in_dim, reduction_dim, kernel_size=1, bias=False),
76 | Norm2d(reduction_dim), nn.ReLU(inplace=True))
77 |
78 | def forward(self, x):
79 | x_size = x.size()
80 |
81 | img_features = self.img_pooling(x)
82 | img_features = self.img_conv(img_features)
83 | img_features = Upsample(img_features, x_size[2:])
84 | out = img_features
85 |
86 | for f in self.features:
87 | y = f(x)
88 | out = torch.cat((out, y), 1)
89 | return out
90 |
91 |
92 | class DeepV3Plus(nn.Module):
93 | """
94 | Implement DeepLabV3 model
95 | A: stride8
96 | B: stride16
97 | with skip connections
98 | """
99 |
100 | def __init__(self, num_classes, trunk='seresnext-50', criterion=None, variant='D',
101 | skip='m1', skip_num=48):
102 | super(DeepV3Plus, self).__init__()
103 | self.criterion = criterion
104 | self.variant = variant
105 | self.skip = skip
106 | self.skip_num = skip_num
107 |
108 | if trunk == 'seresnext-50':
109 | resnet = SEresnext.se_resnext50_32x4d()
110 | elif trunk == 'seresnext-101':
111 | resnet = SEresnext.se_resnext101_32x4d()
112 | elif trunk == 'resnet-50':
113 | resnet = Resnet.resnet50()
114 | resnet.layer0 = nn.Sequential(resnet.conv1, resnet.bn1, resnet.relu, resnet.maxpool)
115 | elif trunk == 'resnet-101':
116 | resnet = Resnet.resnet101()
117 | resnet.layer0 = nn.Sequential(resnet.conv1, resnet.bn1, resnet.relu, resnet.maxpool)
118 | else:
119 | raise ValueError("Not a valid network arch")
120 |
121 | self.layer0 = resnet.layer0
122 | self.layer1, self.layer2, self.layer3, self.layer4 = \
123 | resnet.layer1, resnet.layer2, resnet.layer3, resnet.layer4
124 |
125 | if self.variant == 'D':
126 | for n, m in self.layer3.named_modules():
127 | if 'conv2' in n:
128 | m.dilation, m.padding, m.stride = (2, 2), (2, 2), (1, 1)
129 | elif 'downsample.0' in n:
130 | m.stride = (1, 1)
131 | for n, m in self.layer4.named_modules():
132 | if 'conv2' in n:
133 | m.dilation, m.padding, m.stride = (4, 4), (4, 4), (1, 1)
134 | elif 'downsample.0' in n:
135 | m.stride = (1, 1)
136 | elif self.variant == 'D16':
137 | for n, m in self.layer4.named_modules():
138 | if 'conv2' in n:
139 | m.dilation, m.padding, m.stride = (2, 2), (2, 2), (1, 1)
140 | elif 'downsample.0' in n:
141 | m.stride = (1, 1)
142 | else:
143 | print("Not using Dilation ")
144 |
145 | self.aspp = _AtrousSpatialPyramidPoolingModule(2048, 256,
146 | output_stride=8)
147 |
148 | if self.skip == 'm1':
149 | self.bot_fine = nn.Conv2d(256, self.skip_num, kernel_size=1, bias=False)
150 | elif self.skip == 'm2':
151 | self.bot_fine = nn.Conv2d(512, self.skip_num, kernel_size=1, bias=False)
152 | else:
153 | raise Exception('Not a valid skip')
154 |
155 | self.bot_aspp = nn.Conv2d(1280, 256, kernel_size=1, bias=False)
156 |
157 | self.final = nn.Sequential(
158 | nn.Conv2d(256 + self.skip_num, 256, kernel_size=3, padding=1, bias=False),
159 | Norm2d(256),
160 | nn.ReLU(inplace=True),
161 | nn.Conv2d(256, 256, kernel_size=3, padding=1, bias=False),
162 | Norm2d(256),
163 | nn.ReLU(inplace=True),
164 | nn.Conv2d(256, num_classes, kernel_size=1, bias=False))
165 |
166 | initialize_weights(self.aspp)
167 | initialize_weights(self.bot_aspp)
168 | initialize_weights(self.bot_fine)
169 | initialize_weights(self.final)
170 |
171 | def forward(self, x, gts=None):
172 |
173 | x_size = x.size() # 800
174 | x0 = self.layer0(x) # 400
175 | x1 = self.layer1(x0) # 400
176 | x2 = self.layer2(x1) # 100
177 | x3 = self.layer3(x2) # 100
178 | x4 = self.layer4(x3) # 100
179 | xp = self.aspp(x4)
180 |
181 | dec0_up = self.bot_aspp(xp)
182 | if self.skip == 'm1':
183 | dec0_fine = self.bot_fine(x1)
184 | dec0_up = Upsample(dec0_up, x1.size()[2:])
185 | else:
186 | dec0_fine = self.bot_fine(x2)
187 | dec0_up = Upsample(dec0_up, x2.size()[2:])
188 |
189 | dec0 = [dec0_fine, dec0_up]
190 | dec0 = torch.cat(dec0, 1)
191 | dec1 = self.final(dec0)
192 | main_out = Upsample(dec1, x_size[2:])
193 |
194 | if self.training:
195 | return self.criterion(main_out, gts)
196 |
197 | return main_out
198 |
199 |
200 | class DeepWV3Plus(nn.Module):
201 | """
202 | WideResNet38 version of DeepLabV3
203 | mod1
204 | pool2
205 | mod2 bot_fine
206 | pool3
207 | mod3-7
208 | bot_aspp
209 |
210 | structure: [3, 3, 6, 3, 1, 1]
211 | channels = [(128, 128), (256, 256), (512, 512), (512, 1024), (512, 1024, 2048),
212 | (1024, 2048, 4096)]
213 | """
214 |
215 | def __init__(self, num_classes, trunk='WideResnet38', criterion=None):
216 |
217 | super(DeepWV3Plus, self).__init__()
218 | self.criterion = criterion
219 | logging.info("Trunk: %s", trunk)
220 |
221 | wide_resnet = wider_resnet38_a2(classes=1000, dilation=True)
222 | wide_resnet = torch.nn.DataParallel(wide_resnet)
223 | if criterion is not None:
224 | try:
225 | checkpoint = torch.load('./pretrained_models/wider_resnet38.pth.tar', map_location='cpu')
226 | wide_resnet.load_state_dict(checkpoint['state_dict'])
227 | del checkpoint
228 | except:
229 | print("Please download the ImageNet weights of WideResNet38 in our repo to ./pretrained_models/wider_resnet38.pth.tar.")
230 | raise RuntimeError("=====================Could not load ImageNet weights of WideResNet38 network.=======================")
231 | wide_resnet = wide_resnet.module
232 |
233 | self.mod1 = wide_resnet.mod1
234 | self.mod2 = wide_resnet.mod2
235 | self.mod3 = wide_resnet.mod3
236 | self.mod4 = wide_resnet.mod4
237 | self.mod5 = wide_resnet.mod5
238 | self.mod6 = wide_resnet.mod6
239 | self.mod7 = wide_resnet.mod7
240 | self.pool2 = wide_resnet.pool2
241 | self.pool3 = wide_resnet.pool3
242 | del wide_resnet
243 |
244 | self.aspp = _AtrousSpatialPyramidPoolingModule(4096, 256,
245 | output_stride=8)
246 |
247 | self.bot_fine = nn.Conv2d(128, 48, kernel_size=1, bias=False)
248 | self.bot_aspp = nn.Conv2d(1280, 256, kernel_size=1, bias=False)
249 |
250 | self.final = nn.Sequential(
251 | nn.Conv2d(256 + 48, 256, kernel_size=3, padding=1, bias=False),
252 | Norm2d(256),
253 | nn.ReLU(inplace=True),
254 | nn.Conv2d(256, 256, kernel_size=3, padding=1, bias=False),
255 | Norm2d(256),
256 | nn.ReLU(inplace=True),
257 | nn.Conv2d(256, num_classes, kernel_size=1, bias=False))
258 |
259 | initialize_weights(self.final)
260 |
261 | def forward(self, inp, gts=None):
262 |
263 | x_size = inp.size()
264 | x = self.mod1(inp)
265 | m2 = self.mod2(self.pool2(x))
266 | x = self.mod3(self.pool3(m2))
267 | x = self.mod4(x)
268 | x = self.mod5(x)
269 | x = self.mod6(x)
270 | x = self.mod7(x)
271 | x = self.aspp(x)
272 | dec0_up = self.bot_aspp(x)
273 |
274 | dec0_fine = self.bot_fine(m2)
275 | dec0_up = Upsample(dec0_up, m2.size()[2:])
276 | dec0 = [dec0_fine, dec0_up]
277 | dec0 = torch.cat(dec0, 1)
278 |
279 | dec1 = self.final(dec0)
280 | out = Upsample(dec1, x_size[2:])
281 |
282 | if self.training:
283 | return self.criterion(out, gts)
284 |
285 | return out
286 |
287 |
288 | def DeepSRNX50V3PlusD_m1(num_classes, criterion):
289 | """
290 | SEResNeXt-50 Based Network
291 | """
292 | return DeepV3Plus(num_classes, trunk='seresnext-50', criterion=criterion, variant='D',
293 | skip='m1')
294 |
295 | def DeepR50V3PlusD_m1(num_classes, criterion):
296 | """
297 | ResNet-50 Based Network
298 | """
299 | return DeepV3Plus(num_classes, trunk='resnet-50', criterion=criterion, variant='D', skip='m1')
300 |
301 |
302 | def DeepSRNX101V3PlusD_m1(num_classes, criterion):
303 | """
304 | SEResNeXt-101 Based Network
305 | """
306 | return DeepV3Plus(num_classes, trunk='seresnext-101', criterion=criterion, variant='D',
307 | skip='m1')
308 |
309 |
--------------------------------------------------------------------------------
/network/mynn.py:
--------------------------------------------------------------------------------
1 | """
2 | Custom Norm wrappers to enable sync BN, regular BN and for weight initialization
3 | """
4 | import torch.nn as nn
5 | from config import cfg
6 |
7 | from apex import amp
8 |
9 | def Norm2d(in_channels):
10 | """
11 | Custom Norm Function to allow flexible switching
12 | """
13 | layer = getattr(cfg.MODEL, 'BNFUNC')
14 | normalization_layer = layer(in_channels)
15 | return normalization_layer
16 |
17 |
18 | def initialize_weights(*models):
19 | """
20 | Initialize Model Weights
21 | """
22 | for model in models:
23 | for module in model.modules():
24 | if isinstance(module, (nn.Conv2d, nn.Linear)):
25 | nn.init.kaiming_normal_(module.weight)
26 | if module.bias is not None:
27 | module.bias.data.zero_()
28 | elif isinstance(module, nn.BatchNorm2d):
29 | module.weight.data.fill_(1)
30 | module.bias.data.zero_()
31 |
32 |
33 | @amp.float_function
34 | def Upsample(x, size):
35 | """
36 | Wrapper Around the Upsample Call
37 | """
38 | return nn.functional.interpolate(x, size=size, mode='bilinear',
39 | align_corners=True)
40 |
--------------------------------------------------------------------------------
/optimizer.py:
--------------------------------------------------------------------------------
1 | """
2 | Pytorch Optimizer and Scheduler Related Task
3 | """
4 | import math
5 | import logging
6 | import torch
7 | from torch import optim
8 | from config import cfg
9 |
10 |
11 | def get_optimizer(args, net):
12 | """
13 | Decide Optimizer (Adam or SGD)
14 | """
15 | param_groups = net.parameters()
16 |
17 | if args.sgd:
18 | optimizer = optim.SGD(param_groups,
19 | lr=args.lr,
20 | weight_decay=args.weight_decay,
21 | momentum=args.momentum,
22 | nesterov=False)
23 | elif args.adam:
24 | amsgrad = False
25 | if args.amsgrad:
26 | amsgrad = True
27 | optimizer = optim.Adam(param_groups,
28 | lr=args.lr,
29 | weight_decay=args.weight_decay,
30 | amsgrad=amsgrad
31 | )
32 | else:
33 | raise ValueError('Not a valid optimizer')
34 |
35 | if args.lr_schedule == 'scl-poly':
36 | if cfg.REDUCE_BORDER_EPOCH == -1:
37 | raise ValueError('ERROR Cannot Do Scale Poly')
38 |
39 | rescale_thresh = cfg.REDUCE_BORDER_EPOCH
40 | scale_value = args.rescale
41 | lambda1 = lambda epoch: \
42 | math.pow(1 - epoch / args.max_epoch,
43 | args.poly_exp) if epoch < rescale_thresh else scale_value * math.pow(
44 | 1 - (epoch - rescale_thresh) / (args.max_epoch - rescale_thresh),
45 | args.repoly)
46 | scheduler = optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda1)
47 | elif args.lr_schedule == 'poly':
48 | lambda1 = lambda epoch: math.pow(1 - epoch / args.max_epoch, args.poly_exp)
49 | scheduler = optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda1)
50 | else:
51 | raise ValueError('unknown lr schedule {}'.format(args.lr_schedule))
52 |
53 | return optimizer, scheduler
54 |
55 |
56 | def load_weights(net, optimizer, snapshot_file, restore_optimizer_bool=False):
57 | """
58 | Load weights from snapshot file
59 | """
60 | logging.info("Loading weights from model %s", snapshot_file)
61 | net, optimizer = restore_snapshot(net, optimizer, snapshot_file, restore_optimizer_bool)
62 | return net, optimizer
63 |
64 |
65 | def restore_snapshot(net, optimizer, snapshot, restore_optimizer_bool):
66 | """
67 | Restore weights and optimizer (if needed ) for resuming job.
68 | """
69 | checkpoint = torch.load(snapshot, map_location=torch.device('cpu'))
70 | logging.info("Checkpoint Load Compelete")
71 | if optimizer is not None and 'optimizer' in checkpoint and restore_optimizer_bool:
72 | optimizer.load_state_dict(checkpoint['optimizer'])
73 |
74 | if 'state_dict' in checkpoint:
75 | net = forgiving_state_restore(net, checkpoint['state_dict'])
76 | else:
77 | net = forgiving_state_restore(net, checkpoint)
78 |
79 | return net, optimizer
80 |
81 |
82 | def forgiving_state_restore(net, loaded_dict):
83 | """
84 | Handle partial loading when some tensors don't match up in size.
85 | Because we want to use models that were trained off a different
86 | number of classes.
87 | """
88 | net_state_dict = net.state_dict()
89 | new_loaded_dict = {}
90 | for k in net_state_dict:
91 | if k in loaded_dict and net_state_dict[k].size() == loaded_dict[k].size():
92 | new_loaded_dict[k] = loaded_dict[k]
93 | else:
94 | logging.info("Skipped loading parameter %s", k)
95 | net_state_dict.update(new_loaded_dict)
96 | net.load_state_dict(net_state_dict)
97 | return net
98 |
--------------------------------------------------------------------------------
/scripts/eval_cityscapes_SEResNeXt50.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 | echo "Running inference on" ${1}
3 | echo "Saving Results :" ${2}
4 | PYTHONPATH=$PWD:$PYTHONPATH python3 eval.py \
5 | --dataset cityscapes \
6 | --arch network.deepv3.DeepSRNX50V3PlusD_m1 \
7 | --inference_mode sliding \
8 | --scales 1.0 \
9 | --split val \
10 | --cv_split 0 \
11 | --dump_images \
12 | --ckpt_path ${2} \
13 | --snapshot ${1}
14 |
--------------------------------------------------------------------------------
/scripts/eval_cityscapes_WideResNet38.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 | echo "Running inference on" ${1}
3 | echo "Saving Results :" ${2}
4 | PYTHONPATH=$PWD:$PYTHONPATH python3 eval.py \
5 | --dataset cityscapes \
6 | --arch network.deepv3.DeepWV3Plus \
7 | --inference_mode sliding \
8 | --scales 1.0 \
9 | --split val \
10 | --cv_split 0 \
11 | --dump_images \
12 | --ckpt_path ${2} \
13 | --snapshot ${1}
14 |
--------------------------------------------------------------------------------
/scripts/submit_cityscapes_WideResNet38.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 | echo "Running inference on" ${1}
3 | echo "Saving Results :" ${2}
4 | PYTHONPATH=$PWD:$PYTHONPATH python3 eval.py \
5 | --dataset cityscapes \
6 | --arch network.deepv3.DeepWV3Plus \
7 | --inference_mode sliding \
8 | --scales 0.5,1.0,2.0 \
9 | --split test \
10 | --cv_split 0 \
11 | --dump_images \
12 | --ckpt_path ${2} \
13 | --snapshot ${1}
14 |
--------------------------------------------------------------------------------
/scripts/test_kitti_WideResNet38.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 | echo "Running inference on" ${1}
3 | echo "Saving Results :" ${2}
4 | PYTHONPATH=$PWD:$PYTHONPATH python3 eval.py \
5 | --dataset kitti \
6 | --arch network.deepv3.DeepWV3Plus \
7 | --mode semantic \
8 | --split test \
9 | --inference_mode sliding \
10 | --cv_split 0 \
11 | --scales 1.5,2.0,2.5 \
12 | --crop_size 368 \
13 | --dump_images \
14 | --snapshot ${1} \
15 | --ckpt_path ${2}
16 |
--------------------------------------------------------------------------------
/scripts/train_cityscapes_SEResNeXt50.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 |
3 | # Example on Cityscapes
4 | python -m torch.distributed.launch --nproc_per_node=8 train.py \
5 | --dataset cityscapes \
6 | --cv 0 \
7 | --arch network.deepv3.DeepSRNX50V3PlusD_m1 \
8 | --snapshot ./pretrained_models/YOUR_TRAINED_MAPILLARY_MODEL \
9 | --class_uniform_pct 0.5 \
10 | --class_uniform_tile 1024 \
11 | --max_cu_epoch 150 \
12 | --lr 0.001 \
13 | --lr_schedule scl-poly \
14 | --poly_exp 1.0 \
15 | --repoly 1.5 \
16 | --rescale 1.0 \
17 | --syncbn \
18 | --sgd \
19 | --crop_size 896 \
20 | --scale_min 0.5 \
21 | --scale_max 2.0 \
22 | --color_aug 0.25 \
23 | --gblur \
24 | --max_epoch 175 \
25 | --coarse_boost_classes 14,15,16,3,12,17,4 \
26 | --jointwtborder \
27 | --strict_bdr_cls 5,6,7,11,12,17,18 \
28 | --rlx_off_epoch 100 \
29 | --wt_bound 1.0 \
30 | --bs_mult 2 \
31 | --apex \
32 | --exp cityscapes_ft \
33 | --ckpt ./logs/ \
34 | --tb_path ./logs/
35 |
--------------------------------------------------------------------------------
/scripts/train_cityscapes_WideResNet38.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 |
3 | # Example on Cityscapes
4 | python -m torch.distributed.launch --nproc_per_node=8 train.py \
5 | --dataset cityscapes \
6 | --cv 2 \
7 | --arch network.deepv3.DeepWV3Plus \
8 | --snapshot ./pretrained_models/YOUR_TRAINED_MAPILLARY_MODEL \
9 | --class_uniform_pct 0.5 \
10 | --class_uniform_tile 1024 \
11 | --max_cu_epoch 150 \
12 | --lr 0.001 \
13 | --lr_schedule scl-poly \
14 | --poly_exp 1.0 \
15 | --repoly 1.5 \
16 | --rescale 1.0 \
17 | --syncbn \
18 | --sgd \
19 | --crop_size 896 \
20 | --scale_min 0.5 \
21 | --scale_max 2.0 \
22 | --color_aug 0.25 \
23 | --gblur \
24 | --max_epoch 175 \
25 | --coarse_boost_classes 14,15,16,3,12,17,4 \
26 | --jointwtborder \
27 | --strict_bdr_cls 5,6,7,11,12,17,18 \
28 | --rlx_off_epoch 100 \
29 | --wt_bound 1.0 \
30 | --bs_mult 2 \
31 | --apex \
32 | --exp cityscapes_ft \
33 | --ckpt ./logs/ \
34 | --tb_path ./logs/
35 |
--------------------------------------------------------------------------------
/scripts/train_comma10k_WideResNet38.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 |
3 | CUDA_VISIBLE_DEVICES=0 python -m torch.distributed.launch --nproc_per_node=1 train.py \
4 | --dataset comma10k \
5 | --cv 2 \
6 | --arch network.deepv3.DeepWV3Plus \
7 | --class_uniform_pct 0.0 \
8 | --class_uniform_tile 300 \
9 | --lr 2e-2 \
10 | --lr_schedule poly \
11 | --poly_exp 1.0 \
12 | --adam \
13 | --crop_size 360 \
14 | --scale_min 1.0 \
15 | --scale_max 2.0 \
16 | --color_aug 0.25 \
17 | --max_epoch 90 \
18 | --img_wt_loss \
19 | --wt_bound 1.0 \
20 | --bs_mult 2 \
21 | --exp comma10k \
22 | --ckpt ./logs/ \
23 | --tb_path ./logs/
24 |
25 | #--lr 0.001 \
26 | #--snapshot ./pretrained_models/cityscapes_best.pth \
27 | #--apex \
28 | #--syncbn \
29 |
--------------------------------------------------------------------------------
/scripts/train_kitti_WideResNet38.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 |
3 | # Example on KITTI, fine tune
4 | python -m torch.distributed.launch --nproc_per_node=8 train.py \
5 | --dataset kitti \
6 | --cv 2 \
7 | --arch network.deepv3.DeepWV3Plus \
8 | --snapshot ./pretrained_models/cityscapes_best.pth \
9 | --class_uniform_pct 0.5 \
10 | --class_uniform_tile 300 \
11 | --lr 0.001 \
12 | --lr_schedule poly \
13 | --poly_exp 1.0 \
14 | --syncbn \
15 | --sgd \
16 | --crop_size 360 \
17 | --scale_min 1.0 \
18 | --scale_max 2.0 \
19 | --color_aug 0.25 \
20 | --max_epoch 90 \
21 | --img_wt_loss \
22 | --wt_bound 1.0 \
23 | --bs_mult 2 \
24 | --apex \
25 | --exp kitti_ft \
26 | --ckpt ./logs/ \
27 | --tb_path ./logs/
28 |
29 |
30 |
--------------------------------------------------------------------------------
/scripts/train_mapillary_SEResNeXt50.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 |
3 | # Example on Mapillary
4 | python -m torch.distributed.launch --nproc_per_node=8 train.py \
5 | --dataset mapillary \
6 | --arch network.deepv3.DeepSRNX50V3PlusD_m1 \
7 | --class_uniform_pct 0.5 \
8 | --class_uniform_tile 1024 \
9 | --syncbn \
10 | --sgd \
11 | --lr 2e-2 \
12 | --lr_schedule poly \
13 | --poly_exp 1.0 \
14 | --crop_size 896 \
15 | --scale_min 0.5 \
16 | --scale_max 2.0 \
17 | --color_aug 0.25 \
18 | --gblur \
19 | --max_epoch 175 \
20 | --img_wt_loss \
21 | --wt_bound 6.0 \
22 | --bs_mult 2 \
23 | --apex \
24 | --exp mapillary_pretrain \
25 | --ckpt ./logs/ \
26 | --tb_path ./logs/
27 |
--------------------------------------------------------------------------------
/scripts/train_mapillary_WideResNet38.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 |
3 | # Example on Mapillary
4 | python -m torch.distributed.launch --nproc_per_node=8 train.py \
5 | --dataset mapillary \
6 | --arch network.deepv3.DeepWV3Plus \
7 | --class_uniform_pct 0.5 \
8 | --class_uniform_tile 1024 \
9 | --syncbn \
10 | --sgd \
11 | --lr 2e-2 \
12 | --lr_schedule poly \
13 | --poly_exp 1.0 \
14 | --crop_size 896 \
15 | --scale_min 0.5 \
16 | --scale_max 2.0 \
17 | --color_aug 0.25 \
18 | --gblur \
19 | --max_epoch 175 \
20 | --img_wt_loss \
21 | --wt_bound 6.0 \
22 | --bs_mult 2 \
23 | --apex \
24 | --exp mapillary_pretrain \
25 | --ckpt ./logs/ \
26 | --tb_path ./logs/
27 |
--------------------------------------------------------------------------------
/sdcnet/_aug.sh:
--------------------------------------------------------------------------------
1 | python sdc_aug.py --propagate 5 --vis \
2 | --pretrained ../pretrained_models/sdc_cityscapes_vrec.pth.tar \
3 | --flownet2_checkpoint ../pretrained_models/FlowNet2_checkpoint.pth.tar \
4 | --source_dir /home/ubuntu/yizhu/data/Cityscapes \
5 | --target_dir /home/ubuntu/yizhu/data/Cityscapes/cs_aug
6 |
7 |
8 |
9 |
10 |
--------------------------------------------------------------------------------
/sdcnet/_eval.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 | # Run SDC2DRecon on Cityscapes dataset
3 |
4 | # Root folder of cityscapes images
5 | VAL_FILE=~/data/tmp/tinycs
6 | SDC2DREC_CHECKPOINT=../pretrained_models/sdc_cityscapes_vrec.pth.tar
7 | FLOWNET2_CHECKPOINT=../pretrained_models/FlowNet2_checkpoint.pth.tar
8 |
9 | python3 main.py \
10 | --eval \
11 | --sequence_length 2 \
12 | --save ./ \
13 | --name __evalrun \
14 | --val_n_batches 1 \
15 | --write_images \
16 | --dataset FrameLoader \
17 | --model SDCNet2DRecon \
18 | --val_file ${VAL_FILE} \
19 | --resume ${SDC2DREC_CHECKPOINT} \
20 | --flownet2_checkpoint ${FLOWNET2_CHECKPOINT}
21 |
22 |
--------------------------------------------------------------------------------
/sdcnet/datasets/__init__.py:
--------------------------------------------------------------------------------
1 | from .frame_loader import *
--------------------------------------------------------------------------------
/sdcnet/datasets/dataset_utils.py:
--------------------------------------------------------------------------------
1 | from __future__ import division
2 | from __future__ import print_function
3 |
4 | import torch
5 |
6 | class StaticRandomCrop(object):
7 | """
8 | Helper function for random spatial crop
9 | """
10 | def __init__(self, size, image_shape):
11 | h, w = image_shape
12 | self.th, self.tw = size
13 | self.h1 = torch.randint(0, h - self.th + 1, (1,)).item()
14 | self.w1 = torch.randint(0, w - self.tw + 1, (1,)).item()
15 |
16 | def __call__(self, img):
17 | return img[self.h1:(self.h1 + self.th), self.w1:(self.w1 + self.tw), :]
18 |
--------------------------------------------------------------------------------
/sdcnet/datasets/frame_loader.py:
--------------------------------------------------------------------------------
1 | from __future__ import division
2 | from __future__ import print_function
3 |
4 | import os
5 | import natsort
6 | import numpy as np
7 | import cv2
8 |
9 |
10 | import torch
11 | from torch.utils import data
12 | from datasets.dataset_utils import StaticRandomCrop
13 |
14 | class FrameLoader(data.Dataset):
15 | def __init__(self, args, root, is_training = False, transform=None):
16 |
17 | self.is_training = is_training
18 | self.transform = transform
19 | self.chsize = 3
20 |
21 | # carry over command line arguments
22 | assert args.sequence_length > 1, 'sequence length must be > 1'
23 | self.sequence_length = args.sequence_length
24 |
25 | assert args.sample_rate > 0, 'sample rate must be > 0'
26 | self.sample_rate = args.sample_rate
27 |
28 | self.crop_size = args.crop_size
29 | self.start_index = args.start_index
30 | self.stride = args.stride
31 |
32 | assert (os.path.exists(root))
33 | if self.is_training:
34 | self.start_index = 0
35 |
36 | # collect, colors, motion vectors, and depth
37 | self.ref = self.collect_filelist(root)
38 |
39 | counts = [((len(el) - self.sequence_length) // (self.sample_rate)) for el in self.ref]
40 | self.total = np.sum(counts)
41 | self.cum_sum = list(np.cumsum([0] + [el for el in counts]))
42 |
43 | def collect_filelist(self, root):
44 | include_ext = [".png", ".jpg", "jpeg", ".bmp"]
45 | # collect subfolders, excluding hidden files, but following symlinks
46 | dirs = [x[0] for x in os.walk(root, followlinks=True) if not x[0].startswith('.')]
47 |
48 | # naturally sort, both dirs and individual images, while skipping hidden files
49 | dirs = natsort.natsorted(dirs)
50 |
51 | datasets = [
52 | [os.path.join(fdir, el) for el in natsort.natsorted(os.listdir(fdir))
53 | if os.path.isfile(os.path.join(fdir, el))
54 | and not el.startswith('.')
55 | and any([el.endswith(ext) for ext in include_ext])]
56 | for fdir in dirs
57 | ]
58 |
59 | return [el for el in datasets if el]
60 |
61 | def __len__(self):
62 | return self.total
63 |
64 | def __getitem__(self, index):
65 | # adjust index
66 | index = len(self) + index if index < 0 else index
67 | index = index + self.start_index
68 |
69 | dataset_index = np.searchsorted(self.cum_sum, index + 1)
70 | index = self.sample_rate * (index - self.cum_sum[np.maximum(0, dataset_index - 1)])
71 |
72 | image_list = self.ref[dataset_index - 1]
73 | input_files = [ image_list[index + offset] for offset in range(self.sequence_length + 1)]
74 |
75 | # reverse image order with p=0.5
76 | if self.is_training and torch.randint(0, 2, (1,)).item():
77 | input_files = input_files[::-1]
78 |
79 | # images = [imageio.imread(imfile)[..., :self.chsize] for imfile in input_files]
80 | images = [cv2.imread(imfile)[..., :self.chsize] for imfile in input_files]
81 | input_shape = images[0].shape[:2]
82 | if self.is_training:
83 | cropper = StaticRandomCrop(self.crop_size, input_shape)
84 | images = map(cropper, images)
85 |
86 | # Pad images along height and width to fit them evenly into models.
87 | height, width = input_shape
88 | if (height % self.stride) != 0:
89 | padded_height = (height // self.stride + 1) * self.stride
90 | images = [ np.pad(im, ((0, padded_height - height), (0,0), (0,0)), 'reflect') for im in images]
91 |
92 | if (width % self.stride) != 0:
93 | padded_width = (width // self.stride + 1) * self.stride
94 | images = [np.pad(im, ((0, 0), (0, padded_width - width), (0, 0)), 'reflect') for im in images]
95 |
96 | input_images = [torch.from_numpy(im.transpose(2, 0, 1)).float() for im in images]
97 |
98 | output_dict = {
99 | 'image': input_images, 'ishape': input_shape, 'input_files': input_files
100 | }
101 |
102 | return output_dict
--------------------------------------------------------------------------------
/sdcnet/models/__init__.py:
--------------------------------------------------------------------------------
1 | from .sdc_net2d import *
--------------------------------------------------------------------------------
/sdcnet/models/model_utils.py:
--------------------------------------------------------------------------------
1 | from __future__ import division
2 | from __future__ import print_function
3 |
4 | import torch.nn as nn
5 |
6 | def conv2d(channels_in, channels_out, kernel_size=3, stride=1, bias = True):
7 | return nn.Sequential(
8 | nn.Conv2d(channels_in, channels_out, kernel_size=kernel_size, stride=stride, padding=(kernel_size-1)//2, bias=bias),
9 | nn.LeakyReLU(0.1,inplace=True)
10 | )
11 |
12 | def deconv2d(channels_in, channels_out, kernel_size=4, stride=2, padding=1, bias=True):
13 | return nn.Sequential(
14 | nn.ConvTranspose2d(channels_in, channels_out, kernel_size=kernel_size, stride=stride, padding=padding, bias=bias),
15 | nn.LeakyReLU(0.1,inplace=True)
16 | )
--------------------------------------------------------------------------------
/sdcnet/models/sdc_net2d.py:
--------------------------------------------------------------------------------
1 | '''
2 | Portions of this code are adapted from:
3 | https://github.com/NVIDIA/flownet2-pytorch/blob/master/networks/FlowNetS.py
4 | https://github.com/ClementPinard/FlowNetPytorch/blob/master/models/FlowNetS.py
5 | '''
6 | from __future__ import division
7 | from __future__ import print_function
8 |
9 | import torch
10 | import torch.nn as nn
11 | from torch.nn import init
12 | import os
13 |
14 | from models.model_utils import conv2d, deconv2d
15 |
16 | from flownet2_pytorch.models import FlowNet2
17 | from flownet2_pytorch.networks.resample2d_package.resample2d import Resample2d
18 |
19 |
20 | class SDCNet2D(nn.Module):
21 | def __init__(self, args):
22 | super(SDCNet2D,self).__init__()
23 |
24 | self.rgb_max = args.rgb_max
25 | self.sequence_length = args.sequence_length
26 |
27 | factor = 2
28 | input_channels = self.sequence_length * 3 + (self.sequence_length - 1) * 2
29 |
30 | self.conv1 = conv2d(input_channels, 64 // factor, kernel_size=7, stride=2)
31 | self.conv2 = conv2d(64 // factor, 128 // factor, kernel_size=5, stride=2)
32 | self.conv3 = conv2d(128 // factor, 256 // factor, kernel_size=5, stride=2)
33 | self.conv3_1 = conv2d(256 // factor, 256 // factor)
34 | self.conv4 = conv2d(256 // factor, 512 // factor, stride=2)
35 | self.conv4_1 = conv2d(512 // factor, 512 // factor)
36 | self.conv5 = conv2d(512 // factor, 512 // factor, stride=2)
37 | self.conv5_1 = conv2d(512 // factor, 512 // factor)
38 | self.conv6 = conv2d(512 // factor, 1024 // factor, stride=2)
39 | self.conv6_1 = conv2d(1024 // factor, 1024 // factor)
40 |
41 | self.deconv5 = deconv2d(1024 // factor, 512 // factor)
42 | self.deconv4 = deconv2d(1024 // factor, 256 // factor)
43 | self.deconv3 = deconv2d(768 // factor, 128 // factor)
44 | self.deconv2 = deconv2d(384 // factor, 64 // factor)
45 | self.deconv1 = deconv2d(192 // factor, 32 // factor)
46 | self.deconv0 = deconv2d(96 // factor, 16 // factor)
47 |
48 | self.final_flow = nn.Conv2d(input_channels + 16 // factor, 2,
49 | kernel_size=3, stride=1, padding=1, bias=True)
50 |
51 |
52 | # init parameters, when doing convtranspose3d, do bilinear init
53 | for m in self.modules():
54 | if isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose3d):
55 | if m.bias is not None:
56 | init.uniform_(m.bias)
57 | init.xavier_uniform_(m.weight)
58 |
59 | self.flownet2 = FlowNet2(args, batchNorm=False)
60 | assert os.path.exists(args.flownet2_checkpoint), "flownet2 checkpoint must be provided."
61 | flownet2_checkpoint = torch.load(args.flownet2_checkpoint)
62 | self.flownet2.load_state_dict(flownet2_checkpoint['state_dict'], strict=False)
63 |
64 | for param in self.flownet2.parameters():
65 | param.requires_grad = False
66 |
67 | self.warp_nn = Resample2d(bilinear=False)
68 | self.warp_bilinear = Resample2d(bilinear=True)
69 |
70 | self.L1Loss = nn.L1Loss()
71 |
72 | flow_mean = torch.FloatTensor([-0.94427323, -1.23077035]).view(1, 2, 1, 1)
73 | flow_std = torch.FloatTensor([13.77204132, 7.47463894]).view(1, 2, 1, 1)
74 | rgb_mean = torch.FloatTensor([106.7747911, 96.13649598, 76.61428884]).view(1, 3, 1, 1)
75 |
76 | self.register_buffer('flow_mean', flow_mean)
77 | self.register_buffer('flow_std', flow_std)
78 | self.register_buffer('rgb_mean', rgb_mean)
79 |
80 | self.ignore_keys = ['flownet2']
81 | return
82 |
83 | def interframe_optical_flow(self, input_images):
84 |
85 | #FIXME: flownet2 implementation expects RGB images,
86 | # while image formats for both SDCNet and DeepLabV3 expects BGR.
87 | # input_images = [torch.flip(input_image, dims=[1]) for input_image in input_images]
88 |
89 | # Create image pairs for flownet, then merge batch and frame dimension
90 | # so theres only a single call to flownet2 is done.
91 | flownet2_inputs = torch.stack(
92 | [torch.cat([input_images[i + 1].unsqueeze(2), input_images[i].unsqueeze(2)], dim=2) for i in
93 | range(0, self.sequence_length - 1)], dim=0).contiguous()
94 |
95 | batch_size, channel_count, height, width = input_images[0].shape
96 | flownet2_inputs_flattened = flownet2_inputs.view(-1,channel_count, 2, height, width)
97 | flownet2_outputs = [self.flownet2(flownet2_input) for flownet2_input in
98 | torch.chunk(flownet2_inputs_flattened, self.sequence_length - 1)]
99 |
100 | #FIXME: flipback images to BGR,
101 | # input_images = [torch.flip(input_image, dims=[1]) for input_image in input_images]
102 |
103 | return flownet2_outputs
104 |
105 | def network_output(self, input_images, input_flows):
106 |
107 | # Normalize input flows
108 | input_flows = [(input_flow - self.flow_mean) / (3 * self.flow_std) for
109 | input_flow in input_flows]
110 |
111 | # Normalize input via flownet2-type normalisation
112 | concated_images = torch.cat([image.unsqueeze(2) for image in input_images], dim=2).contiguous()
113 | rgb_mean = concated_images.view(concated_images.size()[:2] + (-1,)).mean(dim=-1).view(
114 | concated_images.size()[:2] + 2 * (1,))
115 | input_images = [(input_image - rgb_mean) / self.rgb_max for input_image in input_images]
116 | bsize, channels, height, width = input_flows[0].shape
117 |
118 | # Atypical concatenation of input images along channels (done for compatibility with pre-trained models)
119 | # for two rgb images, concated channels would appear as (r1r2g1g2b1b2),
120 | # instaed of typical (r1g1b1r2g2b2) that can be obtained by torch.cat(..,dim=1)
121 | input_images = torch.cat([input_image.unsqueeze(2) for input_image in input_images], dim=2)
122 | input_images = input_images.contiguous().view(bsize, -1, height, width)
123 |
124 | # same atypical concatenation done for input flows.
125 | input_flows = torch.cat([input_flow.unsqueeze(2) for input_flow in input_flows], dim=2)
126 | input_flows = input_flows.contiguous().view(bsize, -1, height, width)
127 |
128 | # Network input
129 | images_and_flows = torch.cat((input_flows, input_images), dim=1)
130 |
131 | out_conv1 = self.conv1(images_and_flows)
132 |
133 | out_conv2 = self.conv2(out_conv1)
134 | out_conv3 = self.conv3_1(self.conv3(out_conv2))
135 | out_conv4 = self.conv4_1(self.conv4(out_conv3))
136 | out_conv5 = self.conv5_1(self.conv5(out_conv4))
137 | out_conv6 = self.conv6_1(self.conv6(out_conv5))
138 |
139 | out_deconv5 = self.deconv5(out_conv6)
140 | concat5 = torch.cat((out_conv5, out_deconv5), 1)
141 |
142 | out_deconv4 = self.deconv4(concat5)
143 | concat4 = torch.cat((out_conv4, out_deconv4), 1)
144 |
145 | out_deconv3 = self.deconv3(concat4)
146 | concat3 = torch.cat((out_conv3, out_deconv3), 1)
147 |
148 | out_deconv2 = self.deconv2(concat3)
149 | concat2 = torch.cat((out_conv2, out_deconv2), 1)
150 |
151 | out_deconv1 = self.deconv1(concat2)
152 | concat1 = torch.cat((out_conv1, out_deconv1), 1)
153 |
154 | out_deconv0 = self.deconv0(concat1)
155 |
156 | concat0 = torch.cat((images_and_flows, out_deconv0), 1)
157 | output_flow = self.final_flow(concat0)
158 |
159 | flow_prediction = 3 * self.flow_std * output_flow + self.flow_mean
160 |
161 | return flow_prediction
162 |
163 | def prepare_inputs(self, input_dict):
164 | images = input_dict['image'] # expects a list
165 |
166 | input_images = images[:-1]
167 |
168 | target_image = images[-1]
169 |
170 | last_image = (input_images[-1]).clone()
171 |
172 | return input_images, last_image, target_image
173 |
174 | def forward(self, input_dict, label_image=None):
175 |
176 | input_images, last_image, target_image = self.prepare_inputs(input_dict)
177 |
178 | input_flows = self.interframe_optical_flow(input_images)
179 |
180 | flow_prediction = self.network_output(input_images, input_flows)
181 |
182 | image_prediction = self.warp_bilinear(last_image, flow_prediction)
183 |
184 | if label_image is not None:
185 | label_prediction = self.warp_nn(label_image, flow_prediction)
186 |
187 | # calculate losses
188 | losses = {}
189 |
190 | losses['color'] = self.L1Loss(image_prediction/self.rgb_max, target_image/self.rgb_max)
191 |
192 | losses['color_gradient'] = self.L1Loss(torch.abs(image_prediction[...,1:] - image_prediction[...,:-1]), \
193 | torch.abs(target_image[...,1:] - target_image[...,:-1])) + \
194 | self.L1Loss(torch.abs(image_prediction[..., 1:,:] - image_prediction[..., :-1,:]), \
195 | torch.abs(target_image[..., 1:,:] - target_image[..., :-1,:]))
196 |
197 | losses['flow_smoothness'] = self.L1Loss(flow_prediction[...,1:], flow_prediction[...,:-1]) + \
198 | self.L1Loss(flow_prediction[..., 1:,:], flow_prediction[..., :-1,:])
199 |
200 | losses['tot'] = 0.7 * losses['color'] + 0.2 * losses['color_gradient'] + 0.1 * losses['flow_smoothness']
201 |
202 | if label_image is not None:
203 | image_prediction = label_prediction
204 |
205 | return losses, image_prediction, target_image
206 |
207 |
208 | class SDCNet2DRecon(SDCNet2D):
209 | def __init__(self, args):
210 | args.sequence_length += 1
211 | super(SDCNet2DRecon, self).__init__(args)
212 |
213 | def prepare_inputs(self, input_dict):
214 | images = input_dict['image'] # expects a list
215 |
216 | input_images = images
217 |
218 | target_image = images[-1]
219 |
220 | last_image = (input_images[-2]).clone()
221 |
222 | return input_images, last_image, target_image
--------------------------------------------------------------------------------
/sdcnet/spatialdisplconv_package/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/geohot/semantic-segmentation/d5b6ec7ec9d4e296fc0e1aa85163ed0e98f54f92/sdcnet/spatialdisplconv_package/__init__.py
--------------------------------------------------------------------------------
/sdcnet/spatialdisplconv_package/setup.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | import os
3 | import torch
4 |
5 | from setuptools import setup
6 | from torch.utils.cpp_extension import BuildExtension, CUDAExtension
7 |
8 | cxx_args = ['-std=c++11']
9 |
10 | nvcc_args = [
11 | '-gencode', 'arch=compute_50,code=sm_50',
12 | '-gencode', 'arch=compute_52,code=sm_52',
13 | '-gencode', 'arch=compute_60,code=sm_60',
14 | '-gencode', 'arch=compute_61,code=sm_61',
15 | '-gencode', 'arch=compute_70,code=sm_70',
16 | '-gencode', 'arch=compute_70,code=compute_70'
17 | ]
18 |
19 | setup(
20 | name='spatialdisplconv_cuda',
21 | ext_modules=[
22 | CUDAExtension('spatialdisplconv_cuda', [
23 | 'spatialdisplconv_cuda.cc',
24 | 'spatialdisplconv_kernel.cu'
25 | ], extra_compile_args={'cxx': cxx_args, 'nvcc': nvcc_args})
26 | ],
27 | cmdclass={
28 | 'build_ext': BuildExtension
29 | })
30 |
--------------------------------------------------------------------------------
/sdcnet/spatialdisplconv_package/spatialdisplconv.py:
--------------------------------------------------------------------------------
1 | from torch.nn.modules.module import Module
2 | from torch.autograd import Function, Variable
3 | import spatialdisplconv_cuda
4 |
5 | class SpatialDisplConvFunction(Function):
6 |
7 | @staticmethod
8 | def forward(ctx, input1, input2, input3, input4, kernel_size = 1):
9 | assert input1.is_contiguous(), "spatialdisplconv forward - input1 is not contiguous"
10 | assert input2.is_contiguous(), "spatialdisplconv forward - input2 is not contiguous"
11 | assert input3.is_contiguous(), "spatialdisplconv forward - input3 is not contiguous"
12 | assert input4.is_contiguous(), "spatialdisplconv forward - input4 is not contiguous"
13 |
14 | ctx.save_for_backward(input1, input2, input3, input4)
15 | ctx.kernel_size = kernel_size
16 |
17 | _, image_channels, _, _ = input1.size()
18 | batch_size, _, height, width = input2.size()
19 | output = input1.new(batch_size, image_channels, height, width).zero_()
20 |
21 | spatialdisplconv_cuda.forward(
22 | input1,
23 | input2,
24 | input3,
25 | input4,
26 | output,
27 | kernel_size
28 | )
29 |
30 | return output
31 |
32 | @staticmethod
33 | def backward(ctx, grad_output):
34 | grad_output = grad_output.contiguous()
35 | assert grad_output.is_contiguous()
36 |
37 | input1, input2, input3, input4 = ctx.saved_tensors
38 |
39 | grad_input1 = Variable(input1.new(input1.size()).zero_())
40 | grad_input2 = Variable(input2.new(input2.size()).zero_())
41 | grad_input3 = Variable(input3.new(input3.size()).zero_())
42 | grad_input4 = Variable(input4.new(input4.size()).zero_())
43 |
44 | spatialdisplconv_cuda.backward(
45 | input1,
46 | input2,
47 | input3,
48 | input4,
49 | grad_output.data,
50 | grad_input1.data,
51 | grad_input2.data,
52 | grad_input3.data,
53 | grad_input4.data,
54 | ctx.kernel_size
55 | )
56 |
57 | return grad_input1, grad_input2, grad_input3, grad_input4, None
58 |
59 | class SpatialDisplConv(Module):
60 | def __init__(self, kernel_size = 1):
61 | super(SpatialDisplConv, self).__init__()
62 | self.kernel_size = kernel_size
63 |
64 |
65 | def forward(self, input1, input2, input3, input4):
66 |
67 | return SpatialDisplConvFunction.apply(input1, input2, input3, input4, self.kernel_size)
68 |
--------------------------------------------------------------------------------
/sdcnet/spatialdisplconv_package/spatialdisplconv_cuda.cc:
--------------------------------------------------------------------------------
1 | #include
2 | #include
3 |
4 | #include "spatialdisplconv_kernel.cuh"
5 |
6 | int spatialdisplconv_cuda_forward(
7 | at::Tensor& input1,
8 | at::Tensor& input2,
9 | at::Tensor& input3,
10 | at::Tensor& input4,
11 | at::Tensor& output,
12 | int kernel_size) {
13 |
14 | spatialdisplconv_kernel_forward(
15 | input1,
16 | input2,
17 | input3,
18 | input4,
19 | output,
20 | kernel_size
21 | );
22 |
23 | return 1;
24 | }
25 |
26 |
27 | int spatialdisplconv_cuda_backward(
28 | at::Tensor& input1,
29 | at::Tensor& input2,
30 | at::Tensor& input3,
31 | at::Tensor& input4,
32 | at::Tensor& gradOutput,
33 | at::Tensor& gradInput1,
34 | at::Tensor& gradInput2,
35 | at::Tensor& gradInput3,
36 | at::Tensor& gradInput4,
37 | int kernel_size
38 |
39 | ) {
40 | spatialdisplconv_kernel_backward(
41 | input1,
42 | input2,
43 | input3,
44 | input4,
45 | gradOutput,
46 | gradInput1,
47 | gradInput2,
48 | gradInput3,
49 | gradInput4,
50 | kernel_size
51 | );
52 |
53 | return 1;
54 | }
55 |
56 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
57 | m.def("forward", &spatialdisplconv_cuda_forward, "SpatialDisplConv forward (CUDA)");
58 | m.def("backward", &spatialdisplconv_cuda_backward, "SpatialDisplConv backward (CUDA)");
59 | }
60 |
--------------------------------------------------------------------------------
/sdcnet/spatialdisplconv_package/spatialdisplconv_kernel.cuh:
--------------------------------------------------------------------------------
1 | #pragma once
2 |
3 | #include
4 |
5 | void spatialdisplconv_kernel_forward(
6 | at::Tensor& input1,
7 | at::Tensor& input2,
8 | at::Tensor& input3,
9 | at::Tensor& input4,
10 | at::Tensor& output,
11 | int kernel_size
12 | );
13 |
14 | void spatialdisplconv_kernel_backward(
15 | at::Tensor& input1,
16 | at::Tensor& input2,
17 | at::Tensor& input3,
18 | at::Tensor& input4,
19 | at::Tensor& gradOutput,
20 | at::Tensor& gradInput1,
21 | at::Tensor& gradInput2,
22 | at::Tensor& gradInput3,
23 | at::Tensor& gradInput4,
24 | int kernel_size
25 |
26 | );
27 |
--------------------------------------------------------------------------------
/sdcnet/spatialdisplconv_package/test_spatialdisplconv.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import time
3 | from spatialdisplconv import SpatialDisplConv
4 |
5 | assert torch.cuda.is_available()
6 | cuda_device = torch.device("cuda") # device object representing GPU
7 |
8 | n = 8
9 | h = 224
10 | w = 224
11 |
12 | offset = 9 # 11
13 |
14 | #input1 = N, 3, H + 11, W + 11
15 | #input2 = N, 11, H, W
16 | #input3 = N, 11, H, W
17 | #input4 = N, 2, H, W
18 |
19 | # Note the device=cuda_device arguments here
20 | a = torch.randn(n, 3, h + offset, w + offset, device=cuda_device, requires_grad=True).contiguous()
21 |
22 | b = torch.randn(n, offset, h, w, device=cuda_device, requires_grad=True).contiguous()
23 |
24 | c = torch.randn(n, offset, h, w, device=cuda_device, requires_grad=True).contiguous()
25 |
26 | d = torch.randn(n, 2, h, w, device=cuda_device, requires_grad=True).contiguous()
27 |
28 | sdc_layer = SpatialDisplConv(kernel_size=1).cuda()
29 |
30 | forward = 0
31 | backward = 0
32 | num_runs = 100
33 |
34 | for _ in range(num_runs):
35 |
36 | start = time.time()
37 |
38 | result = sdc_layer.forward(a, b, c, d)
39 | torch.cuda.synchronize()
40 | forward += time.time() - start
41 |
42 | sdc_layer.zero_grad()
43 |
44 | start = time.time()
45 |
46 | result_sum = result.sum()
47 |
48 | result_sum.backward()
49 | torch.cuda.synchronize()
50 | backward += time.time() - start
51 |
52 | print("Forward time per iteration: %.4f ms" % (forward * 1.0 / num_runs * 1000))
53 | print("Backward time per iteration: %.4f ms" % (forward * 1.0 / num_runs * 1000))
54 |
--------------------------------------------------------------------------------
/sdcnet/utility/Dockerfile:
--------------------------------------------------------------------------------
1 | # ===========
2 | # base images
3 | # ===========
4 | FROM FROM nvidia/cuda:10.0-cudnn7-devel-ubuntu16.04
5 |
6 |
7 | # ===============
8 | # system packages
9 | # ===============
10 | RUN apt-get update
11 | RUN apt-get install -y bash-completion \
12 | emacs \
13 | ffmpeg \
14 | git \
15 | graphviz \
16 | htop \
17 | libopenexr-dev \
18 | openssh-server \
19 | rsync \
20 | wget
21 |
22 | # ===========
23 | # latest apex
24 | # ===========
25 | RUN pip uninstall -y apex
26 | RUN git clone https://github.com/NVIDIA/apex.git ~/apex && \
27 | cd ~/apex && \
28 | pip install -v --no-cache-dir --global-option="--cpp_ext" --global-option="--cuda_ext" .
29 |
30 | # ============
31 | # pip packages
32 | # ============
33 | RUN pip install --upgrade pip
34 | RUN pip install --upgrade setuptools
35 | RUN pip install --upgrade boto3
36 | RUN pip install --upgrade cffi
37 | RUN pip install --upgrade colorama==0.3.7
38 | RUN pip install --upgrade Cython
39 | RUN pip install --upgrade dominate
40 | RUN pip install --upgrade ffmpeg
41 | RUN pip install --upgrade graphviz
42 | RUN pip install --upgrade imageio
43 | RUN pip install --upgrade ipython
44 | RUN pip install --upgrade matplotlib
45 | RUN pip install --upgrade natsort
46 | RUN pip install --upgrade nltk
47 | RUN pip install --upgrade numpy
48 | RUN pip install --upgrade openexr
49 | RUN pip install --upgrade packaging
50 | RUN pip install --upgrade pandas
51 | RUN pip install --upgrade pillow
52 | RUN pip install --upgrade pylint
53 | RUN pip install --upgrade pytz
54 | RUN pip install --upgrade pyyaml
55 | RUN pip install --upgrade requests
56 | RUN pip install --upgrade scikit-image
57 | RUN pip install --upgrade scikit-learn
58 | RUN pip install --upgrade scipy
59 | RUN pip install --upgrade sentencepiece
60 | RUN pip install --upgrade setproctitle
61 | RUN pip install --upgrade tensorboard
62 | RUN pip install --upgrade tensorboardX
63 | RUN pip install --upgrade tensorflow
64 | RUN pip install --upgrade torchvision
65 | RUN pip install --upgrade tqdm
66 | RUN pip install --upgrade youtube_dl
67 | RUN pip install --opencv-python
68 |
--------------------------------------------------------------------------------
/sdcnet/utility/tools.py:
--------------------------------------------------------------------------------
1 | import os
2 | import subprocess
3 | import time
4 | from inspect import isclass
5 |
6 | class TimerBlock:
7 | def __init__(self, title):
8 | print(("{}".format(title)))
9 |
10 | def __enter__(self):
11 | self.start = time.clock()
12 | return self
13 |
14 | def __exit__(self, exc_type, exc_value, traceback):
15 | self.end = time.clock()
16 | self.interval = self.end - self.start
17 |
18 | if exc_type is not None:
19 | self.log("Operation failed\n")
20 | else:
21 | self.log("Operation finished\n")
22 |
23 | def log(self, string):
24 | duration = time.clock() - self.start
25 | units = 's'
26 | if duration > 60:
27 | duration = duration / 60.
28 | units = 'm'
29 | print(" [{:.3f}{}] {}".format(duration, units, string), flush = True)
30 |
31 | def module_to_dict(module, exclude=[]):
32 | return dict([(x, getattr(module, x)) for x in dir(module)
33 | if isclass(getattr(module, x))
34 | and x not in exclude
35 | and getattr(module, x) not in exclude])
36 |
37 |
38 | # creat_pipe: adapted from https://stackoverflow.com/questions/23709893/popen-write-operation-on-closed-file-images-to-video-using-ffmpeg/23709937#23709937
39 | # start an ffmpeg pipe for creating RGB8 for color images or FFV1 for depth
40 | # NOTE: this is REALLY lossy and not optimal for HDR data. when it comes time to train
41 | # on HDR data, you'll need to figure out the way to save to pix_fmt=rgb48 or something
42 | # similar
43 | def create_pipe(pipe_filename, width, height, frame_rate=60, quite=True):
44 | # default extension and tonemapper
45 | pix_fmt = 'rgb24'
46 | out_fmt = 'yuv420p'
47 | codec = 'h264'
48 |
49 | command = ['ffmpeg',
50 | '-threads', '2', # number of threads to start
51 | '-y', # (optional) overwrite output file if it exists
52 | '-f', 'rawvideo', # input format
53 | '-vcodec', 'rawvideo', # input codec
54 | '-s', str(width) + 'x' + str(height), # size of one frame
55 | '-pix_fmt', pix_fmt, # input pixel format
56 | '-r', str(frame_rate), # frames per second
57 | '-i', '-', # The imput comes from a pipe
58 | '-an', # Tells FFMPEG not to expect any audio
59 | '-codec:v', codec, # output codec
60 | '-crf', '18',
61 | # compression quality for h264 (maybe h265 too?) - http://slhck.info/video/2017/02/24/crf-guide.html
62 | # '-compression_level', '10', # compression level for libjpeg if doing lossy depth
63 | '-strict', '-2', # experimental 16 bit support nessesary for gray16le
64 | '-pix_fmt', out_fmt, # output pixel format
65 | '-s', str(width) + 'x' + str(height), # output size
66 | pipe_filename]
67 | cmd = ' '.join(command)
68 | if not quite:
69 | print('openning a pip ....\n' + cmd + '\n')
70 |
71 | # open the pipe, and ignore stdout and stderr output
72 | DEVNULL = open(os.devnull, 'wb')
73 | return subprocess.Popen(command, stdin=subprocess.PIPE, stdout=DEVNULL, stderr=DEVNULL, close_fds=True)
74 |
75 | # AverageMeter: code from https://github.com/pytorch/examples/blob/master/imagenet/main.py
76 | class AverageMeter(object):
77 | """Computes and stores the average and current value"""
78 | def __init__(self, name, fmt=':f'):
79 | self.name = name
80 | self.fmt = fmt
81 | self.reset()
82 |
83 | def reset(self):
84 | self.val = 0
85 | self.avg = 0
86 | self.sum = 0
87 | self.count = 0
88 |
89 | def update(self, val, n=1):
90 | self.val = val
91 | self.sum += val * n
92 | self.count += n
93 | self.avg = self.sum / self.count
94 |
95 | def __str__(self):
96 | fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})'
97 | return fmtstr.format(**self.__dict__)
--------------------------------------------------------------------------------
/transforms/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/geohot/semantic-segmentation/d5b6ec7ec9d4e296fc0e1aa85163ed0e98f54f92/transforms/__init__.py
--------------------------------------------------------------------------------
/transforms/transforms.py:
--------------------------------------------------------------------------------
1 | """
2 | # Code borrowded from:
3 | # https://github.com/zijundeng/pytorch-semantic-segmentation/blob/master/utils/transforms.py
4 | #
5 | #
6 | # MIT License
7 | #
8 | # Copyright (c) 2017 ZijunDeng
9 | #
10 | # Permission is hereby granted, free of charge, to any person obtaining a copy
11 | # of this software and associated documentation files (the "Software"), to deal
12 | # in the Software without restriction, including without limitation the rights
13 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
14 | # copies of the Software, and to permit persons to whom the Software is
15 | # furnished to do so, subject to the following conditions:
16 | #
17 | # The above copyright notice and this permission notice shall be included in all
18 | # copies or substantial portions of the Software.
19 | #
20 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
21 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
22 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
23 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
24 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
25 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
26 | # SOFTWARE.
27 |
28 | """
29 |
30 | """
31 | Standard Transform
32 | """
33 |
34 | import random
35 | import numpy as np
36 | from skimage.filters import gaussian
37 | from skimage.restoration import denoise_bilateral
38 | import torch
39 | from PIL import Image, ImageEnhance
40 | import torchvision.transforms as torch_tr
41 | from config import cfg
42 | from scipy.ndimage.interpolation import shift
43 |
44 | from skimage.segmentation import find_boundaries
45 |
46 | try:
47 | import accimage
48 | except ImportError:
49 | accimage = None
50 |
51 |
52 | class RandomVerticalFlip(object):
53 | def __call__(self, img):
54 | if random.random() < 0.5:
55 | return img.transpose(Image.FLIP_TOP_BOTTOM)
56 | return img
57 |
58 |
59 | class DeNormalize(object):
60 | def __init__(self, mean, std):
61 | self.mean = mean
62 | self.std = std
63 |
64 | def __call__(self, tensor):
65 | for t, m, s in zip(tensor, self.mean, self.std):
66 | t.mul_(s).add_(m)
67 | return tensor
68 |
69 |
70 | class MaskToTensor(object):
71 | def __call__(self, img):
72 | return torch.from_numpy(np.array(img, dtype=np.int32)).long()
73 |
74 | class RelaxedBoundaryLossToTensor(object):
75 | """
76 | Boundary Relaxation
77 | """
78 | def __init__(self,ignore_id, num_classes):
79 | self.ignore_id=ignore_id
80 | self.num_classes= num_classes
81 |
82 |
83 | def new_one_hot_converter(self,a):
84 | ncols = self.num_classes+1
85 | out = np.zeros( (a.size,ncols), dtype=np.uint8)
86 | out[np.arange(a.size),a.ravel()] = 1
87 | out.shape = a.shape + (ncols,)
88 | return out
89 |
90 | def __call__(self,img):
91 |
92 | img_arr = np.array(img)
93 | img_arr[img_arr==self.ignore_id]=self.num_classes
94 |
95 | if cfg.STRICTBORDERCLASS != None:
96 | one_hot_orig = self.new_one_hot_converter(img_arr)
97 | mask = np.zeros((img_arr.shape[0],img_arr.shape[1]))
98 | for cls in cfg.STRICTBORDERCLASS:
99 | mask = np.logical_or(mask,(img_arr == cls))
100 | one_hot = 0
101 |
102 | border = cfg.BORDER_WINDOW
103 | if (cfg.REDUCE_BORDER_EPOCH !=-1 and cfg.EPOCH > cfg.REDUCE_BORDER_EPOCH):
104 | border = border // 2
105 | border_prediction = find_boundaries(img_arr, mode='thick').astype(np.uint8)
106 |
107 | for i in range(-border,border+1):
108 | for j in range(-border, border+1):
109 | shifted= shift(img_arr,(i,j), cval=self.num_classes)
110 | one_hot += self.new_one_hot_converter(shifted)
111 |
112 | one_hot[one_hot>1] = 1
113 |
114 | if cfg.STRICTBORDERCLASS != None:
115 | one_hot = np.where(np.expand_dims(mask,2), one_hot_orig, one_hot)
116 |
117 | one_hot = np.moveaxis(one_hot,-1,0)
118 |
119 |
120 | if (cfg.REDUCE_BORDER_EPOCH !=-1 and cfg.EPOCH > cfg.REDUCE_BORDER_EPOCH):
121 | one_hot = np.where(border_prediction,2*one_hot,1*one_hot)
122 | # print(one_hot.shape)
123 | return torch.from_numpy(one_hot).byte()
124 |
125 | class ResizeHeight(object):
126 | def __init__(self, size, interpolation=Image.BILINEAR):
127 | self.target_h = size
128 | self.interpolation = interpolation
129 |
130 | def __call__(self, img):
131 | w, h = img.size
132 | target_w = int(w / h * self.target_h)
133 | return img.resize((target_w, self.target_h), self.interpolation)
134 |
135 |
136 | class FreeScale(object):
137 | def __init__(self, size, interpolation=Image.BILINEAR):
138 | self.size = tuple(reversed(size)) # size: (h, w)
139 | self.interpolation = interpolation
140 |
141 | def __call__(self, img):
142 | return img.resize(self.size, self.interpolation)
143 |
144 |
145 | class FlipChannels(object):
146 | """
147 | Flip around the x-axis
148 | """
149 | def __call__(self, img):
150 | img = np.array(img)[:, :, ::-1]
151 | return Image.fromarray(img.astype(np.uint8))
152 |
153 |
154 | class RandomGaussianBlur(object):
155 | """
156 | Apply Gaussian Blur
157 | """
158 | def __call__(self, img):
159 | sigma = 0.15 + random.random() * 1.15
160 | blurred_img = gaussian(np.array(img), sigma=sigma, multichannel=True)
161 | blurred_img *= 255
162 | return Image.fromarray(blurred_img.astype(np.uint8))
163 |
164 |
165 | class RandomBilateralBlur(object):
166 | """
167 | Apply Bilateral Filtering
168 |
169 | """
170 | def __call__(self, img):
171 | sigma = random.uniform(0.05,0.75)
172 | blurred_img = denoise_bilateral(np.array(img), sigma_spatial=sigma, multichannel=True)
173 | blurred_img *= 255
174 | return Image.fromarray(blurred_img.astype(np.uint8))
175 |
176 | def _is_pil_image(img):
177 | if accimage is not None:
178 | return isinstance(img, (Image.Image, accimage.Image))
179 | else:
180 | return isinstance(img, Image.Image)
181 |
182 |
183 | def adjust_brightness(img, brightness_factor):
184 | """Adjust brightness of an Image.
185 |
186 | Args:
187 | img (PIL Image): PIL Image to be adjusted.
188 | brightness_factor (float): How much to adjust the brightness. Can be
189 | any non negative number. 0 gives a black image, 1 gives the
190 | original image while 2 increases the brightness by a factor of 2.
191 |
192 | Returns:
193 | PIL Image: Brightness adjusted image.
194 | """
195 | if not _is_pil_image(img):
196 | raise TypeError('img should be PIL Image. Got {}'.format(type(img)))
197 |
198 | enhancer = ImageEnhance.Brightness(img)
199 | img = enhancer.enhance(brightness_factor)
200 | return img
201 |
202 |
203 | def adjust_contrast(img, contrast_factor):
204 | """Adjust contrast of an Image.
205 |
206 | Args:
207 | img (PIL Image): PIL Image to be adjusted.
208 | contrast_factor (float): How much to adjust the contrast. Can be any
209 | non negative number. 0 gives a solid gray image, 1 gives the
210 | original image while 2 increases the contrast by a factor of 2.
211 |
212 | Returns:
213 | PIL Image: Contrast adjusted image.
214 | """
215 | if not _is_pil_image(img):
216 | raise TypeError('img should be PIL Image. Got {}'.format(type(img)))
217 |
218 | enhancer = ImageEnhance.Contrast(img)
219 | img = enhancer.enhance(contrast_factor)
220 | return img
221 |
222 |
223 | def adjust_saturation(img, saturation_factor):
224 | """Adjust color saturation of an image.
225 |
226 | Args:
227 | img (PIL Image): PIL Image to be adjusted.
228 | saturation_factor (float): How much to adjust the saturation. 0 will
229 | give a black and white image, 1 will give the original image while
230 | 2 will enhance the saturation by a factor of 2.
231 |
232 | Returns:
233 | PIL Image: Saturation adjusted image.
234 | """
235 | if not _is_pil_image(img):
236 | raise TypeError('img should be PIL Image. Got {}'.format(type(img)))
237 |
238 | enhancer = ImageEnhance.Color(img)
239 | img = enhancer.enhance(saturation_factor)
240 | return img
241 |
242 |
243 | def adjust_hue(img, hue_factor):
244 | """Adjust hue of an image.
245 |
246 | The image hue is adjusted by converting the image to HSV and
247 | cyclically shifting the intensities in the hue channel (H).
248 | The image is then converted back to original image mode.
249 |
250 | `hue_factor` is the amount of shift in H channel and must be in the
251 | interval `[-0.5, 0.5]`.
252 |
253 | See https://en.wikipedia.org/wiki/Hue for more details on Hue.
254 |
255 | Args:
256 | img (PIL Image): PIL Image to be adjusted.
257 | hue_factor (float): How much to shift the hue channel. Should be in
258 | [-0.5, 0.5]. 0.5 and -0.5 give complete reversal of hue channel in
259 | HSV space in positive and negative direction respectively.
260 | 0 means no shift. Therefore, both -0.5 and 0.5 will give an image
261 | with complementary colors while 0 gives the original image.
262 |
263 | Returns:
264 | PIL Image: Hue adjusted image.
265 | """
266 | if not(-0.5 <= hue_factor <= 0.5):
267 | raise ValueError('hue_factor is not in [-0.5, 0.5].'.format(hue_factor))
268 |
269 | if not _is_pil_image(img):
270 | raise TypeError('img should be PIL Image. Got {}'.format(type(img)))
271 |
272 | input_mode = img.mode
273 | if input_mode in {'L', '1', 'I', 'F'}:
274 | return img
275 |
276 | h, s, v = img.convert('HSV').split()
277 |
278 | np_h = np.array(h, dtype=np.uint8)
279 | # uint8 addition take cares of rotation across boundaries
280 | with np.errstate(over='ignore'):
281 | np_h += np.uint8(hue_factor * 255)
282 | h = Image.fromarray(np_h, 'L')
283 |
284 | img = Image.merge('HSV', (h, s, v)).convert(input_mode)
285 | return img
286 |
287 |
288 | class ColorJitter(object):
289 | """Randomly change the brightness, contrast and saturation of an image.
290 |
291 | Args:
292 | brightness (float): How much to jitter brightness. brightness_factor
293 | is chosen uniformly from [max(0, 1 - brightness), 1 + brightness].
294 | contrast (float): How much to jitter contrast. contrast_factor
295 | is chosen uniformly from [max(0, 1 - contrast), 1 + contrast].
296 | saturation (float): How much to jitter saturation. saturation_factor
297 | is chosen uniformly from [max(0, 1 - saturation), 1 + saturation].
298 | hue(float): How much to jitter hue. hue_factor is chosen uniformly from
299 | [-hue, hue]. Should be >=0 and <= 0.5.
300 | """
301 | def __init__(self, brightness=0, contrast=0, saturation=0, hue=0):
302 | self.brightness = brightness
303 | self.contrast = contrast
304 | self.saturation = saturation
305 | self.hue = hue
306 |
307 | @staticmethod
308 | def get_params(brightness, contrast, saturation, hue):
309 | """Get a randomized transform to be applied on image.
310 |
311 | Arguments are same as that of __init__.
312 |
313 | Returns:
314 | Transform which randomly adjusts brightness, contrast and
315 | saturation in a random order.
316 | """
317 | transforms = []
318 | if brightness > 0:
319 | brightness_factor = np.random.uniform(max(0, 1 - brightness), 1 + brightness)
320 | transforms.append(
321 | torch_tr.Lambda(lambda img: adjust_brightness(img, brightness_factor)))
322 |
323 | if contrast > 0:
324 | contrast_factor = np.random.uniform(max(0, 1 - contrast), 1 + contrast)
325 | transforms.append(
326 | torch_tr.Lambda(lambda img: adjust_contrast(img, contrast_factor)))
327 |
328 | if saturation > 0:
329 | saturation_factor = np.random.uniform(max(0, 1 - saturation), 1 + saturation)
330 | transforms.append(
331 | torch_tr.Lambda(lambda img: adjust_saturation(img, saturation_factor)))
332 |
333 | if hue > 0:
334 | hue_factor = np.random.uniform(-hue, hue)
335 | transforms.append(
336 | torch_tr.Lambda(lambda img: adjust_hue(img, hue_factor)))
337 |
338 | np.random.shuffle(transforms)
339 | transform = torch_tr.Compose(transforms)
340 |
341 | return transform
342 |
343 | def __call__(self, img):
344 | """
345 | Args:
346 | img (PIL Image): Input image.
347 |
348 | Returns:
349 | PIL Image: Color jittered image.
350 | """
351 | transform = self.get_params(self.brightness, self.contrast,
352 | self.saturation, self.hue)
353 | return transform(img)
354 |
--------------------------------------------------------------------------------
/utils/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/geohot/semantic-segmentation/d5b6ec7ec9d4e296fc0e1aa85163ed0e98f54f92/utils/__init__.py
--------------------------------------------------------------------------------
/utils/attr_dict.py:
--------------------------------------------------------------------------------
1 | """
2 | # Code adapted from:
3 | # https://github.com/facebookresearch/Detectron/blob/master/detectron/utils/collections.py
4 |
5 | Source License
6 | # Copyright (c) 2017-present, Facebook, Inc.
7 | #
8 | # Licensed under the Apache License, Version 2.0 (the "License");
9 | # you may not use this file except in compliance with the License.
10 | # You may obtain a copy of the License at
11 | #
12 | # http://www.apache.org/licenses/LICENSE-2.0
13 | #
14 | # Unless required by applicable law or agreed to in writing, software
15 | # distributed under the License is distributed on an "AS IS" BASIS,
16 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
17 | # See the License for the specific language governing permissions and
18 | # limitations under the License.
19 | ##############################################################################
20 | #
21 | # Based on:
22 | # --------------------------------------------------------
23 | # Fast R-CNN
24 | # Copyright (c) 2015 Microsoft
25 | # Licensed under The MIT License [see LICENSE for details]
26 | # Written by Ross Girshick
27 | # --------------------------------------------------------
28 | """
29 |
30 | class AttrDict(dict):
31 |
32 | IMMUTABLE = '__immutable__'
33 |
34 | def __init__(self, *args, **kwargs):
35 | super(AttrDict, self).__init__(*args, **kwargs)
36 | self.__dict__[AttrDict.IMMUTABLE] = False
37 |
38 | def __getattr__(self, name):
39 | if name in self.__dict__:
40 | return self.__dict__[name]
41 | elif name in self:
42 | return self[name]
43 | else:
44 | raise AttributeError(name)
45 |
46 | def __setattr__(self, name, value):
47 | if not self.__dict__[AttrDict.IMMUTABLE]:
48 | if name in self.__dict__:
49 | self.__dict__[name] = value
50 | else:
51 | self[name] = value
52 | else:
53 | raise AttributeError(
54 | 'Attempted to set "{}" to "{}", but AttrDict is immutable'.
55 | format(name, value)
56 | )
57 |
58 | def immutable(self, is_immutable):
59 | """Set immutability to is_immutable and recursively apply the setting
60 | to all nested AttrDicts.
61 | """
62 | self.__dict__[AttrDict.IMMUTABLE] = is_immutable
63 | # Recursively set immutable state
64 | for v in self.__dict__.values():
65 | if isinstance(v, AttrDict):
66 | v.immutable(is_immutable)
67 | for v in self.values():
68 | if isinstance(v, AttrDict):
69 | v.immutable(is_immutable)
70 |
71 | def is_immutable(self):
72 | return self.__dict__[AttrDict.IMMUTABLE]
73 |
--------------------------------------------------------------------------------
/utils/misc.py:
--------------------------------------------------------------------------------
1 | """
2 | Miscellanous Functions
3 | """
4 |
5 | import sys
6 | import re
7 | import os
8 | import shutil
9 | import torch
10 | from datetime import datetime
11 | import logging
12 | from subprocess import call
13 | import shlex
14 | from tensorboardX import SummaryWriter
15 | import numpy as np
16 | import torchvision.transforms as standard_transforms
17 | import torchvision.utils as vutils
18 |
19 |
20 | # Create unique output dir name based on non-default command line args
21 | def make_exp_name(args, parser):
22 | exp_name = '{}-{}'.format(args.dataset[:4], args.arch[:])
23 | dict_args = vars(args)
24 |
25 | # sort so that we get a consistent directory name
26 | argnames = sorted(dict_args)
27 | ignorelist = ['exp', 'arch','prev_best_filepath', 'lr_schedule', 'max_cu_epoch', 'max_epoch',
28 | 'strict_bdr_cls', 'world_size', 'tb_path','best_record', 'test_mode', 'ckpt']
29 | # build experiment name with non-default args
30 | for argname in argnames:
31 | if dict_args[argname] != parser.get_default(argname):
32 | if argname in ignorelist:
33 | continue
34 | if argname == 'snapshot':
35 | arg_str = 'PT'
36 | argname = ''
37 | elif argname == 'nosave':
38 | arg_str = ''
39 | argname=''
40 | elif argname == 'freeze_trunk':
41 | argname = ''
42 | arg_str = 'ft'
43 | elif argname == 'syncbn':
44 | argname = ''
45 | arg_str = 'sbn'
46 | elif argname == 'jointwtborder':
47 | argname = ''
48 | arg_str = 'rlx_loss'
49 | elif isinstance(dict_args[argname], bool):
50 | arg_str = 'T' if dict_args[argname] else 'F'
51 | else:
52 | arg_str = str(dict_args[argname])[:7]
53 | if argname is not '':
54 | exp_name += '_{}_{}'.format(str(argname), arg_str)
55 | else:
56 | exp_name += '_{}'.format(arg_str)
57 | # clean special chars out exp_name = re.sub(r'[^A-Za-z0-9_\-]+', '', exp_name)
58 | return exp_name
59 |
60 | def fast_hist(label_pred, label_true, num_classes):
61 | mask = (label_true >= 0) & (label_true < num_classes)
62 | hist = np.bincount(
63 | num_classes * label_true[mask].astype(int) +
64 | label_pred[mask], minlength=num_classes ** 2).reshape(num_classes, num_classes)
65 | return hist
66 |
67 | def per_class_iu(hist):
68 | return np.diag(hist) / (hist.sum(1) + hist.sum(0) - np.diag(hist))
69 |
70 | def save_log(prefix, output_dir, date_str, rank=0):
71 | fmt = '%(asctime)s.%(msecs)03d %(message)s'
72 | date_fmt = '%m-%d %H:%M:%S'
73 | filename = os.path.join(output_dir, prefix + '_' + date_str +'_rank_' + str(rank) +'.log')
74 | print("Logging :", filename)
75 | logging.basicConfig(level=logging.INFO, format=fmt, datefmt=date_fmt,
76 | filename=filename, filemode='w')
77 | console = logging.StreamHandler()
78 | console.setLevel(logging.INFO)
79 | formatter = logging.Formatter(fmt=fmt, datefmt=date_fmt)
80 | console.setFormatter(formatter)
81 | if rank == 0:
82 | logging.getLogger('').addHandler(console)
83 | else:
84 | fh = logging.FileHandler(filename)
85 | logging.getLogger('').addHandler(fh)
86 |
87 |
88 |
89 | def prep_experiment(args, parser):
90 | """
91 | Make output directories, setup logging, Tensorboard, snapshot code.
92 | """
93 | ckpt_path = args.ckpt
94 | tb_path = args.tb_path
95 | exp_name = make_exp_name(args, parser)
96 | args.exp_path = os.path.join(ckpt_path, args.exp, exp_name)
97 | args.tb_exp_path = os.path.join(tb_path, args.exp, exp_name)
98 | args.ngpu = torch.cuda.device_count()
99 | args.date_str = str(datetime.now().strftime('%Y_%m_%d_%H_%M_%S'))
100 | args.best_record = {'epoch': -1, 'iter': 0, 'val_loss': 1e10, 'acc': 0,
101 | 'acc_cls': 0, 'mean_iu': 0, 'fwavacc': 0}
102 | args.last_record = {}
103 | if args.local_rank == 0:
104 | os.makedirs(args.exp_path, exist_ok=True)
105 | os.makedirs(args.tb_exp_path, exist_ok=True)
106 | save_log('log', args.exp_path, args.date_str, rank=args.local_rank)
107 | open(os.path.join(args.exp_path, args.date_str + '.txt'), 'w').write(
108 | str(args) + '\n\n')
109 | writer = SummaryWriter(logdir=args.tb_exp_path, comment=args.tb_tag)
110 | return writer
111 | return None
112 |
113 | def evaluate_eval_for_inference(hist, dataset=None):
114 | """
115 | Modified IOU mechanism for on-the-fly IOU calculations ( prevents memory overflow for
116 | large dataset) Only applies to eval/eval.py
117 | """
118 | # axis 0: gt, axis 1: prediction
119 | acc = np.diag(hist).sum() / hist.sum()
120 | acc_cls = np.diag(hist) / hist.sum(axis=1)
121 | acc_cls = np.nanmean(acc_cls)
122 | iu = np.diag(hist) / (hist.sum(axis=1) + hist.sum(axis=0) - np.diag(hist))
123 |
124 | print_evaluate_results(hist, iu, dataset=dataset)
125 | freq = hist.sum(axis=1) / hist.sum()
126 | mean_iu = np.nanmean(iu)
127 | logging.info('mean {}'.format(mean_iu))
128 | fwavacc = (freq[freq > 0] * iu[freq > 0]).sum()
129 | return acc, acc_cls, mean_iu, fwavacc
130 |
131 |
132 |
133 | def evaluate_eval(args, net, optimizer, val_loss, hist, dump_images, writer, epoch=0, dataset=None, ):
134 | """
135 | Modified IOU mechanism for on-the-fly IOU calculations ( prevents memory overflow for
136 | large dataset) Only applies to eval/eval.py
137 | """
138 | # axis 0: gt, axis 1: prediction
139 | acc = np.diag(hist).sum() / hist.sum()
140 | acc_cls = np.diag(hist) / hist.sum(axis=1)
141 | acc_cls = np.nanmean(acc_cls)
142 | iu = np.diag(hist) / (hist.sum(axis=1) + hist.sum(axis=0) - np.diag(hist))
143 |
144 | print_evaluate_results(hist, iu, dataset)
145 | freq = hist.sum(axis=1) / hist.sum()
146 | mean_iu = np.nanmean(iu)
147 | logging.info('mean {}'.format(mean_iu))
148 | fwavacc = (freq[freq > 0] * iu[freq > 0]).sum()
149 |
150 | # update latest snapshot
151 | if 'mean_iu' in args.last_record:
152 | last_snapshot = 'last_epoch_{}_mean-iu_{:.5f}.pth'.format(
153 | args.last_record['epoch'], args.last_record['mean_iu'])
154 | last_snapshot = os.path.join(args.exp_path, last_snapshot)
155 | try:
156 | os.remove(last_snapshot)
157 | except OSError:
158 | pass
159 | last_snapshot = 'last_epoch_{}_mean-iu_{:.5f}.pth'.format(epoch, mean_iu)
160 | last_snapshot = os.path.join(args.exp_path, last_snapshot)
161 | args.last_record['mean_iu'] = mean_iu
162 | args.last_record['epoch'] = epoch
163 |
164 | torch.cuda.synchronize()
165 |
166 | torch.save({
167 | 'state_dict': net.state_dict(),
168 | 'optimizer': optimizer.state_dict(),
169 | 'epoch': epoch,
170 | 'mean_iu': mean_iu,
171 | 'command': ' '.join(sys.argv[1:])
172 | }, last_snapshot)
173 |
174 | # update best snapshot
175 | if mean_iu > args.best_record['mean_iu'] :
176 | # remove old best snapshot
177 | if args.best_record['epoch'] != -1:
178 | best_snapshot = 'best_epoch_{}_mean-iu_{:.5f}.pth'.format(
179 | args.best_record['epoch'], args.best_record['mean_iu'])
180 | best_snapshot = os.path.join(args.exp_path, best_snapshot)
181 | assert os.path.exists(best_snapshot), \
182 | 'cant find old snapshot {}'.format(best_snapshot)
183 | os.remove(best_snapshot)
184 |
185 |
186 | # save new best
187 | args.best_record['val_loss'] = val_loss.avg
188 | args.best_record['epoch'] = epoch
189 | args.best_record['acc'] = acc
190 | args.best_record['acc_cls'] = acc_cls
191 | args.best_record['mean_iu'] = mean_iu
192 | args.best_record['fwavacc'] = fwavacc
193 |
194 | best_snapshot = 'best_epoch_{}_mean-iu_{:.5f}.pth'.format(
195 | args.best_record['epoch'], args.best_record['mean_iu'])
196 | best_snapshot = os.path.join(args.exp_path, best_snapshot)
197 | shutil.copyfile(last_snapshot, best_snapshot)
198 |
199 |
200 | to_save_dir = os.path.join(args.exp_path, 'best_images')
201 | os.makedirs(to_save_dir, exist_ok=True)
202 | val_visual = []
203 |
204 | idx = 0
205 |
206 | visualize = standard_transforms.Compose([
207 | standard_transforms.Scale(384),
208 | standard_transforms.ToTensor()
209 | ])
210 | for bs_idx, bs_data in enumerate(dump_images):
211 | for local_idx, data in enumerate(zip(bs_data[0], bs_data[1],bs_data[2])):
212 | gt_pil = args.dataset_cls.colorize_mask(data[0].cpu().numpy())
213 | predictions_pil = args.dataset_cls.colorize_mask(data[1].cpu().numpy())
214 | img_name = data[2]
215 |
216 | prediction_fn = '{}_prediction.png'.format(img_name)
217 | predictions_pil.save(os.path.join(to_save_dir, prediction_fn))
218 | gt_fn = '{}_gt.png'.format(img_name)
219 | gt_pil.save(os.path.join(to_save_dir, gt_fn))
220 | val_visual.extend([visualize(gt_pil.convert('RGB')),
221 | visualize(predictions_pil.convert('RGB'))])
222 | if local_idx >= 9:
223 | break
224 | val_visual = torch.stack(val_visual, 0)
225 | val_visual = vutils.make_grid(val_visual, nrow=10, padding=5)
226 | writer.add_image('imgs', val_visual, epoch )
227 |
228 | logging.info('-' * 107)
229 | fmt_str = '[epoch %d], [val loss %.5f], [acc %.5f], [acc_cls %.5f], ' +\
230 | '[mean_iu %.5f], [fwavacc %.5f]'
231 | logging.info(fmt_str % (epoch, val_loss.avg, acc, acc_cls, mean_iu, fwavacc))
232 | fmt_str = 'best record: [val loss %.5f], [acc %.5f], [acc_cls %.5f], ' +\
233 | '[mean_iu %.5f], [fwavacc %.5f], [epoch %d], '
234 | logging.info(fmt_str % (args.best_record['val_loss'], args.best_record['acc'],
235 | args.best_record['acc_cls'], args.best_record['mean_iu'],
236 | args.best_record['fwavacc'], args.best_record['epoch']))
237 | logging.info('-' * 107)
238 |
239 | # tensorboard logging of validation phase metrics
240 |
241 | writer.add_scalar('training/acc', acc, epoch)
242 | writer.add_scalar('training/acc_cls', acc_cls, epoch)
243 | writer.add_scalar('training/mean_iu', mean_iu, epoch)
244 | writer.add_scalar('training/val_loss', val_loss.avg, epoch)
245 |
246 |
247 |
248 | def fast_hist(label_pred, label_true, num_classes):
249 | mask = (label_true >= 0) & (label_true < num_classes)
250 | hist = np.bincount(
251 | num_classes * label_true[mask].astype(int) +
252 | label_pred[mask], minlength=num_classes ** 2).reshape(num_classes, num_classes)
253 | return hist
254 |
255 |
256 |
257 | def print_evaluate_results(hist, iu, dataset=None):
258 | # fixme: Need to refactor this dict
259 | try:
260 | id2cat = dataset.id2cat
261 | except:
262 | id2cat = {i: i for i in range(dataset.num_classes)}
263 | iu_false_positive = hist.sum(axis=1) - np.diag(hist)
264 | iu_false_negative = hist.sum(axis=0) - np.diag(hist)
265 | iu_true_positive = np.diag(hist)
266 |
267 | logging.info('IoU:')
268 | logging.info('label_id label iU Precision Recall TP FP FN')
269 | for idx, i in enumerate(iu):
270 | # Format all of the strings:
271 | idx_string = "{:2d}".format(idx)
272 | class_name = "{:>13}".format(id2cat[idx]) if idx in id2cat else ''
273 | iu_string = '{:5.2f}'.format(i * 100)
274 | total_pixels = hist.sum()
275 | tp = '{:5.2f}'.format(100 * iu_true_positive[idx] / total_pixels)
276 | fp = '{:5.2f}'.format(
277 | iu_false_positive[idx] / iu_true_positive[idx])
278 | fn = '{:5.2f}'.format(iu_false_negative[idx] / iu_true_positive[idx])
279 | precision = '{:5.2f}'.format(
280 | iu_true_positive[idx] / (iu_true_positive[idx] + iu_false_positive[idx]))
281 | recall = '{:5.2f}'.format(
282 | iu_true_positive[idx] / (iu_true_positive[idx] + iu_false_negative[idx]))
283 | logging.info('{} {} {} {} {} {} {} {}'.format(
284 | idx_string, class_name, iu_string, precision, recall, tp, fp, fn))
285 |
286 |
287 |
288 |
289 | class AverageMeter(object):
290 |
291 | def __init__(self):
292 | self.reset()
293 |
294 | def reset(self):
295 | self.val = 0
296 | self.avg = 0
297 | self.sum = 0
298 | self.count = 0
299 |
300 | def update(self, val, n=1):
301 | self.val = val
302 | self.sum += val * n
303 | self.count += n
304 | self.avg = self.sum / self.count
305 |
--------------------------------------------------------------------------------
/utils/my_data_parallel.py:
--------------------------------------------------------------------------------
1 |
2 | """
3 | # Code adapted from:
4 | # https://github.com/pytorch/pytorch/blob/master/torch/nn/parallel/data_parallel.py
5 | #
6 | # BSD 3-Clause License
7 | #
8 | # Copyright (c) 2017,
9 | # All rights reserved.
10 | #
11 | # Redistribution and use in source and binary forms, with or without
12 | # modification, are permitted provided that the following conditions are met:
13 | #
14 | # * Redistributions of source code must retain the above copyright notice, this
15 | # list of conditions and the following disclaimer.
16 | #
17 | # * Redistributions in binary form must reproduce the above copyright notice,
18 | # this list of conditions and the following disclaimer in the documentation
19 | # and/or other materials provided with the distribution.
20 | #
21 | # * Neither the name of the copyright holder nor the names of its
22 | # contributors may be used to endorse or promote products derived from
23 | # this software without specific prior written permission.
24 | #
25 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
26 | # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
27 | # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
28 | # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
29 | # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
30 | # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
31 | # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
32 | # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
33 | # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
34 | # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.s
35 | """
36 |
37 |
38 | import operator
39 | import torch
40 | import warnings
41 | from torch.nn.modules import Module
42 | from torch.nn.parallel.scatter_gather import scatter_kwargs, gather
43 | from torch.nn.parallel.replicate import replicate
44 | from torch.nn.parallel.parallel_apply import parallel_apply
45 |
46 |
47 | def _check_balance(device_ids):
48 | imbalance_warn = """
49 | There is an imbalance between your GPUs. You may want to exclude GPU {} which
50 | has less than 75% of the memory or cores of GPU {}. You can do so by setting
51 | the device_ids argument to DataParallel, or by setting the CUDA_VISIBLE_DEVICES
52 | environment variable."""
53 |
54 | dev_props = [torch.cuda.get_device_properties(i) for i in device_ids]
55 |
56 | def warn_imbalance(get_prop):
57 | values = [get_prop(props) for props in dev_props]
58 | min_pos, min_val = min(enumerate(values), key=operator.itemgetter(1))
59 | max_pos, max_val = max(enumerate(values), key=operator.itemgetter(1))
60 | if min_val / max_val < 0.75:
61 | warnings.warn(imbalance_warn.format(device_ids[min_pos], device_ids[max_pos]))
62 | return True
63 | return False
64 |
65 | if warn_imbalance(lambda props: props.total_memory):
66 | return
67 | if warn_imbalance(lambda props: props.multi_processor_count):
68 | return
69 |
70 |
71 |
72 | def data_parallel(module, inputs, device_ids=None, output_device=None, dim=0, module_kwargs=None, gather=True):
73 | """
74 | Evaluates module(input) in parallel across the GPUs given in device_ids.
75 | This is the functional version of the DataParallel module.
76 | Args:
77 | module: the module to evaluate in parallel
78 | inputs: inputs to the module
79 | device_ids: GPU ids on which to replicate module
80 | output_device: GPU location of the output Use -1 to indicate the CPU.
81 | (default: device_ids[0])
82 | Returns:
83 | a Tensor containing the result of module(input) located on
84 | output_device
85 | """
86 | if not isinstance(inputs, tuple):
87 | inputs = (inputs,)
88 |
89 | if device_ids is None:
90 | device_ids = list(range(torch.cuda.device_count()))
91 |
92 | if output_device is None:
93 | output_device = device_ids[0]
94 |
95 | inputs, module_kwargs = scatter_kwargs(inputs, module_kwargs, device_ids, dim)
96 | if len(device_ids) == 1:
97 | return module(*inputs[0], **module_kwargs[0])
98 | used_device_ids = device_ids[:len(inputs)]
99 | replicas = replicate(module, used_device_ids)
100 | outputs = parallel_apply(replicas, inputs, module_kwargs, used_device_ids)
101 | if gather:
102 | return gather(outputs, output_device, dim)
103 | else:
104 | return outputs
105 |
106 |
107 |
108 | class MyDataParallel(Module):
109 | """
110 | Implements data parallelism at the module level.
111 | This container parallelizes the application of the given module by
112 | splitting the input across the specified devices by chunking in the batch
113 | dimension. In the forward pass, the module is replicated on each device,
114 | and each replica handles a portion of the input. During the backwards
115 | pass, gradients from each replica are summed into the original module.
116 | The batch size should be larger than the number of GPUs used.
117 | See also: :ref:`cuda-nn-dataparallel-instead`
118 | Arbitrary positional and keyword inputs are allowed to be passed into
119 | DataParallel EXCEPT Tensors. All tensors will be scattered on dim
120 | specified (default 0). Primitive types will be broadcasted, but all
121 | other types will be a shallow copy and can be corrupted if written to in
122 | the model's forward pass.
123 | .. warning::
124 | Forward and backward hooks defined on :attr:`module` and its submodules
125 | will be invoked ``len(device_ids)`` times, each with inputs located on
126 | a particular device. Particularly, the hooks are only guaranteed to be
127 | executed in correct order with respect to operations on corresponding
128 | devices. For example, it is not guaranteed that hooks set via
129 | :meth:`~torch.nn.Module.register_forward_pre_hook` be executed before
130 | `all` ``len(device_ids)`` :meth:`~torch.nn.Module.forward` calls, but
131 | that each such hook be executed before the corresponding
132 | :meth:`~torch.nn.Module.forward` call of that device.
133 | .. warning::
134 | When :attr:`module` returns a scalar (i.e., 0-dimensional tensor) in
135 | :func:`forward`, this wrapper will return a vector of length equal to
136 | number of devices used in data parallelism, containing the result from
137 | each device.
138 | .. note::
139 | There is a subtlety in using the
140 | ``pack sequence -> recurrent network -> unpack sequence`` pattern in a
141 | :class:`~torch.nn.Module` wrapped in :class:`~torch.nn.DataParallel`.
142 | See :ref:`pack-rnn-unpack-with-data-parallelism` section in FAQ for
143 | details.
144 | Args:
145 | module: module to be parallelized
146 | device_ids: CUDA devices (default: all devices)
147 | output_device: device location of output (default: device_ids[0])
148 | Attributes:
149 | module (Module): the module to be parallelized
150 | Example::
151 | >>> net = torch.nn.DataParallel(model, device_ids=[0, 1, 2])
152 | >>> output = net(input_var)
153 | """
154 |
155 | # TODO: update notes/cuda.rst when this class handles 8+ GPUs well
156 |
157 | def __init__(self, module, device_ids=None, output_device=None, dim=0, gather=True):
158 | super(MyDataParallel, self).__init__()
159 |
160 | if not torch.cuda.is_available():
161 | self.module = module
162 | self.device_ids = []
163 | return
164 |
165 | if device_ids is None:
166 | device_ids = list(range(torch.cuda.device_count()))
167 | if output_device is None:
168 | output_device = device_ids[0]
169 | self.dim = dim
170 | self.module = module
171 | self.device_ids = device_ids
172 | self.output_device = output_device
173 | self.gather_bool = gather
174 |
175 | _check_balance(self.device_ids)
176 |
177 | if len(self.device_ids) == 1:
178 | self.module.cuda(device_ids[0])
179 |
180 | def forward(self, *inputs, **kwargs):
181 | if not self.device_ids:
182 | return self.module(*inputs, **kwargs)
183 | inputs, kwargs = self.scatter(inputs, kwargs, self.device_ids)
184 | if len(self.device_ids) == 1:
185 | return [self.module(*inputs[0], **kwargs[0])]
186 | replicas = self.replicate(self.module, self.device_ids[:len(inputs)])
187 | outputs = self.parallel_apply(replicas, inputs, kwargs)
188 | if self.gather_bool:
189 | return self.gather(outputs, self.output_device)
190 | else:
191 | return outputs
192 |
193 | def replicate(self, module, device_ids):
194 | return replicate(module, device_ids)
195 |
196 | def scatter(self, inputs, kwargs, device_ids):
197 | return scatter_kwargs(inputs, kwargs, device_ids, dim=self.dim)
198 |
199 | def parallel_apply(self, replicas, inputs, kwargs):
200 | return parallel_apply(replicas, inputs, kwargs, self.device_ids[:len(replicas)])
201 |
202 | def gather(self, outputs, output_device):
203 | return gather(outputs, output_device, dim=self.dim)
204 |
205 |
--------------------------------------------------------------------------------