├── LICENSE
├── README.md
├── assets
├── model_architecture.png
└── results.png
├── config.py
├── datasets
├── __init__.py
├── bdd100k.py
├── camvid.py
├── cityscapes.py
├── cityscapes_labels.py
├── gtav.py
├── kitti.py
├── mapillary.py
├── multi_loader.py
├── nullloader.py
├── sampler.py
├── synthia.py
└── uniform.py
├── eval.py
├── infer.py
├── loss.py
├── network
├── Mobilenet.py
├── Resnet.py
├── SEresnext.py
├── Shufflenet.py
├── __init__.py
├── __pycache__
│ ├── Mobilenet.cpython-36.pyc
│ ├── Mobilenet.cpython-37.pyc
│ ├── Resnet.cpython-36.pyc
│ ├── Resnet.cpython-37.pyc
│ ├── Shufflenet.cpython-36.pyc
│ ├── Shufflenet.cpython-37.pyc
│ ├── __init__.cpython-36.pyc
│ ├── __init__.cpython-37.pyc
│ ├── cov_settings.cpython-36.pyc
│ ├── cov_settings.cpython-37.pyc
│ ├── cwcl.cpython-36.pyc
│ ├── cwcl.cpython-37.pyc
│ ├── deepv3.cpython-36.pyc
│ ├── deepv3.cpython-37.pyc
│ ├── edge_contrast.cpython-36.pyc
│ ├── edge_contrast_batch.cpython-36.pyc
│ ├── edge_contrast_v1_opt.cpython-36.pyc
│ ├── edge_contrast_v1_opt.cpython-37.pyc
│ ├── edge_contrast_v2.cpython-36.pyc
│ ├── instance_whitening.cpython-36.pyc
│ ├── instance_whitening.cpython-37.pyc
│ ├── mynn.cpython-36.pyc
│ ├── mynn.cpython-37.pyc
│ ├── pixel_nce.cpython-36.pyc
│ ├── pixel_nce.cpython-37.pyc
│ ├── pixel_nce_batch.cpython-36.pyc
│ ├── sdcl.cpython-36.pyc
│ ├── sdcl.cpython-37.pyc
│ ├── sync_switchwhiten.cpython-36.pyc
│ └── sync_switchwhiten.cpython-37.pyc
├── bn_helper.py
├── cov_settings.py
├── cwcl.py
├── deepv3.py
├── instance_whitening.py
├── mynn.py
├── sdcl.py
├── switchwhiten.py
├── sync_switchwhiten.py
└── wider_resnet.py
├── optimizer.py
├── scripts
├── blindnet_infer_r50os16.sh
├── blindnet_train_r50os16_gtav.sh
└── blindnet_valid_r50os16_gtav.sh
├── split_data
├── gtav_split_test.txt
├── gtav_split_train.txt
├── gtav_split_val.txt
├── synthia_split_train.txt
└── synthia_split_val.txt
├── train.py
├── transforms
├── __init__.py
├── __pycache__
│ ├── joint_transforms.cpython-36.pyc
│ ├── joint_transforms.cpython-37.pyc
│ ├── transforms.cpython-36.pyc
│ └── transforms.cpython-37.pyc
├── joint_transforms.py
└── transforms.py
├── utils
├── __init__.py
├── attr_dict.py
├── misc.py
└── my_data_parallel.py
└── valid.py
/LICENSE:
--------------------------------------------------------------------------------
1 | BSD 3-Clause License
2 |
3 | Copyright (c) 2024, root0yang
4 |
5 | Redistribution and use in source and binary forms, with or without
6 | modification, are permitted provided that the following conditions are met:
7 |
8 | 1. Redistributions of source code must retain the above copyright notice, this
9 | list of conditions and the following disclaimer.
10 |
11 | 2. Redistributions in binary form must reproduce the above copyright notice,
12 | this list of conditions and the following disclaimer in the documentation
13 | and/or other materials provided with the distribution.
14 |
15 | 3. Neither the name of the copyright holder nor the names of its
16 | contributors may be used to endorse or promote products derived from
17 | this software without specific prior written permission.
18 |
19 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
20 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
21 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
22 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
23 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
24 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
25 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
26 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
27 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
28 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
29 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # BlindNet (CVPR 2024) : Official Project Webpage
2 | This repository provides the official PyTorch implementation of the following paper:
3 | > [**Style Blind Domain Generalized Semantic Segmentation via Covariance Alignment and Semantic Consistence Contrastive Learning**](https://arxiv.org/abs/2403.06122)
4 | > Woo-Jin Ahn, Geun-Yeong Yang, Hyun-Duck Choi, Myo-Taeg Lim
5 | > Korea University, Chonnam National University
6 |
7 | > **Abstract:**
8 | > *Deep learning models for semantic segmentation often
9 | experience performance degradation when deployed to unseen target domains unidentified during the training phase.
10 | This is mainly due to variations in image texture (i.e. style)
11 | from different data sources. To tackle this challenge, existing domain generalized semantic segmentation (DGSS)
12 | methods attempt to remove style variations from the feature. However, these approaches struggle with the entanglement of style and content, which may lead to the unintentional removal of crucial content information, causing
13 | performance degradation. This study addresses this limitation by proposing BlindNet, a novel DGSS approach that
14 | blinds the style without external modules or datasets. The
15 | main idea behind our proposed approach is to alleviate the
16 | effect of style in the encoder whilst facilitating robust segmentation in the decoder. To achieve this, BlindNet comprises two key components: covariance alignment and semantic consistency contrastive learning. Specifically, the
17 | covariance alignment trains the encoder to uniformly recognize various styles and preserve the content information
18 | of the feature, rather than removing the style-sensitive factor. Meanwhile, semantic consistency contrastive learning
19 | enables the decoder to construct discriminative class embedding space and disentangles features that are vulnerable to misclassification. Through extensive experiments,
20 | our approach outperforms existing DGSS methods, exhibiting robustness and superior performance for semantic segmentation on unseen target domains.*
21 |
22 |
23 |
24 |
25 |
26 |
27 | ## Pytorch Implementation
28 |
29 | Our pytorch implementation is heaviliy derived from [RobustNet](https://github.com/shachoi/RobustNet) (CVPR 2021). If you use this code in your research, please also cite their work.
30 | [[link to license](https://github.com/shachoi/RobustNet/blob/main/LICENSE)]
31 |
32 | ### Installation
33 | Clone this repository.
34 | ```
35 | git clone https://github.com/root0yang/BlindNet.git
36 | cd BlindNet
37 | ```
38 | Install following packages.
39 | ```
40 | conda create --name blindnet python=3.6
41 | conda activate blindnet
42 | conda install pytorch==1.2.0 cudatoolkit==10.2
43 | conda install scipy==1.1.0
44 | conda install tqdm==4.46.0
45 | conda install scikit-image==0.16.2
46 | pip install tensorboardX==2.4
47 | pip install thop
48 | imageio_download_bin freeimage
49 | ```
50 |
51 | ### How to Run BlindNet
52 | We evaluated the model on [Cityscapes](https://www.cityscapes-dataset.com/), [BDD-100K](https://bair.berkeley.edu/blog/2018/05/30/bdd/), [Synthia](https://synthia-dataset.net/downloads/) ([SYNTHIA-RAND-CITYSCAPES](http://synthia-dataset.net/download/808/)), [GTAV](https://download.visinf.tu-darmstadt.de/data/from_games/) and [Mapillary Vistas](https://www.mapillary.com/dataset/vistas?pKey=2ix3yvnjy9fwqdzwum3t9g&lat=20&lng=0&z=1.5).
53 |
54 | We adopt Class uniform sampling proposed in [this paper](https://openaccess.thecvf.com/content_CVPR_2019/papers/Zhu_Improving_Semantic_Segmentation_via_Video_Propagation_and_Label_Relaxation_CVPR_2019_paper.pdf) to handle class imbalance problems.
55 |
56 |
57 | 1. For Cityscapes dataset, download "leftImg8bit_trainvaltest.zip" and "gtFine_trainvaltest.zip" from https://www.cityscapes-dataset.com/downloads/
58 | Unzip the files and make the directory structures as follows.
59 | ```
60 | cityscapes
61 | └ leftImg8bit_trainvaltest
62 | └ leftImg8bit
63 | └ train
64 | └ val
65 | └ test
66 | └ gtFine_trainvaltest
67 | └ gtFine
68 | └ train
69 | └ val
70 | └ test
71 | ```
72 | ```
73 | bdd-100k
74 | └ images
75 | └ train
76 | └ val
77 | └ test
78 | └ labels
79 | └ train
80 | └ val
81 | ```
82 | ```
83 | mapillary
84 | └ training
85 | └ images
86 | └ labels
87 | └ validation
88 | └ images
89 | └ labels
90 | └ test
91 | └ images
92 | └ labels
93 | ```
94 |
95 | 2. We used [GTAV_Split](https://download.visinf.tu-darmstadt.de/data/from_games/code/read_mapping.zip) to split GTAV dataset into training/validation/test set. Please refer the txt files in [split_data](https://github.com/suhyeonlee/WildNet/tree/main/split_data).
96 |
97 | ```
98 | GTAV
99 | └ images
100 | └ train
101 | └ folder
102 | └ valid
103 | └ folder
104 | └ test
105 | └ folder
106 | └ labels
107 | └ train
108 | └ folder
109 | └ valid
110 | └ folder
111 | └ test
112 | └ folder
113 | ```
114 |
115 | 3. We split [Synthia dataset](http://synthia-dataset.net/download/808/) into train/val set following the [RobustNet](https://github.com/shachoi/RobustNet). Please refer the txt files in [split_data](https://github.com/suhyeonlee/WildNet/tree/main/split_data).
116 |
117 | ```
118 | synthia
119 | └ RGB
120 | └ train
121 | └ val
122 | └ GT
123 | └ COLOR
124 | └ train
125 | └ val
126 | └ LABELS
127 | └ train
128 | └ val
129 | ```
130 |
131 | 4. You should modify the path in **"/config.py"** according to your dataset path.
132 | ```
133 | #Cityscapes Dir Location
134 | __C.DATASET.CITYSCAPES_DIR =
135 | #Mapillary Dataset Dir Location
136 | __C.DATASET.MAPILLARY_DIR =
137 | #GTAV Dataset Dir Location
138 | __C.DATASET.GTAV_DIR =
139 | #BDD-100K Dataset Dir Location
140 | __C.DATASET.BDD_DIR =
141 | #Synthia Dataset Dir Location
142 | __C.DATASET.SYNTHIA_DIR =
143 | ```
144 | 5. You can train BlindNet with the following command.
145 | ```
146 | $ CUDA_VISIBLE_DEVICES=0,1 ./scripts/blindnet_train_r50os16_gtav.sh
147 | ```
148 |
149 | 6. You can download Our ResNet-50 model trained with GTAV at [Google Drive](https://drive.google.com/file/d/1Kkdl_2xjE9iooA1Is5VWcjeRcuzaJ8Pi/view?usp=drive_link) and validate pretrained model with the following command
150 | ```
151 | $ CUDA_VISIBLE_DEVICES=0,1 ./scripts/blindnet_valid_r50os16_gtav.sh
152 | ```
153 |
154 | 7. You can infer the segmentation results from images through pretrained model with following commands.
155 | ```
156 | $ CUDA_VISIBLE_DEVICES=0,1 ./scripts/blindnet_infer_r50os16.sh
157 | ```
158 |
159 | ## Citation
160 | If you find this work useful in your research, please cite our paper:
161 | ```
162 | @article{ahn2024style,
163 | title={Style Blind Domain Generalized Semantic Segmentation via Covariance Alignment and Semantic Consistence Contrastive Learning},
164 | author={Ahn, Woo-Jin and Yang, Geun-Yeong and Choi, Hyun-Duck and Lim, Myo-Taeg},
165 | journal={arXiv preprint arXiv:2403.06122},
166 | year={2024}
167 | }
168 | ```
169 |
170 | ## Terms of Use
171 | This software is for non-commercial use only.
172 | The source code is released under the Attribution-NonCommercial-ShareAlike (CC BY-NC-SA) Licence
173 | (see [this](https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode) for details)
174 |
--------------------------------------------------------------------------------
/assets/model_architecture.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/root0yang/BlindNet/1359fa6e9d1e9b011c416f43ec265e31cbde7d9a/assets/model_architecture.png
--------------------------------------------------------------------------------
/assets/results.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/root0yang/BlindNet/1359fa6e9d1e9b011c416f43ec265e31cbde7d9a/assets/results.png
--------------------------------------------------------------------------------
/config.py:
--------------------------------------------------------------------------------
1 | """
2 | # Code adapted from:
3 | # https://github.com/facebookresearch/Detectron/blob/master/detectron/core/config.py
4 |
5 | Source License
6 | # Copyright (c) 2017-present, Facebook, Inc.
7 | #
8 | # Licensed under the Apache License, Version 2.0 (the "License");
9 | # you may not use this file except in compliance with the License.
10 | # You may obtain a copy of the License at
11 | #
12 | # http://www.apache.org/licenses/LICENSE-2.0
13 | #
14 | # Unless required by applicable law or agreed to in writing, software
15 | # distributed under the License is distributed on an "AS IS" BASIS,
16 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
17 | # See the License for the specific language governing permissions and
18 | # limitations under the License.
19 | ##############################################################################
20 | #
21 | # Based on:
22 | # --------------------------------------------------------
23 | # Fast R-CNN
24 | # Copyright (c) 2015 Microsoft
25 | # Licensed under The MIT License [see LICENSE for details]
26 | # Written by Ross Girshick
27 | # --------------------------------------------------------
28 | """
29 | ##############################################################################
30 | # Config
31 | # #############################################################################
32 |
33 |
34 | from __future__ import absolute_import
35 | from __future__ import division
36 | from __future__ import print_function
37 | from __future__ import unicode_literals
38 |
39 |
40 | import torch
41 |
42 |
43 | from utils.attr_dict import AttrDict
44 |
45 |
46 | __C = AttrDict()
47 | cfg = __C
48 | __C.ITER = 0
49 | __C.EPOCH = 0
50 |
51 | __C.RANDOM_SEED = 304
52 | # Use Class Uniform Sampling to give each class proper sampling
53 | __C.CLASS_UNIFORM_PCT = 0.0
54 |
55 | # Use class weighted loss per batch to increase loss for low pixel count classes per batch
56 | __C.BATCH_WEIGHTING = False
57 |
58 | # Border Relaxation Count
59 | __C.BORDER_WINDOW = 1
60 | # Number of epoch to use before turn off border restriction
61 | __C.REDUCE_BORDER_ITER = -1
62 | __C.REDUCE_BORDER_EPOCH = -1
63 | # Comma Seperated List of class id to relax
64 | __C.STRICTBORDERCLASS = None
65 |
66 |
67 |
68 | #Attribute Dictionary for Dataset
69 | __C.DATASET = AttrDict()
70 | #Cityscapes Dir Location
71 | __C.DATASET.CITYSCAPES_DIR = '/data3/yang/datasets/cityscapes'
72 | #SDC Augmented Cityscapes Dir Location
73 | __C.DATASET.CITYSCAPES_AUG_DIR = ''
74 | #Mapillary Dataset Dir Location
75 | __C.DATASET.MAPILLARY_DIR = '/data3/yang/datasets/mapillary'
76 | #GTAV, BDD100K Dataset Dir Location
77 | __C.DATASET.GTAV_DIR = '/data3/yang/datasets/gtav'
78 | __C.DATASET.BDD_DIR = '/data3/yang/datasets/bdd-100k'
79 | #Synthia Dataset Dir Location
80 | __C.DATASET.SYNTHIA_DIR = '/data3/yang/datasets/synthia'
81 | #Kitti Dataset Dir Location
82 | __C.DATASET.KITTI_DIR = ''
83 | #SDC Augmented Kitti Dataset Dir Location
84 | __C.DATASET.KITTI_AUG_DIR = ''
85 | #Camvid Dataset Dir Location
86 | __C.DATASET.CAMVID_DIR = '/home/nas_datasets/segmentation/SegNet-Tutorial/CamVid'
87 | #Number of splits to support
88 | __C.DATASET.CV_SPLITS = 3
89 |
90 |
91 | __C.MODEL = AttrDict()
92 | __C.MODEL.BN = 'pytorch-syncnorm'
93 | __C.MODEL.BNFUNC = torch.nn.SyncBatchNorm
94 |
95 | def assert_and_infer_cfg(args, make_immutable=True, train_mode=True):
96 | """Call this function in your script after you have finished setting all cfg
97 | values that are necessary (e.g., merging a config from a file, merging
98 | command line config options, etc.). By default, this function will also
99 | mark the global cfg as immutable to prevent changing the global cfg settings
100 | during script execution (which can lead to hard to debug errors or code
101 | that's harder to understand than is necessary).
102 | """
103 |
104 | if hasattr(args, 'syncbn') and args.syncbn:
105 | __C.MODEL.BN = 'pytorch-syncnorm'
106 | __C.MODEL.BNFUNC = torch.nn.SyncBatchNorm
107 | print('Using pytorch sync batch norm')
108 | else:
109 | __C.MODEL.BNFUNC = torch.nn.BatchNorm2d
110 | print('Using regular batch norm')
111 |
112 | if not train_mode:
113 | cfg.immutable(True)
114 | return
115 | if args.class_uniform_pct:
116 | cfg.CLASS_UNIFORM_PCT = args.class_uniform_pct
117 |
118 | if args.batch_weighting:
119 | __C.BATCH_WEIGHTING = True
120 |
121 | if args.jointwtborder:
122 | if args.strict_bdr_cls != '':
123 | __C.STRICTBORDERCLASS = [int(i) for i in args.strict_bdr_cls.split(",")]
124 | if args.rlx_off_iter > -1:
125 | __C.REDUCE_BORDER_ITER = args.rlx_off_iter
126 |
127 | if make_immutable:
128 | cfg.immutable(True)
129 |
--------------------------------------------------------------------------------
/datasets/cityscapes_labels.py:
--------------------------------------------------------------------------------
1 | """
2 | # File taken from https://github.com/mcordts/cityscapesScripts/
3 | # License File Available at:
4 | # https://github.com/mcordts/cityscapesScripts/blob/master/license.txt
5 |
6 | # ----------------------
7 | # The Cityscapes Dataset
8 | # ----------------------
9 | #
10 | #
11 | # License agreement
12 | # -----------------
13 | #
14 | # This dataset is made freely available to academic and non-academic entities for non-commercial purposes such as academic research, teaching, scientific publications, or personal experimentation. Permission is granted to use the data given that you agree:
15 | #
16 | # 1. That the dataset comes "AS IS", without express or implied warranty. Although every effort has been made to ensure accuracy, we (Daimler AG, MPI Informatics, TU Darmstadt) do not accept any responsibility for errors or omissions.
17 | # 2. That you include a reference to the Cityscapes Dataset in any work that makes use of the dataset. For research papers, cite our preferred publication as listed on our website; for other media cite our preferred publication as listed on our website or link to the Cityscapes website.
18 | # 3. That you do not distribute this dataset or modified versions. It is permissible to distribute derivative works in as far as they are abstract representations of this dataset (such as models trained on it or additional annotations that do not directly include any of our data) and do not allow to recover the dataset or something similar in character.
19 | # 4. That you may not use the dataset or any derivative work for commercial purposes as, for example, licensing or selling the data, or using the data with a purpose to procure a commercial gain.
20 | # 5. That all rights not expressly granted to you are reserved by us (Daimler AG, MPI Informatics, TU Darmstadt).
21 | #
22 | #
23 | # Contact
24 | # -------
25 | #
26 | # Marius Cordts, Mohamed Omran
27 | # www.cityscapes-dataset.net
28 |
29 | """
30 | from collections import namedtuple
31 |
32 |
33 | #--------------------------------------------------------------------------------
34 | # Definitions
35 | #--------------------------------------------------------------------------------
36 |
37 | # a label and all meta information
38 | Label = namedtuple( 'Label' , [
39 |
40 | 'name' , # The identifier of this label, e.g. 'car', 'person', ... .
41 | # We use them to uniquely name a class
42 |
43 | 'id' , # An integer ID that is associated with this label.
44 | # The IDs are used to represent the label in ground truth images
45 | # An ID of -1 means that this label does not have an ID and thus
46 | # is ignored when creating ground truth images (e.g. license plate).
47 | # Do not modify these IDs, since exactly these IDs are expected by the
48 | # evaluation server.
49 |
50 | 'trainId' , # Feel free to modify these IDs as suitable for your method. Then create
51 | # ground truth images with train IDs, using the tools provided in the
52 | # 'preparation' folder. However, make sure to validate or submit results
53 | # to our evaluation server using the regular IDs above!
54 | # For trainIds, multiple labels might have the same ID. Then, these labels
55 | # are mapped to the same class in the ground truth images. For the inverse
56 | # mapping, we use the label that is defined first in the list below.
57 | # For example, mapping all void-type classes to the same ID in training,
58 | # might make sense for some approaches.
59 | # Max value is 255!
60 |
61 | 'category' , # The name of the category that this label belongs to
62 |
63 | 'categoryId' , # The ID of this category. Used to create ground truth images
64 | # on category level.
65 |
66 | 'hasInstances', # Whether this label distinguishes between single instances or not
67 |
68 | 'ignoreInEval', # Whether pixels having this class as ground truth label are ignored
69 | # during evaluations or not
70 |
71 | 'color' , # The color of this label
72 | ] )
73 |
74 |
75 | #--------------------------------------------------------------------------------
76 | # A list of all labels
77 | #--------------------------------------------------------------------------------
78 |
79 | # Please adapt the train IDs as appropriate for you approach.
80 | # Note that you might want to ignore labels with ID 255 during training.
81 | # Further note that the current train IDs are only a suggestion. You can use whatever you like.
82 | # Make sure to provide your results using the original IDs and not the training IDs.
83 | # Note that many IDs are ignored in evaluation and thus you never need to predict these!
84 |
85 | labels = [
86 | # name id trainId category catId hasInstances ignoreInEval color
87 | Label( 'unlabeled' , 0 , 255 , 'void' , 0 , False , True , ( 0, 0, 0) ),
88 | Label( 'ego vehicle' , 1 , 255 , 'void' , 0 , False , True , ( 0, 0, 0) ),
89 | Label( 'rectification border' , 2 , 255 , 'void' , 0 , False , True , ( 0, 0, 0) ),
90 | Label( 'out of roi' , 3 , 255 , 'void' , 0 , False , True , ( 0, 0, 0) ),
91 | Label( 'static' , 4 , 255 , 'void' , 0 , False , True , ( 0, 0, 0) ),
92 | Label( 'dynamic' , 5 , 255 , 'void' , 0 , False , True , (111, 74, 0) ),
93 | Label( 'ground' , 6 , 255 , 'void' , 0 , False , True , ( 81, 0, 81) ),
94 | Label( 'road' , 7 , 0 , 'flat' , 1 , False , False , (128, 64,128) ),
95 | Label( 'sidewalk' , 8 , 1 , 'flat' , 1 , False , False , (244, 35,232) ),
96 | Label( 'parking' , 9 , 255 , 'flat' , 1 , False , True , (250,170,160) ),
97 | Label( 'rail track' , 10 , 255 , 'flat' , 1 , False , True , (230,150,140) ),
98 | Label( 'building' , 11 , 2 , 'construction' , 2 , False , False , ( 70, 70, 70) ),
99 | Label( 'wall' , 12 , 3 , 'construction' , 2 , False , False , (102,102,156) ),
100 | Label( 'fence' , 13 , 4 , 'construction' , 2 , False , False , (190,153,153) ),
101 | Label( 'guard rail' , 14 , 255 , 'construction' , 2 , False , True , (180,165,180) ),
102 | Label( 'bridge' , 15 , 255 , 'construction' , 2 , False , True , (150,100,100) ),
103 | Label( 'tunnel' , 16 , 255 , 'construction' , 2 , False , True , (150,120, 90) ),
104 | Label( 'pole' , 17 , 5 , 'object' , 3 , False , False , (153,153,153) ),
105 | Label( 'polegroup' , 18 , 255 , 'object' , 3 , False , True , (153,153,154) ), # (153,153,153)
106 | Label( 'traffic light' , 19 , 6 , 'object' , 3 , False , False , (250,170, 30) ),
107 | Label( 'traffic sign' , 20 , 7 , 'object' , 3 , False , False , (220,220, 0) ),
108 | Label( 'vegetation' , 21 , 8 , 'nature' , 4 , False , False , (107,142, 35) ),
109 | Label( 'terrain' , 22 , 9 , 'nature' , 4 , False , False , (152,251,152) ),
110 | Label( 'sky' , 23 , 10 , 'sky' , 5 , False , False , ( 70,130,180) ),
111 | Label( 'person' , 24 , 11 , 'human' , 6 , True , False , (220, 20, 60) ),
112 | Label( 'rider' , 25 , 12 , 'human' , 6 , True , False , (255, 0, 0) ),
113 | Label( 'car' , 26 , 13 , 'vehicle' , 7 , True , False , ( 0, 0,142) ),
114 | Label( 'truck' , 27 , 14 , 'vehicle' , 7 , True , False , ( 0, 0, 70) ),
115 | Label( 'bus' , 28 , 15 , 'vehicle' , 7 , True , False , ( 0, 60,100) ),
116 | Label( 'caravan' , 29 , 255 , 'vehicle' , 7 , True , True , ( 0, 0, 90) ),
117 | Label( 'trailer' , 30 , 255 , 'vehicle' , 7 , True , True , ( 0, 0,110) ),
118 | Label( 'train' , 31 , 16 , 'vehicle' , 7 , True , False , ( 0, 80,100) ),
119 | Label( 'motorcycle' , 32 , 17 , 'vehicle' , 7 , True , False , ( 0, 0,230) ),
120 | Label( 'bicycle' , 33 , 18 , 'vehicle' , 7 , True , False , (119, 11, 32) ),
121 | Label( 'license plate' , -1 , -1 , 'vehicle' , 7 , False , True , ( 0, 0,143) ), # ( 0, 0,142)
122 | ]
123 |
124 |
125 | #--------------------------------------------------------------------------------
126 | # Create dictionaries for a fast lookup
127 | #--------------------------------------------------------------------------------
128 |
129 | # Please refer to the main method below for example usages!
130 |
131 | # name to label object
132 | name2label = { label.name : label for label in labels }
133 | # id to label object
134 | id2label = { label.id : label for label in labels }
135 | # trainId to label object
136 | trainId2label = { label.trainId : label for label in reversed(labels) }
137 | # label2trainid
138 | label2trainid = { label.id : label.trainId for label in labels }
139 | # trainId to label object
140 | trainId2name = { label.trainId : label.name for label in labels }
141 | trainId2color = { label.trainId : label.color for label in labels }
142 |
143 | color2trainId = { label.color : label.trainId for label in labels }
144 |
145 | trainId2trainId = { label.trainId : label.trainId for label in labels }
146 |
147 | # category to list of label objects
148 | category2labels = {}
149 | for label in labels:
150 | category = label.category
151 | if category in category2labels:
152 | category2labels[category].append(label)
153 | else:
154 | category2labels[category] = [label]
155 |
156 | #--------------------------------------------------------------------------------
157 | # Assure single instance name
158 | #--------------------------------------------------------------------------------
159 |
160 | # returns the label name that describes a single instance (if possible)
161 | # e.g. input | output
162 | # ----------------------
163 | # car | car
164 | # cargroup | car
165 | # foo | None
166 | # foogroup | None
167 | # skygroup | None
168 | def assureSingleInstanceName( name ):
169 | # if the name is known, it is not a group
170 | if name in name2label:
171 | return name
172 | # test if the name actually denotes a group
173 | if not name.endswith("group"):
174 | return None
175 | # remove group
176 | name = name[:-len("group")]
177 | # test if the new name exists
178 | if not name in name2label:
179 | return None
180 | # test if the new name denotes a label that actually has instances
181 | if not name2label[name].hasInstances:
182 | return None
183 | # all good then
184 | return name
185 |
186 | #--------------------------------------------------------------------------------
187 | # Main for testing
188 | #--------------------------------------------------------------------------------
189 |
190 | # just a dummy main
191 | if __name__ == "__main__":
192 | # Print all the labels
193 | print("List of cityscapes labels:")
194 | print("")
195 | print((" {:>21} | {:>3} | {:>7} | {:>14} | {:>10} | {:>12} | {:>12}".format( 'name', 'id', 'trainId', 'category', 'categoryId', 'hasInstances', 'ignoreInEval' )))
196 | print((" " + ('-' * 98)))
197 | for label in labels:
198 | print((" {:>21} | {:>3} | {:>7} | {:>14} | {:>10} | {:>12} | {:>12}".format( label.name, label.id, label.trainId, label.category, label.categoryId, label.hasInstances, label.ignoreInEval )))
199 | print("")
200 |
201 | print("Example usages:")
202 |
203 | # Map from name to label
204 | name = 'car'
205 | id = name2label[name].id
206 | print(("ID of label '{name}': {id}".format( name=name, id=id )))
207 |
208 | # Map from ID to label
209 | category = id2label[id].category
210 | print(("Category of label with ID '{id}': {category}".format( id=id, category=category )))
211 |
212 | # Map from trainID to label
213 | trainId = 0
214 | name = trainId2label[trainId].name
215 | print(("Name of label with trainID '{id}': {name}".format( id=trainId, name=name )))
216 |
--------------------------------------------------------------------------------
/datasets/kitti.py:
--------------------------------------------------------------------------------
1 | """
2 | KITTI Dataset Loader
3 | """
4 |
5 | import os
6 | import sys
7 | import numpy as np
8 | from PIL import Image
9 | from torch.utils import data
10 | import logging
11 | import datasets.uniform as uniform
12 | import datasets.cityscapes_labels as cityscapes_labels
13 | import json
14 | from config import cfg
15 |
16 |
17 | trainid_to_name = cityscapes_labels.trainId2name
18 | id_to_trainid = cityscapes_labels.label2trainid
19 | num_classes = 19
20 | ignore_label = 255
21 | root = cfg.DATASET.KITTI_DIR
22 | aug_root = cfg.DATASET.KITTI_AUG_DIR
23 |
24 | palette = [128, 64, 128, 244, 35, 232, 70, 70, 70, 102, 102, 156, 190, 153, 153,
25 | 153, 153, 153, 250, 170, 30,
26 | 220, 220, 0, 107, 142, 35, 152, 251, 152, 70, 130, 180, 220, 20, 60,
27 | 255, 0, 0, 0, 0, 142, 0, 0, 70,
28 | 0, 60, 100, 0, 80, 100, 0, 0, 230, 119, 11, 32]
29 | zero_pad = 256 * 3 - len(palette)
30 | for i in range(zero_pad):
31 | palette.append(0)
32 |
33 | def colorize_mask(mask):
34 | # mask: numpy array of the mask
35 | new_mask = Image.fromarray(mask.astype(np.uint8)).convert('P')
36 | new_mask.putpalette(palette)
37 | return new_mask
38 |
39 | def get_train_val(cv_split, all_items):
40 |
41 | # 90/10 train/val split, three random splits
42 | val_0 = [1,5,11,29,35,49,57,68,72,82,93,115,119,130,145,154,156,167,169,189,198]
43 | val_1 = [0,12,24,31,42,50,63,71,84,96,101,112,121,133,141,155,164,171,187,191,197]
44 | val_2 = [3,6,13,21,41,54,61,73,88,91,110,121,126,131,142,149,150,163,173,183,199]
45 |
46 | train_set = []
47 | val_set = []
48 |
49 | if cv_split == 0:
50 | for i in range(200):
51 | if i in val_0:
52 | val_set.append(all_items[i])
53 | else:
54 | train_set.append(all_items[i])
55 | elif cv_split == 1:
56 | for i in range(200):
57 | if i in val_1:
58 | val_set.append(all_items[i])
59 | else:
60 | train_set.append(all_items[i])
61 | elif cv_split == 2:
62 | for i in range(200):
63 | if i in val_2:
64 | val_set.append(all_items[i])
65 | else:
66 | train_set.append(all_items[i])
67 | else:
68 | logging.info('Unknown cv_split {}'.format(cv_split))
69 | sys.exit()
70 |
71 | return train_set, val_set
72 |
73 | def make_dataset(quality, mode, maxSkip=0, cv_split=0, hardnm=0):
74 |
75 | items = []
76 | all_items = []
77 | aug_items = []
78 |
79 | assert quality == 'semantic'
80 | assert mode in ['train', 'val', 'trainval']
81 | # note that train and val are randomly determined, no official split
82 |
83 | img_dir_name = "training"
84 | img_path = os.path.join(root, img_dir_name, 'image_2')
85 | mask_path = os.path.join(root, img_dir_name, 'semantic')
86 |
87 | c_items = os.listdir(img_path)
88 | c_items.sort()
89 |
90 | for it in c_items:
91 | item = (os.path.join(img_path, it), os.path.join(mask_path, it))
92 | all_items.append(item)
93 | logging.info('KITTI has a total of {} images'.format(len(all_items)))
94 |
95 | # split into train/val
96 | train_set, val_set = get_train_val(cv_split, all_items)
97 |
98 | if mode == 'train':
99 | items = train_set
100 | elif mode == 'val':
101 | items = val_set
102 | elif mode == 'trainval':
103 | items = train_set + val_set
104 | else:
105 | logging.info('Unknown mode {}'.format(mode))
106 | sys.exit()
107 |
108 | logging.info('KITTI-{}: {} images'.format(mode, len(items)))
109 |
110 | return items, aug_items
111 |
112 | class KITTI(data.Dataset):
113 |
114 | def __init__(self, quality, mode, maxSkip=0, joint_transform_list=None,
115 | transform=None, target_transform=None, dump_images=False,
116 | class_uniform_pct=0, class_uniform_tile=0, test=False,
117 | cv_split=None, scf=None, hardnm=0):
118 |
119 | self.quality = quality
120 | self.mode = mode
121 | self.maxSkip = maxSkip
122 | self.joint_transform_list = joint_transform_list
123 | self.transform = transform
124 | self.target_transform = target_transform
125 | self.dump_images = dump_images
126 | self.class_uniform_pct = class_uniform_pct
127 | self.class_uniform_tile = class_uniform_tile
128 | self.scf = scf
129 | self.hardnm = hardnm
130 |
131 | if cv_split:
132 | self.cv_split = cv_split
133 | assert cv_split < cfg.DATASET.CV_SPLITS, \
134 | 'expected cv_split {} to be < CV_SPLITS {}'.format(
135 | cv_split, cfg.DATASET.CV_SPLITS)
136 | else:
137 | self.cv_split = 0
138 |
139 | self.imgs, _ = make_dataset(quality, mode, self.maxSkip, cv_split=self.cv_split, hardnm=self.hardnm)
140 | assert len(self.imgs), 'Found 0 images, please check the data set'
141 | # self.cal_shape(self.imgs)
142 |
143 | # Centroids for GT data
144 | if self.class_uniform_pct > 0:
145 | if self.scf:
146 | json_fn = 'kitti_tile{}_cv{}_scf.json'.format(self.class_uniform_tile, self.cv_split)
147 | else:
148 | json_fn = 'kitti_tile{}_cv{}_{}_hardnm{}.json'.format(self.class_uniform_tile, self.cv_split, self.mode, self.hardnm)
149 | if os.path.isfile(json_fn):
150 | with open(json_fn, 'r') as json_data:
151 | centroids = json.load(json_data)
152 | self.centroids = {int(idx): centroids[idx] for idx in centroids}
153 | else:
154 | if self.scf:
155 | self.centroids = kitti_uniform.class_centroids_all(
156 | self.imgs,
157 | num_classes,
158 | id2trainid=id_to_trainid,
159 | tile_size=class_uniform_tile)
160 | else:
161 | self.centroids = uniform.class_centroids_all(
162 | self.imgs,
163 | num_classes,
164 | id2trainid=id_to_trainid,
165 | tile_size=class_uniform_tile)
166 | with open(json_fn, 'w') as outfile:
167 | json.dump(self.centroids, outfile, indent=4)
168 |
169 | self.build_epoch()
170 |
171 |
172 | def cal_shape(self, imgs):
173 |
174 | for i in imgs:
175 | img_path, mask_path = i
176 | img = Image.open(img_path).convert('RGB')
177 | print(img.size)
178 |
179 | def build_epoch(self, cut=False):
180 | if self.class_uniform_pct > 0:
181 | self.imgs_uniform = uniform.build_epoch(self.imgs,
182 | self.centroids,
183 | num_classes,
184 | cfg.CLASS_UNIFORM_PCT)
185 | else:
186 | self.imgs_uniform = self.imgs
187 |
188 | def __getitem__(self, index):
189 | elem = self.imgs_uniform[index]
190 | centroid = None
191 | if len(elem) == 4:
192 | img_path, mask_path, centroid, class_id = elem
193 | else:
194 | img_path, mask_path = elem
195 |
196 | img, mask = Image.open(img_path).convert('RGB'), Image.open(mask_path)
197 | img_name = os.path.splitext(os.path.basename(img_path))[0]
198 |
199 | # kitti scale correction factor
200 | if self.mode == 'train' or self.mode == 'trainval':
201 | if self.scf:
202 | width, height = img.size
203 | img = img.resize((width*2, height*2), Image.BICUBIC)
204 | mask = mask.resize((width*2, height*2), Image.NEAREST)
205 | elif self.mode == 'val':
206 | width, height = 1242, 376
207 | img = img.resize((width, height), Image.BICUBIC)
208 | mask = mask.resize((width, height), Image.NEAREST)
209 | else:
210 | logging.info('Unknown mode {}'.format(mode))
211 | sys.exit()
212 |
213 | mask = np.array(mask)
214 | mask_copy = mask.copy()
215 |
216 | for k, v in id_to_trainid.items():
217 | mask_copy[mask == k] = v
218 | mask = Image.fromarray(mask_copy.astype(np.uint8))
219 |
220 | # Image Transformations
221 | if self.joint_transform_list is not None:
222 | for idx, xform in enumerate(self.joint_transform_list):
223 | if idx == 0 and centroid is not None:
224 | # HACK
225 | # We assume that the first transform is capable of taking
226 | # in a centroid
227 | img, mask = xform(img, mask, centroid)
228 | else:
229 | img, mask = xform(img, mask)
230 |
231 | # Debug
232 | if self.dump_images and centroid is not None:
233 | outdir = './dump_imgs_{}'.format(self.mode)
234 | os.makedirs(outdir, exist_ok=True)
235 | dump_img_name = trainid_to_name[class_id] + '_' + img_name
236 | out_img_fn = os.path.join(outdir, dump_img_name + '.png')
237 | out_msk_fn = os.path.join(outdir, dump_img_name + '_mask.png')
238 | mask_img = colorize_mask(np.array(mask))
239 | img.save(out_img_fn)
240 | mask_img.save(out_msk_fn)
241 |
242 | if self.transform is not None:
243 | img = self.transform(img)
244 | if self.target_transform is not None:
245 | mask = self.target_transform(mask)
246 |
247 | return img, mask, img_name
248 |
249 | def __len__(self):
250 | return len(self.imgs_uniform)
251 |
252 |
--------------------------------------------------------------------------------
/datasets/mapillary.py:
--------------------------------------------------------------------------------
1 | """
2 | Mapillary Dataset Loader
3 | """
4 | import logging
5 | import json
6 | import os
7 | import numpy as np
8 | from PIL import Image, ImageCms
9 | from skimage import color
10 |
11 | from torch.utils import data
12 | import torch
13 | import torchvision.transforms as transforms
14 | import datasets.uniform as uniform
15 | import datasets.cityscapes_labels as cityscapes_labels
16 | import transforms.transforms as extended_transforms
17 | import copy
18 |
19 | from config import cfg
20 |
21 | # Convert this dataset to have labels from cityscapes
22 | num_classes = 19 #65
23 | ignore_label = 255 #65
24 | root = cfg.DATASET.MAPILLARY_DIR
25 | config_fn = os.path.join(root, 'config.json')
26 | color_mapping = []
27 | id_to_trainid = {}
28 | id_to_ignore_or_group = {}
29 |
30 |
31 | def gen_id_to_ignore():
32 | global id_to_ignore_or_group
33 | for i in range(66):
34 | id_to_ignore_or_group[i] = ignore_label
35 |
36 | ### Convert each class to cityscapes one
37 | ### Road
38 | # Road
39 | id_to_ignore_or_group[13] = 0
40 | # Lane Marking - General
41 | id_to_ignore_or_group[24] = 0
42 | # Manhole
43 | id_to_ignore_or_group[41] = 0
44 |
45 | ### Sidewalk
46 | # Curb
47 | id_to_ignore_or_group[2] = 1
48 | # Sidewalk
49 | id_to_ignore_or_group[15] = 1
50 |
51 | ### Building
52 | # Building
53 | id_to_ignore_or_group[17] = 2
54 |
55 | ### Wall
56 | # Wall
57 | id_to_ignore_or_group[6] = 3
58 |
59 | ### Fence
60 | # Fence
61 | id_to_ignore_or_group[3] = 4
62 |
63 | ### Pole
64 | # Pole
65 | id_to_ignore_or_group[45] = 5
66 | # Utility Pole
67 | id_to_ignore_or_group[47] = 5
68 |
69 | ### Traffic Light
70 | # Traffic Light
71 | id_to_ignore_or_group[48] = 6
72 |
73 | ### Traffic Sign
74 | # Traffic Sign
75 | id_to_ignore_or_group[50] = 7
76 |
77 | ### Vegetation
78 | # Vegitation
79 | id_to_ignore_or_group[30] = 8
80 |
81 | ### Terrain
82 | # Terrain
83 | id_to_ignore_or_group[29] = 9
84 |
85 | ### Sky
86 | # Sky
87 | id_to_ignore_or_group[27] = 10
88 |
89 | ### Person
90 | # Person
91 | id_to_ignore_or_group[19] = 11
92 |
93 | ### Rider
94 | # Bicyclist
95 | id_to_ignore_or_group[20] = 12
96 | # Motorcyclist
97 | id_to_ignore_or_group[21] = 12
98 | # Other Rider
99 | id_to_ignore_or_group[22] = 12
100 |
101 | ### Car
102 | # Car
103 | id_to_ignore_or_group[55] = 13
104 |
105 | ### Truck
106 | # Truck
107 | id_to_ignore_or_group[61] = 14
108 |
109 | ### Bus
110 | # Bus
111 | id_to_ignore_or_group[54] = 15
112 |
113 | ### Train
114 | # On Rails
115 | id_to_ignore_or_group[58] = 16
116 |
117 | ### Motorcycle
118 | # Motorcycle
119 | id_to_ignore_or_group[57] = 17
120 |
121 | ### Bicycle
122 | # Bicycle
123 | id_to_ignore_or_group[52] = 18
124 |
125 |
126 | def colorize_mask(image_array):
127 | """
128 | Colorize a segmentation mask
129 | """
130 | new_mask = Image.fromarray(image_array.astype(np.uint8)).convert('P')
131 | new_mask.putpalette(color_mapping)
132 | return new_mask
133 |
134 |
135 | def make_dataset(quality, mode):
136 | """
137 | Create File List
138 | """
139 | assert (quality == 'semantic' and mode in ['train', 'val'])
140 | img_dir_name = None
141 | if quality == 'semantic':
142 | if mode == 'train':
143 | img_dir_name = 'training'
144 | if mode == 'val':
145 | img_dir_name = 'validation'
146 | mask_path = os.path.join(root, img_dir_name, 'labels')
147 | else:
148 | raise BaseException("Instance Segmentation Not support")
149 |
150 | img_path = os.path.join(root, img_dir_name, 'images')
151 | print(img_path)
152 | if quality != 'video':
153 | imgs = sorted([os.path.splitext(f)[0] for f in os.listdir(img_path)])
154 | msks = sorted([os.path.splitext(f)[0] for f in os.listdir(mask_path)])
155 | assert imgs == msks
156 |
157 | items = []
158 | c_items = os.listdir(img_path)
159 | if '.DS_Store' in c_items:
160 | c_items.remove('.DS_Store')
161 |
162 | for it in c_items:
163 | if quality == 'video':
164 | item = (os.path.join(img_path, it), os.path.join(img_path, it))
165 | else:
166 | item = (os.path.join(img_path, it),
167 | os.path.join(mask_path, it.replace(".jpg", ".png")))
168 | items.append(item)
169 | return items
170 |
171 |
172 | def gen_colormap():
173 | """
174 | Get Color Map from file
175 | """
176 | global color_mapping
177 |
178 | # load mapillary config
179 | with open(config_fn) as config_file:
180 | config = json.load(config_file)
181 | config_labels = config['labels']
182 |
183 | # calculate label color mapping
184 | colormap = []
185 | id2name = {}
186 | for i in range(0, len(config_labels)):
187 | colormap = colormap + config_labels[i]['color']
188 | id2name[i] = config_labels[i]['readable']
189 | color_mapping = colormap
190 | return id2name
191 |
192 |
193 | class Mapillary(data.Dataset):
194 | def __init__(self, quality, mode, joint_transform_list=None,
195 | transform=None, target_transform=None, target_aux_transform=None,
196 | image_in=False, dump_images=False, class_uniform_pct=0,
197 | class_uniform_tile=768, test=False):
198 | """
199 | class_uniform_pct = Percent of class uniform samples. 1.0 means fully uniform.
200 | 0.0 means fully random.
201 | class_uniform_tile_size = Class uniform tile size
202 | """
203 | gen_id_to_ignore()
204 | self.quality = quality
205 | self.mode = mode
206 | self.joint_transform_list = joint_transform_list
207 | self.transform = transform
208 | self.target_transform = target_transform
209 | self.image_in = image_in
210 | self.target_aux_transform = target_aux_transform
211 | self.dump_images = dump_images
212 | self.class_uniform_pct = class_uniform_pct
213 | self.class_uniform_tile = class_uniform_tile
214 | self.id2name = gen_colormap()
215 | self.imgs_uniform = None
216 |
217 |
218 | # find all images
219 | self.imgs = make_dataset(quality, mode)
220 | if len(self.imgs) == 0:
221 | raise RuntimeError('Found 0 images, please check the data set')
222 | if test:
223 | np.random.shuffle(self.imgs)
224 | self.imgs = self.imgs[:200]
225 |
226 | if self.class_uniform_pct:
227 | json_fn = 'mapillary_tile{}.json'.format(self.class_uniform_tile)
228 | if os.path.isfile(json_fn):
229 | with open(json_fn, 'r') as json_data:
230 | centroids = json.load(json_data)
231 | self.centroids = {int(idx): centroids[idx] for idx in centroids}
232 | else:
233 | # centroids is a dict (indexed by class) of lists of centroids
234 | self.centroids = uniform.class_centroids_all(
235 | self.imgs,
236 | num_classes,
237 | id2trainid=None,
238 | tile_size=self.class_uniform_tile)
239 | with open(json_fn, 'w') as outfile:
240 | json.dump(self.centroids, outfile, indent=4)
241 | else:
242 | self.centroids = []
243 | self.build_epoch()
244 |
245 | def build_epoch(self):
246 | if self.class_uniform_pct != 0:
247 | self.imgs_uniform = uniform.build_epoch(self.imgs,
248 | self.centroids,
249 | num_classes,
250 | self.class_uniform_pct)
251 | else:
252 | self.imgs_uniform = self.imgs
253 |
254 | def __getitem__(self, index):
255 | if len(self.imgs_uniform[index]) == 2:
256 | img_path, mask_path = self.imgs_uniform[index]
257 | centroid = None
258 | class_id = None
259 | else:
260 | img_path, mask_path, centroid, class_id = self.imgs_uniform[index]
261 | img, mask = Image.open(img_path).convert('RGB'), Image.open(mask_path)
262 | img_name = os.path.splitext(os.path.basename(img_path))[0]
263 |
264 | mask = np.array(mask)
265 | mask_copy = mask.copy()
266 | for k, v in id_to_ignore_or_group.items():
267 | mask_copy[mask == k] = v
268 | mask = Image.fromarray(mask_copy.astype(np.uint8))
269 |
270 | # Image Transformations
271 | if self.joint_transform_list is not None:
272 | for idx, xform in enumerate(self.joint_transform_list):
273 | if idx == 0 and centroid is not None:
274 | # HACK! Assume the first transform accepts a centroid
275 | img, mask = xform(img, mask, centroid)
276 | else:
277 | img, mask = xform(img, mask)
278 |
279 | if self.dump_images:
280 | outdir = 'dump_imgs_{}'.format(self.mode)
281 | os.makedirs(outdir, exist_ok=True)
282 | if centroid is not None:
283 | dump_img_name = self.id2name[class_id] + '_' + img_name
284 | else:
285 | dump_img_name = img_name
286 | out_img_fn = os.path.join(outdir, dump_img_name + '.png')
287 | out_msk_fn = os.path.join(outdir, dump_img_name + '_mask.png')
288 | mask_img = colorize_mask(np.array(mask))
289 | img.save(out_img_fn)
290 | mask_img.save(out_msk_fn)
291 |
292 | if self.transform is not None:
293 | img = self.transform(img)
294 |
295 | rgb_mean_std = ([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
296 | img_gt = transforms.Normalize(*rgb_mean_std)(img)
297 | if self.image_in:
298 | eps = 1e-5
299 | rgb_mean_std = ([torch.mean(img[0]), torch.mean(img[1]), torch.mean(img[2])],
300 | [torch.std(img[0])+eps, torch.std(img[1])+eps, torch.std(img[2])+eps])
301 | img = transforms.Normalize(*rgb_mean_std)(img)
302 |
303 | if self.target_aux_transform is not None:
304 | mask_aux = self.target_aux_transform(mask)
305 | else:
306 | mask_aux = torch.tensor([0])
307 | if self.target_transform is not None:
308 | mask = self.target_transform(mask)
309 |
310 | mask = extended_transforms.MaskToTensor()(mask)
311 | return img, mask, img_name, mask_aux
312 |
313 | def __len__(self):
314 | return len(self.imgs_uniform)
315 |
316 | def calculate_weights(self):
317 | raise BaseException("not supported yet")
318 |
--------------------------------------------------------------------------------
/datasets/multi_loader.py:
--------------------------------------------------------------------------------
1 | """
2 | Custom DomainUniformConcatDataset
3 | """
4 | import numpy as np
5 |
6 | from torch.utils.data import Dataset
7 | import torch
8 | from config import cfg
9 |
10 |
11 | np.random.seed(cfg.RANDOM_SEED)
12 |
13 |
14 | class DomainUniformConcatDataset(Dataset):
15 | """
16 | DomainUniformConcatDataset
17 |
18 | Sample images uniformly across the domains
19 | If bs_mul is n, this outputs # of domains * n images per batch
20 | """
21 | @staticmethod
22 | def cumsum(sequence):
23 | r, s = [], 0
24 | for e in sequence:
25 | l = len(e)
26 | r.append(l + s)
27 | s += l
28 | return r
29 |
30 | def __init__(self, args, datasets):
31 | """
32 | This dataset is to return sample image (source)
33 | and augmented sample image (target)
34 | Args:
35 | args: input config arguments
36 | datasets: list of datasets to concat
37 | """
38 | super(DomainUniformConcatDataset, self).__init__()
39 | self.datasets = datasets
40 | self.lengths = [len(d) for d in datasets]
41 | self.offsets = self.cumsum(datasets)
42 | self.length = np.sum(self.lengths)
43 |
44 | print("# domains: {}, Total length: {}, 1 epoch: {}, offsets: {}".format(
45 | str(len(datasets)), str(self.length), str(len(self)), str(self.offsets)))
46 |
47 |
48 | def __len__(self):
49 | """
50 | Returns:
51 | The number of images in a domain that has minimum image samples
52 | """
53 | return min(self.lengths)
54 |
55 |
56 | def _get_batch_from_dataset(self, dataset, idx):
57 | """
58 | Get batch from dataset
59 | New idx = idx + random integer
60 | Args:
61 | dataset: dataset class object
62 | idx: integer
63 |
64 | Returns:
65 | One batch from dataset
66 | """
67 | p_index = idx + np.random.randint(len(dataset))
68 | if p_index > len(dataset) - 1:
69 | p_index -= len(dataset)
70 |
71 | return dataset[p_index]
72 |
73 |
74 | def __getitem__(self, idx):
75 | """
76 | Args:
77 | idx (int): Index
78 |
79 | Returns:
80 | images corresonding to the index from each domain
81 | """
82 | imgs = []
83 | masks = []
84 | img_names = []
85 | mask_auxs = []
86 |
87 | for dataset in self.datasets:
88 | img, mask, img_name, mask_aux = self._get_batch_from_dataset(dataset, idx)
89 | imgs.append(img)
90 | masks.append(mask)
91 | img_names.append(img_name)
92 | mask_auxs.append(mask_aux)
93 | imgs, masks, mask_auxs = torch.stack(imgs, 0), torch.stack(masks, 0), torch.stack(mask_auxs, 0)
94 |
95 | return imgs, masks, img_names, mask_auxs
96 |
97 |
--------------------------------------------------------------------------------
/datasets/nullloader.py:
--------------------------------------------------------------------------------
1 | """
2 | Null Loader
3 | """
4 | import numpy as np
5 | import torch
6 | from torch.utils import data
7 |
8 | num_classes = 19
9 | ignore_label = 255
10 |
11 | class NullLoader(data.Dataset):
12 | """
13 | Null Dataset for Performance
14 | """
15 | def __init__(self,crop_size):
16 | self.imgs = range(200)
17 | self.crop_size = crop_size
18 |
19 | def __getitem__(self, index):
20 | #Return img, mask, name
21 | return torch.FloatTensor(np.zeros((3,self.crop_size,self.crop_size))), torch.LongTensor(np.zeros((self.crop_size,self.crop_size))), 'img' + str(index)
22 |
23 | def __len__(self):
24 | return len(self.imgs)
--------------------------------------------------------------------------------
/datasets/sampler.py:
--------------------------------------------------------------------------------
1 | """
2 | # Code adapted from:
3 | # https://github.com/pytorch/pytorch/blob/master/torch/utils/data/distributed.py
4 | #
5 | # BSD 3-Clause License
6 | #
7 | # Copyright (c) 2017,
8 | # All rights reserved.
9 | #
10 | # Redistribution and use in source and binary forms, with or without
11 | # modification, are permitted provided that the following conditions are met:
12 | #
13 | # * Redistributions of source code must retain the above copyright notice, this
14 | # list of conditions and the following disclaimer.
15 | #
16 | # * Redistributions in binary form must reproduce the above copyright notice,
17 | # this list of conditions and the following disclaimer in the documentation
18 | # and/or other materials provided with the distribution.
19 | #
20 | # * Neither the name of the copyright holder nor the names of its
21 | # contributors may be used to endorse or promote products derived from
22 | # this software without specific prior written permission.
23 | #
24 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
25 | # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
26 | # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
27 | # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
28 | # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
29 | # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
30 | # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
31 | # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
32 | # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
33 | # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
34 | """
35 |
36 |
37 |
38 | import math
39 | import torch
40 | from torch.distributed import get_world_size, get_rank
41 | from torch.utils.data import Sampler
42 |
43 | class DistributedSampler(Sampler):
44 | """Sampler that restricts data loading to a subset of the dataset.
45 |
46 | It is especially useful in conjunction with
47 | :class:`torch.nn.parallel.DistributedDataParallel`. In such case, each
48 | process can pass a DistributedSampler instance as a DataLoader sampler,
49 | and load a subset of the original dataset that is exclusive to it.
50 |
51 | .. note::
52 | Dataset is assumed to be of constant size.
53 |
54 | Arguments:
55 | dataset: Dataset used for sampling.
56 | num_replicas (optional): Number of processes participating in
57 | distributed training.
58 | rank (optional): Rank of the current process within num_replicas.
59 | """
60 |
61 | def __init__(self, dataset, pad=False, consecutive_sample=False, permutation=False, num_replicas=None, rank=None):
62 | if num_replicas is None:
63 | num_replicas = get_world_size()
64 | if rank is None:
65 | rank = get_rank()
66 | self.dataset = dataset
67 | self.num_replicas = num_replicas
68 | self.rank = rank
69 | self.epoch = 0
70 | self.consecutive_sample = consecutive_sample
71 | self.permutation = permutation
72 | if pad:
73 | self.num_samples = int(math.ceil(len(self.dataset) * 1.0 / self.num_replicas))
74 | else:
75 | self.num_samples = int(math.floor(len(self.dataset) * 1.0 / self.num_replicas))
76 | self.total_size = self.num_samples * self.num_replicas
77 |
78 | def __iter__(self):
79 | # deterministically shuffle based on epoch
80 | g = torch.Generator()
81 | g.manual_seed(self.epoch)
82 |
83 | if self.permutation:
84 | indices = list(torch.randperm(len(self.dataset), generator=g))
85 | else:
86 | indices = list([x for x in range(len(self.dataset))])
87 |
88 | # add extra samples to make it evenly divisible
89 | if self.total_size > len(indices):
90 | indices += indices[:(self.total_size - len(indices))]
91 |
92 | # subsample
93 | if self.consecutive_sample:
94 | offset = self.num_samples * self.rank
95 | indices = indices[offset:offset + self.num_samples]
96 | else:
97 | indices = indices[self.rank:self.total_size:self.num_replicas]
98 | assert len(indices) == self.num_samples
99 |
100 | return iter(indices)
101 |
102 | def __len__(self):
103 | return self.num_samples
104 |
105 | def set_epoch(self, epoch):
106 | self.epoch = epoch
107 |
108 | def set_num_samples(self):
109 | self.num_samples = int(math.ceil(len(self.dataset) * 1.0 / self.num_replicas))
110 | self.total_size = self.num_samples * self.num_replicas
--------------------------------------------------------------------------------
/datasets/uniform.py:
--------------------------------------------------------------------------------
1 | """
2 | Uniform sampling of classes.
3 | For all images, for all classes, generate centroids around which to sample.
4 |
5 | All images are divided into tiles.
6 | For each tile, a class can be present or not. If it is
7 | present, calculate the centroid of the class and record it.
8 |
9 | We would like to thank Peter Kontschieder for the inspiration of this idea.
10 | """
11 |
12 | import logging
13 | from collections import defaultdict
14 | from PIL import Image
15 | import numpy as np
16 | from scipy import ndimage
17 | from tqdm import tqdm
18 |
19 | pbar = None
20 |
21 | class Point():
22 | """
23 | Point Class For X and Y Location
24 | """
25 | def __init__(self, x, y):
26 | self.x = x
27 | self.y = y
28 |
29 |
30 | def calc_tile_locations(tile_size, image_size):
31 | """
32 | Divide an image into tiles to help us cover classes that are spread out.
33 | tile_size: size of tile to distribute
34 | image_size: original image size
35 | return: locations of the tiles
36 | """
37 | image_size_y, image_size_x = image_size
38 | locations = []
39 | for y in range(image_size_y // tile_size):
40 | for x in range(image_size_x // tile_size):
41 | x_offs = x * tile_size
42 | y_offs = y * tile_size
43 | locations.append((x_offs, y_offs))
44 | return locations
45 |
46 |
47 | def class_centroids_image(item, tile_size, num_classes, id2trainid):
48 | """
49 | For one image, calculate centroids for all classes present in image.
50 | item: image, image_name
51 | tile_size:
52 | num_classes:
53 | id2trainid: mapping from original id to training ids
54 | return: Centroids are calculated for each tile.
55 | """
56 | image_fn, label_fn = item
57 | centroids = defaultdict(list)
58 | mask = np.array(Image.open(label_fn))
59 | if len(mask.shape) == 3:
60 | # Remove instance mask
61 | mask = mask[:,:,0]
62 | image_size = mask.shape
63 | tile_locations = calc_tile_locations(tile_size, image_size)
64 |
65 | mask_copy = mask.copy()
66 | if id2trainid:
67 | for k, v in id2trainid.items():
68 | mask[mask_copy == k] = v
69 |
70 | for x_offs, y_offs in tile_locations:
71 | patch = mask[y_offs:y_offs + tile_size, x_offs:x_offs + tile_size]
72 | for class_id in range(num_classes):
73 | if class_id in patch:
74 | patch_class = (patch == class_id).astype(int)
75 | centroid_y, centroid_x = ndimage.measurements.center_of_mass(patch_class)
76 | centroid_y = int(centroid_y) + y_offs
77 | centroid_x = int(centroid_x) + x_offs
78 | centroid = (centroid_x, centroid_y)
79 | centroids[class_id].append((image_fn, label_fn, centroid, class_id))
80 | pbar.update(1)
81 | return centroids
82 |
83 | import scipy.misc as m
84 |
85 | def class_centroids_image_from_color(item, tile_size, num_classes, id2trainid):
86 | """
87 | For one image, calculate centroids for all classes present in image.
88 | item: image, image_name
89 | tile_size:
90 | num_classes:
91 | id2trainid: mapping from original id to training ids
92 | return: Centroids are calculated for each tile.
93 | """
94 | image_fn, label_fn = item
95 | centroids = defaultdict(list)
96 | mask = m.imread(label_fn)
97 | image_size = mask[:,:,0].shape
98 | tile_locations = calc_tile_locations(tile_size, image_size)
99 |
100 | # mask = m.imread(label_fn)
101 | # mask_copy = np.full((img.size[1], img.size[0]), 255, dtype=np.uint8)
102 | # for k, v in id2trainid.items():
103 | # mask_copy[(mask == k)[:,:,0]] = v
104 | # mask = Image.fromarray(mask_copy.astype(np.uint8))
105 |
106 | # mask_copy = mask.copy()
107 | # mask_copy = mask.copy()
108 | # if id2trainid:
109 | # for k, v in id2trainid.items():
110 | # mask[mask_copy == k] = v
111 |
112 | mask_copy = np.full(image_size, 255, dtype=np.uint8)
113 |
114 | if id2trainid:
115 | for k, v in id2trainid.items():
116 | # print("0", mask.shape)
117 | # print("1", ((mask == np.array(k))[:,:,0]).shape) # 1052, 1914
118 | # # print("2", mask == np.array(k)[:,:,0])
119 | # break
120 | # if v != 255:
121 | # print(v)
122 | # if v == 2:
123 | # print(k, v, "num", np.count_nonzero(mask == np.array(k)))
124 | # break
125 | if v != 255 and v != -1:
126 | mask_copy[(mask == np.array(k))[:,:,0] & (mask == np.array(k))[:,:,1] & (mask == np.array(k))[:,:,2]] = v
127 | mask = mask_copy
128 |
129 | # mask_copy = mask.copy()
130 | # if id2trainid:
131 | # for k, v in id2trainid.items():
132 | # mask[mask_copy == k] = v
133 |
134 | for x_offs, y_offs in tile_locations:
135 | patch = mask[y_offs:y_offs + tile_size, x_offs:x_offs + tile_size]
136 | for class_id in range(num_classes):
137 | if class_id in patch:
138 | patch_class = (patch == class_id).astype(int)
139 | centroid_y, centroid_x = ndimage.measurements.center_of_mass(patch_class)
140 | centroid_y = int(centroid_y) + y_offs
141 | centroid_x = int(centroid_x) + x_offs
142 | centroid = (centroid_x, centroid_y)
143 | centroids[class_id].append((image_fn, label_fn, centroid, class_id))
144 | pbar.update(1)
145 | return centroids
146 |
147 | def pooled_class_centroids_all_from_color(items, num_classes, id2trainid, tile_size=1024):
148 | """
149 | Calculate class centroids for all classes for all images for all tiles.
150 | items: list of (image_fn, label_fn)
151 | tile size: size of tile
152 | returns: dict that contains a list of centroids for each class
153 | """
154 | from multiprocessing.dummy import Pool
155 | from functools import partial
156 | pool = Pool(32)
157 | global pbar
158 | pbar = tqdm(total=len(items), desc='pooled centroid extraction')
159 | class_centroids_item = partial(class_centroids_image_from_color,
160 | num_classes=num_classes,
161 | id2trainid=id2trainid,
162 | tile_size=tile_size)
163 |
164 | centroids = defaultdict(list)
165 | new_centroids = pool.map(class_centroids_item, items)
166 | pool.close()
167 | pool.join()
168 |
169 | # combine each image's items into a single global dict
170 | for image_items in new_centroids:
171 | for class_id in image_items:
172 | centroids[class_id].extend(image_items[class_id])
173 | return centroids
174 |
175 |
176 | def pooled_class_centroids_all(items, num_classes, id2trainid, tile_size=1024):
177 | """
178 | Calculate class centroids for all classes for all images for all tiles.
179 | items: list of (image_fn, label_fn)
180 | tile size: size of tile
181 | returns: dict that contains a list of centroids for each class
182 | """
183 | from multiprocessing.dummy import Pool
184 | from functools import partial
185 | pool = Pool(80)
186 | global pbar
187 | pbar = tqdm(total=len(items), desc='pooled centroid extraction')
188 | class_centroids_item = partial(class_centroids_image,
189 | num_classes=num_classes,
190 | id2trainid=id2trainid,
191 | tile_size=tile_size)
192 |
193 | centroids = defaultdict(list)
194 | new_centroids = pool.map(class_centroids_item, items)
195 | pool.close()
196 | pool.join()
197 |
198 | # combine each image's items into a single global dict
199 | for image_items in new_centroids:
200 | for class_id in image_items:
201 | centroids[class_id].extend(image_items[class_id])
202 | return centroids
203 |
204 |
205 | def unpooled_class_centroids_all(items, num_classes, tile_size=1024):
206 | """
207 | Calculate class centroids for all classes for all images for all tiles.
208 | items: list of (image_fn, label_fn)
209 | tile size: size of tile
210 | returns: dict that contains a list of centroids for each class
211 | """
212 | centroids = defaultdict(list)
213 | global pbar
214 | pbar = tqdm(total=len(items), desc='centroid extraction')
215 | for image, label in items:
216 | new_centroids = class_centroids_image((image, label),
217 | tile_size,
218 | num_classes)
219 | for class_id in new_centroids:
220 | centroids[class_id].extend(new_centroids[class_id])
221 |
222 | return centroids
223 |
224 |
225 | def class_centroids_all_from_color(items, num_classes, id2trainid, tile_size=1024):
226 | """
227 | intermediate function to call pooled_class_centroid
228 | """
229 |
230 | pooled_centroids = pooled_class_centroids_all_from_color(items, num_classes,
231 | id2trainid, tile_size)
232 | return pooled_centroids
233 |
234 |
235 | def class_centroids_all(items, num_classes, id2trainid, tile_size=1024):
236 | """
237 | intermediate function to call pooled_class_centroid
238 | """
239 |
240 | pooled_centroids = pooled_class_centroids_all(items, num_classes,
241 | id2trainid, tile_size)
242 | return pooled_centroids
243 |
244 |
245 | def random_sampling(alist, num):
246 | """
247 | Randomly sample num items from the list
248 | alist: list of centroids to sample from
249 | num: can be larger than the list and if so, then wrap around
250 | return: class uniform samples from the list
251 | """
252 | sampling = []
253 | len_list = len(alist)
254 | assert len_list, 'len_list is zero!'
255 | indices = np.arange(len_list)
256 | np.random.shuffle(indices)
257 |
258 | for i in range(num):
259 | item = alist[indices[i % len_list]]
260 | sampling.append(item)
261 | return sampling
262 |
263 |
264 | def build_epoch(imgs, centroids, num_classes, class_uniform_pct):
265 | """
266 | Generate an epochs-worth of crops using uniform sampling. Needs to be called every
267 | imgs: list of imgs
268 | centroids:
269 | num_classes:
270 | class_uniform_pct: class uniform sampling percent ( % of uniform images in one epoch )
271 | """
272 | logging.info("Class Uniform Percentage: %s", str(class_uniform_pct))
273 | num_epoch = int(len(imgs))
274 |
275 | logging.info('Class Uniform items per Epoch:%s', str(num_epoch))
276 | num_per_class = int((num_epoch * class_uniform_pct) / num_classes)
277 | num_rand = num_epoch - num_per_class * num_classes
278 | # create random crops
279 | imgs_uniform = random_sampling(imgs, num_rand)
280 |
281 | # now add uniform sampling
282 | for class_id in range(num_classes):
283 | string_format = "cls %d len %d"% (class_id, len(centroids[class_id]))
284 | logging.info(string_format)
285 | for class_id in range(num_classes):
286 | centroid_len = len(centroids[class_id])
287 | if centroid_len == 0:
288 | pass
289 | else:
290 | class_centroids = random_sampling(centroids[class_id], num_per_class)
291 | imgs_uniform.extend(class_centroids)
292 |
293 | return imgs_uniform
294 |
--------------------------------------------------------------------------------
/loss.py:
--------------------------------------------------------------------------------
1 | """
2 | Loss.py
3 | """
4 |
5 | import logging
6 | import numpy as np
7 | import torch
8 | import torch.nn as nn
9 | import torch.nn.functional as F
10 | import datasets
11 | from config import cfg
12 |
13 |
14 | def get_loss(args):
15 | """
16 | Get the criterion based on the loss function
17 | args: commandline arguments
18 | return: criterion, criterion_val
19 | """
20 | if args.cls_wt_loss:
21 | ce_weight = torch.Tensor([0.8373, 0.9180, 0.8660, 1.0345, 1.0166, 0.9969, 0.9754,
22 | 1.0489, 0.8786, 1.0023, 0.9539, 0.9843, 1.1116, 0.9037,
23 | 1.0865, 1.0955, 1.0865, 1.1529, 1.0507])
24 | else:
25 | ce_weight = None
26 |
27 | if args.img_wt_loss:
28 | criterion = ImageBasedCrossEntropyLoss2d(
29 | classes=datasets.num_classes, size_average=True,
30 | ignore_index=datasets.ignore_label,
31 | upper_bound=args.wt_bound).cuda()
32 | elif args.jointwtborder:
33 | criterion = ImgWtLossSoftNLL(classes=datasets.num_classes,
34 | ignore_index=datasets.ignore_label,
35 | upper_bound=args.wt_bound).cuda()
36 | else:
37 | print("standard cross entropy")
38 | criterion = nn.CrossEntropyLoss(weight=ce_weight, reduction='mean',
39 | ignore_index=datasets.ignore_label).cuda()
40 |
41 | criterion_val = nn.CrossEntropyLoss(reduction='mean',
42 | ignore_index=datasets.ignore_label).cuda()
43 | return criterion, criterion_val
44 |
45 | def get_loss_by_epoch(args):
46 | """
47 | Get the criterion based on the loss function
48 | args: commandline arguments
49 | return: criterion, criterion_val
50 | """
51 |
52 | if args.img_wt_loss:
53 | criterion = ImageBasedCrossEntropyLoss2d(
54 | classes=datasets.num_classes, size_average=True,
55 | ignore_index=datasets.ignore_label,
56 | upper_bound=args.wt_bound).cuda()
57 | elif args.jointwtborder:
58 | criterion = ImgWtLossSoftNLL_by_epoch(classes=datasets.num_classes,
59 | ignore_index=datasets.ignore_label,
60 | upper_bound=args.wt_bound).cuda()
61 | else:
62 | criterion = CrossEntropyLoss2d(size_average=True,
63 | ignore_index=datasets.ignore_label).cuda()
64 |
65 | criterion_val = CrossEntropyLoss2d(size_average=True,
66 | weight=None,
67 | ignore_index=datasets.ignore_label).cuda()
68 | return criterion, criterion_val
69 |
70 |
71 | def get_loss_aux(args):
72 | """
73 | Get the criterion based on the loss function
74 | args: commandline arguments
75 | return: criterion, criterion_val
76 | """
77 | if args.cls_wt_loss:
78 | ce_weight = torch.Tensor([0.8373, 0.9180, 0.8660, 1.0345, 1.0166, 0.9969, 0.9754,
79 | 1.0489, 0.8786, 1.0023, 0.9539, 0.9843, 1.1116, 0.9037,
80 | 1.0865, 1.0955, 1.0865, 1.1529, 1.0507])
81 | else:
82 | ce_weight = None
83 |
84 | print("standard cross entropy")
85 | criterion = nn.CrossEntropyLoss(weight=ce_weight, reduction='mean',
86 | ignore_index=datasets.ignore_label).cuda()
87 |
88 | return criterion
89 |
90 | def get_loss_bcelogit(args):
91 | if args.cls_wt_loss:
92 | pos_weight = torch.Tensor([0.8373, 0.9180, 0.8660, 1.0345, 1.0166, 0.9969, 0.9754,
93 | 1.0489, 0.8786, 1.0023, 0.9539, 0.9843, 1.1116, 0.9037,
94 | 1.0865, 1.0955, 1.0865, 1.1529, 1.0507])
95 | else:
96 | pos_weight = None
97 | print("standard bce with logit cross entropy")
98 | criterion = nn.BCEWithLogitsLoss(reduction='mean').cuda()
99 |
100 | return criterion
101 |
102 | def weighted_binary_cross_entropy(output, target):
103 |
104 | weights = torch.Tensor([0.1, 0.9])
105 |
106 | loss = weights[1] * (target * torch.log(output)) + \
107 | weights[0] * ((1 - target) * torch.log(1 - output))
108 |
109 | return torch.neg(torch.mean(loss))
110 |
111 |
112 | class L1Loss(nn.Module):
113 | def __init__(self):
114 | super(L1Loss, self).__init__()
115 |
116 | def __call__(self, in0, in1):
117 | return torch.sum(torch.abs(in0 - in1), dim=1, keepdim=True)
118 |
119 |
120 | class ImageBasedCrossEntropyLoss2d(nn.Module):
121 | """
122 | Image Weighted Cross Entropy Loss
123 | """
124 |
125 | def __init__(self, classes, weight=None, size_average=True, ignore_index=255,
126 | norm=False, upper_bound=1.0):
127 | super(ImageBasedCrossEntropyLoss2d, self).__init__()
128 | logging.info("Using Per Image based weighted loss")
129 | self.num_classes = classes
130 | self.nll_loss = nn.NLLLoss(weight=weight, reduction='mean', ignore_index=ignore_index)
131 | self.norm = norm
132 | self.upper_bound = upper_bound
133 | self.batch_weights = cfg.BATCH_WEIGHTING
134 | self.logsoftmax = nn.LogSoftmax(dim=1)
135 |
136 | def calculate_weights(self, target):
137 | """
138 | Calculate weights of classes based on the training crop
139 | """
140 | hist = np.histogram(target.flatten(), range(
141 | self.num_classes + 1), normed=True)[0]
142 | if self.norm:
143 | hist = ((hist != 0) * self.upper_bound * (1 / hist)) + 1
144 | else:
145 | hist = ((hist != 0) * self.upper_bound * (1 - hist)) + 1
146 | return hist
147 |
148 | def forward(self, inputs, targets):
149 |
150 | target_cpu = targets.data.cpu().numpy()
151 | if self.batch_weights:
152 | weights = self.calculate_weights(target_cpu)
153 | self.nll_loss.weight = torch.Tensor(weights).cuda()
154 |
155 | loss = 0.0
156 | for i in range(0, inputs.shape[0]):
157 | if not self.batch_weights:
158 | weights = self.calculate_weights(target_cpu[i])
159 | self.nll_loss.weight = torch.Tensor(weights).cuda()
160 |
161 | loss += self.nll_loss(self.logsoftmax(inputs[i].unsqueeze(0)),
162 | targets[i].unsqueeze(0))
163 | return loss
164 |
165 |
166 |
167 | class CrossEntropyLoss2d(nn.Module):
168 | """
169 | Cross Entroply NLL Loss
170 | """
171 |
172 | def __init__(self, weight=None, size_average=True, ignore_index=255):
173 | super(CrossEntropyLoss2d, self).__init__()
174 | logging.info("Using Cross Entropy Loss")
175 | self.nll_loss = nn.NLLLoss(weight=weight, reduction='mean', ignore_index=ignore_index)
176 | self.logsoftmax = nn.LogSoftmax(dim=1)
177 | # self.weight = weight
178 |
179 | def forward(self, inputs, targets):
180 | return self.nll_loss(self.logsoftmax(inputs), targets)
181 |
182 | def customsoftmax(inp, multihotmask):
183 | """
184 | Custom Softmax
185 | """
186 | soft = F.softmax(inp, dim=1)
187 | # This takes the mask * softmax ( sums it up hence summing up the classes in border
188 | # then takes of summed up version vs no summed version
189 | return torch.log(
190 | torch.max(soft, (multihotmask * (soft * multihotmask).sum(1, keepdim=True)))
191 | )
192 |
193 | class ImgWtLossSoftNLL(nn.Module):
194 | """
195 | Relax Loss
196 | """
197 |
198 | def __init__(self, classes, ignore_index=255, weights=None, upper_bound=1.0,
199 | norm=False):
200 | super(ImgWtLossSoftNLL, self).__init__()
201 | self.weights = weights
202 | self.num_classes = classes
203 | self.ignore_index = ignore_index
204 | self.upper_bound = upper_bound
205 | self.norm = norm
206 | self.batch_weights = cfg.BATCH_WEIGHTING
207 |
208 | def calculate_weights(self, target):
209 | """
210 | Calculate weights of the classes based on training crop
211 | """
212 | if len(target.shape) == 3:
213 | hist = np.sum(target, axis=(1, 2)) * 1.0 / target.sum()
214 | else:
215 | hist = np.sum(target, axis=(0, 2, 3)) * 1.0 / target.sum()
216 | if self.norm:
217 | hist = ((hist != 0) * self.upper_bound * (1 / hist)) + 1
218 | else:
219 | hist = ((hist != 0) * self.upper_bound * (1 - hist)) + 1
220 | return hist[:-1]
221 |
222 | def custom_nll(self, inputs, target, class_weights, border_weights, mask):
223 | """
224 | NLL Relaxed Loss Implementation
225 | """
226 | if (cfg.REDUCE_BORDER_ITER != -1 and cfg.ITER > cfg.REDUCE_BORDER_ITER):
227 | border_weights = 1 / border_weights
228 | target[target > 1] = 1
229 |
230 | loss_matrix = (-1 / border_weights *
231 | (target[:, :-1, :, :].float() *
232 | class_weights.unsqueeze(0).unsqueeze(2).unsqueeze(3) *
233 | customsoftmax(inputs, target[:, :-1, :, :].float())).sum(1)) * \
234 | (1. - mask.float())
235 |
236 | # loss_matrix[border_weights > 1] = 0
237 | loss = loss_matrix.sum()
238 |
239 | # +1 to prevent division by 0
240 | loss = loss / (target.shape[0] * target.shape[2] * target.shape[3] - mask.sum().item() + 1)
241 | return loss
242 |
243 | def forward(self, inputs, target):
244 | weights = target[:, :-1, :, :].sum(1).float()
245 | ignore_mask = (weights == 0)
246 | weights[ignore_mask] = 1
247 |
248 | loss = 0
249 | target_cpu = target.data.cpu().numpy()
250 |
251 | if self.batch_weights:
252 | class_weights = self.calculate_weights(target_cpu)
253 |
254 | for i in range(0, inputs.shape[0]):
255 | if not self.batch_weights:
256 | class_weights = self.calculate_weights(target_cpu[i])
257 | loss = loss + self.custom_nll(inputs[i].unsqueeze(0),
258 | target[i].unsqueeze(0),
259 | class_weights=torch.Tensor(class_weights).cuda(),
260 | border_weights=weights[i], mask=ignore_mask[i])
261 |
262 | loss = loss / inputs.shape[0]
263 | return loss
264 |
265 | class ImgWtLossSoftNLL_by_epoch(nn.Module):
266 | """
267 | Relax Loss
268 | """
269 |
270 | def __init__(self, classes, ignore_index=255, weights=None, upper_bound=1.0,
271 | norm=False):
272 | super(ImgWtLossSoftNLL_by_epoch, self).__init__()
273 | self.weights = weights
274 | self.num_classes = classes
275 | self.ignore_index = ignore_index
276 | self.upper_bound = upper_bound
277 | self.norm = norm
278 | self.batch_weights = cfg.BATCH_WEIGHTING
279 | self.fp16 = False
280 |
281 |
282 | def calculate_weights(self, target):
283 | """
284 | Calculate weights of the classes based on training crop
285 | """
286 | if len(target.shape) == 3:
287 | hist = np.sum(target, axis=(1, 2)) * 1.0 / target.sum()
288 | else:
289 | hist = np.sum(target, axis=(0, 2, 3)) * 1.0 / target.sum()
290 | if self.norm:
291 | hist = ((hist != 0) * self.upper_bound * (1 / hist)) + 1
292 | else:
293 | hist = ((hist != 0) * self.upper_bound * (1 - hist)) + 1
294 | return hist[:-1]
295 |
296 | def custom_nll(self, inputs, target, class_weights, border_weights, mask):
297 | """
298 | NLL Relaxed Loss Implementation
299 | """
300 | if (cfg.REDUCE_BORDER_EPOCH != -1 and cfg.EPOCH > cfg.REDUCE_BORDER_EPOCH):
301 | border_weights = 1 / border_weights
302 | target[target > 1] = 1
303 | if self.fp16:
304 | loss_matrix = (-1 / border_weights *
305 | (target[:, :-1, :, :].half() *
306 | class_weights.unsqueeze(0).unsqueeze(2).unsqueeze(3) *
307 | customsoftmax(inputs, target[:, :-1, :, :].half())).sum(1)) * \
308 | (1. - mask.half())
309 | else:
310 | loss_matrix = (-1 / border_weights *
311 | (target[:, :-1, :, :].float() *
312 | class_weights.unsqueeze(0).unsqueeze(2).unsqueeze(3) *
313 | customsoftmax(inputs, target[:, :-1, :, :].float())).sum(1)) * \
314 | (1. - mask.float())
315 |
316 | # loss_matrix[border_weights > 1] = 0
317 | loss = loss_matrix.sum()
318 |
319 | # +1 to prevent division by 0
320 | loss = loss / (target.shape[0] * target.shape[2] * target.shape[3] - mask.sum().item() + 1)
321 | return loss
322 |
323 | def forward(self, inputs, target):
324 | if self.fp16:
325 | weights = target[:, :-1, :, :].sum(1).half()
326 | else:
327 | weights = target[:, :-1, :, :].sum(1).float()
328 | ignore_mask = (weights == 0)
329 | weights[ignore_mask] = 1
330 |
331 | loss = 0
332 | target_cpu = target.data.cpu().numpy()
333 |
334 | if self.batch_weights:
335 | class_weights = self.calculate_weights(target_cpu)
336 |
337 | for i in range(0, inputs.shape[0]):
338 | if not self.batch_weights:
339 | class_weights = self.calculate_weights(target_cpu[i])
340 | loss = loss + self.custom_nll(inputs[i].unsqueeze(0),
341 | target[i].unsqueeze(0),
342 | class_weights=torch.Tensor(class_weights).cuda(),
343 | border_weights=weights, mask=ignore_mask[i])
344 |
345 | return loss
346 |
--------------------------------------------------------------------------------
/network/Mobilenet.py:
--------------------------------------------------------------------------------
1 | from torch import nn
2 | from torch import Tensor
3 | from torch.utils.model_zoo import load_url as load_state_dict_from_url
4 | from typing import Callable, Any, Optional, List
5 | from network.instance_whitening import InstanceWhitening
6 | from network.mynn import forgiving_state_restore
7 |
8 | __all__ = ['MobileNetV2', 'mobilenet_v2']
9 |
10 |
11 | model_urls = {
12 | 'mobilenet_v2': 'https://download.pytorch.org/models/mobilenet_v2-b0353104.pth',
13 | }
14 |
15 |
16 | def _make_divisible(v: float, divisor: int, min_value: Optional[int] = None) -> int:
17 | """
18 | This function is taken from the original tf repo.
19 | It ensures that all layers have a channel number that is divisible by 8
20 | It can be seen here:
21 | https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet.py
22 | :param v:
23 | :param divisor:
24 | :param min_value:
25 | :return:
26 | """
27 | if min_value is None:
28 | min_value = divisor
29 | new_v = max(min_value, int(v + divisor / 2) // divisor * divisor)
30 | # Make sure that round down does not go down by more than 10%.
31 | if new_v < 0.9 * v:
32 | new_v += divisor
33 | return new_v
34 |
35 |
36 | class ConvBNReLU(nn.Sequential):
37 | def __init__(
38 | self,
39 | in_planes: int,
40 | out_planes: int,
41 | kernel_size: int = 3,
42 | stride: int = 1,
43 | groups: int = 1,
44 | norm_layer: Optional[Callable[..., nn.Module]] = None,
45 | iw: int = 0,
46 | ) -> None:
47 |
48 | padding = (kernel_size - 1) // 2
49 | if norm_layer is None:
50 | norm_layer = nn.BatchNorm2d
51 |
52 | self.iw = iw
53 |
54 | if iw == 1:
55 | instance_norm_layer = InstanceWhitening(out_planes)
56 | elif iw == 2:
57 | instance_norm_layer = InstanceWhitening(out_planes)
58 | elif iw == 3:
59 | instance_norm_layer = nn.InstanceNorm2d(out_planes, affine=False)
60 | elif iw == 4:
61 | instance_norm_layer = nn.InstanceNorm2d(out_planes, affine=True)
62 | else:
63 | instance_norm_layer = nn.Sequential()
64 |
65 | super(ConvBNReLU, self).__init__(
66 | nn.Conv2d(in_planes, out_planes, kernel_size, stride, padding, groups=groups, bias=False),
67 | norm_layer(out_planes),
68 | nn.ReLU6(inplace=True),
69 | instance_norm_layer
70 | )
71 |
72 |
73 | def forward(self, x_tuple):
74 | if len(x_tuple) == 2:
75 | w_arr = x_tuple[1]
76 | x = x_tuple[0]
77 | else:
78 | print("error in BN forward path")
79 | return
80 |
81 | for i, module in enumerate(self):
82 | if i == len(self) - 1:
83 | if self.iw >= 1:
84 | if self.iw == 1 or self.iw == 2:
85 | x, w = self.instance_norm_layer(x)
86 | w_arr.append(w)
87 | else:
88 | x = self.instance_norm_layer(x)
89 | else:
90 | x = module(x)
91 |
92 | return [x, w_arr]
93 |
94 |
95 | class InvertedResidual(nn.Module):
96 | def __init__(
97 | self,
98 | inp: int,
99 | oup: int,
100 | stride: int,
101 | expand_ratio: int,
102 | norm_layer: Optional[Callable[..., nn.Module]] = None,
103 | iw: int = 0,
104 | ) -> None:
105 | super(InvertedResidual, self).__init__()
106 | self.stride = stride
107 | assert stride in [1, 2]
108 | if norm_layer is None:
109 | norm_layer = nn.BatchNorm2d
110 | self.expand_ratio = expand_ratio
111 | self.iw = iw
112 |
113 | if iw == 1:
114 | self.instance_norm_layer = InstanceWhitening(oup)
115 | elif iw == 2:
116 | self.instance_norm_layer = InstanceWhitening(oup)
117 | elif iw == 3:
118 | self.instance_norm_layer = nn.InstanceNorm2d(oup, affine=False)
119 | elif iw == 4:
120 | self.instance_norm_layer = nn.InstanceNorm2d(oup, affine=True)
121 | else:
122 | self.instance_norm_layer = nn.Sequential()
123 |
124 | hidden_dim = int(round(inp * expand_ratio))
125 | self.use_res_connect = self.stride == 1 and inp == oup
126 |
127 | layers: List[nn.Module] = []
128 | if expand_ratio != 1:
129 | # pw
130 | layers.append(ConvBNReLU(inp, hidden_dim, kernel_size=1, norm_layer=norm_layer))
131 | layers.extend([
132 | # dw
133 | ConvBNReLU(hidden_dim, hidden_dim, stride=stride, groups=hidden_dim, norm_layer=norm_layer),
134 | # pw-linear
135 | nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False),
136 | norm_layer(oup),
137 | ])
138 | self.conv = nn.Sequential(*layers)
139 |
140 |
141 | def forward(self, x_tuple):
142 | if len(x_tuple) == 2:
143 | x = x_tuple[0]
144 | else:
145 | print("error in invert residual forward path")
146 | return
147 | if self.expand_ratio != 1:
148 | x_tuple = self.conv[0](x_tuple)
149 | x_tuple = self.conv[1](x_tuple)
150 | conv_x = x_tuple[0]
151 | w_arr = x_tuple[1]
152 | conv_x = self.conv[2](conv_x)
153 | conv_x = self.conv[3](conv_x)
154 | else:
155 | x_tuple = self.conv[0](x_tuple)
156 | conv_x = x_tuple[0]
157 | w_arr = x_tuple[1]
158 | conv_x = self.conv[1](conv_x)
159 | conv_x = self.conv[2](conv_x)
160 |
161 | if self.use_res_connect:
162 | x = x + conv_x
163 | else:
164 | x = conv_x
165 |
166 | if self.iw >= 1:
167 | if self.iw == 1 or self.iw == 2:
168 | x, w = self.instance_norm_layer(x)
169 | w_arr.append(w)
170 | else:
171 | x = self.instance_norm_layer(x)
172 |
173 | return [x, w_arr]
174 |
175 |
176 | class MobileNetV2(nn.Module):
177 | def __init__(
178 | self,
179 | num_classes: int = 1000,
180 | width_mult: float = 1.0,
181 | inverted_residual_setting: Optional[List[List[int]]] = None,
182 | round_nearest: int = 8,
183 | block: Optional[Callable[..., nn.Module]] = None,
184 | norm_layer: Optional[Callable[..., nn.Module]] = None,
185 | iw: list = [0, 0, 0, 0, 0, 0, 0],
186 | ) -> None:
187 | """
188 | MobileNet V2 main class
189 | Args:
190 | num_classes (int): Number of classes
191 | width_mult (float): Width multiplier - adjusts number of channels in each layer by this amount
192 | inverted_residual_setting: Network structure
193 | round_nearest (int): Round the number of channels in each layer to be a multiple of this number
194 | Set to 1 to turn off rounding
195 | block: Module specifying inverted residual building block for mobilenet
196 | norm_layer: Module specifying the normalization layer to use
197 | """
198 | super(MobileNetV2, self).__init__()
199 |
200 | if block is None:
201 | block = InvertedResidual
202 |
203 | if norm_layer is None:
204 | norm_layer = nn.BatchNorm2d
205 |
206 | input_channel = 32
207 | last_channel = 1280
208 |
209 | if inverted_residual_setting is None:
210 | inverted_residual_setting = [
211 | # t, c, n, s
212 | [1, 16, 1, 1], # feature 1
213 | [6, 24, 2, 2], # feature 2, 3
214 | [6, 32, 3, 2], # feature 4, 5, 6
215 | [6, 64, 4, 2], # feature 7, 8, 9, 10
216 | [6, 96, 3, 1], # feature 11, 12, 13
217 | [6, 160, 3, 2], # feature 14, 15, 16
218 | [6, 320, 1, 1], # feature 17
219 | ]
220 |
221 | # only check the first element, assuming user knows t,c,n,s are required
222 | if len(inverted_residual_setting) == 0 or len(inverted_residual_setting[0]) != 4:
223 | raise ValueError("inverted_residual_setting should be non-empty "
224 | "or a 4-element list, got {}".format(inverted_residual_setting))
225 |
226 | # building first layer
227 | input_channel = _make_divisible(input_channel * width_mult, round_nearest)
228 | self.last_channel = _make_divisible(last_channel * max(1.0, width_mult), round_nearest)
229 | # feature 0
230 | features: List[nn.Module] = [ConvBNReLU(3, input_channel, stride=2, norm_layer=norm_layer)]
231 | # building inverted residual blocks
232 | feature_count = 0
233 | iw_layer = [1, 6, 10, 17, 18]
234 | for t, c, n, s in inverted_residual_setting:
235 | output_channel = _make_divisible(c * width_mult, round_nearest)
236 | for i in range(n):
237 | feature_count += 1
238 | stride = s if i == 0 else 1
239 | if feature_count in iw_layer:
240 | layer = iw_layer.index(feature_count)
241 | features.append(block(input_channel, output_channel, stride, expand_ratio=t, norm_layer=norm_layer, iw=iw[layer + 2]))
242 | else:
243 | features.append(block(input_channel, output_channel, stride, expand_ratio=t, norm_layer=norm_layer, iw=0))
244 | input_channel = output_channel
245 | # building last several layers
246 | # feature 18
247 | features.append(ConvBNReLU(input_channel, self.last_channel, kernel_size=1, norm_layer=norm_layer))
248 | # make it nn.Sequential
249 | self.features = nn.Sequential(*features)
250 |
251 | # building classifier
252 | self.classifier = nn.Sequential(
253 | nn.Dropout(0.2),
254 | nn.Linear(self.last_channel, num_classes),
255 | )
256 |
257 | # weight initialization
258 | for m in self.modules():
259 | if isinstance(m, nn.Conv2d):
260 | nn.init.kaiming_normal_(m.weight, mode='fan_out')
261 | if m.bias is not None:
262 | nn.init.zeros_(m.bias)
263 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
264 | nn.init.ones_(m.weight)
265 | nn.init.zeros_(m.bias)
266 | elif isinstance(m, nn.Linear):
267 | nn.init.normal_(m.weight, 0, 0.01)
268 | nn.init.zeros_(m.bias)
269 |
270 | def _forward_impl(self, x: Tensor) -> Tensor:
271 | # This exists since TorchScript doesn't support inheritance, so the superclass method
272 | # (this one) needs to have a name other than `forward` that can be accessed in a subclass
273 | x = self.features(x)
274 | # Cannot use "squeeze" as batch-size can be 1 => must use reshape with x.shape[0]
275 | x = nn.functional.adaptive_avg_pool2d(x, (1, 1)).reshape(x.shape[0], -1)
276 | x = self.classifier(x)
277 | return x
278 |
279 | def forward(self, x: Tensor) -> Tensor:
280 | return self._forward_impl(x)
281 |
282 |
283 | def mobilenet_v2(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> MobileNetV2:
284 | """
285 | Constructs a MobileNetV2 architecture from
286 | `"MobileNetV2: Inverted Residuals and Linear Bottlenecks" `_.
287 | Args:
288 | pretrained (bool): If True, returns a model pre-trained on ImageNet
289 | progress (bool): If True, displays a progress bar of the download to stderr
290 | """
291 | model = MobileNetV2(**kwargs)
292 | if pretrained:
293 | state_dict = load_state_dict_from_url(model_urls['mobilenet_v2'],
294 | progress=progress)
295 | #model.load_state_dict(state_dict)
296 | forgiving_state_restore(model, state_dict)
297 | return model
298 |
--------------------------------------------------------------------------------
/network/Shufflenet.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from torch.utils.model_zoo import load_url as load_state_dict_from_url
4 | from network.instance_whitening import InstanceWhitening
5 | import network.mynn as mynn
6 |
7 |
8 | __all__ = [
9 | 'ShuffleNetV2', 'shufflenet_v2_x0_5', 'shufflenet_v2_x1_0',
10 | 'shufflenet_v2_x1_5', 'shufflenet_v2_x2_0'
11 | ]
12 |
13 | model_urls = {
14 | 'shufflenetv2_x0.5': 'https://download.pytorch.org/models/shufflenetv2_x0.5-f707e7126e.pth',
15 | 'shufflenetv2_x1.0': 'https://download.pytorch.org/models/shufflenetv2_x1-5666bf0f80.pth',
16 | 'shufflenetv2_x1.5': None,
17 | 'shufflenetv2_x2.0': None,
18 | }
19 |
20 |
21 | def channel_shuffle(x, groups):
22 | # type: (torch.Tensor, int) -> torch.Tensor
23 | batchsize, num_channels, height, width = x.data.size()
24 | channels_per_group = num_channels // groups
25 |
26 | # reshape
27 | x = x.view(batchsize, groups,
28 | channels_per_group, height, width)
29 |
30 | x = torch.transpose(x, 1, 2).contiguous()
31 |
32 | # flatten
33 | x = x.view(batchsize, -1, height, width)
34 |
35 | return x
36 |
37 |
38 | class InvertedResidual(nn.Module):
39 | def __init__(self, inp, oup, stride, iw=0):
40 | super(InvertedResidual, self).__init__()
41 |
42 | if not (1 <= stride <= 3):
43 | raise ValueError('illegal stride value')
44 | self.stride = stride
45 |
46 | branch_features = oup // 2
47 | assert (self.stride != 1) or (inp == branch_features << 1)
48 |
49 | if self.stride > 1:
50 | self.branch1 = nn.Sequential(
51 | self.depthwise_conv(inp, inp, kernel_size=3, stride=self.stride, padding=1),
52 | nn.BatchNorm2d(inp),
53 | nn.Conv2d(inp, branch_features, kernel_size=1, stride=1, padding=0, bias=False),
54 | nn.BatchNorm2d(branch_features),
55 | nn.ReLU(inplace=True),
56 | )
57 | else:
58 | self.branch1 = nn.Sequential()
59 |
60 | self.branch2 = nn.Sequential(
61 | nn.Conv2d(inp if (self.stride > 1) else branch_features,
62 | branch_features, kernel_size=1, stride=1, padding=0, bias=False),
63 | nn.BatchNorm2d(branch_features),
64 | nn.ReLU(inplace=True),
65 | self.depthwise_conv(branch_features, branch_features, kernel_size=3, stride=self.stride, padding=1),
66 | nn.BatchNorm2d(branch_features),
67 | nn.Conv2d(branch_features, branch_features, kernel_size=1, stride=1, padding=0, bias=False),
68 | nn.BatchNorm2d(branch_features),
69 | nn.ReLU(inplace=True),
70 | )
71 | self.iw = iw
72 | if iw == 1:
73 | self.instance_norm_layer = InstanceWhitening(oup)
74 | elif iw == 2:
75 | self.instance_norm_layer = InstanceWhitening(oup)
76 | elif iw == 3:
77 | self.instance_norm_layer = nn.InstanceNorm2d(oup, affine=False)
78 | elif iw == 4:
79 | self.instance_norm_layer = nn.InstanceNorm2d(oup, affine=True)
80 | else:
81 | self.instance_norm_layer = nn.Sequential()
82 |
83 |
84 | @staticmethod
85 | def depthwise_conv(i, o, kernel_size, stride=1, padding=0, bias=False):
86 | return nn.Conv2d(i, o, kernel_size, stride, padding, bias=bias, groups=i)
87 |
88 | def forward(self, x_tuple):
89 | if len(x_tuple) == 2:
90 | w_arr = x_tuple[1]
91 | x = x_tuple[0]
92 | else:
93 | print("error in invert residual forward path")
94 | return
95 |
96 | if self.stride == 1:
97 | x1, x2 = x.chunk(2, dim=1)
98 | out = torch.cat((x1, self.branch2(x2)), dim=1)
99 | else:
100 | out = torch.cat((self.branch1(x), self.branch2(x)), dim=1)
101 |
102 | out = channel_shuffle(out, 2)
103 |
104 | if self.iw >= 1:
105 | if self.iw == 1 or self.iw == 2:
106 | out, w = self.instance_norm_layer(out)
107 | w_arr.append(w)
108 | else:
109 | out = self.instance_norm_layer(out)
110 | return [out, w_arr]
111 |
112 |
113 | class ShuffleNetV2(nn.Module):
114 | def __init__(self, stages_repeats, stages_out_channels, num_classes=1000, inverted_residual=InvertedResidual,
115 | iw=[0, 0, 0, 0, 0, 0, 0]):
116 | super(ShuffleNetV2, self).__init__()
117 |
118 | if len(stages_repeats) != 3:
119 | raise ValueError('expected stages_repeats as list of 3 positive ints')
120 | if len(stages_out_channels) != 5:
121 | raise ValueError('expected stages_out_channels as list of 5 positive ints')
122 | self._stage_out_channels = stages_out_channels
123 |
124 | input_channels = 3
125 | output_channels = self._stage_out_channels[0]
126 | self.conv1 = nn.Sequential(
127 | nn.Conv2d(input_channels, output_channels, 3, 2, 1, bias=False),
128 | nn.BatchNorm2d(output_channels),
129 | nn.ReLU(inplace=True),
130 | )
131 |
132 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
133 |
134 | iw_count = 2
135 |
136 | if iw[iw_count] == 1:
137 | self.instance_norm_layer1 = InstanceWhitening(output_channels)
138 | elif iw[iw_count] == 2:
139 | self.instance_norm_layer1 = InstanceWhitening(output_channels)
140 | elif iw[iw_count] == 3:
141 | self.instance_norm_layer1 = nn.InstanceNorm2d(output_channels, affine=False)
142 | elif iw[iw_count] == 4:
143 | self.instance_norm_layer1 = nn.InstanceNorm2d(output_channels, affine=True)
144 | else:
145 | self.instance_norm_layer1 = nn.Sequential()
146 |
147 | iw_count += 1
148 |
149 | input_channels = output_channels
150 |
151 | stage_names = ['stage{}'.format(i) for i in [2, 3, 4]]
152 | for name, repeats, output_channels in zip(
153 | stage_names, stages_repeats, self._stage_out_channels[1:]):
154 | seq = [inverted_residual(input_channels, output_channels, 2)]
155 | for i in range(repeats - 1):
156 | if i == repeats - 2:
157 | seq.append(inverted_residual(output_channels, output_channels, 1, iw=iw[iw_count]))
158 | iw_count += 1
159 | else:
160 | seq.append(inverted_residual(output_channels, output_channels, 1, iw=0))
161 | setattr(self, name, nn.Sequential(*seq))
162 | input_channels = output_channels
163 |
164 | output_channels = self._stage_out_channels[-1]
165 |
166 | self.conv5 = nn.Sequential(
167 | nn.Conv2d(input_channels, output_channels, 1, 1, 0, bias=False),
168 | nn.BatchNorm2d(output_channels),
169 | nn.ReLU(inplace=True),
170 | )
171 | if iw[iw_count] == 1:
172 | self.instance_norm_layer2 = InstanceWhitening(output_channels)
173 | elif iw[iw_count] == 2:
174 | self.instance_norm_layer2 = InstanceWhitening(output_channels)
175 | elif iw[iw_count] == 3:
176 | self.instance_norm_layer2 = nn.InstanceNorm2d(output_channels, affine=False)
177 | elif iw[iw_count] == 4:
178 | self.instance_norm_layer2 = nn.InstanceNorm2d(output_channels, affine=True)
179 | else:
180 | self.instance_norm_layer2 = nn.Sequential()
181 |
182 | self.fc = nn.Linear(output_channels, num_classes)
183 |
184 |
185 | def _forward_impl(self, x):
186 | # See note [TorchScript super()]
187 | """
188 | x = self.conv1(x)
189 | x = self.maxpool(x)
190 | """
191 | x = self.layer0(x)
192 | x = self.stage2(x)
193 | x = self.stage3(x)
194 | x = self.stage4(x)
195 | x = self.layer4(x)
196 | x = x.mean([2, 3]) # globalpool
197 | x = self.fc(x)
198 | return x
199 |
200 | def forward(self, x):
201 | return self._forward_impl(x)
202 |
203 |
204 | def _shufflenetv2(arch, pretrained, progress, *args, **kwargs):
205 | model = ShuffleNetV2(*args, **kwargs)
206 |
207 | if pretrained:
208 | model_url = model_urls[arch]
209 | if model_url is None:
210 | raise NotImplementedError('pretrained {} is not supported as of now'.format(arch))
211 | else:
212 | state_dict = load_state_dict_from_url(model_url, progress=progress)
213 | mynn.forgiving_state_restore(model, state_dict)
214 | ### model.load_state_dict(state_dict)
215 |
216 | return model
217 |
218 |
219 | def shufflenet_v2_x0_5(pretrained=False, progress=True, **kwargs):
220 | """
221 | Constructs a ShuffleNetV2 with 0.5x output channels, as described in
222 | `"ShuffleNet V2: Practical Guidelines for Efficient CNN Architecture Design"
223 | `_.
224 |
225 | Args:
226 | pretrained (bool): If True, returns a model pre-trained on ImageNet
227 | progress (bool): If True, displays a progress bar of the download to stderr
228 | """
229 | return _shufflenetv2('shufflenetv2_x0.5', pretrained, progress,
230 | [4, 8, 4], [24, 48, 96, 192, 1024], **kwargs)
231 |
232 |
233 | def shufflenet_v2_x1_0(pretrained=False, progress=True, **kwargs):
234 | """
235 | Constructs a ShuffleNetV2 with 1.0x output channels, as described in
236 | `"ShuffleNet V2: Practical Guidelines for Efficient CNN Architecture Design"
237 | `_.
238 |
239 | Args:
240 | pretrained (bool): If True, returns a model pre-trained on ImageNet
241 | progress (bool): If True, displays a progress bar of the download to stderr
242 | """
243 | return _shufflenetv2('shufflenetv2_x1.0', pretrained, progress,
244 | [4, 8, 4], [24, 116, 232, 464, 1024], **kwargs)
245 |
246 |
247 | def shufflenet_v2_x1_5(pretrained=False, progress=True, **kwargs):
248 | """
249 | Constructs a ShuffleNetV2 with 1.5x output channels, as described in
250 | `"ShuffleNet V2: Practical Guidelines for Efficient CNN Architecture Design"
251 | `_.
252 |
253 | Args:
254 | pretrained (bool): If True, returns a model pre-trained on ImageNet
255 | progress (bool): If True, displays a progress bar of the download to stderr
256 | """
257 | return _shufflenetv2('shufflenetv2_x1.5', pretrained, progress,
258 | [4, 8, 4], [24, 176, 352, 704, 1024], **kwargs)
259 |
260 |
261 | def shufflenet_v2_x2_0(pretrained=False, progress=True, **kwargs):
262 | """
263 | Constructs a ShuffleNetV2 with 2.0x output channels, as described in
264 | `"ShuffleNet V2: Practical Guidelines for Efficient CNN Architecture Design"
265 | `_.
266 |
267 | Args:
268 | pretrained (bool): If True, returns a model pre-trained on ImageNet
269 | progress (bool): If True, displays a progress bar of the download to stderr
270 | """
271 | return _shufflenetv2('shufflenetv2_x2.0', pretrained, progress,
272 | [4, 8, 4], [24, 244, 488, 976, 2048], **kwargs)
273 |
--------------------------------------------------------------------------------
/network/__init__.py:
--------------------------------------------------------------------------------
1 | """
2 | Network Initializations
3 | """
4 |
5 | import logging
6 | import importlib
7 | import torch
8 | import datasets
9 |
10 |
11 |
12 | def get_net(args, criterion, criterion_aux=None):
13 | """
14 | Get Network Architecture based on arguments provided
15 | """
16 | net = get_model(args=args, num_classes=datasets.num_classes,
17 | criterion=criterion, criterion_aux=criterion_aux)
18 | num_params = sum([param.nelement() for param in net.parameters()])
19 | logging.info('Model params = {:2.3f}M'.format(num_params / 1000000))
20 |
21 | net = net.cuda()
22 | return net
23 |
24 |
25 | def warp_network_in_dataparallel(net, gpuid):
26 | """
27 | Wrap the network in Dataparallel
28 | """
29 | # torch.cuda.set_device(gpuid)
30 | # net.cuda(gpuid)
31 | net = torch.nn.parallel.DistributedDataParallel(net, device_ids=[gpuid], find_unused_parameters=True)
32 | # net = torch.nn.parallel.DistributedDataParallel(net, device_ids=[gpuid])#, find_unused_parameters=True)
33 | return net
34 |
35 |
36 | def get_model(args, num_classes, criterion, criterion_aux=None):
37 | """
38 | Fetch Network Function Pointer
39 | """
40 | network = args.arch
41 | module = network[:network.rfind('.')]
42 | model = network[network.rfind('.') + 1:]
43 | mod = importlib.import_module(module)
44 | net_func = getattr(mod, model)
45 | net = net_func(args=args, num_classes=num_classes, criterion=criterion, criterion_aux=criterion_aux)
46 | return net
47 |
--------------------------------------------------------------------------------
/network/__pycache__/Mobilenet.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/root0yang/BlindNet/1359fa6e9d1e9b011c416f43ec265e31cbde7d9a/network/__pycache__/Mobilenet.cpython-36.pyc
--------------------------------------------------------------------------------
/network/__pycache__/Mobilenet.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/root0yang/BlindNet/1359fa6e9d1e9b011c416f43ec265e31cbde7d9a/network/__pycache__/Mobilenet.cpython-37.pyc
--------------------------------------------------------------------------------
/network/__pycache__/Resnet.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/root0yang/BlindNet/1359fa6e9d1e9b011c416f43ec265e31cbde7d9a/network/__pycache__/Resnet.cpython-36.pyc
--------------------------------------------------------------------------------
/network/__pycache__/Resnet.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/root0yang/BlindNet/1359fa6e9d1e9b011c416f43ec265e31cbde7d9a/network/__pycache__/Resnet.cpython-37.pyc
--------------------------------------------------------------------------------
/network/__pycache__/Shufflenet.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/root0yang/BlindNet/1359fa6e9d1e9b011c416f43ec265e31cbde7d9a/network/__pycache__/Shufflenet.cpython-36.pyc
--------------------------------------------------------------------------------
/network/__pycache__/Shufflenet.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/root0yang/BlindNet/1359fa6e9d1e9b011c416f43ec265e31cbde7d9a/network/__pycache__/Shufflenet.cpython-37.pyc
--------------------------------------------------------------------------------
/network/__pycache__/__init__.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/root0yang/BlindNet/1359fa6e9d1e9b011c416f43ec265e31cbde7d9a/network/__pycache__/__init__.cpython-36.pyc
--------------------------------------------------------------------------------
/network/__pycache__/__init__.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/root0yang/BlindNet/1359fa6e9d1e9b011c416f43ec265e31cbde7d9a/network/__pycache__/__init__.cpython-37.pyc
--------------------------------------------------------------------------------
/network/__pycache__/cov_settings.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/root0yang/BlindNet/1359fa6e9d1e9b011c416f43ec265e31cbde7d9a/network/__pycache__/cov_settings.cpython-36.pyc
--------------------------------------------------------------------------------
/network/__pycache__/cov_settings.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/root0yang/BlindNet/1359fa6e9d1e9b011c416f43ec265e31cbde7d9a/network/__pycache__/cov_settings.cpython-37.pyc
--------------------------------------------------------------------------------
/network/__pycache__/cwcl.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/root0yang/BlindNet/1359fa6e9d1e9b011c416f43ec265e31cbde7d9a/network/__pycache__/cwcl.cpython-36.pyc
--------------------------------------------------------------------------------
/network/__pycache__/cwcl.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/root0yang/BlindNet/1359fa6e9d1e9b011c416f43ec265e31cbde7d9a/network/__pycache__/cwcl.cpython-37.pyc
--------------------------------------------------------------------------------
/network/__pycache__/deepv3.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/root0yang/BlindNet/1359fa6e9d1e9b011c416f43ec265e31cbde7d9a/network/__pycache__/deepv3.cpython-36.pyc
--------------------------------------------------------------------------------
/network/__pycache__/deepv3.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/root0yang/BlindNet/1359fa6e9d1e9b011c416f43ec265e31cbde7d9a/network/__pycache__/deepv3.cpython-37.pyc
--------------------------------------------------------------------------------
/network/__pycache__/edge_contrast.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/root0yang/BlindNet/1359fa6e9d1e9b011c416f43ec265e31cbde7d9a/network/__pycache__/edge_contrast.cpython-36.pyc
--------------------------------------------------------------------------------
/network/__pycache__/edge_contrast_batch.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/root0yang/BlindNet/1359fa6e9d1e9b011c416f43ec265e31cbde7d9a/network/__pycache__/edge_contrast_batch.cpython-36.pyc
--------------------------------------------------------------------------------
/network/__pycache__/edge_contrast_v1_opt.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/root0yang/BlindNet/1359fa6e9d1e9b011c416f43ec265e31cbde7d9a/network/__pycache__/edge_contrast_v1_opt.cpython-36.pyc
--------------------------------------------------------------------------------
/network/__pycache__/edge_contrast_v1_opt.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/root0yang/BlindNet/1359fa6e9d1e9b011c416f43ec265e31cbde7d9a/network/__pycache__/edge_contrast_v1_opt.cpython-37.pyc
--------------------------------------------------------------------------------
/network/__pycache__/edge_contrast_v2.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/root0yang/BlindNet/1359fa6e9d1e9b011c416f43ec265e31cbde7d9a/network/__pycache__/edge_contrast_v2.cpython-36.pyc
--------------------------------------------------------------------------------
/network/__pycache__/instance_whitening.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/root0yang/BlindNet/1359fa6e9d1e9b011c416f43ec265e31cbde7d9a/network/__pycache__/instance_whitening.cpython-36.pyc
--------------------------------------------------------------------------------
/network/__pycache__/instance_whitening.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/root0yang/BlindNet/1359fa6e9d1e9b011c416f43ec265e31cbde7d9a/network/__pycache__/instance_whitening.cpython-37.pyc
--------------------------------------------------------------------------------
/network/__pycache__/mynn.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/root0yang/BlindNet/1359fa6e9d1e9b011c416f43ec265e31cbde7d9a/network/__pycache__/mynn.cpython-36.pyc
--------------------------------------------------------------------------------
/network/__pycache__/mynn.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/root0yang/BlindNet/1359fa6e9d1e9b011c416f43ec265e31cbde7d9a/network/__pycache__/mynn.cpython-37.pyc
--------------------------------------------------------------------------------
/network/__pycache__/pixel_nce.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/root0yang/BlindNet/1359fa6e9d1e9b011c416f43ec265e31cbde7d9a/network/__pycache__/pixel_nce.cpython-36.pyc
--------------------------------------------------------------------------------
/network/__pycache__/pixel_nce.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/root0yang/BlindNet/1359fa6e9d1e9b011c416f43ec265e31cbde7d9a/network/__pycache__/pixel_nce.cpython-37.pyc
--------------------------------------------------------------------------------
/network/__pycache__/pixel_nce_batch.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/root0yang/BlindNet/1359fa6e9d1e9b011c416f43ec265e31cbde7d9a/network/__pycache__/pixel_nce_batch.cpython-36.pyc
--------------------------------------------------------------------------------
/network/__pycache__/sdcl.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/root0yang/BlindNet/1359fa6e9d1e9b011c416f43ec265e31cbde7d9a/network/__pycache__/sdcl.cpython-36.pyc
--------------------------------------------------------------------------------
/network/__pycache__/sdcl.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/root0yang/BlindNet/1359fa6e9d1e9b011c416f43ec265e31cbde7d9a/network/__pycache__/sdcl.cpython-37.pyc
--------------------------------------------------------------------------------
/network/__pycache__/sync_switchwhiten.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/root0yang/BlindNet/1359fa6e9d1e9b011c416f43ec265e31cbde7d9a/network/__pycache__/sync_switchwhiten.cpython-36.pyc
--------------------------------------------------------------------------------
/network/__pycache__/sync_switchwhiten.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/root0yang/BlindNet/1359fa6e9d1e9b011c416f43ec265e31cbde7d9a/network/__pycache__/sync_switchwhiten.cpython-37.pyc
--------------------------------------------------------------------------------
/network/bn_helper.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import functools
3 |
4 | if torch.__version__.startswith('0'):
5 | from .sync_bn.inplace_abn.bn import InPlaceABNSync
6 | BatchNorm2d = functools.partial(InPlaceABNSync, activation='none')
7 | BatchNorm2d_class = InPlaceABNSync
8 | relu_inplace = False
9 | else:
10 | BatchNorm2d_class = BatchNorm2d = torch.nn.SyncBatchNorm
11 | relu_inplace = True
--------------------------------------------------------------------------------
/network/cov_settings.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from network.mynn import initialize_embedding
4 | import kmeans1d
5 |
6 |
7 | def make_cov_index_matrix(dim): # make symmetric matrix for embedding index
8 | matrix = torch.LongTensor()
9 | s_index = 0
10 | for i in range(dim):
11 | matrix = torch.cat([matrix, torch.arange(s_index, s_index + dim).unsqueeze(0)], dim=0)
12 | s_index += (dim - (2 + i))
13 | return matrix.triu(diagonal=1).transpose(0, 1) + matrix.triu(diagonal=1)
14 |
15 |
16 | class CovMatrix_ISW:
17 | def __init__(self, dim, relax_denom=0, clusters=50):
18 | super(CovMatrix_ISW, self).__init__()
19 |
20 | self.dim = dim
21 | self.i = torch.eye(dim, dim).cuda()
22 |
23 | # print(torch.ones(16, 16).triu(diagonal=1))
24 | self.reversal_i = torch.ones(dim, dim).triu(diagonal=1).cuda()
25 |
26 | # num_off_diagonal = ((dim * dim - dim) // 2) # number of off-diagonal
27 | self.num_off_diagonal = torch.sum(self.reversal_i)
28 | self.num_sensitive = 0
29 | self.var_matrix = None
30 | self.count_var_cov = 0
31 | self.mask_matrix = None
32 | self.clusters = clusters
33 | print("num_off_diagonal", self.num_off_diagonal)
34 | if relax_denom == 0: # kmeans1d clustering setting for ISW
35 | print("relax_denom == 0!!!!!")
36 | print("cluster == ", self.clusters)
37 | self.margin = 0
38 | else: # do not use
39 | self.margin = self.num_off_diagonal // relax_denom
40 |
41 | def get_eye_matrix(self):
42 | return self.i, self.reversal_i
43 |
44 | def get_mask_matrix(self, mask=True):
45 | if self.mask_matrix is None:
46 | self.set_mask_matrix()
47 | return self.i, self.mask_matrix, 0, self.num_sensitive
48 |
49 | def reset_mask_matrix(self):
50 | self.mask_matrix = None
51 |
52 | def set_mask_matrix(self):
53 | # torch.set_printoptions(threshold=500000)
54 | self.var_matrix = self.var_matrix / self.count_var_cov
55 | var_flatten = torch.flatten(self.var_matrix)
56 |
57 | if self.margin == 0: # kmeans1d clustering setting for ISW
58 | clusters, centroids = kmeans1d.cluster(var_flatten, self.clusters) # 50 clusters
59 | num_sensitive = var_flatten.size()[0] - clusters.count(0) # 1: Insensitive Cov, 2~50: Sensitive Cov
60 | print("num_sensitive, centroids =", num_sensitive, centroids)
61 | _, indices = torch.topk(var_flatten, k=int(num_sensitive))
62 | else: # do not use
63 | num_sensitive = self.num_off_diagonal - self.margin
64 | print("num_sensitive = ", num_sensitive)
65 | _, indices = torch.topk(var_flatten, k=int(num_sensitive))
66 | mask_matrix = torch.flatten(torch.zeros(self.dim, self.dim).cuda())
67 | mask_matrix[indices] = 1
68 |
69 | if self.mask_matrix is not None:
70 | self.mask_matrix = (self.mask_matrix.int() & mask_matrix.view(self.dim, self.dim).int()).float()
71 | else:
72 | self.mask_matrix = mask_matrix.view(self.dim, self.dim)
73 | self.num_sensitive = torch.sum(self.mask_matrix)
74 | print("Check whether two ints are same", num_sensitive, self.num_sensitive)
75 |
76 | self.var_matrix = None
77 | self.count_var_cov = 0
78 |
79 | if torch.cuda.current_device() == 0:
80 | print("Covariance Info: (CXC Shape, Num_Off_Diagonal)", self.mask_matrix.shape, self.num_off_diagonal)
81 | print("Selective (Sensitive Covariance)", self.num_sensitive)
82 |
83 |
84 | def set_variance_of_covariance(self, var_cov):
85 | if self.var_matrix is None:
86 | self.var_matrix = var_cov
87 | else:
88 | self.var_matrix = self.var_matrix + var_cov
89 | self.count_var_cov += 1
90 |
91 | class CovMatrix_IRW:
92 | def __init__(self, dim, relax_denom=0):
93 | super(CovMatrix_IRW, self).__init__()
94 |
95 | self.dim = dim
96 | self.i = torch.eye(dim, dim).cuda()
97 | self.reversal_i = torch.ones(dim, dim).triu(diagonal=1).cuda()
98 |
99 | self.num_off_diagonal = torch.sum(self.reversal_i)
100 | if relax_denom == 0:
101 | print("relax_denom == 0!!!!!")
102 | self.margin = 0
103 | else:
104 | self.margin = self.num_off_diagonal // relax_denom
105 |
106 | def get_mask_matrix(self):
107 | return self.i, self.reversal_i, self.margin, self.num_off_diagonal
108 |
--------------------------------------------------------------------------------
/network/cwcl.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | from typing import Any
3 | from packaging import version
4 |
5 | from abc import ABC
6 |
7 | import torch
8 | from torch import nn
9 | import torch.nn.functional as F
10 |
11 |
12 | class Class_PixelNCELoss(nn.Module):
13 | def __init__(self, args):
14 | super(Class_PixelNCELoss, self).__init__()
15 |
16 | self.args = args
17 | self.ignore_label = 255
18 |
19 | self.cross_entropy_loss = torch.nn.CrossEntropyLoss(reduction='none')
20 |
21 | self.max_classes = self.args.contrast_max_classes
22 | self.max_views = self.args.contrast_max_views
23 |
24 | # reshape label or prediction
25 | def resize_label(self, labels, HW):
26 | labels = labels.unsqueeze(1).float().clone()
27 | labels = torch.nn.functional.interpolate(labels,
28 | HW, mode='nearest')
29 | labels = labels.squeeze(1).long()
30 |
31 | return labels
32 |
33 | def _hard_anchor_sampling(self, X_q, X_k, y_hat, y):
34 | # X : Feature map, shape:(B, h*w, C), y_hat : label, shape:(B, h*w), y : prediction, shape:(B, H*W?)
35 | batch_size, feat_dim = X_q.shape[0], X_q.shape[-1]
36 |
37 | classes = []
38 | num_classes = []
39 | total_classes = 0
40 | # 한 배치 내의 이미지들에 대한 label들로부터 존재하는 class들 골라내기
41 | for ii in range(batch_size):
42 | this_y = y_hat[ii]
43 | this_classes = torch.unique(this_y) # 텐서에서 중복된 요소 제거하여 존재하는 고유요소들 반환
44 | this_classes = [x for x in this_classes if x != self.ignore_label] # ignore label 제거
45 | this_classes = [x for x in this_classes if (this_y == x).nonzero().shape[0] > self.max_views] # class가 일정 개수 이상인 경우만 골라내기
46 |
47 | classes.append(this_classes)
48 |
49 | total_classes += len(this_classes)
50 | num_classes.append(len(this_classes))
51 |
52 | # return none if there is no class in the image
53 | if total_classes == 0:
54 | return None, None, None
55 |
56 | n_view = self.max_views
57 |
58 | # output tensors
59 | X_q_ = torch.zeros((batch_size, self.max_classes, n_view, feat_dim), dtype=torch.float).cuda()
60 | X_k_ = torch.zeros((batch_size, self.max_classes, n_view, feat_dim), dtype=torch.float).cuda()
61 |
62 | for ii in range(batch_size):
63 | this_y_hat = y_hat[ii]
64 | this_y = y[ii]
65 | this_classes = classes[ii]
66 | this_indices = []
67 |
68 | # if there is no class in the image, randomly sample patcthes
69 | if len(this_classes) == 0:
70 | indices = torch.arange(X_q.shape[1], device=X_q.device)
71 | perm = torch.randperm(X_q.shape[1], device=X_q.device)
72 | indices = indices[perm[:n_view * self.max_classes]]
73 | indices = indices.view(self.max_classes, -1)
74 |
75 | X_q_[ii, :, :, :] = X_q[ii, indices, :]
76 | X_k_[ii, :, :, :] = X_k[ii, indices, :]
77 |
78 | continue
79 |
80 | # referecne : https://github.com/tfzhou/ContrastiveSeg/tree/main
81 | for n, cls_id in enumerate(this_classes):
82 |
83 | if n == self.max_classes:
84 | break
85 |
86 | # sample hard pathces(wrong prediction) and easy pathces(correct prediction)
87 | hard_indices = ((this_y_hat == cls_id) & (this_y != cls_id)).nonzero()
88 | easy_indices = ((this_y_hat == cls_id) & (this_y == cls_id)).nonzero()
89 |
90 | num_hard = hard_indices.shape[0]
91 | num_easy = easy_indices.shape[0]
92 |
93 | if num_hard >= n_view / 2 and num_easy >= n_view / 2:
94 | num_hard_keep = n_view // 2
95 | num_easy_keep = n_view - num_hard_keep
96 | elif num_hard >= n_view / 2:
97 | num_easy_keep = num_easy
98 | num_hard_keep = n_view - num_easy_keep
99 | elif num_easy >= n_view / 2:
100 | num_hard_keep = num_hard
101 | num_easy_keep = n_view - num_hard_keep
102 |
103 | perm = torch.randperm(num_hard)
104 | hard_indices = hard_indices[perm[:num_hard_keep]]
105 | perm = torch.randperm(num_easy)
106 | easy_indices = easy_indices[perm[:num_easy_keep]]
107 | indices = torch.cat((hard_indices, easy_indices), dim=0)
108 |
109 | X_q_[ii, n, :, :] = X_q[ii, indices, :].squeeze(1)
110 | X_k_[ii, n, :, :] = X_k[ii, indices, :].squeeze(1)
111 |
112 | this_indices.append(indices)
113 |
114 | # fill the spare space with random pathces
115 | if len(this_classes) < self.max_classes:
116 | this_indices = torch.stack(this_indices)
117 | this_indices = this_indices.flatten(0, 1)
118 |
119 | num_remain = self.max_classes - len(this_classes)
120 | all_indices = torch.arange(X_q.shape[1], device=X_q[0].device)
121 | left_indices = torch.zeros(X_q .shape[1], device=X_q[0].device, dtype=torch.uint8)
122 | left_indices[this_indices] = 1
123 | left_indices = all_indices[~left_indices]
124 |
125 | perm = torch.randperm(len(left_indices), device=X_q[0].device)
126 |
127 | indices = left_indices[perm[:n_view * num_remain]]
128 | indices = indices.view(num_remain, -1)
129 |
130 | X_q_[ii, n + 1:, :, :] = X_q[ii, indices, :]
131 | X_k_[ii, n + 1:, :, :] = X_k[ii, indices, :]
132 |
133 | return X_q_, X_k_, num_classes
134 |
135 | def _contrastive(self, feats_q_, feats_k_):
136 | # feats shape : (B, nc, N, C)
137 | batch_size, num_classes, n_view, patch_dim = feats_q_.shape
138 | num_patches = batch_size * num_classes * n_view
139 |
140 | # feats shape : (B*nc*N, 1, C)
141 | feats_q_ = feats_q_.contiguous().view(num_patches, -1, patch_dim)
142 | feats_k_ = feats_k_.contiguous().view(num_patches, -1, patch_dim)
143 |
144 | # logit_positive : same positive patches between key and query
145 | # shape : (B * nc * N , 1)
146 | l_pos = torch.bmm(
147 | feats_q_, feats_k_.transpose(2, 1)
148 | )
149 | l_pos =l_pos.view(num_patches, 1)
150 |
151 | # feats shape : (B, nc*N, C)
152 | feats_q_ = feats_q_.contiguous().view(batch_size, -1, patch_dim)
153 | feats_k_ = feats_k_.contiguous().view(batch_size, -1, patch_dim)
154 | n_patches = feats_q_.shape[1]
155 |
156 | # logit negative shape : (B, nc*N, nc*N)
157 | l_neg_curbatch = torch.bmm(feats_q_, feats_k_.transpose(2, 1))
158 |
159 | # exclude same class patches
160 | diag_block= torch.zeros((batch_size, n_patches, n_patches), device=feats_q_.device, dtype=torch.uint8)
161 | for i in range(num_classes):
162 | diag_block[:, i*n_view:(i+1)*n_view, i*n_view:(i+1)*n_view] = 1
163 |
164 | l_neg_curbatch = l_neg_curbatch[~diag_block].view(batch_size, n_patches, -1)
165 |
166 | # logit negative shape : (B*nc*N, nc*(N-1))
167 | l_neg = l_neg_curbatch.view(num_patches, -1)
168 |
169 | out = torch.cat([l_pos, l_neg], dim=1) / self.args.nce_T
170 |
171 | loss = self.cross_entropy_loss(out, torch.zeros(out.size(0), dtype=torch.long, device=feats_q_.device))
172 |
173 | return loss
174 |
175 | def forward(self, feats_q, feats_k, labels=None, predict=None):
176 | B, C, H, W = feats_q.shape
177 |
178 | # resize label and prediction
179 | labels = self.resize_label(labels, (H, W))
180 | predict = self.resize_label(predict, (H, W))
181 |
182 | labels = labels.contiguous().view(B, -1)
183 | predict = predict.contiguous().view(B, -1)
184 |
185 | # change axis
186 | feats_q = feats_q.permute(0, 2, 3, 1)
187 | feats_q = feats_q.contiguous().view(feats_q.shape[0], -1, feats_q.shape[-1])
188 |
189 | feats_k = feats_k.detach()
190 |
191 | feats_k = feats_k.permute(0, 2, 3, 1)
192 | feats_k = feats_k.contiguous().view(feats_k.shape[0], -1, feats_k.shape[-1])
193 |
194 | # sample patches
195 | feats_q_, feats_k_, num_classes = self._hard_anchor_sampling(feats_q, feats_k, labels, predict)
196 |
197 | if feats_q_ is None:
198 | loss = torch.FloatTensor([0]).cuda()
199 | return loss
200 |
201 | loss = self._contrastive(feats_q_, feats_k_)
202 |
203 | del labels
204 |
205 | return loss
206 |
207 | class Normalize(nn.Module):
208 | def __init__(self, power=2):
209 | super(Normalize, self).__init__()
210 | self.power = power
211 |
212 | def forward(self, x):
213 | norm = x.pow(self.power).sum(1, keepdim=True).pow(1. / self.power)
214 | out = x.div(norm + 1e-7)
215 | return out
216 |
217 |
218 | class ProjectionHead(nn.Module):
219 |
220 | def __init__(self, dim_in, proj_dim=256):
221 | super(ProjectionHead, self).__init__()
222 |
223 | self.proj = nn.Sequential(
224 | nn.Conv2d(dim_in, dim_in, kernel_size=1),
225 | nn.ReLU(),
226 | #nn.BatchNorm2d(dim_in),
227 | nn.Conv2d(dim_in, proj_dim, kernel_size=1)
228 | )
229 | self.l2norm = Normalize(2)
230 |
231 | def forward(self, x):
232 | return self.l2norm(self.proj(x))
233 | #return F.normalize(self.proj(x), p=2, dim=1)
234 |
235 |
--------------------------------------------------------------------------------
/network/instance_whitening.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 |
6 | class InstanceWhitening(nn.Module):
7 |
8 | def __init__(self, dim):
9 | super(InstanceWhitening, self).__init__()
10 | self.instance_standardization = nn.InstanceNorm2d(dim, affine=False)
11 |
12 | def forward(self, x):
13 |
14 | x = self.instance_standardization(x)
15 | w = x
16 |
17 | return x, w
18 |
19 |
20 | def get_covariance_matrix(f_map, eye=None):
21 | eps = 1e-5
22 | B, C, H, W = f_map.shape # i-th feature size (B X C X H X W)
23 | HW = H * W
24 | if eye is None:
25 | eye = torch.eye(C).cuda()
26 | f_map = f_map.contiguous().view(B, C, -1) # B X C X H X W > B X C X (H X W)
27 | f_cor = torch.bmm(f_map, f_map.transpose(1, 2)).div(HW-1) + (eps * eye) # C X C / HW
28 |
29 | return f_cor, B
30 |
31 | # Calcualate Cross Covarianc of two feature maps
32 | # reference : https://github.com/shachoi/RobustNet
33 | def get_cross_covariance_matrix(f_map1, f_map2, eye=None):
34 | eps = 1e-5
35 | assert f_map1.shape == f_map2.shape
36 |
37 | B, C, H, W = f_map1.shape
38 | HW = H*W
39 |
40 | if eye is None:
41 | eye = torch.eye(C).cuda()
42 |
43 | # feature map shape : (B,C,H,W) -> (B,C,HW)
44 | f_map1 = f_map1.contiguous().view(B, C, -1)
45 | f_map2 = f_map2.contiguous().view(B, C, -1)
46 |
47 | # f_cor shape : (B, C, C)
48 | f_cor = torch.bmm(f_map1, f_map2.transpose(1, 2)).div(HW-1) + (eps * eye)
49 |
50 | return f_cor, B
51 |
52 | def cross_whitening_loss(k_feat, q_feat):
53 | assert k_feat.shape == q_feat.shape
54 |
55 | f_cor, B = get_cross_covariance_matrix(k_feat, q_feat)
56 | diag_loss = torch.FloatTensor([0]).cuda()
57 |
58 | # get diagonal values of covariance matrix
59 | for cor in f_cor:
60 | diag = torch.diagonal(cor.squeeze(dim=0), 0)
61 | eye = torch.ones_like(diag).cuda()
62 | diag_loss = diag_loss + F.mse_loss(diag, eye)
63 | diag_loss = diag_loss / B
64 |
65 | return diag_loss
66 |
--------------------------------------------------------------------------------
/network/mynn.py:
--------------------------------------------------------------------------------
1 | """
2 | Custom Norm wrappers to enable sync BN, regular BN and for weight initialization
3 | """
4 | import torch.nn as nn
5 | import torch
6 | from config import cfg
7 |
8 | def Norm2d(in_channels):
9 | """
10 | Custom Norm Function to allow flexible switching
11 | """
12 | layer = getattr(cfg.MODEL, 'BNFUNC')
13 | normalization_layer = layer(in_channels)
14 | return normalization_layer
15 |
16 |
17 | def freeze_weights(*models):
18 | for model in models:
19 | for k in model.parameters():
20 | k.requires_grad = False
21 |
22 | def unfreeze_weights(*models):
23 | for model in models:
24 | for k in model.parameters():
25 | k.requires_grad = True
26 |
27 | def initialize_weights(*models):
28 | """
29 | Initialize Model Weights
30 | """
31 | for model in models:
32 | for module in model.modules():
33 | if isinstance(module, (nn.Conv2d, nn.Linear)):
34 | nn.init.kaiming_normal_(module.weight, nonlinearity='relu')
35 | if module.bias is not None:
36 | module.bias.data.zero_()
37 | elif isinstance(module, nn.Conv1d):
38 | nn.init.kaiming_normal_(module.weight, nonlinearity='relu')
39 | if module.bias is not None:
40 | module.bias.data.zero_()
41 | elif isinstance(module, nn.BatchNorm2d) or isinstance(module, nn.BatchNorm1d) or \
42 | isinstance(module, nn.GroupNorm) or isinstance(module, nn.SyncBatchNorm):
43 | module.weight.data.fill_(1)
44 | module.bias.data.zero_()
45 |
46 | def initialize_embedding(*models):
47 | """
48 | Initialize Model Weights
49 | """
50 | for model in models:
51 | for module in model.modules():
52 | if isinstance(module, nn.Embedding):
53 | module.weight.data.zero_() #original
54 |
55 |
56 |
57 | def Upsample(x, size):
58 | """
59 | Wrapper Around the Upsample Call
60 | """
61 | return nn.functional.interpolate(x, size=size, mode='bilinear',
62 | align_corners=True)
63 |
64 | def forgiving_state_restore(net, loaded_dict):
65 | """
66 | Handle partial loading when some tensors don't match up in size.
67 | Because we want to use models that were trained off a different
68 | number of classes.
69 | """
70 | net_state_dict = net.state_dict()
71 | new_loaded_dict = {}
72 | for k in net_state_dict:
73 | if k in loaded_dict and net_state_dict[k].size() == loaded_dict[k].size():
74 | new_loaded_dict[k] = loaded_dict[k]
75 | else:
76 | print("Skipped loading parameter", k)
77 | # logging.info("Skipped loading parameter %s", k)
78 | net_state_dict.update(new_loaded_dict)
79 | net.load_state_dict(net_state_dict)
80 | return net
81 |
--------------------------------------------------------------------------------
/network/sdcl.py:
--------------------------------------------------------------------------------
1 | from typing import Any
2 | from packaging import version
3 |
4 | from abc import ABC
5 |
6 | import torch
7 | from torch import nn
8 | import torch.nn.functional as F
9 |
10 | from network.mynn import initialize_weights
11 | import numpy as np
12 | import time
13 |
14 |
15 | class Disentangle_Contrast(nn.Module):
16 | def __init__(self, args):
17 | super(Disentangle_Contrast, self).__init__()
18 |
19 | self.args = args
20 | self.temperature = 0.1
21 | self.cross_entropy_loss = torch.nn.CrossEntropyLoss(reduction='none')
22 |
23 | self.num_patch = self.args.num_patch
24 | self.num_classes = 19
25 | self.max_samples = 1000
26 |
27 |
28 | # reshape label or prediction
29 | def reshape_map(self, map, shape):
30 |
31 | map = map.unsqueeze(1).float().clone()
32 | map = torch.nn.functional.interpolate(map, shape, mode='nearest')
33 | map = map.squeeze(1).long()
34 |
35 | return map
36 |
37 |
38 | def _contrastive(self, pos_q, pos_k, neg):
39 | num_patch, _, patch_dim = pos_q.shape
40 |
41 | # l_pos shape : (num_patch, 1)
42 | l_pos = torch.bmm(pos_q, pos_k.transpose(2, 1))
43 | l_pos = l_pos.view(num_patch, 1)
44 |
45 | # l_neg shape : (num_patch, negative_size)
46 | l_neg = torch.bmm(pos_q, neg.transpose(2, 1))
47 | l_neg = l_neg.view(num_patch, -1)
48 |
49 | out = torch.cat([l_pos, l_neg], dim=1) / self.args.nce_T
50 |
51 | loss = self.cross_entropy_loss(out, torch.zeros(out.size(0), dtype=torch.long, device=pos_q.device))
52 |
53 | return loss
54 |
55 |
56 | def Disentangle_Sampler(self, correct_maps, feats_q, feats_k, predicts, predicts_j, labels):
57 | B, HW, C = feats_q.shape
58 | start_ts = time.time()
59 |
60 | # X_pos_q : anchors, X_pos_k : positives, X_neg : negatives
61 | X_pos_q = []
62 | X_pos_k = []
63 | X_neg = []
64 |
65 | for ii in range(B):
66 | img_sample_ts = time.time()
67 | M = correct_maps[ii]
68 | # indices : wrong prediction location in query features
69 | indices = (M == 1).nonzero()
70 |
71 | classes_labels = torch.unique(labels[ii])
72 | classes_wrong = torch.unique(predicts_j[ii, indices])
73 |
74 | pos_indices = []
75 | neg_indices = []
76 |
77 | # sample anchor, positive, negative for each wrong class
78 | for cls_id in classes_wrong:
79 | sampling_time = time.time()
80 | # cls_indices : anchor, positive indices
81 | cls_indices = ((M == 1) & (predicts_j[ii] == cls_id)).nonzero()
82 |
83 | # pass if wrong class doesn't exist in the image
84 | if cls_id not in classes_labels:
85 | continue
86 | else:
87 | neg_cls_indices = (labels[ii] == cls_id).nonzero()
88 |
89 | if neg_cls_indices.size(0) < self.num_patch:
90 | continue
91 |
92 | neg_sampled_indices = [neg_cls_indices[torch.randperm(neg_cls_indices.size(0))[:self.num_patch]].squeeze()] * cls_indices.size(0)
93 | neg_sampled_indices = torch.cat(neg_sampled_indices, dim=0)
94 |
95 | pos_indices.append(cls_indices)
96 | neg_indices.append(neg_sampled_indices)
97 |
98 | if not pos_indices:
99 | continue
100 | pos_indices = torch.cat(pos_indices, dim=0)
101 | neg_indices = torch.cat(neg_indices, dim=0)
102 |
103 | # anchor from query feature
104 | X_pos_q.append(feats_q[ii, pos_indices, :])
105 | # positive from key feature
106 | X_pos_k.append(feats_k[ii, pos_indices, :])
107 | # Negative from query feature
108 | X_neg.append(feats_q[ii, neg_indices, :].view(pos_indices.size(0), self.num_patch, C))
109 |
110 | if not X_pos_q:
111 | return None, None, None
112 | # X_pos_q, X_pos_k shape : (num_patch, 1, C)
113 | # X_neg shape : (num_patch, negative_size, C)
114 | X_pos_q = torch.cat(X_pos_q, dim=0)
115 | X_pos_k = torch.cat(X_pos_k, dim=0)
116 | X_neg = torch.cat(X_neg, dim=0)
117 |
118 | if X_pos_q.shape[0] > B * self.max_samples:
119 | indices = torch.randperm(X_pos_q.size(0))[:B*self.max_samples]
120 | X_pos_q = X_pos_q[indices, :, :]
121 | X_pos_k = X_pos_k[indices, :, :]
122 | X_neg = X_neg[indices, :, :]
123 |
124 | return X_pos_q, X_pos_k, X_neg
125 |
126 |
127 | def forward(self, feats_q, feats_k, predicts, predicts_j, labels):
128 | B, C, H, W = feats_q.shape
129 |
130 | # reshape the labels and predictions to feature map's size
131 | labels = self.reshape_map(labels, (H, W))
132 | predicts = self.reshape_map(predicts, (H, W))
133 | predicts_j = self.reshape_map(predicts_j, (H, W))
134 |
135 | # calculate Correction map
136 | correct_maps = torch.ones_like(predicts, device=feats_q[0].device)
137 | correct_maps[predicts == predicts_j] = 0
138 | correct_maps[labels == 255] = 0
139 | correct_maps[predicts != labels] = 0
140 | correct_maps = correct_maps.flatten(1, 2)
141 |
142 | predicts = predicts.flatten(1, 2)
143 | predicts_j = predicts_j.flatten(1, 2)
144 | labels = labels.flatten(1, 2)
145 |
146 | feats_k = feats_k.detach()
147 |
148 | feats_q_reshape = feats_q.permute(0, 2, 3, 1).flatten(1, 2)
149 | feats_k_reshape = feats_k.permute(0, 2, 3, 1).flatten(1, 2)
150 |
151 | # Sample the anchor and positives, negatives
152 | patches_q, patches_k, patches_neg = self.Disentangle_Sampler(correct_maps, feats_q_reshape, feats_k_reshape,
153 | predicts, predicts_j, labels)
154 |
155 | if patches_q is None:
156 | loss = torch.FloatTensor([0]).cuda()
157 | return loss
158 |
159 | loss = self._contrastive(patches_q, patches_k, patches_neg)
160 |
161 | return loss
162 |
163 |
--------------------------------------------------------------------------------
/network/switchwhiten.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from torch.nn.parameter import Parameter
4 | from torch.nn.modules.module import Module
5 |
6 |
7 | class SwitchWhiten2d(Module):
8 | """Switchable Whitening.
9 |
10 | Args:
11 | num_features (int): Number of channels.
12 | num_pergroup (int): Number of channels for each whitening group.
13 | sw_type (int): Switchable whitening type, from {2, 3, 5}.
14 | sw_type = 2: BW + IW
15 | sw_type = 3: BW + IW + LN
16 | sw_type = 5: BW + IW + BN + IN + LN
17 | T (int): Number of iterations for iterative whitening.
18 | tie_weight (bool): Use the same importance weight for mean and
19 | covariance or not.
20 | """
21 |
22 | def __init__(self,
23 | num_features,
24 | num_pergroup=16,
25 | sw_type=2,
26 | T=5,
27 | tie_weight=False,
28 | eps=1e-5,
29 | momentum=0.99,
30 | affine=True):
31 | super(SwitchWhiten2d, self).__init__()
32 | if sw_type not in [2, 3, 5]:
33 | raise ValueError('sw_type should be in [2, 3, 5], '
34 | 'but got {}'.format(sw_type))
35 | assert num_features % num_pergroup == 0
36 | self.num_features = num_features
37 | self.num_pergroup = num_pergroup
38 | self.num_groups = num_features // num_pergroup
39 | self.sw_type = sw_type
40 | self.T = T
41 | self.tie_weight = tie_weight
42 | self.eps = eps
43 | self.momentum = momentum
44 | self.affine = affine
45 | num_components = sw_type
46 |
47 | self.sw_mean_weight = Parameter(torch.ones(num_components))
48 | if not self.tie_weight:
49 | self.sw_var_weight = Parameter(torch.ones(num_components))
50 | else:
51 | self.register_parameter('sw_var_weight', None)
52 |
53 | if self.affine:
54 | self.weight = Parameter(torch.ones(num_features))
55 | self.bias = Parameter(torch.zeros(num_features))
56 | else:
57 | self.register_parameter('weight', None)
58 | self.register_parameter('bias', None)
59 |
60 | self.register_buffer('running_mean',
61 | torch.zeros(self.num_groups, num_pergroup, 1))
62 | self.register_buffer(
63 | 'running_cov',
64 | torch.eye(num_pergroup).unsqueeze(0).repeat(self.num_groups, 1, 1))
65 |
66 | self.reset_parameters()
67 |
68 | def reset_parameters(self):
69 | self.running_mean.zero_()
70 | self.running_cov.zero_()
71 | nn.init.ones_(self.sw_mean_weight)
72 | if not self.tie_weight:
73 | nn.init.ones_(self.sw_var_weight)
74 | if self.affine:
75 | nn.init.ones_(self.weight)
76 | nn.init.zeros_(self.bias)
77 |
78 | def __repr__(self):
79 | return ('{name}({num_features}, num_pergroup={num_pergroup}, '
80 | 'sw_type={sw_type}, T={T}, tie_weight={tie_weight}, '
81 | 'eps={eps}, momentum={momentum}, affine={affine})'.format(
82 | name=self.__class__.__name__, **self.__dict__))
83 |
84 | def forward(self, x):
85 | N, C, H, W = x.size()
86 | c, g = self.num_pergroup, self.num_groups
87 |
88 | in_data_t = x.transpose(0, 1).contiguous()
89 | # g x c x (N x H x W)
90 | in_data_t = in_data_t.view(g, c, -1)
91 |
92 | # calculate batch mean and covariance
93 | if self.training:
94 | # g x c x 1
95 | mean_bn = in_data_t.mean(-1, keepdim=True)
96 | in_data_bn = in_data_t - mean_bn
97 | # g x c x c
98 | cov_bn = torch.bmm(in_data_bn,
99 | in_data_bn.transpose(1, 2)).div(H * W * N)
100 |
101 | self.running_mean.mul_(self.momentum)
102 | self.running_mean.add_((1 - self.momentum) * mean_bn.data)
103 | self.running_cov.mul_(self.momentum)
104 | self.running_cov.add_((1 - self.momentum) * cov_bn.data)
105 | else:
106 | mean_bn = torch.autograd.Variable(self.running_mean)
107 | cov_bn = torch.autograd.Variable(self.running_cov)
108 |
109 | mean_bn = mean_bn.view(1, g, c, 1).expand(N, g, c, 1).contiguous()
110 | mean_bn = mean_bn.view(N * g, c, 1)
111 | cov_bn = cov_bn.view(1, g, c, c).expand(N, g, c, c).contiguous()
112 | cov_bn = cov_bn.view(N * g, c, c)
113 |
114 | # (N x g) x c x (H x W)
115 | in_data = x.view(N * g, c, -1)
116 |
117 | eye = in_data.data.new().resize_(c, c)
118 | eye = torch.nn.init.eye_(eye).view(1, c, c).expand(N * g, c, c)
119 |
120 | # calculate other statistics
121 | # (N x g) x c x 1
122 | mean_in = in_data.mean(-1, keepdim=True)
123 | x_in = in_data - mean_in
124 | # (N x g) x c x c
125 | cov_in = torch.bmm(x_in, torch.transpose(x_in, 1, 2)).div(H * W)
126 | if self.sw_type in [3, 5]:
127 | x = x.view(N, -1)
128 | mean_ln = x.mean(-1, keepdim=True).view(N, 1, 1, 1)
129 | mean_ln = mean_ln.expand(N, g, 1, 1).contiguous().view(N * g, 1, 1)
130 | var_ln = x.var(-1, keepdim=True).view(N, 1, 1, 1)
131 | var_ln = var_ln.expand(N, g, 1, 1).contiguous().view(N * g, 1, 1)
132 | var_ln = var_ln * eye
133 | if self.sw_type == 5:
134 | var_bn = torch.diag_embed(torch.diagonal(cov_bn, dim1=-2, dim2=-1))
135 | var_in = torch.diag_embed(torch.diagonal(cov_in, dim1=-2, dim2=-1))
136 |
137 | # calculate weighted average of mean and covariance
138 | softmax = nn.Softmax(0)
139 | mean_weight = softmax(self.sw_mean_weight)
140 | if not self.tie_weight:
141 | var_weight = softmax(self.sw_var_weight)
142 | else:
143 | var_weight = mean_weight
144 |
145 | # BW + IW
146 | if self.sw_type == 2:
147 | # (N x g) x c x 1
148 | mean = mean_weight[0] * mean_bn + mean_weight[1] * mean_in
149 | cov = var_weight[0] * cov_bn + var_weight[1] * cov_in + \
150 | self.eps * eye
151 | # BW + IW + LN
152 | elif self.sw_type == 3:
153 | mean = mean_weight[0] * mean_bn + \
154 | mean_weight[1] * mean_in + mean_weight[2] * mean_ln
155 | cov = var_weight[0] * cov_bn + var_weight[1] * cov_in + \
156 | var_weight[2] * var_ln + self.eps * eye
157 | # BW + IW + BN + IN + LN
158 | elif self.sw_type == 5:
159 | mean = (mean_weight[0] + mean_weight[2]) * mean_bn + \
160 | (mean_weight[1] + mean_weight[3]) * mean_in + \
161 | mean_weight[4] * mean_ln
162 | cov = var_weight[0] * cov_bn + var_weight[1] * cov_in + \
163 | var_weight[0] * var_bn + var_weight[1] * var_in + \
164 | var_weight[4] * var_ln + self.eps * eye
165 |
166 | # perform whitening using Newton's iteration
167 | Ng, c, _ = cov.size()
168 | P = torch.eye(c).to(cov).expand(Ng, c, c)
169 | # reciprocal of trace of covariance
170 | rTr = (cov * P).sum((1, 2), keepdim=True).reciprocal_()
171 | cov_N = cov * rTr
172 | for k in range(self.T):
173 | P = torch.baddbmm(1.5, P, -0.5, torch.matrix_power(P, 3), cov_N)
174 | # whiten matrix: the matrix inverse of covariance, i.e., cov^{-1/2}
175 | wm = P.mul_(rTr.sqrt())
176 |
177 | x_hat = torch.bmm(wm, in_data - mean)
178 | x_hat = x_hat.view(N, C, H, W)
179 | if self.affine:
180 | x_hat = x_hat * self.weight.view(1, self.num_features, 1, 1) + \
181 | self.bias.view(1, self.num_features, 1, 1)
182 |
183 | return x_hat
184 |
--------------------------------------------------------------------------------
/network/sync_switchwhiten.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.distributed as dist
3 | import torch.nn as nn
4 | from torch.autograd import Function
5 | from torch.nn.modules.module import Module
6 | from torch.nn.parameter import Parameter
7 |
8 |
9 | class SyncMeanCov(Function):
10 |
11 | @staticmethod
12 | def forward(ctx, in_data, running_mean, running_cov, momentum, training):
13 | g, c, NHW = in_data.size()
14 | ctx.g = g
15 | ctx.c = c
16 | ctx.NHW = NHW
17 | ctx.training = training
18 |
19 | if training:
20 | mean_bn = in_data.mean(-1, keepdim=True) # g x c x 1
21 | dist.all_reduce(mean_bn)
22 | mean_bn /= dist.get_world_size()
23 | in_data_bn = in_data - mean_bn
24 | cov_bn = torch.bmm(in_data_bn, in_data_bn.transpose(1, 2)).div(NHW)
25 | dist.all_reduce(cov_bn)
26 | cov_bn /= dist.get_world_size()
27 |
28 | running_mean.mul_(momentum)
29 | running_mean.add_((1 - momentum) * mean_bn.data)
30 | running_cov.mul_(momentum)
31 | running_cov.add_((1 - momentum) * cov_bn.data)
32 | else:
33 | mean_bn = torch.autograd.Variable(running_mean)
34 | cov_bn = torch.autograd.Variable(running_cov)
35 |
36 | ctx.save_for_backward(in_data.data, mean_bn.data)
37 | return mean_bn, cov_bn
38 |
39 | @staticmethod
40 | def backward(ctx, grad_mean_out, grad_cov_out):
41 | in_data, mean_bn = ctx.saved_tensors
42 |
43 | if ctx.training:
44 | dist.all_reduce(grad_mean_out)
45 | dist.all_reduce(grad_cov_out)
46 | world_size = dist.get_world_size()
47 | else:
48 | world_size = 1
49 |
50 | grad_cov_out = (grad_cov_out + grad_cov_out.transpose(1, 2)) / 2
51 | grad_cov_in = 2 * torch.bmm(grad_cov_out, (in_data - mean_bn)) \
52 | / (ctx.NHW*world_size) # g x c x (N x H x W)
53 |
54 | grad_mean_in = grad_mean_out / ctx.NHW / world_size
55 | inDiff = grad_mean_in + grad_cov_in
56 | return inDiff, None, None, None, None
57 |
58 |
59 | class SyncSwitchWhiten2d(Module):
60 | """Syncronized Switchable Whitening.
61 |
62 | Args:
63 | num_features (int): Number of channels.
64 | num_pergroup (int): Number of channels for each whitening group.
65 | sw_type (int): Switchable whitening type, from {2, 3, 4, 5}.
66 | sw_type = 2: BW + IW
67 | sw_type = 3: BW + IW + LN
68 | sw_type = 5: BW + IW + BN + IN + LN
69 | T (int): Number of iterations for iterative whitening.
70 | tie_weight (bool): Use the same importance weight for mean and
71 | covariance or not.
72 | """
73 |
74 | def __init__(self,
75 | num_features,
76 | num_pergroup=16,
77 | sw_type=2,
78 | T=5,
79 | tie_weight=False,
80 | eps=1e-5,
81 | momentum=0.99,
82 | affine=True):
83 | super(SyncSwitchWhiten2d, self).__init__()
84 | if sw_type not in [2, 3, 4, 5]:
85 | raise ValueError('sw_type should be in [2, 3, 4, 5], '
86 | 'but got {}'.format(sw_type))
87 | assert num_features % num_pergroup == 0
88 | self.num_features = num_features
89 | self.num_pergroup = num_pergroup
90 | self.num_groups = num_features // num_pergroup
91 | self.sw_type = sw_type
92 | self.T = T
93 | self.tie_weight = tie_weight
94 | self.eps = eps
95 | self.momentum = momentum
96 | self.tie_weight = tie_weight
97 | self.affine = affine
98 | num_components = sw_type
99 |
100 | self.sw_mean_weight = Parameter(torch.ones(num_components))
101 | if not self.tie_weight:
102 | self.sw_var_weight = Parameter(torch.ones(num_components))
103 | else:
104 | self.register_parameter('sw_var_weight', None)
105 |
106 | if self.affine:
107 | self.weight = Parameter(torch.ones(num_features))
108 | self.bias = Parameter(torch.zeros(num_features))
109 | else:
110 | self.register_parameter('weight', None)
111 | self.register_parameter('bias', None)
112 |
113 | self.register_buffer('running_mean',
114 | torch.zeros(self.num_groups, num_pergroup, 1))
115 | self.register_buffer(
116 | 'running_cov',
117 | torch.eye(num_pergroup).unsqueeze(0).repeat(self.num_groups, 1, 1))
118 |
119 | self.reset_parameters()
120 |
121 | def reset_parameters(self):
122 | self.running_mean.zero_()
123 | self.running_cov.zero_()
124 | nn.init.ones_(self.sw_mean_weight)
125 | if not self.tie_weight:
126 | nn.init.ones_(self.sw_var_weight)
127 | if self.affine:
128 | nn.init.ones_(self.weight)
129 | nn.init.zeros_(self.bias)
130 |
131 | def __repr__(self):
132 | return ('{name}({num_features}, num_pergroup={num_pergroup}, '
133 | 'sw_type={sw_type}, T={T}, tie_weight={tie_weight}, '
134 | 'eps={eps}, momentum={momentum}, affine={affine})'.format(
135 | name=self.__class__.__name__, **self.__dict__))
136 |
137 | def forward(self, x):
138 | N, C, H, W = x.size()
139 | c, g = self.num_pergroup, self.num_groups
140 |
141 | in_data_t = x.transpose(0, 1).contiguous()
142 | # g x c x (N x H x W)
143 | in_data_t = in_data_t.view(g, c, -1)
144 | # calculate batch mean and covariance
145 | mean_bn, cov_bn = SyncMeanCov.apply(in_data_t, self.running_mean,
146 | self.running_cov, self.momentum,
147 | self.training)
148 |
149 | mean_bn = mean_bn.view(1, g, c, 1).expand(N, g, c, 1).contiguous()
150 | mean_bn = mean_bn.view(N * g, c, 1)
151 | cov_bn = cov_bn.view(1, g, c, c).expand(N, g, c, c).contiguous()
152 | cov_bn = cov_bn.view(N * g, c, c)
153 |
154 | # (N x g) x c x (H x W)
155 | in_data = x.view(N * g, c, -1)
156 |
157 | eye = in_data.data.new().resize_(c, c)
158 | eye = torch.nn.init.eye_(eye).view(1, c, c).expand(N * g, c, c)
159 |
160 | # calculate other statistics
161 | # (N x g) x c x 1
162 | mean_in = in_data.mean(-1, keepdim=True)
163 | x_in = in_data - mean_in
164 | # (N x g) x c x c
165 | cov_in = torch.bmm(x_in, torch.transpose(x_in, 1, 2)).div(H * W)
166 | if self.sw_type in [3, 5]:
167 | x = x.view(N, -1)
168 | mean_ln = x.mean(-1, keepdim=True).view(N, 1, 1, 1)
169 | mean_ln = mean_ln.expand(N, g, 1, 1).contiguous().view(N * g, 1, 1)
170 | var_ln = x.var(-1, keepdim=True).view(N, 1, 1, 1)
171 | var_ln = var_ln.expand(N, g, 1, 1).contiguous().view(N * g, 1, 1)
172 | var_ln = var_ln * eye
173 | if self.sw_type == 5:
174 | var_bn = torch.diag_embed(torch.diagonal(cov_bn, dim1=-2, dim2=-1))
175 | var_in = torch.diag_embed(torch.diagonal(cov_in, dim1=-2, dim2=-1))
176 |
177 | # calculate weighted average of mean and covariance
178 | softmax = nn.Softmax(0)
179 | mean_weight = softmax(self.sw_mean_weight)
180 | if not self.tie_weight:
181 | var_weight = softmax(self.sw_var_weight)
182 | else:
183 | var_weight = mean_weight
184 |
185 | # BW + IW
186 | if self.sw_type == 2:
187 | # (N x g) x c x 1
188 | mean = mean_weight[0] * mean_bn + mean_weight[1] * mean_in
189 | cov = var_weight[0] * cov_bn + var_weight[1] * cov_in + \
190 | self.eps * eye
191 | # BW + IW + LN
192 | elif self.sw_type == 3:
193 | mean = mean_weight[0] * mean_bn + \
194 | mean_weight[1] * mean_in + mean_weight[2] * mean_ln
195 | cov = var_weight[0] * cov_bn + var_weight[1] * cov_in + \
196 | var_weight[2] * var_ln + self.eps * eye
197 | # BW + IW + BN + IN + LN
198 | elif self.sw_type == 5:
199 | mean = (mean_weight[0] + mean_weight[2]) * mean_bn + \
200 | (mean_weight[1] + mean_weight[3]) * mean_in + \
201 | mean_weight[4] * mean_ln
202 | cov = var_weight[0] * cov_bn + var_weight[1] * cov_in + \
203 | var_weight[0] * var_bn + var_weight[1] * var_in + \
204 | var_weight[4] * var_ln + self.eps * eye
205 |
206 | # perform whitening using Newton's iteration
207 | Ng, c, _ = cov.size()
208 | P = torch.eye(c).to(cov).expand(Ng, c, c)
209 | # reciprocal of trace of covariance
210 | rTr = (cov * P).sum((1, 2), keepdim=True).reciprocal_()
211 | cov_N = cov * rTr
212 | for k in range(self.T):
213 | P = torch.baddbmm(beta=1.5, input=P, alpha=-0.5, batch1=torch.matrix_power(P, 3), batch2=cov_N)
214 | # whiten matrix: the matrix inverse of covariance, i.e., cov^{-1/2}
215 | wm = P.mul_(rTr.sqrt())
216 |
217 | x_hat = torch.bmm(wm, in_data - mean)
218 | x_hat = x_hat.view(N, C, H, W)
219 | if self.affine:
220 | x_hat = x_hat * self.weight.view(1, self.num_features, 1, 1) + \
221 | self.bias.view(1, self.num_features, 1, 1)
222 |
223 | return x_hat
224 |
--------------------------------------------------------------------------------
/network/wider_resnet.py:
--------------------------------------------------------------------------------
1 | """
2 | # Code adapted from:
3 | # https://github.com/mapillary/inplace_abn/
4 | #
5 | # BSD 3-Clause License
6 | #
7 | # Copyright (c) 2017, mapillary
8 | # All rights reserved.
9 | #
10 | # Redistribution and use in source and binary forms, with or without
11 | # modification, are permitted provided that the following conditions are met:
12 | #
13 | # * Redistributions of source code must retain the above copyright notice, this
14 | # list of conditions and the following disclaimer.
15 | #
16 | # * Redistributions in binary form must reproduce the above copyright notice,
17 | # this list of conditions and the following disclaimer in the documentation
18 | # and/or other materials provided with the distribution.
19 | #
20 | # * Neither the name of the copyright holder nor the names of its
21 | # contributors may be used to endorse or promote products derived from
22 | # this software without specific prior written permission.
23 | #
24 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
25 | # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
26 | # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
27 | # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
28 | # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
29 | # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
30 | # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
31 | # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
32 | # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
33 | # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
34 | """
35 | import logging
36 | import sys
37 | from collections import OrderedDict
38 | from functools import partial
39 | import torch.nn as nn
40 | import torch
41 | import network.mynn as mynn
42 |
43 | def bnrelu(channels):
44 | """
45 | Single Layer BN and Relui
46 | """
47 | return nn.Sequential(mynn.Norm2d(channels),
48 | nn.ReLU(inplace=True))
49 |
50 | class GlobalAvgPool2d(nn.Module):
51 | """
52 | Global average pooling over the input's spatial dimensions
53 | """
54 |
55 | def __init__(self):
56 | super(GlobalAvgPool2d, self).__init__()
57 | logging.info("Global Average Pooling Initialized")
58 |
59 | def forward(self, inputs):
60 | in_size = inputs.size()
61 | return inputs.view((in_size[0], in_size[1], -1)).mean(dim=2)
62 |
63 |
64 | class IdentityResidualBlock(nn.Module):
65 | """
66 | Identity Residual Block for WideResnet
67 | """
68 | def __init__(self,
69 | in_channels,
70 | channels,
71 | stride=1,
72 | dilation=1,
73 | groups=1,
74 | norm_act=bnrelu,
75 | dropout=None,
76 | dist_bn=False
77 | ):
78 | """Configurable identity-mapping residual block
79 |
80 | Parameters
81 | ----------
82 | in_channels : int
83 | Number of input channels.
84 | channels : list of int
85 | Number of channels in the internal feature maps.
86 | Can either have two or three elements: if three construct
87 | a residual block with two `3 x 3` convolutions,
88 | otherwise construct a bottleneck block with `1 x 1`, then
89 | `3 x 3` then `1 x 1` convolutions.
90 | stride : int
91 | Stride of the first `3 x 3` convolution
92 | dilation : int
93 | Dilation to apply to the `3 x 3` convolutions.
94 | groups : int
95 | Number of convolution groups.
96 | This is used to create ResNeXt-style blocks and is only compatible with
97 | bottleneck blocks.
98 | norm_act : callable
99 | Function to create normalization / activation Module.
100 | dropout: callable
101 | Function to create Dropout Module.
102 | dist_bn: Boolean
103 | A variable to enable or disable use of distributed BN
104 | """
105 | super(IdentityResidualBlock, self).__init__()
106 | self.dist_bn = dist_bn
107 |
108 | # Check if we are using distributed BN and use the nn from encoding.nn
109 | # library rather than using standard pytorch.nn
110 |
111 |
112 | # Check parameters for inconsistencies
113 | if len(channels) != 2 and len(channels) != 3:
114 | raise ValueError("channels must contain either two or three values")
115 | if len(channels) == 2 and groups != 1:
116 | raise ValueError("groups > 1 are only valid if len(channels) == 3")
117 |
118 | is_bottleneck = len(channels) == 3
119 | need_proj_conv = stride != 1 or in_channels != channels[-1]
120 |
121 | self.bn1 = norm_act(in_channels)
122 | if not is_bottleneck:
123 | layers = [
124 | ("conv1", nn.Conv2d(in_channels,
125 | channels[0],
126 | 3,
127 | stride=stride,
128 | padding=dilation,
129 | bias=False,
130 | dilation=dilation)),
131 | ("bn2", norm_act(channels[0])),
132 | ("conv2", nn.Conv2d(channels[0], channels[1],
133 | 3,
134 | stride=1,
135 | padding=dilation,
136 | bias=False,
137 | dilation=dilation))
138 | ]
139 | if dropout is not None:
140 | layers = layers[0:2] + [("dropout", dropout())] + layers[2:]
141 | else:
142 | layers = [
143 | ("conv1",
144 | nn.Conv2d(in_channels,
145 | channels[0],
146 | 1,
147 | stride=stride,
148 | padding=0,
149 | bias=False)),
150 | ("bn2", norm_act(channels[0])),
151 | ("conv2", nn.Conv2d(channels[0],
152 | channels[1],
153 | 3, stride=1,
154 | padding=dilation, bias=False,
155 | groups=groups,
156 | dilation=dilation)),
157 | ("bn3", norm_act(channels[1])),
158 | ("conv3", nn.Conv2d(channels[1], channels[2],
159 | 1, stride=1, padding=0, bias=False))
160 | ]
161 | if dropout is not None:
162 | layers = layers[0:4] + [("dropout", dropout())] + layers[4:]
163 | self.convs = nn.Sequential(OrderedDict(layers))
164 |
165 | if need_proj_conv:
166 | self.proj_conv = nn.Conv2d(
167 | in_channels, channels[-1], 1, stride=stride, padding=0, bias=False)
168 |
169 | def forward(self, x):
170 | """
171 | This is the standard forward function for non-distributed batch norm
172 | """
173 | if hasattr(self, "proj_conv"):
174 | bn1 = self.bn1(x)
175 | shortcut = self.proj_conv(bn1)
176 | else:
177 | shortcut = x.clone()
178 | bn1 = self.bn1(x)
179 |
180 | out = self.convs(bn1)
181 | out.add_(shortcut)
182 | return out
183 |
184 |
185 |
186 |
187 | class WiderResNet(nn.Module):
188 | """
189 | WideResnet Global Module for Initialization
190 | """
191 | def __init__(self,
192 | structure,
193 | norm_act=bnrelu,
194 | classes=0
195 | ):
196 | """Wider ResNet with pre-activation (identity mapping) blocks
197 |
198 | Parameters
199 | ----------
200 | structure : list of int
201 | Number of residual blocks in each of the six modules of the network.
202 | norm_act : callable
203 | Function to create normalization / activation Module.
204 | classes : int
205 | If not `0` also include global average pooling and \
206 | a fully-connected layer with `classes` outputs at the end
207 | of the network.
208 | """
209 | super(WiderResNet, self).__init__()
210 | self.structure = structure
211 |
212 | if len(structure) != 6:
213 | raise ValueError("Expected a structure with six values")
214 |
215 | # Initial layers
216 | self.mod1 = nn.Sequential(OrderedDict([
217 | ("conv1", nn.Conv2d(3, 64, 3, stride=1, padding=1, bias=False))
218 | ]))
219 |
220 | # Groups of residual blocks
221 | in_channels = 64
222 | channels = [(128, 128), (256, 256), (512, 512), (512, 1024),
223 | (512, 1024, 2048), (1024, 2048, 4096)]
224 | for mod_id, num in enumerate(structure):
225 | # Create blocks for module
226 | blocks = []
227 | for block_id in range(num):
228 | blocks.append((
229 | "block%d" % (block_id + 1),
230 | IdentityResidualBlock(in_channels, channels[mod_id],
231 | norm_act=norm_act)
232 | ))
233 |
234 | # Update channels and p_keep
235 | in_channels = channels[mod_id][-1]
236 |
237 | # Create module
238 | if mod_id <= 4:
239 | self.add_module("pool%d" %
240 | (mod_id + 2), nn.MaxPool2d(3, stride=2, padding=1))
241 | self.add_module("mod%d" % (mod_id + 2), nn.Sequential(OrderedDict(blocks)))
242 |
243 | # Pooling and predictor
244 | self.bn_out = norm_act(in_channels)
245 | if classes != 0:
246 | self.classifier = nn.Sequential(OrderedDict([
247 | ("avg_pool", GlobalAvgPool2d()),
248 | ("fc", nn.Linear(in_channels, classes))
249 | ]))
250 |
251 | def forward(self, img):
252 | out = self.mod1(img)
253 | out = self.mod2(self.pool2(out))
254 | out = self.mod3(self.pool3(out))
255 | out = self.mod4(self.pool4(out))
256 | out = self.mod5(self.pool5(out))
257 | out = self.mod6(self.pool6(out))
258 | out = self.mod7(out)
259 | out = self.bn_out(out)
260 |
261 | if hasattr(self, "classifier"):
262 | out = self.classifier(out)
263 |
264 | return out
265 |
266 |
267 | class WiderResNetA2(nn.Module):
268 | """
269 | Wider ResNet with pre-activation (identity mapping) blocks
270 |
271 | This variant uses down-sampling by max-pooling in the first two blocks and
272 | by strided convolution in the others.
273 |
274 | Parameters
275 | ----------
276 | structure : list of int
277 | Number of residual blocks in each of the six modules of the network.
278 | norm_act : callable
279 | Function to create normalization / activation Module.
280 | classes : int
281 | If not `0` also include global average pooling and a fully-connected layer
282 | with `classes` outputs at the end
283 | of the network.
284 | dilation : bool
285 | If `True` apply dilation to the last three modules and change the
286 | down-sampling factor from 32 to 8.
287 | """
288 | def __init__(self,
289 | structure,
290 | norm_act=bnrelu,
291 | classes=0,
292 | dilation=False,
293 | dist_bn=False
294 | ):
295 | super(WiderResNetA2, self).__init__()
296 | self.dist_bn = dist_bn
297 |
298 | # If using distributed batch norm, use the encoding.nn as oppose to torch.nn
299 |
300 |
301 | nn.Dropout = nn.Dropout2d
302 | norm_act = bnrelu
303 | self.structure = structure
304 | self.dilation = dilation
305 |
306 | if len(structure) != 6:
307 | raise ValueError("Expected a structure with six values")
308 |
309 | # Initial layers
310 | self.mod1 = torch.nn.Sequential(OrderedDict([
311 | ("conv1", nn.Conv2d(3, 64, 3, stride=1, padding=1, bias=False))
312 | ]))
313 |
314 | # Groups of residual blocks
315 | in_channels = 64
316 | channels = [(128, 128), (256, 256), (512, 512), (512, 1024), (512, 1024, 2048),
317 | (1024, 2048, 4096)]
318 | for mod_id, num in enumerate(structure):
319 | # Create blocks for module
320 | blocks = []
321 | for block_id in range(num):
322 | if not dilation:
323 | dil = 1
324 | stride = 2 if block_id == 0 and 2 <= mod_id <= 4 else 1
325 | else:
326 | if mod_id == 3:
327 | dil = 2
328 | elif mod_id > 3:
329 | dil = 4
330 | else:
331 | dil = 1
332 | stride = 2 if block_id == 0 and mod_id == 2 else 1
333 |
334 | if mod_id == 4:
335 | drop = partial(nn.Dropout, p=0.3)
336 | elif mod_id == 5:
337 | drop = partial(nn.Dropout, p=0.5)
338 | else:
339 | drop = None
340 |
341 | blocks.append((
342 | "block%d" % (block_id + 1),
343 | IdentityResidualBlock(in_channels,
344 | channels[mod_id], norm_act=norm_act,
345 | stride=stride, dilation=dil,
346 | dropout=drop, dist_bn=self.dist_bn)
347 | ))
348 |
349 | # Update channels and p_keep
350 | in_channels = channels[mod_id][-1]
351 |
352 | # Create module
353 | if mod_id < 2:
354 | self.add_module("pool%d" %
355 | (mod_id + 2), nn.MaxPool2d(3, stride=2, padding=1))
356 | self.add_module("mod%d" % (mod_id + 2), nn.Sequential(OrderedDict(blocks)))
357 |
358 | # Pooling and predictor
359 | self.bn_out = norm_act(in_channels)
360 | if classes != 0:
361 | self.classifier = nn.Sequential(OrderedDict([
362 | ("avg_pool", GlobalAvgPool2d()),
363 | ("fc", nn.Linear(in_channels, classes))
364 | ]))
365 |
366 | def forward(self, img):
367 | out = self.mod1(img)
368 | out = self.mod2(self.pool2(out))
369 | out = self.mod3(self.pool3(out))
370 | out = self.mod4(out)
371 | out = self.mod5(out)
372 | out = self.mod6(out)
373 | out = self.mod7(out)
374 | out = self.bn_out(out)
375 |
376 | if hasattr(self, "classifier"):
377 | return self.classifier(out)
378 | return out
379 |
380 |
381 | _NETS = {
382 | "16": {"structure": [1, 1, 1, 1, 1, 1]},
383 | "20": {"structure": [1, 1, 1, 3, 1, 1]},
384 | "38": {"structure": [3, 3, 6, 3, 1, 1]},
385 | }
386 |
387 | __all__ = []
388 | for name, params in _NETS.items():
389 | net_name = "wider_resnet" + name
390 | setattr(sys.modules[__name__], net_name, partial(WiderResNet, **params))
391 | __all__.append(net_name)
392 | for name, params in _NETS.items():
393 | net_name = "wider_resnet" + name + "_a2"
394 | setattr(sys.modules[__name__], net_name, partial(WiderResNetA2, **params))
395 | __all__.append(net_name)
396 |
--------------------------------------------------------------------------------
/optimizer.py:
--------------------------------------------------------------------------------
1 | """
2 | Pytorch Optimizer and Scheduler Related Task
3 | """
4 | import math
5 | import logging
6 | import torch
7 | from torch import optim
8 | from config import cfg
9 |
10 |
11 | def get_optimizer(args, net):
12 | """
13 | Decide Optimizer (Adam or SGD)
14 | """
15 | base_params = []
16 |
17 | for name, param in net.named_parameters():
18 | base_params.append(param)
19 |
20 | if args.sgd:
21 | optimizer = optim.SGD(base_params,
22 | lr=args.lr,
23 | weight_decay=5e-4, #args.weight_decay,
24 | momentum=args.momentum,
25 | nesterov=False)
26 | else:
27 | raise ValueError('Not a valid optimizer')
28 |
29 | if args.lr_schedule == 'scl-poly':
30 | if cfg.REDUCE_BORDER_ITER == -1:
31 | raise ValueError('ERROR Cannot Do Scale Poly')
32 |
33 | rescale_thresh = cfg.REDUCE_BORDER_ITER
34 | scale_value = args.rescale
35 | lambda1 = lambda iteration: \
36 | math.pow(1 - iteration / args.max_iter,
37 | args.poly_exp) if iteration < rescale_thresh else scale_value * math.pow(
38 | 1 - (iteration - rescale_thresh) / (args.max_iter - rescale_thresh),
39 | args.repoly)
40 | scheduler = optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda1)
41 | elif args.lr_schedule == 'poly':
42 | lambda1 = lambda iteration: math.pow(1 - iteration / args.max_iter, args.poly_exp)
43 | scheduler = optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda1)
44 | else:
45 | raise ValueError('unknown lr schedule {}'.format(args.lr_schedule))
46 |
47 | return optimizer, scheduler
48 |
49 |
50 | def load_weights(net, optimizer, scheduler, snapshot_file, restore_optimizer_bool=False):
51 | """
52 | Load weights from snapshot file
53 | """
54 | logging.info("Loading weights from model %s", snapshot_file)
55 | net, optimizer, scheduler, epoch, mean_iu = restore_snapshot(net, optimizer, scheduler, snapshot_file,
56 | restore_optimizer_bool)
57 | return epoch, mean_iu
58 |
59 |
60 | def restore_snapshot(net, optimizer, scheduler, snapshot, restore_optimizer_bool):
61 | """
62 | Restore weights and optimizer (if needed ) for resuming job.
63 | """
64 | checkpoint = torch.load(snapshot, map_location=torch.device('cpu'))
65 | logging.info("Checkpoint Load Compelete")
66 | if optimizer is not None and 'optimizer' in checkpoint and restore_optimizer_bool:
67 | optimizer.load_state_dict(checkpoint['optimizer'])
68 | if scheduler is not None and 'scheduler' in checkpoint and restore_optimizer_bool:
69 | scheduler.load_state_dict(checkpoint['scheduler'])
70 |
71 | if 'state_dict' in checkpoint:
72 | net = forgiving_state_restore(net, checkpoint['state_dict'])
73 | else:
74 | net = forgiving_state_restore(net, checkpoint)
75 |
76 | return net, optimizer, scheduler, checkpoint['epoch'], checkpoint['mean_iu']
77 |
78 |
79 | def forgiving_state_restore(net, loaded_dict):
80 | """
81 | Handle partial loading when some tensors don't match up in size.
82 | Because we want to use models that were trained off a different
83 | number of classes.
84 | """
85 | net_state_dict = net.state_dict()
86 | new_loaded_dict = {}
87 | for k in net_state_dict:
88 | if k in loaded_dict and net_state_dict[k].size() == loaded_dict[k].size():
89 | new_loaded_dict[k] = loaded_dict[k]
90 | else:
91 | print("Skipped loading parameter", k)
92 | # logging.info("Skipped loading parameter %s", k)
93 | net_state_dict.update(new_loaded_dict)
94 | net.load_state_dict(net_state_dict)
95 | return net
96 |
97 | def forgiving_state_copy(target_net, source_net):
98 | """
99 | Handle partial loading when some tensors don't match up in size.
100 | Because we want to use models that were trained off a different
101 | number of classes.
102 | """
103 | net_state_dict = target_net.state_dict()
104 | loaded_dict = source_net.state_dict()
105 | new_loaded_dict = {}
106 | for k in net_state_dict:
107 | if k in loaded_dict and net_state_dict[k].size() == loaded_dict[k].size():
108 | new_loaded_dict[k] = loaded_dict[k]
109 | print("Matched", k)
110 | else:
111 | print("Skipped loading parameter ", k)
112 | # logging.info("Skipped loading parameter %s", k)
113 | net_state_dict.update(new_loaded_dict)
114 | target_net.load_state_dict(net_state_dict)
115 | return target_net
116 |
--------------------------------------------------------------------------------
/scripts/blindnet_infer_r50os16.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 | echo "Running inference on" ${1}
3 | echo "Saving Results :" ${2}
4 |
5 | python -m torch.distributed.launch --nproc_per_node=1 --master_port=29401 infer.py \
6 | --val_dataset cityscapes mapillary bdd-100k \
7 | --arch network.deepv3.DeepR50V3PlusD \
8 | --wt_layer 0 0 1 1 1 0 0 \
9 | --mod blindnet \
10 | --results ${2} \
11 | --date 0101 \
12 | --exp blindnet_r50os16_gtav \
13 | --snapshot ${1}
14 |
--------------------------------------------------------------------------------
/scripts/blindnet_train_r50os16_gtav.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 | export NCCL_DEBUG=INFO
3 |
4 | python -m torch.distributed.launch --nproc_per_node=2 train.py \
5 | --dataset gtav_jitter \
6 | --covstat_val_dataset gtav \
7 | --val_dataset cityscapes bdd100k mapillary \
8 | --arch network.deepv3.DeepR50V3PlusD \
9 | --city_mode 'train' \
10 | --lr_schedule poly \
11 | --lr 0.01 \
12 | --poly_exp 0.9 \
13 | --max_cu_epoch 10000 \
14 | --class_uniform_pct 0.5 \
15 | --class_uniform_tile 1024 \
16 | --crop_size 768 \
17 | --scale_min 0.5 \
18 | --scale_max 2.0 \
19 | --rrotate 0 \
20 | --max_iter 40000 \
21 | --bs_mult 4 \
22 | --gblur \
23 | --color_aug 0.5 \
24 | --relax_denom 0.0 \
25 | --wt_layer 0 0 1 1 1 0 0 \
26 | --use_ca \
27 | --use_cwcl \
28 | --nce_T 0.07 \
29 | --contrast_max_classes 15 \
30 | --contrast_max_view 50 \
31 | --jit_only \
32 | --use_sdcl \
33 | --w1 0.2 \
34 | --w2 0.2 \
35 | --w3 0.3 \
36 | --w4 0.3 \
37 | --num_patch 20 \
38 | --date 0326 \
39 | --exp BlindNet_r50os16_gtav \
40 | --ckpt ./logs/ \
41 | --tb_path ./logs/
42 |
--------------------------------------------------------------------------------
/scripts/blindnet_valid_r50os16_gtav.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 | echo "Running inference on" ${1}
3 |
4 | python -m torch.distributed.launch --nproc_per_node=1 --master_port=29400 valid.py \
5 | --val_dataset cityscapes bdd100k mapillary gtav \
6 | --arch network.deepv3.DeepR50V3PlusD \
7 | --wt_layer 0 0 1 1 1 0 0 \
8 | --date 0101 \
9 | --exp r50os16_gtav_blindnet \
10 | --snapshot ${1}
11 |
--------------------------------------------------------------------------------
/transforms/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/root0yang/BlindNet/1359fa6e9d1e9b011c416f43ec265e31cbde7d9a/transforms/__init__.py
--------------------------------------------------------------------------------
/transforms/__pycache__/joint_transforms.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/root0yang/BlindNet/1359fa6e9d1e9b011c416f43ec265e31cbde7d9a/transforms/__pycache__/joint_transforms.cpython-36.pyc
--------------------------------------------------------------------------------
/transforms/__pycache__/joint_transforms.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/root0yang/BlindNet/1359fa6e9d1e9b011c416f43ec265e31cbde7d9a/transforms/__pycache__/joint_transforms.cpython-37.pyc
--------------------------------------------------------------------------------
/transforms/__pycache__/transforms.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/root0yang/BlindNet/1359fa6e9d1e9b011c416f43ec265e31cbde7d9a/transforms/__pycache__/transforms.cpython-36.pyc
--------------------------------------------------------------------------------
/transforms/__pycache__/transforms.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/root0yang/BlindNet/1359fa6e9d1e9b011c416f43ec265e31cbde7d9a/transforms/__pycache__/transforms.cpython-37.pyc
--------------------------------------------------------------------------------
/transforms/transforms.py:
--------------------------------------------------------------------------------
1 | """
2 | # Code borrowded from:
3 | # https://github.com/zijundeng/pytorch-semantic-segmentation/blob/master/utils/transforms.py
4 | #
5 | #
6 | # MIT License
7 | #
8 | # Copyright (c) 2017 ZijunDeng
9 | #
10 | # Permission is hereby granted, free of charge, to any person obtaining a copy
11 | # of this software and associated documentation files (the "Software"), to deal
12 | # in the Software without restriction, including without limitation the rights
13 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
14 | # copies of the Software, and to permit persons to whom the Software is
15 | # furnished to do so, subject to the following conditions:
16 | #
17 | # The above copyright notice and this permission notice shall be included in all
18 | # copies or substantial portions of the Software.
19 | #
20 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
21 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
22 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
23 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
24 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
25 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
26 | # SOFTWARE.
27 |
28 | """
29 |
30 | """
31 | Standard Transform
32 | """
33 |
34 | import random
35 | import numpy as np
36 | from skimage.filters import gaussian
37 | from skimage.restoration import denoise_bilateral
38 | import torch
39 | from PIL import Image, ImageEnhance
40 | import torchvision.transforms as torch_tr
41 | from config import cfg
42 | from scipy.ndimage.interpolation import shift
43 |
44 | from skimage.segmentation import find_boundaries
45 | from skimage.util import random_noise
46 |
47 | try:
48 | import accimage
49 | except ImportError:
50 | accimage = None
51 |
52 |
53 | class RandomVerticalFlip(object):
54 | def __call__(self, img):
55 | if random.random() < 0.5:
56 | return img.transpose(Image.FLIP_TOP_BOTTOM)
57 | return img
58 |
59 |
60 | class DeNormalize(object):
61 | def __init__(self, mean, std):
62 | self.mean = mean
63 | self.std = std
64 |
65 | def __call__(self, tensor):
66 | for t, m, s in zip(tensor, self.mean, self.std):
67 | t.mul_(s).add_(m)
68 | return tensor
69 |
70 |
71 | class MaskToTensor(object):
72 | def __call__(self, img):
73 | return torch.from_numpy(np.array(img, dtype=np.int32)).long()
74 |
75 | class RelaxedBoundaryLossToTensor(object):
76 | """
77 | Boundary Relaxation
78 | """
79 | def __init__(self,ignore_id, num_classes):
80 | self.ignore_id=ignore_id
81 | self.num_classes= num_classes
82 |
83 |
84 | def new_one_hot_converter(self,a):
85 | ncols = self.num_classes+1
86 | out = np.zeros( (a.size,ncols), dtype=np.uint8)
87 | out[np.arange(a.size),a.ravel()] = 1
88 | out.shape = a.shape + (ncols,)
89 | return out
90 |
91 | def __call__(self,img):
92 |
93 | img_arr = np.array(img)
94 | img_arr[img_arr==self.ignore_id]=self.num_classes
95 |
96 | if cfg.STRICTBORDERCLASS != None:
97 | one_hot_orig = self.new_one_hot_converter(img_arr)
98 | mask = np.zeros((img_arr.shape[0],img_arr.shape[1]))
99 | for cls in cfg.STRICTBORDERCLASS:
100 | mask = np.logical_or(mask,(img_arr == cls))
101 | one_hot = 0
102 |
103 | border = cfg.BORDER_WINDOW
104 | if (cfg.REDUCE_BORDER_ITER !=-1 and cfg.ITER > cfg.REDUCE_BORDER_ITER):
105 | border = border // 2
106 | border_prediction = find_boundaries(img_arr, mode='thick').astype(np.uint8)
107 |
108 | for i in range(-border,border+1):
109 | for j in range(-border, border+1):
110 | shifted= shift(img_arr,(i,j), cval=self.num_classes)
111 | one_hot += self.new_one_hot_converter(shifted)
112 |
113 | one_hot[one_hot>1] = 1
114 |
115 | if cfg.STRICTBORDERCLASS != None:
116 | one_hot = np.where(np.expand_dims(mask,2), one_hot_orig, one_hot)
117 |
118 | one_hot = np.moveaxis(one_hot,-1,0)
119 |
120 |
121 | if (cfg.REDUCE_BORDER_ITER !=-1 and cfg.ITER > cfg.REDUCE_BORDER_ITER):
122 | one_hot = np.where(border_prediction,2*one_hot,1*one_hot)
123 | # print(one_hot.shape)
124 | return torch.from_numpy(one_hot).byte()
125 |
126 | class ResizeHeight(object):
127 | def __init__(self, size, interpolation=Image.BILINEAR):
128 | self.target_h = size
129 | self.interpolation = interpolation
130 |
131 | def __call__(self, img):
132 | w, h = img.size
133 | target_w = int(w / h * self.target_h)
134 | return img.resize((target_w, self.target_h), self.interpolation)
135 |
136 |
137 | class FreeScale(object):
138 | def __init__(self, size, interpolation=Image.BILINEAR):
139 | self.size = tuple(reversed(size)) # size: (h, w)
140 | self.interpolation = interpolation
141 |
142 | def __call__(self, img):
143 | return img.resize(self.size, self.interpolation)
144 |
145 |
146 | class FlipChannels(object):
147 | """
148 | Flip around the x-axis
149 | """
150 | def __call__(self, img):
151 | img = np.array(img)[:, :, ::-1]
152 | return Image.fromarray(img.astype(np.uint8))
153 |
154 |
155 | class RandomGaussianBlur(object):
156 | """
157 | Apply Gaussian Blur
158 | """
159 | def __call__(self, img):
160 | sigma = 0.15 + random.random() * 1.15
161 | blurred_img = gaussian(np.array(img), sigma=sigma, multichannel=True)
162 | blurred_img *= 255
163 | return Image.fromarray(blurred_img.astype(np.uint8))
164 |
165 |
166 | class RandomGaussianNoise(object):
167 | def __call__(self, img):
168 | noised_img = random_noise(np.array(img), mode='gaussian')
169 | noised_img *= 255
170 | return Image.fromarray(noised_img.astype(np.uint8))
171 |
172 |
173 | class RandomBilateralBlur(object):
174 | """
175 | Apply Bilateral Filtering
176 |
177 | """
178 | def __call__(self, img):
179 | sigma = random.uniform(0.05,0.75)
180 | blurred_img = denoise_bilateral(np.array(img), sigma_spatial=sigma, multichannel=True)
181 | blurred_img *= 255
182 | return Image.fromarray(blurred_img.astype(np.uint8))
183 |
184 | def _is_pil_image(img):
185 | if accimage is not None:
186 | return isinstance(img, (Image.Image, accimage.Image))
187 | else:
188 | return isinstance(img, Image.Image)
189 |
190 |
191 | def adjust_brightness(img, brightness_factor):
192 | """Adjust brightness of an Image.
193 |
194 | Args:
195 | img (PIL Image): PIL Image to be adjusted.
196 | brightness_factor (float): How much to adjust the brightness. Can be
197 | any non negative number. 0 gives a black image, 1 gives the
198 | original image while 2 increases the brightness by a factor of 2.
199 |
200 | Returns:
201 | PIL Image: Brightness adjusted image.
202 | """
203 | if not _is_pil_image(img):
204 | raise TypeError('img should be PIL Image. Got {}'.format(type(img)))
205 |
206 | enhancer = ImageEnhance.Brightness(img)
207 | img = enhancer.enhance(brightness_factor)
208 | return img
209 |
210 |
211 | def adjust_contrast(img, contrast_factor):
212 | """Adjust contrast of an Image.
213 |
214 | Args:
215 | img (PIL Image): PIL Image to be adjusted.
216 | contrast_factor (float): How much to adjust the contrast. Can be any
217 | non negative number. 0 gives a solid gray image, 1 gives the
218 | original image while 2 increases the contrast by a factor of 2.
219 |
220 | Returns:
221 | PIL Image: Contrast adjusted image.
222 | """
223 | if not _is_pil_image(img):
224 | raise TypeError('img should be PIL Image. Got {}'.format(type(img)))
225 |
226 | enhancer = ImageEnhance.Contrast(img)
227 | img = enhancer.enhance(contrast_factor)
228 | return img
229 |
230 |
231 | def adjust_saturation(img, saturation_factor):
232 | """Adjust color saturation of an image.
233 |
234 | Args:
235 | img (PIL Image): PIL Image to be adjusted.
236 | saturation_factor (float): How much to adjust the saturation. 0 will
237 | give a black and white image, 1 will give the original image while
238 | 2 will enhance the saturation by a factor of 2.
239 |
240 | Returns:
241 | PIL Image: Saturation adjusted image.
242 | """
243 | if not _is_pil_image(img):
244 | raise TypeError('img should be PIL Image. Got {}'.format(type(img)))
245 |
246 | enhancer = ImageEnhance.Color(img)
247 | img = enhancer.enhance(saturation_factor)
248 | return img
249 |
250 |
251 | def adjust_hue(img, hue_factor):
252 | """Adjust hue of an image.
253 |
254 | The image hue is adjusted by converting the image to HSV and
255 | cyclically shifting the intensities in the hue channel (H).
256 | The image is then converted back to original image mode.
257 |
258 | `hue_factor` is the amount of shift in H channel and must be in the
259 | interval `[-0.5, 0.5]`.
260 |
261 | See https://en.wikipedia.org/wiki/Hue for more details on Hue.
262 |
263 | Args:
264 | img (PIL Image): PIL Image to be adjusted.
265 | hue_factor (float): How much to shift the hue channel. Should be in
266 | [-0.5, 0.5]. 0.5 and -0.5 give complete reversal of hue channel in
267 | HSV space in positive and negative direction respectively.
268 | 0 means no shift. Therefore, both -0.5 and 0.5 will give an image
269 | with complementary colors while 0 gives the original image.
270 |
271 | Returns:
272 | PIL Image: Hue adjusted image.
273 | """
274 | if not(-0.5 <= hue_factor <= 0.5):
275 | raise ValueError('hue_factor is not in [-0.5, 0.5].'.format(hue_factor))
276 |
277 | if not _is_pil_image(img):
278 | raise TypeError('img should be PIL Image. Got {}'.format(type(img)))
279 | input_mode = img.mode
280 | if input_mode in {'L', '1', 'I', 'F'}:
281 | return img
282 |
283 | h, s, v = img.convert('HSV').split()
284 |
285 | np_h = np.array(h, dtype=np.uint8)
286 | # uint8 addition take cares of rotation across boundaries
287 | with np.errstate(over='ignore'):
288 | np_h += np.uint8(hue_factor * 255)
289 | h = Image.fromarray(np_h, 'L')
290 | img = Image.merge('HSV', (h, s, v)).convert(input_mode)
291 | return img
292 |
293 |
294 | class ColorJitter(object):
295 | """Randomly change the brightness, contrast and saturation of an image.
296 |
297 | Args:
298 | brightness (float): How much to jitter brightness. brightness_factor
299 | is chosen uniformly from [max(0, 1 - brightness), 1 + brightness].
300 | contrast (float): How much to jitter contrast. contrast_factor
301 | is chosen uniformly from [max(0, 1 - contrast), 1 + contrast].
302 | saturation (float): How much to jitter saturation. saturation_factor
303 | is chosen uniformly from [max(0, 1 - saturation), 1 + saturation].
304 | hue(float): How much to jitter hue. hue_factor is chosen uniformly from
305 | [-hue, hue]. Should be >=0 and <= 0.5.
306 | """
307 | def __init__(self, brightness=0, contrast=0, saturation=0, hue=0):
308 | self.brightness = brightness
309 | self.contrast = contrast
310 | self.saturation = saturation
311 | self.hue = hue
312 |
313 | @staticmethod
314 | def get_params(brightness, contrast, saturation, hue):
315 | """Get a randomized transform to be applied on image.
316 |
317 | Arguments are same as that of __init__.
318 |
319 | Returns:
320 | Transform which randomly adjusts brightness, contrast and
321 | saturation in a random order.
322 | """
323 | transforms = []
324 | if brightness > 0:
325 | brightness_factor = np.random.uniform(max(0, 1 - brightness), 1 + brightness)
326 | transforms.append(
327 | torch_tr.Lambda(lambda img: adjust_brightness(img, brightness_factor)))
328 |
329 | if contrast > 0:
330 | contrast_factor = np.random.uniform(max(0, 1 - contrast), 1 + contrast)
331 | transforms.append(
332 | torch_tr.Lambda(lambda img: adjust_contrast(img, contrast_factor)))
333 |
334 | if saturation > 0:
335 | saturation_factor = np.random.uniform(max(0, 1 - saturation), 1 + saturation)
336 | transforms.append(
337 | torch_tr.Lambda(lambda img: adjust_saturation(img, saturation_factor)))
338 |
339 | if hue > 0:
340 | hue_factor = np.random.uniform(-hue, hue)
341 | transforms.append(
342 | torch_tr.Lambda(lambda img: adjust_hue(img, hue_factor)))
343 |
344 | np.random.shuffle(transforms)
345 | transform = torch_tr.Compose(transforms)
346 |
347 | return transform
348 |
349 | def __call__(self, img):
350 | """
351 | Args:
352 | img (PIL Image): Input image.
353 |
354 | Returns:
355 | PIL Image: Color jittered image.
356 | """
357 | transform = self.get_params(self.brightness, self.contrast,
358 | self.saturation, self.hue)
359 | return transform(img)
360 |
--------------------------------------------------------------------------------
/utils/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/root0yang/BlindNet/1359fa6e9d1e9b011c416f43ec265e31cbde7d9a/utils/__init__.py
--------------------------------------------------------------------------------
/utils/attr_dict.py:
--------------------------------------------------------------------------------
1 | """
2 | # Code adapted from:
3 | # https://github.com/facebookresearch/Detectron/blob/master/detectron/utils/collections.py
4 |
5 | Source License
6 | # Copyright (c) 2017-present, Facebook, Inc.
7 | #
8 | # Licensed under the Apache License, Version 2.0 (the "License");
9 | # you may not use this file except in compliance with the License.
10 | # You may obtain a copy of the License at
11 | #
12 | # http://www.apache.org/licenses/LICENSE-2.0
13 | #
14 | # Unless required by applicable law or agreed to in writing, software
15 | # distributed under the License is distributed on an "AS IS" BASIS,
16 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
17 | # See the License for the specific language governing permissions and
18 | # limitations under the License.
19 | ##############################################################################
20 | #
21 | # Based on:
22 | # --------------------------------------------------------
23 | # Fast R-CNN
24 | # Copyright (c) 2015 Microsoft
25 | # Licensed under The MIT License [see LICENSE for details]
26 | # Written by Ross Girshick
27 | # --------------------------------------------------------
28 | """
29 |
30 | class AttrDict(dict):
31 |
32 | IMMUTABLE = '__immutable__'
33 |
34 | def __init__(self, *args, **kwargs):
35 | super(AttrDict, self).__init__(*args, **kwargs)
36 | self.__dict__[AttrDict.IMMUTABLE] = False
37 |
38 | def __getattr__(self, name):
39 | if name in self.__dict__:
40 | return self.__dict__[name]
41 | elif name in self:
42 | return self[name]
43 | else:
44 | raise AttributeError(name)
45 |
46 | def __setattr__(self, name, value):
47 | if not self.__dict__[AttrDict.IMMUTABLE]:
48 | if name in self.__dict__:
49 | self.__dict__[name] = value
50 | else:
51 | self[name] = value
52 | else:
53 | raise AttributeError(
54 | 'Attempted to set "{}" to "{}", but AttrDict is immutable'.
55 | format(name, value)
56 | )
57 |
58 | def immutable(self, is_immutable):
59 | """Set immutability to is_immutable and recursively apply the setting
60 | to all nested AttrDicts.
61 | """
62 | self.__dict__[AttrDict.IMMUTABLE] = is_immutable
63 | # Recursively set immutable state
64 | for v in self.__dict__.values():
65 | if isinstance(v, AttrDict):
66 | v.immutable(is_immutable)
67 | for v in self.values():
68 | if isinstance(v, AttrDict):
69 | v.immutable(is_immutable)
70 |
71 | def is_immutable(self):
72 | return self.__dict__[AttrDict.IMMUTABLE]
73 |
--------------------------------------------------------------------------------
/utils/misc.py:
--------------------------------------------------------------------------------
1 | """
2 | Miscellanous Functions
3 | """
4 |
5 | import sys
6 | import re
7 | import os
8 | import shutil
9 | import torch
10 | from datetime import datetime
11 | import logging
12 | from subprocess import call
13 | import shlex
14 | from tensorboardX import SummaryWriter
15 | import datasets
16 | import numpy as np
17 | import torchvision.transforms as standard_transforms
18 | import torchvision.utils as vutils
19 | from config import cfg
20 | import random
21 |
22 |
23 | # Create unique output dir name based on non-default command line args
24 | def make_exp_name(args, parser):
25 | exp_name = '{}-{}'.format(args.dataset[:4], args.arch[:])
26 | dict_args = vars(args)
27 |
28 | # sort so that we get a consistent directory name
29 | argnames = sorted(dict_args)
30 | ignorelist = ['date', 'exp', 'arch','prev_best_filepath', 'lr_schedule', 'max_cu_epoch', 'max_epoch',
31 | 'strict_bdr_cls', 'world_size', 'tb_path','best_record', 'test_mode', 'ckpt', 'coarse_boost_classes',
32 | 'crop_size', 'dist_url', 'syncbn', 'max_iter', 'color_aug', 'scale_max', 'scale_min', 'bs_mult',
33 | 'class_uniform_pct', 'class_uniform_tile']
34 | # build experiment name with non-default args
35 | for argname in argnames:
36 | if dict_args[argname] != parser.get_default(argname):
37 | if argname in ignorelist:
38 | continue
39 | if argname == 'snapshot':
40 | arg_str = 'PT'
41 | argname = ''
42 | elif argname == 'nosave':
43 | arg_str = ''
44 | argname=''
45 | elif argname == 'freeze_trunk':
46 | argname = ''
47 | arg_str = 'ft'
48 | elif argname == 'syncbn':
49 | argname = ''
50 | arg_str = 'sbn'
51 | elif argname == 'jointwtborder':
52 | argname = ''
53 | arg_str = 'rlx_loss'
54 | elif isinstance(dict_args[argname], bool):
55 | arg_str = 'T' if dict_args[argname] else 'F'
56 | else:
57 | arg_str = str(dict_args[argname])[:7]
58 | if argname is not '':
59 | exp_name += '_{}_{}'.format(str(argname), arg_str)
60 | else:
61 | exp_name += '_{}'.format(arg_str)
62 | # clean special chars out exp_name = re.sub(r'[^A-Za-z0-9_\-]+', '', exp_name)
63 | return exp_name
64 |
65 | def fast_hist(label_pred, label_true, num_classes):
66 | mask = (label_true >= 0) & (label_true < num_classes)
67 | hist = np.bincount(
68 | num_classes * label_true[mask].astype(int) +
69 | label_pred[mask], minlength=num_classes ** 2).reshape(num_classes, num_classes)
70 | return hist
71 |
72 | def per_class_iu(hist):
73 | return np.diag(hist) / (hist.sum(1) + hist.sum(0) - np.diag(hist))
74 |
75 | def save_log(prefix, output_dir, date_str, rank=0):
76 | fmt = '%(asctime)s.%(msecs)03d %(message)s'
77 | date_fmt = '%m-%d %H:%M:%S'
78 | filename = os.path.join(output_dir, prefix + '_' + date_str +'_rank_' + str(rank) +'.log')
79 | print("Logging :", filename)
80 | logging.basicConfig(level=logging.INFO, format=fmt, datefmt=date_fmt,
81 | filename=filename, filemode='w')
82 | console = logging.StreamHandler()
83 | console.setLevel(logging.INFO)
84 | formatter = logging.Formatter(fmt=fmt, datefmt=date_fmt)
85 | console.setFormatter(formatter)
86 | if rank == 0:
87 | logging.getLogger('').addHandler(console)
88 | else:
89 | fh = logging.FileHandler(filename)
90 | logging.getLogger('').addHandler(fh)
91 |
92 |
93 |
94 | def prep_experiment(args, parser):
95 | """
96 | Make output directories, setup logging, Tensorboard, snapshot code.
97 | """
98 | ckpt_path = args.ckpt
99 | tb_path = args.tb_path
100 | exp_name = make_exp_name(args, parser)
101 | args.exp_path = os.path.join(ckpt_path, args.date, args.exp, str(datetime.now().strftime('%m_%d_%H')))
102 | args.tb_exp_path = os.path.join(tb_path, args.date, args.exp, str(datetime.now().strftime('%m_%d_%H')))
103 | args.ngpu = torch.cuda.device_count()
104 | args.date_str = str(datetime.now().strftime('%Y_%m_%d_%H_%M_%S'))
105 | args.best_record = {}
106 | # args.best_record = {'epoch': -1, 'iter': 0, 'val_loss': 1e10, 'acc': 0,
107 | # 'acc_cls': 0, 'mean_iu': 0, 'fwavacc': 0}
108 | args.last_record = {}
109 | if args.local_rank == 0:
110 | os.makedirs(args.exp_path, exist_ok=True)
111 | os.makedirs(args.tb_exp_path, exist_ok=True)
112 | save_log('log', args.exp_path, args.date_str, rank=args.local_rank)
113 | open(os.path.join(args.exp_path, args.date_str + '.txt'), 'w').write(
114 | str(args) + '\n\n')
115 | writer = SummaryWriter(log_dir=args.tb_exp_path, comment=args.tb_tag)
116 | return writer
117 | return None
118 |
119 | def evaluate_eval_for_inference(hist, dataset=None):
120 | """
121 | Modified IOU mechanism for on-the-fly IOU calculations ( prevents memory overflow for
122 | large dataset) Only applies to eval/eval.py
123 | """
124 | # axis 0: gt, axis 1: prediction
125 | acc = np.diag(hist).sum() / hist.sum()
126 | acc_cls = np.diag(hist) / hist.sum(axis=1)
127 | acc_cls = np.nanmean(acc_cls)
128 | iu = np.diag(hist) / (hist.sum(axis=1) + hist.sum(axis=0) - np.diag(hist))
129 |
130 | print_evaluate_results(hist, iu, dataset=dataset)
131 | freq = hist.sum(axis=1) / hist.sum()
132 | mean_iu = np.nanmean(iu)
133 | logging.info('mean {}'.format(mean_iu))
134 | fwavacc = (freq[freq > 0] * iu[freq > 0]).sum()
135 | return acc, acc_cls, mean_iu, fwavacc
136 |
137 |
138 |
139 | def evaluate_eval(args, net, optimizer, scheduler, val_loss, hist, dump_images, writer, epoch=0, dataset_name=None, dataset=None, curr_iter=0, optimizer_at=None, scheduler_at=None, save_pth=True):
140 | """
141 | Modified IOU mechanism for on-the-fly IOU calculations ( prevents memory overflow for
142 | large dataset) Only applies to eval/eval.py
143 | """
144 | if val_loss is not None and hist is not None:
145 | # axis 0: gt, axis 1: prediction
146 | acc = np.diag(hist).sum() / hist.sum()
147 | acc_cls = np.diag(hist) / hist.sum(axis=1)
148 | acc_cls = np.nanmean(acc_cls)
149 | iu = np.diag(hist) / (hist.sum(axis=1) + hist.sum(axis=0) - np.diag(hist))
150 |
151 | print_evaluate_results(hist, iu, dataset_name=dataset_name, dataset=dataset)
152 | freq = hist.sum(axis=1) / hist.sum()
153 | mean_iu = np.nanmean(iu)
154 | logging.info('mean {}'.format(mean_iu))
155 | fwavacc = (freq[freq > 0] * iu[freq > 0]).sum()
156 | else:
157 | mean_iu = 0
158 |
159 | if dataset_name not in args.last_record.keys():
160 | args.last_record[dataset_name] = {}
161 |
162 | if save_pth:
163 | # update latest snapshot
164 | if 'mean_iu' in args.last_record[dataset_name]:
165 | last_snapshot = 'last_{}_epoch_{}_mean-iu_{:.5f}.pth'.format(
166 | dataset_name, args.last_record[dataset_name]['epoch'],
167 | args.last_record[dataset_name]['mean_iu'])
168 | last_snapshot = os.path.join(args.exp_path, last_snapshot)
169 | try:
170 | os.remove(last_snapshot)
171 | except OSError:
172 | pass
173 |
174 | last_snapshot = 'last_{}_epoch_{}_mean-iu_{:.5f}.pth'.format(dataset_name, epoch, mean_iu)
175 | last_snapshot = os.path.join(args.exp_path, last_snapshot)
176 | args.last_record[dataset_name]['mean_iu'] = mean_iu
177 | args.last_record[dataset_name]['epoch'] = epoch
178 |
179 | torch.cuda.synchronize()
180 |
181 | if optimizer_at is not None:
182 | torch.save({
183 | 'state_dict': net.state_dict(),
184 | 'optimizer': optimizer.state_dict(),
185 | 'optimizer_at': optimizer_at.state_dict(),
186 | 'scheduler': scheduler.state_dict(),
187 | 'scheduler_at': scheduler_at.state_dict(),
188 | 'epoch': epoch,
189 | 'mean_iu': mean_iu,
190 | 'command': ' '.join(sys.argv[1:])
191 | }, last_snapshot)
192 | else:
193 | torch.save({
194 | 'state_dict': net.state_dict(),
195 | 'optimizer': optimizer.state_dict(),
196 | 'scheduler': scheduler.state_dict(),
197 | 'epoch': epoch,
198 | 'mean_iu': mean_iu,
199 | 'command': ' '.join(sys.argv[1:])
200 | }, last_snapshot)
201 |
202 | if val_loss is not None and hist is not None:
203 | if dataset_name not in args.best_record.keys():
204 | args.best_record[dataset_name] = {'epoch': -1, 'iter': 0, 'val_loss': 1e10, 'acc': 0,
205 | 'acc_cls': 0, 'mean_iu': 0, 'fwavacc': 0}
206 | # update best snapshot
207 | if mean_iu > args.best_record[dataset_name]['mean_iu'] :
208 | # remove old best snapshot
209 | if args.best_record[dataset_name]['epoch'] != -1:
210 | best_snapshot = 'best_{}_epoch_{}_mean-iu_{:.5f}.pth'.format(
211 | dataset_name, args.best_record[dataset_name]['epoch'],
212 | args.best_record[dataset_name]['mean_iu'])
213 |
214 | best_snapshot = os.path.join(args.exp_path, best_snapshot)
215 | assert os.path.exists(best_snapshot), \
216 | 'cant find old snapshot {}'.format(best_snapshot)
217 | os.remove(best_snapshot)
218 |
219 | # save new best
220 | args.best_record[dataset_name]['val_loss'] = val_loss.avg
221 | args.best_record[dataset_name]['epoch'] = epoch
222 | args.best_record[dataset_name]['acc'] = acc
223 | args.best_record[dataset_name]['acc_cls'] = acc_cls
224 | args.best_record[dataset_name]['mean_iu'] = mean_iu
225 | args.best_record[dataset_name]['fwavacc'] = fwavacc
226 |
227 | best_snapshot = 'best_{}_epoch_{}_mean-iu_{:.5f}.pth'.format(
228 | dataset_name, args.best_record[dataset_name]['epoch'],
229 | args.best_record[dataset_name]['mean_iu'])
230 | best_snapshot = os.path.join(args.exp_path, best_snapshot)
231 | shutil.copyfile(last_snapshot, best_snapshot)
232 | else:
233 | logging.info("Saved file to {}".format(last_snapshot))
234 |
235 | if val_loss is not None and hist is not None:
236 | logging.info('-' * 107)
237 | fmt_str = '[epoch %d], [dataset name %s], [val loss %.5f], [acc %.5f], [acc_cls %.5f], ' +\
238 | '[mean_iu %.5f], [fwavacc %.5f]'
239 | logging.info(fmt_str % (epoch, dataset_name, val_loss.avg, acc, acc_cls, mean_iu, fwavacc))
240 | if save_pth:
241 | fmt_str = 'best record: [dataset name %s], [val loss %.5f], [acc %.5f], [acc_cls %.5f], ' +\
242 | '[mean_iu %.5f], [fwavacc %.5f], [epoch %d], '
243 | logging.info(fmt_str % (dataset_name,
244 | args.best_record[dataset_name]['val_loss'], args.best_record[dataset_name]['acc'],
245 | args.best_record[dataset_name]['acc_cls'], args.best_record[dataset_name]['mean_iu'],
246 | args.best_record[dataset_name]['fwavacc'], args.best_record[dataset_name]['epoch']))
247 | logging.info('-' * 107)
248 |
249 | if writer:
250 | # tensorboard logging of validation phase metrics
251 | writer.add_scalar('{}/acc'.format(dataset_name), acc, curr_iter)
252 | writer.add_scalar('{}/acc_cls'.format(dataset_name), acc_cls, curr_iter)
253 | writer.add_scalar('{}/mean_iu'.format(dataset_name), mean_iu, curr_iter)
254 | writer.add_scalar('{}/val_loss'.format(dataset_name), val_loss.avg, curr_iter)
255 |
256 |
257 |
258 |
259 |
260 | def print_evaluate_results(hist, iu, dataset_name=None, dataset=None):
261 | # fixme: Need to refactor this dict
262 | try:
263 | id2cat = dataset.id2cat
264 | except:
265 | id2cat = {i: i for i in range(datasets.num_classes)}
266 | iu_false_positive = hist.sum(axis=1) - np.diag(hist)
267 | iu_false_negative = hist.sum(axis=0) - np.diag(hist)
268 | iu_true_positive = np.diag(hist)
269 |
270 | logging.info('Dataset name: {}'.format(dataset_name))
271 | logging.info('IoU:')
272 | logging.info('label_id label iU Precision Recall TP FP FN')
273 | for idx, i in enumerate(iu):
274 | # Format all of the strings:
275 | idx_string = "{:2d}".format(idx)
276 | class_name = "{:>13}".format(id2cat[idx]) if idx in id2cat else ''
277 | iu_string = '{:5.1f}'.format(i * 100)
278 | total_pixels = hist.sum()
279 | tp = '{:5.1f}'.format(100 * iu_true_positive[idx] / total_pixels)
280 | fp = '{:5.1f}'.format(
281 | iu_false_positive[idx] / iu_true_positive[idx])
282 | fn = '{:5.1f}'.format(iu_false_negative[idx] / iu_true_positive[idx])
283 | precision = '{:5.1f}'.format(
284 | iu_true_positive[idx] / (iu_true_positive[idx] + iu_false_positive[idx]))
285 | recall = '{:5.1f}'.format(
286 | iu_true_positive[idx] / (iu_true_positive[idx] + iu_false_negative[idx]))
287 | logging.info('{} {} {} {} {} {} {} {}'.format(
288 | idx_string, class_name, iu_string, precision, recall, tp, fp, fn))
289 |
290 |
291 |
292 |
293 | class AverageMeter(object):
294 |
295 | def __init__(self):
296 | self.reset()
297 |
298 | def reset(self):
299 | self.val = 0
300 | self.avg = 0
301 | self.sum = 0
302 | self.count = 0
303 |
304 | def update(self, val, n=1):
305 | self.val = val
306 | self.sum += val * n
307 | self.count += n
308 | self.avg = self.sum / self.count
309 |
--------------------------------------------------------------------------------
/utils/my_data_parallel.py:
--------------------------------------------------------------------------------
1 |
2 | """
3 | # Code adapted from:
4 | # https://github.com/pytorch/pytorch/blob/master/torch/nn/parallel/data_parallel.py
5 | #
6 | # BSD 3-Clause License
7 | #
8 | # Copyright (c) 2017,
9 | # All rights reserved.
10 | #
11 | # Redistribution and use in source and binary forms, with or without
12 | # modification, are permitted provided that the following conditions are met:
13 | #
14 | # * Redistributions of source code must retain the above copyright notice, this
15 | # list of conditions and the following disclaimer.
16 | #
17 | # * Redistributions in binary form must reproduce the above copyright notice,
18 | # this list of conditions and the following disclaimer in the documentation
19 | # and/or other materials provided with the distribution.
20 | #
21 | # * Neither the name of the copyright holder nor the names of its
22 | # contributors may be used to endorse or promote products derived from
23 | # this software without specific prior written permission.
24 | #
25 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
26 | # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
27 | # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
28 | # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
29 | # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
30 | # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
31 | # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
32 | # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
33 | # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
34 | # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.s
35 | """
36 |
37 |
38 | import operator
39 | import torch
40 | import warnings
41 | from torch.nn.modules import Module
42 | from torch.nn.parallel.scatter_gather import scatter_kwargs, gather
43 | from torch.nn.parallel.replicate import replicate
44 | from torch.nn.parallel.parallel_apply import parallel_apply
45 |
46 |
47 | def _check_balance(device_ids):
48 | imbalance_warn = """
49 | There is an imbalance between your GPUs. You may want to exclude GPU {} which
50 | has less than 75% of the memory or cores of GPU {}. You can do so by setting
51 | the device_ids argument to DataParallel, or by setting the CUDA_VISIBLE_DEVICES
52 | environment variable."""
53 |
54 | dev_props = [torch.cuda.get_device_properties(i) for i in device_ids]
55 |
56 | def warn_imbalance(get_prop):
57 | values = [get_prop(props) for props in dev_props]
58 | min_pos, min_val = min(enumerate(values), key=operator.itemgetter(1))
59 | max_pos, max_val = max(enumerate(values), key=operator.itemgetter(1))
60 | if min_val / max_val < 0.75:
61 | warnings.warn(imbalance_warn.format(device_ids[min_pos], device_ids[max_pos]))
62 | return True
63 | return False
64 |
65 | if warn_imbalance(lambda props: props.total_memory):
66 | return
67 | if warn_imbalance(lambda props: props.multi_processor_count):
68 | return
69 |
70 |
71 |
72 | def data_parallel(module, inputs, device_ids=None, output_device=None, dim=0, module_kwargs=None, gather=True):
73 | """
74 | Evaluates module(input) in parallel across the GPUs given in device_ids.
75 | This is the functional version of the DataParallel module.
76 | Args:
77 | module: the module to evaluate in parallel
78 | inputs: inputs to the module
79 | device_ids: GPU ids on which to replicate module
80 | output_device: GPU location of the output Use -1 to indicate the CPU.
81 | (default: device_ids[0])
82 | Returns:
83 | a Tensor containing the result of module(input) located on
84 | output_device
85 | """
86 | if not isinstance(inputs, tuple):
87 | inputs = (inputs,)
88 |
89 | if device_ids is None:
90 | device_ids = list(range(torch.cuda.device_count()))
91 |
92 | if output_device is None:
93 | output_device = device_ids[0]
94 |
95 | inputs, module_kwargs = scatter_kwargs(inputs, module_kwargs, device_ids, dim)
96 | if len(device_ids) == 1:
97 | return module(*inputs[0], **module_kwargs[0])
98 | used_device_ids = device_ids[:len(inputs)]
99 | replicas = replicate(module, used_device_ids)
100 | outputs = parallel_apply(replicas, inputs, module_kwargs, used_device_ids)
101 | if gather:
102 | return gather(outputs, output_device, dim)
103 | else:
104 | return outputs
105 |
106 |
107 |
108 | class MyDataParallel(Module):
109 | """
110 | Implements data parallelism at the module level.
111 | This container parallelizes the application of the given module by
112 | splitting the input across the specified devices by chunking in the batch
113 | dimension. In the forward pass, the module is replicated on each device,
114 | and each replica handles a portion of the input. During the backwards
115 | pass, gradients from each replica are summed into the original module.
116 | The batch size should be larger than the number of GPUs used.
117 | See also: :ref:`cuda-nn-dataparallel-instead`
118 | Arbitrary positional and keyword inputs are allowed to be passed into
119 | DataParallel EXCEPT Tensors. All tensors will be scattered on dim
120 | specified (default 0). Primitive types will be broadcasted, but all
121 | other types will be a shallow copy and can be corrupted if written to in
122 | the model's forward pass.
123 | .. warning::
124 | Forward and backward hooks defined on :attr:`module` and its submodules
125 | will be invoked ``len(device_ids)`` times, each with inputs located on
126 | a particular device. Particularly, the hooks are only guaranteed to be
127 | executed in correct order with respect to operations on corresponding
128 | devices. For example, it is not guaranteed that hooks set via
129 | :meth:`~torch.nn.Module.register_forward_pre_hook` be executed before
130 | `all` ``len(device_ids)`` :meth:`~torch.nn.Module.forward` calls, but
131 | that each such hook be executed before the corresponding
132 | :meth:`~torch.nn.Module.forward` call of that device.
133 | .. warning::
134 | When :attr:`module` returns a scalar (i.e., 0-dimensional tensor) in
135 | :func:`forward`, this wrapper will return a vector of length equal to
136 | number of devices used in data parallelism, containing the result from
137 | each device.
138 | .. note::
139 | There is a subtlety in using the
140 | ``pack sequence -> recurrent network -> unpack sequence`` pattern in a
141 | :class:`~torch.nn.Module` wrapped in :class:`~torch.nn.DataParallel`.
142 | See :ref:`pack-rnn-unpack-with-data-parallelism` section in FAQ for
143 | details.
144 | Args:
145 | module: module to be parallelized
146 | device_ids: CUDA devices (default: all devices)
147 | output_device: device location of output (default: device_ids[0])
148 | Attributes:
149 | module (Module): the module to be parallelized
150 | Example::
151 | >>> net = torch.nn.DataParallel(model, device_ids=[0, 1, 2])
152 | >>> output = net(input_var)
153 | """
154 |
155 | # TODO: update notes/cuda.rst when this class handles 8+ GPUs well
156 |
157 | def __init__(self, module, device_ids=None, output_device=None, dim=0, gather=True):
158 | super(MyDataParallel, self).__init__()
159 |
160 | if not torch.cuda.is_available():
161 | self.module = module
162 | self.device_ids = []
163 | return
164 |
165 | if device_ids is None:
166 | device_ids = list(range(torch.cuda.device_count()))
167 | if output_device is None:
168 | output_device = device_ids[0]
169 | self.dim = dim
170 | self.module = module
171 | self.device_ids = device_ids
172 | self.output_device = output_device
173 | self.gather_bool = gather
174 |
175 | _check_balance(self.device_ids)
176 |
177 | if len(self.device_ids) == 1:
178 | self.module.cuda(device_ids[0])
179 |
180 | def forward(self, *inputs, **kwargs):
181 | if not self.device_ids:
182 | return self.module(*inputs, **kwargs)
183 | inputs, kwargs = self.scatter(inputs, kwargs, self.device_ids)
184 | if len(self.device_ids) == 1:
185 | return [self.module(*inputs[0], **kwargs[0])]
186 | replicas = self.replicate(self.module, self.device_ids[:len(inputs)])
187 | outputs = self.parallel_apply(replicas, inputs, kwargs)
188 | if self.gather_bool:
189 | return self.gather(outputs, self.output_device)
190 | else:
191 | return outputs
192 |
193 | def replicate(self, module, device_ids):
194 | return replicate(module, device_ids)
195 |
196 | def scatter(self, inputs, kwargs, device_ids):
197 | return scatter_kwargs(inputs, kwargs, device_ids, dim=self.dim)
198 |
199 | def parallel_apply(self, replicas, inputs, kwargs):
200 | return parallel_apply(replicas, inputs, kwargs, self.device_ids[:len(replicas)])
201 |
202 | def gather(self, outputs, output_device):
203 | return gather(outputs, output_device, dim=self.dim)
204 |
205 |
--------------------------------------------------------------------------------