├── .gitignore ├── LICENSE ├── README.md ├── bashscripts └── voc12 │ ├── train_segsort.sh │ ├── train_segsort_mgpu.sh │ └── train_segsort_unsup.sh ├── dataset └── voc12 │ ├── sbd_clsimg.zip │ ├── test.txt │ ├── train+.txt │ ├── train.txt │ ├── trainval.txt │ └── val.txt ├── misc ├── colormapvoc.mat └── main.png ├── network ├── __init__.py ├── common │ ├── __init__.py │ ├── layers.py │ └── resnet_v1.py ├── multigpu │ ├── __init__.py │ ├── layers.py │ ├── resnet_v1.py │ └── utils.py └── segsort │ ├── common_utils.py │ ├── eval_utils.py │ ├── train_utils.py │ └── vis_utils.py ├── pyscripts ├── benchmark │ └── benchmark_by_mIoU.py ├── inference │ ├── extract_prototypes.py │ ├── inference.py │ ├── inference_msc.py │ ├── inference_patch.py │ ├── inference_segsort.py │ ├── inference_segsort_msc.py │ ├── inference_vmf.py │ ├── inference_vmf_embedding.py │ ├── prototype_embedding_fine.py │ ├── prototype_embedding_rgb.py │ ├── prototype_embedding_with_flip.py │ └── prototype_unsup.py └── train │ ├── train_segsort.py │ ├── train_segsort_mgpu.py │ └── train_segsort_unsup.py ├── seg_models ├── __init__.py ├── image_reader.py └── models │ ├── __init__.py │ ├── deeplab.py │ ├── fcn.py │ ├── pspnet.py │ └── pspnet_mgpu.py └── utils ├── __init__.py ├── general.py └── metrics.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | env/ 12 | build/ 13 | develop-eggs/ 14 | dist/ 15 | downloads/ 16 | eggs/ 17 | .eggs/ 18 | lib/ 19 | lib64/ 20 | parts/ 21 | sdist/ 22 | var/ 23 | wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *.cover 47 | .hypothesis/ 48 | 49 | # Translations 50 | *.mo 51 | *.pot 52 | 53 | # Django stuff: 54 | *.log 55 | local_settings.py 56 | 57 | # Flask stuff: 58 | instance/ 59 | .webassets-cache 60 | 61 | # Scrapy stuff: 62 | .scrapy 63 | 64 | # Sphinx documentation 65 | docs/_build/ 66 | 67 | # PyBuilder 68 | target/ 69 | 70 | # Jupyter Notebook 71 | .ipynb_checkpoints 72 | 73 | # pyenv 74 | .python-version 75 | 76 | # celery beat schedule file 77 | celerybeat-schedule 78 | 79 | # SageMath parsed files 80 | *.sage.py 81 | 82 | # dotenv 83 | .env 84 | 85 | # virtualenv 86 | .venv 87 | venv/ 88 | ENV/ 89 | 90 | # Spyder project settings 91 | .spyderproject 92 | .spyproject 93 | 94 | # Rope project settings 95 | .ropeproject 96 | 97 | # mkdocs documentation 98 | /site 99 | 100 | # mypy 101 | .mypy_cache/ 102 | 103 | # cache dir 104 | snapshots/ 105 | logs/ 106 | 107 | # vi/vim cache files 108 | *.swp 109 | *.so 110 | 111 | # html files 112 | index*.html 113 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2019 Jyh-Jing Hwang 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # SegSort: Segmentation by Discriminative Sorting of Segments 2 | 3 | By [Jyh-Jing Hwang](https://jyhjinghwang.github.io/), [Stella X. Yu](http://www1.icsi.berkeley.edu/~stellayu/), [Jianbo Shi](https://www.cis.upenn.edu/~jshi/), Maxwell D. Collins, Tien-Ju Yang, Xiao Zhang, and [Liang-Chieh Chen](http://liangchiehchen.com/) 4 | 5 | 6 | 7 | Almost all existing deep learning approaches for semantic segmentation tackle this task as a pixel-wise classification problem. Yet humans understand a scene not in terms of pixels, but by decomposing it into perceptual groups and structures that are the basic building blocks of recognition. This motivates us to propose an end-to-end pixel-wise metric learning approach that mimics this process. In our approach, the optimal visual representation determines the right segmentation within individual images and associates segments with the same semantic classes across images. The core visual learning problem is therefore to maximize the similarity within segments and minimize the similarity between segments. Given a model trained this way, inference is performed consistently by extracting pixel-wise embeddings and clustering, with the semantic label determined by the majority vote of its nearest neighbors from an annotated set. 8 | 9 | As a result, we present the SegSort, as a first attempt using deep learning for unsupervised semantic segmentation, achieving 76% performance of its supervised counterpart. When supervision is available, SegSort shows consistent improvements over conventional approaches based on pixel-wise softmax training. Additionally, our approach produces more precise boundaries and consistent region predictions. The proposed SegSort further produces an interpretable result, as each choice of label can be easily understood from the retrieved nearest segments. 10 | 11 | SegSort is published in ICCV 2019, see [our paper](https://arxiv.org/abs/1910.06962) for more details. 12 | 13 | 14 | ## Codebase 15 | This release of SegSort is based on our previous published codebase [AAF](https://github.com/twke18/Adaptive_Affinity_Fields) in ECCV 2018. It is also easy to integrate SegSort modules [network/segsort/](https://github.com/jyhjinghwang/SegSort/tree/master/network/segsort) with the popular codebase [DeepLab](https://github.com/tensorflow/models/tree/master/research/deeplab). 16 | 17 | ## Prerequisites 18 | 19 | 1. Linux 20 | 2. Python2.7 or Python3 (>=3.5) 21 | 3. Cuda 8.0 and Cudnn 6 22 | 23 | ## Required Python Packages 24 | 25 | 1. tensorflow 1.X 26 | 2. numpy 27 | 3. scipy 28 | 4. tqdm 29 | 5. PIL 30 | 6. opencv 31 | 32 | ## Data Preparation 33 | 34 | * [PASCAL VOC 2012](http://host.robots.ox.ac.uk/pascal/VOC/voc2012/) 35 | * Augmented PASCAL VOC training set by [SBD](http://home.bharathh.info/pubs/codes/SBD/download.html). We process the ground truth masks, which are placed as [SegSort/dataset/voc12/sbd_clsimg.zip](https://github.com/jyhjinghwang/SegSort/blob/master/dataset/voc12/sbd_clsimg.zip). Please unzip it and put it besides the VOC2012/ folder as sbd/dataset/clsimg/. 36 | * The ground truth semantic segmentation masks are reformatted as grayscale images, or you can download them [here](https://www.dropbox.com/sh/fd2m7s87gk7jeyh/AAC6tN6syhFKmEaDYCwUIgnXa?dl=0). Please put them under the VOC2012/ folder. 37 | * The oversegmentation masks (from contours) can be produced by combining any contour detectors with gPb-owt-ucm. We provide the HED-owt-ucm masks [here](https://www.dropbox.com/sh/fd2m7s87gk7jeyh/AAC6tN6syhFKmEaDYCwUIgnXa?dl=0). Please put them under the VOC2012/ folder. 38 | * Dataset folder structure: 39 | 40 | sbd/ 41 | - dataset/ 42 | - clsimg/ 43 | 44 | VOC2012/ 45 | - JPEGImages/ 46 | - segcls/ 47 | - hed/ 48 | 49 | ## ImageNet Pre-Trained Models 50 | 51 | Download ResNet101.v1 from [Tensorflow-Slim](https://github.com/tensorflow/models/tree/master/research/slim). 52 | Please put it under a new directory SegSort/snapshots/imagenet/trained/. 53 | 54 | We also provide our SegSort models (supervised/unsupervised) trained on PASCAL VOC and results [here](https://www.dropbox.com/sh/fd2m7s87gk7jeyh/AAC6tN6syhFKmEaDYCwUIgnXa?dl=0). 55 | 56 | ## Bashscripts to Get Started 57 | 58 | * SegSort (Single-GPU and fast training) 59 | ``` 60 | source bashscripts/voc12/train_segsort.sh 61 | ``` 62 | 63 | * SegSort (Multi-GPUs) 64 | ``` 65 | source bashscripts/voc12/train_segsort_mgpu.sh 66 | ``` 67 | 68 | * Unsupervised SegSort (Single-GPU) 69 | ``` 70 | source bashscripts/voc12/train_segsort_unsup.sh 71 | ``` 72 | 73 | * Baseline Models: Please refer to our previous codebase [AAF](https://github.com/twke18/Adaptive_Affinity_Fields). 74 | 75 | 76 | ## Citation 77 | If you find this code useful for your research, please consider citing our paper [SegSort: Segmentation by Discriminative Sorting of Segments](https://arxiv.org/abs/1910.06962). 78 | 79 | ``` 80 | @inproceedings{hwang2019segsort, 81 | title={SegSort: Segmentation by Discriminative Sorting of Segments}, 82 | author={Hwang, Jyh-Jing and Yu, Stella X and Shi, Jianbo and Collins, Maxwell D and Yang, Tien-Ju and Zhang, Xiao and Chen, Liang-Chieh}, 83 | booktitle={Proceedings of the IEEE International Conference on Computer Vision}, 84 | pages={7334--7344}, 85 | year={2019} 86 | } 87 | ``` 88 | 89 | ## License 90 | SegSort is released under the MIT License (refer to the LICENSE file for details). 91 | -------------------------------------------------------------------------------- /bashscripts/voc12/train_segsort.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # This script is used for training, inference and benchmarking 3 | # the PSPNet with SegSort on PASCAL VOC 2012. 4 | # 5 | # Usage: 6 | # # From SegSort/ directory. 7 | # bash bashscripts/voc12/train_segsort.sh 8 | 9 | 10 | # Set up training hyper-parameters. 11 | BATCH_SIZE=8 12 | TRAIN_INPUT_SIZE=336,336 13 | WEIGHT_DECAY=5e-4 14 | ITER_SIZE=1 15 | NUM_STEPS=30000 16 | NUM_CLASSES=21 17 | LEARNING_RATE=2e-3 18 | 19 | # Set up parameters for inference. 20 | INFERENCE_INPUT_SIZE=480,480 21 | INFERENCE_STRIDES=320,320 22 | INFERENCE_SPLIT=val 23 | 24 | # Set up SegSort hyper-parameters. 25 | CONCENTRATION=10 26 | NUM_BANKS=2 27 | EMBEDDING_DIM=32 28 | NUM_CLUSTERS=5 29 | KMEANS_ITERATIONS=10 30 | K_IN_NEAREST_NEIGHBORS=15 31 | 32 | # Set up path for saving models. 33 | SNAPSHOT_DIR=snapshots/voc12/segsort/segsort_lr2e-3_it30000 34 | 35 | # Set up the procedure pipeline. 36 | IS_TRAIN_1=1 37 | IS_PROTOTYPE_1=1 38 | IS_INFERENCE_1=1 39 | IS_INFERENCE_MSC_1=0 40 | IS_BENCHMARK_1=1 41 | IS_TRAIN_2=1 42 | IS_PROTOTYPE_2=1 43 | IS_INFERENCE_2=1 44 | IS_INFERENCE_MSC_2=1 45 | IS_BENCHMARK_2=1 46 | 47 | # Update PYTHONPATH. 48 | export PYTHONPATH=`pwd`:$PYTHONPATH 49 | 50 | # Set up the data directory. 51 | DATAROOT=/ssd/jyh/datasets 52 | 53 | # Train for the 1st stage. 54 | if [ ${IS_TRAIN_1} -eq 1 ]; then 55 | python3 pyscripts/train/train_segsort.py\ 56 | --snapshot_dir ${SNAPSHOT_DIR}/stage1\ 57 | --restore_from snapshots/imagenet/trained/resnet_v1_101.ckpt\ 58 | --data_list dataset/voc12/train+.txt\ 59 | --data_dir ${DATAROOT}/VOCdevkit/\ 60 | --batch_size ${BATCH_SIZE}\ 61 | --save_pred_every ${NUM_STEPS}\ 62 | --update_tb_every 50\ 63 | --input_size ${TRAIN_INPUT_SIZE}\ 64 | --learning_rate ${LEARNING_RATE}\ 65 | --weight_decay ${WEIGHT_DECAY}\ 66 | --iter_size ${ITER_SIZE}\ 67 | --num_classes ${NUM_CLASSES}\ 68 | --num_steps $(($NUM_STEPS+1))\ 69 | --concentration ${CONCENTRATION}\ 70 | --num_banks ${NUM_BANKS}\ 71 | --embedding_dim ${EMBEDDING_DIM}\ 72 | --num_clusters ${NUM_CLUSTERS}\ 73 | --kmeans_iterations ${KMEANS_ITERATIONS}\ 74 | --random_mirror\ 75 | --random_scale\ 76 | --random_crop\ 77 | --not_restore_classifier\ 78 | --is_training 79 | fi 80 | 81 | # Prototype for the 1st stage. 82 | if [ ${IS_PROTOTYPE_1} -eq 1 ]; then 83 | python3 pyscripts/inference/extract_prototypes.py\ 84 | --data_dir ${DATAROOT}/VOCdevkit/\ 85 | --data_list dataset/voc12/train+.txt\ 86 | --restore_from ${SNAPSHOT_DIR}/stage1/model.ckpt-${NUM_STEPS}\ 87 | --input_size ${INFERENCE_INPUT_SIZE}\ 88 | --strides ${INFERENCE_STRIDES}\ 89 | --num_classes ${NUM_CLASSES}\ 90 | --ignore_label 255\ 91 | --embedding_dim ${EMBEDDING_DIM}\ 92 | --num_clusters ${NUM_CLUSTERS}\ 93 | --kmeans_iterations ${KMEANS_ITERATIONS}\ 94 | --save_dir ${SNAPSHOT_DIR}/stage1/results/train+ 95 | fi 96 | 97 | # Single-scale inference for the 1st stage. 98 | if [ ${IS_INFERENCE_1} -eq 1 ]; then 99 | python3 pyscripts/inference/inference_segsort.py\ 100 | --data_dir ${DATAROOT}/VOCdevkit/\ 101 | --data_list dataset/voc12/${INFERENCE_SPLIT}.txt\ 102 | --input_size 720,720\ 103 | --strides ${INFERENCE_STRIDES}\ 104 | --restore_from ${SNAPSHOT_DIR}/stage1/model.ckpt-${NUM_STEPS}\ 105 | --colormap misc/colormapvoc.mat\ 106 | --num_classes ${NUM_CLASSES}\ 107 | --ignore_label 255\ 108 | --embedding_dim ${EMBEDDING_DIM}\ 109 | --num_clusters ${NUM_CLUSTERS}\ 110 | --kmeans_iterations ${KMEANS_ITERATIONS}\ 111 | --k_in_nearest_neighbors ${K_IN_NEAREST_NEIGHBORS}\ 112 | --save_dir ${SNAPSHOT_DIR}/stage1/results/${INFERENCE_SPLIT}\ 113 | --prototype_dir ${SNAPSHOT_DIR}/stage1/results/train+/prototypes 114 | fi 115 | 116 | # Multi-scale inference for the 1st stage. 117 | if [ ${IS_INFERENCE_MSC_1} -eq 1 ]; then 118 | python3 pyscripts/inference/inference_segsort_msc.py\ 119 | --data_dir ${DATAROOT}/VOCdevkit/\ 120 | --data_list dataset/voc12/${INFERENCE_SPLIT}.txt\ 121 | --input_size ${INFERENCE_INPUT_SIZE}\ 122 | --strides ${INFERENCE_STRIDES}\ 123 | --restore_from ${SNAPSHOT_DIR}/stage1/model.ckpt-${NUM_STEPS}\ 124 | --colormap misc/colormapvoc.mat\ 125 | --num_classes ${NUM_CLASSES}\ 126 | --ignore_label 255\ 127 | --flip_aug\ 128 | --scale_aug\ 129 | --embedding_dim ${EMBEDDING_DIM}\ 130 | --num_clusters ${NUM_CLUSTERS}\ 131 | --kmeans_iterations ${KMEANS_ITERATIONS}\ 132 | --k_in_nearest_neighbors ${K_IN_NEAREST_NEIGHBORS}\ 133 | --save_dir ${SNAPSHOT_DIR}/stage1/results/${INFERENCE_SPLIT}\ 134 | --prototype_dir ${SNAPSHOT_DIR}/stage1/results/train+/prototypes 135 | fi 136 | 137 | # Benchmark for the 1st stage. 138 | if [ ${IS_BENCHMARK_1} -eq 1 ]; then 139 | python3 pyscripts/benchmark/benchmark_by_mIoU.py\ 140 | --pred_dir ${SNAPSHOT_DIR}/stage1/results/${INFERENCE_SPLIT}/gray/\ 141 | --gt_dir ${DATAROOT}/VOCdevkit/VOC2012/segcls/\ 142 | --num_classes ${NUM_CLASSES} 143 | fi 144 | 145 | 146 | LEARNING_RATE=2e-4 147 | # Train for the 2nd stage. 148 | if [ ${IS_TRAIN_2} -eq 1 ]; then 149 | python3 pyscripts/train/train_segsort.py\ 150 | --snapshot_dir ${SNAPSHOT_DIR}/stage2\ 151 | --restore_from ${SNAPSHOT_DIR}/stage1/model.ckpt-30000\ 152 | --data_list dataset/voc12/train.txt\ 153 | --data_dir ${DATAROOT}/VOCdevkit/\ 154 | --batch_size ${BATCH_SIZE}\ 155 | --save_pred_every ${NUM_STEPS}\ 156 | --update_tb_every 50\ 157 | --input_size ${TRAIN_INPUT_SIZE}\ 158 | --learning_rate ${LEARNING_RATE}\ 159 | --weight_decay ${WEIGHT_DECAY}\ 160 | --iter_size ${ITER_SIZE}\ 161 | --num_classes ${NUM_CLASSES}\ 162 | --num_steps $(($NUM_STEPS+1))\ 163 | --concentration ${CONCENTRATION}\ 164 | --num_banks ${NUM_BANKS}\ 165 | --embedding_dim ${EMBEDDING_DIM}\ 166 | --num_clusters ${NUM_CLUSTERS}\ 167 | --kmeans_iterations ${KMEANS_ITERATIONS}\ 168 | --random_mirror\ 169 | --random_scale\ 170 | --random_crop\ 171 | --not_restore_classifier\ 172 | --is_training 173 | fi 174 | 175 | # Prototype for the 1st stage. 176 | if [ ${IS_PROTOTYPE_2} -eq 1 ]; then 177 | python3 pyscripts/inference/extract_prototypes.py\ 178 | --data_dir ${DATAROOT}/VOCdevkit/\ 179 | --data_list dataset/voc12/train.txt\ 180 | --restore_from ${SNAPSHOT_DIR}/stage2/model.ckpt-${NUM_STEPS}\ 181 | --input_size ${INFERENCE_INPUT_SIZE}\ 182 | --strides ${INFERENCE_STRIDES}\ 183 | --num_classes ${NUM_CLASSES}\ 184 | --ignore_label 255\ 185 | --embedding_dim ${EMBEDDING_DIM}\ 186 | --num_clusters ${NUM_CLUSTERS}\ 187 | --kmeans_iterations ${KMEANS_ITERATIONS}\ 188 | --save_dir ${SNAPSHOT_DIR}/stage2/results/train 189 | fi 190 | 191 | # Single-scale inference for the 2nd stage. 192 | if [ ${IS_INFERENCE_2} -eq 1 ]; then 193 | python3 pyscripts/inference/inference_segsort.py\ 194 | --data_dir ${DATAROOT}/VOCdevkit/\ 195 | --data_list dataset/voc12/${INFERENCE_SPLIT}.txt\ 196 | --input_size 720,720\ 197 | --strides ${INFERENCE_STRIDES}\ 198 | --restore_from ${SNAPSHOT_DIR}/stage2/model.ckpt-${NUM_STEPS}\ 199 | --colormap misc/colormapvoc.mat\ 200 | --num_classes ${NUM_CLASSES}\ 201 | --ignore_label 255\ 202 | --embedding_dim ${EMBEDDING_DIM}\ 203 | --num_clusters ${NUM_CLUSTERS}\ 204 | --kmeans_iterations ${KMEANS_ITERATIONS}\ 205 | --k_in_nearest_neighbors ${K_IN_NEAREST_NEIGHBORS}\ 206 | --save_dir ${SNAPSHOT_DIR}/stage2/results/${INFERENCE_SPLIT}\ 207 | --prototype_dir ${SNAPSHOT_DIR}/stage2/results/train/prototypes 208 | fi 209 | 210 | # Multi-scale inference for the 2nd stage. 211 | if [ ${IS_INFERENCE_MSC_2} -eq 1 ]; then 212 | python3 pyscripts/inference/inference_segsort_msc.py\ 213 | --data_dir ${DATAROOT}/VOCdevkit/\ 214 | --data_list dataset/voc12/${INFERENCE_SPLIT}.txt\ 215 | --input_size ${INFERENCE_INPUT_SIZE}\ 216 | --strides ${INFERENCE_STRIDES}\ 217 | --restore_from ${SNAPSHOT_DIR}/stage2/model.ckpt-${NUM_STEPS}\ 218 | --colormap misc/colormapvoc.mat\ 219 | --num_classes ${NUM_CLASSES}\ 220 | --ignore_label 255\ 221 | --flip_aug\ 222 | --scale_aug\ 223 | --embedding_dim ${EMBEDDING_DIM}\ 224 | --num_clusters ${NUM_CLUSTERS}\ 225 | --kmeans_iterations ${KMEANS_ITERATIONS}\ 226 | --k_in_nearest_neighbors ${K_IN_NEAREST_NEIGHBORS}\ 227 | --save_dir ${SNAPSHOT_DIR}/stage2/results/${INFERENCE_SPLIT}\ 228 | --prototype_dir ${SNAPSHOT_DIR}/stage2/results/train/prototypes 229 | fi 230 | 231 | # Benchmark for the 2nd stage. 232 | if [ ${IS_BENCHMARK_2} -eq 1 ]; then 233 | python3 pyscripts/benchmark/benchmark_by_mIoU.py\ 234 | --pred_dir ${SNAPSHOT_DIR}/stage2/results/${INFERENCE_SPLIT}/gray/\ 235 | --gt_dir ${DATAROOT}/VOCdevkit/VOC2012/segcls/\ 236 | --num_classes ${NUM_CLASSES} 237 | fi 238 | -------------------------------------------------------------------------------- /bashscripts/voc12/train_segsort_mgpu.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # This script is used for multi-gpu training, inference and 3 | # benchmarking the PSPNet with SegSort on PASCAL VOC 2012. 4 | # 5 | # Usage: 6 | # # From SegSort/ directory. 7 | # bash bashscripts/voc12/train_segsort_mgpu.sh 8 | 9 | 10 | # Set up parameters for training. 11 | BATCH_SIZE=16 12 | TRAIN_INPUT_SIZE=480,480 13 | WEIGHT_DECAY=5e-4 14 | ITER_SIZE=1 15 | NUM_STEPS1=100000 16 | NUM_STEPS2=30000 17 | NUM_CLASSES=21 18 | NUM_GPU=5 19 | LEARNING_RATE=2e-3 20 | 21 | # Set up parameters for inference. 22 | INFERENCE_INPUT_SIZE=480,480 23 | INFERENCE_STRIDES=320,320 24 | INFERENCE_SPLIT=val 25 | 26 | # Set up SegSort hyper-parameters. 27 | CONCENTRATION=10 28 | NUM_BANKS=2 29 | EMBEDDING_DIM=32 30 | NUM_CLUSTERS=5 31 | KMEANS_ITERATIONS=10 32 | K_IN_NEAREST_NEIGHBORS=21 33 | 34 | # Set up path for saving models. 35 | SNAPSHOT_DIR=snapshots/voc12/segsort/segsort_mgpu_lr2e-3_it100k 36 | 37 | # Set up the procedure pipeline. 38 | IS_TRAIN_1=1 39 | IS_PROTOTYPE_1=1 40 | IS_INFERENCE_1=1 41 | IS_BENCHMARK_1=1 42 | IS_TRAIN_2=1 43 | IS_PROTOTYPE_2=1 44 | IS_INFERENCE_2=0 45 | IS_INFERENCE_MSC_2=1 46 | IS_BENCHMARK_2=1 47 | 48 | # Update PYTHONPATH. 49 | export PYTHONPATH=`pwd`:$PYTHONPATH 50 | 51 | # Set up the data directory. 52 | DATAROOT=/ssd/jyh/datasets 53 | 54 | # Train for the 1st stage. 55 | if [ ${IS_TRAIN_1} -eq 1 ]; then 56 | python3 pyscripts/train/train_segsort_mgpu.py\ 57 | --snapshot_dir ${SNAPSHOT_DIR}/stage1\ 58 | --restore_from snapshots/imagenet/trained/resnet_v1_101.ckpt\ 59 | --data_list dataset/voc12/train+.txt\ 60 | --data_dir ${DATAROOT}/VOCdevkit/\ 61 | --batch_size ${BATCH_SIZE}\ 62 | --save_pred_every ${NUM_STEPS1}\ 63 | --update_tb_every 50\ 64 | --input_size ${TRAIN_INPUT_SIZE}\ 65 | --learning_rate ${LEARNING_RATE}\ 66 | --weight_decay ${WEIGHT_DECAY}\ 67 | --iter_size ${ITER_SIZE}\ 68 | --num_classes ${NUM_CLASSES}\ 69 | --num_steps $(($NUM_STEPS1+1))\ 70 | --concentration ${CONCENTRATION}\ 71 | --num_banks ${NUM_BANKS}\ 72 | --embedding_dim ${EMBEDDING_DIM}\ 73 | --num_clusters ${NUM_CLUSTERS}\ 74 | --kmeans_iterations ${KMEANS_ITERATIONS}\ 75 | --num_gpu ${NUM_GPU}\ 76 | --random_mirror\ 77 | --random_scale\ 78 | --random_crop\ 79 | --not_restore_classifier\ 80 | --is_training 81 | fi 82 | 83 | # Prototype for the 1st stage. 84 | if [ ${IS_PROTOTYPE_1} -eq 1 ]; then 85 | python3 pyscripts/inference/extract_prototypes.py\ 86 | --data_dir ${DATAROOT}/VOCdevkit/\ 87 | --data_list dataset/voc12/train+.txt\ 88 | --restore_from ${SNAPSHOT_DIR}/stage1/model.ckpt-${NUM_STEPS1}\ 89 | --input_size ${INFERENCE_INPUT_SIZE}\ 90 | --strides ${INFERENCE_STRIDES}\ 91 | --num_classes ${NUM_CLASSES}\ 92 | --ignore_label 255\ 93 | --embedding_dim ${EMBEDDING_DIM}\ 94 | --num_clusters ${NUM_CLUSTERS}\ 95 | --kmeans_iterations ${KMEANS_ITERATIONS}\ 96 | --save_dir ${SNAPSHOT_DIR}/stage1/results/train+ 97 | fi 98 | 99 | # Single-scale inference for the 1st stage. 100 | if [ ${IS_INFERENCE_1} -eq 1 ]; then 101 | python3 pyscripts/inference/inference_segsort.py\ 102 | --data_dir ${DATAROOT}/VOCdevkit/\ 103 | --data_list dataset/voc12/${INFERENCE_SPLIT}.txt\ 104 | --input_size 720,720\ 105 | --strides ${INFERENCE_STRIDES}\ 106 | --restore_from ${SNAPSHOT_DIR}/stage1/model.ckpt-${NUM_STEPS1}\ 107 | --colormap misc/colormapvoc.mat\ 108 | --num_classes ${NUM_CLASSES}\ 109 | --ignore_label 255\ 110 | --embedding_dim ${EMBEDDING_DIM}\ 111 | --num_clusters ${NUM_CLUSTERS}\ 112 | --kmeans_iterations ${KMEANS_ITERATIONS}\ 113 | --k_in_nearest_neighbors ${K_IN_NEAREST_NEIGHBORS}\ 114 | --save_dir ${SNAPSHOT_DIR}/stage1/results/${INFERENCE_SPLIT}\ 115 | --prototype_dir ${SNAPSHOT_DIR}/stage1/results/train+/prototypes 116 | fi 117 | 118 | # Benchmark for the 1st stage. 119 | if [ ${IS_BENCHMARK_1} -eq 1 ]; then 120 | python3 pyscripts/benchmark/benchmark_by_mIoU.py\ 121 | --pred_dir ${SNAPSHOT_DIR}/stage1/results/${INFERENCE_SPLIT}/gray/\ 122 | --gt_dir ${DATAROOT}/VOCdevkit/VOC2012/segcls/\ 123 | --num_classes ${NUM_CLASSES} 124 | fi 125 | 126 | 127 | LEARNING_RATE=2e-4 128 | # Train for the 2nd stage. 129 | if [ ${IS_TRAIN_2} -eq 1 ]; then 130 | python3 pyscripts/train/train_segsort_mgpu.py\ 131 | --snapshot_dir ${SNAPSHOT_DIR}/stage2\ 132 | --restore_from ${SNAPSHOT_DIR}/stage1/model.ckpt-${NUM_STEPS1}\ 133 | --data_list dataset/voc12/train.txt\ 134 | --data_dir ${DATAROOT}/VOCdevkit/\ 135 | --batch_size ${BATCH_SIZE}\ 136 | --save_pred_every 10000\ 137 | --update_tb_every 50\ 138 | --input_size ${TRAIN_INPUT_SIZE}\ 139 | --learning_rate ${LEARNING_RATE}\ 140 | --weight_decay ${WEIGHT_DECAY}\ 141 | --iter_size ${ITER_SIZE}\ 142 | --num_classes ${NUM_CLASSES}\ 143 | --num_steps $(($NUM_STEPS2+1))\ 144 | --num_gpu ${NUM_GPU}\ 145 | --concentration ${CONCENTRATION}\ 146 | --num_banks ${NUM_BANKS}\ 147 | --embedding_dim ${EMBEDDING_DIM}\ 148 | --num_clusters ${NUM_CLUSTERS}\ 149 | --kmeans_iterations ${KMEANS_ITERATIONS}\ 150 | --random_mirror\ 151 | --random_scale\ 152 | --random_crop\ 153 | --is_training 154 | fi 155 | 156 | # Prototype for the 2nd stage. 157 | if [ ${IS_PROTOTYPE_2} -eq 1 ]; then 158 | python3 pyscripts/inference/extract_prototypes.py\ 159 | --data_dir ${DATAROOT}/VOCdevkit/\ 160 | --data_list dataset/voc12/train.txt\ 161 | --input_size ${INFERENCE_INPUT_SIZE}\ 162 | --strides ${INFERENCE_STRIDES}\ 163 | --restore_from ${SNAPSHOT_DIR}/stage2/model.ckpt-${NUM_STEPS2}\ 164 | --num_classes ${NUM_CLASSES}\ 165 | --ignore_label 255\ 166 | --embedding_dim ${EMBEDDING_DIM}\ 167 | --num_clusters ${NUM_CLUSTERS}\ 168 | --kmeans_iterations ${KMEANS_ITERATIONS}\ 169 | --save_dir ${SNAPSHOT_DIR}/stage2/results/train 170 | fi 171 | 172 | # Single-scale inference for the 2nd stage. 173 | if [ ${IS_INFERENCE_2} -eq 1 ]; then 174 | python3 pyscripts/inference/inference_segsort.py\ 175 | --data_dir ${DATAROOT}/VOCdevkit/\ 176 | --data_list dataset/voc12/${INFERENCE_SPLIT}.txt\ 177 | --input_size 720,720\ 178 | --strides ${INFERENCE_STRIDES}\ 179 | --restore_from ${SNAPSHOT_DIR}/stage2/model.ckpt-${NUM_STEPS2}\ 180 | --colormap misc/colormapvoc.mat\ 181 | --num_classes ${NUM_CLASSES}\ 182 | --ignore_label 255\ 183 | --embedding_dim ${EMBEDDING_DIM}\ 184 | --num_clusters ${NUM_CLUSTERS}\ 185 | --kmeans_iterations ${KMEANS_ITERATIONS}\ 186 | --k_in_nearest_neighbors ${K_IN_NEAREST_NEIGHBORS}\ 187 | --save_dir ${SNAPSHOT_DIR}/stage2/results/${INFERENCE_SPLIT}\ 188 | --prototype_dir ${SNAPSHOT_DIR}/stage2/results/train/prototypes 189 | fi 190 | 191 | # Inference for the 2nd stage. 192 | if [ ${IS_INFERENCE_MSC_2} -eq 1 ]; then 193 | python3 pyscripts/inference/inference_segsort_msc.py\ 194 | --data_dir ${DATAROOT}/VOCdevkit/\ 195 | --data_list dataset/voc12/${INFERENCE_SPLIT}.txt\ 196 | --input_size ${INFERENCE_INPUT_SIZE}\ 197 | --strides ${INFERENCE_STRIDES}\ 198 | --restore_from ${SNAPSHOT_DIR}/stage2/model.ckpt-${NUM_STEPS2}\ 199 | --colormap misc/colormapvoc.mat\ 200 | --num_classes ${NUM_CLASSES}\ 201 | --ignore_label 255\ 202 | --embedding_dim ${EMBEDDING_DIM}\ 203 | --num_clusters ${NUM_CLUSTERS}\ 204 | --kmeans_iterations ${KMEANS_ITERATIONS}\ 205 | --k_in_nearest_neighbors ${K_IN_NEAREST_NEIGHBORS}\ 206 | --flip_aug\ 207 | --scale_aug\ 208 | --save_dir ${SNAPSHOT_DIR}/stage2/results/${INFERENCE_SPLIT}\ 209 | --prototype_dir ${SNAPSHOT_DIR}/stage2/results/train/prototypes 210 | fi 211 | 212 | # Benchmark for the 2nd stage. 213 | if [ ${IS_BENCHMARK_2} -eq 1 ]; then 214 | python3 pyscripts/benchmark/benchmark_by_mIoU.py\ 215 | --pred_dir ${SNAPSHOT_DIR}/stage2/results/${INFERENCE_SPLIT}/gray/\ 216 | --gt_dir ${DATAROOT}/VOCdevkit/VOC2012/segcls/\ 217 | --num_classes ${NUM_CLASSES} 218 | fi 219 | -------------------------------------------------------------------------------- /bashscripts/voc12/train_segsort_unsup.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # This script is used for unsupervised training, inference and 3 | # benchmarking the PSPNet with SegSort on PASCAL VOC 2012. 4 | # 5 | # Usage: 6 | # # From SegSort/ directory. 7 | # bash bashscripts/voc12/train_segsort_unsup.sh 8 | 9 | 10 | # Set up parameters for training. 11 | BATCH_SIZE=8 12 | TRAIN_INPUT_SIZE=336,336 13 | WEIGHT_DECAY=5e-4 14 | ITER_SIZE=1 15 | NUM_STEPS=10000 16 | NUM_CLASSES=21 17 | LEARNING_RATE=2e-3 18 | 19 | # Set up parameters for inference. 20 | INFERENCE_INPUT_SIZE=480,480 21 | INFERENCE_STRIDES=320,320 22 | INFERENCE_SPLIT=val 23 | 24 | # Set up SegSort hyper-parameters. 25 | CONCENTRATION=10 26 | EMBEDDING_DIM=32 27 | NUM_CLUSTERS=5 28 | KMEANS_ITERATIONS=10 29 | K_IN_NEAREST_NEIGHBORS=15 30 | 31 | # Set up path for saving models. 32 | SNAPSHOT_DIR=snapshots/voc12/unsup_segsort/unsup_segsort_lr2e-3_it10k 33 | 34 | # Set up the procedure pipeline. 35 | IS_TRAIN=1 36 | IS_PROTOTYPE=1 37 | IS_INFERENCE_MSC=1 38 | IS_BENCHMARK=1 39 | 40 | # Update PYTHONPATH. 41 | export PYTHONPATH=`pwd`:$PYTHONPATH 42 | 43 | # Set up the data directory. 44 | DATAROOT=/ssd/jyh/datasets 45 | 46 | # Train. 47 | if [ ${IS_TRAIN} -eq 1 ]; then 48 | python3 pyscripts/train/train_segsort_unsup.py\ 49 | --snapshot_dir ${SNAPSHOT_DIR}\ 50 | --restore_from snapshots/imagenet/trained/resnet_v1_101.ckpt\ 51 | --data_list dataset/voc12/train+.txt\ 52 | --data_dir ${DATAROOT}/VOCdevkit/\ 53 | --batch_size ${BATCH_SIZE}\ 54 | --save_pred_every ${NUM_STEPS}\ 55 | --update_tb_every 50\ 56 | --input_size ${TRAIN_INPUT_SIZE}\ 57 | --learning_rate ${LEARNING_RATE}\ 58 | --weight_decay ${WEIGHT_DECAY}\ 59 | --iter_size ${ITER_SIZE}\ 60 | --num_classes ${NUM_CLASSES}\ 61 | --num_steps $(($NUM_STEPS+1))\ 62 | --concentration ${CONCENTRATION}\ 63 | --embedding_dim ${EMBEDDING_DIM}\ 64 | --random_mirror\ 65 | --random_scale\ 66 | --random_crop\ 67 | --not_restore_classifier\ 68 | --is_training 69 | fi 70 | 71 | # Extract prototypes. 72 | if [ ${IS_PROTOTYPE} -eq 1 ]; then 73 | python3 pyscripts/inference/extract_prototypes.py\ 74 | --data_dir ${DATAROOT}/VOCdevkit/\ 75 | --data_list dataset/voc12/train+.txt\ 76 | --input_size ${INFERENCE_INPUT_SIZE}\ 77 | --strides ${INFERENCE_STRIDES}\ 78 | --restore_from ${SNAPSHOT_DIR}/model.ckpt-${NUM_STEPS}\ 79 | --num_classes ${NUM_CLASSES}\ 80 | --ignore_label 255\ 81 | --embedding_dim ${EMBEDDING_DIM}\ 82 | --num_clusters ${NUM_CLUSTERS}\ 83 | --kmeans_iterations ${KMEANS_ITERATIONS}\ 84 | --save_dir ${SNAPSHOT_DIR}/results/train+ 85 | fi 86 | 87 | # Inference. 88 | if [ ${IS_INFERENCE_MSC} -eq 1 ]; then 89 | python3 pyscripts/inference/inference_segsort_msc.py\ 90 | --data_dir ${DATAROOT}/VOCdevkit/\ 91 | --data_list dataset/voc12/${INFERENCE_SPLIT}.txt\ 92 | --input_size ${INFERENCE_INPUT_SIZE}\ 93 | --strides ${INFERENCE_STRIDES}\ 94 | --restore_from ${SNAPSHOT_DIR}/model.ckpt-${NUM_STEPS}\ 95 | --colormap misc/colormapvoc.mat\ 96 | --num_classes ${NUM_CLASSES}\ 97 | --ignore_label 255\ 98 | --save_dir ${SNAPSHOT_DIR}/results/${INFERENCE_SPLIT}\ 99 | --flip_aug\ 100 | --scale_aug\ 101 | --embedding_dim ${EMBEDDING_DIM}\ 102 | --num_clusters ${NUM_CLUSTERS}\ 103 | --kmeans_iterations ${KMEANS_ITERATIONS}\ 104 | --k_in_nearest_neighbors ${K_IN_NEAREST_NEIGHBORS}\ 105 | --prototype_dir ${SNAPSHOT_DIR}/results/train+/prototypes 106 | fi 107 | 108 | # Benchmark. 109 | if [ ${IS_BENCHMARK} -eq 1 ]; then 110 | python3 pyscripts/benchmark/benchmark_by_mIoU.py\ 111 | --pred_dir ${SNAPSHOT_DIR}/results/${INFERENCE_SPLIT}/gray/\ 112 | --gt_dir ${DATAROOT}/VOCdevkit/VOC2012/segcls/\ 113 | --num_classes ${NUM_CLASSES} 114 | fi 115 | -------------------------------------------------------------------------------- /dataset/voc12/sbd_clsimg.zip: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jyhjinghwang/SegSort/4e3278b2a7732d62f3784eb629a323acc0ca68f3/dataset/voc12/sbd_clsimg.zip -------------------------------------------------------------------------------- /misc/colormapvoc.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jyhjinghwang/SegSort/4e3278b2a7732d62f3784eb629a323acc0ca68f3/misc/colormapvoc.mat -------------------------------------------------------------------------------- /misc/main.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jyhjinghwang/SegSort/4e3278b2a7732d62f3784eb629a323acc0ca68f3/misc/main.png -------------------------------------------------------------------------------- /network/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jyhjinghwang/SegSort/4e3278b2a7732d62f3784eb629a323acc0ca68f3/network/__init__.py -------------------------------------------------------------------------------- /network/common/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jyhjinghwang/SegSort/4e3278b2a7732d62f3784eb629a323acc0ca68f3/network/common/__init__.py -------------------------------------------------------------------------------- /network/common/layers.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import tensorflow as tf 4 | import numpy as np 5 | 6 | 7 | def batch_norm(x, 8 | name, 9 | activation_fn=None, 10 | decay=0.99, 11 | epsilon=0.001, 12 | is_training=True): 13 | """Batch normalization. 14 | 15 | This function perform batch normalization. If it is set for training, 16 | it will update moving mean and moving variance to keep track of global 17 | statistics by exponential decay. 18 | 19 | output = [(x - mean) / sqrt(var)] * gamma + beta. 20 | 21 | Args: 22 | x: A tensor of size [batch_size, height_in, width_in, channels]. 23 | name: The prefix of tensorflow variables defined in this layer. 24 | activation_fn: The non-linear function, such as tf.nn.relu. If 25 | activation_fn is None, skip it and maintain a linear activation. 26 | decay: The exponential decay rate. 27 | epsilon: Small float added to variance to avoid dividing by zero. 28 | is_training: enable/disable is_training for updating moving mean and 29 | moving variance by exponential decay. If True, compute batch mean 30 | and batch variance per batch; otherwise, use moving mean and moving 31 | variance as batch mean and batch variance. 32 | 33 | Returns: 34 | A tensor of size [batch_size, height_in, width_in, channels] 35 | """ 36 | with tf.variable_scope(name) as scope: 37 | shape_x = x.get_shape().as_list() 38 | 39 | beta = tf.get_variable( 40 | 'beta', 41 | shape_x[-1], 42 | initializer=tf.constant_initializer(0.0), 43 | trainable=is_training) 44 | gamma = tf.get_variable( 45 | 'gamma', 46 | shape_x[-1], 47 | initializer=tf.constant_initializer(1.0), 48 | trainable=is_training) 49 | moving_mean = tf.get_variable( 50 | 'moving_mean', 51 | shape_x[-1], 52 | initializer=tf.constant_initializer(0.0), 53 | trainable=False) 54 | moving_var = tf.get_variable( 55 | 'moving_variance', 56 | shape_x[-1], 57 | initializer=tf.constant_initializer(1.0), 58 | trainable=False) 59 | 60 | if is_training: 61 | # Update moving mean and variance before 62 | # applying batch normalization 63 | mean, var = tf.nn.moments(x, 64 | np.arange(len(shape_x)-1), 65 | keep_dims=True) 66 | mean = tf.reshape(mean, 67 | [mean.shape.as_list()[-1]]) 68 | var = tf.reshape(var, 69 | [var.shape.as_list()[-1]]) 70 | 71 | # Update moving mean and moving variance by exponential decay. 72 | update_moving_mean = tf.assign( 73 | moving_mean, 74 | moving_mean*decay + mean*(1-decay)) 75 | update_moving_var = tf.assign( 76 | moving_var, 77 | moving_var*decay + var*(1-decay)) 78 | update_ops = [update_moving_mean, update_moving_var] 79 | 80 | with tf.control_dependencies(update_ops): 81 | output = tf.nn.batch_normalization(x, 82 | mean, 83 | var, 84 | beta, 85 | gamma, 86 | epsilon) 87 | else: 88 | # Use collected moving mean and moving variance for normalization. 89 | mean = moving_mean 90 | var = moving_var 91 | 92 | output = tf.nn.batch_normalization(x, 93 | mean, 94 | var, 95 | beta, 96 | gamma, 97 | epsilon) 98 | 99 | # Apply activation_fn, if it is not None. 100 | if activation_fn: 101 | output = activation_fn(output) 102 | 103 | return output 104 | 105 | 106 | def conv(x, 107 | name, 108 | filters, 109 | kernel_size, 110 | strides, 111 | padding, 112 | relu=True, 113 | biased=True, 114 | bn=True, 115 | decay=0.9997, 116 | is_training=True, 117 | use_global_status=True): 118 | """Convolutional layers with batch normalization and ReLU. 119 | 120 | This function perform convolution, batch_norm (if bn=True), 121 | and ReLU (if relu=True). 122 | 123 | Args: 124 | x: A tensor of size [batch_size, height_in, width_in, channels]. 125 | name: The prefix of tensorflow variables defined in this layer. 126 | filters: A number indicating the number of output channels. 127 | kernel_size: A number indicating the size of convolutional kernels. 128 | strides: A number indicating the stride of the sliding window for 129 | height and width. 130 | padding: 'VALID' or 'SAME'. 131 | relu: enable/disable relu for ReLU as activation function. If relu 132 | is False, maintain linear activation. 133 | biased: enable/disable biased for adding biases after convolution. 134 | bn: enable/disable bn for batch normalization. 135 | decay: A number indication decay rate for updating moving mean and 136 | moving variance in batch normalization. 137 | is_training: If the tensorflow variables defined in this layer 138 | would be used for training. 139 | use_global_status: enable/disable use_global_status for batch 140 | normalization. If True, moving mean and moving variance are updated 141 | by exponential decay. 142 | 143 | Returns: 144 | A tensor of size [batch_size, height_out, width_out, channels_out]. 145 | """ 146 | c_i = x.get_shape().as_list()[-1] # input channels 147 | c_o = filters # output channels 148 | 149 | # Define helper function. 150 | convolve = lambda i,k: tf.nn.conv2d( 151 | i, 152 | k, 153 | [1, strides, strides, 1], 154 | padding=padding) 155 | 156 | with tf.variable_scope(name) as scope: 157 | kernel = tf.get_variable( 158 | name='weights', 159 | shape=[kernel_size, kernel_size, c_i, c_o], 160 | trainable=is_training) 161 | 162 | if strides > 1: 163 | pad = kernel_size - 1 164 | pad_beg = pad // 2 165 | pad_end = pad - pad_beg 166 | pad_h = [pad_beg, pad_end] 167 | pad_w = [pad_beg, pad_end] 168 | x = tf.pad(x, [[0,0], pad_h, pad_w, [0,0]]) 169 | 170 | output = convolve(x, kernel) 171 | 172 | # Add the biases. 173 | if biased: 174 | biases = tf.get_variable('biases', [c_o], trainable=is_training) 175 | output = tf.nn.bias_add(output, biases) 176 | 177 | # Apply batch normalization. 178 | if bn: 179 | is_bn_training = not use_global_status 180 | output = batch_norm(output, 181 | 'BatchNorm', 182 | is_training=is_bn_training, 183 | decay=decay, 184 | activation_fn=None) 185 | 186 | # Apply ReLU as activation function. 187 | if relu: 188 | output = tf.nn.relu(output) 189 | 190 | return output 191 | 192 | 193 | def atrous_conv(x, 194 | name, 195 | filters, 196 | kernel_size, 197 | dilation, 198 | padding, 199 | relu=True, 200 | biased=True, 201 | bn=True, 202 | decay=0.9997, 203 | is_training=True, 204 | use_global_status=True): 205 | """Atrous convolutional layers with batch normalization and ReLU. 206 | 207 | This function perform atrous convolution, batch_norm (if bn=True), 208 | and ReLU (if relu=True). 209 | 210 | Args: 211 | x: A tensor of size [batch_size, height_in, width_in, channels]. 212 | name: The prefix of tensorflow variables defined in this layer. 213 | filters: A number indicating the number of output channels. 214 | kernel_size: A number indicating the size of convolutional kernels. 215 | dilation: A number indicating the dilation factor for height and width. 216 | padding: 'VALID' or 'SAME'. 217 | relu: enable/disable relu for ReLU as activation function. If relu 218 | is False, maintain linear activation. 219 | biased: enable/disable biased for adding biases after convolution. 220 | bn: enable/disable bn for batch normalization. 221 | decay: A number indication decay rate for updating moving mean and 222 | moving variance in batch normalization. 223 | is_training: If the tensorflow variables defined in this layer 224 | would be used for training. 225 | use_global_status: enable/disable use_global_status for batch 226 | normalization. If True, moving mean and moving variance are updated 227 | by exponential decay. 228 | 229 | Returns: 230 | A tensor of size [batch_size, height_out, width_out, channels_out]. 231 | """ 232 | c_i = x.get_shape().as_list()[-1] # input channels 233 | c_o = filters # output channels 234 | 235 | # Define helper function. 236 | convolve = lambda i,k: tf.nn.atrous_conv2d( 237 | i, 238 | k, 239 | dilation, 240 | padding) 241 | 242 | with tf.variable_scope(name) as scope: 243 | kernel = tf.get_variable( 244 | name='weights', 245 | shape=[kernel_size, kernel_size, c_i, c_o], 246 | trainable=is_training,) 247 | output = convolve(x, kernel) 248 | 249 | # Add the biases. 250 | if biased: 251 | biases = tf.get_variable('biases', [c_o], trainable=is_training) 252 | output = tf.nn.bias_add(output, biases) 253 | 254 | # Apply batch normalization. 255 | if bn: 256 | is_bn_training = not use_global_status 257 | output = batch_norm(output, 'BatchNorm', 258 | is_training=is_bn_training, 259 | decay=decay, 260 | activation_fn=None) 261 | 262 | # Apply ReLU as activation function. 263 | if relu: 264 | output = tf.nn.relu(output) 265 | 266 | return output 267 | 268 | def _pool(x, 269 | name, 270 | kernel_size, 271 | strides, 272 | padding, 273 | pool_fn): 274 | """Helper function for spatial pooling layer. 275 | 276 | Args: 277 | x: A tensor of size [batch_size, height_in, width_in, channels]. 278 | name: The prefix of tensorflow variables defined in this layer. 279 | kernel_size: A number indicating the size of pooling kernels. 280 | strides: A number indicating the stride of the sliding window for 281 | height and width. 282 | padding: 'VALID' or 'SAME'. 283 | pool_fn: A tensorflow operation for pooling, such as tf.nn.max_pool. 284 | 285 | Returns: 286 | A tensor of size [batch_size, height_out, width_out, channels]. 287 | """ 288 | k = kernel_size 289 | s = strides 290 | if s > 1 and padding != 'SAME': 291 | pad = k - 1 292 | pad_beg = pad // 2 293 | pad_end = pad - pad_beg 294 | pad_h = [pad_beg, pad_end] 295 | pad_w = [pad_beg, pad_end] 296 | x = tf.pad(x, [[0,0], pad_h, pad_w, [0,0]]) 297 | 298 | 299 | output = pool_fn(x, 300 | ksize=[1, k, k, 1], 301 | strides=[1, s, s, 1], 302 | padding=padding, 303 | name=name) 304 | 305 | return output 306 | 307 | def max_pool(x, 308 | name, 309 | kernel_size, 310 | strides, 311 | padding): 312 | """Max pooling layer. 313 | 314 | Args: 315 | x: A tensor of size [batch_size, height_in, width_in, channels]. 316 | name: The prefix of tensorflow variables defined in this layer. 317 | kernel_size: A number indicating the size of pooling kernels. 318 | strides: A number indicating the stride of the sliding window for 319 | height and width. 320 | padding: 'VALID' or 'SAME'. 321 | 322 | Returns: 323 | A tensor of size [batch_size, height_out, width_out, channels]. 324 | """ 325 | return _pool(x, name, kernel_size, strides, padding, tf.nn.max_pool) 326 | 327 | def avg_pool(x, name, kernel_size, strides, padding): 328 | """Average pooling layer. 329 | 330 | Args: 331 | x: A tensor of size [batch_size, height_in, width_in, channels]. 332 | name: The prefix of tensorflow variables defined in this layer. 333 | kernel_size: A number indicating the size of pooling kernels. 334 | strides: A number indicating the stride of the sliding window for 335 | height and width. 336 | padding: 'VALID' or 'SAME'. 337 | 338 | Returns: 339 | A tensor of size [batch_size, height_out, width_out, channels]. 340 | """ 341 | return _pool(x, name, kernel_size, strides, padding, tf.nn.avg_pool) 342 | -------------------------------------------------------------------------------- /network/common/resnet_v1.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | 3 | import network.common.layers as nn 4 | 5 | 6 | def bottleneck(x, 7 | name, 8 | filters, 9 | strides=None, 10 | dilation=None, 11 | is_training=True, 12 | use_global_status=True): 13 | """Builds the bottleneck module in ResNet. 14 | 15 | This function stack 3 convolutional layers and fuse the output with 16 | the residual connection. 17 | 18 | Args: 19 | x: A tensor of size [batch_size, height_in, width_in, channels]. 20 | name: The prefix of tensorflow variables defined in this layer. 21 | filters: A number indicating the number of output channels. 22 | strides: A number indicating the stride of the sliding window for 23 | height and width. 24 | dilation: A number indicating the dilation factor for height and width. 25 | is_training: If the tensorflow variables defined in this layer 26 | would be used for training. 27 | use_global_status: enable/disable use_global_status for batch 28 | normalization. If True, moving mean and moving variance are updated 29 | by exponential decay. 30 | 31 | Returns: 32 | A tensor of size [batch_size, height_out, width_out, channels_out]. 33 | """ 34 | if strides is None and dilation is None: 35 | raise ValueError('None of strides or dilation is specified, ' 36 | +'set one of them to 1 or bigger number.') 37 | elif strides > 1 and dilation is not None and dilation > 1: 38 | raise ValueError('strides and dilation are both specified, ' 39 | +'set one of them to 1 or None.') 40 | 41 | with tf.variable_scope(name) as scope: 42 | c_i = x.get_shape().as_list()[-1] 43 | 44 | if c_i != filters*4: 45 | # Use a convolutional layer as residual connection when the 46 | # number of input channels is different from output channels. 47 | shortcut = nn.conv(x, 48 | name='shortcut', 49 | filters=filters*4, 50 | kernel_size=1, 51 | strides=strides, 52 | padding='VALID', 53 | biased=False, 54 | bn=True, 55 | relu=False, 56 | is_training=is_training, 57 | use_global_status=use_global_status) 58 | elif strides > 1: 59 | # Use max-pooling as residual connection when the number of 60 | # input channel is same as output channels, but stride is 61 | # larger than 1. 62 | shortcut = nn.max_pool(x, 63 | name='shortcut', 64 | kernel_size=1, 65 | strides=strides, 66 | padding='VALID') 67 | else: 68 | # Otherwise, keep the original input as residual connection. 69 | shortcut = x 70 | 71 | # Build the 1st convolutional layer. 72 | x = nn.conv(x, 73 | name='conv1', 74 | filters=filters, 75 | kernel_size=1, 76 | strides=1, 77 | padding='SAME', 78 | biased=False, 79 | bn=True, 80 | relu=True, 81 | is_training=is_training, 82 | use_global_status=use_global_status) 83 | 84 | if dilation is not None and dilation > 1: 85 | # If dilation > 1, apply atrous conv to the 2nd convolutional layer. 86 | x = nn.atrous_conv( 87 | x, 88 | name='conv2', 89 | filters=filters, 90 | kernel_size=3, 91 | dilation=dilation, 92 | padding='SAME', 93 | biased=False, 94 | bn=True, 95 | relu=True, 96 | is_training=is_training, 97 | use_global_status=use_global_status) 98 | else: 99 | padding = 'VALID' if strides > 1 else 'SAME' 100 | x = nn.conv( 101 | x, 102 | name='conv2', 103 | filters=filters, 104 | kernel_size=3, 105 | strides=strides, 106 | padding=padding, 107 | biased=False, 108 | bn=True, 109 | relu=True, 110 | is_training=is_training, 111 | use_global_status=use_global_status) 112 | 113 | # Build the 3rd convolutional layer (increase the channels). 114 | x = nn.conv(x, 115 | name='conv3', 116 | filters=filters*4, 117 | kernel_size=1, 118 | strides=1, 119 | padding='SAME', 120 | biased=False, 121 | bn=True, 122 | relu=False, 123 | is_training=is_training, 124 | use_global_status=use_global_status) 125 | 126 | # Fuse the convolutional outputs with residual connection. 127 | x = tf.add_n([x, shortcut], name='add') 128 | x = tf.nn.relu(x, name='relu') 129 | 130 | return x 131 | 132 | 133 | def resnet_v1(x, 134 | name, 135 | filters=[64,128,256,512], 136 | num_blocks=[3,4,23,3], 137 | strides=[2,1,1,1], 138 | dilations=[None, None, 2, 2], 139 | is_training=True, 140 | use_global_status=True, 141 | reuse=False): 142 | """Helper function to build ResNet. 143 | 144 | Args: 145 | x: A tensor of size [batch_size, height_in, width_in, channels]. 146 | name: The prefix of tensorflow variables defined in this network. 147 | filters: A list of numbers indicating the number of output channels 148 | (The output channels would be 4 times to the numbers). 149 | strides: A list of numbers indicating the stride of the sliding window for 150 | height and width. 151 | dilation: A number indicating the dilation factor for height and width. 152 | is_training: If the tensorflow variables defined in this layer 153 | would be used for training. 154 | use_global_status: enable/disable use_global_status for batch 155 | normalization. If True, moving mean and moving variance are updated 156 | by exponential decay. 157 | reuse: enable/disable reuse for reusing tensorflow variables. It is 158 | useful for sharing weight parameters across two identical networks. 159 | 160 | Returns: 161 | A tensor of size [batch_size, height_out, width_out, channels_out]. 162 | """ 163 | if len(filters) != len(num_blocks) or len(filters) != len(strides): 164 | raise ValueError('length of lists are not consistent') 165 | 166 | with tf.variable_scope(name, reuse=reuse) as scope: 167 | # Build conv1. 168 | x = nn.conv(x, 169 | name='conv1', 170 | filters=64, 171 | kernel_size=7, 172 | strides=2, 173 | padding='VALID', 174 | biased=False, 175 | bn=True, 176 | relu=True, 177 | is_training=is_training, 178 | use_global_status=use_global_status) 179 | 180 | # Build pool1. 181 | x = nn.max_pool(x, 182 | name='pool1', 183 | kernel_size=3, 184 | strides=2, 185 | padding='VALID') 186 | 187 | # Build residual bottleneck blocks. 188 | for ib in range(len(filters)): 189 | for iu in range(num_blocks[ib]): 190 | name_format = 'block{:d}/unit_{:d}/bottleneck_v1' 191 | block_name = name_format.format(ib+1, iu+1) 192 | 193 | c_o = filters[ib] # output channel 194 | # Apply strides to the last block. 195 | s = strides[ib] if iu == num_blocks[ib]-1 else 1 196 | d = dilations[ib] 197 | x = bottleneck(x, 198 | name=block_name, 199 | filters=c_o, 200 | strides=s, 201 | dilation=d, 202 | is_training=is_training, 203 | use_global_status=use_global_status) 204 | 205 | return x 206 | 207 | 208 | def resnet_v1_101(x, 209 | name, 210 | is_training, 211 | use_global_status, 212 | reuse=False): 213 | """Builds ResNet101 v1. 214 | 215 | Args: 216 | x: A tensor of size [batch_size, height_in, width_in, channels]. 217 | name: The prefix of tensorflow variables defined in this network. 218 | is_training: If the tensorflow variables defined in this layer 219 | would be used for training. 220 | use_global_status: enable/disable use_global_status for batch 221 | normalization. If True, moving mean and moving variance are updated 222 | by exponential decay. 223 | reuse: enable/disable reuse for reusing tensorflow variables. It is 224 | useful for sharing weight parameters across two identical networks. 225 | 226 | Returns: 227 | A tensor of size [batch_size, height_out, width_out, channels_out]. 228 | """ 229 | return resnet_v1(x, 230 | name=name, 231 | filters=[64,128,256,512], 232 | num_blocks=[3,4,23,3], 233 | strides=[2,1,1,1], 234 | dilations=[None, None, 2, 4], 235 | is_training=is_training, 236 | use_global_status=use_global_status, 237 | reuse=reuse) 238 | -------------------------------------------------------------------------------- /network/multigpu/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jyhjinghwang/SegSort/4e3278b2a7732d62f3784eb629a323acc0ca68f3/network/multigpu/__init__.py -------------------------------------------------------------------------------- /network/multigpu/resnet_v1.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | 3 | import network.multigpu.layers as nn_mgpu 4 | 5 | 6 | def bottleneck(xs, 7 | name, 8 | filters, 9 | strides=None, 10 | dilation=None, 11 | is_training=True, 12 | use_global_status=True): 13 | """Builds the bottleneck module in ResNet. 14 | 15 | This function stack 3 convolutional layers and fuse the output with 16 | the residual connection. 17 | 18 | Args: 19 | xs: A list of tensor, in which each tensor is of size 20 | [batch_size, height_in, width_in, channels]. 21 | name: The prefix of tensorflow variables defined in this layer. 22 | filters: A number indicating the number of output channels. 23 | strides: A number indicating the stride of the sliding window for 24 | height and width. 25 | dilation: A number indicating the dilation factor for height and width. 26 | is_training: If the tensorflow variables defined in this layer 27 | would be used for training. 28 | use_global_status: enable/disable use_global_status for batch 29 | normalization. If True, moving mean and moving variance are updated 30 | by exponential decay. 31 | 32 | Returns: 33 | A list of tensor, in which each tensor is of size 34 | [batch_size, height_out, width_out, channels_out]. 35 | """ 36 | 37 | if strides is None and dilation is None: 38 | raise ValueError('None of strides or dilation is specified, ' 39 | +'set one of them to 1 or bigger number.') 40 | elif strides > 1 and dilation is not None and dilation > 1: 41 | raise ValueError('strides and dilation are both specified, ' 42 | +'set one of them to 1 or None.') 43 | 44 | with tf.variable_scope(name) as scope: 45 | c_i = xs[0].get_shape().as_list()[-1] 46 | 47 | if c_i != filters*4: 48 | # Use a convolutional layer as residual connection when the 49 | # number of input channels is different from output channels. 50 | shortcuts = nn_mgpu.conv(xs, name='shortcut', 51 | filters=filters*4, 52 | kernel_size=1, 53 | strides=strides, 54 | padding='VALID', 55 | biased=False, 56 | bn=True, relu=False, 57 | is_training=is_training, 58 | use_global_status=use_global_status) 59 | elif strides > 1: 60 | # Use max-pooling as residual connection when the number of 61 | # input channel is same as output channels, but stride is 62 | # larger than 1. 63 | shortcuts = nn_mgpu.max_pool(xs, 64 | name='shortcut', 65 | kernel_size=1, 66 | strides=strides, 67 | padding='VALID') 68 | else: 69 | # Otherwise, keep the original input as residual connection. 70 | shortcuts = [x for x in xs] 71 | 72 | # Build the 1st convolutional layer. 73 | xs = nn_mgpu.conv(xs, 74 | name='conv1', 75 | filters=filters, 76 | kernel_size=1, 77 | strides=1, 78 | padding='SAME', 79 | biased=False, 80 | bn=True, 81 | relu=True, 82 | is_training=is_training, 83 | use_global_status=use_global_status) 84 | 85 | if dilation is not None and dilation > 1: 86 | # If dilation > 1, apply atrous conv to the 2nd convolutional layer. 87 | xs = nn_mgpu.atrous_conv( 88 | xs, 89 | name='conv2', 90 | filters=filters, 91 | kernel_size=3, 92 | dilation=dilation, 93 | padding='SAME', 94 | biased=False, 95 | bn=True, 96 | relu=True, 97 | is_training=is_training, 98 | use_global_status=use_global_status) 99 | else: 100 | padding = 'VALID' if strides > 1 else 'SAME' 101 | xs = nn_mgpu.conv( 102 | xs, 103 | name='conv2', 104 | filters=filters, 105 | kernel_size=3, 106 | strides=strides, 107 | padding=padding, 108 | biased=False, 109 | bn=True, 110 | relu=True, 111 | is_training=is_training, 112 | use_global_status=use_global_status) 113 | 114 | # Build the 3rd convolutional layer (increase the channels). 115 | xs = nn_mgpu.conv(xs, 116 | name='conv3', 117 | filters=filters*4, 118 | kernel_size=1, 119 | strides=1, 120 | padding='SAME', 121 | biased=False, 122 | bn=True, 123 | relu=False, 124 | is_training=is_training, 125 | use_global_status=use_global_status) 126 | 127 | # Fuse the convolutional outputs with residual connection. 128 | outputs = [] 129 | for x,shortcut in zip(xs, shortcuts): 130 | with tf.device(x.device): 131 | x = tf.add_n([x, shortcut], name='add') 132 | x = tf.nn.relu(x, name='relu') 133 | outputs.append(x) 134 | 135 | return outputs 136 | 137 | 138 | def resnet_v1(xs, 139 | name, 140 | filters=[64,128,256,512], 141 | num_blocks=[3,4,23,3], 142 | strides=[2,1,1,1], 143 | dilations=[None, None, 2, 2], 144 | is_training=True, 145 | use_global_status=True, 146 | reuse=False): 147 | """Helper function to build ResNet. 148 | 149 | Args: 150 | xs: A list of tensor, in which each tensor is of size 151 | [batch_size, height_in, width_in, channels]. 152 | name: The prefix of tensorflow variables defined in this network. 153 | filters: A list of numbers indicating the number of output channels 154 | (The output channels would be 4 times to the numbers). 155 | strides: A list of numbers indicating the stride of the sliding window for 156 | height and width. 157 | dilation: A number indicating the dilation factor for height and width. 158 | is_training: If the tensorflow variables defined in this layer 159 | would be used for training. 160 | use_global_status: enable/disable use_global_status for batch 161 | normalization. If True, moving mean and moving variance are updated 162 | by exponential decay. 163 | reuse: enable/disable reuse for reusing tensorflow variables. It is 164 | useful for sharing weight parameters across two identical networks. 165 | 166 | Returns: 167 | A list of tensor, in which each tensor of size 168 | [batch_size, height_out, width_out, channels_out]. 169 | """ 170 | if len(filters) != len(num_blocks)\ 171 | or len(filters) != len(strides): 172 | raise ValueError('length of lists are not consistent') 173 | 174 | with tf.variable_scope(name, reuse=reuse) as scope: 175 | # Build conv1. 176 | xs = nn_mgpu.conv(xs, 177 | name='conv1', 178 | filters=64, 179 | kernel_size=7, 180 | strides=2, 181 | padding='VALID', 182 | biased=False, 183 | bn=True, 184 | relu=True, 185 | is_training=is_training, 186 | use_global_status=use_global_status) 187 | 188 | # Build pool1. 189 | xs = nn_mgpu.max_pool(xs, 190 | name='pool1', 191 | kernel_size=3, 192 | strides=2, 193 | padding='VALID') 194 | 195 | # Build residual bottleneck blocks. 196 | for ib in range(len(filters)): 197 | for iu in range(num_blocks[ib]): 198 | name_format = 'block{:d}/unit_{:d}/bottleneck_v1' 199 | block_name = name_format.format(ib+1, iu+1) 200 | 201 | c_o = filters[ib] # output channel 202 | # Apply strides to the last block. 203 | s = strides[ib] if iu == num_blocks[ib]-1 else 1 204 | d = dilations[ib] 205 | xs = bottleneck(xs, 206 | name=block_name, 207 | filters=c_o, 208 | strides=s, 209 | dilation=d, 210 | is_training=is_training, 211 | use_global_status=use_global_status) 212 | 213 | return xs 214 | 215 | 216 | def resnet_v1_101(xs, 217 | name, 218 | is_training, 219 | use_global_status, 220 | reuse=False): 221 | """Builds ResNet101 v1. 222 | 223 | Args: 224 | xs: A list of tensor, in which each tensor is of size 225 | [batch_size, height_in, width_in, channels]. 226 | name: The prefix of tensorflow variables defined in this network. 227 | is_training: If the tensorflow variables defined in this layer 228 | would be used for training. 229 | use_global_status: enable/disable use_global_status for batch 230 | normalization. If True, moving mean and moving variance are updated 231 | by exponential decay. 232 | reuse: enable/disable reuse for reusing tensorflow variables. It is 233 | useful for sharing weight parameters across two identical networks. 234 | 235 | Returns: 236 | A list of tensor, in which each tensor of size 237 | [batch_size, height_out, width_out, channels_out]. 238 | """ 239 | return resnet_v1(xs, 240 | name=name, 241 | filters=[64,128,256,512], 242 | num_blocks=[3,4,23,3], 243 | strides=[2,1,1,1], 244 | dilations=[None, None, 2, 4], 245 | is_training=is_training, 246 | use_global_status=use_global_status, 247 | reuse=reuse) 248 | -------------------------------------------------------------------------------- /network/multigpu/utils.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | 3 | 4 | def on_each_gpu(func): 5 | """A Decorator which perform func independently on each gpu. 6 | 7 | This function will call func on each gpu, which is useful for 8 | multi-gpu computations. The decorator takes a list of tensor 9 | as inputs. See examples in network/multigpu/layers.py. 10 | 11 | Args: 12 | func: A tensorflow operation which is performed on a single GPU, and 13 | the function takes a tensor as input. 14 | """ 15 | def inner(*args, **kwargs): 16 | xs = args[0] 17 | assert(isinstance(xs, list)) 18 | outputs = [] 19 | for x in xs: 20 | with tf.device(x.device): 21 | outputs.append(func(x, *args[1:], **kwargs)) 22 | 23 | return outputs 24 | 25 | return inner 26 | -------------------------------------------------------------------------------- /network/segsort/common_utils.py: -------------------------------------------------------------------------------- 1 | """Common utility functions for SegSort.""" 2 | 3 | import numpy as np 4 | import tensorflow as tf 5 | 6 | 7 | def normalize_embedding(embedding): 8 | """Normalizes embedding by L2 norm. 9 | 10 | This function is used to normalize embedding so that the embedding features 11 | lie on a unit hypersphere. 12 | 13 | Args: 14 | embedding: An N-D float tensor with feature embedding in the last dimension. 15 | 16 | Returns: 17 | An N-D float tensor with the same shape as input embedding with feature 18 | embedding normalized by L2 norm in the last dimension. 19 | """ 20 | return embedding / tf.norm(embedding, axis=-1, keep_dims=True) 21 | 22 | 23 | def calculate_prototypes_from_labels(embedding, 24 | labels, 25 | max_label=None): 26 | """Calculates prototypes from labels. 27 | 28 | This function calculates prototypes (mean direction) from embedding features 29 | for each label. This function is also used as the m-step in k-means 30 | clustering. 31 | 32 | Args: 33 | embedding: A 2-D or 4-D float tensor with feature embedding in the last 34 | dimension (embedding_dim). 35 | labels: An N-D int32 label map for each embedding pixel. 36 | max_label: The maximum value of the label map. Calculated on-the-fly if not 37 | specified. 38 | 39 | Returns: 40 | A 2-D float tensor with shape `[num_prototypes, embedding_dim]`. 41 | """ 42 | embedding = tf.reshape(embedding, [-1, tf.shape(embedding)[-1]]) 43 | labels = tf.reshape(labels, [-1]) 44 | 45 | if max_label is None: 46 | max_label = tf.reduce_max(labels) + 1 47 | one_hot_labels = tf.one_hot(labels, tf.cast(max_label, tf.int32)) 48 | prototypes = tf.matmul(one_hot_labels, embedding, transpose_a=True) 49 | prototypes = normalize_embedding(prototypes) 50 | return prototypes 51 | 52 | 53 | def find_nearest_prototypes(embedding, prototypes): 54 | """Finds the nearest prototype for each embedding pixel. 55 | 56 | This function calculates the index of nearest prototype for each embedding 57 | pixel. This function is also used as the e-step in k-means clustering. 58 | 59 | Args: 60 | embedding: An N-D float tensor with embedding features in the last 61 | dimension (embedding_dim). 62 | prototypes: A 2-D float tensor with shape `[num_prototypes, embedding_dim]`. 63 | 64 | Returns: 65 | A 1-D int32 tensor with length `[num_pixels]` containing the index of the 66 | nearest prototype for each pixel. 67 | """ 68 | embedding = tf.reshape(embedding, [-1, tf.shape(prototypes)[-1]]) 69 | similarities = tf.matmul(embedding, prototypes, transpose_b=True) 70 | return tf.argmax(similarities, axis=1) 71 | 72 | 73 | def kmeans_with_initial_labels(embedding, 74 | initial_labels, 75 | max_label=None, 76 | iterations=10): 77 | """Performs the von-Mises Fisher k-means clustering with initial labels. 78 | 79 | Args: 80 | embedding: A 2-D float tensor with shape `[num_pixels, embedding_dim]`. 81 | initial_labels: A 1-D integer tensor with length [num_pixels]. K-means 82 | clustering will start with this cluster labels if provided. 83 | max_label: An integer for the maximum of labels. 84 | iterations: Number of iterations for the k-means clustering. 85 | 86 | Returns: 87 | A 1-D integer tensor of the cluster label for each pixel. 88 | """ 89 | if max_label is None: 90 | max_label = tf.reduce_max(initial_labels) + 1 91 | labels = initial_labels 92 | for _ in range(iterations): 93 | # M-step of the vMF k-means clustering. 94 | prototypes = calculate_prototypes_from_labels(embedding, labels, max_label) 95 | # E-step of the vMF k-means clustering. 96 | labels = find_nearest_prototypes(embedding, prototypes) 97 | return labels 98 | 99 | 100 | def kmeans(embedding, num_clusters, iterations=10): 101 | """Performs the von-Mises Fisher k-means clustering. 102 | 103 | Args: 104 | embedding: A 4-D float tensor with shape 105 | `[batch, height, width, embedding_dim]`. 106 | num_clusters: A list of 2 integers for number of clusters in y and x axes. 107 | iterations: Number of iterations for the k-means clustering. 108 | 109 | Returns: 110 | A 3-D integer tensor of the cluster label for each pixel with shape 111 | `[batch, height, width]`. 112 | """ 113 | # shape = embedding.get_shape().as_list() 114 | shape = tf.shape(embedding) 115 | labels = initialize_cluster_labels(num_clusters, [shape[1], shape[2]]) 116 | 117 | embedding = tf.reshape(embedding, [-1, shape[3]]) 118 | labels = tf.reshape(labels, [-1]) 119 | 120 | labels = kmeans_with_initial_labels(embedding, labels, iterations=iterations) 121 | 122 | labels = tf.reshape(labels, [shape[0], shape[1], shape[2]]) 123 | return labels 124 | 125 | 126 | def initialize_cluster_labels(num_clusters, img_dimensions): 127 | """Initializes uniform cluster labels for an image. 128 | 129 | This function is used to initialize cluster labels that uniformly partition 130 | a 2-D image. 131 | 132 | Args: 133 | num_clusters: A list of 2 integers for number of clusters in y and x axes. 134 | img_dimensions: A list of 2 integers for image's y and x dimension. 135 | 136 | Returns: 137 | A 2-D int32 tensor with shape specified by img_dimension. 138 | """ 139 | yx_range = tf.cast(tf.ceil(tf.cast(img_dimensions, tf.float32) / 140 | tf.cast(num_clusters, tf.float32)), tf.int32) 141 | y_labels = tf.reshape(tf.range(img_dimensions[0]) // yx_range[0], [-1, 1]) 142 | x_labels = tf.reshape(tf.range(img_dimensions[1]) // yx_range[1], [1, -1]) 143 | labels = y_labels + (tf.reduce_max(y_labels) + 1) * x_labels 144 | return labels 145 | 146 | 147 | def generate_location_features(img_dimensions, feature_type='int'): 148 | """Calculates location features for an image. 149 | 150 | This function generates location features for an image. The 2-D location 151 | features range from 0 to 1 for y and x axes each. 152 | 153 | Args: 154 | img_dimensions: A list of 2 integers for image's y and x dimension. 155 | feature_type: The data type of location features, integer or float. 156 | 157 | Returns: 158 | A 3-D float32 tensor with shape `[img_dimension[0], img_dimension[1], 2]`. 159 | 160 | Raises: 161 | ValueError: Type of location features is neither 'int' nor 'float'. 162 | """ 163 | if feature_type == 'int': 164 | y_features = tf.range(img_dimensions[0]) 165 | x_features = tf.range(img_dimensions[1]) 166 | elif feature_type == 'float': 167 | y_features = (tf.range(img_dimensions[0], dtype=tf.float32) / 168 | img_dimensions[0]) 169 | x_features = (tf.range(img_dimensions[1], dtype=tf.float32) / 170 | img_dimensions[1]) 171 | else: 172 | raise ValueError('Type of location features should be either int or float.') 173 | 174 | x_features, y_features = tf.meshgrid(x_features, y_features) 175 | location_features = tf.stack([y_features, x_features], axis=2) 176 | return location_features 177 | 178 | 179 | def generate_location_features_np(img_dimensions): 180 | y_features = np.linspace(0, 1, img_dimensions[0]) 181 | x_features = np.linspace(0, 1, img_dimensions[1]) 182 | 183 | x_features, y_features = np.meshgrid(x_features, y_features) 184 | location_features = np.expand_dims(np.stack([y_features, x_features], axis=2), 0) 185 | return location_features 186 | 187 | def prepare_prototype_labels(semantic_labels, instance_labels, offset=256): 188 | """Prepares prototype labels from semantic and instance labels. 189 | 190 | This function generates unique prototype labels from semantic and instance 191 | labels. Note that instance labels sometimes can be cluster labels. 192 | 193 | Args: 194 | semantic_labels: A 1-D integer tensor for semantic labels. 195 | instance_labels: A 1-D integer tensor for instance labels. 196 | offset: An integer for instance offset. 197 | 198 | Returns: 199 | unique_instance_labels: A 1-D integer tensor for unique instance labels with 200 | the same length as the input semantic labels. 201 | prototype_labels: A 1-D integer tensor for the semantic labels of 202 | prototypes with length as the number of unique instances. 203 | """ 204 | instance_labels = tf.cast(instance_labels, tf.int64) 205 | semantic_labels = tf.cast(semantic_labels, tf.int64) 206 | prototype_labels, unique_instance_labels = tf.unique( 207 | tf.reshape(semantic_labels + instance_labels * offset, [-1])) 208 | 209 | unique_instance_labels = tf.cast(unique_instance_labels, tf.int32) 210 | prototype_labels = tf.cast(prototype_labels % offset, tf.int32) 211 | return unique_instance_labels, prototype_labels 212 | -------------------------------------------------------------------------------- /network/segsort/vis_utils.py: -------------------------------------------------------------------------------- 1 | """Utility functions for visualizing embeddings.""" 2 | 3 | import tensorflow as tf 4 | 5 | 6 | def calculate_principal_components(embedding, num_components=3): 7 | """Calculates the principal components given the embedding features. 8 | 9 | Args: 10 | embedding: A 2-D float tensor with embedding features in the last dimension. 11 | num_components: the number of principal components to return. 12 | 13 | Returns: 14 | A 2-D float tensor with principal components in the last dimension. 15 | """ 16 | embedding -= tf.reduce_mean(embedding, axis=0, keep_dims=True) 17 | sigma = tf.matmul(embedding, embedding, transpose_a=True) 18 | _, u, _ = tf.svd(sigma) 19 | return u[:, :num_components] 20 | 21 | 22 | def pca(embedding, num_components=3, principal_components=None): 23 | """Conducts principal component analysis on the embedding features. 24 | 25 | This function is used to reduce the dimensionality of the embedding, so that 26 | we can visualize the embedding as an RGB image. 27 | 28 | Args: 29 | embedding: A 4-D float tensor with shape 30 | [batch, height, width, embedding_dims]. 31 | num_components: The number of principal components to be reduced to. 32 | principal_components: A 2-D float tensor used to convert the embedding 33 | features to PCA'ed space, also known as the U matrix from SVD. If not 34 | given, this function will calculate the principal_components given inputs. 35 | 36 | Returns: 37 | A 4-D float tensor with shape [batch, height, width, num_components]. 38 | """ 39 | # shape = embedding.get_shape().as_list() 40 | shape = tf.shape(embedding) 41 | embedding = tf.reshape(embedding, [-1, shape[3]]) 42 | 43 | if principal_components is None: 44 | principal_components = calculate_principal_components(embedding, 45 | num_components) 46 | embedding = tf.matmul(embedding, principal_components) 47 | 48 | embedding = tf.reshape(embedding, 49 | [shape[0], shape[1], shape[2], num_components]) 50 | return embedding 51 | -------------------------------------------------------------------------------- /pyscripts/benchmark/benchmark_by_mIoU.py: -------------------------------------------------------------------------------- 1 | """Script for benchmarking semantic segmentation results by mIoU.""" 2 | import argparse 3 | import os 4 | 5 | import numpy as np 6 | 7 | from PIL import Image 8 | from utils.metrics import iou_stats 9 | 10 | 11 | parser = argparse.ArgumentParser( 12 | description='Benchmark segmentation predictions' 13 | ) 14 | parser.add_argument('--pred_dir', type=str, default='', 15 | help='/path/to/prediction.') 16 | parser.add_argument('--gt_dir', type=str, default='', 17 | help='/path/to/ground-truths') 18 | parser.add_argument('--num_classes', type=int, default=21, 19 | help='number of segmentation classes') 20 | parser.add_argument('--string_replace', type=str, default=',', 21 | help='replace the first string with the second one') 22 | args = parser.parse_args() 23 | 24 | 25 | assert(os.path.isdir(args.pred_dir)) 26 | assert(os.path.isdir(args.gt_dir)) 27 | tp_fn = np.zeros(args.num_classes, dtype=np.float64) 28 | tp_fp = np.zeros(args.num_classes, dtype=np.float64) 29 | tp = np.zeros(args.num_classes, dtype=np.float64) 30 | for dirpath, dirnames, filenames in os.walk(args.pred_dir): 31 | for filename in filenames: 32 | predname = os.path.join(dirpath, filename) 33 | gtname = predname.replace(args.pred_dir, args.gt_dir) 34 | if args.string_replace != '': 35 | stra, strb = args.string_replace.split(',') 36 | gtname = gtname.replace(stra, strb) 37 | 38 | pred = np.asarray( 39 | Image.open(predname).convert(mode='L'), 40 | dtype=np.uint8) 41 | gt = np.asarray( 42 | Image.open(gtname).convert(mode='L'), 43 | dtype=np.uint8) 44 | _tp_fn, _tp_fp, _tp = iou_stats( 45 | pred, 46 | gt, 47 | num_classes=args.num_classes, 48 | background=0) 49 | 50 | tp_fn += _tp_fn 51 | tp_fp += _tp_fp 52 | tp += _tp 53 | 54 | iou = tp / (tp_fn + tp_fp - tp + 1e-12) * 100.0 55 | 56 | class_names = ['Background', 'Aero', 'Bike', 'Bird', 'Boat', 57 | 'Bottle', 'Bus', 'Car', 'Cat', 'Chair','Cow', 58 | 'Table', 'Dog', 'Horse' ,'MBike', 'Person', 59 | 'Plant', 'Sheep', 'Sofa', 'Train', 'TV'] 60 | 61 | for i in range(args.num_classes): 62 | if i >= len(class_names): 63 | break 64 | print('class {:10s}: {:02d}, acc: {:4.4f}%'.format( 65 | class_names[i], i, iou[i])) 66 | mean_iou = iou.sum() / args.num_classes 67 | print('mean IOU: {:4.4f}%'.format(mean_iou)) 68 | 69 | mean_pixel_acc = tp.sum() / (tp_fp.sum() + 1e-12) 70 | print('mean Pixel Acc: {:4.4f}%'.format(mean_pixel_acc)) 71 | -------------------------------------------------------------------------------- /pyscripts/inference/extract_prototypes.py: -------------------------------------------------------------------------------- 1 | """Inference script for extracting segment prototypes with SegSort.""" 2 | 3 | from __future__ import print_function 4 | 5 | import argparse 6 | import math 7 | import os 8 | 9 | import network.segsort.common_utils as common_utils 10 | import network.segsort.eval_utils as eval_utils 11 | import numpy as np 12 | import tensorflow as tf 13 | 14 | from seg_models.models.pspnet import pspnet_resnet101 as model 15 | from seg_models.image_reader import SegSortImageReader 16 | from tqdm import tqdm 17 | 18 | 19 | IMG_MEAN = np.array((122.675, 116.669, 104.008), dtype=np.float32) 20 | 21 | 22 | def get_arguments(): 23 | """Parses all the arguments provided from the CLI. 24 | 25 | Returns: 26 | A list of parsed arguments. 27 | """ 28 | parser = argparse.ArgumentParser( 29 | description='Extracting Prototypes for Semantic Segmentation') 30 | parser.add_argument('--data_dir', type=str, default='', 31 | help='/path/to/dataset.') 32 | parser.add_argument('--data_list', type=str, default='', 33 | help='/path/to/datalist/file.') 34 | parser.add_argument('--input_size', type=str, default='512,512', 35 | help='Comma-separated string with H and W of image.') 36 | parser.add_argument('--strides', type=str, default='512,512', 37 | help='Comma-separated string with strides of H and W.') 38 | parser.add_argument('--num_classes', type=int, default=21, 39 | help='Number of classes to predict.') 40 | parser.add_argument('--ignore_label', type=int, default=255, 41 | help='Index of label to ignore.') 42 | parser.add_argument('--restore_from', type=str, default='', 43 | help='Where restore model parameters from.') 44 | parser.add_argument('--save_dir', type=str, default='', 45 | help='/path/to/save/predictions.') 46 | parser.add_argument('--colormap', type=str, default='', 47 | help='/path/to/colormap/file.') 48 | # SegSort parameters. 49 | parser.add_argument('--embedding_dim', type=int, default=32, 50 | help='Dimension of the feature embeddings.') 51 | parser.add_argument('--num_clusters', type=int, default=5, 52 | help='Number of kmeans clusters along each axis.') 53 | parser.add_argument('--kmeans_iterations', type=int, default=10, 54 | help='Number of kmeans iterations.') 55 | 56 | return parser.parse_args() 57 | 58 | 59 | def load(saver, sess, ckpt_path): 60 | """Loads the trained weights. 61 | 62 | Args: 63 | saver: TensorFlow saver object. 64 | sess: TensorFlow session. 65 | ckpt_path: path to checkpoint file with parameters. 66 | """ 67 | saver.restore(sess, ckpt_path) 68 | print('Restored model parameters from {}'.format(ckpt_path)) 69 | 70 | 71 | def parse_commastr(str_comma): 72 | """Reads comma-sperated string.""" 73 | if '' == str_comma: 74 | return None 75 | else: 76 | a, b = map(int, str_comma.split(',')) 77 | 78 | return [a,b] 79 | 80 | def main(): 81 | """Creates the model and start the inference process.""" 82 | args = get_arguments() 83 | 84 | # Parse image processing arguments. 85 | input_size = parse_commastr(args.input_size) 86 | strides = parse_commastr(args.strides) 87 | assert(input_size is not None and strides is not None) 88 | h, w = input_size 89 | innet_size = (int(math.ceil(h/8)), int(math.ceil(w/8))) 90 | 91 | 92 | # Create queue coordinator. 93 | coord = tf.train.Coordinator() 94 | 95 | # Load the data reader. 96 | with tf.name_scope('create_inputs'): 97 | reader = SegSortImageReader( 98 | args.data_dir, 99 | args.data_list, 100 | None, 101 | False, # No random scale 102 | False, # No random mirror 103 | False, # No random crop, center crop instead 104 | args.ignore_label, 105 | IMG_MEAN) 106 | 107 | image = reader.image 108 | label = reader.label 109 | image_list = reader.image_list 110 | image_batch = tf.expand_dims(image, dim=0) 111 | label_batch = tf.expand_dims(label, dim=0) 112 | 113 | # Create input tensor to the Network. 114 | crop_image_batch = tf.placeholder( 115 | name='crop_image_batch', 116 | shape=[1,input_size[0],input_size[1],3], 117 | dtype=tf.float32) 118 | 119 | # Create network and output prediction. 120 | outputs = model(crop_image_batch, 121 | args.embedding_dim, 122 | False, 123 | True) 124 | 125 | # Grab variable names which should be restored from checkpoints. 126 | restore_var = [ 127 | v for v in tf.global_variables() if 'crop_image_batch' not in v.name] 128 | 129 | # Output predictions. 130 | output = outputs[0] 131 | output = tf.image.resize_bilinear( 132 | output, 133 | [input_size[0], input_size[1]]) 134 | 135 | # Input full-sized embedding. 136 | label_input = tf.placeholder( 137 | tf.int32, shape=[1, None, None, 1]) 138 | embedding_input = tf.placeholder( 139 | tf.float32, shape=[1, None, None, args.embedding_dim]) 140 | embedding = common_utils.normalize_embedding(embedding_input) 141 | loc_feature = tf.placeholder( 142 | tf.float32, shape=[1, None, None, 2]) 143 | 144 | # Combine embedding with location features and kmeans. 145 | shape = tf.shape(embedding) 146 | cluster_labels = common_utils.initialize_cluster_labels( 147 | [args.num_clusters, args.num_clusters], 148 | [shape[1], shape[2]]) 149 | embedding = tf.reshape(embedding, [-1, args.embedding_dim]) 150 | labels = tf.reshape(label_input, [-1]) 151 | cluster_labels = tf.reshape(cluster_labels, [-1]) 152 | location_features = tf.reshape(loc_feature, [-1, 2]) 153 | 154 | # Extract prototype features and labels from embeddings. 155 | (prototype_features, 156 | prototype_labels, 157 | _) = eval_utils.extract_trained_prototypes( 158 | embedding, location_features, cluster_labels, 159 | args.num_clusters * args.num_clusters, 160 | args.kmeans_iterations, labels, 161 | 1, args.ignore_label, 162 | 'semantic') 163 | 164 | # Set up tf session and initialize variables. 165 | config = tf.ConfigProto() 166 | config.gpu_options.allow_growth = True 167 | sess = tf.Session(config=config) 168 | init = tf.global_variables_initializer() 169 | 170 | sess.run(init) 171 | sess.run(tf.local_variables_initializer()) 172 | 173 | # Load weights. 174 | loader = tf.train.Saver(var_list=restore_var) 175 | if args.restore_from is not None: 176 | load(loader, sess, args.restore_from) 177 | 178 | # Start queue threads. 179 | threads = tf.train.start_queue_runners(coord=coord, sess=sess) 180 | 181 | # Create directory for saving prototypes. 182 | save_dir = os.path.join(args.save_dir, 'prototypes') 183 | if not os.path.isdir(save_dir): 184 | os.makedirs(save_dir) 185 | 186 | # Iterate over testing steps. 187 | with open(args.data_list, 'r') as listf: 188 | num_steps = len(listf.read().split('\n'))-1 189 | 190 | 191 | pbar = tqdm(range(num_steps)) 192 | for step in pbar: 193 | image_batch_np, label_batch_np = sess.run( 194 | [image_batch, label_batch]) 195 | 196 | img_size = image_batch_np.shape 197 | padded_img_size = list(img_size) # deep copy of img_size 198 | 199 | if input_size[0] > padded_img_size[1]: 200 | padded_img_size[1] = input_size[0] 201 | if input_size[1] > padded_img_size[2]: 202 | padded_img_size[2] = input_size[1] 203 | padded_img_batch = np.zeros(padded_img_size, 204 | dtype=np.float32) 205 | img_h, img_w = img_size[1:3] 206 | padded_img_batch[:, :img_h, :img_w, :] = image_batch_np 207 | 208 | stride_h, stride_w = strides 209 | npatches_h = math.ceil(1.0*(padded_img_size[1]-input_size[0])/stride_h) + 1 210 | npatches_w = math.ceil(1.0*(padded_img_size[2]-input_size[1])/stride_w) + 1 211 | 212 | # Create the ending index of each patch. 213 | patch_indh = np.linspace( 214 | input_size[0], padded_img_size[1], npatches_h, dtype=np.int32) 215 | patch_indw = np.linspace( 216 | input_size[1], padded_img_size[2], npatches_w, dtype=np.int32) 217 | 218 | # Create embedding holder. 219 | padded_img_size[-1] = args.embedding_dim 220 | embedding_all_np = np.zeros(padded_img_size, 221 | dtype=np.float32) 222 | for indh in patch_indh: 223 | for indw in patch_indw: 224 | sh, eh = indh-input_size[0], indh # start & end ind of H 225 | sw, ew = indw-input_size[1], indw # start & end ind of W 226 | cropimg_batch = padded_img_batch[:, sh:eh, sw:ew, :] 227 | 228 | embedding_np = sess.run(output, feed_dict={ 229 | crop_image_batch: cropimg_batch}) 230 | embedding_all_np[:, sh:eh, sw:ew, :] += embedding_np 231 | 232 | embedding_all_np = embedding_all_np[:, :img_h, :img_w, :] 233 | loc_feature_np = common_utils.generate_location_features_np([padded_img_size[1], padded_img_size[2]]) 234 | feed_dict = {label_input: label_batch_np, 235 | embedding_input: embedding_all_np, 236 | loc_feature: loc_feature_np} 237 | 238 | (batch_prototype_features_np, 239 | batch_prototype_labels_np) = sess.run( 240 | [prototype_features, prototype_labels], 241 | feed_dict=feed_dict) 242 | 243 | if step == 0: 244 | prototype_features_np = batch_prototype_features_np 245 | prototype_labels_np = batch_prototype_labels_np 246 | else: 247 | prototype_features_np = np.concatenate( 248 | [prototype_features_np, batch_prototype_features_np], axis=0) 249 | prototype_labels_np = np.concatenate( 250 | [prototype_labels_np, 251 | batch_prototype_labels_np], axis=0) 252 | 253 | 254 | print ('Total number of prototypes extracted: ', 255 | len(prototype_labels_np)) 256 | np.save( 257 | tf.gfile.Open('%s/%s.npy' % (save_dir, 'prototype_features'), 258 | mode='w'), prototype_features_np) 259 | np.save( 260 | tf.gfile.Open('%s/%s.npy' % (save_dir, 'prototype_labels'), 261 | mode='w'), prototype_labels_np) 262 | 263 | 264 | coord.request_stop() 265 | coord.join(threads) 266 | 267 | if __name__ == '__main__': 268 | main() 269 | -------------------------------------------------------------------------------- /pyscripts/inference/inference.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | 3 | import argparse 4 | import os 5 | import time 6 | import math 7 | 8 | import tensorflow as tf 9 | import numpy as np 10 | import scipy.io 11 | import scipy.misc 12 | from PIL import Image 13 | 14 | from seg_models.models.pspnet import pspnet_resnet101 as model 15 | from seg_models.image_reader import ImageReader 16 | import utils.general 17 | 18 | IMG_MEAN = np.array((122.675, 116.669, 104.008), dtype=np.float32) 19 | 20 | 21 | def get_arguments(): 22 | """Parse all the arguments provided from the CLI. 23 | 24 | Returns: 25 | A list of parsed arguments. 26 | """ 27 | parser = argparse.ArgumentParser( 28 | description='Inference for Semantic Segmentation') 29 | parser.add_argument('--data-dir', type=str, default='', 30 | help='/path/to/dataset.') 31 | parser.add_argument('--data-list', type=str, default='', 32 | help='/path/to/datalist/file.') 33 | parser.add_argument('--input-size', type=str, default='512,512', 34 | help='Comma-separated string with H and W of image.') 35 | parser.add_argument('--strides', type=str, default='512,512', 36 | help='Comma-separated string with strides of H and W.') 37 | parser.add_argument('--num-classes', type=int, default=21, 38 | help='Number of classes to predict.') 39 | parser.add_argument('--ignore-label', type=int, default=255, 40 | help='Index of label to ignore.') 41 | parser.add_argument('--restore-from', type=str, default='', 42 | help='Where restore model parameters from.') 43 | parser.add_argument('--save-dir', type=str, default='', 44 | help='/path/to/save/predictions.') 45 | parser.add_argument('--colormap', type=str, default='', 46 | help='/path/to/colormap/file.') 47 | 48 | return parser.parse_args() 49 | 50 | def load(saver, sess, ckpt_path): 51 | """Load the trained weights. 52 | 53 | Args: 54 | saver: TensorFlow saver object. 55 | sess: TensorFlow session. 56 | ckpt_path: path to checkpoint file with parameters. 57 | """ 58 | saver.restore(sess, ckpt_path) 59 | print('Restored model parameters from {}'.format(ckpt_path)) 60 | 61 | def parse_commastr(str_comma): 62 | """Read comma-sperated string. 63 | """ 64 | if '' == str_comma: 65 | return None 66 | else: 67 | a, b = map(int, str_comma.split(',')) 68 | 69 | return [a,b] 70 | 71 | def main(): 72 | """Create the model and start the Inference process. 73 | """ 74 | args = get_arguments() 75 | 76 | # Parse image processing arguments. 77 | input_size = parse_commastr(args.input_size) 78 | strides = parse_commastr(args.strides) 79 | assert(input_size is not None and strides is not None) 80 | h, w = input_size 81 | innet_size = (int(math.ceil(h/8)), int(math.ceil(w/8))) 82 | 83 | 84 | # Create queue coordinator. 85 | coord = tf.train.Coordinator() 86 | 87 | # Load the data reader. 88 | with tf.name_scope('create_inputs'): 89 | reader = ImageReader( 90 | args.data_dir, 91 | args.data_list, 92 | None, 93 | False, # No random scale. 94 | False, # No random mirror. 95 | False, # No random crop, center crop instead 96 | args.ignore_label, 97 | IMG_MEAN) 98 | image = reader.image 99 | image_list = reader.image_list 100 | image_batch = tf.expand_dims(image, dim=0) 101 | 102 | # Create input tensor to the Network. 103 | crop_image_batch = tf.placeholder( 104 | name='crop_image_batch', 105 | shape=[1,input_size[0],input_size[1],3], 106 | dtype=tf.float32) 107 | 108 | # Create network and output prediction. 109 | outputs = model(crop_image_batch, 110 | args.num_classes, 111 | False, 112 | True) 113 | 114 | # Grab variable names which should be restored from checkpoints. 115 | restore_var = [ 116 | v for v in tf.global_variables() if 'crop_image_batch' not in v.name] 117 | 118 | # Output predictions. 119 | output = outputs[-1] 120 | output = tf.image.resize_bilinear( 121 | output, 122 | tf.shape(crop_image_batch)[1:3,]) 123 | output = tf.nn.softmax(output, dim=3) 124 | 125 | # Set up tf session and initialize variables. 126 | config = tf.ConfigProto() 127 | config.gpu_options.allow_growth = True 128 | sess = tf.Session(config=config) 129 | init = tf.global_variables_initializer() 130 | 131 | sess.run(init) 132 | sess.run(tf.local_variables_initializer()) 133 | 134 | # Load weights. 135 | loader = tf.train.Saver(var_list=restore_var) 136 | if args.restore_from is not None: 137 | load(loader, sess, args.restore_from) 138 | 139 | # Start queue threads. 140 | threads = tf.train.start_queue_runners(coord=coord, sess=sess) 141 | 142 | # Get colormap. 143 | map_data = scipy.io.loadmat(args.colormap) 144 | key = os.path.basename(args.colormap).replace('.mat','') 145 | colormap = map_data[key] 146 | colormap *= 255 147 | colormap = colormap.astype(np.uint8) 148 | 149 | # Create directory for saving predictions. 150 | pred_dir = os.path.join(args.save_dir, 'gray') 151 | color_dir = os.path.join(args.save_dir, 'color') 152 | if not os.path.isdir(pred_dir): 153 | os.makedirs(pred_dir) 154 | if not os.path.isdir(color_dir): 155 | os.makedirs(color_dir) 156 | 157 | # Iterate over testing steps. 158 | with open(args.data_list, 'r') as listf: 159 | num_steps = len(listf.read().split('\n'))-1 160 | 161 | for step in range(num_steps): 162 | img_batch = sess.run(image_batch) 163 | img_size = img_batch.shape 164 | padimg_size = list(img_size) # deep copy of img_size 165 | 166 | padimg_h, padimg_w = padimg_size[1:3] 167 | input_h, input_w = input_size 168 | 169 | if input_h > padimg_h: 170 | padimg_h = input_h 171 | if input_w > padimg_w: 172 | padimg_w = input_w 173 | 174 | # Update padded image size. 175 | padimg_size[1] = padimg_h 176 | padimg_size[2] = padimg_w 177 | padimg_batch = np.zeros(padimg_size, dtype=np.float32) 178 | img_h, img_w = img_size[1:3] 179 | padimg_batch[:, :img_h, :img_w, :] = img_batch 180 | 181 | # Create padded label array. 182 | lab_size = list(padimg_size) 183 | lab_size[-1] = args.num_classes 184 | lab_batch = np.zeros(lab_size, dtype=np.float32) 185 | lab_batch.fill(args.ignore_label) 186 | 187 | stride_h, stride_w = strides 188 | npatches_h = math.ceil(1.0*(padimg_h-input_h)/stride_h) + 1 189 | npatches_w = math.ceil(1.0*(padimg_w-input_w)/stride_w) + 1 190 | 191 | # Crate the ending index of each patch. 192 | patch_indh = np.linspace( 193 | input_h, padimg_h, npatches_h, dtype=np.int32) 194 | patch_indw = np.linspace( 195 | input_w, padimg_w, npatches_w, dtype=np.int32) 196 | 197 | for indh in patch_indh: 198 | for indw in patch_indw: 199 | sh, eh = indh-input_h, indh # start&end ind of H 200 | sw, ew = indw-input_w, indw # start&end ind of W 201 | cropimg_batch = padimg_batch[:, sh:eh, sw:ew, :] 202 | feed_dict = {crop_image_batch: cropimg_batch} 203 | 204 | out = sess.run(output, feed_dict=feed_dict) 205 | lab_batch[:, sh:eh, sw:ew, :] += out 206 | 207 | lab_batch = lab_batch[0, :img_h, :img_w, :] 208 | lab_batch = np.argmax(lab_batch, axis=-1) 209 | lab_batch = lab_batch.astype(np.uint8) 210 | 211 | basename = os.path.basename(image_list[step]) 212 | basename = basename.replace('jpg', 'png') 213 | 214 | predname = os.path.join(pred_dir, basename) 215 | Image.fromarray(lab_batch, mode='L').save(predname) 216 | 217 | colorname = os.path.join(color_dir, basename) 218 | color = colormap[lab_batch] 219 | Image.fromarray(color, mode='RGB').save(colorname) 220 | 221 | coord.request_stop() 222 | coord.join(threads) 223 | 224 | if __name__ == '__main__': 225 | main() 226 | -------------------------------------------------------------------------------- /pyscripts/inference/inference_msc.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | 3 | import argparse 4 | import os 5 | import time 6 | import math 7 | 8 | import tensorflow as tf 9 | import numpy as np 10 | import scipy.io 11 | import scipy.misc 12 | from PIL import Image 13 | import cv2 14 | 15 | from seg_models.models.pspnet import pspnet_resnet101 as model 16 | from seg_models.image_reader import ImageReader 17 | import utils.general 18 | 19 | IMG_MEAN = np.array((122.675, 116.669, 104.008), dtype=np.float32) 20 | 21 | 22 | def get_arguments(): 23 | """Parse all the arguments provided from the CLI. 24 | 25 | Returns: 26 | A list of parsed arguments. 27 | """ 28 | parser = argparse.ArgumentParser( 29 | description='Inference of Semantic Segmentation.') 30 | parser.add_argument('--data-dir', type=str, default='', 31 | help='/path/to/dataset.') 32 | parser.add_argument('--data-list', type=str, default='', 33 | help='/path/to/datalist/file.') 34 | parser.add_argument('--input-size', type=str, default='512,512', 35 | help='Comma-separated string with H and W of image.') 36 | parser.add_argument('--strides', type=str, default='512,512', 37 | help='Comma-separated string with strides of H and W.') 38 | parser.add_argument('--num-classes', type=int, default=21, 39 | help='Number of classes to predict.') 40 | parser.add_argument('--ignore-label', type=int, default=255, 41 | help='Index of label to ignore.') 42 | parser.add_argument('--restore-from', type=str, default='', 43 | help='Where restore model parameters from.') 44 | parser.add_argument('--save-dir', type=str, default='', 45 | help='/path/to/save/predictions.') 46 | parser.add_argument('--colormap', type=str, default='', 47 | help='/path/to/colormap/file.') 48 | parser.add_argument('--flip-aug', action='store_true', 49 | help='Augment data by horizontal flipping.') 50 | parser.add_argument('--scale-aug', action='store_true', 51 | help='Augment data with multi-scale.') 52 | 53 | return parser.parse_args() 54 | 55 | def load(saver, sess, ckpt_path): 56 | """Load the trained weights. 57 | 58 | Args: 59 | saver: TensorFlow saver object. 60 | sess: TensorFlow session. 61 | ckpt_path: path to checkpoint file with parameters. 62 | """ 63 | saver.restore(sess, ckpt_path) 64 | print('Restored model parameters from {}'.format(ckpt_path)) 65 | 66 | def parse_commastr(str_comma): 67 | """Read comma-sperated string. 68 | """ 69 | if '' == str_comma: 70 | return None 71 | else: 72 | a, b = map(int, str_comma.split(',')) 73 | 74 | return [a,b] 75 | 76 | def main(): 77 | """Create the model and start the inference process. 78 | """ 79 | args = get_arguments() 80 | 81 | # Parse image processing arguments. 82 | input_size = parse_commastr(args.input_size) 83 | strides = parse_commastr(args.strides) 84 | assert(input_size is not None and strides is not None) 85 | h, w = input_size 86 | innet_size = (int(math.ceil(h/8)), int(math.ceil(w/8))) 87 | 88 | 89 | # Create queue coordinator. 90 | coord = tf.train.Coordinator() 91 | 92 | # Load the data reader. 93 | with tf.name_scope('create_inputs'): 94 | reader = ImageReader( 95 | args.data_dir, 96 | args.data_list, 97 | None, 98 | False, # No random scale. 99 | False, # No random mirror. 100 | False, # No random crop, center crop instead 101 | args.ignore_label, 102 | IMG_MEAN) 103 | image = reader.image 104 | image_list = reader.image_list 105 | image_batch = tf.expand_dims(image, dim=0) 106 | 107 | # Create multi-scale augmented datas. 108 | rescale_image_batches = [] 109 | is_flipped = [] 110 | scales = [0.5, 0.75, 1, 1.25, 1.5, 1.75] if args.scale_aug else [1] 111 | for scale in scales: 112 | h_new = tf.to_int32( 113 | tf.multiply(tf.to_float(tf.shape(image_batch)[1]), scale)) 114 | w_new = tf.to_int32( 115 | tf.multiply(tf.to_float(tf.shape(image_batch)[2]), scale)) 116 | new_shape = tf.stack([h_new, w_new]) 117 | new_image_batch = tf.image.resize_images(image_batch, new_shape) 118 | rescale_image_batches.append(new_image_batch) 119 | is_flipped.append(False) 120 | 121 | # Create horizontally flipped augmented datas. 122 | if args.flip_aug: 123 | for i in range(len(scales)): 124 | img = rescale_image_batches[i] 125 | is_flip = is_flipped[i] 126 | img = tf.squeeze(img, axis=0) 127 | flip_img = tf.image.flip_left_right(img) 128 | flip_img = tf.expand_dims(flip_img, axis=0) 129 | rescale_image_batches.append(flip_img) 130 | is_flipped.append(True) 131 | 132 | # Create input tensor to the Network 133 | crop_image_batch = tf.placeholder( 134 | name='crop_image_batch', 135 | shape=[1,input_size[0],input_size[1],3], 136 | dtype=tf.float32) 137 | 138 | # Create network. 139 | outputs = model(crop_image_batch, 140 | args.num_classes, 141 | False, 142 | True) 143 | 144 | # Grab variable names which should be restored from checkpoints. 145 | restore_var = [ 146 | v for v in tf.global_variables() if 'crop_image_batch' not in v.name] 147 | 148 | # Output predictions. 149 | output = outputs[-1] 150 | output = tf.image.resize_bilinear( 151 | output, 152 | tf.shape(crop_image_batch)[1:3,]) 153 | output = tf.nn.softmax(output, dim=3) 154 | 155 | # Set up tf session and initialize variables. 156 | config = tf.ConfigProto() 157 | config.gpu_options.allow_growth = True 158 | sess = tf.Session(config=config) 159 | init = tf.global_variables_initializer() 160 | 161 | sess.run(init) 162 | sess.run(tf.local_variables_initializer()) 163 | 164 | # Load weights. 165 | loader = tf.train.Saver(var_list=restore_var) 166 | if args.restore_from is not None: 167 | load(loader, sess, args.restore_from) 168 | 169 | # Start queue threads. 170 | threads = tf.train.start_queue_runners(coord=coord, sess=sess) 171 | 172 | # Get colormap. 173 | map_data = scipy.io.loadmat(args.colormap) 174 | key = os.path.basename(args.colormap).replace('.mat','') 175 | colormap = map_data[key] 176 | colormap *= 255 177 | colormap = colormap.astype(np.uint8) 178 | 179 | # Create directory for saving predictions. 180 | pred_dir = os.path.join(args.save_dir, 'gray') 181 | color_dir = os.path.join(args.save_dir, 'color') 182 | if not os.path.isdir(pred_dir): 183 | os.makedirs(pred_dir) 184 | if not os.path.isdir(color_dir): 185 | os.makedirs(color_dir) 186 | 187 | # Iterate over testing steps. 188 | with open(args.data_list, 'r') as listf: 189 | num_steps = len(listf.read().split('\n'))-1 190 | 191 | for step in range(num_steps): 192 | rescale_img_batches = sess.run(rescale_image_batches) 193 | # Final segmentation results (average across multiple scales). 194 | scale_ind = 2 if args.scale_aug else 0 195 | final_lab_size = list(rescale_img_batches[scale_ind].shape[1:]) 196 | final_lab_size[-1] = args.num_classes 197 | final_lab_batch = np.zeros(final_lab_size) 198 | 199 | # Iterate over multiple scales. 200 | for img_batch,is_flip in zip(rescale_img_batches, is_flipped): 201 | img_size = img_batch.shape 202 | padimg_size = list(img_size) # deep copy of img_size 203 | 204 | padimg_h, padimg_w = padimg_size[1:3] 205 | input_h, input_w = input_size 206 | 207 | if input_h > padimg_h: 208 | padimg_h = input_h 209 | if input_w > padimg_w: 210 | padimg_w = input_w 211 | # Update padded image size. 212 | padimg_size[1] = padimg_h 213 | padimg_size[2] = padimg_w 214 | padimg_batch = np.zeros(padimg_size, dtype=np.float32) 215 | img_h, img_w = img_size[1:3] 216 | padimg_batch[:, :img_h, :img_w, :] = img_batch 217 | 218 | # Create padded label array. 219 | lab_size = list(padimg_size) 220 | lab_size[-1] = args.num_classes 221 | lab_batch = np.zeros(lab_size, dtype=np.float32) 222 | lab_batch.fill(args.ignore_label) 223 | num_batch = np.zeros(lab_size[:-1], dtype=np.float32) 224 | 225 | stride_h, stride_w = strides 226 | npatches_h = math.ceil(1.0*(padimg_h-input_h)/stride_h) + 1 227 | npatches_w = math.ceil(1.0*(padimg_w-input_w)/stride_w) + 1 228 | 229 | # Create the ending index of each patch. 230 | patch_indh = np.linspace( 231 | input_h, padimg_h, npatches_h, dtype=np.int32) 232 | patch_indw = np.linspace( 233 | input_w, padimg_w, npatches_w, dtype=np.int32) 234 | 235 | for indh in patch_indh: 236 | for indw in patch_indw: 237 | sh, eh = indh-input_h, indh # start&end ind of H 238 | sw, ew = indw-input_w, indw # start&end ind of W 239 | cropimg_batch = padimg_batch[:, sh:eh, sw:ew, :] 240 | feed_dict = {crop_image_batch: cropimg_batch} 241 | 242 | out = sess.run(output, feed_dict=feed_dict) 243 | lab_batch[:, sh:eh, sw:ew, :] += out 244 | num_batch[:, sh:eh, sw:ew] += 1 245 | 246 | lab_batch /= num_batch[..., np.newaxis] 247 | lab_batch = lab_batch[0, :img_h, :img_w, :] 248 | # Rescale prediction back to original resolution. 249 | lab_batch = cv2.resize( 250 | lab_batch, 251 | (final_lab_size[1], final_lab_size[0]), 252 | interpolation=cv2.INTER_LINEAR) 253 | if is_flip: 254 | # Flipped prediction back to original orientation. 255 | lab_batch = lab_batch[:, ::-1, :] 256 | final_lab_batch += lab_batch 257 | 258 | final_lab_ind = np.argmax(final_lab_batch, axis=-1) 259 | final_lab_ind = final_lab_ind.astype(np.uint8) 260 | 261 | basename = os.path.basename(image_list[step]) 262 | basename = basename.replace('jpg', 'png') 263 | 264 | predname = os.path.join(pred_dir, basename) 265 | Image.fromarray(final_lab_ind, mode='L').save(predname) 266 | 267 | colorname = os.path.join(color_dir, basename) 268 | color = colormap[final_lab_ind] 269 | Image.fromarray(color, mode='RGB').save(colorname) 270 | 271 | coord.request_stop() 272 | coord.join(threads) 273 | 274 | if __name__ == '__main__': 275 | main() 276 | -------------------------------------------------------------------------------- /pyscripts/inference/inference_patch.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | 3 | import argparse 4 | import copy 5 | import os 6 | import time 7 | import math 8 | 9 | import tensorflow as tf 10 | import numpy as np 11 | import scipy.io 12 | import scipy.misc 13 | from PIL import Image 14 | import network.vmf.common_utils as common_utils 15 | import network.vmf.eval_utils as eval_utils 16 | import network.vmf.vis_utils as vis_utils 17 | 18 | from seg_models.models.pspnet import pspnet_resnet101 as model 19 | from seg_models.image_reader import VMFImageReader 20 | import utils.general 21 | import utils.html_helper as html_helper 22 | 23 | IMG_MEAN = np.array((122.675, 116.669, 104.008), dtype=np.float32) 24 | 25 | 26 | def get_arguments(): 27 | """Parse all the arguments provided from the CLI. 28 | 29 | Returns: 30 | A list of parsed arguments. 31 | """ 32 | parser = argparse.ArgumentParser( 33 | description='Inference for Semantic Segmentation') 34 | parser.add_argument('--data-dir', type=str, default='', 35 | help='/path/to/dataset.') 36 | parser.add_argument('--data-list', type=str, default='', 37 | help='/path/to/datalist/file.') 38 | parser.add_argument('--input-size', type=str, default='512,512', 39 | help='Comma-separated string with H and W of image.') 40 | parser.add_argument('--strides', type=str, default='512,512', 41 | help='Comma-separated string with strides of H and W.') 42 | parser.add_argument('--num-classes', type=int, default=21, 43 | help='Number of classes to predict.') 44 | parser.add_argument('--ignore-label', type=int, default=255, 45 | help='Index of label to ignore.') 46 | parser.add_argument('--restore-from', type=str, default='', 47 | help='Where restore model parameters from.') 48 | parser.add_argument('--save-dir', type=str, default='', 49 | help='/path/to/save/predictions.') 50 | parser.add_argument('--colormap', type=str, default='', 51 | help='/path/to/colormap/file.') 52 | # vMF parameters 53 | parser.add_argument('--prototype_dir', type=str, default='', 54 | help='/path/to/prototype/file.') 55 | parser.add_argument('--embedding_dim', type=int, default=32, 56 | help='Dimension of the feature embeddings.') 57 | parser.add_argument('--num_clusters', type=int, default=5, 58 | help='Number of kmeans clusters along each axis') 59 | parser.add_argument('--kmeans_iterations', type=int, default=10, 60 | help='Number of kmeans iterations.') 61 | parser.add_argument('--k_in_nearest_neighbors', type=int, default=15, 62 | help='K in k-nearest neighbor search.') 63 | 64 | return parser.parse_args() 65 | 66 | def load(saver, sess, ckpt_path): 67 | """Load the trained weights. 68 | 69 | Args: 70 | saver: TensorFlow saver object. 71 | sess: TensorFlow session. 72 | ckpt_path: path to checkpoint file with parameters. 73 | """ 74 | saver.restore(sess, ckpt_path) 75 | print('Restored model parameters from {}'.format(ckpt_path)) 76 | 77 | def parse_commastr(str_comma): 78 | """Read comma-sperated string. 79 | """ 80 | if '' == str_comma: 81 | return None 82 | else: 83 | a, b = map(int, str_comma.split(',')) 84 | 85 | return [a,b] 86 | 87 | def main(): 88 | """Create the model and start the Inference process. 89 | """ 90 | args = get_arguments() 91 | 92 | # Create queue coordinator. 93 | coord = tf.train.Coordinator() 94 | 95 | # Load the data reader. 96 | with tf.name_scope('create_inputs'): 97 | reader = VMFImageReader( 98 | args.data_dir, 99 | args.data_list, 100 | None, 101 | False, # No random scale. 102 | False, # No random mirror. 103 | False, # No random crop, center crop instead 104 | args.ignore_label, 105 | IMG_MEAN) 106 | 107 | image_list = reader.image_list 108 | image = reader.image 109 | cluster_label = reader.cluster_label 110 | loc_feature = reader.loc_feature 111 | height = reader.height 112 | width = reader.width 113 | 114 | # Create network and output prediction. 115 | outputs = model(tf.expand_dims(image, dim=0), 116 | args.embedding_dim, 117 | False, 118 | True) 119 | 120 | # Grab variable names which should be restored from checkpoints. 121 | restore_var = [v for v in tf.global_variables()] 122 | 123 | # Output predictions. 124 | output = outputs[0] 125 | output = tf.image.resize_bilinear( 126 | output, 127 | tf.shape(image)[:2,]) 128 | embedding = common_utils.normalize_embedding(output) 129 | embedding = tf.squeeze(embedding, axis=0) 130 | 131 | image = image[:height, :width] 132 | embedding = tf.reshape( 133 | embedding[:height, :width], [-1, args.embedding_dim]) 134 | cluster_label = tf.reshape(cluster_label[:height, :width], [-1]) 135 | loc_feature = tf.reshape( 136 | loc_feature[:height, :width], [-1, 2]) 137 | 138 | # Prototype placeholders. 139 | prototype_features = tf.placeholder(tf.float32, 140 | shape=[None, args.embedding_dim]) 141 | prototype_labels = tf.placeholder(tf.int32) 142 | 143 | # Combine embedding with location features and kmeans 144 | embedding_with_location = tf.concat([embedding, loc_feature], 1) 145 | embedding_with_location = common_utils.normalize_embedding( 146 | embedding_with_location) 147 | cluster_label = common_utils.kmeans_with_initial_labels( 148 | embedding_with_location, 149 | cluster_label, 150 | args.num_clusters * args.num_clusters, 151 | args.kmeans_iterations) 152 | _, cluster_labels = tf.unique(cluster_label) 153 | test_prototypes = common_utils.calculate_prototypes_from_labels( 154 | embedding, cluster_labels) 155 | 156 | cluster_labels = tf.reshape(cluster_labels, [height, width]) 157 | 158 | # Predict semantic labels. 159 | similarities = tf.matmul(test_prototypes, 160 | prototype_features, 161 | transpose_b=True) 162 | _, k_predictions = tf.nn.top_k(similarities, k=args.k_in_nearest_neighbors, sorted=True) 163 | 164 | prototype_semantic_predictions = eval_utils.k_nearest_neighbors( 165 | k_predictions, prototype_labels) 166 | semantic_predictions = tf.gather(prototype_semantic_predictions, 167 | cluster_labels) 168 | # semantic_predictions = tf.squeeze(semantic_predictions) 169 | 170 | # Visualize embedding using PCA 171 | embedding = vis_utils.pca(tf.reshape(embedding, [1, height, width, args.embedding_dim])) 172 | embedding = ((embedding - tf.reduce_min(embedding)) / 173 | (tf.reduce_max(embedding) - tf.reduce_min(embedding))) 174 | embedding = tf.cast(embedding * 255, tf.uint8) 175 | embedding = tf.squeeze(embedding, axis=0) 176 | 177 | # Set up tf session and initialize variables. 178 | config = tf.ConfigProto() 179 | config.gpu_options.allow_growth = True 180 | sess = tf.Session(config=config) 181 | init = tf.global_variables_initializer() 182 | 183 | sess.run(init) 184 | sess.run(tf.local_variables_initializer()) 185 | 186 | # Load weights. 187 | loader = tf.train.Saver(var_list=restore_var) 188 | if args.restore_from is not None: 189 | load(loader, sess, args.restore_from) 190 | 191 | # Start queue threads. 192 | threads = tf.train.start_queue_runners(coord=coord, sess=sess) 193 | 194 | # Get colormap. 195 | map_data = scipy.io.loadmat(args.colormap) 196 | key = os.path.basename(args.colormap).replace('.mat','') 197 | colormap = map_data[key] 198 | colormap *= 255 199 | colormap = colormap.astype(np.uint8) 200 | 201 | # Create directory for saving predictions. 202 | pred_dir = os.path.join(args.save_dir, 'gray') 203 | color_dir = os.path.join(args.save_dir, 'color') 204 | cluster_dir = os.path.join(args.save_dir, 'cluster') 205 | embedding_dir = os.path.join(args.save_dir, 'embedding') 206 | patch_dir = os.path.join(args.save_dir, 'test_patches') 207 | if not os.path.isdir(pred_dir): 208 | os.makedirs(pred_dir) 209 | if not os.path.isdir(color_dir): 210 | os.makedirs(color_dir) 211 | if not os.path.isdir(cluster_dir): 212 | os.makedirs(cluster_dir) 213 | if not os.path.isdir(embedding_dir): 214 | os.makedirs(embedding_dir) 215 | if not os.path.isdir(patch_dir): 216 | os.makedirs(patch_dir) 217 | 218 | # Iterate over testing steps. 219 | with open(args.data_list, 'r') as listf: 220 | num_steps = len(listf.read().split('\n'))-1 221 | 222 | # Load prototype features and labels 223 | prototype_features_np = np.load( 224 | os.path.join(args.prototype_dir, 'prototype_features.npy')) 225 | prototype_labels_np = np.load( 226 | os.path.join(args.prototype_dir, 'prototype_labels.npy')) 227 | 228 | feed_dict = {prototype_features: prototype_features_np, 229 | prototype_labels: prototype_labels_np} 230 | 231 | f = html_helper.open_html_for_write(os.path.join(args.save_dir, 'index.html'), 232 | 'Visualization for Segment Collaging') 233 | for step in range(num_steps): 234 | image_np, semantic_predictions_np, cluster_labels_np, embedding_np, k_predictions_np = sess.run( 235 | [image, semantic_predictions, cluster_labels, embedding, k_predictions], 236 | feed_dict=feed_dict) 237 | 238 | imgname = os.path.basename(image_list[step]) 239 | basename = imgname.replace('jpg', 'png') 240 | 241 | predname = os.path.join(pred_dir, basename) 242 | Image.fromarray(semantic_predictions_np, mode='L').save(predname) 243 | 244 | colorname = os.path.join(color_dir, basename) 245 | color = colormap[semantic_predictions_np] 246 | Image.fromarray(color, mode='RGB').save(colorname) 247 | 248 | clustername = os.path.join(cluster_dir, basename) 249 | cluster = colormap[cluster_labels_np] 250 | Image.fromarray(cluster, mode='RGB').save(clustername) 251 | 252 | embeddingname = os.path.join(embedding_dir, basename) 253 | Image.fromarray(embedding_np, mode='RGB').save(embeddingname) 254 | 255 | image_np = (image_np + IMG_MEAN).astype(np.uint8) 256 | for i in range(np.max(cluster_labels_np) + 1): 257 | image_temp = copy.deepcopy(image_np) 258 | image_temp[cluster_labels_np != i] = 0 259 | coords = np.where(cluster_labels_np == i) 260 | crop = image_temp[np.min(coords[0]):np.max(coords[0]), np.min(coords[1]):np.max(coords[1])] 261 | scipy.misc.imsave(patch_dir + '/' + basename + str(i).zfill(3) + '.png', crop) 262 | 263 | html_helper.write_vmf_to_html(f, './images/' + imgname, './labels/' + basename, 264 | './color/' + basename, './cluster/' + basename, 265 | './embedding/' + basename, './test_patches/' + basename, './patches/', k_predictions_np) 266 | 267 | if (step + 1) % 100 == 0: 268 | print('Processed batches: ', (step + 1), '/', num_steps) 269 | 270 | html_helper.close_html(f) 271 | coord.request_stop() 272 | coord.join(threads) 273 | 274 | if __name__ == '__main__': 275 | main() 276 | -------------------------------------------------------------------------------- /pyscripts/inference/inference_segsort.py: -------------------------------------------------------------------------------- 1 | """Single-scale inference script for predicting segmentations using SegSort.""" 2 | 3 | from __future__ import print_function 4 | 5 | import argparse 6 | import math 7 | import os 8 | import time 9 | import scipy.io 10 | 11 | import network.segsort.common_utils as common_utils 12 | import network.segsort.eval_utils as eval_utils 13 | import tensorflow as tf 14 | import numpy as np 15 | 16 | from PIL import Image 17 | from seg_models.image_reader import SegSortImageReader 18 | from seg_models.models.pspnet import pspnet_resnet101 as model 19 | from tqdm import tqdm 20 | 21 | 22 | IMG_MEAN = np.array((122.675, 116.669, 104.008), dtype=np.float32) 23 | 24 | 25 | def get_arguments(): 26 | """Parse all the arguments provided from the CLI. 27 | 28 | Returns: 29 | A list of parsed arguments. 30 | """ 31 | parser = argparse.ArgumentParser( 32 | description='Inference for Semantic Segmentation') 33 | parser.add_argument('--data_dir', type=str, default='', 34 | help='/path/to/dataset.') 35 | parser.add_argument('--data_list', type=str, default='', 36 | help='/path/to/datalist/file.') 37 | parser.add_argument('--input_size', type=str, default='512,512', 38 | help='Comma-separated string with H and W of image.') 39 | parser.add_argument('--strides', type=str, default='512,512', 40 | help='Comma-separated string with strides of H and W.') 41 | parser.add_argument('--num_classes', type=int, default=21, 42 | help='Number of classes to predict.') 43 | parser.add_argument('--ignore_label', type=int, default=255, 44 | help='Index of label to ignore.') 45 | parser.add_argument('--restore_from', type=str, default='', 46 | help='Where restore model parameters from.') 47 | parser.add_argument('--save_dir', type=str, default='', 48 | help='/path/to/save/predictions.') 49 | parser.add_argument('--colormap', type=str, default='', 50 | help='/path/to/colormap/file.') 51 | # SegSort parameters. 52 | parser.add_argument('--prototype_dir', type=str, default='', 53 | help='/path/to/prototype/file.') 54 | parser.add_argument('--embedding_dim', type=int, default=32, 55 | help='Dimension of the feature embeddings.') 56 | parser.add_argument('--num_clusters', type=int, default=5, 57 | help='Number of kmeans clusters along each axis') 58 | parser.add_argument('--kmeans_iterations', type=int, default=10, 59 | help='Number of kmeans iterations.') 60 | parser.add_argument('--k_in_nearest_neighbors', type=int, default=15, 61 | help='K in k-nearest neighbor search.') 62 | 63 | return parser.parse_args() 64 | 65 | def load(saver, sess, ckpt_path): 66 | """Load the trained weights. 67 | 68 | Args: 69 | saver: TensorFlow saver object. 70 | sess: TensorFlow session. 71 | ckpt_path: path to checkpoint file with parameters. 72 | """ 73 | saver.restore(sess, ckpt_path) 74 | print('Restored model parameters from {}'.format(ckpt_path)) 75 | 76 | def parse_commastr(str_comma): 77 | """Read comma-sperated string. 78 | """ 79 | if '' == str_comma: 80 | return None 81 | else: 82 | a, b = map(int, str_comma.split(',')) 83 | 84 | return [a,b] 85 | 86 | def main(): 87 | """Create the model and start the Inference process.""" 88 | args = get_arguments() 89 | 90 | # Create queue coordinator. 91 | coord = tf.train.Coordinator() 92 | 93 | # Load the data reader. 94 | with tf.name_scope('create_inputs'): 95 | reader = SegSortImageReader( 96 | args.data_dir, 97 | args.data_list, 98 | parse_commastr(args.input_size), 99 | False, # No random scale 100 | False, # No random mirror 101 | False, # No random crop, center crop instead 102 | args.ignore_label, 103 | IMG_MEAN) 104 | 105 | image_list = reader.image_list 106 | image_batch = tf.expand_dims(reader.image, dim=0) 107 | label_batch = tf.expand_dims(reader.label, dim=0) 108 | cluster_label_batch = tf.expand_dims(reader.cluster_label, dim=0) 109 | loc_feature_batch = tf.expand_dims(reader.loc_feature, dim=0) 110 | height = reader.height 111 | width = reader.width 112 | 113 | # Create network and output prediction. 114 | outputs = model(image_batch, 115 | args.embedding_dim, 116 | False, 117 | True) 118 | 119 | # Grab variable names which should be restored from checkpoints. 120 | restore_var = [ 121 | v for v in tf.global_variables() if 'crop_image_batch' not in v.name] 122 | 123 | # Output predictions. 124 | output = outputs[0] 125 | output = tf.image.resize_bilinear( 126 | output, 127 | tf.shape(image_batch)[1:3,]) 128 | embedding = common_utils.normalize_embedding(output) 129 | 130 | # Prototype placeholders. 131 | prototype_features = tf.placeholder(tf.float32, 132 | shape=[None, args.embedding_dim]) 133 | prototype_labels = tf.placeholder(tf.int32) 134 | 135 | # Combine embedding with location features. 136 | embedding_with_location = tf.concat([embedding, loc_feature_batch], 3) 137 | embedding_with_location = common_utils.normalize_embedding( 138 | embedding_with_location) 139 | 140 | # Kmeans clustering. 141 | cluster_labels = common_utils.kmeans( 142 | embedding_with_location, 143 | [args.num_clusters, args.num_clusters], 144 | args.kmeans_iterations) 145 | test_prototypes = common_utils.calculate_prototypes_from_labels( 146 | embedding, cluster_labels) 147 | 148 | # Predict semantic labels. 149 | semantic_predictions, _ = eval_utils.predict_semantic_instance_labels( 150 | cluster_labels, 151 | test_prototypes, 152 | prototype_features, 153 | prototype_labels, 154 | None, 155 | args.k_in_nearest_neighbors) 156 | semantic_predictions = tf.cast(semantic_predictions, tf.uint8) 157 | semantic_predictions = tf.squeeze(semantic_predictions) 158 | 159 | 160 | # Set up tf session and initialize variables. 161 | config = tf.ConfigProto() 162 | config.gpu_options.allow_growth = True 163 | sess = tf.Session(config=config) 164 | init = tf.global_variables_initializer() 165 | 166 | sess.run(init) 167 | sess.run(tf.local_variables_initializer()) 168 | 169 | # Load weights. 170 | loader = tf.train.Saver(var_list=restore_var) 171 | if args.restore_from is not None: 172 | load(loader, sess, args.restore_from) 173 | 174 | # Start queue threads. 175 | threads = tf.train.start_queue_runners(coord=coord, sess=sess) 176 | 177 | # Get colormap. 178 | map_data = scipy.io.loadmat(args.colormap) 179 | key = os.path.basename(args.colormap).replace('.mat','') 180 | colormap = map_data[key] 181 | colormap *= 255 182 | colormap = colormap.astype(np.uint8) 183 | 184 | # Create directory for saving predictions. 185 | pred_dir = os.path.join(args.save_dir, 'gray') 186 | color_dir = os.path.join(args.save_dir, 'color') 187 | if not os.path.isdir(pred_dir): 188 | os.makedirs(pred_dir) 189 | if not os.path.isdir(color_dir): 190 | os.makedirs(color_dir) 191 | 192 | # Iterate over testing steps. 193 | with open(args.data_list, 'r') as listf: 194 | num_steps = len(listf.read().split('\n'))-1 195 | 196 | # Load prototype features and labels. 197 | prototype_features_np = np.load( 198 | os.path.join(args.prototype_dir, 'prototype_features.npy')) 199 | prototype_labels_np = np.load( 200 | os.path.join(args.prototype_dir, 'prototype_labels.npy')) 201 | 202 | feed_dict = {prototype_features: prototype_features_np, 203 | prototype_labels: prototype_labels_np} 204 | 205 | for step in tqdm(range(num_steps)): 206 | semantic_predictions_np, height_np, width_np = sess.run( 207 | [semantic_predictions, height, width], feed_dict=feed_dict) 208 | 209 | semantic_predictions_np = semantic_predictions_np[:height_np, :width_np] 210 | 211 | basename = os.path.basename(image_list[step]) 212 | basename = basename.replace('jpg', 'png') 213 | 214 | predname = os.path.join(pred_dir, basename) 215 | Image.fromarray(semantic_predictions_np, mode='L').save(predname) 216 | 217 | colorname = os.path.join(color_dir, basename) 218 | color = colormap[semantic_predictions_np] 219 | Image.fromarray(color, mode='RGB').save(colorname) 220 | 221 | coord.request_stop() 222 | coord.join(threads) 223 | 224 | if __name__ == '__main__': 225 | main() 226 | -------------------------------------------------------------------------------- /pyscripts/inference/inference_vmf.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | 3 | import argparse 4 | import os 5 | import time 6 | import math 7 | 8 | import tensorflow as tf 9 | import numpy as np 10 | import scipy.io 11 | import scipy.misc 12 | from PIL import Image 13 | import network.vmf.common_utils as common_utils 14 | import network.vmf.eval_utils as eval_utils 15 | 16 | from seg_models.models.pspnet import pspnet_resnet101 as model 17 | from seg_models.image_reader import VMFImageReader 18 | import utils.general 19 | 20 | IMG_MEAN = np.array((122.675, 116.669, 104.008), dtype=np.float32) 21 | 22 | 23 | def get_arguments(): 24 | """Parse all the arguments provided from the CLI. 25 | 26 | Returns: 27 | A list of parsed arguments. 28 | """ 29 | parser = argparse.ArgumentParser( 30 | description='Inference for Semantic Segmentation') 31 | parser.add_argument('--data-dir', type=str, default='', 32 | help='/path/to/dataset.') 33 | parser.add_argument('--data-list', type=str, default='', 34 | help='/path/to/datalist/file.') 35 | parser.add_argument('--input-size', type=str, default='512,512', 36 | help='Comma-separated string with H and W of image.') 37 | parser.add_argument('--strides', type=str, default='512,512', 38 | help='Comma-separated string with strides of H and W.') 39 | parser.add_argument('--num-classes', type=int, default=21, 40 | help='Number of classes to predict.') 41 | parser.add_argument('--ignore-label', type=int, default=255, 42 | help='Index of label to ignore.') 43 | parser.add_argument('--restore-from', type=str, default='', 44 | help='Where restore model parameters from.') 45 | parser.add_argument('--save-dir', type=str, default='', 46 | help='/path/to/save/predictions.') 47 | parser.add_argument('--colormap', type=str, default='', 48 | help='/path/to/colormap/file.') 49 | # vMF parameters 50 | parser.add_argument('--prototype_dir', type=str, default='', 51 | help='/path/to/prototype/file.') 52 | parser.add_argument('--embedding_dim', type=int, default=32, 53 | help='Dimension of the feature embeddings.') 54 | parser.add_argument('--num_clusters', type=int, default=5, 55 | help='Number of kmeans clusters along each axis') 56 | parser.add_argument('--kmeans_iterations', type=int, default=10, 57 | help='Number of kmeans iterations.') 58 | parser.add_argument('--k_in_nearest_neighbors', type=int, default=15, 59 | help='K in k-nearest neighbor search.') 60 | 61 | return parser.parse_args() 62 | 63 | def load(saver, sess, ckpt_path): 64 | """Load the trained weights. 65 | 66 | Args: 67 | saver: TensorFlow saver object. 68 | sess: TensorFlow session. 69 | ckpt_path: path to checkpoint file with parameters. 70 | """ 71 | saver.restore(sess, ckpt_path) 72 | print('Restored model parameters from {}'.format(ckpt_path)) 73 | 74 | def parse_commastr(str_comma): 75 | """Read comma-sperated string. 76 | """ 77 | if '' == str_comma: 78 | return None 79 | else: 80 | a, b = map(int, str_comma.split(',')) 81 | 82 | return [a,b] 83 | 84 | def main(): 85 | """Create the model and start the Inference process. 86 | """ 87 | args = get_arguments() 88 | 89 | # Create queue coordinator. 90 | coord = tf.train.Coordinator() 91 | 92 | # Load the data reader. 93 | with tf.name_scope('create_inputs'): 94 | reader = VMFImageReader( 95 | args.data_dir, 96 | args.data_list, 97 | None, 98 | False, # No random scale. 99 | False, # No random mirror. 100 | False, # No random crop, center crop instead 101 | args.ignore_label, 102 | IMG_MEAN) 103 | 104 | image_list = reader.image_list 105 | image_batch = tf.expand_dims(reader.image, dim=0) 106 | label_batch = tf.expand_dims(reader.label, dim=0) 107 | cluster_label_batch = tf.expand_dims(reader.cluster_label, dim=0) 108 | loc_feature_batch = tf.expand_dims(reader.loc_feature, dim=0) 109 | height = reader.height 110 | width = reader.width 111 | 112 | # Create network and output prediction. 113 | outputs = model(image_batch, 114 | args.embedding_dim, 115 | False, 116 | True) 117 | 118 | # Grab variable names which should be restored from checkpoints. 119 | restore_var = [ 120 | v for v in tf.global_variables() if 'crop_image_batch' not in v.name] 121 | 122 | # Output predictions. 123 | output = outputs[0] 124 | output = tf.image.resize_bilinear( 125 | output, 126 | tf.shape(image_batch)[1:3,]) 127 | embedding = common_utils.normalize_embedding(output) 128 | 129 | # Prototype placeholders. 130 | prototype_features = tf.placeholder(tf.float32, 131 | shape=[None, args.embedding_dim]) 132 | prototype_labels = tf.placeholder(tf.int32) 133 | 134 | # Combine embedding with location features. 135 | embedding_with_location = tf.concat([embedding, loc_feature_batch], 3) 136 | embedding_with_location = common_utils.normalize_embedding( 137 | embedding_with_location) 138 | 139 | # Kmeans clustering. 140 | cluster_labels = common_utils.kmeans( 141 | embedding_with_location, 142 | [args.num_clusters, args.num_clusters], 143 | args.kmeans_iterations) 144 | test_prototypes = common_utils.calculate_prototypes_from_labels( 145 | embedding, cluster_labels) 146 | 147 | # Predict semantic labels. 148 | semantic_predictions, _ = eval_utils.predict_semantic_instance_labels( 149 | cluster_labels, 150 | test_prototypes, 151 | prototype_features, 152 | prototype_labels, 153 | None, 154 | args.k_in_nearest_neighbors) 155 | semantic_predictions = tf.cast(semantic_predictions, tf.uint8) 156 | semantic_predictions = tf.squeeze(semantic_predictions) 157 | 158 | 159 | # Set up tf session and initialize variables. 160 | config = tf.ConfigProto() 161 | config.gpu_options.allow_growth = True 162 | sess = tf.Session(config=config) 163 | init = tf.global_variables_initializer() 164 | 165 | sess.run(init) 166 | sess.run(tf.local_variables_initializer()) 167 | 168 | # Load weights. 169 | loader = tf.train.Saver(var_list=restore_var) 170 | if args.restore_from is not None: 171 | load(loader, sess, args.restore_from) 172 | 173 | # Start queue threads. 174 | threads = tf.train.start_queue_runners(coord=coord, sess=sess) 175 | 176 | # Get colormap. 177 | map_data = scipy.io.loadmat(args.colormap) 178 | key = os.path.basename(args.colormap).replace('.mat','') 179 | colormap = map_data[key] 180 | colormap *= 255 181 | colormap = colormap.astype(np.uint8) 182 | 183 | # Create directory for saving predictions. 184 | pred_dir = os.path.join(args.save_dir, 'gray') 185 | color_dir = os.path.join(args.save_dir, 'color') 186 | if not os.path.isdir(pred_dir): 187 | os.makedirs(pred_dir) 188 | if not os.path.isdir(color_dir): 189 | os.makedirs(color_dir) 190 | 191 | # Iterate over testing steps. 192 | with open(args.data_list, 'r') as listf: 193 | num_steps = len(listf.read().split('\n'))-1 194 | 195 | # Load prototype features and labels 196 | prototype_features_np = np.load( 197 | os.path.join(args.prototype_dir, 'prototype_features.npy')) 198 | prototype_labels_np = np.load( 199 | os.path.join(args.prototype_dir, 'prototype_labels.npy')) 200 | 201 | feed_dict = {prototype_features: prototype_features_np, 202 | prototype_labels: prototype_labels_np} 203 | 204 | for step in range(num_steps): 205 | semantic_predictions_np, height_np, width_np = sess.run( 206 | [semantic_predictions, height, width], feed_dict=feed_dict) 207 | 208 | semantic_predictions_np = semantic_predictions_np[:height_np, :width_np] 209 | 210 | basename = os.path.basename(image_list[step]) 211 | basename = basename.replace('jpg', 'png') 212 | 213 | predname = os.path.join(pred_dir, basename) 214 | Image.fromarray(semantic_predictions_np, mode='L').save(predname) 215 | 216 | colorname = os.path.join(color_dir, basename) 217 | color = colormap[semantic_predictions_np] 218 | Image.fromarray(color, mode='RGB').save(colorname) 219 | 220 | if (step + 1) % 100 == 0: 221 | print('Processed batches: ', (step + 1), '/', num_steps) 222 | 223 | coord.request_stop() 224 | coord.join(threads) 225 | 226 | if __name__ == '__main__': 227 | main() 228 | -------------------------------------------------------------------------------- /pyscripts/inference/prototype_embedding_fine.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | 3 | import argparse 4 | import os 5 | import time 6 | import math 7 | from tqdm import tqdm 8 | 9 | import tensorflow as tf 10 | import numpy as np 11 | import scipy.io 12 | import scipy.misc 13 | import network.vmf.common_utils as common_utils 14 | import network.vmf.eval_utils as eval_utils 15 | from PIL import Image 16 | 17 | from seg_models.models.pspnet import pspnet_resnet101 as model 18 | from seg_models.image_reader import VMFImageReader 19 | import utils.general 20 | 21 | IMG_MEAN = np.array((122.675, 116.669, 104.008), dtype=np.float32) 22 | 23 | 24 | def get_arguments(): 25 | """Parse all the arguments provided from the CLI. 26 | 27 | Returns: 28 | A list of parsed arguments. 29 | """ 30 | parser = argparse.ArgumentParser( 31 | description='Extracting Prototypes for Semantic Segmentation') 32 | parser.add_argument('--data-dir', type=str, default='', 33 | help='/path/to/dataset.') 34 | parser.add_argument('--data-list', type=str, default='', 35 | help='/path/to/datalist/file.') 36 | parser.add_argument('--input-size', type=str, default='512,512', 37 | help='Comma-separated string with H and W of image.') 38 | parser.add_argument('--strides', type=str, default='512,512', 39 | help='Comma-separated string with strides of H and W.') 40 | parser.add_argument('--num-classes', type=int, default=21, 41 | help='Number of classes to predict.') 42 | parser.add_argument('--ignore-label', type=int, default=255, 43 | help='Index of label to ignore.') 44 | parser.add_argument('--restore-from', type=str, default='', 45 | help='Where restore model parameters from.') 46 | parser.add_argument('--save-dir', type=str, default='', 47 | help='/path/to/save/predictions.') 48 | parser.add_argument('--colormap', type=str, default='', 49 | help='/path/to/colormap/file.') 50 | # vMF parameters 51 | parser.add_argument('--embedding_dim', type=int, default=32, 52 | help='Dimension of the feature embeddings.') 53 | parser.add_argument('--num_clusters', type=int, default=5, 54 | help='Number of kmeans clusters along each axis') 55 | parser.add_argument('--kmeans_iterations', type=int, default=10, 56 | help='Number of kmeans iterations.') 57 | 58 | 59 | return parser.parse_args() 60 | 61 | def load(saver, sess, ckpt_path): 62 | """Load the trained weights. 63 | 64 | Args: 65 | saver: TensorFlow saver object. 66 | sess: TensorFlow session. 67 | ckpt_path: path to checkpoint file with parameters. 68 | """ 69 | saver.restore(sess, ckpt_path) 70 | print('Restored model parameters from {}'.format(ckpt_path)) 71 | 72 | def parse_commastr(str_comma): 73 | """Read comma-sperated string. 74 | """ 75 | if '' == str_comma: 76 | return None 77 | else: 78 | a, b = map(int, str_comma.split(',')) 79 | 80 | return [a,b] 81 | 82 | def main(): 83 | """Create the model and start the Inference process. 84 | """ 85 | args = get_arguments() 86 | 87 | # Parse image processing arguments. 88 | input_size = parse_commastr(args.input_size) 89 | strides = parse_commastr(args.strides) 90 | assert(input_size is not None and strides is not None) 91 | h, w = input_size 92 | innet_size = (int(math.ceil(h/8)), int(math.ceil(w/8))) 93 | 94 | 95 | # Create queue coordinator. 96 | coord = tf.train.Coordinator() 97 | 98 | # Load the data reader. 99 | with tf.name_scope('create_inputs'): 100 | reader = VMFImageReader( 101 | args.data_dir, 102 | args.data_list, 103 | None, 104 | False, # No random scale. 105 | False, # No random mirror. 106 | False, # No random crop, center crop instead 107 | args.ignore_label, 108 | IMG_MEAN) 109 | 110 | image = reader.image 111 | label = reader.label 112 | image_list = reader.image_list 113 | image_batch = tf.expand_dims(image, dim=0) 114 | label_batch = tf.expand_dims(label, dim=0) 115 | 116 | # Create input tensor to the Network 117 | crop_image_batch = tf.placeholder( 118 | name='crop_image_batch', 119 | shape=[1,input_size[0],input_size[1],3], 120 | dtype=tf.float32) 121 | 122 | # Create network and output prediction. 123 | outputs = model(crop_image_batch, 124 | args.embedding_dim, 125 | False, 126 | True) 127 | 128 | # Grab variable names which should be restored from checkpoints. 129 | restore_var = [ 130 | v for v in tf.global_variables() if 'crop_image_batch' not in v.name] 131 | 132 | # Output predictions. 133 | output = outputs[0] 134 | output = tf.image.resize_bilinear( 135 | output, 136 | [input_size[0], input_size[1]]) 137 | 138 | # Input full-sized embedding 139 | label_input = tf.placeholder( 140 | tf.int32, shape=[1, None, None, 1]) 141 | embedding_input = tf.placeholder( 142 | tf.float32, shape=[1, None, None, args.embedding_dim]) 143 | embedding = common_utils.normalize_embedding(embedding_input) 144 | loc_feature = tf.placeholder( 145 | tf.float32, shape=[1, None, None, 2]) 146 | 147 | # Combine embedding with location features and kmeans 148 | shape = tf.shape(embedding) 149 | # embedding_with_location = tf.concat([embedding, loc_feature], 3) 150 | # embedding_with_location = common_utils.normalize_embedding( 151 | # embedding_with_location) 152 | cluster_labels = common_utils.initialize_cluster_labels( 153 | [args.num_clusters, args.num_clusters], 154 | [shape[1], shape[2]]) 155 | embedding = tf.reshape(embedding, [-1, args.embedding_dim]) 156 | labels = tf.reshape(label_input, [-1]) 157 | cluster_labels = tf.reshape(cluster_labels, [-1]) 158 | location_features = tf.reshape(loc_feature, [-1, 2]) 159 | 160 | # Collect pixels of valid semantic classes. 161 | valid_pixels = tf.where( 162 | tf.not_equal(labels, args.ignore_label)) 163 | labels = tf.squeeze(tf.gather(labels, valid_pixels), axis=1) 164 | cluster_labels = tf.squeeze(tf.gather(cluster_labels, valid_pixels), axis=1) 165 | embedding = tf.squeeze(tf.gather(embedding, valid_pixels), axis=1) 166 | location_features = tf.squeeze( 167 | tf.gather(location_features, valid_pixels), axis=1) 168 | 169 | # Generate cluster labels via kmeans clustering. 170 | embedding_with_location = tf.concat([embedding, location_features], 1) 171 | embedding_with_location = common_utils.normalize_embedding( 172 | embedding_with_location) 173 | cluster_labels = common_utils.kmeans_with_initial_labels( 174 | embedding_with_location, 175 | cluster_labels, 176 | args.num_clusters * args.num_clusters, 177 | args.kmeans_iterations) 178 | 179 | cluster_labels, prototype_labels = common_utils.prepare_prototype_labels(labels, cluster_labels) 180 | prototype_features = common_utils.calculate_prototypes_from_labels( 181 | embedding, cluster_labels) 182 | 183 | # Set up tf session and initialize variables. 184 | config = tf.ConfigProto() 185 | config.gpu_options.allow_growth = True 186 | sess = tf.Session(config=config) 187 | init = tf.global_variables_initializer() 188 | 189 | sess.run(init) 190 | sess.run(tf.local_variables_initializer()) 191 | 192 | # Load weights. 193 | loader = tf.train.Saver(var_list=restore_var) 194 | if args.restore_from is not None: 195 | load(loader, sess, args.restore_from) 196 | 197 | # Start queue threads. 198 | threads = tf.train.start_queue_runners(coord=coord, sess=sess) 199 | 200 | # Create directory for saving prototypes. 201 | save_dir = os.path.join(args.save_dir, 'prototypes') 202 | if not os.path.isdir(save_dir): 203 | os.makedirs(save_dir) 204 | 205 | # Iterate over testing steps. 206 | with open(args.data_list, 'r') as listf: 207 | num_steps = len(listf.read().split('\n'))-1 208 | 209 | 210 | pbar = tqdm(range(num_steps)) 211 | for step in pbar: 212 | image_batch_np, label_batch_np = sess.run( 213 | [image_batch, label_batch]) 214 | 215 | img_size = image_batch_np.shape 216 | padded_img_size = list(img_size) # deep copy of img_size 217 | 218 | if input_size[0] > padded_img_size[1]: 219 | padded_img_size[1] = input_size[0] 220 | if input_size[1] > padded_img_size[2]: 221 | padded_img_size[2] = input_size[1] 222 | padded_img_batch = np.zeros(padded_img_size, 223 | dtype=np.float32) 224 | img_h, img_w = img_size[1:3] 225 | padded_img_batch[:, :img_h, :img_w, :] = image_batch_np 226 | 227 | stride_h, stride_w = strides 228 | npatches_h = math.ceil(1.0*(padded_img_size[1]-input_size[0])/stride_h) + 1 229 | npatches_w = math.ceil(1.0*(padded_img_size[2]-input_size[1])/stride_w) + 1 230 | 231 | # Create the ending index of each patch. 232 | patch_indh = np.linspace( 233 | input_size[0], padded_img_size[1], npatches_h, dtype=np.int32) 234 | patch_indw = np.linspace( 235 | input_size[1], padded_img_size[2], npatches_w, dtype=np.int32) 236 | 237 | # Create embedding holder. 238 | padded_img_size[-1] = args.embedding_dim 239 | embedding_all_np = np.zeros(padded_img_size, 240 | dtype=np.float32) 241 | for indh in patch_indh: 242 | for indw in patch_indw: 243 | sh, eh = indh-input_size[0], indh # start & end ind of H 244 | sw, ew = indw-input_size[1], indw # start & end ind of W 245 | cropimg_batch = padded_img_batch[:, sh:eh, sw:ew, :] 246 | 247 | embedding_np = sess.run(output, feed_dict={ 248 | crop_image_batch: cropimg_batch}) 249 | embedding_all_np[:, sh:eh, sw:ew, :] += embedding_np 250 | 251 | embedding_all_np = embedding_all_np[:, :img_h, :img_w, :] 252 | loc_feature_np = common_utils.generate_location_features_np([padded_img_size[1], padded_img_size[2]]) 253 | feed_dict = {label_input: label_batch_np, 254 | embedding_input: embedding_all_np, 255 | loc_feature: loc_feature_np} 256 | 257 | (batch_prototype_features_np, 258 | batch_prototype_labels_np) = sess.run( 259 | [prototype_features, prototype_labels], 260 | feed_dict=feed_dict) 261 | 262 | if step == 0: 263 | prototype_features_np = batch_prototype_features_np 264 | prototype_labels_np = batch_prototype_labels_np 265 | else: 266 | prototype_features_np = np.concatenate( 267 | [prototype_features_np, batch_prototype_features_np], axis=0) 268 | prototype_labels_np = np.concatenate( 269 | [prototype_labels_np, 270 | batch_prototype_labels_np], axis=0) 271 | 272 | 273 | print ('Total number of prototypes extracted: ', 274 | len(prototype_labels_np)) 275 | np.save( 276 | tf.gfile.Open('%s/%s.npy' % (save_dir, 'prototype_features'), 277 | mode='w'), prototype_features_np) 278 | np.save( 279 | tf.gfile.Open('%s/%s.npy' % (save_dir, 'prototype_labels'), 280 | mode='w'), prototype_labels_np) 281 | 282 | 283 | coord.request_stop() 284 | coord.join(threads) 285 | 286 | if __name__ == '__main__': 287 | main() 288 | -------------------------------------------------------------------------------- /pyscripts/inference/prototype_embedding_rgb.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | 3 | import argparse 4 | import os 5 | import time 6 | import math 7 | from tqdm import tqdm 8 | 9 | import tensorflow as tf 10 | import numpy as np 11 | import scipy.io 12 | import scipy.misc 13 | import network.vmf.common_utils as common_utils 14 | import network.vmf.eval_utils as eval_utils 15 | from PIL import Image 16 | 17 | from seg_models.models.pspnet import pspnet_resnet101 as model 18 | from seg_models.image_reader import VMFImageReader 19 | import utils.general 20 | 21 | IMG_MEAN = np.array((122.675, 116.669, 104.008), dtype=np.float32) 22 | 23 | 24 | def get_arguments(): 25 | """Parse all the arguments provided from the CLI. 26 | 27 | Returns: 28 | A list of parsed arguments. 29 | """ 30 | parser = argparse.ArgumentParser( 31 | description='Extracting Prototypes for Semantic Segmentation') 32 | parser.add_argument('--data-dir', type=str, default='', 33 | help='/path/to/dataset.') 34 | parser.add_argument('--data-list', type=str, default='', 35 | help='/path/to/datalist/file.') 36 | parser.add_argument('--input-size', type=str, default='512,512', 37 | help='Comma-separated string with H and W of image.') 38 | parser.add_argument('--strides', type=str, default='512,512', 39 | help='Comma-separated string with strides of H and W.') 40 | parser.add_argument('--num-classes', type=int, default=21, 41 | help='Number of classes to predict.') 42 | parser.add_argument('--ignore-label', type=int, default=255, 43 | help='Index of label to ignore.') 44 | parser.add_argument('--restore-from', type=str, default='', 45 | help='Where restore model parameters from.') 46 | parser.add_argument('--save-dir', type=str, default='', 47 | help='/path/to/save/predictions.') 48 | parser.add_argument('--colormap', type=str, default='', 49 | help='/path/to/colormap/file.') 50 | # vMF parameters 51 | parser.add_argument('--embedding_dim', type=int, default=32, 52 | help='Dimension of the feature embeddings.') 53 | parser.add_argument('--num_clusters', type=int, default=5, 54 | help='Number of kmeans clusters along each axis') 55 | parser.add_argument('--kmeans_iterations', type=int, default=10, 56 | help='Number of kmeans iterations.') 57 | 58 | 59 | return parser.parse_args() 60 | 61 | def load(saver, sess, ckpt_path): 62 | """Load the trained weights. 63 | 64 | Args: 65 | saver: TensorFlow saver object. 66 | sess: TensorFlow session. 67 | ckpt_path: path to checkpoint file with parameters. 68 | """ 69 | saver.restore(sess, ckpt_path) 70 | print('Restored model parameters from {}'.format(ckpt_path)) 71 | 72 | def parse_commastr(str_comma): 73 | """Read comma-sperated string. 74 | """ 75 | if '' == str_comma: 76 | return None 77 | else: 78 | a, b = map(int, str_comma.split(',')) 79 | 80 | return [a,b] 81 | 82 | def main(): 83 | """Create the model and start the Inference process. 84 | """ 85 | args = get_arguments() 86 | 87 | # Parse image processing arguments. 88 | input_size = parse_commastr(args.input_size) 89 | strides = parse_commastr(args.strides) 90 | assert(input_size is not None and strides is not None) 91 | h, w = input_size 92 | innet_size = (int(math.ceil(h/8)), int(math.ceil(w/8))) 93 | 94 | 95 | # Create queue coordinator. 96 | coord = tf.train.Coordinator() 97 | 98 | # Load the data reader. 99 | with tf.name_scope('create_inputs'): 100 | reader = VMFImageReader( 101 | args.data_dir, 102 | args.data_list, 103 | None, 104 | False, # No random scale. 105 | False, # No random mirror. 106 | False, # No random crop, center crop instead 107 | args.ignore_label, 108 | IMG_MEAN) 109 | 110 | image = reader.image 111 | label = reader.label 112 | image_list = reader.image_list 113 | image_batch = tf.expand_dims(image, dim=0) 114 | label_batch = tf.expand_dims(label, dim=0) 115 | 116 | # Create input tensor to the Network 117 | crop_image_batch = tf.placeholder( 118 | name='crop_image_batch', 119 | shape=[1,input_size[0],input_size[1],3], 120 | dtype=tf.float32) 121 | 122 | # Create network and output prediction. 123 | outputs = model(crop_image_batch, 124 | args.embedding_dim, 125 | False, 126 | True) 127 | 128 | # Grab variable names which should be restored from checkpoints. 129 | restore_var = [ 130 | v for v in tf.global_variables() if 'crop_image_batch' not in v.name] 131 | 132 | # Output predictions. 133 | output = outputs[0] 134 | output = tf.image.resize_bilinear( 135 | output, 136 | [input_size[0], input_size[1]]) 137 | 138 | # Input full-sized embedding 139 | label_input = tf.placeholder( 140 | tf.int32, shape=[1, None, None, 1]) 141 | embedding_input = tf.placeholder( 142 | tf.float32, shape=[1, None, None, args.embedding_dim]) 143 | embedding = common_utils.normalize_embedding(embedding_input) 144 | loc_feature = tf.placeholder( 145 | tf.float32, shape=[1, None, None, 2]) 146 | rgb_feature = tf.placeholder( 147 | tf.float32, shape=[1, None, None, 3]) 148 | 149 | # Combine embedding with location features and kmeans 150 | shape = tf.shape(embedding) 151 | cluster_labels = common_utils.initialize_cluster_labels( 152 | [args.num_clusters, args.num_clusters], 153 | [shape[1], shape[2]]) 154 | embedding = tf.reshape(embedding, [-1, args.embedding_dim]) 155 | labels = tf.reshape(label_input, [-1]) 156 | cluster_labels = tf.reshape(cluster_labels, [-1]) 157 | location_features = tf.reshape(loc_feature, [-1, 2]) 158 | rgb_features = common_utils.normalize_embedding( 159 | tf.reshape(rgb_feature, [-1, 3])) / args.embedding_dim 160 | 161 | # Collect pixels of valid semantic classes. 162 | valid_pixels = tf.where( 163 | tf.not_equal(labels, args.ignore_label)) 164 | labels = tf.squeeze(tf.gather(labels, valid_pixels), axis=1) 165 | cluster_labels = tf.squeeze(tf.gather(cluster_labels, valid_pixels), axis=1) 166 | embedding = tf.squeeze(tf.gather(embedding, valid_pixels), axis=1) 167 | location_features = tf.squeeze( 168 | tf.gather(location_features, valid_pixels), axis=1) 169 | rgb_features = tf.squeeze(tf.gather(rgb_features, valid_pixels), axis=1) 170 | 171 | # Generate cluster labels via kmeans clustering. 172 | embedding_with_location = tf.concat( 173 | [embedding, location_features, rgb_features], 1) 174 | embedding_with_location = common_utils.normalize_embedding( 175 | embedding_with_location) 176 | cluster_labels = common_utils.kmeans_with_initial_labels( 177 | embedding_with_location, 178 | cluster_labels, 179 | args.num_clusters * args.num_clusters, 180 | args.kmeans_iterations) 181 | _, cluster_labels = tf.unique(cluster_labels) 182 | 183 | # Find pixels of majority semantic classes. 184 | select_pixels, prototype_labels = eval_utils.find_majority_label_index( 185 | labels, cluster_labels) 186 | 187 | # Calculate the prototype features. 188 | cluster_labels = tf.squeeze(tf.gather(cluster_labels, select_pixels), axis=1) 189 | embedding = tf.squeeze(tf.gather(embedding, select_pixels), axis=1) 190 | 191 | prototype_features = common_utils.calculate_prototypes_from_labels( 192 | embedding, cluster_labels) 193 | 194 | 195 | # Set up tf session and initialize variables. 196 | config = tf.ConfigProto() 197 | config.gpu_options.allow_growth = True 198 | sess = tf.Session(config=config) 199 | init = tf.global_variables_initializer() 200 | 201 | sess.run(init) 202 | sess.run(tf.local_variables_initializer()) 203 | 204 | # Load weights. 205 | loader = tf.train.Saver(var_list=restore_var) 206 | if args.restore_from is not None: 207 | load(loader, sess, args.restore_from) 208 | 209 | # Start queue threads. 210 | threads = tf.train.start_queue_runners(coord=coord, sess=sess) 211 | 212 | # Create directory for saving prototypes. 213 | save_dir = os.path.join(args.save_dir, 'prototypes') 214 | if not os.path.isdir(save_dir): 215 | os.makedirs(save_dir) 216 | 217 | # Iterate over testing steps. 218 | with open(args.data_list, 'r') as listf: 219 | num_steps = len(listf.read().split('\n'))-1 220 | 221 | 222 | pbar = tqdm(range(num_steps)) 223 | for step in pbar: 224 | image_batch_np, label_batch_np = sess.run( 225 | [image_batch, label_batch]) 226 | 227 | img_size = image_batch_np.shape 228 | padded_img_size = list(img_size) # deep copy of img_size 229 | 230 | if input_size[0] > padded_img_size[1]: 231 | padded_img_size[1] = input_size[0] 232 | if input_size[1] > padded_img_size[2]: 233 | padded_img_size[2] = input_size[1] 234 | padded_img_batch = np.zeros(padded_img_size, 235 | dtype=np.float32) 236 | img_h, img_w = img_size[1:3] 237 | padded_img_batch[:, :img_h, :img_w, :] = image_batch_np 238 | 239 | stride_h, stride_w = strides 240 | npatches_h = math.ceil(1.0*(padded_img_size[1]-input_size[0])/stride_h) + 1 241 | npatches_w = math.ceil(1.0*(padded_img_size[2]-input_size[1])/stride_w) + 1 242 | 243 | # Create the ending index of each patch. 244 | patch_indh = np.linspace( 245 | input_size[0], padded_img_size[1], npatches_h, dtype=np.int32) 246 | patch_indw = np.linspace( 247 | input_size[1], padded_img_size[2], npatches_w, dtype=np.int32) 248 | 249 | # Create embedding holder. 250 | padded_img_size[-1] = args.embedding_dim 251 | embedding_all_np = np.zeros(padded_img_size, 252 | dtype=np.float32) 253 | for indh in patch_indh: 254 | for indw in patch_indw: 255 | sh, eh = indh-input_size[0], indh # start & end ind of H 256 | sw, ew = indw-input_size[1], indw # start & end ind of W 257 | cropimg_batch = padded_img_batch[:, sh:eh, sw:ew, :] 258 | 259 | embedding_np = sess.run(output, feed_dict={ 260 | crop_image_batch: cropimg_batch}) 261 | embedding_all_np[:, sh:eh, sw:ew, :] += embedding_np 262 | 263 | embedding_all_np = embedding_all_np[:, :img_h, :img_w, :] 264 | loc_feature_np = common_utils.generate_location_features_np([padded_img_size[1], padded_img_size[2]]) 265 | feed_dict = {label_input: label_batch_np, 266 | embedding_input: embedding_all_np, 267 | loc_feature: loc_feature_np, 268 | rgb_feature: padded_img_batch} 269 | 270 | (batch_prototype_features_np, 271 | batch_prototype_labels_np) = sess.run( 272 | [prototype_features, prototype_labels], 273 | feed_dict=feed_dict) 274 | 275 | if step == 0: 276 | prototype_features_np = batch_prototype_features_np 277 | prototype_labels_np = batch_prototype_labels_np 278 | else: 279 | prototype_features_np = np.concatenate( 280 | [prototype_features_np, batch_prototype_features_np], axis=0) 281 | prototype_labels_np = np.concatenate( 282 | [prototype_labels_np, 283 | batch_prototype_labels_np], axis=0) 284 | 285 | 286 | print ('Total number of prototypes extracted: ', 287 | len(prototype_labels_np)) 288 | np.save( 289 | tf.gfile.Open('%s/%s.npy' % (save_dir, 'prototype_features'), 290 | mode='w'), prototype_features_np) 291 | np.save( 292 | tf.gfile.Open('%s/%s.npy' % (save_dir, 'prototype_labels'), 293 | mode='w'), prototype_labels_np) 294 | 295 | 296 | coord.request_stop() 297 | coord.join(threads) 298 | 299 | if __name__ == '__main__': 300 | main() 301 | -------------------------------------------------------------------------------- /pyscripts/inference/prototype_embedding_with_flip.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | 3 | import argparse 4 | import os 5 | import time 6 | import math 7 | from tqdm import tqdm 8 | 9 | import tensorflow as tf 10 | import numpy as np 11 | import scipy.io 12 | import scipy.misc 13 | import network.vmf.common_utils as common_utils 14 | import network.vmf.eval_utils as eval_utils 15 | from PIL import Image 16 | 17 | from seg_models.models.pspnet import pspnet_resnet101 as model 18 | from seg_models.image_reader import VMFImageReader 19 | import utils.general 20 | 21 | IMG_MEAN = np.array((122.675, 116.669, 104.008), dtype=np.float32) 22 | 23 | 24 | def get_arguments(): 25 | """Parse all the arguments provided from the CLI. 26 | 27 | Returns: 28 | A list of parsed arguments. 29 | """ 30 | parser = argparse.ArgumentParser( 31 | description='Extracting Prototypes for Semantic Segmentation') 32 | parser.add_argument('--data-dir', type=str, default='', 33 | help='/path/to/dataset.') 34 | parser.add_argument('--data-list', type=str, default='', 35 | help='/path/to/datalist/file.') 36 | parser.add_argument('--input-size', type=str, default='512,512', 37 | help='Comma-separated string with H and W of image.') 38 | parser.add_argument('--strides', type=str, default='512,512', 39 | help='Comma-separated string with strides of H and W.') 40 | parser.add_argument('--num-classes', type=int, default=21, 41 | help='Number of classes to predict.') 42 | parser.add_argument('--ignore-label', type=int, default=255, 43 | help='Index of label to ignore.') 44 | parser.add_argument('--restore-from', type=str, default='', 45 | help='Where restore model parameters from.') 46 | parser.add_argument('--save-dir', type=str, default='', 47 | help='/path/to/save/predictions.') 48 | parser.add_argument('--colormap', type=str, default='', 49 | help='/path/to/colormap/file.') 50 | parser.add_argument('--flip-aug', action='store_true', 51 | help='Augment data by horizontal flipping.') 52 | # vMF parameters 53 | parser.add_argument('--embedding_dim', type=int, default=32, 54 | help='Dimension of the feature embeddings.') 55 | parser.add_argument('--num_clusters', type=int, default=5, 56 | help='Number of kmeans clusters along each axis') 57 | parser.add_argument('--kmeans_iterations', type=int, default=10, 58 | help='Number of kmeans iterations.') 59 | 60 | 61 | return parser.parse_args() 62 | 63 | def load(saver, sess, ckpt_path): 64 | """Load the trained weights. 65 | 66 | Args: 67 | saver: TensorFlow saver object. 68 | sess: TensorFlow session. 69 | ckpt_path: path to checkpoint file with parameters. 70 | """ 71 | saver.restore(sess, ckpt_path) 72 | print('Restored model parameters from {}'.format(ckpt_path)) 73 | 74 | def parse_commastr(str_comma): 75 | """Read comma-sperated string. 76 | """ 77 | if '' == str_comma: 78 | return None 79 | else: 80 | a, b = map(int, str_comma.split(',')) 81 | 82 | return [a,b] 83 | 84 | def main(): 85 | """Create the model and start the Inference process. 86 | """ 87 | args = get_arguments() 88 | 89 | # Parse image processing arguments. 90 | input_size = parse_commastr(args.input_size) 91 | strides = parse_commastr(args.strides) 92 | assert(input_size is not None and strides is not None) 93 | h, w = input_size 94 | innet_size = (int(math.ceil(h/8)), int(math.ceil(w/8))) 95 | 96 | 97 | # Create queue coordinator. 98 | coord = tf.train.Coordinator() 99 | 100 | # Load the data reader. 101 | with tf.name_scope('create_inputs'): 102 | reader = VMFImageReader( 103 | args.data_dir, 104 | args.data_list, 105 | None, 106 | False, # No random scale. 107 | False, # No random mirror. 108 | False, # No random crop, center crop instead 109 | args.ignore_label, 110 | IMG_MEAN) 111 | 112 | image = reader.image 113 | label = reader.label 114 | image_list = reader.image_list 115 | image_batch = tf.expand_dims(image, dim=0) 116 | label_batch = tf.expand_dims(label, dim=0) 117 | 118 | # Create input tensor to the Network 119 | crop_image_batch = tf.placeholder( 120 | name='crop_image_batch', 121 | shape=[1,input_size[0],input_size[1],3], 122 | dtype=tf.float32) 123 | 124 | # Create network and output prediction. 125 | outputs = model(crop_image_batch, 126 | args.embedding_dim, 127 | False, 128 | True) 129 | 130 | # Grab variable names which should be restored from checkpoints. 131 | restore_var = [ 132 | v for v in tf.global_variables() if 'crop_image_batch' not in v.name] 133 | 134 | # Output predictions. 135 | output = outputs[0] 136 | output = tf.image.resize_bilinear( 137 | output, 138 | [input_size[0], input_size[1]]) 139 | 140 | # Input full-sized embedding 141 | label_input = tf.placeholder( 142 | tf.int32, shape=[1, None, None, 1]) 143 | embedding_input = tf.placeholder( 144 | tf.float32, shape=[1, None, None, args.embedding_dim]) 145 | embedding = common_utils.normalize_embedding(embedding_input) 146 | loc_feature = tf.placeholder( 147 | tf.float32, shape=[1, None, None, 2]) 148 | 149 | # Combine embedding with location features and kmeans 150 | shape = tf.shape(embedding) 151 | # embedding_with_location = tf.concat([embedding, loc_feature], 3) 152 | # embedding_with_location = common_utils.normalize_embedding( 153 | # embedding_with_location) 154 | cluster_labels = common_utils.initialize_cluster_labels( 155 | [args.num_clusters, args.num_clusters], 156 | [shape[1], shape[2]]) 157 | embedding = tf.reshape(embedding, [-1, args.embedding_dim]) 158 | labels = tf.reshape(label_input, [-1]) 159 | cluster_labels = tf.reshape(cluster_labels, [-1]) 160 | location_features = tf.reshape(loc_feature, [-1, 2]) 161 | 162 | (prototype_features, 163 | prototype_labels, 164 | _) = eval_utils.extract_trained_prototypes( 165 | embedding, location_features, cluster_labels, 166 | args.num_clusters * args.num_clusters, 167 | args.kmeans_iterations, labels, 168 | 1, args.ignore_label, 169 | 'semantic') 170 | 171 | # Set up tf session and initialize variables. 172 | config = tf.ConfigProto() 173 | config.gpu_options.allow_growth = True 174 | sess = tf.Session(config=config) 175 | init = tf.global_variables_initializer() 176 | 177 | sess.run(init) 178 | sess.run(tf.local_variables_initializer()) 179 | 180 | # Load weights. 181 | loader = tf.train.Saver(var_list=restore_var) 182 | if args.restore_from is not None: 183 | load(loader, sess, args.restore_from) 184 | 185 | # Start queue threads. 186 | threads = tf.train.start_queue_runners(coord=coord, sess=sess) 187 | 188 | # Create directory for saving prototypes. 189 | save_dir = os.path.join(args.save_dir, 'prototypes') 190 | if not os.path.isdir(save_dir): 191 | os.makedirs(save_dir) 192 | 193 | # Iterate over testing steps. 194 | with open(args.data_list, 'r') as listf: 195 | num_steps = len(listf.read().split('\n'))-1 196 | 197 | 198 | def _compute_prototypes(image_batch_np, label_batch_np): 199 | img_size = image_batch_np.shape 200 | padded_img_size = list(img_size) # deep copy of img_size 201 | 202 | if input_size[0] > padded_img_size[1]: 203 | padded_img_size[1] = input_size[0] 204 | if input_size[1] > padded_img_size[2]: 205 | padded_img_size[2] = input_size[1] 206 | padded_img_batch = np.zeros(padded_img_size, 207 | dtype=np.float32) 208 | img_h, img_w = img_size[1:3] 209 | padded_img_batch[:, :img_h, :img_w, :] = image_batch_np 210 | 211 | stride_h, stride_w = strides 212 | npatches_h = math.ceil(1.0*(padded_img_size[1]-input_size[0])/stride_h) + 1 213 | npatches_w = math.ceil(1.0*(padded_img_size[2]-input_size[1])/stride_w) + 1 214 | 215 | # Create the ending index of each patch. 216 | patch_indh = np.linspace( 217 | input_size[0], padded_img_size[1], npatches_h, dtype=np.int32) 218 | patch_indw = np.linspace( 219 | input_size[1], padded_img_size[2], npatches_w, dtype=np.int32) 220 | 221 | # Create embedding holder. 222 | padded_img_size[-1] = args.embedding_dim 223 | embedding_all_np = np.zeros(padded_img_size, 224 | dtype=np.float32) 225 | for indh in patch_indh: 226 | for indw in patch_indw: 227 | sh, eh = indh-input_size[0], indh # start & end ind of H 228 | sw, ew = indw-input_size[1], indw # start & end ind of W 229 | cropimg_batch = padded_img_batch[:, sh:eh, sw:ew, :] 230 | 231 | embedding_np = sess.run(output, feed_dict={ 232 | crop_image_batch: cropimg_batch}) 233 | embedding_all_np[:, sh:eh, sw:ew, :] += embedding_np 234 | 235 | embedding_all_np = embedding_all_np[:, :img_h, :img_w, :] 236 | loc_feature_np = common_utils.generate_location_features_np([padded_img_size[1], padded_img_size[2]]) 237 | feed_dict = {label_input: label_batch_np, 238 | embedding_input: embedding_all_np, 239 | loc_feature: loc_feature_np} 240 | 241 | (batch_prototype_features_np, 242 | batch_prototype_labels_np) = sess.run( 243 | [prototype_features, prototype_labels], 244 | feed_dict=feed_dict) 245 | return batch_prototype_features_np, batch_prototype_labels_np 246 | 247 | 248 | pbar = tqdm(range(num_steps)) 249 | for step in pbar: 250 | image_batch_np, label_batch_np = sess.run( 251 | [image_batch, label_batch]) 252 | 253 | (batch_prototype_features_np, 254 | batch_prototype_labels_np) = _compute_prototypes( 255 | image_batch_np, label_batch_np) 256 | 257 | if step == 0: 258 | prototype_features_np = batch_prototype_features_np 259 | prototype_labels_np = batch_prototype_labels_np 260 | else: 261 | prototype_features_np = np.concatenate( 262 | [prototype_features_np, batch_prototype_features_np], axis=0) 263 | prototype_labels_np = np.concatenate( 264 | [prototype_labels_np, 265 | batch_prototype_labels_np], axis=0) 266 | 267 | if args.flip_aug: 268 | image_batch_np = image_batch_np[:, :, ::-1, :] 269 | label_batch_np = label_batch_np[:, :, ::-1, :] 270 | 271 | (batch_prototype_features_np, 272 | batch_prototype_labels_np) = _compute_prototypes( 273 | image_batch_np, label_batch_np) 274 | 275 | prototype_features_np = np.concatenate( 276 | [prototype_features_np, batch_prototype_features_np], axis=0) 277 | prototype_labels_np = np.concatenate( 278 | [prototype_labels_np, 279 | batch_prototype_labels_np], axis=0) 280 | 281 | print ('Total number of prototypes extracted: ', 282 | len(prototype_labels_np)) 283 | np.save( 284 | tf.gfile.Open('%s/%s.npy' % (save_dir, 'prototype_features'), 285 | mode='w'), prototype_features_np) 286 | np.save( 287 | tf.gfile.Open('%s/%s.npy' % (save_dir, 'prototype_labels'), 288 | mode='w'), prototype_labels_np) 289 | 290 | 291 | coord.request_stop() 292 | coord.join(threads) 293 | 294 | if __name__ == '__main__': 295 | main() 296 | -------------------------------------------------------------------------------- /pyscripts/inference/prototype_unsup.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | 3 | import argparse 4 | import os 5 | import time 6 | import math 7 | 8 | import tensorflow as tf 9 | import numpy as np 10 | import scipy.io 11 | import scipy.misc 12 | import network.vmf.common_utils as common_utils 13 | import network.vmf.eval_utils as eval_utils 14 | from PIL import Image 15 | 16 | from seg_models.models.pspnet import pspnet_resnet101 as model 17 | from seg_models.image_reader import VMFImageReader 18 | import utils.general 19 | 20 | IMG_MEAN = np.array((122.675, 116.669, 104.008), dtype=np.float32) 21 | 22 | 23 | def get_arguments(): 24 | """Parse all the arguments provided from the CLI. 25 | 26 | Returns: 27 | A list of parsed arguments. 28 | """ 29 | parser = argparse.ArgumentParser( 30 | description='Extracting Prototypes for Semantic Segmentation') 31 | parser.add_argument('--data-dir', type=str, default='', 32 | help='/path/to/dataset.') 33 | parser.add_argument('--data-list', type=str, default='', 34 | help='/path/to/datalist/file.') 35 | parser.add_argument('--input-size', type=str, default='512,512', 36 | help='Comma-separated string with H and W of image.') 37 | parser.add_argument('--strides', type=str, default='512,512', 38 | help='Comma-separated string with strides of H and W.') 39 | parser.add_argument('--num-classes', type=int, default=21, 40 | help='Number of classes to predict.') 41 | parser.add_argument('--ignore-label', type=int, default=255, 42 | help='Index of label to ignore.') 43 | parser.add_argument('--restore-from', type=str, default='', 44 | help='Where restore model parameters from.') 45 | parser.add_argument('--save-dir', type=str, default='', 46 | help='/path/to/save/predictions.') 47 | parser.add_argument('--colormap', type=str, default='', 48 | help='/path/to/colormap/file.') 49 | # vMF parameters 50 | parser.add_argument('--embedding_dim', type=int, default=32, 51 | help='Dimension of the feature embeddings.') 52 | parser.add_argument('--num_clusters', type=int, default=5, 53 | help='Number of kmeans clusters along each axis') 54 | parser.add_argument('--kmeans_iterations', type=int, default=10, 55 | help='Number of kmeans iterations.') 56 | 57 | 58 | return parser.parse_args() 59 | 60 | def load(saver, sess, ckpt_path): 61 | """Load the trained weights. 62 | 63 | Args: 64 | saver: TensorFlow saver object. 65 | sess: TensorFlow session. 66 | ckpt_path: path to checkpoint file with parameters. 67 | """ 68 | saver.restore(sess, ckpt_path) 69 | print('Restored model parameters from {}'.format(ckpt_path)) 70 | 71 | def parse_commastr(str_comma): 72 | """Read comma-sperated string. 73 | """ 74 | if '' == str_comma: 75 | return None 76 | else: 77 | a, b = map(int, str_comma.split(',')) 78 | 79 | return [a,b] 80 | 81 | def main(): 82 | """Create the model and start the Inference process. 83 | """ 84 | args = get_arguments() 85 | 86 | # Parse image processing arguments. 87 | input_size = parse_commastr(args.input_size) 88 | strides = parse_commastr(args.strides) 89 | assert(input_size is not None and strides is not None) 90 | h, w = input_size 91 | innet_size = (int(math.ceil(h/8)), int(math.ceil(w/8))) 92 | 93 | 94 | # Create queue coordinator. 95 | coord = tf.train.Coordinator() 96 | 97 | # Load the data reader. 98 | with tf.name_scope('create_inputs'): 99 | reader = VMFImageReader( 100 | args.data_dir, 101 | args.data_list, 102 | None, 103 | False, # No random scale. 104 | False, # No random mirror. 105 | False, # No random crop, center crop instead 106 | args.ignore_label, 107 | IMG_MEAN) 108 | 109 | image_batch = tf.expand_dims(reader.image, dim=0) 110 | label_batch = tf.expand_dims(reader.label, dim=0) 111 | cluster_label_batch = tf.expand_dims(reader.cluster_label, dim=0) 112 | loc_feature_batch = tf.expand_dims(reader.loc_feature, dim=0) 113 | 114 | # Create network and output prediction. 115 | outputs = model(image_batch, 116 | args.embedding_dim, 117 | False, 118 | True) 119 | 120 | # Grab variable names which should be restored from checkpoints. 121 | restore_var = [ 122 | v for v in tf.global_variables() if 'crop_image_batch' not in v.name] 123 | 124 | # Output predictions. 125 | output = outputs[0] 126 | output = tf.image.resize_bilinear( 127 | output, 128 | tf.shape(image_batch)[1:3,]) 129 | embedding = common_utils.normalize_embedding(output) 130 | 131 | shape = embedding.get_shape().as_list() 132 | batch_size = shape[0] 133 | 134 | labels = label_batch 135 | initial_cluster_labels = cluster_label_batch[0, :, :] 136 | location_features = tf.reshape(loc_feature_batch[0, :, :], [-1, 2]) 137 | 138 | prototype_feature_list = [] 139 | prototype_label_list = [] 140 | for bs in range(batch_size): 141 | cur_labels = tf.reshape(labels[bs], [-1]) 142 | cur_cluster_labels = tf.reshape(initial_cluster_labels, [-1]) 143 | cur_embedding = tf.reshape(embedding[bs], [-1, args.embedding_dim]) 144 | 145 | (prototype_features, 146 | prototype_labels, 147 | _) = eval_utils.extract_trained_prototypes( 148 | cur_embedding, location_features, cur_cluster_labels, 149 | args.num_clusters * args.num_clusters, 150 | args.kmeans_iterations, cur_labels, 151 | 1, args.ignore_label, 152 | 'semantic') 153 | 154 | prototype_feature_list.append(prototype_features) 155 | prototype_label_list.append(prototype_labels) 156 | 157 | prototype_features = tf.concat(prototype_feature_list, axis=0) 158 | prototype_labels = tf.concat(prototype_label_list, axis=0) 159 | 160 | 161 | # Set up tf session and initialize variables. 162 | config = tf.ConfigProto() 163 | config.gpu_options.allow_growth = True 164 | sess = tf.Session(config=config) 165 | init = tf.global_variables_initializer() 166 | 167 | sess.run(init) 168 | sess.run(tf.local_variables_initializer()) 169 | 170 | # Load weights. 171 | loader = tf.train.Saver(var_list=restore_var) 172 | if args.restore_from is not None: 173 | load(loader, sess, args.restore_from) 174 | 175 | # Start queue threads. 176 | threads = tf.train.start_queue_runners(coord=coord, sess=sess) 177 | 178 | # Create directory for saving prototypes. 179 | save_dir = os.path.join(args.save_dir, 'prototypes') 180 | if not os.path.isdir(save_dir): 181 | os.makedirs(save_dir) 182 | 183 | # Iterate over testing steps. 184 | with open(args.data_list, 'r') as listf: 185 | num_steps = len(listf.read().split('\n'))-1 186 | 187 | for step in range(num_steps): 188 | (batch_prototype_features_np, 189 | batch_prototype_labels_np) = sess.run( 190 | [prototype_features, prototype_labels]) 191 | 192 | if step == 0: 193 | prototype_features_np = batch_prototype_features_np 194 | prototype_labels_np = batch_prototype_labels_np 195 | else: 196 | prototype_features_np = np.concatenate( 197 | [prototype_features_np, batch_prototype_features_np], axis=0) 198 | prototype_labels_np = np.concatenate( 199 | [prototype_labels_np, 200 | batch_prototype_labels_np], axis=0) 201 | 202 | if (step + 1) % 100 == 0: 203 | print('Processed batches: ', (step + 1), '/', num_steps) 204 | 205 | print ('Total number of prototypes extracted: ', 206 | len(prototype_labels_np)) 207 | np.save( 208 | tf.gfile.Open('%s/%s.npy' % (save_dir, 'prototype_features'), 209 | mode='w'), prototype_features_np) 210 | np.save( 211 | tf.gfile.Open('%s/%s.npy' % (save_dir, 'prototype_labels'), 212 | mode='w'), prototype_labels_np) 213 | 214 | 215 | coord.request_stop() 216 | coord.join(threads) 217 | 218 | if __name__ == '__main__': 219 | main() 220 | -------------------------------------------------------------------------------- /pyscripts/train/train_segsort.py: -------------------------------------------------------------------------------- 1 | """Training script for training PSPNet/ResNet-101 with SegSort.""" 2 | 3 | from __future__ import print_function 4 | 5 | import argparse 6 | import math 7 | import os 8 | import time 9 | import utils.general 10 | 11 | import network.common.layers as nn 12 | import network.segsort.train_utils as train_utils 13 | import numpy as np 14 | import tensorflow as tf 15 | 16 | from seg_models.models.pspnet import pspnet_resnet101 as model 17 | from seg_models.image_reader import SegSortImageReader 18 | from tqdm import tqdm 19 | 20 | 21 | IMG_MEAN = np.array((122.675, 116.669, 104.008), dtype=np.float32) 22 | 23 | 24 | def get_arguments(): 25 | """Parses all the arguments provided from the CLI. 26 | 27 | Returns: 28 | A list of parsed arguments. 29 | """ 30 | parser = argparse.ArgumentParser( 31 | description='SegSort: Segmentation by Discriminative Sorting of Segments.') 32 | # Data parameters. 33 | parser.add_argument('--batch_size', type=int, default=1, 34 | help='Number of images in a batch.') 35 | parser.add_argument('--data_dir', type=str, default='', 36 | help='/path/to/dataset/.') 37 | parser.add_argument('--data_list', type=str, default='', 38 | help='/path/to/datalist/file.') 39 | parser.add_argument('--ignore_label', type=int, default=255, 40 | help='The index of the label to ignore.') 41 | parser.add_argument('--input_size', type=str, default='336,336', 42 | help='Comma-separated string with H and W of image.') 43 | parser.add_argument('--random_seed', type=int, default=1234, 44 | help='Random seed for reproducible results.') 45 | # Training paramters. 46 | parser.add_argument('--is_training', action='store_true', 47 | help='Whether to updates weights.') 48 | parser.add_argument('--use_global_status', action='store_true', 49 | help='Whether to updates moving mean and variance.') 50 | parser.add_argument('--learning_rate', type=float, default=1e-3, 51 | help='Base learning rate.') 52 | parser.add_argument('--power', type=float, default=0.9, 53 | help='Decay for poly learing rate policy.') 54 | parser.add_argument('--momentum', type=float, default=0.9, 55 | help='Momentum component of the optimiser.') 56 | parser.add_argument('--weight_decay', type=float, default=5e-4, 57 | help='Regularization hyperparameter for L2-loss.') 58 | parser.add_argument('--num_classes', type=int, default=21, 59 | help='Number of classes to predict.') 60 | parser.add_argument('--num_steps', type=int, default=20000, 61 | help='Number of training steps.') 62 | parser.add_argument('--iter_size', type=int, default=10, 63 | help='Number of iteration to update weights.') 64 | parser.add_argument('--random_mirror', action='store_true', 65 | help='Whether to randomly mirror the inputs.') 66 | parser.add_argument('--random_crop', action='store_true', 67 | help='Whether to randomly crop the inputs.') 68 | parser.add_argument('--random_scale', action='store_true', 69 | help='Whether to randomly scale the inputs.') 70 | # SegSort parameters. 71 | parser.add_argument('--embedding_dim', type=int, default=32, 72 | help='Dimension of the feature embeddings.') 73 | parser.add_argument('--concentration', type=float, default=10.0, 74 | help='Concentration of the vMF distribution.') 75 | parser.add_argument('--num_clusters', type=int, default=5, 76 | help='Number of kmeans clusters along each axis.') 77 | parser.add_argument('--kmeans_iterations', type=int, default=10, 78 | help='Number of kmeans iterations.') 79 | parser.add_argument('--num_banks', type=int, default=2, 80 | help='Number of memory banks for prototypes.') 81 | # Misc paramters. 82 | parser.add_argument('--restore_from', type=str, default='', 83 | help='Where restore model parameters from.') 84 | parser.add_argument('--save_pred_every', type=int, default=10000, 85 | help='Save summaries and checkpoint every often.') 86 | parser.add_argument('--update_tb_every', type=int, default=20, 87 | help='Update summaries every often.') 88 | parser.add_argument('--snapshot_dir', type=str, default='', 89 | help='Where to save snapshots of the model.') 90 | parser.add_argument('--not_restore_classifier', action='store_true', 91 | help='Whether to not restore classifier layers.') 92 | 93 | return parser.parse_args() 94 | 95 | 96 | def save(saver, sess, logdir, step): 97 | """Saves the trained weights. 98 | 99 | Args: 100 | saver: TensorFlow Saver object. 101 | sess: TensorFlow session. 102 | logdir: path to the snapshots directory. 103 | step: current training step. 104 | """ 105 | model_name = 'model.ckpt' 106 | checkpoint_path = os.path.join(logdir, model_name) 107 | 108 | if not os.path.exists(logdir): 109 | os.makedirs(logdir) 110 | saver.save(sess, checkpoint_path, global_step=step) 111 | print('The checkpoint has been created.') 112 | 113 | 114 | def load(saver, sess, ckpt_path): 115 | """Loads the trained weights. 116 | 117 | Args: 118 | saver: TensorFlow Saver object. 119 | sess: TensorFlow session. 120 | ckpt_path: path to checkpoint file with parameters. 121 | """ 122 | saver.restore(sess, ckpt_path) 123 | print('Restored model parameters from {}'.format(ckpt_path)) 124 | 125 | 126 | def main(): 127 | """Creates the model and start training.""" 128 | 129 | # Read CL arguments and snapshot the arguments into text file. 130 | args = get_arguments() 131 | utils.general.snapshot_arg(args) 132 | 133 | # The segmentation network is stride 8 by default. 134 | h, w = map(int, args.input_size.split(',')) 135 | input_size = (h, w) 136 | innet_size = (int(math.ceil(h/8)), int(math.ceil(w/8))) 137 | 138 | # Initialize the random seed. 139 | tf.set_random_seed(args.random_seed) 140 | 141 | # Create queue coordinator. 142 | coord = tf.train.Coordinator() 143 | 144 | # current step 145 | step_ph = tf.placeholder(dtype=tf.float32, shape=()) 146 | 147 | # Load the data reader. 148 | with tf.device('/cpu:0'): 149 | with tf.name_scope('create_inputs'): 150 | reader = SegSortImageReader( 151 | args.data_dir, 152 | args.data_list, 153 | input_size, 154 | args.random_scale, 155 | args.random_mirror, 156 | args.random_crop, 157 | args.ignore_label, 158 | IMG_MEAN, 159 | [args.num_clusters, args.num_clusters]) 160 | 161 | image_batch, label_batch, cluster_label_batch, loc_feature_batch = ( 162 | reader.dequeue(args.batch_size)) 163 | 164 | # Shrink labels to the size of the network output. 165 | labels = tf.image.resize_nearest_neighbor( 166 | label_batch, innet_size, name='label_shrink') 167 | cluster_labels = tf.image.resize_nearest_neighbor( 168 | cluster_label_batch, innet_size) 169 | loc_features = tf.image.resize_nearest_neighbor( 170 | loc_feature_batch, innet_size) 171 | 172 | # Create network and predictions. 173 | outputs = model(image_batch, 174 | args.embedding_dim, 175 | args.is_training, 176 | args.use_global_status) 177 | 178 | # Grab variable names which should be restored from checkpoints. 179 | restore_var = [ 180 | v for v in tf.global_variables() 181 | if 'block5' not in v.name or not args.not_restore_classifier 182 | ] 183 | 184 | # Add the SegSort loss. 185 | seg_losses = train_utils.add_segsort_loss( 186 | outputs[0], labels, args.embedding_dim, args.ignore_label, 187 | args.concentration, cluster_labels, args.num_clusters, 188 | args.kmeans_iterations, args.num_banks, loc_features) 189 | 190 | # Define weight regularization loss. 191 | w = args.weight_decay 192 | l2_losses = [w*tf.nn.l2_loss(v) for v in tf.trainable_variables() 193 | if 'weights' in v.name] 194 | 195 | # Sum all loss terms. 196 | mean_seg_loss = seg_losses 197 | mean_l2_loss = tf.add_n(l2_losses) 198 | reduced_loss = mean_seg_loss + mean_l2_loss 199 | 200 | # Grab variable names which are used for training. 201 | all_trainable = tf.trainable_variables() 202 | fc_trainable = [v for v in all_trainable if 'block5' in v.name] # lr*10 203 | base_trainable = [v for v in all_trainable if 'block5' not in v.name] # lr*1 204 | 205 | # Computes gradients per iteration. 206 | grads = tf.gradients(reduced_loss, base_trainable+fc_trainable) 207 | grads_base = grads[:len(base_trainable)] 208 | grads_fc = grads[len(base_trainable):] 209 | 210 | # Define optimisation parameters. 211 | base_lr = tf.constant(args.learning_rate) 212 | learning_rate = tf.scalar_mul( 213 | base_lr, 214 | tf.pow((1-step_ph/args.num_steps), args.power)) 215 | 216 | opt_base = tf.train.MomentumOptimizer(learning_rate*1.0, args.momentum) 217 | opt_fc = tf.train.MomentumOptimizer(learning_rate*10.0, args.momentum) 218 | 219 | # Define tensorflow operations which apply gradients to update variables. 220 | train_op_base = opt_base.apply_gradients(zip(grads_base, base_trainable)) 221 | train_op_fc = opt_fc.apply_gradients(zip(grads_fc, fc_trainable)) 222 | train_op = tf.group(train_op_base, train_op_fc) 223 | 224 | # Process for visualisation. 225 | with tf.device('/cpu:0'): 226 | # Image summary for input image, ground-truth label and prediction. 227 | 228 | # Visualize first 3 channels of embeddings. 229 | # Can also perform PCA by calling the vis_utils.pca function. 230 | output_vis = tf.image.resize_nearest_neighbor( 231 | outputs[-1], tf.shape(image_batch)[1:3,]) 232 | output_vis = tf.argmax(output_vis, axis=3) 233 | output_vis = tf.expand_dims(output_vis, dim=3) 234 | output_vis = tf.cast(output_vis, dtype=tf.uint8) 235 | 236 | labels_vis = tf.cast(label_batch, dtype=tf.uint8) 237 | 238 | in_summary = tf.py_func( 239 | utils.general.inv_preprocess, 240 | [image_batch, IMG_MEAN], 241 | tf.uint8) 242 | gt_summary = tf.py_func( 243 | utils.general.decode_labels, 244 | [labels_vis, args.num_classes], 245 | tf.uint8) 246 | out_summary = tf.py_func( 247 | utils.general.decode_labels, 248 | [output_vis, args.num_classes], 249 | tf.uint8) 250 | # Concatenate image summaries in a row. 251 | total_summary = tf.summary.image( 252 | 'images', 253 | tf.concat(axis=2, values=[in_summary, gt_summary, out_summary]), 254 | max_outputs=args.batch_size) 255 | 256 | # Scalar summary for different loss terms. 257 | seg_loss_summary = tf.summary.scalar( 258 | 'seg_loss', mean_seg_loss) 259 | total_summary = tf.summary.merge_all() 260 | 261 | summary_writer = tf.summary.FileWriter( 262 | args.snapshot_dir, 263 | graph=tf.get_default_graph()) 264 | 265 | # Set up tf session and initialize variables. 266 | config = tf.ConfigProto() 267 | config.gpu_options.allow_growth = True 268 | sess = tf.Session(config=config) 269 | init = tf.global_variables_initializer() 270 | 271 | sess.run(init) 272 | 273 | # Saver for storing checkpoints of the model. 274 | saver = tf.train.Saver(var_list=tf.global_variables(), max_to_keep=10) 275 | 276 | # Load variables if the checkpoint is provided. 277 | if args.restore_from is not None: 278 | loader = tf.train.Saver(var_list=restore_var) 279 | load(loader, sess, args.restore_from) 280 | 281 | # Start queue threads. 282 | threads = tf.train.start_queue_runners( 283 | coord=coord, sess=sess) 284 | 285 | # Iterate over training steps. 286 | pbar = tqdm(range(args.num_steps)) 287 | for step in pbar: 288 | start_time = time.time() 289 | feed_dict = {step_ph : step} 290 | 291 | step_loss = 0 292 | for it in range(args.iter_size): 293 | # Update summary periodically. 294 | if it == args.iter_size-1 and step % args.update_tb_every == 0: 295 | sess_outs = [reduced_loss, total_summary, train_op] 296 | loss_value, summary, _ = sess.run(sess_outs, 297 | feed_dict=feed_dict) 298 | summary_writer.add_summary(summary, step) 299 | else: 300 | sess_outs = [reduced_loss, train_op] 301 | loss_value, _ = sess.run(sess_outs, feed_dict=feed_dict) 302 | 303 | step_loss += loss_value 304 | 305 | step_loss /= args.iter_size 306 | 307 | lr = sess.run(learning_rate, feed_dict=feed_dict) 308 | 309 | # Save trained model periodically. 310 | if step % args.save_pred_every == 0 and step > 0: 311 | save(saver, sess, args.snapshot_dir, step) 312 | 313 | duration = time.time() - start_time 314 | desc = 'loss = {:.3f}, lr = {:.6f}'.format(step_loss, lr) 315 | pbar.set_description(desc) 316 | 317 | coord.request_stop() 318 | coord.join(threads) 319 | 320 | if __name__ == '__main__': 321 | main() 322 | -------------------------------------------------------------------------------- /pyscripts/train/train_segsort_unsup.py: -------------------------------------------------------------------------------- 1 | """Training script for unsupervised training PSPNet/ResNet-101 with SegSort.""" 2 | 3 | from __future__ import print_function 4 | 5 | import argparse 6 | import math 7 | import os 8 | import time 9 | import utils.general 10 | 11 | import network.common.layers as nn 12 | import network.segsort.train_utils as train_utils 13 | import numpy as np 14 | import tensorflow as tf 15 | 16 | from seg_models.models.pspnet import pspnet_resnet101 as model 17 | from seg_models.image_reader import SegSortUnsupImageReader 18 | from tqdm import tqdm 19 | 20 | 21 | IMG_MEAN = np.array((122.675, 116.669, 104.008), dtype=np.float32) 22 | 23 | 24 | def get_arguments(): 25 | """Parse all the arguments provided from the CLI. 26 | 27 | Returns: 28 | A list of parsed arguments. 29 | """ 30 | parser = argparse.ArgumentParser(description='Semantic Segmentation') 31 | # Data parameters. 32 | parser.add_argument('--batch_size', type=int, default=1, 33 | help='Number of images in one step.') 34 | parser.add_argument('--data_dir', type=str, default='', 35 | help='/path/to/dataset/.') 36 | parser.add_argument('--data_list', type=str, default='', 37 | help='/path/to/datalist/file.') 38 | parser.add_argument('--ignore_label', type=int, default=255, 39 | help='The index of the label to ignore.') 40 | parser.add_argument('--input_size', type=str, default='336,336', 41 | help='Comma-separated string with H and W of image.') 42 | parser.add_argument('--random_seed', type=int, default=1234, 43 | help='Random seed to have reproducible results.') 44 | # Training paramters. 45 | parser.add_argument('--is_training', action='store_true', 46 | help='Whether to updates weights.') 47 | parser.add_argument('--use_global_status', action='store_true', 48 | help='Whether to updates moving mean and variance.') 49 | parser.add_argument('--learning_rate', type=float, default=2.5e-4, 50 | help='Base learning rate.') 51 | parser.add_argument('--power', type=float, default=0.9, 52 | help='Decay for poly learing rate policy.') 53 | parser.add_argument('--momentum', type=float, default=0.9, 54 | help='Momentum component of the optimiser.') 55 | parser.add_argument('--weight_decay', type=float, default=5e-4, 56 | help='Regularisation parameter for L2-loss.') 57 | parser.add_argument('--num_classes', type=int, default=21, 58 | help='Number of classes to predict.') 59 | parser.add_argument('--num_steps', type=int, default=20000, 60 | help='Number of training steps.') 61 | parser.add_argument('--iter_size', type=int, default=10, 62 | help='Number of iteration to update weights') 63 | parser.add_argument('--random_mirror', action='store_true', 64 | help='Whether to randomly mirror the inputs.') 65 | parser.add_argument('--random_crop', action='store_true', 66 | help='Whether to randomly crop the inputs.') 67 | parser.add_argument('--random_scale', action='store_true', 68 | help='Whether to randomly scale the inputs.') 69 | # SegSort parameters. 70 | parser.add_argument('--embedding_dim', type=int, default=32, 71 | help='Dimension of the feature embeddings.') 72 | parser.add_argument('--concentration', type=float, default=10.0, 73 | help='Concentration of the vMF distribution.') 74 | # Misc paramters. 75 | parser.add_argument('--restore_from', type=str, default='', 76 | help='Where restore model parameters from.') 77 | parser.add_argument('--save_pred_every', type=int, default=10000, 78 | help='Save summaries and checkpoint every often.') 79 | parser.add_argument('--update_tb_every', type=int, default=20, 80 | help='Update summaries every often.') 81 | parser.add_argument('--snapshot_dir', type=str, default='', 82 | help='Where to save snapshots of the model.') 83 | parser.add_argument('--not_restore_classifier', action='store_true', 84 | help='Whether to not restore classifier layers.') 85 | 86 | return parser.parse_args() 87 | 88 | 89 | def save(saver, sess, logdir, step): 90 | """Saves the trained weights. 91 | 92 | Args: 93 | saver: TensorFlow Saver object. 94 | sess: TensorFlow session. 95 | logdir: path to the snapshots directory. 96 | step: current training step. 97 | """ 98 | model_name = 'model.ckpt' 99 | checkpoint_path = os.path.join(logdir, model_name) 100 | 101 | if not os.path.exists(logdir): 102 | os.makedirs(logdir) 103 | saver.save(sess, checkpoint_path, global_step=step) 104 | print('The checkpoint has been created.') 105 | 106 | 107 | def load(saver, sess, ckpt_path): 108 | """Loads the trained weights. 109 | 110 | Args: 111 | saver: TensorFlow Saver object. 112 | sess: TensorFlow session. 113 | ckpt_path: path to checkpoint file with parameters. 114 | """ 115 | saver.restore(sess, ckpt_path) 116 | print('Restored model parameters from {}'.format(ckpt_path)) 117 | 118 | 119 | def main(): 120 | """Create the model and start training. 121 | """ 122 | # Read CL arguments and snapshot the arguments into text file. 123 | args = get_arguments() 124 | utils.general.snapshot_arg(args) 125 | 126 | # The segmentation network is stride 8 by default. 127 | h, w = map(int, args.input_size.split(',')) 128 | input_size = (h, w) 129 | innet_size = (int(math.ceil(h/8)), int(math.ceil(w/8))) 130 | 131 | # Initialize the random seed. 132 | tf.set_random_seed(args.random_seed) 133 | 134 | # Create queue coordinator. 135 | coord = tf.train.Coordinator() 136 | 137 | # current step 138 | step_ph = tf.placeholder(dtype=tf.float32, shape=()) 139 | 140 | # Load the data reader. 141 | with tf.device('/cpu:0'): 142 | with tf.name_scope('create_inputs'): 143 | reader = SegSortUnsupImageReader( 144 | args.data_dir, 145 | args.data_list, 146 | input_size, 147 | args.random_scale, 148 | args.random_mirror, 149 | args.random_crop, 150 | args.ignore_label, 151 | IMG_MEAN) 152 | 153 | image_batch, _, cluster_label_batch = ( 154 | reader.dequeue(args.batch_size)) 155 | 156 | # Shrink labels to the size of the network output. 157 | cluster_labels = tf.image.resize_nearest_neighbor( 158 | cluster_label_batch, innet_size) 159 | 160 | # Create network and predictions. 161 | outputs = model(image_batch, 162 | args.embedding_dim, 163 | args.is_training, 164 | args.use_global_status) 165 | 166 | # Grab variable names which should be restored from checkpoints. 167 | restore_var = [ 168 | v for v in tf.global_variables() 169 | if 'block5' not in v.name or not args.not_restore_classifier 170 | ] 171 | 172 | # Add Unsupervised SegSort loss. 173 | seg_losses = train_utils.add_unsupervised_segsort_loss( 174 | outputs[0], args.concentration, cluster_labels, ) 175 | 176 | # Define weight regularization loss. 177 | w = args.weight_decay 178 | l2_losses = [w*tf.nn.l2_loss(v) for v in tf.trainable_variables() 179 | if 'weights' in v.name] 180 | 181 | # Sum all loss terms. 182 | mean_seg_loss = seg_losses 183 | mean_l2_loss = tf.add_n(l2_losses) 184 | reduced_loss = mean_seg_loss + mean_l2_loss 185 | 186 | # Grab variable names which are used for training. 187 | all_trainable = tf.trainable_variables() 188 | fc_trainable = [v for v in all_trainable if 'block5' in v.name] # lr*10 189 | base_trainable = [v for v in all_trainable if 'block5' not in v.name] # lr*1 190 | 191 | # Computes gradients per iteration. 192 | grads = tf.gradients(reduced_loss, base_trainable+fc_trainable) 193 | grads_base = grads[:len(base_trainable)] 194 | grads_fc = grads[len(base_trainable):] 195 | 196 | # Define optimisation parameters. 197 | base_lr = tf.constant(args.learning_rate) 198 | learning_rate = tf.scalar_mul( 199 | base_lr, 200 | tf.pow((1-step_ph/args.num_steps), args.power)) 201 | 202 | opt_base = tf.train.MomentumOptimizer(learning_rate*1.0, args.momentum) 203 | opt_fc = tf.train.MomentumOptimizer(learning_rate*10.0, args.momentum) 204 | 205 | # Define tensorflow operations which apply gradients to update variables. 206 | train_op_base = opt_base.apply_gradients(zip(grads_base, base_trainable)) 207 | train_op_fc = opt_fc.apply_gradients(zip(grads_fc, fc_trainable)) 208 | train_op = tf.group(train_op_base, train_op_fc) 209 | 210 | # Process for visualisation. 211 | with tf.device('/cpu:0'): 212 | # Image summary for input image, ground-truth label and prediction. 213 | output_vis = tf.image.resize_nearest_neighbor( 214 | outputs[-1], tf.shape(image_batch)[1:3,]) 215 | output_vis = tf.argmax(output_vis, axis=3) 216 | output_vis = tf.expand_dims(output_vis, dim=3) 217 | output_vis = tf.cast(output_vis, dtype=tf.uint8) 218 | 219 | labels_vis = tf.cast(cluster_label_batch, dtype=tf.uint8) 220 | 221 | in_summary = tf.py_func( 222 | utils.general.inv_preprocess, 223 | [image_batch, IMG_MEAN], 224 | tf.uint8) 225 | gt_summary = tf.py_func( 226 | utils.general.decode_labels, 227 | [labels_vis, args.num_classes], 228 | tf.uint8) 229 | out_summary = tf.py_func( 230 | utils.general.decode_labels, 231 | [output_vis, args.num_classes], 232 | tf.uint8) 233 | # Concatenate image summaries in a row. 234 | total_summary = tf.summary.image( 235 | 'images', 236 | tf.concat(axis=2, values=[in_summary, gt_summary, out_summary]), 237 | max_outputs=args.batch_size) 238 | 239 | # Scalar summary for different loss terms. 240 | seg_loss_summary = tf.summary.scalar( 241 | 'seg_loss', mean_seg_loss) 242 | total_summary = tf.summary.merge_all() 243 | 244 | summary_writer = tf.summary.FileWriter( 245 | args.snapshot_dir, 246 | graph=tf.get_default_graph()) 247 | 248 | # Set up tf session and initialize variables. 249 | config = tf.ConfigProto() 250 | config.gpu_options.allow_growth = True 251 | sess = tf.Session(config=config) 252 | init = tf.global_variables_initializer() 253 | 254 | sess.run(init) 255 | 256 | # Saver for storing checkpoints of the model. 257 | saver = tf.train.Saver(var_list=tf.global_variables(), max_to_keep=10) 258 | 259 | # Load variables if the checkpoint is provided. 260 | if args.restore_from is not None: 261 | loader = tf.train.Saver(var_list=restore_var) 262 | load(loader, sess, args.restore_from) 263 | 264 | # Start queue threads. 265 | threads = tf.train.start_queue_runners( 266 | coord=coord, sess=sess) 267 | 268 | # Iterate over training steps. 269 | pbar = tqdm(range(args.num_steps)) 270 | for step in pbar: 271 | start_time = time.time() 272 | feed_dict = {step_ph : step} 273 | 274 | step_loss = 0 275 | for it in range(args.iter_size): 276 | # Update summary periodically. 277 | if it == args.iter_size-1 and step % args.update_tb_every == 0: 278 | sess_outs = [reduced_loss, total_summary, train_op] 279 | loss_value, summary, _ = sess.run(sess_outs, 280 | feed_dict=feed_dict) 281 | summary_writer.add_summary(summary, step) 282 | else: 283 | sess_outs = [reduced_loss, train_op] 284 | loss_value, _ = sess.run(sess_outs, feed_dict=feed_dict) 285 | 286 | step_loss += loss_value 287 | 288 | step_loss /= args.iter_size 289 | 290 | lr = sess.run(learning_rate, feed_dict=feed_dict) 291 | 292 | # Save trained model periodically. 293 | if step % args.save_pred_every == 0 and step > 0: 294 | save(saver, sess, args.snapshot_dir, step) 295 | 296 | duration = time.time() - start_time 297 | desc = 'loss = {:.3f}, lr = {:.6f}'.format(step_loss, lr) 298 | pbar.set_description(desc) 299 | 300 | coord.request_stop() 301 | coord.join(threads) 302 | 303 | if __name__ == '__main__': 304 | main() 305 | -------------------------------------------------------------------------------- /seg_models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jyhjinghwang/SegSort/4e3278b2a7732d62f3784eb629a323acc0ca68f3/seg_models/__init__.py -------------------------------------------------------------------------------- /seg_models/models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jyhjinghwang/SegSort/4e3278b2a7732d62f3784eb629a323acc0ca68f3/seg_models/models/__init__.py -------------------------------------------------------------------------------- /seg_models/models/deeplab.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | 3 | from network.common.resnet_v1 import resnet_v1_101 4 | import network.common.layers as nn 5 | 6 | def _deeplab_builder(x, 7 | name, 8 | cnn_fn, 9 | num_classes, 10 | is_training, 11 | use_global_status, 12 | reuse=False): 13 | """Helper function to build Deeplab v2 model for semantic segmentation. 14 | 15 | The Deeplab v2 model is composed of one base network (ResNet101) and 16 | one ASPP module (4 Atrous Convolutional layers of different size). The 17 | segmentation prediction is the summation of 4 outputs of the ASPP module. 18 | 19 | Args: 20 | x: A tensor of size [batch_size, height_in, width_in, channels]. 21 | name: The prefix of tensorflow variables defined in this network. 22 | cnn_fn: A function which builds the base network (ResNet101). 23 | num_classes: Number of predicted classes for classification tasks. 24 | is_training: If the tensorflow variables defined in this network 25 | would be used for training. 26 | use_global_status: enable/disable use_global_status for batch 27 | normalization. If True, moving mean and moving variance are updated 28 | by exponential decay. 29 | reuse: enable/disable reuse for reusing tensorflow variables. It is 30 | useful for sharing weight parameters across two identical networks. 31 | 32 | Returns: 33 | A tensor of size [batch_size, height_in/8, width_in/8, num_classes]. 34 | """ 35 | # Build the base network. 36 | x = cnn_fn(x, name, is_training, use_global_status, reuse) 37 | 38 | with tf.variable_scope(name, reuse=reuse) as scope: 39 | # Build the ASPP module. 40 | aspp = [] 41 | for i,dilation in enumerate([6, 12, 18, 24]): 42 | score = nn.atrous_conv( 43 | x, 44 | name='fc1_c{:d}'.format(i), 45 | filters=num_classes, 46 | kernel_size=3, 47 | dilation=dilation, 48 | padding='SAME', 49 | relu=False, 50 | biased=True, 51 | bn=False, 52 | is_training=is_training) 53 | aspp.append(score) 54 | 55 | score = tf.add_n(aspp, name='fc1_sum') 56 | 57 | return score 58 | 59 | 60 | def deeplab_resnet101(x, 61 | num_classes, 62 | is_training, 63 | use_global_status, 64 | reuse=False): 65 | """Builds Deeplab v2 based on ResNet101. 66 | 67 | Args: 68 | x: A tensor of size [batch_size, height_in, width_in, channels]. 69 | name: The prefix of tensorflow variables defined in this network. 70 | num_classes: Number of predicted classes for classification tasks. 71 | is_training: If the tensorflow variables defined in this network 72 | would be used for training. 73 | use_global_status: enable/disable use_global_status for batch 74 | normalization. If True, moving mean and moving variance are updated 75 | by exponential decay. 76 | reuse: enable/disable reuse for reusing tensorflow variables. It is 77 | useful for sharing weight parameters across two identical networks. 78 | 79 | Returns: 80 | A tensor of size [batch_size, height_in/8, width_in/8, num_classes]. 81 | """ 82 | h, w = x.get_shape().as_list()[1:3] # NxHxWxC 83 | 84 | scores = [] 85 | for i,scale in enumerate([1]): 86 | with tf.name_scope('scale_{:d}'.format(i)) as scope: 87 | x_in = x 88 | 89 | score = _deeplab_builder( 90 | x_in, 91 | 'resnet_v1_101', 92 | resnet_v1_101, 93 | num_classes, 94 | is_training, 95 | use_global_status, 96 | reuse=reuse) 97 | 98 | scores.append(score) 99 | 100 | return scores 101 | -------------------------------------------------------------------------------- /seg_models/models/fcn.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | 3 | from network.common.resnet_v1 import resnet_v1_101 4 | import network.common.layers as nn 5 | 6 | def _fcn_builder(x, 7 | name, 8 | cnn_fn, 9 | num_classes, 10 | is_training, 11 | use_global_status, 12 | reuse=False): 13 | """Helper function to build FCN8s model for semantic segmentation. 14 | 15 | The FCN8s model is composed of one base network (ResNet101) and 16 | one classifier. 17 | 18 | Args: 19 | x: A tensor of size [batch_size, height_in, width_in, channels]. 20 | name: The prefix of tensorflow variables defined in this network. 21 | cnn_fn: A function which builds the base network (ResNet101). 22 | num_classes: Number of predicted classes for classification tasks. 23 | is_training: If the tensorflow variables defined in this network 24 | would be used for training. 25 | use_global_status: enable/disable use_global_status for batch 26 | normalization. If True, moving mean and moving variance are updated 27 | by exponential decay. 28 | reuse: enable/disable reuse for reusing tensorflow variables. It is 29 | useful for sharing weight parameters across two identical networks. 30 | 31 | Returns: 32 | A tensor of size [batch_size, height_in/8, width_in/8, num_classes]. 33 | """ 34 | h, w = x.get_shape().as_list()[1:3] # NxHxWxC 35 | assert(h%48 == 0 and w%48 == 0 and h == w) 36 | 37 | # Build the base network. 38 | x = cnn_fn(x, name, is_training, use_global_status, reuse) 39 | 40 | with tf.variable_scope(name, reuse=reuse) as scope: 41 | x = nn.conv(x, 42 | 'block5/fc1_voc12', 43 | num_classes, 44 | 1, 45 | 1, 46 | padding='SAME', 47 | biased=True, 48 | bn=False, 49 | relu=False, 50 | is_training=is_training) 51 | 52 | return x 53 | 54 | 55 | def fcn8s_resnet101(x, 56 | num_classes, 57 | is_training, 58 | use_global_status, 59 | reuse=False): 60 | """Builds FCN8s model based on ResNet101. 61 | 62 | Args: 63 | x: A tensor of size [batch_size, height_in, width_in, channels]. 64 | name: The prefix of tensorflow variables defined in this network. 65 | num_classes: Number of predicted classes for classification tasks. 66 | is_training: If the tensorflow variables defined in this network 67 | would be used for training. 68 | use_global_status: enable/disable use_global_status for batch 69 | normalization. If True, moving mean and moving variance are updated 70 | by exponential decay. 71 | reuse: enable/disable reuse for reusing tensorflow variables. It is 72 | useful for sharing weight parameters across two identical networks. 73 | 74 | Returns: 75 | A tensor of size [batch_size, height_in/8, width_in/8, num_classes]. 76 | """ 77 | scores = [] 78 | with tf.name_scope('scale_0') as scope: 79 | score = _fcn_builder( 80 | x, 81 | 'resnet_v1_101', 82 | resnet_v1_101, 83 | num_classes, 84 | is_training, 85 | use_global_status, 86 | reuse=reuse) 87 | 88 | scores.append(score) 89 | 90 | return scores 91 | -------------------------------------------------------------------------------- /seg_models/models/pspnet.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | 3 | from network.common.resnet_v1 import resnet_v1_101 4 | import network.common.layers as nn 5 | 6 | def _pspnet_builder(x, 7 | name, 8 | cnn_fn, 9 | num_classes, 10 | is_training, 11 | use_global_status, 12 | reuse=False): 13 | """Helper function to build PSPNet model for semantic segmentation. 14 | 15 | The PSPNet model is composed of one base network (ResNet101) and 16 | one pyramid spatial pooling (PSP) module, followed with concatenation 17 | and two more convlutional layers for segmentation prediction. 18 | 19 | Args: 20 | x: A tensor of size [batch_size, height_in, width_in, channels]. 21 | name: The prefix of tensorflow variables defined in this network. 22 | cnn_fn: A function which builds the base network (ResNet101). 23 | num_classes: Number of predicted classes for classification tasks. 24 | is_training: If the tensorflow variables defined in this network 25 | would be used for training. 26 | use_global_status: enable/disable use_global_status for batch 27 | normalization. If True, moving mean and moving variance are updated 28 | by exponential decay. 29 | reuse: enable/disable reuse for reusing tensorflow variables. It is 30 | useful for sharing weight parameters across two identical networks. 31 | 32 | Returns: 33 | A tensor of size [batch_size, height_in/8, width_in/8, num_classes]. 34 | """ 35 | # Ensure that the size of input data is valid (should be multiple of 6x8=48). 36 | h, w = x.get_shape().as_list()[1:3] # NxHxWxC 37 | assert(h%48 == 0 and w%48 == 0 and h == w) 38 | 39 | # Build the base network. 40 | x = cnn_fn(x, name, is_training, use_global_status, reuse) 41 | 42 | with tf.variable_scope(name, reuse=reuse) as scope: 43 | # Build the PSP module 44 | pool_k = int(h/8) # the base network is stride 8 by default. 45 | 46 | # Build pooling layer results in 1x1 output. 47 | pool1 = tf.nn.avg_pool(x, 48 | name='block5/pool1', 49 | ksize=[1,pool_k,pool_k,1], 50 | strides=[1,pool_k,pool_k,1], 51 | padding='VALID') 52 | pool1 = nn.conv(pool1, 53 | 'block5/pool1/conv1', 54 | 512, 55 | 1, 56 | 1, 57 | padding='SAME', 58 | biased=False, 59 | bn=True, 60 | relu=True, 61 | is_training=is_training, 62 | decay=0.99, 63 | use_global_status=use_global_status) 64 | pool1 = tf.image.resize_bilinear(pool1, [pool_k, pool_k]) 65 | 66 | # Build pooling layer results in 2x2 output. 67 | pool2 = tf.nn.avg_pool(x, 68 | name='block5/pool2', 69 | ksize=[1,pool_k//2,pool_k//2,1], 70 | strides=[1,pool_k//2,pool_k//2,1], 71 | padding='VALID') 72 | pool2 = nn.conv(pool2, 73 | 'block5/pool2/conv1', 74 | 512, 75 | 1, 76 | 1, 77 | padding='SAME', 78 | biased=False, 79 | bn=True, 80 | relu=True, 81 | is_training=is_training, 82 | decay=0.99, 83 | use_global_status=use_global_status) 84 | pool2 = tf.image.resize_bilinear(pool2, [pool_k, pool_k]) 85 | 86 | # Build pooling layer results in 3x3 output. 87 | pool3 = tf.nn.avg_pool(x, 88 | name='block5/pool3', 89 | ksize=[1,pool_k//3,pool_k//3,1], 90 | strides=[1,pool_k//3,pool_k//3,1], 91 | padding='VALID') 92 | pool3 = nn.conv(pool3, 93 | 'block5/pool3/conv1', 94 | 512, 95 | 1, 96 | 1, 97 | padding='SAME', 98 | biased=False, 99 | bn=True, 100 | relu=True, 101 | is_training=is_training, 102 | decay=0.99, 103 | use_global_status=use_global_status) 104 | pool3 = tf.image.resize_bilinear(pool3, [pool_k, pool_k]) 105 | 106 | # Build pooling layer results in 6x6 output. 107 | pool6 = tf.nn.avg_pool(x, 108 | name='block5/pool6', 109 | ksize=[1,pool_k//6,pool_k//6,1], 110 | strides=[1,pool_k//6,pool_k//6,1], 111 | padding='VALID') 112 | pool6 = nn.conv(pool6, 113 | 'block5/pool6/conv1', 114 | 512, 115 | 1, 116 | 1, 117 | padding='SAME', 118 | biased=False, 119 | bn=True, 120 | relu=True, 121 | is_training=is_training, 122 | decay=0.99, 123 | use_global_status=use_global_status) 124 | pool6 = tf.image.resize_bilinear(pool6, [pool_k, pool_k]) 125 | 126 | # Fuse the pooled feature maps with its input, and generate 127 | # segmentation prediction. 128 | x = tf.concat([pool1, pool2, pool3, pool6, x], 129 | name='block5/concat', 130 | axis=3) 131 | x = nn.conv(x, 132 | 'block5/conv2', 133 | 512, 134 | 3, 135 | 1, 136 | padding='SAME', 137 | biased=False, 138 | bn=True, 139 | relu=True, 140 | is_training=is_training, 141 | decay=0.99, 142 | use_global_status=use_global_status) 143 | x = nn.conv(x, 144 | 'block5/fc1_voc12', 145 | num_classes, 146 | 1, 147 | 1, 148 | padding='SAME', 149 | biased=True, 150 | bn=False, 151 | relu=False, 152 | is_training=is_training) 153 | 154 | return x 155 | 156 | 157 | def pspnet_resnet101(x, 158 | num_classes, 159 | is_training, 160 | use_global_status, 161 | reuse=False): 162 | """Helper function to build PSPNet model for semantic segmentation. 163 | 164 | The PSPNet model is composed of one base network (ResNet101) and 165 | one pyramid spatial pooling (PSP) module, followed with concatenation 166 | and two more convlutional layers for segmentation prediction. 167 | 168 | Args: 169 | x: A tensor of size [batch_size, height_in, width_in, channels]. 170 | num_classes: Number of predicted classes for classification tasks. 171 | is_training: If the tensorflow variables defined in this network 172 | would be used for training. 173 | use_global_status: enable/disable use_global_status for batch 174 | normalization. If True, moving mean and moving variance are updated 175 | by exponential decay. 176 | reuse: enable/disable reuse for reusing tensorflow variables. It is 177 | useful for sharing weight parameters across two identical networks. 178 | 179 | Returns: 180 | A tensor of size [batch_size, height_in/8, width_in/8, num_classes]. 181 | """ 182 | 183 | scores = [] 184 | with tf.name_scope('scale_0') as scope: 185 | score = _pspnet_builder( 186 | x, 187 | 'resnet_v1_101', 188 | resnet_v1_101, 189 | num_classes, 190 | is_training, 191 | use_global_status, 192 | reuse=reuse) 193 | 194 | scores.append(score) 195 | 196 | return scores 197 | -------------------------------------------------------------------------------- /seg_models/models/pspnet_mgpu.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | 3 | from network.multigpu.resnet_v1 import resnet_v1_101 4 | import network.multigpu.layers as nn_mgpu 5 | from network.multigpu.utils import on_each_gpu 6 | 7 | 8 | @on_each_gpu 9 | def avg_pools(x, 10 | name, 11 | kernel_size, 12 | strides, 13 | padding): 14 | k = kernel_size 15 | s = strides 16 | return tf.nn.avg_pool(x, 17 | name=name, 18 | ksize=[1,k,k,1], 19 | strides=[1,s,s,1], 20 | padding=padding) 21 | 22 | 23 | @on_each_gpu 24 | def upsample_bilinears(x, 25 | new_h, 26 | new_w): 27 | return tf.image.resize_bilinear(x, [new_h, new_w]) 28 | 29 | 30 | def _pspnet_builder(xs, 31 | name, 32 | cnn_fn, 33 | num_classes, 34 | is_training, 35 | use_global_status, 36 | reuse=False): 37 | """Helper function to build PSPNet model for semantic segmentation. 38 | 39 | The PSPNet model is composed of one base network (ResNet101) and 40 | one pyramid spatial pooling (PSP) module, followed with concatenation 41 | and two more convlutional layers for segmentation prediction. 42 | 43 | Args: 44 | x: A tensor of size [batch_size, height_in, width_in, channels]. 45 | name: The prefix of tensorflow variables defined in this network. 46 | cnn_fn: A function which builds the base network (ResNet101). 47 | num_classes: Number of predicted classes for classification tasks. 48 | is_training: If the tensorflow variables defined in this network 49 | would be used for training. 50 | use_global_status: enable/disable use_global_status for batch 51 | normalization. If True, moving mean and moving variance are updated 52 | by exponential decay. 53 | reuse: enable/disable reuse for reusing tensorflow variables. It is 54 | useful for sharing weight parameters across two identical networks. 55 | 56 | Returns: 57 | A tensor of size [batch_size, height_in/8, width_in/8, num_classes]. 58 | """ 59 | # Ensure that the size of input data is valid (should be multiple of 6x8=48). 60 | h, w = xs[0].get_shape().as_list()[1:3] # NxHxWxC 61 | assert(h%48 == 0 and w%48 == 0 and h == w) 62 | 63 | # Build the base network. 64 | xs = cnn_fn(xs, name, is_training, use_global_status, reuse) 65 | 66 | with tf.variable_scope(name, reuse=reuse) as scope: 67 | # Build the PSP module 68 | pool_k = int(h/8) # the base network is stride 8 by default. 69 | 70 | # Build pooling layer results in 1x1 output. 71 | pool1s = avg_pools(xs, 72 | 'block5/pool1', 73 | pool_k, 74 | pool_k, 75 | 'VALID') 76 | pool1s = nn_mgpu.conv(pool1s, 77 | 'block5/pool1/conv1', 78 | 512, 79 | 1, 80 | 1, 81 | padding='SAME', 82 | biased=False, 83 | bn=True, 84 | relu=True, 85 | is_training=is_training, 86 | decay=0.99, 87 | use_global_status=use_global_status) 88 | pool1s = upsample_bilinears(pool1s, pool_k, pool_k) 89 | 90 | # Build pooling layer results in 2x2 output. 91 | pool2s = avg_pools(xs, 92 | 'block5/pool2', 93 | pool_k//2, 94 | pool_k//2, 95 | 'VALID') 96 | pool2s = nn_mgpu.conv(pool2s, 97 | 'block5/pool2/conv1', 98 | 512, 99 | 1, 100 | 1, 101 | padding='SAME', 102 | biased=False, 103 | bn=True, 104 | relu=True, 105 | is_training=is_training, 106 | decay=0.99, 107 | use_global_status=use_global_status) 108 | pool2s = upsample_bilinears(pool2s, pool_k, pool_k) 109 | 110 | # Build pooling layer results in 3x3 output. 111 | pool3s = avg_pools(xs, 112 | 'block5/pool3', 113 | pool_k//3, 114 | pool_k//3, 115 | 'VALID') 116 | pool3s = nn_mgpu.conv(pool3s, 117 | 'block5/pool3/conv1', 118 | 512, 119 | 1, 120 | 1, 121 | padding='SAME', 122 | biased=False, 123 | bn=True, 124 | relu=True, 125 | is_training=is_training, 126 | decay=0.99, 127 | use_global_status=use_global_status) 128 | pool3s = upsample_bilinears(pool3s, pool_k, pool_k) 129 | 130 | # Build pooling layer results in 6x6 output. 131 | pool6s = avg_pools(xs, 132 | 'block5/pool6', 133 | pool_k//6, 134 | pool_k//6, 135 | 'VALID') 136 | pool6s = nn_mgpu.conv(pool6s, 137 | 'block5/pool6/conv1', 138 | 512, 139 | 1, 140 | 1, 141 | padding='SAME', 142 | biased=False, 143 | bn=True, 144 | relu=True, 145 | is_training=is_training, 146 | decay=0.99, 147 | use_global_status=use_global_status) 148 | pool6s = upsample_bilinears(pool6s, pool_k, pool_k) 149 | 150 | # Fuse the pooled feature maps with its input, and generate 151 | # segmentation prediction. 152 | xs = nn_mgpu.concat( 153 | [pool1s, pool2s, pool3s, pool6s, xs], 154 | name='block5/concat', 155 | axis=3) 156 | xs = nn_mgpu.conv(xs, 157 | 'block5/conv2', 158 | 512, 159 | 3, 160 | 1, 161 | padding='SAME', 162 | biased=False, 163 | bn=True, 164 | relu=True, 165 | is_training=is_training, 166 | decay=0.99, 167 | use_global_status=use_global_status) 168 | xs = nn_mgpu.conv(xs, 169 | 'block5/fc1_voc12', 170 | num_classes, 171 | 1, 172 | 1, 173 | padding='SAME', 174 | biased=True, 175 | bn=False, 176 | relu=False, 177 | is_training=is_training) 178 | 179 | return xs 180 | 181 | 182 | def pspnet_resnet101(xs, 183 | num_classes, 184 | is_training, 185 | use_global_status, 186 | reuse=False): 187 | """Helper function to build PSPNet model for semantic segmentation. 188 | 189 | The PSPNet model is composed of one base network (ResNet101) and 190 | one pyramid spatial pooling (PSP) module, followed with concatenation 191 | and two more convlutional layers for segmentation prediction. 192 | 193 | Args: 194 | x: A tensor of size [batch_size, height_in, width_in, channels]. 195 | num_classes: Number of predicted classes for classification tasks. 196 | is_training: If the tensorflow variables defined in this network 197 | would be used for training. 198 | use_global_status: enable/disable use_global_status for batch 199 | normalization. If True, moving mean and moving variance are updated 200 | by exponential decay. 201 | reuse: enable/disable reuse for reusing tensorflow variables. It is 202 | useful for sharing weight parameters across two identical networks. 203 | 204 | Returns: 205 | A tensor of size [batch_size, height_in/8, width_in/8, num_classes]. 206 | """ 207 | 208 | num_gpu = len(xs) 209 | scores = [] 210 | with tf.name_scope('scale_0') as scope: 211 | score = _pspnet_builder( 212 | xs, 213 | 'resnet_v1_101', 214 | resnet_v1_101, 215 | num_classes, 216 | is_training, 217 | use_global_status, 218 | reuse=reuse) 219 | for s in score: 220 | scores.append([s]) 221 | 222 | return scores 223 | -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jyhjinghwang/SegSort/4e3278b2a7732d62f3784eb629a323acc0ca68f3/utils/__init__.py -------------------------------------------------------------------------------- /utils/general.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from PIL import Image 4 | import numpy as np 5 | import scipy.io 6 | import tensorflow as tf 7 | 8 | LABEL_COLORS = scipy.io.loadmat('misc/colormapvoc.mat')['colormapvoc'] 9 | LABEL_COLORS *= 255 10 | LABEL_COLORS = LABEL_COLORS.astype(np.uint8) 11 | 12 | 13 | def decode_labels(labels, num_classes=21): 14 | """Encodes label indices to color maps. 15 | 16 | Args: 17 | labels: A tensor of size [batch_size, height_in, width_in, 1] 18 | num_classes: A number indicating number of valid classes. 19 | 20 | Returns: 21 | A tensor of size [batch_size, height_in, width_in, 3] 22 | """ 23 | n, h, w, c = labels.shape 24 | outputs = np.zeros((n, h, w, 3), dtype=np.uint8) 25 | for i in range(n): 26 | outputs[i] = LABEL_COLORS[labels[i,:,:,0]] 27 | 28 | return outputs 29 | 30 | 31 | def inv_preprocess(imgs, img_mean): 32 | """Inverses image preprocessing of the input images. 33 | 34 | This function adds back the mean vector and convert BGR to RGB. 35 | 36 | Args: 37 | imgs: A tensor of size [batch_size, height_in, width_in, 3] 38 | img_mean: A 1-D tensor indicating the vector of mean colour values. 39 | 40 | Returns: 41 | A tensor of size [batch_size, height_in, width_in, 3] 42 | """ 43 | n, h, w, c = imgs.shape 44 | outputs = np.zeros((n, h, w, c), dtype=np.uint8) 45 | for i in range(n): 46 | outputs[i] = (imgs[i] + img_mean).astype(np.uint8) 47 | 48 | return outputs 49 | 50 | 51 | def snapshot_arg(args): 52 | """Print and snapshots Command-Line arguments to a text file. 53 | """ 54 | snap_dir = args.snapshot_dir 55 | dictargs = vars(args) 56 | if not os.path.isdir(snap_dir): 57 | os.makedirs(snap_dir) 58 | print('-----------------------------------------------') 59 | print('-----------------------------------------------') 60 | with open(os.path.join(snap_dir, 'config'), 'w') as argsfile: 61 | for key, val in dictargs.items(): 62 | line = '| {0} = {1}'.format(key, val) 63 | print(line) 64 | argsfile.write(line+'\n') 65 | print('-----------------------------------------------') 66 | print('-----------------------------------------------') 67 | -------------------------------------------------------------------------------- /utils/metrics.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | def iou_stats(pred, target, num_classes=21, background=0): 5 | """Computes statistics of true positive (TP), false negative (FN) and 6 | false positive (FP). 7 | 8 | Args: 9 | pred: A numpy array. 10 | target: A numpy array which should be in the same size as pred. 11 | num_classes: A number indicating the number of valid classes. 12 | background: A number indicating the class index of the back ground. 13 | 14 | Returns: 15 | Three num_classes-D vector indicating the statistics of (TP+FN), (TP+FP) 16 | and TP across each class. 17 | """ 18 | # Set redundant classes to background. 19 | locs = np.logical_and(target > -1, target < num_classes) 20 | 21 | # true positive + false negative 22 | tp_fn, _ = np.histogram(target[locs], 23 | bins=np.arange(num_classes+1)) 24 | # true positive + false positive 25 | tp_fp, _ = np.histogram(pred[locs], 26 | bins=np.arange(num_classes+1)) 27 | # true positive 28 | tp_locs = np.logical_and(locs, pred == target) 29 | tp, _ = np.histogram(target[tp_locs], 30 | bins=np.arange(num_classes+1)) 31 | 32 | return tp_fn, tp_fp, tp 33 | 34 | 35 | def confusion_matrix(pred, target, num_classes=21): 36 | """Computes the confusion matrix between prediction and ground-truth. 37 | 38 | Args: 39 | pred: A numpy array. 40 | target: A numpy array which should be in the same size as pred. 41 | num_classes: A number indicating the number of valid classes. 42 | 43 | Returns: 44 | A (num_classes)x(num_classes) 2-D array, in which each row denotes 45 | ground-truth class, and each column represents predicted class. 46 | """ 47 | mat = np.zeros((num_classes, num_classes)) 48 | for c in range(num_classes): 49 | mask = target == c 50 | if mask.any(): 51 | vec, _ = np.histogram(pred[mask], 52 | bins=np.arange(num_classes+1)) 53 | mat[c, :] += vec 54 | 55 | return mat 56 | 57 | 58 | def accuracy(pred, target): 59 | """Computes pixel accuracy. 60 | 61 | acc = true_positive / (true_positive + false_positive) 62 | 63 | Args: 64 | pred: A numpy array. 65 | target: A numpy array which should be in the same size as pred. 66 | 67 | Returns: 68 | A number indicating the average accuracy. 69 | """ 70 | N = pred.shape[0] 71 | return (pred == target).sum() * 1.0 / N 72 | --------------------------------------------------------------------------------