├── .gitignore ├── .idea ├── .gitignore ├── inspectionProfiles │ └── profiles_settings.xml ├── jsd_experiments.iml ├── misc.xml ├── modules.xml └── vcs.xml ├── LICENSE ├── README.md ├── SegThor ├── .idea │ ├── .gitignore │ ├── ADELE.iml │ ├── inspectionProfiles │ │ └── profiles_settings.xml │ ├── misc.xml │ ├── modules.xml │ └── vcs.xml ├── LICENSE ├── brat │ ├── __init__.py │ ├── brat_util.py │ ├── label_correction.py │ ├── loading.py │ ├── train_segthor.py │ ├── unet_model.py │ └── unet_parts.py ├── lib │ └── utils │ │ ├── JSD_loss.py │ │ └── iou_computation.py └── requirements.txt ├── __init__.py ├── config.py ├── lib ├── datasets │ ├── BaseDataset.py │ ├── BaseMultiwGTauginfoDataset.py │ ├── VOCDataset.py │ ├── VOCEvalDataset.py │ ├── VOCTrainwsegDataset.py │ ├── __init__.py │ ├── generateData.py │ ├── metric.py │ ├── transform.py │ ├── transformmultiGT.py │ └── transformmultiGTauginfo.py ├── net │ ├── __init__.py │ ├── backbone │ │ ├── __init__.py │ │ ├── builder.py │ │ ├── resnet.py │ │ ├── resnet38d.py │ │ └── xception.py │ ├── deeplabv1_wo_interp.py │ ├── generateNet.py │ ├── operators │ │ ├── ASPP.py │ │ ├── PPM.py │ │ └── __init__.py │ └── sync_batchnorm │ │ ├── __init__.py │ │ ├── batchnorm.py │ │ ├── comm.py │ │ ├── replicate.py │ │ ├── sync_batchnorm │ │ ├── __init__.py │ │ ├── batchnorm.py │ │ ├── batchnorm_reimpl.py │ │ ├── comm.py │ │ ├── replicate.py │ │ └── unittest.py │ │ ├── tests │ │ ├── test_numeric_batchnorm.py │ │ └── test_sync_batchnorm.py │ │ └── unittest.py └── utils │ ├── DenseCRF.py │ ├── JSD_loss.py │ ├── __init__.py │ ├── configuration.py │ ├── eval_net_utils.py │ ├── finalprocess.py │ ├── imutils.py │ ├── iou_computation.py │ ├── logger.py │ ├── registry.py │ ├── test_utils.py │ └── visualization.py ├── requirements.txt └── train.py /.gitignore: -------------------------------------------------------------------------------- 1 | sftp*.json 2 | run*.sh 3 | train_onebyone_eval_w_compact_dict.py 4 | -------------------------------------------------------------------------------- /.idea/.gitignore: -------------------------------------------------------------------------------- 1 | # Default ignored files 2 | /shelf/ 3 | /workspace.xml 4 | # Datasource local storage ignored files 5 | /dataSources/ 6 | /dataSources.local.xml 7 | # Editor-based HTTP Client requests 8 | /httpRequests/ 9 | -------------------------------------------------------------------------------- /.idea/inspectionProfiles/profiles_settings.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 5 | -------------------------------------------------------------------------------- /.idea/jsd_experiments.iml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 14 | 15 | 17 | -------------------------------------------------------------------------------- /.idea/misc.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | -------------------------------------------------------------------------------- /.idea/modules.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | -------------------------------------------------------------------------------- /.idea/vcs.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 Hibercraft 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # ADELE (Adaptive Early-Learning Correction for Segmentation from Noisy Annotations) (CVPR 2022 Oral) 2 | 3 | 4 | Sheng Liu*, Kangning Liu*, Weicheng Zhu, Yiqiu Shen, Carlos Fernandez-Granda 5 | 6 | (* The first two authors contribute equally, order decided by coin flipping.) 7 | 8 | 9 | 10 | 11 | Official Implementation of [Adaptive Early-Learning Correction for Segmentation from Noisy Annotations](https://arxiv.org/abs/2110.03740) (CVPR 2022 Oral) 12 | 13 | ## PASCAL VOC dataset 14 | Thanks to the work of Yude Wang, the code of this repository borrows heavily from his SEAM repository, and we follw the same pipeline to verify the effectiveness of our ADELE. 15 | We use the same ImageNet pretrained ResNet38 model as SEAM, which can be downloaded from https://github.com/YudeWang/semantic-segmentation-codebase/tree/main/experiment/seamv1-pseudovoc 16 | 17 | 18 | 19 | 20 | 21 | 22 | The code related to PASCAL VOC locates in the main folder, we provide the trained model for SEAM+ADELE in the following link 23 | 24 | https://drive.google.com/file/d/10cTOraETOmb2jOCJ4E0m_y9lrjrA3g2u/view?usp=sharing 25 | 26 | We use two NVIDIA Quadro RTX 8000 GPUs to train the model, if you encounter out of memory issue, please consider decreasing the resolution of the input image. 27 | 28 | ### Installation 29 | - Install python dependencies. 30 | ``` 31 | pip install -r requirements.txt 32 | ``` 33 | 34 | Note that we use comet to record the statistics online. Comet is similar to tensorboard, more information can found via https://www.comet.ml/site/ . 35 | - Create softlink to your dataset. Make sure that the dataset can be accessed by `$your_dataset_path/VOCdevkit/VOC2012...` 36 | ``` 37 | ln -s $your_dataset_path data 38 | ``` 39 | 40 | 41 | 42 | 43 | Inference code is the same as the official code for SEAM. Attach the Code link provide by the SEAM author: https://github.com/YudeWang/semantic-segmentation-codebase/tree/main/experiment/seamv1-pseudovoc 44 | 45 | 46 | 47 | 48 | For the training code, an example script for ADELE would be: 49 | 50 | ``` 51 | python train.py \ 52 | --EXP_NAME EXP_name \ 53 | --Lambda1 1 --TRAIN_BATCHES 10 --TRAIN_LR 0.001 --mask_threshold 0.8 \ 54 | --scale_index 0 --flip yes --CRF yes \ 55 | --dict_save_scale_factor 1 --npl_metrics 0 \ 56 | --api_key API_key \ 57 | --r_threshold 0.9 --Reinit_dict yes \ 58 | --DATA_PSEUDO_GT Inital_Pseudo_Label_Location 59 | ``` 60 | 61 | 62 | 63 | We store some default value for the arguments in the config.py file, those value would be passed to arguments as cfg.XXX. You may change the default value in the config.py or change that via arguments in the script. 64 | It is especially important to assign the path of your initial pseudo annotation via --DATA_PSEUDO_GT or specify that in cfg.DATA_PSEUDO_GT in the config.py file. For the detailed method to obtain the initial pseudo annotation, please refer to the related method such as AffinityNet, SEAM, ICD, NSROM, etc. 65 | 66 | 67 | The arguments represent: 68 | 69 | parser.add_argument("--EXP_NAME", type=str, default=cfg.EXP_NAME, 70 | help="the name of the experiment") 71 | parser.add_argument("--scale_factor", type=float, default=cfg.scale_factor, 72 | help="scale_factor of downsample the image") 73 | parser.add_argument("--scale_factor2", type=float, default=cfg.scale_factor2, 74 | help="scale_factor of upsample the image") 75 | parser.add_argument("--DATA_PSEUDO_GT", type=str, default=cfg.DATA_PSEUDO_GT, 76 | help="Data path for the main segmentation map") 77 | parser.add_argument("--TRAIN_CKPT", type=str, default=cfg.TRAIN_CKPT, 78 | help="Training path") 79 | parser.add_argument("--Lambda1", type=float, default=1, 80 | help="to balance the loss between CE and Consistency loss") 81 | parser.add_argument("--TRAIN_BATCHES", type=int, default=cfg.TRAIN_BATCHES, 82 | help="training batch szie") 83 | parser.add_argument('--threshold', type=float, default=0.8, 84 | help="threshold to select the mask for Consistency loss computation ") 85 | parser.add_argument('--DATA_WORKERS', type=int, default=cfg.DATA_WORKERS, 86 | help="number of workers in dataloader") 87 | parser.add_argument('--TRAIN_LR', type=float, 88 | default=cfg.TRAIN_LR, 89 | help="the path of trained weight") 90 | parser.add_argument('--TRAIN_ITERATION', type=int, 91 | default=cfg.TRAIN_ITERATION, 92 | help="the training iteration number") 93 | parser.add_argument('--DATA_RANDOMCROP', type=int, default=cfg.DATA_RANDOMCROP, 94 | help="the resolution of random crop") 95 | 96 | 97 | 98 | # related to the pseudo label updating 99 | parser.add_argument('--mask_threshold', type=float, default=0.8, 100 | help="only the region with high probability and disagree with Pseudo label be updated") 101 | parser.add_argument('--update_interval', type=int, default=1, 102 | help="evaluate the prediction every 1 epoch") 103 | parser.add_argument('--npl_metrics', type=int, default=0, 104 | help="0: using the original cam to compute the npl similarity, 1: use the updated pseudo label to compute the npl") 105 | parser.add_argument('--r_threshold', type=float, default=0.9, 106 | help="the r threshold to decide if_update") 107 | 108 | # related to the eval mode 109 | parser.add_argument('--scale_index', type=int, default=2, 110 | help="0: scale [0.7, 1.0, 1.5] 1:[0.5, 1.0, 1.75], 2:[0.5, 0.75, 1.0, 1.25, 1.5, 1.75] ") 111 | parser.add_argument('--flip', type=str, default='yes', 112 | help="do not flip in the eval pred if no, else flip") 113 | parser.add_argument('--CRF', type=str, default='no', 114 | help="whether to use CRF, yes or no, default no") 115 | parser.add_argument('--dict_save_scale_factor', type=float, default=1, 116 | help="dict_save_scale_factor downsample_factor (in case the CPU memory is not enough)") 117 | parser.add_argument('--evaluate_interval', type=int, default=1, 118 | help="evaluate the prediction every 1 epoch, this is always set to one for PASCAL VOC dataset") 119 | parser.add_argument('--Reinit_dict', type=str2bool, nargs='?', 120 | const=True, default=False, 121 | help="whether to reinit the dict every epoch") 122 | parser.add_argument('--evaluate_aug_epoch', type=int, default=9, 123 | help="when to start aug the evaluate with CRF and flip, this can be used to save some time when updating the pseudo label, we did not find significant difference") 124 | 125 | 126 | 127 | # continue_training_related: 128 | parser.add_argument('--continue_train_epoch', type=int, default=0, 129 | help="load the trained model from which epoch, if 0, no continue training") 130 | parser.add_argument('--checkpoint_path', type=str, default='no', 131 | help="the checkpoint path to load the model") 132 | parser.add_argument('--dict_path', type=str, 133 | default='no', 134 | help="the dict path of seg path") 135 | parser.add_argument('--MODEL_BACKBONE_PRETRAIN', type=str2bool, nargs='?', 136 | const=True, default=True, 137 | help="Do not load pretrained model if false") 138 | 139 | 140 | # Comet 141 | parser.add_argument('--api_key', type=str, 142 | default='', 143 | help="The api_key of Comet, please refer to https://www.comet.ml/site/ for more information" 144 | parser.add_argument('--online', type=str2bool, nargs='?', 145 | const=True, default=True, 146 | help="False when use Comet offline") 147 | 148 | 149 | 150 | 151 | 152 | 153 | 154 | 155 | ## SegTHOR dataset 156 | The code related to SegTHOR locates in the folder SegThor, please go to the subdirectory SegThor 157 | ### Installation 158 | 159 | - Install python dependencies. 160 | ``` 161 | pip install -r requirements.txt 162 | ``` 163 | - Downlaod the SegTHOR dataset and conduct data preprocessing, resize all the image to the size of 256*256 using linear interpolation of opencv_python (INTER_LINEAR). 164 | 165 | The details of public SegTHOR dataset can be found in [this link](https://competitions.codalab.org/competitions/21145). 166 | 167 | In this study, we randomly assign patients in the original training set into training, validation, and test set using following scheme: 168 | 169 | - training set: ['Patient_01', 'Patient_02', 'Patient_03', 'Patient_04', 170 | 'Patient_05', 'Patient_06', 'Patient_07', 'Patient_09', 171 | 'Patient_10', 'Patient_11', 'Patient_12', 'Patient_13', 172 | 'Patient_14', 'Patient_15', 'Patient_16', 'Patient_17', 173 | 'Patient_18', 'Patient_19', 'Patient_20', 'Patient_22', 174 | 'Patient_24', 'Patient_25', 'Patient_26', 'Patient_28', 175 | 'Patient_30', 'Patient_31', 'Patient_33', 'Patient_36', 176 | 'Patient_38', 'Patient_39', 'Patient_40'] 177 | - validation set: ['Patient_21', 'Patient_23', 'Patient_27', 'Patient_29', 178 | 'Patient_37'] 179 | - test set: ['Patient_08', 'Patient_27', 'Patient_32', 'Patient_34', 180 | 'Patient_35'] 181 | 182 | We used only slices that contain foreground class and downsampled all slices into 256 * 256 pixels using linear interpolation. 183 | 184 | ### Experiments 185 | Here is the example script of ADELE: 186 | ``` 187 | python3 brat/train_segthor.py \ 188 | --cache-dir DIR_OF_THE_DATA \ 189 | --data-list DIR_OF_THE_DATALIST \ 190 | --save-dir MODEL_SAVE_DIR \ 191 | --model-name MODEL_NAME \ 192 | --seed 0 \ 193 | --jsd-lambda 1 \ 194 | --rho 0.8 \ 195 | --label-correction \ 196 | --tau_fg 0.7 \ 197 | --tau_bg 0.7 \ 198 | --r 0.9 199 | ``` 200 | 201 | where the arguments represent: 202 | * `cache-dir` - Parent dir of the datalist, tr.pkl, val.pkl, ts.pkl, which are the input data for training, validation and testing set. 203 | * `data-list` - Parent dir of the data_list.pkl file, which is the list of names for the input data. 204 | * `save-dir` - Folder, where models and results will be saved. 205 | * `model-name` - Name of the model. 206 | * `seed` - the random seed of the noise realization, default 0. 207 | * `jsd-lambda` - the consistency strength, if set to 0, no consistency regularization will be applied, default 1. 208 | * `rho` - consistency confidence threshold, this is the threshold on the confidence of model's prediction to decide which examples are applied with consistency regularization 209 | * `label-correction` - whether to conduct label correction, if set this arguments, the model will do label correction, default False. 210 | * `tau_fg, tau_bg` - label correction confidence threshold for foreground and background, in the main paper and all the experiment, we set these two values to be the same for simplicity, default 0.7. 211 | * `r` - curve fitting threshold to control when a specific semantic category will be corrected, default 0.9. 212 | 213 | 214 | 215 | Here is the example script of baseline: 216 | ``` 217 | python3 brat/train_segthor.py \ 218 | --cache-dir DIR_OF_THE_DATA \ 219 | --data-list DIR_OF_THE_DATALIST \ 220 | --save-dir MODEL_SAVE_DIR \ 221 | --model-name MODEL_NAME \ 222 | --seed 0 \ 223 | --jsd-lambda 0 224 | ``` 225 | 226 | 227 | 228 | 229 | 230 | 231 | 232 | 233 | ## Citation 234 | 235 | Please cite our paper if the code is helpful to your research. 236 | ``` 237 | @article{liu2021adaptive, 238 | title={Adaptive Early-Learning Correction for Segmentation from Noisy Annotations}, 239 | author={Liu, Sheng and Liu, Kangning and Zhu, Weicheng and Shen, Yiqiu and Fernandez-Granda, Carlos}, 240 | journal={CVPR 2022}, 241 | year={2022} 242 | } 243 | ``` 244 | 245 | -------------------------------------------------------------------------------- /SegThor/.idea/.gitignore: -------------------------------------------------------------------------------- 1 | # Default ignored files 2 | /shelf/ 3 | /workspace.xml 4 | # Datasource local storage ignored files 5 | /dataSources/ 6 | /dataSources.local.xml 7 | # Editor-based HTTP Client requests 8 | /httpRequests/ 9 | -------------------------------------------------------------------------------- /SegThor/.idea/ADELE.iml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 11 | -------------------------------------------------------------------------------- /SegThor/.idea/inspectionProfiles/profiles_settings.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 5 | -------------------------------------------------------------------------------- /SegThor/.idea/misc.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | -------------------------------------------------------------------------------- /SegThor/.idea/modules.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | -------------------------------------------------------------------------------- /SegThor/.idea/vcs.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | -------------------------------------------------------------------------------- /SegThor/LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 Sheng Liu 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /SegThor/brat/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Kangningthu/ADELE/7195bd0af39be79c533d67dd7eab7f9bfd6a4285/SegThor/brat/__init__.py -------------------------------------------------------------------------------- /SegThor/brat/brat_util.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | import numpy as np 3 | import time 4 | 5 | 6 | class DocumentUnit: 7 | """ 8 | Object that document the output from the model in an epoch 9 | """ 10 | 11 | def __init__(self, columns): 12 | self.data_dict = dict([(col, []) for col in columns]) 13 | # skip case level pred and labels since they has 1/2 length 14 | self.skip_key = ["case_pred", "case_label", "left_case_pred", 15 | "right_case_pred", "fusion_case_pred", "left_right_case_pred"] 16 | # accumulator for localization 17 | self.localization_accumulator = None 18 | 19 | def update_accumulator(self, delta): 20 | """ 21 | Method that accumulates localization pixel-wise values for metrics such as mIOU 22 | :param delta: 23 | :return: 24 | """ 25 | if self.localization_accumulator is None: 26 | self.localization_accumulator = delta 27 | else: 28 | assert self.localization_accumulator.shape == delta.shape,\ 29 | "self.localization_accumulator.shape {0} != delta.shape {1}".format(self.localization_accumulator.shape, delta.shape) 30 | self.localization_accumulator += delta 31 | 32 | def add_values(self, column, values, process_method=lambda x: x): 33 | """ 34 | Method that add values into the document unit 35 | :param column: 36 | :param values: 37 | :return: 38 | """ 39 | for val in values: 40 | self.data_dict[column].append(process_method(val)) 41 | 42 | def form_df(self): 43 | """ 44 | Method that creates a dataframe out of stored data 45 | :return: 46 | """ 47 | to_be_save_dict = {} 48 | for key in self.data_dict: 49 | if key not in self.skip_key and len(self.data_dict[key]) != 0: 50 | to_be_save_dict[key] = self.data_dict[key] 51 | df = pd.DataFrame(to_be_save_dict).reset_index() 52 | return df 53 | 54 | def get_latest_results(self): 55 | """ 56 | Method that retrieves the latest results from the stored data 57 | :return: 58 | """ 59 | to_be_save_dict = {} 60 | for key in self.data_dict: 61 | if key not in self.skip_key and len(self.data_dict[key]) != 0: 62 | to_be_save_dict[key] = self.data_dict[key][-1] 63 | return to_be_save_dict 64 | 65 | def to_csv(self, dir): 66 | """ 67 | Export to csv 68 | :param dir: 69 | :return: 70 | """ 71 | df = self.form_df() 72 | df.to_csv(dir, index=False) 73 | 74 | 75 | 76 | class RuntimeProfiler: 77 | """ 78 | Object that documents run-time 79 | """ 80 | def __init__(self): 81 | self.elpased_time_dict = {} 82 | self.current_time_point = None 83 | 84 | def tik(self, time_category=None): 85 | """ 86 | Take a time point 87 | :param time_category: 88 | :return: 89 | """ 90 | new_time = time.time() 91 | return_time = False 92 | if self.current_time_point is not None: 93 | return_time = True 94 | elapsed_time = new_time - self.current_time_point 95 | if time_category is not None: 96 | if time_category not in self.elpased_time_dict: 97 | self.elpased_time_dict[time_category] = [] 98 | self.elpased_time_dict[time_category].append(elapsed_time) 99 | self.current_time_point = new_time 100 | if return_time: 101 | return elapsed_time 102 | 103 | 104 | def report_avg(self): 105 | """ 106 | Generate a format string for average run-time statistics 107 | :return: 108 | """ 109 | output_str = "" 110 | for time_category in self.elpased_time_dict: 111 | output_str += "category:{0}, avg_time:{1}, std_time:{2}, min_time:{3}, max_time:{4}, num_points:{5}\n".format( 112 | time_category, np.mean(self.elpased_time_dict[time_category]), np.std(self.elpased_time_dict[time_category]), 113 | np.min(self.elpased_time_dict[time_category]), np.max(self.elpased_time_dict[time_category]), 114 | len(self.elpased_time_dict[time_category]) 115 | ) 116 | return output_str 117 | 118 | def report_latest(self): 119 | """ 120 | Generate a format string for the latest run-time statistics 121 | :return: 122 | """ 123 | output_str = "" 124 | for time_category in self.elpased_time_dict: 125 | output_str += "category:{0}, runtime:{1} \n".format(time_category, self.elpased_time_dict[time_category][-1]) 126 | return output_str -------------------------------------------------------------------------------- /SegThor/brat/label_correction.py: -------------------------------------------------------------------------------- 1 | from scipy.optimize import curve_fit 2 | import numpy as np 3 | 4 | def curve_func(x, a, b, c): 5 | return a *(1-np.exp( -1/c * x**b )) 6 | 7 | 8 | def fit(func, x, y): 9 | popt, pcov = curve_fit(func, x, y, p0 =(1,1,1), method= 'trf', sigma = np.geomspace(1,.1,len(y)), absolute_sigma=True, bounds= ([0,0,0],[1,1,np.inf]) ) 10 | return tuple(popt) 11 | 12 | 13 | def derivation(x, a, b, c): 14 | x = x + 1e-6 # numerical robustness 15 | return a * b * 1/c * np.exp(-1/c * x**b) * (x**(b-1)) 16 | 17 | 18 | def label_update_epoch(ydata_fit, n_epoch = 16, threshold = 0.9, eval_interval = 100, num_iter_per_epoch= 10581/10): 19 | xdata_fit = np.linspace(0, len(ydata_fit)*eval_interval/num_iter_per_epoch, len(ydata_fit)) 20 | a, b, c = fit(curve_func, xdata_fit, ydata_fit) 21 | epoch = np.arange(1, n_epoch) 22 | y_hat = curve_func(epoch, a, b, c) 23 | relative_change = abs(abs(derivation(epoch, a, b, c)) - abs(derivation(1, a, b, c)))/ abs(derivation(1, a, b, c)) 24 | relative_change[relative_change > 1] = 0 25 | update_epoch = np.sum(relative_change <= threshold) + 1 26 | return update_epoch#, a, b, c 27 | 28 | def if_update(iou_value, current_epoch, n_epoch = 16, threshold = 0.90, eval_interval=1, num_iter_per_epoch=1): 29 | # check iou_value 30 | start_iter = 0 31 | print("len(iou_value)=",len(iou_value)) 32 | for k in range(len(iou_value)-1): 33 | if iou_value[k+1]-iou_value[k] < 0.1: 34 | start_iter = max(start_iter, k + 1) 35 | else: 36 | break 37 | shifted_epoch = start_iter*eval_interval/num_iter_per_epoch 38 | #cut out the first few entries 39 | iou_value = iou_value[start_iter: ] 40 | update_epoch = label_update_epoch(iou_value, n_epoch = n_epoch, threshold=threshold, eval_interval=eval_interval, num_iter_per_epoch=num_iter_per_epoch) 41 | # Shift back 42 | update_epoch = shifted_epoch + update_epoch 43 | return current_epoch >= update_epoch#, update_epoch 44 | 45 | 46 | def merge_labels_with_skip(original_labels, model_predictions, need_label_correction_dict, conf_threshold=0.8, logic_255=False,class_constraint=True, conf_threshold_bg = 0.95): 47 | 48 | 49 | new_label_dict = {} 50 | update_list = [] 51 | for c in need_label_correction_dict: 52 | if need_label_correction_dict[c]: 53 | update_list.append(c) 54 | 55 | 56 | for pid in model_predictions: 57 | pred_prob = model_predictions[pid] 58 | pred = np.argmax(pred_prob, axis=0) 59 | label = original_labels[pid] 60 | 61 | # print(np.unique(label)) 62 | # print(update_list) 63 | # does not belong to the class that need to be updated, then we do not need the following updating process 64 | if set(np.unique(label)).isdisjoint(set(update_list)): 65 | new_label_dict[pid] = label 66 | continue 67 | 68 | 69 | # if the prediction is confident 70 | # confident = np.max(pred_prob, axis=0) > conf_threshold 71 | 72 | # if the prediction is confident 73 | # code support different threshold for foreground and background, 74 | # during the experiment, we always set them to be the same for simplicity 75 | confident = (np.max(pred_prob[1:], axis=0) > conf_threshold) |(pred_prob[0] > conf_threshold_bg) 76 | 77 | # before update: only class that need correction will be replaced 78 | belong_to_correction_class = label==0 79 | for c in need_label_correction_dict: 80 | if need_label_correction_dict[c]: 81 | belong_to_correction_class |= (label==c) 82 | 83 | # after update: only pixels that will be flipped to the allowed classes will be updated 84 | after_belong = pred==0 85 | for c in need_label_correction_dict: 86 | if need_label_correction_dict[c]: 87 | after_belong |= (pred==c) 88 | 89 | # combine all three masks together 90 | replace_flag = confident & belong_to_correction_class & after_belong 91 | 92 | 93 | # the class constraint 94 | if class_constraint: 95 | unique_class = np.unique(label) 96 | # print(unique_class) 97 | # indx = torch.zeros((h, w), dtype=torch.long) 98 | class_constraint_indx = (pred==0) 99 | for element in unique_class: 100 | class_constraint_indx = class_constraint_indx | (pred == element) 101 | 102 | 103 | replace_flag = replace_flag & (class_constraint_indx != 0) 104 | 105 | 106 | # replace with the new label 107 | next_label = np.where(replace_flag, pred, label).astype("int32") 108 | 109 | # logic 255: 110 | # - rule# 1: if label[i,j] != 0, and pred[i,j] = 0, then next_label[i,j] = 255 111 | # - rule# 2: if label[i,j] = 255 and pred[i,j] != 0 and confident, then next_label[i,j] = pred[i,j] 112 | # rule 2 is already enforced above, don't need additional code 113 | if logic_255: 114 | rule_1_flag = (label != 0) & (pred == 0) 115 | next_label = np.where(rule_1_flag, np.ones(next_label.shape)*255, next_label).astype("int32") 116 | 117 | new_label_dict[pid] = next_label 118 | 119 | return new_label_dict 120 | 121 | 122 | -------------------------------------------------------------------------------- /SegThor/brat/loading.py: -------------------------------------------------------------------------------- 1 | import os, torch, cv2, pickle, copy, random 2 | import numpy as np 3 | from scipy import misc 4 | from PIL import Image 5 | from torch.utils.data import Dataset, Sampler 6 | import torchvision.transforms as transforms 7 | import torchvision.transforms.functional as F 8 | 9 | 10 | def final_noise_function(mat): 11 | mode = np.random.choice(["under", "over"]) 12 | iterations = np.random.choice(np.arange(2,5)) 13 | return under_over_seg(mat, iterations, mode) 14 | 15 | 16 | def under_over_seg(mat, iteration=1, mode="under"): 17 | target_num = 1000 18 | mat = np.copy(mat) 19 | kernel = np.ones((3,3),np.uint8) 20 | for cls in [1,3,4,2]: 21 | binary_mat = mat==cls 22 | foreground_num = np.sum(binary_mat) 23 | if foreground_num != 0: 24 | # resize the image to match the foreground pixel number 25 | h, w = mat.shape 26 | ratio = np.sqrt(target_num/foreground_num) 27 | h_new = int(round( h * ratio)) 28 | w_new = int(round( w * ratio)) 29 | resized_img = cv2.resize(binary_mat.astype("uint8"), (w_new, h_new), interpolation=cv2.INTER_CUBIC) > 0 30 | # erosion or dilation 31 | if mode == "under": 32 | binary_mat_processed = cv2.erode(resized_img.astype("uint8"),kernel, iterations =iteration) 33 | elif mode == "over": 34 | binary_mat_processed = cv2.dilate(resized_img.astype("uint8"), kernel, iterations=iteration) 35 | # resize back to the original size 36 | binary_mat_processed_resized = cv2.resize(binary_mat_processed, (w, h), interpolation=cv2.INTER_CUBIC) > 0 37 | # fill in the gap 38 | if mode == "under": 39 | mat = np.where(binary_mat_processed_resized!=binary_mat, np.zeros(mat.shape), mat) 40 | elif mode == "over": 41 | mat = np.where(binary_mat_processed_resized & (mat==0), np.ones(mat.shape)*cls, mat) 42 | return mat 43 | 44 | def under_seg(mat): 45 | mat = np.copy(mat) 46 | kernel_small = np.ones((2,2),np.uint8) 47 | kernel_medium = np.ones((3,3),np.uint8) 48 | kernel_large = np.ones((5,5),np.uint8) 49 | for cls in [1,2,3,4]: 50 | binary_mat = mat==cls 51 | if cls in [1,3]: 52 | kernel_used = kernel_small 53 | iteration = 1 54 | elif cls == 2: 55 | kernel_used = kernel_large 56 | iteration = 2 57 | else: 58 | kernel_used = kernel_medium 59 | iteration = 2 60 | binary_mat_eroded = cv2.erode(binary_mat.astype("uint8"),kernel_used,iterations =iteration) 61 | mat = np.where(binary_mat_eroded!=binary_mat, np.zeros(mat.shape), mat) 62 | return mat 63 | 64 | def over_seg(mat): 65 | mat = np.copy(mat) 66 | kernel_small = np.ones((2,2),np.uint8) 67 | kernel_medium = np.ones((3,3),np.uint8) 68 | kernel_large = np.ones((5,5),np.uint8) 69 | for cls in [1,2,3,4]: 70 | if cls in [1,3]: 71 | kernel_used = kernel_small 72 | elif cls == 3: 73 | kernel_used = kernel_large 74 | else: 75 | kernel_used = kernel_medium 76 | binary_mat = mat==cls 77 | binary_mat_dilated = cv2.dilate(binary_mat.astype("uint8"), kernel_used, iterations=2) 78 | mat = np.where(binary_mat_dilated, np.ones(mat.shape)*cls, mat) 79 | return mat 80 | 81 | def wrong_seg(mat): 82 | mat_cp = np.copy(mat) 83 | channel_0 = np.random.choice([1,2]) 84 | channel_1 = np.random.choice([0,2]) 85 | channel_2 = np.random.choice([0,1]) 86 | mat_cp[0,:,:] = mat[channel_0,:,:] 87 | mat_cp[1,:,:] = mat[channel_1,:,:] 88 | mat_cp[2,:,:] = mat[channel_2,:,:] 89 | return mat_cp 90 | 91 | def noise_seg(mat, noise_level=0.05): 92 | """ 93 | P(out=0 | in=0) = 1-noise_level 94 | P(out=1234 | in=0) = noise_level/4 95 | P(out=0 | in=1234) = noise_level 96 | P(out=1234 | in=1234) = 1-noise_level 97 | """ 98 | mat = np.copy(mat) 99 | fate = np.random.uniform(low=0, high=1, size=mat.shape) 100 | # deal with 0 101 | is_zero_indicator = mat == 0 102 | background_flip_to = np.random.choice([1,2,3,4], size=mat.shape) 103 | mat = np.where( (fate <= noise_level) & is_zero_indicator, background_flip_to, mat) 104 | # deal with 1,2,3,4 105 | mat = np.where( (fate <= noise_level) & (~is_zero_indicator), np.zeros(mat.shape), mat) 106 | return mat 107 | 108 | def mixed_seg(mat): 109 | fate = np.random.uniform(0,1) 110 | if fate < 0.33: 111 | return under_seg(mat) 112 | elif fate < 0.67: 113 | return over_seg(mat) 114 | else: 115 | return noise_seg(mat) 116 | 117 | 118 | NOISE_LABEL_DICT = {"under":under_seg, "over":over_seg, "wrong":wrong_seg, "noise":noise_seg, 119 | "mixed":mixed_seg, "final":final_noise_function} 120 | 121 | class StackedRandomAffine(transforms.RandomAffine): 122 | def __call__(self, imgs): 123 | """ 124 | img (PIL Image): Image to be transformed. 125 | Returns: 126 | PIL Image: Affine transformed image. 127 | """ 128 | ret = self.get_params(self.degrees, self.translate, self.scale, self.shear, imgs[0].size) 129 | return [F.affine(x, *ret, resample=self.resample, fillcolor=self.fillcolor) for x in imgs] 130 | 131 | 132 | def standarize(img): 133 | return (img - img.mean()) / img.std() 134 | 135 | class BaseDataset(Dataset): 136 | def __init__(self, parameters, data_list, augmentation=False, noise_label=None, noise_level=None, cache_dir=None): 137 | self.data_list = data_list 138 | self.data_dir = parameters["data_dir"] 139 | self.img_dir = os.path.join(self.data_dir, "img") 140 | self.seg_dir = os.path.join(self.data_dir, "label") 141 | 142 | # reset seeds 143 | random.seed(parameters["seed"]) 144 | torch.manual_seed(parameters["seed"]) 145 | torch.cuda.manual_seed(parameters["seed"]) 146 | np.random.seed(parameters["seed"]) 147 | 148 | # load cached images and labels if necessary 149 | if cache_dir is None: 150 | self.cache_label = None 151 | self.cache_img = None 152 | else: 153 | with open(cache_dir, "rb") as f: 154 | self.cache_img, self.cache_label = pickle.load(f) 155 | self.cache_clean_label = copy.deepcopy(self.cache_label) 156 | 157 | # noise label functions 158 | self.noise_function = None if noise_label is None else NOISE_LABEL_DICT[noise_label] 159 | if self.noise_function is not None and noise_level is not None: 160 | noise_number = int(round(noise_level * len(self.data_list))) 161 | self.noise_index_list = np.random.permutation(np.arange(len(self.data_list)))[:noise_number] 162 | # add noise to the cached labels 163 | for i in range(len(self.data_list)): 164 | if i in self.noise_index_list: 165 | img_name = self.data_list[i] 166 | self.cache_label[img_name] = self.noise_function(self.cache_label[img_name]) 167 | self.cache_noisy_label = copy.deepcopy(self.cache_label) 168 | else: 169 | self.cache_noisy_label = self.cache_clean_label 170 | 171 | # augmentation setting 172 | self.augmentation = augmentation 173 | self.augmentation_function = StackedRandomAffine(degrees=(-45, 45), translate=(0.1, 0.1), scale=(0.8, 1.5)) 174 | 175 | # transformation setting 176 | transform_list = [] 177 | if parameters["resize"] is not None: 178 | transform_list.append(transforms.Resize(size=(parameters["resize"], parameters["resize"]), 179 | interpolation=0)) 180 | transform_list.append(transforms.ToTensor()) 181 | self.transform = transforms.Compose(transform_list) 182 | 183 | def __len__(self): 184 | return len(self.data_list) 185 | 186 | 187 | 188 | class BraTSDataset(BaseDataset): 189 | def __init__(self, parameters, data_list, augmentation=False, noise_label=None): 190 | super(BraTSDataset, self).__init__(parameters, data_list, augmentation, noise_label) 191 | 192 | def __getitem__(self, index): 193 | img_name = self.data_list[index] 194 | 195 | # put up paths 196 | img_path = os.path.join(self.img_dir, img_name) 197 | seg_path = os.path.join(self.seg_dir, img_name) 198 | 199 | # load images and seg 200 | img = np.load(img_path).astype("int16") 201 | seg = np.load(seg_path).astype("int8") 202 | if self.noise_function is not None: 203 | seg = self.noise_function(seg) 204 | 205 | # convert to pil image 206 | img_channel_pils = [Image.fromarray(img[i,:,:].astype("int16")) for i in range(img.shape[0])] 207 | seg_channel_pils = [Image.fromarray(seg[i,:,:].astype("int8")) for i in range(seg.shape[0])] 208 | 209 | # augmentation 210 | if self.augmentation: 211 | aug_res = self.augmentation_function(img_channel_pils + seg_channel_pils) 212 | img_channel_pils = aug_res[:4] 213 | seg_channel_pils = aug_res[4:] 214 | 215 | # post-process 216 | img_channel_torch = [standarize(self.to_tensor(x).float()) for x in img_channel_pils] 217 | label_channel_torch = [self.to_tensor(x) for x in seg_channel_pils] 218 | img_torch = torch.cat(img_channel_torch, dim=0) 219 | label_torch = torch.cat(label_channel_torch, dim=0) 220 | label_torch[label_torch > 0] = 1 221 | 222 | return img_torch.float(), label_torch.long(), img_name 223 | 224 | class SegTHORDataset(BaseDataset): 225 | def __init__(self, parameters, data_list, augmentation=False, noise_label=None, noise_level=None, cache_dir=None): 226 | super(SegTHORDataset, self).__init__(parameters, data_list, augmentation, noise_label, noise_level, cache_dir) 227 | 228 | def reset_labels(self, new_labels): 229 | self.cache_label = new_labels 230 | 231 | def __getitem__(self, index): 232 | img_name = self.data_list[index] 233 | # load image and the segmentation label 234 | if self.cache_img is None: 235 | img_path = os.path.join(self.img_dir, img_name) 236 | img = np.load(img_path).astype("int16") 237 | img -= img.min() 238 | else: 239 | img = self.cache_img[img_name] 240 | if self.cache_label is None: 241 | seg_path = os.path.join(self.seg_dir, img_name) 242 | seg = np.load(seg_path).astype("int8") 243 | # add noise to the label if needed 244 | if self.noise_function is not None and index in self.noise_index_list: 245 | seg = self.noise_function(seg) 246 | else: 247 | seg = self.cache_label[img_name] 248 | clean_seg = self.cache_clean_label[img_name] 249 | original_noisy_seg = self.cache_noisy_label[img_name] 250 | 251 | # convert to pil image 252 | img_pils = Image.fromarray(img) 253 | seg_pils = Image.fromarray(seg) 254 | clean_seg_pils = Image.fromarray(clean_seg) 255 | original_noisy_seg_pils = Image.fromarray(original_noisy_seg) 256 | 257 | # augmentation 258 | if self.augmentation: 259 | img_pils, seg_pils, clean_seg_pils, original_noisy_seg_pils = self.augmentation_function([img_pils, seg_pils, clean_seg_pils, original_noisy_seg_pils]) 260 | 261 | # post-process 262 | img_torch = standarize(self.transform(img_pils).float()) 263 | label_torch = self.transform(seg_pils) 264 | clean_label_torch = self.transform(clean_seg_pils) 265 | original_noisy_torch = self.transform(original_noisy_seg_pils) 266 | return img_torch.float(), label_torch.long(), original_noisy_torch.long(), clean_label_torch.long(), img_name -------------------------------------------------------------------------------- /SegThor/brat/unet_model.py: -------------------------------------------------------------------------------- 1 | """ Full assembly of the parts to form the complete network """ 2 | 3 | import torch.nn.functional as F 4 | 5 | from .unet_parts import * 6 | 7 | 8 | class UNet(nn.Module): 9 | def __init__(self, n_channels, n_classes, bilinear=True): 10 | super(UNet, self).__init__() 11 | self.n_channels = n_channels 12 | self.n_classes = n_classes 13 | self.bilinear = bilinear 14 | 15 | self.inc = DoubleConv(n_channels, 64) 16 | self.down1 = Down(64, 128) 17 | self.down2 = Down(128, 256) 18 | self.down3 = Down(256, 512) 19 | factor = 2 if bilinear else 1 20 | self.down4 = Down(512, 1024 // factor) 21 | self.up1 = Up(1024, 512 // factor, bilinear) 22 | self.up2 = Up(512, 256 // factor, bilinear) 23 | self.up3 = Up(256, 128 // factor, bilinear) 24 | self.up4 = Up(128, 64, bilinear) 25 | self.outc = OutConv(64, n_classes) 26 | 27 | def forward(self, x): 28 | x1 = self.inc(x) 29 | x2 = self.down1(x1) 30 | x3 = self.down2(x2) 31 | x4 = self.down3(x3) 32 | x5 = self.down4(x4) 33 | x = self.up1(x5, x4) 34 | x = self.up2(x, x3) 35 | x = self.up3(x, x2) 36 | x = self.up4(x, x1) 37 | logits = self.outc(x) 38 | return logits -------------------------------------------------------------------------------- /SegThor/brat/unet_parts.py: -------------------------------------------------------------------------------- 1 | """ Parts of the U-Net model """ 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | 7 | 8 | class DoubleConv(nn.Module): 9 | """(convolution => [BN] => ReLU) * 2""" 10 | 11 | def __init__(self, in_channels, out_channels, mid_channels=None): 12 | super().__init__() 13 | if not mid_channels: 14 | mid_channels = out_channels 15 | self.double_conv = nn.Sequential( 16 | nn.Conv2d(in_channels, mid_channels, kernel_size=3, padding=1), 17 | nn.BatchNorm2d(mid_channels), 18 | nn.ReLU(inplace=True), 19 | nn.Conv2d(mid_channels, out_channels, kernel_size=3, padding=1), 20 | nn.BatchNorm2d(out_channels), 21 | nn.ReLU(inplace=True) 22 | ) 23 | 24 | def forward(self, x): 25 | return self.double_conv(x) 26 | 27 | 28 | class Down(nn.Module): 29 | """Downscaling with maxpool then double conv""" 30 | 31 | def __init__(self, in_channels, out_channels): 32 | super().__init__() 33 | self.maxpool_conv = nn.Sequential( 34 | nn.MaxPool2d(2), 35 | DoubleConv(in_channels, out_channels) 36 | ) 37 | 38 | def forward(self, x): 39 | return self.maxpool_conv(x) 40 | 41 | 42 | class Up(nn.Module): 43 | """Upscaling then double conv""" 44 | 45 | def __init__(self, in_channels, out_channels, bilinear=True): 46 | super().__init__() 47 | 48 | # if bilinear, use the normal convolutions to reduce the number of channels 49 | if bilinear: 50 | self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True) 51 | self.conv = DoubleConv(in_channels, out_channels, in_channels // 2) 52 | else: 53 | self.up = nn.ConvTranspose2d(in_channels , in_channels // 2, kernel_size=2, stride=2) 54 | self.conv = DoubleConv(in_channels, out_channels) 55 | 56 | 57 | def forward(self, x1, x2): 58 | x1 = self.up(x1) 59 | # input is CHW 60 | diffY = x2.size()[2] - x1.size()[2] 61 | diffX = x2.size()[3] - x1.size()[3] 62 | 63 | x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2, 64 | diffY // 2, diffY - diffY // 2]) 65 | # if you have padding issues, see 66 | # https://github.com/HaiyongJiang/U-Net-Pytorch-Unstructured-Buggy/commit/0e854509c2cea854e247a9c615f175f76fbb2e3a 67 | # https://github.com/xiaopeng-liao/Pytorch-UNet/commit/8ebac70e633bac59fc22bb5195e513d5832fb3bd 68 | x = torch.cat([x2, x1], dim=1) 69 | return self.conv(x) 70 | 71 | 72 | class OutConv(nn.Module): 73 | def __init__(self, in_channels, out_channels): 74 | super(OutConv, self).__init__() 75 | self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1) 76 | 77 | def forward(self, x): 78 | return self.conv(x) -------------------------------------------------------------------------------- /SegThor/lib/utils/JSD_loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torchvision 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | 7 | def calc_jsd_multiscale(weight, labels1_a, pred1, pred2, pred3, threshold=0.8, Mask_label255_sign='no'): 8 | 9 | Mask_label255 = (labels1_a < 255).float() # do not compute the area that is irrelavant (dataaug) b,h,w 10 | weight_softmax = F.softmax(weight, dim=0) 11 | 12 | criterion1 = nn.CrossEntropyLoss(ignore_index=255, reduction='none') 13 | criterion2 = nn.CrossEntropyLoss(ignore_index=255, reduction='none') 14 | criterion3 = nn.CrossEntropyLoss(ignore_index=255, reduction='none') 15 | 16 | loss1 = criterion1(pred1 * weight_softmax[0], labels1_a) # * weight_softmax[0] 17 | loss2 = criterion2(pred2 * weight_softmax[1], labels1_a) # * weight_softmax[1] 18 | loss3 = criterion3(pred3 * weight_softmax[2], labels1_a) # * weight_softmax[2] 19 | 20 | loss = (loss1 + loss2 + loss3) 21 | 22 | probs = [F.softmax(logits, dim=1) for i, logits in enumerate([pred1, pred2, pred3])] 23 | 24 | weighted_probs = [weight_softmax[i] * prob for i, prob in enumerate(probs)] # weight_softmax[i]* 25 | mixture_label = (torch.stack(weighted_probs)).sum(axis=0) 26 | #mixture_label = torch.clamp(mixture_label, 1e-7, 1) # h,c,h,w 27 | mixture_label = torch.clamp(mixture_label, 1e-3, 1-1e-3) # h,c,h,w 28 | 29 | # add this code block for early torch version where torch.amax is not available 30 | if torch.__version__=="1.5.0" or torch.__version__=="1.6.0": 31 | _, max_probs = torch.max(mixture_label*Mask_label255.unsqueeze(1), dim=-3, keepdim=True) 32 | _, max_probs = torch.max(max_probs, dim=-2, keepdim=True) 33 | _, max_probs = torch.max(max_probs, dim=-1, keepdim=True) 34 | else: 35 | max_probs = torch.amax(mixture_label*Mask_label255.unsqueeze(1), dim=(-3, -2, -1), keepdim=True) 36 | mask = max_probs.ge(threshold).float() 37 | 38 | 39 | logp_mixture = mixture_label.log() 40 | 41 | log_probs = [torch.sum(F.kl_div(logp_mixture, prob, reduction='none') * mask, dim=1) for prob in probs] 42 | if Mask_label255_sign == 'yes': 43 | consistency = sum(log_probs)*Mask_label255 44 | else: 45 | consistency = sum(log_probs) 46 | 47 | return torch.mean(loss), torch.mean(consistency), consistency, mixture_label 48 | -------------------------------------------------------------------------------- /SegThor/lib/utils/iou_computation.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | 5 | def update_iou_stat(predict, gt, TP, P, T, num_classes = 21): 6 | """ 7 | :param predict: the pred of each batch, should be numpy array, after take the argmax b,h,w 8 | :param gt: the gt label of the batch, should be numpy array b,h,w 9 | :param TP: True positive 10 | :param P: positive prediction 11 | :param T: True seg 12 | :param num_classes: number of classes in the dataset 13 | :return: TP, P, T 14 | """ 15 | cal = gt < 255 16 | 17 | mask = (predict == gt) * cal 18 | 19 | for i in range(num_classes): 20 | P[i] += np.sum((predict == i) * cal) 21 | T[i] += np.sum((gt == i) * cal) 22 | TP[i] += np.sum((gt == i) * mask) 23 | 24 | return TP, P, T 25 | 26 | 27 | def iter_iou_stat(predict, gt, num_classes = 21): 28 | """ 29 | :param predict: the pred of each batch, should be numpy array, after take the argmax b,h,w 30 | :param gt: the gt label of the batch, should be numpy array b,h,w 31 | :param TP: True positive 32 | :param P: positive prediction 33 | :param T: True seg 34 | :param num_classes: number of classes in the dataset 35 | :return: TP, P, T 36 | """ 37 | cal = gt < 255 38 | 39 | mask = (predict == gt) * cal 40 | 41 | TP = np.zeros(num_classes) 42 | P = np.zeros(num_classes) 43 | T = np.zeros(num_classes) 44 | 45 | for i in range(num_classes): 46 | P[i] = np.sum((predict == i) * cal) 47 | T[i] = np.sum((gt == i) * cal) 48 | TP[i] = np.sum((gt == i) * mask) 49 | 50 | return np.array([TP, P, T]) 51 | 52 | 53 | def compute_iou(TP, P, T, num_classes = 21): 54 | """ 55 | :param TP: 56 | :param P: 57 | :param T: 58 | :param num_classes: number of classes in the dataset 59 | :return: IoU 60 | """ 61 | IoU = [] 62 | for i in range(num_classes): 63 | IoU.append(TP[i] / (T[i] + P[i] - TP[i] + 1e-10)) 64 | return IoU 65 | 66 | 67 | def update_fraction_batchwise(mask, gt, fraction, num_classes = 21): 68 | """ 69 | :param mask: True when belong to subgroup (memorized, correct, others) which we want to calculate fraction on 70 | :param gt: the gt label of the batch, numpy array 71 | :param fraction: fraction of pixels in the subgroup 72 | :param num_classes: number of classes in the dataset 73 | :return: updated fraction 74 | """ 75 | cal = gt < 255 76 | 77 | for i in range(num_classes): 78 | fraction[i] += np.sum((mask * (gt == i) * cal))/np.sum((gt == i) * cal) 79 | 80 | return fraction 81 | 82 | 83 | def update_fraction_instancewise(mask, gt, fraction, num_classes = 21): 84 | """ 85 | :param mask: True when belong to subgroup (memorized, correct, others) which we want to calculate fraction on 86 | :param gt: the gt label of the batch, numpy array 87 | :param fraction: fraction of pixels in the subgroup 88 | :param num_classes: number of classes in the dataset 89 | :return: updated fraction 90 | """ 91 | # np.sum((gt == i) * cal maybe a nan value, can't do that 92 | cal = gt < 255 93 | 94 | for i in range(num_classes): 95 | fraction[i] += np.mean(np.sum((mask * (gt == i) * cal), axis= (-2,-1))/np.sum((gt == i) * cal, axis= (-2,-1))) 96 | 97 | return fraction 98 | 99 | def update_fraction_pixelwise(mask, gt, abs_num_and_total, num_classes = 21): 100 | """ 101 | :param mask: True when belong to subgroup (memorized, correct, others) which we want to calculate fraction on 102 | :param gt: the gt label of the batch, numpy array 103 | :param abs_num_and_total: the absolute number of pixel belong to the mask and the total num of pixels [abs_num, pixel_num] 104 | :param num_classes: number of classes in the dataset 105 | :return: updated fraction 106 | """ 107 | cal = gt < 255 108 | 109 | for i in range(num_classes): 110 | abs_num_and_total[i][0] += np.sum(mask * (gt == i) * cal) 111 | abs_num_and_total[i][1] += np.sum((gt == i) * cal) 112 | 113 | 114 | return abs_num_and_total 115 | 116 | def iter_fraction_pixelwise(mask, gt, num_classes = 21): 117 | """ 118 | :param mask: True when belong to subgroup (memorized, correct, others) which we want to calculate fraction on 119 | :param gt: the gt label of the batch, numpy array 120 | :param num_classes: number of classes in the dataset 121 | :return: updated fraction 122 | """ 123 | cal = gt < 255 124 | 125 | abs_num_and_total = np.zeros((num_classes,2)) 126 | 127 | for i in range(num_classes): 128 | abs_num_and_total[i][0] += np.sum(mask * (gt == i) * cal) 129 | abs_num_and_total[i][1] += np.sum((gt == i) * cal) 130 | 131 | 132 | return abs_num_and_total 133 | 134 | 135 | 136 | def get_mask(gt_np, label_np, pred_np): 137 | """ 138 | 139 | Args: 140 | gt_np: the GT label 141 | label_np: the CAM pseudo label 142 | pred_np: the prediction 143 | 144 | Returns: the mask of different type 145 | 146 | """ 147 | wrong_mask_correct = (gt_np != label_np) & (pred_np == gt_np) 148 | wrong_mask_memorized = (gt_np != label_np) & (pred_np == label_np) 149 | wrong_mask_others = (gt_np != label_np) & (pred_np != gt_np) & (pred_np != label_np) 150 | clean_mask_correct = (gt_np == label_np) & (pred_np == gt_np) 151 | clean_mask_incorrect = (gt_np == label_np) & (pred_np != gt_np) 152 | 153 | return (wrong_mask_correct,wrong_mask_memorized,wrong_mask_others,clean_mask_correct,clean_mask_incorrect) -------------------------------------------------------------------------------- /SegThor/requirements.txt: -------------------------------------------------------------------------------- 1 | torch>=1.7.0 2 | torchvision>=0.8.1 3 | mxnet>=1.7.0.post1 4 | scipy>=1.5.1 5 | numpy>=1.19.4 6 | scikit_image>=0.17.2 7 | pydensecrf>=1.0rc3 8 | pandas>=1.0.5 9 | opencv_python>=4.3.0.36 10 | matplotlib>=3.3.0 11 | Pillow>=8.1.0 12 | tensorboardX>=2.1 13 | tqdm 14 | scikit-image 15 | -------------------------------------------------------------------------------- /__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Kangningthu/ADELE/7195bd0af39be79c533d67dd7eab7f9bfd6a4285/__init__.py -------------------------------------------------------------------------------- /config.py: -------------------------------------------------------------------------------- 1 | # ---------------------------------------- 2 | # Written by Yude Wang 3 | # ---------------------------------------- 4 | import torch 5 | import argparse 6 | import os 7 | import sys 8 | import cv2 9 | import time 10 | 11 | config_dict = { 12 | 'EXP_NAME': 'Experiment', 13 | 'GPUS': 2, 14 | 'TEST_GPUS': 1, 15 | 'DATA_NAME': 'VOCTrainwsegDataset', 16 | 'DATA_YEAR': 2012, 17 | 'DATA_AUG': True, 18 | 'DATA_WORKERS': 8, 19 | 'DATA_MEAN': [0.485, 0.456, 0.406], 20 | 'DATA_STD': [0.229, 0.224, 0.225], 21 | 'DATA_RANDOMCROP': 448, 22 | 'DATA_RANDOMSCALE': [0.5, 1.5], 23 | 'DATA_RANDOM_H': 10, 24 | 'DATA_RANDOM_S': 10, 25 | 'DATA_RANDOM_V': 10, 26 | 'DATA_RANDOMFLIP': 0.5, 27 | 'DATA_PSEUDO_GT': '/scratch/kl3141/seam/SEAM-master/results/aff_rw_aug', 28 | 'DATA_FEATURE_DIR':False, 29 | 30 | 'MODEL_NAME': 'deeplabv1', 31 | 'MODEL_BACKBONE': 'resnet38', 32 | 'MODEL_BACKBONE_PRETRAIN': True, 33 | 'MODEL_NUM_CLASSES': 21, 34 | 'MODEL_FREEZEBN': False, 35 | 36 | 'TRAIN_LR': 0.001, 37 | 'TRAIN_MOMENTUM': 0.9, 38 | 'TRAIN_WEIGHT_DECAY': 0.0005, 39 | 'TRAIN_BN_MOM': 0.0003, 40 | 'TRAIN_POWER': 0.9, 41 | 'TRAIN_BATCHES': 10, 42 | 'TRAIN_SHUFFLE': True, 43 | 'TRAIN_MINEPOCH': 0, 44 | 'TRAIN_ITERATION': 20000, 45 | 'TRAIN_TBLOG': True, 46 | 47 | 'TEST_MULTISCALE': [0.5, 0.75, 1.0, 1.25, 1.5, 1.75], 48 | 'TEST_FLIP': True, 49 | 'TEST_CRF': True, 50 | 'TEST_BATCHES': 1, 51 | 52 | 'MODEL_BACKBONE_DILATED':False, 53 | 'MODEL_BACKBONE_MULTIGRID':False, 54 | 'MODEL_BACKBONE_DEEPBASE': False, 55 | 56 | 'scale_factor':0.7, 57 | 'scale_factor2': 1.5, 58 | 'lambda_seg': 0.5, 59 | 60 | } 61 | 62 | config_dict['ROOT_DIR'] = os.path.abspath(os.path.dirname("__file__")) 63 | config_dict['MODEL_SAVE_DIR'] = os.path.join(config_dict['ROOT_DIR'],'model',config_dict['EXP_NAME']) 64 | config_dict['TRAIN_CKPT'] = None 65 | config_dict['LOG_DIR'] = os.path.join(config_dict['ROOT_DIR'],'log',config_dict['EXP_NAME']) 66 | sys.path.insert(0, os.path.join(config_dict['ROOT_DIR'], 'lib')) 67 | -------------------------------------------------------------------------------- /lib/datasets/BaseDataset.py: -------------------------------------------------------------------------------- 1 | # ---------------------------------------- 2 | # Written by Yude Wang 3 | # ---------------------------------------- 4 | 5 | from __future__ import print_function, division 6 | import os 7 | import torch 8 | import pandas as pd 9 | import cv2 10 | import multiprocessing 11 | from skimage import io 12 | from PIL import Image 13 | import numpy as np 14 | from torch.utils.data import Dataset 15 | from datasets.transform import * 16 | from utils.imutils import * 17 | from utils.registry import DATASETS 18 | 19 | #@DATASETS.register_module 20 | class BaseDataset(Dataset): 21 | def __init__(self, cfg, period, transform='none'): 22 | super(BaseDataset, self).__init__() 23 | self.cfg = cfg 24 | self.period = period 25 | self.transform = transform 26 | if 'train' not in self.period: 27 | assert self.transform == 'none' 28 | self.num_categories = None 29 | self.totensor = ToTensor() 30 | self.imagenorm = ImageNorm(cfg.DATA_MEAN, cfg.DATA_STD) 31 | 32 | if self.transform != 'none': 33 | if cfg.DATA_RANDOMCROP > 0: 34 | self.randomcrop = RandomCrop(cfg.DATA_RANDOMCROP) 35 | if cfg.DATA_RANDOMSCALE != 1: 36 | self.randomscale = RandomScale(cfg.DATA_RANDOMSCALE) 37 | if cfg.DATA_RANDOMFLIP > 0: 38 | self.randomflip = RandomFlip(cfg.DATA_RANDOMFLIP) 39 | if cfg.DATA_RANDOM_H > 0 or cfg.DATA_RANDOM_S > 0 or cfg.DATA_RANDOM_V > 0: 40 | self.randomhsv = RandomHSV(cfg.DATA_RANDOM_H, cfg.DATA_RANDOM_S, cfg.DATA_RANDOM_V) 41 | else: 42 | self.multiscale = Multiscale(self.cfg.TEST_MULTISCALE) 43 | 44 | 45 | def __getitem__(self, idx): 46 | sample = self.__sample_generate__(idx) 47 | 48 | if 'segmentation' in sample.keys(): 49 | sample['mask'] = sample['segmentation'] < self.num_categories 50 | t = sample['segmentation'].copy() 51 | t[t >= self.num_categories] = 0 52 | sample['segmentation_onehot']=onehot(t,self.num_categories) 53 | return self.totensor(sample) 54 | 55 | def __sample_generate__(self, idx, split_idx=0): 56 | name = self.load_name(idx) 57 | image = self.load_image(idx) 58 | r,c,_ = image.shape 59 | sample = {'image': image, 'name': name, 'row': r, 'col': c} 60 | 61 | if 'test' in self.period: 62 | return self.__transform__(sample) 63 | elif self.cfg.DATA_PSEUDO_GT and idx>=split_idx and 'train' in self.period: 64 | segmentation = self.load_pseudo_segmentation(idx) 65 | else: 66 | segmentation = self.load_segmentation(idx) 67 | sample['segmentation'] = segmentation 68 | t = sample['segmentation'].copy() 69 | t[t >= self.num_categories] = 0 70 | sample['category'] = seg2cls(t,self.num_categories) 71 | sample['category_copypaste'] = np.zeros(sample['category'].shape) 72 | 73 | if self.transform == 'none' and self.cfg.DATA_FEATURE_DIR: 74 | feature = self.load_feature(idx) 75 | sample['feature'] = feature 76 | return self.__transform__(sample) 77 | 78 | def __transform__(self, sample): 79 | if self.transform == 'weak': 80 | sample = self.__weak_augment__(sample) 81 | elif self.transform == 'strong': 82 | sample = self.__strong_augment__(sample) 83 | else: 84 | sample = self.imagenorm(sample) 85 | sample = self.multiscale(sample) 86 | return sample 87 | 88 | def __weak_augment__(self, sample): 89 | if self.cfg.DATA_RANDOM_H>0 or self.cfg.DATA_RANDOM_S>0 or self.cfg.DATA_RANDOM_V>0: 90 | sample = self.randomhsv(sample) 91 | if self.cfg.DATA_RANDOMFLIP > 0: 92 | sample = self.randomflip(sample) 93 | if self.cfg.DATA_RANDOMSCALE != 1: 94 | sample = self.randomscale(sample) 95 | sample = self.imagenorm(sample) 96 | if self.cfg.DATA_RANDOMCROP > 0: 97 | sample = self.randomcrop(sample) 98 | return sample 99 | 100 | def __strong_augment__(self, sample): 101 | raise NotImplementedError 102 | 103 | def __len__(self): 104 | raise NotImplementedError 105 | 106 | def load_name(self, idx): 107 | raise NotImplementedError 108 | 109 | def load_image(self, idx): 110 | raise NotImplementedError 111 | 112 | def load_segmentation(self, idx): 113 | raise NotImplementedError 114 | 115 | def load_pseudo_segmentation(self, idx): 116 | raise NotImplementedError 117 | 118 | def load_feature(self, idx): 119 | raise NotImplementedError 120 | 121 | def save_result(self, result_list, model_id): 122 | raise NotImplementedError 123 | 124 | def save_pseudo_gt(self, result_list, level=None): 125 | raise NotImplementedError 126 | 127 | def do_python_eval(self, model_id): 128 | raise NotImplementedError 129 | -------------------------------------------------------------------------------- /lib/datasets/BaseMultiwGTauginfoDataset.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function, division 2 | import os 3 | import torch 4 | import pandas as pd 5 | import cv2 6 | import multiprocessing 7 | from skimage import io 8 | from PIL import Image 9 | import numpy as np 10 | from torch.utils.data import Dataset 11 | from datasets.transformmultiGTauginfo import * 12 | from utils.imutils import * 13 | from utils.registry import DATASETS 14 | 15 | #@DATASETS.register_module 16 | class BaseMultiwGTauginfoDataset(Dataset): 17 | def __init__(self, cfg, period, transform='none'): 18 | super(BaseMultiwGTauginfoDataset, self).__init__() 19 | self.cfg = cfg 20 | self.period = period 21 | self.transform = transform 22 | if 'train' not in self.period: 23 | assert self.transform == 'none' 24 | self.num_categories = None 25 | self.totensor = ToTensor() 26 | self.imagenorm = ImageNorm(cfg.DATA_MEAN, cfg.DATA_STD) 27 | 28 | if self.transform != 'none': 29 | if cfg.DATA_RANDOMCROP > 0: 30 | if self.transform == 'no': 31 | # self.randomcrop = RandomCrop(512) 32 | self.randomcrop = CenterCrop(512) 33 | else: 34 | self.randomcrop = RandomCrop(cfg.DATA_RANDOMCROP) 35 | if cfg.DATA_RANDOMSCALE != 1: 36 | self.randomscale = RandomScale(cfg.DATA_RANDOMSCALE) 37 | if cfg.DATA_RANDOMFLIP > 0: 38 | self.randomflip = RandomFlip(cfg.DATA_RANDOMFLIP) 39 | if cfg.DATA_RANDOM_H > 0 or cfg.DATA_RANDOM_S > 0 or cfg.DATA_RANDOM_V > 0: 40 | self.randomhsv = RandomHSV(cfg.DATA_RANDOM_H, cfg.DATA_RANDOM_S, cfg.DATA_RANDOM_V) 41 | else: 42 | self.multiscale = Multiscale(self.cfg.TEST_MULTISCALE) 43 | 44 | 45 | def __getitem__(self, idx): 46 | sample = self.__sample_generate__(idx) 47 | 48 | if 'segmentation' in sample.keys(): 49 | sample['mask'] = sample['segmentation'] < self.num_categories 50 | t = sample['segmentation'].copy() 51 | t[t >= self.num_categories] = 0 52 | sample['segmentation_onehot']=onehot(t,self.num_categories) 53 | return self.totensor(sample) 54 | 55 | def __sample_generate__(self, idx, split_idx=0): 56 | name = self.load_name(idx) 57 | image = self.load_image(idx) 58 | r,c,_ = image.shape 59 | sample = {'image': image, 'name': name, 'row': r, 'col': c} 60 | 61 | if 'test' in self.period: 62 | return self.__transform__(sample) 63 | elif self.cfg.DATA_PSEUDO_GT and idx>=split_idx and 'train' in self.period: 64 | segmentation = self.load_pseudo_segmentation(idx) 65 | else: 66 | segmentation = self.load_segmentation(idx) 67 | sample['segmentation'] = segmentation 68 | t = sample['segmentation'].copy() 69 | t[t >= self.num_categories] = 0 70 | sample['category'] = seg2cls(t,self.num_categories) 71 | sample['category_copypaste'] = np.zeros(sample['category'].shape) 72 | 73 | if self.transform == 'none' and self.cfg.DATA_FEATURE_DIR: 74 | feature = self.load_feature(idx) 75 | sample['feature'] = feature 76 | return self.__transform__(sample) 77 | 78 | def __transform__(self, sample): 79 | if self.transform == 'weak': 80 | sample = self.__weak_augment__(sample) 81 | elif self.transform == 'strong': 82 | sample = self.__strong_augment__(sample) 83 | elif self.transform == 'no': 84 | sample = self.__dict_augment__(sample) 85 | else: 86 | sample = self.imagenorm(sample) 87 | sample = self.multiscale(sample) 88 | return sample 89 | 90 | def __weak_augment__(self, sample): 91 | if self.cfg.DATA_RANDOM_H>0 or self.cfg.DATA_RANDOM_S>0 or self.cfg.DATA_RANDOM_V>0: 92 | sample = self.randomhsv(sample) 93 | if self.cfg.DATA_RANDOMFLIP > 0: 94 | sample = self.randomflip(sample) 95 | if self.cfg.DATA_RANDOMSCALE != 1: 96 | sample = self.randomscale(sample) 97 | sample = self.imagenorm(sample) 98 | if self.cfg.DATA_RANDOMCROP > 0: 99 | sample = self.randomcrop(sample) 100 | return sample 101 | 102 | def __dict_augment__(self, sample): 103 | sample = self.imagenorm(sample) 104 | if self.cfg.DATA_RANDOMCROP > 0: 105 | sample = self.randomcrop(sample) 106 | return sample 107 | def __strong_augment__(self, sample): 108 | raise NotImplementedError 109 | 110 | def __len__(self): 111 | raise NotImplementedError 112 | 113 | def load_name(self, idx): 114 | raise NotImplementedError 115 | 116 | def load_image(self, idx): 117 | raise NotImplementedError 118 | 119 | def load_segmentation(self, idx): 120 | raise NotImplementedError 121 | 122 | def load_pseudo_segmentation(self, idx): 123 | raise NotImplementedError 124 | 125 | def load_feature(self, idx): 126 | raise NotImplementedError 127 | 128 | def save_result(self, result_list, model_id): 129 | raise NotImplementedError 130 | 131 | def save_pseudo_gt(self, result_list, level=None): 132 | raise NotImplementedError 133 | 134 | def do_python_eval(self, model_id): 135 | raise NotImplementedError 136 | -------------------------------------------------------------------------------- /lib/datasets/VOCEvalDataset.py: -------------------------------------------------------------------------------- 1 | # ---------------------------------------- 2 | # The dataset for Eval that is used for label correction 3 | # ---------------------------------------- 4 | 5 | from __future__ import print_function, division 6 | import os 7 | import torch 8 | import pandas as pd 9 | import cv2 10 | import multiprocessing 11 | from skimage import io 12 | from PIL import Image 13 | import numpy as np 14 | from torch.utils.data import Dataset 15 | from datasets.transformmultiGT import * 16 | from utils.imutils import * 17 | from utils.registry import DATASETS 18 | from datasets.BaseMultiwGTauginfoDataset import BaseMultiwGTauginfoDataset 19 | import torch.nn.functional as F 20 | from utils.iou_computation import update_iou_stat, compute_iou 21 | 22 | 23 | @DATASETS.register_module 24 | class VOCEvalDataset(BaseMultiwGTauginfoDataset): 25 | def __init__(self, cfg, period, transform='none'): 26 | super(VOCEvalDataset, self).__init__(cfg, period, transform) 27 | self.dataset_name = 'VOC%d'%cfg.DATA_YEAR 28 | self.root_dir = os.path.join(cfg.ROOT_DIR,'data','VOCdevkit') 29 | self.dataset_dir = os.path.join(self.root_dir,self.dataset_name) 30 | self.rst_dir = os.path.join(self.root_dir,'results',self.dataset_name,'Segmentation') 31 | self.eval_dir = os.path.join(self.root_dir,'eval_result',self.dataset_name,'Segmentation') 32 | self.img_dir = os.path.join(self.dataset_dir, 'JPEGImages') 33 | # print(self.img_dir) 34 | self.ann_dir = os.path.join(self.dataset_dir, 'Annotations') 35 | self.seg_dir = os.path.join(self.dataset_dir, 'SegmentationClass') 36 | self.seg_dir_gt = os.path.join(self.dataset_dir, 'SegmentationClassAug') 37 | self.set_dir = os.path.join(self.dataset_dir, 'ImageSets', 'Segmentation') 38 | if cfg.DATA_PSEUDO_GT: 39 | self.pseudo_gt_dir = cfg.DATA_PSEUDO_GT 40 | # self.pseudo_gt_dir_2 = cfg.DATA_PSEUDO_GT_2 41 | # self.pseudo_gt_dir_3 = cfg.DATA_PSEUDO_GT_3 42 | else: 43 | self.pseudo_gt_dir = os.path.join(self.root_dir,'pseudo_gt',self.dataset_name,'Segmentation') 44 | 45 | file_name = None 46 | if cfg.DATA_AUG and 'train' in self.period: 47 | file_name = self.set_dir+'/'+period+'aug.txt' 48 | else: 49 | file_name = self.set_dir+'/'+period+'.txt' 50 | df = pd.read_csv(file_name, names=['filename']) 51 | self.name_list = df['filename'].values 52 | # print(self.name_list[1]) 53 | if self.dataset_name == 'VOC2012': 54 | self.categories = ['aeroplane','bicycle','bird','boat','bottle','bus','car','cat','chair','cow', 55 | 'diningtable','dog','horse','motorbike','person','pottedplant','sheep','sofa','train','tvmonitor'] 56 | self.coco2voc = [[0],[5],[2],[16],[9],[44],[6],[3],[17],[62], 57 | [21],[67],[18],[19],[4],[1],[64],[20],[63],[7],[72]] 58 | 59 | self.num_categories = len(self.categories)+1 60 | self.cmap = self.__colormap(len(self.categories)+1) 61 | 62 | # to record the previous prediction 63 | self.prev_pred_dict = {} 64 | 65 | self.ori_indx_list =[] 66 | 67 | def __len__(self): 68 | return len(self.name_list) 69 | 70 | 71 | def __getitem__(self, idx): 72 | sample = self.__sample_generate__(idx) 73 | if 'segmentation' in sample.keys(): 74 | sample['mask'] = sample['segmentation'] < self.num_categories 75 | t = sample['segmentation'].copy() 76 | t[t >= self.num_categories] = 0 77 | sample['segmentation_onehot']=onehot(t,self.num_categories) 78 | return self.totensor(sample) 79 | 80 | def __sample_generate__(self, idx, split_idx=0): 81 | name = self.load_name(idx) 82 | image = self.load_image(idx) 83 | r,c,_ = image.shape 84 | sample = {'image': image, 'name': name, 'row': r, 'col': c, 'batch_idx':idx } 85 | 86 | if 'test' in self.period: 87 | return self.__transform__(sample) 88 | elif self.cfg.DATA_PSEUDO_GT and idx>=split_idx and 'train' in self.period: 89 | segmentation, seg_gt = self.load_pseudo_segmentation(idx) 90 | else: 91 | segmentation = self.load_segmentation(idx) 92 | 93 | sample['segmentation'] = segmentation 94 | t = sample['segmentation'].copy() 95 | t[t >= self.num_categories] = 0 96 | sample['category'] = seg2cls(t,self.num_categories) 97 | sample['category_copypaste'] = np.zeros(sample['category'].shape) 98 | 99 | # if there is previous prediction for this video 100 | if idx in self.prev_pred_dict.keys(): 101 | # interpolate to the image spatial resolution self.prev_pred_dict[idx] size 1,c,h,w 102 | if torch.is_tensor(self.prev_pred_dict[idx]): 103 | # prev_pred = F.interpolate(self.prev_pred_dict[idx], size=(r, c), mode='nearest') 104 | prev_pred = F.interpolate(self.prev_pred_dict[idx], size=(r, c), mode='bilinear',align_corners=True, 105 | recompute_scale_factor=False) 106 | else: 107 | # prev_pred = F.interpolate(torch.tensor(self.prev_pred_dict[idx]), size=(r, c), mode='nearest') 108 | prev_pred = F.interpolate(torch.tensor(self.prev_pred_dict[idx]), size=(r, c), mode='bilinear',align_corners=True, 109 | recompute_scale_factor=False) 110 | sample['prev_prediction'] = prev_pred #1,c,h,w 111 | 112 | # the small scale case 113 | # sample['segmentation2'] = segmentation2 114 | # sample['segmentation3'] = segmentation3 115 | 116 | 117 | sample['segmentationgt'] = seg_gt 118 | 119 | if self.transform == 'none' and self.cfg.DATA_FEATURE_DIR: 120 | feature = self.load_feature(idx) 121 | sample['feature'] = feature 122 | return self.__transform__(sample) 123 | 124 | 125 | def load_name(self, idx): 126 | name = self.name_list[idx] 127 | return name 128 | 129 | def load_image(self, idx): 130 | name = self.name_list[idx] 131 | img_file = self.img_dir + '/' + name + '.jpg' 132 | image = cv2.imread(img_file) 133 | image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) 134 | return image_rgb 135 | 136 | def load_segmentation(self, idx): 137 | name = self.name_list[idx] 138 | seg_file = self.seg_dir + '/' + name + '.png' 139 | segmentation = np.array(Image.open(seg_file)) 140 | return segmentation 141 | 142 | def load_pseudo_segmentation(self, idx): 143 | name = self.name_list[idx] 144 | seg_file = self.pseudo_gt_dir + '/' + name + '.png' 145 | 146 | segmentation1 = Image.open(seg_file) 147 | width, height = segmentation1.size 148 | 149 | segmentation1 = np.array(segmentation1) 150 | 151 | seg_gt_file = self.seg_dir_gt + '/' + name + '.png' 152 | seg_gt = np.array(Image.open(seg_gt_file).resize((width, height))) 153 | 154 | return segmentation1, seg_gt 155 | 156 | def __colormap(self, N): 157 | """Get the map from label index to color 158 | 159 | Args: 160 | N: number of class 161 | 162 | return: a Nx3 matrix 163 | 164 | """ 165 | cmap = np.zeros((N, 3), dtype = np.uint8) 166 | 167 | def uint82bin(n, count=8): 168 | """returns the binary of integer n, count refers to amount of bits""" 169 | return ''.join([str((n >> y) & 1) for y in range(count-1, -1, -1)]) 170 | 171 | for i in range(N): 172 | r = 0 173 | g = 0 174 | b = 0 175 | idx = i 176 | for j in range(7): 177 | str_id = uint82bin(idx) 178 | r = r ^ ( np.uint8(str_id[-1]) << (7-j)) 179 | g = g ^ ( np.uint8(str_id[-2]) << (7-j)) 180 | b = b ^ ( np.uint8(str_id[-3]) << (7-j)) 181 | idx = idx >> 3 182 | cmap[i, 0] = r 183 | cmap[i, 1] = g 184 | cmap[i, 2] = b 185 | return cmap 186 | 187 | def load_ranked_namelist(self): 188 | df = self.read_rank_result() 189 | self.name_list = df['filename'].values 190 | 191 | def label2colormap(self, label): 192 | m = label.astype(np.uint8) 193 | r,c = m.shape 194 | cmap = np.zeros((r,c,3), dtype=np.uint8) 195 | cmap[:,:,0] = (m&1)<<7 | (m&8)<<3 196 | cmap[:,:,1] = (m&2)<<6 | (m&16)<<2 197 | cmap[:,:,2] = (m&4)<<5 198 | cmap[m==255] = [255,255,255] 199 | return cmap 200 | 201 | def save_result(self, result_list, model_id): 202 | """Save test results 203 | 204 | Args: 205 | result_list(list of dict): [{'name':name1, 'predict':predict_seg1},{...},...] 206 | 207 | """ 208 | folder_path = os.path.join(self.rst_dir,'%s_%s'%(model_id,self.period)) 209 | if not os.path.exists(folder_path): 210 | os.makedirs(folder_path) 211 | 212 | for sample in result_list: 213 | file_path = os.path.join(folder_path, '%s.png'%sample['name']) 214 | cv2.imwrite(file_path, sample['predict']) 215 | 216 | def save_pseudo_gt(self, result_list, folder_path=None): 217 | """Save pseudo gt 218 | 219 | Args: 220 | result_list(list of dict): [{'name':name1, 'predict':predict_seg1},{...},...] 221 | 222 | """ 223 | i = 1 224 | folder_path = self.pseudo_gt_dir if folder_path is None else folder_path 225 | if not os.path.exists(folder_path): 226 | os.makedirs(folder_path) 227 | for sample in result_list: 228 | file_path = os.path.join(folder_path, '%s.png'%(sample['name'])) 229 | cv2.imwrite(file_path, sample['predict']) 230 | i+=1 231 | 232 | def do_matlab_eval(self, model_id): 233 | import subprocess 234 | path = os.path.join(self.root_dir, 'VOCcode') 235 | eval_filename = os.path.join(self.eval_dir,'%s_result.mat'%model_id) 236 | cmd = 'cd {} && '.format(path) 237 | cmd += 'matlab -nodisplay -nodesktop ' 238 | cmd += '-r "dbstop if error; VOCinit; ' 239 | cmd += 'VOCevalseg(VOCopts,\'{:s}\');'.format(model_id) 240 | cmd += 'accuracies,avacc,conf,rawcounts = VOCevalseg(VOCopts,\'{:s}\'); '.format(model_id) 241 | cmd += 'save(\'{:s}\',\'accuracies\',\'avacc\',\'conf\',\'rawcounts\'); '.format(eval_filename) 242 | cmd += 'quit;"' 243 | 244 | print('start subprocess for matlab evaluation...') 245 | print(cmd) 246 | subprocess.call(cmd, shell=True) 247 | 248 | def do_python_eval(self, model_id): 249 | predict_folder = os.path.join(self.rst_dir,'%s_%s'%(model_id,self.period)) 250 | gt_folder = self.seg_dir 251 | TP = [] 252 | P = [] 253 | T = [] 254 | for i in range(self.num_categories): 255 | TP.append(multiprocessing.Value('i', 0, lock=True)) 256 | P.append(multiprocessing.Value('i', 0, lock=True)) 257 | T.append(multiprocessing.Value('i', 0, lock=True)) 258 | 259 | def compare(start,step,TP,P,T): 260 | for idx in range(start,len(self.name_list),step): 261 | #print('%d/%d'%(idx,len(self.name_list))) 262 | name = self.name_list[idx] 263 | predict_file = os.path.join(predict_folder,'%s.png'%name) 264 | gt_file = os.path.join(gt_folder,'%s.png'%name) 265 | predict = np.array(Image.open(predict_file)) #cv2.imread(predict_file) 266 | gt = np.array(Image.open(gt_file)) 267 | cal = gt<255 268 | mask = (predict==gt) * cal 269 | 270 | for i in range(self.num_categories): 271 | P[i].acquire() 272 | P[i].value += np.sum((predict==i)*cal) 273 | P[i].release() 274 | T[i].acquire() 275 | T[i].value += np.sum((gt==i)*cal) 276 | T[i].release() 277 | TP[i].acquire() 278 | TP[i].value += np.sum((gt==i)*mask) 279 | TP[i].release() 280 | p_list = [] 281 | for i in range(8): 282 | p = multiprocessing.Process(target=compare, args=(i,8,TP,P,T)) 283 | p.start() 284 | p_list.append(p) 285 | for p in p_list: 286 | p.join() 287 | IoU = [] 288 | for i in range(self.num_categories): 289 | IoU.append(TP[i].value/(T[i].value+P[i].value-TP[i].value+1e-10)) 290 | loglist = {} 291 | for i in range(self.num_categories): 292 | if i == 0: 293 | print('%11s:%7.3f%%'%('background',IoU[i]*100),end='\t') 294 | loglist['background'] = IoU[i] * 100 295 | else: 296 | if i%2 != 1: 297 | print('%11s:%7.3f%%'%(self.categories[i-1],IoU[i]*100),end='\t') 298 | else: 299 | print('%11s:%7.3f%%'%(self.categories[i-1],IoU[i]*100)) 300 | loglist[self.categories[i-1]] = IoU[i] * 100 301 | 302 | miou = np.mean(np.array(IoU)) 303 | print('\n======================================================') 304 | print('%11s:%7.3f%%'%('mIoU',miou*100)) 305 | loglist['mIoU'] = miou * 100 306 | return loglist 307 | 308 | def do_python_eval_batch_pseudo_one_process(self): 309 | self.seg_dir_gt = os.path.join(self.dataset_dir, 'SegmentationClassAug') 310 | gt_folder = self.seg_dir_gt 311 | TP_gt_epoch = [0] * 21 312 | P_gt_epoch = [0] * 21 313 | T_gt_epoch = [0] * 21 314 | loglist = {} 315 | for idx in range(len(self.name_list)): 316 | # print(idx) 317 | name = self.name_list[idx] 318 | gt_file = os.path.join(gt_folder, '%s.png' % name) 319 | gt = np.array(Image.open(gt_file)) 320 | r, c = gt.shape 321 | # print(r) 322 | predict_tensor = F.interpolate(self.prev_pred_dict[idx], size=(r, c), mode='bilinear', align_corners=True, 323 | recompute_scale_factor=False) # 1,c,h,w 324 | predict = predict_tensor[0].cpu().numpy() # c,h,w 325 | predict = np.argmax(predict, axis=0) # h,w 326 | 327 | TP_gt_epoch, P_gt_epoch, T_gt_epoch = update_iou_stat(predict, gt, TP_gt_epoch, 328 | P_gt_epoch, T_gt_epoch) 329 | IoU_gt_epoch = compute_iou(TP_gt_epoch, P_gt_epoch, T_gt_epoch) 330 | for indx, class_name in enumerate( 331 | ['background', 'aeroplane', 'bicycle', 'bird', 'boat', 'bottle', 'bus', 'car', 'cat', 'chair', 332 | 'cow', 333 | 'diningtable', 'dog', 'horse', 'motorbike', 'person', 'pottedplant', 'sheep', 'sofa', 'train', 334 | 'tvmonitor']): 335 | loglist[class_name] = IoU_gt_epoch[indx] 336 | mIoU_clean_epoch = np.mean(np.array(IoU_gt_epoch)) 337 | loglist['mIoU'] = mIoU_clean_epoch 338 | return loglist 339 | 340 | 341 | def __coco2voc(self, m): 342 | r,c = m.shape 343 | result = np.zeros((r,c),dtype=np.uint8) 344 | for i in range(0,21): 345 | for j in self.coco2voc[i]: 346 | result[m==j] = i 347 | return result 348 | 349 | 350 | -------------------------------------------------------------------------------- /lib/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | from .VOCDataset import * 2 | from .VOCEvalDataset import * 3 | from .VOCTrainwsegDataset import * 4 | -------------------------------------------------------------------------------- /lib/datasets/generateData.py: -------------------------------------------------------------------------------- 1 | from utils.registry import DATASETS 2 | 3 | def generate_dataset(cfg, **kwargs): 4 | dataset = DATASETS.get(cfg.DATA_NAME)(cfg, **kwargs) 5 | return dataset 6 | -------------------------------------------------------------------------------- /lib/datasets/metric.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import re 3 | import functools 4 | 5 | class AverageMeter(object): 6 | """Computes and stores the average and current value""" 7 | def __init__(self): 8 | self.initialized = False 9 | self.val = None 10 | self.avg = None 11 | self.sum = None 12 | self.count = None 13 | 14 | def initialize(self, val, weight): 15 | self.val = val 16 | self.avg = val 17 | self.sum = val * weight 18 | self.count = weight 19 | self.initialized = True 20 | 21 | def update(self, val, weight=1): 22 | if not self.initialized: 23 | self.initialize(val, weight) 24 | else: 25 | self.add(val, weight) 26 | 27 | def add(self, val, weight): 28 | self.val = val 29 | self.sum += val * weight 30 | self.count += weight 31 | self.avg = self.sum / self.count 32 | 33 | def value(self): 34 | return self.val 35 | 36 | def average(self): 37 | return self.avg 38 | 39 | 40 | def unique(ar, return_index=False, return_inverse=False, return_counts=False): 41 | ar = np.asanyarray(ar).flatten() 42 | 43 | optional_indices = return_index or return_inverse 44 | optional_returns = optional_indices or return_counts 45 | 46 | if ar.size == 0: 47 | if not optional_returns: 48 | ret = ar 49 | else: 50 | ret = (ar,) 51 | if return_index: 52 | ret += (np.empty(0, np.bool),) 53 | if return_inverse: 54 | ret += (np.empty(0, np.bool),) 55 | if return_counts: 56 | ret += (np.empty(0, np.intp),) 57 | return ret 58 | if optional_indices: 59 | perm = ar.argsort(kind='mergesort' if return_index else 'quicksort') 60 | aux = ar[perm] 61 | else: 62 | ar.sort() 63 | aux = ar 64 | flag = np.concatenate(([True], aux[1:] != aux[:-1])) 65 | 66 | if not optional_returns: 67 | ret = aux[flag] 68 | else: 69 | ret = (aux[flag],) 70 | if return_index: 71 | ret += (perm[flag],) 72 | if return_inverse: 73 | iflag = np.cumsum(flag) - 1 74 | inv_idx = np.empty(ar.shape, dtype=np.intp) 75 | inv_idx[perm] = iflag 76 | ret += (inv_idx,) 77 | if return_counts: 78 | idx = np.concatenate(np.nonzero(flag) + ([ar.size],)) 79 | ret += (np.diff(idx),) 80 | return ret 81 | 82 | 83 | def colorEncode(labelmap, colors, mode='BGR'): 84 | labelmap = labelmap.astype('int') 85 | labelmap_rgb = np.zeros((labelmap.shape[0], labelmap.shape[1], 3), 86 | dtype=np.uint8) 87 | for label in unique(labelmap): 88 | if label < 0: 89 | continue 90 | labelmap_rgb += (labelmap == label)[:, :, np.newaxis] * \ 91 | np.tile(colors[label], 92 | (labelmap.shape[0], labelmap.shape[1], 1)) 93 | 94 | if mode == 'BGR': 95 | return labelmap_rgb[:, :, ::-1] 96 | else: 97 | return labelmap_rgb 98 | 99 | 100 | def accuracy(preds, label): 101 | valid = (label >= 0) 102 | acc_sum = (valid * (preds == label)).sum() 103 | valid_sum = valid.sum() 104 | acc = float(acc_sum) / (valid_sum + 1e-10) 105 | return acc, valid_sum 106 | 107 | 108 | def intersectionAndUnion(imPred, imLab, numClass): 109 | imPred = np.asarray(imPred).copy() 110 | imLab = np.asarray(imLab).copy() 111 | 112 | imPred += 1 113 | imLab += 1 114 | # Remove classes from unlabeled pixels in gt image. 115 | # We should not penalize detections in unlabeled portions of the image. 116 | imPred = imPred * (imLab > 0) 117 | 118 | # Compute area intersection: 119 | intersection = imPred * (imPred == imLab) 120 | (area_intersection, _) = np.histogram( 121 | intersection, bins=numClass, range=(1, numClass)) 122 | 123 | # Compute area union: 124 | (area_pred, _) = np.histogram(imPred, bins=numClass, range=(1, numClass)) 125 | (area_lab, _) = np.histogram(imLab, bins=numClass, range=(1, numClass)) 126 | area_union = area_pred + area_lab - area_intersection 127 | 128 | return (area_intersection, area_union) 129 | 130 | 131 | class NotSupportedCliException(Exception): 132 | pass 133 | 134 | 135 | def process_range(xpu, inp): 136 | start, end = map(int, inp) 137 | if start > end: 138 | end, start = start, end 139 | return map(lambda x: '{}{}'.format(xpu, x), range(start, end+1)) 140 | 141 | 142 | REGEX = [ 143 | (re.compile(r'^gpu(\d+)$'), lambda x: ['gpu%s' % x[0]]), 144 | (re.compile(r'^(\d+)$'), lambda x: ['gpu%s' % x[0]]), 145 | (re.compile(r'^gpu(\d+)-(?:gpu)?(\d+)$'), 146 | functools.partial(process_range, 'gpu')), 147 | (re.compile(r'^(\d+)-(\d+)$'), 148 | functools.partial(process_range, 'gpu')), 149 | ] 150 | 151 | 152 | def parse_devices(input_devices): 153 | 154 | """Parse user's devices input str to standard format. 155 | e.g. [gpu0, gpu1, ...] 156 | 157 | """ 158 | ret = [] 159 | for d in input_devices.split(','): 160 | for regex, func in REGEX: 161 | m = regex.match(d.lower().strip()) 162 | if m: 163 | tmp = func(m.groups()) 164 | # prevent duplicate 165 | for x in tmp: 166 | if x not in ret: 167 | ret.append(x) 168 | break 169 | else: 170 | raise NotSupportedCliException( 171 | 'Can not recognize device: "%s"' % d) 172 | return ret 173 | -------------------------------------------------------------------------------- /lib/datasets/transform.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import numpy as np 3 | import torch 4 | import random 5 | import PIL 6 | from PIL import Image, ImageOps, ImageFilter 7 | 8 | class RandomCrop(object): 9 | """Crop randomly the image in a sample. 10 | 11 | Args: 12 | output_size (tuple or int): Desired output size. If int, square crop 13 | is made. 14 | """ 15 | 16 | def __init__(self, output_size): 17 | assert isinstance(output_size, (int, tuple)) 18 | if isinstance(output_size, int): 19 | self.output_size = (output_size, output_size) 20 | else: 21 | assert len(output_size) == 2 22 | self.output_size = output_size 23 | 24 | def __call__(self, sample): 25 | 26 | h, w = sample['image'].shape[:2] 27 | ch = min(h, self.output_size[0]) 28 | cw = min(w, self.output_size[1]) 29 | 30 | h_space = h - self.output_size[0] 31 | w_space = w - self.output_size[1] 32 | 33 | if w_space > 0: 34 | cont_left = 0 35 | img_left = random.randrange(w_space+1) 36 | else: 37 | cont_left = random.randrange(-w_space+1) 38 | img_left = 0 39 | 40 | if h_space > 0: 41 | cont_top = 0 42 | img_top = random.randrange(h_space+1) 43 | else: 44 | cont_top = random.randrange(-h_space+1) 45 | img_top = 0 46 | 47 | key_list = sample.keys() 48 | for key in key_list: 49 | if 'image' in key: 50 | img = sample[key] 51 | img_crop = np.zeros((self.output_size[0], self.output_size[1], 3), np.float32) 52 | img_crop[cont_top:cont_top+ch, cont_left:cont_left+cw] = \ 53 | img[img_top:img_top+ch, img_left:img_left+cw] 54 | #img_crop = img[img_top:img_top+ch, img_left:img_left+cw] 55 | sample[key] = img_crop 56 | elif 'segmentation' == key: 57 | seg = sample[key] 58 | seg_crop = np.ones((self.output_size[0], self.output_size[1]), np.float32)*255 59 | seg_crop[cont_top:cont_top+ch, cont_left:cont_left+cw] = \ 60 | seg[img_top:img_top+ch, img_left:img_left+cw] 61 | #seg_crop = seg[img_top:img_top+ch, img_left:img_left+cw] 62 | sample[key] = seg_crop 63 | elif 'segmentation_pseudo' in key: 64 | seg_pseudo = sample[key] 65 | seg_crop = np.ones((self.output_size[0], self.output_size[1]), np.float32)*255 66 | seg_crop[cont_top:cont_top+ch, cont_left:cont_left+cw] = \ 67 | seg_pseudo[img_top:img_top+ch, img_left:img_left+cw] 68 | #seg_crop = seg_pseudo[img_top:img_top+ch, img_left:img_left+cw] 69 | sample[key] = seg_crop 70 | return sample 71 | 72 | class RandomHSV(object): 73 | """Generate randomly the image in hsv space.""" 74 | def __init__(self, h_r, s_r, v_r): 75 | self.h_r = h_r 76 | self.s_r = s_r 77 | self.v_r = v_r 78 | 79 | def __call__(self, sample): 80 | image = sample['image'] 81 | hsv = cv2.cvtColor(image, cv2.COLOR_RGB2HSV) 82 | h = hsv[:,:,0].astype(np.int32) 83 | s = hsv[:,:,1].astype(np.int32) 84 | v = hsv[:,:,2].astype(np.int32) 85 | delta_h = random.randint(-self.h_r,self.h_r) 86 | delta_s = random.randint(-self.s_r,self.s_r) 87 | delta_v = random.randint(-self.v_r,self.v_r) 88 | h = (h + delta_h)%180 89 | s = s + delta_s 90 | s[s>255] = 255 91 | s[s<0] = 0 92 | v = v + delta_v 93 | v[v>255] = 255 94 | v[v<0] = 0 95 | hsv = np.stack([h,s,v], axis=-1).astype(np.uint8) 96 | image = cv2.cvtColor(hsv, cv2.COLOR_HSV2RGB).astype(np.uint8) 97 | sample['image'] = image 98 | return sample 99 | 100 | class RandomFlip(object): 101 | """Randomly flip image""" 102 | def __init__(self, threshold): 103 | self.flip_t = threshold 104 | def __call__(self, sample): 105 | if random.random() < self.flip_t: 106 | key_list = sample.keys() 107 | for key in key_list: 108 | if 'image' in key: 109 | img = sample[key] 110 | img = np.flip(img, axis=1) 111 | sample[key] = img 112 | elif 'segmentation' == key: 113 | seg = sample[key] 114 | seg = np.flip(seg, axis=1) 115 | sample[key] = seg 116 | elif 'segmentation_pseudo' in key: 117 | seg_pseudo = sample[key] 118 | seg_pseudo = np.flip(seg_pseudo, axis=1) 119 | sample[key] = seg_pseudo 120 | return sample 121 | 122 | class RandomScale(object): 123 | """Randomly scale image""" 124 | def __init__(self, scale_r, is_continuous=False): 125 | self.scale_r = scale_r 126 | self.seg_interpolation = cv2.INTER_CUBIC if is_continuous else cv2.INTER_NEAREST 127 | 128 | def __call__(self, sample): 129 | row, col, _ = sample['image'].shape 130 | rand_scale = random.random()*(self.scale_r[1] - self.scale_r[0]) + self.scale_r[0] 131 | key_list = sample.keys() 132 | for key in key_list: 133 | if 'image' in key: 134 | img = sample[key] 135 | img = cv2.resize(img, None, fx=rand_scale, fy=rand_scale, interpolation=cv2.INTER_CUBIC) 136 | sample[key] = img 137 | elif 'segmentation' == key: 138 | seg = sample[key] 139 | seg = cv2.resize(seg, None, fx=rand_scale, fy=rand_scale, interpolation=self.seg_interpolation) 140 | sample[key] = seg 141 | elif 'segmentation_pseudo' in key: 142 | seg_pseudo = sample[key] 143 | seg_pseudo = cv2.resize(seg_pseudo, None, fx=rand_scale, fy=rand_scale, interpolation=self.seg_interpolation) 144 | sample[key] = seg_pseudo 145 | return sample 146 | 147 | class ImageNorm(object): 148 | """Randomly scale image""" 149 | def __init__(self, mean=None, std=None): 150 | self.mean = mean 151 | self.std = std 152 | def __call__(self, sample): 153 | key_list = sample.keys() 154 | for key in key_list: 155 | if 'image' in key: 156 | image = sample[key].astype(np.float32) 157 | if self.mean is not None and self.std is not None: 158 | image[...,0] = (image[...,0]/255 - self.mean[0]) / self.std[0] 159 | image[...,1] = (image[...,1]/255 - self.mean[1]) / self.std[1] 160 | image[...,2] = (image[...,2]/255 - self.mean[2]) / self.std[2] 161 | else: 162 | image /= 255.0 163 | sample[key] = image 164 | return sample 165 | 166 | class Multiscale(object): 167 | def __init__(self, rate_list): 168 | self.rate_list = rate_list 169 | 170 | def __call__(self, sample): 171 | image = sample['image'] 172 | row, col, _ = image.shape 173 | image_multiscale = [] 174 | for rate in self.rate_list: 175 | rescaled_image = cv2.resize(image, None, fx=rate, fy=rate, interpolation=cv2.INTER_CUBIC) 176 | sample['image_%f'%rate] = rescaled_image 177 | return sample 178 | 179 | 180 | class ToTensor(object): 181 | """Convert ndarrays in sample to Tensors.""" 182 | 183 | def __call__(self, sample): 184 | key_list = sample.keys() 185 | for key in key_list: 186 | if 'image' in key: 187 | image = sample[key].astype(np.float32) 188 | # swap color axis because 189 | # numpy image: H x W x C 190 | # torch image: C X H X W 191 | image = image.transpose((2,0,1)) 192 | sample[key] = torch.from_numpy(image) 193 | #sample[key] = torch.from_numpy(image.astype(np.float32)/128.0-1.0) 194 | elif 'edge' == key: 195 | edge = sample['edge'] 196 | sample['edge'] = torch.from_numpy(edge.astype(np.float32)) 197 | sample['edge'] = torch.unsqueeze(sample['edge'],0) 198 | elif 'segmentation' == key: 199 | segmentation = sample['segmentation'] 200 | sample['segmentation'] = torch.from_numpy(segmentation.astype(np.long)) 201 | elif 'segmentation_pseudo' in key: 202 | segmentation_pseudo = sample[key] 203 | sample[key] = torch.from_numpy(segmentation_pseudo.astype(np.float32)) 204 | elif 'segmentation_onehot' == key: 205 | onehot = sample['segmentation_onehot'].transpose((2,0,1)) 206 | sample['segmentation_onehot'] = torch.from_numpy(onehot.astype(np.float32)) 207 | elif 'category' in key: 208 | sample[key] = torch.from_numpy(sample[key].astype(np.float32)) 209 | elif 'mask' == key: 210 | mask = sample['mask'] 211 | sample['mask'] = torch.from_numpy(mask.astype(np.float32)) 212 | elif 'feature' == key: 213 | feature = sample['feature'] 214 | sample['feature'] = torch.from_numpy(feature.astype(np.float32)) 215 | return sample 216 | 217 | -------------------------------------------------------------------------------- /lib/datasets/transformmultiGT.py: -------------------------------------------------------------------------------- 1 | # ---------------------------------------- 2 | # heavily borrowed from Yude Wang, modified by Kangning Liu 3 | # ---------------------------------------- 4 | 5 | import cv2 6 | import numpy as np 7 | import torch 8 | import random 9 | import PIL 10 | from PIL import Image, ImageOps, ImageFilter 11 | 12 | class RandomCrop(object): 13 | """Crop randomly the image in a sample. 14 | 15 | Args: 16 | output_size (tuple or int): Desired output size. If int, square crop 17 | is made. 18 | """ 19 | 20 | def __init__(self, output_size): 21 | assert isinstance(output_size, (int, tuple)) 22 | if isinstance(output_size, int): 23 | self.output_size = (output_size, output_size) 24 | else: 25 | assert len(output_size) == 2 26 | self.output_size = output_size 27 | 28 | def __call__(self, sample): 29 | 30 | h, w = sample['image'].shape[:2] 31 | ch = min(h, self.output_size[0]) 32 | cw = min(w, self.output_size[1]) 33 | 34 | h_space = h - self.output_size[0] 35 | w_space = w - self.output_size[1] 36 | 37 | if w_space > 0: 38 | cont_left = 0 39 | img_left = random.randrange(w_space+1) 40 | else: 41 | cont_left = random.randrange(-w_space+1) 42 | img_left = 0 43 | 44 | if h_space > 0: 45 | cont_top = 0 46 | img_top = random.randrange(h_space+1) 47 | else: 48 | cont_top = random.randrange(-h_space+1) 49 | img_top = 0 50 | 51 | key_list = sample.keys() 52 | for key in key_list: 53 | if 'image' in key: 54 | img = sample[key] 55 | img_crop = np.zeros((self.output_size[0], self.output_size[1], 3), np.float32) 56 | img_crop[cont_top:cont_top+ch, cont_left:cont_left+cw] = \ 57 | img[img_top:img_top+ch, img_left:img_left+cw] 58 | #img_crop = img[img_top:img_top+ch, img_left:img_left+cw] 59 | sample[key] = img_crop 60 | elif 'segmentation' == key: 61 | seg = sample[key] 62 | seg_crop = np.ones((self.output_size[0], self.output_size[1]), np.float32)*255 63 | seg_crop[cont_top:cont_top+ch, cont_left:cont_left+cw] = \ 64 | seg[img_top:img_top+ch, img_left:img_left+cw] 65 | #seg_crop = seg[img_top:img_top+ch, img_left:img_left+cw] 66 | sample[key] = seg_crop 67 | elif 'segmentation2' == key or 'segmentation3' == key or 'segmentationgt' == key: 68 | seg = sample[key] 69 | seg_crop = np.ones((self.output_size[0], self.output_size[1]), np.float32)*255 70 | seg_crop[cont_top:cont_top+ch, cont_left:cont_left+cw] = \ 71 | seg[img_top:img_top+ch, img_left:img_left+cw] 72 | #seg_crop = seg[img_top:img_top+ch, img_left:img_left+cw] 73 | sample[key] = seg_crop 74 | elif 'segmentation_pseudo' in key: 75 | seg_pseudo = sample[key] 76 | seg_crop = np.ones((self.output_size[0], self.output_size[1]), np.float32)*255 77 | seg_crop[cont_top:cont_top+ch, cont_left:cont_left+cw] = \ 78 | seg_pseudo[img_top:img_top+ch, img_left:img_left+cw] 79 | #seg_crop = seg_pseudo[img_top:img_top+ch, img_left:img_left+cw] 80 | sample[key] = seg_crop 81 | return sample 82 | 83 | class RandomHSV(object): 84 | """Generate randomly the image in hsv space.""" 85 | def __init__(self, h_r, s_r, v_r): 86 | self.h_r = h_r 87 | self.s_r = s_r 88 | self.v_r = v_r 89 | 90 | def __call__(self, sample): 91 | image = sample['image'] 92 | hsv = cv2.cvtColor(image, cv2.COLOR_RGB2HSV) 93 | h = hsv[:,:,0].astype(np.int32) 94 | s = hsv[:,:,1].astype(np.int32) 95 | v = hsv[:,:,2].astype(np.int32) 96 | delta_h = random.randint(-self.h_r,self.h_r) 97 | delta_s = random.randint(-self.s_r,self.s_r) 98 | delta_v = random.randint(-self.v_r,self.v_r) 99 | h = (h + delta_h)%180 100 | s = s + delta_s 101 | s[s>255] = 255 102 | s[s<0] = 0 103 | v = v + delta_v 104 | v[v>255] = 255 105 | v[v<0] = 0 106 | hsv = np.stack([h,s,v], axis=-1).astype(np.uint8) 107 | image = cv2.cvtColor(hsv, cv2.COLOR_HSV2RGB).astype(np.uint8) 108 | sample['image'] = image 109 | return sample 110 | 111 | class RandomFlip(object): 112 | """Randomly flip image""" 113 | def __init__(self, threshold): 114 | self.flip_t = threshold 115 | def __call__(self, sample): 116 | if random.random() < self.flip_t: 117 | key_list = sample.keys() 118 | for key in key_list: 119 | if 'image' in key: 120 | img = sample[key] 121 | img = np.flip(img, axis=1) 122 | sample[key] = img 123 | elif 'segmentation' == key: 124 | seg = sample[key] 125 | seg = np.flip(seg, axis=1) 126 | sample[key] = seg 127 | elif 'segmentation2' == key or 'segmentation3' == key or 'segmentationgt' == key: 128 | seg = sample[key] 129 | seg = np.flip(seg, axis=1) 130 | sample[key] = seg 131 | elif 'segmentation_pseudo' in key: 132 | seg_pseudo = sample[key] 133 | seg_pseudo = np.flip(seg_pseudo, axis=1) 134 | sample[key] = seg_pseudo 135 | return sample 136 | 137 | class RandomScale(object): 138 | """Randomly scale image""" 139 | def __init__(self, scale_r, is_continuous=False): 140 | self.scale_r = scale_r 141 | self.seg_interpolation = cv2.INTER_CUBIC if is_continuous else cv2.INTER_NEAREST 142 | 143 | def __call__(self, sample): 144 | row, col, _ = sample['image'].shape 145 | rand_scale = random.random()*(self.scale_r[1] - self.scale_r[0]) + self.scale_r[0] 146 | key_list = sample.keys() 147 | for key in key_list: 148 | if 'image' in key: 149 | img = sample[key] 150 | img = cv2.resize(img, None, fx=rand_scale, fy=rand_scale, interpolation=cv2.INTER_CUBIC) 151 | sample[key] = img 152 | elif 'segmentation' == key: 153 | seg = sample[key] 154 | seg = cv2.resize(seg, None, fx=rand_scale, fy=rand_scale, interpolation=self.seg_interpolation) 155 | sample[key] = seg 156 | elif 'segmentation2' == key or 'segmentation3' == key or 'segmentationgt' == key: 157 | seg = sample[key] 158 | seg = cv2.resize(seg, None, fx=rand_scale, fy=rand_scale, interpolation=self.seg_interpolation) 159 | sample[key] = seg 160 | elif 'segmentation_pseudo' in key: 161 | seg_pseudo = sample[key] 162 | seg_pseudo = cv2.resize(seg_pseudo, None, fx=rand_scale, fy=rand_scale, interpolation=self.seg_interpolation) 163 | sample[key] = seg_pseudo 164 | return sample 165 | 166 | class ImageNorm(object): 167 | """Randomly scale image""" 168 | def __init__(self, mean=None, std=None): 169 | self.mean = mean 170 | self.std = std 171 | def __call__(self, sample): 172 | key_list = sample.keys() 173 | for key in key_list: 174 | if 'image' in key: 175 | image = sample[key].astype(np.float32) 176 | if self.mean is not None and self.std is not None: 177 | image[...,0] = (image[...,0]/255 - self.mean[0]) / self.std[0] 178 | image[...,1] = (image[...,1]/255 - self.mean[1]) / self.std[1] 179 | image[...,2] = (image[...,2]/255 - self.mean[2]) / self.std[2] 180 | else: 181 | image /= 255.0 182 | sample[key] = image 183 | return sample 184 | 185 | class Multiscale(object): 186 | def __init__(self, rate_list): 187 | self.rate_list = rate_list 188 | 189 | def __call__(self, sample): 190 | image = sample['image'] 191 | row, col, _ = image.shape 192 | image_multiscale = [] 193 | for rate in self.rate_list: 194 | rescaled_image = cv2.resize(image, None, fx=rate, fy=rate, interpolation=cv2.INTER_CUBIC) 195 | sample['image_%f'%rate] = rescaled_image 196 | return sample 197 | 198 | 199 | class ToTensor(object): 200 | """Convert ndarrays in sample to Tensors.""" 201 | 202 | def __call__(self, sample): 203 | key_list = sample.keys() 204 | for key in key_list: 205 | if 'image' in key: 206 | image = sample[key].astype(np.float32) 207 | # swap color axis because 208 | # numpy image: H x W x C 209 | # torch image: C X H X W 210 | image = image.transpose((2,0,1)) 211 | sample[key] = torch.from_numpy(image) 212 | #sample[key] = torch.from_numpy(image.astype(np.float32)/128.0-1.0) 213 | elif 'edge' == key: 214 | edge = sample['edge'] 215 | sample['edge'] = torch.from_numpy(edge.astype(np.float32)) 216 | sample['edge'] = torch.unsqueeze(sample['edge'],0) 217 | elif 'segmentation' == key: 218 | segmentation = sample['segmentation'] 219 | sample['segmentation'] = torch.from_numpy(segmentation.astype(np.long)) 220 | 221 | elif 'segmentation2' == key or 'segmentation3' == key or 'segmentationgt' == key: 222 | # segmentation = sample['segmentation2'] 223 | # sample['segmentation2'] = torch.from_numpy(segmentation.astype(np.long)) 224 | segmentation = sample[key] 225 | sample[key] = torch.from_numpy(segmentation.astype(np.long)) 226 | 227 | elif 'segmentation_pseudo' in key: 228 | segmentation_pseudo = sample[key] 229 | sample[key] = torch.from_numpy(segmentation_pseudo.astype(np.float32)) 230 | elif 'segmentation_onehot' == key: 231 | onehot = sample['segmentation_onehot'].transpose((2,0,1)) 232 | sample['segmentation_onehot'] = torch.from_numpy(onehot.astype(np.float32)) 233 | elif 'category' in key: 234 | sample[key] = torch.from_numpy(sample[key].astype(np.float32)) 235 | elif 'mask' == key: 236 | mask = sample['mask'] 237 | sample['mask'] = torch.from_numpy(mask.astype(np.float32)) 238 | elif 'feature' == key: 239 | feature = sample['feature'] 240 | sample['feature'] = torch.from_numpy(feature.astype(np.float32)) 241 | return sample 242 | 243 | -------------------------------------------------------------------------------- /lib/net/__init__.py: -------------------------------------------------------------------------------- 1 | from .deeplabv1_wo_interp import * -------------------------------------------------------------------------------- /lib/net/backbone/__init__.py: -------------------------------------------------------------------------------- 1 | from .builder import build_backbone 2 | from .resnet38d import * 3 | from .resnet import * 4 | from .xception import * 5 | 6 | __all__ = ['build_backbone'] 7 | -------------------------------------------------------------------------------- /lib/net/backbone/builder.py: -------------------------------------------------------------------------------- 1 | # ---------------------------------------- 2 | # Written by Yude Wang 3 | # ---------------------------------------- 4 | 5 | from utils.registry import BACKBONES 6 | 7 | def build_backbone(backbone_name, pretrained=True, **kwargs): 8 | net = BACKBONES.get(backbone_name)(pretrained=pretrained, **kwargs) 9 | return net 10 | -------------------------------------------------------------------------------- /lib/net/backbone/resnet.py: -------------------------------------------------------------------------------- 1 | """Dilated ResNet""" 2 | import math 3 | import torch 4 | import torch.utils.model_zoo as model_zoo 5 | import torch.nn as nn 6 | from net.sync_batchnorm import SynchronizedBatchNorm2d 7 | from utils.registry import BACKBONES 8 | 9 | __all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101', 10 | 'resnet152', 'BasicBlock', 'Bottleneck'] 11 | bn_mom = 0.1 12 | model_urls = { 13 | 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth', 14 | 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth', 15 | 'resnet50': '~/.cache/torch/checkpoints/resnet50s-a75c83cf.pth', 16 | 'resnet101': '/home/wangyude/.cache/torch/checkpoints/resnet101s-03a0f310.pth', 17 | 'resnet152': '~/.cache/torch/checkpoints/resnet152s-36670e8b.pth', 18 | #'resnet50': 'https://s3.us-west-1.wasabisys.com/encoding/models/resnet50s-a75c83cf.zip', 19 | #'resnet101': 'https://s3.us-west-1.wasabisys.com/encoding/models/resnet101s-03a0f310.zip', 20 | #'resnet152': 'https://s3.us-west-1.wasabisys.com/encoding/models/resnet152s-36670e8b.zip' 21 | } 22 | mean = (0.485, 0.456, 0.406) 23 | std = (0.229, 0.224, 0.225) 24 | 25 | def conv3x3(in_planes, out_planes, stride=1): 26 | "3x3 convolution with padding" 27 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 28 | padding=1, bias=False) 29 | 30 | 31 | class BasicBlock(nn.Module): 32 | """ResNet BasicBlock 33 | """ 34 | expansion = 1 35 | def __init__(self, inplanes, planes, stride=1, dilation=1, downsample=None, previous_dilation=1, 36 | norm_layer=None): 37 | super(BasicBlock, self).__init__() 38 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=3, stride=stride, 39 | padding=dilation, dilation=dilation, bias=False) 40 | self.bn1 = norm_layer(planes, momentum=bn_mom, affine=True) 41 | self.relu = nn.ReLU(inplace=True) 42 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, 43 | padding=previous_dilation, dilation=previous_dilation, bias=False) 44 | self.bn2 = norm_layer(planes, momentum=bn_mom, affine=True) 45 | self.downsample = downsample 46 | self.stride = stride 47 | 48 | def forward(self, x): 49 | residual = x 50 | 51 | out = self.conv1(x) 52 | out = self.bn1(out) 53 | out = self.relu(out) 54 | 55 | out = self.conv2(out) 56 | out = self.bn2(out) 57 | 58 | if self.downsample is not None: 59 | residual = self.downsample(x) 60 | 61 | out += residual 62 | out = self.relu(out) 63 | 64 | return out 65 | 66 | 67 | class Bottleneck(nn.Module): 68 | """ResNet Bottleneck 69 | """ 70 | # pylint: disable=unused-argument 71 | expansion = 4 72 | def __init__(self, inplanes, planes, stride=1, dilation=1, 73 | downsample=None, previous_dilation=1, norm_layer=None): 74 | super(Bottleneck, self).__init__() 75 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) 76 | self.bn1 = norm_layer(planes, momentum=bn_mom, affine=True) 77 | self.conv2 = nn.Conv2d( 78 | planes, planes, kernel_size=3, stride=stride, 79 | padding=dilation, dilation=dilation, bias=False) 80 | self.bn2 = norm_layer(planes, momentum=bn_mom, affine=True) 81 | self.conv3 = nn.Conv2d( 82 | planes, planes * 4, kernel_size=1, bias=False) 83 | self.bn3 = norm_layer(planes * 4, momentum=bn_mom, affine=True) 84 | self.relu = nn.ReLU(inplace=True) 85 | self.downsample = downsample 86 | self.dilation = dilation 87 | self.stride = stride 88 | 89 | def _sum_each(self, x, y): 90 | assert(len(x) == len(y)) 91 | z = [] 92 | for i in range(len(x)): 93 | z.append(x[i]+y[i]) 94 | return z 95 | 96 | def forward(self, x): 97 | residual = x 98 | 99 | out = self.conv1(x) 100 | out = self.bn1(out) 101 | out = self.relu(out) 102 | 103 | out = self.conv2(out) 104 | out = self.bn2(out) 105 | out = self.relu(out) 106 | 107 | out = self.conv3(out) 108 | out = self.bn3(out) 109 | 110 | if self.downsample is not None: 111 | residual = self.downsample(x) 112 | 113 | out += residual 114 | out = self.relu(out) 115 | 116 | return out 117 | 118 | 119 | class ResNet(nn.Module): 120 | """Dilated Pre-trained ResNet Model, which preduces the stride of 8 featuremaps at conv5. 121 | 122 | Parameters 123 | ---------- 124 | block : Block 125 | Class for the residual block. Options are BasicBlockV1, BottleneckV1. 126 | layers : list of int 127 | Numbers of layers in each block 128 | classes : int, default 1000 129 | Number of classification classes. 130 | dilated : bool, default False 131 | Applying dilation strategy to pretrained ResNet yielding a stride-8 model, 132 | typically used in Semantic Segmentation. 133 | norm_layer : object 134 | Normalization layer used in backbone network (default: :class:`mxnet.gluon.nn.BatchNorm`; 135 | for Synchronized Cross-GPU BachNormalization). 136 | 137 | Reference: 138 | 139 | - He, Kaiming, et al. "Deep residual learning for image recognition." Proceedings of the IEEE conference on computer vision and pattern recognition. 2016. 140 | 141 | - Yu, Fisher, and Vladlen Koltun. "Multi-scale context aggregation by dilated convolutions." 142 | """ 143 | # pylint: disable=unused-variable 144 | def __init__(self, block, layers, dilated=True, multi_grid=False, 145 | deep_base=True, norm_layer=nn.BatchNorm2d): 146 | self.inplanes = 128 if deep_base else 64 147 | super(ResNet, self).__init__() 148 | if deep_base: 149 | self.conv1 = nn.Sequential( 150 | nn.Conv2d(3, 64, kernel_size=3, stride=2, padding=1, bias=False), 151 | norm_layer(64, momentum=bn_mom, affine=True), 152 | nn.ReLU(inplace=True), 153 | nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1, bias=False), 154 | norm_layer(64, momentum=bn_mom, affine=True), 155 | nn.ReLU(inplace=True), 156 | nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1, bias=False), 157 | ) 158 | else: 159 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, 160 | bias=False) 161 | self.bn1 = norm_layer(self.inplanes, momentum=bn_mom, affine=True) 162 | self.relu = nn.ReLU(inplace=True) 163 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 164 | self.layer1 = self._make_layer(block, 64, layers[0], norm_layer=norm_layer) 165 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2, norm_layer=norm_layer) 166 | if dilated: 167 | self.layer3 = self._make_layer(block, 256, layers[2], stride=1, 168 | dilation=2, norm_layer=norm_layer) 169 | if multi_grid: 170 | self.layer4 = self._make_layer(block, 512, layers[3], stride=1, 171 | dilation=4, norm_layer=norm_layer, 172 | multi_grid=True) 173 | else: 174 | self.layer4 = self._make_layer(block, 512, layers[3], stride=1, 175 | dilation=4, norm_layer=norm_layer) 176 | else: 177 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2, 178 | norm_layer=norm_layer) 179 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2, 180 | norm_layer=norm_layer) 181 | self.OUTPUT_DIM = 2048 182 | self.MIDDLE_DIM = 256 183 | 184 | for m in self.modules(): 185 | if isinstance(m, nn.Conv2d): 186 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 187 | m.weight.data.normal_(0, math.sqrt(2. / n)) 188 | elif isinstance(m, norm_layer): 189 | m.weight.data.fill_(1) 190 | m.bias.data.zero_() 191 | 192 | def _make_layer(self, block, planes, blocks, stride=1, dilation=1, norm_layer=None, multi_grid=False): 193 | downsample = None 194 | if stride != 1 or self.inplanes != planes * block.expansion: 195 | downsample = nn.Sequential( 196 | nn.Conv2d(self.inplanes, planes * block.expansion, 197 | kernel_size=1, stride=stride, bias=False), 198 | norm_layer(planes * block.expansion, momentum=bn_mom, affine=True), 199 | ) 200 | 201 | layers = [] 202 | #multi_dilations = [4, 8, 16] 203 | multi_dilations = [3, 4, 5] 204 | if multi_grid: 205 | layers.append(block(self.inplanes, planes, stride, dilation=multi_dilations[0], 206 | downsample=downsample, previous_dilation=dilation, norm_layer=norm_layer)) 207 | elif dilation == 1 or dilation == 2: 208 | layers.append(block(self.inplanes, planes, stride, dilation=1, 209 | downsample=downsample, previous_dilation=dilation, norm_layer=norm_layer)) 210 | elif dilation == 4: 211 | layers.append(block(self.inplanes, planes, stride, dilation=2, 212 | downsample=downsample, previous_dilation=dilation, norm_layer=norm_layer)) 213 | else: 214 | raise RuntimeError("=> unknown dilation size: {}".format(dilation)) 215 | 216 | self.inplanes = planes * block.expansion 217 | for i in range(1, blocks): 218 | if multi_grid: 219 | layers.append(block(self.inplanes, planes, dilation=multi_dilations[i], 220 | previous_dilation=dilation, norm_layer=norm_layer)) 221 | else: 222 | layers.append(block(self.inplanes, planes, dilation=dilation, previous_dilation=dilation, 223 | norm_layer=norm_layer)) 224 | 225 | return nn.Sequential(*layers) 226 | 227 | def forward(self, x): 228 | x = self.conv1(x) 229 | x = self.bn1(x) 230 | x = self.relu(x) 231 | x = self.maxpool(x) 232 | 233 | l1 = self.layer1(x) 234 | l2 = self.layer2(l1) 235 | l3 = self.layer3(l2) 236 | l4 = self.layer4(l3) 237 | return [l1, l2, l3, l4] 238 | 239 | @BACKBONES.register_module 240 | def resnet18(pretrained=False, **kwargs): 241 | """Constructs a ResNet-18 model. 242 | 243 | Args: 244 | pretrained (bool): If True, returns a model pre-trained on ImageNet 245 | """ 246 | model = ResNet(BasicBlock, [2, 2, 2, 2], **kwargs) 247 | if pretrained: 248 | model.load_state_dict(model_zoo.load_url(model_urls['resnet18'])) 249 | return model 250 | 251 | 252 | @BACKBONES.register_module 253 | def resnet34(pretrained=False, **kwargs): 254 | """Constructs a ResNet-34 model. 255 | 256 | Args: 257 | pretrained (bool): If True, returns a model pre-trained on ImageNet 258 | """ 259 | model = ResNet(BasicBlock, [3, 4, 6, 3], **kwargs) 260 | if pretrained: 261 | model.load_state_dict(model_zoo.load_url(model_urls['resnet34'])) 262 | return model 263 | 264 | 265 | @BACKBONES.register_module 266 | def resnet50(pretrained=False, **kwargs): 267 | """Constructs a ResNet-50 model. 268 | 269 | Args: 270 | pretrained (bool): If True, returns a model pre-trained on ImageNet 271 | """ 272 | model = ResNet(Bottleneck, [3, 4, 6, 3], **kwargs) 273 | if pretrained: 274 | old_dict = model_zoo.load_url(model_urls['resnet50']) 275 | model_dict = model.state_dict() 276 | old_dict = {k: v for k,v in old_dict.items() if (k in model_dict)} 277 | model_dict.update(old_dict) 278 | model.load_state_dict(model_dict) 279 | print('%s loaded.'%model_urls['resnet50']) 280 | return model 281 | 282 | 283 | @BACKBONES.register_module 284 | def resnet101(pretrained=False, **kwargs): 285 | """Constructs a ResNet-101 model. 286 | 287 | Args: 288 | pretrained (bool): If True, returns a model pre-trained on ImageNet 289 | """ 290 | model = ResNet(Bottleneck, [3, 4, 23, 3], **kwargs) 291 | if pretrained: 292 | old_dict = torch.load(model_urls['resnet101']) 293 | model_dict = model.state_dict() 294 | old_dict = {k: v for k,v in old_dict.items() if (k in model_dict)} 295 | model_dict.update(old_dict) 296 | model.load_state_dict(model_dict) 297 | print('%s loaded.'%model_urls['resnet101']) 298 | return model 299 | 300 | 301 | @BACKBONES.register_module 302 | def resnet152(pretrained=False, **kwargs): 303 | """Constructs a ResNet-152 model. 304 | 305 | Args: 306 | pretrained (bool): If True, returns a model pre-trained on ImageNet 307 | """ 308 | model = ResNet(Bottleneck, [3, 8, 36, 3], **kwargs) 309 | if pretrained: 310 | old_dict = model_zoo.load_url(model_urls['resnet152']) 311 | model_dict = model.state_dict() 312 | old_dict = {k: v for k,v in old_dict.items() if (k in model_dict)} 313 | model_dict.update(old_dict) 314 | model.load_state_dict(model_dict) 315 | print('%s loaded.'%model_urls['resnet152']) 316 | return model 317 | -------------------------------------------------------------------------------- /lib/net/backbone/resnet38d.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import numpy as np 4 | import torch.nn.functional as F 5 | from utils.registry import BACKBONES 6 | 7 | model_url='/scratch/kl3141/seam/SEAM-master/model_weight/ilsvrc-cls_rna-a1_cls1000_ep-0001.params' 8 | bn_mom = 0.0003 9 | 10 | class ResBlock(nn.Module): 11 | def __init__(self, in_channels, mid_channels, out_channels, stride=1, first_dilation=None, dilation=1, norm_layer=nn.BatchNorm2d): 12 | super(ResBlock, self).__init__() 13 | self.norm_layer = norm_layer 14 | 15 | self.same_shape = (in_channels == out_channels and stride == 1) 16 | 17 | if first_dilation == None: first_dilation = dilation 18 | 19 | self.bn_branch2a = self.norm_layer(in_channels, momentum=bn_mom, affine=True) 20 | 21 | self.conv_branch2a = nn.Conv2d(in_channels, mid_channels, 3, stride, 22 | padding=first_dilation, dilation=first_dilation, bias=False) 23 | 24 | self.bn_branch2b1 = self.norm_layer(mid_channels, momentum=bn_mom, affine=True) 25 | 26 | self.conv_branch2b1 = nn.Conv2d(mid_channels, out_channels, 3, padding=dilation, dilation=dilation, bias=False) 27 | 28 | if not self.same_shape: 29 | self.conv_branch1 = nn.Conv2d(in_channels, out_channels, 1, stride, bias=False) 30 | 31 | def forward(self, x, get_x_bn_relu=False): 32 | 33 | branch2 = self.bn_branch2a(x) 34 | branch2 = F.relu(branch2) 35 | 36 | x_bn_relu = branch2 37 | 38 | if not self.same_shape: 39 | branch1 = self.conv_branch1(branch2) 40 | else: 41 | branch1 = x 42 | 43 | branch2 = self.conv_branch2a(branch2) 44 | branch2 = self.bn_branch2b1(branch2) 45 | branch2 = F.relu(branch2) 46 | branch2 = self.conv_branch2b1(branch2) 47 | 48 | x = branch1 + branch2 49 | 50 | if get_x_bn_relu: 51 | return x, x_bn_relu 52 | 53 | return x 54 | 55 | def __call__(self, x, get_x_bn_relu=False): 56 | return self.forward(x, get_x_bn_relu=get_x_bn_relu) 57 | 58 | class ResBlock_bot(nn.Module): 59 | def __init__(self, in_channels, out_channels, stride=1, dilation=1, dropout=0., norm_layer=nn.BatchNorm2d): 60 | super(ResBlock_bot, self).__init__() 61 | self.norm_layer = norm_layer 62 | 63 | self.same_shape = (in_channels == out_channels and stride == 1) 64 | 65 | self.bn_branch2a = self.norm_layer(in_channels, momentum=bn_mom, affine=True) 66 | self.conv_branch2a = nn.Conv2d(in_channels, out_channels//4, 1, stride, bias=False) 67 | 68 | self.bn_branch2b1 = self.norm_layer(out_channels//4, momentum=bn_mom, affine=True) 69 | self.dropout_2b1 = torch.nn.Dropout2d(dropout) 70 | self.conv_branch2b1 = nn.Conv2d(out_channels//4, out_channels//2, 3, padding=dilation, dilation=dilation, bias=False) 71 | 72 | self.bn_branch2b2 = self.norm_layer(out_channels//2, momentum=bn_mom, affine=True) 73 | self.dropout_2b2 = torch.nn.Dropout2d(dropout) 74 | self.conv_branch2b2 = nn.Conv2d(out_channels//2, out_channels, 1, bias=False) 75 | 76 | if not self.same_shape: 77 | self.conv_branch1 = nn.Conv2d(in_channels, out_channels, 1, stride, bias=False) 78 | 79 | def forward(self, x, get_x_bn_relu=False): 80 | 81 | branch2 = self.bn_branch2a(x) 82 | branch2 = F.relu(branch2) 83 | x_bn_relu = branch2 84 | 85 | branch1 = self.conv_branch1(branch2) 86 | 87 | branch2 = self.conv_branch2a(branch2) 88 | 89 | branch2 = self.bn_branch2b1(branch2) 90 | branch2 = F.relu(branch2) 91 | branch2 = self.dropout_2b1(branch2) 92 | branch2 = self.conv_branch2b1(branch2) 93 | 94 | branch2 = self.bn_branch2b2(branch2) 95 | branch2 = F.relu(branch2) 96 | branch2 = self.dropout_2b2(branch2) 97 | branch2 = self.conv_branch2b2(branch2) 98 | 99 | x = branch1 + branch2 100 | 101 | if get_x_bn_relu: 102 | return x, x_bn_relu 103 | 104 | return x 105 | 106 | def __call__(self, x, get_x_bn_relu=False): 107 | return self.forward(x, get_x_bn_relu=get_x_bn_relu) 108 | 109 | class Normalize(): 110 | def __init__(self, mean = (0.485, 0.456, 0.406), std = (0.229, 0.224, 0.225)): 111 | 112 | self.mean = mean 113 | self.std = std 114 | 115 | def __call__(self, img): 116 | imgarr = np.asarray(img) 117 | proc_img = np.empty_like(imgarr, np.float32) 118 | 119 | proc_img[..., 0] = (imgarr[..., 0] / 255. - self.mean[0]) / self.std[0] 120 | proc_img[..., 1] = (imgarr[..., 1] / 255. - self.mean[1]) / self.std[1] 121 | proc_img[..., 2] = (imgarr[..., 2] / 255. - self.mean[2]) / self.std[2] 122 | 123 | return proc_img 124 | 125 | class Net(nn.Module): 126 | def __init__(self, norm_layer=nn.BatchNorm2d): 127 | super(Net, self).__init__() 128 | self.norm_layer = norm_layer 129 | 130 | self.conv1a = nn.Conv2d(3, 64, 3, padding=1, bias=False) 131 | 132 | self.b2 = ResBlock(64, 128, 128, stride=2, norm_layer=self.norm_layer) 133 | self.b2_1 = ResBlock(128, 128, 128, norm_layer=self.norm_layer) 134 | self.b2_2 = ResBlock(128, 128, 128, norm_layer=self.norm_layer) 135 | 136 | self.b3 = ResBlock(128, 256, 256, stride=2, norm_layer=self.norm_layer) 137 | self.b3_1 = ResBlock(256, 256, 256, norm_layer=self.norm_layer) 138 | self.b3_2 = ResBlock(256, 256, 256, norm_layer=self.norm_layer) 139 | 140 | self.b4 = ResBlock(256, 512, 512, stride=2, norm_layer=self.norm_layer) 141 | self.b4_1 = ResBlock(512, 512, 512, norm_layer=self.norm_layer) 142 | self.b4_2 = ResBlock(512, 512, 512, norm_layer=self.norm_layer) 143 | self.b4_3 = ResBlock(512, 512, 512, norm_layer=self.norm_layer) 144 | self.b4_4 = ResBlock(512, 512, 512, norm_layer=self.norm_layer) 145 | self.b4_5 = ResBlock(512, 512, 512, norm_layer=self.norm_layer) 146 | 147 | self.b5 = ResBlock(512, 512, 1024, stride=1, first_dilation=1, dilation=2, norm_layer=self.norm_layer) 148 | self.b5_1 = ResBlock(1024, 512, 1024, dilation=2, norm_layer=self.norm_layer) 149 | self.b5_2 = ResBlock(1024, 512, 1024, dilation=2, norm_layer=self.norm_layer) 150 | 151 | self.b6 = ResBlock_bot(1024, 2048, stride=1, dilation=4, dropout=0.3, norm_layer=self.norm_layer) 152 | 153 | self.b7 = ResBlock_bot(2048, 4096, dilation=4, dropout=0.5, norm_layer=self.norm_layer) 154 | 155 | self.bn7 = self.norm_layer(4096, momentum=bn_mom, affine=True) 156 | 157 | self.not_training = [self.conv1a] 158 | 159 | self.normalize = Normalize() 160 | self.OUTPUT_DIM = 4096 161 | 162 | def forward(self, x): 163 | 164 | x = self.conv1a(x) 165 | 166 | x = self.b2(x) 167 | x = self.b2_1(x) 168 | x = self.b2_2(x) 169 | 170 | x = self.b3(x) 171 | x = self.b3_1(x) 172 | x = self.b3_2(x) 173 | 174 | x = self.b4(x) 175 | x = self.b4_1(x) 176 | x = self.b4_2(x) 177 | x = self.b4_3(x) 178 | x = self.b4_4(x) 179 | x = self.b4_5(x) 180 | 181 | x, conv4 = self.b5(x, get_x_bn_relu=True) 182 | x = self.b5_1(x) 183 | x = self.b5_2(x) 184 | 185 | x, conv5 = self.b6(x, get_x_bn_relu=True) 186 | 187 | x = self.b7(x) 188 | conv6 = F.relu(self.bn7(x)) 189 | 190 | return [conv4, conv5, conv6] 191 | 192 | def train(self, mode=True): 193 | 194 | super().train(mode) 195 | 196 | for layer in self.not_training: 197 | 198 | if isinstance(layer, torch.nn.Conv2d): 199 | layer.weight.requires_grad = False 200 | 201 | elif isinstance(layer, torch.nn.Module): 202 | for c in layer.children(): 203 | c.weight.requires_grad = False 204 | if c.bias is not None: 205 | c.bias.requires_grad = False 206 | 207 | for layer in self.modules(): 208 | 209 | if isinstance(layer, self.norm_layer): 210 | layer.eval() 211 | layer.bias.requires_grad = False 212 | layer.weight.requires_grad = False 213 | 214 | return 215 | 216 | def convert_mxnet_to_torch(filename): 217 | import mxnet 218 | 219 | save_dict = mxnet.nd.load(filename) 220 | 221 | renamed_dict = dict() 222 | 223 | bn_param_mx_pt = {'beta': 'bias', 'gamma': 'weight', 'mean': 'running_mean', 'var': 'running_var'} 224 | 225 | for k, v in save_dict.items(): 226 | 227 | v = torch.from_numpy(v.asnumpy()) 228 | toks = k.split('_') 229 | 230 | if 'conv1a' in toks[0]: 231 | renamed_dict['conv1a.weight'] = v 232 | 233 | elif 'linear1000' in toks[0]: 234 | pass 235 | 236 | elif 'branch' in toks[1]: 237 | 238 | pt_name = [] 239 | 240 | if toks[0][-1] != 'a': 241 | pt_name.append('b' + toks[0][-3] + '_' + toks[0][-1]) 242 | else: 243 | pt_name.append('b' + toks[0][-2]) 244 | 245 | if 'res' in toks[0]: 246 | layer_type = 'conv' 247 | last_name = 'weight' 248 | 249 | else: # 'bn' in toks[0]: 250 | layer_type = 'bn' 251 | last_name = bn_param_mx_pt[toks[-1]] 252 | 253 | pt_name.append(layer_type + '_' + toks[1]) 254 | 255 | pt_name.append(last_name) 256 | 257 | torch_name = '.'.join(pt_name) 258 | renamed_dict[torch_name] = v 259 | 260 | else: 261 | last_name = bn_param_mx_pt[toks[-1]] 262 | renamed_dict['bn7.' + last_name] = v 263 | 264 | return renamed_dict 265 | 266 | @BACKBONES.register_module 267 | def resnet38(pretrained=False, norm_layer=nn.BatchNorm2d, **kwargs): 268 | model = Net(norm_layer) 269 | if pretrained: 270 | weight_dict = convert_mxnet_to_torch(model_url) 271 | model.load_state_dict(weight_dict,strict=False) 272 | return model 273 | -------------------------------------------------------------------------------- /lib/net/backbone/xception.py: -------------------------------------------------------------------------------- 1 | """ 2 | Ported to pytorch thanks to [tstandley](https://github.com/tstandley/Xception-PyTorch) 3 | @author: tstandley 4 | Adapted by cadene 5 | Creates an Xception Model as defined in: 6 | Francois Chollet 7 | Xception: Deep Learning with Depthwise Separable Convolutions 8 | https://arxiv.org/pdf/1610.02357.pdf 9 | This weights ported from the Keras implementation. Achieves the following performance on the validation set: 10 | Loss:0.9173 Prec@1:78.892 Prec@5:94.292 11 | REMEMBER to set your image size to 3x299x299 for both test and validation 12 | normalize = transforms.Normalize(mean=[0.5, 0.5, 0.5], 13 | std=[0.5, 0.5, 0.5]) 14 | The resize parameter of the validation transform should be 333, and make sure to center crop at 299x299 15 | """ 16 | import math 17 | import torch 18 | import torch.nn as nn 19 | import torch.nn.functional as F 20 | import torch.utils.model_zoo as model_zoo 21 | from torch.nn import init 22 | from net.sync_batchnorm import SynchronizedBatchNorm2d 23 | from utils.registry import BACKBONES 24 | 25 | bn_mom = 0.1 26 | __all__ = ['xception'] 27 | 28 | model_urls = { 29 | 'xception': '/home/wangyude/.cache/torch/checkpoints/xception_pytorch_imagenet.pth'#'http://data.lip6.fr/cadene/pretrainedmodels/xception-b5690688.pth' 30 | } 31 | 32 | class SeparableConv2d(nn.Module): 33 | def __init__(self,in_channels,out_channels,kernel_size=1,stride=1,padding=0,dilation=1,bias=False,activate_first=True,inplace=True,norm_layer=nn.BatchNorm2d): 34 | super(SeparableConv2d,self).__init__() 35 | self.norm_layer = norm_layer 36 | self.relu0 = nn.ReLU(inplace=inplace) 37 | self.depthwise = nn.Conv2d(in_channels,in_channels,kernel_size,stride,padding,dilation,groups=in_channels,bias=bias) 38 | self.bn1 = self.norm_layer(in_channels, momentum=bn_mom) 39 | self.relu1 = nn.ReLU(inplace=True) 40 | self.pointwise = nn.Conv2d(in_channels,out_channels,1,1,0,1,1,bias=bias) 41 | self.bn2 = self.norm_layer(out_channels, momentum=bn_mom) 42 | self.relu2 = nn.ReLU(inplace=True) 43 | self.activate_first = activate_first 44 | def forward(self,x): 45 | if self.activate_first: 46 | x = self.relu0(x) 47 | x = self.depthwise(x) 48 | x = self.bn1(x) 49 | if not self.activate_first: 50 | x = self.relu1(x) 51 | x = self.pointwise(x) 52 | x = self.bn2(x) 53 | if not self.activate_first: 54 | x = self.relu2(x) 55 | return x 56 | 57 | 58 | class Block(nn.Module): 59 | def __init__(self,in_filters,out_filters,strides=1,atrous=None,grow_first=True,activate_first=True,inplace=True,norm_layer=nn.BatchNorm2d): 60 | super(Block, self).__init__() 61 | self.norm_layer = norm_layer 62 | if atrous == None: 63 | atrous = [1]*3 64 | elif isinstance(atrous, int): 65 | atrous_list = [atrous]*3 66 | atrous = atrous_list 67 | idx = 0 68 | self.head_relu = True 69 | if out_filters != in_filters or strides!=1: 70 | self.skip = nn.Conv2d(in_filters,out_filters,1,stride=strides, bias=False) 71 | self.skipbn = self.norm_layer(out_filters, momentum=bn_mom) 72 | self.head_relu = False 73 | else: 74 | self.skip=None 75 | 76 | self.hook_layer = None 77 | if grow_first: 78 | filters = out_filters 79 | else: 80 | filters = in_filters 81 | self.sepconv1 = SeparableConv2d(in_filters,filters,3,stride=1,padding=1*atrous[0],dilation=atrous[0],bias=False,activate_first=activate_first,inplace=self.head_relu,norm_layer=self.norm_layer) 82 | self.sepconv2 = SeparableConv2d(filters,out_filters,3,stride=1,padding=1*atrous[1],dilation=atrous[1],bias=False,activate_first=activate_first,norm_layer=self.norm_layer) 83 | self.sepconv3 = SeparableConv2d(out_filters,out_filters,3,stride=strides,padding=1*atrous[2],dilation=atrous[2],bias=False,activate_first=activate_first,inplace=inplace,norm_layer=self.norm_layer) 84 | 85 | def forward(self,inp): 86 | 87 | if self.skip is not None: 88 | skip = self.skip(inp) 89 | skip = self.skipbn(skip) 90 | else: 91 | skip = inp 92 | 93 | x = self.sepconv1(inp) 94 | x = self.sepconv2(x) 95 | self.hook_layer = x 96 | x = self.sepconv3(x) 97 | 98 | x+=skip 99 | return x 100 | 101 | 102 | class Xception(nn.Module): 103 | """ 104 | Xception optimized for the ImageNet dataset, as specified in 105 | https://arxiv.org/pdf/1610.02357.pdf 106 | """ 107 | def __init__(self, os, norm_layer=nn.BatchNorm2d): 108 | """ Constructor 109 | Args: 110 | num_classes: number of classes 111 | """ 112 | super(Xception, self).__init__() 113 | self.norm_layer = norm_layer 114 | 115 | stride_list = None 116 | if os == 8: 117 | stride_list = [2,1,1] 118 | elif os == 16: 119 | stride_list = [2,2,1] 120 | else: 121 | raise ValueError('xception.py: output stride=%d is not supported.'%os) 122 | self.conv1 = nn.Conv2d(3, 32, 3, 2, 1, bias=False) 123 | self.bn1 = self.norm_layer(32, momentum=bn_mom) 124 | self.relu = nn.ReLU(inplace=True) 125 | 126 | self.conv2 = nn.Conv2d(32,64,3,1,1,bias=False) 127 | self.bn2 = self.norm_layer(64, momentum=bn_mom) 128 | #do relu here 129 | 130 | self.block1=Block(64,128,2,norm_layer=self.norm_layer) 131 | self.block2=Block(128,256,stride_list[0],inplace=False,norm_layer=self.norm_layer) 132 | self.block3=Block(256,728,stride_list[1],norm_layer=self.norm_layer) 133 | 134 | rate = 16//os 135 | self.block4=Block(728,728,1,atrous=rate,norm_layer=self.norm_layer) 136 | self.block5=Block(728,728,1,atrous=rate,norm_layer=self.norm_layer) 137 | self.block6=Block(728,728,1,atrous=rate,norm_layer=self.norm_layer) 138 | self.block7=Block(728,728,1,atrous=rate,norm_layer=self.norm_layer) 139 | 140 | self.block8=Block(728,728,1,atrous=rate,norm_layer=self.norm_layer) 141 | self.block9=Block(728,728,1,atrous=rate,norm_layer=self.norm_layer) 142 | self.block10=Block(728,728,1,atrous=rate,norm_layer=self.norm_layer) 143 | self.block11=Block(728,728,1,atrous=rate,norm_layer=self.norm_layer) 144 | 145 | self.block12=Block(728,728,1,atrous=rate,norm_layer=self.norm_layer) 146 | self.block13=Block(728,728,1,atrous=rate,norm_layer=self.norm_layer) 147 | self.block14=Block(728,728,1,atrous=rate,norm_layer=self.norm_layer) 148 | self.block15=Block(728,728,1,atrous=rate,norm_layer=self.norm_layer) 149 | 150 | self.block16=Block(728,728,1,atrous=[1*rate,1*rate,1*rate],norm_layer=self.norm_layer) 151 | self.block17=Block(728,728,1,atrous=[1*rate,1*rate,1*rate],norm_layer=self.norm_layer) 152 | self.block18=Block(728,728,1,atrous=[1*rate,1*rate,1*rate],norm_layer=self.norm_layer) 153 | self.block19=Block(728,728,1,atrous=[1*rate,1*rate,1*rate],norm_layer=self.norm_layer) 154 | 155 | self.block20=Block(728,1024,stride_list[2],atrous=rate,grow_first=False,norm_layer=self.norm_layer) 156 | #self.block12=Block(728,1024,2,2,start_with_relu=True,grow_first=False) 157 | 158 | self.conv3 = SeparableConv2d(1024,1536,3,1,1*rate,dilation=rate,activate_first=False,norm_layer=self.norm_layer) 159 | # self.bn3 = SynchronizedBatchNorm2d(1536, momentum=bn_mom) 160 | 161 | self.conv4 = SeparableConv2d(1536,1536,3,1,1*rate,dilation=rate,activate_first=False,norm_layer=self.norm_layer) 162 | # self.bn4 = SynchronizedBatchNorm2d(1536, momentum=bn_mom) 163 | 164 | #do relu here 165 | self.conv5 = SeparableConv2d(1536,2048,3,1,1*rate,dilation=rate,activate_first=False,norm_layer=self.norm_layer) 166 | # self.bn5 = SynchronizedBatchNorm2d(2048, momentum=bn_mom) 167 | self.OUTPUT_DIM = 2048 168 | self.MIDDLE_DIM = 256 169 | 170 | #------- init weights -------- 171 | for m in self.modules(): 172 | if isinstance(m, nn.Conv2d): 173 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 174 | m.weight.data.normal_(0, math.sqrt(2. / n)) 175 | elif isinstance(m, self.norm_layer): 176 | m.weight.data.fill_(1) 177 | m.bias.data.zero_() 178 | #----------------------------- 179 | 180 | def forward(self, input): 181 | layers = [] 182 | x = self.conv1(input) 183 | x = self.bn1(x) 184 | x = self.relu(x) 185 | #self.layers.append(x) 186 | x = self.conv2(x) 187 | x = self.bn2(x) 188 | x = self.relu(x) 189 | 190 | x = self.block1(x) 191 | x = self.block2(x) 192 | l1 = self.block2.hook_layer 193 | x = self.block3(x) 194 | l2 = self.block3.hook_layer 195 | x = self.block4(x) 196 | x = self.block5(x) 197 | x = self.block6(x) 198 | x = self.block7(x) 199 | x = self.block8(x) 200 | x = self.block9(x) 201 | x = self.block10(x) 202 | x = self.block11(x) 203 | x = self.block12(x) 204 | x = self.block13(x) 205 | x = self.block14(x) 206 | x = self.block15(x) 207 | x = self.block16(x) 208 | x = self.block17(x) 209 | x = self.block18(x) 210 | x = self.block19(x) 211 | x = self.block20(x) 212 | l3 = self.block20.hook_layer 213 | 214 | x = self.conv3(x) 215 | # x = self.bn3(x) 216 | # x = self.relu(x) 217 | 218 | x = self.conv4(x) 219 | # x = self.bn4(x) 220 | # x = self.relu(x) 221 | 222 | l4 = self.conv5(x) 223 | # x = self.bn5(x) 224 | # x = self.relu(x) 225 | 226 | #return layers 227 | return [l1,l2,l3,l4] 228 | 229 | @BACKBONES.register_module 230 | def xception(pretrained=True, os=8, **kwargs): 231 | model = Xception(os=os) 232 | if pretrained: 233 | old_dict = torch.load(model_urls['xception']) 234 | # old_dict = model_zoo.load_url(model_urls['xception']) 235 | # for name, weights in old_dict.items(): 236 | # if 'pointwise' in name: 237 | # old_dict[name] = weights.unsqueeze(-1).unsqueeze(-1) 238 | model_dict = model.state_dict() 239 | old_dict = {k: v for k,v in old_dict.items() if ('itr' not in k and 'tmp' not in k and 'track' not in k)} 240 | model_dict.update(old_dict) 241 | 242 | model.load_state_dict(model_dict) 243 | 244 | return model 245 | -------------------------------------------------------------------------------- /lib/net/deeplabv1_wo_interp.py: -------------------------------------------------------------------------------- 1 | # ---------------------------------------- 2 | # Written by Yude Wang 3 | # ---------------------------------------- 4 | 5 | import numpy as np 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | from torch.nn import init 10 | from net.backbone import build_backbone 11 | from utils.registry import NETS 12 | 13 | @NETS.register_module 14 | class deeplabv1_wo_interp(nn.Module): 15 | def __init__(self, cfg, batchnorm=nn.BatchNorm2d, **kwargs): 16 | super(deeplabv1_wo_interp, self).__init__() 17 | self.cfg = cfg 18 | self.batchnorm = batchnorm 19 | #self.backbone = build_backbone(self.cfg.MODEL_BACKBONE, os=self.cfg.MODEL_OUTPUT_STRIDE) 20 | self.backbone = build_backbone(self.cfg.MODEL_BACKBONE, pretrained=cfg.MODEL_BACKBONE_PRETRAIN, norm_layer=self.batchnorm, **kwargs) 21 | self.conv_fov = nn.Conv2d(self.backbone.OUTPUT_DIM, 512, 3, 1, padding=12, dilation=12, bias=False) 22 | self.bn_fov = batchnorm(512, momentum=cfg.TRAIN_BN_MOM, affine=True) 23 | self.conv_fov2 = nn.Conv2d(512, 512, 1, 1, padding=0, bias=False) 24 | self.bn_fov2 = batchnorm(512, momentum=cfg.TRAIN_BN_MOM, affine=True) 25 | self.dropout1 = nn.Dropout(0.5) 26 | self.cls_conv = nn.Conv2d(512, cfg.MODEL_NUM_CLASSES, 1, 1, padding=0) 27 | self.__initial__() 28 | self.not_training = []#[self.backbone.conv1a, self.backbone.b2, self.backbone.b2_1, self.backbone.b2_2] 29 | #self.from_scratch_layers = [self.cls_conv] 30 | self.from_scratch_layers = [self.conv_fov, self.conv_fov2, self.cls_conv] 31 | 32 | def __initial__(self): 33 | for m in self.modules(): 34 | if m not in self.backbone.modules(): 35 | if isinstance(m, nn.Conv2d) or isinstance(m, nn.Conv1d): 36 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 37 | elif isinstance(m, self.batchnorm): 38 | nn.init.constant_(m.weight, 1) 39 | nn.init.constant_(m.bias, 0) 40 | #self.backbone = build_backbone(self.cfg.MODEL_BACKBONE, pretrained=self.cfg.MODEL_BACKBONE_PRETRAIN) 41 | 42 | def forward(self, x): 43 | n,c,h,w = x.size() 44 | x_bottom = self.backbone(x)[-1] 45 | feature = self.conv_fov(x_bottom) 46 | feature = self.bn_fov(feature) 47 | feature = F.relu(feature, inplace=True) 48 | feature = self.conv_fov2(feature) 49 | feature = self.bn_fov2(feature) 50 | feature = F.relu(feature, inplace=True) 51 | feature = self.dropout1(feature) 52 | result = self.cls_conv(feature) 53 | # result = F.interpolate(result,(h,w),mode='bilinear', align_corners=True) 54 | return result 55 | 56 | def get_parameter_groups(self): 57 | groups = ([], [], [], []) 58 | for m in self.modules(): 59 | if isinstance(m, nn.Conv2d): 60 | if m.weight.requires_grad: 61 | if m in self.from_scratch_layers: 62 | groups[2].append(m.weight) 63 | else: 64 | groups[0].append(m.weight) 65 | 66 | if m.bias is not None and m.bias.requires_grad: 67 | 68 | if m in self.from_scratch_layers: 69 | groups[3].append(m.bias) 70 | else: 71 | groups[1].append(m.bias) 72 | return groups -------------------------------------------------------------------------------- /lib/net/generateNet.py: -------------------------------------------------------------------------------- 1 | # ---------------------------------------- 2 | # Written by Yude Wang 3 | # ---------------------------------------- 4 | 5 | #from net.deeplabv3plus import deeplabv3plus 6 | #from net.deeplabv3 import deeplabv3, deeplabv3_noise, deeplabv3_feature, deeplabv3_glore 7 | #from net.deeplabv2 import deeplabv2, deeplabv2_caffe 8 | #from net.deeplabv1 import deeplabv1, deeplabv1_caffe 9 | #from net.clsnet import ClsNet 10 | #from net.fcn import FCN 11 | #from net.DFANet import DFANet 12 | from utils.registry import NETS 13 | 14 | def generate_net(cfg, **kwargs): 15 | net = NETS.get(cfg.MODEL_NAME)(cfg, **kwargs) 16 | return net 17 | #def generate_net(cfg): 18 | # if cfg.MODEL_NAME == 'deeplabv3plus' or cfg.MODEL_NAME == 'deeplabv3+': 19 | # return deeplabv3plus(cfg) 20 | # elif cfg.MODEL_NAME == 'deeplabv3': 21 | # return deeplabv3(cfg) 22 | # elif cfg.MODEL_NAME == 'deeplabv2': 23 | # return deeplabv2(cfg) 24 | # elif cfg.MODEL_NAME == 'deeplabv1': 25 | # return deeplabv1(cfg) 26 | # elif cfg.MODEL_NAME == 'deeplabv1_caffe': 27 | # return deeplabv1_caffe(cfg) 28 | # elif cfg.MODEL_NAME == 'deeplabv2_caffe': 29 | # return deeplabv2_caffe(cfg) 30 | # elif cfg.MODEL_NAME == 'clsnet' or cfg.MODEL_NAME == 'ClsNet': 31 | # return ClsNet(cfg) 32 | # elif cfg.MODEL_NAME == 'fcn' or cfg.MODEL_NAME == 'FCN': 33 | # return FCN(cfg) 34 | # elif cfg.MODEL_NAME == 'DFANet' or cfg.MODEL_NAME == 'dfanet': 35 | # return DFANet(cfg) 36 | # else: 37 | # raise ValueError('generateNet.py: network %s is not support yet'%cfg.MODEL_NAME) 38 | -------------------------------------------------------------------------------- /lib/net/operators/ASPP.py: -------------------------------------------------------------------------------- 1 | # ---------------------------------------- 2 | # Written by Yude Wang 3 | # ---------------------------------------- 4 | 5 | import torch 6 | import torch.nn as nn 7 | import numpy as np 8 | import torch.nn.functional as F 9 | from net.sync_batchnorm import SynchronizedBatchNorm2d 10 | 11 | class ASPP(nn.Module): 12 | 13 | def __init__(self, dim_in, dim_out, rate=[1,6,12,18], bn_mom=0.1, has_global=True, batchnorm=SynchronizedBatchNorm2d): 14 | super(ASPP, self).__init__() 15 | self.dim_in = dim_in 16 | self.dim_out = dim_out 17 | self.has_global = has_global 18 | if rate[0] == 0: 19 | self.branch1 = nn.Sequential( 20 | nn.Conv2d(dim_in, dim_out, 1, 1, padding=0, dilation=1,bias=False), 21 | batchnorm(dim_out, momentum=bn_mom, affine=True), 22 | nn.ReLU(inplace=True), 23 | ) 24 | else: 25 | self.branch1 = nn.Sequential( 26 | nn.Conv2d(dim_in, dim_out, 3, 1, padding=rate[0], dilation=rate[0],bias=False), 27 | batchnorm(dim_out, momentum=bn_mom, affine=True), 28 | nn.ReLU(inplace=True), 29 | ) 30 | self.branch2 = nn.Sequential( 31 | nn.Conv2d(dim_in, dim_out, 3, 1, padding=rate[1], dilation=rate[1],bias=False), 32 | batchnorm(dim_out, momentum=bn_mom, affine=True), 33 | nn.ReLU(inplace=True), 34 | ) 35 | self.branch3 = nn.Sequential( 36 | nn.Conv2d(dim_in, dim_out, 3, 1, padding=rate[2], dilation=rate[2],bias=False), 37 | batchnorm(dim_out, momentum=bn_mom, affine=True), 38 | nn.ReLU(inplace=True), 39 | ) 40 | self.branch4 = nn.Sequential( 41 | nn.Conv2d(dim_in, dim_out, 3, 1, padding=rate[3], dilation=rate[3],bias=False), 42 | batchnorm(dim_out, momentum=bn_mom, affine=True), 43 | nn.ReLU(inplace=True), 44 | ) 45 | if self.has_global: 46 | self.branch5_conv = nn.Conv2d(dim_in, dim_out, 1, 1, 0,bias=False) 47 | self.branch5_bn = batchnorm(dim_out, momentum=bn_mom, affine=True) 48 | self.branch5_relu = nn.ReLU(inplace=True) 49 | self.conv_cat = nn.Sequential( 50 | nn.Conv2d(dim_out*5, dim_out, 1, 1, padding=0,bias=False), 51 | batchnorm(dim_out, momentum=bn_mom, affine=True), 52 | nn.ReLU(inplace=True), 53 | nn.Dropout(0.5) 54 | ) 55 | else: 56 | self.conv_cat = nn.Sequential( 57 | nn.Conv2d(dim_out*4, dim_out, 1, 1, padding=0), 58 | batchnorm(dim_out, momentum=bn_mom, affine=True), 59 | nn.ReLU(inplace=True), 60 | nn.Dropout(0.5) 61 | ) 62 | def forward(self, x): 63 | result = None 64 | [b,c,row,col] = x.size() 65 | conv1x1 = self.branch1(x) 66 | conv3x3_1 = self.branch2(x) 67 | conv3x3_2 = self.branch3(x) 68 | conv3x3_3 = self.branch4(x) 69 | if self.has_global: 70 | global_feature = F.adaptive_avg_pool2d(x, (1,1)) 71 | global_feature = self.branch5_conv(global_feature) 72 | global_feature = self.branch5_bn(global_feature) 73 | global_feature = self.branch5_relu(global_feature) 74 | global_feature = F.interpolate(global_feature, (row,col), None, 'bilinear', align_corners=True) 75 | 76 | feature_cat = torch.cat([conv1x1, conv3x3_1, conv3x3_2, conv3x3_3, global_feature], dim=1) 77 | else: 78 | feature_cat = torch.cat([conv1x1, conv3x3_1, conv3x3_2, conv3x3_3], dim=1) 79 | result = self.conv_cat(feature_cat) 80 | 81 | return result 82 | -------------------------------------------------------------------------------- /lib/net/operators/PPM.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | class PPM(nn.Module): 6 | """ 7 | Reference: 8 | Zhao, Hengshuang, et al. *"Pyramid scene parsing network."* 9 | """ 10 | def __init__(self, features, out_features=512, sizes=(1, 2, 3, 6), norm_layer=nn.BatchNorm2d): 11 | super(PPM, self).__init__() 12 | 13 | self.stages = [] 14 | self.stages = nn.ModuleList([self._make_stage(features, out_features, size, norm_layer) for size in sizes]) 15 | self.bottleneck = nn.Sequential( 16 | nn.Conv2d(features+len(sizes)*out_features, out_features, kernel_size=1, padding=0, dilation=1, bias=False), 17 | norm_layer(out_features), 18 | nn.ReLU(), 19 | nn.Dropout2d(0.1) 20 | ) 21 | 22 | def _make_stage(self, features, out_features, size, norm_layer): 23 | prior = nn.AdaptiveAvgPool2d(output_size=(size, size)) 24 | conv = nn.Conv2d(features, out_features, kernel_size=1, bias=False) 25 | bn = norm_layer(out_features) 26 | return nn.Sequential(prior, conv, bn) 27 | 28 | def forward(self, feats): 29 | h, w = feats.size(2), feats.size(3) 30 | priors = [F.upsample(input=stage(feats), size=(h, w), mode='bilinear', align_corners=True) for stage in self.stages] + [feats] 31 | bottle = self.bottleneck(torch.cat(priors, 1)) 32 | return bottle 33 | -------------------------------------------------------------------------------- /lib/net/operators/__init__.py: -------------------------------------------------------------------------------- 1 | from .ASPP import * 2 | from .PPM import * 3 | -------------------------------------------------------------------------------- /lib/net/sync_batchnorm/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # File : __init__.py 3 | # Author : Jiayuan Mao 4 | # Email : maojiayuan@gmail.com 5 | # Date : 27/01/2018 6 | # 7 | # This file is part of Synchronized-BatchNorm-PyTorch. 8 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch 9 | # Distributed under MIT License. 10 | 11 | from .batchnorm import SynchronizedBatchNorm1d, SynchronizedBatchNorm2d, SynchronizedBatchNorm3d 12 | from .replicate import DataParallelWithCallback, patch_replication_callback 13 | -------------------------------------------------------------------------------- /lib/net/sync_batchnorm/comm.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # File : comm.py 3 | # Author : Jiayuan Mao 4 | # Email : maojiayuan@gmail.com 5 | # Date : 27/01/2018 6 | # 7 | # This file is part of Synchronized-BatchNorm-PyTorch. 8 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch 9 | # Distributed under MIT License. 10 | 11 | import queue 12 | import collections 13 | import threading 14 | 15 | __all__ = ['FutureResult', 'SlavePipe', 'SyncMaster'] 16 | 17 | 18 | class FutureResult(object): 19 | """A thread-safe future implementation. Used only as one-to-one pipe.""" 20 | 21 | def __init__(self): 22 | self._result = None 23 | self._lock = threading.Lock() 24 | self._cond = threading.Condition(self._lock) 25 | 26 | def put(self, result): 27 | with self._lock: 28 | assert self._result is None, 'Previous result has\'t been fetched.' 29 | self._result = result 30 | self._cond.notify() 31 | 32 | def get(self): 33 | with self._lock: 34 | if self._result is None: 35 | self._cond.wait() 36 | 37 | res = self._result 38 | self._result = None 39 | return res 40 | 41 | 42 | _MasterRegistry = collections.namedtuple('MasterRegistry', ['result']) 43 | _SlavePipeBase = collections.namedtuple('_SlavePipeBase', ['identifier', 'queue', 'result']) 44 | 45 | 46 | class SlavePipe(_SlavePipeBase): 47 | """Pipe for master-slave communication.""" 48 | 49 | def run_slave(self, msg): 50 | self.queue.put((self.identifier, msg)) 51 | ret = self.result.get() 52 | self.queue.put(True) 53 | return ret 54 | 55 | 56 | class SyncMaster(object): 57 | """An abstract `SyncMaster` object. 58 | 59 | - During the replication, as the data parallel will trigger an callback of each module, all slave devices should 60 | call `register(id)` and obtain an `SlavePipe` to communicate with the master. 61 | - During the forward pass, master device invokes `run_master`, all messages from slave devices will be collected, 62 | and passed to a registered callback. 63 | - After receiving the messages, the master device should gather the information and determine to message passed 64 | back to each slave devices. 65 | """ 66 | 67 | def __init__(self, master_callback): 68 | """ 69 | 70 | Args: 71 | master_callback: a callback to be invoked after having collected messages from slave devices. 72 | """ 73 | self._master_callback = master_callback 74 | self._queue = queue.Queue() 75 | self._registry = collections.OrderedDict() 76 | self._activated = False 77 | 78 | def register_slave(self, identifier): 79 | """ 80 | Register an slave device. 81 | 82 | Args: 83 | identifier: an identifier, usually is the device id. 84 | 85 | Returns: a `SlavePipe` object which can be used to communicate with the master device. 86 | 87 | """ 88 | if self._activated: 89 | assert self._queue.empty(), 'Queue is not clean before next initialization.' 90 | self._activated = False 91 | self._registry.clear() 92 | future = FutureResult() 93 | self._registry[identifier] = _MasterRegistry(future) 94 | return SlavePipe(identifier, self._queue, future) 95 | 96 | def run_master(self, master_msg): 97 | """ 98 | Main entry for the master device in each forward pass. 99 | The messages were first collected from each devices (including the master device), and then 100 | an callback will be invoked to compute the message to be sent back to each devices 101 | (including the master device). 102 | 103 | Args: 104 | master_msg: the message that the master want to send to itself. This will be placed as the first 105 | message when calling `master_callback`. For detailed usage, see `_SynchronizedBatchNorm` for an example. 106 | 107 | Returns: the message to be sent back to the master device. 108 | 109 | """ 110 | self._activated = True 111 | 112 | intermediates = [(0, master_msg)] 113 | for i in range(self.nr_slaves): 114 | intermediates.append(self._queue.get()) 115 | 116 | results = self._master_callback(intermediates) 117 | assert results[0][0] == 0, 'The first result should belongs to the master.' 118 | 119 | for i, res in results: 120 | if i == 0: 121 | continue 122 | self._registry[i].result.put(res) 123 | 124 | for i in range(self.nr_slaves): 125 | assert self._queue.get() is True 126 | 127 | return results[0][1] 128 | 129 | @property 130 | def nr_slaves(self): 131 | return len(self._registry) 132 | -------------------------------------------------------------------------------- /lib/net/sync_batchnorm/replicate.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # File : replicate.py 3 | # Author : Jiayuan Mao 4 | # Email : maojiayuan@gmail.com 5 | # Date : 27/01/2018 6 | # 7 | # This file is part of Synchronized-BatchNorm-PyTorch. 8 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch 9 | # Distributed under MIT License. 10 | 11 | import functools 12 | 13 | from torch.nn.parallel.data_parallel import DataParallel 14 | 15 | __all__ = [ 16 | 'CallbackContext', 17 | 'execute_replication_callbacks', 18 | 'DataParallelWithCallback', 19 | 'patch_replication_callback' 20 | ] 21 | 22 | 23 | class CallbackContext(object): 24 | pass 25 | 26 | 27 | def execute_replication_callbacks(modules): 28 | """ 29 | Execute an replication callback `__data_parallel_replicate__` on each module created by original replication. 30 | 31 | The callback will be invoked with arguments `__data_parallel_replicate__(ctx, copy_id)` 32 | 33 | Note that, as all modules are isomorphism, we assign each sub-module with a context 34 | (shared among multiple copies of this module on different devices). 35 | Through this context, different copies can share some information. 36 | 37 | We guarantee that the callback on the master copy (the first copy) will be called ahead of calling the callback 38 | of any slave copies. 39 | """ 40 | master_copy = modules[0] 41 | nr_modules = len(list(master_copy.modules())) 42 | ctxs = [CallbackContext() for _ in range(nr_modules)] 43 | 44 | for i, module in enumerate(modules): 45 | for j, m in enumerate(module.modules()): 46 | if hasattr(m, '__data_parallel_replicate__'): 47 | m.__data_parallel_replicate__(ctxs[j], i) 48 | 49 | 50 | class DataParallelWithCallback(DataParallel): 51 | """ 52 | Data Parallel with a replication callback. 53 | 54 | An replication callback `__data_parallel_replicate__` of each module will be invoked after being created by 55 | original `replicate` function. 56 | The callback will be invoked with arguments `__data_parallel_replicate__(ctx, copy_id)` 57 | 58 | Examples: 59 | > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False) 60 | > sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1]) 61 | # sync_bn.__data_parallel_replicate__ will be invoked. 62 | """ 63 | 64 | def replicate(self, module, device_ids): 65 | modules = super(DataParallelWithCallback, self).replicate(module, device_ids) 66 | execute_replication_callbacks(modules) 67 | return modules 68 | 69 | 70 | def patch_replication_callback(data_parallel): 71 | """ 72 | Monkey-patch an existing `DataParallel` object. Add the replication callback. 73 | Useful when you have customized `DataParallel` implementation. 74 | 75 | Examples: 76 | > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False) 77 | > sync_bn = DataParallel(sync_bn, device_ids=[0, 1]) 78 | > patch_replication_callback(sync_bn) 79 | # this is equivalent to 80 | > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False) 81 | > sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1]) 82 | """ 83 | 84 | assert isinstance(data_parallel, DataParallel) 85 | 86 | old_replicate = data_parallel.replicate 87 | 88 | @functools.wraps(old_replicate) 89 | def new_replicate(module, device_ids): 90 | modules = old_replicate(module, device_ids) 91 | execute_replication_callbacks(modules) 92 | return modules 93 | 94 | data_parallel.replicate = new_replicate 95 | -------------------------------------------------------------------------------- /lib/net/sync_batchnorm/sync_batchnorm/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # File : __init__.py 3 | # Author : Jiayuan Mao 4 | # Email : maojiayuan@gmail.com 5 | # Date : 27/01/2018 6 | # 7 | # This file is part of Synchronized-BatchNorm-PyTorch. 8 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch 9 | # Distributed under MIT License. 10 | 11 | from .batchnorm import SynchronizedBatchNorm1d, SynchronizedBatchNorm2d, SynchronizedBatchNorm3d 12 | from .replicate import DataParallelWithCallback, patch_replication_callback 13 | -------------------------------------------------------------------------------- /lib/net/sync_batchnorm/sync_batchnorm/batchnorm.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # File : batchnorm.py 3 | # Author : Jiayuan Mao 4 | # Email : maojiayuan@gmail.com 5 | # Date : 27/01/2018 6 | # 7 | # This file is part of Synchronized-BatchNorm-PyTorch. 8 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch 9 | # Distributed under MIT License. 10 | 11 | import collections 12 | 13 | import torch 14 | import torch.nn.functional as F 15 | 16 | from torch.nn.modules.batchnorm import _BatchNorm 17 | from torch.nn.parallel._functions import ReduceAddCoalesced, Broadcast 18 | 19 | from .comm import SyncMaster 20 | 21 | __all__ = ['SynchronizedBatchNorm1d', 'SynchronizedBatchNorm2d', 'SynchronizedBatchNorm3d'] 22 | 23 | 24 | def _sum_ft(tensor): 25 | """sum over the first and last dimention""" 26 | return tensor.sum(dim=0).sum(dim=-1) 27 | 28 | 29 | def _unsqueeze_ft(tensor): 30 | """add new dementions at the front and the tail""" 31 | return tensor.unsqueeze(0).unsqueeze(-1) 32 | 33 | 34 | _ChildMessage = collections.namedtuple('_ChildMessage', ['sum', 'ssum', 'sum_size']) 35 | _MasterMessage = collections.namedtuple('_MasterMessage', ['sum', 'inv_std']) 36 | 37 | 38 | class _SynchronizedBatchNorm(_BatchNorm): 39 | def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True): 40 | super(_SynchronizedBatchNorm, self).__init__(num_features, eps=eps, momentum=momentum, affine=affine) 41 | 42 | self._sync_master = SyncMaster(self._data_parallel_master) 43 | 44 | self._is_parallel = False 45 | self._parallel_id = None 46 | self._slave_pipe = None 47 | 48 | def forward(self, input): 49 | # If it is not parallel computation or is in evaluation mode, use PyTorch's implementation. 50 | if not (self._is_parallel and self.training): 51 | return F.batch_norm( 52 | input, self.running_mean, self.running_var, self.weight, self.bias, 53 | self.training, self.momentum, self.eps) 54 | 55 | # Resize the input to (B, C, -1). 56 | input_shape = input.size() 57 | input = input.view(input.size(0), self.num_features, -1) 58 | 59 | # Compute the sum and square-sum. 60 | sum_size = input.size(0) * input.size(2) 61 | input_sum = _sum_ft(input) 62 | input_ssum = _sum_ft(input ** 2) 63 | 64 | # Reduce-and-broadcast the statistics. 65 | if self._parallel_id == 0: 66 | mean, inv_std = self._sync_master.run_master(_ChildMessage(input_sum, input_ssum, sum_size)) 67 | else: 68 | mean, inv_std = self._slave_pipe.run_slave(_ChildMessage(input_sum, input_ssum, sum_size)) 69 | 70 | # Compute the output. 71 | if self.affine: 72 | # MJY:: Fuse the multiplication for speed. 73 | output = (input - _unsqueeze_ft(mean)) * _unsqueeze_ft(inv_std * self.weight) + _unsqueeze_ft(self.bias) 74 | else: 75 | output = (input - _unsqueeze_ft(mean)) * _unsqueeze_ft(inv_std) 76 | 77 | # Reshape it. 78 | return output.view(input_shape) 79 | 80 | def __data_parallel_replicate__(self, ctx, copy_id): 81 | self._is_parallel = True 82 | self._parallel_id = copy_id 83 | 84 | # parallel_id == 0 means master device. 85 | if self._parallel_id == 0: 86 | ctx.sync_master = self._sync_master 87 | else: 88 | self._slave_pipe = ctx.sync_master.register_slave(copy_id) 89 | 90 | def _data_parallel_master(self, intermediates): 91 | """Reduce the sum and square-sum, compute the statistics, and broadcast it.""" 92 | 93 | # Always using same "device order" makes the ReduceAdd operation faster. 94 | # Thanks to:: Tete Xiao (http://tetexiao.com/) 95 | intermediates = sorted(intermediates, key=lambda i: i[1].sum.get_device()) 96 | 97 | to_reduce = [i[1][:2] for i in intermediates] 98 | to_reduce = [j for i in to_reduce for j in i] # flatten 99 | target_gpus = [i[1].sum.get_device() for i in intermediates] 100 | 101 | sum_size = sum([i[1].sum_size for i in intermediates]) 102 | sum_, ssum = ReduceAddCoalesced.apply(target_gpus[0], 2, *to_reduce) 103 | mean, inv_std = self._compute_mean_std(sum_, ssum, sum_size) 104 | 105 | broadcasted = Broadcast.apply(target_gpus, mean, inv_std) 106 | 107 | outputs = [] 108 | for i, rec in enumerate(intermediates): 109 | outputs.append((rec[0], _MasterMessage(*broadcasted[i*2:i*2+2]))) 110 | 111 | return outputs 112 | 113 | def _compute_mean_std(self, sum_, ssum, size): 114 | """Compute the mean and standard-deviation with sum and square-sum. This method 115 | also maintains the moving average on the master device.""" 116 | assert size > 1, 'BatchNorm computes unbiased standard-deviation, which requires size > 1.' 117 | mean = sum_ / size 118 | sumvar = ssum - sum_ * mean 119 | unbias_var = sumvar / (size - 1) 120 | bias_var = sumvar / size 121 | 122 | self.running_mean = (1 - self.momentum) * self.running_mean + self.momentum * mean.data 123 | self.running_var = (1 - self.momentum) * self.running_var + self.momentum * unbias_var.data 124 | 125 | return mean, bias_var.clamp(self.eps) ** -0.5 126 | 127 | 128 | class SynchronizedBatchNorm1d(_SynchronizedBatchNorm): 129 | r"""Applies Synchronized Batch Normalization over a 2d or 3d input that is seen as a 130 | mini-batch. 131 | 132 | .. math:: 133 | 134 | y = \frac{x - mean[x]}{ \sqrt{Var[x] + \epsilon}} * gamma + beta 135 | 136 | This module differs from the built-in PyTorch BatchNorm1d as the mean and 137 | standard-deviation are reduced across all devices during training. 138 | 139 | For example, when one uses `nn.DataParallel` to wrap the network during 140 | training, PyTorch's implementation normalize the tensor on each device using 141 | the statistics only on that device, which accelerated the computation and 142 | is also easy to implement, but the statistics might be inaccurate. 143 | Instead, in this synchronized version, the statistics will be computed 144 | over all training samples distributed on multiple devices. 145 | 146 | Note that, for one-GPU or CPU-only case, this module behaves exactly same 147 | as the built-in PyTorch implementation. 148 | 149 | The mean and standard-deviation are calculated per-dimension over 150 | the mini-batches and gamma and beta are learnable parameter vectors 151 | of size C (where C is the input size). 152 | 153 | During training, this layer keeps a running estimate of its computed mean 154 | and variance. The running sum is kept with a default momentum of 0.1. 155 | 156 | During evaluation, this running mean/variance is used for normalization. 157 | 158 | Because the BatchNorm is done over the `C` dimension, computing statistics 159 | on `(N, L)` slices, it's common terminology to call this Temporal BatchNorm 160 | 161 | Args: 162 | num_features: num_features from an expected input of size 163 | `batch_size x num_features [x width]` 164 | eps: a value added to the denominator for numerical stability. 165 | Default: 1e-5 166 | momentum: the value used for the running_mean and running_var 167 | computation. Default: 0.1 168 | affine: a boolean value that when set to ``True``, gives the layer learnable 169 | affine parameters. Default: ``True`` 170 | 171 | Shape: 172 | - Input: :math:`(N, C)` or :math:`(N, C, L)` 173 | - Output: :math:`(N, C)` or :math:`(N, C, L)` (same shape as input) 174 | 175 | Examples: 176 | >>> # With Learnable Parameters 177 | >>> m = SynchronizedBatchNorm1d(100) 178 | >>> # Without Learnable Parameters 179 | >>> m = SynchronizedBatchNorm1d(100, affine=False) 180 | >>> input = torch.autograd.Variable(torch.randn(20, 100)) 181 | >>> output = m(input) 182 | """ 183 | 184 | def _check_input_dim(self, input): 185 | if input.dim() != 2 and input.dim() != 3: 186 | raise ValueError('expected 2D or 3D input (got {}D input)' 187 | .format(input.dim())) 188 | super(SynchronizedBatchNorm1d, self)._check_input_dim(input) 189 | 190 | 191 | class SynchronizedBatchNorm2d(_SynchronizedBatchNorm): 192 | r"""Applies Batch Normalization over a 4d input that is seen as a mini-batch 193 | of 3d inputs 194 | 195 | .. math:: 196 | 197 | y = \frac{x - mean[x]}{ \sqrt{Var[x] + \epsilon}} * gamma + beta 198 | 199 | This module differs from the built-in PyTorch BatchNorm2d as the mean and 200 | standard-deviation are reduced across all devices during training. 201 | 202 | For example, when one uses `nn.DataParallel` to wrap the network during 203 | training, PyTorch's implementation normalize the tensor on each device using 204 | the statistics only on that device, which accelerated the computation and 205 | is also easy to implement, but the statistics might be inaccurate. 206 | Instead, in this synchronized version, the statistics will be computed 207 | over all training samples distributed on multiple devices. 208 | 209 | Note that, for one-GPU or CPU-only case, this module behaves exactly same 210 | as the built-in PyTorch implementation. 211 | 212 | The mean and standard-deviation are calculated per-dimension over 213 | the mini-batches and gamma and beta are learnable parameter vectors 214 | of size C (where C is the input size). 215 | 216 | During training, this layer keeps a running estimate of its computed mean 217 | and variance. The running sum is kept with a default momentum of 0.1. 218 | 219 | During evaluation, this running mean/variance is used for normalization. 220 | 221 | Because the BatchNorm is done over the `C` dimension, computing statistics 222 | on `(N, H, W)` slices, it's common terminology to call this Spatial BatchNorm 223 | 224 | Args: 225 | num_features: num_features from an expected input of 226 | size batch_size x num_features x height x width 227 | eps: a value added to the denominator for numerical stability. 228 | Default: 1e-5 229 | momentum: the value used for the running_mean and running_var 230 | computation. Default: 0.1 231 | affine: a boolean value that when set to ``True``, gives the layer learnable 232 | affine parameters. Default: ``True`` 233 | 234 | Shape: 235 | - Input: :math:`(N, C, H, W)` 236 | - Output: :math:`(N, C, H, W)` (same shape as input) 237 | 238 | Examples: 239 | >>> # With Learnable Parameters 240 | >>> m = SynchronizedBatchNorm2d(100) 241 | >>> # Without Learnable Parameters 242 | >>> m = SynchronizedBatchNorm2d(100, affine=False) 243 | >>> input = torch.autograd.Variable(torch.randn(20, 100, 35, 45)) 244 | >>> output = m(input) 245 | """ 246 | 247 | def _check_input_dim(self, input): 248 | if input.dim() != 4: 249 | raise ValueError('expected 4D input (got {}D input)' 250 | .format(input.dim())) 251 | super(SynchronizedBatchNorm2d, self)._check_input_dim(input) 252 | 253 | 254 | class SynchronizedBatchNorm3d(_SynchronizedBatchNorm): 255 | r"""Applies Batch Normalization over a 5d input that is seen as a mini-batch 256 | of 4d inputs 257 | 258 | .. math:: 259 | 260 | y = \frac{x - mean[x]}{ \sqrt{Var[x] + \epsilon}} * gamma + beta 261 | 262 | This module differs from the built-in PyTorch BatchNorm3d as the mean and 263 | standard-deviation are reduced across all devices during training. 264 | 265 | For example, when one uses `nn.DataParallel` to wrap the network during 266 | training, PyTorch's implementation normalize the tensor on each device using 267 | the statistics only on that device, which accelerated the computation and 268 | is also easy to implement, but the statistics might be inaccurate. 269 | Instead, in this synchronized version, the statistics will be computed 270 | over all training samples distributed on multiple devices. 271 | 272 | Note that, for one-GPU or CPU-only case, this module behaves exactly same 273 | as the built-in PyTorch implementation. 274 | 275 | The mean and standard-deviation are calculated per-dimension over 276 | the mini-batches and gamma and beta are learnable parameter vectors 277 | of size C (where C is the input size). 278 | 279 | During training, this layer keeps a running estimate of its computed mean 280 | and variance. The running sum is kept with a default momentum of 0.1. 281 | 282 | During evaluation, this running mean/variance is used for normalization. 283 | 284 | Because the BatchNorm is done over the `C` dimension, computing statistics 285 | on `(N, D, H, W)` slices, it's common terminology to call this Volumetric BatchNorm 286 | or Spatio-temporal BatchNorm 287 | 288 | Args: 289 | num_features: num_features from an expected input of 290 | size batch_size x num_features x depth x height x width 291 | eps: a value added to the denominator for numerical stability. 292 | Default: 1e-5 293 | momentum: the value used for the running_mean and running_var 294 | computation. Default: 0.1 295 | affine: a boolean value that when set to ``True``, gives the layer learnable 296 | affine parameters. Default: ``True`` 297 | 298 | Shape: 299 | - Input: :math:`(N, C, D, H, W)` 300 | - Output: :math:`(N, C, D, H, W)` (same shape as input) 301 | 302 | Examples: 303 | >>> # With Learnable Parameters 304 | >>> m = SynchronizedBatchNorm3d(100) 305 | >>> # Without Learnable Parameters 306 | >>> m = SynchronizedBatchNorm3d(100, affine=False) 307 | >>> input = torch.autograd.Variable(torch.randn(20, 100, 35, 45, 10)) 308 | >>> output = m(input) 309 | """ 310 | 311 | def _check_input_dim(self, input): 312 | if input.dim() != 5: 313 | raise ValueError('expected 5D input (got {}D input)' 314 | .format(input.dim())) 315 | super(SynchronizedBatchNorm3d, self)._check_input_dim(input) 316 | -------------------------------------------------------------------------------- /lib/net/sync_batchnorm/sync_batchnorm/batchnorm_reimpl.py: -------------------------------------------------------------------------------- 1 | #! /usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | # File : batchnorm_reimpl.py 4 | # Author : acgtyrant 5 | # Date : 11/01/2018 6 | # 7 | # This file is part of Synchronized-BatchNorm-PyTorch. 8 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch 9 | # Distributed under MIT License. 10 | 11 | import torch 12 | import torch.nn as nn 13 | import torch.nn.init as init 14 | 15 | __all__ = ['BatchNormReimpl'] 16 | 17 | 18 | class BatchNorm2dReimpl(nn.Module): 19 | """ 20 | A re-implementation of batch normalization, used for testing the numerical 21 | stability. 22 | 23 | Author: acgtyrant 24 | See also: 25 | https://github.com/vacancy/Synchronized-BatchNorm-PyTorch/issues/14 26 | """ 27 | def __init__(self, num_features, eps=1e-5, momentum=0.1): 28 | super().__init__() 29 | 30 | self.num_features = num_features 31 | self.eps = eps 32 | self.momentum = momentum 33 | self.weight = nn.Parameter(torch.empty(num_features)) 34 | self.bias = nn.Parameter(torch.empty(num_features)) 35 | self.register_buffer('running_mean', torch.zeros(num_features)) 36 | self.register_buffer('running_var', torch.ones(num_features)) 37 | self.reset_parameters() 38 | 39 | def reset_running_stats(self): 40 | self.running_mean.zero_() 41 | self.running_var.fill_(1) 42 | 43 | def reset_parameters(self): 44 | self.reset_running_stats() 45 | init.uniform_(self.weight) 46 | init.zeros_(self.bias) 47 | 48 | def forward(self, input_): 49 | batchsize, channels, height, width = input_.size() 50 | numel = batchsize * height * width 51 | input_ = input_.permute(1, 0, 2, 3).contiguous().view(channels, numel) 52 | sum_ = input_.sum(1) 53 | sum_of_square = input_.pow(2).sum(1) 54 | mean = sum_ / numel 55 | sumvar = sum_of_square - sum_ * mean 56 | 57 | self.running_mean = ( 58 | (1 - self.momentum) * self.running_mean 59 | + self.momentum * mean.detach() 60 | ) 61 | unbias_var = sumvar / (numel - 1) 62 | self.running_var = ( 63 | (1 - self.momentum) * self.running_var 64 | + self.momentum * unbias_var.detach() 65 | ) 66 | 67 | bias_var = sumvar / numel 68 | inv_std = 1 / (bias_var + self.eps).pow(0.5) 69 | output = ( 70 | (input_ - mean.unsqueeze(1)) * inv_std.unsqueeze(1) * 71 | self.weight.unsqueeze(1) + self.bias.unsqueeze(1)) 72 | 73 | return output.view(channels, batchsize, height, width).permute(1, 0, 2, 3).contiguous() 74 | 75 | -------------------------------------------------------------------------------- /lib/net/sync_batchnorm/sync_batchnorm/comm.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # File : comm.py 3 | # Author : Jiayuan Mao 4 | # Email : maojiayuan@gmail.com 5 | # Date : 27/01/2018 6 | # 7 | # This file is part of Synchronized-BatchNorm-PyTorch. 8 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch 9 | # Distributed under MIT License. 10 | 11 | import queue 12 | import collections 13 | import threading 14 | 15 | __all__ = ['FutureResult', 'SlavePipe', 'SyncMaster'] 16 | 17 | 18 | class FutureResult(object): 19 | """A thread-safe future implementation. Used only as one-to-one pipe.""" 20 | 21 | def __init__(self): 22 | self._result = None 23 | self._lock = threading.Lock() 24 | self._cond = threading.Condition(self._lock) 25 | 26 | def put(self, result): 27 | with self._lock: 28 | assert self._result is None, 'Previous result has\'t been fetched.' 29 | self._result = result 30 | self._cond.notify() 31 | 32 | def get(self): 33 | with self._lock: 34 | if self._result is None: 35 | self._cond.wait() 36 | 37 | res = self._result 38 | self._result = None 39 | return res 40 | 41 | 42 | _MasterRegistry = collections.namedtuple('MasterRegistry', ['result']) 43 | _SlavePipeBase = collections.namedtuple('_SlavePipeBase', ['identifier', 'queue', 'result']) 44 | 45 | 46 | class SlavePipe(_SlavePipeBase): 47 | """Pipe for master-slave communication.""" 48 | 49 | def run_slave(self, msg): 50 | self.queue.put((self.identifier, msg)) 51 | ret = self.result.get() 52 | self.queue.put(True) 53 | return ret 54 | 55 | 56 | class SyncMaster(object): 57 | """An abstract `SyncMaster` object. 58 | 59 | - During the replication, as the data parallel will trigger an callback of each module, all slave devices should 60 | call `register(id)` and obtain an `SlavePipe` to communicate with the master. 61 | - During the forward pass, master device invokes `run_master`, all messages from slave devices will be collected, 62 | and passed to a registered callback. 63 | - After receiving the messages, the master device should gather the information and determine to message passed 64 | back to each slave devices. 65 | """ 66 | 67 | def __init__(self, master_callback): 68 | """ 69 | 70 | Args: 71 | master_callback: a callback to be invoked after having collected messages from slave devices. 72 | """ 73 | self._master_callback = master_callback 74 | self._queue = queue.Queue() 75 | self._registry = collections.OrderedDict() 76 | self._activated = False 77 | 78 | def __getstate__(self): 79 | return {'master_callback': self._master_callback} 80 | 81 | def __setstate__(self, state): 82 | self.__init__(state['master_callback']) 83 | 84 | def register_slave(self, identifier): 85 | """ 86 | Register an slave device. 87 | 88 | Args: 89 | identifier: an identifier, usually is the device id. 90 | 91 | Returns: a `SlavePipe` object which can be used to communicate with the master device. 92 | 93 | """ 94 | if self._activated: 95 | assert self._queue.empty(), 'Queue is not clean before next initialization.' 96 | self._activated = False 97 | self._registry.clear() 98 | future = FutureResult() 99 | self._registry[identifier] = _MasterRegistry(future) 100 | return SlavePipe(identifier, self._queue, future) 101 | 102 | def run_master(self, master_msg): 103 | """ 104 | Main entry for the master device in each forward pass. 105 | The messages were first collected from each devices (including the master device), and then 106 | an callback will be invoked to compute the message to be sent back to each devices 107 | (including the master device). 108 | 109 | Args: 110 | master_msg: the message that the master want to send to itself. This will be placed as the first 111 | message when calling `master_callback`. For detailed usage, see `_SynchronizedBatchNorm` for an example. 112 | 113 | Returns: the message to be sent back to the master device. 114 | 115 | """ 116 | self._activated = True 117 | 118 | intermediates = [(0, master_msg)] 119 | for i in range(self.nr_slaves): 120 | intermediates.append(self._queue.get()) 121 | 122 | results = self._master_callback(intermediates) 123 | assert results[0][0] == 0, 'The first result should belongs to the master.' 124 | 125 | for i, res in results: 126 | if i == 0: 127 | continue 128 | self._registry[i].result.put(res) 129 | 130 | for i in range(self.nr_slaves): 131 | assert self._queue.get() is True 132 | 133 | return results[0][1] 134 | 135 | @property 136 | def nr_slaves(self): 137 | return len(self._registry) 138 | -------------------------------------------------------------------------------- /lib/net/sync_batchnorm/sync_batchnorm/replicate.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # File : replicate.py 3 | # Author : Jiayuan Mao 4 | # Email : maojiayuan@gmail.com 5 | # Date : 27/01/2018 6 | # 7 | # This file is part of Synchronized-BatchNorm-PyTorch. 8 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch 9 | # Distributed under MIT License. 10 | 11 | import functools 12 | 13 | from torch.nn.parallel.data_parallel import DataParallel 14 | 15 | __all__ = [ 16 | 'CallbackContext', 17 | 'execute_replication_callbacks', 18 | 'DataParallelWithCallback', 19 | 'patch_replication_callback' 20 | ] 21 | 22 | 23 | class CallbackContext(object): 24 | pass 25 | 26 | 27 | def execute_replication_callbacks(modules): 28 | """ 29 | Execute an replication callback `__data_parallel_replicate__` on each module created by original replication. 30 | 31 | The callback will be invoked with arguments `__data_parallel_replicate__(ctx, copy_id)` 32 | 33 | Note that, as all modules are isomorphism, we assign each sub-module with a context 34 | (shared among multiple copies of this module on different devices). 35 | Through this context, different copies can share some information. 36 | 37 | We guarantee that the callback on the master copy (the first copy) will be called ahead of calling the callback 38 | of any slave copies. 39 | """ 40 | master_copy = modules[0] 41 | nr_modules = len(list(master_copy.modules())) 42 | ctxs = [CallbackContext() for _ in range(nr_modules)] 43 | 44 | for i, module in enumerate(modules): 45 | for j, m in enumerate(module.modules()): 46 | if hasattr(m, '__data_parallel_replicate__'): 47 | m.__data_parallel_replicate__(ctxs[j], i) 48 | 49 | 50 | class DataParallelWithCallback(DataParallel): 51 | """ 52 | Data Parallel with a replication callback. 53 | 54 | An replication callback `__data_parallel_replicate__` of each module will be invoked after being created by 55 | original `replicate` function. 56 | The callback will be invoked with arguments `__data_parallel_replicate__(ctx, copy_id)` 57 | 58 | Examples: 59 | > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False) 60 | > sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1]) 61 | # sync_bn.__data_parallel_replicate__ will be invoked. 62 | """ 63 | 64 | def replicate(self, module, device_ids): 65 | modules = super(DataParallelWithCallback, self).replicate(module, device_ids) 66 | execute_replication_callbacks(modules) 67 | return modules 68 | 69 | 70 | def patch_replication_callback(data_parallel): 71 | """ 72 | Monkey-patch an existing `DataParallel` object. Add the replication callback. 73 | Useful when you have customized `DataParallel` implementation. 74 | 75 | Examples: 76 | > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False) 77 | > sync_bn = DataParallel(sync_bn, device_ids=[0, 1]) 78 | > patch_replication_callback(sync_bn) 79 | # this is equivalent to 80 | > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False) 81 | > sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1]) 82 | """ 83 | 84 | assert isinstance(data_parallel, DataParallel) 85 | 86 | old_replicate = data_parallel.replicate 87 | 88 | @functools.wraps(old_replicate) 89 | def new_replicate(module, device_ids): 90 | modules = old_replicate(module, device_ids) 91 | execute_replication_callbacks(modules) 92 | return modules 93 | 94 | data_parallel.replicate = new_replicate 95 | -------------------------------------------------------------------------------- /lib/net/sync_batchnorm/sync_batchnorm/unittest.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # File : unittest.py 3 | # Author : Jiayuan Mao 4 | # Email : maojiayuan@gmail.com 5 | # Date : 27/01/2018 6 | # 7 | # This file is part of Synchronized-BatchNorm-PyTorch. 8 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch 9 | # Distributed under MIT License. 10 | 11 | import unittest 12 | import torch 13 | 14 | 15 | class TorchTestCase(unittest.TestCase): 16 | def assertTensorClose(self, x, y): 17 | adiff = float((x - y).abs().max()) 18 | if (y == 0).all(): 19 | rdiff = 'NaN' 20 | else: 21 | rdiff = float((adiff / y).abs().max()) 22 | 23 | message = ( 24 | 'Tensor close check failed\n' 25 | 'adiff={}\n' 26 | 'rdiff={}\n' 27 | ).format(adiff, rdiff) 28 | self.assertTrue(torch.allclose(x, y), message) 29 | 30 | -------------------------------------------------------------------------------- /lib/net/sync_batchnorm/tests/test_numeric_batchnorm.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # File : test_numeric_batchnorm.py 3 | # Author : Jiayuan Mao 4 | # Email : maojiayuan@gmail.com 5 | # Date : 27/01/2018 6 | # 7 | # This file is part of Synchronized-BatchNorm-PyTorch. 8 | 9 | import unittest 10 | 11 | import torch 12 | import torch.nn as nn 13 | from torch.autograd import Variable 14 | 15 | from sync_batchnorm.unittest import TorchTestCase 16 | 17 | 18 | def handy_var(a, unbias=True): 19 | n = a.size(0) 20 | asum = a.sum(dim=0) 21 | as_sum = (a ** 2).sum(dim=0) # a square sum 22 | sumvar = as_sum - asum * asum / n 23 | if unbias: 24 | return sumvar / (n - 1) 25 | else: 26 | return sumvar / n 27 | 28 | 29 | class NumericTestCase(TorchTestCase): 30 | def testNumericBatchNorm(self): 31 | a = torch.rand(16, 10) 32 | bn = nn.BatchNorm2d(10, momentum=1, eps=1e-5, affine=False) 33 | bn.train() 34 | 35 | a_var1 = Variable(a, requires_grad=True) 36 | b_var1 = bn(a_var1) 37 | loss1 = b_var1.sum() 38 | loss1.backward() 39 | 40 | a_var2 = Variable(a, requires_grad=True) 41 | a_mean2 = a_var2.mean(dim=0, keepdim=True) 42 | a_std2 = torch.sqrt(handy_var(a_var2, unbias=False).clamp(min=1e-5)) 43 | # a_std2 = torch.sqrt(a_var2.var(dim=0, keepdim=True, unbiased=False) + 1e-5) 44 | b_var2 = (a_var2 - a_mean2) / a_std2 45 | loss2 = b_var2.sum() 46 | loss2.backward() 47 | 48 | self.assertTensorClose(bn.running_mean, a.mean(dim=0)) 49 | self.assertTensorClose(bn.running_var, handy_var(a)) 50 | self.assertTensorClose(a_var1.data, a_var2.data) 51 | self.assertTensorClose(b_var1.data, b_var2.data) 52 | self.assertTensorClose(a_var1.grad, a_var2.grad) 53 | 54 | 55 | if __name__ == '__main__': 56 | unittest.main() 57 | -------------------------------------------------------------------------------- /lib/net/sync_batchnorm/tests/test_sync_batchnorm.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # File : test_sync_batchnorm.py 3 | # Author : Jiayuan Mao 4 | # Email : maojiayuan@gmail.com 5 | # Date : 27/01/2018 6 | # 7 | # This file is part of Synchronized-BatchNorm-PyTorch. 8 | 9 | import unittest 10 | 11 | import torch 12 | import torch.nn as nn 13 | from torch.autograd import Variable 14 | 15 | from sync_batchnorm import SynchronizedBatchNorm1d, SynchronizedBatchNorm2d, DataParallelWithCallback 16 | from sync_batchnorm.unittest import TorchTestCase 17 | 18 | 19 | def handy_var(a, unbias=True): 20 | n = a.size(0) 21 | asum = a.sum(dim=0) 22 | as_sum = (a ** 2).sum(dim=0) # a square sum 23 | sumvar = as_sum - asum * asum / n 24 | if unbias: 25 | return sumvar / (n - 1) 26 | else: 27 | return sumvar / n 28 | 29 | 30 | def _find_bn(module): 31 | for m in module.modules(): 32 | if isinstance(m, (nn.BatchNorm1d, nn.BatchNorm2d, SynchronizedBatchNorm1d, SynchronizedBatchNorm2d)): 33 | return m 34 | 35 | 36 | class SyncTestCase(TorchTestCase): 37 | def _syncParameters(self, bn1, bn2): 38 | bn1.reset_parameters() 39 | bn2.reset_parameters() 40 | if bn1.affine and bn2.affine: 41 | bn2.weight.data.copy_(bn1.weight.data) 42 | bn2.bias.data.copy_(bn1.bias.data) 43 | 44 | def _checkBatchNormResult(self, bn1, bn2, input, is_train, cuda=False): 45 | """Check the forward and backward for the customized batch normalization.""" 46 | bn1.train(mode=is_train) 47 | bn2.train(mode=is_train) 48 | 49 | if cuda: 50 | input = input.cuda() 51 | 52 | self._syncParameters(_find_bn(bn1), _find_bn(bn2)) 53 | 54 | input1 = Variable(input, requires_grad=True) 55 | output1 = bn1(input1) 56 | output1.sum().backward() 57 | input2 = Variable(input, requires_grad=True) 58 | output2 = bn2(input2) 59 | output2.sum().backward() 60 | 61 | self.assertTensorClose(input1.data, input2.data) 62 | self.assertTensorClose(output1.data, output2.data) 63 | self.assertTensorClose(input1.grad, input2.grad) 64 | self.assertTensorClose(_find_bn(bn1).running_mean, _find_bn(bn2).running_mean) 65 | self.assertTensorClose(_find_bn(bn1).running_var, _find_bn(bn2).running_var) 66 | 67 | def testSyncBatchNormNormalTrain(self): 68 | bn = nn.BatchNorm1d(10) 69 | sync_bn = SynchronizedBatchNorm1d(10) 70 | 71 | self._checkBatchNormResult(bn, sync_bn, torch.rand(16, 10), True) 72 | 73 | def testSyncBatchNormNormalEval(self): 74 | bn = nn.BatchNorm1d(10) 75 | sync_bn = SynchronizedBatchNorm1d(10) 76 | 77 | self._checkBatchNormResult(bn, sync_bn, torch.rand(16, 10), False) 78 | 79 | def testSyncBatchNormSyncTrain(self): 80 | bn = nn.BatchNorm1d(10, eps=1e-5, affine=False) 81 | sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False) 82 | sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1]) 83 | 84 | bn.cuda() 85 | sync_bn.cuda() 86 | 87 | self._checkBatchNormResult(bn, sync_bn, torch.rand(16, 10), True, cuda=True) 88 | 89 | def testSyncBatchNormSyncEval(self): 90 | bn = nn.BatchNorm1d(10, eps=1e-5, affine=False) 91 | sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False) 92 | sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1]) 93 | 94 | bn.cuda() 95 | sync_bn.cuda() 96 | 97 | self._checkBatchNormResult(bn, sync_bn, torch.rand(16, 10), False, cuda=True) 98 | 99 | def testSyncBatchNorm2DSyncTrain(self): 100 | bn = nn.BatchNorm2d(10) 101 | sync_bn = SynchronizedBatchNorm2d(10) 102 | sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1]) 103 | 104 | bn.cuda() 105 | sync_bn.cuda() 106 | 107 | self._checkBatchNormResult(bn, sync_bn, torch.rand(16, 10, 16, 16), True, cuda=True) 108 | 109 | 110 | if __name__ == '__main__': 111 | unittest.main() 112 | -------------------------------------------------------------------------------- /lib/net/sync_batchnorm/unittest.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # File : unittest.py 3 | # Author : Jiayuan Mao 4 | # Email : maojiayuan@gmail.com 5 | # Date : 27/01/2018 6 | # 7 | # This file is part of Synchronized-BatchNorm-PyTorch. 8 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch 9 | # Distributed under MIT License. 10 | 11 | import unittest 12 | 13 | import numpy as np 14 | from torch.autograd import Variable 15 | 16 | 17 | def as_numpy(v): 18 | if isinstance(v, Variable): 19 | v = v.data 20 | return v.cpu().numpy() 21 | 22 | 23 | class TorchTestCase(unittest.TestCase): 24 | def assertTensorClose(self, a, b, atol=1e-3, rtol=1e-3): 25 | npa, npb = as_numpy(a), as_numpy(b) 26 | self.assertTrue( 27 | np.allclose(npa, npb, atol=atol), 28 | 'Tensor close check failed\n{}\n{}\nadiff={}, rdiff={}'.format(a, b, np.abs(npa - npb).max(), np.abs((npa - npb) / np.fmax(npa, 1e-5)).max()) 29 | ) 30 | -------------------------------------------------------------------------------- /lib/utils/DenseCRF.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pydensecrf.densecrf as dcrf 3 | from pydensecrf.utils import unary_from_softmax 4 | 5 | def dense_crf(probs, img=None, n_classes=21, n_iters=1, scale_factor=1): 6 | #probs = np.transpose(probs,(1,2,0)).copy(order='C') 7 | c,h,w = probs.shape 8 | 9 | if img is not None: 10 | assert(img.shape[1:3] == (h, w)) 11 | img = np.transpose(img,(1,2,0)).copy(order='C') 12 | 13 | #probs = probs.transpose(2, 0, 1).copy(order='C') # Need a contiguous array. 14 | 15 | d = dcrf.DenseCRF2D(w, h, n_classes) # Define DenseCRF model. 16 | 17 | unary = unary_from_softmax(probs) 18 | unary = np.ascontiguousarray(unary) 19 | d.setUnaryEnergy(unary) 20 | d.addPairwiseGaussian(sxy=3/scale_factor, compat=3) 21 | #d.addPairwiseBilateral(sxy=80/scale_factor, srgb=13, rgbim=np.copy(img), compat=10) 22 | d.addPairwiseBilateral(sxy=32/scale_factor, srgb=13, rgbim=np.copy(img), compat=10) 23 | Q = d.inference(n_iters) 24 | 25 | # U = -np.log(probs) # Unary potential. 26 | # U = U.reshape((n_classes, -1)) # Needs to be flat. 27 | # d.setUnaryEnergy(U) 28 | # d.addPairwiseGaussian(sxy=sxy_gaussian, compat=compat_gaussian, 29 | # kernel=kernel_gaussian, normalization=normalisation_gaussian) 30 | # if img is not None: 31 | # assert(img.shape[1:3] == (h, w)) 32 | # img = np.transpose(img,(1,2,0)).copy(order='C') 33 | # d.addPairwiseBilateral(sxy=sxy_bilateral, compat=compat_bilateral, 34 | # kernel=kernel_bilateral, normalization=normalisation_bilateral, 35 | # srgb=srgb_bilateral, rgbim=img) 36 | # Q = d.inference(n_iters) 37 | preds = np.array(Q, dtype=np.float32).reshape((n_classes, h, w)) 38 | #return np.expand_dims(preds, 0) 39 | return preds 40 | 41 | def pro_crf(p, img, itr): 42 | C, H, W = p.shape 43 | p_bg = 1-p 44 | for i in range(C): 45 | cat = np.concatenate([p[i,:,:], p_bg[i,:,:]], axis=0) 46 | crf_pro = dense_crf(cat, img.astype(np.uint8), n_classes=C, n_iters=itr) 47 | p[i,:,:] = crf_pro[0] 48 | return p 49 | -------------------------------------------------------------------------------- /lib/utils/JSD_loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torchvision 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | 7 | 8 | def calc_jsd_single(weight, labels1_a, pred, threshold=0.8, Mask_label255_sign='no'): 9 | 10 | Mask_label255 = (labels1_a < 255).float() # do not compute the area that is irrelavant (dataaug) 11 | weight_softmax = F.softmax(weight, dim=0) 12 | 13 | criterion = nn.CrossEntropyLoss(ignore_index=255, reduction='none') 14 | 15 | loss = criterion(pred * weight_softmax[0], labels1_a) # * weight_softmax[0] 16 | 17 | 18 | 19 | prob = F.softmax(pred, dim=1) 20 | prob = torch.clamp(prob, 1e-7, 1) 21 | 22 | max_probs = torch.amax(prob*Mask_label255.unsqueeze(1), dim=(-3, -2, -1), keepdim=True) # select according to the pred without the aug area 23 | mask = max_probs.ge(threshold).float() 24 | 25 | 26 | logp = prob.log() 27 | 28 | log_probs = torch.sum(F.kl_div(logp, prob, reduction='none') * mask, dim=1) 29 | if Mask_label255_sign == 'yes': 30 | consistency = sum(log_probs)*Mask_label255 31 | else: 32 | consistency = sum(log_probs) 33 | 34 | return torch.mean(loss), torch.mean(consistency), consistency, prob 35 | 36 | 37 | def calc_jsd_multiscale(weight, labels1_a, pred1, pred2, pred3, threshold=0.8, Mask_label255_sign='no'): 38 | 39 | Mask_label255 = (labels1_a < 255).float() # do not compute the area that is irrelavant (dataaug) b,h,w 40 | weight_softmax = F.softmax(weight, dim=0) 41 | 42 | criterion1 = nn.CrossEntropyLoss(ignore_index=255, reduction='none') 43 | criterion2 = nn.CrossEntropyLoss(ignore_index=255, reduction='none') 44 | criterion3 = nn.CrossEntropyLoss(ignore_index=255, reduction='none') 45 | 46 | loss1 = criterion1(pred1 * weight_softmax[0], labels1_a) # * weight_softmax[0] 47 | loss2 = criterion2(pred2 * weight_softmax[1], labels1_a) # * weight_softmax[1] 48 | loss3 = criterion3(pred3 * weight_softmax[2], labels1_a) # * weight_softmax[2] 49 | 50 | loss = (loss1 + loss2 + loss3) 51 | 52 | probs = [F.softmax(logits, dim=1) for i, logits in enumerate([pred1, pred2, pred3])] 53 | 54 | weighted_probs = [weight_softmax[i] * prob for i, prob in enumerate(probs)] # weight_softmax[i]* 55 | mixture_label = (torch.stack(weighted_probs)).sum(axis=0) 56 | #mixture_label = torch.clamp(mixture_label, 1e-7, 1) # h,c,h,w 57 | mixture_label = torch.clamp(mixture_label, 1e-3, 1-1e-3) # h,c,h,w 58 | 59 | # add this code block for early torch version where torch.amax is not available 60 | if torch.__version__=="1.5.0" or torch.__version__=="1.6.0": 61 | _, max_probs = torch.max(mixture_label*Mask_label255.unsqueeze(1), dim=-3, keepdim=True) 62 | _, max_probs = torch.max(max_probs, dim=-2, keepdim=True) 63 | _, max_probs = torch.max(max_probs, dim=-1, keepdim=True) 64 | else: 65 | max_probs = torch.amax(mixture_label*Mask_label255.unsqueeze(1), dim=(-3, -2, -1), keepdim=True) 66 | mask = max_probs.ge(threshold).float() 67 | 68 | 69 | logp_mixture = mixture_label.log() 70 | 71 | log_probs = [torch.sum(F.kl_div(logp_mixture, prob, reduction='none') * mask, dim=1) for prob in probs] 72 | if Mask_label255_sign == 'yes': 73 | consistency = sum(log_probs)*Mask_label255 74 | else: 75 | consistency = sum(log_probs) 76 | 77 | return torch.mean(loss), torch.mean(consistency), consistency, mixture_label 78 | 79 | 80 | 81 | def calc_multiscale_backup(weight, seg_backup, pred1, pred2, pred3, mask_seg_prednan, seg_prediction, Lambda_back=0): 82 | b,_,h,w = pred1.size() 83 | seg_tempt = torch.zeros((b, 21, h, w), dtype=torch.float) 84 | 85 | seg_tempt[~mask_seg_prednan[:, :, :, :]] = F.softmax(seg_prediction, dim=1)[ 86 | ~mask_seg_prednan[:, :, :, :]] # b,h,w 87 | # need to mask out the irrelavant region NaN 88 | 89 | weight_softmax = F.softmax(weight, dim=0) 90 | 91 | criterion1 = nn.CrossEntropyLoss(ignore_index=255, reduction='none') 92 | criterion2 = nn.CrossEntropyLoss(ignore_index=255, reduction='none') 93 | criterion3 = nn.CrossEntropyLoss(ignore_index=255, reduction='none') 94 | 95 | # to let the loss depends on the confidence of previous prediction on the background class seg_tempt[:,0,:,:] 96 | if Lambda_back == 0: 97 | loss1 = torch.mean(criterion1(pred1 * weight_softmax[0], seg_backup.to(0)) * seg_tempt[:, 0, :, :].to(0)) 98 | loss2 = torch.mean(criterion2(pred2 * weight_softmax[1], seg_backup.to(0)) * seg_tempt[:, 0, :, :].to(0)) 99 | loss3 = torch.mean(criterion3(pred3 * weight_softmax[2], seg_backup.to(0)) * seg_tempt[:, 0, :, :].to(0)) 100 | else: 101 | loss1 = torch.mean(criterion1(pred1 * weight_softmax[0], seg_backup.to(0))) * Lambda_back 102 | loss2 = torch.mean(criterion2(pred2 * weight_softmax[1], seg_backup.to(0))) * Lambda_back 103 | loss3 = torch.mean(criterion3(pred3 * weight_softmax[2], seg_backup.to(0))) * Lambda_back 104 | 105 | loss = (loss1 + loss2 + loss3) 106 | 107 | 108 | 109 | 110 | 111 | 112 | return loss -------------------------------------------------------------------------------- /lib/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .registry import DATASETS, BACKBONES, NETS 2 | 3 | __all__ = ['DATASETS', 'BACKBONES', 'NETS'] 4 | -------------------------------------------------------------------------------- /lib/utils/configuration.py: -------------------------------------------------------------------------------- 1 | # ---------------------------------------- 2 | # Written by Yude Wang 3 | # ---------------------------------------- 4 | import torch 5 | import os 6 | import sys 7 | import shutil 8 | 9 | class Configuration(): 10 | def __init__(self, config_dict, clear=True): 11 | self.__dict__ = config_dict 12 | self.clear = clear 13 | self.__check() 14 | 15 | def __check(self): 16 | if not torch.cuda.is_available(): 17 | raise ValueError('config.py: cuda is not avalable') 18 | if self.GPUS == 0: 19 | raise ValueError('config.py: the number of GPU is 0') 20 | if self.GPUS != torch.cuda.device_count(): 21 | raise ValueError('config.py: GPU number is not matched') 22 | 23 | if not os.path.isdir(self.LOG_DIR): 24 | os.mkdir(self.LOG_DIR) 25 | # elif self.clear: 26 | # shutil.rmtree(self.LOG_DIR) 27 | # os.mkdir(self.LOG_DIR) 28 | if not os.path.isdir(self.MODEL_SAVE_DIR): 29 | # os.makedirs(self.MODEL_SAVE_DIR) 30 | os.mkdir(self.MODEL_SAVE_DIR) 31 | 32 | 33 | 34 | 35 | 36 | -------------------------------------------------------------------------------- /lib/utils/eval_net_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from tqdm import tqdm 4 | import numpy as np 5 | import cv2 6 | from .imutils import img_denorm 7 | from .DenseCRF import dense_crf 8 | import pickle 9 | import os 10 | import torch.multiprocessing as mp 11 | 12 | def eval_net_multiprocess(SpawnContext, net1, net2, IoU_npl_indx, train_dataloader, eval_dataloader1, 13 | eval_dataloader2, momentum=0.3, scale_index=0, flip='no', 14 | scalefactor=1.0, CRF_post='no', tempt_save_root='.',update_all_bg_img=True,t_eval=3): 15 | net1.eval() 16 | net2.eval() 17 | if torch.cuda.device_count() > 1: 18 | 19 | 20 | seg_dict_copy = train_dataloader.dataset.seg_dict.copy() 21 | p1 = SpawnContext.Process(target = eval_net_bs_one, args=(torch.device(0), net1, IoU_npl_indx, eval_dataloader1,seg_dict_copy, momentum, scale_index, flip, scalefactor, CRF_post, tempt_save_root, 'eval_dict_tempt1.npy',update_all_bg_img,t_eval)) 22 | p2 = SpawnContext.Process(target = eval_net_bs_one, args=(torch.device(1), net2, IoU_npl_indx, eval_dataloader2,seg_dict_copy, momentum, scale_index, flip, scalefactor, CRF_post, tempt_save_root, 'eval_dict_tempt2.npy',update_all_bg_img,t_eval)) 23 | 24 | p1.start() 25 | p2.start() 26 | 27 | p1.join() 28 | p2.join() 29 | 30 | 31 | tempt = np.load(os.path.join(tempt_save_root, 'eval_dict_tempt1.npy'), allow_pickle=True) 32 | prev_pred_dict = tempt[()] 33 | 34 | tempt2 = np.load(os.path.join(tempt_save_root, 'eval_dict_tempt2.npy'), allow_pickle=True) 35 | prev_pred_dict2 = tempt2[()] 36 | 37 | prev_pred_dict.update(prev_pred_dict2) 38 | train_dataloader.dataset.prev_pred_dict = prev_pred_dict 39 | 40 | os.remove(os.path.join(tempt_save_root, 'eval_dict_tempt1.npy')) 41 | os.remove(os.path.join(tempt_save_root, 'eval_dict_tempt2.npy')) 42 | del seg_dict_copy 43 | 44 | 45 | 46 | return 47 | 48 | 49 | def eval_net_bs_one(device, net, IoU_npl_indx, eval_dataloader, seg_dict_copy, momentum=0.3, scale_index=0, flip='no', scalefactor=1.0, CRF_post='no',tempt_save_root='.', save_name='eval_dict_tempt1.npy', update_all_bg_img=False, t_eval=1.0): 50 | # net.eval() 51 | #scale_index = 2 # currently only support this version, improve later 52 | if scale_index==0: 53 | TEST_MULTISCALE = [0.75, 1.0, 1.5] 54 | elif scale_index==1: 55 | TEST_MULTISCALE = [0.5, 1.0, 1.75] 56 | elif scale_index==2: 57 | TEST_MULTISCALE = [0.5, 0.75, 1.0, 1.25, 1.5, 1.75] 58 | elif scale_index==3: 59 | TEST_MULTISCALE = [0.7, 1.0, 1.5] 60 | elif scale_index==4: 61 | TEST_MULTISCALE = [0.5, 0.75, 1.0, 1.25, 1.5] 62 | elif scale_index==5: 63 | TEST_MULTISCALE = [1.0] 64 | # print('eval_with_onebyone') 65 | prev_pred_dict = {} 66 | with tqdm(total=len(eval_dataloader)) as pbar: 67 | with torch.no_grad(): 68 | for i_batch, sample in enumerate(eval_dataloader): 69 | # print(sample['batch_idx']) 70 | # seg_labels = sample['segmentation'] 71 | 72 | seg_labels = seg_dict_copy[eval_dataloader.dataset.ori_indx_list[sample['batch_idx']]] 73 | # if they are not disjoint, we should evaluate it 74 | if set(np.unique(seg_labels[0].cpu().numpy())).isdisjoint(set(IoU_npl_indx[1:])): 75 | 76 | if update_all_bg_img and not (set(np.unique(seg_labels[0].numpy())) - set(np.array([0, 255]))): 77 | # only the background in the pseudo label, then this picture will still be evaluated 78 | pass 79 | else: 80 | # skip this one 81 | continue 82 | 83 | inputs = sample['image'] 84 | n, c, h, w = inputs.size() # 1,c,h,w 85 | result_list =[] 86 | image_multiscale = [] 87 | for rate in TEST_MULTISCALE: 88 | inputs_batched = sample['image_%f' % rate] 89 | image_multiscale.append(inputs_batched) 90 | if flip!='no': 91 | image_multiscale.append(torch.flip(inputs_batched, [3])) 92 | for img in image_multiscale: 93 | result = net(img.to(device)) 94 | result_list.append(result.cpu()) 95 | img.cpu() 96 | 97 | for i in range(len(result_list)): 98 | result_seg = F.interpolate(result_list[i], (h,w), mode='bilinear', align_corners=True) 99 | if i % 2 == 1 and flip!='no': 100 | result_seg = torch.flip(result_seg, [3]) 101 | result_list[i] = result_seg 102 | prob_seg = torch.stack(result_list, dim=0) # 12, 1, c,h,w 103 | prob_seg = F.softmax(torch.mean(prob_seg/t_eval, dim=0, keepdim=False), dim=1) # 1,c,h,w 104 | #prob_seg = torch.clamp(prob_seg, 1e-7, 1) 105 | # do the CRF 106 | if CRF_post !='no': 107 | prob = prob_seg.cpu().numpy() # 1,c,h,w 108 | img_batched = img_denorm(sample['image'][0].numpy()).astype(np.uint8) 109 | prob = dense_crf(prob[0], img_batched, n_classes=21, n_iters=1) 110 | prob_seg = torch.from_numpy(prob.astype(np.float32)) 111 | result = prob_seg.unsqueeze(dim=0) # 1,c,h,w 112 | else: 113 | result = prob_seg.cpu() # 1,c,h,w 114 | 115 | result_argmax = torch.argmax(result,dim=1) # 1,c,h,w the pred argmax label 116 | result_max_prob, _ = torch.max(result, dim=1) # 1,c,h,w the max probability 117 | for batch_idx in sample['batch_idx'].numpy(): 118 | # prev_pred_dict[batch_idx] = result 119 | prev_pred_dict[eval_dataloader.dataset.ori_indx_list[batch_idx]]= (result_argmax, result_max_prob) 120 | pbar.set_description("Correcting Labels ") 121 | pbar.update(1) 122 | 123 | np.save(os.path.join(tempt_save_root, save_name), prev_pred_dict) 124 | 125 | -------------------------------------------------------------------------------- /lib/utils/finalprocess.py: -------------------------------------------------------------------------------- 1 | import os 2 | import csv 3 | 4 | def writelog(cfg, period, metric=None, commit=''): 5 | filepath = os.path.join(cfg.ROOT_DIR,'log','logfile.txt') 6 | logfile = open(filepath,'a') 7 | import time 8 | logfile.write(time.strftime("%Y-%m-%d %H:%M:%S", time.localtime())) 9 | logfile.write('\t%s\n'%period) 10 | para_data_dict = {} 11 | para_model_dict = {} 12 | para_train_dict = {} 13 | para_test_dict = {} 14 | para_name = dir(cfg) 15 | for name in para_name: 16 | if 'DATA_' in name: 17 | v = getattr(cfg,name) 18 | para_data_dict[name] = v 19 | elif 'MODEL_' in name: 20 | v = getattr(cfg,name) 21 | para_model_dict[name] = v 22 | elif 'TRAIN_' in name: 23 | v = getattr(cfg,name) 24 | para_train_dict[name] = v 25 | elif 'TEST_' in name: 26 | v = getattr(cfg,name) 27 | para_test_dict[name] = v 28 | writedict(logfile, {'EXP_NAME': cfg.EXP_NAME}) 29 | writedict(logfile, para_data_dict) 30 | writedict(logfile, para_model_dict) 31 | if 'train' in period: 32 | writedict(logfile, para_train_dict) 33 | else: 34 | writedict(logfile, para_test_dict) 35 | writedict(logfile, metric) 36 | 37 | logfile.write(commit) 38 | logfile.write('=====================================\n') 39 | logfile.close() 40 | 41 | 42 | def writelog_seperate(cfg, period, metric=None, commit=''): 43 | filepath = os.path.join(cfg.ROOT_DIR,'log',cfg.logfile) 44 | logfile = open(filepath,'a') 45 | import time 46 | logfile.write(time.strftime("%Y-%m-%d %H:%M:%S", time.localtime())) 47 | logfile.write('\t%s\n'%period) 48 | para_data_dict = {} 49 | para_model_dict = {} 50 | para_train_dict = {} 51 | para_test_dict = {} 52 | para_name = dir(cfg) 53 | for name in para_name: 54 | if 'DATA_' in name: 55 | v = getattr(cfg,name) 56 | para_data_dict[name] = v 57 | elif 'MODEL_' in name: 58 | v = getattr(cfg,name) 59 | para_model_dict[name] = v 60 | elif 'TRAIN_' in name: 61 | v = getattr(cfg,name) 62 | para_train_dict[name] = v 63 | elif 'TEST_' in name: 64 | v = getattr(cfg,name) 65 | para_test_dict[name] = v 66 | writedict(logfile, {'EXP_NAME': cfg.EXP_NAME}) 67 | writedict(logfile, para_data_dict) 68 | writedict(logfile, para_model_dict) 69 | if 'train' in period: 70 | writedict(logfile, para_train_dict) 71 | else: 72 | writedict(logfile, para_test_dict) 73 | writedict(logfile, metric) 74 | 75 | logfile.write(commit) 76 | logfile.write('=====================================\n') 77 | logfile.close() 78 | 79 | 80 | 81 | def writedict(file, dictionary): 82 | s = '' 83 | for key in dictionary.keys(): 84 | sub = '%s:%s '%(key, dictionary[key]) 85 | s += sub 86 | s += '\n' 87 | file.write(s) 88 | 89 | 90 | -------------------------------------------------------------------------------- /lib/utils/imutils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import cv2 3 | 4 | def pseudo_erode(label, num, t=1): 5 | label_onehot = onehot(label, num) 6 | k = np.ones((15,15),np.uint8) 7 | e = cv2.erode(label_onehot, k, t) 8 | m = (e != label_onehot) 9 | m = np.max(m, axis=2) 10 | label[m] = 255 11 | return label 12 | 13 | 14 | def onehot(label, num): 15 | num = int(num) 16 | m = label.astype(np.int32) 17 | one_hot = np.eye(num)[m] 18 | return one_hot 19 | 20 | def seg2cls(label, num): 21 | cls = np.zeros(num) 22 | index = np.unique(label) 23 | cls[index] = 1 24 | #cls[0] = 0 25 | cls = cls.reshape((num,1,1)) 26 | return cls 27 | 28 | def gamma_correction(img): 29 | gamma = np.mean(img)/128.0 30 | lookUpTable = np.empty((1,256), np.uint8) 31 | for i in range(256): 32 | lookUpTable[0,i] = np.clip(pow(i / 255.0, gamma) * 255.0, 0, 255) 33 | res_img = cv2.LUT(img, lookUpTable) 34 | return res_img 35 | 36 | def img_denorm(inputs, mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225), mul=True): 37 | inputs = np.ascontiguousarray(inputs) 38 | if inputs.ndim == 3: 39 | inputs[0,:,:] = (inputs[0,:,:]*std[0] + mean[0]) 40 | inputs[1,:,:] = (inputs[1,:,:]*std[1] + mean[1]) 41 | inputs[2,:,:] = (inputs[2,:,:]*std[2] + mean[2]) 42 | elif inputs.ndim == 4: 43 | n = inputs.shape[0] 44 | for i in range(n): 45 | inputs[i,0,:,:] = (inputs[i,0,:,:]*std[0] + mean[0]) 46 | inputs[i,1,:,:] = (inputs[i,1,:,:]*std[1] + mean[1]) 47 | inputs[i,2,:,:] = (inputs[i,2,:,:]*std[2] + mean[2]) 48 | 49 | if mul: 50 | inputs = inputs*255 51 | inputs[inputs > 255] = 255 52 | inputs[inputs < 0] = 0 53 | inputs = inputs.astype(np.uint8) 54 | else: 55 | inputs[inputs > 1] = 1 56 | inputs[inputs < 0] = 0 57 | return inputs 58 | -------------------------------------------------------------------------------- /lib/utils/iou_computation.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | 5 | def update_iou_stat(predict, gt, TP, P, T, num_classes = 21): 6 | """ 7 | :param predict: the pred of each batch, should be numpy array, after take the argmax b,h,w 8 | :param gt: the gt label of the batch, should be numpy array b,h,w 9 | :param TP: True positive 10 | :param P: positive prediction 11 | :param T: True seg 12 | :param num_classes: number of classes in the dataset 13 | :return: TP, P, T 14 | """ 15 | cal = gt < 255 16 | 17 | mask = (predict == gt) * cal 18 | 19 | for i in range(num_classes): 20 | P[i] += np.sum((predict == i) * cal) 21 | T[i] += np.sum((gt == i) * cal) 22 | TP[i] += np.sum((gt == i) * mask) 23 | 24 | return TP, P, T 25 | 26 | 27 | def iter_iou_stat(predict, gt, num_classes = 21): 28 | """ 29 | :param predict: the pred of each batch, should be numpy array, after take the argmax b,h,w 30 | :param gt: the gt label of the batch, should be numpy array b,h,w 31 | :param TP: True positive 32 | :param P: positive prediction 33 | :param T: True seg 34 | :param num_classes: number of classes in the dataset 35 | :return: TP, P, T 36 | """ 37 | cal = gt < 255 38 | 39 | mask = (predict == gt) * cal 40 | 41 | TP = np.zeros(num_classes) 42 | P = np.zeros(num_classes) 43 | T = np.zeros(num_classes) 44 | 45 | for i in range(num_classes): 46 | P[i] = np.sum((predict == i) * cal) 47 | T[i] = np.sum((gt == i) * cal) 48 | TP[i] = np.sum((gt == i) * mask) 49 | 50 | return np.array([TP, P, T]) 51 | 52 | 53 | def compute_iou(TP, P, T, num_classes = 21): 54 | """ 55 | :param TP: 56 | :param P: 57 | :param T: 58 | :param num_classes: number of classes in the dataset 59 | :return: IoU 60 | """ 61 | IoU = [] 62 | for i in range(num_classes): 63 | IoU.append(TP[i] / (T[i] + P[i] - TP[i] + 1e-10)) 64 | return IoU 65 | 66 | 67 | def update_fraction_batchwise(mask, gt, fraction, num_classes = 21): 68 | """ 69 | :param mask: True when belong to subgroup (memorized, correct, others) which we want to calculate fraction on 70 | :param gt: the gt label of the batch, numpy array 71 | :param fraction: fraction of pixels in the subgroup 72 | :param num_classes: number of classes in the dataset 73 | :return: updated fraction 74 | """ 75 | cal = gt < 255 76 | 77 | for i in range(num_classes): 78 | fraction[i] += np.sum((mask * (gt == i) * cal))/np.sum((gt == i) * cal) 79 | 80 | return fraction 81 | 82 | 83 | def update_fraction_instancewise(mask, gt, fraction, num_classes = 21): 84 | """ 85 | :param mask: True when belong to subgroup (memorized, correct, others) which we want to calculate fraction on 86 | :param gt: the gt label of the batch, numpy array 87 | :param fraction: fraction of pixels in the subgroup 88 | :param num_classes: number of classes in the dataset 89 | :return: updated fraction 90 | """ 91 | # np.sum((gt == i) * cal maybe a nan value, can't do that 92 | cal = gt < 255 93 | 94 | for i in range(num_classes): 95 | fraction[i] += np.mean(np.sum((mask * (gt == i) * cal), axis= (-2,-1))/np.sum((gt == i) * cal, axis= (-2,-1))) 96 | 97 | return fraction 98 | 99 | def update_fraction_pixelwise(mask, gt, abs_num_and_total, num_classes = 21): 100 | """ 101 | :param mask: True when belong to subgroup (memorized, correct, others) which we want to calculate fraction on 102 | :param gt: the gt label of the batch, numpy array 103 | :param abs_num_and_total: the absolute number of pixel belong to the mask and the total num of pixels [abs_num, pixel_num] 104 | :param num_classes: number of classes in the dataset 105 | :return: updated fraction 106 | """ 107 | cal = gt < 255 108 | 109 | for i in range(num_classes): 110 | abs_num_and_total[i][0] += np.sum(mask * (gt == i) * cal) 111 | abs_num_and_total[i][1] += np.sum((gt == i) * cal) 112 | 113 | 114 | return abs_num_and_total 115 | 116 | def iter_fraction_pixelwise(mask, gt, num_classes = 21): 117 | """ 118 | :param mask: True when belong to subgroup (memorized, correct, others) which we want to calculate fraction on 119 | :param gt: the gt label of the batch, numpy array 120 | :param num_classes: number of classes in the dataset 121 | :return: updated fraction 122 | """ 123 | cal = gt < 255 124 | 125 | abs_num_and_total = np.zeros((num_classes,2)) 126 | 127 | for i in range(num_classes): 128 | abs_num_and_total[i][0] += np.sum(mask * (gt == i) * cal) 129 | abs_num_and_total[i][1] += np.sum((gt == i) * cal) 130 | 131 | 132 | return abs_num_and_total 133 | 134 | 135 | 136 | def get_mask(gt_np, label_np, pred_np): 137 | """ 138 | 139 | Args: 140 | gt_np: the GT label 141 | label_np: the CAM pseudo label 142 | pred_np: the prediction 143 | 144 | Returns: the mask of different type 145 | 146 | """ 147 | wrong_mask_correct = (gt_np != label_np) & (pred_np == gt_np) 148 | wrong_mask_memorized = (gt_np != label_np) & (pred_np == label_np) 149 | wrong_mask_others = (gt_np != label_np) & (pred_np != gt_np) & (pred_np != label_np) 150 | clean_mask_correct = (gt_np == label_np) & (pred_np == gt_np) 151 | clean_mask_incorrect = (gt_np == label_np) & (pred_np != gt_np) 152 | 153 | return (wrong_mask_correct,wrong_mask_memorized,wrong_mask_others,clean_mask_correct,clean_mask_incorrect) -------------------------------------------------------------------------------- /lib/utils/logger.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | try: 3 | from comet_ml import Experiment as CometExperiment 4 | from comet_ml import OfflineExperiment as CometOfflineExperiment 5 | except ImportError: # pragma: no-cover 6 | _COMET_AVAILABLE = False 7 | else: 8 | _COMET_AVAILABLE = True 9 | 10 | 11 | import torch 12 | from torch import is_tensor 13 | from typing import Any, Dict, Optional, Union 14 | from datetime import datetime 15 | 16 | class Timer: 17 | def __init__(self): 18 | self.cache = datetime.now() 19 | 20 | def check(self): 21 | now = datetime.now() 22 | duration = now - self.cache 23 | self.cache = now 24 | return duration.total_seconds() 25 | 26 | def reset(self): 27 | self.cache = datetime.now() 28 | 29 | class CometWriter: 30 | def __init__( 31 | self, 32 | project_name: Optional[str] = None, 33 | experiment_name: Optional[str] = None, 34 | api_key: Optional[str] = None, 35 | log_dir: Optional[str] = None, 36 | offline: bool = False, 37 | **kwargs): 38 | if not _COMET_AVAILABLE: 39 | raise ImportError( 40 | "You want to use `comet_ml` logger which is not installed yet," 41 | " install it with `pip install comet-ml`." 42 | ) 43 | 44 | self.project_name = project_name 45 | self.experiment_name = experiment_name 46 | self.kwargs = kwargs 47 | 48 | self.timer = Timer() 49 | 50 | 51 | if (api_key is not None) and (log_dir is not None): 52 | self.mode = "offline" if offline else "online" 53 | self.api_key = api_key 54 | self.log_dir = log_dir 55 | 56 | elif api_key is not None: 57 | self.mode = "online" 58 | self.api_key = api_key 59 | self.log_dir = None 60 | elif log_dir is not None: 61 | self.mode = "offline" 62 | self.log_dir = log_dir 63 | else: 64 | print("CometLogger requires either api_key or save_dir during initialization.") 65 | 66 | if self.mode == "online": 67 | self.experiment = CometExperiment( 68 | api_key=self.api_key, 69 | project_name = self.project_name, 70 | **self.kwargs, 71 | ) 72 | else: 73 | self.experiment = CometOfflineExperiment( 74 | offline_directory=self.log_dir, 75 | project_name=self.project_name, 76 | **self.kwargs, 77 | ) 78 | 79 | if self.experiment_name: 80 | self.experiment.set_name(self.experiment_name) 81 | 82 | def set_step(self, step, epoch = None, mode='train') -> None: 83 | self.mode = mode 84 | self.step = step 85 | self.epoch = epoch 86 | if step == 0: 87 | self.timer.reset() 88 | else: 89 | duration = self.timer.check() 90 | self.add_scalar({'steps_per_sec': 1 / duration}) 91 | 92 | def log_hyperparams(self, params: Dict[str, Any]) -> None: 93 | self.experiment.log_parameters(params) 94 | 95 | def log_code(self, file_name = None, folder = 'models/') -> None: 96 | self.experiment.log_code(file_name=file_name, folder=folder) 97 | 98 | 99 | def add_scalar(self, metrics: Dict[str, Union[torch.Tensor, float]], step: Optional[int] = None, epoch: Optional[int] = None) -> None: 100 | metrics_renamed = {} 101 | for key, val in metrics.items(): 102 | tag = '{}/{}'.format(key, self.mode) 103 | if is_tensor(val): 104 | metrics_renamed[tag] = val.cpu().detach() 105 | else: 106 | metrics_renamed[tag] = val 107 | if epoch is None and step is None: 108 | self.experiment.log_metrics(metrics_renamed, step = self.step, epoch = self.epoch) 109 | elif epoch is None and step is not None: 110 | self.experiment.log_metrics(metrics_renamed, step = step) 111 | elif epoch is not None and step is None: 112 | self.experiment.log_metrics(metrics_renamed, epoch = epoch) 113 | else: 114 | self.experiment.log_metrics(metrics_renamed, step = step, epoch = epoch) 115 | 116 | def add_plot(self, figure_name, figure): 117 | """ 118 | Primarily for log gate plots 119 | """ 120 | self.experiment.log_figure(figure_name = figure_name, figure = figure) 121 | 122 | def add_text(self, text, step): 123 | """ 124 | Primarily for log gate plots 125 | """ 126 | self.experiment.log_text(text, step = step) 127 | 128 | def add_hist3d(self, hist, name): 129 | """ 130 | Primarily for log gate plots 131 | """ 132 | self.experiment.log_histogram_3d(hist, name = name) 133 | 134 | def reset_experiment(self): 135 | self.experiment = None 136 | 137 | def finalize(self) -> None: 138 | self.experiment.end() 139 | self.reset_experiment() 140 | -------------------------------------------------------------------------------- /lib/utils/registry.py: -------------------------------------------------------------------------------- 1 | 2 | class Registry(object): 3 | def __init__(self, name): 4 | super(Registry, self).__init__() 5 | self._name = name 6 | self._module_dict = dict() 7 | 8 | @property 9 | def name(self): 10 | return self._name 11 | 12 | @property 13 | def module_dict(self): 14 | return self._module_dict 15 | 16 | def __len__(self): 17 | return len(self.module_dict) 18 | 19 | def get(self, key): 20 | return self._module_dict[key] 21 | 22 | def register_module(self, module=None): 23 | if module is None: 24 | raise TypeError('fail to register None in Registry {}'.format(self.name)) 25 | module_name = module.__name__ 26 | if module_name in self._module_dict: 27 | raise KeyError('{} is already registry in Registry {}'.format(module_name, self.name)) 28 | self._module_dict[module_name] = module 29 | return module 30 | 31 | DATASETS = Registry('dataset') 32 | BACKBONES = Registry('backbone') 33 | NETS = Registry('nets') 34 | -------------------------------------------------------------------------------- /lib/utils/test_utils.py: -------------------------------------------------------------------------------- 1 | import time 2 | import torch 3 | from tqdm import tqdm 4 | 5 | def single_gpu_test(model, dataloader, prepare_func, inference_func, collect_func, save_step_func=None): 6 | model.eval() 7 | n_gpus = torch.cuda.device_count() 8 | #assert n_gpus == 1 9 | collect_list = [] 10 | total_num = len(dataloader) 11 | with tqdm(total=total_num) as pbar: 12 | with torch.no_grad(): 13 | for i_batch, sample in enumerate(dataloader): 14 | name = sample['name'] 15 | image_msf = prepare_func(sample) 16 | result_list = [] 17 | for img in image_msf: 18 | result = inference_func(model, img.cuda()) 19 | result_list.append(result) 20 | result_item = collect_func(result_list, sample) 21 | result_sample = {'predict': result_item, 'name':name[0]} 22 | #print('%d/%d'%(i_batch,len(dataloader))) 23 | pbar.set_description('Processing') 24 | pbar.update(1) 25 | time.sleep(0.001) 26 | 27 | if save_step_func is not None: 28 | save_step_func(result_sample) 29 | else: 30 | collect_list.append(result_sample) 31 | return collect_list 32 | 33 | 34 | def single_gpu_multimodel_ensemble_test(model,model2, dataloader, prepare_func, inference_func, collect_func, save_step_func=None): 35 | model.eval() 36 | n_gpus = torch.cuda.device_count() 37 | #assert n_gpus == 1 38 | collect_list = [] 39 | total_num = len(dataloader) 40 | with tqdm(total=total_num) as pbar: 41 | with torch.no_grad(): 42 | for i_batch, sample in enumerate(dataloader): 43 | name = sample['name'] 44 | image_msf = prepare_func(sample) 45 | result_list = [] 46 | for img in image_msf: 47 | result = inference_func(model,model2, img.cuda()) 48 | result_list.append(result) 49 | result_item = collect_func(result_list, sample) 50 | result_sample = {'predict': result_item, 'name':name[0]} 51 | #print('%d/%d'%(i_batch,len(dataloader))) 52 | pbar.set_description('Processing') 53 | pbar.update(1) 54 | time.sleep(0.001) 55 | 56 | if save_step_func is not None: 57 | save_step_func(result_sample) 58 | else: 59 | collect_list.append(result_sample) 60 | return collect_list 61 | 62 | 63 | def single_gpu_triplemodel_ensemble_test(model,model2,model3, dataloader, prepare_func, inference_func, collect_func, save_step_func=None): 64 | model.eval() 65 | n_gpus = torch.cuda.device_count() 66 | #assert n_gpus == 1 67 | collect_list = [] 68 | total_num = len(dataloader) 69 | with tqdm(total=total_num) as pbar: 70 | with torch.no_grad(): 71 | for i_batch, sample in enumerate(dataloader): 72 | name = sample['name'] 73 | image_msf = prepare_func(sample) 74 | result_list = [] 75 | for img in image_msf: 76 | result = inference_func(model,model2,model3, img.cuda()) 77 | result_list.append(result) 78 | result_item = collect_func(result_list, sample) 79 | result_sample = {'predict': result_item, 'name':name[0]} 80 | #print('%d/%d'%(i_batch,len(dataloader))) 81 | pbar.set_description('Processing') 82 | pbar.update(1) 83 | time.sleep(0.001) 84 | 85 | if save_step_func is not None: 86 | save_step_func(result_sample) 87 | else: 88 | collect_list.append(result_sample) 89 | return collect_list 90 | -------------------------------------------------------------------------------- /lib/utils/visualization.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn.functional as F 4 | import cv2 5 | from utils.DenseCRF import * 6 | #from cv2.ximgproc import l0Smooth 7 | 8 | def color_pro(pro, img=None, mode='hwc'): 9 | H, W = pro.shape 10 | pro_255 = (pro*255).astype(np.uint8) 11 | pro_255 = np.expand_dims(pro_255,axis=2) 12 | color = cv2.applyColorMap(pro_255,cv2.COLORMAP_JET) 13 | color = cv2.cvtColor(color, cv2.COLOR_BGR2RGB) 14 | if img is not None: 15 | rate = 0.5 16 | if mode == 'hwc': 17 | assert img.shape[0] == H and img.shape[1] == W 18 | color = cv2.addWeighted(img,rate,color,1-rate,0) 19 | elif mode == 'chw': 20 | assert img.shape[1] == H and img.shape[2] == W 21 | img = np.transpose(img,(1,2,0)) 22 | color = cv2.addWeighted(img,rate,color,1-rate,0) 23 | color = np.transpose(color,(2,0,1)) 24 | else: 25 | if mode == 'chw': 26 | color = np.transpose(color,(2,0,1)) 27 | return color 28 | 29 | def generate_vis(p, gt, img, func_label2color, threshold=0.1, norm=True, crf=False): 30 | # All the input should be numpy.array 31 | # img should be 0-255 uint8 32 | C, H, W = p.shape 33 | 34 | if norm: 35 | prob = max_norm(p, 'numpy') 36 | else: 37 | prob = p 38 | if gt is not None: 39 | prob = prob * gt 40 | prob[prob<=0] = 1e-5 41 | if threshold is not None: 42 | prob[0,:,:] = np.power(1-np.max(prob[1:,:,:],axis=0,keepdims=True), 4) 43 | 44 | CLS = ColorCLS(prob, func_label2color) 45 | CAM = ColorCAM(prob, img) 46 | if crf: 47 | prob_crf = dense_crf(prob, img, n_classes=C, n_iters=1) 48 | CLS_crf = ColorCLS(prob_crf, func_label2color) 49 | CAM_crf = ColorCAM(prob_crf, img) 50 | return CLS, CAM, CLS_crf, CAM_crf 51 | else: 52 | return CLS, CAM 53 | 54 | def max_norm(p, version='torch', e=1e-5): 55 | if version is 'torch': 56 | if p.dim() == 3: 57 | C, H, W = p.size() 58 | p = F.relu(p, inplace=True) 59 | max_v = torch.max(p.view(C,-1),dim=-1)[0].view(C,1,1) 60 | min_v = torch.min(p.view(C,-1),dim=-1)[0].view(C,1,1) 61 | p = F.relu(p-min_v-e, inplace=True)/(max_v-min_v+e) 62 | elif p.dim() == 4: 63 | N, C, H, W = p.size() 64 | p = F.relu(p, inplace=True) 65 | max_v = torch.max(p.view(N,C,-1),dim=-1)[0].view(N,C,1,1) 66 | min_v = torch.min(p.view(N,C,-1),dim=-1)[0].view(N,C,1,1) 67 | p = F.relu(p-min_v-e, inplace=True)/(max_v-min_v+e) 68 | elif version is 'numpy' or version is 'np': 69 | if p.ndim == 3: 70 | C, H, W = p.shape 71 | p[p=1.7.0 2 | torchvision>=0.8.1 3 | mxnet>=1.7.0.post1 4 | scipy>=1.5.1 5 | numpy>=1.19.4 6 | scikit_image>=0.17.2 7 | pydensecrf>=1.0rc3 8 | pandas>=1.0.5 9 | opencv_python>=4.3.0.36 10 | matplotlib>=3.3.0 11 | Pillow>=8.1.0 12 | tensorboardX>=2.1 13 | tqdm 14 | scikit-image 15 | comet-ml 16 | --------------------------------------------------------------------------------