├── .gitignore ├── LICENSE ├── README.md ├── Visual.jpg ├── awesome-SegDA └── README.md ├── compute_iou.py ├── convert_robot_color.py ├── dataset ├── __init__.py ├── autoaugment.py ├── cityscapes_dataset.py ├── cityscapes_list │ ├── .DS_Store │ ├── info.json │ ├── label.txt │ ├── train.txt │ └── val.txt ├── cityscapes_pseudo_dataset.py ├── cityscapes_train_dataset.py ├── gta5_dataset.py ├── gta5_list │ └── train.txt ├── robot_dataset.py ├── robot_list │ ├── info.json │ ├── label.txt │ ├── train.txt │ └── val.txt ├── robot_pseudo_dataset.py ├── synthia_dataset.py └── synthia_list │ └── train.txt ├── evaluate_cityscapes.py ├── evaluate_gta5.py ├── evaluate_robot.py ├── generate_plabel_cityscapes.py ├── generate_plabel_cityscapes_SYNTHIA.py ├── generate_plabel_robot.py ├── model ├── __init__.py ├── deeplab.py ├── deeplab_multi.py ├── deeplab_vgg.py ├── discriminator.py └── ms_discriminator.py ├── pdf ├── .gitkeep ├── Zheng-Yang2021_Article_RectifyingPseudoLabelLearningV.pdf └── ijcai20.pdf ├── sitemap.xml ├── test.py ├── train_ft.py ├── train_ft_robot.py ├── train_ft_synthia.py ├── train_ms.py ├── train_ms_robot.py ├── train_ms_synthia.py ├── trainer_ms.py ├── trainer_ms_variance.py ├── try_run.py ├── utils ├── __init__.py ├── autoaugment.py ├── clear_model.py ├── loss.py └── tool.py └── visualize_noisy_label.py /.gitignore: -------------------------------------------------------------------------------- 1 | *.pyc 2 | *.pth 3 | *.png 4 | *.jpg 5 | *.yaml 6 | 7 | log/ 8 | data/ 9 | result/ 10 | snapshots/ 11 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 Zhedong Zheng 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ## Seg_Uncertainty 2 | ![Python 3.6](https://img.shields.io/badge/python-3.6-green.svg) 3 | [![License: MIT](https://img.shields.io/badge/License-MIT-green.svg)](https://opensource.org/licenses/MIT) 4 | 5 | ![](https://github.com/layumi/Seg_Uncertainty/blob/master/Visual.jpg) 6 | 7 | [Zhedong Zheng](zdzheng.xyz), [Yi Yang](https://reler.net) 8 | 9 | In this repo, we provide the code for the two papers, i.e., 10 | 11 | - MRNet:[Unsupervised Scene Adaptation with Memory Regularization in vivo](https://arxiv.org/pdf/1912.11164.pdf), IJCAI (2020) 12 | 13 | - MRNet+Rectifying: [Rectifying Pseudo Label Learning via Uncertainty Estimation for Domain Adaptive Semantic Segmentation](https://arxiv.org/pdf/2003.03773.pdf), IJCV (2021) [[中文介绍]](https://zhuanlan.zhihu.com/p/130220572) [[Poster]](https://zdzheng.xyz/files/valse_ijcv_poster.pdf) 14 | 15 | - [[中文介绍视频]](https://www.bilibili.com/video/BV14p4y1s77p) 16 | 17 | ## Initial Model 18 | The original DeepLab link of ucmerced is failed. Please use the following link. 19 | 20 | [Google Drive] https://drive.google.com/file/d/1BMTTMCNkV98pjZh_rU0Pp47zeVqF3MEc/view?usp=share_link 21 | 22 | [One Drive] https://1drv.ms/u/s!Avx-MJllNj5b3SqR7yurCxTgIUOK?e=A1dq3m 23 | 24 | or use 25 | ``` 26 | pip install gdown 27 | pip install --upgrade gdown 28 | gdown 1BMTTMCNkV98pjZh_rU0Pp47zeVqF3MEc 29 | ``` 30 | 31 | 32 | ## Table of contents 33 | * [CommonQ&A](#common-qa) 34 | * [The Core Code](#the-core-code) 35 | * [Prerequisites](#prerequisites) 36 | * [Prepare Data](#prepare-data) 37 | * [Training](#training) 38 | * [Testing](#testing) 39 | * [Trained Model](#trained-model) 40 | * [Related Works](#related-works) 41 | * [Citation](#citation) 42 | 43 | ### News 44 | - [19 Jan 2024] We further apply the uncertainty to compositional image retrieval. The paper is accepted by ICLR'24 [[code]](https://github.com/Monoxide-Chen/uncertainty_retrieval). 45 | - [27 Jan 2023] You are welcomed to check our new transformer-based work [PiPa](https://github.com/chen742/PiPa), which achieves **75.6** mIoU on GTA5->Cityscapes. 46 | - [5 Sep 2021] Zheng etal. apply the Uncertainty to domain adaptive reid, and also achieve good performance. "Exploiting Sample Uncertainty for Domain Adaptive Person Re-Identification" Kecheng Zheng, Cuiling Lan, Wenjun Zeng, Zhizheng Zhang, and Zheng-Jun Zha. AAAI 2021 47 | 48 | - [13 Aug 2021] We release one new method by Adaptive Boosting (AdaBoost) for Domain Adaptation. You may check the project at https://github.com/layumi/AdaBoost_Seg 49 | 50 | ### Common Q&A 51 | 1. Why KLDivergence is always non-negative (>=0)? 52 | 53 | Please check the wikipedia at (https://en.wikipedia.org/wiki/Kullback–Leibler_divergence#Properties) . It provides one good demonstration. 54 | 55 | 2. Why both log_sm and sm are used? 56 | 57 | You may check the pytorch doc at https://pytorch.org/docs/stable/generated/torch.nn.KLDivLoss.html?highlight=nn%20kldivloss#torch.nn.KLDivLoss. 58 | I follow the discussion at https://discuss.pytorch.org/t/kl-divergence-loss/65393 59 | 60 | ### The Core Code 61 | Core code is relatively simple, and could be directly applied to other works. 62 | - Memory in vivo: https://github.com/layumi/Seg_Uncertainty/blob/master/trainer_ms.py#L232 63 | 64 | - Recitfying Pseudo label: https://github.com/layumi/Seg_Uncertainty/blob/master/trainer_ms_variance.py#L166 65 | 66 | ### Prerequisites 67 | - Python 3.6 68 | - GPU Memory >= 11G (e.g., GTX2080Ti or GTX1080Ti) 69 | - Pytorch or [Paddlepaddle](https://www.paddlepaddle.org.cn/) 70 | 71 | 72 | ### Prepare Data 73 | Download [GTA5] and [Cityscapes] to run the basic code. 74 | Alternatively, you could download extra two datasets from [SYNTHIA] and [OxfordRobotCar]. 75 | 76 | - Download [The GTA5 Dataset]( https://download.visinf.tu-darmstadt.de/data/from_games/ ) 77 | 78 | - Download [The SYNTHIA Dataset]( http://synthia-dataset.net/download/808/) SYNTHIA-RAND-CITYSCAPES (CVPR16) 79 | 80 | - Download [The Cityscapes Dataset]( https://www.cityscapes-dataset.com/ ) 81 | 82 | - Download [The Oxford RobotCar Dataset]( http://www.nec-labs.com/~mas/adapt-seg/adapt-seg.html ) 83 | 84 | The data folder is structured as follows: 85 | ``` 86 | ├── data/ 87 | │ ├── Cityscapes/ 88 | | | ├── data/ 89 | | | ├── gtFine/ 90 | | | ├── leftImg8bit/ 91 | │ ├── GTA5/ 92 | | | ├── images/ 93 | | | ├── labels/ 94 | | | ├── ... 95 | │ ├── synthia/ 96 | | | ├── RGB/ 97 | | | ├── GT/ 98 | | | ├── Depth/ 99 | | | ├── ... 100 | │ └── Oxford_Robot_ICCV19 101 | | | ├── train/ 102 | | | ├── ... 103 | ``` 104 | 105 | ### Training 106 | Stage-I: 107 | ```bash 108 | python train_ms.py --snapshot-dir ./snapshots/SE_GN_batchsize2_1024x512_pp_ms_me0_classbalance7_kl0.1_lr2_drop0.1_seg0.5 --drop 0.1 --warm-up 5000 --batch-size 2 --learning-rate 2e-4 --crop-size 1024,512 --lambda-seg 0.5 --lambda-adv-target1 0.0002 --lambda-adv-target2 0.001 --lambda-me-target 0 --lambda-kl-target 0.1 --norm-style gn --class-balance --only-hard-label 80 --max-value 7 --gpu-ids 0,1 --often-balance --use-se 109 | ``` 110 | 111 | Generate Pseudo Label: 112 | ```bash 113 | python generate_plabel_cityscapes.py --restore-from ./snapshots/SE_GN_batchsize2_1024x512_pp_ms_me0_classbalance7_kl0.1_lr2_drop0.1_seg0.5/GTA5_25000.pth 114 | ``` 115 | 116 | Stage-II (with recitfying pseudo label): 117 | ```bash 118 | python train_ft.py --snapshot-dir ./snapshots/1280x640_restore_ft_GN_batchsize9_512x256_pp_ms_me0_classbalance7_kl0_lr1_drop0.2_seg0.5_BN_80_255_0.8_Noaug --restore-from ./snapshots/SE_GN_batchsize2_1024x512_pp_ms_me0_classbalance7_kl0.1_lr2_drop0.1_seg0.5/GTA5_25000.pth --drop 0.2 --warm-up 5000 --batch-size 9 --learning-rate 1e-4 --crop-size 512,256 --lambda-seg 0.5 --lambda-adv-target1 0 --lambda-adv-target2 0 --lambda-me-target 0 --lambda-kl-target 0 --norm-style gn --class-balance --only-hard-label 80 --max-value 7 --gpu-ids 0,1,2 --often-balance --use-se --input-size 1280,640 --train_bn --autoaug False 119 | ``` 120 | *** If you want to run the code without rectifying pseudo label, please change [[this line]](https://github.com/layumi/Seg_Uncertainty/blob/master/train_ft.py#L20) to 'from trainer_ms import AD_Trainer', which would apply the conventional pseudo label learning. *** 121 | 122 | ### Testing 123 | ```bash 124 | python evaluate_cityscapes.py --restore-from ./snapshots/1280x640_restore_ft_GN_batchsize9_512x256_pp_ms_me0_classbalance7_kl0_lr1_drop0.2_seg0.5_BN_80_255_0.8_Noaug/GTA5_25000.pth 125 | ``` 126 | 127 | ### Trained Model 128 | The trained model is available at https://drive.google.com/file/d/1smh1sbOutJwhrfK8dk-tNvonc0HLaSsw/view?usp=sharing 129 | 130 | - The folder with `SY` in name is for SYNTHIA-to-Cityscapes 131 | - The folder with `RB` in name is for Cityscapes-to-Robot Car 132 | 133 | ### One Note for SYNTHIA-to-Cityscapes 134 | Note that the evaluation code I provided for SYNTHIA-to-Cityscapes is still average the IoU by divide 19. 135 | Actually, you need to re-calculate the value by divide 16. There are only 16 shared classes for SYNTHIA-to-Cityscapes. 136 | In this way, the result is same as the value reported in paper. 137 | 138 | ### Related Works 139 | We also would like to thank great works as follows: 140 | - https://github.com/wasidennis/AdaptSegNet 141 | - https://github.com/RoyalVane/CLAN 142 | - https://github.com/yzou2/CRST 143 | 144 | ### Citation 145 | ```bibtex 146 | @inproceedings{zheng2020unsupervised, 147 | title={Unsupervised Scene Adaptation with Memory Regularization in vivo}, 148 | author={Zheng, Zhedong and Yang, Yi}, 149 | booktitle={IJCAI}, 150 | year={2020} 151 | } 152 | @article{zheng2021rectifying, 153 | title={Rectifying Pseudo Label Learning via Uncertainty Estimation for Domain Adaptive Semantic Segmentation }, 154 | author={Zheng, Zhedong and Yang, Yi}, 155 | journal={International Journal of Computer Vision (IJCV)}, 156 | doi={10.1007/s11263-020-01395-y}, 157 | note={\mbox{doi}:\url{10.1007/s11263-020-01395-y}}, 158 | year={2021} 159 | } 160 | ``` 161 | -------------------------------------------------------------------------------- /Visual.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/layumi/Seg-Uncertainty/6fce9eae141c2c0592b3e7c1b3e5f8ee7b1ce9a6/Visual.jpg -------------------------------------------------------------------------------- /awesome-SegDA/README.md: -------------------------------------------------------------------------------- 1 | ## Awesome Segmentation Domain Adaptation [![Awesome](https://awesome.re/badge.svg)](https://awesome.re) 2 | 3 | If you notice any result or the public code that has not been included in this page, please connect [Zhedong Zheng](mailto:zdzheng12@gmail.com) without hesitation to add the method. You are welcomed! 4 | or create pull request. 5 | 6 | Priorities are given to papers whose codes are published. 7 | 8 | **Arxiv** 9 | - Cross-Region Domain Adaptation for Class-level Alignment [[14 Sep 2021]](https://arxiv.org/pdf/2109.06422.pdf) 10 | - Contrastive Learning and Self-Training for Unsupervised Domain Adaptation in Semantic Segmentation [[5 May 2021]](https://arxiv.org/abs/2105.02001) 11 | - ACDC: The Adverse Conditions Dataset with Correspondences for Semantic Driving Scene Understanding [[29 April 2021]](https://arxiv.org/abs/2104.13395) 12 | - Domain Adaptive Semantic Segmentation with Self-Supervised Depth Estimation [[28 Apr 2021]](https://arxiv.org/abs/2104.13613) 13 | - Class-Conditional Domain Adaptation on Semantic Segmentation [[27 Nov 2019]](https://arxiv.org/abs/1911.11981v1) 14 | - Domain Bridge for Unpaired Image-to-Image Translation and Unsupervised Domain Adaptation [[23 Oct 2019]](https://arxiv.org/abs/1910.10563) 15 | - Restyling Data: Application to Unsupervised Domain Adaptation [[24 Sep 2019]](https://arxiv.org/abs/1909.10900) 16 | - Adversarial Learning and Self-Teaching Techniques for Domain Adaptation in Semantic Segmentation [[2 Sep 2019]](https://arxiv.org/abs/1909.00781v1) 17 | - FCNs in the Wild: Pixel-level Adversarial and Constraint-based Adaptation [[8 Dec 2016]](https://arxiv.org/abs/1612.02649) 18 | 19 | **Journal** 20 | - Adaptive Boosting for Domain Adaptation: Towards Robust Predictions in Scene Segmentation [[TIP 2022]](https://arxiv.org/abs/2103.15685)[[code]](https://github.com/layumi/AdaBoost_Seg) 21 | - Rectifying Pseudo Label Learning via Uncertainty Estimation for Domain Adaptive Semantic Segmentation 22 | [[IJCV 2021]](https://arxiv.org/abs/2003.03773) 23 | - Affinity Space Adaptation for Semantic Segmentation Across Domains [[TIP2020]](https://arxiv.org/pdf/2009.12559.pdf) 24 | - Semantic-aware short path adversarial training for cross-domain semantic segmentation [[Neurocomputing 2019]](https://www.sciencedirect.com/science/article/pii/S0925231219315656#fig0002) 25 | - Weakly Supervised Adversarial Domain Adaptation for Semantic Segmentation in Urban Scenes [[TIP 2019]](https://arxiv.org/abs/1904.09092v1) 26 | 27 | **Conference** 28 | - PiPa: Pixel- and Patch-wise Self-supervised Learning for Domain Adaptative Semantic Segmentation [[ACM MM2023]](https://arxiv.org/abs/2211.07609) [[Code]](https://github.com/chen742/PiPa) 29 | - DAformer: Improving Network Architectures and Training Strategies for Domain-Adaptive Semantic Segmentation [[CVPR2022]] 30 | - Class-balanced pixel-level self-labeling for domain adaptive semantic segmentation [[CVPR2022]] 31 | - Undoing the damage of label shift for cross-domain semantic segmentation [[CVPR2022]] 32 | - Adas: A direct adaptation strategy for multi-target domain adaptive semantic segmentation [[CVPR2022]] 33 | - Generalize Then Adapt: Source-Free Domain Adaptive Semantic Segmentation [[ICCV2021]](https://openaccess.thecvf.com/content/ICCV2021/papers/Kundu_Generalize_Then_Adapt_Source-Free_Domain_Adaptive_Semantic_Segmentation_ICCV_2021_paper.pdf) 34 | - Prototypical Pseudo Label Denoising and Target Structure Learning for Domain Adaptive Semantic Segmentation [[CVPR2021]](https://arxiv.org/abs/2101.10979) 35 | - Improving Domain Generalization in Urban-Scene Segmentationvia Instance Selective Whitening [[CVPR2021]](https://arxiv.org/abs/2103.15597) 36 | - Rethinking Ensemble-Distillation for Semantic Segmentation Based Unsupervised Domain Adaptation [[CVPRW 2021]](https://arxiv.org/abs/2104.14203) 37 | - Instance Adaptive Self-Training for Unsupervised Domain Adaptation [[ECCV 2020]](https://arxiv.org/abs/2008.12197) 38 | - Two-phase Pseudo Label Densification for Self-training based Domain Adaptation [[ECCV 2020]](https://arxiv.org/abs/2012.04828) 39 | - Unsupervised Scene Adaptation with Memory Regularization in vivo [[IJCAI 2020]](https://arxiv.org/abs/1912.11164) [[code]](https://github.com/layumi/Seg-Uncertainty) 40 | - Adversarial Style Mining for One-Shot Unsupervised Domain Adaptation [[NeurIPS 2020]](https://proceedings.neurips.cc/paper/2020/hash/ed265bc903a5a097f61d3ec064d96d2e-Abstract.html)[[Pytorch]](https://github.com/RoyalVane/ASM) 41 | - Content-Consistent Matching for Domain Adaptive Semantic Segmentation [[ECCV 2020]](https://www.ecva.net/papers/eccv_2020/papers_ECCV/papers/123590426.pdf) [[Code]](https://github.com/Solacex/CCM) 42 | - Differential Treatment for Stuff and Things: A Simple Unsupervised Domain Adaptation Method for Semantic Segmentation [[CVPR 2020]](https://arxiv.org/pdf/2003.08040.pdf) [[Code]](https://github.com/SHI-Labs/Unsupervised-Domain-Adaptation-with-Differential-Treatment) 43 | - Contextual-Relation Consistent Domain Adaptation for Semantic Segmentation[[ECCV 2020]](https://arxiv.org/pdf/2007.02424.pdf) 44 | - An Adversarial Perturbation Oriented Domain Adaptation Approach for Semantic Segmentation [[AAAI2020]](https://arxiv.org/abs/1912.08954v1) 45 | - Category Anchor-Guided Unsupervised Domain Adaptation for Semantic Segmentation [[NeurIPS2019]](https://arxiv.org/abs/1910.13049)) [[code]](https://github.com/RogerZhangzz/CAG_UDA) 46 | - MLSL: Multi-Level Self-Supervised Learning for Domain Adaptation with Spatially Independent and Semantically Consistent Labeling [[WACV2020]](https://arxiv.org/abs/1909.13776) 47 | - Guided Curriculum Model Adaptation and Uncertainty-Aware Evaluation for 48 | Semantic Nighttime Image Segmentation [[ICCV2019]](http://openaccess.thecvf.com/content_ICCV_2019/papers/Sakaridis_Guided_Curriculum_Model_Adaptation_and_Uncertainty-Aware_Evaluation_for_Semantic_Nighttime_ICCV_2019_paper.pdf) 49 | - Constructing Self-motivated Pyramid Curriculums for Cross-Domain Semantic 50 | Segmentation: A Non-Adversarial Approach [[ICCV2019]](http://openaccess.thecvf.com/content_ICCV_2019/papers/Lian_Constructing_Self-Motivated_Pyramid_Curriculums_for_Cross-Domain_Semantic_Segmentation_A_Non-Adversarial_ICCV_2019_paper.pdf) [[Pytorch]](https://github.com/lianqing11/pycda) 51 | - SSF-DAN: Separated Semantic Feature Based Domain Adaptation Network for Semantic Segmentation [[ICCV2019]](http://openaccess.thecvf.com/content_ICCV_2019/papers/Du_SSF-DAN_Separated_Semantic_Feature_Based_Domain_Adaptation_Network_for_Semantic_ICCV_2019_paper.pdf) 52 | - Significance-aware Information Bottleneck for Domain Adaptive Semantic Segmentation [[ICCV2019]](http://openaccess.thecvf.com/content_ICCV_2019/papers/Luo_Significance-Aware_Information_Bottleneck_for_Domain_Adaptive_Semantic_Segmentation_ICCV_2019_paper.pdf) 53 | - Domain Adaptation for Semantic Segmentation with Maximum Squares Loss [[ICCV2019]](http://openaccess.thecvf.com/content_ICCV_2019/papers/Chen_Domain_Adaptation_for_Semantic_Segmentation_With_Maximum_Squares_Loss_ICCV_2019_paper.pdf) [[Pytorch]](https://github.com/ZJULearning/MaxSquareLoss) 54 | - Self-Ensembling with GAN-based Data Augmentation for Domain Adaptation in Semantic Segmentation [[ICCV2019]](http://openaccess.thecvf.com/content_ICCV_2019/papers/Choi_Self-Ensembling_With_GAN-Based_Data_Augmentation_for_Domain_Adaptation_in_Semantic_ICCV_2019_paper.pdf) 55 | - DADA: Depth-aware Domain Adaptation in Semantic Segmentation [[ICCV2019]](http://openaccess.thecvf.com/content_ICCV_2019/papers/Vu_DADA_Depth-Aware_Domain_Adaptation_in_Semantic_Segmentation_ICCV_2019_paper.pdf) [[code]](https://github.com/valeoai/DADA) 56 | - Domain Adaptation for Structured Output via Discriminative Patch Representations [[ICCV2019 Oral]](http://openaccess.thecvf.com/content_ICCV_2019/papers/Tsai_Domain_Adaptation_for_Structured_Output_via_Discriminative_Patch_Representations_ICCV_2019_paper.pdf) [[Project]](https://sites.google.com/site/yihsuantsai/research/iccv19-adapt-seg) 57 | - Not All Areas Are Equal: Transfer Learning for Semantic Segmentation via Hierarchical Region Selection [[CVPR2019(Oral)(PDF Coming Soon)]](http://openaccess.thecvf.com/content_CVPR_2019/papers/Sun_Not_All_Areas_Are_Equal_Transfer_Learning_for_Semantic_Segmentation_CVPR_2019_paper.pdf) 58 | - CrDoCo: Pixel-level Domain Transfer with Cross-Domain Consistency [[CVPR2019]](http://openaccess.thecvf.com/content_CVPR_2019/papers/Chen_CrDoCo_Pixel-Level_Domain_Transfer_With_Cross-Domain_Consistency_CVPR_2019_paper.pdf) [[Project]](https://yunchunchen.github.io/CrDoCo/) 59 | - Bidirectional Learning for Domain Adaptation of Semantic Segmentation [[CVPR2019]](http://openaccess.thecvf.com/content_CVPR_2019/papers/Li_Bidirectional_Learning_for_Domain_Adaptation_of_Semantic_Segmentation_CVPR_2019_paper.pdf) [[Pytorch]](https://github.com/liyunsheng13/BDL) 60 | - Learning Semantic Segmentation from Synthetic Data: A Geometrically Guided Input-Output Adaptation Approach [[CVPR2019]](http://openaccess.thecvf.com/content_CVPR_2019/papers/Chen_Learning_Semantic_Segmentation_From_Synthetic_Data_A_Geometrically_Guided_Input-Output_CVPR_2019_paper.pdf) 61 | - All about Structure: Adapting Structural Information across Domains for Boosting Semantic Segmentation [[CVPR2019]](http://openaccess.thecvf.com/content_CVPR_2019/papers/Chang_All_About_Structure_Adapting_Structural_Information_Across_Domains_for_Boosting_CVPR_2019_paper.pdf) [[Pytorch]](https://github.com/a514514772/DISE-Domain-Invariant-Structure-Extraction) 62 | - DLOW: Domain Flow for Adaptation and Generalization [[CVPR2019 Oral]](http://openaccess.thecvf.com/content_CVPR_2019/papers/Gong_DLOW_Domain_Flow_for_Adaptation_and_Generalization_CVPR_2019_paper.pdf) 63 | - Taking A Closer Look at Domain Shift: Category-level Adversaries for Semantics Consistent Domain Adaptation [[CVPR2019 Oral]](http://openaccess.thecvf.com/content_CVPR_2019/papers/Luo_Taking_a_Closer_Look_at_Domain_Shift_Category-Level_Adversaries_for_CVPR_2019_paper.pdf) [[Pytorch]](https://github.com/RoyalVane/CLAN) 64 | - ADVENT: Adversarial Entropy Minimization for Domain Adaptation in Semantic Segmentation [[CVPR2019 Oral]](http://openaccess.thecvf.com/content_CVPR_2019/papers/Vu_ADVENT_Adversarial_Entropy_Minimization_for_Domain_Adaptation_in_Semantic_Segmentation_CVPR_2019_paper.pdf) [[Pytorch]](https://github.com/valeoai/ADVENT) 65 | - SPIGAN: Privileged Adversarial Learning from Simulation [[ICLR2019]](https://openreview.net/forum?id=rkxoNnC5FQ) 66 | - Penalizing Top Performers: Conservative Loss for Semantic Segmentation Adaptation [[ECCV2018]](http://openaccess.thecvf.com/content_ECCV_2018/papers/Xinge_Zhu_Penalizing_Top_Performers_ECCV_2018_paper.pdf) 67 | - Domain transfer through deep activation matching [[ECCV2018]](http://openaccess.thecvf.com/content_ECCV_2018/papers/Haoshuo_Huang_Domain_transfer_through_ECCV_2018_paper.pdf) 68 | - Unsupervised Domain Adaptation for Semantic Segmentation via Class-Balanced Self-Training [[ECCV2018]](http://openaccess.thecvf.com/content_ECCV_2018/papers/Yang_Zou_Unsupervised_Domain_Adaptation_ECCV_2018_paper.pdf) 69 | - DCAN: Dual channel-wise alignment networks for unsupervised scene adaptation [[ECCV2018]](https://eccv2018.org/openaccess/content_ECCV_2018/papers/Zuxuan_Wu_DCAN_Dual_Channel-wise_ECCV_2018_paper.pdf) 70 | - Fully convolutional adaptation networks for semantic 71 | segmentation [[CVPR2018]](http://openaccess.thecvf.com/content_cvpr_2018/papers/Zhang_Fully_Convolutional_Adaptation_CVPR_2018_paper.pdf) 72 | - Learning to Adapt Structured Output Space for Semantic Segmentation [[CVPR2018]](http://openaccess.thecvf.com/content_cvpr_2018/papers/Tsai_Learning_to_Adapt_CVPR_2018_paper.pdf) [[Pytorch]](https://github.com/wasidennis/AdaptSegNet) 73 | - Conditional Generative Adversarial Network for Structured Domain Adaptation [[CVPR2018]](http://openaccess.thecvf.com/content_cvpr_2018/papers/Hong_Conditional_Generative_Adversarial_CVPR_2018_paper.pdf) 74 | - Learning From Synthetic Data: Addressing Domain Shift for Semantic Segmentation [[CVPR2018]](http://openaccess.thecvf.com/content_cvpr_2018/papers/Sankaranarayanan_Learning_From_Synthetic_CVPR_2018_paper.pdf) 75 | - Curriculum Domain Adaptation for Semantic Segmentation of Urban Scenes [[ICCV2017]](http://openaccess.thecvf.com/content_ICCV_2017/papers/Zhang_Curriculum_Domain_Adaptation_ICCV_2017_paper.pdf) [[Journal Version]](https://arxiv.org/abs/1812.09953v3) 76 | - No more discrimina- tion: Cross city adaptation of road scene segmenters [[ICCV2017]](http://openaccess.thecvf.com/content_ICCV_2017/supplemental/Chen_No_More_Discrimination_ICCV_2017_supplemental.pdf) 77 | 78 | ### Reference 79 | - https://github.com/zhaoxin94/awesome-domain-adaptation#semantic-segmentation 80 | -------------------------------------------------------------------------------- /compute_iou.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import argparse 3 | import json 4 | from PIL import Image 5 | from os.path import join 6 | 7 | 8 | def fast_hist(a, b, n): 9 | k = (a >= 0) & (a < n) 10 | return np.bincount(n * a[k].astype(int) + b[k], minlength=n ** 2).reshape(n, n) 11 | 12 | 13 | def per_class_iu(hist): 14 | return np.diag(hist) / (hist.sum(1) + hist.sum(0) - np.diag(hist)) 15 | 16 | 17 | def label_mapping(input, mapping): 18 | output = np.copy(input) 19 | for ind in range(len(mapping)): 20 | output[input == mapping[ind][0]] = mapping[ind][1] 21 | return np.array(output, dtype=np.int64) 22 | 23 | 24 | def compute_mIoU(gt_dir, pred_dir, devkit_dir=''): 25 | """ 26 | Compute IoU given the predicted colorized images and 27 | """ 28 | with open(join(devkit_dir, 'info.json'), 'r') as fp: 29 | info = json.load(fp) 30 | num_classes = np.int(info['classes']) 31 | print(('Num classes', num_classes)) 32 | name_classes = np.array(info['label'], dtype=np.str) 33 | mapping = np.array(info['label2train'], dtype=np.int) 34 | hist = np.zeros((num_classes, num_classes)) 35 | 36 | image_path_list = join(devkit_dir, 'val.txt') 37 | label_path_list = join(devkit_dir, 'label.txt') 38 | gt_imgs = open(label_path_list, 'r').read().splitlines() 39 | gt_imgs = [join(gt_dir, x) for x in gt_imgs] 40 | pred_imgs = open(image_path_list, 'r').read().splitlines() 41 | pred_imgs = [join(pred_dir, x.split('/')[-1]) for x in pred_imgs] 42 | 43 | for ind in range(len(gt_imgs)): 44 | pred = np.array(Image.open(pred_imgs[ind])) 45 | label = np.array(Image.open(gt_imgs[ind])) 46 | label = label_mapping(label, mapping) 47 | if len(label.shape) == 3 and label.shape[2]==4: 48 | label = label[:,:,0] 49 | if len(label.flatten()) != len(pred.flatten()): 50 | print(('Skipping: len(gt) = {:d}, len(pred) = {:d}, {:s}, {:s}'.format(len(label.flatten()), len(pred.flatten()), gt_imgs[ind], pred_imgs[ind]))) 51 | continue 52 | hist += fast_hist(label.flatten(), pred.flatten(), num_classes) 53 | if ind > 0 and ind % 10 == 0: 54 | print(('{:d} / {:d}: {:0.2f}'.format(ind, len(gt_imgs), 100*np.mean(per_class_iu(hist))))) 55 | 56 | mIoUs = per_class_iu(hist) 57 | for ind_class in range(num_classes): 58 | print(('===>' + name_classes[ind_class] + ':\t' + str(round(mIoUs[ind_class] * 100, 2)))) 59 | print(('===> mIoU: ' + str(round(np.nanmean(mIoUs) * 100, 2)))) 60 | return mIoUs 61 | 62 | 63 | def main(args): 64 | compute_mIoU(args.gt_dir, args.pred_dir, args.devkit_dir) 65 | 66 | 67 | if __name__ == "__main__": 68 | parser = argparse.ArgumentParser() 69 | parser.add_argument('gt_dir', type=str, help='directory which stores CityScapes val gt images') 70 | parser.add_argument('pred_dir', type=str, help='directory which stores CityScapes val pred images') 71 | parser.add_argument('--devkit_dir', default='dataset/cityscapes_list', help='base directory of cityscapes') 72 | args = parser.parse_args() 73 | main(args) 74 | -------------------------------------------------------------------------------- /convert_robot_color.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import argparse 3 | import json 4 | from PIL import Image 5 | from os.path import join 6 | from evaluate_robot import colorize_mask,save 7 | from compute_iou import label_mapping 8 | 9 | def main(gt_dir='./data/Oxford_Robot_ICCV19/anno', devkit_dir = './dataset/robot_list/'): 10 | """ 11 | Compute IoU given the predicted colorized images and 12 | """ 13 | with open(join(devkit_dir, 'info.json'), 'r') as fp: 14 | info = json.load(fp) 15 | image_path_list = join(devkit_dir, 'val.txt') 16 | label_path_list = join(devkit_dir, 'label.txt') 17 | mapping = np.array(info['label2train'], dtype=np.int) 18 | gt_imgs = open(label_path_list, 'r').read().splitlines() 19 | gt_imgs = [join(gt_dir, x) for x in gt_imgs] 20 | 21 | for ind in range(len(gt_imgs)): 22 | label = np.array(Image.open(gt_imgs[ind])) 23 | label = label_mapping(label, mapping) 24 | label = label[:,:,0].astype(np.uint8) 25 | name_tmp = gt_imgs[ind].replace('anno','anno_color') 26 | save([label, name_tmp]) 27 | 28 | return 29 | 30 | 31 | 32 | if __name__ == "__main__": 33 | main() 34 | -------------------------------------------------------------------------------- /dataset/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/layumi/Seg-Uncertainty/6fce9eae141c2c0592b3e7c1b3e5f8ee7b1ce9a6/dataset/__init__.py -------------------------------------------------------------------------------- /dataset/autoaugment.py: -------------------------------------------------------------------------------- 1 | from PIL import Image, ImageEnhance, ImageOps 2 | import numpy as np 3 | import random 4 | 5 | 6 | class ImageNetPolicy(object): 7 | """ Randomly choose one of the best 24 Sub-policies on ImageNet. 8 | 9 | Example: 10 | >>> policy = ImageNetPolicy() 11 | >>> transformed = policy(image) 12 | 13 | Example as a PyTorch Transform: 14 | >>> transform=transforms.Compose([ 15 | >>> transforms.Resize(256), 16 | >>> ImageNetPolicy(), 17 | >>> transforms.ToTensor()]) 18 | """ 19 | def __init__(self, fillcolor=(128, 128, 128)): 20 | self.policies = [ 21 | #SubPolicy(0.4, "posterize", 8, 0.6, "rotate", 9, fillcolor), 22 | SubPolicy(0.6, "solarize", 5, 0.6, "autocontrast", 5, fillcolor), 23 | SubPolicy(0.8, "equalize", 8, 0.6, "equalize", 3, fillcolor), 24 | SubPolicy(0.6, "posterize", 7, 0.6, "posterize", 6, fillcolor), 25 | SubPolicy(0.4, "equalize", 7, 0.2, "solarize", 4, fillcolor), 26 | 27 | #SubPolicy(0.4, "equalize", 4, 0.8, "rotate", 8, fillcolor), 28 | SubPolicy(0.6, "solarize", 3, 0.6, "equalize", 7, fillcolor), 29 | SubPolicy(0.8, "posterize", 5, 1.0, "equalize", 2, fillcolor), 30 | #SubPolicy(0.2, "rotate", 3, 0.6, "solarize", 8, fillcolor), 31 | SubPolicy(0.6, "equalize", 8, 0.4, "posterize", 6, fillcolor), 32 | 33 | #SubPolicy(0.8, "rotate", 8, 0.4, "color", 0, fillcolor), 34 | #SubPolicy(0.4, "rotate", 9, 0.6, "equalize", 2, fillcolor), 35 | SubPolicy(0.0, "equalize", 7, 0.8, "equalize", 8, fillcolor), 36 | SubPolicy(0.6, "invert", 4, 1.0, "equalize", 8, fillcolor), 37 | SubPolicy(0.6, "color", 4, 1.0, "contrast", 8, fillcolor), 38 | 39 | #SubPolicy(0.8, "rotate", 8, 1.0, "color", 2, fillcolor), 40 | SubPolicy(0.8, "color", 8, 0.8, "solarize", 7, fillcolor), 41 | SubPolicy(0.4, "sharpness", 7, 0.6, "invert", 8, fillcolor), 42 | #SubPolicy(0.6, "shearX", 5, 1.0, "equalize", 9, fillcolor), 43 | SubPolicy(0.4, "color", 0, 0.6, "equalize", 3, fillcolor), 44 | 45 | SubPolicy(0.4, "equalize", 7, 0.2, "solarize", 4, fillcolor), 46 | SubPolicy(0.6, "solarize", 5, 0.6, "autocontrast", 5, fillcolor), 47 | SubPolicy(0.6, "invert", 4, 1.0, "equalize", 8, fillcolor), 48 | SubPolicy(0.6, "color", 4, 1.0, "contrast", 8, fillcolor), 49 | SubPolicy(0.8, "equalize", 8, 0.6, "equalize", 3, fillcolor) 50 | ] 51 | 52 | 53 | def __call__(self, img): 54 | policy_idx = random.randint(0, len(self.policies) - 1) 55 | return self.policies[policy_idx](img) 56 | 57 | def __repr__(self): 58 | return "AutoAugment ImageNet Policy" 59 | 60 | 61 | class CIFAR10Policy(object): 62 | """ Randomly choose one of the best 25 Sub-policies on CIFAR10. 63 | 64 | Example: 65 | >>> policy = CIFAR10Policy() 66 | >>> transformed = policy(image) 67 | 68 | Example as a PyTorch Transform: 69 | >>> transform=transforms.Compose([ 70 | >>> transforms.Resize(256), 71 | >>> CIFAR10Policy(), 72 | >>> transforms.ToTensor()]) 73 | """ 74 | def __init__(self, fillcolor=(128, 128, 128)): 75 | self.policies = [ 76 | SubPolicy(0.1, "invert", 7, 0.2, "contrast", 6, fillcolor), 77 | SubPolicy(0.7, "rotate", 2, 0.3, "translateX", 9, fillcolor), 78 | SubPolicy(0.8, "sharpness", 1, 0.9, "sharpness", 3, fillcolor), 79 | SubPolicy(0.5, "shearY", 8, 0.7, "translateY", 9, fillcolor), 80 | SubPolicy(0.5, "autocontrast", 8, 0.9, "equalize", 2, fillcolor), 81 | 82 | SubPolicy(0.2, "shearY", 7, 0.3, "posterize", 7, fillcolor), 83 | SubPolicy(0.4, "color", 3, 0.6, "brightness", 7, fillcolor), 84 | SubPolicy(0.3, "sharpness", 9, 0.7, "brightness", 9, fillcolor), 85 | SubPolicy(0.6, "equalize", 5, 0.5, "equalize", 1, fillcolor), 86 | SubPolicy(0.6, "contrast", 7, 0.6, "sharpness", 5, fillcolor), 87 | 88 | SubPolicy(0.7, "color", 7, 0.5, "translateX", 8, fillcolor), 89 | SubPolicy(0.3, "equalize", 7, 0.4, "autocontrast", 8, fillcolor), 90 | SubPolicy(0.4, "translateY", 3, 0.2, "sharpness", 6, fillcolor), 91 | SubPolicy(0.9, "brightness", 6, 0.2, "color", 8, fillcolor), 92 | SubPolicy(0.5, "solarize", 2, 0.0, "invert", 3, fillcolor), 93 | 94 | SubPolicy(0.2, "equalize", 0, 0.6, "autocontrast", 0, fillcolor), 95 | SubPolicy(0.2, "equalize", 8, 0.8, "equalize", 4, fillcolor), 96 | SubPolicy(0.9, "color", 9, 0.6, "equalize", 6, fillcolor), 97 | SubPolicy(0.8, "autocontrast", 4, 0.2, "solarize", 8, fillcolor), 98 | SubPolicy(0.1, "brightness", 3, 0.7, "color", 0, fillcolor), 99 | 100 | SubPolicy(0.4, "solarize", 5, 0.9, "autocontrast", 3, fillcolor), 101 | SubPolicy(0.9, "translateY", 9, 0.7, "translateY", 9, fillcolor), 102 | SubPolicy(0.9, "autocontrast", 2, 0.8, "solarize", 3, fillcolor), 103 | SubPolicy(0.8, "equalize", 8, 0.1, "invert", 3, fillcolor), 104 | SubPolicy(0.7, "translateY", 9, 0.9, "autocontrast", 1, fillcolor) 105 | ] 106 | 107 | 108 | def __call__(self, img): 109 | policy_idx = random.randint(0, len(self.policies) - 1) 110 | return self.policies[policy_idx](img) 111 | 112 | def __repr__(self): 113 | return "AutoAugment CIFAR10 Policy" 114 | 115 | 116 | class SVHNPolicy(object): 117 | """ Randomly choose one of the best 25 Sub-policies on SVHN. 118 | 119 | Example: 120 | >>> policy = SVHNPolicy() 121 | >>> transformed = policy(image) 122 | 123 | Example as a PyTorch Transform: 124 | >>> transform=transforms.Compose([ 125 | >>> transforms.Resize(256), 126 | >>> SVHNPolicy(), 127 | >>> transforms.ToTensor()]) 128 | """ 129 | def __init__(self, fillcolor=(128, 128, 128)): 130 | self.policies = [ 131 | SubPolicy(0.9, "shearX", 4, 0.2, "invert", 3, fillcolor), 132 | SubPolicy(0.9, "shearY", 8, 0.7, "invert", 5, fillcolor), 133 | SubPolicy(0.6, "equalize", 5, 0.6, "solarize", 6, fillcolor), 134 | SubPolicy(0.9, "invert", 3, 0.6, "equalize", 3, fillcolor), 135 | SubPolicy(0.6, "equalize", 1, 0.9, "rotate", 3, fillcolor), 136 | 137 | SubPolicy(0.9, "shearX", 4, 0.8, "autocontrast", 3, fillcolor), 138 | SubPolicy(0.9, "shearY", 8, 0.4, "invert", 5, fillcolor), 139 | SubPolicy(0.9, "shearY", 5, 0.2, "solarize", 6, fillcolor), 140 | SubPolicy(0.9, "invert", 6, 0.8, "autocontrast", 1, fillcolor), 141 | SubPolicy(0.6, "equalize", 3, 0.9, "rotate", 3, fillcolor), 142 | 143 | SubPolicy(0.9, "shearX", 4, 0.3, "solarize", 3, fillcolor), 144 | SubPolicy(0.8, "shearY", 8, 0.7, "invert", 4, fillcolor), 145 | SubPolicy(0.9, "equalize", 5, 0.6, "translateY", 6, fillcolor), 146 | SubPolicy(0.9, "invert", 4, 0.6, "equalize", 7, fillcolor), 147 | SubPolicy(0.3, "contrast", 3, 0.8, "rotate", 4, fillcolor), 148 | 149 | SubPolicy(0.8, "invert", 5, 0.0, "translateY", 2, fillcolor), 150 | SubPolicy(0.7, "shearY", 6, 0.4, "solarize", 8, fillcolor), 151 | SubPolicy(0.6, "invert", 4, 0.8, "rotate", 4, fillcolor), 152 | SubPolicy(0.3, "shearY", 7, 0.9, "translateX", 3, fillcolor), 153 | SubPolicy(0.1, "shearX", 6, 0.6, "invert", 5, fillcolor), 154 | 155 | SubPolicy(0.7, "solarize", 2, 0.6, "translateY", 7, fillcolor), 156 | SubPolicy(0.8, "shearY", 4, 0.8, "invert", 8, fillcolor), 157 | SubPolicy(0.7, "shearX", 9, 0.8, "translateY", 3, fillcolor), 158 | SubPolicy(0.8, "shearY", 5, 0.7, "autocontrast", 3, fillcolor), 159 | SubPolicy(0.7, "shearX", 2, 0.1, "invert", 5, fillcolor) 160 | ] 161 | 162 | 163 | def __call__(self, img): 164 | policy_idx = random.randint(0, len(self.policies) - 1) 165 | return self.policies[policy_idx](img) 166 | 167 | def __repr__(self): 168 | return "AutoAugment SVHN Policy" 169 | 170 | 171 | class SubPolicy(object): 172 | def __init__(self, p1, operation1, magnitude_idx1, p2, operation2, magnitude_idx2, fillcolor=(128, 128, 128)): 173 | ranges = { 174 | "shearX": np.linspace(0, 0.3, 10), 175 | "shearY": np.linspace(0, 0.3, 10), 176 | "translateX": np.linspace(0, 150 / 331, 10), 177 | "translateY": np.linspace(0, 150 / 331, 10), 178 | "rotate": np.linspace(0, 30, 10), 179 | "color": np.linspace(0.0, 0.9, 10), 180 | "posterize": np.round(np.linspace(8, 4, 10), 0).astype(np.int), 181 | "solarize": np.linspace(256, 0, 10), 182 | "contrast": np.linspace(0.0, 0.9, 10), 183 | "sharpness": np.linspace(0.0, 0.9, 10), 184 | "brightness": np.linspace(0.0, 0.9, 10), 185 | "autocontrast": [0] * 10, 186 | "equalize": [0] * 10, 187 | "invert": [0] * 10 188 | } 189 | 190 | # from https://stackoverflow.com/questions/5252170/specify-image-filling-color-when-rotating-in-python-with-pil-and-setting-expand 191 | def rotate_with_fill(img, magnitude): 192 | rot = img.convert("RGBA").rotate(magnitude) 193 | return Image.composite(rot, Image.new("RGBA", rot.size, (128,) * 4), rot).convert(img.mode) 194 | 195 | func = { 196 | "shearX": lambda img, magnitude: img.transform( 197 | img.size, Image.AFFINE, (1, magnitude * random.choice([-1, 1]), 0, 0, 1, 0), 198 | Image.BICUBIC, fillcolor=fillcolor), 199 | "shearY": lambda img, magnitude: img.transform( 200 | img.size, Image.AFFINE, (1, 0, 0, magnitude * random.choice([-1, 1]), 1, 0), 201 | Image.BICUBIC, fillcolor=fillcolor), 202 | "translateX": lambda img, magnitude: img.transform( 203 | img.size, Image.AFFINE, (1, 0, magnitude * img.size[0] * random.choice([-1, 1]), 0, 1, 0), 204 | fillcolor=fillcolor), 205 | "translateY": lambda img, magnitude: img.transform( 206 | img.size, Image.AFFINE, (1, 0, 0, 0, 1, magnitude * img.size[1] * random.choice([-1, 1])), 207 | fillcolor=fillcolor), 208 | "rotate": lambda img, magnitude: rotate_with_fill(img, magnitude), 209 | # "rotate": lambda img, magnitude: img.rotate(magnitude * random.choice([-1, 1])), 210 | "color": lambda img, magnitude: ImageEnhance.Color(img).enhance(1 + magnitude * random.choice([-1, 1])), 211 | "posterize": lambda img, magnitude: ImageOps.posterize(img, magnitude), 212 | "solarize": lambda img, magnitude: ImageOps.solarize(img, magnitude), 213 | "contrast": lambda img, magnitude: ImageEnhance.Contrast(img).enhance( 214 | 1 + magnitude * random.choice([-1, 1])), 215 | "sharpness": lambda img, magnitude: ImageEnhance.Sharpness(img).enhance( 216 | 1 + magnitude * random.choice([-1, 1])), 217 | "brightness": lambda img, magnitude: ImageEnhance.Brightness(img).enhance( 218 | 1 + magnitude * random.choice([-1, 1])), 219 | "autocontrast": lambda img, magnitude: ImageOps.autocontrast(img), 220 | "equalize": lambda img, magnitude: ImageOps.equalize(img), 221 | "invert": lambda img, magnitude: ImageOps.invert(img) 222 | } 223 | 224 | # self.name = "{}_{:.2f}_and_{}_{:.2f}".format( 225 | # operation1, ranges[operation1][magnitude_idx1], 226 | # operation2, ranges[operation2][magnitude_idx2]) 227 | self.p1 = p1 228 | self.operation1 = func[operation1] 229 | self.magnitude1 = ranges[operation1][magnitude_idx1] 230 | self.p2 = p2 231 | self.operation2 = func[operation2] 232 | self.magnitude2 = ranges[operation2][magnitude_idx2] 233 | 234 | 235 | def __call__(self, img): 236 | if random.random() < self.p1: img = self.operation1(img, self.magnitude1) 237 | if random.random() < self.p2: img = self.operation2(img, self.magnitude2) 238 | return img 239 | -------------------------------------------------------------------------------- /dataset/cityscapes_dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import os.path as osp 3 | import numpy as np 4 | import random 5 | import matplotlib 6 | matplotlib.use('agg') 7 | import matplotlib.pyplot as plt 8 | import collections 9 | import torch 10 | import torchvision 11 | from torch.utils import data 12 | from PIL import Image, ImageFile 13 | from dataset.autoaugment import ImageNetPolicy 14 | import time 15 | 16 | ImageFile.LOAD_TRUNCATED_IMAGES = True 17 | 18 | class cityscapesDataSet(data.Dataset): 19 | def __init__(self, root, list_path, max_iters=None, resize_size=(1024, 512), crop_size=(512, 1024), mean=(128, 128, 128), scale=True, mirror=True, ignore_label=255, set='val', autoaug=False): 20 | self.root = root 21 | self.list_path = list_path 22 | self.crop_size = crop_size 23 | self.scale = scale 24 | self.ignore_label = ignore_label 25 | self.mean = mean 26 | self.is_mirror = mirror 27 | self.resize_size = resize_size 28 | self.autoaug = autoaug 29 | self.h = crop_size[0] 30 | self.w = crop_size[1] 31 | # self.mean_bgr = np.array([104.00698793, 116.66876762, 122.67891434]) 32 | self.img_ids = [i_id.strip() for i_id in open(list_path)] 33 | if not max_iters==None: 34 | self.img_ids = self.img_ids * int(np.ceil(float(max_iters) / len(self.img_ids))) 35 | self.files = [] 36 | self.set = set 37 | # for split in ["train", "trainval", "val"]: 38 | 39 | #https://github.com/mcordts/cityscapesScripts/blob/master/cityscapesscripts/helpers/labels.py 40 | self.id_to_trainid = {7: 0, 8: 1, 11: 2, 12: 3, 13: 4, 17: 5, 41 | 19: 6, 20: 7, 21: 8, 22: 9, 23: 10, 24: 11, 25: 12, 42 | 26: 13, 27: 14, 28: 15, 31: 16, 32: 17, 33: 18} 43 | 44 | for name in self.img_ids: 45 | img_file = osp.join(self.root, "leftImg8bit/%s/%s" % (self.set, name)) 46 | label_file = osp.join(self.root, "gtFine/%s/%s" % (self.set, name.replace('leftImg8bit', 'gtFine_labelIds') )) 47 | self.files.append({ 48 | "img": img_file, 49 | "label": label_file, 50 | "name": name 51 | }) 52 | 53 | def __len__(self): 54 | return len(self.files) 55 | 56 | def __getitem__(self, index): 57 | #tt = time.time() 58 | datafiles = self.files[index] 59 | name = datafiles["name"] 60 | 61 | image, label = Image.open(datafiles["img"]).convert('RGB'), Image.open(datafiles["label"]) 62 | # resize 63 | image, label = image.resize(self.resize_size, Image.BICUBIC), label.resize(self.resize_size, Image.NEAREST) 64 | if self.autoaug: 65 | policy = ImageNetPolicy() 66 | image = policy(image) 67 | 68 | image, label = np.asarray(image, np.float32), np.asarray(label, np.uint8) 69 | 70 | # re-assign labels to match the format of Cityscapes 71 | label_copy = 255 * np.ones(label.shape, dtype=np.uint8) 72 | for k, v in list(self.id_to_trainid.items()): 73 | label_copy[label == k] = v 74 | 75 | size = image.shape 76 | image = image[:, :, ::-1] # change to BGR 77 | image -= self.mean 78 | image = image.transpose((2, 0, 1)) 79 | x1 = random.randint(0, image.shape[1] - self.h) 80 | y1 = random.randint(0, image.shape[2] - self.w) 81 | image = image[:, x1:x1+self.h, y1:y1+self.w] 82 | label_copy = label_copy[x1:x1+self.h, y1:y1+self.w] 83 | 84 | if self.is_mirror and random.random() < 0.5: 85 | image = np.flip(image, axis = 2) 86 | label_copy = np.flip(label_copy, axis = 1) 87 | #print('Time used: {} sec'.format(time.time()-tt)) 88 | return image.copy(), label_copy.copy(), np.array(size), name 89 | 90 | 91 | if __name__ == '__main__': 92 | dst = cityscapesDataSet('./data/Cityscapes/data', './dataset/cityscapes_list/train.txt', mean=(0,0,0), set = 'train') 93 | trainloader = data.DataLoader(dst, batch_size=4) 94 | for i, data in enumerate(trainloader): 95 | imgs, _, _, _ = data 96 | if i == 0: 97 | img = torchvision.utils.make_grid(imgs).numpy() 98 | img = np.transpose(img, (1, 2, 0)) 99 | img = img[:, :, ::-1] 100 | img = Image.fromarray(np.uint8(img) ) 101 | img.save('Cityscape_Demo.jpg') 102 | break 103 | -------------------------------------------------------------------------------- /dataset/cityscapes_list/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/layumi/Seg-Uncertainty/6fce9eae141c2c0592b3e7c1b3e5f8ee7b1ce9a6/dataset/cityscapes_list/.DS_Store -------------------------------------------------------------------------------- /dataset/cityscapes_list/info.json: -------------------------------------------------------------------------------- 1 | { 2 | "classes":19, 3 | "label2train":[ 4 | [0, 255], 5 | [1, 255], 6 | [2, 255], 7 | [3, 255], 8 | [4, 255], 9 | [5, 255], 10 | [6, 255], 11 | [7, 0], 12 | [8, 1], 13 | [9, 255], 14 | [10, 255], 15 | [11, 2], 16 | [12, 3], 17 | [13, 4], 18 | [14, 255], 19 | [15, 255], 20 | [16, 255], 21 | [17, 5], 22 | [18, 255], 23 | [19, 6], 24 | [20, 7], 25 | [21, 8], 26 | [22, 9], 27 | [23, 10], 28 | [24, 11], 29 | [25, 12], 30 | [26, 13], 31 | [27, 14], 32 | [28, 15], 33 | [29, 255], 34 | [30, 255], 35 | [31, 16], 36 | [32, 17], 37 | [33, 18], 38 | [-1, 255]], 39 | "label":[ 40 | "road", 41 | "sidewalk", 42 | "building", 43 | "wall", 44 | "fence", 45 | "pole", 46 | "light", 47 | "sign", 48 | "vegetation", 49 | "terrain", 50 | "sky", 51 | "person", 52 | "rider", 53 | "car", 54 | "truck", 55 | "bus", 56 | "train", 57 | "motocycle", 58 | "bicycle"], 59 | "palette":[ 60 | [128,64,128], 61 | [244,35,232], 62 | [70,70,70], 63 | [102,102,156], 64 | [190,153,153], 65 | [153,153,153], 66 | [250,170,30], 67 | [220,220,0], 68 | [107,142,35], 69 | [152,251,152], 70 | [70,130,180], 71 | [220,20,60], 72 | [255,0,0], 73 | [0,0,142], 74 | [0,0,70], 75 | [0,60,100], 76 | [0,80,100], 77 | [0,0,230], 78 | [119,11,32], 79 | [0,0,0]], 80 | "mean":[ 81 | 73.158359210711552, 82 | 82.908917542625858, 83 | 72.392398761941593], 84 | "std":[ 85 | 47.675755341814678, 86 | 48.494214368814916, 87 | 47.736546325441594] 88 | } 89 | -------------------------------------------------------------------------------- /dataset/cityscapes_pseudo_dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import os.path as osp 3 | import numpy as np 4 | import random 5 | import matplotlib 6 | matplotlib.use('agg') 7 | import matplotlib.pyplot as plt 8 | import collections 9 | import torch 10 | import torchvision 11 | from torch.utils import data 12 | from PIL import Image,ImageFile 13 | from dataset.autoaugment import ImageNetPolicy 14 | 15 | ImageFile.LOAD_TRUNCATED_IMAGES = True 16 | 17 | class cityscapes_pseudo_DataSet(data.Dataset): 18 | def __init__(self, root, list_path, max_iters=None, resize_size=(1024, 512), crop_size=(512, 1024), mean=(128, 128, 128), scale=False, mirror=True, ignore_label=255, set='val', autoaug=False, synthia=False, threshold = 1.0): 19 | self.root = root 20 | self.list_path = list_path 21 | self.crop_size = crop_size 22 | self.scale = scale 23 | self.ignore_label = ignore_label 24 | self.mean = mean 25 | self.is_mirror = mirror 26 | self.resize_size = resize_size 27 | self.autoaug = autoaug 28 | self.h = crop_size[0] 29 | self.w = crop_size[1] 30 | # self.mean_bgr = np.array([104.00698793, 116.66876762, 122.67891434]) 31 | self.img_ids = [i_id.strip() for i_id in open(list_path)] 32 | if not max_iters==None: 33 | self.img_ids = self.img_ids * int(np.ceil(float(max_iters) / len(self.img_ids))) 34 | self.files = [] 35 | self.set = set 36 | # for split in ["train", "trainval", "val"]: 37 | 38 | #https://github.com/mcordts/cityscapesScripts/blob/master/cityscapesscripts/helpers/labels.py 39 | self.id_to_trainid = {7: 0, 8: 1, 11: 2, 12: 3, 13: 4, 17: 5, 40 | 19: 6, 20: 7, 21: 8, 22: 9, 23: 10, 24: 11, 25: 12, 41 | 26: 13, 27: 14, 28: 15, 31: 16, 32: 17, 33: 18} 42 | 43 | for name in self.img_ids: 44 | img_file = osp.join(self.root, "leftImg8bit/%s/%s" % (self.set, name)) 45 | label_file = osp.join(self.root, "pseudo_FULL/%s/%s" % (self.set, name )) 46 | if threshold != 1.0: 47 | label_file = osp.join(self.root, "pseudo_%.1f/%s/%s" % (threshold, self.set, name )) 48 | if synthia: 49 | label_file = osp.join(self.root, "pseudo_SYNTHIA/%s/%s" % (self.set, name )) 50 | self.files.append({ 51 | "img": img_file, 52 | "label": label_file, 53 | "name": name 54 | }) 55 | 56 | def __len__(self): 57 | return len(self.files) 58 | 59 | def __getitem__(self, index): 60 | datafiles = self.files[index] 61 | 62 | image = Image.open(datafiles["img"]).convert('RGB') 63 | label = Image.open(datafiles["label"]) 64 | name = datafiles["name"] 65 | 66 | # resize 67 | if self.scale: 68 | random_scale = 0.8 + random.random()*0.4 # 0.8 - 1.2 69 | image = image.resize( ( round(self.resize_size[0] * random_scale), round(self.resize_size[1] * random_scale)) , Image.BICUBIC) 70 | label = label.resize( ( round(self.resize_size[0] * random_scale), round(self.resize_size[1] * random_scale)) , Image.NEAREST) 71 | else: 72 | image = image.resize( ( self.resize_size[0], self.resize_size[1] ) , Image.BICUBIC) 73 | label = label.resize( ( self.resize_size[0], self.resize_size[1] ) , Image.NEAREST) 74 | 75 | if self.autoaug: 76 | policy = ImageNetPolicy() 77 | image = policy(image) 78 | 79 | image = np.asarray(image, np.float32) 80 | label = np.asarray(label, np.uint8) 81 | 82 | # re-assign labels to match the format of Cityscapes 83 | #label_copy = 255 * np.ones(label.shape, dtype=np.float32) 84 | #for k, v in list(self.id_to_trainid.items()): 85 | # label_copy[label == k] = v 86 | label_copy = label 87 | 88 | size = image.shape 89 | image = image[:, :, ::-1] # change to BGR 90 | image -= self.mean 91 | image = image.transpose((2, 0, 1)) 92 | #print(image.shape, label.shape) 93 | for i in range(10): #find hard samples 94 | x1 = random.randint(0, image.shape[1] - self.h) 95 | y1 = random.randint(0, image.shape[2] - self.w) 96 | tmp_label_copy = label_copy[x1:x1+self.h, y1:y1+self.w] 97 | tmp_image = image[:, x1:x1+self.h, y1:y1+self.w] 98 | u = np.unique(tmp_label_copy) 99 | if len(u) > 10: 100 | break 101 | #else: 102 | #print('Cityscape-Pseudo: Too young too naive for %d times!'%i) 103 | image = tmp_image 104 | label_copy = tmp_label_copy 105 | 106 | if self.is_mirror and random.random() < 0.5: 107 | image = np.flip(image, axis = 2) 108 | label_copy = np.flip(label_copy, axis = 1) 109 | 110 | return image.copy(), label_copy.copy(), np.array(size), name 111 | 112 | 113 | if __name__ == '__main__': 114 | dst = cityscapes_pseudo_DataSet('./data/Cityscapes/data', './dataset/cityscapes_list/train.txt', mean=(0,0,0), set = 'train', autoaug=True) 115 | trainloader = data.DataLoader(dst, batch_size=4) 116 | for i, data in enumerate(trainloader): 117 | imgs, _, _,_ = data 118 | if i == 0: 119 | img = torchvision.utils.make_grid(imgs).numpy() 120 | img = np.transpose(img, (1, 2, 0)) 121 | img = img[:, :, ::-1] 122 | img = Image.fromarray(np.uint8(img) ) 123 | img.save('Cityscape_Demo.jpg') 124 | break 125 | -------------------------------------------------------------------------------- /dataset/cityscapes_train_dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import os.path as osp 3 | import numpy as np 4 | import random 5 | import matplotlib 6 | matplotlib.use('agg') 7 | import matplotlib.pyplot as plt 8 | import collections 9 | import torch 10 | import torchvision 11 | from torch.utils import data 12 | from PIL import Image 13 | from dataset.autoaugment import ImageNetPolicy 14 | import time 15 | 16 | class cityscapesDataSet(data.Dataset): 17 | def __init__(self, root, list_path, max_iters=None, resize_size=(1024, 512), crop_size=(512, 1024), mean=(128, 128, 128), scale=True, mirror=True, ignore_label=255, set='train', autoaug=False): 18 | self.root = root 19 | self.list_path = list_path 20 | self.crop_size = crop_size 21 | self.scale = scale 22 | self.ignore_label = ignore_label 23 | self.mean = mean 24 | self.is_mirror = mirror 25 | self.resize_size = resize_size 26 | self.autoaug = autoaug 27 | self.h = crop_size[0] 28 | self.w = crop_size[1] 29 | # self.mean_bgr = np.array([104.00698793, 116.66876762, 122.67891434]) 30 | self.img_ids = [i_id.strip() for i_id in open(list_path)] 31 | if not max_iters==None: 32 | self.img_ids = self.img_ids * int(np.ceil(float(max_iters) / len(self.img_ids))) 33 | self.files = [] 34 | self.set = set 35 | # for split in ["train", "trainval", "val"]: 36 | 37 | #https://github.com/mcordts/cityscapesScripts/blob/master/cityscapesscripts/helpers/labels.py 38 | ''' 39 | project Cityscapes to Oxford Robot 40 | 7 road -> 8; 8 sidewalk -> 7; building 11 -> 6; wall 12 -> 255; 41 | fence 13 -> 255; pole 17-> 255: light 19 -> 5; sign 20->4; 42 | vegetation -> 255; terrain -> 255; sky 23 -> 0; person 24 -> 1 ; 43 | rider 25 -> 1 ; car 26 -> 3; truck 27 ->3; bus 28 ->3; train 31->255; 44 | motorcycle 32->2 ; bike 33 -> 2; 45 | 46 | ''' 47 | self.id_to_trainid = {7: 8, 8: 7, 11: 6, 48 | 19: 5, 20: 4, 23: 0, 24: 1, 25: 1, 49 | 26: 3, 27: 3, 28: 3, 32: 2, 33: 2} 50 | 51 | for name in self.img_ids: 52 | img_file = osp.join(self.root, "leftImg8bit/%s/%s" % (self.set, name)) 53 | label_file = osp.join(self.root, "gtFine/%s/%s" % (self.set, name.replace('leftImg8bit', 'gtFine_labelIds') )) 54 | self.files.append({ 55 | "img": img_file, 56 | "label": label_file, 57 | "name": name 58 | }) 59 | 60 | def __len__(self): 61 | return len(self.files) 62 | 63 | def __getitem__(self, index): 64 | #tt = time.time() 65 | datafiles = self.files[index] 66 | name = datafiles["name"] 67 | 68 | image, label = Image.open(datafiles["img"]).convert('RGB'), Image.open(datafiles["label"]) 69 | # resize 70 | image, label = image.resize(self.resize_size, Image.BICUBIC), label.resize(self.resize_size, Image.NEAREST) 71 | if self.autoaug: 72 | policy = ImageNetPolicy() 73 | image = policy(image) 74 | 75 | image, label = np.asarray(image, np.float32), np.asarray(label, np.uint8) 76 | 77 | # re-assign labels to match the format of Cityscapes 78 | label_copy = 255 * np.ones(label.shape, dtype=np.uint8) 79 | for k, v in list(self.id_to_trainid.items()): 80 | label_copy[label == k] = v 81 | 82 | size = image.shape 83 | image = image[:, :, ::-1] # change to BGR 84 | image -= self.mean 85 | image = image.transpose((2, 0, 1)) 86 | x1 = random.randint(0, image.shape[1] - self.h) 87 | y1 = random.randint(0, image.shape[2] - self.w) 88 | image = image[:, x1:x1+self.h, y1:y1+self.w] 89 | label_copy = label_copy[x1:x1+self.h, y1:y1+self.w] 90 | 91 | if self.is_mirror and random.random() < 0.5: 92 | image = np.flip(image, axis = 2) 93 | label_copy = np.flip(label_copy, axis = 1) 94 | #print('Time used: {} sec'.format(time.time()-tt)) 95 | return image.copy(), label_copy.copy(), np.array(size), name 96 | 97 | 98 | if __name__ == '__main__': 99 | dst = cityscapesDataSet('./data/Cityscapes/data', './dataset/cityscapes_list/train.txt', mean=(0,0,0), set = 'train') 100 | trainloader = data.DataLoader(dst, batch_size=4) 101 | for i, data in enumerate(trainloader): 102 | imgs, _, _, _ = data 103 | if i == 0: 104 | img = torchvision.utils.make_grid(imgs).numpy() 105 | img = np.transpose(img, (1, 2, 0)) 106 | img = img[:, :, ::-1] 107 | img = Image.fromarray(np.uint8(img) ) 108 | img.save('Cityscape_Demo.jpg') 109 | break 110 | -------------------------------------------------------------------------------- /dataset/gta5_dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import os.path as osp 3 | import numpy as np 4 | import random 5 | import matplotlib.pyplot as plt 6 | import collections 7 | import torch 8 | import torchvision 9 | from torch.utils import data 10 | from PIL import Image, ImageFile 11 | from dataset.autoaugment import ImageNetPolicy 12 | 13 | ImageFile.LOAD_TRUNCATED_IMAGES = True 14 | 15 | 16 | class GTA5DataSet(data.Dataset): 17 | def __init__(self, root, list_path, max_iters=None, resize_size=(1024, 512), crop_size=(512, 1024), mean=(128, 128, 128), scale=False, mirror=True, ignore_label=255, autoaug = False): 18 | self.root = root 19 | self.list_path = list_path 20 | self.crop_size = crop_size 21 | self.scale = scale 22 | self.ignore_label = ignore_label 23 | self.mean = mean 24 | self.is_mirror = mirror 25 | self.resize_size = resize_size 26 | self.autoaug = autoaug 27 | self.h = crop_size[0] 28 | self.w = crop_size[1] 29 | # self.mean_bgr = np.array([104.00698793, 116.66876762, 122.67891434]) 30 | self.img_ids = [i_id.strip() for i_id in open(list_path)] 31 | if not max_iters==None: 32 | self.img_ids = self.img_ids * int(np.ceil(float(max_iters) / len(self.img_ids))) 33 | self.files = [] 34 | 35 | self.id_to_trainid = {7: 0, 8: 1, 11: 2, 12: 3, 13: 4, 17: 5, 36 | 19: 6, 20: 7, 21: 8, 22: 9, 23: 10, 24: 11, 25: 12, 37 | 26: 13, 27: 14, 28: 15, 31: 16, 32: 17, 33: 18} 38 | 39 | # for split in ["train", "trainval", "val"]: 40 | for name in self.img_ids: 41 | img_file = osp.join(self.root, "images/%s" % name) 42 | label_file = osp.join(self.root, "labels/%s" % name) 43 | self.files.append({ 44 | "img": img_file, 45 | "label": label_file, 46 | "name": name 47 | }) 48 | 49 | def __len__(self): 50 | return len(self.files) 51 | 52 | 53 | def __getitem__(self, index): 54 | datafiles = self.files[index] 55 | 56 | image = Image.open(datafiles["img"]).convert('RGB') 57 | label = Image.open(datafiles["label"]) 58 | name = datafiles["name"] 59 | 60 | # resize 61 | if self.scale: 62 | random_scale = 0.8 + random.random()*0.4 # 0.8 - 1.2 63 | image = image.resize( ( round(self.resize_size[0] * random_scale), round(self.resize_size[1] * random_scale)) , Image.BICUBIC) 64 | label = label.resize( ( round(self.resize_size[0] * random_scale), round(self.resize_size[1] * random_scale)) , Image.NEAREST) 65 | else: 66 | image = image.resize( ( self.resize_size[0], self.resize_size[1] ) , Image.BICUBIC) 67 | label = label.resize( ( self.resize_size[0], self.resize_size[1] ) , Image.NEAREST) 68 | 69 | if self.autoaug: 70 | policy = ImageNetPolicy() 71 | image = policy(image) 72 | 73 | image = np.asarray(image, np.float32) 74 | label = np.asarray(label, np.uint8) 75 | 76 | # re-assign labels to match the format of Cityscapes 77 | label_copy = 255 * np.ones(label.shape, dtype=np.uint8) 78 | for k, v in list(self.id_to_trainid.items()): 79 | label_copy[label == k] = v 80 | 81 | size = image.shape 82 | image = image[:, :, ::-1] # change to BGR 83 | image -= self.mean 84 | image = image.transpose((2, 0, 1)) 85 | print(image.shape, label.shape) 86 | for i in range(10): #find hard samples 87 | x1 = random.randint(0, image.shape[1] - self.h) 88 | y1 = random.randint(0, image.shape[2] - self.w) 89 | tmp_label_copy = label_copy[x1:x1+self.h, y1:y1+self.w] 90 | tmp_image = image[:, x1:x1+self.h, y1:y1+self.w] 91 | u = np.unique(tmp_label_copy) 92 | if len(u) > 10: 93 | break 94 | else: 95 | print('GTA5: Too young too naive for %d times!'%i) 96 | 97 | image = tmp_image 98 | label_copy = tmp_label_copy 99 | 100 | if self.is_mirror and random.random() < 0.5: 101 | image = np.flip(image, axis = 2) 102 | label_copy = np.flip(label_copy, axis = 1) 103 | 104 | return image.copy(), label_copy.copy(), np.array(size), name 105 | 106 | 107 | if __name__ == '__main__': 108 | dst = GTA5DataSet('./data/GTA5/', './dataset/gta5_list/train.txt', mean=(0,0,0), autoaug=True) 109 | trainloader = data.DataLoader(dst, batch_size=4) 110 | for i, data in enumerate(trainloader): 111 | imgs, _, _, _ = data 112 | if i == 0: 113 | img = torchvision.utils.make_grid(imgs).numpy() 114 | img = np.transpose(img, (1, 2, 0)) 115 | img = img[:, :, ::-1] 116 | img = Image.fromarray(np.uint8(img) ) 117 | img.save('GTA5_Demo.jpg') 118 | break 119 | -------------------------------------------------------------------------------- /dataset/robot_dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import os.path as osp 3 | import numpy as np 4 | import random 5 | import matplotlib 6 | matplotlib.use('agg') 7 | import matplotlib.pyplot as plt 8 | import collections 9 | import torch 10 | import torchvision 11 | from torch.utils import data 12 | from PIL import Image,ImageFile 13 | from dataset.autoaugment import ImageNetPolicy 14 | import time 15 | ImageFile.LOAD_TRUNCATED_IMAGES = True 16 | 17 | 18 | class robotDataSet(data.Dataset): 19 | def __init__(self, root, list_path, max_iters=None, resize_size=(1024, 512), crop_size=(512, 1024), mean=(128, 128, 128), scale=True, mirror=True, ignore_label=255, set='val', autoaug=False): 20 | self.root = root 21 | self.list_path = list_path 22 | self.crop_size = crop_size 23 | self.scale = scale 24 | self.ignore_label = ignore_label 25 | self.mean = mean 26 | self.is_mirror = mirror 27 | self.resize_size = resize_size 28 | self.autoaug = autoaug 29 | self.h = crop_size[0] 30 | self.w = crop_size[1] 31 | # self.mean_bgr = np.array([104.00698793, 116.66876762, 122.67891434]) 32 | self.img_ids = [i_id.strip() for i_id in open(list_path)] 33 | if not max_iters==None: 34 | self.img_ids = self.img_ids * int(np.ceil(float(max_iters) / len(self.img_ids))) 35 | self.files = [] 36 | self.set = set 37 | # for split in ["train", "trainval", "val"]: 38 | 39 | #https://github.com/mcordts/cityscapesScripts/blob/master/cityscapesscripts/helpers/labels.py 40 | ''' 41 | 0 sky; 1 person; 2 two-wheel; 3 automobile; 4 sign 42 | 5 light 6 building 7 sidewalk 8 road 43 | ''' 44 | self.id_to_trainid = {1:0, 2:1, 3:2, 4:3, 5:4, 6:5, 7:6, 10:7, 11:8, 12:8, 13:8, 14:8, 17:8} 45 | 46 | for name in self.img_ids: 47 | img_file = osp.join(self.root, "%s/%s" % (self.set, name)) 48 | if set == 'val': 49 | label_file = osp.join(self.root, "anno/%s" %name ) 50 | else: 51 | label_file = '' 52 | self.files.append({ 53 | "img": img_file, 54 | "label": label_file, 55 | "name": name 56 | }) 57 | 58 | def __len__(self): 59 | return len(self.files) 60 | 61 | def __getitem__(self, index): 62 | #tt = time.time() 63 | datafiles = self.files[index] 64 | name = datafiles["name"] 65 | 66 | image = Image.open(datafiles["img"]).convert('RGB') 67 | image= image.resize(self.resize_size, Image.BICUBIC) 68 | 69 | if self.set == 'val': 70 | label = Image.open(datafiles["label"]) 71 | label = label.resize(self.resize_size, Image.NEAREST) 72 | label = np.asarray(label, np.uint8) 73 | # re-assign labels to match the format of Cityscapes 74 | label_copy = 255 * np.ones(label.shape, dtype=np.uint8) 75 | for k, v in list(self.id_to_trainid.items()): 76 | label_copy[label == k] = v 77 | 78 | if self.autoaug: 79 | policy = ImageNetPolicy() 80 | image = policy(image) 81 | image = np.asarray(image, np.float32) 82 | size = image.shape 83 | image = image[:, :, ::-1] # change to BGR 84 | image -= self.mean 85 | image = image.transpose((2, 0, 1)) 86 | x1 = random.randint(0, image.shape[1] - self.h) 87 | y1 = random.randint(0, image.shape[2] - self.w) 88 | image = image[:, x1:x1+self.h, y1:y1+self.w] 89 | 90 | if self.set == 'val': 91 | label_copy = label_copy[x1:x1+self.h, y1:y1+self.w] 92 | else: 93 | label_copy = np.ones(image.shape[1:3])*255 94 | 95 | if self.is_mirror and random.random() < 0.5: 96 | image = np.flip(image, axis = 2) 97 | if self.set == 'val': 98 | label_copy = np.flip(label_copy, axis = 1) 99 | 100 | 101 | return image.copy(), label_copy.copy(), np.array(size), name 102 | 103 | 104 | if __name__ == '__main__': 105 | dst = robotDataSet('./data/Oxford_Robot_ICCV19', './dataset/robot_list/train.txt', mean=(0,0,0), set = 'train') 106 | trainloader = data.DataLoader(dst, batch_size=4) 107 | for i, data in enumerate(trainloader): 108 | imgs, _, _, _ = data 109 | if i == 0: 110 | img = torchvision.utils.make_grid(imgs).numpy() 111 | img = np.transpose(img, (1, 2, 0)) 112 | img = img[:, :, ::-1] 113 | img = Image.fromarray(np.uint8(img) ) 114 | img.save('Robot_Demo.jpg') 115 | break 116 | -------------------------------------------------------------------------------- /dataset/robot_list/info.json: -------------------------------------------------------------------------------- 1 | { 2 | "classes":9, 3 | "label2train":[ 4 | [0, 255], 5 | [1, 0], 6 | [2, 1], 7 | [3, 2], 8 | [4, 3], 9 | [5, 4], 10 | [6, 5], 11 | [7, 6], 12 | [8, 255], 13 | [9, 255], 14 | [10, 7], 15 | [11, 8], 16 | [12, 8], 17 | [13, 8], 18 | [14, 8], 19 | [15, 255], 20 | [16, 255], 21 | [17, 8], 22 | [18, 255], 23 | [19, 255], 24 | [20, 255], 25 | [21, 255], 26 | [22, 255], 27 | [23, 255], 28 | [24, 255], 29 | [25, 255], 30 | [26, 255], 31 | [27, 255], 32 | [28, 255], 33 | [29, 255], 34 | [30, 255], 35 | [31, 255], 36 | [32, 255], 37 | [33, 255], 38 | [-1, 255]], 39 | "label":[ 40 | "sky", 41 | "person", 42 | "two-wheel", 43 | "automobile", 44 | "sign", 45 | "light", 46 | "building", 47 | "sidewalk", 48 | "road"], 49 | "palette":[ 50 | [70,130,180], 51 | [220,20,60], 52 | [119,11,32], 53 | [0,0,142], 54 | [220,220,0], 55 | [250,170,30], 56 | [70,70,70], 57 | [244,35,232], 58 | [128,64,128], 59 | [0,0,0]], 60 | "mean":[ 61 | 73.158359210711552, 62 | 82.908917542625858, 63 | 72.392398761941593], 64 | "std":[ 65 | 47.675755341814678, 66 | 48.494214368814916, 67 | 47.736546325441594] 68 | } 69 | -------------------------------------------------------------------------------- /dataset/robot_list/label.txt: -------------------------------------------------------------------------------- 1 | 2014-11-25-09-18-32_01_000038.png 2 | 2014-11-25-09-18-32_01_000039.png 3 | 2014-11-25-09-18-32_01_000040.png 4 | 2014-11-25-09-18-32_01_000041.png 5 | 2014-11-25-09-18-32_01_000042.png 6 | 2014-11-25-09-18-32_01_000043.png 7 | 2014-11-25-09-18-32_01_000044.png 8 | 2014-11-25-09-18-32_01_000045.png 9 | 2014-11-25-09-18-32_01_000046.png 10 | 2014-11-25-09-18-32_01_000047.png 11 | 2014-11-25-09-18-32_01_000048.png 12 | 2014-11-25-09-18-32_01_000049.png 13 | 2014-11-25-09-18-32_01_000050.png 14 | 2014-11-25-09-18-32_01_000051.png 15 | 2014-11-25-09-18-32_01_000052.png 16 | 2014-11-25-09-18-32_01_000053.png 17 | 2014-11-25-09-18-32_01_000054.png 18 | 2014-11-25-09-18-32_01_000055.png 19 | 2014-11-25-09-18-32_01_000056.png 20 | 2014-11-25-09-18-32_01_000057.png 21 | 2014-11-25-09-18-32_01_000058.png 22 | 2014-11-25-09-18-32_01_000059.png 23 | 2014-11-25-09-18-32_01_000060.png 24 | 2014-11-25-09-18-32_01_000061.png 25 | 2014-11-25-09-18-32_01_000062.png 26 | 2014-11-25-09-18-32_01_000063.png 27 | 2014-11-25-09-18-32_01_000064.png 28 | 2014-11-25-09-18-32_01_000065.png 29 | 2014-11-25-09-18-32_01_000066.png 30 | 2014-11-25-09-18-32_01_000067.png 31 | 2014-11-25-09-18-32_01_000068.png 32 | 2014-11-25-09-18-32_01_000069.png 33 | 2014-11-25-09-18-32_01_000070.png 34 | 2014-11-25-09-18-32_01_000071.png 35 | 2014-11-25-09-18-32_01_000072.png 36 | 2014-11-25-09-18-32_01_000073.png 37 | 2014-11-25-09-18-32_01_000074.png 38 | 2014-11-25-09-18-32_01_000075.png 39 | 2014-11-25-09-18-32_01_000076.png 40 | 2014-11-25-09-18-32_01_000077.png 41 | 2014-11-25-09-18-32_01_000078.png 42 | 2014-11-25-09-18-32_01_000079.png 43 | 2014-11-25-09-18-32_01_000080.png 44 | 2014-11-25-09-18-32_01_000081.png 45 | 2014-11-25-09-18-32_01_000082.png 46 | 2014-11-25-09-18-32_01_000083.png 47 | 2014-11-25-09-18-32_01_000084.png 48 | 2014-11-25-09-18-32_01_000085.png 49 | 2014-11-25-09-18-32_01_000086.png 50 | 2014-11-25-09-18-32_01_000087.png 51 | 2014-11-25-09-18-32_01_000088.png 52 | 2014-11-25-09-18-32_01_000089.png 53 | 2014-11-25-09-18-32_01_000090.png 54 | 2014-11-25-09-18-32_01_000128.png 55 | 2014-11-25-09-18-32_01_000129.png 56 | 2014-11-25-09-18-32_01_000130.png 57 | 2014-11-25-09-18-32_01_000131.png 58 | 2014-11-25-09-18-32_01_000132.png 59 | 2014-11-25-09-18-32_01_000133.png 60 | 2014-11-25-09-18-32_01_000134.png 61 | 2014-11-25-09-18-32_01_000135.png 62 | 2014-11-25-09-18-32_01_000136.png 63 | 2014-11-25-09-18-32_01_000137.png 64 | 2014-11-25-09-18-32_01_000138.png 65 | 2014-11-25-09-18-32_01_000139.png 66 | 2014-11-25-09-18-32_01_000140.png 67 | 2014-11-25-09-18-32_01_000141.png 68 | 2014-11-25-09-18-32_01_000142.png 69 | 2014-11-25-09-18-32_01_000143.png 70 | 2014-11-25-09-18-32_01_000144.png 71 | 2014-11-25-09-18-32_01_000145.png 72 | 2014-11-25-09-18-32_01_000146.png 73 | 2014-11-25-09-18-32_01_000147.png 74 | 2014-11-25-09-18-32_01_000148.png 75 | 2014-11-25-09-18-32_01_000149.png 76 | 2014-11-25-09-18-32_01_000150.png 77 | 2014-11-25-09-18-32_01_000151.png 78 | 2014-11-25-09-18-32_01_000152.png 79 | 2014-11-25-09-18-32_01_000153.png 80 | 2014-11-25-09-18-32_01_000154.png 81 | 2014-11-25-09-18-32_01_000155.png 82 | 2014-11-25-09-18-32_01_000156.png 83 | 2014-11-25-09-18-32_01_000157.png 84 | 2014-11-25-09-18-32_01_000158.png 85 | 2014-11-25-09-18-32_01_000159.png 86 | 2014-11-25-09-18-32_01_000160.png 87 | 2014-11-25-09-18-32_01_000161.png 88 | 2014-11-25-09-18-32_01_000162.png 89 | 2014-11-25-09-18-32_01_000163.png 90 | 2014-11-25-09-18-32_01_000164.png 91 | 2014-11-25-09-18-32_01_000165.png 92 | 2014-11-25-09-18-32_01_000166.png 93 | 2014-11-25-09-18-32_01_000167.png 94 | 2014-11-25-09-18-32_01_000168.png 95 | 2014-11-25-09-18-32_01_000169.png 96 | 2014-11-25-09-18-32_01_000170.png 97 | 2014-11-25-09-18-32_01_000171.png 98 | 2014-11-25-09-18-32_01_000172.png 99 | 2014-11-25-09-18-32_01_000173.png 100 | 2014-11-25-09-18-32_01_000174.png 101 | 2014-11-25-09-18-32_01_000175.png 102 | 2014-11-25-09-18-32_01_000176.png 103 | 2014-11-25-09-18-32_01_000177.png 104 | 2014-11-25-09-18-32_01_000178.png 105 | 2014-11-25-09-18-32_01_000179.png 106 | 2014-11-25-09-18-32_01_000180.png 107 | 2014-11-25-09-18-32_01_000181.png 108 | 2014-11-25-09-18-32_01_000182.png 109 | 2014-11-25-09-18-32_01_000183.png 110 | 2014-11-25-09-18-32_01_000184.png 111 | 2014-11-25-09-18-32_01_000185.png 112 | 2014-11-25-09-18-32_01_000186.png 113 | 2014-11-25-09-18-32_01_000187.png 114 | 2014-11-25-09-18-32_01_000188.png 115 | 2014-11-25-09-18-32_01_000189.png 116 | 2014-11-25-09-18-32_01_000190.png 117 | 2014-11-25-09-18-32_01_000191.png 118 | 2014-11-25-09-18-32_01_000192.png 119 | 2014-11-25-09-18-32_01_000193.png 120 | 2014-11-25-09-18-32_01_000194.png 121 | 2014-11-25-09-18-32_01_000195.png 122 | 2014-11-25-09-18-32_01_000196.png 123 | 2014-11-25-09-18-32_01_000197.png 124 | 2014-11-25-09-18-32_01_000198.png 125 | 2014-11-25-09-18-32_01_000199.png 126 | 2015-10-29-12-18-17_07_000000.png 127 | 2015-10-29-12-18-17_07_000001.png 128 | 2015-10-29-12-18-17_07_000002.png 129 | 2015-10-29-12-18-17_07_000003.png 130 | 2015-10-29-12-18-17_07_000004.png 131 | 2015-10-29-12-18-17_07_000005.png 132 | 2015-10-29-12-18-17_07_000006.png 133 | 2015-10-29-12-18-17_07_000007.png 134 | 2015-10-29-12-18-17_07_000008.png 135 | 2015-10-29-12-18-17_07_000009.png 136 | 2015-10-29-12-18-17_07_000010.png 137 | 2015-10-29-12-18-17_07_000011.png 138 | 2015-10-29-12-18-17_07_000012.png 139 | 2015-10-29-12-18-17_07_000013.png 140 | 2015-10-29-12-18-17_07_000014.png 141 | 2015-10-29-12-18-17_07_000015.png 142 | 2015-10-29-12-18-17_07_000016.png 143 | 2015-10-29-12-18-17_07_000017.png 144 | 2015-10-29-12-18-17_07_000018.png 145 | 2015-10-29-12-18-17_07_000019.png 146 | 2015-10-29-12-18-17_07_000020.png 147 | 2015-10-29-12-18-17_07_000021.png 148 | 2015-10-29-12-18-17_07_000022.png 149 | 2015-10-29-12-18-17_07_000023.png 150 | 2015-10-29-12-18-17_07_000024.png 151 | 2015-10-29-12-18-17_07_000025.png 152 | 2015-10-29-12-18-17_07_000026.png 153 | 2015-10-29-12-18-17_07_000027.png 154 | 2015-10-29-12-18-17_07_000028.png 155 | 2015-10-29-12-18-17_07_000029.png 156 | 2015-10-29-12-18-17_07_000030.png 157 | 2015-10-29-12-18-17_07_000031.png 158 | 2015-10-29-12-18-17_07_000032.png 159 | 2015-10-29-12-18-17_07_000033.png 160 | 2015-10-29-12-18-17_07_000034.png 161 | 2015-10-29-12-18-17_07_000035.png 162 | 2015-10-29-12-18-17_07_000036.png 163 | 2015-10-29-12-18-17_07_000037.png 164 | 2015-10-29-12-18-17_07_000038.png 165 | 2015-10-29-12-18-17_07_000039.png 166 | 2015-10-29-12-18-17_07_000040.png 167 | 2015-10-29-12-18-17_07_000041.png 168 | 2015-11-06-11-21-12_05_000000.png 169 | 2015-11-06-11-21-12_05_000001.png 170 | 2015-11-06-11-21-12_05_000002.png 171 | 2015-11-06-11-21-12_05_000003.png 172 | 2015-11-06-11-21-12_05_000004.png 173 | 2015-11-06-11-21-12_05_000005.png 174 | 2015-11-06-11-21-12_05_000006.png 175 | 2015-11-06-11-21-12_05_000007.png 176 | 2015-11-06-11-21-12_05_000008.png 177 | 2015-11-06-11-21-12_05_000009.png 178 | 2015-11-06-11-21-12_05_000010.png 179 | 2015-11-06-11-21-12_05_000011.png 180 | 2015-11-06-11-21-12_05_000012.png 181 | 2015-11-06-11-21-12_05_000013.png 182 | 2015-11-06-11-21-12_05_000014.png 183 | 2015-11-06-11-21-12_05_000015.png 184 | 2015-11-06-11-21-12_05_000016.png 185 | 2015-11-06-11-21-12_05_000017.png 186 | 2015-11-06-11-21-12_05_000018.png 187 | 2015-11-06-11-21-12_05_000019.png 188 | 2015-11-06-11-21-12_05_000020.png 189 | 2015-11-06-11-21-12_05_000021.png 190 | 2015-11-06-11-21-12_05_000022.png 191 | 2015-11-06-11-21-12_05_000023.png 192 | 2015-11-06-11-21-12_05_000024.png 193 | 2015-11-06-11-21-12_05_000025.png 194 | 2015-11-06-11-21-12_05_000026.png 195 | 2015-11-06-11-21-12_05_000027.png 196 | 2015-11-06-11-21-12_05_000028.png 197 | 2015-11-06-11-21-12_05_000029.png 198 | 2015-11-06-11-21-12_05_000030.png 199 | 2015-11-06-11-21-12_05_000031.png 200 | 2015-11-06-11-21-12_05_000032.png 201 | 2015-11-06-11-21-12_05_000033.png 202 | 2015-11-06-11-21-12_05_000034.png 203 | 2015-11-06-11-21-12_05_000035.png 204 | 2015-11-06-11-21-12_05_000036.png 205 | 2015-11-06-11-21-12_05_000037.png 206 | 2015-11-06-11-21-12_05_000038.png 207 | 2015-11-06-11-21-12_05_000039.png 208 | 2015-11-06-11-21-12_05_000040.png 209 | 2015-11-06-11-21-12_05_000041.png 210 | 2015-11-06-11-21-12_05_000042.png 211 | 2015-11-06-11-21-12_05_000043.png 212 | 2015-11-06-11-21-12_05_000044.png 213 | 2015-11-06-11-21-12_05_000045.png 214 | 2015-11-06-11-21-12_05_000046.png 215 | 2015-11-06-11-21-12_05_000047.png 216 | 2015-11-06-11-21-12_05_000049.png 217 | 2015-11-06-11-21-12_05_000050.png 218 | 2015-11-06-11-21-12_05_000051.png 219 | 2015-11-06-11-21-12_05_000052.png 220 | 2015-11-06-11-21-12_05_000053.png 221 | 2015-11-06-11-21-12_05_000054.png 222 | 2015-11-06-11-21-12_05_000055.png 223 | 2015-11-06-11-21-12_05_000056.png 224 | 2015-11-06-11-21-12_05_000057.png 225 | 2015-11-06-11-21-12_05_000058.png 226 | 2015-11-06-11-21-12_05_000059.png 227 | 2015-11-06-11-21-12_05_000060.png 228 | 2015-11-06-11-21-12_05_000061.png 229 | 2015-11-06-11-21-12_05_000062.png 230 | 2015-11-06-11-21-12_05_000063.png 231 | 2015-11-06-11-21-12_05_000064.png 232 | 2015-11-06-11-21-12_05_000065.png 233 | 2015-11-06-11-21-12_05_000066.png 234 | 2015-11-06-11-21-12_05_000067.png 235 | 2015-11-06-11-21-12_05_000068.png 236 | 2015-11-06-11-21-12_05_000069.png 237 | 2015-11-06-11-21-12_05_000071.png 238 | 2015-11-06-11-21-12_05_000072.png 239 | 2015-11-06-11-21-12_05_000073.png 240 | 2015-11-06-11-21-12_05_000074.png 241 | 2015-11-06-11-21-12_05_000075.png 242 | 2015-11-06-11-21-12_05_000076.png 243 | 2015-11-06-11-21-12_05_000078.png 244 | 2015-11-06-11-21-12_05_000079.png 245 | 2015-11-06-11-21-12_05_000080.png 246 | 2015-11-06-11-21-12_05_000081.png 247 | 2015-11-06-11-21-12_05_000082.png 248 | 2015-11-06-11-21-12_05_000083.png 249 | 2015-11-06-11-21-12_05_000084.png 250 | 2015-11-06-11-21-12_05_000085.png 251 | 2015-11-06-11-21-12_05_000086.png 252 | 2015-11-06-11-21-12_05_000087.png 253 | 2015-11-06-11-21-12_05_000088.png 254 | 2015-11-06-11-21-12_05_000089.png 255 | 2015-11-06-11-21-12_05_000090.png 256 | 2015-11-06-11-21-12_05_000091.png 257 | 2015-11-06-11-21-12_05_000092.png 258 | 2015-11-06-11-21-12_05_000095.png 259 | 2015-11-06-11-21-12_05_000096.png 260 | 2015-11-06-11-21-12_05_000097.png 261 | 2015-11-06-11-21-12_05_000099.png 262 | 2015-11-06-11-21-12_05_000100.png 263 | 2015-11-06-11-21-12_05_000101.png 264 | 2015-11-06-11-21-12_05_000102.png 265 | 2015-11-06-11-21-12_05_000103.png 266 | 2015-11-06-11-21-12_05_000104.png 267 | 2015-11-06-11-21-12_05_000105.png 268 | 2015-11-06-11-21-12_05_000106.png 269 | 2015-11-06-11-21-12_05_000107.png 270 | 2015-11-06-11-21-12_05_000108.png 271 | 2015-11-06-11-21-12_05_000109.png 272 | -------------------------------------------------------------------------------- /dataset/robot_list/val.txt: -------------------------------------------------------------------------------- 1 | 2014-11-25-09-18-32_01_000038.png 2 | 2014-11-25-09-18-32_01_000039.png 3 | 2014-11-25-09-18-32_01_000040.png 4 | 2014-11-25-09-18-32_01_000041.png 5 | 2014-11-25-09-18-32_01_000042.png 6 | 2014-11-25-09-18-32_01_000043.png 7 | 2014-11-25-09-18-32_01_000044.png 8 | 2014-11-25-09-18-32_01_000045.png 9 | 2014-11-25-09-18-32_01_000046.png 10 | 2014-11-25-09-18-32_01_000047.png 11 | 2014-11-25-09-18-32_01_000048.png 12 | 2014-11-25-09-18-32_01_000049.png 13 | 2014-11-25-09-18-32_01_000050.png 14 | 2014-11-25-09-18-32_01_000051.png 15 | 2014-11-25-09-18-32_01_000052.png 16 | 2014-11-25-09-18-32_01_000053.png 17 | 2014-11-25-09-18-32_01_000054.png 18 | 2014-11-25-09-18-32_01_000055.png 19 | 2014-11-25-09-18-32_01_000056.png 20 | 2014-11-25-09-18-32_01_000057.png 21 | 2014-11-25-09-18-32_01_000058.png 22 | 2014-11-25-09-18-32_01_000059.png 23 | 2014-11-25-09-18-32_01_000060.png 24 | 2014-11-25-09-18-32_01_000061.png 25 | 2014-11-25-09-18-32_01_000062.png 26 | 2014-11-25-09-18-32_01_000063.png 27 | 2014-11-25-09-18-32_01_000064.png 28 | 2014-11-25-09-18-32_01_000065.png 29 | 2014-11-25-09-18-32_01_000066.png 30 | 2014-11-25-09-18-32_01_000067.png 31 | 2014-11-25-09-18-32_01_000068.png 32 | 2014-11-25-09-18-32_01_000069.png 33 | 2014-11-25-09-18-32_01_000070.png 34 | 2014-11-25-09-18-32_01_000071.png 35 | 2014-11-25-09-18-32_01_000072.png 36 | 2014-11-25-09-18-32_01_000073.png 37 | 2014-11-25-09-18-32_01_000074.png 38 | 2014-11-25-09-18-32_01_000075.png 39 | 2014-11-25-09-18-32_01_000076.png 40 | 2014-11-25-09-18-32_01_000077.png 41 | 2014-11-25-09-18-32_01_000078.png 42 | 2014-11-25-09-18-32_01_000079.png 43 | 2014-11-25-09-18-32_01_000080.png 44 | 2014-11-25-09-18-32_01_000081.png 45 | 2014-11-25-09-18-32_01_000082.png 46 | 2014-11-25-09-18-32_01_000083.png 47 | 2014-11-25-09-18-32_01_000084.png 48 | 2014-11-25-09-18-32_01_000085.png 49 | 2014-11-25-09-18-32_01_000086.png 50 | 2014-11-25-09-18-32_01_000087.png 51 | 2014-11-25-09-18-32_01_000088.png 52 | 2014-11-25-09-18-32_01_000089.png 53 | 2014-11-25-09-18-32_01_000090.png 54 | 2014-11-25-09-18-32_01_000128.png 55 | 2014-11-25-09-18-32_01_000129.png 56 | 2014-11-25-09-18-32_01_000130.png 57 | 2014-11-25-09-18-32_01_000131.png 58 | 2014-11-25-09-18-32_01_000132.png 59 | 2014-11-25-09-18-32_01_000133.png 60 | 2014-11-25-09-18-32_01_000134.png 61 | 2014-11-25-09-18-32_01_000135.png 62 | 2014-11-25-09-18-32_01_000136.png 63 | 2014-11-25-09-18-32_01_000137.png 64 | 2014-11-25-09-18-32_01_000138.png 65 | 2014-11-25-09-18-32_01_000139.png 66 | 2014-11-25-09-18-32_01_000140.png 67 | 2014-11-25-09-18-32_01_000141.png 68 | 2014-11-25-09-18-32_01_000142.png 69 | 2014-11-25-09-18-32_01_000143.png 70 | 2014-11-25-09-18-32_01_000144.png 71 | 2014-11-25-09-18-32_01_000145.png 72 | 2014-11-25-09-18-32_01_000146.png 73 | 2014-11-25-09-18-32_01_000147.png 74 | 2014-11-25-09-18-32_01_000148.png 75 | 2014-11-25-09-18-32_01_000149.png 76 | 2014-11-25-09-18-32_01_000150.png 77 | 2014-11-25-09-18-32_01_000151.png 78 | 2014-11-25-09-18-32_01_000152.png 79 | 2014-11-25-09-18-32_01_000153.png 80 | 2014-11-25-09-18-32_01_000154.png 81 | 2014-11-25-09-18-32_01_000155.png 82 | 2014-11-25-09-18-32_01_000156.png 83 | 2014-11-25-09-18-32_01_000157.png 84 | 2014-11-25-09-18-32_01_000158.png 85 | 2014-11-25-09-18-32_01_000159.png 86 | 2014-11-25-09-18-32_01_000160.png 87 | 2014-11-25-09-18-32_01_000161.png 88 | 2014-11-25-09-18-32_01_000162.png 89 | 2014-11-25-09-18-32_01_000163.png 90 | 2014-11-25-09-18-32_01_000164.png 91 | 2014-11-25-09-18-32_01_000165.png 92 | 2014-11-25-09-18-32_01_000166.png 93 | 2014-11-25-09-18-32_01_000167.png 94 | 2014-11-25-09-18-32_01_000168.png 95 | 2014-11-25-09-18-32_01_000169.png 96 | 2014-11-25-09-18-32_01_000170.png 97 | 2014-11-25-09-18-32_01_000171.png 98 | 2014-11-25-09-18-32_01_000172.png 99 | 2014-11-25-09-18-32_01_000173.png 100 | 2014-11-25-09-18-32_01_000174.png 101 | 2014-11-25-09-18-32_01_000175.png 102 | 2014-11-25-09-18-32_01_000176.png 103 | 2014-11-25-09-18-32_01_000177.png 104 | 2014-11-25-09-18-32_01_000178.png 105 | 2014-11-25-09-18-32_01_000179.png 106 | 2014-11-25-09-18-32_01_000180.png 107 | 2014-11-25-09-18-32_01_000181.png 108 | 2014-11-25-09-18-32_01_000182.png 109 | 2014-11-25-09-18-32_01_000183.png 110 | 2014-11-25-09-18-32_01_000184.png 111 | 2014-11-25-09-18-32_01_000185.png 112 | 2014-11-25-09-18-32_01_000186.png 113 | 2014-11-25-09-18-32_01_000187.png 114 | 2014-11-25-09-18-32_01_000188.png 115 | 2014-11-25-09-18-32_01_000189.png 116 | 2014-11-25-09-18-32_01_000190.png 117 | 2014-11-25-09-18-32_01_000191.png 118 | 2014-11-25-09-18-32_01_000192.png 119 | 2014-11-25-09-18-32_01_000193.png 120 | 2014-11-25-09-18-32_01_000194.png 121 | 2014-11-25-09-18-32_01_000195.png 122 | 2014-11-25-09-18-32_01_000196.png 123 | 2014-11-25-09-18-32_01_000197.png 124 | 2014-11-25-09-18-32_01_000198.png 125 | 2014-11-25-09-18-32_01_000199.png 126 | 2015-10-29-12-18-17_07_000000.png 127 | 2015-10-29-12-18-17_07_000001.png 128 | 2015-10-29-12-18-17_07_000002.png 129 | 2015-10-29-12-18-17_07_000003.png 130 | 2015-10-29-12-18-17_07_000004.png 131 | 2015-10-29-12-18-17_07_000005.png 132 | 2015-10-29-12-18-17_07_000006.png 133 | 2015-10-29-12-18-17_07_000007.png 134 | 2015-10-29-12-18-17_07_000008.png 135 | 2015-10-29-12-18-17_07_000009.png 136 | 2015-10-29-12-18-17_07_000010.png 137 | 2015-10-29-12-18-17_07_000011.png 138 | 2015-10-29-12-18-17_07_000012.png 139 | 2015-10-29-12-18-17_07_000013.png 140 | 2015-10-29-12-18-17_07_000014.png 141 | 2015-10-29-12-18-17_07_000015.png 142 | 2015-10-29-12-18-17_07_000016.png 143 | 2015-10-29-12-18-17_07_000017.png 144 | 2015-10-29-12-18-17_07_000018.png 145 | 2015-10-29-12-18-17_07_000019.png 146 | 2015-10-29-12-18-17_07_000020.png 147 | 2015-10-29-12-18-17_07_000021.png 148 | 2015-10-29-12-18-17_07_000022.png 149 | 2015-10-29-12-18-17_07_000023.png 150 | 2015-10-29-12-18-17_07_000024.png 151 | 2015-10-29-12-18-17_07_000025.png 152 | 2015-10-29-12-18-17_07_000026.png 153 | 2015-10-29-12-18-17_07_000027.png 154 | 2015-10-29-12-18-17_07_000028.png 155 | 2015-10-29-12-18-17_07_000029.png 156 | 2015-10-29-12-18-17_07_000030.png 157 | 2015-10-29-12-18-17_07_000031.png 158 | 2015-10-29-12-18-17_07_000032.png 159 | 2015-10-29-12-18-17_07_000033.png 160 | 2015-10-29-12-18-17_07_000034.png 161 | 2015-10-29-12-18-17_07_000035.png 162 | 2015-10-29-12-18-17_07_000036.png 163 | 2015-10-29-12-18-17_07_000037.png 164 | 2015-10-29-12-18-17_07_000038.png 165 | 2015-10-29-12-18-17_07_000039.png 166 | 2015-10-29-12-18-17_07_000040.png 167 | 2015-10-29-12-18-17_07_000041.png 168 | 2015-11-06-11-21-12_05_000000.png 169 | 2015-11-06-11-21-12_05_000001.png 170 | 2015-11-06-11-21-12_05_000002.png 171 | 2015-11-06-11-21-12_05_000003.png 172 | 2015-11-06-11-21-12_05_000004.png 173 | 2015-11-06-11-21-12_05_000005.png 174 | 2015-11-06-11-21-12_05_000006.png 175 | 2015-11-06-11-21-12_05_000007.png 176 | 2015-11-06-11-21-12_05_000008.png 177 | 2015-11-06-11-21-12_05_000009.png 178 | 2015-11-06-11-21-12_05_000010.png 179 | 2015-11-06-11-21-12_05_000011.png 180 | 2015-11-06-11-21-12_05_000012.png 181 | 2015-11-06-11-21-12_05_000013.png 182 | 2015-11-06-11-21-12_05_000014.png 183 | 2015-11-06-11-21-12_05_000015.png 184 | 2015-11-06-11-21-12_05_000016.png 185 | 2015-11-06-11-21-12_05_000017.png 186 | 2015-11-06-11-21-12_05_000018.png 187 | 2015-11-06-11-21-12_05_000019.png 188 | 2015-11-06-11-21-12_05_000020.png 189 | 2015-11-06-11-21-12_05_000021.png 190 | 2015-11-06-11-21-12_05_000022.png 191 | 2015-11-06-11-21-12_05_000023.png 192 | 2015-11-06-11-21-12_05_000024.png 193 | 2015-11-06-11-21-12_05_000025.png 194 | 2015-11-06-11-21-12_05_000026.png 195 | 2015-11-06-11-21-12_05_000027.png 196 | 2015-11-06-11-21-12_05_000028.png 197 | 2015-11-06-11-21-12_05_000029.png 198 | 2015-11-06-11-21-12_05_000030.png 199 | 2015-11-06-11-21-12_05_000031.png 200 | 2015-11-06-11-21-12_05_000032.png 201 | 2015-11-06-11-21-12_05_000033.png 202 | 2015-11-06-11-21-12_05_000034.png 203 | 2015-11-06-11-21-12_05_000035.png 204 | 2015-11-06-11-21-12_05_000036.png 205 | 2015-11-06-11-21-12_05_000037.png 206 | 2015-11-06-11-21-12_05_000038.png 207 | 2015-11-06-11-21-12_05_000039.png 208 | 2015-11-06-11-21-12_05_000040.png 209 | 2015-11-06-11-21-12_05_000041.png 210 | 2015-11-06-11-21-12_05_000042.png 211 | 2015-11-06-11-21-12_05_000043.png 212 | 2015-11-06-11-21-12_05_000044.png 213 | 2015-11-06-11-21-12_05_000045.png 214 | 2015-11-06-11-21-12_05_000046.png 215 | 2015-11-06-11-21-12_05_000047.png 216 | 2015-11-06-11-21-12_05_000049.png 217 | 2015-11-06-11-21-12_05_000050.png 218 | 2015-11-06-11-21-12_05_000051.png 219 | 2015-11-06-11-21-12_05_000052.png 220 | 2015-11-06-11-21-12_05_000053.png 221 | 2015-11-06-11-21-12_05_000054.png 222 | 2015-11-06-11-21-12_05_000055.png 223 | 2015-11-06-11-21-12_05_000056.png 224 | 2015-11-06-11-21-12_05_000057.png 225 | 2015-11-06-11-21-12_05_000058.png 226 | 2015-11-06-11-21-12_05_000059.png 227 | 2015-11-06-11-21-12_05_000060.png 228 | 2015-11-06-11-21-12_05_000061.png 229 | 2015-11-06-11-21-12_05_000062.png 230 | 2015-11-06-11-21-12_05_000063.png 231 | 2015-11-06-11-21-12_05_000064.png 232 | 2015-11-06-11-21-12_05_000065.png 233 | 2015-11-06-11-21-12_05_000066.png 234 | 2015-11-06-11-21-12_05_000067.png 235 | 2015-11-06-11-21-12_05_000068.png 236 | 2015-11-06-11-21-12_05_000069.png 237 | 2015-11-06-11-21-12_05_000071.png 238 | 2015-11-06-11-21-12_05_000072.png 239 | 2015-11-06-11-21-12_05_000073.png 240 | 2015-11-06-11-21-12_05_000074.png 241 | 2015-11-06-11-21-12_05_000075.png 242 | 2015-11-06-11-21-12_05_000076.png 243 | 2015-11-06-11-21-12_05_000078.png 244 | 2015-11-06-11-21-12_05_000079.png 245 | 2015-11-06-11-21-12_05_000080.png 246 | 2015-11-06-11-21-12_05_000081.png 247 | 2015-11-06-11-21-12_05_000082.png 248 | 2015-11-06-11-21-12_05_000083.png 249 | 2015-11-06-11-21-12_05_000084.png 250 | 2015-11-06-11-21-12_05_000085.png 251 | 2015-11-06-11-21-12_05_000086.png 252 | 2015-11-06-11-21-12_05_000087.png 253 | 2015-11-06-11-21-12_05_000088.png 254 | 2015-11-06-11-21-12_05_000089.png 255 | 2015-11-06-11-21-12_05_000090.png 256 | 2015-11-06-11-21-12_05_000091.png 257 | 2015-11-06-11-21-12_05_000092.png 258 | 2015-11-06-11-21-12_05_000095.png 259 | 2015-11-06-11-21-12_05_000096.png 260 | 2015-11-06-11-21-12_05_000097.png 261 | 2015-11-06-11-21-12_05_000099.png 262 | 2015-11-06-11-21-12_05_000100.png 263 | 2015-11-06-11-21-12_05_000101.png 264 | 2015-11-06-11-21-12_05_000102.png 265 | 2015-11-06-11-21-12_05_000103.png 266 | 2015-11-06-11-21-12_05_000104.png 267 | 2015-11-06-11-21-12_05_000105.png 268 | 2015-11-06-11-21-12_05_000106.png 269 | 2015-11-06-11-21-12_05_000107.png 270 | 2015-11-06-11-21-12_05_000108.png 271 | 2015-11-06-11-21-12_05_000109.png 272 | -------------------------------------------------------------------------------- /dataset/robot_pseudo_dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import os.path as osp 3 | import numpy as np 4 | import random 5 | import matplotlib 6 | matplotlib.use('agg') 7 | import matplotlib.pyplot as plt 8 | import collections 9 | import torch 10 | import torchvision 11 | from torch.utils import data 12 | from PIL import Image, ImageFile 13 | from dataset.autoaugment import ImageNetPolicy 14 | import time 15 | 16 | ImageFile.LOAD_TRUNCATED_IMAGES = True 17 | 18 | class robot_pseudo_DataSet(data.Dataset): 19 | def __init__(self, root, list_path, max_iters=None, resize_size=(1280, 960), crop_size=(512, 1024), mean=(128, 128, 128), scale=True, mirror=True, ignore_label=255, set='train', autoaug=False): 20 | self.root = root 21 | self.list_path = list_path 22 | self.crop_size = crop_size 23 | self.scale = scale 24 | self.ignore_label = ignore_label 25 | self.mean = mean 26 | self.is_mirror = mirror 27 | self.resize_size = resize_size 28 | self.autoaug = autoaug 29 | self.h = crop_size[0] 30 | self.w = crop_size[1] 31 | # self.mean_bgr = np.array([104.00698793, 116.66876762, 122.67891434]) 32 | self.img_ids = [i_id.strip() for i_id in open(list_path)] 33 | if not max_iters==None: 34 | self.img_ids = self.img_ids * int(np.ceil(float(max_iters) / len(self.img_ids))) 35 | self.files = [] 36 | self.set = set 37 | # for split in ["train", "trainval", "val"]: 38 | 39 | #https://github.com/mcordts/cityscapesScripts/blob/master/cityscapesscripts/helpers/labels.py 40 | ''' 41 | 0 sky; 1 person; 2 two-wheel; 3 automobile; 4 sign 42 | 5 light 6 building 7 sidewalk 8 road 43 | ''' 44 | self.id_to_trainid = {0:0, 1:1, 2:2, 3:3, 4:4, 5:5, 6:6, 7:7, 8:8} 45 | 46 | for name in self.img_ids: 47 | img_file = osp.join(self.root, "%s/%s" % (self.set, name)) 48 | if set == 'val': 49 | label_file = osp.join(self.root, "anno/%s" %name ) 50 | else: 51 | label_file = osp.join(self.root, "pseudo_train_0.9/%s" %name ) 52 | self.files.append({ 53 | "img": img_file, 54 | "label": label_file, 55 | "name": name 56 | }) 57 | 58 | def __len__(self): 59 | return len(self.files) 60 | 61 | def __getitem__(self, index): 62 | #tt = time.time() 63 | datafiles = self.files[index] 64 | name = datafiles["name"] 65 | 66 | image = Image.open(datafiles["img"]).convert('RGB') 67 | label = Image.open(datafiles["label"]) 68 | 69 | if self.scale: 70 | random_scale = 0.8 + random.random()*0.4 # 0.8 - 1.2 71 | image = image.resize( ( round(self.resize_size[0] * random_scale), round(self.resize_size[1] * random_scale)) , Image.BICUBIC) 72 | label = label.resize( ( round(self.resize_size[0] * random_scale), round(self.resize_size[1] * random_scale)) , Image.NEAREST) 73 | else: 74 | image = image.resize( ( self.resize_size[0], self.resize_size[1] ) , Image.BICUBIC) 75 | label = label.resize( ( self.resize_size[0], self.resize_size[1] ) , Image.NEAREST) 76 | 77 | label = np.asarray(label, np.uint8) 78 | # re-assign labels to match the format of Cityscapes 79 | label_copy = 255 * np.ones(label.shape, dtype=np.uint8) 80 | for k, v in list(self.id_to_trainid.items()): 81 | label_copy[label == k] = v 82 | if self.autoaug: 83 | policy = ImageNetPolicy() 84 | image = policy(image) 85 | image = np.asarray(image, np.float32) 86 | size = image.shape 87 | image = image[:, :, ::-1] # change to BGR 88 | image -= self.mean 89 | image = image.transpose((2, 0, 1)) 90 | if self.set == 'train': 91 | for i in range(10): #find hard samples 92 | x1 = random.randint(0, image.shape[1] - self.h) 93 | y1 = random.randint(0, image.shape[2] - self.w) 94 | tmp_label_copy = label_copy[x1:x1+self.h, y1:y1+self.w] 95 | tmp_image = image[:, x1:x1+self.h, y1:y1+self.w] 96 | u = np.unique(tmp_label_copy) 97 | if len(u) > 4: 98 | break 99 | else: 100 | print('RB: Too young too naive for %d times!'%i) 101 | else: 102 | x1 = random.randint(0, image.shape[1] - self.h) 103 | y1 = random.randint(0, image.shape[2] - self.w) 104 | tmp_image = image[:, x1:x1+self.h, y1:y1+self.w] 105 | tmp_label_copy = label_copy[x1:x1+self.h, y1:y1+self.w] 106 | 107 | image = tmp_image 108 | label_copy = tmp_label_copy 109 | 110 | if self.is_mirror and random.random() < 0.5: 111 | image = np.flip(image, axis = 2) 112 | label_copy = np.flip(label_copy, axis = 1) 113 | 114 | return image.copy(), label_copy.copy(), np.array(size), name 115 | 116 | 117 | if __name__ == '__main__': 118 | dst = robot_pseudo_DataSet('./data/Oxford_Robot_ICCV19', './dataset/robot_list/train.txt', mean=(0,0,0), set = 'train', autoaug=True) 119 | trainloader = data.DataLoader(dst, batch_size=4) 120 | for i, data in enumerate(trainloader): 121 | imgs, _, _, _ = data 122 | if i == 0: 123 | img = torchvision.utils.make_grid(imgs).numpy() 124 | img = np.transpose(img, (1, 2, 0)) 125 | img = img[:, :, ::-1] 126 | img = Image.fromarray(np.uint8(img) ) 127 | img.save('Robot_pseudo_Demo.jpg') 128 | break 129 | -------------------------------------------------------------------------------- /dataset/synthia_dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import os.path as osp 3 | import numpy as np 4 | import random 5 | import matplotlib.pyplot as plt 6 | import collections 7 | import torch 8 | import torchvision 9 | from torch.utils import data 10 | from PIL import Image, ImageFile 11 | from dataset.autoaugment import ImageNetPolicy 12 | import imageio 13 | ImageFile.LOAD_TRUNCATED_IMAGES = True 14 | 15 | 16 | class SynthiaDataSet(data.Dataset): 17 | def __init__(self, root, list_path, max_iters=None, resize_size=(1024, 512), crop_size=(512, 1024), mean=(128, 128, 128), scale=False, mirror=True, ignore_label=255, autoaug = False): 18 | self.root = root 19 | self.list_path = list_path 20 | self.crop_size = crop_size 21 | self.scale = scale 22 | self.ignore_label = ignore_label 23 | self.mean = mean 24 | self.is_mirror = mirror 25 | self.resize_size = resize_size 26 | self.autoaug = autoaug 27 | self.h = crop_size[0] 28 | self.w = crop_size[1] 29 | # self.mean_bgr = np.array([104.00698793, 116.66876762, 122.67891434]) 30 | self.img_ids = [i_id.strip() for i_id in open(list_path)] 31 | if not max_iters==None: 32 | self.img_ids = self.img_ids * int(np.ceil(float(max_iters) / len(self.img_ids))) 33 | self.files = [] 34 | 35 | self.id_to_trainid = {3: 0, 4: 1, 2: 2, 21: 3, 5: 4, 7: 5, 36 | 15: 6, 9: 7, 6: 8, 16: 9, 1: 10, 10: 11, 17: 12, 37 | 8: 13, 18: 14, 19: 15, 20: 16, 12: 17, 11: 18} 38 | # for split in ["train", "trainval", "val"]: 39 | for name in self.img_ids: 40 | img_file = osp.join(self.root, "RGB/%s" % name) 41 | label_file = osp.join(self.root, "GT/LABELS/%s" % name) 42 | self.files.append({ 43 | "img": img_file, 44 | "label": label_file, 45 | "name": name 46 | }) 47 | 48 | def __len__(self): 49 | return len(self.files) 50 | 51 | 52 | def __getitem__(self, index): 53 | datafiles = self.files[index] 54 | 55 | image = Image.open(datafiles["img"]).convert('RGB') 56 | #label = Image.open(datafiles["label"]) 57 | label = np.asarray(imageio.imread(datafiles["label"], format='PNG-FI'))[:,:,0] # uint16 58 | label = Image.fromarray(label) 59 | name = datafiles["name"] 60 | 61 | # resize 62 | if self.scale: 63 | random_scale = 0.8 + random.random()*0.4 # 0.8 - 1.2 64 | image = image.resize( ( round(self.resize_size[0] * random_scale), round(self.resize_size[1] * random_scale)) , Image.BICUBIC) 65 | label = label.resize( ( round(self.resize_size[0] * random_scale), round(self.resize_size[1] * random_scale)) , Image.NEAREST) 66 | else: 67 | image = image.resize( ( self.resize_size[0], self.resize_size[1] ) , Image.BICUBIC) 68 | label = label.resize( ( self.resize_size[0], self.resize_size[1] ) , Image.NEAREST) 69 | 70 | if self.autoaug: 71 | policy = ImageNetPolicy() 72 | image = policy(image) 73 | 74 | image = np.asarray(image, np.float32) 75 | label = np.asarray(label, np.uint8) 76 | 77 | # re-assign labels to match the format of Cityscapes 78 | label_copy = 255 * np.ones(label.shape, dtype=np.uint8) 79 | for k, v in list(self.id_to_trainid.items()): 80 | label_copy[label == k] = v 81 | 82 | size = image.shape 83 | image = image[:, :, ::-1] # change to BGR 84 | image -= self.mean 85 | image = image.transpose((2, 0, 1)) 86 | print(image.shape, label.shape) 87 | for i in range(10): #find hard samples 88 | x1 = random.randint(0, image.shape[1] - self.h) 89 | y1 = random.randint(0, image.shape[2] - self.w) 90 | tmp_label_copy = label_copy[x1:x1+self.h, y1:y1+self.w] 91 | tmp_image = image[:, x1:x1+self.h, y1:y1+self.w] 92 | u = np.unique(tmp_label_copy) 93 | if len(u) > 10: 94 | break 95 | else: 96 | continue 97 | #print('GTA5: Too young too naive for %d times!'%i) 98 | 99 | image = tmp_image 100 | label_copy = tmp_label_copy 101 | 102 | if self.is_mirror and random.random() < 0.5: 103 | image = np.flip(image, axis = 2) 104 | label_copy = np.flip(label_copy, axis = 1) 105 | 106 | return image.copy(), label_copy.copy(), np.array(size), name 107 | 108 | 109 | if __name__ == '__main__': 110 | dst = SynthiaDataSet('./data/synthia/', './dataset/synthia_list/train.txt', mean=(0,0,0), autoaug=True) 111 | trainloader = data.DataLoader(dst, batch_size=4) 112 | for i, data in enumerate(trainloader): 113 | imgs, _, _, _ = data 114 | if i == 0: 115 | img = torchvision.utils.make_grid(imgs).numpy() 116 | img = np.transpose(img, (1, 2, 0)) 117 | img = img[:, :, ::-1] 118 | img = Image.fromarray(np.uint8(img) ) 119 | img.save('Synthia_Demo.jpg') 120 | break 121 | -------------------------------------------------------------------------------- /evaluate_cityscapes.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import scipy 3 | from scipy import ndimage 4 | import numpy as np 5 | import sys 6 | from packaging import version 7 | from multiprocessing import Pool 8 | import torch 9 | from torch.autograd import Variable 10 | import torchvision.models as models 11 | import torch.nn.functional as F 12 | from torch.utils import data, model_zoo 13 | from model.deeplab import Res_Deeplab 14 | from model.deeplab_multi import DeeplabMulti 15 | from model.deeplab_vgg import DeeplabVGG 16 | from dataset.cityscapes_dataset import cityscapesDataSet 17 | from collections import OrderedDict 18 | import os 19 | from PIL import Image 20 | from utils.tool import fliplr 21 | import matplotlib.pyplot as plt 22 | import torch.nn as nn 23 | import yaml 24 | import time 25 | 26 | torch.backends.cudnn.benchmark=True 27 | 28 | IMG_MEAN = np.array((104.00698793,116.66876762,122.67891434), dtype=np.float32) 29 | 30 | DATA_DIRECTORY = './data/Cityscapes/data' 31 | DATA_LIST_PATH = './dataset/cityscapes_list/val.txt' 32 | SAVE_PATH = './result/cityscapes' 33 | 34 | IGNORE_LABEL = 255 35 | NUM_CLASSES = 19 36 | NUM_STEPS = 500 # Number of images in the validation set. 37 | RESTORE_FROM = 'http://vllab.ucmerced.edu/ytsai/CVPR18/GTA2Cityscapes_multi-ed35151c.pth' 38 | RESTORE_FROM_VGG = 'http://vllab.ucmerced.edu/ytsai/CVPR18/GTA2Cityscapes_vgg-ac4ac9f6.pth' 39 | RESTORE_FROM_ORC = 'http://vllab1.ucmerced.edu/~whung/adaptSeg/cityscapes_oracle-b7b9934.pth' 40 | SET = 'val' 41 | 42 | MODEL = 'DeeplabMulti' 43 | 44 | palette = [128, 64, 128, 244, 35, 232, 70, 70, 70, 102, 102, 156, 190, 153, 153, 153, 153, 153, 250, 170, 30, 45 | 220, 220, 0, 107, 142, 35, 152, 251, 152, 70, 130, 180, 220, 20, 60, 255, 0, 0, 0, 0, 142, 0, 0, 70, 46 | 0, 60, 100, 0, 80, 100, 0, 0, 230, 119, 11, 32] 47 | zero_pad = 256 * 3 - len(palette) 48 | for i in range(zero_pad): 49 | palette.append(0) 50 | 51 | 52 | def colorize_mask(mask): 53 | # mask: numpy array of the mask 54 | new_mask = Image.fromarray(mask.astype(np.uint8)).convert('P') 55 | new_mask.putpalette(palette) 56 | 57 | return new_mask 58 | 59 | def get_arguments(): 60 | """Parse all the arguments provided from the CLI. 61 | 62 | Returns: 63 | A list of parsed arguments. 64 | """ 65 | parser = argparse.ArgumentParser(description="DeepLab-ResNet Network") 66 | parser.add_argument("--model", type=str, default=MODEL, 67 | help="Model Choice (DeeplabMulti/DeeplabVGG/Oracle).") 68 | parser.add_argument("--data-dir", type=str, default=DATA_DIRECTORY, 69 | help="Path to the directory containing the Cityscapes dataset.") 70 | parser.add_argument("--data-list", type=str, default=DATA_LIST_PATH, 71 | help="Path to the file listing the images in the dataset.") 72 | parser.add_argument("--ignore-label", type=int, default=IGNORE_LABEL, 73 | help="The index of the label to ignore during the training.") 74 | parser.add_argument("--num-classes", type=int, default=NUM_CLASSES, 75 | help="Number of classes to predict (including background).") 76 | parser.add_argument("--restore-from", type=str, default=RESTORE_FROM, 77 | help="Where restore model parameters from.") 78 | parser.add_argument("--gpu", type=int, default=0, 79 | help="choose gpu device.") 80 | parser.add_argument("--batchsize", type=int, default=16, 81 | help="choose gpu device.") 82 | parser.add_argument("--set", type=str, default=SET, 83 | help="choose evaluation set.") 84 | parser.add_argument("--save", type=str, default=SAVE_PATH, 85 | help="Path to save result.") 86 | return parser.parse_args() 87 | 88 | def save(output_name): 89 | output, name = output_name 90 | output_col = colorize_mask(output) 91 | output = Image.fromarray(output) 92 | 93 | output.save('%s' % (name)) 94 | output_col.save('%s_color.png' % (name.split('.jpg')[0])) 95 | return 96 | 97 | def save_heatmap(output_name): 98 | output, name = output_name 99 | fig = plt.figure() 100 | plt.axis('off') 101 | heatmap = plt.imshow(output, cmap='viridis') 102 | #fig.colorbar(heatmap) 103 | fig.savefig('%s_heatmap.png' % (name.split('.jpg')[0])) 104 | return 105 | 106 | def save_scoremap(output_name): 107 | output, name = output_name 108 | fig = plt.figure() 109 | plt.axis('off') 110 | heatmap = plt.imshow(output, cmap='viridis') 111 | #fig.colorbar(heatmap) 112 | fig.savefig('%s_scoremap.png' % (name.split('.jpg')[0])) 113 | return 114 | 115 | def main(): 116 | """Create the model and start the evaluation process.""" 117 | args = get_arguments() 118 | 119 | config_path = os.path.join(os.path.dirname(args.restore_from),'opts.yaml') 120 | with open(config_path, 'r') as stream: 121 | config = yaml.load(stream) 122 | 123 | args.model = config['model'] 124 | print('ModelType:%s'%args.model) 125 | print('NormType:%s'%config['norm_style']) 126 | gpu0 = args.gpu 127 | batchsize = args.batchsize 128 | 129 | model_name = os.path.basename( os.path.dirname(args.restore_from) ) 130 | args.save += model_name 131 | 132 | if not os.path.exists(args.save): 133 | os.makedirs(args.save) 134 | 135 | if args.model == 'DeepLab': 136 | model = DeeplabMulti(num_classes=args.num_classes, use_se = config['use_se'], train_bn = False, norm_style = config['norm_style']) 137 | elif args.model == 'Oracle': 138 | model = Res_Deeplab(num_classes=args.num_classes) 139 | if args.restore_from == RESTORE_FROM: 140 | args.restore_from = RESTORE_FROM_ORC 141 | elif args.model == 'DeeplabVGG': 142 | model = DeeplabVGG(num_classes=args.num_classes) 143 | if args.restore_from == RESTORE_FROM: 144 | args.restore_from = RESTORE_FROM_VGG 145 | 146 | if args.restore_from[:4] == 'http' : 147 | saved_state_dict = model_zoo.load_url(args.restore_from) 148 | else: 149 | saved_state_dict = torch.load(args.restore_from) 150 | 151 | try: 152 | model.load_state_dict(saved_state_dict) 153 | except: 154 | model = torch.nn.DataParallel(model) 155 | model.load_state_dict(saved_state_dict) 156 | #model = torch.nn.DataParallel(model) 157 | model.eval() 158 | model.cuda(gpu0) 159 | 160 | testloader = data.DataLoader(cityscapesDataSet(args.data_dir, args.data_list, crop_size=(512, 1024), resize_size=(1024, 512), mean=IMG_MEAN, scale=False, mirror=False, set=args.set), 161 | batch_size=batchsize, shuffle=False, pin_memory=True, num_workers=4) 162 | 163 | scale = 1.25 164 | testloader2 = data.DataLoader(cityscapesDataSet(args.data_dir, args.data_list, crop_size=(round(512*scale), round(1024*scale) ), resize_size=( round(1024*scale), round(512*scale)), mean=IMG_MEAN, scale=False, mirror=False, set=args.set), 165 | batch_size=batchsize, shuffle=False, pin_memory=True, num_workers=4) 166 | scale = 0.9 167 | testloader3 = data.DataLoader(cityscapesDataSet(args.data_dir, args.data_list, crop_size=(round(512*scale), round(1024*scale) ), resize_size=( round(1024*scale), round(512*scale)), mean=IMG_MEAN, scale=False, mirror=False, set=args.set), 168 | batch_size=batchsize, shuffle=False, pin_memory=True, num_workers=4) 169 | 170 | 171 | if version.parse(torch.__version__) >= version.parse('0.4.0'): 172 | interp = nn.Upsample(size=(1024, 2048), mode='bilinear', align_corners=True) 173 | else: 174 | interp = nn.Upsample(size=(1024, 2048), mode='bilinear') 175 | 176 | sm = torch.nn.Softmax(dim = 1) 177 | log_sm = torch.nn.LogSoftmax(dim = 1) 178 | kl_distance = nn.KLDivLoss( reduction = 'none') 179 | 180 | for index, img_data in enumerate(zip(testloader, testloader2, testloader3) ): 181 | batch, batch2, batch3 = img_data 182 | image, _, _, name = batch 183 | image2, _, _, name2 = batch2 184 | #image3, _, _, name3 = batch3 185 | 186 | inputs = image.cuda() 187 | inputs2 = image2.cuda() 188 | #inputs3 = Variable(image3).cuda() 189 | print('\r>>>>Extracting feature...%03d/%03d'%(index*batchsize, NUM_STEPS), end='') 190 | if args.model == 'DeepLab': 191 | with torch.no_grad(): 192 | output1, output2 = model(inputs) 193 | output_batch = interp(sm(0.5* output1 + output2)) 194 | heatmap_output1, heatmap_output2 = output1, output2 195 | #output_batch = interp(sm(output1)) 196 | #output_batch = interp(sm(output2)) 197 | output1, output2 = model(fliplr(inputs)) 198 | output1, output2 = fliplr(output1), fliplr(output2) 199 | output_batch += interp(sm(0.5 * output1 + output2)) 200 | heatmap_output1, heatmap_output2 = heatmap_output1+output1, heatmap_output2+output2 201 | #output_batch += interp(sm(output1)) 202 | #output_batch += interp(sm(output2)) 203 | del output1, output2, inputs 204 | 205 | output1, output2 = model(inputs2) 206 | output_batch += interp(sm(0.5* output1 + output2)) 207 | #output_batch += interp(sm(output1)) 208 | #output_batch += interp(sm(output2)) 209 | output1, output2 = model(fliplr(inputs2)) 210 | output1, output2 = fliplr(output1), fliplr(output2) 211 | output_batch += interp(sm(0.5 * output1 + output2)) 212 | #output_batch += interp(sm(output1)) 213 | #output_batch += interp(sm(output2)) 214 | del output1, output2, inputs2 215 | output_batch = output_batch.cpu().data.numpy() 216 | heatmap_batch = torch.sum(kl_distance(log_sm(heatmap_output1), sm(heatmap_output2)), dim=1) 217 | heatmap_batch = torch.log(1 + 10*heatmap_batch) # for visualization 218 | heatmap_batch = heatmap_batch.cpu().data.numpy() 219 | 220 | #output1, output2 = model(inputs3) 221 | #output_batch += interp(sm(0.5* output1 + output2)).cpu().data.numpy() 222 | #output1, output2 = model(fliplr(inputs3)) 223 | #output1, output2 = fliplr(output1), fliplr(output2) 224 | #output_batch += interp(sm(0.5 * output1 + output2)).cpu().data.numpy() 225 | #del output1, output2, inputs3 226 | elif args.model == 'DeeplabVGG' or args.model == 'Oracle': 227 | output_batch = model(Variable(image).cuda()) 228 | output_batch = interp(output_batch).cpu().data.numpy() 229 | 230 | output_batch = output_batch.transpose(0,2,3,1) 231 | scoremap_batch = np.asarray(np.max(output_batch, axis=3)) 232 | output_batch = np.asarray(np.argmax(output_batch, axis=3), dtype=np.uint8) 233 | output_iterator = [] 234 | heatmap_iterator = [] 235 | scoremap_iterator = [] 236 | 237 | for i in range(output_batch.shape[0]): 238 | output_iterator.append(output_batch[i,:,:]) 239 | heatmap_iterator.append(heatmap_batch[i,:,:]/np.max(heatmap_batch[i,:,:])) 240 | scoremap_iterator.append(1-scoremap_batch[i,:,:]/np.max(scoremap_batch[i,:,:])) 241 | name_tmp = name[i].split('/')[-1] 242 | name[i] = '%s/%s' % (args.save, name_tmp) 243 | with Pool(4) as p: 244 | p.map(save, zip(output_iterator, name) ) 245 | p.map(save_heatmap, zip(heatmap_iterator, name) ) 246 | p.map(save_scoremap, zip(scoremap_iterator, name) ) 247 | 248 | del output_batch 249 | 250 | 251 | return args.save 252 | 253 | if __name__ == '__main__': 254 | tt = time.time() 255 | with torch.no_grad(): 256 | save_path = main() 257 | print('Time used: {} sec'.format(time.time()-tt)) 258 | os.system('python compute_iou.py ./data/Cityscapes/data/gtFine/val %s'%save_path) 259 | -------------------------------------------------------------------------------- /evaluate_gta5.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import scipy 3 | from scipy import ndimage 4 | import numpy as np 5 | import sys 6 | from packaging import version 7 | 8 | import torch 9 | from torch.autograd import Variable 10 | import torchvision.models as models 11 | import torch.nn.functional as F 12 | from torch.utils import data, model_zoo 13 | from model.deeplab import Res_Deeplab 14 | from model.deeplab_multi import DeeplabMulti 15 | from model.deeplab_vgg import DeeplabVGG 16 | from dataset.gta5_dataset import GTA5DataSet 17 | from collections import OrderedDict 18 | import os 19 | from PIL import Image 20 | from utils.tool import fliplr 21 | import matplotlib.pyplot as plt 22 | import torch.nn as nn 23 | IMG_MEAN = np.array((104.00698793,116.66876762,122.67891434), dtype=np.float32) 24 | 25 | # We just use this file to evaluate the perfromance on the training set 26 | DATA_DIRECTORY = './data/GTA5' 27 | DATA_LIST_PATH = './dataset/gta5_list/train.txt' 28 | SAVE_PATH = './result/GTA5' 29 | 30 | IGNORE_LABEL = 255 31 | NUM_CLASSES = 19 32 | NUM_STEPS = 500 # Number of images in the validation set. 33 | RESTORE_FROM = 'http://vllab.ucmerced.edu/ytsai/CVPR18/GTA2Cityscapes_multi-ed35151c.pth' 34 | RESTORE_FROM_VGG = 'http://vllab.ucmerced.edu/ytsai/CVPR18/GTA2Cityscapes_vgg-ac4ac9f6.pth' 35 | RESTORE_FROM_ORC = 'http://vllab1.ucmerced.edu/~whung/adaptSeg/cityscapes_oracle-b7b9934.pth' 36 | SET = 'val' 37 | 38 | MODEL = 'DeeplabMulti' 39 | 40 | palette = [128, 64, 128, 244, 35, 232, 70, 70, 70, 102, 102, 156, 190, 153, 153, 153, 153, 153, 250, 170, 30, 41 | 220, 220, 0, 107, 142, 35, 152, 251, 152, 70, 130, 180, 220, 20, 60, 255, 0, 0, 0, 0, 142, 0, 0, 70, 42 | 0, 60, 100, 0, 80, 100, 0, 0, 230, 119, 11, 32] 43 | zero_pad = 256 * 3 - len(palette) 44 | for i in range(zero_pad): 45 | palette.append(0) 46 | 47 | 48 | def colorize_mask(mask): 49 | # mask: numpy array of the mask 50 | new_mask = Image.fromarray(mask.astype(np.uint8)).convert('P') 51 | new_mask.putpalette(palette) 52 | 53 | return new_mask 54 | 55 | def get_arguments(): 56 | """Parse all the arguments provided from the CLI. 57 | 58 | Returns: 59 | A list of parsed arguments. 60 | """ 61 | parser = argparse.ArgumentParser(description="DeepLab-ResNet Network") 62 | parser.add_argument("--model", type=str, default=MODEL, 63 | help="Model Choice (DeeplabMulti/DeeplabVGG/Oracle).") 64 | parser.add_argument("--data-dir", type=str, default=DATA_DIRECTORY, 65 | help="Path to the directory containing the Cityscapes dataset.") 66 | parser.add_argument("--data-list", type=str, default=DATA_LIST_PATH, 67 | help="Path to the file listing the images in the dataset.") 68 | parser.add_argument("--ignore-label", type=int, default=IGNORE_LABEL, 69 | help="The index of the label to ignore during the training.") 70 | parser.add_argument("--num-classes", type=int, default=NUM_CLASSES, 71 | help="Number of classes to predict (including background).") 72 | parser.add_argument("--restore-from", type=str, default=RESTORE_FROM, 73 | help="Where restore model parameters from.") 74 | parser.add_argument("--gpu", type=int, default=0, 75 | help="choose gpu device.") 76 | parser.add_argument("--batchsize", type=int, default=10, 77 | help="choose gpu device.") 78 | parser.add_argument("--set", type=str, default=SET, 79 | help="choose evaluation set.") 80 | parser.add_argument("--save", type=str, default=SAVE_PATH, 81 | help="Path to save result.") 82 | return parser.parse_args() 83 | 84 | 85 | def main(): 86 | """Create the model and start the evaluation process.""" 87 | 88 | args = get_arguments() 89 | 90 | gpu0 = args.gpu 91 | batchsize = args.batchsize 92 | 93 | model_name = os.path.basename( os.path.dirname(args.restore_from) ) 94 | args.save += model_name 95 | 96 | if not os.path.exists(args.save): 97 | os.makedirs(args.save) 98 | 99 | if args.model == 'DeeplabMulti': 100 | model = DeeplabMulti(num_classes=args.num_classes, train_bn = False, norm_style = 'in') 101 | elif args.model == 'Oracle': 102 | model = Res_Deeplab(num_classes=args.num_classes) 103 | if args.restore_from == RESTORE_FROM: 104 | args.restore_from = RESTORE_FROM_ORC 105 | elif args.model == 'DeeplabVGG': 106 | model = DeeplabVGG(num_classes=args.num_classes) 107 | if args.restore_from == RESTORE_FROM: 108 | args.restore_from = RESTORE_FROM_VGG 109 | 110 | if args.restore_from[:4] == 'http' : 111 | saved_state_dict = model_zoo.load_url(args.restore_from) 112 | else: 113 | saved_state_dict = torch.load(args.restore_from) 114 | 115 | try: 116 | model.load_state_dict(saved_state_dict) 117 | except: 118 | model = torch.nn.DataParallel(model) 119 | model.load_state_dict(saved_state_dict) 120 | model.eval() 121 | model.cuda() 122 | 123 | testloader = data.DataLoader(GTA5DataSet(args.data_dir, args.data_list, crop_size=(640, 1280), resize_size=(1280, 640), mean=IMG_MEAN, scale=False, mirror=False), 124 | batch_size=batchsize, shuffle=False, pin_memory=True) 125 | 126 | 127 | if version.parse(torch.__version__) >= version.parse('0.4.0'): 128 | interp = nn.Upsample(size=(640, 1280 ), mode='bilinear', align_corners=True) 129 | else: 130 | interp = nn.Upsample(size=(640, 1280 ), mode='bilinear') 131 | 132 | sm = torch.nn.Softmax(dim = 1) 133 | for index, batch in enumerate(testloader): 134 | if (index*batchsize) % 100 == 0: 135 | print('%d processd' % (index*batchsize)) 136 | image, _, _, name = batch 137 | print(image.shape) 138 | 139 | inputs = Variable(image).cuda() 140 | if args.model == 'DeeplabMulti': 141 | output1, output2 = model(inputs) 142 | output_batch = interp(sm(0.5* output1 + output2)).cpu().data.numpy() 143 | #output1, output2 = model(fliplr(inputs)) 144 | #output2 = fliplr(output2) 145 | #output_batch += interp(output2).cpu().data.numpy() 146 | elif args.model == 'DeeplabVGG' or args.model == 'Oracle': 147 | output_batch = model(Variable(image).cuda()) 148 | output_batch = interp(output_batch).cpu().data.numpy() 149 | 150 | output_batch = output_batch.transpose(0,2,3,1) 151 | output_batch = np.asarray(np.argmax(output_batch, axis=3), dtype=np.uint8) 152 | 153 | for i in range(output_batch.shape[0]): 154 | output = output_batch[i,:,:] 155 | output_col = colorize_mask(output) 156 | output = Image.fromarray(output) 157 | 158 | name_tmp = name[i].split('/')[-1] 159 | output.save('%s/%s' % (args.save, name_tmp)) 160 | output_col.save('%s/%s_color.png' % (args.save, name_tmp.split('.')[0])) 161 | 162 | return args.save 163 | 164 | if __name__ == '__main__': 165 | with torch.no_grad(): 166 | save_path = main() 167 | os.system('python compute_iou.py ./data/GTA5/data/gtFine/val %s'%save_path) 168 | -------------------------------------------------------------------------------- /evaluate_robot.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import scipy 3 | from scipy import ndimage 4 | import numpy as np 5 | import sys 6 | from packaging import version 7 | from multiprocessing import Pool 8 | import torch 9 | from torch.autograd import Variable 10 | import torchvision.models as models 11 | import torch.nn.functional as F 12 | from torch.utils import data, model_zoo 13 | from model.deeplab import Res_Deeplab 14 | from model.deeplab_multi import DeeplabMulti 15 | from model.deeplab_vgg import DeeplabVGG 16 | from dataset.robot_dataset import robotDataSet 17 | from collections import OrderedDict 18 | import os 19 | from PIL import Image 20 | from utils.tool import fliplr 21 | import matplotlib.pyplot as plt 22 | import torch.nn as nn 23 | import yaml 24 | import time 25 | 26 | torch.backends.cudnn.benchmark=True 27 | 28 | IMG_MEAN = np.array((104.00698793,116.66876762,122.67891434), dtype=np.float32) 29 | 30 | DATA_DIRECTORY = './data/Oxford_Robot_ICCV19' 31 | DATA_LIST_PATH = './dataset/robot_list/val.txt' 32 | SAVE_PATH = './result/robot' 33 | 34 | IGNORE_LABEL = 255 35 | NUM_CLASSES = 9 36 | NUM_STEPS = 271 # Number of images in the validation set. 37 | RESTORE_FROM = 'http://vllab.ucmerced.edu/ytsai/CVPR18/GTA2Cityscapes_multi-ed35151c.pth' 38 | RESTORE_FROM_VGG = 'http://vllab.ucmerced.edu/ytsai/CVPR18/GTA2Cityscapes_vgg-ac4ac9f6.pth' 39 | RESTORE_FROM_ORC = 'http://vllab1.ucmerced.edu/~whung/adaptSeg/cityscapes_oracle-b7b9934.pth' 40 | SET = 'val' 41 | 42 | MODEL = 'DeeplabMulti' 43 | 44 | palette = [ 45 | [70,130,180], 46 | [220,20,60], 47 | [119,11,32], 48 | [0,0,142], 49 | [220,220,0], 50 | [250,170,30], 51 | [70,70,70], 52 | [244,35,232], 53 | [128,64,128], 54 | ] 55 | palette = [item for sublist in palette for item in sublist] 56 | zero_pad = 256 * 3 - len(palette) 57 | for i in range(zero_pad): 58 | palette.append(0) 59 | 60 | def colorize_mask(mask): 61 | # mask: numpy array of the mask 62 | new_mask = Image.fromarray(mask.astype(np.uint8)).convert('P') 63 | new_mask.putpalette(palette) 64 | 65 | return new_mask 66 | 67 | def get_arguments(): 68 | """Parse all the arguments provided from the CLI. 69 | 70 | Returns: 71 | A list of parsed arguments. 72 | """ 73 | parser = argparse.ArgumentParser(description="DeepLab-ResNet Network") 74 | parser.add_argument("--model", type=str, default=MODEL, 75 | help="Model Choice (DeeplabMulti/DeeplabVGG/Oracle).") 76 | parser.add_argument("--data-dir", type=str, default=DATA_DIRECTORY, 77 | help="Path to the directory containing the Cityscapes dataset.") 78 | parser.add_argument("--data-list", type=str, default=DATA_LIST_PATH, 79 | help="Path to the file listing the images in the dataset.") 80 | parser.add_argument("--ignore-label", type=int, default=IGNORE_LABEL, 81 | help="The index of the label to ignore during the training.") 82 | parser.add_argument("--num-classes", type=int, default=NUM_CLASSES, 83 | help="Number of classes to predict (including background).") 84 | parser.add_argument("--restore-from", type=str, default=RESTORE_FROM, 85 | help="Where restore model parameters from.") 86 | parser.add_argument("--gpu", type=int, default=0, 87 | help="choose gpu device.") 88 | parser.add_argument("--batchsize", type=int, default=12, 89 | help="choose gpu device.") 90 | parser.add_argument("--set", type=str, default=SET, 91 | help="choose evaluation set.") 92 | parser.add_argument("--save", type=str, default=SAVE_PATH, 93 | help="Path to save result.") 94 | return parser.parse_args() 95 | 96 | def save(output_name): 97 | output, name = output_name 98 | output_col = colorize_mask(output) 99 | output = Image.fromarray(output) 100 | 101 | output.save('%s' % (name)) 102 | output_col.save('%s_color.png' % (name.split('.jpg')[0])) 103 | return 104 | 105 | 106 | def main(): 107 | """Create the model and start the evaluation process.""" 108 | args = get_arguments() 109 | 110 | config_path = os.path.join(os.path.dirname(args.restore_from),'opts.yaml') 111 | with open(config_path, 'r') as stream: 112 | config = yaml.load(stream) 113 | 114 | args.model = config['model'] 115 | print('ModelType:%s'%args.model) 116 | print('NormType:%s'%config['norm_style']) 117 | gpu0 = args.gpu 118 | batchsize = args.batchsize 119 | 120 | model_name = os.path.basename( os.path.dirname(args.restore_from) ) 121 | args.save += model_name 122 | 123 | if not os.path.exists(args.save): 124 | os.makedirs(args.save) 125 | 126 | if args.model == 'DeepLab': 127 | model = DeeplabMulti(num_classes=args.num_classes, use_se = config['use_se'], train_bn = False, norm_style = config['norm_style']) 128 | elif args.model == 'Oracle': 129 | model = Res_Deeplab(num_classes=args.num_classes) 130 | if args.restore_from == RESTORE_FROM: 131 | args.restore_from = RESTORE_FROM_ORC 132 | elif args.model == 'DeeplabVGG': 133 | model = DeeplabVGG(num_classes=args.num_classes) 134 | if args.restore_from == RESTORE_FROM: 135 | args.restore_from = RESTORE_FROM_VGG 136 | 137 | if args.restore_from[:4] == 'http' : 138 | saved_state_dict = model_zoo.load_url(args.restore_from) 139 | else: 140 | saved_state_dict = torch.load(args.restore_from) 141 | 142 | try: 143 | model.load_state_dict(saved_state_dict) 144 | except: 145 | model = torch.nn.DataParallel(model) 146 | model.load_state_dict(saved_state_dict) 147 | #model = torch.nn.DataParallel(model) 148 | model.eval() 149 | model.cuda(gpu0) 150 | 151 | th = 960 152 | tw = 1280 153 | 154 | testloader = data.DataLoader(robotDataSet(args.data_dir, args.data_list, crop_size=(th, tw), resize_size=(tw, th), mean=IMG_MEAN, scale=False, mirror=False, set=args.set), 155 | batch_size=batchsize, shuffle=False, pin_memory=True, num_workers=4) 156 | 157 | scale = 0.8 158 | testloader2 = data.DataLoader(robotDataSet(args.data_dir, args.data_list, crop_size=(round(th*scale), round(tw*scale) ), resize_size=( round(tw*scale), round(th*scale)), mean=IMG_MEAN, scale=False, mirror=False, set=args.set), 159 | batch_size=batchsize, shuffle=False, pin_memory=True, num_workers=4) 160 | scale = 0.9 161 | testloader3 = data.DataLoader(robotDataSet(args.data_dir, args.data_list, crop_size=(round(th*scale), round(tw*scale) ), resize_size=( round(tw*scale), round(th*scale)), mean=IMG_MEAN, scale=False, mirror=False, set=args.set), 162 | batch_size=batchsize, shuffle=False, pin_memory=True, num_workers=4) 163 | 164 | 165 | if version.parse(torch.__version__) >= version.parse('0.4.0'): 166 | interp = nn.Upsample(size=(960, 1280), mode='bilinear', align_corners=True) 167 | else: 168 | interp = nn.Upsample(size=(960, 1280), mode='bilinear') 169 | 170 | sm = torch.nn.Softmax(dim = 1) 171 | for index, img_data in enumerate(zip(testloader, testloader2, testloader3) ): 172 | batch, batch2, batch3 = img_data 173 | image, _, _, name = batch 174 | image2, _, _, name2 = batch2 175 | image3, _, _, name3 = batch3 176 | 177 | inputs = image.cuda() 178 | inputs2 = image2.cuda() 179 | inputs3 = Variable(image3).cuda() 180 | print('\r>>>>Extracting feature...%03d/%03d'%(index*batchsize, NUM_STEPS), end='') 181 | if args.model == 'DeepLab': 182 | with torch.no_grad(): 183 | output1, output2 = model(inputs) 184 | output_batch = interp(sm(0.5* output1 + output2)) 185 | #output_batch = interp(sm(output1)) 186 | #output_batch = interp(sm(output2)) 187 | output1, output2 = model(fliplr(inputs)) 188 | output1, output2 = fliplr(output1), fliplr(output2) 189 | output_batch += interp(sm(0.5 * output1 + output2)) 190 | #output_batch += interp(sm(output1)) 191 | #output_batch += interp(sm(output2)) 192 | del output1, output2, inputs 193 | 194 | output1, output2 = model(inputs2) 195 | output_batch += interp(sm(0.5* output1 + output2)) 196 | #output_batch += interp(sm(output1)) 197 | #output_batch += interp(sm(output2)) 198 | output1, output2 = model(fliplr(inputs2)) 199 | output1, output2 = fliplr(output1), fliplr(output2) 200 | output_batch += interp(sm(0.5 * output1 + output2)) 201 | #output_batch += interp(sm(output1)) 202 | #output_batch += interp(sm(output2)) 203 | del output1, output2, inputs2 204 | 205 | #output1, output2 = model(inputs3) 206 | #output_batch += interp(sm(0.5* output1 + output2)) 207 | #output1, output2 = model(fliplr(inputs3)) 208 | #output1, output2 = fliplr(output1), fliplr(output2) 209 | #output_batch += interp(sm(0.5 * output1 + output2)) 210 | #del output1, output2, inputs3 211 | 212 | output_batch = output_batch.cpu().data.numpy() 213 | elif args.model == 'DeeplabVGG' or args.model == 'Oracle': 214 | output_batch = model(Variable(image).cuda()) 215 | output_batch = interp(output_batch).cpu().data.numpy() 216 | 217 | output_batch = output_batch.transpose(0,2,3,1) 218 | output_batch = np.asarray(np.argmax(output_batch, axis=3), dtype=np.uint8) 219 | output_iterator = [] 220 | for i in range(output_batch.shape[0]): 221 | output_iterator.append(output_batch[i,:,:]) 222 | name_tmp = name[i].split('/')[-1] 223 | name[i] = '%s/%s' % (args.save, name_tmp) 224 | with Pool(4) as p: 225 | p.map(save, zip(output_iterator, name) ) 226 | del output_batch 227 | 228 | 229 | return args.save 230 | 231 | if __name__ == '__main__': 232 | tt = time.time() 233 | with torch.no_grad(): 234 | save_path = main() 235 | print('Time used: {} sec'.format(time.time()-tt)) 236 | devkit_path='dataset/robot_list' 237 | os.system('python compute_iou.py ./data/Oxford_Robot_ICCV19/anno %s --devkit_dir %s'%(save_path, devkit_path)) 238 | -------------------------------------------------------------------------------- /generate_plabel_cityscapes.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import scipy 3 | from scipy import ndimage 4 | import numpy as np 5 | import sys 6 | import re 7 | from packaging import version 8 | 9 | import torch 10 | from torch.autograd import Variable 11 | import torchvision.models as models 12 | import torch.nn.functional as F 13 | from torch.utils import data, model_zoo 14 | from model.deeplab import Res_Deeplab 15 | from model.deeplab_multi import DeeplabMulti 16 | from model.deeplab_vgg import DeeplabVGG 17 | from dataset.cityscapes_dataset import cityscapesDataSet 18 | from collections import OrderedDict 19 | import os 20 | from PIL import Image 21 | from utils.tool import fliplr 22 | import matplotlib.pyplot as plt 23 | import torch.nn as nn 24 | import yaml 25 | 26 | torch.backends.cudnn.benchmark=True 27 | 28 | IMG_MEAN = np.array((104.00698793,116.66876762,122.67891434), dtype=np.float32) 29 | 30 | DATA_DIRECTORY = './data/Cityscapes/data' 31 | DATA_LIST_PATH = './dataset/cityscapes_list/train.txt' 32 | SAVE_PATH = './data/Cityscapes/data/pseudo/train' 33 | 34 | if not os.path.isdir('./data/Cityscapes/data/pseudo/'): 35 | os.mkdir('./data/Cityscapes/data/pseudo/') 36 | os.mkdir(SAVE_PATH) 37 | 38 | IGNORE_LABEL = 255 39 | NUM_CLASSES = 19 40 | NUM_STEPS = 2975 # Number of images in the validation set. 41 | RESTORE_FROM = 'http://vllab.ucmerced.edu/ytsai/CVPR18/GTA2Cityscapes_multi-ed35151c.pth' 42 | RESTORE_FROM_VGG = 'http://vllab.ucmerced.edu/ytsai/CVPR18/GTA2Cityscapes_vgg-ac4ac9f6.pth' 43 | RESTORE_FROM_ORC = 'http://vllab1.ucmerced.edu/~whung/adaptSeg/cityscapes_oracle-b7b9934.pth' 44 | SET = 'train' # We generate pseudo label for training set 45 | 46 | MODEL = 'DeeplabMulti' 47 | 48 | palette = [128, 64, 128, 244, 35, 232, 70, 70, 70, 102, 102, 156, 190, 153, 153, 153, 153, 153, 250, 170, 30, 49 | 220, 220, 0, 107, 142, 35, 152, 251, 152, 70, 130, 180, 220, 20, 60, 255, 0, 0, 0, 0, 142, 0, 0, 70, 50 | 0, 60, 100, 0, 80, 100, 0, 0, 230, 119, 11, 32] 51 | zero_pad = 256 * 3 - len(palette) 52 | for i in range(zero_pad): 53 | palette.append(0) 54 | 55 | 56 | def colorize_mask(mask): 57 | # mask: numpy array of the mask 58 | new_mask = Image.fromarray(mask.astype(np.uint8)).convert('P') 59 | new_mask.putpalette(palette) 60 | 61 | return new_mask 62 | 63 | def get_arguments(): 64 | """Parse all the arguments provided from the CLI. 65 | 66 | Returns: 67 | A list of parsed arguments. 68 | """ 69 | parser = argparse.ArgumentParser(description="DeepLab-ResNet Network") 70 | parser.add_argument("--model", type=str, default=MODEL, 71 | help="Model Choice (DeeplabMulti/DeeplabVGG/Oracle).") 72 | parser.add_argument("--data-dir", type=str, default=DATA_DIRECTORY, 73 | help="Path to the directory containing the Cityscapes dataset.") 74 | parser.add_argument("--data-list", type=str, default=DATA_LIST_PATH, 75 | help="Path to the file listing the images in the dataset.") 76 | parser.add_argument("--ignore-label", type=int, default=IGNORE_LABEL, 77 | help="The index of the label to ignore during the training.") 78 | parser.add_argument("--num-classes", type=int, default=NUM_CLASSES, 79 | help="Number of classes to predict (including background).") 80 | parser.add_argument("--restore-from", type=str, default=RESTORE_FROM, 81 | help="Where restore model parameters from.") 82 | parser.add_argument("--gpu", type=int, default=0, 83 | help="choose gpu device.") 84 | parser.add_argument("--batchsize", type=int, default=12, 85 | help="choose gpu device.") 86 | parser.add_argument("--set", type=str, default=SET, 87 | help="choose evaluation set.") 88 | parser.add_argument("--save", type=str, default=SAVE_PATH, 89 | help="Path to save result.") 90 | return parser.parse_args() 91 | 92 | def save_heatmap(output_name): 93 | output, name = output_name 94 | fig = plt.figure() 95 | plt.axis('off') 96 | heatmap = plt.imshow(output, cmap='viridis') 97 | fig.colorbar(heatmap) 98 | fig.savefig('%s_heatmap.png' % (name.split('.jpg')[0])) 99 | return 100 | 101 | def main(): 102 | """Create the model and start the evaluation process.""" 103 | 104 | args = get_arguments() 105 | 106 | config_path = os.path.join(os.path.dirname(args.restore_from),'opts.yaml') 107 | with open(config_path, 'r') as stream: 108 | config = yaml.load(stream) 109 | 110 | args.model = config['model'] 111 | print('ModelType:%s'%args.model) 112 | print('NormType:%s'%config['norm_style']) 113 | gpu0 = args.gpu 114 | batchsize = args.batchsize 115 | 116 | model_name = os.path.basename( os.path.dirname(args.restore_from) ) 117 | #args.save += model_name 118 | 119 | if not os.path.exists(args.save): 120 | os.makedirs(args.save) 121 | 122 | if args.model == 'DeepLab': 123 | model = DeeplabMulti(num_classes=args.num_classes, use_se = config['use_se'], train_bn = False, norm_style = config['norm_style']) 124 | elif args.model == 'Oracle': 125 | model = Res_Deeplab(num_classes=args.num_classes) 126 | if args.restore_from == RESTORE_FROM: 127 | args.restore_from = RESTORE_FROM_ORC 128 | elif args.model == 'DeeplabVGG': 129 | model = DeeplabVGG(num_classes=args.num_classes) 130 | if args.restore_from == RESTORE_FROM: 131 | args.restore_from = RESTORE_FROM_VGG 132 | 133 | if args.restore_from[:4] == 'http' : 134 | saved_state_dict = model_zoo.load_url(args.restore_from) 135 | else: 136 | saved_state_dict = torch.load(args.restore_from) 137 | 138 | try: 139 | model.load_state_dict(saved_state_dict) 140 | except: 141 | model = torch.nn.DataParallel(model) 142 | model.load_state_dict(saved_state_dict) 143 | model.eval() 144 | model.cuda(gpu0) 145 | 146 | testloader = data.DataLoader(cityscapesDataSet(args.data_dir, args.data_list, crop_size=(512, 1024), resize_size=(1024, 512), mean=IMG_MEAN, scale=False, mirror=False, set=args.set), 147 | batch_size=batchsize, shuffle=False, pin_memory=True, num_workers=4) 148 | 149 | scale = 1.25 150 | testloader2 = data.DataLoader(cityscapesDataSet(args.data_dir, args.data_list, crop_size=(round(512*scale), round(1024*scale) ), resize_size=( round(1024*scale), round(512*scale)), mean=IMG_MEAN, scale=False, mirror=False, set=args.set), 151 | batch_size=batchsize, shuffle=False, pin_memory=True, num_workers=4) 152 | 153 | 154 | if version.parse(torch.__version__) >= version.parse('0.4.0'): 155 | interp = nn.Upsample(size=(1024, 2048), mode='bilinear', align_corners=True) 156 | else: 157 | interp = nn.Upsample(size=(1024, 2048), mode='bilinear') 158 | 159 | sm = torch.nn.Softmax(dim = 1) 160 | log_sm = torch.nn.LogSoftmax(dim = 1) 161 | kl_distance = nn.KLDivLoss( reduction = 'none') 162 | 163 | for index, img_data in enumerate(zip(testloader, testloader2) ): 164 | batch, batch2 = img_data 165 | image, _, _, name = batch 166 | image2, _, _, name2 = batch2 167 | print(image.shape) 168 | 169 | inputs = image.cuda() 170 | inputs2 = image2.cuda() 171 | print('\r>>>>Extracting feature...%04d/%04d'%(index*batchsize, NUM_STEPS), end='') 172 | if args.model == 'DeepLab': 173 | with torch.no_grad(): 174 | output1, output2 = model(inputs) 175 | output_batch = interp(sm(0.5* output1 + output2)) 176 | 177 | heatmap_batch = torch.sum(kl_distance(log_sm(output1), sm(output2)), dim=1) 178 | 179 | output1, output2 = model(fliplr(inputs)) 180 | output1, output2 = fliplr(output1), fliplr(output2) 181 | output_batch += interp(sm(0.5 * output1 + output2)) 182 | del output1, output2, inputs 183 | 184 | output1, output2 = model(inputs2) 185 | output_batch += interp(sm(0.5* output1 + output2)) 186 | output1, output2 = model(fliplr(inputs2)) 187 | output1, output2 = fliplr(output1), fliplr(output2) 188 | output_batch += interp(sm(0.5 * output1 + output2)) 189 | del output1, output2, inputs2 190 | output_batch = output_batch.cpu().data.numpy() 191 | heatmap_batch = heatmap_batch.cpu().data.numpy() 192 | elif args.model == 'DeeplabVGG' or args.model == 'Oracle': 193 | output_batch = model(Variable(image).cuda()) 194 | output_batch = interp(output_batch).cpu().data.numpy() 195 | 196 | #output_batch = output_batch.transpose(0,2,3,1) 197 | #output_batch = np.asarray(np.argmax(output_batch, axis=3), dtype=np.uint8) 198 | output_batch = output_batch.transpose(0,2,3,1) 199 | score_batch = np.max(output_batch, axis=3) 200 | output_batch = np.asarray(np.argmax(output_batch, axis=3), dtype=np.uint8) 201 | #output_batch[score_batch<3.2] = 255 #3.2 = 4*0.8 202 | for i in range(output_batch.shape[0]): 203 | output = output_batch[i,:,:] 204 | output_col = colorize_mask(output) 205 | output = Image.fromarray(output) 206 | 207 | name_tmp = name[i].split('/')[-1] 208 | dir_name = name[i].split('/')[-2] 209 | save_path = args.save + '/' + dir_name 210 | #save_path = re.replace(save_path, 'leftImg8bit', 'pseudo') 211 | #print(save_path) 212 | if not os.path.isdir(save_path): 213 | os.mkdir(save_path) 214 | output.save('%s/%s' % (save_path, name_tmp)) 215 | print('%s/%s' % (save_path, name_tmp)) 216 | output_col.save('%s/%s_color.png' % (save_path, name_tmp.split('.')[0])) 217 | 218 | heatmap_tmp = heatmap_batch[i,:,:]/np.max(heatmap_batch[i,:,:]) 219 | fig = plt.figure() 220 | plt.axis('off') 221 | heatmap = plt.imshow(heatmap_tmp, cmap='viridis') 222 | fig.colorbar(heatmap) 223 | fig.savefig('%s/%s_heatmap.png' % (save_path, name_tmp.split('.')[0])) 224 | 225 | return args.save 226 | 227 | if __name__ == '__main__': 228 | with torch.no_grad(): 229 | save_path = main() 230 | #os.system('python compute_iou.py ./data/Cityscapes/data/gtFine/train %s'%save_path) 231 | -------------------------------------------------------------------------------- /generate_plabel_cityscapes_SYNTHIA.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import scipy 3 | from scipy import ndimage 4 | import numpy as np 5 | import sys 6 | import re 7 | from packaging import version 8 | 9 | import torch 10 | from torch.autograd import Variable 11 | import torchvision.models as models 12 | import torch.nn.functional as F 13 | from torch.utils import data, model_zoo 14 | from model.deeplab import Res_Deeplab 15 | from model.deeplab_multi import DeeplabMulti 16 | from model.deeplab_vgg import DeeplabVGG 17 | from dataset.cityscapes_dataset import cityscapesDataSet 18 | from collections import OrderedDict 19 | import os 20 | from PIL import Image 21 | from utils.tool import fliplr 22 | import matplotlib.pyplot as plt 23 | import torch.nn as nn 24 | import yaml 25 | 26 | torch.backends.cudnn.benchmark=True 27 | 28 | IMG_MEAN = np.array((104.00698793,116.66876762,122.67891434), dtype=np.float32) 29 | 30 | DATA_DIRECTORY = './data/Cityscapes/data' 31 | DATA_LIST_PATH = './dataset/cityscapes_list/train.txt' 32 | SAVE_PATH = './data/Cityscapes/data/pseudo_SYNTHIA/train' 33 | 34 | if not os.path.isdir('./data/Cityscapes/data/pseudo_SYNTHIA/'): 35 | os.mkdir('./data/Cityscapes/data/pseudo_SYNTHIA/') 36 | os.mkdir(SAVE_PATH) 37 | 38 | IGNORE_LABEL = 255 39 | NUM_CLASSES = 19 40 | NUM_STEPS = 2975 # Number of images in the validation set. 41 | RESTORE_FROM = 'http://vllab.ucmerced.edu/ytsai/CVPR18/GTA2Cityscapes_multi-ed35151c.pth' 42 | RESTORE_FROM_VGG = 'http://vllab.ucmerced.edu/ytsai/CVPR18/GTA2Cityscapes_vgg-ac4ac9f6.pth' 43 | RESTORE_FROM_ORC = 'http://vllab1.ucmerced.edu/~whung/adaptSeg/cityscapes_oracle-b7b9934.pth' 44 | SET = 'train' # We generate pseudo label for training set 45 | 46 | MODEL = 'DeeplabMulti' 47 | 48 | palette = [128, 64, 128, 244, 35, 232, 70, 70, 70, 102, 102, 156, 190, 153, 153, 153, 153, 153, 250, 170, 30, 49 | 220, 220, 0, 107, 142, 35, 152, 251, 152, 70, 130, 180, 220, 20, 60, 255, 0, 0, 0, 0, 142, 0, 0, 70, 50 | 0, 60, 100, 0, 80, 100, 0, 0, 230, 119, 11, 32] 51 | zero_pad = 256 * 3 - len(palette) 52 | for i in range(zero_pad): 53 | palette.append(0) 54 | 55 | 56 | def colorize_mask(mask): 57 | # mask: numpy array of the mask 58 | new_mask = Image.fromarray(mask.astype(np.uint8)).convert('P') 59 | new_mask.putpalette(palette) 60 | 61 | return new_mask 62 | 63 | def get_arguments(): 64 | """Parse all the arguments provided from the CLI. 65 | 66 | Returns: 67 | A list of parsed arguments. 68 | """ 69 | parser = argparse.ArgumentParser(description="DeepLab-ResNet Network") 70 | parser.add_argument("--model", type=str, default=MODEL, 71 | help="Model Choice (DeeplabMulti/DeeplabVGG/Oracle).") 72 | parser.add_argument("--data-dir", type=str, default=DATA_DIRECTORY, 73 | help="Path to the directory containing the Cityscapes dataset.") 74 | parser.add_argument("--data-list", type=str, default=DATA_LIST_PATH, 75 | help="Path to the file listing the images in the dataset.") 76 | parser.add_argument("--ignore-label", type=int, default=IGNORE_LABEL, 77 | help="The index of the label to ignore during the training.") 78 | parser.add_argument("--num-classes", type=int, default=NUM_CLASSES, 79 | help="Number of classes to predict (including background).") 80 | parser.add_argument("--restore-from", type=str, default=RESTORE_FROM, 81 | help="Where restore model parameters from.") 82 | parser.add_argument("--gpu", type=int, default=0, 83 | help="choose gpu device.") 84 | parser.add_argument("--batchsize", type=int, default=12, 85 | help="choose gpu device.") 86 | parser.add_argument("--set", type=str, default=SET, 87 | help="choose evaluation set.") 88 | parser.add_argument("--save", type=str, default=SAVE_PATH, 89 | help="Path to save result.") 90 | return parser.parse_args() 91 | 92 | 93 | def main(): 94 | """Create the model and start the evaluation process.""" 95 | 96 | args = get_arguments() 97 | 98 | config_path = os.path.join(os.path.dirname(args.restore_from),'opts.yaml') 99 | with open(config_path, 'r') as stream: 100 | config = yaml.load(stream) 101 | 102 | args.model = config['model'] 103 | print('ModelType:%s'%args.model) 104 | print('NormType:%s'%config['norm_style']) 105 | gpu0 = args.gpu 106 | batchsize = args.batchsize 107 | 108 | model_name = os.path.basename( os.path.dirname(args.restore_from) ) 109 | #args.save += model_name 110 | 111 | if not os.path.exists(args.save): 112 | os.makedirs(args.save) 113 | 114 | if args.model == 'DeepLab': 115 | model = DeeplabMulti(num_classes=args.num_classes, use_se = config['use_se'], train_bn = False, norm_style = config['norm_style']) 116 | elif args.model == 'Oracle': 117 | model = Res_Deeplab(num_classes=args.num_classes) 118 | if args.restore_from == RESTORE_FROM: 119 | args.restore_from = RESTORE_FROM_ORC 120 | elif args.model == 'DeeplabVGG': 121 | model = DeeplabVGG(num_classes=args.num_classes) 122 | if args.restore_from == RESTORE_FROM: 123 | args.restore_from = RESTORE_FROM_VGG 124 | 125 | if args.restore_from[:4] == 'http' : 126 | saved_state_dict = model_zoo.load_url(args.restore_from) 127 | else: 128 | saved_state_dict = torch.load(args.restore_from) 129 | 130 | try: 131 | model.load_state_dict(saved_state_dict) 132 | except: 133 | model = torch.nn.DataParallel(model) 134 | model.load_state_dict(saved_state_dict) 135 | model = torch.nn.DataParallel(model) 136 | model.eval() 137 | model.cuda(gpu0) 138 | 139 | testloader = data.DataLoader(cityscapesDataSet(args.data_dir, args.data_list, crop_size=(512, 1024), resize_size=(1024, 512), mean=IMG_MEAN, scale=False, mirror=False, set=args.set), 140 | batch_size=batchsize, shuffle=False, pin_memory=True, num_workers=4) 141 | 142 | scale = 1.25 143 | testloader2 = data.DataLoader(cityscapesDataSet(args.data_dir, args.data_list, crop_size=(round(512*scale), round(1024*scale) ), resize_size=( round(1024*scale), round(512*scale)), mean=IMG_MEAN, scale=False, mirror=False, set=args.set), 144 | batch_size=batchsize, shuffle=False, pin_memory=True, num_workers=4) 145 | 146 | 147 | if version.parse(torch.__version__) >= version.parse('0.4.0'): 148 | interp = nn.Upsample(size=(1024, 2048), mode='bilinear', align_corners=True) 149 | else: 150 | interp = nn.Upsample(size=(1024, 2048), mode='bilinear') 151 | 152 | sm = torch.nn.Softmax(dim = 1) 153 | for index, img_data in enumerate(zip(testloader, testloader2) ): 154 | batch, batch2 = img_data 155 | image, _, _, name = batch 156 | image2, _, _, name2 = batch2 157 | print(image.shape) 158 | 159 | inputs = image.cuda() 160 | inputs2 = image2.cuda() 161 | print('\r>>>>Extracting feature...%04d/%04d'%(index*batchsize, NUM_STEPS), end='') 162 | if args.model == 'DeepLab': 163 | with torch.no_grad(): 164 | output1, output2 = model(inputs) 165 | output_batch = interp(sm(0.5* output1 + output2)) 166 | output1, output2 = model(fliplr(inputs)) 167 | output1, output2 = fliplr(output1), fliplr(output2) 168 | output_batch += interp(sm(0.5 * output1 + output2)) 169 | del output1, output2, inputs 170 | 171 | output1, output2 = model(inputs2) 172 | output_batch += interp(sm(0.5* output1 + output2)) 173 | output1, output2 = model(fliplr(inputs2)) 174 | output1, output2 = fliplr(output1), fliplr(output2) 175 | output_batch += interp(sm(0.5 * output1 + output2)) 176 | del output1, output2, inputs2 177 | output_batch = output_batch.cpu().data.numpy() 178 | elif args.model == 'DeeplabVGG' or args.model == 'Oracle': 179 | output_batch = model(Variable(image).cuda()) 180 | output_batch = interp(output_batch).cpu().data.numpy() 181 | 182 | #output_batch = output_batch.transpose(0,2,3,1) 183 | #output_batch = np.asarray(np.argmax(output_batch, axis=3), dtype=np.uint8) 184 | output_batch = output_batch.transpose(0,2,3,1) 185 | score_batch = np.max(output_batch, axis=3) 186 | output_batch = np.asarray(np.argmax(output_batch, axis=3), dtype=np.uint8) 187 | output_batch[score_batch<3.6] = 255 #3.6 = 4*0.9 188 | 189 | 190 | for i in range(output_batch.shape[0]): 191 | output = output_batch[i,:,:] 192 | output_col = colorize_mask(output) 193 | output = Image.fromarray(output) 194 | 195 | name_tmp = name[i].split('/')[-1] 196 | dir_name = name[i].split('/')[-2] 197 | save_path = args.save + '/' + dir_name 198 | #save_path = re.replace(save_path, 'leftImg8bit', 'pseudo') 199 | #print(save_path) 200 | if not os.path.isdir(save_path): 201 | os.mkdir(save_path) 202 | output.save('%s/%s' % (save_path, name_tmp)) 203 | print('%s/%s' % (save_path, name_tmp)) 204 | output_col.save('%s/%s_color.png' % (save_path, name_tmp.split('.')[0])) 205 | 206 | return args.save 207 | 208 | if __name__ == '__main__': 209 | with torch.no_grad(): 210 | save_path = main() 211 | #os.system('python compute_iou.py ./data/Cityscapes/data/gtFine/train %s'%save_path) 212 | -------------------------------------------------------------------------------- /generate_plabel_robot.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import scipy 3 | from scipy import ndimage 4 | import numpy as np 5 | import sys 6 | import re 7 | from packaging import version 8 | 9 | import torch 10 | from torch.autograd import Variable 11 | import torchvision.models as models 12 | import torch.nn.functional as F 13 | from torch.utils import data, model_zoo 14 | from model.deeplab import Res_Deeplab 15 | from model.deeplab_multi import DeeplabMulti 16 | from model.deeplab_vgg import DeeplabVGG 17 | from dataset.robot_dataset import robotDataSet 18 | from collections import OrderedDict 19 | import os 20 | from PIL import Image 21 | from utils.tool import fliplr 22 | import matplotlib.pyplot as plt 23 | import torch.nn as nn 24 | import yaml 25 | 26 | torch.backends.cudnn.benchmark=True 27 | 28 | IMG_MEAN = np.array((104.00698793,116.66876762,122.67891434), dtype=np.float32) 29 | 30 | DATA_DIRECTORY = './data/Oxford_Robot_ICCV19/' 31 | DATA_LIST_PATH = './dataset/robot_list/train.txt' 32 | SAVE_PATH = './data/Oxford_Robot_ICCV19/pseudo_train' 33 | 34 | if not os.path.isdir(SAVE_PATH): 35 | os.mkdir(SAVE_PATH) 36 | 37 | IGNORE_LABEL = 255 38 | NUM_CLASSES = 9 39 | NUM_STEPS = 894 # Number of images in the validation set. 40 | RESTORE_FROM = 'http://vllab.ucmerced.edu/ytsai/CVPR18/GTA2Cityscapes_multi-ed35151c.pth' 41 | RESTORE_FROM_VGG = 'http://vllab.ucmerced.edu/ytsai/CVPR18/GTA2Cityscapes_vgg-ac4ac9f6.pth' 42 | RESTORE_FROM_ORC = 'http://vllab1.ucmerced.edu/~whung/adaptSeg/cityscapes_oracle-b7b9934.pth' 43 | SET = 'train' # We generate pseudo label for training set 44 | 45 | MODEL = 'DeeplabMulti' 46 | 47 | palette = [ 48 | [70,130,180], 49 | [220,20,60], 50 | [119,11,32], 51 | [0,0,142], 52 | [220,220,0], 53 | [250,170,30], 54 | [70,70,70], 55 | [244,35,232], 56 | [128,64,128], 57 | ] 58 | palette = [item for sublist in palette for item in sublist] 59 | 60 | zero_pad = 256 * 3 - len(palette) 61 | for i in range(zero_pad): 62 | palette.append(0) 63 | 64 | 65 | def colorize_mask(mask): 66 | # mask: numpy array of the mask 67 | new_mask = Image.fromarray(mask.astype(np.uint8)).convert('P') 68 | new_mask.putpalette(palette) 69 | 70 | return new_mask 71 | 72 | def get_arguments(): 73 | """Parse all the arguments provided from the CLI. 74 | 75 | Returns: 76 | A list of parsed arguments. 77 | """ 78 | parser = argparse.ArgumentParser(description="DeepLab-ResNet Network") 79 | parser.add_argument("--model", type=str, default=MODEL, 80 | help="Model Choice (DeeplabMulti/DeeplabVGG/Oracle).") 81 | parser.add_argument("--data-dir", type=str, default=DATA_DIRECTORY, 82 | help="Path to the directory containing the Cityscapes dataset.") 83 | parser.add_argument("--data-list", type=str, default=DATA_LIST_PATH, 84 | help="Path to the file listing the images in the dataset.") 85 | parser.add_argument("--ignore-label", type=int, default=IGNORE_LABEL, 86 | help="The index of the label to ignore during the training.") 87 | parser.add_argument("--num-classes", type=int, default=NUM_CLASSES, 88 | help="Number of classes to predict (including background).") 89 | parser.add_argument("--restore-from", type=str, default=RESTORE_FROM, 90 | help="Where restore model parameters from.") 91 | parser.add_argument("--gpu", type=int, default=0, 92 | help="choose gpu device.") 93 | parser.add_argument("--batchsize", type=int, default=12, 94 | help="choose gpu device.") 95 | parser.add_argument("--set", type=str, default=SET, 96 | help="choose evaluation set.") 97 | parser.add_argument("--save", type=str, default=SAVE_PATH, 98 | help="Path to save result.") 99 | return parser.parse_args() 100 | 101 | 102 | def main(): 103 | """Create the model and start the evaluation process.""" 104 | 105 | args = get_arguments() 106 | 107 | config_path = os.path.join(os.path.dirname(args.restore_from),'opts.yaml') 108 | with open(config_path, 'r') as stream: 109 | config = yaml.load(stream) 110 | 111 | args.model = config['model'] 112 | print('ModelType:%s'%args.model) 113 | print('NormType:%s'%config['norm_style']) 114 | gpu0 = args.gpu 115 | batchsize = args.batchsize 116 | 117 | model_name = os.path.basename( os.path.dirname(args.restore_from) ) 118 | #args.save += model_name 119 | 120 | if not os.path.exists(args.save): 121 | os.makedirs(args.save) 122 | 123 | if args.model == 'DeepLab': 124 | model = DeeplabMulti(num_classes=args.num_classes, use_se = config['use_se'], train_bn = False, norm_style = config['norm_style']) 125 | elif args.model == 'Oracle': 126 | model = Res_Deeplab(num_classes=args.num_classes) 127 | if args.restore_from == RESTORE_FROM: 128 | args.restore_from = RESTORE_FROM_ORC 129 | elif args.model == 'DeeplabVGG': 130 | model = DeeplabVGG(num_classes=args.num_classes) 131 | if args.restore_from == RESTORE_FROM: 132 | args.restore_from = RESTORE_FROM_VGG 133 | 134 | if args.restore_from[:4] == 'http' : 135 | saved_state_dict = model_zoo.load_url(args.restore_from) 136 | else: 137 | saved_state_dict = torch.load(args.restore_from) 138 | 139 | try: 140 | model.load_state_dict(saved_state_dict) 141 | except: 142 | model = torch.nn.DataParallel(model) 143 | model.load_state_dict(saved_state_dict) 144 | model = torch.nn.DataParallel(model) 145 | model.eval() 146 | model.cuda(gpu0) 147 | 148 | testloader = data.DataLoader(robotDataSet(args.data_dir, args.data_list, crop_size=(960, 1280), resize_size=(1280, 960), mean=IMG_MEAN, scale=False, mirror=False, set=args.set), 149 | batch_size=batchsize, shuffle=False, pin_memory=True, num_workers=4) 150 | 151 | scale = 1.25 152 | testloader2 = data.DataLoader(robotDataSet(args.data_dir, args.data_list, crop_size=(round(960*scale), round(1280*scale) ), resize_size=( round(1280*scale), round(960*scale)), mean=IMG_MEAN, scale=False, mirror=False, set=args.set), 153 | batch_size=batchsize, shuffle=False, pin_memory=True, num_workers=4) 154 | 155 | 156 | if version.parse(torch.__version__) >= version.parse('0.4.0'): 157 | interp = nn.Upsample(size=(960, 1280), mode='bilinear', align_corners=True) 158 | else: 159 | interp = nn.Upsample(size=(960, 1280), mode='bilinear') 160 | 161 | sm = torch.nn.Softmax(dim = 1) 162 | for index, img_data in enumerate(zip(testloader, testloader2) ): 163 | batch, batch2 = img_data 164 | image, _, _, name = batch 165 | image2, _, _, name2 = batch2 166 | print(image.shape) 167 | 168 | inputs = image.cuda() 169 | inputs2 = image2.cuda() 170 | print('\r>>>>Extracting feature...%04d/%04d'%(index*batchsize, NUM_STEPS), end='') 171 | if args.model == 'DeepLab': 172 | with torch.no_grad(): 173 | output1, output2 = model(inputs) 174 | output_batch = interp(sm(0.5* output1 + output2)) 175 | output1, output2 = model(fliplr(inputs)) 176 | output1, output2 = fliplr(output1), fliplr(output2) 177 | output_batch += interp(sm(0.5 * output1 + output2)) 178 | del output1, output2, inputs 179 | 180 | output1, output2 = model(inputs2) 181 | output_batch += interp(sm(0.5* output1 + output2)) 182 | output1, output2 = model(fliplr(inputs2)) 183 | output1, output2 = fliplr(output1), fliplr(output2) 184 | output_batch += interp(sm(0.5 * output1 + output2)) 185 | del output1, output2, inputs2 186 | output_batch = output_batch.cpu().data.numpy() 187 | elif args.model == 'DeeplabVGG' or args.model == 'Oracle': 188 | output_batch = model(Variable(image).cuda()) 189 | output_batch = interp(output_batch).cpu().data.numpy() 190 | 191 | output_batch = output_batch.transpose(0,2,3,1) 192 | score_batch = np.max(output_batch, axis=3) 193 | output_batch = np.asarray(np.argmax(output_batch, axis=3), dtype=np.uint8) 194 | #output_batch[score_batch<3.6] = 255 #3.6 = 4*0.9 195 | 196 | for i in range(output_batch.shape[0]): 197 | output = output_batch[i,:,:] 198 | output_col = colorize_mask(output) 199 | output = Image.fromarray(output) 200 | 201 | name_tmp = name[i].split('/')[-1] 202 | dir_name = name[i].split('/')[-2] 203 | save_path = args.save + '/' + dir_name 204 | #save_path = re.replace(save_path, 'leftImg8bit', 'pseudo') 205 | #print(save_path) 206 | if not os.path.isdir(save_path): 207 | os.mkdir(save_path) 208 | output.save('%s/%s' % (save_path, name_tmp)) 209 | print('%s/%s' % (save_path, name_tmp)) 210 | output_col.save('%s/%s_color.png' % (save_path, name_tmp.split('.')[0])) 211 | 212 | return args.save 213 | 214 | if __name__ == '__main__': 215 | with torch.no_grad(): 216 | save_path = main() 217 | #os.system('python compute_iou.py ./data/Cityscapes/data/gtFine/train %s'%save_path) 218 | -------------------------------------------------------------------------------- /model/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/layumi/Seg-Uncertainty/6fce9eae141c2c0592b3e7c1b3e5f8ee7b1ce9a6/model/__init__.py -------------------------------------------------------------------------------- /model/deeplab.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import math 3 | import torch.utils.model_zoo as model_zoo 4 | import torch 5 | import numpy as np 6 | affine_par = True 7 | 8 | 9 | def outS(i): 10 | i = int(i) 11 | i = (i+1)/2 12 | i = int(np.ceil((i+1)/2.0)) 13 | i = (i+1)/2 14 | return i 15 | 16 | def conv3x3(in_planes, out_planes, stride=1): 17 | "3x3 convolution with padding" 18 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 19 | padding=1, bias=False) 20 | 21 | 22 | class BasicBlock(nn.Module): 23 | expansion = 1 24 | 25 | def __init__(self, inplanes, planes, stride=1, downsample=None): 26 | super(BasicBlock, self).__init__() 27 | self.conv1 = conv3x3(inplanes, planes, stride) 28 | self.bn1 = nn.BatchNorm2d(planes, affine = affine_par) 29 | self.relu = nn.ReLU(inplace=True) 30 | self.conv2 = conv3x3(planes, planes) 31 | self.bn2 = nn.BatchNorm2d(planes, affine = affine_par) 32 | self.downsample = downsample 33 | self.stride = stride 34 | 35 | def forward(self, x): 36 | residual = x 37 | 38 | out = self.conv1(x) 39 | out = self.bn1(out) 40 | out = self.relu(out) 41 | 42 | out = self.conv2(out) 43 | out = self.bn2(out) 44 | 45 | if self.downsample is not None: 46 | residual = self.downsample(x) 47 | 48 | out += residual 49 | out = self.relu(out) 50 | 51 | return out 52 | 53 | 54 | class Bottleneck(nn.Module): 55 | expansion = 4 56 | 57 | def __init__(self, inplanes, planes, stride=1, dilation=1, downsample=None): 58 | super(Bottleneck, self).__init__() 59 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, stride=stride, bias=False) # change 60 | self.bn1 = nn.BatchNorm2d(planes,affine = affine_par) 61 | for i in self.bn1.parameters(): 62 | i.requires_grad = False 63 | 64 | padding = dilation 65 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, # change 66 | padding=padding, bias=False, dilation = dilation) 67 | self.bn2 = nn.BatchNorm2d(planes,affine = affine_par) 68 | for i in self.bn2.parameters(): 69 | i.requires_grad = False 70 | self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False) 71 | self.bn3 = nn.BatchNorm2d(planes * 4, affine = affine_par) 72 | for i in self.bn3.parameters(): 73 | i.requires_grad = False 74 | self.relu = nn.ReLU(inplace=True) 75 | self.downsample = downsample 76 | self.stride = stride 77 | 78 | 79 | def forward(self, x): 80 | residual = x 81 | 82 | out = self.conv1(x) 83 | out = self.bn1(out) 84 | out = self.relu(out) 85 | 86 | out = self.conv2(out) 87 | out = self.bn2(out) 88 | out = self.relu(out) 89 | 90 | out = self.conv3(out) 91 | out = self.bn3(out) 92 | 93 | if self.downsample is not None: 94 | residual = self.downsample(x) 95 | 96 | out += residual 97 | out = self.relu(out) 98 | 99 | return out 100 | 101 | class Classifier_Module(nn.Module): 102 | 103 | def __init__(self, dilation_series, padding_series, num_classes): 104 | super(Classifier_Module, self).__init__() 105 | self.conv2d_list = nn.ModuleList() 106 | for dilation, padding in zip(dilation_series, padding_series): 107 | self.conv2d_list.append(nn.Conv2d(2048, num_classes, kernel_size=3, stride=1, padding=padding, dilation=dilation, bias = True)) 108 | 109 | for m in self.conv2d_list: 110 | m.weight.data.normal_(0, 0.01) 111 | 112 | def forward(self, x): 113 | out = self.conv2d_list[0](x) 114 | for i in range(len(self.conv2d_list)-1): 115 | out += self.conv2d_list[i+1](x) 116 | return out 117 | 118 | 119 | 120 | class ResNet(nn.Module): 121 | def __init__(self, block, layers, num_classes): 122 | self.inplanes = 64 123 | super(ResNet, self).__init__() 124 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, 125 | bias=False) 126 | self.bn1 = nn.BatchNorm2d(64, affine = affine_par) 127 | for i in self.bn1.parameters(): 128 | i.requires_grad = False 129 | self.relu = nn.ReLU(inplace=True) 130 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1, ceil_mode=True) # change 131 | self.layer1 = self._make_layer(block, 64, layers[0]) 132 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2) 133 | self.layer3 = self._make_layer(block, 256, layers[2], stride=1, dilation=2) 134 | self.layer4 = self._make_layer(block, 512, layers[3], stride=1, dilation=4) 135 | self.layer5 = self._make_pred_layer(Classifier_Module, [6,12,18,24],[6,12,18,24],num_classes) 136 | 137 | for m in self.modules(): 138 | if isinstance(m, nn.Conv2d): 139 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 140 | m.weight.data.normal_(0, 0.01) 141 | elif isinstance(m, nn.BatchNorm2d): 142 | m.weight.data.fill_(1) 143 | m.bias.data.zero_() 144 | # for i in m.parameters(): 145 | # i.requires_grad = False 146 | 147 | def _make_layer(self, block, planes, blocks, stride=1, dilation=1): 148 | downsample = None 149 | if stride != 1 or self.inplanes != planes * block.expansion or dilation == 2 or dilation == 4: 150 | downsample = nn.Sequential( 151 | nn.Conv2d(self.inplanes, planes * block.expansion, 152 | kernel_size=1, stride=stride, bias=False), 153 | nn.BatchNorm2d(planes * block.expansion,affine = affine_par)) 154 | for i in downsample._modules['1'].parameters(): 155 | i.requires_grad = False 156 | layers = [] 157 | layers.append(block(self.inplanes, planes, stride,dilation=dilation, downsample=downsample)) 158 | self.inplanes = planes * block.expansion 159 | for i in range(1, blocks): 160 | layers.append(block(self.inplanes, planes, dilation=dilation)) 161 | 162 | return nn.Sequential(*layers) 163 | def _make_pred_layer(self,block, dilation_series, padding_series,num_classes): 164 | return block(dilation_series,padding_series,num_classes) 165 | 166 | def forward(self, x): 167 | x = self.conv1(x) 168 | x = self.bn1(x) 169 | x = self.relu(x) 170 | x = self.maxpool(x) 171 | x = self.layer1(x) 172 | x = self.layer2(x) 173 | x = self.layer3(x) 174 | x = self.layer4(x) 175 | x = self.layer5(x) 176 | 177 | return x 178 | 179 | def get_1x_lr_params_NOscale(self): 180 | """ 181 | This generator returns all the parameters of the net except for 182 | the last classification layer. Note that for each batchnorm layer, 183 | requires_grad is set to False in deeplab_resnet.py, therefore this function does not return 184 | any batchnorm parameter 185 | """ 186 | b = [] 187 | 188 | b.append(self.conv1) 189 | b.append(self.bn1) 190 | b.append(self.layer1) 191 | b.append(self.layer2) 192 | b.append(self.layer3) 193 | b.append(self.layer4) 194 | 195 | 196 | for i in range(len(b)): 197 | for j in b[i].modules(): 198 | jj = 0 199 | for k in j.parameters(): 200 | jj+=1 201 | if k.requires_grad: 202 | yield k 203 | 204 | def get_10x_lr_params(self): 205 | """ 206 | This generator returns all the parameters for the last layer of the net, 207 | which does the classification of pixel into classes 208 | """ 209 | b = [] 210 | b.append(self.layer5.parameters()) 211 | 212 | for j in range(len(b)): 213 | for i in b[j]: 214 | yield i 215 | 216 | 217 | 218 | def optim_parameters(self, args): 219 | return [{'params': self.get_1x_lr_params_NOscale(), 'lr': args.learning_rate}, 220 | {'params': self.get_10x_lr_params(), 'lr': 10*args.learning_rate}] 221 | 222 | 223 | def Res_Deeplab(num_classes=21): 224 | model = ResNet(Bottleneck,[3, 4, 23, 3], num_classes) 225 | return model 226 | 227 | -------------------------------------------------------------------------------- /model/deeplab_multi.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import math 3 | import torch.utils.model_zoo as model_zoo 4 | import torch 5 | import numpy as np 6 | 7 | affine_par = True 8 | 9 | 10 | def outS(i): 11 | i = int(i) 12 | i = (i + 1) / 2 13 | i = int(np.ceil((i + 1) / 2.0)) 14 | i = (i + 1) / 2 15 | return i 16 | 17 | 18 | def conv3x3(in_planes, out_planes, stride=1): 19 | "3x3 convolution with padding" 20 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 21 | padding=1, bias=False) 22 | 23 | def NormLayer(norm_dim, norm_style = 'bn'): 24 | if norm_style == 'bn': 25 | norm_layer = nn.BatchNorm2d(norm_dim) 26 | elif norm_style == 'in': 27 | norm_layer = nn.InstanceNorm2d(norm_dim, affine = True) 28 | elif norm_style == 'ln': 29 | norm_layer = nn.LayerNorm(norm_dim, elementwise_affine=True) 30 | elif norm_style == 'gn': 31 | norm_layer = nn.GroupNorm(num_groups=32, num_channels=norm_dim, affine = True) 32 | return norm_layer 33 | 34 | class SEBlock(nn.Module): 35 | def __init__(self, inplanes, r = 16): 36 | super(SEBlock, self).__init__() 37 | self.global_pool = nn.AdaptiveAvgPool2d((1,1)) 38 | self.se = nn.Sequential( 39 | nn.Linear(inplanes, inplanes//r), 40 | nn.ReLU(inplace=True), 41 | nn.Linear(inplanes//r, inplanes), 42 | nn.Sigmoid() 43 | ) 44 | def forward(self, x): 45 | xx = self.global_pool(x) 46 | xx = xx.view(xx.size(0), xx.size(1)) 47 | se_weight = self.se(xx).unsqueeze(-1).unsqueeze(-1) 48 | return x.mul(se_weight) 49 | 50 | class BasicBlock(nn.Module): 51 | expansion = 1 52 | 53 | def __init__(self, inplanes, planes, stride=1, downsample=None): 54 | super(BasicBlock, self).__init__() 55 | self.conv1 = conv3x3(inplanes, planes, stride) 56 | self.bn1 = nn.BatchNorm2d(planes, affine=affine_par) 57 | self.relu = nn.ReLU(inplace=True) 58 | self.conv2 = conv3x3(planes, planes) 59 | self.bn2 = nn.BatchNorm2d(planes, affine=affine_par) 60 | self.downsample = downsample 61 | self.stride = stride 62 | 63 | def forward(self, x): 64 | residual = x 65 | 66 | out = self.conv1(x) 67 | out = self.bn1(out) 68 | out = self.relu(out) 69 | 70 | out = self.conv2(out) 71 | out = self.bn2(out) 72 | 73 | if self.downsample is not None: 74 | residual = self.downsample(x) 75 | 76 | out += residual 77 | out = self.relu(out) 78 | 79 | return out 80 | 81 | 82 | class Bottleneck(nn.Module): 83 | expansion = 4 84 | 85 | def __init__(self, inplanes, planes, stride=1, dilation=1, downsample=None, train_bn = False): 86 | super(Bottleneck, self).__init__() 87 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, stride=stride, bias=False) # change 88 | self.bn1 = nn.BatchNorm2d(planes, affine=affine_par) 89 | for i in self.bn1.parameters(): 90 | i.requires_grad = train_bn 91 | 92 | padding = dilation 93 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, # change 94 | padding=padding, bias=False, dilation=dilation) 95 | self.bn2 = nn.BatchNorm2d(planes, affine=affine_par) 96 | for i in self.bn2.parameters(): 97 | i.requires_grad = train_bn 98 | self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False) 99 | self.bn3 = nn.BatchNorm2d(planes * 4, affine=affine_par) 100 | for i in self.bn3.parameters(): 101 | i.requires_grad = train_bn 102 | self.relu = nn.ReLU(inplace=True) 103 | self.downsample = downsample 104 | self.stride = stride 105 | 106 | def forward(self, x): 107 | residual = x 108 | 109 | out = self.conv1(x) 110 | out = self.bn1(out) 111 | out = self.relu(out) 112 | 113 | out = self.conv2(out) 114 | out = self.bn2(out) 115 | out = self.relu(out) 116 | 117 | out = self.conv3(out) 118 | out = self.bn3(out) 119 | 120 | if self.downsample is not None: 121 | residual = self.downsample(x) 122 | 123 | out += residual 124 | out = self.relu(out) 125 | 126 | return out 127 | 128 | 129 | class Classifier_Module(nn.Module): 130 | def __init__(self, inplanes, dilation_series, padding_series, num_classes, norm_style = 'bn', droprate = 0.1, use_se = False): 131 | super(Classifier_Module, self).__init__() 132 | self.conv2d_list = nn.ModuleList() 133 | self.conv2d_list.append( 134 | nn.Sequential(*[ 135 | nn.Conv2d(inplanes, 256, kernel_size=1, stride=1, padding=0, dilation=1, bias=True), 136 | NormLayer(256, norm_style), 137 | nn.ReLU(inplace=True) ])) 138 | 139 | for dilation, padding in zip(dilation_series, padding_series): 140 | #self.conv2d_list.append( 141 | # nn.BatchNorm2d(inplanes)) 142 | self.conv2d_list.append( 143 | nn.Sequential(*[ 144 | #nn.ReflectionPad2d(padding), 145 | nn.Conv2d(inplanes, 256, kernel_size=3, stride=1, padding=padding, dilation=dilation, bias=True), 146 | NormLayer(256, norm_style), 147 | nn.ReLU(inplace=True) ])) 148 | 149 | if use_se: 150 | self.bottleneck = nn.Sequential(*[SEBlock(256 * (len(dilation_series) + 1)), 151 | nn.Conv2d(256 * (len(dilation_series) + 1), 512, kernel_size=3, stride=1, padding=1, dilation=1, bias=True) , 152 | NormLayer(512, norm_style) ]) 153 | else: 154 | self.bottleneck = nn.Sequential(*[ 155 | nn.Conv2d(256 * (len(dilation_series) + 1), 512, kernel_size=3, stride=1, padding=1, dilation=1, bias=True) , 156 | NormLayer(512, norm_style) ]) 157 | 158 | self.head = nn.Sequential(*[nn.Dropout2d(droprate), 159 | nn.Conv2d(512, num_classes, kernel_size=1, padding=0, dilation=1, bias=False) ]) 160 | 161 | ##########init####### 162 | for m in self.conv2d_list: 163 | if isinstance(m, nn.Conv2d): 164 | torch.nn.init.kaiming_normal_(m.weight.data, a=0, mode='fan_in') 165 | m.bias.data.zero_() 166 | elif isinstance(m, nn.BatchNorm2d) or isinstance(m, nn.InstanceNorm2d) or isinstance(m, nn.GroupNorm): 167 | m.weight.data.fill_(1) 168 | m.bias.data.zero_() 169 | 170 | for m in self.bottleneck: 171 | if isinstance(m, nn.Conv2d): 172 | torch.nn.init.kaiming_normal_(m.weight.data, a=0, mode='fan_in') 173 | m.bias.data.zero_() 174 | elif isinstance(m, nn.Linear): 175 | torch.nn.init.kaiming_normal_(m.weight.data, a=0, mode='fan_out') 176 | m.bias.data.zero_() 177 | elif isinstance(m, nn.BatchNorm2d) or isinstance(m, nn.InstanceNorm2d) or isinstance(m, nn.GroupNorm) or isinstance(m, nn.LayerNorm): 178 | m.weight.data.fill_(1) 179 | m.bias.data.zero_() 180 | 181 | for m in self.head: 182 | if isinstance(m, nn.Conv2d): 183 | m.weight.data.normal_(0, 0.001) 184 | 185 | def forward(self, x): 186 | out = self.conv2d_list[0](x) 187 | for i in range(len(self.conv2d_list) - 1): 188 | out = torch.cat( (out, self.conv2d_list[i+1](x)), 1) 189 | out = self.bottleneck(out) 190 | out = self.head(out) 191 | return out 192 | 193 | 194 | class ResNetMulti(nn.Module): 195 | def __init__(self, block, layers, num_classes, use_se = False, train_bn = False, norm_style = 'bn', droprate = 0.1): 196 | self.inplanes = 64 197 | self.train_bn = train_bn 198 | super(ResNetMulti, self).__init__() 199 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, 200 | bias=False) 201 | self.bn1 = nn.BatchNorm2d(64, affine=affine_par) 202 | for i in self.bn1.parameters(): 203 | i.requires_grad = self.train_bn 204 | self.relu = nn.ReLU(inplace=True) 205 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1, ceil_mode=True) # change 206 | self.layer1 = self._make_layer(block, 64, layers[0]) 207 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2) 208 | self.layer3 = self._make_layer(block, 256, layers[2], stride=1, dilation=2) 209 | self.layer4 = self._make_layer(block, 512, layers[3], stride=1, dilation=4) 210 | self.layer5 = self._make_pred_layer(Classifier_Module, 1024, [6, 12, 18, 24], [6, 12, 18, 24], num_classes, norm_style, droprate, use_se) 211 | self.layer6 = self._make_pred_layer(Classifier_Module, 1024 + 2048, [6, 12, 18, 24], [6, 12, 18, 24], num_classes, norm_style, droprate, use_se) 212 | 213 | #for m in self.modules(): 214 | # if isinstance(m, nn.Conv2d): 215 | # n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 216 | # m.weight.data.normal_(0, 0.01) 217 | # elif isinstance(m, nn.BatchNorm2d): 218 | # m.weight.data.fill_(1) 219 | # m.bias.data.zero_() 220 | # for i in m.parameters(): 221 | # i.requires_grad = False 222 | 223 | def _make_layer(self, block, planes, blocks, stride=1, dilation=1): 224 | downsample = None 225 | if stride != 1 or self.inplanes != planes * block.expansion or dilation == 2 or dilation == 4: 226 | downsample = nn.Sequential( 227 | nn.Conv2d(self.inplanes, planes * block.expansion, 228 | kernel_size=1, stride=stride, bias=False), 229 | nn.BatchNorm2d(planes * block.expansion, affine=affine_par)) 230 | for i in downsample._modules['1'].parameters(): 231 | i.requires_grad = self.train_bn 232 | layers = [] 233 | layers.append(block(self.inplanes, planes, stride, dilation=dilation, downsample=downsample, train_bn = self.train_bn)) 234 | self.inplanes = planes * block.expansion 235 | for i in range(1, blocks): 236 | layers.append(block(self.inplanes, planes, dilation=dilation, downsample = None, train_bn = self.train_bn)) 237 | 238 | return nn.Sequential(*layers) 239 | 240 | def _make_pred_layer(self, block, inplanes, dilation_series, padding_series, num_classes, norm_style, droprate, use_se): 241 | return block(inplanes, dilation_series, padding_series, num_classes, norm_style, droprate, use_se) 242 | 243 | def forward(self, x): 244 | x = self.conv1(x) 245 | x = self.bn1(x) 246 | x = self.relu(x) 247 | x = self.maxpool(x) 248 | x = self.layer1(x) 249 | x = self.layer2(x) 250 | 251 | x = self.layer3(x) 252 | x1 = self.layer5(x) 253 | 254 | x2 = torch.cat((self.layer4(x),x), 1) 255 | x2 = self.layer6(x2) 256 | 257 | return x1, x2 258 | 259 | def get_1x_lr_params_NOscale(self): 260 | """ 261 | This generator returns all the parameters of the net except for 262 | the last classification layer. Note that for each batchnorm layer, 263 | requires_grad is set to False in deeplab_resnet.py, therefore this function does not return 264 | any batchnorm parameter 265 | """ 266 | b = [] 267 | 268 | b.append(self.conv1) 269 | b.append(self.bn1) 270 | b.append(self.layer1) 271 | b.append(self.layer2) 272 | b.append(self.layer3) 273 | b.append(self.layer4) 274 | 275 | for i in range(len(b)): 276 | for j in b[i].modules(): 277 | jj = 0 278 | for k in j.parameters(): 279 | jj += 1 280 | if k.requires_grad: 281 | yield k 282 | 283 | def get_10x_lr_params(self): 284 | """ 285 | This generator returns all the parameters for the last layer of the net, 286 | which does the classification of pixel into classes 287 | """ 288 | b = [] 289 | b.append(self.layer5.parameters()) 290 | b.append(self.layer6.parameters()) 291 | 292 | for j in range(len(b)): 293 | for i in b[j]: 294 | yield i 295 | 296 | def optim_parameters(self, args): 297 | return [{'params': self.get_1x_lr_params_NOscale(), 'lr': args.learning_rate}, 298 | {'params': self.get_10x_lr_params(), 'lr': 10 * args.learning_rate}] 299 | 300 | 301 | def DeeplabMulti(num_classes=21, use_se = False, train_bn = False, norm_style = 'bn', droprate = 0.1): 302 | model = ResNetMulti(Bottleneck, [3, 4, 23, 3], num_classes, use_se = use_se, train_bn = train_bn, norm_style = norm_style, droprate = droprate) 303 | return model 304 | -------------------------------------------------------------------------------- /model/deeplab_vgg.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from torch import nn 4 | from torchvision import models 5 | 6 | class Classifier_Module(nn.Module): 7 | 8 | def __init__(self, dims_in, dilation_series, padding_series, num_classes): 9 | super(Classifier_Module, self).__init__() 10 | self.conv2d_list = nn.ModuleList() 11 | for dilation, padding in zip(dilation_series, padding_series): 12 | self.conv2d_list.append(nn.Conv2d(dims_in, num_classes, kernel_size=3, stride=1, padding=padding, dilation=dilation, bias = True)) 13 | 14 | for m in self.conv2d_list: 15 | m.weight.data.normal_(0, 0.01) 16 | 17 | def forward(self, x): 18 | out = self.conv2d_list[0](x) 19 | for i in range(len(self.conv2d_list)-1): 20 | out += self.conv2d_list[i+1](x) 21 | return out 22 | 23 | 24 | class DeeplabVGG(nn.Module): 25 | def __init__(self, num_classes, vgg16_caffe_path=None, pretrained=False): 26 | super(DeeplabVGG, self).__init__() 27 | vgg = models.vgg16() 28 | if pretrained: 29 | vgg.load_state_dict(torch.load(vgg16_caffe_path)) 30 | 31 | features, classifier = list(vgg.features.children()), list(vgg.classifier.children()) 32 | 33 | #remove pool4/pool5 34 | features = nn.Sequential(*(features[i] for i in list(range(23))+list(range(24,30)))) 35 | 36 | for i in [23,25,27]: 37 | features[i].dilation = (2,2) 38 | features[i].padding = (2,2) 39 | 40 | fc6 = nn.Conv2d(512, 1024, kernel_size=3, padding=4, dilation=4) 41 | fc7 = nn.Conv2d(1024, 1024, kernel_size=3, padding=4, dilation=4) 42 | 43 | self.features = nn.Sequential(*([features[i] for i in range(len(features))] + [ fc6, nn.ReLU(inplace=True), fc7, nn.ReLU(inplace=True)])) 44 | 45 | self.classifier = Classifier_Module(1024, [6,12,18,24],[6,12,18,24],num_classes) 46 | 47 | 48 | def forward(self, x): 49 | x = self.features(x) 50 | x = self.classifier(x) 51 | return x 52 | 53 | def optim_parameters(self, args): 54 | return self.parameters() 55 | -------------------------------------------------------------------------------- /model/discriminator.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch.nn.functional as F 3 | 4 | 5 | class FCDiscriminator(nn.Module): 6 | 7 | def __init__(self, num_classes, ndf = 64): 8 | super(FCDiscriminator, self).__init__() 9 | 10 | self.conv1 = nn.Conv2d(num_classes, ndf, kernel_size=4, stride=2, padding=1) 11 | self.conv2 = nn.Conv2d(ndf, ndf*2, kernel_size=4, stride=2, padding=1) 12 | self.conv3 = nn.Conv2d(ndf*2, ndf*4, kernel_size=4, stride=2, padding=1) 13 | self.conv4 = nn.Conv2d(ndf*4, ndf*8, kernel_size=4, stride=2, padding=1) 14 | self.classifier = nn.Conv2d(ndf*8, 1, kernel_size=4, stride=2, padding=1) 15 | 16 | self.leaky_relu = nn.LeakyReLU(negative_slope=0.2, inplace=True) 17 | #self.up_sample = nn.Upsample(scale_factor=32, mode='bilinear') 18 | #self.sigmoid = nn.Sigmoid() 19 | 20 | 21 | def forward(self, x): 22 | x = self.conv1(x) 23 | x = self.leaky_relu(x) 24 | x = self.conv2(x) 25 | x = self.leaky_relu(x) 26 | x = self.conv3(x) 27 | x = self.leaky_relu(x) 28 | x = self.conv4(x) 29 | x = self.leaky_relu(x) 30 | x = self.classifier(x) 31 | #x = self.up_sample(x) 32 | #x = self.sigmoid(x) 33 | 34 | return x 35 | -------------------------------------------------------------------------------- /pdf/.gitkeep: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /pdf/Zheng-Yang2021_Article_RectifyingPseudoLabelLearningV.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/layumi/Seg-Uncertainty/6fce9eae141c2c0592b3e7c1b3e5f8ee7b1ce9a6/pdf/Zheng-Yang2021_Article_RectifyingPseudoLabelLearningV.pdf -------------------------------------------------------------------------------- /pdf/ijcai20.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/layumi/Seg-Uncertainty/6fce9eae141c2c0592b3e7c1b3e5f8ee7b1ce9a6/pdf/ijcai20.pdf -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | # -*- coding=utf-8 -*- 2 | # @Mail: beanocean@outlook.com 3 | # @D&T: Sat 07 Dec 2019 22:04:31 AEDT 4 | 5 | import os 6 | 7 | while(1): 8 | try: 9 | os.system('python train_ms.py --snapshot-dir ./snapshots/SE_GN_batchsize2_1024x512_pp_ms_me0_classbalance7_kl0.1_lr2_drop0.1_seg0.5_aug_fp16 --drop 0.1 --warm-up 5000 --batch-size 2 --learning-rate 2e-4 --crop-size 1024,512 --lambda-seg 0.5 --lambda-adv-target1 0.0002 --lambda-adv-target2 0.001 --lambda-me-target 0 --lambda-kl-target 0.1 --norm-style gn --class-balance --only-hard-label 80 --max-value 7 --gpu-ids 0,1 --often-balance --use-se --autoaug') 10 | except: 11 | continue 12 | -------------------------------------------------------------------------------- /trainer_ms.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | from torch.utils import data, model_zoo 3 | import torch.optim as optim 4 | import torch.nn.functional as F 5 | from model.deeplab_multi import DeeplabMulti 6 | from model.discriminator import FCDiscriminator 7 | from model.ms_discriminator import MsImageDis 8 | import torch 9 | import torch.nn.init as init 10 | import copy 11 | import numpy as np 12 | #fp16 13 | try: 14 | import apex 15 | from apex import amp 16 | from apex.fp16_utils import * 17 | except ImportError: 18 | print('This is not an error. If you want to use low precision, i.e., fp16, please install the apex with cuda support (https://github.com/NVIDIA/apex) and update pytorch to 1.0') 19 | 20 | def weights_init(init_type='gaussian'): 21 | def init_fun(m): 22 | classname = m.__class__.__name__ 23 | if (classname.find('Conv') == 0 or classname.find('Linear') == 0) and hasattr(m, 'weight'): 24 | # print m.__class__.__name__ 25 | if init_type == 'gaussian': 26 | init.normal_(m.weight.data, 0.0, 0.02) 27 | elif init_type == 'xavier': 28 | init.xavier_normal_(m.weight.data, gain=math.sqrt(2)) 29 | elif init_type == 'kaiming': 30 | init.kaiming_normal_(m.weight.data, a=0, mode='fan_in') 31 | elif init_type == 'orthogonal': 32 | init.orthogonal_(m.weight.data, gain=math.sqrt(2)) 33 | elif init_type == 'default': 34 | pass 35 | else: 36 | assert 0, "Unsupported initialization: {}".format(init_type) 37 | if hasattr(m, 'bias') and m.bias is not None: 38 | init.constant_(m.bias.data, 0.0) 39 | 40 | return init_fun 41 | 42 | def train_bn(m): 43 | classname = m.__class__.__name__ 44 | if classname.find('BatchNorm') != -1: 45 | m.train() 46 | 47 | def inplace_relu(m): 48 | classname = m.__class__.__name__ 49 | if classname.find('ReLU') != -1: 50 | m.inplace=True 51 | 52 | def fliplr(img): 53 | '''flip horizontal''' 54 | inv_idx = torch.arange(img.size(3)-1,-1,-1).long().cuda() # N x C x H x W 55 | img_flip = img.index_select(3,inv_idx) 56 | return img_flip 57 | 58 | class AD_Trainer(nn.Module): 59 | def __init__(self, args): 60 | super(AD_Trainer, self).__init__() 61 | self.fp16 = args.fp16 62 | self.class_balance = args.class_balance 63 | self.often_balance = args.often_balance 64 | self.num_classes = args.num_classes 65 | self.class_weight = torch.FloatTensor(self.num_classes).zero_().cuda() + 1 66 | self.often_weight = torch.FloatTensor(self.num_classes).zero_().cuda() + 1 67 | self.multi_gpu = args.multi_gpu 68 | self.only_hard_label = args.only_hard_label 69 | if args.model == 'DeepLab': 70 | self.G = DeeplabMulti(num_classes=args.num_classes, use_se = args.use_se, train_bn = args.train_bn, norm_style = args.norm_style, droprate = args.droprate) 71 | if args.restore_from[:4] == 'http' : 72 | saved_state_dict = model_zoo.load_url(args.restore_from) 73 | else: 74 | saved_state_dict = torch.load(args.restore_from) 75 | 76 | new_params = self.G.state_dict().copy() 77 | for i in saved_state_dict: 78 | # Scale.layer5.conv2d_list.3.weight 79 | i_parts = i.split('.') 80 | # print i_parts 81 | if args.restore_from[:4] == 'http' : 82 | if i_parts[1] !='fc' and i_parts[1] !='layer5': 83 | new_params['.'.join(i_parts[1:])] = saved_state_dict[i] 84 | print('%s is loaded from pre-trained weight.\n'%i_parts[1:]) 85 | else: 86 | #new_params['.'.join(i_parts[1:])] = saved_state_dict[i] 87 | if i_parts[0] =='module': 88 | new_params['.'.join(i_parts[1:])] = saved_state_dict[i] 89 | print('%s is loaded from pre-trained weight.\n'%i_parts[1:]) 90 | else: 91 | new_params['.'.join(i_parts[0:])] = saved_state_dict[i] 92 | print('%s is loaded from pre-trained weight.\n'%i_parts[0:]) 93 | self.G.load_state_dict(new_params) 94 | 95 | self.D1 = MsImageDis(input_dim = args.num_classes).cuda() 96 | self.D2 = MsImageDis(input_dim = args.num_classes).cuda() 97 | self.D1.apply(weights_init('gaussian')) 98 | self.D2.apply(weights_init('gaussian')) 99 | 100 | if self.multi_gpu and args.sync_bn: 101 | print("using apex synced BN") 102 | self.G = apex.parallel.convert_syncbn_model(self.G) 103 | 104 | self.gen_opt = optim.SGD(self.G.optim_parameters(args), 105 | lr=args.learning_rate, momentum=args.momentum, nesterov=True, weight_decay=args.weight_decay) 106 | 107 | self.dis1_opt = optim.Adam(self.D1.parameters(), lr=args.learning_rate_D, betas=(0.9, 0.99)) 108 | 109 | self.dis2_opt = optim.Adam(self.D2.parameters(), lr=args.learning_rate_D, betas=(0.9, 0.99)) 110 | 111 | self.seg_loss = nn.CrossEntropyLoss(ignore_index=255) 112 | self.kl_loss = nn.KLDivLoss(size_average=False) 113 | self.sm = torch.nn.Softmax(dim = 1) 114 | self.log_sm = torch.nn.LogSoftmax(dim = 1) 115 | self.G = self.G.cuda() 116 | self.D1 = self.D1.cuda() 117 | self.D2 = self.D2.cuda() 118 | self.interp = nn.Upsample(size= args.crop_size, mode='bilinear', align_corners=True) 119 | self.interp_target = nn.Upsample(size= args.crop_size, mode='bilinear', align_corners=True) 120 | self.lambda_seg = args.lambda_seg 121 | self.max_value = args.max_value 122 | self.lambda_me_target = args.lambda_me_target 123 | self.lambda_kl_target = args.lambda_kl_target 124 | self.lambda_adv_target1 = args.lambda_adv_target1 125 | self.lambda_adv_target2 = args.lambda_adv_target2 126 | self.class_w = torch.FloatTensor(self.num_classes).zero_().cuda() + 1 127 | if args.fp16: 128 | # Name the FP16_Optimizer instance to replace the existing optimizer 129 | assert torch.backends.cudnn.enabled, "fp16 mode requires cudnn backend to be enabled." 130 | self.G, self.gen_opt = amp.initialize(self.G, self.gen_opt, opt_level="O1") 131 | self.D1, self.dis1_opt = amp.initialize(self.D1, self.dis1_opt, opt_level="O1") 132 | self.D2, self.dis2_opt = amp.initialize(self.D2, self.dis2_opt, opt_level="O1") 133 | 134 | def update_class_criterion(self, labels): 135 | weight = torch.FloatTensor(self.num_classes).zero_().cuda() 136 | weight += 1 137 | count = torch.FloatTensor(self.num_classes).zero_().cuda() 138 | often = torch.FloatTensor(self.num_classes).zero_().cuda() 139 | often += 1 140 | print(labels.shape) 141 | n, h, w = labels.shape 142 | for i in range(self.num_classes): 143 | count[i] = torch.sum(labels==i) 144 | if count[i] < 64*64*n: #small objective 145 | weight[i] = self.max_value 146 | if self.often_balance: 147 | often[count == 0] = self.max_value 148 | 149 | self.often_weight = 0.9 * self.often_weight + 0.1 * often 150 | self.class_weight = weight * self.often_weight 151 | print(self.class_weight) 152 | return nn.CrossEntropyLoss(weight = self.class_weight, ignore_index=255) 153 | 154 | def update_label(self, labels, prediction): 155 | criterion = nn.CrossEntropyLoss(weight = self.class_weight, ignore_index=255, reduction = 'none') 156 | #criterion = self.seg_loss 157 | loss = criterion(prediction, labels) 158 | print('original loss: %f'% self.seg_loss(prediction, labels) ) 159 | #mm = torch.median(loss) 160 | loss_data = loss.data.cpu().numpy() 161 | mm = np.percentile(loss_data[:], self.only_hard_label) 162 | #print(m.data.cpu(), mm) 163 | labels[loss < mm] = 255 164 | return labels 165 | 166 | 167 | def gen_update(self, images, images_t, labels, labels_t, i_iter): 168 | self.gen_opt.zero_grad() 169 | 170 | pred1, pred2 = self.G(images) 171 | pred1 = self.interp(pred1) 172 | pred2 = self.interp(pred2) 173 | 174 | if self.class_balance: 175 | self.seg_loss = self.update_class_criterion(labels) 176 | 177 | if self.only_hard_label > 0: 178 | labels1 = self.update_label(labels.clone(), pred1) 179 | labels2 = self.update_label(labels.clone(), pred2) 180 | loss_seg1 = self.seg_loss(pred1, labels1) 181 | loss_seg2 = self.seg_loss(pred2, labels2) 182 | else: 183 | loss_seg1 = self.seg_loss(pred1, labels) 184 | loss_seg2 = self.seg_loss(pred2, labels) 185 | 186 | loss = loss_seg2 + self.lambda_seg * loss_seg1 187 | 188 | # target 189 | pred_target1, pred_target2 = self.G(images_t) 190 | pred_target1 = self.interp_target(pred_target1) 191 | pred_target2 = self.interp_target(pred_target2) 192 | 193 | if self.multi_gpu: 194 | #if self.lambda_adv_target1 > 0 and self.lambda_adv_target2 > 0: 195 | loss_adv_target1 = self.D1.module.calc_gen_loss( self.D1, input_fake = F.softmax(pred_target1, dim=1) ) 196 | loss_adv_target2 = self.D2.module.calc_gen_loss( self.D2, input_fake = F.softmax(pred_target2, dim=1) ) 197 | #else: 198 | # print('skip the discriminator') 199 | # loss_adv_target1, loss_adv_target2 = 0, 0 200 | else: 201 | #if self.lambda_adv_target1 > 0 and self.lambda_adv_target2 > 0: 202 | loss_adv_target1 = self.D1.calc_gen_loss( self.D1, input_fake = F.softmax(pred_target1, dim=1) ) 203 | loss_adv_target2 = self.D2.calc_gen_loss( self.D2, input_fake = F.softmax(pred_target2, dim=1) ) 204 | #else: 205 | #loss_adv_target1 = 0.0 #torch.tensor(0).cuda() 206 | #loss_adv_target2 = 0.0 #torch.tensor(0).cuda() 207 | 208 | loss += self.lambda_adv_target1 * loss_adv_target1 + self.lambda_adv_target2 * loss_adv_target2 209 | 210 | 211 | if i_iter < 15000: 212 | self.lambda_kl_target_copy = 0 213 | self.lambda_me_target_copy = 0 214 | else: 215 | self.lambda_kl_target_copy = self.lambda_kl_target 216 | self.lambda_me_target_copy = self.lambda_me_target 217 | 218 | loss_me = 0.0 219 | if self.lambda_me_target_copy>0: 220 | confidence_map = torch.sum( self.sm(0.5*pred_target1 + pred_target2)**2, 1).detach() 221 | loss_me = -torch.mean( confidence_map * torch.sum( self.sm(0.5*pred_target1 + pred_target2) * self.log_sm(0.5*pred_target1 + pred_target2), 1) ) 222 | loss += self.lambda_me_target * loss_me 223 | 224 | loss_kl = 0.0 225 | if self.lambda_kl_target_copy>0: 226 | n, c, h, w = pred_target1.shape 227 | with torch.no_grad(): 228 | #pred_target1_flip, pred_target2_flip = self.G(fliplr(images_t)) 229 | #pred_target1_flip = self.interp_target(pred_target1_flip) 230 | #pred_target2_flip = self.interp_target(pred_target2_flip) 231 | mean_pred = self.sm(0.5*pred_target1 + pred_target2) #+ self.sm(fliplr(0.5*pred_target1_flip + pred_target2_flip)) ) /2 232 | loss_kl = ( self.kl_loss(self.log_sm(pred_target2) , mean_pred) + self.kl_loss(self.log_sm(pred_target1) , mean_pred))/(n*h*w) 233 | #loss_kl = (self.kl_loss(self.log_sm(pred_target2) , self.sm(pred_target1) ) ) / (n*h*w) + (self.kl_loss(self.log_sm(pred_target1) , self.sm(pred_target2)) ) / (n*h*w) 234 | print(loss_kl) 235 | loss += self.lambda_kl_target * loss_kl 236 | 237 | if self.fp16: 238 | with amp.scale_loss(loss, self.gen_opt) as scaled_loss: 239 | scaled_loss.backward() 240 | else: 241 | loss.backward() 242 | self.gen_opt.step() 243 | 244 | val_loss = self.seg_loss(pred_target2, labels_t) 245 | 246 | return loss_seg1, loss_seg2, loss_adv_target1, loss_adv_target2, loss_me, loss_kl, pred1, pred2, pred_target1, pred_target2, val_loss 247 | 248 | def dis_update(self, pred1, pred2, pred_target1, pred_target2): 249 | self.dis1_opt.zero_grad() 250 | self.dis2_opt.zero_grad() 251 | pred1 = pred1.detach() 252 | pred2 = pred2.detach() 253 | pred_target1 = pred_target1.detach() 254 | pred_target2 = pred_target2.detach() 255 | 256 | if self.multi_gpu: 257 | loss_D1, reg1 = self.D1.module.calc_dis_loss( self.D1, input_fake = F.softmax(pred_target1, dim=1), input_real = F.softmax(0.5*pred1 + pred2, dim=1) ) 258 | loss_D2, reg2 = self.D2.module.calc_dis_loss( self.D2, input_fake = F.softmax(pred_target2, dim=1), input_real = F.softmax(0.5*pred1 + pred2, dim=1) ) 259 | else: 260 | loss_D1, reg1 = self.D1.calc_dis_loss( self.D1, input_fake = F.softmax(pred_target1, dim=1), input_real = F.softmax(0.5*pred1 + pred2, dim=1) ) 261 | loss_D2, reg2 = self.D2.calc_dis_loss( self.D2, input_fake = F.softmax(pred_target2, dim=1), input_real = F.softmax(0.5*pred1 + pred2, dim=1) ) 262 | 263 | loss = loss_D1 + loss_D2 264 | if self.fp16: 265 | with amp.scale_loss(loss, [self.dis1_opt, self.dis2_opt]) as scaled_loss: 266 | scaled_loss.backward() 267 | else: 268 | loss.backward() 269 | 270 | self.dis1_opt.step() 271 | self.dis2_opt.step() 272 | return loss_D1, loss_D2 273 | -------------------------------------------------------------------------------- /trainer_ms_variance.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | from torch.utils import data, model_zoo 3 | import torch.optim as optim 4 | import torch.nn.functional as F 5 | from model.deeplab_multi import DeeplabMulti 6 | from model.discriminator import FCDiscriminator 7 | from model.ms_discriminator import MsImageDis 8 | import torch 9 | import torch.nn.init as init 10 | import copy 11 | import numpy as np 12 | #fp16 13 | try: 14 | import apex 15 | from apex import amp 16 | from apex.fp16_utils import * 17 | except ImportError: 18 | print('This is not an error. If you want to use low precision, i.e., fp16, please install the apex with cuda support (https://github.com/NVIDIA/apex) and update pytorch to 1.0') 19 | 20 | def weights_init(init_type='gaussian'): 21 | def init_fun(m): 22 | classname = m.__class__.__name__ 23 | if (classname.find('Conv') == 0 or classname.find('Linear') == 0) and hasattr(m, 'weight'): 24 | # print m.__class__.__name__ 25 | if init_type == 'gaussian': 26 | init.normal_(m.weight.data, 0.0, 0.02) 27 | elif init_type == 'xavier': 28 | init.xavier_normal_(m.weight.data, gain=math.sqrt(2)) 29 | elif init_type == 'kaiming': 30 | init.kaiming_normal_(m.weight.data, a=0, mode='fan_in') 31 | elif init_type == 'orthogonal': 32 | init.orthogonal_(m.weight.data, gain=math.sqrt(2)) 33 | elif init_type == 'default': 34 | pass 35 | else: 36 | assert 0, "Unsupported initialization: {}".format(init_type) 37 | if hasattr(m, 'bias') and m.bias is not None: 38 | init.constant_(m.bias.data, 0.0) 39 | 40 | return init_fun 41 | 42 | def train_bn(m): 43 | classname = m.__class__.__name__ 44 | if classname.find('BatchNorm') != -1: 45 | m.train() 46 | 47 | def inplace_relu(m): 48 | classname = m.__class__.__name__ 49 | if classname.find('ReLU') != -1: 50 | m.inplace=True 51 | 52 | def fliplr(img): 53 | '''flip horizontal''' 54 | inv_idx = torch.arange(img.size(3)-1,-1,-1).long().cuda() # N x C x H x W 55 | img_flip = img.index_select(3,inv_idx) 56 | return img_flip 57 | 58 | class AD_Trainer(nn.Module): 59 | def __init__(self, args): 60 | super(AD_Trainer, self).__init__() 61 | self.fp16 = args.fp16 62 | self.class_balance = args.class_balance 63 | self.often_balance = args.often_balance 64 | self.num_classes = args.num_classes 65 | self.class_weight = torch.FloatTensor(self.num_classes).zero_().cuda() + 1 66 | self.often_weight = torch.FloatTensor(self.num_classes).zero_().cuda() + 1 67 | self.multi_gpu = args.multi_gpu 68 | self.only_hard_label = args.only_hard_label 69 | if args.model == 'DeepLab': 70 | self.G = DeeplabMulti(num_classes=args.num_classes, use_se = args.use_se, train_bn = args.train_bn, norm_style = args.norm_style, droprate = args.droprate) 71 | if args.restore_from[:4] == 'http' : 72 | saved_state_dict = model_zoo.load_url(args.restore_from) 73 | else: 74 | saved_state_dict = torch.load(args.restore_from) 75 | 76 | new_params = self.G.state_dict().copy() 77 | for i in saved_state_dict: 78 | # Scale.layer5.conv2d_list.3.weight 79 | i_parts = i.split('.') 80 | # print i_parts 81 | if args.restore_from[:4] == 'http' : 82 | if i_parts[1] !='fc' and i_parts[1] !='layer5': 83 | new_params['.'.join(i_parts[1:])] = saved_state_dict[i] 84 | print('%s is loaded from pre-trained weight.\n'%i_parts[1:]) 85 | else: 86 | #new_params['.'.join(i_parts[1:])] = saved_state_dict[i] 87 | if i_parts[0] =='module': 88 | new_params['.'.join(i_parts[1:])] = saved_state_dict[i] 89 | print('%s is loaded from pre-trained weight.\n'%i_parts[1:]) 90 | else: 91 | new_params['.'.join(i_parts[0:])] = saved_state_dict[i] 92 | print('%s is loaded from pre-trained weight.\n'%i_parts[0:]) 93 | self.G.load_state_dict(new_params) 94 | 95 | self.D1 = MsImageDis(input_dim = args.num_classes).cuda() 96 | self.D2 = MsImageDis(input_dim = args.num_classes).cuda() 97 | self.D1.apply(weights_init('gaussian')) 98 | self.D2.apply(weights_init('gaussian')) 99 | 100 | if self.multi_gpu and args.sync_bn: 101 | print("using apex synced BN") 102 | self.G = apex.parallel.convert_syncbn_model(self.G) 103 | 104 | self.gen_opt = optim.SGD(self.G.optim_parameters(args), 105 | lr=args.learning_rate, momentum=args.momentum, nesterov=True, weight_decay=args.weight_decay) 106 | 107 | self.dis1_opt = optim.Adam(self.D1.parameters(), lr=args.learning_rate_D, betas=(0.9, 0.99)) 108 | 109 | self.dis2_opt = optim.Adam(self.D2.parameters(), lr=args.learning_rate_D, betas=(0.9, 0.99)) 110 | 111 | self.seg_loss = nn.CrossEntropyLoss(ignore_index=255) 112 | self.kl_loss = nn.KLDivLoss(size_average=False) 113 | self.sm = torch.nn.Softmax(dim = 1) 114 | self.log_sm = torch.nn.LogSoftmax(dim = 1) 115 | self.G = self.G.cuda() 116 | self.D1 = self.D1.cuda() 117 | self.D2 = self.D2.cuda() 118 | self.interp = nn.Upsample(size= args.crop_size, mode='bilinear', align_corners=True) 119 | self.interp_target = nn.Upsample(size= args.crop_size, mode='bilinear', align_corners=True) 120 | self.lambda_seg = args.lambda_seg 121 | self.max_value = args.max_value 122 | self.lambda_me_target = args.lambda_me_target 123 | self.lambda_kl_target = args.lambda_kl_target 124 | self.lambda_adv_target1 = args.lambda_adv_target1 125 | self.lambda_adv_target2 = args.lambda_adv_target2 126 | self.class_w = torch.FloatTensor(self.num_classes).zero_().cuda() + 1 127 | if args.fp16: 128 | # Name the FP16_Optimizer instance to replace the existing optimizer 129 | assert torch.backends.cudnn.enabled, "fp16 mode requires cudnn backend to be enabled." 130 | self.G, self.gen_opt = amp.initialize(self.G, self.gen_opt, opt_level="O1") 131 | self.D1, self.dis1_opt = amp.initialize(self.D1, self.dis1_opt, opt_level="O1") 132 | self.D2, self.dis2_opt = amp.initialize(self.D2, self.dis2_opt, opt_level="O1") 133 | 134 | def update_class_criterion(self, labels): 135 | weight = torch.FloatTensor(self.num_classes).zero_().cuda() 136 | weight += 1 137 | count = torch.FloatTensor(self.num_classes).zero_().cuda() 138 | often = torch.FloatTensor(self.num_classes).zero_().cuda() 139 | often += 1 140 | print(labels.shape) 141 | n, h, w = labels.shape 142 | for i in range(self.num_classes): 143 | count[i] = torch.sum(labels==i) 144 | if count[i] < 64*64*n: #small objective 145 | weight[i] = self.max_value 146 | if self.often_balance: 147 | often[count == 0] = self.max_value 148 | 149 | self.often_weight = 0.9 * self.often_weight + 0.1 * often 150 | self.class_weight = weight * self.often_weight 151 | print(self.class_weight) 152 | return nn.CrossEntropyLoss(weight = self.class_weight, ignore_index=255) 153 | 154 | def update_label(self, labels, prediction): 155 | criterion = nn.CrossEntropyLoss(weight = self.class_weight, ignore_index=255, reduction = 'none') 156 | #criterion = self.seg_loss 157 | loss = criterion(prediction, labels) 158 | print('original loss: %f'% self.seg_loss(prediction, labels) ) 159 | #mm = torch.median(loss) 160 | loss_data = loss.data.cpu().numpy() 161 | mm = np.percentile(loss_data[:], self.only_hard_label) 162 | #print(m.data.cpu(), mm) 163 | labels[loss < mm] = 255 164 | return labels 165 | 166 | def update_variance(self, labels, pred1, pred2): 167 | criterion = nn.CrossEntropyLoss(weight = self.class_weight, ignore_index=255, reduction = 'none') 168 | kl_distance = nn.KLDivLoss( reduction = 'none') 169 | loss = criterion(pred1, labels) 170 | 171 | #n, h, w = labels.shape 172 | #labels_onehot = torch.zeros(n, self.num_classes, h, w) 173 | #labels_onehot = labels_onehot.cuda() 174 | #labels_onehot.scatter_(1, labels.view(n,1,h,w), 1) 175 | 176 | variance = torch.sum(kl_distance(self.log_sm(pred1),self.sm(pred2)), dim=1) 177 | exp_variance = torch.exp(-variance) 178 | #variance = torch.log( 1 + (torch.mean((pred1-pred2)**2, dim=1))) 179 | #torch.mean( kl_distance(self.log_sm(pred1),pred2), dim=1) + 1e-6 180 | print(variance.shape) 181 | print('variance mean: %.4f'%torch.mean(exp_variance[:])) 182 | print('variance min: %.4f'%torch.min(exp_variance[:])) 183 | print('variance max: %.4f'%torch.max(exp_variance[:])) 184 | #loss = torch.mean(loss/variance) + torch.mean(variance) 185 | loss = torch.mean(loss*exp_variance) + torch.mean(variance) 186 | return loss 187 | 188 | def gen_update(self, images, images_t, labels, labels_t, i_iter): 189 | self.gen_opt.zero_grad() 190 | 191 | pred1, pred2 = self.G(images) 192 | pred1 = self.interp(pred1) 193 | pred2 = self.interp(pred2) 194 | 195 | if self.class_balance: 196 | self.seg_loss = self.update_class_criterion(labels) 197 | 198 | loss_seg1 = self.update_variance(labels, pred1, pred2) 199 | loss_seg2 = self.update_variance(labels, pred2, pred1) 200 | 201 | loss = loss_seg2 + self.lambda_seg * loss_seg1 202 | 203 | # target 204 | pred_target1, pred_target2 = self.G(images_t) 205 | pred_target1 = self.interp_target(pred_target1) 206 | pred_target2 = self.interp_target(pred_target2) 207 | 208 | if self.multi_gpu: 209 | #if self.lambda_adv_target1 > 0 and self.lambda_adv_target2 > 0: 210 | loss_adv_target1 = self.D1.module.calc_gen_loss( self.D1, input_fake = F.softmax(pred_target1, dim=1) ) 211 | loss_adv_target2 = self.D2.module.calc_gen_loss( self.D2, input_fake = F.softmax(pred_target2, dim=1) ) 212 | #else: 213 | # print('skip the discriminator') 214 | # loss_adv_target1, loss_adv_target2 = 0, 0 215 | else: 216 | #if self.lambda_adv_target1 > 0 and self.lambda_adv_target2 > 0: 217 | loss_adv_target1 = self.D1.calc_gen_loss( self.D1, input_fake = F.softmax(pred_target1, dim=1) ) 218 | loss_adv_target2 = self.D2.calc_gen_loss( self.D2, input_fake = F.softmax(pred_target2, dim=1) ) 219 | #else: 220 | #loss_adv_target1 = 0.0 #torch.tensor(0).cuda() 221 | #loss_adv_target2 = 0.0 #torch.tensor(0).cuda() 222 | 223 | loss += self.lambda_adv_target1 * loss_adv_target1 + self.lambda_adv_target2 * loss_adv_target2 224 | 225 | 226 | if i_iter < 15000: 227 | self.lambda_kl_target_copy = 0 228 | self.lambda_me_target_copy = 0 229 | else: 230 | self.lambda_kl_target_copy = self.lambda_kl_target 231 | self.lambda_me_target_copy = self.lambda_me_target 232 | 233 | loss_me = 0.0 234 | if self.lambda_me_target_copy>0: 235 | confidence_map = torch.sum( self.sm(0.5*pred_target1 + pred_target2)**2, 1).detach() 236 | loss_me = -torch.mean( confidence_map * torch.sum( self.sm(0.5*pred_target1 + pred_target2) * self.log_sm(0.5*pred_target1 + pred_target2), 1) ) 237 | loss += self.lambda_me_target * loss_me 238 | 239 | loss_kl = 0.0 240 | if self.lambda_kl_target_copy>0: 241 | n, c, h, w = pred_target1.shape 242 | with torch.no_grad(): 243 | #pred_target1_flip, pred_target2_flip = self.G(fliplr(images_t)) 244 | #pred_target1_flip = self.interp_target(pred_target1_flip) 245 | #pred_target2_flip = self.interp_target(pred_target2_flip) 246 | mean_pred = self.sm(0.5*pred_target1 + pred_target2) #+ self.sm(fliplr(0.5*pred_target1_flip + pred_target2_flip)) ) /2 247 | loss_kl = ( self.kl_loss(self.log_sm(pred_target2) , mean_pred) + self.kl_loss(self.log_sm(pred_target1) , mean_pred))/(n*h*w) 248 | #loss_kl = (self.kl_loss(self.log_sm(pred_target2) , self.sm(pred_target1) ) ) / (n*h*w) + (self.kl_loss(self.log_sm(pred_target1) , self.sm(pred_target2)) ) / (n*h*w) 249 | print(loss_kl) 250 | loss += self.lambda_kl_target * loss_kl 251 | 252 | if self.fp16: 253 | with amp.scale_loss(loss, self.gen_opt) as scaled_loss: 254 | scaled_loss.backward() 255 | else: 256 | loss.backward() 257 | self.gen_opt.step() 258 | 259 | val_loss = self.seg_loss(pred_target2, labels_t) 260 | 261 | return loss_seg1, loss_seg2, loss_adv_target1, loss_adv_target2, loss_me, loss_kl, pred1, pred2, pred_target1, pred_target2, val_loss 262 | 263 | def dis_update(self, pred1, pred2, pred_target1, pred_target2): 264 | self.dis1_opt.zero_grad() 265 | self.dis2_opt.zero_grad() 266 | pred1 = pred1.detach() 267 | pred2 = pred2.detach() 268 | pred_target1 = pred_target1.detach() 269 | pred_target2 = pred_target2.detach() 270 | 271 | if self.multi_gpu: 272 | loss_D1, reg1 = self.D1.module.calc_dis_loss( self.D1, input_fake = F.softmax(pred_target1, dim=1), input_real = F.softmax(0.5*pred1 + pred2, dim=1) ) 273 | loss_D2, reg2 = self.D2.module.calc_dis_loss( self.D2, input_fake = F.softmax(pred_target2, dim=1), input_real = F.softmax(0.5*pred1 + pred2, dim=1) ) 274 | else: 275 | loss_D1, reg1 = self.D1.calc_dis_loss( self.D1, input_fake = F.softmax(pred_target1, dim=1), input_real = F.softmax(0.5*pred1 + pred2, dim=1) ) 276 | loss_D2, reg2 = self.D2.calc_dis_loss( self.D2, input_fake = F.softmax(pred_target2, dim=1), input_real = F.softmax(0.5*pred1 + pred2, dim=1) ) 277 | 278 | loss = loss_D1 + loss_D2 279 | if self.fp16: 280 | with amp.scale_loss(loss, [self.dis1_opt, self.dis2_opt]) as scaled_loss: 281 | scaled_loss.backward() 282 | else: 283 | loss.backward() 284 | 285 | self.dis1_opt.step() 286 | self.dis2_opt.step() 287 | return loss_D1, loss_D2 288 | -------------------------------------------------------------------------------- /try_run.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | while(True): 4 | os.system('python train_ms.py --snapshot-dir ./snapshots/SE_GN_batchsize2_1024x512_pp_ms_me0_classbalance7_kl0.1_lr2_drop0.1_seg0.5_aug_fp16 --drop 0.1 --warm-up 5000 --batch-size 2 --learning-rate 2e-4 --crop-size 1024,512 --lambda-seg 0.5 --lambda-adv-target1 0.0002 --lambda-adv-target2 0.001 --lambda-me-target 0 --lambda-kl-target 0.1 --norm-style gn --class-balance --only-hard-label 80 --max-value 7 --gpu-ids 0,1 --often-balance --use-se --autoaug --fp16') 5 | -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/layumi/Seg-Uncertainty/6fce9eae141c2c0592b3e7c1b3e5f8ee7b1ce9a6/utils/__init__.py -------------------------------------------------------------------------------- /utils/autoaugment.py: -------------------------------------------------------------------------------- 1 | from PIL import Image, ImageEnhance, ImageOps 2 | import numpy as np 3 | import random 4 | 5 | 6 | class ImageNetPolicy(object): 7 | """ Randomly choose one of the best 24 Sub-policies on ImageNet. 8 | 9 | Example: 10 | >>> policy = ImageNetPolicy() 11 | >>> transformed = policy(image) 12 | 13 | Example as a PyTorch Transform: 14 | >>> transform=transforms.Compose([ 15 | >>> transforms.Resize(256), 16 | >>> ImageNetPolicy(), 17 | >>> transforms.ToTensor()]) 18 | """ 19 | def __init__(self, fillcolor=(128, 128, 128)): 20 | self.policies = [ 21 | SubPolicy(0.4, "posterize", 8, 0.6, "rotate", 9, fillcolor), 22 | SubPolicy(0.6, "solarize", 5, 0.6, "autocontrast", 5, fillcolor), 23 | SubPolicy(0.8, "equalize", 8, 0.6, "equalize", 3, fillcolor), 24 | SubPolicy(0.6, "posterize", 7, 0.6, "posterize", 6, fillcolor), 25 | SubPolicy(0.4, "equalize", 7, 0.2, "solarize", 4, fillcolor), 26 | 27 | SubPolicy(0.4, "equalize", 4, 0.8, "rotate", 8, fillcolor), 28 | SubPolicy(0.6, "solarize", 3, 0.6, "equalize", 7, fillcolor), 29 | SubPolicy(0.8, "posterize", 5, 1.0, "equalize", 2, fillcolor), 30 | SubPolicy(0.2, "rotate", 3, 0.6, "solarize", 8, fillcolor), 31 | SubPolicy(0.6, "equalize", 8, 0.4, "posterize", 6, fillcolor), 32 | 33 | SubPolicy(0.8, "rotate", 8, 0.4, "color", 0, fillcolor), 34 | SubPolicy(0.4, "rotate", 9, 0.6, "equalize", 2, fillcolor), 35 | SubPolicy(0.0, "equalize", 7, 0.8, "equalize", 8, fillcolor), 36 | SubPolicy(0.6, "invert", 4, 1.0, "equalize", 8, fillcolor), 37 | SubPolicy(0.6, "color", 4, 1.0, "contrast", 8, fillcolor), 38 | 39 | SubPolicy(0.8, "rotate", 8, 1.0, "color", 2, fillcolor), 40 | SubPolicy(0.8, "color", 8, 0.8, "solarize", 7, fillcolor), 41 | SubPolicy(0.4, "sharpness", 7, 0.6, "invert", 8, fillcolor), 42 | SubPolicy(0.6, "shearX", 5, 1.0, "equalize", 9, fillcolor), 43 | SubPolicy(0.4, "color", 0, 0.6, "equalize", 3, fillcolor), 44 | 45 | SubPolicy(0.4, "equalize", 7, 0.2, "solarize", 4, fillcolor), 46 | SubPolicy(0.6, "solarize", 5, 0.6, "autocontrast", 5, fillcolor), 47 | SubPolicy(0.6, "invert", 4, 1.0, "equalize", 8, fillcolor), 48 | SubPolicy(0.6, "color", 4, 1.0, "contrast", 8, fillcolor), 49 | SubPolicy(0.8, "equalize", 8, 0.6, "equalize", 3, fillcolor) 50 | ] 51 | 52 | 53 | def __call__(self, img): 54 | policy_idx = random.randint(0, len(self.policies) - 1) 55 | return self.policies[policy_idx](img) 56 | 57 | def __repr__(self): 58 | return "AutoAugment ImageNet Policy" 59 | 60 | 61 | class CIFAR10Policy(object): 62 | """ Randomly choose one of the best 25 Sub-policies on CIFAR10. 63 | 64 | Example: 65 | >>> policy = CIFAR10Policy() 66 | >>> transformed = policy(image) 67 | 68 | Example as a PyTorch Transform: 69 | >>> transform=transforms.Compose([ 70 | >>> transforms.Resize(256), 71 | >>> CIFAR10Policy(), 72 | >>> transforms.ToTensor()]) 73 | """ 74 | def __init__(self, fillcolor=(128, 128, 128)): 75 | self.policies = [ 76 | SubPolicy(0.1, "invert", 7, 0.2, "contrast", 6, fillcolor), 77 | SubPolicy(0.7, "rotate", 2, 0.3, "translateX", 9, fillcolor), 78 | SubPolicy(0.8, "sharpness", 1, 0.9, "sharpness", 3, fillcolor), 79 | SubPolicy(0.5, "shearY", 8, 0.7, "translateY", 9, fillcolor), 80 | SubPolicy(0.5, "autocontrast", 8, 0.9, "equalize", 2, fillcolor), 81 | 82 | SubPolicy(0.2, "shearY", 7, 0.3, "posterize", 7, fillcolor), 83 | SubPolicy(0.4, "color", 3, 0.6, "brightness", 7, fillcolor), 84 | SubPolicy(0.3, "sharpness", 9, 0.7, "brightness", 9, fillcolor), 85 | SubPolicy(0.6, "equalize", 5, 0.5, "equalize", 1, fillcolor), 86 | SubPolicy(0.6, "contrast", 7, 0.6, "sharpness", 5, fillcolor), 87 | 88 | SubPolicy(0.7, "color", 7, 0.5, "translateX", 8, fillcolor), 89 | SubPolicy(0.3, "equalize", 7, 0.4, "autocontrast", 8, fillcolor), 90 | SubPolicy(0.4, "translateY", 3, 0.2, "sharpness", 6, fillcolor), 91 | SubPolicy(0.9, "brightness", 6, 0.2, "color", 8, fillcolor), 92 | SubPolicy(0.5, "solarize", 2, 0.0, "invert", 3, fillcolor), 93 | 94 | SubPolicy(0.2, "equalize", 0, 0.6, "autocontrast", 0, fillcolor), 95 | SubPolicy(0.2, "equalize", 8, 0.8, "equalize", 4, fillcolor), 96 | SubPolicy(0.9, "color", 9, 0.6, "equalize", 6, fillcolor), 97 | SubPolicy(0.8, "autocontrast", 4, 0.2, "solarize", 8, fillcolor), 98 | SubPolicy(0.1, "brightness", 3, 0.7, "color", 0, fillcolor), 99 | 100 | SubPolicy(0.4, "solarize", 5, 0.9, "autocontrast", 3, fillcolor), 101 | SubPolicy(0.9, "translateY", 9, 0.7, "translateY", 9, fillcolor), 102 | SubPolicy(0.9, "autocontrast", 2, 0.8, "solarize", 3, fillcolor), 103 | SubPolicy(0.8, "equalize", 8, 0.1, "invert", 3, fillcolor), 104 | SubPolicy(0.7, "translateY", 9, 0.9, "autocontrast", 1, fillcolor) 105 | ] 106 | 107 | 108 | def __call__(self, img): 109 | policy_idx = random.randint(0, len(self.policies) - 1) 110 | return self.policies[policy_idx](img) 111 | 112 | def __repr__(self): 113 | return "AutoAugment CIFAR10 Policy" 114 | 115 | 116 | class SVHNPolicy(object): 117 | """ Randomly choose one of the best 25 Sub-policies on SVHN. 118 | 119 | Example: 120 | >>> policy = SVHNPolicy() 121 | >>> transformed = policy(image) 122 | 123 | Example as a PyTorch Transform: 124 | >>> transform=transforms.Compose([ 125 | >>> transforms.Resize(256), 126 | >>> SVHNPolicy(), 127 | >>> transforms.ToTensor()]) 128 | """ 129 | def __init__(self, fillcolor=(128, 128, 128)): 130 | self.policies = [ 131 | SubPolicy(0.9, "shearX", 4, 0.2, "invert", 3, fillcolor), 132 | SubPolicy(0.9, "shearY", 8, 0.7, "invert", 5, fillcolor), 133 | SubPolicy(0.6, "equalize", 5, 0.6, "solarize", 6, fillcolor), 134 | SubPolicy(0.9, "invert", 3, 0.6, "equalize", 3, fillcolor), 135 | SubPolicy(0.6, "equalize", 1, 0.9, "rotate", 3, fillcolor), 136 | 137 | SubPolicy(0.9, "shearX", 4, 0.8, "autocontrast", 3, fillcolor), 138 | SubPolicy(0.9, "shearY", 8, 0.4, "invert", 5, fillcolor), 139 | SubPolicy(0.9, "shearY", 5, 0.2, "solarize", 6, fillcolor), 140 | SubPolicy(0.9, "invert", 6, 0.8, "autocontrast", 1, fillcolor), 141 | SubPolicy(0.6, "equalize", 3, 0.9, "rotate", 3, fillcolor), 142 | 143 | SubPolicy(0.9, "shearX", 4, 0.3, "solarize", 3, fillcolor), 144 | SubPolicy(0.8, "shearY", 8, 0.7, "invert", 4, fillcolor), 145 | SubPolicy(0.9, "equalize", 5, 0.6, "translateY", 6, fillcolor), 146 | SubPolicy(0.9, "invert", 4, 0.6, "equalize", 7, fillcolor), 147 | SubPolicy(0.3, "contrast", 3, 0.8, "rotate", 4, fillcolor), 148 | 149 | SubPolicy(0.8, "invert", 5, 0.0, "translateY", 2, fillcolor), 150 | SubPolicy(0.7, "shearY", 6, 0.4, "solarize", 8, fillcolor), 151 | SubPolicy(0.6, "invert", 4, 0.8, "rotate", 4, fillcolor), 152 | SubPolicy(0.3, "shearY", 7, 0.9, "translateX", 3, fillcolor), 153 | SubPolicy(0.1, "shearX", 6, 0.6, "invert", 5, fillcolor), 154 | 155 | SubPolicy(0.7, "solarize", 2, 0.6, "translateY", 7, fillcolor), 156 | SubPolicy(0.8, "shearY", 4, 0.8, "invert", 8, fillcolor), 157 | SubPolicy(0.7, "shearX", 9, 0.8, "translateY", 3, fillcolor), 158 | SubPolicy(0.8, "shearY", 5, 0.7, "autocontrast", 3, fillcolor), 159 | SubPolicy(0.7, "shearX", 2, 0.1, "invert", 5, fillcolor) 160 | ] 161 | 162 | 163 | def __call__(self, img): 164 | policy_idx = random.randint(0, len(self.policies) - 1) 165 | return self.policies[policy_idx](img) 166 | 167 | def __repr__(self): 168 | return "AutoAugment SVHN Policy" 169 | 170 | 171 | class SubPolicy(object): 172 | def __init__(self, p1, operation1, magnitude_idx1, p2, operation2, magnitude_idx2, fillcolor=(128, 128, 128)): 173 | ranges = { 174 | "shearX": np.linspace(0, 0.3, 10), 175 | "shearY": np.linspace(0, 0.3, 10), 176 | "translateX": np.linspace(0, 150 / 331, 10), 177 | "translateY": np.linspace(0, 150 / 331, 10), 178 | "rotate": np.linspace(0, 30, 10), 179 | "color": np.linspace(0.0, 0.9, 10), 180 | "posterize": np.round(np.linspace(8, 4, 10), 0).astype(np.int), 181 | "solarize": np.linspace(256, 0, 10), 182 | "contrast": np.linspace(0.0, 0.9, 10), 183 | "sharpness": np.linspace(0.0, 0.9, 10), 184 | "brightness": np.linspace(0.0, 0.9, 10), 185 | "autocontrast": [0] * 10, 186 | "equalize": [0] * 10, 187 | "invert": [0] * 10 188 | } 189 | 190 | # from https://stackoverflow.com/questions/5252170/specify-image-filling-color-when-rotating-in-python-with-pil-and-setting-expand 191 | def rotate_with_fill(img, magnitude): 192 | rot = img.convert("RGBA").rotate(magnitude) 193 | return Image.composite(rot, Image.new("RGBA", rot.size, (128,) * 4), rot).convert(img.mode) 194 | 195 | func = { 196 | "shearX": lambda img, magnitude: img.transform( 197 | img.size, Image.AFFINE, (1, magnitude * random.choice([-1, 1]), 0, 0, 1, 0), 198 | Image.BICUBIC, fillcolor=fillcolor), 199 | "shearY": lambda img, magnitude: img.transform( 200 | img.size, Image.AFFINE, (1, 0, 0, magnitude * random.choice([-1, 1]), 1, 0), 201 | Image.BICUBIC, fillcolor=fillcolor), 202 | "translateX": lambda img, magnitude: img.transform( 203 | img.size, Image.AFFINE, (1, 0, magnitude * img.size[0] * random.choice([-1, 1]), 0, 1, 0), 204 | fillcolor=fillcolor), 205 | "translateY": lambda img, magnitude: img.transform( 206 | img.size, Image.AFFINE, (1, 0, 0, 0, 1, magnitude * img.size[1] * random.choice([-1, 1])), 207 | fillcolor=fillcolor), 208 | "rotate": lambda img, magnitude: rotate_with_fill(img, magnitude), 209 | # "rotate": lambda img, magnitude: img.rotate(magnitude * random.choice([-1, 1])), 210 | "color": lambda img, magnitude: ImageEnhance.Color(img).enhance(1 + magnitude * random.choice([-1, 1])), 211 | "posterize": lambda img, magnitude: ImageOps.posterize(img, magnitude), 212 | "solarize": lambda img, magnitude: ImageOps.solarize(img, magnitude), 213 | "contrast": lambda img, magnitude: ImageEnhance.Contrast(img).enhance( 214 | 1 + magnitude * random.choice([-1, 1])), 215 | "sharpness": lambda img, magnitude: ImageEnhance.Sharpness(img).enhance( 216 | 1 + magnitude * random.choice([-1, 1])), 217 | "brightness": lambda img, magnitude: ImageEnhance.Brightness(img).enhance( 218 | 1 + magnitude * random.choice([-1, 1])), 219 | "autocontrast": lambda img, magnitude: ImageOps.autocontrast(img), 220 | "equalize": lambda img, magnitude: ImageOps.equalize(img), 221 | "invert": lambda img, magnitude: ImageOps.invert(img) 222 | } 223 | 224 | # self.name = "{}_{:.2f}_and_{}_{:.2f}".format( 225 | # operation1, ranges[operation1][magnitude_idx1], 226 | # operation2, ranges[operation2][magnitude_idx2]) 227 | self.p1 = p1 228 | self.operation1 = func[operation1] 229 | self.magnitude1 = ranges[operation1][magnitude_idx1] 230 | self.p2 = p2 231 | self.operation2 = func[operation2] 232 | self.magnitude2 = ranges[operation2][magnitude_idx2] 233 | 234 | 235 | def __call__(self, img): 236 | if random.random() < self.p1: img = self.operation1(img, self.magnitude1) 237 | if random.random() < self.p2: img = self.operation2(img, self.magnitude2) 238 | return img 239 | -------------------------------------------------------------------------------- /utils/clear_model.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | root = '../snapshots/' 4 | nn = [] 5 | for f in os.listdir(root): 6 | for ff in os.listdir(root+f): 7 | dir_name = root+f 8 | for fff in os.listdir(dir_name): 9 | if fff =='opts.yaml': 10 | continue 11 | try: 12 | if int(fff[5:10])<25000 or int(fff[5:10])==30000 or int(fff[5:10])==35000 or int(fff[5:10])==40000 or int(fff[5:10])==45000 or int(fff[5:10])>70000: 13 | dst = dir_name+'/'+fff 14 | print(dst) 15 | os.remove(dst) 16 | except: 17 | continue 18 | -------------------------------------------------------------------------------- /utils/loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | import torch.nn as nn 4 | from torch.autograd import Variable 5 | 6 | 7 | class CrossEntropy2d(nn.Module): 8 | 9 | def __init__(self, size_average=True, ignore_label=255): 10 | super(CrossEntropy2d, self).__init__() 11 | self.size_average = size_average 12 | self.ignore_label = ignore_label 13 | 14 | def forward(self, predict, target, weight=None): 15 | """ 16 | Args: 17 | predict:(n, c, h, w) 18 | target:(n, h, w) 19 | weight (Tensor, optional): a manual rescaling weight given to each class. 20 | If given, has to be a Tensor of size "nclasses" 21 | """ 22 | assert not target.requires_grad 23 | assert predict.dim() == 4 24 | assert target.dim() == 3 25 | assert predict.size(0) == target.size(0), "{0} vs {1} ".format(predict.size(0), target.size(0)) 26 | assert predict.size(2) == target.size(1), "{0} vs {1} ".format(predict.size(2), target.size(1)) 27 | assert predict.size(3) == target.size(2), "{0} vs {1} ".format(predict.size(3), target.size(3)) 28 | n, c, h, w = predict.size() 29 | target_mask = (target >= 0) * (target != self.ignore_label) 30 | target = target[target_mask] 31 | if not target.data.dim(): 32 | return Variable(torch.zeros(1)) 33 | predict = predict.transpose(1, 2).transpose(2, 3).contiguous() 34 | predict = predict[target_mask.view(n, h, w, 1).repeat(1, 1, 1, c)].view(-1, c) 35 | loss = F.cross_entropy(predict, target, weight=weight, size_average=self.size_average) 36 | return loss 37 | -------------------------------------------------------------------------------- /utils/tool.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import time 3 | 4 | def lr_poly(base_lr, iter, max_iter, power): 5 | return base_lr * ((1 - float(iter) / max_iter) ** (power)) 6 | 7 | def lr_step(base_lr, iter): 8 | lr = base_lr 9 | if iter>40000: 10 | lr = base_lr * 0.5 11 | if iter>60000: 12 | lr = base_lr * 0.5 * 0.5 13 | if iter>70000: 14 | lr = base_lr * 0.5 * 0.5 * 0.5 15 | return lr 16 | 17 | def adjust_learning_rate(optimizer, i_iter, args): 18 | if i_iter < args.warm_up: 19 | lr = args.learning_rate * (0.1 + 0.9 * i_iter / args.warm_up) 20 | else: 21 | lr = lr_poly(args.learning_rate, i_iter, args.num_steps, args.power) 22 | #lr = lr_step(args.learning_rate, i_iter) 23 | optimizer.param_groups[0]['lr'] = lr 24 | print('-------lr_G: %f-------'%lr) 25 | if len(optimizer.param_groups) > 1: 26 | optimizer.param_groups[1]['lr'] = lr * 10 27 | 28 | 29 | def adjust_learning_rate_D(optimizer, i_iter, args): 30 | lr = lr_poly(args.learning_rate_D, i_iter, args.num_steps, args.power) 31 | optimizer.param_groups[0]['lr'] = lr 32 | #if len(optimizer.param_groups) > 1: 33 | # optimizer.param_groups[1]['lr'] = lr * 10 34 | 35 | 36 | def fliplr(img): 37 | '''flip horizontal''' 38 | inv_idx = torch.arange(img.size(3)-1,-1,-1).long().cuda() # N x C x H x W 39 | img_flip = img.index_select(3,inv_idx) 40 | return img_flip 41 | 42 | class Timer: 43 | def __init__(self, msg): 44 | self.msg = msg 45 | self.start_time = None 46 | 47 | def __enter__(self): 48 | self.start_time = time.time() 49 | 50 | def __exit__(self, exc_type, exc_value, exc_tb): 51 | print(self.msg % (time.time() - self.start_time)) 52 | -------------------------------------------------------------------------------- /visualize_noisy_label.py: -------------------------------------------------------------------------------- 1 | from PIL import Image 2 | import numpy as np 3 | import matplotlib.pyplot as plt 4 | import json 5 | 6 | devkit_dir = 'dataset/cityscapes_list' 7 | with open(devkit_dir+'/info.json', 'r') as fp: 8 | info = json.load(fp) 9 | num_classes = np.int(info['classes']) 10 | name_classes = np.array(info['label'], dtype=np.str) 11 | mapping = np.array(info['label2train'], dtype=np.int) 12 | 13 | def label_mapping(input, mapping): 14 | output = np.copy(input) 15 | for ind in range(len(mapping)): 16 | output[input == mapping[ind][0]] = mapping[ind][1] 17 | return np.array(output, dtype=np.int64) 18 | 19 | img1 = 'result/cityscapesSE_GN_batchsize2_1024x512_pp_ms_me0_classbalance7_kl0.1_lr2_drop0.1_seg0.5/frankfurt_000001_005898_leftImg8bit.png' 20 | img2 = 'data/Cityscapes/data/gtFine/val/frankfurt/frankfurt_000001_005898_gtFine_labelIds.png' 21 | 22 | img1 = np.asarray(Image.open(img1)) 23 | img2 = np.asarray(Image.open(img2)) 24 | img2 = label_mapping(img2, mapping) 25 | 26 | print(img1) 27 | print(img2) 28 | output = np.abs(img1-img2) 29 | output[output>200] = 0 30 | output[output>1] = 1 31 | 32 | fig = plt.figure() 33 | plt.axis('off') 34 | heatmap = plt.imshow(output, cmap='viridis') 35 | fig.colorbar(heatmap) 36 | fig.savefig('label_heatmap.png') 37 | --------------------------------------------------------------------------------