├── .gitignore
├── .idea
├── .gitignore
├── inspectionProfiles
│ └── profiles_settings.xml
├── jsd_experiments.iml
├── misc.xml
├── modules.xml
└── vcs.xml
├── LICENSE
├── README.md
├── SegThor
├── .idea
│ ├── .gitignore
│ ├── ADELE.iml
│ ├── inspectionProfiles
│ │ └── profiles_settings.xml
│ ├── misc.xml
│ ├── modules.xml
│ └── vcs.xml
├── LICENSE
├── brat
│ ├── __init__.py
│ ├── brat_util.py
│ ├── label_correction.py
│ ├── loading.py
│ ├── train_segthor.py
│ ├── unet_model.py
│ └── unet_parts.py
├── lib
│ └── utils
│ │ ├── JSD_loss.py
│ │ └── iou_computation.py
└── requirements.txt
├── __init__.py
├── config.py
├── lib
├── datasets
│ ├── BaseDataset.py
│ ├── BaseMultiwGTauginfoDataset.py
│ ├── VOCDataset.py
│ ├── VOCEvalDataset.py
│ ├── VOCTrainwsegDataset.py
│ ├── __init__.py
│ ├── generateData.py
│ ├── metric.py
│ ├── transform.py
│ ├── transformmultiGT.py
│ └── transformmultiGTauginfo.py
├── net
│ ├── __init__.py
│ ├── backbone
│ │ ├── __init__.py
│ │ ├── builder.py
│ │ ├── resnet.py
│ │ ├── resnet38d.py
│ │ └── xception.py
│ ├── deeplabv1_wo_interp.py
│ ├── generateNet.py
│ ├── operators
│ │ ├── ASPP.py
│ │ ├── PPM.py
│ │ └── __init__.py
│ └── sync_batchnorm
│ │ ├── __init__.py
│ │ ├── batchnorm.py
│ │ ├── comm.py
│ │ ├── replicate.py
│ │ ├── sync_batchnorm
│ │ ├── __init__.py
│ │ ├── batchnorm.py
│ │ ├── batchnorm_reimpl.py
│ │ ├── comm.py
│ │ ├── replicate.py
│ │ └── unittest.py
│ │ ├── tests
│ │ ├── test_numeric_batchnorm.py
│ │ └── test_sync_batchnorm.py
│ │ └── unittest.py
└── utils
│ ├── DenseCRF.py
│ ├── JSD_loss.py
│ ├── __init__.py
│ ├── configuration.py
│ ├── eval_net_utils.py
│ ├── finalprocess.py
│ ├── imutils.py
│ ├── iou_computation.py
│ ├── logger.py
│ ├── registry.py
│ ├── test_utils.py
│ └── visualization.py
├── requirements.txt
└── train.py
/.gitignore:
--------------------------------------------------------------------------------
1 | sftp*.json
2 | run*.sh
3 | train_onebyone_eval_w_compact_dict.py
4 |
--------------------------------------------------------------------------------
/.idea/.gitignore:
--------------------------------------------------------------------------------
1 | # Default ignored files
2 | /shelf/
3 | /workspace.xml
4 | # Datasource local storage ignored files
5 | /dataSources/
6 | /dataSources.local.xml
7 | # Editor-based HTTP Client requests
8 | /httpRequests/
9 |
--------------------------------------------------------------------------------
/.idea/inspectionProfiles/profiles_settings.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
--------------------------------------------------------------------------------
/.idea/jsd_experiments.iml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
--------------------------------------------------------------------------------
/.idea/misc.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
--------------------------------------------------------------------------------
/.idea/modules.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
--------------------------------------------------------------------------------
/.idea/vcs.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2021 Hibercraft
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # ADELE (Adaptive Early-Learning Correction for Segmentation from Noisy Annotations) (CVPR 2022 Oral)
2 |
3 |
4 | Sheng Liu*, Kangning Liu*, Weicheng Zhu, Yiqiu Shen, Carlos Fernandez-Granda
5 |
6 | (* The first two authors contribute equally, order decided by coin flipping.)
7 |
8 |
9 |
10 |
11 | Official Implementation of [Adaptive Early-Learning Correction for Segmentation from Noisy Annotations](https://arxiv.org/abs/2110.03740) (CVPR 2022 Oral)
12 |
13 | ## PASCAL VOC dataset
14 | Thanks to the work of Yude Wang, the code of this repository borrows heavily from his SEAM repository, and we follw the same pipeline to verify the effectiveness of our ADELE.
15 | We use the same ImageNet pretrained ResNet38 model as SEAM, which can be downloaded from https://github.com/YudeWang/semantic-segmentation-codebase/tree/main/experiment/seamv1-pseudovoc
16 |
17 |
18 |
19 |
20 |
21 |
22 | The code related to PASCAL VOC locates in the main folder, we provide the trained model for SEAM+ADELE in the following link
23 |
24 | https://drive.google.com/file/d/10cTOraETOmb2jOCJ4E0m_y9lrjrA3g2u/view?usp=sharing
25 |
26 | We use two NVIDIA Quadro RTX 8000 GPUs to train the model, if you encounter out of memory issue, please consider decreasing the resolution of the input image.
27 |
28 | ### Installation
29 | - Install python dependencies.
30 | ```
31 | pip install -r requirements.txt
32 | ```
33 |
34 | Note that we use comet to record the statistics online. Comet is similar to tensorboard, more information can found via https://www.comet.ml/site/ .
35 | - Create softlink to your dataset. Make sure that the dataset can be accessed by `$your_dataset_path/VOCdevkit/VOC2012...`
36 | ```
37 | ln -s $your_dataset_path data
38 | ```
39 |
40 |
41 |
42 |
43 | Inference code is the same as the official code for SEAM. Attach the Code link provide by the SEAM author: https://github.com/YudeWang/semantic-segmentation-codebase/tree/main/experiment/seamv1-pseudovoc
44 |
45 |
46 |
47 |
48 | For the training code, an example script for ADELE would be:
49 |
50 | ```
51 | python train.py \
52 | --EXP_NAME EXP_name \
53 | --Lambda1 1 --TRAIN_BATCHES 10 --TRAIN_LR 0.001 --mask_threshold 0.8 \
54 | --scale_index 0 --flip yes --CRF yes \
55 | --dict_save_scale_factor 1 --npl_metrics 0 \
56 | --api_key API_key \
57 | --r_threshold 0.9 --Reinit_dict yes \
58 | --DATA_PSEUDO_GT Inital_Pseudo_Label_Location
59 | ```
60 |
61 |
62 |
63 | We store some default value for the arguments in the config.py file, those value would be passed to arguments as cfg.XXX. You may change the default value in the config.py or change that via arguments in the script.
64 | It is especially important to assign the path of your initial pseudo annotation via --DATA_PSEUDO_GT or specify that in cfg.DATA_PSEUDO_GT in the config.py file. For the detailed method to obtain the initial pseudo annotation, please refer to the related method such as AffinityNet, SEAM, ICD, NSROM, etc.
65 |
66 |
67 | The arguments represent:
68 |
69 | parser.add_argument("--EXP_NAME", type=str, default=cfg.EXP_NAME,
70 | help="the name of the experiment")
71 | parser.add_argument("--scale_factor", type=float, default=cfg.scale_factor,
72 | help="scale_factor of downsample the image")
73 | parser.add_argument("--scale_factor2", type=float, default=cfg.scale_factor2,
74 | help="scale_factor of upsample the image")
75 | parser.add_argument("--DATA_PSEUDO_GT", type=str, default=cfg.DATA_PSEUDO_GT,
76 | help="Data path for the main segmentation map")
77 | parser.add_argument("--TRAIN_CKPT", type=str, default=cfg.TRAIN_CKPT,
78 | help="Training path")
79 | parser.add_argument("--Lambda1", type=float, default=1,
80 | help="to balance the loss between CE and Consistency loss")
81 | parser.add_argument("--TRAIN_BATCHES", type=int, default=cfg.TRAIN_BATCHES,
82 | help="training batch szie")
83 | parser.add_argument('--threshold', type=float, default=0.8,
84 | help="threshold to select the mask for Consistency loss computation ")
85 | parser.add_argument('--DATA_WORKERS', type=int, default=cfg.DATA_WORKERS,
86 | help="number of workers in dataloader")
87 | parser.add_argument('--TRAIN_LR', type=float,
88 | default=cfg.TRAIN_LR,
89 | help="the path of trained weight")
90 | parser.add_argument('--TRAIN_ITERATION', type=int,
91 | default=cfg.TRAIN_ITERATION,
92 | help="the training iteration number")
93 | parser.add_argument('--DATA_RANDOMCROP', type=int, default=cfg.DATA_RANDOMCROP,
94 | help="the resolution of random crop")
95 |
96 |
97 |
98 | # related to the pseudo label updating
99 | parser.add_argument('--mask_threshold', type=float, default=0.8,
100 | help="only the region with high probability and disagree with Pseudo label be updated")
101 | parser.add_argument('--update_interval', type=int, default=1,
102 | help="evaluate the prediction every 1 epoch")
103 | parser.add_argument('--npl_metrics', type=int, default=0,
104 | help="0: using the original cam to compute the npl similarity, 1: use the updated pseudo label to compute the npl")
105 | parser.add_argument('--r_threshold', type=float, default=0.9,
106 | help="the r threshold to decide if_update")
107 |
108 | # related to the eval mode
109 | parser.add_argument('--scale_index', type=int, default=2,
110 | help="0: scale [0.7, 1.0, 1.5] 1:[0.5, 1.0, 1.75], 2:[0.5, 0.75, 1.0, 1.25, 1.5, 1.75] ")
111 | parser.add_argument('--flip', type=str, default='yes',
112 | help="do not flip in the eval pred if no, else flip")
113 | parser.add_argument('--CRF', type=str, default='no',
114 | help="whether to use CRF, yes or no, default no")
115 | parser.add_argument('--dict_save_scale_factor', type=float, default=1,
116 | help="dict_save_scale_factor downsample_factor (in case the CPU memory is not enough)")
117 | parser.add_argument('--evaluate_interval', type=int, default=1,
118 | help="evaluate the prediction every 1 epoch, this is always set to one for PASCAL VOC dataset")
119 | parser.add_argument('--Reinit_dict', type=str2bool, nargs='?',
120 | const=True, default=False,
121 | help="whether to reinit the dict every epoch")
122 | parser.add_argument('--evaluate_aug_epoch', type=int, default=9,
123 | help="when to start aug the evaluate with CRF and flip, this can be used to save some time when updating the pseudo label, we did not find significant difference")
124 |
125 |
126 |
127 | # continue_training_related:
128 | parser.add_argument('--continue_train_epoch', type=int, default=0,
129 | help="load the trained model from which epoch, if 0, no continue training")
130 | parser.add_argument('--checkpoint_path', type=str, default='no',
131 | help="the checkpoint path to load the model")
132 | parser.add_argument('--dict_path', type=str,
133 | default='no',
134 | help="the dict path of seg path")
135 | parser.add_argument('--MODEL_BACKBONE_PRETRAIN', type=str2bool, nargs='?',
136 | const=True, default=True,
137 | help="Do not load pretrained model if false")
138 |
139 |
140 | # Comet
141 | parser.add_argument('--api_key', type=str,
142 | default='',
143 | help="The api_key of Comet, please refer to https://www.comet.ml/site/ for more information"
144 | parser.add_argument('--online', type=str2bool, nargs='?',
145 | const=True, default=True,
146 | help="False when use Comet offline")
147 |
148 |
149 |
150 |
151 |
152 |
153 |
154 |
155 | ## SegTHOR dataset
156 | The code related to SegTHOR locates in the folder SegThor, please go to the subdirectory SegThor
157 | ### Installation
158 |
159 | - Install python dependencies.
160 | ```
161 | pip install -r requirements.txt
162 | ```
163 | - Downlaod the SegTHOR dataset and conduct data preprocessing, resize all the image to the size of 256*256 using linear interpolation of opencv_python (INTER_LINEAR).
164 |
165 | The details of public SegTHOR dataset can be found in [this link](https://competitions.codalab.org/competitions/21145).
166 |
167 | In this study, we randomly assign patients in the original training set into training, validation, and test set using following scheme:
168 |
169 | - training set: ['Patient_01', 'Patient_02', 'Patient_03', 'Patient_04',
170 | 'Patient_05', 'Patient_06', 'Patient_07', 'Patient_09',
171 | 'Patient_10', 'Patient_11', 'Patient_12', 'Patient_13',
172 | 'Patient_14', 'Patient_15', 'Patient_16', 'Patient_17',
173 | 'Patient_18', 'Patient_19', 'Patient_20', 'Patient_22',
174 | 'Patient_24', 'Patient_25', 'Patient_26', 'Patient_28',
175 | 'Patient_30', 'Patient_31', 'Patient_33', 'Patient_36',
176 | 'Patient_38', 'Patient_39', 'Patient_40']
177 | - validation set: ['Patient_21', 'Patient_23', 'Patient_27', 'Patient_29',
178 | 'Patient_37']
179 | - test set: ['Patient_08', 'Patient_27', 'Patient_32', 'Patient_34',
180 | 'Patient_35']
181 |
182 | We used only slices that contain foreground class and downsampled all slices into 256 * 256 pixels using linear interpolation.
183 |
184 | ### Experiments
185 | Here is the example script of ADELE:
186 | ```
187 | python3 brat/train_segthor.py \
188 | --cache-dir DIR_OF_THE_DATA \
189 | --data-list DIR_OF_THE_DATALIST \
190 | --save-dir MODEL_SAVE_DIR \
191 | --model-name MODEL_NAME \
192 | --seed 0 \
193 | --jsd-lambda 1 \
194 | --rho 0.8 \
195 | --label-correction \
196 | --tau_fg 0.7 \
197 | --tau_bg 0.7 \
198 | --r 0.9
199 | ```
200 |
201 | where the arguments represent:
202 | * `cache-dir` - Parent dir of the datalist, tr.pkl, val.pkl, ts.pkl, which are the input data for training, validation and testing set.
203 | * `data-list` - Parent dir of the data_list.pkl file, which is the list of names for the input data.
204 | * `save-dir` - Folder, where models and results will be saved.
205 | * `model-name` - Name of the model.
206 | * `seed` - the random seed of the noise realization, default 0.
207 | * `jsd-lambda` - the consistency strength, if set to 0, no consistency regularization will be applied, default 1.
208 | * `rho` - consistency confidence threshold, this is the threshold on the confidence of model's prediction to decide which examples are applied with consistency regularization
209 | * `label-correction` - whether to conduct label correction, if set this arguments, the model will do label correction, default False.
210 | * `tau_fg, tau_bg` - label correction confidence threshold for foreground and background, in the main paper and all the experiment, we set these two values to be the same for simplicity, default 0.7.
211 | * `r` - curve fitting threshold to control when a specific semantic category will be corrected, default 0.9.
212 |
213 |
214 |
215 | Here is the example script of baseline:
216 | ```
217 | python3 brat/train_segthor.py \
218 | --cache-dir DIR_OF_THE_DATA \
219 | --data-list DIR_OF_THE_DATALIST \
220 | --save-dir MODEL_SAVE_DIR \
221 | --model-name MODEL_NAME \
222 | --seed 0 \
223 | --jsd-lambda 0
224 | ```
225 |
226 |
227 |
228 |
229 |
230 |
231 |
232 |
233 | ## Citation
234 |
235 | Please cite our paper if the code is helpful to your research.
236 | ```
237 | @article{liu2021adaptive,
238 | title={Adaptive Early-Learning Correction for Segmentation from Noisy Annotations},
239 | author={Liu, Sheng and Liu, Kangning and Zhu, Weicheng and Shen, Yiqiu and Fernandez-Granda, Carlos},
240 | journal={CVPR 2022},
241 | year={2022}
242 | }
243 | ```
244 |
245 |
--------------------------------------------------------------------------------
/SegThor/.idea/.gitignore:
--------------------------------------------------------------------------------
1 | # Default ignored files
2 | /shelf/
3 | /workspace.xml
4 | # Datasource local storage ignored files
5 | /dataSources/
6 | /dataSources.local.xml
7 | # Editor-based HTTP Client requests
8 | /httpRequests/
9 |
--------------------------------------------------------------------------------
/SegThor/.idea/ADELE.iml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
--------------------------------------------------------------------------------
/SegThor/.idea/inspectionProfiles/profiles_settings.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
--------------------------------------------------------------------------------
/SegThor/.idea/misc.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
--------------------------------------------------------------------------------
/SegThor/.idea/modules.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
--------------------------------------------------------------------------------
/SegThor/.idea/vcs.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
--------------------------------------------------------------------------------
/SegThor/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2021 Sheng Liu
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/SegThor/brat/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Kangningthu/ADELE/7195bd0af39be79c533d67dd7eab7f9bfd6a4285/SegThor/brat/__init__.py
--------------------------------------------------------------------------------
/SegThor/brat/brat_util.py:
--------------------------------------------------------------------------------
1 | import pandas as pd
2 | import numpy as np
3 | import time
4 |
5 |
6 | class DocumentUnit:
7 | """
8 | Object that document the output from the model in an epoch
9 | """
10 |
11 | def __init__(self, columns):
12 | self.data_dict = dict([(col, []) for col in columns])
13 | # skip case level pred and labels since they has 1/2 length
14 | self.skip_key = ["case_pred", "case_label", "left_case_pred",
15 | "right_case_pred", "fusion_case_pred", "left_right_case_pred"]
16 | # accumulator for localization
17 | self.localization_accumulator = None
18 |
19 | def update_accumulator(self, delta):
20 | """
21 | Method that accumulates localization pixel-wise values for metrics such as mIOU
22 | :param delta:
23 | :return:
24 | """
25 | if self.localization_accumulator is None:
26 | self.localization_accumulator = delta
27 | else:
28 | assert self.localization_accumulator.shape == delta.shape,\
29 | "self.localization_accumulator.shape {0} != delta.shape {1}".format(self.localization_accumulator.shape, delta.shape)
30 | self.localization_accumulator += delta
31 |
32 | def add_values(self, column, values, process_method=lambda x: x):
33 | """
34 | Method that add values into the document unit
35 | :param column:
36 | :param values:
37 | :return:
38 | """
39 | for val in values:
40 | self.data_dict[column].append(process_method(val))
41 |
42 | def form_df(self):
43 | """
44 | Method that creates a dataframe out of stored data
45 | :return:
46 | """
47 | to_be_save_dict = {}
48 | for key in self.data_dict:
49 | if key not in self.skip_key and len(self.data_dict[key]) != 0:
50 | to_be_save_dict[key] = self.data_dict[key]
51 | df = pd.DataFrame(to_be_save_dict).reset_index()
52 | return df
53 |
54 | def get_latest_results(self):
55 | """
56 | Method that retrieves the latest results from the stored data
57 | :return:
58 | """
59 | to_be_save_dict = {}
60 | for key in self.data_dict:
61 | if key not in self.skip_key and len(self.data_dict[key]) != 0:
62 | to_be_save_dict[key] = self.data_dict[key][-1]
63 | return to_be_save_dict
64 |
65 | def to_csv(self, dir):
66 | """
67 | Export to csv
68 | :param dir:
69 | :return:
70 | """
71 | df = self.form_df()
72 | df.to_csv(dir, index=False)
73 |
74 |
75 |
76 | class RuntimeProfiler:
77 | """
78 | Object that documents run-time
79 | """
80 | def __init__(self):
81 | self.elpased_time_dict = {}
82 | self.current_time_point = None
83 |
84 | def tik(self, time_category=None):
85 | """
86 | Take a time point
87 | :param time_category:
88 | :return:
89 | """
90 | new_time = time.time()
91 | return_time = False
92 | if self.current_time_point is not None:
93 | return_time = True
94 | elapsed_time = new_time - self.current_time_point
95 | if time_category is not None:
96 | if time_category not in self.elpased_time_dict:
97 | self.elpased_time_dict[time_category] = []
98 | self.elpased_time_dict[time_category].append(elapsed_time)
99 | self.current_time_point = new_time
100 | if return_time:
101 | return elapsed_time
102 |
103 |
104 | def report_avg(self):
105 | """
106 | Generate a format string for average run-time statistics
107 | :return:
108 | """
109 | output_str = ""
110 | for time_category in self.elpased_time_dict:
111 | output_str += "category:{0}, avg_time:{1}, std_time:{2}, min_time:{3}, max_time:{4}, num_points:{5}\n".format(
112 | time_category, np.mean(self.elpased_time_dict[time_category]), np.std(self.elpased_time_dict[time_category]),
113 | np.min(self.elpased_time_dict[time_category]), np.max(self.elpased_time_dict[time_category]),
114 | len(self.elpased_time_dict[time_category])
115 | )
116 | return output_str
117 |
118 | def report_latest(self):
119 | """
120 | Generate a format string for the latest run-time statistics
121 | :return:
122 | """
123 | output_str = ""
124 | for time_category in self.elpased_time_dict:
125 | output_str += "category:{0}, runtime:{1} \n".format(time_category, self.elpased_time_dict[time_category][-1])
126 | return output_str
--------------------------------------------------------------------------------
/SegThor/brat/label_correction.py:
--------------------------------------------------------------------------------
1 | from scipy.optimize import curve_fit
2 | import numpy as np
3 |
4 | def curve_func(x, a, b, c):
5 | return a *(1-np.exp( -1/c * x**b ))
6 |
7 |
8 | def fit(func, x, y):
9 | popt, pcov = curve_fit(func, x, y, p0 =(1,1,1), method= 'trf', sigma = np.geomspace(1,.1,len(y)), absolute_sigma=True, bounds= ([0,0,0],[1,1,np.inf]) )
10 | return tuple(popt)
11 |
12 |
13 | def derivation(x, a, b, c):
14 | x = x + 1e-6 # numerical robustness
15 | return a * b * 1/c * np.exp(-1/c * x**b) * (x**(b-1))
16 |
17 |
18 | def label_update_epoch(ydata_fit, n_epoch = 16, threshold = 0.9, eval_interval = 100, num_iter_per_epoch= 10581/10):
19 | xdata_fit = np.linspace(0, len(ydata_fit)*eval_interval/num_iter_per_epoch, len(ydata_fit))
20 | a, b, c = fit(curve_func, xdata_fit, ydata_fit)
21 | epoch = np.arange(1, n_epoch)
22 | y_hat = curve_func(epoch, a, b, c)
23 | relative_change = abs(abs(derivation(epoch, a, b, c)) - abs(derivation(1, a, b, c)))/ abs(derivation(1, a, b, c))
24 | relative_change[relative_change > 1] = 0
25 | update_epoch = np.sum(relative_change <= threshold) + 1
26 | return update_epoch#, a, b, c
27 |
28 | def if_update(iou_value, current_epoch, n_epoch = 16, threshold = 0.90, eval_interval=1, num_iter_per_epoch=1):
29 | # check iou_value
30 | start_iter = 0
31 | print("len(iou_value)=",len(iou_value))
32 | for k in range(len(iou_value)-1):
33 | if iou_value[k+1]-iou_value[k] < 0.1:
34 | start_iter = max(start_iter, k + 1)
35 | else:
36 | break
37 | shifted_epoch = start_iter*eval_interval/num_iter_per_epoch
38 | #cut out the first few entries
39 | iou_value = iou_value[start_iter: ]
40 | update_epoch = label_update_epoch(iou_value, n_epoch = n_epoch, threshold=threshold, eval_interval=eval_interval, num_iter_per_epoch=num_iter_per_epoch)
41 | # Shift back
42 | update_epoch = shifted_epoch + update_epoch
43 | return current_epoch >= update_epoch#, update_epoch
44 |
45 |
46 | def merge_labels_with_skip(original_labels, model_predictions, need_label_correction_dict, conf_threshold=0.8, logic_255=False,class_constraint=True, conf_threshold_bg = 0.95):
47 |
48 |
49 | new_label_dict = {}
50 | update_list = []
51 | for c in need_label_correction_dict:
52 | if need_label_correction_dict[c]:
53 | update_list.append(c)
54 |
55 |
56 | for pid in model_predictions:
57 | pred_prob = model_predictions[pid]
58 | pred = np.argmax(pred_prob, axis=0)
59 | label = original_labels[pid]
60 |
61 | # print(np.unique(label))
62 | # print(update_list)
63 | # does not belong to the class that need to be updated, then we do not need the following updating process
64 | if set(np.unique(label)).isdisjoint(set(update_list)):
65 | new_label_dict[pid] = label
66 | continue
67 |
68 |
69 | # if the prediction is confident
70 | # confident = np.max(pred_prob, axis=0) > conf_threshold
71 |
72 | # if the prediction is confident
73 | # code support different threshold for foreground and background,
74 | # during the experiment, we always set them to be the same for simplicity
75 | confident = (np.max(pred_prob[1:], axis=0) > conf_threshold) |(pred_prob[0] > conf_threshold_bg)
76 |
77 | # before update: only class that need correction will be replaced
78 | belong_to_correction_class = label==0
79 | for c in need_label_correction_dict:
80 | if need_label_correction_dict[c]:
81 | belong_to_correction_class |= (label==c)
82 |
83 | # after update: only pixels that will be flipped to the allowed classes will be updated
84 | after_belong = pred==0
85 | for c in need_label_correction_dict:
86 | if need_label_correction_dict[c]:
87 | after_belong |= (pred==c)
88 |
89 | # combine all three masks together
90 | replace_flag = confident & belong_to_correction_class & after_belong
91 |
92 |
93 | # the class constraint
94 | if class_constraint:
95 | unique_class = np.unique(label)
96 | # print(unique_class)
97 | # indx = torch.zeros((h, w), dtype=torch.long)
98 | class_constraint_indx = (pred==0)
99 | for element in unique_class:
100 | class_constraint_indx = class_constraint_indx | (pred == element)
101 |
102 |
103 | replace_flag = replace_flag & (class_constraint_indx != 0)
104 |
105 |
106 | # replace with the new label
107 | next_label = np.where(replace_flag, pred, label).astype("int32")
108 |
109 | # logic 255:
110 | # - rule# 1: if label[i,j] != 0, and pred[i,j] = 0, then next_label[i,j] = 255
111 | # - rule# 2: if label[i,j] = 255 and pred[i,j] != 0 and confident, then next_label[i,j] = pred[i,j]
112 | # rule 2 is already enforced above, don't need additional code
113 | if logic_255:
114 | rule_1_flag = (label != 0) & (pred == 0)
115 | next_label = np.where(rule_1_flag, np.ones(next_label.shape)*255, next_label).astype("int32")
116 |
117 | new_label_dict[pid] = next_label
118 |
119 | return new_label_dict
120 |
121 |
122 |
--------------------------------------------------------------------------------
/SegThor/brat/loading.py:
--------------------------------------------------------------------------------
1 | import os, torch, cv2, pickle, copy, random
2 | import numpy as np
3 | from scipy import misc
4 | from PIL import Image
5 | from torch.utils.data import Dataset, Sampler
6 | import torchvision.transforms as transforms
7 | import torchvision.transforms.functional as F
8 |
9 |
10 | def final_noise_function(mat):
11 | mode = np.random.choice(["under", "over"])
12 | iterations = np.random.choice(np.arange(2,5))
13 | return under_over_seg(mat, iterations, mode)
14 |
15 |
16 | def under_over_seg(mat, iteration=1, mode="under"):
17 | target_num = 1000
18 | mat = np.copy(mat)
19 | kernel = np.ones((3,3),np.uint8)
20 | for cls in [1,3,4,2]:
21 | binary_mat = mat==cls
22 | foreground_num = np.sum(binary_mat)
23 | if foreground_num != 0:
24 | # resize the image to match the foreground pixel number
25 | h, w = mat.shape
26 | ratio = np.sqrt(target_num/foreground_num)
27 | h_new = int(round( h * ratio))
28 | w_new = int(round( w * ratio))
29 | resized_img = cv2.resize(binary_mat.astype("uint8"), (w_new, h_new), interpolation=cv2.INTER_CUBIC) > 0
30 | # erosion or dilation
31 | if mode == "under":
32 | binary_mat_processed = cv2.erode(resized_img.astype("uint8"),kernel, iterations =iteration)
33 | elif mode == "over":
34 | binary_mat_processed = cv2.dilate(resized_img.astype("uint8"), kernel, iterations=iteration)
35 | # resize back to the original size
36 | binary_mat_processed_resized = cv2.resize(binary_mat_processed, (w, h), interpolation=cv2.INTER_CUBIC) > 0
37 | # fill in the gap
38 | if mode == "under":
39 | mat = np.where(binary_mat_processed_resized!=binary_mat, np.zeros(mat.shape), mat)
40 | elif mode == "over":
41 | mat = np.where(binary_mat_processed_resized & (mat==0), np.ones(mat.shape)*cls, mat)
42 | return mat
43 |
44 | def under_seg(mat):
45 | mat = np.copy(mat)
46 | kernel_small = np.ones((2,2),np.uint8)
47 | kernel_medium = np.ones((3,3),np.uint8)
48 | kernel_large = np.ones((5,5),np.uint8)
49 | for cls in [1,2,3,4]:
50 | binary_mat = mat==cls
51 | if cls in [1,3]:
52 | kernel_used = kernel_small
53 | iteration = 1
54 | elif cls == 2:
55 | kernel_used = kernel_large
56 | iteration = 2
57 | else:
58 | kernel_used = kernel_medium
59 | iteration = 2
60 | binary_mat_eroded = cv2.erode(binary_mat.astype("uint8"),kernel_used,iterations =iteration)
61 | mat = np.where(binary_mat_eroded!=binary_mat, np.zeros(mat.shape), mat)
62 | return mat
63 |
64 | def over_seg(mat):
65 | mat = np.copy(mat)
66 | kernel_small = np.ones((2,2),np.uint8)
67 | kernel_medium = np.ones((3,3),np.uint8)
68 | kernel_large = np.ones((5,5),np.uint8)
69 | for cls in [1,2,3,4]:
70 | if cls in [1,3]:
71 | kernel_used = kernel_small
72 | elif cls == 3:
73 | kernel_used = kernel_large
74 | else:
75 | kernel_used = kernel_medium
76 | binary_mat = mat==cls
77 | binary_mat_dilated = cv2.dilate(binary_mat.astype("uint8"), kernel_used, iterations=2)
78 | mat = np.where(binary_mat_dilated, np.ones(mat.shape)*cls, mat)
79 | return mat
80 |
81 | def wrong_seg(mat):
82 | mat_cp = np.copy(mat)
83 | channel_0 = np.random.choice([1,2])
84 | channel_1 = np.random.choice([0,2])
85 | channel_2 = np.random.choice([0,1])
86 | mat_cp[0,:,:] = mat[channel_0,:,:]
87 | mat_cp[1,:,:] = mat[channel_1,:,:]
88 | mat_cp[2,:,:] = mat[channel_2,:,:]
89 | return mat_cp
90 |
91 | def noise_seg(mat, noise_level=0.05):
92 | """
93 | P(out=0 | in=0) = 1-noise_level
94 | P(out=1234 | in=0) = noise_level/4
95 | P(out=0 | in=1234) = noise_level
96 | P(out=1234 | in=1234) = 1-noise_level
97 | """
98 | mat = np.copy(mat)
99 | fate = np.random.uniform(low=0, high=1, size=mat.shape)
100 | # deal with 0
101 | is_zero_indicator = mat == 0
102 | background_flip_to = np.random.choice([1,2,3,4], size=mat.shape)
103 | mat = np.where( (fate <= noise_level) & is_zero_indicator, background_flip_to, mat)
104 | # deal with 1,2,3,4
105 | mat = np.where( (fate <= noise_level) & (~is_zero_indicator), np.zeros(mat.shape), mat)
106 | return mat
107 |
108 | def mixed_seg(mat):
109 | fate = np.random.uniform(0,1)
110 | if fate < 0.33:
111 | return under_seg(mat)
112 | elif fate < 0.67:
113 | return over_seg(mat)
114 | else:
115 | return noise_seg(mat)
116 |
117 |
118 | NOISE_LABEL_DICT = {"under":under_seg, "over":over_seg, "wrong":wrong_seg, "noise":noise_seg,
119 | "mixed":mixed_seg, "final":final_noise_function}
120 |
121 | class StackedRandomAffine(transforms.RandomAffine):
122 | def __call__(self, imgs):
123 | """
124 | img (PIL Image): Image to be transformed.
125 | Returns:
126 | PIL Image: Affine transformed image.
127 | """
128 | ret = self.get_params(self.degrees, self.translate, self.scale, self.shear, imgs[0].size)
129 | return [F.affine(x, *ret, resample=self.resample, fillcolor=self.fillcolor) for x in imgs]
130 |
131 |
132 | def standarize(img):
133 | return (img - img.mean()) / img.std()
134 |
135 | class BaseDataset(Dataset):
136 | def __init__(self, parameters, data_list, augmentation=False, noise_label=None, noise_level=None, cache_dir=None):
137 | self.data_list = data_list
138 | self.data_dir = parameters["data_dir"]
139 | self.img_dir = os.path.join(self.data_dir, "img")
140 | self.seg_dir = os.path.join(self.data_dir, "label")
141 |
142 | # reset seeds
143 | random.seed(parameters["seed"])
144 | torch.manual_seed(parameters["seed"])
145 | torch.cuda.manual_seed(parameters["seed"])
146 | np.random.seed(parameters["seed"])
147 |
148 | # load cached images and labels if necessary
149 | if cache_dir is None:
150 | self.cache_label = None
151 | self.cache_img = None
152 | else:
153 | with open(cache_dir, "rb") as f:
154 | self.cache_img, self.cache_label = pickle.load(f)
155 | self.cache_clean_label = copy.deepcopy(self.cache_label)
156 |
157 | # noise label functions
158 | self.noise_function = None if noise_label is None else NOISE_LABEL_DICT[noise_label]
159 | if self.noise_function is not None and noise_level is not None:
160 | noise_number = int(round(noise_level * len(self.data_list)))
161 | self.noise_index_list = np.random.permutation(np.arange(len(self.data_list)))[:noise_number]
162 | # add noise to the cached labels
163 | for i in range(len(self.data_list)):
164 | if i in self.noise_index_list:
165 | img_name = self.data_list[i]
166 | self.cache_label[img_name] = self.noise_function(self.cache_label[img_name])
167 | self.cache_noisy_label = copy.deepcopy(self.cache_label)
168 | else:
169 | self.cache_noisy_label = self.cache_clean_label
170 |
171 | # augmentation setting
172 | self.augmentation = augmentation
173 | self.augmentation_function = StackedRandomAffine(degrees=(-45, 45), translate=(0.1, 0.1), scale=(0.8, 1.5))
174 |
175 | # transformation setting
176 | transform_list = []
177 | if parameters["resize"] is not None:
178 | transform_list.append(transforms.Resize(size=(parameters["resize"], parameters["resize"]),
179 | interpolation=0))
180 | transform_list.append(transforms.ToTensor())
181 | self.transform = transforms.Compose(transform_list)
182 |
183 | def __len__(self):
184 | return len(self.data_list)
185 |
186 |
187 |
188 | class BraTSDataset(BaseDataset):
189 | def __init__(self, parameters, data_list, augmentation=False, noise_label=None):
190 | super(BraTSDataset, self).__init__(parameters, data_list, augmentation, noise_label)
191 |
192 | def __getitem__(self, index):
193 | img_name = self.data_list[index]
194 |
195 | # put up paths
196 | img_path = os.path.join(self.img_dir, img_name)
197 | seg_path = os.path.join(self.seg_dir, img_name)
198 |
199 | # load images and seg
200 | img = np.load(img_path).astype("int16")
201 | seg = np.load(seg_path).astype("int8")
202 | if self.noise_function is not None:
203 | seg = self.noise_function(seg)
204 |
205 | # convert to pil image
206 | img_channel_pils = [Image.fromarray(img[i,:,:].astype("int16")) for i in range(img.shape[0])]
207 | seg_channel_pils = [Image.fromarray(seg[i,:,:].astype("int8")) for i in range(seg.shape[0])]
208 |
209 | # augmentation
210 | if self.augmentation:
211 | aug_res = self.augmentation_function(img_channel_pils + seg_channel_pils)
212 | img_channel_pils = aug_res[:4]
213 | seg_channel_pils = aug_res[4:]
214 |
215 | # post-process
216 | img_channel_torch = [standarize(self.to_tensor(x).float()) for x in img_channel_pils]
217 | label_channel_torch = [self.to_tensor(x) for x in seg_channel_pils]
218 | img_torch = torch.cat(img_channel_torch, dim=0)
219 | label_torch = torch.cat(label_channel_torch, dim=0)
220 | label_torch[label_torch > 0] = 1
221 |
222 | return img_torch.float(), label_torch.long(), img_name
223 |
224 | class SegTHORDataset(BaseDataset):
225 | def __init__(self, parameters, data_list, augmentation=False, noise_label=None, noise_level=None, cache_dir=None):
226 | super(SegTHORDataset, self).__init__(parameters, data_list, augmentation, noise_label, noise_level, cache_dir)
227 |
228 | def reset_labels(self, new_labels):
229 | self.cache_label = new_labels
230 |
231 | def __getitem__(self, index):
232 | img_name = self.data_list[index]
233 | # load image and the segmentation label
234 | if self.cache_img is None:
235 | img_path = os.path.join(self.img_dir, img_name)
236 | img = np.load(img_path).astype("int16")
237 | img -= img.min()
238 | else:
239 | img = self.cache_img[img_name]
240 | if self.cache_label is None:
241 | seg_path = os.path.join(self.seg_dir, img_name)
242 | seg = np.load(seg_path).astype("int8")
243 | # add noise to the label if needed
244 | if self.noise_function is not None and index in self.noise_index_list:
245 | seg = self.noise_function(seg)
246 | else:
247 | seg = self.cache_label[img_name]
248 | clean_seg = self.cache_clean_label[img_name]
249 | original_noisy_seg = self.cache_noisy_label[img_name]
250 |
251 | # convert to pil image
252 | img_pils = Image.fromarray(img)
253 | seg_pils = Image.fromarray(seg)
254 | clean_seg_pils = Image.fromarray(clean_seg)
255 | original_noisy_seg_pils = Image.fromarray(original_noisy_seg)
256 |
257 | # augmentation
258 | if self.augmentation:
259 | img_pils, seg_pils, clean_seg_pils, original_noisy_seg_pils = self.augmentation_function([img_pils, seg_pils, clean_seg_pils, original_noisy_seg_pils])
260 |
261 | # post-process
262 | img_torch = standarize(self.transform(img_pils).float())
263 | label_torch = self.transform(seg_pils)
264 | clean_label_torch = self.transform(clean_seg_pils)
265 | original_noisy_torch = self.transform(original_noisy_seg_pils)
266 | return img_torch.float(), label_torch.long(), original_noisy_torch.long(), clean_label_torch.long(), img_name
--------------------------------------------------------------------------------
/SegThor/brat/unet_model.py:
--------------------------------------------------------------------------------
1 | """ Full assembly of the parts to form the complete network """
2 |
3 | import torch.nn.functional as F
4 |
5 | from .unet_parts import *
6 |
7 |
8 | class UNet(nn.Module):
9 | def __init__(self, n_channels, n_classes, bilinear=True):
10 | super(UNet, self).__init__()
11 | self.n_channels = n_channels
12 | self.n_classes = n_classes
13 | self.bilinear = bilinear
14 |
15 | self.inc = DoubleConv(n_channels, 64)
16 | self.down1 = Down(64, 128)
17 | self.down2 = Down(128, 256)
18 | self.down3 = Down(256, 512)
19 | factor = 2 if bilinear else 1
20 | self.down4 = Down(512, 1024 // factor)
21 | self.up1 = Up(1024, 512 // factor, bilinear)
22 | self.up2 = Up(512, 256 // factor, bilinear)
23 | self.up3 = Up(256, 128 // factor, bilinear)
24 | self.up4 = Up(128, 64, bilinear)
25 | self.outc = OutConv(64, n_classes)
26 |
27 | def forward(self, x):
28 | x1 = self.inc(x)
29 | x2 = self.down1(x1)
30 | x3 = self.down2(x2)
31 | x4 = self.down3(x3)
32 | x5 = self.down4(x4)
33 | x = self.up1(x5, x4)
34 | x = self.up2(x, x3)
35 | x = self.up3(x, x2)
36 | x = self.up4(x, x1)
37 | logits = self.outc(x)
38 | return logits
--------------------------------------------------------------------------------
/SegThor/brat/unet_parts.py:
--------------------------------------------------------------------------------
1 | """ Parts of the U-Net model """
2 |
3 | import torch
4 | import torch.nn as nn
5 | import torch.nn.functional as F
6 |
7 |
8 | class DoubleConv(nn.Module):
9 | """(convolution => [BN] => ReLU) * 2"""
10 |
11 | def __init__(self, in_channels, out_channels, mid_channels=None):
12 | super().__init__()
13 | if not mid_channels:
14 | mid_channels = out_channels
15 | self.double_conv = nn.Sequential(
16 | nn.Conv2d(in_channels, mid_channels, kernel_size=3, padding=1),
17 | nn.BatchNorm2d(mid_channels),
18 | nn.ReLU(inplace=True),
19 | nn.Conv2d(mid_channels, out_channels, kernel_size=3, padding=1),
20 | nn.BatchNorm2d(out_channels),
21 | nn.ReLU(inplace=True)
22 | )
23 |
24 | def forward(self, x):
25 | return self.double_conv(x)
26 |
27 |
28 | class Down(nn.Module):
29 | """Downscaling with maxpool then double conv"""
30 |
31 | def __init__(self, in_channels, out_channels):
32 | super().__init__()
33 | self.maxpool_conv = nn.Sequential(
34 | nn.MaxPool2d(2),
35 | DoubleConv(in_channels, out_channels)
36 | )
37 |
38 | def forward(self, x):
39 | return self.maxpool_conv(x)
40 |
41 |
42 | class Up(nn.Module):
43 | """Upscaling then double conv"""
44 |
45 | def __init__(self, in_channels, out_channels, bilinear=True):
46 | super().__init__()
47 |
48 | # if bilinear, use the normal convolutions to reduce the number of channels
49 | if bilinear:
50 | self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
51 | self.conv = DoubleConv(in_channels, out_channels, in_channels // 2)
52 | else:
53 | self.up = nn.ConvTranspose2d(in_channels , in_channels // 2, kernel_size=2, stride=2)
54 | self.conv = DoubleConv(in_channels, out_channels)
55 |
56 |
57 | def forward(self, x1, x2):
58 | x1 = self.up(x1)
59 | # input is CHW
60 | diffY = x2.size()[2] - x1.size()[2]
61 | diffX = x2.size()[3] - x1.size()[3]
62 |
63 | x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2,
64 | diffY // 2, diffY - diffY // 2])
65 | # if you have padding issues, see
66 | # https://github.com/HaiyongJiang/U-Net-Pytorch-Unstructured-Buggy/commit/0e854509c2cea854e247a9c615f175f76fbb2e3a
67 | # https://github.com/xiaopeng-liao/Pytorch-UNet/commit/8ebac70e633bac59fc22bb5195e513d5832fb3bd
68 | x = torch.cat([x2, x1], dim=1)
69 | return self.conv(x)
70 |
71 |
72 | class OutConv(nn.Module):
73 | def __init__(self, in_channels, out_channels):
74 | super(OutConv, self).__init__()
75 | self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1)
76 |
77 | def forward(self, x):
78 | return self.conv(x)
--------------------------------------------------------------------------------
/SegThor/lib/utils/JSD_loss.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torchvision
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 |
6 |
7 | def calc_jsd_multiscale(weight, labels1_a, pred1, pred2, pred3, threshold=0.8, Mask_label255_sign='no'):
8 |
9 | Mask_label255 = (labels1_a < 255).float() # do not compute the area that is irrelavant (dataaug) b,h,w
10 | weight_softmax = F.softmax(weight, dim=0)
11 |
12 | criterion1 = nn.CrossEntropyLoss(ignore_index=255, reduction='none')
13 | criterion2 = nn.CrossEntropyLoss(ignore_index=255, reduction='none')
14 | criterion3 = nn.CrossEntropyLoss(ignore_index=255, reduction='none')
15 |
16 | loss1 = criterion1(pred1 * weight_softmax[0], labels1_a) # * weight_softmax[0]
17 | loss2 = criterion2(pred2 * weight_softmax[1], labels1_a) # * weight_softmax[1]
18 | loss3 = criterion3(pred3 * weight_softmax[2], labels1_a) # * weight_softmax[2]
19 |
20 | loss = (loss1 + loss2 + loss3)
21 |
22 | probs = [F.softmax(logits, dim=1) for i, logits in enumerate([pred1, pred2, pred3])]
23 |
24 | weighted_probs = [weight_softmax[i] * prob for i, prob in enumerate(probs)] # weight_softmax[i]*
25 | mixture_label = (torch.stack(weighted_probs)).sum(axis=0)
26 | #mixture_label = torch.clamp(mixture_label, 1e-7, 1) # h,c,h,w
27 | mixture_label = torch.clamp(mixture_label, 1e-3, 1-1e-3) # h,c,h,w
28 |
29 | # add this code block for early torch version where torch.amax is not available
30 | if torch.__version__=="1.5.0" or torch.__version__=="1.6.0":
31 | _, max_probs = torch.max(mixture_label*Mask_label255.unsqueeze(1), dim=-3, keepdim=True)
32 | _, max_probs = torch.max(max_probs, dim=-2, keepdim=True)
33 | _, max_probs = torch.max(max_probs, dim=-1, keepdim=True)
34 | else:
35 | max_probs = torch.amax(mixture_label*Mask_label255.unsqueeze(1), dim=(-3, -2, -1), keepdim=True)
36 | mask = max_probs.ge(threshold).float()
37 |
38 |
39 | logp_mixture = mixture_label.log()
40 |
41 | log_probs = [torch.sum(F.kl_div(logp_mixture, prob, reduction='none') * mask, dim=1) for prob in probs]
42 | if Mask_label255_sign == 'yes':
43 | consistency = sum(log_probs)*Mask_label255
44 | else:
45 | consistency = sum(log_probs)
46 |
47 | return torch.mean(loss), torch.mean(consistency), consistency, mixture_label
48 |
--------------------------------------------------------------------------------
/SegThor/lib/utils/iou_computation.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 |
4 |
5 | def update_iou_stat(predict, gt, TP, P, T, num_classes = 21):
6 | """
7 | :param predict: the pred of each batch, should be numpy array, after take the argmax b,h,w
8 | :param gt: the gt label of the batch, should be numpy array b,h,w
9 | :param TP: True positive
10 | :param P: positive prediction
11 | :param T: True seg
12 | :param num_classes: number of classes in the dataset
13 | :return: TP, P, T
14 | """
15 | cal = gt < 255
16 |
17 | mask = (predict == gt) * cal
18 |
19 | for i in range(num_classes):
20 | P[i] += np.sum((predict == i) * cal)
21 | T[i] += np.sum((gt == i) * cal)
22 | TP[i] += np.sum((gt == i) * mask)
23 |
24 | return TP, P, T
25 |
26 |
27 | def iter_iou_stat(predict, gt, num_classes = 21):
28 | """
29 | :param predict: the pred of each batch, should be numpy array, after take the argmax b,h,w
30 | :param gt: the gt label of the batch, should be numpy array b,h,w
31 | :param TP: True positive
32 | :param P: positive prediction
33 | :param T: True seg
34 | :param num_classes: number of classes in the dataset
35 | :return: TP, P, T
36 | """
37 | cal = gt < 255
38 |
39 | mask = (predict == gt) * cal
40 |
41 | TP = np.zeros(num_classes)
42 | P = np.zeros(num_classes)
43 | T = np.zeros(num_classes)
44 |
45 | for i in range(num_classes):
46 | P[i] = np.sum((predict == i) * cal)
47 | T[i] = np.sum((gt == i) * cal)
48 | TP[i] = np.sum((gt == i) * mask)
49 |
50 | return np.array([TP, P, T])
51 |
52 |
53 | def compute_iou(TP, P, T, num_classes = 21):
54 | """
55 | :param TP:
56 | :param P:
57 | :param T:
58 | :param num_classes: number of classes in the dataset
59 | :return: IoU
60 | """
61 | IoU = []
62 | for i in range(num_classes):
63 | IoU.append(TP[i] / (T[i] + P[i] - TP[i] + 1e-10))
64 | return IoU
65 |
66 |
67 | def update_fraction_batchwise(mask, gt, fraction, num_classes = 21):
68 | """
69 | :param mask: True when belong to subgroup (memorized, correct, others) which we want to calculate fraction on
70 | :param gt: the gt label of the batch, numpy array
71 | :param fraction: fraction of pixels in the subgroup
72 | :param num_classes: number of classes in the dataset
73 | :return: updated fraction
74 | """
75 | cal = gt < 255
76 |
77 | for i in range(num_classes):
78 | fraction[i] += np.sum((mask * (gt == i) * cal))/np.sum((gt == i) * cal)
79 |
80 | return fraction
81 |
82 |
83 | def update_fraction_instancewise(mask, gt, fraction, num_classes = 21):
84 | """
85 | :param mask: True when belong to subgroup (memorized, correct, others) which we want to calculate fraction on
86 | :param gt: the gt label of the batch, numpy array
87 | :param fraction: fraction of pixels in the subgroup
88 | :param num_classes: number of classes in the dataset
89 | :return: updated fraction
90 | """
91 | # np.sum((gt == i) * cal maybe a nan value, can't do that
92 | cal = gt < 255
93 |
94 | for i in range(num_classes):
95 | fraction[i] += np.mean(np.sum((mask * (gt == i) * cal), axis= (-2,-1))/np.sum((gt == i) * cal, axis= (-2,-1)))
96 |
97 | return fraction
98 |
99 | def update_fraction_pixelwise(mask, gt, abs_num_and_total, num_classes = 21):
100 | """
101 | :param mask: True when belong to subgroup (memorized, correct, others) which we want to calculate fraction on
102 | :param gt: the gt label of the batch, numpy array
103 | :param abs_num_and_total: the absolute number of pixel belong to the mask and the total num of pixels [abs_num, pixel_num]
104 | :param num_classes: number of classes in the dataset
105 | :return: updated fraction
106 | """
107 | cal = gt < 255
108 |
109 | for i in range(num_classes):
110 | abs_num_and_total[i][0] += np.sum(mask * (gt == i) * cal)
111 | abs_num_and_total[i][1] += np.sum((gt == i) * cal)
112 |
113 |
114 | return abs_num_and_total
115 |
116 | def iter_fraction_pixelwise(mask, gt, num_classes = 21):
117 | """
118 | :param mask: True when belong to subgroup (memorized, correct, others) which we want to calculate fraction on
119 | :param gt: the gt label of the batch, numpy array
120 | :param num_classes: number of classes in the dataset
121 | :return: updated fraction
122 | """
123 | cal = gt < 255
124 |
125 | abs_num_and_total = np.zeros((num_classes,2))
126 |
127 | for i in range(num_classes):
128 | abs_num_and_total[i][0] += np.sum(mask * (gt == i) * cal)
129 | abs_num_and_total[i][1] += np.sum((gt == i) * cal)
130 |
131 |
132 | return abs_num_and_total
133 |
134 |
135 |
136 | def get_mask(gt_np, label_np, pred_np):
137 | """
138 |
139 | Args:
140 | gt_np: the GT label
141 | label_np: the CAM pseudo label
142 | pred_np: the prediction
143 |
144 | Returns: the mask of different type
145 |
146 | """
147 | wrong_mask_correct = (gt_np != label_np) & (pred_np == gt_np)
148 | wrong_mask_memorized = (gt_np != label_np) & (pred_np == label_np)
149 | wrong_mask_others = (gt_np != label_np) & (pred_np != gt_np) & (pred_np != label_np)
150 | clean_mask_correct = (gt_np == label_np) & (pred_np == gt_np)
151 | clean_mask_incorrect = (gt_np == label_np) & (pred_np != gt_np)
152 |
153 | return (wrong_mask_correct,wrong_mask_memorized,wrong_mask_others,clean_mask_correct,clean_mask_incorrect)
--------------------------------------------------------------------------------
/SegThor/requirements.txt:
--------------------------------------------------------------------------------
1 | torch>=1.7.0
2 | torchvision>=0.8.1
3 | mxnet>=1.7.0.post1
4 | scipy>=1.5.1
5 | numpy>=1.19.4
6 | scikit_image>=0.17.2
7 | pydensecrf>=1.0rc3
8 | pandas>=1.0.5
9 | opencv_python>=4.3.0.36
10 | matplotlib>=3.3.0
11 | Pillow>=8.1.0
12 | tensorboardX>=2.1
13 | tqdm
14 | scikit-image
15 |
--------------------------------------------------------------------------------
/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Kangningthu/ADELE/7195bd0af39be79c533d67dd7eab7f9bfd6a4285/__init__.py
--------------------------------------------------------------------------------
/config.py:
--------------------------------------------------------------------------------
1 | # ----------------------------------------
2 | # Written by Yude Wang
3 | # ----------------------------------------
4 | import torch
5 | import argparse
6 | import os
7 | import sys
8 | import cv2
9 | import time
10 |
11 | config_dict = {
12 | 'EXP_NAME': 'Experiment',
13 | 'GPUS': 2,
14 | 'TEST_GPUS': 1,
15 | 'DATA_NAME': 'VOCTrainwsegDataset',
16 | 'DATA_YEAR': 2012,
17 | 'DATA_AUG': True,
18 | 'DATA_WORKERS': 8,
19 | 'DATA_MEAN': [0.485, 0.456, 0.406],
20 | 'DATA_STD': [0.229, 0.224, 0.225],
21 | 'DATA_RANDOMCROP': 448,
22 | 'DATA_RANDOMSCALE': [0.5, 1.5],
23 | 'DATA_RANDOM_H': 10,
24 | 'DATA_RANDOM_S': 10,
25 | 'DATA_RANDOM_V': 10,
26 | 'DATA_RANDOMFLIP': 0.5,
27 | 'DATA_PSEUDO_GT': '/scratch/kl3141/seam/SEAM-master/results/aff_rw_aug',
28 | 'DATA_FEATURE_DIR':False,
29 |
30 | 'MODEL_NAME': 'deeplabv1',
31 | 'MODEL_BACKBONE': 'resnet38',
32 | 'MODEL_BACKBONE_PRETRAIN': True,
33 | 'MODEL_NUM_CLASSES': 21,
34 | 'MODEL_FREEZEBN': False,
35 |
36 | 'TRAIN_LR': 0.001,
37 | 'TRAIN_MOMENTUM': 0.9,
38 | 'TRAIN_WEIGHT_DECAY': 0.0005,
39 | 'TRAIN_BN_MOM': 0.0003,
40 | 'TRAIN_POWER': 0.9,
41 | 'TRAIN_BATCHES': 10,
42 | 'TRAIN_SHUFFLE': True,
43 | 'TRAIN_MINEPOCH': 0,
44 | 'TRAIN_ITERATION': 20000,
45 | 'TRAIN_TBLOG': True,
46 |
47 | 'TEST_MULTISCALE': [0.5, 0.75, 1.0, 1.25, 1.5, 1.75],
48 | 'TEST_FLIP': True,
49 | 'TEST_CRF': True,
50 | 'TEST_BATCHES': 1,
51 |
52 | 'MODEL_BACKBONE_DILATED':False,
53 | 'MODEL_BACKBONE_MULTIGRID':False,
54 | 'MODEL_BACKBONE_DEEPBASE': False,
55 |
56 | 'scale_factor':0.7,
57 | 'scale_factor2': 1.5,
58 | 'lambda_seg': 0.5,
59 |
60 | }
61 |
62 | config_dict['ROOT_DIR'] = os.path.abspath(os.path.dirname("__file__"))
63 | config_dict['MODEL_SAVE_DIR'] = os.path.join(config_dict['ROOT_DIR'],'model',config_dict['EXP_NAME'])
64 | config_dict['TRAIN_CKPT'] = None
65 | config_dict['LOG_DIR'] = os.path.join(config_dict['ROOT_DIR'],'log',config_dict['EXP_NAME'])
66 | sys.path.insert(0, os.path.join(config_dict['ROOT_DIR'], 'lib'))
67 |
--------------------------------------------------------------------------------
/lib/datasets/BaseDataset.py:
--------------------------------------------------------------------------------
1 | # ----------------------------------------
2 | # Written by Yude Wang
3 | # ----------------------------------------
4 |
5 | from __future__ import print_function, division
6 | import os
7 | import torch
8 | import pandas as pd
9 | import cv2
10 | import multiprocessing
11 | from skimage import io
12 | from PIL import Image
13 | import numpy as np
14 | from torch.utils.data import Dataset
15 | from datasets.transform import *
16 | from utils.imutils import *
17 | from utils.registry import DATASETS
18 |
19 | #@DATASETS.register_module
20 | class BaseDataset(Dataset):
21 | def __init__(self, cfg, period, transform='none'):
22 | super(BaseDataset, self).__init__()
23 | self.cfg = cfg
24 | self.period = period
25 | self.transform = transform
26 | if 'train' not in self.period:
27 | assert self.transform == 'none'
28 | self.num_categories = None
29 | self.totensor = ToTensor()
30 | self.imagenorm = ImageNorm(cfg.DATA_MEAN, cfg.DATA_STD)
31 |
32 | if self.transform != 'none':
33 | if cfg.DATA_RANDOMCROP > 0:
34 | self.randomcrop = RandomCrop(cfg.DATA_RANDOMCROP)
35 | if cfg.DATA_RANDOMSCALE != 1:
36 | self.randomscale = RandomScale(cfg.DATA_RANDOMSCALE)
37 | if cfg.DATA_RANDOMFLIP > 0:
38 | self.randomflip = RandomFlip(cfg.DATA_RANDOMFLIP)
39 | if cfg.DATA_RANDOM_H > 0 or cfg.DATA_RANDOM_S > 0 or cfg.DATA_RANDOM_V > 0:
40 | self.randomhsv = RandomHSV(cfg.DATA_RANDOM_H, cfg.DATA_RANDOM_S, cfg.DATA_RANDOM_V)
41 | else:
42 | self.multiscale = Multiscale(self.cfg.TEST_MULTISCALE)
43 |
44 |
45 | def __getitem__(self, idx):
46 | sample = self.__sample_generate__(idx)
47 |
48 | if 'segmentation' in sample.keys():
49 | sample['mask'] = sample['segmentation'] < self.num_categories
50 | t = sample['segmentation'].copy()
51 | t[t >= self.num_categories] = 0
52 | sample['segmentation_onehot']=onehot(t,self.num_categories)
53 | return self.totensor(sample)
54 |
55 | def __sample_generate__(self, idx, split_idx=0):
56 | name = self.load_name(idx)
57 | image = self.load_image(idx)
58 | r,c,_ = image.shape
59 | sample = {'image': image, 'name': name, 'row': r, 'col': c}
60 |
61 | if 'test' in self.period:
62 | return self.__transform__(sample)
63 | elif self.cfg.DATA_PSEUDO_GT and idx>=split_idx and 'train' in self.period:
64 | segmentation = self.load_pseudo_segmentation(idx)
65 | else:
66 | segmentation = self.load_segmentation(idx)
67 | sample['segmentation'] = segmentation
68 | t = sample['segmentation'].copy()
69 | t[t >= self.num_categories] = 0
70 | sample['category'] = seg2cls(t,self.num_categories)
71 | sample['category_copypaste'] = np.zeros(sample['category'].shape)
72 |
73 | if self.transform == 'none' and self.cfg.DATA_FEATURE_DIR:
74 | feature = self.load_feature(idx)
75 | sample['feature'] = feature
76 | return self.__transform__(sample)
77 |
78 | def __transform__(self, sample):
79 | if self.transform == 'weak':
80 | sample = self.__weak_augment__(sample)
81 | elif self.transform == 'strong':
82 | sample = self.__strong_augment__(sample)
83 | else:
84 | sample = self.imagenorm(sample)
85 | sample = self.multiscale(sample)
86 | return sample
87 |
88 | def __weak_augment__(self, sample):
89 | if self.cfg.DATA_RANDOM_H>0 or self.cfg.DATA_RANDOM_S>0 or self.cfg.DATA_RANDOM_V>0:
90 | sample = self.randomhsv(sample)
91 | if self.cfg.DATA_RANDOMFLIP > 0:
92 | sample = self.randomflip(sample)
93 | if self.cfg.DATA_RANDOMSCALE != 1:
94 | sample = self.randomscale(sample)
95 | sample = self.imagenorm(sample)
96 | if self.cfg.DATA_RANDOMCROP > 0:
97 | sample = self.randomcrop(sample)
98 | return sample
99 |
100 | def __strong_augment__(self, sample):
101 | raise NotImplementedError
102 |
103 | def __len__(self):
104 | raise NotImplementedError
105 |
106 | def load_name(self, idx):
107 | raise NotImplementedError
108 |
109 | def load_image(self, idx):
110 | raise NotImplementedError
111 |
112 | def load_segmentation(self, idx):
113 | raise NotImplementedError
114 |
115 | def load_pseudo_segmentation(self, idx):
116 | raise NotImplementedError
117 |
118 | def load_feature(self, idx):
119 | raise NotImplementedError
120 |
121 | def save_result(self, result_list, model_id):
122 | raise NotImplementedError
123 |
124 | def save_pseudo_gt(self, result_list, level=None):
125 | raise NotImplementedError
126 |
127 | def do_python_eval(self, model_id):
128 | raise NotImplementedError
129 |
--------------------------------------------------------------------------------
/lib/datasets/BaseMultiwGTauginfoDataset.py:
--------------------------------------------------------------------------------
1 | from __future__ import print_function, division
2 | import os
3 | import torch
4 | import pandas as pd
5 | import cv2
6 | import multiprocessing
7 | from skimage import io
8 | from PIL import Image
9 | import numpy as np
10 | from torch.utils.data import Dataset
11 | from datasets.transformmultiGTauginfo import *
12 | from utils.imutils import *
13 | from utils.registry import DATASETS
14 |
15 | #@DATASETS.register_module
16 | class BaseMultiwGTauginfoDataset(Dataset):
17 | def __init__(self, cfg, period, transform='none'):
18 | super(BaseMultiwGTauginfoDataset, self).__init__()
19 | self.cfg = cfg
20 | self.period = period
21 | self.transform = transform
22 | if 'train' not in self.period:
23 | assert self.transform == 'none'
24 | self.num_categories = None
25 | self.totensor = ToTensor()
26 | self.imagenorm = ImageNorm(cfg.DATA_MEAN, cfg.DATA_STD)
27 |
28 | if self.transform != 'none':
29 | if cfg.DATA_RANDOMCROP > 0:
30 | if self.transform == 'no':
31 | # self.randomcrop = RandomCrop(512)
32 | self.randomcrop = CenterCrop(512)
33 | else:
34 | self.randomcrop = RandomCrop(cfg.DATA_RANDOMCROP)
35 | if cfg.DATA_RANDOMSCALE != 1:
36 | self.randomscale = RandomScale(cfg.DATA_RANDOMSCALE)
37 | if cfg.DATA_RANDOMFLIP > 0:
38 | self.randomflip = RandomFlip(cfg.DATA_RANDOMFLIP)
39 | if cfg.DATA_RANDOM_H > 0 or cfg.DATA_RANDOM_S > 0 or cfg.DATA_RANDOM_V > 0:
40 | self.randomhsv = RandomHSV(cfg.DATA_RANDOM_H, cfg.DATA_RANDOM_S, cfg.DATA_RANDOM_V)
41 | else:
42 | self.multiscale = Multiscale(self.cfg.TEST_MULTISCALE)
43 |
44 |
45 | def __getitem__(self, idx):
46 | sample = self.__sample_generate__(idx)
47 |
48 | if 'segmentation' in sample.keys():
49 | sample['mask'] = sample['segmentation'] < self.num_categories
50 | t = sample['segmentation'].copy()
51 | t[t >= self.num_categories] = 0
52 | sample['segmentation_onehot']=onehot(t,self.num_categories)
53 | return self.totensor(sample)
54 |
55 | def __sample_generate__(self, idx, split_idx=0):
56 | name = self.load_name(idx)
57 | image = self.load_image(idx)
58 | r,c,_ = image.shape
59 | sample = {'image': image, 'name': name, 'row': r, 'col': c}
60 |
61 | if 'test' in self.period:
62 | return self.__transform__(sample)
63 | elif self.cfg.DATA_PSEUDO_GT and idx>=split_idx and 'train' in self.period:
64 | segmentation = self.load_pseudo_segmentation(idx)
65 | else:
66 | segmentation = self.load_segmentation(idx)
67 | sample['segmentation'] = segmentation
68 | t = sample['segmentation'].copy()
69 | t[t >= self.num_categories] = 0
70 | sample['category'] = seg2cls(t,self.num_categories)
71 | sample['category_copypaste'] = np.zeros(sample['category'].shape)
72 |
73 | if self.transform == 'none' and self.cfg.DATA_FEATURE_DIR:
74 | feature = self.load_feature(idx)
75 | sample['feature'] = feature
76 | return self.__transform__(sample)
77 |
78 | def __transform__(self, sample):
79 | if self.transform == 'weak':
80 | sample = self.__weak_augment__(sample)
81 | elif self.transform == 'strong':
82 | sample = self.__strong_augment__(sample)
83 | elif self.transform == 'no':
84 | sample = self.__dict_augment__(sample)
85 | else:
86 | sample = self.imagenorm(sample)
87 | sample = self.multiscale(sample)
88 | return sample
89 |
90 | def __weak_augment__(self, sample):
91 | if self.cfg.DATA_RANDOM_H>0 or self.cfg.DATA_RANDOM_S>0 or self.cfg.DATA_RANDOM_V>0:
92 | sample = self.randomhsv(sample)
93 | if self.cfg.DATA_RANDOMFLIP > 0:
94 | sample = self.randomflip(sample)
95 | if self.cfg.DATA_RANDOMSCALE != 1:
96 | sample = self.randomscale(sample)
97 | sample = self.imagenorm(sample)
98 | if self.cfg.DATA_RANDOMCROP > 0:
99 | sample = self.randomcrop(sample)
100 | return sample
101 |
102 | def __dict_augment__(self, sample):
103 | sample = self.imagenorm(sample)
104 | if self.cfg.DATA_RANDOMCROP > 0:
105 | sample = self.randomcrop(sample)
106 | return sample
107 | def __strong_augment__(self, sample):
108 | raise NotImplementedError
109 |
110 | def __len__(self):
111 | raise NotImplementedError
112 |
113 | def load_name(self, idx):
114 | raise NotImplementedError
115 |
116 | def load_image(self, idx):
117 | raise NotImplementedError
118 |
119 | def load_segmentation(self, idx):
120 | raise NotImplementedError
121 |
122 | def load_pseudo_segmentation(self, idx):
123 | raise NotImplementedError
124 |
125 | def load_feature(self, idx):
126 | raise NotImplementedError
127 |
128 | def save_result(self, result_list, model_id):
129 | raise NotImplementedError
130 |
131 | def save_pseudo_gt(self, result_list, level=None):
132 | raise NotImplementedError
133 |
134 | def do_python_eval(self, model_id):
135 | raise NotImplementedError
136 |
--------------------------------------------------------------------------------
/lib/datasets/VOCEvalDataset.py:
--------------------------------------------------------------------------------
1 | # ----------------------------------------
2 | # The dataset for Eval that is used for label correction
3 | # ----------------------------------------
4 |
5 | from __future__ import print_function, division
6 | import os
7 | import torch
8 | import pandas as pd
9 | import cv2
10 | import multiprocessing
11 | from skimage import io
12 | from PIL import Image
13 | import numpy as np
14 | from torch.utils.data import Dataset
15 | from datasets.transformmultiGT import *
16 | from utils.imutils import *
17 | from utils.registry import DATASETS
18 | from datasets.BaseMultiwGTauginfoDataset import BaseMultiwGTauginfoDataset
19 | import torch.nn.functional as F
20 | from utils.iou_computation import update_iou_stat, compute_iou
21 |
22 |
23 | @DATASETS.register_module
24 | class VOCEvalDataset(BaseMultiwGTauginfoDataset):
25 | def __init__(self, cfg, period, transform='none'):
26 | super(VOCEvalDataset, self).__init__(cfg, period, transform)
27 | self.dataset_name = 'VOC%d'%cfg.DATA_YEAR
28 | self.root_dir = os.path.join(cfg.ROOT_DIR,'data','VOCdevkit')
29 | self.dataset_dir = os.path.join(self.root_dir,self.dataset_name)
30 | self.rst_dir = os.path.join(self.root_dir,'results',self.dataset_name,'Segmentation')
31 | self.eval_dir = os.path.join(self.root_dir,'eval_result',self.dataset_name,'Segmentation')
32 | self.img_dir = os.path.join(self.dataset_dir, 'JPEGImages')
33 | # print(self.img_dir)
34 | self.ann_dir = os.path.join(self.dataset_dir, 'Annotations')
35 | self.seg_dir = os.path.join(self.dataset_dir, 'SegmentationClass')
36 | self.seg_dir_gt = os.path.join(self.dataset_dir, 'SegmentationClassAug')
37 | self.set_dir = os.path.join(self.dataset_dir, 'ImageSets', 'Segmentation')
38 | if cfg.DATA_PSEUDO_GT:
39 | self.pseudo_gt_dir = cfg.DATA_PSEUDO_GT
40 | # self.pseudo_gt_dir_2 = cfg.DATA_PSEUDO_GT_2
41 | # self.pseudo_gt_dir_3 = cfg.DATA_PSEUDO_GT_3
42 | else:
43 | self.pseudo_gt_dir = os.path.join(self.root_dir,'pseudo_gt',self.dataset_name,'Segmentation')
44 |
45 | file_name = None
46 | if cfg.DATA_AUG and 'train' in self.period:
47 | file_name = self.set_dir+'/'+period+'aug.txt'
48 | else:
49 | file_name = self.set_dir+'/'+period+'.txt'
50 | df = pd.read_csv(file_name, names=['filename'])
51 | self.name_list = df['filename'].values
52 | # print(self.name_list[1])
53 | if self.dataset_name == 'VOC2012':
54 | self.categories = ['aeroplane','bicycle','bird','boat','bottle','bus','car','cat','chair','cow',
55 | 'diningtable','dog','horse','motorbike','person','pottedplant','sheep','sofa','train','tvmonitor']
56 | self.coco2voc = [[0],[5],[2],[16],[9],[44],[6],[3],[17],[62],
57 | [21],[67],[18],[19],[4],[1],[64],[20],[63],[7],[72]]
58 |
59 | self.num_categories = len(self.categories)+1
60 | self.cmap = self.__colormap(len(self.categories)+1)
61 |
62 | # to record the previous prediction
63 | self.prev_pred_dict = {}
64 |
65 | self.ori_indx_list =[]
66 |
67 | def __len__(self):
68 | return len(self.name_list)
69 |
70 |
71 | def __getitem__(self, idx):
72 | sample = self.__sample_generate__(idx)
73 | if 'segmentation' in sample.keys():
74 | sample['mask'] = sample['segmentation'] < self.num_categories
75 | t = sample['segmentation'].copy()
76 | t[t >= self.num_categories] = 0
77 | sample['segmentation_onehot']=onehot(t,self.num_categories)
78 | return self.totensor(sample)
79 |
80 | def __sample_generate__(self, idx, split_idx=0):
81 | name = self.load_name(idx)
82 | image = self.load_image(idx)
83 | r,c,_ = image.shape
84 | sample = {'image': image, 'name': name, 'row': r, 'col': c, 'batch_idx':idx }
85 |
86 | if 'test' in self.period:
87 | return self.__transform__(sample)
88 | elif self.cfg.DATA_PSEUDO_GT and idx>=split_idx and 'train' in self.period:
89 | segmentation, seg_gt = self.load_pseudo_segmentation(idx)
90 | else:
91 | segmentation = self.load_segmentation(idx)
92 |
93 | sample['segmentation'] = segmentation
94 | t = sample['segmentation'].copy()
95 | t[t >= self.num_categories] = 0
96 | sample['category'] = seg2cls(t,self.num_categories)
97 | sample['category_copypaste'] = np.zeros(sample['category'].shape)
98 |
99 | # if there is previous prediction for this video
100 | if idx in self.prev_pred_dict.keys():
101 | # interpolate to the image spatial resolution self.prev_pred_dict[idx] size 1,c,h,w
102 | if torch.is_tensor(self.prev_pred_dict[idx]):
103 | # prev_pred = F.interpolate(self.prev_pred_dict[idx], size=(r, c), mode='nearest')
104 | prev_pred = F.interpolate(self.prev_pred_dict[idx], size=(r, c), mode='bilinear',align_corners=True,
105 | recompute_scale_factor=False)
106 | else:
107 | # prev_pred = F.interpolate(torch.tensor(self.prev_pred_dict[idx]), size=(r, c), mode='nearest')
108 | prev_pred = F.interpolate(torch.tensor(self.prev_pred_dict[idx]), size=(r, c), mode='bilinear',align_corners=True,
109 | recompute_scale_factor=False)
110 | sample['prev_prediction'] = prev_pred #1,c,h,w
111 |
112 | # the small scale case
113 | # sample['segmentation2'] = segmentation2
114 | # sample['segmentation3'] = segmentation3
115 |
116 |
117 | sample['segmentationgt'] = seg_gt
118 |
119 | if self.transform == 'none' and self.cfg.DATA_FEATURE_DIR:
120 | feature = self.load_feature(idx)
121 | sample['feature'] = feature
122 | return self.__transform__(sample)
123 |
124 |
125 | def load_name(self, idx):
126 | name = self.name_list[idx]
127 | return name
128 |
129 | def load_image(self, idx):
130 | name = self.name_list[idx]
131 | img_file = self.img_dir + '/' + name + '.jpg'
132 | image = cv2.imread(img_file)
133 | image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
134 | return image_rgb
135 |
136 | def load_segmentation(self, idx):
137 | name = self.name_list[idx]
138 | seg_file = self.seg_dir + '/' + name + '.png'
139 | segmentation = np.array(Image.open(seg_file))
140 | return segmentation
141 |
142 | def load_pseudo_segmentation(self, idx):
143 | name = self.name_list[idx]
144 | seg_file = self.pseudo_gt_dir + '/' + name + '.png'
145 |
146 | segmentation1 = Image.open(seg_file)
147 | width, height = segmentation1.size
148 |
149 | segmentation1 = np.array(segmentation1)
150 |
151 | seg_gt_file = self.seg_dir_gt + '/' + name + '.png'
152 | seg_gt = np.array(Image.open(seg_gt_file).resize((width, height)))
153 |
154 | return segmentation1, seg_gt
155 |
156 | def __colormap(self, N):
157 | """Get the map from label index to color
158 |
159 | Args:
160 | N: number of class
161 |
162 | return: a Nx3 matrix
163 |
164 | """
165 | cmap = np.zeros((N, 3), dtype = np.uint8)
166 |
167 | def uint82bin(n, count=8):
168 | """returns the binary of integer n, count refers to amount of bits"""
169 | return ''.join([str((n >> y) & 1) for y in range(count-1, -1, -1)])
170 |
171 | for i in range(N):
172 | r = 0
173 | g = 0
174 | b = 0
175 | idx = i
176 | for j in range(7):
177 | str_id = uint82bin(idx)
178 | r = r ^ ( np.uint8(str_id[-1]) << (7-j))
179 | g = g ^ ( np.uint8(str_id[-2]) << (7-j))
180 | b = b ^ ( np.uint8(str_id[-3]) << (7-j))
181 | idx = idx >> 3
182 | cmap[i, 0] = r
183 | cmap[i, 1] = g
184 | cmap[i, 2] = b
185 | return cmap
186 |
187 | def load_ranked_namelist(self):
188 | df = self.read_rank_result()
189 | self.name_list = df['filename'].values
190 |
191 | def label2colormap(self, label):
192 | m = label.astype(np.uint8)
193 | r,c = m.shape
194 | cmap = np.zeros((r,c,3), dtype=np.uint8)
195 | cmap[:,:,0] = (m&1)<<7 | (m&8)<<3
196 | cmap[:,:,1] = (m&2)<<6 | (m&16)<<2
197 | cmap[:,:,2] = (m&4)<<5
198 | cmap[m==255] = [255,255,255]
199 | return cmap
200 |
201 | def save_result(self, result_list, model_id):
202 | """Save test results
203 |
204 | Args:
205 | result_list(list of dict): [{'name':name1, 'predict':predict_seg1},{...},...]
206 |
207 | """
208 | folder_path = os.path.join(self.rst_dir,'%s_%s'%(model_id,self.period))
209 | if not os.path.exists(folder_path):
210 | os.makedirs(folder_path)
211 |
212 | for sample in result_list:
213 | file_path = os.path.join(folder_path, '%s.png'%sample['name'])
214 | cv2.imwrite(file_path, sample['predict'])
215 |
216 | def save_pseudo_gt(self, result_list, folder_path=None):
217 | """Save pseudo gt
218 |
219 | Args:
220 | result_list(list of dict): [{'name':name1, 'predict':predict_seg1},{...},...]
221 |
222 | """
223 | i = 1
224 | folder_path = self.pseudo_gt_dir if folder_path is None else folder_path
225 | if not os.path.exists(folder_path):
226 | os.makedirs(folder_path)
227 | for sample in result_list:
228 | file_path = os.path.join(folder_path, '%s.png'%(sample['name']))
229 | cv2.imwrite(file_path, sample['predict'])
230 | i+=1
231 |
232 | def do_matlab_eval(self, model_id):
233 | import subprocess
234 | path = os.path.join(self.root_dir, 'VOCcode')
235 | eval_filename = os.path.join(self.eval_dir,'%s_result.mat'%model_id)
236 | cmd = 'cd {} && '.format(path)
237 | cmd += 'matlab -nodisplay -nodesktop '
238 | cmd += '-r "dbstop if error; VOCinit; '
239 | cmd += 'VOCevalseg(VOCopts,\'{:s}\');'.format(model_id)
240 | cmd += 'accuracies,avacc,conf,rawcounts = VOCevalseg(VOCopts,\'{:s}\'); '.format(model_id)
241 | cmd += 'save(\'{:s}\',\'accuracies\',\'avacc\',\'conf\',\'rawcounts\'); '.format(eval_filename)
242 | cmd += 'quit;"'
243 |
244 | print('start subprocess for matlab evaluation...')
245 | print(cmd)
246 | subprocess.call(cmd, shell=True)
247 |
248 | def do_python_eval(self, model_id):
249 | predict_folder = os.path.join(self.rst_dir,'%s_%s'%(model_id,self.period))
250 | gt_folder = self.seg_dir
251 | TP = []
252 | P = []
253 | T = []
254 | for i in range(self.num_categories):
255 | TP.append(multiprocessing.Value('i', 0, lock=True))
256 | P.append(multiprocessing.Value('i', 0, lock=True))
257 | T.append(multiprocessing.Value('i', 0, lock=True))
258 |
259 | def compare(start,step,TP,P,T):
260 | for idx in range(start,len(self.name_list),step):
261 | #print('%d/%d'%(idx,len(self.name_list)))
262 | name = self.name_list[idx]
263 | predict_file = os.path.join(predict_folder,'%s.png'%name)
264 | gt_file = os.path.join(gt_folder,'%s.png'%name)
265 | predict = np.array(Image.open(predict_file)) #cv2.imread(predict_file)
266 | gt = np.array(Image.open(gt_file))
267 | cal = gt<255
268 | mask = (predict==gt) * cal
269 |
270 | for i in range(self.num_categories):
271 | P[i].acquire()
272 | P[i].value += np.sum((predict==i)*cal)
273 | P[i].release()
274 | T[i].acquire()
275 | T[i].value += np.sum((gt==i)*cal)
276 | T[i].release()
277 | TP[i].acquire()
278 | TP[i].value += np.sum((gt==i)*mask)
279 | TP[i].release()
280 | p_list = []
281 | for i in range(8):
282 | p = multiprocessing.Process(target=compare, args=(i,8,TP,P,T))
283 | p.start()
284 | p_list.append(p)
285 | for p in p_list:
286 | p.join()
287 | IoU = []
288 | for i in range(self.num_categories):
289 | IoU.append(TP[i].value/(T[i].value+P[i].value-TP[i].value+1e-10))
290 | loglist = {}
291 | for i in range(self.num_categories):
292 | if i == 0:
293 | print('%11s:%7.3f%%'%('background',IoU[i]*100),end='\t')
294 | loglist['background'] = IoU[i] * 100
295 | else:
296 | if i%2 != 1:
297 | print('%11s:%7.3f%%'%(self.categories[i-1],IoU[i]*100),end='\t')
298 | else:
299 | print('%11s:%7.3f%%'%(self.categories[i-1],IoU[i]*100))
300 | loglist[self.categories[i-1]] = IoU[i] * 100
301 |
302 | miou = np.mean(np.array(IoU))
303 | print('\n======================================================')
304 | print('%11s:%7.3f%%'%('mIoU',miou*100))
305 | loglist['mIoU'] = miou * 100
306 | return loglist
307 |
308 | def do_python_eval_batch_pseudo_one_process(self):
309 | self.seg_dir_gt = os.path.join(self.dataset_dir, 'SegmentationClassAug')
310 | gt_folder = self.seg_dir_gt
311 | TP_gt_epoch = [0] * 21
312 | P_gt_epoch = [0] * 21
313 | T_gt_epoch = [0] * 21
314 | loglist = {}
315 | for idx in range(len(self.name_list)):
316 | # print(idx)
317 | name = self.name_list[idx]
318 | gt_file = os.path.join(gt_folder, '%s.png' % name)
319 | gt = np.array(Image.open(gt_file))
320 | r, c = gt.shape
321 | # print(r)
322 | predict_tensor = F.interpolate(self.prev_pred_dict[idx], size=(r, c), mode='bilinear', align_corners=True,
323 | recompute_scale_factor=False) # 1,c,h,w
324 | predict = predict_tensor[0].cpu().numpy() # c,h,w
325 | predict = np.argmax(predict, axis=0) # h,w
326 |
327 | TP_gt_epoch, P_gt_epoch, T_gt_epoch = update_iou_stat(predict, gt, TP_gt_epoch,
328 | P_gt_epoch, T_gt_epoch)
329 | IoU_gt_epoch = compute_iou(TP_gt_epoch, P_gt_epoch, T_gt_epoch)
330 | for indx, class_name in enumerate(
331 | ['background', 'aeroplane', 'bicycle', 'bird', 'boat', 'bottle', 'bus', 'car', 'cat', 'chair',
332 | 'cow',
333 | 'diningtable', 'dog', 'horse', 'motorbike', 'person', 'pottedplant', 'sheep', 'sofa', 'train',
334 | 'tvmonitor']):
335 | loglist[class_name] = IoU_gt_epoch[indx]
336 | mIoU_clean_epoch = np.mean(np.array(IoU_gt_epoch))
337 | loglist['mIoU'] = mIoU_clean_epoch
338 | return loglist
339 |
340 |
341 | def __coco2voc(self, m):
342 | r,c = m.shape
343 | result = np.zeros((r,c),dtype=np.uint8)
344 | for i in range(0,21):
345 | for j in self.coco2voc[i]:
346 | result[m==j] = i
347 | return result
348 |
349 |
350 |
--------------------------------------------------------------------------------
/lib/datasets/__init__.py:
--------------------------------------------------------------------------------
1 | from .VOCDataset import *
2 | from .VOCEvalDataset import *
3 | from .VOCTrainwsegDataset import *
4 |
--------------------------------------------------------------------------------
/lib/datasets/generateData.py:
--------------------------------------------------------------------------------
1 | from utils.registry import DATASETS
2 |
3 | def generate_dataset(cfg, **kwargs):
4 | dataset = DATASETS.get(cfg.DATA_NAME)(cfg, **kwargs)
5 | return dataset
6 |
--------------------------------------------------------------------------------
/lib/datasets/metric.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import re
3 | import functools
4 |
5 | class AverageMeter(object):
6 | """Computes and stores the average and current value"""
7 | def __init__(self):
8 | self.initialized = False
9 | self.val = None
10 | self.avg = None
11 | self.sum = None
12 | self.count = None
13 |
14 | def initialize(self, val, weight):
15 | self.val = val
16 | self.avg = val
17 | self.sum = val * weight
18 | self.count = weight
19 | self.initialized = True
20 |
21 | def update(self, val, weight=1):
22 | if not self.initialized:
23 | self.initialize(val, weight)
24 | else:
25 | self.add(val, weight)
26 |
27 | def add(self, val, weight):
28 | self.val = val
29 | self.sum += val * weight
30 | self.count += weight
31 | self.avg = self.sum / self.count
32 |
33 | def value(self):
34 | return self.val
35 |
36 | def average(self):
37 | return self.avg
38 |
39 |
40 | def unique(ar, return_index=False, return_inverse=False, return_counts=False):
41 | ar = np.asanyarray(ar).flatten()
42 |
43 | optional_indices = return_index or return_inverse
44 | optional_returns = optional_indices or return_counts
45 |
46 | if ar.size == 0:
47 | if not optional_returns:
48 | ret = ar
49 | else:
50 | ret = (ar,)
51 | if return_index:
52 | ret += (np.empty(0, np.bool),)
53 | if return_inverse:
54 | ret += (np.empty(0, np.bool),)
55 | if return_counts:
56 | ret += (np.empty(0, np.intp),)
57 | return ret
58 | if optional_indices:
59 | perm = ar.argsort(kind='mergesort' if return_index else 'quicksort')
60 | aux = ar[perm]
61 | else:
62 | ar.sort()
63 | aux = ar
64 | flag = np.concatenate(([True], aux[1:] != aux[:-1]))
65 |
66 | if not optional_returns:
67 | ret = aux[flag]
68 | else:
69 | ret = (aux[flag],)
70 | if return_index:
71 | ret += (perm[flag],)
72 | if return_inverse:
73 | iflag = np.cumsum(flag) - 1
74 | inv_idx = np.empty(ar.shape, dtype=np.intp)
75 | inv_idx[perm] = iflag
76 | ret += (inv_idx,)
77 | if return_counts:
78 | idx = np.concatenate(np.nonzero(flag) + ([ar.size],))
79 | ret += (np.diff(idx),)
80 | return ret
81 |
82 |
83 | def colorEncode(labelmap, colors, mode='BGR'):
84 | labelmap = labelmap.astype('int')
85 | labelmap_rgb = np.zeros((labelmap.shape[0], labelmap.shape[1], 3),
86 | dtype=np.uint8)
87 | for label in unique(labelmap):
88 | if label < 0:
89 | continue
90 | labelmap_rgb += (labelmap == label)[:, :, np.newaxis] * \
91 | np.tile(colors[label],
92 | (labelmap.shape[0], labelmap.shape[1], 1))
93 |
94 | if mode == 'BGR':
95 | return labelmap_rgb[:, :, ::-1]
96 | else:
97 | return labelmap_rgb
98 |
99 |
100 | def accuracy(preds, label):
101 | valid = (label >= 0)
102 | acc_sum = (valid * (preds == label)).sum()
103 | valid_sum = valid.sum()
104 | acc = float(acc_sum) / (valid_sum + 1e-10)
105 | return acc, valid_sum
106 |
107 |
108 | def intersectionAndUnion(imPred, imLab, numClass):
109 | imPred = np.asarray(imPred).copy()
110 | imLab = np.asarray(imLab).copy()
111 |
112 | imPred += 1
113 | imLab += 1
114 | # Remove classes from unlabeled pixels in gt image.
115 | # We should not penalize detections in unlabeled portions of the image.
116 | imPred = imPred * (imLab > 0)
117 |
118 | # Compute area intersection:
119 | intersection = imPred * (imPred == imLab)
120 | (area_intersection, _) = np.histogram(
121 | intersection, bins=numClass, range=(1, numClass))
122 |
123 | # Compute area union:
124 | (area_pred, _) = np.histogram(imPred, bins=numClass, range=(1, numClass))
125 | (area_lab, _) = np.histogram(imLab, bins=numClass, range=(1, numClass))
126 | area_union = area_pred + area_lab - area_intersection
127 |
128 | return (area_intersection, area_union)
129 |
130 |
131 | class NotSupportedCliException(Exception):
132 | pass
133 |
134 |
135 | def process_range(xpu, inp):
136 | start, end = map(int, inp)
137 | if start > end:
138 | end, start = start, end
139 | return map(lambda x: '{}{}'.format(xpu, x), range(start, end+1))
140 |
141 |
142 | REGEX = [
143 | (re.compile(r'^gpu(\d+)$'), lambda x: ['gpu%s' % x[0]]),
144 | (re.compile(r'^(\d+)$'), lambda x: ['gpu%s' % x[0]]),
145 | (re.compile(r'^gpu(\d+)-(?:gpu)?(\d+)$'),
146 | functools.partial(process_range, 'gpu')),
147 | (re.compile(r'^(\d+)-(\d+)$'),
148 | functools.partial(process_range, 'gpu')),
149 | ]
150 |
151 |
152 | def parse_devices(input_devices):
153 |
154 | """Parse user's devices input str to standard format.
155 | e.g. [gpu0, gpu1, ...]
156 |
157 | """
158 | ret = []
159 | for d in input_devices.split(','):
160 | for regex, func in REGEX:
161 | m = regex.match(d.lower().strip())
162 | if m:
163 | tmp = func(m.groups())
164 | # prevent duplicate
165 | for x in tmp:
166 | if x not in ret:
167 | ret.append(x)
168 | break
169 | else:
170 | raise NotSupportedCliException(
171 | 'Can not recognize device: "%s"' % d)
172 | return ret
173 |
--------------------------------------------------------------------------------
/lib/datasets/transform.py:
--------------------------------------------------------------------------------
1 | import cv2
2 | import numpy as np
3 | import torch
4 | import random
5 | import PIL
6 | from PIL import Image, ImageOps, ImageFilter
7 |
8 | class RandomCrop(object):
9 | """Crop randomly the image in a sample.
10 |
11 | Args:
12 | output_size (tuple or int): Desired output size. If int, square crop
13 | is made.
14 | """
15 |
16 | def __init__(self, output_size):
17 | assert isinstance(output_size, (int, tuple))
18 | if isinstance(output_size, int):
19 | self.output_size = (output_size, output_size)
20 | else:
21 | assert len(output_size) == 2
22 | self.output_size = output_size
23 |
24 | def __call__(self, sample):
25 |
26 | h, w = sample['image'].shape[:2]
27 | ch = min(h, self.output_size[0])
28 | cw = min(w, self.output_size[1])
29 |
30 | h_space = h - self.output_size[0]
31 | w_space = w - self.output_size[1]
32 |
33 | if w_space > 0:
34 | cont_left = 0
35 | img_left = random.randrange(w_space+1)
36 | else:
37 | cont_left = random.randrange(-w_space+1)
38 | img_left = 0
39 |
40 | if h_space > 0:
41 | cont_top = 0
42 | img_top = random.randrange(h_space+1)
43 | else:
44 | cont_top = random.randrange(-h_space+1)
45 | img_top = 0
46 |
47 | key_list = sample.keys()
48 | for key in key_list:
49 | if 'image' in key:
50 | img = sample[key]
51 | img_crop = np.zeros((self.output_size[0], self.output_size[1], 3), np.float32)
52 | img_crop[cont_top:cont_top+ch, cont_left:cont_left+cw] = \
53 | img[img_top:img_top+ch, img_left:img_left+cw]
54 | #img_crop = img[img_top:img_top+ch, img_left:img_left+cw]
55 | sample[key] = img_crop
56 | elif 'segmentation' == key:
57 | seg = sample[key]
58 | seg_crop = np.ones((self.output_size[0], self.output_size[1]), np.float32)*255
59 | seg_crop[cont_top:cont_top+ch, cont_left:cont_left+cw] = \
60 | seg[img_top:img_top+ch, img_left:img_left+cw]
61 | #seg_crop = seg[img_top:img_top+ch, img_left:img_left+cw]
62 | sample[key] = seg_crop
63 | elif 'segmentation_pseudo' in key:
64 | seg_pseudo = sample[key]
65 | seg_crop = np.ones((self.output_size[0], self.output_size[1]), np.float32)*255
66 | seg_crop[cont_top:cont_top+ch, cont_left:cont_left+cw] = \
67 | seg_pseudo[img_top:img_top+ch, img_left:img_left+cw]
68 | #seg_crop = seg_pseudo[img_top:img_top+ch, img_left:img_left+cw]
69 | sample[key] = seg_crop
70 | return sample
71 |
72 | class RandomHSV(object):
73 | """Generate randomly the image in hsv space."""
74 | def __init__(self, h_r, s_r, v_r):
75 | self.h_r = h_r
76 | self.s_r = s_r
77 | self.v_r = v_r
78 |
79 | def __call__(self, sample):
80 | image = sample['image']
81 | hsv = cv2.cvtColor(image, cv2.COLOR_RGB2HSV)
82 | h = hsv[:,:,0].astype(np.int32)
83 | s = hsv[:,:,1].astype(np.int32)
84 | v = hsv[:,:,2].astype(np.int32)
85 | delta_h = random.randint(-self.h_r,self.h_r)
86 | delta_s = random.randint(-self.s_r,self.s_r)
87 | delta_v = random.randint(-self.v_r,self.v_r)
88 | h = (h + delta_h)%180
89 | s = s + delta_s
90 | s[s>255] = 255
91 | s[s<0] = 0
92 | v = v + delta_v
93 | v[v>255] = 255
94 | v[v<0] = 0
95 | hsv = np.stack([h,s,v], axis=-1).astype(np.uint8)
96 | image = cv2.cvtColor(hsv, cv2.COLOR_HSV2RGB).astype(np.uint8)
97 | sample['image'] = image
98 | return sample
99 |
100 | class RandomFlip(object):
101 | """Randomly flip image"""
102 | def __init__(self, threshold):
103 | self.flip_t = threshold
104 | def __call__(self, sample):
105 | if random.random() < self.flip_t:
106 | key_list = sample.keys()
107 | for key in key_list:
108 | if 'image' in key:
109 | img = sample[key]
110 | img = np.flip(img, axis=1)
111 | sample[key] = img
112 | elif 'segmentation' == key:
113 | seg = sample[key]
114 | seg = np.flip(seg, axis=1)
115 | sample[key] = seg
116 | elif 'segmentation_pseudo' in key:
117 | seg_pseudo = sample[key]
118 | seg_pseudo = np.flip(seg_pseudo, axis=1)
119 | sample[key] = seg_pseudo
120 | return sample
121 |
122 | class RandomScale(object):
123 | """Randomly scale image"""
124 | def __init__(self, scale_r, is_continuous=False):
125 | self.scale_r = scale_r
126 | self.seg_interpolation = cv2.INTER_CUBIC if is_continuous else cv2.INTER_NEAREST
127 |
128 | def __call__(self, sample):
129 | row, col, _ = sample['image'].shape
130 | rand_scale = random.random()*(self.scale_r[1] - self.scale_r[0]) + self.scale_r[0]
131 | key_list = sample.keys()
132 | for key in key_list:
133 | if 'image' in key:
134 | img = sample[key]
135 | img = cv2.resize(img, None, fx=rand_scale, fy=rand_scale, interpolation=cv2.INTER_CUBIC)
136 | sample[key] = img
137 | elif 'segmentation' == key:
138 | seg = sample[key]
139 | seg = cv2.resize(seg, None, fx=rand_scale, fy=rand_scale, interpolation=self.seg_interpolation)
140 | sample[key] = seg
141 | elif 'segmentation_pseudo' in key:
142 | seg_pseudo = sample[key]
143 | seg_pseudo = cv2.resize(seg_pseudo, None, fx=rand_scale, fy=rand_scale, interpolation=self.seg_interpolation)
144 | sample[key] = seg_pseudo
145 | return sample
146 |
147 | class ImageNorm(object):
148 | """Randomly scale image"""
149 | def __init__(self, mean=None, std=None):
150 | self.mean = mean
151 | self.std = std
152 | def __call__(self, sample):
153 | key_list = sample.keys()
154 | for key in key_list:
155 | if 'image' in key:
156 | image = sample[key].astype(np.float32)
157 | if self.mean is not None and self.std is not None:
158 | image[...,0] = (image[...,0]/255 - self.mean[0]) / self.std[0]
159 | image[...,1] = (image[...,1]/255 - self.mean[1]) / self.std[1]
160 | image[...,2] = (image[...,2]/255 - self.mean[2]) / self.std[2]
161 | else:
162 | image /= 255.0
163 | sample[key] = image
164 | return sample
165 |
166 | class Multiscale(object):
167 | def __init__(self, rate_list):
168 | self.rate_list = rate_list
169 |
170 | def __call__(self, sample):
171 | image = sample['image']
172 | row, col, _ = image.shape
173 | image_multiscale = []
174 | for rate in self.rate_list:
175 | rescaled_image = cv2.resize(image, None, fx=rate, fy=rate, interpolation=cv2.INTER_CUBIC)
176 | sample['image_%f'%rate] = rescaled_image
177 | return sample
178 |
179 |
180 | class ToTensor(object):
181 | """Convert ndarrays in sample to Tensors."""
182 |
183 | def __call__(self, sample):
184 | key_list = sample.keys()
185 | for key in key_list:
186 | if 'image' in key:
187 | image = sample[key].astype(np.float32)
188 | # swap color axis because
189 | # numpy image: H x W x C
190 | # torch image: C X H X W
191 | image = image.transpose((2,0,1))
192 | sample[key] = torch.from_numpy(image)
193 | #sample[key] = torch.from_numpy(image.astype(np.float32)/128.0-1.0)
194 | elif 'edge' == key:
195 | edge = sample['edge']
196 | sample['edge'] = torch.from_numpy(edge.astype(np.float32))
197 | sample['edge'] = torch.unsqueeze(sample['edge'],0)
198 | elif 'segmentation' == key:
199 | segmentation = sample['segmentation']
200 | sample['segmentation'] = torch.from_numpy(segmentation.astype(np.long))
201 | elif 'segmentation_pseudo' in key:
202 | segmentation_pseudo = sample[key]
203 | sample[key] = torch.from_numpy(segmentation_pseudo.astype(np.float32))
204 | elif 'segmentation_onehot' == key:
205 | onehot = sample['segmentation_onehot'].transpose((2,0,1))
206 | sample['segmentation_onehot'] = torch.from_numpy(onehot.astype(np.float32))
207 | elif 'category' in key:
208 | sample[key] = torch.from_numpy(sample[key].astype(np.float32))
209 | elif 'mask' == key:
210 | mask = sample['mask']
211 | sample['mask'] = torch.from_numpy(mask.astype(np.float32))
212 | elif 'feature' == key:
213 | feature = sample['feature']
214 | sample['feature'] = torch.from_numpy(feature.astype(np.float32))
215 | return sample
216 |
217 |
--------------------------------------------------------------------------------
/lib/datasets/transformmultiGT.py:
--------------------------------------------------------------------------------
1 | # ----------------------------------------
2 | # heavily borrowed from Yude Wang, modified by Kangning Liu
3 | # ----------------------------------------
4 |
5 | import cv2
6 | import numpy as np
7 | import torch
8 | import random
9 | import PIL
10 | from PIL import Image, ImageOps, ImageFilter
11 |
12 | class RandomCrop(object):
13 | """Crop randomly the image in a sample.
14 |
15 | Args:
16 | output_size (tuple or int): Desired output size. If int, square crop
17 | is made.
18 | """
19 |
20 | def __init__(self, output_size):
21 | assert isinstance(output_size, (int, tuple))
22 | if isinstance(output_size, int):
23 | self.output_size = (output_size, output_size)
24 | else:
25 | assert len(output_size) == 2
26 | self.output_size = output_size
27 |
28 | def __call__(self, sample):
29 |
30 | h, w = sample['image'].shape[:2]
31 | ch = min(h, self.output_size[0])
32 | cw = min(w, self.output_size[1])
33 |
34 | h_space = h - self.output_size[0]
35 | w_space = w - self.output_size[1]
36 |
37 | if w_space > 0:
38 | cont_left = 0
39 | img_left = random.randrange(w_space+1)
40 | else:
41 | cont_left = random.randrange(-w_space+1)
42 | img_left = 0
43 |
44 | if h_space > 0:
45 | cont_top = 0
46 | img_top = random.randrange(h_space+1)
47 | else:
48 | cont_top = random.randrange(-h_space+1)
49 | img_top = 0
50 |
51 | key_list = sample.keys()
52 | for key in key_list:
53 | if 'image' in key:
54 | img = sample[key]
55 | img_crop = np.zeros((self.output_size[0], self.output_size[1], 3), np.float32)
56 | img_crop[cont_top:cont_top+ch, cont_left:cont_left+cw] = \
57 | img[img_top:img_top+ch, img_left:img_left+cw]
58 | #img_crop = img[img_top:img_top+ch, img_left:img_left+cw]
59 | sample[key] = img_crop
60 | elif 'segmentation' == key:
61 | seg = sample[key]
62 | seg_crop = np.ones((self.output_size[0], self.output_size[1]), np.float32)*255
63 | seg_crop[cont_top:cont_top+ch, cont_left:cont_left+cw] = \
64 | seg[img_top:img_top+ch, img_left:img_left+cw]
65 | #seg_crop = seg[img_top:img_top+ch, img_left:img_left+cw]
66 | sample[key] = seg_crop
67 | elif 'segmentation2' == key or 'segmentation3' == key or 'segmentationgt' == key:
68 | seg = sample[key]
69 | seg_crop = np.ones((self.output_size[0], self.output_size[1]), np.float32)*255
70 | seg_crop[cont_top:cont_top+ch, cont_left:cont_left+cw] = \
71 | seg[img_top:img_top+ch, img_left:img_left+cw]
72 | #seg_crop = seg[img_top:img_top+ch, img_left:img_left+cw]
73 | sample[key] = seg_crop
74 | elif 'segmentation_pseudo' in key:
75 | seg_pseudo = sample[key]
76 | seg_crop = np.ones((self.output_size[0], self.output_size[1]), np.float32)*255
77 | seg_crop[cont_top:cont_top+ch, cont_left:cont_left+cw] = \
78 | seg_pseudo[img_top:img_top+ch, img_left:img_left+cw]
79 | #seg_crop = seg_pseudo[img_top:img_top+ch, img_left:img_left+cw]
80 | sample[key] = seg_crop
81 | return sample
82 |
83 | class RandomHSV(object):
84 | """Generate randomly the image in hsv space."""
85 | def __init__(self, h_r, s_r, v_r):
86 | self.h_r = h_r
87 | self.s_r = s_r
88 | self.v_r = v_r
89 |
90 | def __call__(self, sample):
91 | image = sample['image']
92 | hsv = cv2.cvtColor(image, cv2.COLOR_RGB2HSV)
93 | h = hsv[:,:,0].astype(np.int32)
94 | s = hsv[:,:,1].astype(np.int32)
95 | v = hsv[:,:,2].astype(np.int32)
96 | delta_h = random.randint(-self.h_r,self.h_r)
97 | delta_s = random.randint(-self.s_r,self.s_r)
98 | delta_v = random.randint(-self.v_r,self.v_r)
99 | h = (h + delta_h)%180
100 | s = s + delta_s
101 | s[s>255] = 255
102 | s[s<0] = 0
103 | v = v + delta_v
104 | v[v>255] = 255
105 | v[v<0] = 0
106 | hsv = np.stack([h,s,v], axis=-1).astype(np.uint8)
107 | image = cv2.cvtColor(hsv, cv2.COLOR_HSV2RGB).astype(np.uint8)
108 | sample['image'] = image
109 | return sample
110 |
111 | class RandomFlip(object):
112 | """Randomly flip image"""
113 | def __init__(self, threshold):
114 | self.flip_t = threshold
115 | def __call__(self, sample):
116 | if random.random() < self.flip_t:
117 | key_list = sample.keys()
118 | for key in key_list:
119 | if 'image' in key:
120 | img = sample[key]
121 | img = np.flip(img, axis=1)
122 | sample[key] = img
123 | elif 'segmentation' == key:
124 | seg = sample[key]
125 | seg = np.flip(seg, axis=1)
126 | sample[key] = seg
127 | elif 'segmentation2' == key or 'segmentation3' == key or 'segmentationgt' == key:
128 | seg = sample[key]
129 | seg = np.flip(seg, axis=1)
130 | sample[key] = seg
131 | elif 'segmentation_pseudo' in key:
132 | seg_pseudo = sample[key]
133 | seg_pseudo = np.flip(seg_pseudo, axis=1)
134 | sample[key] = seg_pseudo
135 | return sample
136 |
137 | class RandomScale(object):
138 | """Randomly scale image"""
139 | def __init__(self, scale_r, is_continuous=False):
140 | self.scale_r = scale_r
141 | self.seg_interpolation = cv2.INTER_CUBIC if is_continuous else cv2.INTER_NEAREST
142 |
143 | def __call__(self, sample):
144 | row, col, _ = sample['image'].shape
145 | rand_scale = random.random()*(self.scale_r[1] - self.scale_r[0]) + self.scale_r[0]
146 | key_list = sample.keys()
147 | for key in key_list:
148 | if 'image' in key:
149 | img = sample[key]
150 | img = cv2.resize(img, None, fx=rand_scale, fy=rand_scale, interpolation=cv2.INTER_CUBIC)
151 | sample[key] = img
152 | elif 'segmentation' == key:
153 | seg = sample[key]
154 | seg = cv2.resize(seg, None, fx=rand_scale, fy=rand_scale, interpolation=self.seg_interpolation)
155 | sample[key] = seg
156 | elif 'segmentation2' == key or 'segmentation3' == key or 'segmentationgt' == key:
157 | seg = sample[key]
158 | seg = cv2.resize(seg, None, fx=rand_scale, fy=rand_scale, interpolation=self.seg_interpolation)
159 | sample[key] = seg
160 | elif 'segmentation_pseudo' in key:
161 | seg_pseudo = sample[key]
162 | seg_pseudo = cv2.resize(seg_pseudo, None, fx=rand_scale, fy=rand_scale, interpolation=self.seg_interpolation)
163 | sample[key] = seg_pseudo
164 | return sample
165 |
166 | class ImageNorm(object):
167 | """Randomly scale image"""
168 | def __init__(self, mean=None, std=None):
169 | self.mean = mean
170 | self.std = std
171 | def __call__(self, sample):
172 | key_list = sample.keys()
173 | for key in key_list:
174 | if 'image' in key:
175 | image = sample[key].astype(np.float32)
176 | if self.mean is not None and self.std is not None:
177 | image[...,0] = (image[...,0]/255 - self.mean[0]) / self.std[0]
178 | image[...,1] = (image[...,1]/255 - self.mean[1]) / self.std[1]
179 | image[...,2] = (image[...,2]/255 - self.mean[2]) / self.std[2]
180 | else:
181 | image /= 255.0
182 | sample[key] = image
183 | return sample
184 |
185 | class Multiscale(object):
186 | def __init__(self, rate_list):
187 | self.rate_list = rate_list
188 |
189 | def __call__(self, sample):
190 | image = sample['image']
191 | row, col, _ = image.shape
192 | image_multiscale = []
193 | for rate in self.rate_list:
194 | rescaled_image = cv2.resize(image, None, fx=rate, fy=rate, interpolation=cv2.INTER_CUBIC)
195 | sample['image_%f'%rate] = rescaled_image
196 | return sample
197 |
198 |
199 | class ToTensor(object):
200 | """Convert ndarrays in sample to Tensors."""
201 |
202 | def __call__(self, sample):
203 | key_list = sample.keys()
204 | for key in key_list:
205 | if 'image' in key:
206 | image = sample[key].astype(np.float32)
207 | # swap color axis because
208 | # numpy image: H x W x C
209 | # torch image: C X H X W
210 | image = image.transpose((2,0,1))
211 | sample[key] = torch.from_numpy(image)
212 | #sample[key] = torch.from_numpy(image.astype(np.float32)/128.0-1.0)
213 | elif 'edge' == key:
214 | edge = sample['edge']
215 | sample['edge'] = torch.from_numpy(edge.astype(np.float32))
216 | sample['edge'] = torch.unsqueeze(sample['edge'],0)
217 | elif 'segmentation' == key:
218 | segmentation = sample['segmentation']
219 | sample['segmentation'] = torch.from_numpy(segmentation.astype(np.long))
220 |
221 | elif 'segmentation2' == key or 'segmentation3' == key or 'segmentationgt' == key:
222 | # segmentation = sample['segmentation2']
223 | # sample['segmentation2'] = torch.from_numpy(segmentation.astype(np.long))
224 | segmentation = sample[key]
225 | sample[key] = torch.from_numpy(segmentation.astype(np.long))
226 |
227 | elif 'segmentation_pseudo' in key:
228 | segmentation_pseudo = sample[key]
229 | sample[key] = torch.from_numpy(segmentation_pseudo.astype(np.float32))
230 | elif 'segmentation_onehot' == key:
231 | onehot = sample['segmentation_onehot'].transpose((2,0,1))
232 | sample['segmentation_onehot'] = torch.from_numpy(onehot.astype(np.float32))
233 | elif 'category' in key:
234 | sample[key] = torch.from_numpy(sample[key].astype(np.float32))
235 | elif 'mask' == key:
236 | mask = sample['mask']
237 | sample['mask'] = torch.from_numpy(mask.astype(np.float32))
238 | elif 'feature' == key:
239 | feature = sample['feature']
240 | sample['feature'] = torch.from_numpy(feature.astype(np.float32))
241 | return sample
242 |
243 |
--------------------------------------------------------------------------------
/lib/net/__init__.py:
--------------------------------------------------------------------------------
1 | from .deeplabv1_wo_interp import *
--------------------------------------------------------------------------------
/lib/net/backbone/__init__.py:
--------------------------------------------------------------------------------
1 | from .builder import build_backbone
2 | from .resnet38d import *
3 | from .resnet import *
4 | from .xception import *
5 |
6 | __all__ = ['build_backbone']
7 |
--------------------------------------------------------------------------------
/lib/net/backbone/builder.py:
--------------------------------------------------------------------------------
1 | # ----------------------------------------
2 | # Written by Yude Wang
3 | # ----------------------------------------
4 |
5 | from utils.registry import BACKBONES
6 |
7 | def build_backbone(backbone_name, pretrained=True, **kwargs):
8 | net = BACKBONES.get(backbone_name)(pretrained=pretrained, **kwargs)
9 | return net
10 |
--------------------------------------------------------------------------------
/lib/net/backbone/resnet.py:
--------------------------------------------------------------------------------
1 | """Dilated ResNet"""
2 | import math
3 | import torch
4 | import torch.utils.model_zoo as model_zoo
5 | import torch.nn as nn
6 | from net.sync_batchnorm import SynchronizedBatchNorm2d
7 | from utils.registry import BACKBONES
8 |
9 | __all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101',
10 | 'resnet152', 'BasicBlock', 'Bottleneck']
11 | bn_mom = 0.1
12 | model_urls = {
13 | 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth',
14 | 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth',
15 | 'resnet50': '~/.cache/torch/checkpoints/resnet50s-a75c83cf.pth',
16 | 'resnet101': '/home/wangyude/.cache/torch/checkpoints/resnet101s-03a0f310.pth',
17 | 'resnet152': '~/.cache/torch/checkpoints/resnet152s-36670e8b.pth',
18 | #'resnet50': 'https://s3.us-west-1.wasabisys.com/encoding/models/resnet50s-a75c83cf.zip',
19 | #'resnet101': 'https://s3.us-west-1.wasabisys.com/encoding/models/resnet101s-03a0f310.zip',
20 | #'resnet152': 'https://s3.us-west-1.wasabisys.com/encoding/models/resnet152s-36670e8b.zip'
21 | }
22 | mean = (0.485, 0.456, 0.406)
23 | std = (0.229, 0.224, 0.225)
24 |
25 | def conv3x3(in_planes, out_planes, stride=1):
26 | "3x3 convolution with padding"
27 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
28 | padding=1, bias=False)
29 |
30 |
31 | class BasicBlock(nn.Module):
32 | """ResNet BasicBlock
33 | """
34 | expansion = 1
35 | def __init__(self, inplanes, planes, stride=1, dilation=1, downsample=None, previous_dilation=1,
36 | norm_layer=None):
37 | super(BasicBlock, self).__init__()
38 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=3, stride=stride,
39 | padding=dilation, dilation=dilation, bias=False)
40 | self.bn1 = norm_layer(planes, momentum=bn_mom, affine=True)
41 | self.relu = nn.ReLU(inplace=True)
42 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1,
43 | padding=previous_dilation, dilation=previous_dilation, bias=False)
44 | self.bn2 = norm_layer(planes, momentum=bn_mom, affine=True)
45 | self.downsample = downsample
46 | self.stride = stride
47 |
48 | def forward(self, x):
49 | residual = x
50 |
51 | out = self.conv1(x)
52 | out = self.bn1(out)
53 | out = self.relu(out)
54 |
55 | out = self.conv2(out)
56 | out = self.bn2(out)
57 |
58 | if self.downsample is not None:
59 | residual = self.downsample(x)
60 |
61 | out += residual
62 | out = self.relu(out)
63 |
64 | return out
65 |
66 |
67 | class Bottleneck(nn.Module):
68 | """ResNet Bottleneck
69 | """
70 | # pylint: disable=unused-argument
71 | expansion = 4
72 | def __init__(self, inplanes, planes, stride=1, dilation=1,
73 | downsample=None, previous_dilation=1, norm_layer=None):
74 | super(Bottleneck, self).__init__()
75 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
76 | self.bn1 = norm_layer(planes, momentum=bn_mom, affine=True)
77 | self.conv2 = nn.Conv2d(
78 | planes, planes, kernel_size=3, stride=stride,
79 | padding=dilation, dilation=dilation, bias=False)
80 | self.bn2 = norm_layer(planes, momentum=bn_mom, affine=True)
81 | self.conv3 = nn.Conv2d(
82 | planes, planes * 4, kernel_size=1, bias=False)
83 | self.bn3 = norm_layer(planes * 4, momentum=bn_mom, affine=True)
84 | self.relu = nn.ReLU(inplace=True)
85 | self.downsample = downsample
86 | self.dilation = dilation
87 | self.stride = stride
88 |
89 | def _sum_each(self, x, y):
90 | assert(len(x) == len(y))
91 | z = []
92 | for i in range(len(x)):
93 | z.append(x[i]+y[i])
94 | return z
95 |
96 | def forward(self, x):
97 | residual = x
98 |
99 | out = self.conv1(x)
100 | out = self.bn1(out)
101 | out = self.relu(out)
102 |
103 | out = self.conv2(out)
104 | out = self.bn2(out)
105 | out = self.relu(out)
106 |
107 | out = self.conv3(out)
108 | out = self.bn3(out)
109 |
110 | if self.downsample is not None:
111 | residual = self.downsample(x)
112 |
113 | out += residual
114 | out = self.relu(out)
115 |
116 | return out
117 |
118 |
119 | class ResNet(nn.Module):
120 | """Dilated Pre-trained ResNet Model, which preduces the stride of 8 featuremaps at conv5.
121 |
122 | Parameters
123 | ----------
124 | block : Block
125 | Class for the residual block. Options are BasicBlockV1, BottleneckV1.
126 | layers : list of int
127 | Numbers of layers in each block
128 | classes : int, default 1000
129 | Number of classification classes.
130 | dilated : bool, default False
131 | Applying dilation strategy to pretrained ResNet yielding a stride-8 model,
132 | typically used in Semantic Segmentation.
133 | norm_layer : object
134 | Normalization layer used in backbone network (default: :class:`mxnet.gluon.nn.BatchNorm`;
135 | for Synchronized Cross-GPU BachNormalization).
136 |
137 | Reference:
138 |
139 | - He, Kaiming, et al. "Deep residual learning for image recognition." Proceedings of the IEEE conference on computer vision and pattern recognition. 2016.
140 |
141 | - Yu, Fisher, and Vladlen Koltun. "Multi-scale context aggregation by dilated convolutions."
142 | """
143 | # pylint: disable=unused-variable
144 | def __init__(self, block, layers, dilated=True, multi_grid=False,
145 | deep_base=True, norm_layer=nn.BatchNorm2d):
146 | self.inplanes = 128 if deep_base else 64
147 | super(ResNet, self).__init__()
148 | if deep_base:
149 | self.conv1 = nn.Sequential(
150 | nn.Conv2d(3, 64, kernel_size=3, stride=2, padding=1, bias=False),
151 | norm_layer(64, momentum=bn_mom, affine=True),
152 | nn.ReLU(inplace=True),
153 | nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1, bias=False),
154 | norm_layer(64, momentum=bn_mom, affine=True),
155 | nn.ReLU(inplace=True),
156 | nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1, bias=False),
157 | )
158 | else:
159 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3,
160 | bias=False)
161 | self.bn1 = norm_layer(self.inplanes, momentum=bn_mom, affine=True)
162 | self.relu = nn.ReLU(inplace=True)
163 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
164 | self.layer1 = self._make_layer(block, 64, layers[0], norm_layer=norm_layer)
165 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2, norm_layer=norm_layer)
166 | if dilated:
167 | self.layer3 = self._make_layer(block, 256, layers[2], stride=1,
168 | dilation=2, norm_layer=norm_layer)
169 | if multi_grid:
170 | self.layer4 = self._make_layer(block, 512, layers[3], stride=1,
171 | dilation=4, norm_layer=norm_layer,
172 | multi_grid=True)
173 | else:
174 | self.layer4 = self._make_layer(block, 512, layers[3], stride=1,
175 | dilation=4, norm_layer=norm_layer)
176 | else:
177 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2,
178 | norm_layer=norm_layer)
179 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2,
180 | norm_layer=norm_layer)
181 | self.OUTPUT_DIM = 2048
182 | self.MIDDLE_DIM = 256
183 |
184 | for m in self.modules():
185 | if isinstance(m, nn.Conv2d):
186 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
187 | m.weight.data.normal_(0, math.sqrt(2. / n))
188 | elif isinstance(m, norm_layer):
189 | m.weight.data.fill_(1)
190 | m.bias.data.zero_()
191 |
192 | def _make_layer(self, block, planes, blocks, stride=1, dilation=1, norm_layer=None, multi_grid=False):
193 | downsample = None
194 | if stride != 1 or self.inplanes != planes * block.expansion:
195 | downsample = nn.Sequential(
196 | nn.Conv2d(self.inplanes, planes * block.expansion,
197 | kernel_size=1, stride=stride, bias=False),
198 | norm_layer(planes * block.expansion, momentum=bn_mom, affine=True),
199 | )
200 |
201 | layers = []
202 | #multi_dilations = [4, 8, 16]
203 | multi_dilations = [3, 4, 5]
204 | if multi_grid:
205 | layers.append(block(self.inplanes, planes, stride, dilation=multi_dilations[0],
206 | downsample=downsample, previous_dilation=dilation, norm_layer=norm_layer))
207 | elif dilation == 1 or dilation == 2:
208 | layers.append(block(self.inplanes, planes, stride, dilation=1,
209 | downsample=downsample, previous_dilation=dilation, norm_layer=norm_layer))
210 | elif dilation == 4:
211 | layers.append(block(self.inplanes, planes, stride, dilation=2,
212 | downsample=downsample, previous_dilation=dilation, norm_layer=norm_layer))
213 | else:
214 | raise RuntimeError("=> unknown dilation size: {}".format(dilation))
215 |
216 | self.inplanes = planes * block.expansion
217 | for i in range(1, blocks):
218 | if multi_grid:
219 | layers.append(block(self.inplanes, planes, dilation=multi_dilations[i],
220 | previous_dilation=dilation, norm_layer=norm_layer))
221 | else:
222 | layers.append(block(self.inplanes, planes, dilation=dilation, previous_dilation=dilation,
223 | norm_layer=norm_layer))
224 |
225 | return nn.Sequential(*layers)
226 |
227 | def forward(self, x):
228 | x = self.conv1(x)
229 | x = self.bn1(x)
230 | x = self.relu(x)
231 | x = self.maxpool(x)
232 |
233 | l1 = self.layer1(x)
234 | l2 = self.layer2(l1)
235 | l3 = self.layer3(l2)
236 | l4 = self.layer4(l3)
237 | return [l1, l2, l3, l4]
238 |
239 | @BACKBONES.register_module
240 | def resnet18(pretrained=False, **kwargs):
241 | """Constructs a ResNet-18 model.
242 |
243 | Args:
244 | pretrained (bool): If True, returns a model pre-trained on ImageNet
245 | """
246 | model = ResNet(BasicBlock, [2, 2, 2, 2], **kwargs)
247 | if pretrained:
248 | model.load_state_dict(model_zoo.load_url(model_urls['resnet18']))
249 | return model
250 |
251 |
252 | @BACKBONES.register_module
253 | def resnet34(pretrained=False, **kwargs):
254 | """Constructs a ResNet-34 model.
255 |
256 | Args:
257 | pretrained (bool): If True, returns a model pre-trained on ImageNet
258 | """
259 | model = ResNet(BasicBlock, [3, 4, 6, 3], **kwargs)
260 | if pretrained:
261 | model.load_state_dict(model_zoo.load_url(model_urls['resnet34']))
262 | return model
263 |
264 |
265 | @BACKBONES.register_module
266 | def resnet50(pretrained=False, **kwargs):
267 | """Constructs a ResNet-50 model.
268 |
269 | Args:
270 | pretrained (bool): If True, returns a model pre-trained on ImageNet
271 | """
272 | model = ResNet(Bottleneck, [3, 4, 6, 3], **kwargs)
273 | if pretrained:
274 | old_dict = model_zoo.load_url(model_urls['resnet50'])
275 | model_dict = model.state_dict()
276 | old_dict = {k: v for k,v in old_dict.items() if (k in model_dict)}
277 | model_dict.update(old_dict)
278 | model.load_state_dict(model_dict)
279 | print('%s loaded.'%model_urls['resnet50'])
280 | return model
281 |
282 |
283 | @BACKBONES.register_module
284 | def resnet101(pretrained=False, **kwargs):
285 | """Constructs a ResNet-101 model.
286 |
287 | Args:
288 | pretrained (bool): If True, returns a model pre-trained on ImageNet
289 | """
290 | model = ResNet(Bottleneck, [3, 4, 23, 3], **kwargs)
291 | if pretrained:
292 | old_dict = torch.load(model_urls['resnet101'])
293 | model_dict = model.state_dict()
294 | old_dict = {k: v for k,v in old_dict.items() if (k in model_dict)}
295 | model_dict.update(old_dict)
296 | model.load_state_dict(model_dict)
297 | print('%s loaded.'%model_urls['resnet101'])
298 | return model
299 |
300 |
301 | @BACKBONES.register_module
302 | def resnet152(pretrained=False, **kwargs):
303 | """Constructs a ResNet-152 model.
304 |
305 | Args:
306 | pretrained (bool): If True, returns a model pre-trained on ImageNet
307 | """
308 | model = ResNet(Bottleneck, [3, 8, 36, 3], **kwargs)
309 | if pretrained:
310 | old_dict = model_zoo.load_url(model_urls['resnet152'])
311 | model_dict = model.state_dict()
312 | old_dict = {k: v for k,v in old_dict.items() if (k in model_dict)}
313 | model_dict.update(old_dict)
314 | model.load_state_dict(model_dict)
315 | print('%s loaded.'%model_urls['resnet152'])
316 | return model
317 |
--------------------------------------------------------------------------------
/lib/net/backbone/resnet38d.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 | import numpy as np
4 | import torch.nn.functional as F
5 | from utils.registry import BACKBONES
6 |
7 | model_url='/scratch/kl3141/seam/SEAM-master/model_weight/ilsvrc-cls_rna-a1_cls1000_ep-0001.params'
8 | bn_mom = 0.0003
9 |
10 | class ResBlock(nn.Module):
11 | def __init__(self, in_channels, mid_channels, out_channels, stride=1, first_dilation=None, dilation=1, norm_layer=nn.BatchNorm2d):
12 | super(ResBlock, self).__init__()
13 | self.norm_layer = norm_layer
14 |
15 | self.same_shape = (in_channels == out_channels and stride == 1)
16 |
17 | if first_dilation == None: first_dilation = dilation
18 |
19 | self.bn_branch2a = self.norm_layer(in_channels, momentum=bn_mom, affine=True)
20 |
21 | self.conv_branch2a = nn.Conv2d(in_channels, mid_channels, 3, stride,
22 | padding=first_dilation, dilation=first_dilation, bias=False)
23 |
24 | self.bn_branch2b1 = self.norm_layer(mid_channels, momentum=bn_mom, affine=True)
25 |
26 | self.conv_branch2b1 = nn.Conv2d(mid_channels, out_channels, 3, padding=dilation, dilation=dilation, bias=False)
27 |
28 | if not self.same_shape:
29 | self.conv_branch1 = nn.Conv2d(in_channels, out_channels, 1, stride, bias=False)
30 |
31 | def forward(self, x, get_x_bn_relu=False):
32 |
33 | branch2 = self.bn_branch2a(x)
34 | branch2 = F.relu(branch2)
35 |
36 | x_bn_relu = branch2
37 |
38 | if not self.same_shape:
39 | branch1 = self.conv_branch1(branch2)
40 | else:
41 | branch1 = x
42 |
43 | branch2 = self.conv_branch2a(branch2)
44 | branch2 = self.bn_branch2b1(branch2)
45 | branch2 = F.relu(branch2)
46 | branch2 = self.conv_branch2b1(branch2)
47 |
48 | x = branch1 + branch2
49 |
50 | if get_x_bn_relu:
51 | return x, x_bn_relu
52 |
53 | return x
54 |
55 | def __call__(self, x, get_x_bn_relu=False):
56 | return self.forward(x, get_x_bn_relu=get_x_bn_relu)
57 |
58 | class ResBlock_bot(nn.Module):
59 | def __init__(self, in_channels, out_channels, stride=1, dilation=1, dropout=0., norm_layer=nn.BatchNorm2d):
60 | super(ResBlock_bot, self).__init__()
61 | self.norm_layer = norm_layer
62 |
63 | self.same_shape = (in_channels == out_channels and stride == 1)
64 |
65 | self.bn_branch2a = self.norm_layer(in_channels, momentum=bn_mom, affine=True)
66 | self.conv_branch2a = nn.Conv2d(in_channels, out_channels//4, 1, stride, bias=False)
67 |
68 | self.bn_branch2b1 = self.norm_layer(out_channels//4, momentum=bn_mom, affine=True)
69 | self.dropout_2b1 = torch.nn.Dropout2d(dropout)
70 | self.conv_branch2b1 = nn.Conv2d(out_channels//4, out_channels//2, 3, padding=dilation, dilation=dilation, bias=False)
71 |
72 | self.bn_branch2b2 = self.norm_layer(out_channels//2, momentum=bn_mom, affine=True)
73 | self.dropout_2b2 = torch.nn.Dropout2d(dropout)
74 | self.conv_branch2b2 = nn.Conv2d(out_channels//2, out_channels, 1, bias=False)
75 |
76 | if not self.same_shape:
77 | self.conv_branch1 = nn.Conv2d(in_channels, out_channels, 1, stride, bias=False)
78 |
79 | def forward(self, x, get_x_bn_relu=False):
80 |
81 | branch2 = self.bn_branch2a(x)
82 | branch2 = F.relu(branch2)
83 | x_bn_relu = branch2
84 |
85 | branch1 = self.conv_branch1(branch2)
86 |
87 | branch2 = self.conv_branch2a(branch2)
88 |
89 | branch2 = self.bn_branch2b1(branch2)
90 | branch2 = F.relu(branch2)
91 | branch2 = self.dropout_2b1(branch2)
92 | branch2 = self.conv_branch2b1(branch2)
93 |
94 | branch2 = self.bn_branch2b2(branch2)
95 | branch2 = F.relu(branch2)
96 | branch2 = self.dropout_2b2(branch2)
97 | branch2 = self.conv_branch2b2(branch2)
98 |
99 | x = branch1 + branch2
100 |
101 | if get_x_bn_relu:
102 | return x, x_bn_relu
103 |
104 | return x
105 |
106 | def __call__(self, x, get_x_bn_relu=False):
107 | return self.forward(x, get_x_bn_relu=get_x_bn_relu)
108 |
109 | class Normalize():
110 | def __init__(self, mean = (0.485, 0.456, 0.406), std = (0.229, 0.224, 0.225)):
111 |
112 | self.mean = mean
113 | self.std = std
114 |
115 | def __call__(self, img):
116 | imgarr = np.asarray(img)
117 | proc_img = np.empty_like(imgarr, np.float32)
118 |
119 | proc_img[..., 0] = (imgarr[..., 0] / 255. - self.mean[0]) / self.std[0]
120 | proc_img[..., 1] = (imgarr[..., 1] / 255. - self.mean[1]) / self.std[1]
121 | proc_img[..., 2] = (imgarr[..., 2] / 255. - self.mean[2]) / self.std[2]
122 |
123 | return proc_img
124 |
125 | class Net(nn.Module):
126 | def __init__(self, norm_layer=nn.BatchNorm2d):
127 | super(Net, self).__init__()
128 | self.norm_layer = norm_layer
129 |
130 | self.conv1a = nn.Conv2d(3, 64, 3, padding=1, bias=False)
131 |
132 | self.b2 = ResBlock(64, 128, 128, stride=2, norm_layer=self.norm_layer)
133 | self.b2_1 = ResBlock(128, 128, 128, norm_layer=self.norm_layer)
134 | self.b2_2 = ResBlock(128, 128, 128, norm_layer=self.norm_layer)
135 |
136 | self.b3 = ResBlock(128, 256, 256, stride=2, norm_layer=self.norm_layer)
137 | self.b3_1 = ResBlock(256, 256, 256, norm_layer=self.norm_layer)
138 | self.b3_2 = ResBlock(256, 256, 256, norm_layer=self.norm_layer)
139 |
140 | self.b4 = ResBlock(256, 512, 512, stride=2, norm_layer=self.norm_layer)
141 | self.b4_1 = ResBlock(512, 512, 512, norm_layer=self.norm_layer)
142 | self.b4_2 = ResBlock(512, 512, 512, norm_layer=self.norm_layer)
143 | self.b4_3 = ResBlock(512, 512, 512, norm_layer=self.norm_layer)
144 | self.b4_4 = ResBlock(512, 512, 512, norm_layer=self.norm_layer)
145 | self.b4_5 = ResBlock(512, 512, 512, norm_layer=self.norm_layer)
146 |
147 | self.b5 = ResBlock(512, 512, 1024, stride=1, first_dilation=1, dilation=2, norm_layer=self.norm_layer)
148 | self.b5_1 = ResBlock(1024, 512, 1024, dilation=2, norm_layer=self.norm_layer)
149 | self.b5_2 = ResBlock(1024, 512, 1024, dilation=2, norm_layer=self.norm_layer)
150 |
151 | self.b6 = ResBlock_bot(1024, 2048, stride=1, dilation=4, dropout=0.3, norm_layer=self.norm_layer)
152 |
153 | self.b7 = ResBlock_bot(2048, 4096, dilation=4, dropout=0.5, norm_layer=self.norm_layer)
154 |
155 | self.bn7 = self.norm_layer(4096, momentum=bn_mom, affine=True)
156 |
157 | self.not_training = [self.conv1a]
158 |
159 | self.normalize = Normalize()
160 | self.OUTPUT_DIM = 4096
161 |
162 | def forward(self, x):
163 |
164 | x = self.conv1a(x)
165 |
166 | x = self.b2(x)
167 | x = self.b2_1(x)
168 | x = self.b2_2(x)
169 |
170 | x = self.b3(x)
171 | x = self.b3_1(x)
172 | x = self.b3_2(x)
173 |
174 | x = self.b4(x)
175 | x = self.b4_1(x)
176 | x = self.b4_2(x)
177 | x = self.b4_3(x)
178 | x = self.b4_4(x)
179 | x = self.b4_5(x)
180 |
181 | x, conv4 = self.b5(x, get_x_bn_relu=True)
182 | x = self.b5_1(x)
183 | x = self.b5_2(x)
184 |
185 | x, conv5 = self.b6(x, get_x_bn_relu=True)
186 |
187 | x = self.b7(x)
188 | conv6 = F.relu(self.bn7(x))
189 |
190 | return [conv4, conv5, conv6]
191 |
192 | def train(self, mode=True):
193 |
194 | super().train(mode)
195 |
196 | for layer in self.not_training:
197 |
198 | if isinstance(layer, torch.nn.Conv2d):
199 | layer.weight.requires_grad = False
200 |
201 | elif isinstance(layer, torch.nn.Module):
202 | for c in layer.children():
203 | c.weight.requires_grad = False
204 | if c.bias is not None:
205 | c.bias.requires_grad = False
206 |
207 | for layer in self.modules():
208 |
209 | if isinstance(layer, self.norm_layer):
210 | layer.eval()
211 | layer.bias.requires_grad = False
212 | layer.weight.requires_grad = False
213 |
214 | return
215 |
216 | def convert_mxnet_to_torch(filename):
217 | import mxnet
218 |
219 | save_dict = mxnet.nd.load(filename)
220 |
221 | renamed_dict = dict()
222 |
223 | bn_param_mx_pt = {'beta': 'bias', 'gamma': 'weight', 'mean': 'running_mean', 'var': 'running_var'}
224 |
225 | for k, v in save_dict.items():
226 |
227 | v = torch.from_numpy(v.asnumpy())
228 | toks = k.split('_')
229 |
230 | if 'conv1a' in toks[0]:
231 | renamed_dict['conv1a.weight'] = v
232 |
233 | elif 'linear1000' in toks[0]:
234 | pass
235 |
236 | elif 'branch' in toks[1]:
237 |
238 | pt_name = []
239 |
240 | if toks[0][-1] != 'a':
241 | pt_name.append('b' + toks[0][-3] + '_' + toks[0][-1])
242 | else:
243 | pt_name.append('b' + toks[0][-2])
244 |
245 | if 'res' in toks[0]:
246 | layer_type = 'conv'
247 | last_name = 'weight'
248 |
249 | else: # 'bn' in toks[0]:
250 | layer_type = 'bn'
251 | last_name = bn_param_mx_pt[toks[-1]]
252 |
253 | pt_name.append(layer_type + '_' + toks[1])
254 |
255 | pt_name.append(last_name)
256 |
257 | torch_name = '.'.join(pt_name)
258 | renamed_dict[torch_name] = v
259 |
260 | else:
261 | last_name = bn_param_mx_pt[toks[-1]]
262 | renamed_dict['bn7.' + last_name] = v
263 |
264 | return renamed_dict
265 |
266 | @BACKBONES.register_module
267 | def resnet38(pretrained=False, norm_layer=nn.BatchNorm2d, **kwargs):
268 | model = Net(norm_layer)
269 | if pretrained:
270 | weight_dict = convert_mxnet_to_torch(model_url)
271 | model.load_state_dict(weight_dict,strict=False)
272 | return model
273 |
--------------------------------------------------------------------------------
/lib/net/backbone/xception.py:
--------------------------------------------------------------------------------
1 | """
2 | Ported to pytorch thanks to [tstandley](https://github.com/tstandley/Xception-PyTorch)
3 | @author: tstandley
4 | Adapted by cadene
5 | Creates an Xception Model as defined in:
6 | Francois Chollet
7 | Xception: Deep Learning with Depthwise Separable Convolutions
8 | https://arxiv.org/pdf/1610.02357.pdf
9 | This weights ported from the Keras implementation. Achieves the following performance on the validation set:
10 | Loss:0.9173 Prec@1:78.892 Prec@5:94.292
11 | REMEMBER to set your image size to 3x299x299 for both test and validation
12 | normalize = transforms.Normalize(mean=[0.5, 0.5, 0.5],
13 | std=[0.5, 0.5, 0.5])
14 | The resize parameter of the validation transform should be 333, and make sure to center crop at 299x299
15 | """
16 | import math
17 | import torch
18 | import torch.nn as nn
19 | import torch.nn.functional as F
20 | import torch.utils.model_zoo as model_zoo
21 | from torch.nn import init
22 | from net.sync_batchnorm import SynchronizedBatchNorm2d
23 | from utils.registry import BACKBONES
24 |
25 | bn_mom = 0.1
26 | __all__ = ['xception']
27 |
28 | model_urls = {
29 | 'xception': '/home/wangyude/.cache/torch/checkpoints/xception_pytorch_imagenet.pth'#'http://data.lip6.fr/cadene/pretrainedmodels/xception-b5690688.pth'
30 | }
31 |
32 | class SeparableConv2d(nn.Module):
33 | def __init__(self,in_channels,out_channels,kernel_size=1,stride=1,padding=0,dilation=1,bias=False,activate_first=True,inplace=True,norm_layer=nn.BatchNorm2d):
34 | super(SeparableConv2d,self).__init__()
35 | self.norm_layer = norm_layer
36 | self.relu0 = nn.ReLU(inplace=inplace)
37 | self.depthwise = nn.Conv2d(in_channels,in_channels,kernel_size,stride,padding,dilation,groups=in_channels,bias=bias)
38 | self.bn1 = self.norm_layer(in_channels, momentum=bn_mom)
39 | self.relu1 = nn.ReLU(inplace=True)
40 | self.pointwise = nn.Conv2d(in_channels,out_channels,1,1,0,1,1,bias=bias)
41 | self.bn2 = self.norm_layer(out_channels, momentum=bn_mom)
42 | self.relu2 = nn.ReLU(inplace=True)
43 | self.activate_first = activate_first
44 | def forward(self,x):
45 | if self.activate_first:
46 | x = self.relu0(x)
47 | x = self.depthwise(x)
48 | x = self.bn1(x)
49 | if not self.activate_first:
50 | x = self.relu1(x)
51 | x = self.pointwise(x)
52 | x = self.bn2(x)
53 | if not self.activate_first:
54 | x = self.relu2(x)
55 | return x
56 |
57 |
58 | class Block(nn.Module):
59 | def __init__(self,in_filters,out_filters,strides=1,atrous=None,grow_first=True,activate_first=True,inplace=True,norm_layer=nn.BatchNorm2d):
60 | super(Block, self).__init__()
61 | self.norm_layer = norm_layer
62 | if atrous == None:
63 | atrous = [1]*3
64 | elif isinstance(atrous, int):
65 | atrous_list = [atrous]*3
66 | atrous = atrous_list
67 | idx = 0
68 | self.head_relu = True
69 | if out_filters != in_filters or strides!=1:
70 | self.skip = nn.Conv2d(in_filters,out_filters,1,stride=strides, bias=False)
71 | self.skipbn = self.norm_layer(out_filters, momentum=bn_mom)
72 | self.head_relu = False
73 | else:
74 | self.skip=None
75 |
76 | self.hook_layer = None
77 | if grow_first:
78 | filters = out_filters
79 | else:
80 | filters = in_filters
81 | self.sepconv1 = SeparableConv2d(in_filters,filters,3,stride=1,padding=1*atrous[0],dilation=atrous[0],bias=False,activate_first=activate_first,inplace=self.head_relu,norm_layer=self.norm_layer)
82 | self.sepconv2 = SeparableConv2d(filters,out_filters,3,stride=1,padding=1*atrous[1],dilation=atrous[1],bias=False,activate_first=activate_first,norm_layer=self.norm_layer)
83 | self.sepconv3 = SeparableConv2d(out_filters,out_filters,3,stride=strides,padding=1*atrous[2],dilation=atrous[2],bias=False,activate_first=activate_first,inplace=inplace,norm_layer=self.norm_layer)
84 |
85 | def forward(self,inp):
86 |
87 | if self.skip is not None:
88 | skip = self.skip(inp)
89 | skip = self.skipbn(skip)
90 | else:
91 | skip = inp
92 |
93 | x = self.sepconv1(inp)
94 | x = self.sepconv2(x)
95 | self.hook_layer = x
96 | x = self.sepconv3(x)
97 |
98 | x+=skip
99 | return x
100 |
101 |
102 | class Xception(nn.Module):
103 | """
104 | Xception optimized for the ImageNet dataset, as specified in
105 | https://arxiv.org/pdf/1610.02357.pdf
106 | """
107 | def __init__(self, os, norm_layer=nn.BatchNorm2d):
108 | """ Constructor
109 | Args:
110 | num_classes: number of classes
111 | """
112 | super(Xception, self).__init__()
113 | self.norm_layer = norm_layer
114 |
115 | stride_list = None
116 | if os == 8:
117 | stride_list = [2,1,1]
118 | elif os == 16:
119 | stride_list = [2,2,1]
120 | else:
121 | raise ValueError('xception.py: output stride=%d is not supported.'%os)
122 | self.conv1 = nn.Conv2d(3, 32, 3, 2, 1, bias=False)
123 | self.bn1 = self.norm_layer(32, momentum=bn_mom)
124 | self.relu = nn.ReLU(inplace=True)
125 |
126 | self.conv2 = nn.Conv2d(32,64,3,1,1,bias=False)
127 | self.bn2 = self.norm_layer(64, momentum=bn_mom)
128 | #do relu here
129 |
130 | self.block1=Block(64,128,2,norm_layer=self.norm_layer)
131 | self.block2=Block(128,256,stride_list[0],inplace=False,norm_layer=self.norm_layer)
132 | self.block3=Block(256,728,stride_list[1],norm_layer=self.norm_layer)
133 |
134 | rate = 16//os
135 | self.block4=Block(728,728,1,atrous=rate,norm_layer=self.norm_layer)
136 | self.block5=Block(728,728,1,atrous=rate,norm_layer=self.norm_layer)
137 | self.block6=Block(728,728,1,atrous=rate,norm_layer=self.norm_layer)
138 | self.block7=Block(728,728,1,atrous=rate,norm_layer=self.norm_layer)
139 |
140 | self.block8=Block(728,728,1,atrous=rate,norm_layer=self.norm_layer)
141 | self.block9=Block(728,728,1,atrous=rate,norm_layer=self.norm_layer)
142 | self.block10=Block(728,728,1,atrous=rate,norm_layer=self.norm_layer)
143 | self.block11=Block(728,728,1,atrous=rate,norm_layer=self.norm_layer)
144 |
145 | self.block12=Block(728,728,1,atrous=rate,norm_layer=self.norm_layer)
146 | self.block13=Block(728,728,1,atrous=rate,norm_layer=self.norm_layer)
147 | self.block14=Block(728,728,1,atrous=rate,norm_layer=self.norm_layer)
148 | self.block15=Block(728,728,1,atrous=rate,norm_layer=self.norm_layer)
149 |
150 | self.block16=Block(728,728,1,atrous=[1*rate,1*rate,1*rate],norm_layer=self.norm_layer)
151 | self.block17=Block(728,728,1,atrous=[1*rate,1*rate,1*rate],norm_layer=self.norm_layer)
152 | self.block18=Block(728,728,1,atrous=[1*rate,1*rate,1*rate],norm_layer=self.norm_layer)
153 | self.block19=Block(728,728,1,atrous=[1*rate,1*rate,1*rate],norm_layer=self.norm_layer)
154 |
155 | self.block20=Block(728,1024,stride_list[2],atrous=rate,grow_first=False,norm_layer=self.norm_layer)
156 | #self.block12=Block(728,1024,2,2,start_with_relu=True,grow_first=False)
157 |
158 | self.conv3 = SeparableConv2d(1024,1536,3,1,1*rate,dilation=rate,activate_first=False,norm_layer=self.norm_layer)
159 | # self.bn3 = SynchronizedBatchNorm2d(1536, momentum=bn_mom)
160 |
161 | self.conv4 = SeparableConv2d(1536,1536,3,1,1*rate,dilation=rate,activate_first=False,norm_layer=self.norm_layer)
162 | # self.bn4 = SynchronizedBatchNorm2d(1536, momentum=bn_mom)
163 |
164 | #do relu here
165 | self.conv5 = SeparableConv2d(1536,2048,3,1,1*rate,dilation=rate,activate_first=False,norm_layer=self.norm_layer)
166 | # self.bn5 = SynchronizedBatchNorm2d(2048, momentum=bn_mom)
167 | self.OUTPUT_DIM = 2048
168 | self.MIDDLE_DIM = 256
169 |
170 | #------- init weights --------
171 | for m in self.modules():
172 | if isinstance(m, nn.Conv2d):
173 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
174 | m.weight.data.normal_(0, math.sqrt(2. / n))
175 | elif isinstance(m, self.norm_layer):
176 | m.weight.data.fill_(1)
177 | m.bias.data.zero_()
178 | #-----------------------------
179 |
180 | def forward(self, input):
181 | layers = []
182 | x = self.conv1(input)
183 | x = self.bn1(x)
184 | x = self.relu(x)
185 | #self.layers.append(x)
186 | x = self.conv2(x)
187 | x = self.bn2(x)
188 | x = self.relu(x)
189 |
190 | x = self.block1(x)
191 | x = self.block2(x)
192 | l1 = self.block2.hook_layer
193 | x = self.block3(x)
194 | l2 = self.block3.hook_layer
195 | x = self.block4(x)
196 | x = self.block5(x)
197 | x = self.block6(x)
198 | x = self.block7(x)
199 | x = self.block8(x)
200 | x = self.block9(x)
201 | x = self.block10(x)
202 | x = self.block11(x)
203 | x = self.block12(x)
204 | x = self.block13(x)
205 | x = self.block14(x)
206 | x = self.block15(x)
207 | x = self.block16(x)
208 | x = self.block17(x)
209 | x = self.block18(x)
210 | x = self.block19(x)
211 | x = self.block20(x)
212 | l3 = self.block20.hook_layer
213 |
214 | x = self.conv3(x)
215 | # x = self.bn3(x)
216 | # x = self.relu(x)
217 |
218 | x = self.conv4(x)
219 | # x = self.bn4(x)
220 | # x = self.relu(x)
221 |
222 | l4 = self.conv5(x)
223 | # x = self.bn5(x)
224 | # x = self.relu(x)
225 |
226 | #return layers
227 | return [l1,l2,l3,l4]
228 |
229 | @BACKBONES.register_module
230 | def xception(pretrained=True, os=8, **kwargs):
231 | model = Xception(os=os)
232 | if pretrained:
233 | old_dict = torch.load(model_urls['xception'])
234 | # old_dict = model_zoo.load_url(model_urls['xception'])
235 | # for name, weights in old_dict.items():
236 | # if 'pointwise' in name:
237 | # old_dict[name] = weights.unsqueeze(-1).unsqueeze(-1)
238 | model_dict = model.state_dict()
239 | old_dict = {k: v for k,v in old_dict.items() if ('itr' not in k and 'tmp' not in k and 'track' not in k)}
240 | model_dict.update(old_dict)
241 |
242 | model.load_state_dict(model_dict)
243 |
244 | return model
245 |
--------------------------------------------------------------------------------
/lib/net/deeplabv1_wo_interp.py:
--------------------------------------------------------------------------------
1 | # ----------------------------------------
2 | # Written by Yude Wang
3 | # ----------------------------------------
4 |
5 | import numpy as np
6 | import torch
7 | import torch.nn as nn
8 | import torch.nn.functional as F
9 | from torch.nn import init
10 | from net.backbone import build_backbone
11 | from utils.registry import NETS
12 |
13 | @NETS.register_module
14 | class deeplabv1_wo_interp(nn.Module):
15 | def __init__(self, cfg, batchnorm=nn.BatchNorm2d, **kwargs):
16 | super(deeplabv1_wo_interp, self).__init__()
17 | self.cfg = cfg
18 | self.batchnorm = batchnorm
19 | #self.backbone = build_backbone(self.cfg.MODEL_BACKBONE, os=self.cfg.MODEL_OUTPUT_STRIDE)
20 | self.backbone = build_backbone(self.cfg.MODEL_BACKBONE, pretrained=cfg.MODEL_BACKBONE_PRETRAIN, norm_layer=self.batchnorm, **kwargs)
21 | self.conv_fov = nn.Conv2d(self.backbone.OUTPUT_DIM, 512, 3, 1, padding=12, dilation=12, bias=False)
22 | self.bn_fov = batchnorm(512, momentum=cfg.TRAIN_BN_MOM, affine=True)
23 | self.conv_fov2 = nn.Conv2d(512, 512, 1, 1, padding=0, bias=False)
24 | self.bn_fov2 = batchnorm(512, momentum=cfg.TRAIN_BN_MOM, affine=True)
25 | self.dropout1 = nn.Dropout(0.5)
26 | self.cls_conv = nn.Conv2d(512, cfg.MODEL_NUM_CLASSES, 1, 1, padding=0)
27 | self.__initial__()
28 | self.not_training = []#[self.backbone.conv1a, self.backbone.b2, self.backbone.b2_1, self.backbone.b2_2]
29 | #self.from_scratch_layers = [self.cls_conv]
30 | self.from_scratch_layers = [self.conv_fov, self.conv_fov2, self.cls_conv]
31 |
32 | def __initial__(self):
33 | for m in self.modules():
34 | if m not in self.backbone.modules():
35 | if isinstance(m, nn.Conv2d) or isinstance(m, nn.Conv1d):
36 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
37 | elif isinstance(m, self.batchnorm):
38 | nn.init.constant_(m.weight, 1)
39 | nn.init.constant_(m.bias, 0)
40 | #self.backbone = build_backbone(self.cfg.MODEL_BACKBONE, pretrained=self.cfg.MODEL_BACKBONE_PRETRAIN)
41 |
42 | def forward(self, x):
43 | n,c,h,w = x.size()
44 | x_bottom = self.backbone(x)[-1]
45 | feature = self.conv_fov(x_bottom)
46 | feature = self.bn_fov(feature)
47 | feature = F.relu(feature, inplace=True)
48 | feature = self.conv_fov2(feature)
49 | feature = self.bn_fov2(feature)
50 | feature = F.relu(feature, inplace=True)
51 | feature = self.dropout1(feature)
52 | result = self.cls_conv(feature)
53 | # result = F.interpolate(result,(h,w),mode='bilinear', align_corners=True)
54 | return result
55 |
56 | def get_parameter_groups(self):
57 | groups = ([], [], [], [])
58 | for m in self.modules():
59 | if isinstance(m, nn.Conv2d):
60 | if m.weight.requires_grad:
61 | if m in self.from_scratch_layers:
62 | groups[2].append(m.weight)
63 | else:
64 | groups[0].append(m.weight)
65 |
66 | if m.bias is not None and m.bias.requires_grad:
67 |
68 | if m in self.from_scratch_layers:
69 | groups[3].append(m.bias)
70 | else:
71 | groups[1].append(m.bias)
72 | return groups
--------------------------------------------------------------------------------
/lib/net/generateNet.py:
--------------------------------------------------------------------------------
1 | # ----------------------------------------
2 | # Written by Yude Wang
3 | # ----------------------------------------
4 |
5 | #from net.deeplabv3plus import deeplabv3plus
6 | #from net.deeplabv3 import deeplabv3, deeplabv3_noise, deeplabv3_feature, deeplabv3_glore
7 | #from net.deeplabv2 import deeplabv2, deeplabv2_caffe
8 | #from net.deeplabv1 import deeplabv1, deeplabv1_caffe
9 | #from net.clsnet import ClsNet
10 | #from net.fcn import FCN
11 | #from net.DFANet import DFANet
12 | from utils.registry import NETS
13 |
14 | def generate_net(cfg, **kwargs):
15 | net = NETS.get(cfg.MODEL_NAME)(cfg, **kwargs)
16 | return net
17 | #def generate_net(cfg):
18 | # if cfg.MODEL_NAME == 'deeplabv3plus' or cfg.MODEL_NAME == 'deeplabv3+':
19 | # return deeplabv3plus(cfg)
20 | # elif cfg.MODEL_NAME == 'deeplabv3':
21 | # return deeplabv3(cfg)
22 | # elif cfg.MODEL_NAME == 'deeplabv2':
23 | # return deeplabv2(cfg)
24 | # elif cfg.MODEL_NAME == 'deeplabv1':
25 | # return deeplabv1(cfg)
26 | # elif cfg.MODEL_NAME == 'deeplabv1_caffe':
27 | # return deeplabv1_caffe(cfg)
28 | # elif cfg.MODEL_NAME == 'deeplabv2_caffe':
29 | # return deeplabv2_caffe(cfg)
30 | # elif cfg.MODEL_NAME == 'clsnet' or cfg.MODEL_NAME == 'ClsNet':
31 | # return ClsNet(cfg)
32 | # elif cfg.MODEL_NAME == 'fcn' or cfg.MODEL_NAME == 'FCN':
33 | # return FCN(cfg)
34 | # elif cfg.MODEL_NAME == 'DFANet' or cfg.MODEL_NAME == 'dfanet':
35 | # return DFANet(cfg)
36 | # else:
37 | # raise ValueError('generateNet.py: network %s is not support yet'%cfg.MODEL_NAME)
38 |
--------------------------------------------------------------------------------
/lib/net/operators/ASPP.py:
--------------------------------------------------------------------------------
1 | # ----------------------------------------
2 | # Written by Yude Wang
3 | # ----------------------------------------
4 |
5 | import torch
6 | import torch.nn as nn
7 | import numpy as np
8 | import torch.nn.functional as F
9 | from net.sync_batchnorm import SynchronizedBatchNorm2d
10 |
11 | class ASPP(nn.Module):
12 |
13 | def __init__(self, dim_in, dim_out, rate=[1,6,12,18], bn_mom=0.1, has_global=True, batchnorm=SynchronizedBatchNorm2d):
14 | super(ASPP, self).__init__()
15 | self.dim_in = dim_in
16 | self.dim_out = dim_out
17 | self.has_global = has_global
18 | if rate[0] == 0:
19 | self.branch1 = nn.Sequential(
20 | nn.Conv2d(dim_in, dim_out, 1, 1, padding=0, dilation=1,bias=False),
21 | batchnorm(dim_out, momentum=bn_mom, affine=True),
22 | nn.ReLU(inplace=True),
23 | )
24 | else:
25 | self.branch1 = nn.Sequential(
26 | nn.Conv2d(dim_in, dim_out, 3, 1, padding=rate[0], dilation=rate[0],bias=False),
27 | batchnorm(dim_out, momentum=bn_mom, affine=True),
28 | nn.ReLU(inplace=True),
29 | )
30 | self.branch2 = nn.Sequential(
31 | nn.Conv2d(dim_in, dim_out, 3, 1, padding=rate[1], dilation=rate[1],bias=False),
32 | batchnorm(dim_out, momentum=bn_mom, affine=True),
33 | nn.ReLU(inplace=True),
34 | )
35 | self.branch3 = nn.Sequential(
36 | nn.Conv2d(dim_in, dim_out, 3, 1, padding=rate[2], dilation=rate[2],bias=False),
37 | batchnorm(dim_out, momentum=bn_mom, affine=True),
38 | nn.ReLU(inplace=True),
39 | )
40 | self.branch4 = nn.Sequential(
41 | nn.Conv2d(dim_in, dim_out, 3, 1, padding=rate[3], dilation=rate[3],bias=False),
42 | batchnorm(dim_out, momentum=bn_mom, affine=True),
43 | nn.ReLU(inplace=True),
44 | )
45 | if self.has_global:
46 | self.branch5_conv = nn.Conv2d(dim_in, dim_out, 1, 1, 0,bias=False)
47 | self.branch5_bn = batchnorm(dim_out, momentum=bn_mom, affine=True)
48 | self.branch5_relu = nn.ReLU(inplace=True)
49 | self.conv_cat = nn.Sequential(
50 | nn.Conv2d(dim_out*5, dim_out, 1, 1, padding=0,bias=False),
51 | batchnorm(dim_out, momentum=bn_mom, affine=True),
52 | nn.ReLU(inplace=True),
53 | nn.Dropout(0.5)
54 | )
55 | else:
56 | self.conv_cat = nn.Sequential(
57 | nn.Conv2d(dim_out*4, dim_out, 1, 1, padding=0),
58 | batchnorm(dim_out, momentum=bn_mom, affine=True),
59 | nn.ReLU(inplace=True),
60 | nn.Dropout(0.5)
61 | )
62 | def forward(self, x):
63 | result = None
64 | [b,c,row,col] = x.size()
65 | conv1x1 = self.branch1(x)
66 | conv3x3_1 = self.branch2(x)
67 | conv3x3_2 = self.branch3(x)
68 | conv3x3_3 = self.branch4(x)
69 | if self.has_global:
70 | global_feature = F.adaptive_avg_pool2d(x, (1,1))
71 | global_feature = self.branch5_conv(global_feature)
72 | global_feature = self.branch5_bn(global_feature)
73 | global_feature = self.branch5_relu(global_feature)
74 | global_feature = F.interpolate(global_feature, (row,col), None, 'bilinear', align_corners=True)
75 |
76 | feature_cat = torch.cat([conv1x1, conv3x3_1, conv3x3_2, conv3x3_3, global_feature], dim=1)
77 | else:
78 | feature_cat = torch.cat([conv1x1, conv3x3_1, conv3x3_2, conv3x3_3], dim=1)
79 | result = self.conv_cat(feature_cat)
80 |
81 | return result
82 |
--------------------------------------------------------------------------------
/lib/net/operators/PPM.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 | class PPM(nn.Module):
6 | """
7 | Reference:
8 | Zhao, Hengshuang, et al. *"Pyramid scene parsing network."*
9 | """
10 | def __init__(self, features, out_features=512, sizes=(1, 2, 3, 6), norm_layer=nn.BatchNorm2d):
11 | super(PPM, self).__init__()
12 |
13 | self.stages = []
14 | self.stages = nn.ModuleList([self._make_stage(features, out_features, size, norm_layer) for size in sizes])
15 | self.bottleneck = nn.Sequential(
16 | nn.Conv2d(features+len(sizes)*out_features, out_features, kernel_size=1, padding=0, dilation=1, bias=False),
17 | norm_layer(out_features),
18 | nn.ReLU(),
19 | nn.Dropout2d(0.1)
20 | )
21 |
22 | def _make_stage(self, features, out_features, size, norm_layer):
23 | prior = nn.AdaptiveAvgPool2d(output_size=(size, size))
24 | conv = nn.Conv2d(features, out_features, kernel_size=1, bias=False)
25 | bn = norm_layer(out_features)
26 | return nn.Sequential(prior, conv, bn)
27 |
28 | def forward(self, feats):
29 | h, w = feats.size(2), feats.size(3)
30 | priors = [F.upsample(input=stage(feats), size=(h, w), mode='bilinear', align_corners=True) for stage in self.stages] + [feats]
31 | bottle = self.bottleneck(torch.cat(priors, 1))
32 | return bottle
33 |
--------------------------------------------------------------------------------
/lib/net/operators/__init__.py:
--------------------------------------------------------------------------------
1 | from .ASPP import *
2 | from .PPM import *
3 |
--------------------------------------------------------------------------------
/lib/net/sync_batchnorm/__init__.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # File : __init__.py
3 | # Author : Jiayuan Mao
4 | # Email : maojiayuan@gmail.com
5 | # Date : 27/01/2018
6 | #
7 | # This file is part of Synchronized-BatchNorm-PyTorch.
8 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch
9 | # Distributed under MIT License.
10 |
11 | from .batchnorm import SynchronizedBatchNorm1d, SynchronizedBatchNorm2d, SynchronizedBatchNorm3d
12 | from .replicate import DataParallelWithCallback, patch_replication_callback
13 |
--------------------------------------------------------------------------------
/lib/net/sync_batchnorm/comm.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # File : comm.py
3 | # Author : Jiayuan Mao
4 | # Email : maojiayuan@gmail.com
5 | # Date : 27/01/2018
6 | #
7 | # This file is part of Synchronized-BatchNorm-PyTorch.
8 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch
9 | # Distributed under MIT License.
10 |
11 | import queue
12 | import collections
13 | import threading
14 |
15 | __all__ = ['FutureResult', 'SlavePipe', 'SyncMaster']
16 |
17 |
18 | class FutureResult(object):
19 | """A thread-safe future implementation. Used only as one-to-one pipe."""
20 |
21 | def __init__(self):
22 | self._result = None
23 | self._lock = threading.Lock()
24 | self._cond = threading.Condition(self._lock)
25 |
26 | def put(self, result):
27 | with self._lock:
28 | assert self._result is None, 'Previous result has\'t been fetched.'
29 | self._result = result
30 | self._cond.notify()
31 |
32 | def get(self):
33 | with self._lock:
34 | if self._result is None:
35 | self._cond.wait()
36 |
37 | res = self._result
38 | self._result = None
39 | return res
40 |
41 |
42 | _MasterRegistry = collections.namedtuple('MasterRegistry', ['result'])
43 | _SlavePipeBase = collections.namedtuple('_SlavePipeBase', ['identifier', 'queue', 'result'])
44 |
45 |
46 | class SlavePipe(_SlavePipeBase):
47 | """Pipe for master-slave communication."""
48 |
49 | def run_slave(self, msg):
50 | self.queue.put((self.identifier, msg))
51 | ret = self.result.get()
52 | self.queue.put(True)
53 | return ret
54 |
55 |
56 | class SyncMaster(object):
57 | """An abstract `SyncMaster` object.
58 |
59 | - During the replication, as the data parallel will trigger an callback of each module, all slave devices should
60 | call `register(id)` and obtain an `SlavePipe` to communicate with the master.
61 | - During the forward pass, master device invokes `run_master`, all messages from slave devices will be collected,
62 | and passed to a registered callback.
63 | - After receiving the messages, the master device should gather the information and determine to message passed
64 | back to each slave devices.
65 | """
66 |
67 | def __init__(self, master_callback):
68 | """
69 |
70 | Args:
71 | master_callback: a callback to be invoked after having collected messages from slave devices.
72 | """
73 | self._master_callback = master_callback
74 | self._queue = queue.Queue()
75 | self._registry = collections.OrderedDict()
76 | self._activated = False
77 |
78 | def register_slave(self, identifier):
79 | """
80 | Register an slave device.
81 |
82 | Args:
83 | identifier: an identifier, usually is the device id.
84 |
85 | Returns: a `SlavePipe` object which can be used to communicate with the master device.
86 |
87 | """
88 | if self._activated:
89 | assert self._queue.empty(), 'Queue is not clean before next initialization.'
90 | self._activated = False
91 | self._registry.clear()
92 | future = FutureResult()
93 | self._registry[identifier] = _MasterRegistry(future)
94 | return SlavePipe(identifier, self._queue, future)
95 |
96 | def run_master(self, master_msg):
97 | """
98 | Main entry for the master device in each forward pass.
99 | The messages were first collected from each devices (including the master device), and then
100 | an callback will be invoked to compute the message to be sent back to each devices
101 | (including the master device).
102 |
103 | Args:
104 | master_msg: the message that the master want to send to itself. This will be placed as the first
105 | message when calling `master_callback`. For detailed usage, see `_SynchronizedBatchNorm` for an example.
106 |
107 | Returns: the message to be sent back to the master device.
108 |
109 | """
110 | self._activated = True
111 |
112 | intermediates = [(0, master_msg)]
113 | for i in range(self.nr_slaves):
114 | intermediates.append(self._queue.get())
115 |
116 | results = self._master_callback(intermediates)
117 | assert results[0][0] == 0, 'The first result should belongs to the master.'
118 |
119 | for i, res in results:
120 | if i == 0:
121 | continue
122 | self._registry[i].result.put(res)
123 |
124 | for i in range(self.nr_slaves):
125 | assert self._queue.get() is True
126 |
127 | return results[0][1]
128 |
129 | @property
130 | def nr_slaves(self):
131 | return len(self._registry)
132 |
--------------------------------------------------------------------------------
/lib/net/sync_batchnorm/replicate.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # File : replicate.py
3 | # Author : Jiayuan Mao
4 | # Email : maojiayuan@gmail.com
5 | # Date : 27/01/2018
6 | #
7 | # This file is part of Synchronized-BatchNorm-PyTorch.
8 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch
9 | # Distributed under MIT License.
10 |
11 | import functools
12 |
13 | from torch.nn.parallel.data_parallel import DataParallel
14 |
15 | __all__ = [
16 | 'CallbackContext',
17 | 'execute_replication_callbacks',
18 | 'DataParallelWithCallback',
19 | 'patch_replication_callback'
20 | ]
21 |
22 |
23 | class CallbackContext(object):
24 | pass
25 |
26 |
27 | def execute_replication_callbacks(modules):
28 | """
29 | Execute an replication callback `__data_parallel_replicate__` on each module created by original replication.
30 |
31 | The callback will be invoked with arguments `__data_parallel_replicate__(ctx, copy_id)`
32 |
33 | Note that, as all modules are isomorphism, we assign each sub-module with a context
34 | (shared among multiple copies of this module on different devices).
35 | Through this context, different copies can share some information.
36 |
37 | We guarantee that the callback on the master copy (the first copy) will be called ahead of calling the callback
38 | of any slave copies.
39 | """
40 | master_copy = modules[0]
41 | nr_modules = len(list(master_copy.modules()))
42 | ctxs = [CallbackContext() for _ in range(nr_modules)]
43 |
44 | for i, module in enumerate(modules):
45 | for j, m in enumerate(module.modules()):
46 | if hasattr(m, '__data_parallel_replicate__'):
47 | m.__data_parallel_replicate__(ctxs[j], i)
48 |
49 |
50 | class DataParallelWithCallback(DataParallel):
51 | """
52 | Data Parallel with a replication callback.
53 |
54 | An replication callback `__data_parallel_replicate__` of each module will be invoked after being created by
55 | original `replicate` function.
56 | The callback will be invoked with arguments `__data_parallel_replicate__(ctx, copy_id)`
57 |
58 | Examples:
59 | > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False)
60 | > sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1])
61 | # sync_bn.__data_parallel_replicate__ will be invoked.
62 | """
63 |
64 | def replicate(self, module, device_ids):
65 | modules = super(DataParallelWithCallback, self).replicate(module, device_ids)
66 | execute_replication_callbacks(modules)
67 | return modules
68 |
69 |
70 | def patch_replication_callback(data_parallel):
71 | """
72 | Monkey-patch an existing `DataParallel` object. Add the replication callback.
73 | Useful when you have customized `DataParallel` implementation.
74 |
75 | Examples:
76 | > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False)
77 | > sync_bn = DataParallel(sync_bn, device_ids=[0, 1])
78 | > patch_replication_callback(sync_bn)
79 | # this is equivalent to
80 | > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False)
81 | > sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1])
82 | """
83 |
84 | assert isinstance(data_parallel, DataParallel)
85 |
86 | old_replicate = data_parallel.replicate
87 |
88 | @functools.wraps(old_replicate)
89 | def new_replicate(module, device_ids):
90 | modules = old_replicate(module, device_ids)
91 | execute_replication_callbacks(modules)
92 | return modules
93 |
94 | data_parallel.replicate = new_replicate
95 |
--------------------------------------------------------------------------------
/lib/net/sync_batchnorm/sync_batchnorm/__init__.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # File : __init__.py
3 | # Author : Jiayuan Mao
4 | # Email : maojiayuan@gmail.com
5 | # Date : 27/01/2018
6 | #
7 | # This file is part of Synchronized-BatchNorm-PyTorch.
8 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch
9 | # Distributed under MIT License.
10 |
11 | from .batchnorm import SynchronizedBatchNorm1d, SynchronizedBatchNorm2d, SynchronizedBatchNorm3d
12 | from .replicate import DataParallelWithCallback, patch_replication_callback
13 |
--------------------------------------------------------------------------------
/lib/net/sync_batchnorm/sync_batchnorm/batchnorm.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # File : batchnorm.py
3 | # Author : Jiayuan Mao
4 | # Email : maojiayuan@gmail.com
5 | # Date : 27/01/2018
6 | #
7 | # This file is part of Synchronized-BatchNorm-PyTorch.
8 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch
9 | # Distributed under MIT License.
10 |
11 | import collections
12 |
13 | import torch
14 | import torch.nn.functional as F
15 |
16 | from torch.nn.modules.batchnorm import _BatchNorm
17 | from torch.nn.parallel._functions import ReduceAddCoalesced, Broadcast
18 |
19 | from .comm import SyncMaster
20 |
21 | __all__ = ['SynchronizedBatchNorm1d', 'SynchronizedBatchNorm2d', 'SynchronizedBatchNorm3d']
22 |
23 |
24 | def _sum_ft(tensor):
25 | """sum over the first and last dimention"""
26 | return tensor.sum(dim=0).sum(dim=-1)
27 |
28 |
29 | def _unsqueeze_ft(tensor):
30 | """add new dementions at the front and the tail"""
31 | return tensor.unsqueeze(0).unsqueeze(-1)
32 |
33 |
34 | _ChildMessage = collections.namedtuple('_ChildMessage', ['sum', 'ssum', 'sum_size'])
35 | _MasterMessage = collections.namedtuple('_MasterMessage', ['sum', 'inv_std'])
36 |
37 |
38 | class _SynchronizedBatchNorm(_BatchNorm):
39 | def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True):
40 | super(_SynchronizedBatchNorm, self).__init__(num_features, eps=eps, momentum=momentum, affine=affine)
41 |
42 | self._sync_master = SyncMaster(self._data_parallel_master)
43 |
44 | self._is_parallel = False
45 | self._parallel_id = None
46 | self._slave_pipe = None
47 |
48 | def forward(self, input):
49 | # If it is not parallel computation or is in evaluation mode, use PyTorch's implementation.
50 | if not (self._is_parallel and self.training):
51 | return F.batch_norm(
52 | input, self.running_mean, self.running_var, self.weight, self.bias,
53 | self.training, self.momentum, self.eps)
54 |
55 | # Resize the input to (B, C, -1).
56 | input_shape = input.size()
57 | input = input.view(input.size(0), self.num_features, -1)
58 |
59 | # Compute the sum and square-sum.
60 | sum_size = input.size(0) * input.size(2)
61 | input_sum = _sum_ft(input)
62 | input_ssum = _sum_ft(input ** 2)
63 |
64 | # Reduce-and-broadcast the statistics.
65 | if self._parallel_id == 0:
66 | mean, inv_std = self._sync_master.run_master(_ChildMessage(input_sum, input_ssum, sum_size))
67 | else:
68 | mean, inv_std = self._slave_pipe.run_slave(_ChildMessage(input_sum, input_ssum, sum_size))
69 |
70 | # Compute the output.
71 | if self.affine:
72 | # MJY:: Fuse the multiplication for speed.
73 | output = (input - _unsqueeze_ft(mean)) * _unsqueeze_ft(inv_std * self.weight) + _unsqueeze_ft(self.bias)
74 | else:
75 | output = (input - _unsqueeze_ft(mean)) * _unsqueeze_ft(inv_std)
76 |
77 | # Reshape it.
78 | return output.view(input_shape)
79 |
80 | def __data_parallel_replicate__(self, ctx, copy_id):
81 | self._is_parallel = True
82 | self._parallel_id = copy_id
83 |
84 | # parallel_id == 0 means master device.
85 | if self._parallel_id == 0:
86 | ctx.sync_master = self._sync_master
87 | else:
88 | self._slave_pipe = ctx.sync_master.register_slave(copy_id)
89 |
90 | def _data_parallel_master(self, intermediates):
91 | """Reduce the sum and square-sum, compute the statistics, and broadcast it."""
92 |
93 | # Always using same "device order" makes the ReduceAdd operation faster.
94 | # Thanks to:: Tete Xiao (http://tetexiao.com/)
95 | intermediates = sorted(intermediates, key=lambda i: i[1].sum.get_device())
96 |
97 | to_reduce = [i[1][:2] for i in intermediates]
98 | to_reduce = [j for i in to_reduce for j in i] # flatten
99 | target_gpus = [i[1].sum.get_device() for i in intermediates]
100 |
101 | sum_size = sum([i[1].sum_size for i in intermediates])
102 | sum_, ssum = ReduceAddCoalesced.apply(target_gpus[0], 2, *to_reduce)
103 | mean, inv_std = self._compute_mean_std(sum_, ssum, sum_size)
104 |
105 | broadcasted = Broadcast.apply(target_gpus, mean, inv_std)
106 |
107 | outputs = []
108 | for i, rec in enumerate(intermediates):
109 | outputs.append((rec[0], _MasterMessage(*broadcasted[i*2:i*2+2])))
110 |
111 | return outputs
112 |
113 | def _compute_mean_std(self, sum_, ssum, size):
114 | """Compute the mean and standard-deviation with sum and square-sum. This method
115 | also maintains the moving average on the master device."""
116 | assert size > 1, 'BatchNorm computes unbiased standard-deviation, which requires size > 1.'
117 | mean = sum_ / size
118 | sumvar = ssum - sum_ * mean
119 | unbias_var = sumvar / (size - 1)
120 | bias_var = sumvar / size
121 |
122 | self.running_mean = (1 - self.momentum) * self.running_mean + self.momentum * mean.data
123 | self.running_var = (1 - self.momentum) * self.running_var + self.momentum * unbias_var.data
124 |
125 | return mean, bias_var.clamp(self.eps) ** -0.5
126 |
127 |
128 | class SynchronizedBatchNorm1d(_SynchronizedBatchNorm):
129 | r"""Applies Synchronized Batch Normalization over a 2d or 3d input that is seen as a
130 | mini-batch.
131 |
132 | .. math::
133 |
134 | y = \frac{x - mean[x]}{ \sqrt{Var[x] + \epsilon}} * gamma + beta
135 |
136 | This module differs from the built-in PyTorch BatchNorm1d as the mean and
137 | standard-deviation are reduced across all devices during training.
138 |
139 | For example, when one uses `nn.DataParallel` to wrap the network during
140 | training, PyTorch's implementation normalize the tensor on each device using
141 | the statistics only on that device, which accelerated the computation and
142 | is also easy to implement, but the statistics might be inaccurate.
143 | Instead, in this synchronized version, the statistics will be computed
144 | over all training samples distributed on multiple devices.
145 |
146 | Note that, for one-GPU or CPU-only case, this module behaves exactly same
147 | as the built-in PyTorch implementation.
148 |
149 | The mean and standard-deviation are calculated per-dimension over
150 | the mini-batches and gamma and beta are learnable parameter vectors
151 | of size C (where C is the input size).
152 |
153 | During training, this layer keeps a running estimate of its computed mean
154 | and variance. The running sum is kept with a default momentum of 0.1.
155 |
156 | During evaluation, this running mean/variance is used for normalization.
157 |
158 | Because the BatchNorm is done over the `C` dimension, computing statistics
159 | on `(N, L)` slices, it's common terminology to call this Temporal BatchNorm
160 |
161 | Args:
162 | num_features: num_features from an expected input of size
163 | `batch_size x num_features [x width]`
164 | eps: a value added to the denominator for numerical stability.
165 | Default: 1e-5
166 | momentum: the value used for the running_mean and running_var
167 | computation. Default: 0.1
168 | affine: a boolean value that when set to ``True``, gives the layer learnable
169 | affine parameters. Default: ``True``
170 |
171 | Shape:
172 | - Input: :math:`(N, C)` or :math:`(N, C, L)`
173 | - Output: :math:`(N, C)` or :math:`(N, C, L)` (same shape as input)
174 |
175 | Examples:
176 | >>> # With Learnable Parameters
177 | >>> m = SynchronizedBatchNorm1d(100)
178 | >>> # Without Learnable Parameters
179 | >>> m = SynchronizedBatchNorm1d(100, affine=False)
180 | >>> input = torch.autograd.Variable(torch.randn(20, 100))
181 | >>> output = m(input)
182 | """
183 |
184 | def _check_input_dim(self, input):
185 | if input.dim() != 2 and input.dim() != 3:
186 | raise ValueError('expected 2D or 3D input (got {}D input)'
187 | .format(input.dim()))
188 | super(SynchronizedBatchNorm1d, self)._check_input_dim(input)
189 |
190 |
191 | class SynchronizedBatchNorm2d(_SynchronizedBatchNorm):
192 | r"""Applies Batch Normalization over a 4d input that is seen as a mini-batch
193 | of 3d inputs
194 |
195 | .. math::
196 |
197 | y = \frac{x - mean[x]}{ \sqrt{Var[x] + \epsilon}} * gamma + beta
198 |
199 | This module differs from the built-in PyTorch BatchNorm2d as the mean and
200 | standard-deviation are reduced across all devices during training.
201 |
202 | For example, when one uses `nn.DataParallel` to wrap the network during
203 | training, PyTorch's implementation normalize the tensor on each device using
204 | the statistics only on that device, which accelerated the computation and
205 | is also easy to implement, but the statistics might be inaccurate.
206 | Instead, in this synchronized version, the statistics will be computed
207 | over all training samples distributed on multiple devices.
208 |
209 | Note that, for one-GPU or CPU-only case, this module behaves exactly same
210 | as the built-in PyTorch implementation.
211 |
212 | The mean and standard-deviation are calculated per-dimension over
213 | the mini-batches and gamma and beta are learnable parameter vectors
214 | of size C (where C is the input size).
215 |
216 | During training, this layer keeps a running estimate of its computed mean
217 | and variance. The running sum is kept with a default momentum of 0.1.
218 |
219 | During evaluation, this running mean/variance is used for normalization.
220 |
221 | Because the BatchNorm is done over the `C` dimension, computing statistics
222 | on `(N, H, W)` slices, it's common terminology to call this Spatial BatchNorm
223 |
224 | Args:
225 | num_features: num_features from an expected input of
226 | size batch_size x num_features x height x width
227 | eps: a value added to the denominator for numerical stability.
228 | Default: 1e-5
229 | momentum: the value used for the running_mean and running_var
230 | computation. Default: 0.1
231 | affine: a boolean value that when set to ``True``, gives the layer learnable
232 | affine parameters. Default: ``True``
233 |
234 | Shape:
235 | - Input: :math:`(N, C, H, W)`
236 | - Output: :math:`(N, C, H, W)` (same shape as input)
237 |
238 | Examples:
239 | >>> # With Learnable Parameters
240 | >>> m = SynchronizedBatchNorm2d(100)
241 | >>> # Without Learnable Parameters
242 | >>> m = SynchronizedBatchNorm2d(100, affine=False)
243 | >>> input = torch.autograd.Variable(torch.randn(20, 100, 35, 45))
244 | >>> output = m(input)
245 | """
246 |
247 | def _check_input_dim(self, input):
248 | if input.dim() != 4:
249 | raise ValueError('expected 4D input (got {}D input)'
250 | .format(input.dim()))
251 | super(SynchronizedBatchNorm2d, self)._check_input_dim(input)
252 |
253 |
254 | class SynchronizedBatchNorm3d(_SynchronizedBatchNorm):
255 | r"""Applies Batch Normalization over a 5d input that is seen as a mini-batch
256 | of 4d inputs
257 |
258 | .. math::
259 |
260 | y = \frac{x - mean[x]}{ \sqrt{Var[x] + \epsilon}} * gamma + beta
261 |
262 | This module differs from the built-in PyTorch BatchNorm3d as the mean and
263 | standard-deviation are reduced across all devices during training.
264 |
265 | For example, when one uses `nn.DataParallel` to wrap the network during
266 | training, PyTorch's implementation normalize the tensor on each device using
267 | the statistics only on that device, which accelerated the computation and
268 | is also easy to implement, but the statistics might be inaccurate.
269 | Instead, in this synchronized version, the statistics will be computed
270 | over all training samples distributed on multiple devices.
271 |
272 | Note that, for one-GPU or CPU-only case, this module behaves exactly same
273 | as the built-in PyTorch implementation.
274 |
275 | The mean and standard-deviation are calculated per-dimension over
276 | the mini-batches and gamma and beta are learnable parameter vectors
277 | of size C (where C is the input size).
278 |
279 | During training, this layer keeps a running estimate of its computed mean
280 | and variance. The running sum is kept with a default momentum of 0.1.
281 |
282 | During evaluation, this running mean/variance is used for normalization.
283 |
284 | Because the BatchNorm is done over the `C` dimension, computing statistics
285 | on `(N, D, H, W)` slices, it's common terminology to call this Volumetric BatchNorm
286 | or Spatio-temporal BatchNorm
287 |
288 | Args:
289 | num_features: num_features from an expected input of
290 | size batch_size x num_features x depth x height x width
291 | eps: a value added to the denominator for numerical stability.
292 | Default: 1e-5
293 | momentum: the value used for the running_mean and running_var
294 | computation. Default: 0.1
295 | affine: a boolean value that when set to ``True``, gives the layer learnable
296 | affine parameters. Default: ``True``
297 |
298 | Shape:
299 | - Input: :math:`(N, C, D, H, W)`
300 | - Output: :math:`(N, C, D, H, W)` (same shape as input)
301 |
302 | Examples:
303 | >>> # With Learnable Parameters
304 | >>> m = SynchronizedBatchNorm3d(100)
305 | >>> # Without Learnable Parameters
306 | >>> m = SynchronizedBatchNorm3d(100, affine=False)
307 | >>> input = torch.autograd.Variable(torch.randn(20, 100, 35, 45, 10))
308 | >>> output = m(input)
309 | """
310 |
311 | def _check_input_dim(self, input):
312 | if input.dim() != 5:
313 | raise ValueError('expected 5D input (got {}D input)'
314 | .format(input.dim()))
315 | super(SynchronizedBatchNorm3d, self)._check_input_dim(input)
316 |
--------------------------------------------------------------------------------
/lib/net/sync_batchnorm/sync_batchnorm/batchnorm_reimpl.py:
--------------------------------------------------------------------------------
1 | #! /usr/bin/env python3
2 | # -*- coding: utf-8 -*-
3 | # File : batchnorm_reimpl.py
4 | # Author : acgtyrant
5 | # Date : 11/01/2018
6 | #
7 | # This file is part of Synchronized-BatchNorm-PyTorch.
8 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch
9 | # Distributed under MIT License.
10 |
11 | import torch
12 | import torch.nn as nn
13 | import torch.nn.init as init
14 |
15 | __all__ = ['BatchNormReimpl']
16 |
17 |
18 | class BatchNorm2dReimpl(nn.Module):
19 | """
20 | A re-implementation of batch normalization, used for testing the numerical
21 | stability.
22 |
23 | Author: acgtyrant
24 | See also:
25 | https://github.com/vacancy/Synchronized-BatchNorm-PyTorch/issues/14
26 | """
27 | def __init__(self, num_features, eps=1e-5, momentum=0.1):
28 | super().__init__()
29 |
30 | self.num_features = num_features
31 | self.eps = eps
32 | self.momentum = momentum
33 | self.weight = nn.Parameter(torch.empty(num_features))
34 | self.bias = nn.Parameter(torch.empty(num_features))
35 | self.register_buffer('running_mean', torch.zeros(num_features))
36 | self.register_buffer('running_var', torch.ones(num_features))
37 | self.reset_parameters()
38 |
39 | def reset_running_stats(self):
40 | self.running_mean.zero_()
41 | self.running_var.fill_(1)
42 |
43 | def reset_parameters(self):
44 | self.reset_running_stats()
45 | init.uniform_(self.weight)
46 | init.zeros_(self.bias)
47 |
48 | def forward(self, input_):
49 | batchsize, channels, height, width = input_.size()
50 | numel = batchsize * height * width
51 | input_ = input_.permute(1, 0, 2, 3).contiguous().view(channels, numel)
52 | sum_ = input_.sum(1)
53 | sum_of_square = input_.pow(2).sum(1)
54 | mean = sum_ / numel
55 | sumvar = sum_of_square - sum_ * mean
56 |
57 | self.running_mean = (
58 | (1 - self.momentum) * self.running_mean
59 | + self.momentum * mean.detach()
60 | )
61 | unbias_var = sumvar / (numel - 1)
62 | self.running_var = (
63 | (1 - self.momentum) * self.running_var
64 | + self.momentum * unbias_var.detach()
65 | )
66 |
67 | bias_var = sumvar / numel
68 | inv_std = 1 / (bias_var + self.eps).pow(0.5)
69 | output = (
70 | (input_ - mean.unsqueeze(1)) * inv_std.unsqueeze(1) *
71 | self.weight.unsqueeze(1) + self.bias.unsqueeze(1))
72 |
73 | return output.view(channels, batchsize, height, width).permute(1, 0, 2, 3).contiguous()
74 |
75 |
--------------------------------------------------------------------------------
/lib/net/sync_batchnorm/sync_batchnorm/comm.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # File : comm.py
3 | # Author : Jiayuan Mao
4 | # Email : maojiayuan@gmail.com
5 | # Date : 27/01/2018
6 | #
7 | # This file is part of Synchronized-BatchNorm-PyTorch.
8 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch
9 | # Distributed under MIT License.
10 |
11 | import queue
12 | import collections
13 | import threading
14 |
15 | __all__ = ['FutureResult', 'SlavePipe', 'SyncMaster']
16 |
17 |
18 | class FutureResult(object):
19 | """A thread-safe future implementation. Used only as one-to-one pipe."""
20 |
21 | def __init__(self):
22 | self._result = None
23 | self._lock = threading.Lock()
24 | self._cond = threading.Condition(self._lock)
25 |
26 | def put(self, result):
27 | with self._lock:
28 | assert self._result is None, 'Previous result has\'t been fetched.'
29 | self._result = result
30 | self._cond.notify()
31 |
32 | def get(self):
33 | with self._lock:
34 | if self._result is None:
35 | self._cond.wait()
36 |
37 | res = self._result
38 | self._result = None
39 | return res
40 |
41 |
42 | _MasterRegistry = collections.namedtuple('MasterRegistry', ['result'])
43 | _SlavePipeBase = collections.namedtuple('_SlavePipeBase', ['identifier', 'queue', 'result'])
44 |
45 |
46 | class SlavePipe(_SlavePipeBase):
47 | """Pipe for master-slave communication."""
48 |
49 | def run_slave(self, msg):
50 | self.queue.put((self.identifier, msg))
51 | ret = self.result.get()
52 | self.queue.put(True)
53 | return ret
54 |
55 |
56 | class SyncMaster(object):
57 | """An abstract `SyncMaster` object.
58 |
59 | - During the replication, as the data parallel will trigger an callback of each module, all slave devices should
60 | call `register(id)` and obtain an `SlavePipe` to communicate with the master.
61 | - During the forward pass, master device invokes `run_master`, all messages from slave devices will be collected,
62 | and passed to a registered callback.
63 | - After receiving the messages, the master device should gather the information and determine to message passed
64 | back to each slave devices.
65 | """
66 |
67 | def __init__(self, master_callback):
68 | """
69 |
70 | Args:
71 | master_callback: a callback to be invoked after having collected messages from slave devices.
72 | """
73 | self._master_callback = master_callback
74 | self._queue = queue.Queue()
75 | self._registry = collections.OrderedDict()
76 | self._activated = False
77 |
78 | def __getstate__(self):
79 | return {'master_callback': self._master_callback}
80 |
81 | def __setstate__(self, state):
82 | self.__init__(state['master_callback'])
83 |
84 | def register_slave(self, identifier):
85 | """
86 | Register an slave device.
87 |
88 | Args:
89 | identifier: an identifier, usually is the device id.
90 |
91 | Returns: a `SlavePipe` object which can be used to communicate with the master device.
92 |
93 | """
94 | if self._activated:
95 | assert self._queue.empty(), 'Queue is not clean before next initialization.'
96 | self._activated = False
97 | self._registry.clear()
98 | future = FutureResult()
99 | self._registry[identifier] = _MasterRegistry(future)
100 | return SlavePipe(identifier, self._queue, future)
101 |
102 | def run_master(self, master_msg):
103 | """
104 | Main entry for the master device in each forward pass.
105 | The messages were first collected from each devices (including the master device), and then
106 | an callback will be invoked to compute the message to be sent back to each devices
107 | (including the master device).
108 |
109 | Args:
110 | master_msg: the message that the master want to send to itself. This will be placed as the first
111 | message when calling `master_callback`. For detailed usage, see `_SynchronizedBatchNorm` for an example.
112 |
113 | Returns: the message to be sent back to the master device.
114 |
115 | """
116 | self._activated = True
117 |
118 | intermediates = [(0, master_msg)]
119 | for i in range(self.nr_slaves):
120 | intermediates.append(self._queue.get())
121 |
122 | results = self._master_callback(intermediates)
123 | assert results[0][0] == 0, 'The first result should belongs to the master.'
124 |
125 | for i, res in results:
126 | if i == 0:
127 | continue
128 | self._registry[i].result.put(res)
129 |
130 | for i in range(self.nr_slaves):
131 | assert self._queue.get() is True
132 |
133 | return results[0][1]
134 |
135 | @property
136 | def nr_slaves(self):
137 | return len(self._registry)
138 |
--------------------------------------------------------------------------------
/lib/net/sync_batchnorm/sync_batchnorm/replicate.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # File : replicate.py
3 | # Author : Jiayuan Mao
4 | # Email : maojiayuan@gmail.com
5 | # Date : 27/01/2018
6 | #
7 | # This file is part of Synchronized-BatchNorm-PyTorch.
8 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch
9 | # Distributed under MIT License.
10 |
11 | import functools
12 |
13 | from torch.nn.parallel.data_parallel import DataParallel
14 |
15 | __all__ = [
16 | 'CallbackContext',
17 | 'execute_replication_callbacks',
18 | 'DataParallelWithCallback',
19 | 'patch_replication_callback'
20 | ]
21 |
22 |
23 | class CallbackContext(object):
24 | pass
25 |
26 |
27 | def execute_replication_callbacks(modules):
28 | """
29 | Execute an replication callback `__data_parallel_replicate__` on each module created by original replication.
30 |
31 | The callback will be invoked with arguments `__data_parallel_replicate__(ctx, copy_id)`
32 |
33 | Note that, as all modules are isomorphism, we assign each sub-module with a context
34 | (shared among multiple copies of this module on different devices).
35 | Through this context, different copies can share some information.
36 |
37 | We guarantee that the callback on the master copy (the first copy) will be called ahead of calling the callback
38 | of any slave copies.
39 | """
40 | master_copy = modules[0]
41 | nr_modules = len(list(master_copy.modules()))
42 | ctxs = [CallbackContext() for _ in range(nr_modules)]
43 |
44 | for i, module in enumerate(modules):
45 | for j, m in enumerate(module.modules()):
46 | if hasattr(m, '__data_parallel_replicate__'):
47 | m.__data_parallel_replicate__(ctxs[j], i)
48 |
49 |
50 | class DataParallelWithCallback(DataParallel):
51 | """
52 | Data Parallel with a replication callback.
53 |
54 | An replication callback `__data_parallel_replicate__` of each module will be invoked after being created by
55 | original `replicate` function.
56 | The callback will be invoked with arguments `__data_parallel_replicate__(ctx, copy_id)`
57 |
58 | Examples:
59 | > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False)
60 | > sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1])
61 | # sync_bn.__data_parallel_replicate__ will be invoked.
62 | """
63 |
64 | def replicate(self, module, device_ids):
65 | modules = super(DataParallelWithCallback, self).replicate(module, device_ids)
66 | execute_replication_callbacks(modules)
67 | return modules
68 |
69 |
70 | def patch_replication_callback(data_parallel):
71 | """
72 | Monkey-patch an existing `DataParallel` object. Add the replication callback.
73 | Useful when you have customized `DataParallel` implementation.
74 |
75 | Examples:
76 | > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False)
77 | > sync_bn = DataParallel(sync_bn, device_ids=[0, 1])
78 | > patch_replication_callback(sync_bn)
79 | # this is equivalent to
80 | > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False)
81 | > sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1])
82 | """
83 |
84 | assert isinstance(data_parallel, DataParallel)
85 |
86 | old_replicate = data_parallel.replicate
87 |
88 | @functools.wraps(old_replicate)
89 | def new_replicate(module, device_ids):
90 | modules = old_replicate(module, device_ids)
91 | execute_replication_callbacks(modules)
92 | return modules
93 |
94 | data_parallel.replicate = new_replicate
95 |
--------------------------------------------------------------------------------
/lib/net/sync_batchnorm/sync_batchnorm/unittest.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # File : unittest.py
3 | # Author : Jiayuan Mao
4 | # Email : maojiayuan@gmail.com
5 | # Date : 27/01/2018
6 | #
7 | # This file is part of Synchronized-BatchNorm-PyTorch.
8 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch
9 | # Distributed under MIT License.
10 |
11 | import unittest
12 | import torch
13 |
14 |
15 | class TorchTestCase(unittest.TestCase):
16 | def assertTensorClose(self, x, y):
17 | adiff = float((x - y).abs().max())
18 | if (y == 0).all():
19 | rdiff = 'NaN'
20 | else:
21 | rdiff = float((adiff / y).abs().max())
22 |
23 | message = (
24 | 'Tensor close check failed\n'
25 | 'adiff={}\n'
26 | 'rdiff={}\n'
27 | ).format(adiff, rdiff)
28 | self.assertTrue(torch.allclose(x, y), message)
29 |
30 |
--------------------------------------------------------------------------------
/lib/net/sync_batchnorm/tests/test_numeric_batchnorm.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # File : test_numeric_batchnorm.py
3 | # Author : Jiayuan Mao
4 | # Email : maojiayuan@gmail.com
5 | # Date : 27/01/2018
6 | #
7 | # This file is part of Synchronized-BatchNorm-PyTorch.
8 |
9 | import unittest
10 |
11 | import torch
12 | import torch.nn as nn
13 | from torch.autograd import Variable
14 |
15 | from sync_batchnorm.unittest import TorchTestCase
16 |
17 |
18 | def handy_var(a, unbias=True):
19 | n = a.size(0)
20 | asum = a.sum(dim=0)
21 | as_sum = (a ** 2).sum(dim=0) # a square sum
22 | sumvar = as_sum - asum * asum / n
23 | if unbias:
24 | return sumvar / (n - 1)
25 | else:
26 | return sumvar / n
27 |
28 |
29 | class NumericTestCase(TorchTestCase):
30 | def testNumericBatchNorm(self):
31 | a = torch.rand(16, 10)
32 | bn = nn.BatchNorm2d(10, momentum=1, eps=1e-5, affine=False)
33 | bn.train()
34 |
35 | a_var1 = Variable(a, requires_grad=True)
36 | b_var1 = bn(a_var1)
37 | loss1 = b_var1.sum()
38 | loss1.backward()
39 |
40 | a_var2 = Variable(a, requires_grad=True)
41 | a_mean2 = a_var2.mean(dim=0, keepdim=True)
42 | a_std2 = torch.sqrt(handy_var(a_var2, unbias=False).clamp(min=1e-5))
43 | # a_std2 = torch.sqrt(a_var2.var(dim=0, keepdim=True, unbiased=False) + 1e-5)
44 | b_var2 = (a_var2 - a_mean2) / a_std2
45 | loss2 = b_var2.sum()
46 | loss2.backward()
47 |
48 | self.assertTensorClose(bn.running_mean, a.mean(dim=0))
49 | self.assertTensorClose(bn.running_var, handy_var(a))
50 | self.assertTensorClose(a_var1.data, a_var2.data)
51 | self.assertTensorClose(b_var1.data, b_var2.data)
52 | self.assertTensorClose(a_var1.grad, a_var2.grad)
53 |
54 |
55 | if __name__ == '__main__':
56 | unittest.main()
57 |
--------------------------------------------------------------------------------
/lib/net/sync_batchnorm/tests/test_sync_batchnorm.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # File : test_sync_batchnorm.py
3 | # Author : Jiayuan Mao
4 | # Email : maojiayuan@gmail.com
5 | # Date : 27/01/2018
6 | #
7 | # This file is part of Synchronized-BatchNorm-PyTorch.
8 |
9 | import unittest
10 |
11 | import torch
12 | import torch.nn as nn
13 | from torch.autograd import Variable
14 |
15 | from sync_batchnorm import SynchronizedBatchNorm1d, SynchronizedBatchNorm2d, DataParallelWithCallback
16 | from sync_batchnorm.unittest import TorchTestCase
17 |
18 |
19 | def handy_var(a, unbias=True):
20 | n = a.size(0)
21 | asum = a.sum(dim=0)
22 | as_sum = (a ** 2).sum(dim=0) # a square sum
23 | sumvar = as_sum - asum * asum / n
24 | if unbias:
25 | return sumvar / (n - 1)
26 | else:
27 | return sumvar / n
28 |
29 |
30 | def _find_bn(module):
31 | for m in module.modules():
32 | if isinstance(m, (nn.BatchNorm1d, nn.BatchNorm2d, SynchronizedBatchNorm1d, SynchronizedBatchNorm2d)):
33 | return m
34 |
35 |
36 | class SyncTestCase(TorchTestCase):
37 | def _syncParameters(self, bn1, bn2):
38 | bn1.reset_parameters()
39 | bn2.reset_parameters()
40 | if bn1.affine and bn2.affine:
41 | bn2.weight.data.copy_(bn1.weight.data)
42 | bn2.bias.data.copy_(bn1.bias.data)
43 |
44 | def _checkBatchNormResult(self, bn1, bn2, input, is_train, cuda=False):
45 | """Check the forward and backward for the customized batch normalization."""
46 | bn1.train(mode=is_train)
47 | bn2.train(mode=is_train)
48 |
49 | if cuda:
50 | input = input.cuda()
51 |
52 | self._syncParameters(_find_bn(bn1), _find_bn(bn2))
53 |
54 | input1 = Variable(input, requires_grad=True)
55 | output1 = bn1(input1)
56 | output1.sum().backward()
57 | input2 = Variable(input, requires_grad=True)
58 | output2 = bn2(input2)
59 | output2.sum().backward()
60 |
61 | self.assertTensorClose(input1.data, input2.data)
62 | self.assertTensorClose(output1.data, output2.data)
63 | self.assertTensorClose(input1.grad, input2.grad)
64 | self.assertTensorClose(_find_bn(bn1).running_mean, _find_bn(bn2).running_mean)
65 | self.assertTensorClose(_find_bn(bn1).running_var, _find_bn(bn2).running_var)
66 |
67 | def testSyncBatchNormNormalTrain(self):
68 | bn = nn.BatchNorm1d(10)
69 | sync_bn = SynchronizedBatchNorm1d(10)
70 |
71 | self._checkBatchNormResult(bn, sync_bn, torch.rand(16, 10), True)
72 |
73 | def testSyncBatchNormNormalEval(self):
74 | bn = nn.BatchNorm1d(10)
75 | sync_bn = SynchronizedBatchNorm1d(10)
76 |
77 | self._checkBatchNormResult(bn, sync_bn, torch.rand(16, 10), False)
78 |
79 | def testSyncBatchNormSyncTrain(self):
80 | bn = nn.BatchNorm1d(10, eps=1e-5, affine=False)
81 | sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False)
82 | sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1])
83 |
84 | bn.cuda()
85 | sync_bn.cuda()
86 |
87 | self._checkBatchNormResult(bn, sync_bn, torch.rand(16, 10), True, cuda=True)
88 |
89 | def testSyncBatchNormSyncEval(self):
90 | bn = nn.BatchNorm1d(10, eps=1e-5, affine=False)
91 | sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False)
92 | sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1])
93 |
94 | bn.cuda()
95 | sync_bn.cuda()
96 |
97 | self._checkBatchNormResult(bn, sync_bn, torch.rand(16, 10), False, cuda=True)
98 |
99 | def testSyncBatchNorm2DSyncTrain(self):
100 | bn = nn.BatchNorm2d(10)
101 | sync_bn = SynchronizedBatchNorm2d(10)
102 | sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1])
103 |
104 | bn.cuda()
105 | sync_bn.cuda()
106 |
107 | self._checkBatchNormResult(bn, sync_bn, torch.rand(16, 10, 16, 16), True, cuda=True)
108 |
109 |
110 | if __name__ == '__main__':
111 | unittest.main()
112 |
--------------------------------------------------------------------------------
/lib/net/sync_batchnorm/unittest.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # File : unittest.py
3 | # Author : Jiayuan Mao
4 | # Email : maojiayuan@gmail.com
5 | # Date : 27/01/2018
6 | #
7 | # This file is part of Synchronized-BatchNorm-PyTorch.
8 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch
9 | # Distributed under MIT License.
10 |
11 | import unittest
12 |
13 | import numpy as np
14 | from torch.autograd import Variable
15 |
16 |
17 | def as_numpy(v):
18 | if isinstance(v, Variable):
19 | v = v.data
20 | return v.cpu().numpy()
21 |
22 |
23 | class TorchTestCase(unittest.TestCase):
24 | def assertTensorClose(self, a, b, atol=1e-3, rtol=1e-3):
25 | npa, npb = as_numpy(a), as_numpy(b)
26 | self.assertTrue(
27 | np.allclose(npa, npb, atol=atol),
28 | 'Tensor close check failed\n{}\n{}\nadiff={}, rdiff={}'.format(a, b, np.abs(npa - npb).max(), np.abs((npa - npb) / np.fmax(npa, 1e-5)).max())
29 | )
30 |
--------------------------------------------------------------------------------
/lib/utils/DenseCRF.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import pydensecrf.densecrf as dcrf
3 | from pydensecrf.utils import unary_from_softmax
4 |
5 | def dense_crf(probs, img=None, n_classes=21, n_iters=1, scale_factor=1):
6 | #probs = np.transpose(probs,(1,2,0)).copy(order='C')
7 | c,h,w = probs.shape
8 |
9 | if img is not None:
10 | assert(img.shape[1:3] == (h, w))
11 | img = np.transpose(img,(1,2,0)).copy(order='C')
12 |
13 | #probs = probs.transpose(2, 0, 1).copy(order='C') # Need a contiguous array.
14 |
15 | d = dcrf.DenseCRF2D(w, h, n_classes) # Define DenseCRF model.
16 |
17 | unary = unary_from_softmax(probs)
18 | unary = np.ascontiguousarray(unary)
19 | d.setUnaryEnergy(unary)
20 | d.addPairwiseGaussian(sxy=3/scale_factor, compat=3)
21 | #d.addPairwiseBilateral(sxy=80/scale_factor, srgb=13, rgbim=np.copy(img), compat=10)
22 | d.addPairwiseBilateral(sxy=32/scale_factor, srgb=13, rgbim=np.copy(img), compat=10)
23 | Q = d.inference(n_iters)
24 |
25 | # U = -np.log(probs) # Unary potential.
26 | # U = U.reshape((n_classes, -1)) # Needs to be flat.
27 | # d.setUnaryEnergy(U)
28 | # d.addPairwiseGaussian(sxy=sxy_gaussian, compat=compat_gaussian,
29 | # kernel=kernel_gaussian, normalization=normalisation_gaussian)
30 | # if img is not None:
31 | # assert(img.shape[1:3] == (h, w))
32 | # img = np.transpose(img,(1,2,0)).copy(order='C')
33 | # d.addPairwiseBilateral(sxy=sxy_bilateral, compat=compat_bilateral,
34 | # kernel=kernel_bilateral, normalization=normalisation_bilateral,
35 | # srgb=srgb_bilateral, rgbim=img)
36 | # Q = d.inference(n_iters)
37 | preds = np.array(Q, dtype=np.float32).reshape((n_classes, h, w))
38 | #return np.expand_dims(preds, 0)
39 | return preds
40 |
41 | def pro_crf(p, img, itr):
42 | C, H, W = p.shape
43 | p_bg = 1-p
44 | for i in range(C):
45 | cat = np.concatenate([p[i,:,:], p_bg[i,:,:]], axis=0)
46 | crf_pro = dense_crf(cat, img.astype(np.uint8), n_classes=C, n_iters=itr)
47 | p[i,:,:] = crf_pro[0]
48 | return p
49 |
--------------------------------------------------------------------------------
/lib/utils/JSD_loss.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torchvision
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 |
6 |
7 |
8 | def calc_jsd_single(weight, labels1_a, pred, threshold=0.8, Mask_label255_sign='no'):
9 |
10 | Mask_label255 = (labels1_a < 255).float() # do not compute the area that is irrelavant (dataaug)
11 | weight_softmax = F.softmax(weight, dim=0)
12 |
13 | criterion = nn.CrossEntropyLoss(ignore_index=255, reduction='none')
14 |
15 | loss = criterion(pred * weight_softmax[0], labels1_a) # * weight_softmax[0]
16 |
17 |
18 |
19 | prob = F.softmax(pred, dim=1)
20 | prob = torch.clamp(prob, 1e-7, 1)
21 |
22 | max_probs = torch.amax(prob*Mask_label255.unsqueeze(1), dim=(-3, -2, -1), keepdim=True) # select according to the pred without the aug area
23 | mask = max_probs.ge(threshold).float()
24 |
25 |
26 | logp = prob.log()
27 |
28 | log_probs = torch.sum(F.kl_div(logp, prob, reduction='none') * mask, dim=1)
29 | if Mask_label255_sign == 'yes':
30 | consistency = sum(log_probs)*Mask_label255
31 | else:
32 | consistency = sum(log_probs)
33 |
34 | return torch.mean(loss), torch.mean(consistency), consistency, prob
35 |
36 |
37 | def calc_jsd_multiscale(weight, labels1_a, pred1, pred2, pred3, threshold=0.8, Mask_label255_sign='no'):
38 |
39 | Mask_label255 = (labels1_a < 255).float() # do not compute the area that is irrelavant (dataaug) b,h,w
40 | weight_softmax = F.softmax(weight, dim=0)
41 |
42 | criterion1 = nn.CrossEntropyLoss(ignore_index=255, reduction='none')
43 | criterion2 = nn.CrossEntropyLoss(ignore_index=255, reduction='none')
44 | criterion3 = nn.CrossEntropyLoss(ignore_index=255, reduction='none')
45 |
46 | loss1 = criterion1(pred1 * weight_softmax[0], labels1_a) # * weight_softmax[0]
47 | loss2 = criterion2(pred2 * weight_softmax[1], labels1_a) # * weight_softmax[1]
48 | loss3 = criterion3(pred3 * weight_softmax[2], labels1_a) # * weight_softmax[2]
49 |
50 | loss = (loss1 + loss2 + loss3)
51 |
52 | probs = [F.softmax(logits, dim=1) for i, logits in enumerate([pred1, pred2, pred3])]
53 |
54 | weighted_probs = [weight_softmax[i] * prob for i, prob in enumerate(probs)] # weight_softmax[i]*
55 | mixture_label = (torch.stack(weighted_probs)).sum(axis=0)
56 | #mixture_label = torch.clamp(mixture_label, 1e-7, 1) # h,c,h,w
57 | mixture_label = torch.clamp(mixture_label, 1e-3, 1-1e-3) # h,c,h,w
58 |
59 | # add this code block for early torch version where torch.amax is not available
60 | if torch.__version__=="1.5.0" or torch.__version__=="1.6.0":
61 | _, max_probs = torch.max(mixture_label*Mask_label255.unsqueeze(1), dim=-3, keepdim=True)
62 | _, max_probs = torch.max(max_probs, dim=-2, keepdim=True)
63 | _, max_probs = torch.max(max_probs, dim=-1, keepdim=True)
64 | else:
65 | max_probs = torch.amax(mixture_label*Mask_label255.unsqueeze(1), dim=(-3, -2, -1), keepdim=True)
66 | mask = max_probs.ge(threshold).float()
67 |
68 |
69 | logp_mixture = mixture_label.log()
70 |
71 | log_probs = [torch.sum(F.kl_div(logp_mixture, prob, reduction='none') * mask, dim=1) for prob in probs]
72 | if Mask_label255_sign == 'yes':
73 | consistency = sum(log_probs)*Mask_label255
74 | else:
75 | consistency = sum(log_probs)
76 |
77 | return torch.mean(loss), torch.mean(consistency), consistency, mixture_label
78 |
79 |
80 |
81 | def calc_multiscale_backup(weight, seg_backup, pred1, pred2, pred3, mask_seg_prednan, seg_prediction, Lambda_back=0):
82 | b,_,h,w = pred1.size()
83 | seg_tempt = torch.zeros((b, 21, h, w), dtype=torch.float)
84 |
85 | seg_tempt[~mask_seg_prednan[:, :, :, :]] = F.softmax(seg_prediction, dim=1)[
86 | ~mask_seg_prednan[:, :, :, :]] # b,h,w
87 | # need to mask out the irrelavant region NaN
88 |
89 | weight_softmax = F.softmax(weight, dim=0)
90 |
91 | criterion1 = nn.CrossEntropyLoss(ignore_index=255, reduction='none')
92 | criterion2 = nn.CrossEntropyLoss(ignore_index=255, reduction='none')
93 | criterion3 = nn.CrossEntropyLoss(ignore_index=255, reduction='none')
94 |
95 | # to let the loss depends on the confidence of previous prediction on the background class seg_tempt[:,0,:,:]
96 | if Lambda_back == 0:
97 | loss1 = torch.mean(criterion1(pred1 * weight_softmax[0], seg_backup.to(0)) * seg_tempt[:, 0, :, :].to(0))
98 | loss2 = torch.mean(criterion2(pred2 * weight_softmax[1], seg_backup.to(0)) * seg_tempt[:, 0, :, :].to(0))
99 | loss3 = torch.mean(criterion3(pred3 * weight_softmax[2], seg_backup.to(0)) * seg_tempt[:, 0, :, :].to(0))
100 | else:
101 | loss1 = torch.mean(criterion1(pred1 * weight_softmax[0], seg_backup.to(0))) * Lambda_back
102 | loss2 = torch.mean(criterion2(pred2 * weight_softmax[1], seg_backup.to(0))) * Lambda_back
103 | loss3 = torch.mean(criterion3(pred3 * weight_softmax[2], seg_backup.to(0))) * Lambda_back
104 |
105 | loss = (loss1 + loss2 + loss3)
106 |
107 |
108 |
109 |
110 |
111 |
112 | return loss
--------------------------------------------------------------------------------
/lib/utils/__init__.py:
--------------------------------------------------------------------------------
1 | from .registry import DATASETS, BACKBONES, NETS
2 |
3 | __all__ = ['DATASETS', 'BACKBONES', 'NETS']
4 |
--------------------------------------------------------------------------------
/lib/utils/configuration.py:
--------------------------------------------------------------------------------
1 | # ----------------------------------------
2 | # Written by Yude Wang
3 | # ----------------------------------------
4 | import torch
5 | import os
6 | import sys
7 | import shutil
8 |
9 | class Configuration():
10 | def __init__(self, config_dict, clear=True):
11 | self.__dict__ = config_dict
12 | self.clear = clear
13 | self.__check()
14 |
15 | def __check(self):
16 | if not torch.cuda.is_available():
17 | raise ValueError('config.py: cuda is not avalable')
18 | if self.GPUS == 0:
19 | raise ValueError('config.py: the number of GPU is 0')
20 | if self.GPUS != torch.cuda.device_count():
21 | raise ValueError('config.py: GPU number is not matched')
22 |
23 | if not os.path.isdir(self.LOG_DIR):
24 | os.mkdir(self.LOG_DIR)
25 | # elif self.clear:
26 | # shutil.rmtree(self.LOG_DIR)
27 | # os.mkdir(self.LOG_DIR)
28 | if not os.path.isdir(self.MODEL_SAVE_DIR):
29 | # os.makedirs(self.MODEL_SAVE_DIR)
30 | os.mkdir(self.MODEL_SAVE_DIR)
31 |
32 |
33 |
34 |
35 |
36 |
--------------------------------------------------------------------------------
/lib/utils/eval_net_utils.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.functional as F
3 | from tqdm import tqdm
4 | import numpy as np
5 | import cv2
6 | from .imutils import img_denorm
7 | from .DenseCRF import dense_crf
8 | import pickle
9 | import os
10 | import torch.multiprocessing as mp
11 |
12 | def eval_net_multiprocess(SpawnContext, net1, net2, IoU_npl_indx, train_dataloader, eval_dataloader1,
13 | eval_dataloader2, momentum=0.3, scale_index=0, flip='no',
14 | scalefactor=1.0, CRF_post='no', tempt_save_root='.',update_all_bg_img=True,t_eval=3):
15 | net1.eval()
16 | net2.eval()
17 | if torch.cuda.device_count() > 1:
18 |
19 |
20 | seg_dict_copy = train_dataloader.dataset.seg_dict.copy()
21 | p1 = SpawnContext.Process(target = eval_net_bs_one, args=(torch.device(0), net1, IoU_npl_indx, eval_dataloader1,seg_dict_copy, momentum, scale_index, flip, scalefactor, CRF_post, tempt_save_root, 'eval_dict_tempt1.npy',update_all_bg_img,t_eval))
22 | p2 = SpawnContext.Process(target = eval_net_bs_one, args=(torch.device(1), net2, IoU_npl_indx, eval_dataloader2,seg_dict_copy, momentum, scale_index, flip, scalefactor, CRF_post, tempt_save_root, 'eval_dict_tempt2.npy',update_all_bg_img,t_eval))
23 |
24 | p1.start()
25 | p2.start()
26 |
27 | p1.join()
28 | p2.join()
29 |
30 |
31 | tempt = np.load(os.path.join(tempt_save_root, 'eval_dict_tempt1.npy'), allow_pickle=True)
32 | prev_pred_dict = tempt[()]
33 |
34 | tempt2 = np.load(os.path.join(tempt_save_root, 'eval_dict_tempt2.npy'), allow_pickle=True)
35 | prev_pred_dict2 = tempt2[()]
36 |
37 | prev_pred_dict.update(prev_pred_dict2)
38 | train_dataloader.dataset.prev_pred_dict = prev_pred_dict
39 |
40 | os.remove(os.path.join(tempt_save_root, 'eval_dict_tempt1.npy'))
41 | os.remove(os.path.join(tempt_save_root, 'eval_dict_tempt2.npy'))
42 | del seg_dict_copy
43 |
44 |
45 |
46 | return
47 |
48 |
49 | def eval_net_bs_one(device, net, IoU_npl_indx, eval_dataloader, seg_dict_copy, momentum=0.3, scale_index=0, flip='no', scalefactor=1.0, CRF_post='no',tempt_save_root='.', save_name='eval_dict_tempt1.npy', update_all_bg_img=False, t_eval=1.0):
50 | # net.eval()
51 | #scale_index = 2 # currently only support this version, improve later
52 | if scale_index==0:
53 | TEST_MULTISCALE = [0.75, 1.0, 1.5]
54 | elif scale_index==1:
55 | TEST_MULTISCALE = [0.5, 1.0, 1.75]
56 | elif scale_index==2:
57 | TEST_MULTISCALE = [0.5, 0.75, 1.0, 1.25, 1.5, 1.75]
58 | elif scale_index==3:
59 | TEST_MULTISCALE = [0.7, 1.0, 1.5]
60 | elif scale_index==4:
61 | TEST_MULTISCALE = [0.5, 0.75, 1.0, 1.25, 1.5]
62 | elif scale_index==5:
63 | TEST_MULTISCALE = [1.0]
64 | # print('eval_with_onebyone')
65 | prev_pred_dict = {}
66 | with tqdm(total=len(eval_dataloader)) as pbar:
67 | with torch.no_grad():
68 | for i_batch, sample in enumerate(eval_dataloader):
69 | # print(sample['batch_idx'])
70 | # seg_labels = sample['segmentation']
71 |
72 | seg_labels = seg_dict_copy[eval_dataloader.dataset.ori_indx_list[sample['batch_idx']]]
73 | # if they are not disjoint, we should evaluate it
74 | if set(np.unique(seg_labels[0].cpu().numpy())).isdisjoint(set(IoU_npl_indx[1:])):
75 |
76 | if update_all_bg_img and not (set(np.unique(seg_labels[0].numpy())) - set(np.array([0, 255]))):
77 | # only the background in the pseudo label, then this picture will still be evaluated
78 | pass
79 | else:
80 | # skip this one
81 | continue
82 |
83 | inputs = sample['image']
84 | n, c, h, w = inputs.size() # 1,c,h,w
85 | result_list =[]
86 | image_multiscale = []
87 | for rate in TEST_MULTISCALE:
88 | inputs_batched = sample['image_%f' % rate]
89 | image_multiscale.append(inputs_batched)
90 | if flip!='no':
91 | image_multiscale.append(torch.flip(inputs_batched, [3]))
92 | for img in image_multiscale:
93 | result = net(img.to(device))
94 | result_list.append(result.cpu())
95 | img.cpu()
96 |
97 | for i in range(len(result_list)):
98 | result_seg = F.interpolate(result_list[i], (h,w), mode='bilinear', align_corners=True)
99 | if i % 2 == 1 and flip!='no':
100 | result_seg = torch.flip(result_seg, [3])
101 | result_list[i] = result_seg
102 | prob_seg = torch.stack(result_list, dim=0) # 12, 1, c,h,w
103 | prob_seg = F.softmax(torch.mean(prob_seg/t_eval, dim=0, keepdim=False), dim=1) # 1,c,h,w
104 | #prob_seg = torch.clamp(prob_seg, 1e-7, 1)
105 | # do the CRF
106 | if CRF_post !='no':
107 | prob = prob_seg.cpu().numpy() # 1,c,h,w
108 | img_batched = img_denorm(sample['image'][0].numpy()).astype(np.uint8)
109 | prob = dense_crf(prob[0], img_batched, n_classes=21, n_iters=1)
110 | prob_seg = torch.from_numpy(prob.astype(np.float32))
111 | result = prob_seg.unsqueeze(dim=0) # 1,c,h,w
112 | else:
113 | result = prob_seg.cpu() # 1,c,h,w
114 |
115 | result_argmax = torch.argmax(result,dim=1) # 1,c,h,w the pred argmax label
116 | result_max_prob, _ = torch.max(result, dim=1) # 1,c,h,w the max probability
117 | for batch_idx in sample['batch_idx'].numpy():
118 | # prev_pred_dict[batch_idx] = result
119 | prev_pred_dict[eval_dataloader.dataset.ori_indx_list[batch_idx]]= (result_argmax, result_max_prob)
120 | pbar.set_description("Correcting Labels ")
121 | pbar.update(1)
122 |
123 | np.save(os.path.join(tempt_save_root, save_name), prev_pred_dict)
124 |
125 |
--------------------------------------------------------------------------------
/lib/utils/finalprocess.py:
--------------------------------------------------------------------------------
1 | import os
2 | import csv
3 |
4 | def writelog(cfg, period, metric=None, commit=''):
5 | filepath = os.path.join(cfg.ROOT_DIR,'log','logfile.txt')
6 | logfile = open(filepath,'a')
7 | import time
8 | logfile.write(time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()))
9 | logfile.write('\t%s\n'%period)
10 | para_data_dict = {}
11 | para_model_dict = {}
12 | para_train_dict = {}
13 | para_test_dict = {}
14 | para_name = dir(cfg)
15 | for name in para_name:
16 | if 'DATA_' in name:
17 | v = getattr(cfg,name)
18 | para_data_dict[name] = v
19 | elif 'MODEL_' in name:
20 | v = getattr(cfg,name)
21 | para_model_dict[name] = v
22 | elif 'TRAIN_' in name:
23 | v = getattr(cfg,name)
24 | para_train_dict[name] = v
25 | elif 'TEST_' in name:
26 | v = getattr(cfg,name)
27 | para_test_dict[name] = v
28 | writedict(logfile, {'EXP_NAME': cfg.EXP_NAME})
29 | writedict(logfile, para_data_dict)
30 | writedict(logfile, para_model_dict)
31 | if 'train' in period:
32 | writedict(logfile, para_train_dict)
33 | else:
34 | writedict(logfile, para_test_dict)
35 | writedict(logfile, metric)
36 |
37 | logfile.write(commit)
38 | logfile.write('=====================================\n')
39 | logfile.close()
40 |
41 |
42 | def writelog_seperate(cfg, period, metric=None, commit=''):
43 | filepath = os.path.join(cfg.ROOT_DIR,'log',cfg.logfile)
44 | logfile = open(filepath,'a')
45 | import time
46 | logfile.write(time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()))
47 | logfile.write('\t%s\n'%period)
48 | para_data_dict = {}
49 | para_model_dict = {}
50 | para_train_dict = {}
51 | para_test_dict = {}
52 | para_name = dir(cfg)
53 | for name in para_name:
54 | if 'DATA_' in name:
55 | v = getattr(cfg,name)
56 | para_data_dict[name] = v
57 | elif 'MODEL_' in name:
58 | v = getattr(cfg,name)
59 | para_model_dict[name] = v
60 | elif 'TRAIN_' in name:
61 | v = getattr(cfg,name)
62 | para_train_dict[name] = v
63 | elif 'TEST_' in name:
64 | v = getattr(cfg,name)
65 | para_test_dict[name] = v
66 | writedict(logfile, {'EXP_NAME': cfg.EXP_NAME})
67 | writedict(logfile, para_data_dict)
68 | writedict(logfile, para_model_dict)
69 | if 'train' in period:
70 | writedict(logfile, para_train_dict)
71 | else:
72 | writedict(logfile, para_test_dict)
73 | writedict(logfile, metric)
74 |
75 | logfile.write(commit)
76 | logfile.write('=====================================\n')
77 | logfile.close()
78 |
79 |
80 |
81 | def writedict(file, dictionary):
82 | s = ''
83 | for key in dictionary.keys():
84 | sub = '%s:%s '%(key, dictionary[key])
85 | s += sub
86 | s += '\n'
87 | file.write(s)
88 |
89 |
90 |
--------------------------------------------------------------------------------
/lib/utils/imutils.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import cv2
3 |
4 | def pseudo_erode(label, num, t=1):
5 | label_onehot = onehot(label, num)
6 | k = np.ones((15,15),np.uint8)
7 | e = cv2.erode(label_onehot, k, t)
8 | m = (e != label_onehot)
9 | m = np.max(m, axis=2)
10 | label[m] = 255
11 | return label
12 |
13 |
14 | def onehot(label, num):
15 | num = int(num)
16 | m = label.astype(np.int32)
17 | one_hot = np.eye(num)[m]
18 | return one_hot
19 |
20 | def seg2cls(label, num):
21 | cls = np.zeros(num)
22 | index = np.unique(label)
23 | cls[index] = 1
24 | #cls[0] = 0
25 | cls = cls.reshape((num,1,1))
26 | return cls
27 |
28 | def gamma_correction(img):
29 | gamma = np.mean(img)/128.0
30 | lookUpTable = np.empty((1,256), np.uint8)
31 | for i in range(256):
32 | lookUpTable[0,i] = np.clip(pow(i / 255.0, gamma) * 255.0, 0, 255)
33 | res_img = cv2.LUT(img, lookUpTable)
34 | return res_img
35 |
36 | def img_denorm(inputs, mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225), mul=True):
37 | inputs = np.ascontiguousarray(inputs)
38 | if inputs.ndim == 3:
39 | inputs[0,:,:] = (inputs[0,:,:]*std[0] + mean[0])
40 | inputs[1,:,:] = (inputs[1,:,:]*std[1] + mean[1])
41 | inputs[2,:,:] = (inputs[2,:,:]*std[2] + mean[2])
42 | elif inputs.ndim == 4:
43 | n = inputs.shape[0]
44 | for i in range(n):
45 | inputs[i,0,:,:] = (inputs[i,0,:,:]*std[0] + mean[0])
46 | inputs[i,1,:,:] = (inputs[i,1,:,:]*std[1] + mean[1])
47 | inputs[i,2,:,:] = (inputs[i,2,:,:]*std[2] + mean[2])
48 |
49 | if mul:
50 | inputs = inputs*255
51 | inputs[inputs > 255] = 255
52 | inputs[inputs < 0] = 0
53 | inputs = inputs.astype(np.uint8)
54 | else:
55 | inputs[inputs > 1] = 1
56 | inputs[inputs < 0] = 0
57 | return inputs
58 |
--------------------------------------------------------------------------------
/lib/utils/iou_computation.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 |
4 |
5 | def update_iou_stat(predict, gt, TP, P, T, num_classes = 21):
6 | """
7 | :param predict: the pred of each batch, should be numpy array, after take the argmax b,h,w
8 | :param gt: the gt label of the batch, should be numpy array b,h,w
9 | :param TP: True positive
10 | :param P: positive prediction
11 | :param T: True seg
12 | :param num_classes: number of classes in the dataset
13 | :return: TP, P, T
14 | """
15 | cal = gt < 255
16 |
17 | mask = (predict == gt) * cal
18 |
19 | for i in range(num_classes):
20 | P[i] += np.sum((predict == i) * cal)
21 | T[i] += np.sum((gt == i) * cal)
22 | TP[i] += np.sum((gt == i) * mask)
23 |
24 | return TP, P, T
25 |
26 |
27 | def iter_iou_stat(predict, gt, num_classes = 21):
28 | """
29 | :param predict: the pred of each batch, should be numpy array, after take the argmax b,h,w
30 | :param gt: the gt label of the batch, should be numpy array b,h,w
31 | :param TP: True positive
32 | :param P: positive prediction
33 | :param T: True seg
34 | :param num_classes: number of classes in the dataset
35 | :return: TP, P, T
36 | """
37 | cal = gt < 255
38 |
39 | mask = (predict == gt) * cal
40 |
41 | TP = np.zeros(num_classes)
42 | P = np.zeros(num_classes)
43 | T = np.zeros(num_classes)
44 |
45 | for i in range(num_classes):
46 | P[i] = np.sum((predict == i) * cal)
47 | T[i] = np.sum((gt == i) * cal)
48 | TP[i] = np.sum((gt == i) * mask)
49 |
50 | return np.array([TP, P, T])
51 |
52 |
53 | def compute_iou(TP, P, T, num_classes = 21):
54 | """
55 | :param TP:
56 | :param P:
57 | :param T:
58 | :param num_classes: number of classes in the dataset
59 | :return: IoU
60 | """
61 | IoU = []
62 | for i in range(num_classes):
63 | IoU.append(TP[i] / (T[i] + P[i] - TP[i] + 1e-10))
64 | return IoU
65 |
66 |
67 | def update_fraction_batchwise(mask, gt, fraction, num_classes = 21):
68 | """
69 | :param mask: True when belong to subgroup (memorized, correct, others) which we want to calculate fraction on
70 | :param gt: the gt label of the batch, numpy array
71 | :param fraction: fraction of pixels in the subgroup
72 | :param num_classes: number of classes in the dataset
73 | :return: updated fraction
74 | """
75 | cal = gt < 255
76 |
77 | for i in range(num_classes):
78 | fraction[i] += np.sum((mask * (gt == i) * cal))/np.sum((gt == i) * cal)
79 |
80 | return fraction
81 |
82 |
83 | def update_fraction_instancewise(mask, gt, fraction, num_classes = 21):
84 | """
85 | :param mask: True when belong to subgroup (memorized, correct, others) which we want to calculate fraction on
86 | :param gt: the gt label of the batch, numpy array
87 | :param fraction: fraction of pixels in the subgroup
88 | :param num_classes: number of classes in the dataset
89 | :return: updated fraction
90 | """
91 | # np.sum((gt == i) * cal maybe a nan value, can't do that
92 | cal = gt < 255
93 |
94 | for i in range(num_classes):
95 | fraction[i] += np.mean(np.sum((mask * (gt == i) * cal), axis= (-2,-1))/np.sum((gt == i) * cal, axis= (-2,-1)))
96 |
97 | return fraction
98 |
99 | def update_fraction_pixelwise(mask, gt, abs_num_and_total, num_classes = 21):
100 | """
101 | :param mask: True when belong to subgroup (memorized, correct, others) which we want to calculate fraction on
102 | :param gt: the gt label of the batch, numpy array
103 | :param abs_num_and_total: the absolute number of pixel belong to the mask and the total num of pixels [abs_num, pixel_num]
104 | :param num_classes: number of classes in the dataset
105 | :return: updated fraction
106 | """
107 | cal = gt < 255
108 |
109 | for i in range(num_classes):
110 | abs_num_and_total[i][0] += np.sum(mask * (gt == i) * cal)
111 | abs_num_and_total[i][1] += np.sum((gt == i) * cal)
112 |
113 |
114 | return abs_num_and_total
115 |
116 | def iter_fraction_pixelwise(mask, gt, num_classes = 21):
117 | """
118 | :param mask: True when belong to subgroup (memorized, correct, others) which we want to calculate fraction on
119 | :param gt: the gt label of the batch, numpy array
120 | :param num_classes: number of classes in the dataset
121 | :return: updated fraction
122 | """
123 | cal = gt < 255
124 |
125 | abs_num_and_total = np.zeros((num_classes,2))
126 |
127 | for i in range(num_classes):
128 | abs_num_and_total[i][0] += np.sum(mask * (gt == i) * cal)
129 | abs_num_and_total[i][1] += np.sum((gt == i) * cal)
130 |
131 |
132 | return abs_num_and_total
133 |
134 |
135 |
136 | def get_mask(gt_np, label_np, pred_np):
137 | """
138 |
139 | Args:
140 | gt_np: the GT label
141 | label_np: the CAM pseudo label
142 | pred_np: the prediction
143 |
144 | Returns: the mask of different type
145 |
146 | """
147 | wrong_mask_correct = (gt_np != label_np) & (pred_np == gt_np)
148 | wrong_mask_memorized = (gt_np != label_np) & (pred_np == label_np)
149 | wrong_mask_others = (gt_np != label_np) & (pred_np != gt_np) & (pred_np != label_np)
150 | clean_mask_correct = (gt_np == label_np) & (pred_np == gt_np)
151 | clean_mask_incorrect = (gt_np == label_np) & (pred_np != gt_np)
152 |
153 | return (wrong_mask_correct,wrong_mask_memorized,wrong_mask_others,clean_mask_correct,clean_mask_incorrect)
--------------------------------------------------------------------------------
/lib/utils/logger.py:
--------------------------------------------------------------------------------
1 | import importlib
2 | try:
3 | from comet_ml import Experiment as CometExperiment
4 | from comet_ml import OfflineExperiment as CometOfflineExperiment
5 | except ImportError: # pragma: no-cover
6 | _COMET_AVAILABLE = False
7 | else:
8 | _COMET_AVAILABLE = True
9 |
10 |
11 | import torch
12 | from torch import is_tensor
13 | from typing import Any, Dict, Optional, Union
14 | from datetime import datetime
15 |
16 | class Timer:
17 | def __init__(self):
18 | self.cache = datetime.now()
19 |
20 | def check(self):
21 | now = datetime.now()
22 | duration = now - self.cache
23 | self.cache = now
24 | return duration.total_seconds()
25 |
26 | def reset(self):
27 | self.cache = datetime.now()
28 |
29 | class CometWriter:
30 | def __init__(
31 | self,
32 | project_name: Optional[str] = None,
33 | experiment_name: Optional[str] = None,
34 | api_key: Optional[str] = None,
35 | log_dir: Optional[str] = None,
36 | offline: bool = False,
37 | **kwargs):
38 | if not _COMET_AVAILABLE:
39 | raise ImportError(
40 | "You want to use `comet_ml` logger which is not installed yet,"
41 | " install it with `pip install comet-ml`."
42 | )
43 |
44 | self.project_name = project_name
45 | self.experiment_name = experiment_name
46 | self.kwargs = kwargs
47 |
48 | self.timer = Timer()
49 |
50 |
51 | if (api_key is not None) and (log_dir is not None):
52 | self.mode = "offline" if offline else "online"
53 | self.api_key = api_key
54 | self.log_dir = log_dir
55 |
56 | elif api_key is not None:
57 | self.mode = "online"
58 | self.api_key = api_key
59 | self.log_dir = None
60 | elif log_dir is not None:
61 | self.mode = "offline"
62 | self.log_dir = log_dir
63 | else:
64 | print("CometLogger requires either api_key or save_dir during initialization.")
65 |
66 | if self.mode == "online":
67 | self.experiment = CometExperiment(
68 | api_key=self.api_key,
69 | project_name = self.project_name,
70 | **self.kwargs,
71 | )
72 | else:
73 | self.experiment = CometOfflineExperiment(
74 | offline_directory=self.log_dir,
75 | project_name=self.project_name,
76 | **self.kwargs,
77 | )
78 |
79 | if self.experiment_name:
80 | self.experiment.set_name(self.experiment_name)
81 |
82 | def set_step(self, step, epoch = None, mode='train') -> None:
83 | self.mode = mode
84 | self.step = step
85 | self.epoch = epoch
86 | if step == 0:
87 | self.timer.reset()
88 | else:
89 | duration = self.timer.check()
90 | self.add_scalar({'steps_per_sec': 1 / duration})
91 |
92 | def log_hyperparams(self, params: Dict[str, Any]) -> None:
93 | self.experiment.log_parameters(params)
94 |
95 | def log_code(self, file_name = None, folder = 'models/') -> None:
96 | self.experiment.log_code(file_name=file_name, folder=folder)
97 |
98 |
99 | def add_scalar(self, metrics: Dict[str, Union[torch.Tensor, float]], step: Optional[int] = None, epoch: Optional[int] = None) -> None:
100 | metrics_renamed = {}
101 | for key, val in metrics.items():
102 | tag = '{}/{}'.format(key, self.mode)
103 | if is_tensor(val):
104 | metrics_renamed[tag] = val.cpu().detach()
105 | else:
106 | metrics_renamed[tag] = val
107 | if epoch is None and step is None:
108 | self.experiment.log_metrics(metrics_renamed, step = self.step, epoch = self.epoch)
109 | elif epoch is None and step is not None:
110 | self.experiment.log_metrics(metrics_renamed, step = step)
111 | elif epoch is not None and step is None:
112 | self.experiment.log_metrics(metrics_renamed, epoch = epoch)
113 | else:
114 | self.experiment.log_metrics(metrics_renamed, step = step, epoch = epoch)
115 |
116 | def add_plot(self, figure_name, figure):
117 | """
118 | Primarily for log gate plots
119 | """
120 | self.experiment.log_figure(figure_name = figure_name, figure = figure)
121 |
122 | def add_text(self, text, step):
123 | """
124 | Primarily for log gate plots
125 | """
126 | self.experiment.log_text(text, step = step)
127 |
128 | def add_hist3d(self, hist, name):
129 | """
130 | Primarily for log gate plots
131 | """
132 | self.experiment.log_histogram_3d(hist, name = name)
133 |
134 | def reset_experiment(self):
135 | self.experiment = None
136 |
137 | def finalize(self) -> None:
138 | self.experiment.end()
139 | self.reset_experiment()
140 |
--------------------------------------------------------------------------------
/lib/utils/registry.py:
--------------------------------------------------------------------------------
1 |
2 | class Registry(object):
3 | def __init__(self, name):
4 | super(Registry, self).__init__()
5 | self._name = name
6 | self._module_dict = dict()
7 |
8 | @property
9 | def name(self):
10 | return self._name
11 |
12 | @property
13 | def module_dict(self):
14 | return self._module_dict
15 |
16 | def __len__(self):
17 | return len(self.module_dict)
18 |
19 | def get(self, key):
20 | return self._module_dict[key]
21 |
22 | def register_module(self, module=None):
23 | if module is None:
24 | raise TypeError('fail to register None in Registry {}'.format(self.name))
25 | module_name = module.__name__
26 | if module_name in self._module_dict:
27 | raise KeyError('{} is already registry in Registry {}'.format(module_name, self.name))
28 | self._module_dict[module_name] = module
29 | return module
30 |
31 | DATASETS = Registry('dataset')
32 | BACKBONES = Registry('backbone')
33 | NETS = Registry('nets')
34 |
--------------------------------------------------------------------------------
/lib/utils/test_utils.py:
--------------------------------------------------------------------------------
1 | import time
2 | import torch
3 | from tqdm import tqdm
4 |
5 | def single_gpu_test(model, dataloader, prepare_func, inference_func, collect_func, save_step_func=None):
6 | model.eval()
7 | n_gpus = torch.cuda.device_count()
8 | #assert n_gpus == 1
9 | collect_list = []
10 | total_num = len(dataloader)
11 | with tqdm(total=total_num) as pbar:
12 | with torch.no_grad():
13 | for i_batch, sample in enumerate(dataloader):
14 | name = sample['name']
15 | image_msf = prepare_func(sample)
16 | result_list = []
17 | for img in image_msf:
18 | result = inference_func(model, img.cuda())
19 | result_list.append(result)
20 | result_item = collect_func(result_list, sample)
21 | result_sample = {'predict': result_item, 'name':name[0]}
22 | #print('%d/%d'%(i_batch,len(dataloader)))
23 | pbar.set_description('Processing')
24 | pbar.update(1)
25 | time.sleep(0.001)
26 |
27 | if save_step_func is not None:
28 | save_step_func(result_sample)
29 | else:
30 | collect_list.append(result_sample)
31 | return collect_list
32 |
33 |
34 | def single_gpu_multimodel_ensemble_test(model,model2, dataloader, prepare_func, inference_func, collect_func, save_step_func=None):
35 | model.eval()
36 | n_gpus = torch.cuda.device_count()
37 | #assert n_gpus == 1
38 | collect_list = []
39 | total_num = len(dataloader)
40 | with tqdm(total=total_num) as pbar:
41 | with torch.no_grad():
42 | for i_batch, sample in enumerate(dataloader):
43 | name = sample['name']
44 | image_msf = prepare_func(sample)
45 | result_list = []
46 | for img in image_msf:
47 | result = inference_func(model,model2, img.cuda())
48 | result_list.append(result)
49 | result_item = collect_func(result_list, sample)
50 | result_sample = {'predict': result_item, 'name':name[0]}
51 | #print('%d/%d'%(i_batch,len(dataloader)))
52 | pbar.set_description('Processing')
53 | pbar.update(1)
54 | time.sleep(0.001)
55 |
56 | if save_step_func is not None:
57 | save_step_func(result_sample)
58 | else:
59 | collect_list.append(result_sample)
60 | return collect_list
61 |
62 |
63 | def single_gpu_triplemodel_ensemble_test(model,model2,model3, dataloader, prepare_func, inference_func, collect_func, save_step_func=None):
64 | model.eval()
65 | n_gpus = torch.cuda.device_count()
66 | #assert n_gpus == 1
67 | collect_list = []
68 | total_num = len(dataloader)
69 | with tqdm(total=total_num) as pbar:
70 | with torch.no_grad():
71 | for i_batch, sample in enumerate(dataloader):
72 | name = sample['name']
73 | image_msf = prepare_func(sample)
74 | result_list = []
75 | for img in image_msf:
76 | result = inference_func(model,model2,model3, img.cuda())
77 | result_list.append(result)
78 | result_item = collect_func(result_list, sample)
79 | result_sample = {'predict': result_item, 'name':name[0]}
80 | #print('%d/%d'%(i_batch,len(dataloader)))
81 | pbar.set_description('Processing')
82 | pbar.update(1)
83 | time.sleep(0.001)
84 |
85 | if save_step_func is not None:
86 | save_step_func(result_sample)
87 | else:
88 | collect_list.append(result_sample)
89 | return collect_list
90 |
--------------------------------------------------------------------------------
/lib/utils/visualization.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | import torch.nn.functional as F
4 | import cv2
5 | from utils.DenseCRF import *
6 | #from cv2.ximgproc import l0Smooth
7 |
8 | def color_pro(pro, img=None, mode='hwc'):
9 | H, W = pro.shape
10 | pro_255 = (pro*255).astype(np.uint8)
11 | pro_255 = np.expand_dims(pro_255,axis=2)
12 | color = cv2.applyColorMap(pro_255,cv2.COLORMAP_JET)
13 | color = cv2.cvtColor(color, cv2.COLOR_BGR2RGB)
14 | if img is not None:
15 | rate = 0.5
16 | if mode == 'hwc':
17 | assert img.shape[0] == H and img.shape[1] == W
18 | color = cv2.addWeighted(img,rate,color,1-rate,0)
19 | elif mode == 'chw':
20 | assert img.shape[1] == H and img.shape[2] == W
21 | img = np.transpose(img,(1,2,0))
22 | color = cv2.addWeighted(img,rate,color,1-rate,0)
23 | color = np.transpose(color,(2,0,1))
24 | else:
25 | if mode == 'chw':
26 | color = np.transpose(color,(2,0,1))
27 | return color
28 |
29 | def generate_vis(p, gt, img, func_label2color, threshold=0.1, norm=True, crf=False):
30 | # All the input should be numpy.array
31 | # img should be 0-255 uint8
32 | C, H, W = p.shape
33 |
34 | if norm:
35 | prob = max_norm(p, 'numpy')
36 | else:
37 | prob = p
38 | if gt is not None:
39 | prob = prob * gt
40 | prob[prob<=0] = 1e-5
41 | if threshold is not None:
42 | prob[0,:,:] = np.power(1-np.max(prob[1:,:,:],axis=0,keepdims=True), 4)
43 |
44 | CLS = ColorCLS(prob, func_label2color)
45 | CAM = ColorCAM(prob, img)
46 | if crf:
47 | prob_crf = dense_crf(prob, img, n_classes=C, n_iters=1)
48 | CLS_crf = ColorCLS(prob_crf, func_label2color)
49 | CAM_crf = ColorCAM(prob_crf, img)
50 | return CLS, CAM, CLS_crf, CAM_crf
51 | else:
52 | return CLS, CAM
53 |
54 | def max_norm(p, version='torch', e=1e-5):
55 | if version is 'torch':
56 | if p.dim() == 3:
57 | C, H, W = p.size()
58 | p = F.relu(p, inplace=True)
59 | max_v = torch.max(p.view(C,-1),dim=-1)[0].view(C,1,1)
60 | min_v = torch.min(p.view(C,-1),dim=-1)[0].view(C,1,1)
61 | p = F.relu(p-min_v-e, inplace=True)/(max_v-min_v+e)
62 | elif p.dim() == 4:
63 | N, C, H, W = p.size()
64 | p = F.relu(p, inplace=True)
65 | max_v = torch.max(p.view(N,C,-1),dim=-1)[0].view(N,C,1,1)
66 | min_v = torch.min(p.view(N,C,-1),dim=-1)[0].view(N,C,1,1)
67 | p = F.relu(p-min_v-e, inplace=True)/(max_v-min_v+e)
68 | elif version is 'numpy' or version is 'np':
69 | if p.ndim == 3:
70 | C, H, W = p.shape
71 | p[p=1.7.0
2 | torchvision>=0.8.1
3 | mxnet>=1.7.0.post1
4 | scipy>=1.5.1
5 | numpy>=1.19.4
6 | scikit_image>=0.17.2
7 | pydensecrf>=1.0rc3
8 | pandas>=1.0.5
9 | opencv_python>=4.3.0.36
10 | matplotlib>=3.3.0
11 | Pillow>=8.1.0
12 | tensorboardX>=2.1
13 | tqdm
14 | scikit-image
15 | comet-ml
16 |
--------------------------------------------------------------------------------