├── LICENSE ├── README.md ├── assets ├── model_architecture.png └── results.png ├── config.py ├── datasets ├── __init__.py ├── bdd100k.py ├── camvid.py ├── cityscapes.py ├── cityscapes_labels.py ├── gtav.py ├── kitti.py ├── mapillary.py ├── multi_loader.py ├── nullloader.py ├── sampler.py ├── synthia.py └── uniform.py ├── eval.py ├── infer.py ├── loss.py ├── network ├── Mobilenet.py ├── Resnet.py ├── SEresnext.py ├── Shufflenet.py ├── __init__.py ├── __pycache__ │ ├── Mobilenet.cpython-36.pyc │ ├── Mobilenet.cpython-37.pyc │ ├── Resnet.cpython-36.pyc │ ├── Resnet.cpython-37.pyc │ ├── Shufflenet.cpython-36.pyc │ ├── Shufflenet.cpython-37.pyc │ ├── __init__.cpython-36.pyc │ ├── __init__.cpython-37.pyc │ ├── cov_settings.cpython-36.pyc │ ├── cov_settings.cpython-37.pyc │ ├── cwcl.cpython-36.pyc │ ├── cwcl.cpython-37.pyc │ ├── deepv3.cpython-36.pyc │ ├── deepv3.cpython-37.pyc │ ├── edge_contrast.cpython-36.pyc │ ├── edge_contrast_batch.cpython-36.pyc │ ├── edge_contrast_v1_opt.cpython-36.pyc │ ├── edge_contrast_v1_opt.cpython-37.pyc │ ├── edge_contrast_v2.cpython-36.pyc │ ├── instance_whitening.cpython-36.pyc │ ├── instance_whitening.cpython-37.pyc │ ├── mynn.cpython-36.pyc │ ├── mynn.cpython-37.pyc │ ├── pixel_nce.cpython-36.pyc │ ├── pixel_nce.cpython-37.pyc │ ├── pixel_nce_batch.cpython-36.pyc │ ├── sdcl.cpython-36.pyc │ ├── sdcl.cpython-37.pyc │ ├── sync_switchwhiten.cpython-36.pyc │ └── sync_switchwhiten.cpython-37.pyc ├── bn_helper.py ├── cov_settings.py ├── cwcl.py ├── deepv3.py ├── instance_whitening.py ├── mynn.py ├── sdcl.py ├── switchwhiten.py ├── sync_switchwhiten.py └── wider_resnet.py ├── optimizer.py ├── scripts ├── blindnet_infer_r50os16.sh ├── blindnet_train_r50os16_gtav.sh └── blindnet_valid_r50os16_gtav.sh ├── split_data ├── gtav_split_test.txt ├── gtav_split_train.txt ├── gtav_split_val.txt ├── synthia_split_train.txt └── synthia_split_val.txt ├── train.py ├── transforms ├── __init__.py ├── __pycache__ │ ├── joint_transforms.cpython-36.pyc │ ├── joint_transforms.cpython-37.pyc │ ├── transforms.cpython-36.pyc │ └── transforms.cpython-37.pyc ├── joint_transforms.py └── transforms.py ├── utils ├── __init__.py ├── attr_dict.py ├── misc.py └── my_data_parallel.py └── valid.py /LICENSE: -------------------------------------------------------------------------------- 1 | BSD 3-Clause License 2 | 3 | Copyright (c) 2024, root0yang 4 | 5 | Redistribution and use in source and binary forms, with or without 6 | modification, are permitted provided that the following conditions are met: 7 | 8 | 1. Redistributions of source code must retain the above copyright notice, this 9 | list of conditions and the following disclaimer. 10 | 11 | 2. Redistributions in binary form must reproduce the above copyright notice, 12 | this list of conditions and the following disclaimer in the documentation 13 | and/or other materials provided with the distribution. 14 | 15 | 3. Neither the name of the copyright holder nor the names of its 16 | contributors may be used to endorse or promote products derived from 17 | this software without specific prior written permission. 18 | 19 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 20 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 21 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 22 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 23 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 24 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 25 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 26 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 27 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 28 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 29 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # BlindNet (CVPR 2024) : Official Project Webpage 2 | This repository provides the official PyTorch implementation of the following paper: 3 | > [**Style Blind Domain Generalized Semantic Segmentation via Covariance Alignment and Semantic Consistence Contrastive Learning**](https://arxiv.org/abs/2403.06122)
4 | > Woo-Jin Ahn, Geun-Yeong Yang, Hyun-Duck Choi, Myo-Taeg Lim
5 | > Korea University, Chonnam National University 6 | 7 | > **Abstract:** 8 | > *Deep learning models for semantic segmentation often 9 | experience performance degradation when deployed to unseen target domains unidentified during the training phase. 10 | This is mainly due to variations in image texture (i.e. style) 11 | from different data sources. To tackle this challenge, existing domain generalized semantic segmentation (DGSS) 12 | methods attempt to remove style variations from the feature. However, these approaches struggle with the entanglement of style and content, which may lead to the unintentional removal of crucial content information, causing 13 | performance degradation. This study addresses this limitation by proposing BlindNet, a novel DGSS approach that 14 | blinds the style without external modules or datasets. The 15 | main idea behind our proposed approach is to alleviate the 16 | effect of style in the encoder whilst facilitating robust segmentation in the decoder. To achieve this, BlindNet comprises two key components: covariance alignment and semantic consistency contrastive learning. Specifically, the 17 | covariance alignment trains the encoder to uniformly recognize various styles and preserve the content information 18 | of the feature, rather than removing the style-sensitive factor. Meanwhile, semantic consistency contrastive learning 19 | enables the decoder to construct discriminative class embedding space and disentangles features that are vulnerable to misclassification. Through extensive experiments, 20 | our approach outperforms existing DGSS methods, exhibiting robustness and superior performance for semantic segmentation on unseen target domains.*
21 | 22 |

23 | 24 | 25 |

26 | 27 | ## Pytorch Implementation 28 | 29 | Our pytorch implementation is heaviliy derived from [RobustNet](https://github.com/shachoi/RobustNet) (CVPR 2021). If you use this code in your research, please also cite their work. 30 | [[link to license](https://github.com/shachoi/RobustNet/blob/main/LICENSE)] 31 | 32 | ### Installation 33 | Clone this repository. 34 | ``` 35 | git clone https://github.com/root0yang/BlindNet.git 36 | cd BlindNet 37 | ``` 38 | Install following packages. 39 | ``` 40 | conda create --name blindnet python=3.6 41 | conda activate blindnet 42 | conda install pytorch==1.2.0 cudatoolkit==10.2 43 | conda install scipy==1.1.0 44 | conda install tqdm==4.46.0 45 | conda install scikit-image==0.16.2 46 | pip install tensorboardX==2.4 47 | pip install thop 48 | imageio_download_bin freeimage 49 | ``` 50 | 51 | ### How to Run BlindNet 52 | We evaluated the model on [Cityscapes](https://www.cityscapes-dataset.com/), [BDD-100K](https://bair.berkeley.edu/blog/2018/05/30/bdd/), [Synthia](https://synthia-dataset.net/downloads/) ([SYNTHIA-RAND-CITYSCAPES](http://synthia-dataset.net/download/808/)), [GTAV](https://download.visinf.tu-darmstadt.de/data/from_games/) and [Mapillary Vistas](https://www.mapillary.com/dataset/vistas?pKey=2ix3yvnjy9fwqdzwum3t9g&lat=20&lng=0&z=1.5). 53 | 54 | We adopt Class uniform sampling proposed in [this paper](https://openaccess.thecvf.com/content_CVPR_2019/papers/Zhu_Improving_Semantic_Segmentation_via_Video_Propagation_and_Label_Relaxation_CVPR_2019_paper.pdf) to handle class imbalance problems. 55 | 56 | 57 | 1. For Cityscapes dataset, download "leftImg8bit_trainvaltest.zip" and "gtFine_trainvaltest.zip" from https://www.cityscapes-dataset.com/downloads/
58 | Unzip the files and make the directory structures as follows. 59 | ``` 60 | cityscapes 61 | └ leftImg8bit_trainvaltest 62 | └ leftImg8bit 63 | └ train 64 | └ val 65 | └ test 66 | └ gtFine_trainvaltest 67 | └ gtFine 68 | └ train 69 | └ val 70 | └ test 71 | ``` 72 | ``` 73 | bdd-100k 74 | └ images 75 | └ train 76 | └ val 77 | └ test 78 | └ labels 79 | └ train 80 | └ val 81 | ``` 82 | ``` 83 | mapillary 84 | └ training 85 | └ images 86 | └ labels 87 | └ validation 88 | └ images 89 | └ labels 90 | └ test 91 | └ images 92 | └ labels 93 | ``` 94 | 95 | 2. We used [GTAV_Split](https://download.visinf.tu-darmstadt.de/data/from_games/code/read_mapping.zip) to split GTAV dataset into training/validation/test set. Please refer the txt files in [split_data](https://github.com/suhyeonlee/WildNet/tree/main/split_data). 96 | 97 | ``` 98 | GTAV 99 | └ images 100 | └ train 101 | └ folder 102 | └ valid 103 | └ folder 104 | └ test 105 | └ folder 106 | └ labels 107 | └ train 108 | └ folder 109 | └ valid 110 | └ folder 111 | └ test 112 | └ folder 113 | ``` 114 | 115 | 3. We split [Synthia dataset](http://synthia-dataset.net/download/808/) into train/val set following the [RobustNet](https://github.com/shachoi/RobustNet). Please refer the txt files in [split_data](https://github.com/suhyeonlee/WildNet/tree/main/split_data). 116 | 117 | ``` 118 | synthia 119 | └ RGB 120 | └ train 121 | └ val 122 | └ GT 123 | └ COLOR 124 | └ train 125 | └ val 126 | └ LABELS 127 | └ train 128 | └ val 129 | ``` 130 | 131 | 4. You should modify the path in **"/config.py"** according to your dataset path. 132 | ``` 133 | #Cityscapes Dir Location 134 | __C.DATASET.CITYSCAPES_DIR = 135 | #Mapillary Dataset Dir Location 136 | __C.DATASET.MAPILLARY_DIR = 137 | #GTAV Dataset Dir Location 138 | __C.DATASET.GTAV_DIR = 139 | #BDD-100K Dataset Dir Location 140 | __C.DATASET.BDD_DIR = 141 | #Synthia Dataset Dir Location 142 | __C.DATASET.SYNTHIA_DIR = 143 | ``` 144 | 5. You can train BlindNet with the following command. 145 | ``` 146 | $ CUDA_VISIBLE_DEVICES=0,1 ./scripts/blindnet_train_r50os16_gtav.sh 147 | ``` 148 | 149 | 6. You can download Our ResNet-50 model trained with GTAV at [Google Drive](https://drive.google.com/file/d/1Kkdl_2xjE9iooA1Is5VWcjeRcuzaJ8Pi/view?usp=drive_link) and validate pretrained model with the following command 150 | ``` 151 | $ CUDA_VISIBLE_DEVICES=0,1 ./scripts/blindnet_valid_r50os16_gtav.sh 152 | ``` 153 | 154 | 7. You can infer the segmentation results from images through pretrained model with following commands. 155 | ``` 156 | $ CUDA_VISIBLE_DEVICES=0,1 ./scripts/blindnet_infer_r50os16.sh 157 | ``` 158 | 159 | ## Citation 160 | If you find this work useful in your research, please cite our paper: 161 | ``` 162 | @article{ahn2024style, 163 | title={Style Blind Domain Generalized Semantic Segmentation via Covariance Alignment and Semantic Consistence Contrastive Learning}, 164 | author={Ahn, Woo-Jin and Yang, Geun-Yeong and Choi, Hyun-Duck and Lim, Myo-Taeg}, 165 | journal={arXiv preprint arXiv:2403.06122}, 166 | year={2024} 167 | } 168 | ``` 169 | 170 | ## Terms of Use 171 | This software is for non-commercial use only. 172 | The source code is released under the Attribution-NonCommercial-ShareAlike (CC BY-NC-SA) Licence 173 | (see [this](https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode) for details) 174 | -------------------------------------------------------------------------------- /assets/model_architecture.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/root0yang/BlindNet/1359fa6e9d1e9b011c416f43ec265e31cbde7d9a/assets/model_architecture.png -------------------------------------------------------------------------------- /assets/results.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/root0yang/BlindNet/1359fa6e9d1e9b011c416f43ec265e31cbde7d9a/assets/results.png -------------------------------------------------------------------------------- /config.py: -------------------------------------------------------------------------------- 1 | """ 2 | # Code adapted from: 3 | # https://github.com/facebookresearch/Detectron/blob/master/detectron/core/config.py 4 | 5 | Source License 6 | # Copyright (c) 2017-present, Facebook, Inc. 7 | # 8 | # Licensed under the Apache License, Version 2.0 (the "License"); 9 | # you may not use this file except in compliance with the License. 10 | # You may obtain a copy of the License at 11 | # 12 | # http://www.apache.org/licenses/LICENSE-2.0 13 | # 14 | # Unless required by applicable law or agreed to in writing, software 15 | # distributed under the License is distributed on an "AS IS" BASIS, 16 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 17 | # See the License for the specific language governing permissions and 18 | # limitations under the License. 19 | ############################################################################## 20 | # 21 | # Based on: 22 | # -------------------------------------------------------- 23 | # Fast R-CNN 24 | # Copyright (c) 2015 Microsoft 25 | # Licensed under The MIT License [see LICENSE for details] 26 | # Written by Ross Girshick 27 | # -------------------------------------------------------- 28 | """ 29 | ############################################################################## 30 | # Config 31 | # ############################################################################# 32 | 33 | 34 | from __future__ import absolute_import 35 | from __future__ import division 36 | from __future__ import print_function 37 | from __future__ import unicode_literals 38 | 39 | 40 | import torch 41 | 42 | 43 | from utils.attr_dict import AttrDict 44 | 45 | 46 | __C = AttrDict() 47 | cfg = __C 48 | __C.ITER = 0 49 | __C.EPOCH = 0 50 | 51 | __C.RANDOM_SEED = 304 52 | # Use Class Uniform Sampling to give each class proper sampling 53 | __C.CLASS_UNIFORM_PCT = 0.0 54 | 55 | # Use class weighted loss per batch to increase loss for low pixel count classes per batch 56 | __C.BATCH_WEIGHTING = False 57 | 58 | # Border Relaxation Count 59 | __C.BORDER_WINDOW = 1 60 | # Number of epoch to use before turn off border restriction 61 | __C.REDUCE_BORDER_ITER = -1 62 | __C.REDUCE_BORDER_EPOCH = -1 63 | # Comma Seperated List of class id to relax 64 | __C.STRICTBORDERCLASS = None 65 | 66 | 67 | 68 | #Attribute Dictionary for Dataset 69 | __C.DATASET = AttrDict() 70 | #Cityscapes Dir Location 71 | __C.DATASET.CITYSCAPES_DIR = '/data3/yang/datasets/cityscapes' 72 | #SDC Augmented Cityscapes Dir Location 73 | __C.DATASET.CITYSCAPES_AUG_DIR = '' 74 | #Mapillary Dataset Dir Location 75 | __C.DATASET.MAPILLARY_DIR = '/data3/yang/datasets/mapillary' 76 | #GTAV, BDD100K Dataset Dir Location 77 | __C.DATASET.GTAV_DIR = '/data3/yang/datasets/gtav' 78 | __C.DATASET.BDD_DIR = '/data3/yang/datasets/bdd-100k' 79 | #Synthia Dataset Dir Location 80 | __C.DATASET.SYNTHIA_DIR = '/data3/yang/datasets/synthia' 81 | #Kitti Dataset Dir Location 82 | __C.DATASET.KITTI_DIR = '' 83 | #SDC Augmented Kitti Dataset Dir Location 84 | __C.DATASET.KITTI_AUG_DIR = '' 85 | #Camvid Dataset Dir Location 86 | __C.DATASET.CAMVID_DIR = '/home/nas_datasets/segmentation/SegNet-Tutorial/CamVid' 87 | #Number of splits to support 88 | __C.DATASET.CV_SPLITS = 3 89 | 90 | 91 | __C.MODEL = AttrDict() 92 | __C.MODEL.BN = 'pytorch-syncnorm' 93 | __C.MODEL.BNFUNC = torch.nn.SyncBatchNorm 94 | 95 | def assert_and_infer_cfg(args, make_immutable=True, train_mode=True): 96 | """Call this function in your script after you have finished setting all cfg 97 | values that are necessary (e.g., merging a config from a file, merging 98 | command line config options, etc.). By default, this function will also 99 | mark the global cfg as immutable to prevent changing the global cfg settings 100 | during script execution (which can lead to hard to debug errors or code 101 | that's harder to understand than is necessary). 102 | """ 103 | 104 | if hasattr(args, 'syncbn') and args.syncbn: 105 | __C.MODEL.BN = 'pytorch-syncnorm' 106 | __C.MODEL.BNFUNC = torch.nn.SyncBatchNorm 107 | print('Using pytorch sync batch norm') 108 | else: 109 | __C.MODEL.BNFUNC = torch.nn.BatchNorm2d 110 | print('Using regular batch norm') 111 | 112 | if not train_mode: 113 | cfg.immutable(True) 114 | return 115 | if args.class_uniform_pct: 116 | cfg.CLASS_UNIFORM_PCT = args.class_uniform_pct 117 | 118 | if args.batch_weighting: 119 | __C.BATCH_WEIGHTING = True 120 | 121 | if args.jointwtborder: 122 | if args.strict_bdr_cls != '': 123 | __C.STRICTBORDERCLASS = [int(i) for i in args.strict_bdr_cls.split(",")] 124 | if args.rlx_off_iter > -1: 125 | __C.REDUCE_BORDER_ITER = args.rlx_off_iter 126 | 127 | if make_immutable: 128 | cfg.immutable(True) 129 | -------------------------------------------------------------------------------- /datasets/cityscapes_labels.py: -------------------------------------------------------------------------------- 1 | """ 2 | # File taken from https://github.com/mcordts/cityscapesScripts/ 3 | # License File Available at: 4 | # https://github.com/mcordts/cityscapesScripts/blob/master/license.txt 5 | 6 | # ---------------------- 7 | # The Cityscapes Dataset 8 | # ---------------------- 9 | # 10 | # 11 | # License agreement 12 | # ----------------- 13 | # 14 | # This dataset is made freely available to academic and non-academic entities for non-commercial purposes such as academic research, teaching, scientific publications, or personal experimentation. Permission is granted to use the data given that you agree: 15 | # 16 | # 1. That the dataset comes "AS IS", without express or implied warranty. Although every effort has been made to ensure accuracy, we (Daimler AG, MPI Informatics, TU Darmstadt) do not accept any responsibility for errors or omissions. 17 | # 2. That you include a reference to the Cityscapes Dataset in any work that makes use of the dataset. For research papers, cite our preferred publication as listed on our website; for other media cite our preferred publication as listed on our website or link to the Cityscapes website. 18 | # 3. That you do not distribute this dataset or modified versions. It is permissible to distribute derivative works in as far as they are abstract representations of this dataset (such as models trained on it or additional annotations that do not directly include any of our data) and do not allow to recover the dataset or something similar in character. 19 | # 4. That you may not use the dataset or any derivative work for commercial purposes as, for example, licensing or selling the data, or using the data with a purpose to procure a commercial gain. 20 | # 5. That all rights not expressly granted to you are reserved by us (Daimler AG, MPI Informatics, TU Darmstadt). 21 | # 22 | # 23 | # Contact 24 | # ------- 25 | # 26 | # Marius Cordts, Mohamed Omran 27 | # www.cityscapes-dataset.net 28 | 29 | """ 30 | from collections import namedtuple 31 | 32 | 33 | #-------------------------------------------------------------------------------- 34 | # Definitions 35 | #-------------------------------------------------------------------------------- 36 | 37 | # a label and all meta information 38 | Label = namedtuple( 'Label' , [ 39 | 40 | 'name' , # The identifier of this label, e.g. 'car', 'person', ... . 41 | # We use them to uniquely name a class 42 | 43 | 'id' , # An integer ID that is associated with this label. 44 | # The IDs are used to represent the label in ground truth images 45 | # An ID of -1 means that this label does not have an ID and thus 46 | # is ignored when creating ground truth images (e.g. license plate). 47 | # Do not modify these IDs, since exactly these IDs are expected by the 48 | # evaluation server. 49 | 50 | 'trainId' , # Feel free to modify these IDs as suitable for your method. Then create 51 | # ground truth images with train IDs, using the tools provided in the 52 | # 'preparation' folder. However, make sure to validate or submit results 53 | # to our evaluation server using the regular IDs above! 54 | # For trainIds, multiple labels might have the same ID. Then, these labels 55 | # are mapped to the same class in the ground truth images. For the inverse 56 | # mapping, we use the label that is defined first in the list below. 57 | # For example, mapping all void-type classes to the same ID in training, 58 | # might make sense for some approaches. 59 | # Max value is 255! 60 | 61 | 'category' , # The name of the category that this label belongs to 62 | 63 | 'categoryId' , # The ID of this category. Used to create ground truth images 64 | # on category level. 65 | 66 | 'hasInstances', # Whether this label distinguishes between single instances or not 67 | 68 | 'ignoreInEval', # Whether pixels having this class as ground truth label are ignored 69 | # during evaluations or not 70 | 71 | 'color' , # The color of this label 72 | ] ) 73 | 74 | 75 | #-------------------------------------------------------------------------------- 76 | # A list of all labels 77 | #-------------------------------------------------------------------------------- 78 | 79 | # Please adapt the train IDs as appropriate for you approach. 80 | # Note that you might want to ignore labels with ID 255 during training. 81 | # Further note that the current train IDs are only a suggestion. You can use whatever you like. 82 | # Make sure to provide your results using the original IDs and not the training IDs. 83 | # Note that many IDs are ignored in evaluation and thus you never need to predict these! 84 | 85 | labels = [ 86 | # name id trainId category catId hasInstances ignoreInEval color 87 | Label( 'unlabeled' , 0 , 255 , 'void' , 0 , False , True , ( 0, 0, 0) ), 88 | Label( 'ego vehicle' , 1 , 255 , 'void' , 0 , False , True , ( 0, 0, 0) ), 89 | Label( 'rectification border' , 2 , 255 , 'void' , 0 , False , True , ( 0, 0, 0) ), 90 | Label( 'out of roi' , 3 , 255 , 'void' , 0 , False , True , ( 0, 0, 0) ), 91 | Label( 'static' , 4 , 255 , 'void' , 0 , False , True , ( 0, 0, 0) ), 92 | Label( 'dynamic' , 5 , 255 , 'void' , 0 , False , True , (111, 74, 0) ), 93 | Label( 'ground' , 6 , 255 , 'void' , 0 , False , True , ( 81, 0, 81) ), 94 | Label( 'road' , 7 , 0 , 'flat' , 1 , False , False , (128, 64,128) ), 95 | Label( 'sidewalk' , 8 , 1 , 'flat' , 1 , False , False , (244, 35,232) ), 96 | Label( 'parking' , 9 , 255 , 'flat' , 1 , False , True , (250,170,160) ), 97 | Label( 'rail track' , 10 , 255 , 'flat' , 1 , False , True , (230,150,140) ), 98 | Label( 'building' , 11 , 2 , 'construction' , 2 , False , False , ( 70, 70, 70) ), 99 | Label( 'wall' , 12 , 3 , 'construction' , 2 , False , False , (102,102,156) ), 100 | Label( 'fence' , 13 , 4 , 'construction' , 2 , False , False , (190,153,153) ), 101 | Label( 'guard rail' , 14 , 255 , 'construction' , 2 , False , True , (180,165,180) ), 102 | Label( 'bridge' , 15 , 255 , 'construction' , 2 , False , True , (150,100,100) ), 103 | Label( 'tunnel' , 16 , 255 , 'construction' , 2 , False , True , (150,120, 90) ), 104 | Label( 'pole' , 17 , 5 , 'object' , 3 , False , False , (153,153,153) ), 105 | Label( 'polegroup' , 18 , 255 , 'object' , 3 , False , True , (153,153,154) ), # (153,153,153) 106 | Label( 'traffic light' , 19 , 6 , 'object' , 3 , False , False , (250,170, 30) ), 107 | Label( 'traffic sign' , 20 , 7 , 'object' , 3 , False , False , (220,220, 0) ), 108 | Label( 'vegetation' , 21 , 8 , 'nature' , 4 , False , False , (107,142, 35) ), 109 | Label( 'terrain' , 22 , 9 , 'nature' , 4 , False , False , (152,251,152) ), 110 | Label( 'sky' , 23 , 10 , 'sky' , 5 , False , False , ( 70,130,180) ), 111 | Label( 'person' , 24 , 11 , 'human' , 6 , True , False , (220, 20, 60) ), 112 | Label( 'rider' , 25 , 12 , 'human' , 6 , True , False , (255, 0, 0) ), 113 | Label( 'car' , 26 , 13 , 'vehicle' , 7 , True , False , ( 0, 0,142) ), 114 | Label( 'truck' , 27 , 14 , 'vehicle' , 7 , True , False , ( 0, 0, 70) ), 115 | Label( 'bus' , 28 , 15 , 'vehicle' , 7 , True , False , ( 0, 60,100) ), 116 | Label( 'caravan' , 29 , 255 , 'vehicle' , 7 , True , True , ( 0, 0, 90) ), 117 | Label( 'trailer' , 30 , 255 , 'vehicle' , 7 , True , True , ( 0, 0,110) ), 118 | Label( 'train' , 31 , 16 , 'vehicle' , 7 , True , False , ( 0, 80,100) ), 119 | Label( 'motorcycle' , 32 , 17 , 'vehicle' , 7 , True , False , ( 0, 0,230) ), 120 | Label( 'bicycle' , 33 , 18 , 'vehicle' , 7 , True , False , (119, 11, 32) ), 121 | Label( 'license plate' , -1 , -1 , 'vehicle' , 7 , False , True , ( 0, 0,143) ), # ( 0, 0,142) 122 | ] 123 | 124 | 125 | #-------------------------------------------------------------------------------- 126 | # Create dictionaries for a fast lookup 127 | #-------------------------------------------------------------------------------- 128 | 129 | # Please refer to the main method below for example usages! 130 | 131 | # name to label object 132 | name2label = { label.name : label for label in labels } 133 | # id to label object 134 | id2label = { label.id : label for label in labels } 135 | # trainId to label object 136 | trainId2label = { label.trainId : label for label in reversed(labels) } 137 | # label2trainid 138 | label2trainid = { label.id : label.trainId for label in labels } 139 | # trainId to label object 140 | trainId2name = { label.trainId : label.name for label in labels } 141 | trainId2color = { label.trainId : label.color for label in labels } 142 | 143 | color2trainId = { label.color : label.trainId for label in labels } 144 | 145 | trainId2trainId = { label.trainId : label.trainId for label in labels } 146 | 147 | # category to list of label objects 148 | category2labels = {} 149 | for label in labels: 150 | category = label.category 151 | if category in category2labels: 152 | category2labels[category].append(label) 153 | else: 154 | category2labels[category] = [label] 155 | 156 | #-------------------------------------------------------------------------------- 157 | # Assure single instance name 158 | #-------------------------------------------------------------------------------- 159 | 160 | # returns the label name that describes a single instance (if possible) 161 | # e.g. input | output 162 | # ---------------------- 163 | # car | car 164 | # cargroup | car 165 | # foo | None 166 | # foogroup | None 167 | # skygroup | None 168 | def assureSingleInstanceName( name ): 169 | # if the name is known, it is not a group 170 | if name in name2label: 171 | return name 172 | # test if the name actually denotes a group 173 | if not name.endswith("group"): 174 | return None 175 | # remove group 176 | name = name[:-len("group")] 177 | # test if the new name exists 178 | if not name in name2label: 179 | return None 180 | # test if the new name denotes a label that actually has instances 181 | if not name2label[name].hasInstances: 182 | return None 183 | # all good then 184 | return name 185 | 186 | #-------------------------------------------------------------------------------- 187 | # Main for testing 188 | #-------------------------------------------------------------------------------- 189 | 190 | # just a dummy main 191 | if __name__ == "__main__": 192 | # Print all the labels 193 | print("List of cityscapes labels:") 194 | print("") 195 | print((" {:>21} | {:>3} | {:>7} | {:>14} | {:>10} | {:>12} | {:>12}".format( 'name', 'id', 'trainId', 'category', 'categoryId', 'hasInstances', 'ignoreInEval' ))) 196 | print((" " + ('-' * 98))) 197 | for label in labels: 198 | print((" {:>21} | {:>3} | {:>7} | {:>14} | {:>10} | {:>12} | {:>12}".format( label.name, label.id, label.trainId, label.category, label.categoryId, label.hasInstances, label.ignoreInEval ))) 199 | print("") 200 | 201 | print("Example usages:") 202 | 203 | # Map from name to label 204 | name = 'car' 205 | id = name2label[name].id 206 | print(("ID of label '{name}': {id}".format( name=name, id=id ))) 207 | 208 | # Map from ID to label 209 | category = id2label[id].category 210 | print(("Category of label with ID '{id}': {category}".format( id=id, category=category ))) 211 | 212 | # Map from trainID to label 213 | trainId = 0 214 | name = trainId2label[trainId].name 215 | print(("Name of label with trainID '{id}': {name}".format( id=trainId, name=name ))) 216 | -------------------------------------------------------------------------------- /datasets/kitti.py: -------------------------------------------------------------------------------- 1 | """ 2 | KITTI Dataset Loader 3 | """ 4 | 5 | import os 6 | import sys 7 | import numpy as np 8 | from PIL import Image 9 | from torch.utils import data 10 | import logging 11 | import datasets.uniform as uniform 12 | import datasets.cityscapes_labels as cityscapes_labels 13 | import json 14 | from config import cfg 15 | 16 | 17 | trainid_to_name = cityscapes_labels.trainId2name 18 | id_to_trainid = cityscapes_labels.label2trainid 19 | num_classes = 19 20 | ignore_label = 255 21 | root = cfg.DATASET.KITTI_DIR 22 | aug_root = cfg.DATASET.KITTI_AUG_DIR 23 | 24 | palette = [128, 64, 128, 244, 35, 232, 70, 70, 70, 102, 102, 156, 190, 153, 153, 25 | 153, 153, 153, 250, 170, 30, 26 | 220, 220, 0, 107, 142, 35, 152, 251, 152, 70, 130, 180, 220, 20, 60, 27 | 255, 0, 0, 0, 0, 142, 0, 0, 70, 28 | 0, 60, 100, 0, 80, 100, 0, 0, 230, 119, 11, 32] 29 | zero_pad = 256 * 3 - len(palette) 30 | for i in range(zero_pad): 31 | palette.append(0) 32 | 33 | def colorize_mask(mask): 34 | # mask: numpy array of the mask 35 | new_mask = Image.fromarray(mask.astype(np.uint8)).convert('P') 36 | new_mask.putpalette(palette) 37 | return new_mask 38 | 39 | def get_train_val(cv_split, all_items): 40 | 41 | # 90/10 train/val split, three random splits 42 | val_0 = [1,5,11,29,35,49,57,68,72,82,93,115,119,130,145,154,156,167,169,189,198] 43 | val_1 = [0,12,24,31,42,50,63,71,84,96,101,112,121,133,141,155,164,171,187,191,197] 44 | val_2 = [3,6,13,21,41,54,61,73,88,91,110,121,126,131,142,149,150,163,173,183,199] 45 | 46 | train_set = [] 47 | val_set = [] 48 | 49 | if cv_split == 0: 50 | for i in range(200): 51 | if i in val_0: 52 | val_set.append(all_items[i]) 53 | else: 54 | train_set.append(all_items[i]) 55 | elif cv_split == 1: 56 | for i in range(200): 57 | if i in val_1: 58 | val_set.append(all_items[i]) 59 | else: 60 | train_set.append(all_items[i]) 61 | elif cv_split == 2: 62 | for i in range(200): 63 | if i in val_2: 64 | val_set.append(all_items[i]) 65 | else: 66 | train_set.append(all_items[i]) 67 | else: 68 | logging.info('Unknown cv_split {}'.format(cv_split)) 69 | sys.exit() 70 | 71 | return train_set, val_set 72 | 73 | def make_dataset(quality, mode, maxSkip=0, cv_split=0, hardnm=0): 74 | 75 | items = [] 76 | all_items = [] 77 | aug_items = [] 78 | 79 | assert quality == 'semantic' 80 | assert mode in ['train', 'val', 'trainval'] 81 | # note that train and val are randomly determined, no official split 82 | 83 | img_dir_name = "training" 84 | img_path = os.path.join(root, img_dir_name, 'image_2') 85 | mask_path = os.path.join(root, img_dir_name, 'semantic') 86 | 87 | c_items = os.listdir(img_path) 88 | c_items.sort() 89 | 90 | for it in c_items: 91 | item = (os.path.join(img_path, it), os.path.join(mask_path, it)) 92 | all_items.append(item) 93 | logging.info('KITTI has a total of {} images'.format(len(all_items))) 94 | 95 | # split into train/val 96 | train_set, val_set = get_train_val(cv_split, all_items) 97 | 98 | if mode == 'train': 99 | items = train_set 100 | elif mode == 'val': 101 | items = val_set 102 | elif mode == 'trainval': 103 | items = train_set + val_set 104 | else: 105 | logging.info('Unknown mode {}'.format(mode)) 106 | sys.exit() 107 | 108 | logging.info('KITTI-{}: {} images'.format(mode, len(items))) 109 | 110 | return items, aug_items 111 | 112 | class KITTI(data.Dataset): 113 | 114 | def __init__(self, quality, mode, maxSkip=0, joint_transform_list=None, 115 | transform=None, target_transform=None, dump_images=False, 116 | class_uniform_pct=0, class_uniform_tile=0, test=False, 117 | cv_split=None, scf=None, hardnm=0): 118 | 119 | self.quality = quality 120 | self.mode = mode 121 | self.maxSkip = maxSkip 122 | self.joint_transform_list = joint_transform_list 123 | self.transform = transform 124 | self.target_transform = target_transform 125 | self.dump_images = dump_images 126 | self.class_uniform_pct = class_uniform_pct 127 | self.class_uniform_tile = class_uniform_tile 128 | self.scf = scf 129 | self.hardnm = hardnm 130 | 131 | if cv_split: 132 | self.cv_split = cv_split 133 | assert cv_split < cfg.DATASET.CV_SPLITS, \ 134 | 'expected cv_split {} to be < CV_SPLITS {}'.format( 135 | cv_split, cfg.DATASET.CV_SPLITS) 136 | else: 137 | self.cv_split = 0 138 | 139 | self.imgs, _ = make_dataset(quality, mode, self.maxSkip, cv_split=self.cv_split, hardnm=self.hardnm) 140 | assert len(self.imgs), 'Found 0 images, please check the data set' 141 | # self.cal_shape(self.imgs) 142 | 143 | # Centroids for GT data 144 | if self.class_uniform_pct > 0: 145 | if self.scf: 146 | json_fn = 'kitti_tile{}_cv{}_scf.json'.format(self.class_uniform_tile, self.cv_split) 147 | else: 148 | json_fn = 'kitti_tile{}_cv{}_{}_hardnm{}.json'.format(self.class_uniform_tile, self.cv_split, self.mode, self.hardnm) 149 | if os.path.isfile(json_fn): 150 | with open(json_fn, 'r') as json_data: 151 | centroids = json.load(json_data) 152 | self.centroids = {int(idx): centroids[idx] for idx in centroids} 153 | else: 154 | if self.scf: 155 | self.centroids = kitti_uniform.class_centroids_all( 156 | self.imgs, 157 | num_classes, 158 | id2trainid=id_to_trainid, 159 | tile_size=class_uniform_tile) 160 | else: 161 | self.centroids = uniform.class_centroids_all( 162 | self.imgs, 163 | num_classes, 164 | id2trainid=id_to_trainid, 165 | tile_size=class_uniform_tile) 166 | with open(json_fn, 'w') as outfile: 167 | json.dump(self.centroids, outfile, indent=4) 168 | 169 | self.build_epoch() 170 | 171 | 172 | def cal_shape(self, imgs): 173 | 174 | for i in imgs: 175 | img_path, mask_path = i 176 | img = Image.open(img_path).convert('RGB') 177 | print(img.size) 178 | 179 | def build_epoch(self, cut=False): 180 | if self.class_uniform_pct > 0: 181 | self.imgs_uniform = uniform.build_epoch(self.imgs, 182 | self.centroids, 183 | num_classes, 184 | cfg.CLASS_UNIFORM_PCT) 185 | else: 186 | self.imgs_uniform = self.imgs 187 | 188 | def __getitem__(self, index): 189 | elem = self.imgs_uniform[index] 190 | centroid = None 191 | if len(elem) == 4: 192 | img_path, mask_path, centroid, class_id = elem 193 | else: 194 | img_path, mask_path = elem 195 | 196 | img, mask = Image.open(img_path).convert('RGB'), Image.open(mask_path) 197 | img_name = os.path.splitext(os.path.basename(img_path))[0] 198 | 199 | # kitti scale correction factor 200 | if self.mode == 'train' or self.mode == 'trainval': 201 | if self.scf: 202 | width, height = img.size 203 | img = img.resize((width*2, height*2), Image.BICUBIC) 204 | mask = mask.resize((width*2, height*2), Image.NEAREST) 205 | elif self.mode == 'val': 206 | width, height = 1242, 376 207 | img = img.resize((width, height), Image.BICUBIC) 208 | mask = mask.resize((width, height), Image.NEAREST) 209 | else: 210 | logging.info('Unknown mode {}'.format(mode)) 211 | sys.exit() 212 | 213 | mask = np.array(mask) 214 | mask_copy = mask.copy() 215 | 216 | for k, v in id_to_trainid.items(): 217 | mask_copy[mask == k] = v 218 | mask = Image.fromarray(mask_copy.astype(np.uint8)) 219 | 220 | # Image Transformations 221 | if self.joint_transform_list is not None: 222 | for idx, xform in enumerate(self.joint_transform_list): 223 | if idx == 0 and centroid is not None: 224 | # HACK 225 | # We assume that the first transform is capable of taking 226 | # in a centroid 227 | img, mask = xform(img, mask, centroid) 228 | else: 229 | img, mask = xform(img, mask) 230 | 231 | # Debug 232 | if self.dump_images and centroid is not None: 233 | outdir = './dump_imgs_{}'.format(self.mode) 234 | os.makedirs(outdir, exist_ok=True) 235 | dump_img_name = trainid_to_name[class_id] + '_' + img_name 236 | out_img_fn = os.path.join(outdir, dump_img_name + '.png') 237 | out_msk_fn = os.path.join(outdir, dump_img_name + '_mask.png') 238 | mask_img = colorize_mask(np.array(mask)) 239 | img.save(out_img_fn) 240 | mask_img.save(out_msk_fn) 241 | 242 | if self.transform is not None: 243 | img = self.transform(img) 244 | if self.target_transform is not None: 245 | mask = self.target_transform(mask) 246 | 247 | return img, mask, img_name 248 | 249 | def __len__(self): 250 | return len(self.imgs_uniform) 251 | 252 | -------------------------------------------------------------------------------- /datasets/mapillary.py: -------------------------------------------------------------------------------- 1 | """ 2 | Mapillary Dataset Loader 3 | """ 4 | import logging 5 | import json 6 | import os 7 | import numpy as np 8 | from PIL import Image, ImageCms 9 | from skimage import color 10 | 11 | from torch.utils import data 12 | import torch 13 | import torchvision.transforms as transforms 14 | import datasets.uniform as uniform 15 | import datasets.cityscapes_labels as cityscapes_labels 16 | import transforms.transforms as extended_transforms 17 | import copy 18 | 19 | from config import cfg 20 | 21 | # Convert this dataset to have labels from cityscapes 22 | num_classes = 19 #65 23 | ignore_label = 255 #65 24 | root = cfg.DATASET.MAPILLARY_DIR 25 | config_fn = os.path.join(root, 'config.json') 26 | color_mapping = [] 27 | id_to_trainid = {} 28 | id_to_ignore_or_group = {} 29 | 30 | 31 | def gen_id_to_ignore(): 32 | global id_to_ignore_or_group 33 | for i in range(66): 34 | id_to_ignore_or_group[i] = ignore_label 35 | 36 | ### Convert each class to cityscapes one 37 | ### Road 38 | # Road 39 | id_to_ignore_or_group[13] = 0 40 | # Lane Marking - General 41 | id_to_ignore_or_group[24] = 0 42 | # Manhole 43 | id_to_ignore_or_group[41] = 0 44 | 45 | ### Sidewalk 46 | # Curb 47 | id_to_ignore_or_group[2] = 1 48 | # Sidewalk 49 | id_to_ignore_or_group[15] = 1 50 | 51 | ### Building 52 | # Building 53 | id_to_ignore_or_group[17] = 2 54 | 55 | ### Wall 56 | # Wall 57 | id_to_ignore_or_group[6] = 3 58 | 59 | ### Fence 60 | # Fence 61 | id_to_ignore_or_group[3] = 4 62 | 63 | ### Pole 64 | # Pole 65 | id_to_ignore_or_group[45] = 5 66 | # Utility Pole 67 | id_to_ignore_or_group[47] = 5 68 | 69 | ### Traffic Light 70 | # Traffic Light 71 | id_to_ignore_or_group[48] = 6 72 | 73 | ### Traffic Sign 74 | # Traffic Sign 75 | id_to_ignore_or_group[50] = 7 76 | 77 | ### Vegetation 78 | # Vegitation 79 | id_to_ignore_or_group[30] = 8 80 | 81 | ### Terrain 82 | # Terrain 83 | id_to_ignore_or_group[29] = 9 84 | 85 | ### Sky 86 | # Sky 87 | id_to_ignore_or_group[27] = 10 88 | 89 | ### Person 90 | # Person 91 | id_to_ignore_or_group[19] = 11 92 | 93 | ### Rider 94 | # Bicyclist 95 | id_to_ignore_or_group[20] = 12 96 | # Motorcyclist 97 | id_to_ignore_or_group[21] = 12 98 | # Other Rider 99 | id_to_ignore_or_group[22] = 12 100 | 101 | ### Car 102 | # Car 103 | id_to_ignore_or_group[55] = 13 104 | 105 | ### Truck 106 | # Truck 107 | id_to_ignore_or_group[61] = 14 108 | 109 | ### Bus 110 | # Bus 111 | id_to_ignore_or_group[54] = 15 112 | 113 | ### Train 114 | # On Rails 115 | id_to_ignore_or_group[58] = 16 116 | 117 | ### Motorcycle 118 | # Motorcycle 119 | id_to_ignore_or_group[57] = 17 120 | 121 | ### Bicycle 122 | # Bicycle 123 | id_to_ignore_or_group[52] = 18 124 | 125 | 126 | def colorize_mask(image_array): 127 | """ 128 | Colorize a segmentation mask 129 | """ 130 | new_mask = Image.fromarray(image_array.astype(np.uint8)).convert('P') 131 | new_mask.putpalette(color_mapping) 132 | return new_mask 133 | 134 | 135 | def make_dataset(quality, mode): 136 | """ 137 | Create File List 138 | """ 139 | assert (quality == 'semantic' and mode in ['train', 'val']) 140 | img_dir_name = None 141 | if quality == 'semantic': 142 | if mode == 'train': 143 | img_dir_name = 'training' 144 | if mode == 'val': 145 | img_dir_name = 'validation' 146 | mask_path = os.path.join(root, img_dir_name, 'labels') 147 | else: 148 | raise BaseException("Instance Segmentation Not support") 149 | 150 | img_path = os.path.join(root, img_dir_name, 'images') 151 | print(img_path) 152 | if quality != 'video': 153 | imgs = sorted([os.path.splitext(f)[0] for f in os.listdir(img_path)]) 154 | msks = sorted([os.path.splitext(f)[0] for f in os.listdir(mask_path)]) 155 | assert imgs == msks 156 | 157 | items = [] 158 | c_items = os.listdir(img_path) 159 | if '.DS_Store' in c_items: 160 | c_items.remove('.DS_Store') 161 | 162 | for it in c_items: 163 | if quality == 'video': 164 | item = (os.path.join(img_path, it), os.path.join(img_path, it)) 165 | else: 166 | item = (os.path.join(img_path, it), 167 | os.path.join(mask_path, it.replace(".jpg", ".png"))) 168 | items.append(item) 169 | return items 170 | 171 | 172 | def gen_colormap(): 173 | """ 174 | Get Color Map from file 175 | """ 176 | global color_mapping 177 | 178 | # load mapillary config 179 | with open(config_fn) as config_file: 180 | config = json.load(config_file) 181 | config_labels = config['labels'] 182 | 183 | # calculate label color mapping 184 | colormap = [] 185 | id2name = {} 186 | for i in range(0, len(config_labels)): 187 | colormap = colormap + config_labels[i]['color'] 188 | id2name[i] = config_labels[i]['readable'] 189 | color_mapping = colormap 190 | return id2name 191 | 192 | 193 | class Mapillary(data.Dataset): 194 | def __init__(self, quality, mode, joint_transform_list=None, 195 | transform=None, target_transform=None, target_aux_transform=None, 196 | image_in=False, dump_images=False, class_uniform_pct=0, 197 | class_uniform_tile=768, test=False): 198 | """ 199 | class_uniform_pct = Percent of class uniform samples. 1.0 means fully uniform. 200 | 0.0 means fully random. 201 | class_uniform_tile_size = Class uniform tile size 202 | """ 203 | gen_id_to_ignore() 204 | self.quality = quality 205 | self.mode = mode 206 | self.joint_transform_list = joint_transform_list 207 | self.transform = transform 208 | self.target_transform = target_transform 209 | self.image_in = image_in 210 | self.target_aux_transform = target_aux_transform 211 | self.dump_images = dump_images 212 | self.class_uniform_pct = class_uniform_pct 213 | self.class_uniform_tile = class_uniform_tile 214 | self.id2name = gen_colormap() 215 | self.imgs_uniform = None 216 | 217 | 218 | # find all images 219 | self.imgs = make_dataset(quality, mode) 220 | if len(self.imgs) == 0: 221 | raise RuntimeError('Found 0 images, please check the data set') 222 | if test: 223 | np.random.shuffle(self.imgs) 224 | self.imgs = self.imgs[:200] 225 | 226 | if self.class_uniform_pct: 227 | json_fn = 'mapillary_tile{}.json'.format(self.class_uniform_tile) 228 | if os.path.isfile(json_fn): 229 | with open(json_fn, 'r') as json_data: 230 | centroids = json.load(json_data) 231 | self.centroids = {int(idx): centroids[idx] for idx in centroids} 232 | else: 233 | # centroids is a dict (indexed by class) of lists of centroids 234 | self.centroids = uniform.class_centroids_all( 235 | self.imgs, 236 | num_classes, 237 | id2trainid=None, 238 | tile_size=self.class_uniform_tile) 239 | with open(json_fn, 'w') as outfile: 240 | json.dump(self.centroids, outfile, indent=4) 241 | else: 242 | self.centroids = [] 243 | self.build_epoch() 244 | 245 | def build_epoch(self): 246 | if self.class_uniform_pct != 0: 247 | self.imgs_uniform = uniform.build_epoch(self.imgs, 248 | self.centroids, 249 | num_classes, 250 | self.class_uniform_pct) 251 | else: 252 | self.imgs_uniform = self.imgs 253 | 254 | def __getitem__(self, index): 255 | if len(self.imgs_uniform[index]) == 2: 256 | img_path, mask_path = self.imgs_uniform[index] 257 | centroid = None 258 | class_id = None 259 | else: 260 | img_path, mask_path, centroid, class_id = self.imgs_uniform[index] 261 | img, mask = Image.open(img_path).convert('RGB'), Image.open(mask_path) 262 | img_name = os.path.splitext(os.path.basename(img_path))[0] 263 | 264 | mask = np.array(mask) 265 | mask_copy = mask.copy() 266 | for k, v in id_to_ignore_or_group.items(): 267 | mask_copy[mask == k] = v 268 | mask = Image.fromarray(mask_copy.astype(np.uint8)) 269 | 270 | # Image Transformations 271 | if self.joint_transform_list is not None: 272 | for idx, xform in enumerate(self.joint_transform_list): 273 | if idx == 0 and centroid is not None: 274 | # HACK! Assume the first transform accepts a centroid 275 | img, mask = xform(img, mask, centroid) 276 | else: 277 | img, mask = xform(img, mask) 278 | 279 | if self.dump_images: 280 | outdir = 'dump_imgs_{}'.format(self.mode) 281 | os.makedirs(outdir, exist_ok=True) 282 | if centroid is not None: 283 | dump_img_name = self.id2name[class_id] + '_' + img_name 284 | else: 285 | dump_img_name = img_name 286 | out_img_fn = os.path.join(outdir, dump_img_name + '.png') 287 | out_msk_fn = os.path.join(outdir, dump_img_name + '_mask.png') 288 | mask_img = colorize_mask(np.array(mask)) 289 | img.save(out_img_fn) 290 | mask_img.save(out_msk_fn) 291 | 292 | if self.transform is not None: 293 | img = self.transform(img) 294 | 295 | rgb_mean_std = ([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) 296 | img_gt = transforms.Normalize(*rgb_mean_std)(img) 297 | if self.image_in: 298 | eps = 1e-5 299 | rgb_mean_std = ([torch.mean(img[0]), torch.mean(img[1]), torch.mean(img[2])], 300 | [torch.std(img[0])+eps, torch.std(img[1])+eps, torch.std(img[2])+eps]) 301 | img = transforms.Normalize(*rgb_mean_std)(img) 302 | 303 | if self.target_aux_transform is not None: 304 | mask_aux = self.target_aux_transform(mask) 305 | else: 306 | mask_aux = torch.tensor([0]) 307 | if self.target_transform is not None: 308 | mask = self.target_transform(mask) 309 | 310 | mask = extended_transforms.MaskToTensor()(mask) 311 | return img, mask, img_name, mask_aux 312 | 313 | def __len__(self): 314 | return len(self.imgs_uniform) 315 | 316 | def calculate_weights(self): 317 | raise BaseException("not supported yet") 318 | -------------------------------------------------------------------------------- /datasets/multi_loader.py: -------------------------------------------------------------------------------- 1 | """ 2 | Custom DomainUniformConcatDataset 3 | """ 4 | import numpy as np 5 | 6 | from torch.utils.data import Dataset 7 | import torch 8 | from config import cfg 9 | 10 | 11 | np.random.seed(cfg.RANDOM_SEED) 12 | 13 | 14 | class DomainUniformConcatDataset(Dataset): 15 | """ 16 | DomainUniformConcatDataset 17 | 18 | Sample images uniformly across the domains 19 | If bs_mul is n, this outputs # of domains * n images per batch 20 | """ 21 | @staticmethod 22 | def cumsum(sequence): 23 | r, s = [], 0 24 | for e in sequence: 25 | l = len(e) 26 | r.append(l + s) 27 | s += l 28 | return r 29 | 30 | def __init__(self, args, datasets): 31 | """ 32 | This dataset is to return sample image (source) 33 | and augmented sample image (target) 34 | Args: 35 | args: input config arguments 36 | datasets: list of datasets to concat 37 | """ 38 | super(DomainUniformConcatDataset, self).__init__() 39 | self.datasets = datasets 40 | self.lengths = [len(d) for d in datasets] 41 | self.offsets = self.cumsum(datasets) 42 | self.length = np.sum(self.lengths) 43 | 44 | print("# domains: {}, Total length: {}, 1 epoch: {}, offsets: {}".format( 45 | str(len(datasets)), str(self.length), str(len(self)), str(self.offsets))) 46 | 47 | 48 | def __len__(self): 49 | """ 50 | Returns: 51 | The number of images in a domain that has minimum image samples 52 | """ 53 | return min(self.lengths) 54 | 55 | 56 | def _get_batch_from_dataset(self, dataset, idx): 57 | """ 58 | Get batch from dataset 59 | New idx = idx + random integer 60 | Args: 61 | dataset: dataset class object 62 | idx: integer 63 | 64 | Returns: 65 | One batch from dataset 66 | """ 67 | p_index = idx + np.random.randint(len(dataset)) 68 | if p_index > len(dataset) - 1: 69 | p_index -= len(dataset) 70 | 71 | return dataset[p_index] 72 | 73 | 74 | def __getitem__(self, idx): 75 | """ 76 | Args: 77 | idx (int): Index 78 | 79 | Returns: 80 | images corresonding to the index from each domain 81 | """ 82 | imgs = [] 83 | masks = [] 84 | img_names = [] 85 | mask_auxs = [] 86 | 87 | for dataset in self.datasets: 88 | img, mask, img_name, mask_aux = self._get_batch_from_dataset(dataset, idx) 89 | imgs.append(img) 90 | masks.append(mask) 91 | img_names.append(img_name) 92 | mask_auxs.append(mask_aux) 93 | imgs, masks, mask_auxs = torch.stack(imgs, 0), torch.stack(masks, 0), torch.stack(mask_auxs, 0) 94 | 95 | return imgs, masks, img_names, mask_auxs 96 | 97 | -------------------------------------------------------------------------------- /datasets/nullloader.py: -------------------------------------------------------------------------------- 1 | """ 2 | Null Loader 3 | """ 4 | import numpy as np 5 | import torch 6 | from torch.utils import data 7 | 8 | num_classes = 19 9 | ignore_label = 255 10 | 11 | class NullLoader(data.Dataset): 12 | """ 13 | Null Dataset for Performance 14 | """ 15 | def __init__(self,crop_size): 16 | self.imgs = range(200) 17 | self.crop_size = crop_size 18 | 19 | def __getitem__(self, index): 20 | #Return img, mask, name 21 | return torch.FloatTensor(np.zeros((3,self.crop_size,self.crop_size))), torch.LongTensor(np.zeros((self.crop_size,self.crop_size))), 'img' + str(index) 22 | 23 | def __len__(self): 24 | return len(self.imgs) -------------------------------------------------------------------------------- /datasets/sampler.py: -------------------------------------------------------------------------------- 1 | """ 2 | # Code adapted from: 3 | # https://github.com/pytorch/pytorch/blob/master/torch/utils/data/distributed.py 4 | # 5 | # BSD 3-Clause License 6 | # 7 | # Copyright (c) 2017, 8 | # All rights reserved. 9 | # 10 | # Redistribution and use in source and binary forms, with or without 11 | # modification, are permitted provided that the following conditions are met: 12 | # 13 | # * Redistributions of source code must retain the above copyright notice, this 14 | # list of conditions and the following disclaimer. 15 | # 16 | # * Redistributions in binary form must reproduce the above copyright notice, 17 | # this list of conditions and the following disclaimer in the documentation 18 | # and/or other materials provided with the distribution. 19 | # 20 | # * Neither the name of the copyright holder nor the names of its 21 | # contributors may be used to endorse or promote products derived from 22 | # this software without specific prior written permission. 23 | # 24 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 25 | # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 26 | # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 27 | # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 28 | # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 29 | # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 30 | # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 31 | # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 32 | # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 33 | # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 34 | """ 35 | 36 | 37 | 38 | import math 39 | import torch 40 | from torch.distributed import get_world_size, get_rank 41 | from torch.utils.data import Sampler 42 | 43 | class DistributedSampler(Sampler): 44 | """Sampler that restricts data loading to a subset of the dataset. 45 | 46 | It is especially useful in conjunction with 47 | :class:`torch.nn.parallel.DistributedDataParallel`. In such case, each 48 | process can pass a DistributedSampler instance as a DataLoader sampler, 49 | and load a subset of the original dataset that is exclusive to it. 50 | 51 | .. note:: 52 | Dataset is assumed to be of constant size. 53 | 54 | Arguments: 55 | dataset: Dataset used for sampling. 56 | num_replicas (optional): Number of processes participating in 57 | distributed training. 58 | rank (optional): Rank of the current process within num_replicas. 59 | """ 60 | 61 | def __init__(self, dataset, pad=False, consecutive_sample=False, permutation=False, num_replicas=None, rank=None): 62 | if num_replicas is None: 63 | num_replicas = get_world_size() 64 | if rank is None: 65 | rank = get_rank() 66 | self.dataset = dataset 67 | self.num_replicas = num_replicas 68 | self.rank = rank 69 | self.epoch = 0 70 | self.consecutive_sample = consecutive_sample 71 | self.permutation = permutation 72 | if pad: 73 | self.num_samples = int(math.ceil(len(self.dataset) * 1.0 / self.num_replicas)) 74 | else: 75 | self.num_samples = int(math.floor(len(self.dataset) * 1.0 / self.num_replicas)) 76 | self.total_size = self.num_samples * self.num_replicas 77 | 78 | def __iter__(self): 79 | # deterministically shuffle based on epoch 80 | g = torch.Generator() 81 | g.manual_seed(self.epoch) 82 | 83 | if self.permutation: 84 | indices = list(torch.randperm(len(self.dataset), generator=g)) 85 | else: 86 | indices = list([x for x in range(len(self.dataset))]) 87 | 88 | # add extra samples to make it evenly divisible 89 | if self.total_size > len(indices): 90 | indices += indices[:(self.total_size - len(indices))] 91 | 92 | # subsample 93 | if self.consecutive_sample: 94 | offset = self.num_samples * self.rank 95 | indices = indices[offset:offset + self.num_samples] 96 | else: 97 | indices = indices[self.rank:self.total_size:self.num_replicas] 98 | assert len(indices) == self.num_samples 99 | 100 | return iter(indices) 101 | 102 | def __len__(self): 103 | return self.num_samples 104 | 105 | def set_epoch(self, epoch): 106 | self.epoch = epoch 107 | 108 | def set_num_samples(self): 109 | self.num_samples = int(math.ceil(len(self.dataset) * 1.0 / self.num_replicas)) 110 | self.total_size = self.num_samples * self.num_replicas -------------------------------------------------------------------------------- /datasets/uniform.py: -------------------------------------------------------------------------------- 1 | """ 2 | Uniform sampling of classes. 3 | For all images, for all classes, generate centroids around which to sample. 4 | 5 | All images are divided into tiles. 6 | For each tile, a class can be present or not. If it is 7 | present, calculate the centroid of the class and record it. 8 | 9 | We would like to thank Peter Kontschieder for the inspiration of this idea. 10 | """ 11 | 12 | import logging 13 | from collections import defaultdict 14 | from PIL import Image 15 | import numpy as np 16 | from scipy import ndimage 17 | from tqdm import tqdm 18 | 19 | pbar = None 20 | 21 | class Point(): 22 | """ 23 | Point Class For X and Y Location 24 | """ 25 | def __init__(self, x, y): 26 | self.x = x 27 | self.y = y 28 | 29 | 30 | def calc_tile_locations(tile_size, image_size): 31 | """ 32 | Divide an image into tiles to help us cover classes that are spread out. 33 | tile_size: size of tile to distribute 34 | image_size: original image size 35 | return: locations of the tiles 36 | """ 37 | image_size_y, image_size_x = image_size 38 | locations = [] 39 | for y in range(image_size_y // tile_size): 40 | for x in range(image_size_x // tile_size): 41 | x_offs = x * tile_size 42 | y_offs = y * tile_size 43 | locations.append((x_offs, y_offs)) 44 | return locations 45 | 46 | 47 | def class_centroids_image(item, tile_size, num_classes, id2trainid): 48 | """ 49 | For one image, calculate centroids for all classes present in image. 50 | item: image, image_name 51 | tile_size: 52 | num_classes: 53 | id2trainid: mapping from original id to training ids 54 | return: Centroids are calculated for each tile. 55 | """ 56 | image_fn, label_fn = item 57 | centroids = defaultdict(list) 58 | mask = np.array(Image.open(label_fn)) 59 | if len(mask.shape) == 3: 60 | # Remove instance mask 61 | mask = mask[:,:,0] 62 | image_size = mask.shape 63 | tile_locations = calc_tile_locations(tile_size, image_size) 64 | 65 | mask_copy = mask.copy() 66 | if id2trainid: 67 | for k, v in id2trainid.items(): 68 | mask[mask_copy == k] = v 69 | 70 | for x_offs, y_offs in tile_locations: 71 | patch = mask[y_offs:y_offs + tile_size, x_offs:x_offs + tile_size] 72 | for class_id in range(num_classes): 73 | if class_id in patch: 74 | patch_class = (patch == class_id).astype(int) 75 | centroid_y, centroid_x = ndimage.measurements.center_of_mass(patch_class) 76 | centroid_y = int(centroid_y) + y_offs 77 | centroid_x = int(centroid_x) + x_offs 78 | centroid = (centroid_x, centroid_y) 79 | centroids[class_id].append((image_fn, label_fn, centroid, class_id)) 80 | pbar.update(1) 81 | return centroids 82 | 83 | import scipy.misc as m 84 | 85 | def class_centroids_image_from_color(item, tile_size, num_classes, id2trainid): 86 | """ 87 | For one image, calculate centroids for all classes present in image. 88 | item: image, image_name 89 | tile_size: 90 | num_classes: 91 | id2trainid: mapping from original id to training ids 92 | return: Centroids are calculated for each tile. 93 | """ 94 | image_fn, label_fn = item 95 | centroids = defaultdict(list) 96 | mask = m.imread(label_fn) 97 | image_size = mask[:,:,0].shape 98 | tile_locations = calc_tile_locations(tile_size, image_size) 99 | 100 | # mask = m.imread(label_fn) 101 | # mask_copy = np.full((img.size[1], img.size[0]), 255, dtype=np.uint8) 102 | # for k, v in id2trainid.items(): 103 | # mask_copy[(mask == k)[:,:,0]] = v 104 | # mask = Image.fromarray(mask_copy.astype(np.uint8)) 105 | 106 | # mask_copy = mask.copy() 107 | # mask_copy = mask.copy() 108 | # if id2trainid: 109 | # for k, v in id2trainid.items(): 110 | # mask[mask_copy == k] = v 111 | 112 | mask_copy = np.full(image_size, 255, dtype=np.uint8) 113 | 114 | if id2trainid: 115 | for k, v in id2trainid.items(): 116 | # print("0", mask.shape) 117 | # print("1", ((mask == np.array(k))[:,:,0]).shape) # 1052, 1914 118 | # # print("2", mask == np.array(k)[:,:,0]) 119 | # break 120 | # if v != 255: 121 | # print(v) 122 | # if v == 2: 123 | # print(k, v, "num", np.count_nonzero(mask == np.array(k))) 124 | # break 125 | if v != 255 and v != -1: 126 | mask_copy[(mask == np.array(k))[:,:,0] & (mask == np.array(k))[:,:,1] & (mask == np.array(k))[:,:,2]] = v 127 | mask = mask_copy 128 | 129 | # mask_copy = mask.copy() 130 | # if id2trainid: 131 | # for k, v in id2trainid.items(): 132 | # mask[mask_copy == k] = v 133 | 134 | for x_offs, y_offs in tile_locations: 135 | patch = mask[y_offs:y_offs + tile_size, x_offs:x_offs + tile_size] 136 | for class_id in range(num_classes): 137 | if class_id in patch: 138 | patch_class = (patch == class_id).astype(int) 139 | centroid_y, centroid_x = ndimage.measurements.center_of_mass(patch_class) 140 | centroid_y = int(centroid_y) + y_offs 141 | centroid_x = int(centroid_x) + x_offs 142 | centroid = (centroid_x, centroid_y) 143 | centroids[class_id].append((image_fn, label_fn, centroid, class_id)) 144 | pbar.update(1) 145 | return centroids 146 | 147 | def pooled_class_centroids_all_from_color(items, num_classes, id2trainid, tile_size=1024): 148 | """ 149 | Calculate class centroids for all classes for all images for all tiles. 150 | items: list of (image_fn, label_fn) 151 | tile size: size of tile 152 | returns: dict that contains a list of centroids for each class 153 | """ 154 | from multiprocessing.dummy import Pool 155 | from functools import partial 156 | pool = Pool(32) 157 | global pbar 158 | pbar = tqdm(total=len(items), desc='pooled centroid extraction') 159 | class_centroids_item = partial(class_centroids_image_from_color, 160 | num_classes=num_classes, 161 | id2trainid=id2trainid, 162 | tile_size=tile_size) 163 | 164 | centroids = defaultdict(list) 165 | new_centroids = pool.map(class_centroids_item, items) 166 | pool.close() 167 | pool.join() 168 | 169 | # combine each image's items into a single global dict 170 | for image_items in new_centroids: 171 | for class_id in image_items: 172 | centroids[class_id].extend(image_items[class_id]) 173 | return centroids 174 | 175 | 176 | def pooled_class_centroids_all(items, num_classes, id2trainid, tile_size=1024): 177 | """ 178 | Calculate class centroids for all classes for all images for all tiles. 179 | items: list of (image_fn, label_fn) 180 | tile size: size of tile 181 | returns: dict that contains a list of centroids for each class 182 | """ 183 | from multiprocessing.dummy import Pool 184 | from functools import partial 185 | pool = Pool(80) 186 | global pbar 187 | pbar = tqdm(total=len(items), desc='pooled centroid extraction') 188 | class_centroids_item = partial(class_centroids_image, 189 | num_classes=num_classes, 190 | id2trainid=id2trainid, 191 | tile_size=tile_size) 192 | 193 | centroids = defaultdict(list) 194 | new_centroids = pool.map(class_centroids_item, items) 195 | pool.close() 196 | pool.join() 197 | 198 | # combine each image's items into a single global dict 199 | for image_items in new_centroids: 200 | for class_id in image_items: 201 | centroids[class_id].extend(image_items[class_id]) 202 | return centroids 203 | 204 | 205 | def unpooled_class_centroids_all(items, num_classes, tile_size=1024): 206 | """ 207 | Calculate class centroids for all classes for all images for all tiles. 208 | items: list of (image_fn, label_fn) 209 | tile size: size of tile 210 | returns: dict that contains a list of centroids for each class 211 | """ 212 | centroids = defaultdict(list) 213 | global pbar 214 | pbar = tqdm(total=len(items), desc='centroid extraction') 215 | for image, label in items: 216 | new_centroids = class_centroids_image((image, label), 217 | tile_size, 218 | num_classes) 219 | for class_id in new_centroids: 220 | centroids[class_id].extend(new_centroids[class_id]) 221 | 222 | return centroids 223 | 224 | 225 | def class_centroids_all_from_color(items, num_classes, id2trainid, tile_size=1024): 226 | """ 227 | intermediate function to call pooled_class_centroid 228 | """ 229 | 230 | pooled_centroids = pooled_class_centroids_all_from_color(items, num_classes, 231 | id2trainid, tile_size) 232 | return pooled_centroids 233 | 234 | 235 | def class_centroids_all(items, num_classes, id2trainid, tile_size=1024): 236 | """ 237 | intermediate function to call pooled_class_centroid 238 | """ 239 | 240 | pooled_centroids = pooled_class_centroids_all(items, num_classes, 241 | id2trainid, tile_size) 242 | return pooled_centroids 243 | 244 | 245 | def random_sampling(alist, num): 246 | """ 247 | Randomly sample num items from the list 248 | alist: list of centroids to sample from 249 | num: can be larger than the list and if so, then wrap around 250 | return: class uniform samples from the list 251 | """ 252 | sampling = [] 253 | len_list = len(alist) 254 | assert len_list, 'len_list is zero!' 255 | indices = np.arange(len_list) 256 | np.random.shuffle(indices) 257 | 258 | for i in range(num): 259 | item = alist[indices[i % len_list]] 260 | sampling.append(item) 261 | return sampling 262 | 263 | 264 | def build_epoch(imgs, centroids, num_classes, class_uniform_pct): 265 | """ 266 | Generate an epochs-worth of crops using uniform sampling. Needs to be called every 267 | imgs: list of imgs 268 | centroids: 269 | num_classes: 270 | class_uniform_pct: class uniform sampling percent ( % of uniform images in one epoch ) 271 | """ 272 | logging.info("Class Uniform Percentage: %s", str(class_uniform_pct)) 273 | num_epoch = int(len(imgs)) 274 | 275 | logging.info('Class Uniform items per Epoch:%s', str(num_epoch)) 276 | num_per_class = int((num_epoch * class_uniform_pct) / num_classes) 277 | num_rand = num_epoch - num_per_class * num_classes 278 | # create random crops 279 | imgs_uniform = random_sampling(imgs, num_rand) 280 | 281 | # now add uniform sampling 282 | for class_id in range(num_classes): 283 | string_format = "cls %d len %d"% (class_id, len(centroids[class_id])) 284 | logging.info(string_format) 285 | for class_id in range(num_classes): 286 | centroid_len = len(centroids[class_id]) 287 | if centroid_len == 0: 288 | pass 289 | else: 290 | class_centroids = random_sampling(centroids[class_id], num_per_class) 291 | imgs_uniform.extend(class_centroids) 292 | 293 | return imgs_uniform 294 | -------------------------------------------------------------------------------- /loss.py: -------------------------------------------------------------------------------- 1 | """ 2 | Loss.py 3 | """ 4 | 5 | import logging 6 | import numpy as np 7 | import torch 8 | import torch.nn as nn 9 | import torch.nn.functional as F 10 | import datasets 11 | from config import cfg 12 | 13 | 14 | def get_loss(args): 15 | """ 16 | Get the criterion based on the loss function 17 | args: commandline arguments 18 | return: criterion, criterion_val 19 | """ 20 | if args.cls_wt_loss: 21 | ce_weight = torch.Tensor([0.8373, 0.9180, 0.8660, 1.0345, 1.0166, 0.9969, 0.9754, 22 | 1.0489, 0.8786, 1.0023, 0.9539, 0.9843, 1.1116, 0.9037, 23 | 1.0865, 1.0955, 1.0865, 1.1529, 1.0507]) 24 | else: 25 | ce_weight = None 26 | 27 | if args.img_wt_loss: 28 | criterion = ImageBasedCrossEntropyLoss2d( 29 | classes=datasets.num_classes, size_average=True, 30 | ignore_index=datasets.ignore_label, 31 | upper_bound=args.wt_bound).cuda() 32 | elif args.jointwtborder: 33 | criterion = ImgWtLossSoftNLL(classes=datasets.num_classes, 34 | ignore_index=datasets.ignore_label, 35 | upper_bound=args.wt_bound).cuda() 36 | else: 37 | print("standard cross entropy") 38 | criterion = nn.CrossEntropyLoss(weight=ce_weight, reduction='mean', 39 | ignore_index=datasets.ignore_label).cuda() 40 | 41 | criterion_val = nn.CrossEntropyLoss(reduction='mean', 42 | ignore_index=datasets.ignore_label).cuda() 43 | return criterion, criterion_val 44 | 45 | def get_loss_by_epoch(args): 46 | """ 47 | Get the criterion based on the loss function 48 | args: commandline arguments 49 | return: criterion, criterion_val 50 | """ 51 | 52 | if args.img_wt_loss: 53 | criterion = ImageBasedCrossEntropyLoss2d( 54 | classes=datasets.num_classes, size_average=True, 55 | ignore_index=datasets.ignore_label, 56 | upper_bound=args.wt_bound).cuda() 57 | elif args.jointwtborder: 58 | criterion = ImgWtLossSoftNLL_by_epoch(classes=datasets.num_classes, 59 | ignore_index=datasets.ignore_label, 60 | upper_bound=args.wt_bound).cuda() 61 | else: 62 | criterion = CrossEntropyLoss2d(size_average=True, 63 | ignore_index=datasets.ignore_label).cuda() 64 | 65 | criterion_val = CrossEntropyLoss2d(size_average=True, 66 | weight=None, 67 | ignore_index=datasets.ignore_label).cuda() 68 | return criterion, criterion_val 69 | 70 | 71 | def get_loss_aux(args): 72 | """ 73 | Get the criterion based on the loss function 74 | args: commandline arguments 75 | return: criterion, criterion_val 76 | """ 77 | if args.cls_wt_loss: 78 | ce_weight = torch.Tensor([0.8373, 0.9180, 0.8660, 1.0345, 1.0166, 0.9969, 0.9754, 79 | 1.0489, 0.8786, 1.0023, 0.9539, 0.9843, 1.1116, 0.9037, 80 | 1.0865, 1.0955, 1.0865, 1.1529, 1.0507]) 81 | else: 82 | ce_weight = None 83 | 84 | print("standard cross entropy") 85 | criterion = nn.CrossEntropyLoss(weight=ce_weight, reduction='mean', 86 | ignore_index=datasets.ignore_label).cuda() 87 | 88 | return criterion 89 | 90 | def get_loss_bcelogit(args): 91 | if args.cls_wt_loss: 92 | pos_weight = torch.Tensor([0.8373, 0.9180, 0.8660, 1.0345, 1.0166, 0.9969, 0.9754, 93 | 1.0489, 0.8786, 1.0023, 0.9539, 0.9843, 1.1116, 0.9037, 94 | 1.0865, 1.0955, 1.0865, 1.1529, 1.0507]) 95 | else: 96 | pos_weight = None 97 | print("standard bce with logit cross entropy") 98 | criterion = nn.BCEWithLogitsLoss(reduction='mean').cuda() 99 | 100 | return criterion 101 | 102 | def weighted_binary_cross_entropy(output, target): 103 | 104 | weights = torch.Tensor([0.1, 0.9]) 105 | 106 | loss = weights[1] * (target * torch.log(output)) + \ 107 | weights[0] * ((1 - target) * torch.log(1 - output)) 108 | 109 | return torch.neg(torch.mean(loss)) 110 | 111 | 112 | class L1Loss(nn.Module): 113 | def __init__(self): 114 | super(L1Loss, self).__init__() 115 | 116 | def __call__(self, in0, in1): 117 | return torch.sum(torch.abs(in0 - in1), dim=1, keepdim=True) 118 | 119 | 120 | class ImageBasedCrossEntropyLoss2d(nn.Module): 121 | """ 122 | Image Weighted Cross Entropy Loss 123 | """ 124 | 125 | def __init__(self, classes, weight=None, size_average=True, ignore_index=255, 126 | norm=False, upper_bound=1.0): 127 | super(ImageBasedCrossEntropyLoss2d, self).__init__() 128 | logging.info("Using Per Image based weighted loss") 129 | self.num_classes = classes 130 | self.nll_loss = nn.NLLLoss(weight=weight, reduction='mean', ignore_index=ignore_index) 131 | self.norm = norm 132 | self.upper_bound = upper_bound 133 | self.batch_weights = cfg.BATCH_WEIGHTING 134 | self.logsoftmax = nn.LogSoftmax(dim=1) 135 | 136 | def calculate_weights(self, target): 137 | """ 138 | Calculate weights of classes based on the training crop 139 | """ 140 | hist = np.histogram(target.flatten(), range( 141 | self.num_classes + 1), normed=True)[0] 142 | if self.norm: 143 | hist = ((hist != 0) * self.upper_bound * (1 / hist)) + 1 144 | else: 145 | hist = ((hist != 0) * self.upper_bound * (1 - hist)) + 1 146 | return hist 147 | 148 | def forward(self, inputs, targets): 149 | 150 | target_cpu = targets.data.cpu().numpy() 151 | if self.batch_weights: 152 | weights = self.calculate_weights(target_cpu) 153 | self.nll_loss.weight = torch.Tensor(weights).cuda() 154 | 155 | loss = 0.0 156 | for i in range(0, inputs.shape[0]): 157 | if not self.batch_weights: 158 | weights = self.calculate_weights(target_cpu[i]) 159 | self.nll_loss.weight = torch.Tensor(weights).cuda() 160 | 161 | loss += self.nll_loss(self.logsoftmax(inputs[i].unsqueeze(0)), 162 | targets[i].unsqueeze(0)) 163 | return loss 164 | 165 | 166 | 167 | class CrossEntropyLoss2d(nn.Module): 168 | """ 169 | Cross Entroply NLL Loss 170 | """ 171 | 172 | def __init__(self, weight=None, size_average=True, ignore_index=255): 173 | super(CrossEntropyLoss2d, self).__init__() 174 | logging.info("Using Cross Entropy Loss") 175 | self.nll_loss = nn.NLLLoss(weight=weight, reduction='mean', ignore_index=ignore_index) 176 | self.logsoftmax = nn.LogSoftmax(dim=1) 177 | # self.weight = weight 178 | 179 | def forward(self, inputs, targets): 180 | return self.nll_loss(self.logsoftmax(inputs), targets) 181 | 182 | def customsoftmax(inp, multihotmask): 183 | """ 184 | Custom Softmax 185 | """ 186 | soft = F.softmax(inp, dim=1) 187 | # This takes the mask * softmax ( sums it up hence summing up the classes in border 188 | # then takes of summed up version vs no summed version 189 | return torch.log( 190 | torch.max(soft, (multihotmask * (soft * multihotmask).sum(1, keepdim=True))) 191 | ) 192 | 193 | class ImgWtLossSoftNLL(nn.Module): 194 | """ 195 | Relax Loss 196 | """ 197 | 198 | def __init__(self, classes, ignore_index=255, weights=None, upper_bound=1.0, 199 | norm=False): 200 | super(ImgWtLossSoftNLL, self).__init__() 201 | self.weights = weights 202 | self.num_classes = classes 203 | self.ignore_index = ignore_index 204 | self.upper_bound = upper_bound 205 | self.norm = norm 206 | self.batch_weights = cfg.BATCH_WEIGHTING 207 | 208 | def calculate_weights(self, target): 209 | """ 210 | Calculate weights of the classes based on training crop 211 | """ 212 | if len(target.shape) == 3: 213 | hist = np.sum(target, axis=(1, 2)) * 1.0 / target.sum() 214 | else: 215 | hist = np.sum(target, axis=(0, 2, 3)) * 1.0 / target.sum() 216 | if self.norm: 217 | hist = ((hist != 0) * self.upper_bound * (1 / hist)) + 1 218 | else: 219 | hist = ((hist != 0) * self.upper_bound * (1 - hist)) + 1 220 | return hist[:-1] 221 | 222 | def custom_nll(self, inputs, target, class_weights, border_weights, mask): 223 | """ 224 | NLL Relaxed Loss Implementation 225 | """ 226 | if (cfg.REDUCE_BORDER_ITER != -1 and cfg.ITER > cfg.REDUCE_BORDER_ITER): 227 | border_weights = 1 / border_weights 228 | target[target > 1] = 1 229 | 230 | loss_matrix = (-1 / border_weights * 231 | (target[:, :-1, :, :].float() * 232 | class_weights.unsqueeze(0).unsqueeze(2).unsqueeze(3) * 233 | customsoftmax(inputs, target[:, :-1, :, :].float())).sum(1)) * \ 234 | (1. - mask.float()) 235 | 236 | # loss_matrix[border_weights > 1] = 0 237 | loss = loss_matrix.sum() 238 | 239 | # +1 to prevent division by 0 240 | loss = loss / (target.shape[0] * target.shape[2] * target.shape[3] - mask.sum().item() + 1) 241 | return loss 242 | 243 | def forward(self, inputs, target): 244 | weights = target[:, :-1, :, :].sum(1).float() 245 | ignore_mask = (weights == 0) 246 | weights[ignore_mask] = 1 247 | 248 | loss = 0 249 | target_cpu = target.data.cpu().numpy() 250 | 251 | if self.batch_weights: 252 | class_weights = self.calculate_weights(target_cpu) 253 | 254 | for i in range(0, inputs.shape[0]): 255 | if not self.batch_weights: 256 | class_weights = self.calculate_weights(target_cpu[i]) 257 | loss = loss + self.custom_nll(inputs[i].unsqueeze(0), 258 | target[i].unsqueeze(0), 259 | class_weights=torch.Tensor(class_weights).cuda(), 260 | border_weights=weights[i], mask=ignore_mask[i]) 261 | 262 | loss = loss / inputs.shape[0] 263 | return loss 264 | 265 | class ImgWtLossSoftNLL_by_epoch(nn.Module): 266 | """ 267 | Relax Loss 268 | """ 269 | 270 | def __init__(self, classes, ignore_index=255, weights=None, upper_bound=1.0, 271 | norm=False): 272 | super(ImgWtLossSoftNLL_by_epoch, self).__init__() 273 | self.weights = weights 274 | self.num_classes = classes 275 | self.ignore_index = ignore_index 276 | self.upper_bound = upper_bound 277 | self.norm = norm 278 | self.batch_weights = cfg.BATCH_WEIGHTING 279 | self.fp16 = False 280 | 281 | 282 | def calculate_weights(self, target): 283 | """ 284 | Calculate weights of the classes based on training crop 285 | """ 286 | if len(target.shape) == 3: 287 | hist = np.sum(target, axis=(1, 2)) * 1.0 / target.sum() 288 | else: 289 | hist = np.sum(target, axis=(0, 2, 3)) * 1.0 / target.sum() 290 | if self.norm: 291 | hist = ((hist != 0) * self.upper_bound * (1 / hist)) + 1 292 | else: 293 | hist = ((hist != 0) * self.upper_bound * (1 - hist)) + 1 294 | return hist[:-1] 295 | 296 | def custom_nll(self, inputs, target, class_weights, border_weights, mask): 297 | """ 298 | NLL Relaxed Loss Implementation 299 | """ 300 | if (cfg.REDUCE_BORDER_EPOCH != -1 and cfg.EPOCH > cfg.REDUCE_BORDER_EPOCH): 301 | border_weights = 1 / border_weights 302 | target[target > 1] = 1 303 | if self.fp16: 304 | loss_matrix = (-1 / border_weights * 305 | (target[:, :-1, :, :].half() * 306 | class_weights.unsqueeze(0).unsqueeze(2).unsqueeze(3) * 307 | customsoftmax(inputs, target[:, :-1, :, :].half())).sum(1)) * \ 308 | (1. - mask.half()) 309 | else: 310 | loss_matrix = (-1 / border_weights * 311 | (target[:, :-1, :, :].float() * 312 | class_weights.unsqueeze(0).unsqueeze(2).unsqueeze(3) * 313 | customsoftmax(inputs, target[:, :-1, :, :].float())).sum(1)) * \ 314 | (1. - mask.float()) 315 | 316 | # loss_matrix[border_weights > 1] = 0 317 | loss = loss_matrix.sum() 318 | 319 | # +1 to prevent division by 0 320 | loss = loss / (target.shape[0] * target.shape[2] * target.shape[3] - mask.sum().item() + 1) 321 | return loss 322 | 323 | def forward(self, inputs, target): 324 | if self.fp16: 325 | weights = target[:, :-1, :, :].sum(1).half() 326 | else: 327 | weights = target[:, :-1, :, :].sum(1).float() 328 | ignore_mask = (weights == 0) 329 | weights[ignore_mask] = 1 330 | 331 | loss = 0 332 | target_cpu = target.data.cpu().numpy() 333 | 334 | if self.batch_weights: 335 | class_weights = self.calculate_weights(target_cpu) 336 | 337 | for i in range(0, inputs.shape[0]): 338 | if not self.batch_weights: 339 | class_weights = self.calculate_weights(target_cpu[i]) 340 | loss = loss + self.custom_nll(inputs[i].unsqueeze(0), 341 | target[i].unsqueeze(0), 342 | class_weights=torch.Tensor(class_weights).cuda(), 343 | border_weights=weights, mask=ignore_mask[i]) 344 | 345 | return loss 346 | -------------------------------------------------------------------------------- /network/Mobilenet.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | from torch import Tensor 3 | from torch.utils.model_zoo import load_url as load_state_dict_from_url 4 | from typing import Callable, Any, Optional, List 5 | from network.instance_whitening import InstanceWhitening 6 | from network.mynn import forgiving_state_restore 7 | 8 | __all__ = ['MobileNetV2', 'mobilenet_v2'] 9 | 10 | 11 | model_urls = { 12 | 'mobilenet_v2': 'https://download.pytorch.org/models/mobilenet_v2-b0353104.pth', 13 | } 14 | 15 | 16 | def _make_divisible(v: float, divisor: int, min_value: Optional[int] = None) -> int: 17 | """ 18 | This function is taken from the original tf repo. 19 | It ensures that all layers have a channel number that is divisible by 8 20 | It can be seen here: 21 | https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet.py 22 | :param v: 23 | :param divisor: 24 | :param min_value: 25 | :return: 26 | """ 27 | if min_value is None: 28 | min_value = divisor 29 | new_v = max(min_value, int(v + divisor / 2) // divisor * divisor) 30 | # Make sure that round down does not go down by more than 10%. 31 | if new_v < 0.9 * v: 32 | new_v += divisor 33 | return new_v 34 | 35 | 36 | class ConvBNReLU(nn.Sequential): 37 | def __init__( 38 | self, 39 | in_planes: int, 40 | out_planes: int, 41 | kernel_size: int = 3, 42 | stride: int = 1, 43 | groups: int = 1, 44 | norm_layer: Optional[Callable[..., nn.Module]] = None, 45 | iw: int = 0, 46 | ) -> None: 47 | 48 | padding = (kernel_size - 1) // 2 49 | if norm_layer is None: 50 | norm_layer = nn.BatchNorm2d 51 | 52 | self.iw = iw 53 | 54 | if iw == 1: 55 | instance_norm_layer = InstanceWhitening(out_planes) 56 | elif iw == 2: 57 | instance_norm_layer = InstanceWhitening(out_planes) 58 | elif iw == 3: 59 | instance_norm_layer = nn.InstanceNorm2d(out_planes, affine=False) 60 | elif iw == 4: 61 | instance_norm_layer = nn.InstanceNorm2d(out_planes, affine=True) 62 | else: 63 | instance_norm_layer = nn.Sequential() 64 | 65 | super(ConvBNReLU, self).__init__( 66 | nn.Conv2d(in_planes, out_planes, kernel_size, stride, padding, groups=groups, bias=False), 67 | norm_layer(out_planes), 68 | nn.ReLU6(inplace=True), 69 | instance_norm_layer 70 | ) 71 | 72 | 73 | def forward(self, x_tuple): 74 | if len(x_tuple) == 2: 75 | w_arr = x_tuple[1] 76 | x = x_tuple[0] 77 | else: 78 | print("error in BN forward path") 79 | return 80 | 81 | for i, module in enumerate(self): 82 | if i == len(self) - 1: 83 | if self.iw >= 1: 84 | if self.iw == 1 or self.iw == 2: 85 | x, w = self.instance_norm_layer(x) 86 | w_arr.append(w) 87 | else: 88 | x = self.instance_norm_layer(x) 89 | else: 90 | x = module(x) 91 | 92 | return [x, w_arr] 93 | 94 | 95 | class InvertedResidual(nn.Module): 96 | def __init__( 97 | self, 98 | inp: int, 99 | oup: int, 100 | stride: int, 101 | expand_ratio: int, 102 | norm_layer: Optional[Callable[..., nn.Module]] = None, 103 | iw: int = 0, 104 | ) -> None: 105 | super(InvertedResidual, self).__init__() 106 | self.stride = stride 107 | assert stride in [1, 2] 108 | if norm_layer is None: 109 | norm_layer = nn.BatchNorm2d 110 | self.expand_ratio = expand_ratio 111 | self.iw = iw 112 | 113 | if iw == 1: 114 | self.instance_norm_layer = InstanceWhitening(oup) 115 | elif iw == 2: 116 | self.instance_norm_layer = InstanceWhitening(oup) 117 | elif iw == 3: 118 | self.instance_norm_layer = nn.InstanceNorm2d(oup, affine=False) 119 | elif iw == 4: 120 | self.instance_norm_layer = nn.InstanceNorm2d(oup, affine=True) 121 | else: 122 | self.instance_norm_layer = nn.Sequential() 123 | 124 | hidden_dim = int(round(inp * expand_ratio)) 125 | self.use_res_connect = self.stride == 1 and inp == oup 126 | 127 | layers: List[nn.Module] = [] 128 | if expand_ratio != 1: 129 | # pw 130 | layers.append(ConvBNReLU(inp, hidden_dim, kernel_size=1, norm_layer=norm_layer)) 131 | layers.extend([ 132 | # dw 133 | ConvBNReLU(hidden_dim, hidden_dim, stride=stride, groups=hidden_dim, norm_layer=norm_layer), 134 | # pw-linear 135 | nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False), 136 | norm_layer(oup), 137 | ]) 138 | self.conv = nn.Sequential(*layers) 139 | 140 | 141 | def forward(self, x_tuple): 142 | if len(x_tuple) == 2: 143 | x = x_tuple[0] 144 | else: 145 | print("error in invert residual forward path") 146 | return 147 | if self.expand_ratio != 1: 148 | x_tuple = self.conv[0](x_tuple) 149 | x_tuple = self.conv[1](x_tuple) 150 | conv_x = x_tuple[0] 151 | w_arr = x_tuple[1] 152 | conv_x = self.conv[2](conv_x) 153 | conv_x = self.conv[3](conv_x) 154 | else: 155 | x_tuple = self.conv[0](x_tuple) 156 | conv_x = x_tuple[0] 157 | w_arr = x_tuple[1] 158 | conv_x = self.conv[1](conv_x) 159 | conv_x = self.conv[2](conv_x) 160 | 161 | if self.use_res_connect: 162 | x = x + conv_x 163 | else: 164 | x = conv_x 165 | 166 | if self.iw >= 1: 167 | if self.iw == 1 or self.iw == 2: 168 | x, w = self.instance_norm_layer(x) 169 | w_arr.append(w) 170 | else: 171 | x = self.instance_norm_layer(x) 172 | 173 | return [x, w_arr] 174 | 175 | 176 | class MobileNetV2(nn.Module): 177 | def __init__( 178 | self, 179 | num_classes: int = 1000, 180 | width_mult: float = 1.0, 181 | inverted_residual_setting: Optional[List[List[int]]] = None, 182 | round_nearest: int = 8, 183 | block: Optional[Callable[..., nn.Module]] = None, 184 | norm_layer: Optional[Callable[..., nn.Module]] = None, 185 | iw: list = [0, 0, 0, 0, 0, 0, 0], 186 | ) -> None: 187 | """ 188 | MobileNet V2 main class 189 | Args: 190 | num_classes (int): Number of classes 191 | width_mult (float): Width multiplier - adjusts number of channels in each layer by this amount 192 | inverted_residual_setting: Network structure 193 | round_nearest (int): Round the number of channels in each layer to be a multiple of this number 194 | Set to 1 to turn off rounding 195 | block: Module specifying inverted residual building block for mobilenet 196 | norm_layer: Module specifying the normalization layer to use 197 | """ 198 | super(MobileNetV2, self).__init__() 199 | 200 | if block is None: 201 | block = InvertedResidual 202 | 203 | if norm_layer is None: 204 | norm_layer = nn.BatchNorm2d 205 | 206 | input_channel = 32 207 | last_channel = 1280 208 | 209 | if inverted_residual_setting is None: 210 | inverted_residual_setting = [ 211 | # t, c, n, s 212 | [1, 16, 1, 1], # feature 1 213 | [6, 24, 2, 2], # feature 2, 3 214 | [6, 32, 3, 2], # feature 4, 5, 6 215 | [6, 64, 4, 2], # feature 7, 8, 9, 10 216 | [6, 96, 3, 1], # feature 11, 12, 13 217 | [6, 160, 3, 2], # feature 14, 15, 16 218 | [6, 320, 1, 1], # feature 17 219 | ] 220 | 221 | # only check the first element, assuming user knows t,c,n,s are required 222 | if len(inverted_residual_setting) == 0 or len(inverted_residual_setting[0]) != 4: 223 | raise ValueError("inverted_residual_setting should be non-empty " 224 | "or a 4-element list, got {}".format(inverted_residual_setting)) 225 | 226 | # building first layer 227 | input_channel = _make_divisible(input_channel * width_mult, round_nearest) 228 | self.last_channel = _make_divisible(last_channel * max(1.0, width_mult), round_nearest) 229 | # feature 0 230 | features: List[nn.Module] = [ConvBNReLU(3, input_channel, stride=2, norm_layer=norm_layer)] 231 | # building inverted residual blocks 232 | feature_count = 0 233 | iw_layer = [1, 6, 10, 17, 18] 234 | for t, c, n, s in inverted_residual_setting: 235 | output_channel = _make_divisible(c * width_mult, round_nearest) 236 | for i in range(n): 237 | feature_count += 1 238 | stride = s if i == 0 else 1 239 | if feature_count in iw_layer: 240 | layer = iw_layer.index(feature_count) 241 | features.append(block(input_channel, output_channel, stride, expand_ratio=t, norm_layer=norm_layer, iw=iw[layer + 2])) 242 | else: 243 | features.append(block(input_channel, output_channel, stride, expand_ratio=t, norm_layer=norm_layer, iw=0)) 244 | input_channel = output_channel 245 | # building last several layers 246 | # feature 18 247 | features.append(ConvBNReLU(input_channel, self.last_channel, kernel_size=1, norm_layer=norm_layer)) 248 | # make it nn.Sequential 249 | self.features = nn.Sequential(*features) 250 | 251 | # building classifier 252 | self.classifier = nn.Sequential( 253 | nn.Dropout(0.2), 254 | nn.Linear(self.last_channel, num_classes), 255 | ) 256 | 257 | # weight initialization 258 | for m in self.modules(): 259 | if isinstance(m, nn.Conv2d): 260 | nn.init.kaiming_normal_(m.weight, mode='fan_out') 261 | if m.bias is not None: 262 | nn.init.zeros_(m.bias) 263 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): 264 | nn.init.ones_(m.weight) 265 | nn.init.zeros_(m.bias) 266 | elif isinstance(m, nn.Linear): 267 | nn.init.normal_(m.weight, 0, 0.01) 268 | nn.init.zeros_(m.bias) 269 | 270 | def _forward_impl(self, x: Tensor) -> Tensor: 271 | # This exists since TorchScript doesn't support inheritance, so the superclass method 272 | # (this one) needs to have a name other than `forward` that can be accessed in a subclass 273 | x = self.features(x) 274 | # Cannot use "squeeze" as batch-size can be 1 => must use reshape with x.shape[0] 275 | x = nn.functional.adaptive_avg_pool2d(x, (1, 1)).reshape(x.shape[0], -1) 276 | x = self.classifier(x) 277 | return x 278 | 279 | def forward(self, x: Tensor) -> Tensor: 280 | return self._forward_impl(x) 281 | 282 | 283 | def mobilenet_v2(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> MobileNetV2: 284 | """ 285 | Constructs a MobileNetV2 architecture from 286 | `"MobileNetV2: Inverted Residuals and Linear Bottlenecks" `_. 287 | Args: 288 | pretrained (bool): If True, returns a model pre-trained on ImageNet 289 | progress (bool): If True, displays a progress bar of the download to stderr 290 | """ 291 | model = MobileNetV2(**kwargs) 292 | if pretrained: 293 | state_dict = load_state_dict_from_url(model_urls['mobilenet_v2'], 294 | progress=progress) 295 | #model.load_state_dict(state_dict) 296 | forgiving_state_restore(model, state_dict) 297 | return model 298 | -------------------------------------------------------------------------------- /network/Shufflenet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.utils.model_zoo import load_url as load_state_dict_from_url 4 | from network.instance_whitening import InstanceWhitening 5 | import network.mynn as mynn 6 | 7 | 8 | __all__ = [ 9 | 'ShuffleNetV2', 'shufflenet_v2_x0_5', 'shufflenet_v2_x1_0', 10 | 'shufflenet_v2_x1_5', 'shufflenet_v2_x2_0' 11 | ] 12 | 13 | model_urls = { 14 | 'shufflenetv2_x0.5': 'https://download.pytorch.org/models/shufflenetv2_x0.5-f707e7126e.pth', 15 | 'shufflenetv2_x1.0': 'https://download.pytorch.org/models/shufflenetv2_x1-5666bf0f80.pth', 16 | 'shufflenetv2_x1.5': None, 17 | 'shufflenetv2_x2.0': None, 18 | } 19 | 20 | 21 | def channel_shuffle(x, groups): 22 | # type: (torch.Tensor, int) -> torch.Tensor 23 | batchsize, num_channels, height, width = x.data.size() 24 | channels_per_group = num_channels // groups 25 | 26 | # reshape 27 | x = x.view(batchsize, groups, 28 | channels_per_group, height, width) 29 | 30 | x = torch.transpose(x, 1, 2).contiguous() 31 | 32 | # flatten 33 | x = x.view(batchsize, -1, height, width) 34 | 35 | return x 36 | 37 | 38 | class InvertedResidual(nn.Module): 39 | def __init__(self, inp, oup, stride, iw=0): 40 | super(InvertedResidual, self).__init__() 41 | 42 | if not (1 <= stride <= 3): 43 | raise ValueError('illegal stride value') 44 | self.stride = stride 45 | 46 | branch_features = oup // 2 47 | assert (self.stride != 1) or (inp == branch_features << 1) 48 | 49 | if self.stride > 1: 50 | self.branch1 = nn.Sequential( 51 | self.depthwise_conv(inp, inp, kernel_size=3, stride=self.stride, padding=1), 52 | nn.BatchNorm2d(inp), 53 | nn.Conv2d(inp, branch_features, kernel_size=1, stride=1, padding=0, bias=False), 54 | nn.BatchNorm2d(branch_features), 55 | nn.ReLU(inplace=True), 56 | ) 57 | else: 58 | self.branch1 = nn.Sequential() 59 | 60 | self.branch2 = nn.Sequential( 61 | nn.Conv2d(inp if (self.stride > 1) else branch_features, 62 | branch_features, kernel_size=1, stride=1, padding=0, bias=False), 63 | nn.BatchNorm2d(branch_features), 64 | nn.ReLU(inplace=True), 65 | self.depthwise_conv(branch_features, branch_features, kernel_size=3, stride=self.stride, padding=1), 66 | nn.BatchNorm2d(branch_features), 67 | nn.Conv2d(branch_features, branch_features, kernel_size=1, stride=1, padding=0, bias=False), 68 | nn.BatchNorm2d(branch_features), 69 | nn.ReLU(inplace=True), 70 | ) 71 | self.iw = iw 72 | if iw == 1: 73 | self.instance_norm_layer = InstanceWhitening(oup) 74 | elif iw == 2: 75 | self.instance_norm_layer = InstanceWhitening(oup) 76 | elif iw == 3: 77 | self.instance_norm_layer = nn.InstanceNorm2d(oup, affine=False) 78 | elif iw == 4: 79 | self.instance_norm_layer = nn.InstanceNorm2d(oup, affine=True) 80 | else: 81 | self.instance_norm_layer = nn.Sequential() 82 | 83 | 84 | @staticmethod 85 | def depthwise_conv(i, o, kernel_size, stride=1, padding=0, bias=False): 86 | return nn.Conv2d(i, o, kernel_size, stride, padding, bias=bias, groups=i) 87 | 88 | def forward(self, x_tuple): 89 | if len(x_tuple) == 2: 90 | w_arr = x_tuple[1] 91 | x = x_tuple[0] 92 | else: 93 | print("error in invert residual forward path") 94 | return 95 | 96 | if self.stride == 1: 97 | x1, x2 = x.chunk(2, dim=1) 98 | out = torch.cat((x1, self.branch2(x2)), dim=1) 99 | else: 100 | out = torch.cat((self.branch1(x), self.branch2(x)), dim=1) 101 | 102 | out = channel_shuffle(out, 2) 103 | 104 | if self.iw >= 1: 105 | if self.iw == 1 or self.iw == 2: 106 | out, w = self.instance_norm_layer(out) 107 | w_arr.append(w) 108 | else: 109 | out = self.instance_norm_layer(out) 110 | return [out, w_arr] 111 | 112 | 113 | class ShuffleNetV2(nn.Module): 114 | def __init__(self, stages_repeats, stages_out_channels, num_classes=1000, inverted_residual=InvertedResidual, 115 | iw=[0, 0, 0, 0, 0, 0, 0]): 116 | super(ShuffleNetV2, self).__init__() 117 | 118 | if len(stages_repeats) != 3: 119 | raise ValueError('expected stages_repeats as list of 3 positive ints') 120 | if len(stages_out_channels) != 5: 121 | raise ValueError('expected stages_out_channels as list of 5 positive ints') 122 | self._stage_out_channels = stages_out_channels 123 | 124 | input_channels = 3 125 | output_channels = self._stage_out_channels[0] 126 | self.conv1 = nn.Sequential( 127 | nn.Conv2d(input_channels, output_channels, 3, 2, 1, bias=False), 128 | nn.BatchNorm2d(output_channels), 129 | nn.ReLU(inplace=True), 130 | ) 131 | 132 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 133 | 134 | iw_count = 2 135 | 136 | if iw[iw_count] == 1: 137 | self.instance_norm_layer1 = InstanceWhitening(output_channels) 138 | elif iw[iw_count] == 2: 139 | self.instance_norm_layer1 = InstanceWhitening(output_channels) 140 | elif iw[iw_count] == 3: 141 | self.instance_norm_layer1 = nn.InstanceNorm2d(output_channels, affine=False) 142 | elif iw[iw_count] == 4: 143 | self.instance_norm_layer1 = nn.InstanceNorm2d(output_channels, affine=True) 144 | else: 145 | self.instance_norm_layer1 = nn.Sequential() 146 | 147 | iw_count += 1 148 | 149 | input_channels = output_channels 150 | 151 | stage_names = ['stage{}'.format(i) for i in [2, 3, 4]] 152 | for name, repeats, output_channels in zip( 153 | stage_names, stages_repeats, self._stage_out_channels[1:]): 154 | seq = [inverted_residual(input_channels, output_channels, 2)] 155 | for i in range(repeats - 1): 156 | if i == repeats - 2: 157 | seq.append(inverted_residual(output_channels, output_channels, 1, iw=iw[iw_count])) 158 | iw_count += 1 159 | else: 160 | seq.append(inverted_residual(output_channels, output_channels, 1, iw=0)) 161 | setattr(self, name, nn.Sequential(*seq)) 162 | input_channels = output_channels 163 | 164 | output_channels = self._stage_out_channels[-1] 165 | 166 | self.conv5 = nn.Sequential( 167 | nn.Conv2d(input_channels, output_channels, 1, 1, 0, bias=False), 168 | nn.BatchNorm2d(output_channels), 169 | nn.ReLU(inplace=True), 170 | ) 171 | if iw[iw_count] == 1: 172 | self.instance_norm_layer2 = InstanceWhitening(output_channels) 173 | elif iw[iw_count] == 2: 174 | self.instance_norm_layer2 = InstanceWhitening(output_channels) 175 | elif iw[iw_count] == 3: 176 | self.instance_norm_layer2 = nn.InstanceNorm2d(output_channels, affine=False) 177 | elif iw[iw_count] == 4: 178 | self.instance_norm_layer2 = nn.InstanceNorm2d(output_channels, affine=True) 179 | else: 180 | self.instance_norm_layer2 = nn.Sequential() 181 | 182 | self.fc = nn.Linear(output_channels, num_classes) 183 | 184 | 185 | def _forward_impl(self, x): 186 | # See note [TorchScript super()] 187 | """ 188 | x = self.conv1(x) 189 | x = self.maxpool(x) 190 | """ 191 | x = self.layer0(x) 192 | x = self.stage2(x) 193 | x = self.stage3(x) 194 | x = self.stage4(x) 195 | x = self.layer4(x) 196 | x = x.mean([2, 3]) # globalpool 197 | x = self.fc(x) 198 | return x 199 | 200 | def forward(self, x): 201 | return self._forward_impl(x) 202 | 203 | 204 | def _shufflenetv2(arch, pretrained, progress, *args, **kwargs): 205 | model = ShuffleNetV2(*args, **kwargs) 206 | 207 | if pretrained: 208 | model_url = model_urls[arch] 209 | if model_url is None: 210 | raise NotImplementedError('pretrained {} is not supported as of now'.format(arch)) 211 | else: 212 | state_dict = load_state_dict_from_url(model_url, progress=progress) 213 | mynn.forgiving_state_restore(model, state_dict) 214 | ### model.load_state_dict(state_dict) 215 | 216 | return model 217 | 218 | 219 | def shufflenet_v2_x0_5(pretrained=False, progress=True, **kwargs): 220 | """ 221 | Constructs a ShuffleNetV2 with 0.5x output channels, as described in 222 | `"ShuffleNet V2: Practical Guidelines for Efficient CNN Architecture Design" 223 | `_. 224 | 225 | Args: 226 | pretrained (bool): If True, returns a model pre-trained on ImageNet 227 | progress (bool): If True, displays a progress bar of the download to stderr 228 | """ 229 | return _shufflenetv2('shufflenetv2_x0.5', pretrained, progress, 230 | [4, 8, 4], [24, 48, 96, 192, 1024], **kwargs) 231 | 232 | 233 | def shufflenet_v2_x1_0(pretrained=False, progress=True, **kwargs): 234 | """ 235 | Constructs a ShuffleNetV2 with 1.0x output channels, as described in 236 | `"ShuffleNet V2: Practical Guidelines for Efficient CNN Architecture Design" 237 | `_. 238 | 239 | Args: 240 | pretrained (bool): If True, returns a model pre-trained on ImageNet 241 | progress (bool): If True, displays a progress bar of the download to stderr 242 | """ 243 | return _shufflenetv2('shufflenetv2_x1.0', pretrained, progress, 244 | [4, 8, 4], [24, 116, 232, 464, 1024], **kwargs) 245 | 246 | 247 | def shufflenet_v2_x1_5(pretrained=False, progress=True, **kwargs): 248 | """ 249 | Constructs a ShuffleNetV2 with 1.5x output channels, as described in 250 | `"ShuffleNet V2: Practical Guidelines for Efficient CNN Architecture Design" 251 | `_. 252 | 253 | Args: 254 | pretrained (bool): If True, returns a model pre-trained on ImageNet 255 | progress (bool): If True, displays a progress bar of the download to stderr 256 | """ 257 | return _shufflenetv2('shufflenetv2_x1.5', pretrained, progress, 258 | [4, 8, 4], [24, 176, 352, 704, 1024], **kwargs) 259 | 260 | 261 | def shufflenet_v2_x2_0(pretrained=False, progress=True, **kwargs): 262 | """ 263 | Constructs a ShuffleNetV2 with 2.0x output channels, as described in 264 | `"ShuffleNet V2: Practical Guidelines for Efficient CNN Architecture Design" 265 | `_. 266 | 267 | Args: 268 | pretrained (bool): If True, returns a model pre-trained on ImageNet 269 | progress (bool): If True, displays a progress bar of the download to stderr 270 | """ 271 | return _shufflenetv2('shufflenetv2_x2.0', pretrained, progress, 272 | [4, 8, 4], [24, 244, 488, 976, 2048], **kwargs) 273 | -------------------------------------------------------------------------------- /network/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | Network Initializations 3 | """ 4 | 5 | import logging 6 | import importlib 7 | import torch 8 | import datasets 9 | 10 | 11 | 12 | def get_net(args, criterion, criterion_aux=None): 13 | """ 14 | Get Network Architecture based on arguments provided 15 | """ 16 | net = get_model(args=args, num_classes=datasets.num_classes, 17 | criterion=criterion, criterion_aux=criterion_aux) 18 | num_params = sum([param.nelement() for param in net.parameters()]) 19 | logging.info('Model params = {:2.3f}M'.format(num_params / 1000000)) 20 | 21 | net = net.cuda() 22 | return net 23 | 24 | 25 | def warp_network_in_dataparallel(net, gpuid): 26 | """ 27 | Wrap the network in Dataparallel 28 | """ 29 | # torch.cuda.set_device(gpuid) 30 | # net.cuda(gpuid) 31 | net = torch.nn.parallel.DistributedDataParallel(net, device_ids=[gpuid], find_unused_parameters=True) 32 | # net = torch.nn.parallel.DistributedDataParallel(net, device_ids=[gpuid])#, find_unused_parameters=True) 33 | return net 34 | 35 | 36 | def get_model(args, num_classes, criterion, criterion_aux=None): 37 | """ 38 | Fetch Network Function Pointer 39 | """ 40 | network = args.arch 41 | module = network[:network.rfind('.')] 42 | model = network[network.rfind('.') + 1:] 43 | mod = importlib.import_module(module) 44 | net_func = getattr(mod, model) 45 | net = net_func(args=args, num_classes=num_classes, criterion=criterion, criterion_aux=criterion_aux) 46 | return net 47 | -------------------------------------------------------------------------------- /network/__pycache__/Mobilenet.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/root0yang/BlindNet/1359fa6e9d1e9b011c416f43ec265e31cbde7d9a/network/__pycache__/Mobilenet.cpython-36.pyc -------------------------------------------------------------------------------- /network/__pycache__/Mobilenet.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/root0yang/BlindNet/1359fa6e9d1e9b011c416f43ec265e31cbde7d9a/network/__pycache__/Mobilenet.cpython-37.pyc -------------------------------------------------------------------------------- /network/__pycache__/Resnet.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/root0yang/BlindNet/1359fa6e9d1e9b011c416f43ec265e31cbde7d9a/network/__pycache__/Resnet.cpython-36.pyc -------------------------------------------------------------------------------- /network/__pycache__/Resnet.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/root0yang/BlindNet/1359fa6e9d1e9b011c416f43ec265e31cbde7d9a/network/__pycache__/Resnet.cpython-37.pyc -------------------------------------------------------------------------------- /network/__pycache__/Shufflenet.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/root0yang/BlindNet/1359fa6e9d1e9b011c416f43ec265e31cbde7d9a/network/__pycache__/Shufflenet.cpython-36.pyc -------------------------------------------------------------------------------- /network/__pycache__/Shufflenet.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/root0yang/BlindNet/1359fa6e9d1e9b011c416f43ec265e31cbde7d9a/network/__pycache__/Shufflenet.cpython-37.pyc -------------------------------------------------------------------------------- /network/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/root0yang/BlindNet/1359fa6e9d1e9b011c416f43ec265e31cbde7d9a/network/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /network/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/root0yang/BlindNet/1359fa6e9d1e9b011c416f43ec265e31cbde7d9a/network/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /network/__pycache__/cov_settings.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/root0yang/BlindNet/1359fa6e9d1e9b011c416f43ec265e31cbde7d9a/network/__pycache__/cov_settings.cpython-36.pyc -------------------------------------------------------------------------------- /network/__pycache__/cov_settings.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/root0yang/BlindNet/1359fa6e9d1e9b011c416f43ec265e31cbde7d9a/network/__pycache__/cov_settings.cpython-37.pyc -------------------------------------------------------------------------------- /network/__pycache__/cwcl.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/root0yang/BlindNet/1359fa6e9d1e9b011c416f43ec265e31cbde7d9a/network/__pycache__/cwcl.cpython-36.pyc -------------------------------------------------------------------------------- /network/__pycache__/cwcl.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/root0yang/BlindNet/1359fa6e9d1e9b011c416f43ec265e31cbde7d9a/network/__pycache__/cwcl.cpython-37.pyc -------------------------------------------------------------------------------- /network/__pycache__/deepv3.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/root0yang/BlindNet/1359fa6e9d1e9b011c416f43ec265e31cbde7d9a/network/__pycache__/deepv3.cpython-36.pyc -------------------------------------------------------------------------------- /network/__pycache__/deepv3.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/root0yang/BlindNet/1359fa6e9d1e9b011c416f43ec265e31cbde7d9a/network/__pycache__/deepv3.cpython-37.pyc -------------------------------------------------------------------------------- /network/__pycache__/edge_contrast.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/root0yang/BlindNet/1359fa6e9d1e9b011c416f43ec265e31cbde7d9a/network/__pycache__/edge_contrast.cpython-36.pyc -------------------------------------------------------------------------------- /network/__pycache__/edge_contrast_batch.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/root0yang/BlindNet/1359fa6e9d1e9b011c416f43ec265e31cbde7d9a/network/__pycache__/edge_contrast_batch.cpython-36.pyc -------------------------------------------------------------------------------- /network/__pycache__/edge_contrast_v1_opt.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/root0yang/BlindNet/1359fa6e9d1e9b011c416f43ec265e31cbde7d9a/network/__pycache__/edge_contrast_v1_opt.cpython-36.pyc -------------------------------------------------------------------------------- /network/__pycache__/edge_contrast_v1_opt.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/root0yang/BlindNet/1359fa6e9d1e9b011c416f43ec265e31cbde7d9a/network/__pycache__/edge_contrast_v1_opt.cpython-37.pyc -------------------------------------------------------------------------------- /network/__pycache__/edge_contrast_v2.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/root0yang/BlindNet/1359fa6e9d1e9b011c416f43ec265e31cbde7d9a/network/__pycache__/edge_contrast_v2.cpython-36.pyc -------------------------------------------------------------------------------- /network/__pycache__/instance_whitening.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/root0yang/BlindNet/1359fa6e9d1e9b011c416f43ec265e31cbde7d9a/network/__pycache__/instance_whitening.cpython-36.pyc -------------------------------------------------------------------------------- /network/__pycache__/instance_whitening.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/root0yang/BlindNet/1359fa6e9d1e9b011c416f43ec265e31cbde7d9a/network/__pycache__/instance_whitening.cpython-37.pyc -------------------------------------------------------------------------------- /network/__pycache__/mynn.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/root0yang/BlindNet/1359fa6e9d1e9b011c416f43ec265e31cbde7d9a/network/__pycache__/mynn.cpython-36.pyc -------------------------------------------------------------------------------- /network/__pycache__/mynn.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/root0yang/BlindNet/1359fa6e9d1e9b011c416f43ec265e31cbde7d9a/network/__pycache__/mynn.cpython-37.pyc -------------------------------------------------------------------------------- /network/__pycache__/pixel_nce.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/root0yang/BlindNet/1359fa6e9d1e9b011c416f43ec265e31cbde7d9a/network/__pycache__/pixel_nce.cpython-36.pyc -------------------------------------------------------------------------------- /network/__pycache__/pixel_nce.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/root0yang/BlindNet/1359fa6e9d1e9b011c416f43ec265e31cbde7d9a/network/__pycache__/pixel_nce.cpython-37.pyc -------------------------------------------------------------------------------- /network/__pycache__/pixel_nce_batch.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/root0yang/BlindNet/1359fa6e9d1e9b011c416f43ec265e31cbde7d9a/network/__pycache__/pixel_nce_batch.cpython-36.pyc -------------------------------------------------------------------------------- /network/__pycache__/sdcl.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/root0yang/BlindNet/1359fa6e9d1e9b011c416f43ec265e31cbde7d9a/network/__pycache__/sdcl.cpython-36.pyc -------------------------------------------------------------------------------- /network/__pycache__/sdcl.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/root0yang/BlindNet/1359fa6e9d1e9b011c416f43ec265e31cbde7d9a/network/__pycache__/sdcl.cpython-37.pyc -------------------------------------------------------------------------------- /network/__pycache__/sync_switchwhiten.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/root0yang/BlindNet/1359fa6e9d1e9b011c416f43ec265e31cbde7d9a/network/__pycache__/sync_switchwhiten.cpython-36.pyc -------------------------------------------------------------------------------- /network/__pycache__/sync_switchwhiten.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/root0yang/BlindNet/1359fa6e9d1e9b011c416f43ec265e31cbde7d9a/network/__pycache__/sync_switchwhiten.cpython-37.pyc -------------------------------------------------------------------------------- /network/bn_helper.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import functools 3 | 4 | if torch.__version__.startswith('0'): 5 | from .sync_bn.inplace_abn.bn import InPlaceABNSync 6 | BatchNorm2d = functools.partial(InPlaceABNSync, activation='none') 7 | BatchNorm2d_class = InPlaceABNSync 8 | relu_inplace = False 9 | else: 10 | BatchNorm2d_class = BatchNorm2d = torch.nn.SyncBatchNorm 11 | relu_inplace = True -------------------------------------------------------------------------------- /network/cov_settings.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from network.mynn import initialize_embedding 4 | import kmeans1d 5 | 6 | 7 | def make_cov_index_matrix(dim): # make symmetric matrix for embedding index 8 | matrix = torch.LongTensor() 9 | s_index = 0 10 | for i in range(dim): 11 | matrix = torch.cat([matrix, torch.arange(s_index, s_index + dim).unsqueeze(0)], dim=0) 12 | s_index += (dim - (2 + i)) 13 | return matrix.triu(diagonal=1).transpose(0, 1) + matrix.triu(diagonal=1) 14 | 15 | 16 | class CovMatrix_ISW: 17 | def __init__(self, dim, relax_denom=0, clusters=50): 18 | super(CovMatrix_ISW, self).__init__() 19 | 20 | self.dim = dim 21 | self.i = torch.eye(dim, dim).cuda() 22 | 23 | # print(torch.ones(16, 16).triu(diagonal=1)) 24 | self.reversal_i = torch.ones(dim, dim).triu(diagonal=1).cuda() 25 | 26 | # num_off_diagonal = ((dim * dim - dim) // 2) # number of off-diagonal 27 | self.num_off_diagonal = torch.sum(self.reversal_i) 28 | self.num_sensitive = 0 29 | self.var_matrix = None 30 | self.count_var_cov = 0 31 | self.mask_matrix = None 32 | self.clusters = clusters 33 | print("num_off_diagonal", self.num_off_diagonal) 34 | if relax_denom == 0: # kmeans1d clustering setting for ISW 35 | print("relax_denom == 0!!!!!") 36 | print("cluster == ", self.clusters) 37 | self.margin = 0 38 | else: # do not use 39 | self.margin = self.num_off_diagonal // relax_denom 40 | 41 | def get_eye_matrix(self): 42 | return self.i, self.reversal_i 43 | 44 | def get_mask_matrix(self, mask=True): 45 | if self.mask_matrix is None: 46 | self.set_mask_matrix() 47 | return self.i, self.mask_matrix, 0, self.num_sensitive 48 | 49 | def reset_mask_matrix(self): 50 | self.mask_matrix = None 51 | 52 | def set_mask_matrix(self): 53 | # torch.set_printoptions(threshold=500000) 54 | self.var_matrix = self.var_matrix / self.count_var_cov 55 | var_flatten = torch.flatten(self.var_matrix) 56 | 57 | if self.margin == 0: # kmeans1d clustering setting for ISW 58 | clusters, centroids = kmeans1d.cluster(var_flatten, self.clusters) # 50 clusters 59 | num_sensitive = var_flatten.size()[0] - clusters.count(0) # 1: Insensitive Cov, 2~50: Sensitive Cov 60 | print("num_sensitive, centroids =", num_sensitive, centroids) 61 | _, indices = torch.topk(var_flatten, k=int(num_sensitive)) 62 | else: # do not use 63 | num_sensitive = self.num_off_diagonal - self.margin 64 | print("num_sensitive = ", num_sensitive) 65 | _, indices = torch.topk(var_flatten, k=int(num_sensitive)) 66 | mask_matrix = torch.flatten(torch.zeros(self.dim, self.dim).cuda()) 67 | mask_matrix[indices] = 1 68 | 69 | if self.mask_matrix is not None: 70 | self.mask_matrix = (self.mask_matrix.int() & mask_matrix.view(self.dim, self.dim).int()).float() 71 | else: 72 | self.mask_matrix = mask_matrix.view(self.dim, self.dim) 73 | self.num_sensitive = torch.sum(self.mask_matrix) 74 | print("Check whether two ints are same", num_sensitive, self.num_sensitive) 75 | 76 | self.var_matrix = None 77 | self.count_var_cov = 0 78 | 79 | if torch.cuda.current_device() == 0: 80 | print("Covariance Info: (CXC Shape, Num_Off_Diagonal)", self.mask_matrix.shape, self.num_off_diagonal) 81 | print("Selective (Sensitive Covariance)", self.num_sensitive) 82 | 83 | 84 | def set_variance_of_covariance(self, var_cov): 85 | if self.var_matrix is None: 86 | self.var_matrix = var_cov 87 | else: 88 | self.var_matrix = self.var_matrix + var_cov 89 | self.count_var_cov += 1 90 | 91 | class CovMatrix_IRW: 92 | def __init__(self, dim, relax_denom=0): 93 | super(CovMatrix_IRW, self).__init__() 94 | 95 | self.dim = dim 96 | self.i = torch.eye(dim, dim).cuda() 97 | self.reversal_i = torch.ones(dim, dim).triu(diagonal=1).cuda() 98 | 99 | self.num_off_diagonal = torch.sum(self.reversal_i) 100 | if relax_denom == 0: 101 | print("relax_denom == 0!!!!!") 102 | self.margin = 0 103 | else: 104 | self.margin = self.num_off_diagonal // relax_denom 105 | 106 | def get_mask_matrix(self): 107 | return self.i, self.reversal_i, self.margin, self.num_off_diagonal 108 | -------------------------------------------------------------------------------- /network/cwcl.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | from typing import Any 3 | from packaging import version 4 | 5 | from abc import ABC 6 | 7 | import torch 8 | from torch import nn 9 | import torch.nn.functional as F 10 | 11 | 12 | class Class_PixelNCELoss(nn.Module): 13 | def __init__(self, args): 14 | super(Class_PixelNCELoss, self).__init__() 15 | 16 | self.args = args 17 | self.ignore_label = 255 18 | 19 | self.cross_entropy_loss = torch.nn.CrossEntropyLoss(reduction='none') 20 | 21 | self.max_classes = self.args.contrast_max_classes 22 | self.max_views = self.args.contrast_max_views 23 | 24 | # reshape label or prediction 25 | def resize_label(self, labels, HW): 26 | labels = labels.unsqueeze(1).float().clone() 27 | labels = torch.nn.functional.interpolate(labels, 28 | HW, mode='nearest') 29 | labels = labels.squeeze(1).long() 30 | 31 | return labels 32 | 33 | def _hard_anchor_sampling(self, X_q, X_k, y_hat, y): 34 | # X : Feature map, shape:(B, h*w, C), y_hat : label, shape:(B, h*w), y : prediction, shape:(B, H*W?) 35 | batch_size, feat_dim = X_q.shape[0], X_q.shape[-1] 36 | 37 | classes = [] 38 | num_classes = [] 39 | total_classes = 0 40 | # 한 배치 내의 이미지들에 대한 label들로부터 존재하는 class들 골라내기 41 | for ii in range(batch_size): 42 | this_y = y_hat[ii] 43 | this_classes = torch.unique(this_y) # 텐서에서 중복된 요소 제거하여 존재하는 고유요소들 반환 44 | this_classes = [x for x in this_classes if x != self.ignore_label] # ignore label 제거 45 | this_classes = [x for x in this_classes if (this_y == x).nonzero().shape[0] > self.max_views] # class가 일정 개수 이상인 경우만 골라내기 46 | 47 | classes.append(this_classes) 48 | 49 | total_classes += len(this_classes) 50 | num_classes.append(len(this_classes)) 51 | 52 | # return none if there is no class in the image 53 | if total_classes == 0: 54 | return None, None, None 55 | 56 | n_view = self.max_views 57 | 58 | # output tensors 59 | X_q_ = torch.zeros((batch_size, self.max_classes, n_view, feat_dim), dtype=torch.float).cuda() 60 | X_k_ = torch.zeros((batch_size, self.max_classes, n_view, feat_dim), dtype=torch.float).cuda() 61 | 62 | for ii in range(batch_size): 63 | this_y_hat = y_hat[ii] 64 | this_y = y[ii] 65 | this_classes = classes[ii] 66 | this_indices = [] 67 | 68 | # if there is no class in the image, randomly sample patcthes 69 | if len(this_classes) == 0: 70 | indices = torch.arange(X_q.shape[1], device=X_q.device) 71 | perm = torch.randperm(X_q.shape[1], device=X_q.device) 72 | indices = indices[perm[:n_view * self.max_classes]] 73 | indices = indices.view(self.max_classes, -1) 74 | 75 | X_q_[ii, :, :, :] = X_q[ii, indices, :] 76 | X_k_[ii, :, :, :] = X_k[ii, indices, :] 77 | 78 | continue 79 | 80 | # referecne : https://github.com/tfzhou/ContrastiveSeg/tree/main 81 | for n, cls_id in enumerate(this_classes): 82 | 83 | if n == self.max_classes: 84 | break 85 | 86 | # sample hard pathces(wrong prediction) and easy pathces(correct prediction) 87 | hard_indices = ((this_y_hat == cls_id) & (this_y != cls_id)).nonzero() 88 | easy_indices = ((this_y_hat == cls_id) & (this_y == cls_id)).nonzero() 89 | 90 | num_hard = hard_indices.shape[0] 91 | num_easy = easy_indices.shape[0] 92 | 93 | if num_hard >= n_view / 2 and num_easy >= n_view / 2: 94 | num_hard_keep = n_view // 2 95 | num_easy_keep = n_view - num_hard_keep 96 | elif num_hard >= n_view / 2: 97 | num_easy_keep = num_easy 98 | num_hard_keep = n_view - num_easy_keep 99 | elif num_easy >= n_view / 2: 100 | num_hard_keep = num_hard 101 | num_easy_keep = n_view - num_hard_keep 102 | 103 | perm = torch.randperm(num_hard) 104 | hard_indices = hard_indices[perm[:num_hard_keep]] 105 | perm = torch.randperm(num_easy) 106 | easy_indices = easy_indices[perm[:num_easy_keep]] 107 | indices = torch.cat((hard_indices, easy_indices), dim=0) 108 | 109 | X_q_[ii, n, :, :] = X_q[ii, indices, :].squeeze(1) 110 | X_k_[ii, n, :, :] = X_k[ii, indices, :].squeeze(1) 111 | 112 | this_indices.append(indices) 113 | 114 | # fill the spare space with random pathces 115 | if len(this_classes) < self.max_classes: 116 | this_indices = torch.stack(this_indices) 117 | this_indices = this_indices.flatten(0, 1) 118 | 119 | num_remain = self.max_classes - len(this_classes) 120 | all_indices = torch.arange(X_q.shape[1], device=X_q[0].device) 121 | left_indices = torch.zeros(X_q .shape[1], device=X_q[0].device, dtype=torch.uint8) 122 | left_indices[this_indices] = 1 123 | left_indices = all_indices[~left_indices] 124 | 125 | perm = torch.randperm(len(left_indices), device=X_q[0].device) 126 | 127 | indices = left_indices[perm[:n_view * num_remain]] 128 | indices = indices.view(num_remain, -1) 129 | 130 | X_q_[ii, n + 1:, :, :] = X_q[ii, indices, :] 131 | X_k_[ii, n + 1:, :, :] = X_k[ii, indices, :] 132 | 133 | return X_q_, X_k_, num_classes 134 | 135 | def _contrastive(self, feats_q_, feats_k_): 136 | # feats shape : (B, nc, N, C) 137 | batch_size, num_classes, n_view, patch_dim = feats_q_.shape 138 | num_patches = batch_size * num_classes * n_view 139 | 140 | # feats shape : (B*nc*N, 1, C) 141 | feats_q_ = feats_q_.contiguous().view(num_patches, -1, patch_dim) 142 | feats_k_ = feats_k_.contiguous().view(num_patches, -1, patch_dim) 143 | 144 | # logit_positive : same positive patches between key and query 145 | # shape : (B * nc * N , 1) 146 | l_pos = torch.bmm( 147 | feats_q_, feats_k_.transpose(2, 1) 148 | ) 149 | l_pos =l_pos.view(num_patches, 1) 150 | 151 | # feats shape : (B, nc*N, C) 152 | feats_q_ = feats_q_.contiguous().view(batch_size, -1, patch_dim) 153 | feats_k_ = feats_k_.contiguous().view(batch_size, -1, patch_dim) 154 | n_patches = feats_q_.shape[1] 155 | 156 | # logit negative shape : (B, nc*N, nc*N) 157 | l_neg_curbatch = torch.bmm(feats_q_, feats_k_.transpose(2, 1)) 158 | 159 | # exclude same class patches 160 | diag_block= torch.zeros((batch_size, n_patches, n_patches), device=feats_q_.device, dtype=torch.uint8) 161 | for i in range(num_classes): 162 | diag_block[:, i*n_view:(i+1)*n_view, i*n_view:(i+1)*n_view] = 1 163 | 164 | l_neg_curbatch = l_neg_curbatch[~diag_block].view(batch_size, n_patches, -1) 165 | 166 | # logit negative shape : (B*nc*N, nc*(N-1)) 167 | l_neg = l_neg_curbatch.view(num_patches, -1) 168 | 169 | out = torch.cat([l_pos, l_neg], dim=1) / self.args.nce_T 170 | 171 | loss = self.cross_entropy_loss(out, torch.zeros(out.size(0), dtype=torch.long, device=feats_q_.device)) 172 | 173 | return loss 174 | 175 | def forward(self, feats_q, feats_k, labels=None, predict=None): 176 | B, C, H, W = feats_q.shape 177 | 178 | # resize label and prediction 179 | labels = self.resize_label(labels, (H, W)) 180 | predict = self.resize_label(predict, (H, W)) 181 | 182 | labels = labels.contiguous().view(B, -1) 183 | predict = predict.contiguous().view(B, -1) 184 | 185 | # change axis 186 | feats_q = feats_q.permute(0, 2, 3, 1) 187 | feats_q = feats_q.contiguous().view(feats_q.shape[0], -1, feats_q.shape[-1]) 188 | 189 | feats_k = feats_k.detach() 190 | 191 | feats_k = feats_k.permute(0, 2, 3, 1) 192 | feats_k = feats_k.contiguous().view(feats_k.shape[0], -1, feats_k.shape[-1]) 193 | 194 | # sample patches 195 | feats_q_, feats_k_, num_classes = self._hard_anchor_sampling(feats_q, feats_k, labels, predict) 196 | 197 | if feats_q_ is None: 198 | loss = torch.FloatTensor([0]).cuda() 199 | return loss 200 | 201 | loss = self._contrastive(feats_q_, feats_k_) 202 | 203 | del labels 204 | 205 | return loss 206 | 207 | class Normalize(nn.Module): 208 | def __init__(self, power=2): 209 | super(Normalize, self).__init__() 210 | self.power = power 211 | 212 | def forward(self, x): 213 | norm = x.pow(self.power).sum(1, keepdim=True).pow(1. / self.power) 214 | out = x.div(norm + 1e-7) 215 | return out 216 | 217 | 218 | class ProjectionHead(nn.Module): 219 | 220 | def __init__(self, dim_in, proj_dim=256): 221 | super(ProjectionHead, self).__init__() 222 | 223 | self.proj = nn.Sequential( 224 | nn.Conv2d(dim_in, dim_in, kernel_size=1), 225 | nn.ReLU(), 226 | #nn.BatchNorm2d(dim_in), 227 | nn.Conv2d(dim_in, proj_dim, kernel_size=1) 228 | ) 229 | self.l2norm = Normalize(2) 230 | 231 | def forward(self, x): 232 | return self.l2norm(self.proj(x)) 233 | #return F.normalize(self.proj(x), p=2, dim=1) 234 | 235 | -------------------------------------------------------------------------------- /network/instance_whitening.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | class InstanceWhitening(nn.Module): 7 | 8 | def __init__(self, dim): 9 | super(InstanceWhitening, self).__init__() 10 | self.instance_standardization = nn.InstanceNorm2d(dim, affine=False) 11 | 12 | def forward(self, x): 13 | 14 | x = self.instance_standardization(x) 15 | w = x 16 | 17 | return x, w 18 | 19 | 20 | def get_covariance_matrix(f_map, eye=None): 21 | eps = 1e-5 22 | B, C, H, W = f_map.shape # i-th feature size (B X C X H X W) 23 | HW = H * W 24 | if eye is None: 25 | eye = torch.eye(C).cuda() 26 | f_map = f_map.contiguous().view(B, C, -1) # B X C X H X W > B X C X (H X W) 27 | f_cor = torch.bmm(f_map, f_map.transpose(1, 2)).div(HW-1) + (eps * eye) # C X C / HW 28 | 29 | return f_cor, B 30 | 31 | # Calcualate Cross Covarianc of two feature maps 32 | # reference : https://github.com/shachoi/RobustNet 33 | def get_cross_covariance_matrix(f_map1, f_map2, eye=None): 34 | eps = 1e-5 35 | assert f_map1.shape == f_map2.shape 36 | 37 | B, C, H, W = f_map1.shape 38 | HW = H*W 39 | 40 | if eye is None: 41 | eye = torch.eye(C).cuda() 42 | 43 | # feature map shape : (B,C,H,W) -> (B,C,HW) 44 | f_map1 = f_map1.contiguous().view(B, C, -1) 45 | f_map2 = f_map2.contiguous().view(B, C, -1) 46 | 47 | # f_cor shape : (B, C, C) 48 | f_cor = torch.bmm(f_map1, f_map2.transpose(1, 2)).div(HW-1) + (eps * eye) 49 | 50 | return f_cor, B 51 | 52 | def cross_whitening_loss(k_feat, q_feat): 53 | assert k_feat.shape == q_feat.shape 54 | 55 | f_cor, B = get_cross_covariance_matrix(k_feat, q_feat) 56 | diag_loss = torch.FloatTensor([0]).cuda() 57 | 58 | # get diagonal values of covariance matrix 59 | for cor in f_cor: 60 | diag = torch.diagonal(cor.squeeze(dim=0), 0) 61 | eye = torch.ones_like(diag).cuda() 62 | diag_loss = diag_loss + F.mse_loss(diag, eye) 63 | diag_loss = diag_loss / B 64 | 65 | return diag_loss 66 | -------------------------------------------------------------------------------- /network/mynn.py: -------------------------------------------------------------------------------- 1 | """ 2 | Custom Norm wrappers to enable sync BN, regular BN and for weight initialization 3 | """ 4 | import torch.nn as nn 5 | import torch 6 | from config import cfg 7 | 8 | def Norm2d(in_channels): 9 | """ 10 | Custom Norm Function to allow flexible switching 11 | """ 12 | layer = getattr(cfg.MODEL, 'BNFUNC') 13 | normalization_layer = layer(in_channels) 14 | return normalization_layer 15 | 16 | 17 | def freeze_weights(*models): 18 | for model in models: 19 | for k in model.parameters(): 20 | k.requires_grad = False 21 | 22 | def unfreeze_weights(*models): 23 | for model in models: 24 | for k in model.parameters(): 25 | k.requires_grad = True 26 | 27 | def initialize_weights(*models): 28 | """ 29 | Initialize Model Weights 30 | """ 31 | for model in models: 32 | for module in model.modules(): 33 | if isinstance(module, (nn.Conv2d, nn.Linear)): 34 | nn.init.kaiming_normal_(module.weight, nonlinearity='relu') 35 | if module.bias is not None: 36 | module.bias.data.zero_() 37 | elif isinstance(module, nn.Conv1d): 38 | nn.init.kaiming_normal_(module.weight, nonlinearity='relu') 39 | if module.bias is not None: 40 | module.bias.data.zero_() 41 | elif isinstance(module, nn.BatchNorm2d) or isinstance(module, nn.BatchNorm1d) or \ 42 | isinstance(module, nn.GroupNorm) or isinstance(module, nn.SyncBatchNorm): 43 | module.weight.data.fill_(1) 44 | module.bias.data.zero_() 45 | 46 | def initialize_embedding(*models): 47 | """ 48 | Initialize Model Weights 49 | """ 50 | for model in models: 51 | for module in model.modules(): 52 | if isinstance(module, nn.Embedding): 53 | module.weight.data.zero_() #original 54 | 55 | 56 | 57 | def Upsample(x, size): 58 | """ 59 | Wrapper Around the Upsample Call 60 | """ 61 | return nn.functional.interpolate(x, size=size, mode='bilinear', 62 | align_corners=True) 63 | 64 | def forgiving_state_restore(net, loaded_dict): 65 | """ 66 | Handle partial loading when some tensors don't match up in size. 67 | Because we want to use models that were trained off a different 68 | number of classes. 69 | """ 70 | net_state_dict = net.state_dict() 71 | new_loaded_dict = {} 72 | for k in net_state_dict: 73 | if k in loaded_dict and net_state_dict[k].size() == loaded_dict[k].size(): 74 | new_loaded_dict[k] = loaded_dict[k] 75 | else: 76 | print("Skipped loading parameter", k) 77 | # logging.info("Skipped loading parameter %s", k) 78 | net_state_dict.update(new_loaded_dict) 79 | net.load_state_dict(net_state_dict) 80 | return net 81 | -------------------------------------------------------------------------------- /network/sdcl.py: -------------------------------------------------------------------------------- 1 | from typing import Any 2 | from packaging import version 3 | 4 | from abc import ABC 5 | 6 | import torch 7 | from torch import nn 8 | import torch.nn.functional as F 9 | 10 | from network.mynn import initialize_weights 11 | import numpy as np 12 | import time 13 | 14 | 15 | class Disentangle_Contrast(nn.Module): 16 | def __init__(self, args): 17 | super(Disentangle_Contrast, self).__init__() 18 | 19 | self.args = args 20 | self.temperature = 0.1 21 | self.cross_entropy_loss = torch.nn.CrossEntropyLoss(reduction='none') 22 | 23 | self.num_patch = self.args.num_patch 24 | self.num_classes = 19 25 | self.max_samples = 1000 26 | 27 | 28 | # reshape label or prediction 29 | def reshape_map(self, map, shape): 30 | 31 | map = map.unsqueeze(1).float().clone() 32 | map = torch.nn.functional.interpolate(map, shape, mode='nearest') 33 | map = map.squeeze(1).long() 34 | 35 | return map 36 | 37 | 38 | def _contrastive(self, pos_q, pos_k, neg): 39 | num_patch, _, patch_dim = pos_q.shape 40 | 41 | # l_pos shape : (num_patch, 1) 42 | l_pos = torch.bmm(pos_q, pos_k.transpose(2, 1)) 43 | l_pos = l_pos.view(num_patch, 1) 44 | 45 | # l_neg shape : (num_patch, negative_size) 46 | l_neg = torch.bmm(pos_q, neg.transpose(2, 1)) 47 | l_neg = l_neg.view(num_patch, -1) 48 | 49 | out = torch.cat([l_pos, l_neg], dim=1) / self.args.nce_T 50 | 51 | loss = self.cross_entropy_loss(out, torch.zeros(out.size(0), dtype=torch.long, device=pos_q.device)) 52 | 53 | return loss 54 | 55 | 56 | def Disentangle_Sampler(self, correct_maps, feats_q, feats_k, predicts, predicts_j, labels): 57 | B, HW, C = feats_q.shape 58 | start_ts = time.time() 59 | 60 | # X_pos_q : anchors, X_pos_k : positives, X_neg : negatives 61 | X_pos_q = [] 62 | X_pos_k = [] 63 | X_neg = [] 64 | 65 | for ii in range(B): 66 | img_sample_ts = time.time() 67 | M = correct_maps[ii] 68 | # indices : wrong prediction location in query features 69 | indices = (M == 1).nonzero() 70 | 71 | classes_labels = torch.unique(labels[ii]) 72 | classes_wrong = torch.unique(predicts_j[ii, indices]) 73 | 74 | pos_indices = [] 75 | neg_indices = [] 76 | 77 | # sample anchor, positive, negative for each wrong class 78 | for cls_id in classes_wrong: 79 | sampling_time = time.time() 80 | # cls_indices : anchor, positive indices 81 | cls_indices = ((M == 1) & (predicts_j[ii] == cls_id)).nonzero() 82 | 83 | # pass if wrong class doesn't exist in the image 84 | if cls_id not in classes_labels: 85 | continue 86 | else: 87 | neg_cls_indices = (labels[ii] == cls_id).nonzero() 88 | 89 | if neg_cls_indices.size(0) < self.num_patch: 90 | continue 91 | 92 | neg_sampled_indices = [neg_cls_indices[torch.randperm(neg_cls_indices.size(0))[:self.num_patch]].squeeze()] * cls_indices.size(0) 93 | neg_sampled_indices = torch.cat(neg_sampled_indices, dim=0) 94 | 95 | pos_indices.append(cls_indices) 96 | neg_indices.append(neg_sampled_indices) 97 | 98 | if not pos_indices: 99 | continue 100 | pos_indices = torch.cat(pos_indices, dim=0) 101 | neg_indices = torch.cat(neg_indices, dim=0) 102 | 103 | # anchor from query feature 104 | X_pos_q.append(feats_q[ii, pos_indices, :]) 105 | # positive from key feature 106 | X_pos_k.append(feats_k[ii, pos_indices, :]) 107 | # Negative from query feature 108 | X_neg.append(feats_q[ii, neg_indices, :].view(pos_indices.size(0), self.num_patch, C)) 109 | 110 | if not X_pos_q: 111 | return None, None, None 112 | # X_pos_q, X_pos_k shape : (num_patch, 1, C) 113 | # X_neg shape : (num_patch, negative_size, C) 114 | X_pos_q = torch.cat(X_pos_q, dim=0) 115 | X_pos_k = torch.cat(X_pos_k, dim=0) 116 | X_neg = torch.cat(X_neg, dim=0) 117 | 118 | if X_pos_q.shape[0] > B * self.max_samples: 119 | indices = torch.randperm(X_pos_q.size(0))[:B*self.max_samples] 120 | X_pos_q = X_pos_q[indices, :, :] 121 | X_pos_k = X_pos_k[indices, :, :] 122 | X_neg = X_neg[indices, :, :] 123 | 124 | return X_pos_q, X_pos_k, X_neg 125 | 126 | 127 | def forward(self, feats_q, feats_k, predicts, predicts_j, labels): 128 | B, C, H, W = feats_q.shape 129 | 130 | # reshape the labels and predictions to feature map's size 131 | labels = self.reshape_map(labels, (H, W)) 132 | predicts = self.reshape_map(predicts, (H, W)) 133 | predicts_j = self.reshape_map(predicts_j, (H, W)) 134 | 135 | # calculate Correction map 136 | correct_maps = torch.ones_like(predicts, device=feats_q[0].device) 137 | correct_maps[predicts == predicts_j] = 0 138 | correct_maps[labels == 255] = 0 139 | correct_maps[predicts != labels] = 0 140 | correct_maps = correct_maps.flatten(1, 2) 141 | 142 | predicts = predicts.flatten(1, 2) 143 | predicts_j = predicts_j.flatten(1, 2) 144 | labels = labels.flatten(1, 2) 145 | 146 | feats_k = feats_k.detach() 147 | 148 | feats_q_reshape = feats_q.permute(0, 2, 3, 1).flatten(1, 2) 149 | feats_k_reshape = feats_k.permute(0, 2, 3, 1).flatten(1, 2) 150 | 151 | # Sample the anchor and positives, negatives 152 | patches_q, patches_k, patches_neg = self.Disentangle_Sampler(correct_maps, feats_q_reshape, feats_k_reshape, 153 | predicts, predicts_j, labels) 154 | 155 | if patches_q is None: 156 | loss = torch.FloatTensor([0]).cuda() 157 | return loss 158 | 159 | loss = self._contrastive(patches_q, patches_k, patches_neg) 160 | 161 | return loss 162 | 163 | -------------------------------------------------------------------------------- /network/switchwhiten.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.nn.parameter import Parameter 4 | from torch.nn.modules.module import Module 5 | 6 | 7 | class SwitchWhiten2d(Module): 8 | """Switchable Whitening. 9 | 10 | Args: 11 | num_features (int): Number of channels. 12 | num_pergroup (int): Number of channels for each whitening group. 13 | sw_type (int): Switchable whitening type, from {2, 3, 5}. 14 | sw_type = 2: BW + IW 15 | sw_type = 3: BW + IW + LN 16 | sw_type = 5: BW + IW + BN + IN + LN 17 | T (int): Number of iterations for iterative whitening. 18 | tie_weight (bool): Use the same importance weight for mean and 19 | covariance or not. 20 | """ 21 | 22 | def __init__(self, 23 | num_features, 24 | num_pergroup=16, 25 | sw_type=2, 26 | T=5, 27 | tie_weight=False, 28 | eps=1e-5, 29 | momentum=0.99, 30 | affine=True): 31 | super(SwitchWhiten2d, self).__init__() 32 | if sw_type not in [2, 3, 5]: 33 | raise ValueError('sw_type should be in [2, 3, 5], ' 34 | 'but got {}'.format(sw_type)) 35 | assert num_features % num_pergroup == 0 36 | self.num_features = num_features 37 | self.num_pergroup = num_pergroup 38 | self.num_groups = num_features // num_pergroup 39 | self.sw_type = sw_type 40 | self.T = T 41 | self.tie_weight = tie_weight 42 | self.eps = eps 43 | self.momentum = momentum 44 | self.affine = affine 45 | num_components = sw_type 46 | 47 | self.sw_mean_weight = Parameter(torch.ones(num_components)) 48 | if not self.tie_weight: 49 | self.sw_var_weight = Parameter(torch.ones(num_components)) 50 | else: 51 | self.register_parameter('sw_var_weight', None) 52 | 53 | if self.affine: 54 | self.weight = Parameter(torch.ones(num_features)) 55 | self.bias = Parameter(torch.zeros(num_features)) 56 | else: 57 | self.register_parameter('weight', None) 58 | self.register_parameter('bias', None) 59 | 60 | self.register_buffer('running_mean', 61 | torch.zeros(self.num_groups, num_pergroup, 1)) 62 | self.register_buffer( 63 | 'running_cov', 64 | torch.eye(num_pergroup).unsqueeze(0).repeat(self.num_groups, 1, 1)) 65 | 66 | self.reset_parameters() 67 | 68 | def reset_parameters(self): 69 | self.running_mean.zero_() 70 | self.running_cov.zero_() 71 | nn.init.ones_(self.sw_mean_weight) 72 | if not self.tie_weight: 73 | nn.init.ones_(self.sw_var_weight) 74 | if self.affine: 75 | nn.init.ones_(self.weight) 76 | nn.init.zeros_(self.bias) 77 | 78 | def __repr__(self): 79 | return ('{name}({num_features}, num_pergroup={num_pergroup}, ' 80 | 'sw_type={sw_type}, T={T}, tie_weight={tie_weight}, ' 81 | 'eps={eps}, momentum={momentum}, affine={affine})'.format( 82 | name=self.__class__.__name__, **self.__dict__)) 83 | 84 | def forward(self, x): 85 | N, C, H, W = x.size() 86 | c, g = self.num_pergroup, self.num_groups 87 | 88 | in_data_t = x.transpose(0, 1).contiguous() 89 | # g x c x (N x H x W) 90 | in_data_t = in_data_t.view(g, c, -1) 91 | 92 | # calculate batch mean and covariance 93 | if self.training: 94 | # g x c x 1 95 | mean_bn = in_data_t.mean(-1, keepdim=True) 96 | in_data_bn = in_data_t - mean_bn 97 | # g x c x c 98 | cov_bn = torch.bmm(in_data_bn, 99 | in_data_bn.transpose(1, 2)).div(H * W * N) 100 | 101 | self.running_mean.mul_(self.momentum) 102 | self.running_mean.add_((1 - self.momentum) * mean_bn.data) 103 | self.running_cov.mul_(self.momentum) 104 | self.running_cov.add_((1 - self.momentum) * cov_bn.data) 105 | else: 106 | mean_bn = torch.autograd.Variable(self.running_mean) 107 | cov_bn = torch.autograd.Variable(self.running_cov) 108 | 109 | mean_bn = mean_bn.view(1, g, c, 1).expand(N, g, c, 1).contiguous() 110 | mean_bn = mean_bn.view(N * g, c, 1) 111 | cov_bn = cov_bn.view(1, g, c, c).expand(N, g, c, c).contiguous() 112 | cov_bn = cov_bn.view(N * g, c, c) 113 | 114 | # (N x g) x c x (H x W) 115 | in_data = x.view(N * g, c, -1) 116 | 117 | eye = in_data.data.new().resize_(c, c) 118 | eye = torch.nn.init.eye_(eye).view(1, c, c).expand(N * g, c, c) 119 | 120 | # calculate other statistics 121 | # (N x g) x c x 1 122 | mean_in = in_data.mean(-1, keepdim=True) 123 | x_in = in_data - mean_in 124 | # (N x g) x c x c 125 | cov_in = torch.bmm(x_in, torch.transpose(x_in, 1, 2)).div(H * W) 126 | if self.sw_type in [3, 5]: 127 | x = x.view(N, -1) 128 | mean_ln = x.mean(-1, keepdim=True).view(N, 1, 1, 1) 129 | mean_ln = mean_ln.expand(N, g, 1, 1).contiguous().view(N * g, 1, 1) 130 | var_ln = x.var(-1, keepdim=True).view(N, 1, 1, 1) 131 | var_ln = var_ln.expand(N, g, 1, 1).contiguous().view(N * g, 1, 1) 132 | var_ln = var_ln * eye 133 | if self.sw_type == 5: 134 | var_bn = torch.diag_embed(torch.diagonal(cov_bn, dim1=-2, dim2=-1)) 135 | var_in = torch.diag_embed(torch.diagonal(cov_in, dim1=-2, dim2=-1)) 136 | 137 | # calculate weighted average of mean and covariance 138 | softmax = nn.Softmax(0) 139 | mean_weight = softmax(self.sw_mean_weight) 140 | if not self.tie_weight: 141 | var_weight = softmax(self.sw_var_weight) 142 | else: 143 | var_weight = mean_weight 144 | 145 | # BW + IW 146 | if self.sw_type == 2: 147 | # (N x g) x c x 1 148 | mean = mean_weight[0] * mean_bn + mean_weight[1] * mean_in 149 | cov = var_weight[0] * cov_bn + var_weight[1] * cov_in + \ 150 | self.eps * eye 151 | # BW + IW + LN 152 | elif self.sw_type == 3: 153 | mean = mean_weight[0] * mean_bn + \ 154 | mean_weight[1] * mean_in + mean_weight[2] * mean_ln 155 | cov = var_weight[0] * cov_bn + var_weight[1] * cov_in + \ 156 | var_weight[2] * var_ln + self.eps * eye 157 | # BW + IW + BN + IN + LN 158 | elif self.sw_type == 5: 159 | mean = (mean_weight[0] + mean_weight[2]) * mean_bn + \ 160 | (mean_weight[1] + mean_weight[3]) * mean_in + \ 161 | mean_weight[4] * mean_ln 162 | cov = var_weight[0] * cov_bn + var_weight[1] * cov_in + \ 163 | var_weight[0] * var_bn + var_weight[1] * var_in + \ 164 | var_weight[4] * var_ln + self.eps * eye 165 | 166 | # perform whitening using Newton's iteration 167 | Ng, c, _ = cov.size() 168 | P = torch.eye(c).to(cov).expand(Ng, c, c) 169 | # reciprocal of trace of covariance 170 | rTr = (cov * P).sum((1, 2), keepdim=True).reciprocal_() 171 | cov_N = cov * rTr 172 | for k in range(self.T): 173 | P = torch.baddbmm(1.5, P, -0.5, torch.matrix_power(P, 3), cov_N) 174 | # whiten matrix: the matrix inverse of covariance, i.e., cov^{-1/2} 175 | wm = P.mul_(rTr.sqrt()) 176 | 177 | x_hat = torch.bmm(wm, in_data - mean) 178 | x_hat = x_hat.view(N, C, H, W) 179 | if self.affine: 180 | x_hat = x_hat * self.weight.view(1, self.num_features, 1, 1) + \ 181 | self.bias.view(1, self.num_features, 1, 1) 182 | 183 | return x_hat 184 | -------------------------------------------------------------------------------- /network/sync_switchwhiten.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.distributed as dist 3 | import torch.nn as nn 4 | from torch.autograd import Function 5 | from torch.nn.modules.module import Module 6 | from torch.nn.parameter import Parameter 7 | 8 | 9 | class SyncMeanCov(Function): 10 | 11 | @staticmethod 12 | def forward(ctx, in_data, running_mean, running_cov, momentum, training): 13 | g, c, NHW = in_data.size() 14 | ctx.g = g 15 | ctx.c = c 16 | ctx.NHW = NHW 17 | ctx.training = training 18 | 19 | if training: 20 | mean_bn = in_data.mean(-1, keepdim=True) # g x c x 1 21 | dist.all_reduce(mean_bn) 22 | mean_bn /= dist.get_world_size() 23 | in_data_bn = in_data - mean_bn 24 | cov_bn = torch.bmm(in_data_bn, in_data_bn.transpose(1, 2)).div(NHW) 25 | dist.all_reduce(cov_bn) 26 | cov_bn /= dist.get_world_size() 27 | 28 | running_mean.mul_(momentum) 29 | running_mean.add_((1 - momentum) * mean_bn.data) 30 | running_cov.mul_(momentum) 31 | running_cov.add_((1 - momentum) * cov_bn.data) 32 | else: 33 | mean_bn = torch.autograd.Variable(running_mean) 34 | cov_bn = torch.autograd.Variable(running_cov) 35 | 36 | ctx.save_for_backward(in_data.data, mean_bn.data) 37 | return mean_bn, cov_bn 38 | 39 | @staticmethod 40 | def backward(ctx, grad_mean_out, grad_cov_out): 41 | in_data, mean_bn = ctx.saved_tensors 42 | 43 | if ctx.training: 44 | dist.all_reduce(grad_mean_out) 45 | dist.all_reduce(grad_cov_out) 46 | world_size = dist.get_world_size() 47 | else: 48 | world_size = 1 49 | 50 | grad_cov_out = (grad_cov_out + grad_cov_out.transpose(1, 2)) / 2 51 | grad_cov_in = 2 * torch.bmm(grad_cov_out, (in_data - mean_bn)) \ 52 | / (ctx.NHW*world_size) # g x c x (N x H x W) 53 | 54 | grad_mean_in = grad_mean_out / ctx.NHW / world_size 55 | inDiff = grad_mean_in + grad_cov_in 56 | return inDiff, None, None, None, None 57 | 58 | 59 | class SyncSwitchWhiten2d(Module): 60 | """Syncronized Switchable Whitening. 61 | 62 | Args: 63 | num_features (int): Number of channels. 64 | num_pergroup (int): Number of channels for each whitening group. 65 | sw_type (int): Switchable whitening type, from {2, 3, 4, 5}. 66 | sw_type = 2: BW + IW 67 | sw_type = 3: BW + IW + LN 68 | sw_type = 5: BW + IW + BN + IN + LN 69 | T (int): Number of iterations for iterative whitening. 70 | tie_weight (bool): Use the same importance weight for mean and 71 | covariance or not. 72 | """ 73 | 74 | def __init__(self, 75 | num_features, 76 | num_pergroup=16, 77 | sw_type=2, 78 | T=5, 79 | tie_weight=False, 80 | eps=1e-5, 81 | momentum=0.99, 82 | affine=True): 83 | super(SyncSwitchWhiten2d, self).__init__() 84 | if sw_type not in [2, 3, 4, 5]: 85 | raise ValueError('sw_type should be in [2, 3, 4, 5], ' 86 | 'but got {}'.format(sw_type)) 87 | assert num_features % num_pergroup == 0 88 | self.num_features = num_features 89 | self.num_pergroup = num_pergroup 90 | self.num_groups = num_features // num_pergroup 91 | self.sw_type = sw_type 92 | self.T = T 93 | self.tie_weight = tie_weight 94 | self.eps = eps 95 | self.momentum = momentum 96 | self.tie_weight = tie_weight 97 | self.affine = affine 98 | num_components = sw_type 99 | 100 | self.sw_mean_weight = Parameter(torch.ones(num_components)) 101 | if not self.tie_weight: 102 | self.sw_var_weight = Parameter(torch.ones(num_components)) 103 | else: 104 | self.register_parameter('sw_var_weight', None) 105 | 106 | if self.affine: 107 | self.weight = Parameter(torch.ones(num_features)) 108 | self.bias = Parameter(torch.zeros(num_features)) 109 | else: 110 | self.register_parameter('weight', None) 111 | self.register_parameter('bias', None) 112 | 113 | self.register_buffer('running_mean', 114 | torch.zeros(self.num_groups, num_pergroup, 1)) 115 | self.register_buffer( 116 | 'running_cov', 117 | torch.eye(num_pergroup).unsqueeze(0).repeat(self.num_groups, 1, 1)) 118 | 119 | self.reset_parameters() 120 | 121 | def reset_parameters(self): 122 | self.running_mean.zero_() 123 | self.running_cov.zero_() 124 | nn.init.ones_(self.sw_mean_weight) 125 | if not self.tie_weight: 126 | nn.init.ones_(self.sw_var_weight) 127 | if self.affine: 128 | nn.init.ones_(self.weight) 129 | nn.init.zeros_(self.bias) 130 | 131 | def __repr__(self): 132 | return ('{name}({num_features}, num_pergroup={num_pergroup}, ' 133 | 'sw_type={sw_type}, T={T}, tie_weight={tie_weight}, ' 134 | 'eps={eps}, momentum={momentum}, affine={affine})'.format( 135 | name=self.__class__.__name__, **self.__dict__)) 136 | 137 | def forward(self, x): 138 | N, C, H, W = x.size() 139 | c, g = self.num_pergroup, self.num_groups 140 | 141 | in_data_t = x.transpose(0, 1).contiguous() 142 | # g x c x (N x H x W) 143 | in_data_t = in_data_t.view(g, c, -1) 144 | # calculate batch mean and covariance 145 | mean_bn, cov_bn = SyncMeanCov.apply(in_data_t, self.running_mean, 146 | self.running_cov, self.momentum, 147 | self.training) 148 | 149 | mean_bn = mean_bn.view(1, g, c, 1).expand(N, g, c, 1).contiguous() 150 | mean_bn = mean_bn.view(N * g, c, 1) 151 | cov_bn = cov_bn.view(1, g, c, c).expand(N, g, c, c).contiguous() 152 | cov_bn = cov_bn.view(N * g, c, c) 153 | 154 | # (N x g) x c x (H x W) 155 | in_data = x.view(N * g, c, -1) 156 | 157 | eye = in_data.data.new().resize_(c, c) 158 | eye = torch.nn.init.eye_(eye).view(1, c, c).expand(N * g, c, c) 159 | 160 | # calculate other statistics 161 | # (N x g) x c x 1 162 | mean_in = in_data.mean(-1, keepdim=True) 163 | x_in = in_data - mean_in 164 | # (N x g) x c x c 165 | cov_in = torch.bmm(x_in, torch.transpose(x_in, 1, 2)).div(H * W) 166 | if self.sw_type in [3, 5]: 167 | x = x.view(N, -1) 168 | mean_ln = x.mean(-1, keepdim=True).view(N, 1, 1, 1) 169 | mean_ln = mean_ln.expand(N, g, 1, 1).contiguous().view(N * g, 1, 1) 170 | var_ln = x.var(-1, keepdim=True).view(N, 1, 1, 1) 171 | var_ln = var_ln.expand(N, g, 1, 1).contiguous().view(N * g, 1, 1) 172 | var_ln = var_ln * eye 173 | if self.sw_type == 5: 174 | var_bn = torch.diag_embed(torch.diagonal(cov_bn, dim1=-2, dim2=-1)) 175 | var_in = torch.diag_embed(torch.diagonal(cov_in, dim1=-2, dim2=-1)) 176 | 177 | # calculate weighted average of mean and covariance 178 | softmax = nn.Softmax(0) 179 | mean_weight = softmax(self.sw_mean_weight) 180 | if not self.tie_weight: 181 | var_weight = softmax(self.sw_var_weight) 182 | else: 183 | var_weight = mean_weight 184 | 185 | # BW + IW 186 | if self.sw_type == 2: 187 | # (N x g) x c x 1 188 | mean = mean_weight[0] * mean_bn + mean_weight[1] * mean_in 189 | cov = var_weight[0] * cov_bn + var_weight[1] * cov_in + \ 190 | self.eps * eye 191 | # BW + IW + LN 192 | elif self.sw_type == 3: 193 | mean = mean_weight[0] * mean_bn + \ 194 | mean_weight[1] * mean_in + mean_weight[2] * mean_ln 195 | cov = var_weight[0] * cov_bn + var_weight[1] * cov_in + \ 196 | var_weight[2] * var_ln + self.eps * eye 197 | # BW + IW + BN + IN + LN 198 | elif self.sw_type == 5: 199 | mean = (mean_weight[0] + mean_weight[2]) * mean_bn + \ 200 | (mean_weight[1] + mean_weight[3]) * mean_in + \ 201 | mean_weight[4] * mean_ln 202 | cov = var_weight[0] * cov_bn + var_weight[1] * cov_in + \ 203 | var_weight[0] * var_bn + var_weight[1] * var_in + \ 204 | var_weight[4] * var_ln + self.eps * eye 205 | 206 | # perform whitening using Newton's iteration 207 | Ng, c, _ = cov.size() 208 | P = torch.eye(c).to(cov).expand(Ng, c, c) 209 | # reciprocal of trace of covariance 210 | rTr = (cov * P).sum((1, 2), keepdim=True).reciprocal_() 211 | cov_N = cov * rTr 212 | for k in range(self.T): 213 | P = torch.baddbmm(beta=1.5, input=P, alpha=-0.5, batch1=torch.matrix_power(P, 3), batch2=cov_N) 214 | # whiten matrix: the matrix inverse of covariance, i.e., cov^{-1/2} 215 | wm = P.mul_(rTr.sqrt()) 216 | 217 | x_hat = torch.bmm(wm, in_data - mean) 218 | x_hat = x_hat.view(N, C, H, W) 219 | if self.affine: 220 | x_hat = x_hat * self.weight.view(1, self.num_features, 1, 1) + \ 221 | self.bias.view(1, self.num_features, 1, 1) 222 | 223 | return x_hat 224 | -------------------------------------------------------------------------------- /network/wider_resnet.py: -------------------------------------------------------------------------------- 1 | """ 2 | # Code adapted from: 3 | # https://github.com/mapillary/inplace_abn/ 4 | # 5 | # BSD 3-Clause License 6 | # 7 | # Copyright (c) 2017, mapillary 8 | # All rights reserved. 9 | # 10 | # Redistribution and use in source and binary forms, with or without 11 | # modification, are permitted provided that the following conditions are met: 12 | # 13 | # * Redistributions of source code must retain the above copyright notice, this 14 | # list of conditions and the following disclaimer. 15 | # 16 | # * Redistributions in binary form must reproduce the above copyright notice, 17 | # this list of conditions and the following disclaimer in the documentation 18 | # and/or other materials provided with the distribution. 19 | # 20 | # * Neither the name of the copyright holder nor the names of its 21 | # contributors may be used to endorse or promote products derived from 22 | # this software without specific prior written permission. 23 | # 24 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 25 | # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 26 | # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 27 | # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 28 | # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 29 | # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 30 | # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 31 | # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 32 | # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 33 | # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 34 | """ 35 | import logging 36 | import sys 37 | from collections import OrderedDict 38 | from functools import partial 39 | import torch.nn as nn 40 | import torch 41 | import network.mynn as mynn 42 | 43 | def bnrelu(channels): 44 | """ 45 | Single Layer BN and Relui 46 | """ 47 | return nn.Sequential(mynn.Norm2d(channels), 48 | nn.ReLU(inplace=True)) 49 | 50 | class GlobalAvgPool2d(nn.Module): 51 | """ 52 | Global average pooling over the input's spatial dimensions 53 | """ 54 | 55 | def __init__(self): 56 | super(GlobalAvgPool2d, self).__init__() 57 | logging.info("Global Average Pooling Initialized") 58 | 59 | def forward(self, inputs): 60 | in_size = inputs.size() 61 | return inputs.view((in_size[0], in_size[1], -1)).mean(dim=2) 62 | 63 | 64 | class IdentityResidualBlock(nn.Module): 65 | """ 66 | Identity Residual Block for WideResnet 67 | """ 68 | def __init__(self, 69 | in_channels, 70 | channels, 71 | stride=1, 72 | dilation=1, 73 | groups=1, 74 | norm_act=bnrelu, 75 | dropout=None, 76 | dist_bn=False 77 | ): 78 | """Configurable identity-mapping residual block 79 | 80 | Parameters 81 | ---------- 82 | in_channels : int 83 | Number of input channels. 84 | channels : list of int 85 | Number of channels in the internal feature maps. 86 | Can either have two or three elements: if three construct 87 | a residual block with two `3 x 3` convolutions, 88 | otherwise construct a bottleneck block with `1 x 1`, then 89 | `3 x 3` then `1 x 1` convolutions. 90 | stride : int 91 | Stride of the first `3 x 3` convolution 92 | dilation : int 93 | Dilation to apply to the `3 x 3` convolutions. 94 | groups : int 95 | Number of convolution groups. 96 | This is used to create ResNeXt-style blocks and is only compatible with 97 | bottleneck blocks. 98 | norm_act : callable 99 | Function to create normalization / activation Module. 100 | dropout: callable 101 | Function to create Dropout Module. 102 | dist_bn: Boolean 103 | A variable to enable or disable use of distributed BN 104 | """ 105 | super(IdentityResidualBlock, self).__init__() 106 | self.dist_bn = dist_bn 107 | 108 | # Check if we are using distributed BN and use the nn from encoding.nn 109 | # library rather than using standard pytorch.nn 110 | 111 | 112 | # Check parameters for inconsistencies 113 | if len(channels) != 2 and len(channels) != 3: 114 | raise ValueError("channels must contain either two or three values") 115 | if len(channels) == 2 and groups != 1: 116 | raise ValueError("groups > 1 are only valid if len(channels) == 3") 117 | 118 | is_bottleneck = len(channels) == 3 119 | need_proj_conv = stride != 1 or in_channels != channels[-1] 120 | 121 | self.bn1 = norm_act(in_channels) 122 | if not is_bottleneck: 123 | layers = [ 124 | ("conv1", nn.Conv2d(in_channels, 125 | channels[0], 126 | 3, 127 | stride=stride, 128 | padding=dilation, 129 | bias=False, 130 | dilation=dilation)), 131 | ("bn2", norm_act(channels[0])), 132 | ("conv2", nn.Conv2d(channels[0], channels[1], 133 | 3, 134 | stride=1, 135 | padding=dilation, 136 | bias=False, 137 | dilation=dilation)) 138 | ] 139 | if dropout is not None: 140 | layers = layers[0:2] + [("dropout", dropout())] + layers[2:] 141 | else: 142 | layers = [ 143 | ("conv1", 144 | nn.Conv2d(in_channels, 145 | channels[0], 146 | 1, 147 | stride=stride, 148 | padding=0, 149 | bias=False)), 150 | ("bn2", norm_act(channels[0])), 151 | ("conv2", nn.Conv2d(channels[0], 152 | channels[1], 153 | 3, stride=1, 154 | padding=dilation, bias=False, 155 | groups=groups, 156 | dilation=dilation)), 157 | ("bn3", norm_act(channels[1])), 158 | ("conv3", nn.Conv2d(channels[1], channels[2], 159 | 1, stride=1, padding=0, bias=False)) 160 | ] 161 | if dropout is not None: 162 | layers = layers[0:4] + [("dropout", dropout())] + layers[4:] 163 | self.convs = nn.Sequential(OrderedDict(layers)) 164 | 165 | if need_proj_conv: 166 | self.proj_conv = nn.Conv2d( 167 | in_channels, channels[-1], 1, stride=stride, padding=0, bias=False) 168 | 169 | def forward(self, x): 170 | """ 171 | This is the standard forward function for non-distributed batch norm 172 | """ 173 | if hasattr(self, "proj_conv"): 174 | bn1 = self.bn1(x) 175 | shortcut = self.proj_conv(bn1) 176 | else: 177 | shortcut = x.clone() 178 | bn1 = self.bn1(x) 179 | 180 | out = self.convs(bn1) 181 | out.add_(shortcut) 182 | return out 183 | 184 | 185 | 186 | 187 | class WiderResNet(nn.Module): 188 | """ 189 | WideResnet Global Module for Initialization 190 | """ 191 | def __init__(self, 192 | structure, 193 | norm_act=bnrelu, 194 | classes=0 195 | ): 196 | """Wider ResNet with pre-activation (identity mapping) blocks 197 | 198 | Parameters 199 | ---------- 200 | structure : list of int 201 | Number of residual blocks in each of the six modules of the network. 202 | norm_act : callable 203 | Function to create normalization / activation Module. 204 | classes : int 205 | If not `0` also include global average pooling and \ 206 | a fully-connected layer with `classes` outputs at the end 207 | of the network. 208 | """ 209 | super(WiderResNet, self).__init__() 210 | self.structure = structure 211 | 212 | if len(structure) != 6: 213 | raise ValueError("Expected a structure with six values") 214 | 215 | # Initial layers 216 | self.mod1 = nn.Sequential(OrderedDict([ 217 | ("conv1", nn.Conv2d(3, 64, 3, stride=1, padding=1, bias=False)) 218 | ])) 219 | 220 | # Groups of residual blocks 221 | in_channels = 64 222 | channels = [(128, 128), (256, 256), (512, 512), (512, 1024), 223 | (512, 1024, 2048), (1024, 2048, 4096)] 224 | for mod_id, num in enumerate(structure): 225 | # Create blocks for module 226 | blocks = [] 227 | for block_id in range(num): 228 | blocks.append(( 229 | "block%d" % (block_id + 1), 230 | IdentityResidualBlock(in_channels, channels[mod_id], 231 | norm_act=norm_act) 232 | )) 233 | 234 | # Update channels and p_keep 235 | in_channels = channels[mod_id][-1] 236 | 237 | # Create module 238 | if mod_id <= 4: 239 | self.add_module("pool%d" % 240 | (mod_id + 2), nn.MaxPool2d(3, stride=2, padding=1)) 241 | self.add_module("mod%d" % (mod_id + 2), nn.Sequential(OrderedDict(blocks))) 242 | 243 | # Pooling and predictor 244 | self.bn_out = norm_act(in_channels) 245 | if classes != 0: 246 | self.classifier = nn.Sequential(OrderedDict([ 247 | ("avg_pool", GlobalAvgPool2d()), 248 | ("fc", nn.Linear(in_channels, classes)) 249 | ])) 250 | 251 | def forward(self, img): 252 | out = self.mod1(img) 253 | out = self.mod2(self.pool2(out)) 254 | out = self.mod3(self.pool3(out)) 255 | out = self.mod4(self.pool4(out)) 256 | out = self.mod5(self.pool5(out)) 257 | out = self.mod6(self.pool6(out)) 258 | out = self.mod7(out) 259 | out = self.bn_out(out) 260 | 261 | if hasattr(self, "classifier"): 262 | out = self.classifier(out) 263 | 264 | return out 265 | 266 | 267 | class WiderResNetA2(nn.Module): 268 | """ 269 | Wider ResNet with pre-activation (identity mapping) blocks 270 | 271 | This variant uses down-sampling by max-pooling in the first two blocks and 272 | by strided convolution in the others. 273 | 274 | Parameters 275 | ---------- 276 | structure : list of int 277 | Number of residual blocks in each of the six modules of the network. 278 | norm_act : callable 279 | Function to create normalization / activation Module. 280 | classes : int 281 | If not `0` also include global average pooling and a fully-connected layer 282 | with `classes` outputs at the end 283 | of the network. 284 | dilation : bool 285 | If `True` apply dilation to the last three modules and change the 286 | down-sampling factor from 32 to 8. 287 | """ 288 | def __init__(self, 289 | structure, 290 | norm_act=bnrelu, 291 | classes=0, 292 | dilation=False, 293 | dist_bn=False 294 | ): 295 | super(WiderResNetA2, self).__init__() 296 | self.dist_bn = dist_bn 297 | 298 | # If using distributed batch norm, use the encoding.nn as oppose to torch.nn 299 | 300 | 301 | nn.Dropout = nn.Dropout2d 302 | norm_act = bnrelu 303 | self.structure = structure 304 | self.dilation = dilation 305 | 306 | if len(structure) != 6: 307 | raise ValueError("Expected a structure with six values") 308 | 309 | # Initial layers 310 | self.mod1 = torch.nn.Sequential(OrderedDict([ 311 | ("conv1", nn.Conv2d(3, 64, 3, stride=1, padding=1, bias=False)) 312 | ])) 313 | 314 | # Groups of residual blocks 315 | in_channels = 64 316 | channels = [(128, 128), (256, 256), (512, 512), (512, 1024), (512, 1024, 2048), 317 | (1024, 2048, 4096)] 318 | for mod_id, num in enumerate(structure): 319 | # Create blocks for module 320 | blocks = [] 321 | for block_id in range(num): 322 | if not dilation: 323 | dil = 1 324 | stride = 2 if block_id == 0 and 2 <= mod_id <= 4 else 1 325 | else: 326 | if mod_id == 3: 327 | dil = 2 328 | elif mod_id > 3: 329 | dil = 4 330 | else: 331 | dil = 1 332 | stride = 2 if block_id == 0 and mod_id == 2 else 1 333 | 334 | if mod_id == 4: 335 | drop = partial(nn.Dropout, p=0.3) 336 | elif mod_id == 5: 337 | drop = partial(nn.Dropout, p=0.5) 338 | else: 339 | drop = None 340 | 341 | blocks.append(( 342 | "block%d" % (block_id + 1), 343 | IdentityResidualBlock(in_channels, 344 | channels[mod_id], norm_act=norm_act, 345 | stride=stride, dilation=dil, 346 | dropout=drop, dist_bn=self.dist_bn) 347 | )) 348 | 349 | # Update channels and p_keep 350 | in_channels = channels[mod_id][-1] 351 | 352 | # Create module 353 | if mod_id < 2: 354 | self.add_module("pool%d" % 355 | (mod_id + 2), nn.MaxPool2d(3, stride=2, padding=1)) 356 | self.add_module("mod%d" % (mod_id + 2), nn.Sequential(OrderedDict(blocks))) 357 | 358 | # Pooling and predictor 359 | self.bn_out = norm_act(in_channels) 360 | if classes != 0: 361 | self.classifier = nn.Sequential(OrderedDict([ 362 | ("avg_pool", GlobalAvgPool2d()), 363 | ("fc", nn.Linear(in_channels, classes)) 364 | ])) 365 | 366 | def forward(self, img): 367 | out = self.mod1(img) 368 | out = self.mod2(self.pool2(out)) 369 | out = self.mod3(self.pool3(out)) 370 | out = self.mod4(out) 371 | out = self.mod5(out) 372 | out = self.mod6(out) 373 | out = self.mod7(out) 374 | out = self.bn_out(out) 375 | 376 | if hasattr(self, "classifier"): 377 | return self.classifier(out) 378 | return out 379 | 380 | 381 | _NETS = { 382 | "16": {"structure": [1, 1, 1, 1, 1, 1]}, 383 | "20": {"structure": [1, 1, 1, 3, 1, 1]}, 384 | "38": {"structure": [3, 3, 6, 3, 1, 1]}, 385 | } 386 | 387 | __all__ = [] 388 | for name, params in _NETS.items(): 389 | net_name = "wider_resnet" + name 390 | setattr(sys.modules[__name__], net_name, partial(WiderResNet, **params)) 391 | __all__.append(net_name) 392 | for name, params in _NETS.items(): 393 | net_name = "wider_resnet" + name + "_a2" 394 | setattr(sys.modules[__name__], net_name, partial(WiderResNetA2, **params)) 395 | __all__.append(net_name) 396 | -------------------------------------------------------------------------------- /optimizer.py: -------------------------------------------------------------------------------- 1 | """ 2 | Pytorch Optimizer and Scheduler Related Task 3 | """ 4 | import math 5 | import logging 6 | import torch 7 | from torch import optim 8 | from config import cfg 9 | 10 | 11 | def get_optimizer(args, net): 12 | """ 13 | Decide Optimizer (Adam or SGD) 14 | """ 15 | base_params = [] 16 | 17 | for name, param in net.named_parameters(): 18 | base_params.append(param) 19 | 20 | if args.sgd: 21 | optimizer = optim.SGD(base_params, 22 | lr=args.lr, 23 | weight_decay=5e-4, #args.weight_decay, 24 | momentum=args.momentum, 25 | nesterov=False) 26 | else: 27 | raise ValueError('Not a valid optimizer') 28 | 29 | if args.lr_schedule == 'scl-poly': 30 | if cfg.REDUCE_BORDER_ITER == -1: 31 | raise ValueError('ERROR Cannot Do Scale Poly') 32 | 33 | rescale_thresh = cfg.REDUCE_BORDER_ITER 34 | scale_value = args.rescale 35 | lambda1 = lambda iteration: \ 36 | math.pow(1 - iteration / args.max_iter, 37 | args.poly_exp) if iteration < rescale_thresh else scale_value * math.pow( 38 | 1 - (iteration - rescale_thresh) / (args.max_iter - rescale_thresh), 39 | args.repoly) 40 | scheduler = optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda1) 41 | elif args.lr_schedule == 'poly': 42 | lambda1 = lambda iteration: math.pow(1 - iteration / args.max_iter, args.poly_exp) 43 | scheduler = optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda1) 44 | else: 45 | raise ValueError('unknown lr schedule {}'.format(args.lr_schedule)) 46 | 47 | return optimizer, scheduler 48 | 49 | 50 | def load_weights(net, optimizer, scheduler, snapshot_file, restore_optimizer_bool=False): 51 | """ 52 | Load weights from snapshot file 53 | """ 54 | logging.info("Loading weights from model %s", snapshot_file) 55 | net, optimizer, scheduler, epoch, mean_iu = restore_snapshot(net, optimizer, scheduler, snapshot_file, 56 | restore_optimizer_bool) 57 | return epoch, mean_iu 58 | 59 | 60 | def restore_snapshot(net, optimizer, scheduler, snapshot, restore_optimizer_bool): 61 | """ 62 | Restore weights and optimizer (if needed ) for resuming job. 63 | """ 64 | checkpoint = torch.load(snapshot, map_location=torch.device('cpu')) 65 | logging.info("Checkpoint Load Compelete") 66 | if optimizer is not None and 'optimizer' in checkpoint and restore_optimizer_bool: 67 | optimizer.load_state_dict(checkpoint['optimizer']) 68 | if scheduler is not None and 'scheduler' in checkpoint and restore_optimizer_bool: 69 | scheduler.load_state_dict(checkpoint['scheduler']) 70 | 71 | if 'state_dict' in checkpoint: 72 | net = forgiving_state_restore(net, checkpoint['state_dict']) 73 | else: 74 | net = forgiving_state_restore(net, checkpoint) 75 | 76 | return net, optimizer, scheduler, checkpoint['epoch'], checkpoint['mean_iu'] 77 | 78 | 79 | def forgiving_state_restore(net, loaded_dict): 80 | """ 81 | Handle partial loading when some tensors don't match up in size. 82 | Because we want to use models that were trained off a different 83 | number of classes. 84 | """ 85 | net_state_dict = net.state_dict() 86 | new_loaded_dict = {} 87 | for k in net_state_dict: 88 | if k in loaded_dict and net_state_dict[k].size() == loaded_dict[k].size(): 89 | new_loaded_dict[k] = loaded_dict[k] 90 | else: 91 | print("Skipped loading parameter", k) 92 | # logging.info("Skipped loading parameter %s", k) 93 | net_state_dict.update(new_loaded_dict) 94 | net.load_state_dict(net_state_dict) 95 | return net 96 | 97 | def forgiving_state_copy(target_net, source_net): 98 | """ 99 | Handle partial loading when some tensors don't match up in size. 100 | Because we want to use models that were trained off a different 101 | number of classes. 102 | """ 103 | net_state_dict = target_net.state_dict() 104 | loaded_dict = source_net.state_dict() 105 | new_loaded_dict = {} 106 | for k in net_state_dict: 107 | if k in loaded_dict and net_state_dict[k].size() == loaded_dict[k].size(): 108 | new_loaded_dict[k] = loaded_dict[k] 109 | print("Matched", k) 110 | else: 111 | print("Skipped loading parameter ", k) 112 | # logging.info("Skipped loading parameter %s", k) 113 | net_state_dict.update(new_loaded_dict) 114 | target_net.load_state_dict(net_state_dict) 115 | return target_net 116 | -------------------------------------------------------------------------------- /scripts/blindnet_infer_r50os16.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | echo "Running inference on" ${1} 3 | echo "Saving Results :" ${2} 4 | 5 | python -m torch.distributed.launch --nproc_per_node=1 --master_port=29401 infer.py \ 6 | --val_dataset cityscapes mapillary bdd-100k \ 7 | --arch network.deepv3.DeepR50V3PlusD \ 8 | --wt_layer 0 0 1 1 1 0 0 \ 9 | --mod blindnet \ 10 | --results ${2} \ 11 | --date 0101 \ 12 | --exp blindnet_r50os16_gtav \ 13 | --snapshot ${1} 14 | -------------------------------------------------------------------------------- /scripts/blindnet_train_r50os16_gtav.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | export NCCL_DEBUG=INFO 3 | 4 | python -m torch.distributed.launch --nproc_per_node=2 train.py \ 5 | --dataset gtav_jitter \ 6 | --covstat_val_dataset gtav \ 7 | --val_dataset cityscapes bdd100k mapillary \ 8 | --arch network.deepv3.DeepR50V3PlusD \ 9 | --city_mode 'train' \ 10 | --lr_schedule poly \ 11 | --lr 0.01 \ 12 | --poly_exp 0.9 \ 13 | --max_cu_epoch 10000 \ 14 | --class_uniform_pct 0.5 \ 15 | --class_uniform_tile 1024 \ 16 | --crop_size 768 \ 17 | --scale_min 0.5 \ 18 | --scale_max 2.0 \ 19 | --rrotate 0 \ 20 | --max_iter 40000 \ 21 | --bs_mult 4 \ 22 | --gblur \ 23 | --color_aug 0.5 \ 24 | --relax_denom 0.0 \ 25 | --wt_layer 0 0 1 1 1 0 0 \ 26 | --use_ca \ 27 | --use_cwcl \ 28 | --nce_T 0.07 \ 29 | --contrast_max_classes 15 \ 30 | --contrast_max_view 50 \ 31 | --jit_only \ 32 | --use_sdcl \ 33 | --w1 0.2 \ 34 | --w2 0.2 \ 35 | --w3 0.3 \ 36 | --w4 0.3 \ 37 | --num_patch 20 \ 38 | --date 0326 \ 39 | --exp BlindNet_r50os16_gtav \ 40 | --ckpt ./logs/ \ 41 | --tb_path ./logs/ 42 | -------------------------------------------------------------------------------- /scripts/blindnet_valid_r50os16_gtav.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | echo "Running inference on" ${1} 3 | 4 | python -m torch.distributed.launch --nproc_per_node=1 --master_port=29400 valid.py \ 5 | --val_dataset cityscapes bdd100k mapillary gtav \ 6 | --arch network.deepv3.DeepR50V3PlusD \ 7 | --wt_layer 0 0 1 1 1 0 0 \ 8 | --date 0101 \ 9 | --exp r50os16_gtav_blindnet \ 10 | --snapshot ${1} 11 | -------------------------------------------------------------------------------- /transforms/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/root0yang/BlindNet/1359fa6e9d1e9b011c416f43ec265e31cbde7d9a/transforms/__init__.py -------------------------------------------------------------------------------- /transforms/__pycache__/joint_transforms.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/root0yang/BlindNet/1359fa6e9d1e9b011c416f43ec265e31cbde7d9a/transforms/__pycache__/joint_transforms.cpython-36.pyc -------------------------------------------------------------------------------- /transforms/__pycache__/joint_transforms.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/root0yang/BlindNet/1359fa6e9d1e9b011c416f43ec265e31cbde7d9a/transforms/__pycache__/joint_transforms.cpython-37.pyc -------------------------------------------------------------------------------- /transforms/__pycache__/transforms.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/root0yang/BlindNet/1359fa6e9d1e9b011c416f43ec265e31cbde7d9a/transforms/__pycache__/transforms.cpython-36.pyc -------------------------------------------------------------------------------- /transforms/__pycache__/transforms.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/root0yang/BlindNet/1359fa6e9d1e9b011c416f43ec265e31cbde7d9a/transforms/__pycache__/transforms.cpython-37.pyc -------------------------------------------------------------------------------- /transforms/transforms.py: -------------------------------------------------------------------------------- 1 | """ 2 | # Code borrowded from: 3 | # https://github.com/zijundeng/pytorch-semantic-segmentation/blob/master/utils/transforms.py 4 | # 5 | # 6 | # MIT License 7 | # 8 | # Copyright (c) 2017 ZijunDeng 9 | # 10 | # Permission is hereby granted, free of charge, to any person obtaining a copy 11 | # of this software and associated documentation files (the "Software"), to deal 12 | # in the Software without restriction, including without limitation the rights 13 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 14 | # copies of the Software, and to permit persons to whom the Software is 15 | # furnished to do so, subject to the following conditions: 16 | # 17 | # The above copyright notice and this permission notice shall be included in all 18 | # copies or substantial portions of the Software. 19 | # 20 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 21 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 22 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 23 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 24 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 25 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 26 | # SOFTWARE. 27 | 28 | """ 29 | 30 | """ 31 | Standard Transform 32 | """ 33 | 34 | import random 35 | import numpy as np 36 | from skimage.filters import gaussian 37 | from skimage.restoration import denoise_bilateral 38 | import torch 39 | from PIL import Image, ImageEnhance 40 | import torchvision.transforms as torch_tr 41 | from config import cfg 42 | from scipy.ndimage.interpolation import shift 43 | 44 | from skimage.segmentation import find_boundaries 45 | from skimage.util import random_noise 46 | 47 | try: 48 | import accimage 49 | except ImportError: 50 | accimage = None 51 | 52 | 53 | class RandomVerticalFlip(object): 54 | def __call__(self, img): 55 | if random.random() < 0.5: 56 | return img.transpose(Image.FLIP_TOP_BOTTOM) 57 | return img 58 | 59 | 60 | class DeNormalize(object): 61 | def __init__(self, mean, std): 62 | self.mean = mean 63 | self.std = std 64 | 65 | def __call__(self, tensor): 66 | for t, m, s in zip(tensor, self.mean, self.std): 67 | t.mul_(s).add_(m) 68 | return tensor 69 | 70 | 71 | class MaskToTensor(object): 72 | def __call__(self, img): 73 | return torch.from_numpy(np.array(img, dtype=np.int32)).long() 74 | 75 | class RelaxedBoundaryLossToTensor(object): 76 | """ 77 | Boundary Relaxation 78 | """ 79 | def __init__(self,ignore_id, num_classes): 80 | self.ignore_id=ignore_id 81 | self.num_classes= num_classes 82 | 83 | 84 | def new_one_hot_converter(self,a): 85 | ncols = self.num_classes+1 86 | out = np.zeros( (a.size,ncols), dtype=np.uint8) 87 | out[np.arange(a.size),a.ravel()] = 1 88 | out.shape = a.shape + (ncols,) 89 | return out 90 | 91 | def __call__(self,img): 92 | 93 | img_arr = np.array(img) 94 | img_arr[img_arr==self.ignore_id]=self.num_classes 95 | 96 | if cfg.STRICTBORDERCLASS != None: 97 | one_hot_orig = self.new_one_hot_converter(img_arr) 98 | mask = np.zeros((img_arr.shape[0],img_arr.shape[1])) 99 | for cls in cfg.STRICTBORDERCLASS: 100 | mask = np.logical_or(mask,(img_arr == cls)) 101 | one_hot = 0 102 | 103 | border = cfg.BORDER_WINDOW 104 | if (cfg.REDUCE_BORDER_ITER !=-1 and cfg.ITER > cfg.REDUCE_BORDER_ITER): 105 | border = border // 2 106 | border_prediction = find_boundaries(img_arr, mode='thick').astype(np.uint8) 107 | 108 | for i in range(-border,border+1): 109 | for j in range(-border, border+1): 110 | shifted= shift(img_arr,(i,j), cval=self.num_classes) 111 | one_hot += self.new_one_hot_converter(shifted) 112 | 113 | one_hot[one_hot>1] = 1 114 | 115 | if cfg.STRICTBORDERCLASS != None: 116 | one_hot = np.where(np.expand_dims(mask,2), one_hot_orig, one_hot) 117 | 118 | one_hot = np.moveaxis(one_hot,-1,0) 119 | 120 | 121 | if (cfg.REDUCE_BORDER_ITER !=-1 and cfg.ITER > cfg.REDUCE_BORDER_ITER): 122 | one_hot = np.where(border_prediction,2*one_hot,1*one_hot) 123 | # print(one_hot.shape) 124 | return torch.from_numpy(one_hot).byte() 125 | 126 | class ResizeHeight(object): 127 | def __init__(self, size, interpolation=Image.BILINEAR): 128 | self.target_h = size 129 | self.interpolation = interpolation 130 | 131 | def __call__(self, img): 132 | w, h = img.size 133 | target_w = int(w / h * self.target_h) 134 | return img.resize((target_w, self.target_h), self.interpolation) 135 | 136 | 137 | class FreeScale(object): 138 | def __init__(self, size, interpolation=Image.BILINEAR): 139 | self.size = tuple(reversed(size)) # size: (h, w) 140 | self.interpolation = interpolation 141 | 142 | def __call__(self, img): 143 | return img.resize(self.size, self.interpolation) 144 | 145 | 146 | class FlipChannels(object): 147 | """ 148 | Flip around the x-axis 149 | """ 150 | def __call__(self, img): 151 | img = np.array(img)[:, :, ::-1] 152 | return Image.fromarray(img.astype(np.uint8)) 153 | 154 | 155 | class RandomGaussianBlur(object): 156 | """ 157 | Apply Gaussian Blur 158 | """ 159 | def __call__(self, img): 160 | sigma = 0.15 + random.random() * 1.15 161 | blurred_img = gaussian(np.array(img), sigma=sigma, multichannel=True) 162 | blurred_img *= 255 163 | return Image.fromarray(blurred_img.astype(np.uint8)) 164 | 165 | 166 | class RandomGaussianNoise(object): 167 | def __call__(self, img): 168 | noised_img = random_noise(np.array(img), mode='gaussian') 169 | noised_img *= 255 170 | return Image.fromarray(noised_img.astype(np.uint8)) 171 | 172 | 173 | class RandomBilateralBlur(object): 174 | """ 175 | Apply Bilateral Filtering 176 | 177 | """ 178 | def __call__(self, img): 179 | sigma = random.uniform(0.05,0.75) 180 | blurred_img = denoise_bilateral(np.array(img), sigma_spatial=sigma, multichannel=True) 181 | blurred_img *= 255 182 | return Image.fromarray(blurred_img.astype(np.uint8)) 183 | 184 | def _is_pil_image(img): 185 | if accimage is not None: 186 | return isinstance(img, (Image.Image, accimage.Image)) 187 | else: 188 | return isinstance(img, Image.Image) 189 | 190 | 191 | def adjust_brightness(img, brightness_factor): 192 | """Adjust brightness of an Image. 193 | 194 | Args: 195 | img (PIL Image): PIL Image to be adjusted. 196 | brightness_factor (float): How much to adjust the brightness. Can be 197 | any non negative number. 0 gives a black image, 1 gives the 198 | original image while 2 increases the brightness by a factor of 2. 199 | 200 | Returns: 201 | PIL Image: Brightness adjusted image. 202 | """ 203 | if not _is_pil_image(img): 204 | raise TypeError('img should be PIL Image. Got {}'.format(type(img))) 205 | 206 | enhancer = ImageEnhance.Brightness(img) 207 | img = enhancer.enhance(brightness_factor) 208 | return img 209 | 210 | 211 | def adjust_contrast(img, contrast_factor): 212 | """Adjust contrast of an Image. 213 | 214 | Args: 215 | img (PIL Image): PIL Image to be adjusted. 216 | contrast_factor (float): How much to adjust the contrast. Can be any 217 | non negative number. 0 gives a solid gray image, 1 gives the 218 | original image while 2 increases the contrast by a factor of 2. 219 | 220 | Returns: 221 | PIL Image: Contrast adjusted image. 222 | """ 223 | if not _is_pil_image(img): 224 | raise TypeError('img should be PIL Image. Got {}'.format(type(img))) 225 | 226 | enhancer = ImageEnhance.Contrast(img) 227 | img = enhancer.enhance(contrast_factor) 228 | return img 229 | 230 | 231 | def adjust_saturation(img, saturation_factor): 232 | """Adjust color saturation of an image. 233 | 234 | Args: 235 | img (PIL Image): PIL Image to be adjusted. 236 | saturation_factor (float): How much to adjust the saturation. 0 will 237 | give a black and white image, 1 will give the original image while 238 | 2 will enhance the saturation by a factor of 2. 239 | 240 | Returns: 241 | PIL Image: Saturation adjusted image. 242 | """ 243 | if not _is_pil_image(img): 244 | raise TypeError('img should be PIL Image. Got {}'.format(type(img))) 245 | 246 | enhancer = ImageEnhance.Color(img) 247 | img = enhancer.enhance(saturation_factor) 248 | return img 249 | 250 | 251 | def adjust_hue(img, hue_factor): 252 | """Adjust hue of an image. 253 | 254 | The image hue is adjusted by converting the image to HSV and 255 | cyclically shifting the intensities in the hue channel (H). 256 | The image is then converted back to original image mode. 257 | 258 | `hue_factor` is the amount of shift in H channel and must be in the 259 | interval `[-0.5, 0.5]`. 260 | 261 | See https://en.wikipedia.org/wiki/Hue for more details on Hue. 262 | 263 | Args: 264 | img (PIL Image): PIL Image to be adjusted. 265 | hue_factor (float): How much to shift the hue channel. Should be in 266 | [-0.5, 0.5]. 0.5 and -0.5 give complete reversal of hue channel in 267 | HSV space in positive and negative direction respectively. 268 | 0 means no shift. Therefore, both -0.5 and 0.5 will give an image 269 | with complementary colors while 0 gives the original image. 270 | 271 | Returns: 272 | PIL Image: Hue adjusted image. 273 | """ 274 | if not(-0.5 <= hue_factor <= 0.5): 275 | raise ValueError('hue_factor is not in [-0.5, 0.5].'.format(hue_factor)) 276 | 277 | if not _is_pil_image(img): 278 | raise TypeError('img should be PIL Image. Got {}'.format(type(img))) 279 | input_mode = img.mode 280 | if input_mode in {'L', '1', 'I', 'F'}: 281 | return img 282 | 283 | h, s, v = img.convert('HSV').split() 284 | 285 | np_h = np.array(h, dtype=np.uint8) 286 | # uint8 addition take cares of rotation across boundaries 287 | with np.errstate(over='ignore'): 288 | np_h += np.uint8(hue_factor * 255) 289 | h = Image.fromarray(np_h, 'L') 290 | img = Image.merge('HSV', (h, s, v)).convert(input_mode) 291 | return img 292 | 293 | 294 | class ColorJitter(object): 295 | """Randomly change the brightness, contrast and saturation of an image. 296 | 297 | Args: 298 | brightness (float): How much to jitter brightness. brightness_factor 299 | is chosen uniformly from [max(0, 1 - brightness), 1 + brightness]. 300 | contrast (float): How much to jitter contrast. contrast_factor 301 | is chosen uniformly from [max(0, 1 - contrast), 1 + contrast]. 302 | saturation (float): How much to jitter saturation. saturation_factor 303 | is chosen uniformly from [max(0, 1 - saturation), 1 + saturation]. 304 | hue(float): How much to jitter hue. hue_factor is chosen uniformly from 305 | [-hue, hue]. Should be >=0 and <= 0.5. 306 | """ 307 | def __init__(self, brightness=0, contrast=0, saturation=0, hue=0): 308 | self.brightness = brightness 309 | self.contrast = contrast 310 | self.saturation = saturation 311 | self.hue = hue 312 | 313 | @staticmethod 314 | def get_params(brightness, contrast, saturation, hue): 315 | """Get a randomized transform to be applied on image. 316 | 317 | Arguments are same as that of __init__. 318 | 319 | Returns: 320 | Transform which randomly adjusts brightness, contrast and 321 | saturation in a random order. 322 | """ 323 | transforms = [] 324 | if brightness > 0: 325 | brightness_factor = np.random.uniform(max(0, 1 - brightness), 1 + brightness) 326 | transforms.append( 327 | torch_tr.Lambda(lambda img: adjust_brightness(img, brightness_factor))) 328 | 329 | if contrast > 0: 330 | contrast_factor = np.random.uniform(max(0, 1 - contrast), 1 + contrast) 331 | transforms.append( 332 | torch_tr.Lambda(lambda img: adjust_contrast(img, contrast_factor))) 333 | 334 | if saturation > 0: 335 | saturation_factor = np.random.uniform(max(0, 1 - saturation), 1 + saturation) 336 | transforms.append( 337 | torch_tr.Lambda(lambda img: adjust_saturation(img, saturation_factor))) 338 | 339 | if hue > 0: 340 | hue_factor = np.random.uniform(-hue, hue) 341 | transforms.append( 342 | torch_tr.Lambda(lambda img: adjust_hue(img, hue_factor))) 343 | 344 | np.random.shuffle(transforms) 345 | transform = torch_tr.Compose(transforms) 346 | 347 | return transform 348 | 349 | def __call__(self, img): 350 | """ 351 | Args: 352 | img (PIL Image): Input image. 353 | 354 | Returns: 355 | PIL Image: Color jittered image. 356 | """ 357 | transform = self.get_params(self.brightness, self.contrast, 358 | self.saturation, self.hue) 359 | return transform(img) 360 | -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/root0yang/BlindNet/1359fa6e9d1e9b011c416f43ec265e31cbde7d9a/utils/__init__.py -------------------------------------------------------------------------------- /utils/attr_dict.py: -------------------------------------------------------------------------------- 1 | """ 2 | # Code adapted from: 3 | # https://github.com/facebookresearch/Detectron/blob/master/detectron/utils/collections.py 4 | 5 | Source License 6 | # Copyright (c) 2017-present, Facebook, Inc. 7 | # 8 | # Licensed under the Apache License, Version 2.0 (the "License"); 9 | # you may not use this file except in compliance with the License. 10 | # You may obtain a copy of the License at 11 | # 12 | # http://www.apache.org/licenses/LICENSE-2.0 13 | # 14 | # Unless required by applicable law or agreed to in writing, software 15 | # distributed under the License is distributed on an "AS IS" BASIS, 16 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 17 | # See the License for the specific language governing permissions and 18 | # limitations under the License. 19 | ############################################################################## 20 | # 21 | # Based on: 22 | # -------------------------------------------------------- 23 | # Fast R-CNN 24 | # Copyright (c) 2015 Microsoft 25 | # Licensed under The MIT License [see LICENSE for details] 26 | # Written by Ross Girshick 27 | # -------------------------------------------------------- 28 | """ 29 | 30 | class AttrDict(dict): 31 | 32 | IMMUTABLE = '__immutable__' 33 | 34 | def __init__(self, *args, **kwargs): 35 | super(AttrDict, self).__init__(*args, **kwargs) 36 | self.__dict__[AttrDict.IMMUTABLE] = False 37 | 38 | def __getattr__(self, name): 39 | if name in self.__dict__: 40 | return self.__dict__[name] 41 | elif name in self: 42 | return self[name] 43 | else: 44 | raise AttributeError(name) 45 | 46 | def __setattr__(self, name, value): 47 | if not self.__dict__[AttrDict.IMMUTABLE]: 48 | if name in self.__dict__: 49 | self.__dict__[name] = value 50 | else: 51 | self[name] = value 52 | else: 53 | raise AttributeError( 54 | 'Attempted to set "{}" to "{}", but AttrDict is immutable'. 55 | format(name, value) 56 | ) 57 | 58 | def immutable(self, is_immutable): 59 | """Set immutability to is_immutable and recursively apply the setting 60 | to all nested AttrDicts. 61 | """ 62 | self.__dict__[AttrDict.IMMUTABLE] = is_immutable 63 | # Recursively set immutable state 64 | for v in self.__dict__.values(): 65 | if isinstance(v, AttrDict): 66 | v.immutable(is_immutable) 67 | for v in self.values(): 68 | if isinstance(v, AttrDict): 69 | v.immutable(is_immutable) 70 | 71 | def is_immutable(self): 72 | return self.__dict__[AttrDict.IMMUTABLE] 73 | -------------------------------------------------------------------------------- /utils/misc.py: -------------------------------------------------------------------------------- 1 | """ 2 | Miscellanous Functions 3 | """ 4 | 5 | import sys 6 | import re 7 | import os 8 | import shutil 9 | import torch 10 | from datetime import datetime 11 | import logging 12 | from subprocess import call 13 | import shlex 14 | from tensorboardX import SummaryWriter 15 | import datasets 16 | import numpy as np 17 | import torchvision.transforms as standard_transforms 18 | import torchvision.utils as vutils 19 | from config import cfg 20 | import random 21 | 22 | 23 | # Create unique output dir name based on non-default command line args 24 | def make_exp_name(args, parser): 25 | exp_name = '{}-{}'.format(args.dataset[:4], args.arch[:]) 26 | dict_args = vars(args) 27 | 28 | # sort so that we get a consistent directory name 29 | argnames = sorted(dict_args) 30 | ignorelist = ['date', 'exp', 'arch','prev_best_filepath', 'lr_schedule', 'max_cu_epoch', 'max_epoch', 31 | 'strict_bdr_cls', 'world_size', 'tb_path','best_record', 'test_mode', 'ckpt', 'coarse_boost_classes', 32 | 'crop_size', 'dist_url', 'syncbn', 'max_iter', 'color_aug', 'scale_max', 'scale_min', 'bs_mult', 33 | 'class_uniform_pct', 'class_uniform_tile'] 34 | # build experiment name with non-default args 35 | for argname in argnames: 36 | if dict_args[argname] != parser.get_default(argname): 37 | if argname in ignorelist: 38 | continue 39 | if argname == 'snapshot': 40 | arg_str = 'PT' 41 | argname = '' 42 | elif argname == 'nosave': 43 | arg_str = '' 44 | argname='' 45 | elif argname == 'freeze_trunk': 46 | argname = '' 47 | arg_str = 'ft' 48 | elif argname == 'syncbn': 49 | argname = '' 50 | arg_str = 'sbn' 51 | elif argname == 'jointwtborder': 52 | argname = '' 53 | arg_str = 'rlx_loss' 54 | elif isinstance(dict_args[argname], bool): 55 | arg_str = 'T' if dict_args[argname] else 'F' 56 | else: 57 | arg_str = str(dict_args[argname])[:7] 58 | if argname is not '': 59 | exp_name += '_{}_{}'.format(str(argname), arg_str) 60 | else: 61 | exp_name += '_{}'.format(arg_str) 62 | # clean special chars out exp_name = re.sub(r'[^A-Za-z0-9_\-]+', '', exp_name) 63 | return exp_name 64 | 65 | def fast_hist(label_pred, label_true, num_classes): 66 | mask = (label_true >= 0) & (label_true < num_classes) 67 | hist = np.bincount( 68 | num_classes * label_true[mask].astype(int) + 69 | label_pred[mask], minlength=num_classes ** 2).reshape(num_classes, num_classes) 70 | return hist 71 | 72 | def per_class_iu(hist): 73 | return np.diag(hist) / (hist.sum(1) + hist.sum(0) - np.diag(hist)) 74 | 75 | def save_log(prefix, output_dir, date_str, rank=0): 76 | fmt = '%(asctime)s.%(msecs)03d %(message)s' 77 | date_fmt = '%m-%d %H:%M:%S' 78 | filename = os.path.join(output_dir, prefix + '_' + date_str +'_rank_' + str(rank) +'.log') 79 | print("Logging :", filename) 80 | logging.basicConfig(level=logging.INFO, format=fmt, datefmt=date_fmt, 81 | filename=filename, filemode='w') 82 | console = logging.StreamHandler() 83 | console.setLevel(logging.INFO) 84 | formatter = logging.Formatter(fmt=fmt, datefmt=date_fmt) 85 | console.setFormatter(formatter) 86 | if rank == 0: 87 | logging.getLogger('').addHandler(console) 88 | else: 89 | fh = logging.FileHandler(filename) 90 | logging.getLogger('').addHandler(fh) 91 | 92 | 93 | 94 | def prep_experiment(args, parser): 95 | """ 96 | Make output directories, setup logging, Tensorboard, snapshot code. 97 | """ 98 | ckpt_path = args.ckpt 99 | tb_path = args.tb_path 100 | exp_name = make_exp_name(args, parser) 101 | args.exp_path = os.path.join(ckpt_path, args.date, args.exp, str(datetime.now().strftime('%m_%d_%H'))) 102 | args.tb_exp_path = os.path.join(tb_path, args.date, args.exp, str(datetime.now().strftime('%m_%d_%H'))) 103 | args.ngpu = torch.cuda.device_count() 104 | args.date_str = str(datetime.now().strftime('%Y_%m_%d_%H_%M_%S')) 105 | args.best_record = {} 106 | # args.best_record = {'epoch': -1, 'iter': 0, 'val_loss': 1e10, 'acc': 0, 107 | # 'acc_cls': 0, 'mean_iu': 0, 'fwavacc': 0} 108 | args.last_record = {} 109 | if args.local_rank == 0: 110 | os.makedirs(args.exp_path, exist_ok=True) 111 | os.makedirs(args.tb_exp_path, exist_ok=True) 112 | save_log('log', args.exp_path, args.date_str, rank=args.local_rank) 113 | open(os.path.join(args.exp_path, args.date_str + '.txt'), 'w').write( 114 | str(args) + '\n\n') 115 | writer = SummaryWriter(log_dir=args.tb_exp_path, comment=args.tb_tag) 116 | return writer 117 | return None 118 | 119 | def evaluate_eval_for_inference(hist, dataset=None): 120 | """ 121 | Modified IOU mechanism for on-the-fly IOU calculations ( prevents memory overflow for 122 | large dataset) Only applies to eval/eval.py 123 | """ 124 | # axis 0: gt, axis 1: prediction 125 | acc = np.diag(hist).sum() / hist.sum() 126 | acc_cls = np.diag(hist) / hist.sum(axis=1) 127 | acc_cls = np.nanmean(acc_cls) 128 | iu = np.diag(hist) / (hist.sum(axis=1) + hist.sum(axis=0) - np.diag(hist)) 129 | 130 | print_evaluate_results(hist, iu, dataset=dataset) 131 | freq = hist.sum(axis=1) / hist.sum() 132 | mean_iu = np.nanmean(iu) 133 | logging.info('mean {}'.format(mean_iu)) 134 | fwavacc = (freq[freq > 0] * iu[freq > 0]).sum() 135 | return acc, acc_cls, mean_iu, fwavacc 136 | 137 | 138 | 139 | def evaluate_eval(args, net, optimizer, scheduler, val_loss, hist, dump_images, writer, epoch=0, dataset_name=None, dataset=None, curr_iter=0, optimizer_at=None, scheduler_at=None, save_pth=True): 140 | """ 141 | Modified IOU mechanism for on-the-fly IOU calculations ( prevents memory overflow for 142 | large dataset) Only applies to eval/eval.py 143 | """ 144 | if val_loss is not None and hist is not None: 145 | # axis 0: gt, axis 1: prediction 146 | acc = np.diag(hist).sum() / hist.sum() 147 | acc_cls = np.diag(hist) / hist.sum(axis=1) 148 | acc_cls = np.nanmean(acc_cls) 149 | iu = np.diag(hist) / (hist.sum(axis=1) + hist.sum(axis=0) - np.diag(hist)) 150 | 151 | print_evaluate_results(hist, iu, dataset_name=dataset_name, dataset=dataset) 152 | freq = hist.sum(axis=1) / hist.sum() 153 | mean_iu = np.nanmean(iu) 154 | logging.info('mean {}'.format(mean_iu)) 155 | fwavacc = (freq[freq > 0] * iu[freq > 0]).sum() 156 | else: 157 | mean_iu = 0 158 | 159 | if dataset_name not in args.last_record.keys(): 160 | args.last_record[dataset_name] = {} 161 | 162 | if save_pth: 163 | # update latest snapshot 164 | if 'mean_iu' in args.last_record[dataset_name]: 165 | last_snapshot = 'last_{}_epoch_{}_mean-iu_{:.5f}.pth'.format( 166 | dataset_name, args.last_record[dataset_name]['epoch'], 167 | args.last_record[dataset_name]['mean_iu']) 168 | last_snapshot = os.path.join(args.exp_path, last_snapshot) 169 | try: 170 | os.remove(last_snapshot) 171 | except OSError: 172 | pass 173 | 174 | last_snapshot = 'last_{}_epoch_{}_mean-iu_{:.5f}.pth'.format(dataset_name, epoch, mean_iu) 175 | last_snapshot = os.path.join(args.exp_path, last_snapshot) 176 | args.last_record[dataset_name]['mean_iu'] = mean_iu 177 | args.last_record[dataset_name]['epoch'] = epoch 178 | 179 | torch.cuda.synchronize() 180 | 181 | if optimizer_at is not None: 182 | torch.save({ 183 | 'state_dict': net.state_dict(), 184 | 'optimizer': optimizer.state_dict(), 185 | 'optimizer_at': optimizer_at.state_dict(), 186 | 'scheduler': scheduler.state_dict(), 187 | 'scheduler_at': scheduler_at.state_dict(), 188 | 'epoch': epoch, 189 | 'mean_iu': mean_iu, 190 | 'command': ' '.join(sys.argv[1:]) 191 | }, last_snapshot) 192 | else: 193 | torch.save({ 194 | 'state_dict': net.state_dict(), 195 | 'optimizer': optimizer.state_dict(), 196 | 'scheduler': scheduler.state_dict(), 197 | 'epoch': epoch, 198 | 'mean_iu': mean_iu, 199 | 'command': ' '.join(sys.argv[1:]) 200 | }, last_snapshot) 201 | 202 | if val_loss is not None and hist is not None: 203 | if dataset_name not in args.best_record.keys(): 204 | args.best_record[dataset_name] = {'epoch': -1, 'iter': 0, 'val_loss': 1e10, 'acc': 0, 205 | 'acc_cls': 0, 'mean_iu': 0, 'fwavacc': 0} 206 | # update best snapshot 207 | if mean_iu > args.best_record[dataset_name]['mean_iu'] : 208 | # remove old best snapshot 209 | if args.best_record[dataset_name]['epoch'] != -1: 210 | best_snapshot = 'best_{}_epoch_{}_mean-iu_{:.5f}.pth'.format( 211 | dataset_name, args.best_record[dataset_name]['epoch'], 212 | args.best_record[dataset_name]['mean_iu']) 213 | 214 | best_snapshot = os.path.join(args.exp_path, best_snapshot) 215 | assert os.path.exists(best_snapshot), \ 216 | 'cant find old snapshot {}'.format(best_snapshot) 217 | os.remove(best_snapshot) 218 | 219 | # save new best 220 | args.best_record[dataset_name]['val_loss'] = val_loss.avg 221 | args.best_record[dataset_name]['epoch'] = epoch 222 | args.best_record[dataset_name]['acc'] = acc 223 | args.best_record[dataset_name]['acc_cls'] = acc_cls 224 | args.best_record[dataset_name]['mean_iu'] = mean_iu 225 | args.best_record[dataset_name]['fwavacc'] = fwavacc 226 | 227 | best_snapshot = 'best_{}_epoch_{}_mean-iu_{:.5f}.pth'.format( 228 | dataset_name, args.best_record[dataset_name]['epoch'], 229 | args.best_record[dataset_name]['mean_iu']) 230 | best_snapshot = os.path.join(args.exp_path, best_snapshot) 231 | shutil.copyfile(last_snapshot, best_snapshot) 232 | else: 233 | logging.info("Saved file to {}".format(last_snapshot)) 234 | 235 | if val_loss is not None and hist is not None: 236 | logging.info('-' * 107) 237 | fmt_str = '[epoch %d], [dataset name %s], [val loss %.5f], [acc %.5f], [acc_cls %.5f], ' +\ 238 | '[mean_iu %.5f], [fwavacc %.5f]' 239 | logging.info(fmt_str % (epoch, dataset_name, val_loss.avg, acc, acc_cls, mean_iu, fwavacc)) 240 | if save_pth: 241 | fmt_str = 'best record: [dataset name %s], [val loss %.5f], [acc %.5f], [acc_cls %.5f], ' +\ 242 | '[mean_iu %.5f], [fwavacc %.5f], [epoch %d], ' 243 | logging.info(fmt_str % (dataset_name, 244 | args.best_record[dataset_name]['val_loss'], args.best_record[dataset_name]['acc'], 245 | args.best_record[dataset_name]['acc_cls'], args.best_record[dataset_name]['mean_iu'], 246 | args.best_record[dataset_name]['fwavacc'], args.best_record[dataset_name]['epoch'])) 247 | logging.info('-' * 107) 248 | 249 | if writer: 250 | # tensorboard logging of validation phase metrics 251 | writer.add_scalar('{}/acc'.format(dataset_name), acc, curr_iter) 252 | writer.add_scalar('{}/acc_cls'.format(dataset_name), acc_cls, curr_iter) 253 | writer.add_scalar('{}/mean_iu'.format(dataset_name), mean_iu, curr_iter) 254 | writer.add_scalar('{}/val_loss'.format(dataset_name), val_loss.avg, curr_iter) 255 | 256 | 257 | 258 | 259 | 260 | def print_evaluate_results(hist, iu, dataset_name=None, dataset=None): 261 | # fixme: Need to refactor this dict 262 | try: 263 | id2cat = dataset.id2cat 264 | except: 265 | id2cat = {i: i for i in range(datasets.num_classes)} 266 | iu_false_positive = hist.sum(axis=1) - np.diag(hist) 267 | iu_false_negative = hist.sum(axis=0) - np.diag(hist) 268 | iu_true_positive = np.diag(hist) 269 | 270 | logging.info('Dataset name: {}'.format(dataset_name)) 271 | logging.info('IoU:') 272 | logging.info('label_id label iU Precision Recall TP FP FN') 273 | for idx, i in enumerate(iu): 274 | # Format all of the strings: 275 | idx_string = "{:2d}".format(idx) 276 | class_name = "{:>13}".format(id2cat[idx]) if idx in id2cat else '' 277 | iu_string = '{:5.1f}'.format(i * 100) 278 | total_pixels = hist.sum() 279 | tp = '{:5.1f}'.format(100 * iu_true_positive[idx] / total_pixels) 280 | fp = '{:5.1f}'.format( 281 | iu_false_positive[idx] / iu_true_positive[idx]) 282 | fn = '{:5.1f}'.format(iu_false_negative[idx] / iu_true_positive[idx]) 283 | precision = '{:5.1f}'.format( 284 | iu_true_positive[idx] / (iu_true_positive[idx] + iu_false_positive[idx])) 285 | recall = '{:5.1f}'.format( 286 | iu_true_positive[idx] / (iu_true_positive[idx] + iu_false_negative[idx])) 287 | logging.info('{} {} {} {} {} {} {} {}'.format( 288 | idx_string, class_name, iu_string, precision, recall, tp, fp, fn)) 289 | 290 | 291 | 292 | 293 | class AverageMeter(object): 294 | 295 | def __init__(self): 296 | self.reset() 297 | 298 | def reset(self): 299 | self.val = 0 300 | self.avg = 0 301 | self.sum = 0 302 | self.count = 0 303 | 304 | def update(self, val, n=1): 305 | self.val = val 306 | self.sum += val * n 307 | self.count += n 308 | self.avg = self.sum / self.count 309 | -------------------------------------------------------------------------------- /utils/my_data_parallel.py: -------------------------------------------------------------------------------- 1 | 2 | """ 3 | # Code adapted from: 4 | # https://github.com/pytorch/pytorch/blob/master/torch/nn/parallel/data_parallel.py 5 | # 6 | # BSD 3-Clause License 7 | # 8 | # Copyright (c) 2017, 9 | # All rights reserved. 10 | # 11 | # Redistribution and use in source and binary forms, with or without 12 | # modification, are permitted provided that the following conditions are met: 13 | # 14 | # * Redistributions of source code must retain the above copyright notice, this 15 | # list of conditions and the following disclaimer. 16 | # 17 | # * Redistributions in binary form must reproduce the above copyright notice, 18 | # this list of conditions and the following disclaimer in the documentation 19 | # and/or other materials provided with the distribution. 20 | # 21 | # * Neither the name of the copyright holder nor the names of its 22 | # contributors may be used to endorse or promote products derived from 23 | # this software without specific prior written permission. 24 | # 25 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 26 | # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 27 | # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 28 | # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 29 | # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 30 | # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 31 | # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 32 | # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 33 | # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 34 | # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.s 35 | """ 36 | 37 | 38 | import operator 39 | import torch 40 | import warnings 41 | from torch.nn.modules import Module 42 | from torch.nn.parallel.scatter_gather import scatter_kwargs, gather 43 | from torch.nn.parallel.replicate import replicate 44 | from torch.nn.parallel.parallel_apply import parallel_apply 45 | 46 | 47 | def _check_balance(device_ids): 48 | imbalance_warn = """ 49 | There is an imbalance between your GPUs. You may want to exclude GPU {} which 50 | has less than 75% of the memory or cores of GPU {}. You can do so by setting 51 | the device_ids argument to DataParallel, or by setting the CUDA_VISIBLE_DEVICES 52 | environment variable.""" 53 | 54 | dev_props = [torch.cuda.get_device_properties(i) for i in device_ids] 55 | 56 | def warn_imbalance(get_prop): 57 | values = [get_prop(props) for props in dev_props] 58 | min_pos, min_val = min(enumerate(values), key=operator.itemgetter(1)) 59 | max_pos, max_val = max(enumerate(values), key=operator.itemgetter(1)) 60 | if min_val / max_val < 0.75: 61 | warnings.warn(imbalance_warn.format(device_ids[min_pos], device_ids[max_pos])) 62 | return True 63 | return False 64 | 65 | if warn_imbalance(lambda props: props.total_memory): 66 | return 67 | if warn_imbalance(lambda props: props.multi_processor_count): 68 | return 69 | 70 | 71 | 72 | def data_parallel(module, inputs, device_ids=None, output_device=None, dim=0, module_kwargs=None, gather=True): 73 | """ 74 | Evaluates module(input) in parallel across the GPUs given in device_ids. 75 | This is the functional version of the DataParallel module. 76 | Args: 77 | module: the module to evaluate in parallel 78 | inputs: inputs to the module 79 | device_ids: GPU ids on which to replicate module 80 | output_device: GPU location of the output Use -1 to indicate the CPU. 81 | (default: device_ids[0]) 82 | Returns: 83 | a Tensor containing the result of module(input) located on 84 | output_device 85 | """ 86 | if not isinstance(inputs, tuple): 87 | inputs = (inputs,) 88 | 89 | if device_ids is None: 90 | device_ids = list(range(torch.cuda.device_count())) 91 | 92 | if output_device is None: 93 | output_device = device_ids[0] 94 | 95 | inputs, module_kwargs = scatter_kwargs(inputs, module_kwargs, device_ids, dim) 96 | if len(device_ids) == 1: 97 | return module(*inputs[0], **module_kwargs[0]) 98 | used_device_ids = device_ids[:len(inputs)] 99 | replicas = replicate(module, used_device_ids) 100 | outputs = parallel_apply(replicas, inputs, module_kwargs, used_device_ids) 101 | if gather: 102 | return gather(outputs, output_device, dim) 103 | else: 104 | return outputs 105 | 106 | 107 | 108 | class MyDataParallel(Module): 109 | """ 110 | Implements data parallelism at the module level. 111 | This container parallelizes the application of the given module by 112 | splitting the input across the specified devices by chunking in the batch 113 | dimension. In the forward pass, the module is replicated on each device, 114 | and each replica handles a portion of the input. During the backwards 115 | pass, gradients from each replica are summed into the original module. 116 | The batch size should be larger than the number of GPUs used. 117 | See also: :ref:`cuda-nn-dataparallel-instead` 118 | Arbitrary positional and keyword inputs are allowed to be passed into 119 | DataParallel EXCEPT Tensors. All tensors will be scattered on dim 120 | specified (default 0). Primitive types will be broadcasted, but all 121 | other types will be a shallow copy and can be corrupted if written to in 122 | the model's forward pass. 123 | .. warning:: 124 | Forward and backward hooks defined on :attr:`module` and its submodules 125 | will be invoked ``len(device_ids)`` times, each with inputs located on 126 | a particular device. Particularly, the hooks are only guaranteed to be 127 | executed in correct order with respect to operations on corresponding 128 | devices. For example, it is not guaranteed that hooks set via 129 | :meth:`~torch.nn.Module.register_forward_pre_hook` be executed before 130 | `all` ``len(device_ids)`` :meth:`~torch.nn.Module.forward` calls, but 131 | that each such hook be executed before the corresponding 132 | :meth:`~torch.nn.Module.forward` call of that device. 133 | .. warning:: 134 | When :attr:`module` returns a scalar (i.e., 0-dimensional tensor) in 135 | :func:`forward`, this wrapper will return a vector of length equal to 136 | number of devices used in data parallelism, containing the result from 137 | each device. 138 | .. note:: 139 | There is a subtlety in using the 140 | ``pack sequence -> recurrent network -> unpack sequence`` pattern in a 141 | :class:`~torch.nn.Module` wrapped in :class:`~torch.nn.DataParallel`. 142 | See :ref:`pack-rnn-unpack-with-data-parallelism` section in FAQ for 143 | details. 144 | Args: 145 | module: module to be parallelized 146 | device_ids: CUDA devices (default: all devices) 147 | output_device: device location of output (default: device_ids[0]) 148 | Attributes: 149 | module (Module): the module to be parallelized 150 | Example:: 151 | >>> net = torch.nn.DataParallel(model, device_ids=[0, 1, 2]) 152 | >>> output = net(input_var) 153 | """ 154 | 155 | # TODO: update notes/cuda.rst when this class handles 8+ GPUs well 156 | 157 | def __init__(self, module, device_ids=None, output_device=None, dim=0, gather=True): 158 | super(MyDataParallel, self).__init__() 159 | 160 | if not torch.cuda.is_available(): 161 | self.module = module 162 | self.device_ids = [] 163 | return 164 | 165 | if device_ids is None: 166 | device_ids = list(range(torch.cuda.device_count())) 167 | if output_device is None: 168 | output_device = device_ids[0] 169 | self.dim = dim 170 | self.module = module 171 | self.device_ids = device_ids 172 | self.output_device = output_device 173 | self.gather_bool = gather 174 | 175 | _check_balance(self.device_ids) 176 | 177 | if len(self.device_ids) == 1: 178 | self.module.cuda(device_ids[0]) 179 | 180 | def forward(self, *inputs, **kwargs): 181 | if not self.device_ids: 182 | return self.module(*inputs, **kwargs) 183 | inputs, kwargs = self.scatter(inputs, kwargs, self.device_ids) 184 | if len(self.device_ids) == 1: 185 | return [self.module(*inputs[0], **kwargs[0])] 186 | replicas = self.replicate(self.module, self.device_ids[:len(inputs)]) 187 | outputs = self.parallel_apply(replicas, inputs, kwargs) 188 | if self.gather_bool: 189 | return self.gather(outputs, self.output_device) 190 | else: 191 | return outputs 192 | 193 | def replicate(self, module, device_ids): 194 | return replicate(module, device_ids) 195 | 196 | def scatter(self, inputs, kwargs, device_ids): 197 | return scatter_kwargs(inputs, kwargs, device_ids, dim=self.dim) 198 | 199 | def parallel_apply(self, replicas, inputs, kwargs): 200 | return parallel_apply(replicas, inputs, kwargs, self.device_ids[:len(replicas)]) 201 | 202 | def gather(self, outputs, output_device): 203 | return gather(outputs, output_device, dim=self.dim) 204 | 205 | --------------------------------------------------------------------------------