├── .figure ├── image_anno.jpg ├── image_only.jpg └── network.jpg ├── .gitignore ├── README.md ├── configs ├── __init__.py ├── cfg_base.py ├── cfg_dataset.py └── cfg_model.py ├── eval_utils.py ├── hrnet_code ├── .gitignore ├── LICENSE ├── README.md ├── lib │ ├── config │ │ ├── __init__.py │ │ ├── default.py │ │ └── models.py │ ├── core │ │ ├── criterion.py │ │ └── function.py │ ├── datasets │ │ ├── __init__.py │ │ ├── ade20k.py │ │ ├── base_dataset.py │ │ ├── cityscapes.py │ │ ├── cocostuff.py │ │ ├── lip.py │ │ └── pascal_ctx.py │ ├── models │ │ ├── __init__.py │ │ ├── bn_helper.py │ │ ├── seg_hrnet.py │ │ ├── seg_hrnet_ocr.py │ │ └── sync_bn │ │ │ ├── LICENSE │ │ │ ├── __init__.py │ │ │ └── inplace_abn │ │ │ ├── __init__.py │ │ │ ├── bn.py │ │ │ ├── functions.py │ │ │ └── src │ │ │ ├── common.h │ │ │ ├── inplace_abn.cpp │ │ │ ├── inplace_abn.h │ │ │ ├── inplace_abn_cpu.cpp │ │ │ └── inplace_abn_cuda.cu │ └── utils │ │ ├── __init__.py │ │ ├── distributed.py │ │ ├── modelsummary.py │ │ └── utils.py ├── requirements.txt └── tools │ ├── _init_paths.py │ ├── test.py │ └── train.py ├── inference.py ├── lib ├── __init__.py ├── cfg_helper.py ├── data_factory │ ├── __init__.py │ ├── ds_base.py │ ├── ds_cocotext.py │ ├── ds_formatter.py │ ├── ds_icdar13.py │ ├── ds_loader.py │ ├── ds_mlt.py │ ├── ds_sampler.py │ ├── ds_textseg.py │ ├── ds_textssc.py │ ├── ds_totaltext.py │ └── ds_transform.py ├── evaluate_service.py ├── log_service.py ├── loss.py ├── model_zoo │ ├── __init__.py │ ├── deeplab.py │ ├── get_model.py │ ├── hrnet.py │ ├── optim_manager.py │ ├── resnet.py │ ├── texrnet.py │ └── utils.py ├── nputils.py ├── optimizer │ ├── __init__.py │ └── get_optimizer.py └── torchutils.py ├── main.py ├── requirement.txt └── train_utils.py /.figure/image_anno.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SHI-Labs/Rethinking-Text-Segmentation/223f7ffc822c345ce1a7c0eb3d4fac58a43d6a3a/.figure/image_anno.jpg -------------------------------------------------------------------------------- /.figure/image_only.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SHI-Labs/Rethinking-Text-Segmentation/223f7ffc822c345ce1a7c0eb3d4fac58a43d6a3a/.figure/image_only.jpg -------------------------------------------------------------------------------- /.figure/network.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SHI-Labs/Rethinking-Text-Segmentation/223f7ffc822c345ce1a7c0eb3d4fac58a43d6a3a/.figure/network.jpg -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | /data 2 | /models 3 | /log 4 | /scripts 5 | /pretrained 6 | **/__pycache__ 7 | **/.idea/* 8 | **/.vscode/* 9 | **/.ipynb_checkpoints/* 10 | **/old/* 11 | **/veryold/* 12 | *.ipynb 13 | **/build 14 | **/*.out 15 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Rethinking Text Segmentation: A Novel Dataset and A Text-Specific Refinement Approach 2 | 3 | This is the repo to host the dataset TextSeg and code for TexRNet from the following paper: 4 | 5 | [Xingqian Xu](https://ifp-uiuc.github.io/), [Zhifei Zhang](http://web.eecs.utk.edu/~zzhang61/), [Zhaowen Wang](https://research.adobe.com/person/zhaowen-wang/), [Brian Price](https://research.adobe.com/person/brian-price/), [Zhonghao Wang](https://ifp-uiuc.github.io/) and [Humphrey Shi](https://www.humphreyshi.com), **Rethinking Text Segmentation: A Novel Dataset and A Text-Specific Refinement Approach**, [ArXiv Link](arxiv.org/abs/2011.14021) 6 | 7 | **Note:** 8 | 11 | \[2021.04.21\] So far, our dataset is partially released with images and semantic labels. Since many people may request the dataset for OCR or non-segmentation tasks, please stay tuned, and we will release the dataset in full ASAP. 12 | 13 | \[2021.06.18\] **Our dataset is now fully released.** To download the data, please send a request email to *textseg.dataset@gmail.com* and tell us which school you are affiliated with. Please be aware the released dataset is **version 2**, and the annotations are slightly different from the one in the paper. In order to provide the most accurate dataset, we went through a second round of quality assurance, in which we fixed some faulty annotations and made them more consistent across the dataset. Since our TexRNet in the paper doesn't use OCR and character instance labels (*i.e.* word- and character-level bounding polygons; character-level masks;), we will not release the older version of these labels. However, we release the retroactive ```semantic_label_v1.tar.gz``` for researchers to reproduce the results in the paper. For more details about the dataset, please see below. 14 | 15 | ## Introduction 16 | Text in the real world is extremely diverse, yet current text dataset does not reflect such diversity very well. To bridge this gap, we proposed TextSeg, a large-scale fine-annotated and multi-purpose text dataset, collecting scene and design text with six types of annotations: word- and character-wise bounding polygons, masks and transcriptions. We also introduce Text Refinement Network (TexRNet), a novel text segmentation approach that adapts to the unique properties of text, e.g. non-convex boundary, diverse texture, etc., which often impose burdens on traditional segmentation models. TexRNet refines results from common segmentation approach via key features pooling and attention, so that wrong-activated text regions can be adjusted. We also introduce trimap and discriminator losses that show significant improvement on text segmentation. 17 | 18 | ## TextSeg Dataset 19 | 20 | ### Image Collection 21 | 22 |

23 | 24 |

25 | 26 | ### Annotation 27 | 28 |

29 | 30 |

31 | 32 | ### Download 33 | 34 | Our dataset (TextSeg) is academia-only and cannot be used on any commercial project and research. To download the data, please send a request email to *textseg.dataset@gmail.com* and tell us which school you are affiliated with. 35 | 36 | A full download should contain these files: 37 | 38 | * ```image.tar.gz``` contains 4024 images. 39 | * ```annotation.tar.gz``` labels corresponding to the images. These three types of files are included: 40 | * ```[dataID]_anno.json``` contains all word- and character-level translations and bounding polygons. 41 | * ```[dataID]_mask.png``` contains all character masks. Character mask label value will be ordered from 1 to n. Label value 0 means background, 255 means ignore. 42 | * ```[dataID]_maskeff.png``` contains all character masks **with effect**. 43 | * ```Adobe_Research_License_TextSeg.txt``` license file. 44 | * ```semantic_label.tar.gz``` contains all word-level (semantic-level) masks. It contains: 45 | * ```[dataID]_maskfg.png``` 0 means background, 100 means word, 200 means word-effect, 255 means ignore. (The ```[dataID]_maskfg.png``` can also be generated using ```[dataID]_mask.png``` and ```[dataID]_maskeff.png```) 46 | * ```split.json``` the official split of train, val and test. 47 | * [Optional] ```semantic_label_v1.tar.gz``` the old version of label that was used in our paper. One can download it to reproduce our paper results. 48 | 49 | ## TexRNet Structure and Results 50 | 51 |

52 | 53 |

54 | 55 | In this table, we report the performance of our TexRNet on 5 text segmentation dataset including ours. 56 | 57 | 58 | 59 | 60 | 61 | 62 | 63 | 64 | 65 | 66 | 67 | 68 | 69 | 70 | 71 | 72 | 73 | 74 | 75 | 76 | 77 | 78 | 79 | 80 | 81 | 82 | 83 | 84 | 85 | 86 | 87 | 88 | 89 | 90 | 91 | 92 | 93 | 94 | 95 | 96 | 97 | 98 | 99 | 100 | 101 | 102 | 103 | 104 | 105 | 106 | 107 | 108 | 109 | 110 | 111 | 112 | 113 | 114 | 115 |
TextSeg(Ours)ICDAR13 FSTCOCO_TSMLT_STotal-Text
MethodfgIoUF-scorefgIoUF-scorefgIoUF-scorefgIoUF-scorefgIoUF-score
DeeplabV3+84.070.91469.270.80272.070.64184.630.83774.440.824
HRNetV2-W4885.030.91470.980.82268.930.62983.260.83675.290.825
HRNetV2-W48 + OCR85.980.91872.450.83069.540.62783.490.83876.230.832
Ours: TexRNet + DeeplabV3+ 86.06 0.921 72.16 0.835 73.980.72286.31 0.830 76.53 0.844
Ours: TexRNet + HRNetV2-W4886.840.92473.380.850 72.39 0.720 86.09 0.86578.470.848
116 |
117 | 118 | ## To run the code 119 | 120 | ### Set up the environment 121 | ``` 122 | conda create -n texrnet python=3.7 123 | conda activate texrnet 124 | pip install -r requirement.txt 125 | ``` 126 | ### To eval 127 | 128 | First, make the following directories to hold pre-trained models, dataset, and running logs: 129 | ``` 130 | mkdir ./pretrained 131 | mkdir ./data 132 | mkdir ./log 133 | ``` 134 | 135 | Second, download the models from [this link](https://drive.google.com/drive/folders/1EvGNvI5R6NKsW0YTy_0YHD9dpvtM0HDi?usp=sharing). Move those downloaded models to `./pretrained`. 136 | 137 | Thrid, make sure that `./data` contains the data. A sample root directory for **TextSeg** would be `./data/TextSeg`. 138 | 139 | Lastly, evaluate the model and compute fgIoU/F-score with the following command: 140 | ``` 141 | python main.py --eval --pth [model path] [--hrnet] [--gpu 0 1 ...] --dsname [dataset name] 142 | ``` 143 | 144 | Here is the sample command to eval a TexRNet_HRNet on TextSeg with 4 GPUs: 145 | ``` 146 | python main.py --eval --pth pretrained/texrnet_hrnet.pth --hrnet --gpu 0 1 2 3 --dsname textseg 147 | ``` 148 | 149 | The program will store results and execution log in `./log/eval`. 150 | 151 | ### To train 152 | 153 | Similarly, these directories need to be created: 154 | ``` 155 | mkdir ./pretrained 156 | mkdir ./pretrained/init 157 | mkdir ./data 158 | mkdir ./log 159 | ``` 160 | 161 | Second, we use multiple pre-trained models for training. Download these initial models from [this link](https://drive.google.com/drive/folders/1EvGNvI5R6NKsW0YTy_0YHD9dpvtM0HDi?usp=sharing). Move those models to `./pretrained/init`. Also, make sure that `./data` contains the data. 162 | 163 | Lastly, execute the training code with the following command: 164 | ``` 165 | python main.py [--hrnet] [--gpu 0 1 ...] --dsname [dataset name] [--trainwithcls] 166 | ``` 167 | 168 | Here is the sample command to train a TexRNet_HRNet on TextSeg with classifier and discriminate loss using 4 GPUs: 169 | ``` 170 | python main.py --hrnet --gpu 0 1 2 3 --dsname textseg --trainwithcls 171 | ``` 172 | 173 | The training configs, logs, and models will be stored in `./log/texrnet_[dsname]/[exid]_[signature]`. 174 | 175 | ## Bibtex 176 | ``` 177 | @article{xu2020rethinking, 178 | title={Rethinking Text Segmentation: A Novel Dataset and A Text-Specific Refinement Approach}, 179 | author={Xu, Xingqian and Zhang, Zhifei and Wang, Zhaowen and Price, Brian and Wang, Zhonghao and Shi, Humphrey}, 180 | journal={arXiv preprint arXiv:2011.14021}, 181 | year={2020} 182 | } 183 | ``` 184 | 185 | ## Acknowledgements 186 | 187 | The directory `.\hrnet_code` is directly copied from the HRNet official github website [(link)](https://github.com/HRNet/HRNet-Semantic-Segmentation). HRNet code ownership should be credited to HRNet authors, and users should follow their terms of usage. 188 | -------------------------------------------------------------------------------- /configs/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SHI-Labs/Rethinking-Text-Segmentation/223f7ffc822c345ce1a7c0eb3d4fac58a43d6a3a/configs/__init__.py -------------------------------------------------------------------------------- /configs/cfg_base.py: -------------------------------------------------------------------------------- 1 | import os 2 | import os.path as osp 3 | import numpy as np 4 | import copy 5 | import socket 6 | 7 | from easydict import EasyDict as edict 8 | 9 | cfg = edict() 10 | 11 | # -----------------------------BASE----------------------------- 12 | 13 | cfg.DEBUG = False 14 | cfg.EXPERIMENT_ID = 0 15 | cfg.GPU_DEVICE = 'all' 16 | cfg.CUDA = False 17 | cfg.MISC_DIR = osp.abspath(osp.join(osp.dirname(__file__), '..', 'log')) 18 | cfg.LOG_FILE = None 19 | cfg.RND_SEED = None 20 | cfg.RND_RECORDING = False 21 | # cfg.USE_FLOAT16 = False 22 | cfg.MATPLOTLIB_MODE = 'Agg' 23 | cfg.MAINLOOP_EXECUTE = True 24 | cfg.MAIN_CODE_PATH = None 25 | cfg.MAIN_CODE = [] 26 | cfg.SAVE_CODE = True 27 | cfg.COMPUTER_NAME = socket.gethostname() 28 | cfg.TORCH_VERSION = 'unknown' 29 | 30 | cfg.DIST_URL = 'tcp://127.0.0.1:11233' 31 | cfg.DIST_BACKEND = 'nccl' 32 | 33 | cfg_train = copy.deepcopy(cfg) 34 | cfg_test = copy.deepcopy(cfg) 35 | 36 | # -----------------------------TRAIN----------------------------- 37 | 38 | cfg_train.TRAIN = edict() 39 | cfg_train.TRAIN.BATCH_SIZE = None 40 | cfg_train.TRAIN.BATCH_SIZE_PER_GPU = None 41 | cfg_train.TRAIN.MAX_STEP = 0 42 | cfg_train.TRAIN.MAX_STEP_TYPE = None 43 | cfg_train.TRAIN.SKIP_PARTIAL = True 44 | # cfg_train.TRAIN.LR_ADJUST_MODE = None 45 | cfg_train.TRAIN.LR_ITER_BY = None 46 | cfg_train.TRAIN.OPTIMIZER = None 47 | cfg_train.TRAIN.DISPLAY = 0 48 | cfg_train.TRAIN.VISUAL = None 49 | cfg_train.TRAIN.SAVE_INIT_MODEL = True 50 | cfg_train.TRAIN.SAVE_CODE = True 51 | 52 | # -----------------------------TEST----------------------------- 53 | 54 | cfg_test.TEST = edict() 55 | cfg_test.TEST.BATCH_SIZE = None 56 | cfg_test.TEST.BATCH_SIZE_PER_GPU = None 57 | cfg_test.TEST.VISUAL = None 58 | 59 | # -----------------------------COMBINED----------------------------- 60 | 61 | cfg.update(cfg_train) 62 | cfg.update(cfg_test) 63 | -------------------------------------------------------------------------------- /configs/cfg_dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import os.path as osp 3 | import numpy as np 4 | import copy 5 | 6 | from easydict import EasyDict as edict 7 | 8 | cfg = edict() 9 | cfg.DATASET_MODE = None 10 | cfg.LOADER_PIPELINE = [] 11 | cfg.LOAD_BACKEND_IMAGE = 'pil' 12 | cfg.LOAD_IS_MC_IMAGE = False 13 | cfg.TRANS_PIPELINE = [] 14 | cfg.NUM_WORKERS_PER_GPU = None 15 | cfg.NUM_WORKERS = None 16 | cfg.TRY_SAMPLE = None 17 | 18 | ############################## 19 | ##### imagenet ##### 20 | ############################## 21 | 22 | cfg_imagenet = copy.deepcopy(cfg) 23 | cfg_imagenet.DATASET_NAME = 'imagenet' 24 | cfg_imagenet.ROOT_DIR = osp.abspath(osp.join( 25 | osp.dirname(__file__), '..', 'data', 26 | 'ImageNet', 'ILSVRC2012')) 27 | cfg_imagenet.CLASS_INFO_JSON = osp.abspath(osp.join( 28 | osp.dirname(__file__), '..', 'data', 29 | 'ImageNet', 'addon', 'ILSVRC2012', '1000nids.json')) 30 | cfg_imagenet.IM_MEAN = [0.485, 0.456, 0.406] 31 | cfg_imagenet.IM_STD = [0.229, 0.224, 0.225] 32 | cfg_imagenet.CLASS_NUM = 1000 33 | 34 | ############################# 35 | ##### textseg ##### 36 | ############################# 37 | 38 | cfg_textseg = copy.deepcopy(cfg) 39 | cfg_textseg.DATASET_NAME = 'textseg' 40 | cfg_textseg.ROOT_DIR = osp.abspath(osp.join( 41 | osp.dirname(__file__), '..', 'data', 'TextSeg')) 42 | cfg_textseg.CLASS_NUM = 2 43 | cfg_textseg.CLASS_NAME = [ 44 | 'background', 45 | 'text'] 46 | cfg_textseg.SEGLABEL_IGNORE_LABEL = 999 47 | cfg_textseg.SEMANTIC_PICK_CLASS = 'all' 48 | cfg_textseg.IM_MEAN = [0.485, 0.456, 0.406] 49 | cfg_textseg.IM_STD = [0.229, 0.224, 0.225] 50 | cfg_textseg.LOAD_IS_MC_SEGLABEL = True 51 | 52 | ########################## 53 | ##### cocotext ##### 54 | ########################## 55 | 56 | cfg_cocotext = copy.deepcopy(cfg) 57 | cfg_cocotext.DATASET_NAME = 'cocotext' 58 | cfg_cocotext.ROOT_DIR = osp.abspath(osp.join( 59 | osp.dirname(__file__), '..', 'data', 'COCO')) 60 | cfg_cocotext.IM_MEAN = [0.485, 0.456, 0.406] 61 | cfg_cocotext.IM_STD = [0.229, 0.224, 0.225] 62 | 63 | ######################## 64 | ##### cocots ##### 65 | ######################## 66 | 67 | cfg_cocots = copy.deepcopy(cfg) 68 | cfg_cocots.DATASET_NAME = 'cocots' 69 | cfg_cocots.ROOT_DIR = osp.abspath(osp.join( 70 | osp.dirname(__file__), '..', 'data', 'COCO')) 71 | cfg_cocots.IM_MEAN = [0.485, 0.456, 0.406] 72 | cfg_cocots.IM_STD = [0.229, 0.224, 0.225] 73 | cfg_cocots.CLASS_NUM = 2 74 | cfg_cocots.SEGLABEL_IGNORE_LABEL = 255 75 | cfg_cocots.LOAD_BACKEND_SEGLABEL = 'pil' 76 | cfg_cocots.LOAD_IS_MC_SEGLABEL = False 77 | 78 | ##################### 79 | ##### mlt ##### 80 | ##################### 81 | 82 | cfg_mlt = copy.deepcopy(cfg) 83 | cfg_mlt.DATASET_NAME = 'mlt' 84 | cfg_mlt.ROOT_DIR = osp.abspath(osp.join( 85 | osp.dirname(__file__), '..', 'data', 'ICDAR17', 'challenge8')) 86 | cfg_mlt.IM_MEAN = [0.485, 0.456, 0.406] 87 | cfg_mlt.IM_STD = [0.229, 0.224, 0.225] 88 | cfg_mlt.CLASS_NUM = 2 89 | cfg_mlt.SEGLABEL_IGNORE_LABEL = 255 90 | 91 | ####################### 92 | ##### icdar13 ##### 93 | ####################### 94 | 95 | cfg_icdar13 = copy.deepcopy(cfg) 96 | cfg_icdar13.DATASET_NAME = 'icdar13' 97 | cfg_icdar13.ROOT_DIR = osp.abspath(osp.join( 98 | osp.dirname(__file__), '..', 'data', 'ICDAR13')) 99 | cfg_icdar13.CLASS_NUM = 2 100 | cfg_icdar13.CLASS_NAME = [ 101 | 'background', 102 | 'text'] 103 | cfg_icdar13.SEGLABEL_IGNORE_LABEL = 999 104 | cfg_icdar13.SEMANTIC_PICK_CLASS = 'all' 105 | cfg_icdar13.IM_MEAN = [0.485, 0.456, 0.406] 106 | cfg_icdar13.IM_STD = [0.229, 0.224, 0.225] 107 | cfg_icdar13.LOAD_BACKEND_SEGLABEL = 'pil' 108 | cfg_icdar13.LOAD_IS_MC_SEGLABEL = False 109 | cfg_icdar13.FROM_SOURCE = 'addon' 110 | cfg_icdar13.USE_CACHE = False 111 | 112 | ######################### 113 | ##### totaltext ##### 114 | ######################### 115 | 116 | cfg_totaltext = copy.deepcopy(cfg) 117 | cfg_totaltext.DATASET_NAME = 'totaltext' 118 | cfg_totaltext.ROOT_DIR = osp.abspath(osp.join( 119 | osp.dirname(__file__), '..', 'data', 'TotalText')) 120 | cfg_totaltext.CLASS_NUM = 2 121 | cfg_totaltext.CLASS_NAME = [ 122 | 'background', 123 | 'text'] 124 | cfg_totaltext.SEGLABEL_IGNORE_LABEL = 999 125 | # dummy, totaltext pixel level anno has no ignore label 126 | cfg_totaltext.IM_MEAN = [0.485, 0.456, 0.406] 127 | cfg_totaltext.IM_STD = [0.229, 0.224, 0.225] 128 | 129 | ####################### 130 | ##### textssc ##### 131 | ####################### 132 | # text semantic segmentation composed 133 | 134 | cfg_textssc = copy.deepcopy(cfg) 135 | cfg_textssc.DATASET_NAME = 'textssc' 136 | cfg_textssc.ROOT_DIR = osp.abspath(osp.join( 137 | osp.dirname(__file__), '..', 'data', 'TextSSC')) 138 | cfg_textssc.CLASS_NUM = 2 139 | cfg_textssc.CLASS_NAME = [ 140 | 'background', 141 | 'text'] 142 | cfg_textssc.SEGLABEL_IGNORE_LABEL = 999 143 | cfg_textssc.IM_MEAN = [0.485, 0.456, 0.406] 144 | cfg_textssc.IM_STD = [0.229, 0.224, 0.225] 145 | cfg_textssc.LOAD_BACKEND_SEGLABEL = 'pil' 146 | cfg_textssc.LOAD_IS_MC_SEGLABEL = False 147 | -------------------------------------------------------------------------------- /configs/cfg_model.py: -------------------------------------------------------------------------------- 1 | import os 2 | import os.path as osp 3 | import numpy as np 4 | import copy 5 | 6 | from easydict import EasyDict as edict 7 | 8 | cfg = edict() 9 | cfg.MODEL_NAME = None 10 | # cfg.CONV_TYPE = 'conv' 11 | # cfg.BN_TYPE = 'bn' 12 | # cfg.RELU_TYPE = 'relu' 13 | 14 | # resnet 15 | cfg_resnet = copy.deepcopy(cfg) 16 | cfg_resnet.MODEL_NAME = 'resnet' 17 | cfg_resnet.RESNET = edict() 18 | cfg_resnet.RESNET.MODEL_TAGS = None 19 | cfg_resnet.RESNET.PRETRAINED_PTH = None 20 | cfg_resnet.RESNET.BN_TYPE = 'bn' 21 | cfg_resnet.RESNET.RELU_TYPE = 'relu' 22 | 23 | # deeplab 24 | cfg_deeplab = copy.deepcopy(cfg) 25 | cfg_deeplab.MODEL_NAME = 'deeplab' 26 | cfg_deeplab.DEEPLAB = edict() 27 | cfg_deeplab.DEEPLAB.MODEL_TAGS = None 28 | cfg_deeplab.DEEPLAB.PRETRAINED_PTH = None 29 | cfg_deeplab.DEEPLAB.FREEZE_BACKBONE_BN = False 30 | cfg_deeplab.DEEPLAB.BN_TYPE = 'bn' 31 | cfg_deeplab.DEEPLAB.RELU_TYPE = 'relu' 32 | # cfg_deeplab.DEEPLAB.ASPP_DROPOUT_TYPE = 'dropout|0.5' 33 | cfg_deeplab.DEEPLAB.ASPP_WITH_GAP = True 34 | # cfg_deeplab.DEEPLAB.DECODER_DROPOUT2_TYPE = 'dropout|0.5' 35 | # cfg_deeplab.DEEPLAB.DECODER_DROPOUT3_TYPE = 'dropout|0.1' 36 | cfg_deeplab.RESNET = cfg_resnet.RESNET 37 | 38 | # hrnet 39 | cfg_hrnet = copy.deepcopy(cfg) 40 | cfg_hrnet.MODEL_NAME = 'hrnet' 41 | cfg_hrnet.HRNET = edict() 42 | cfg_hrnet.HRNET.MODEL_TAGS = None 43 | cfg_hrnet.HRNET.PRETRAINED_PTH = None 44 | cfg_hrnet.HRNET.BN_TYPE = 'bn' 45 | cfg_hrnet.HRNET.RELU_TYPE = 'relu' 46 | 47 | # texrnet 48 | cfg_texrnet = copy.deepcopy(cfg) 49 | cfg_texrnet.MODEL_NAME = 'texrnet' 50 | cfg_texrnet.TEXRNET = edict() 51 | cfg_texrnet.TEXRNET.MODEL_TAGS = None 52 | cfg_texrnet.TEXRNET.PRETRAINED_PTH = None 53 | cfg_texrnet.RESNET = cfg_resnet.RESNET 54 | cfg_texrnet.DEEPLAB = cfg_deeplab.DEEPLAB 55 | -------------------------------------------------------------------------------- /hrnet_code/.gitignore: -------------------------------------------------------------------------------- 1 | .vscode 2 | __pycache__/ 3 | *.py[co] 4 | data/ 5 | log/ 6 | output/ 7 | pretrained_models 8 | scripts/ 9 | detail-api/ 10 | data/list -------------------------------------------------------------------------------- /hrnet_code/LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) [2019] [Microsoft] 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | 23 | ======================================================================================= 24 | 3-clause BSD licenses 25 | ======================================================================================= 26 | 1. syncbn - For details, see lib/models/syncbn/LICENSE 27 | Copyright (c) 2017 mapillary 28 | -------------------------------------------------------------------------------- /hrnet_code/lib/config/__init__.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------ 2 | # Copyright (c) Microsoft 3 | # Licensed under the MIT License. 4 | # Written by Ke Sun (sunk@mail.ustc.edu.cn) 5 | # ------------------------------------------------------------------------------ 6 | from __future__ import absolute_import 7 | from __future__ import division 8 | from __future__ import print_function 9 | 10 | from .default import _C as config 11 | from .default import update_config 12 | from .models import MODEL_EXTRAS 13 | -------------------------------------------------------------------------------- /hrnet_code/lib/config/default.py: -------------------------------------------------------------------------------- 1 | 2 | # ------------------------------------------------------------------------------ 3 | # Copyright (c) Microsoft 4 | # Licensed under the MIT License. 5 | # Written by Ke Sun (sunk@mail.ustc.edu.cn) 6 | # ------------------------------------------------------------------------------ 7 | 8 | from __future__ import absolute_import 9 | from __future__ import division 10 | from __future__ import print_function 11 | 12 | import os 13 | 14 | from yacs.config import CfgNode as CN 15 | 16 | 17 | _C = CN() 18 | 19 | _C.OUTPUT_DIR = '' 20 | _C.LOG_DIR = '' 21 | _C.GPUS = (0,) 22 | _C.WORKERS = 4 23 | _C.PRINT_FREQ = 20 24 | _C.AUTO_RESUME = False 25 | _C.PIN_MEMORY = True 26 | _C.RANK = 0 27 | 28 | # Cudnn related params 29 | _C.CUDNN = CN() 30 | _C.CUDNN.BENCHMARK = True 31 | _C.CUDNN.DETERMINISTIC = False 32 | _C.CUDNN.ENABLED = True 33 | 34 | # common params for NETWORK 35 | _C.MODEL = CN() 36 | _C.MODEL.NAME = 'seg_hrnet' 37 | _C.MODEL.PRETRAINED = '' 38 | _C.MODEL.ALIGN_CORNERS = True 39 | _C.MODEL.NUM_OUTPUTS = 1 40 | _C.MODEL.EXTRA = CN(new_allowed=True) 41 | 42 | 43 | _C.MODEL.OCR = CN() 44 | _C.MODEL.OCR.MID_CHANNELS = 512 45 | _C.MODEL.OCR.KEY_CHANNELS = 256 46 | _C.MODEL.OCR.DROPOUT = 0.05 47 | _C.MODEL.OCR.SCALE = 1 48 | 49 | _C.LOSS = CN() 50 | _C.LOSS.USE_OHEM = False 51 | _C.LOSS.OHEMTHRES = 0.9 52 | _C.LOSS.OHEMKEEP = 100000 53 | _C.LOSS.CLASS_BALANCE = False 54 | _C.LOSS.BALANCE_WEIGHTS = [1] 55 | 56 | # DATASET related params 57 | _C.DATASET = CN() 58 | _C.DATASET.ROOT = '' 59 | _C.DATASET.DATASET = 'cityscapes' 60 | _C.DATASET.NUM_CLASSES = 19 61 | _C.DATASET.TRAIN_SET = 'list/cityscapes/train.lst' 62 | _C.DATASET.EXTRA_TRAIN_SET = '' 63 | _C.DATASET.TEST_SET = 'list/cityscapes/val.lst' 64 | 65 | # training 66 | _C.TRAIN = CN() 67 | 68 | _C.TRAIN.FREEZE_LAYERS = '' 69 | _C.TRAIN.FREEZE_EPOCHS = -1 70 | _C.TRAIN.NONBACKBONE_KEYWORDS = [] 71 | _C.TRAIN.NONBACKBONE_MULT = 10 72 | 73 | _C.TRAIN.IMAGE_SIZE = [1024, 512] # width * height 74 | _C.TRAIN.BASE_SIZE = 2048 75 | _C.TRAIN.DOWNSAMPLERATE = 1 76 | _C.TRAIN.FLIP = True 77 | _C.TRAIN.MULTI_SCALE = True 78 | _C.TRAIN.SCALE_FACTOR = 16 79 | 80 | _C.TRAIN.RANDOM_BRIGHTNESS = False 81 | _C.TRAIN.RANDOM_BRIGHTNESS_SHIFT_VALUE = 10 82 | 83 | _C.TRAIN.LR_FACTOR = 0.1 84 | _C.TRAIN.LR_STEP = [90, 110] 85 | _C.TRAIN.LR = 0.01 86 | _C.TRAIN.EXTRA_LR = 0.001 87 | 88 | _C.TRAIN.OPTIMIZER = 'sgd' 89 | _C.TRAIN.MOMENTUM = 0.9 90 | _C.TRAIN.WD = 0.0001 91 | _C.TRAIN.NESTEROV = False 92 | _C.TRAIN.IGNORE_LABEL = -1 93 | 94 | _C.TRAIN.BEGIN_EPOCH = 0 95 | _C.TRAIN.END_EPOCH = 484 96 | _C.TRAIN.EXTRA_EPOCH = 0 97 | 98 | _C.TRAIN.RESUME = False 99 | 100 | _C.TRAIN.BATCH_SIZE_PER_GPU = 32 101 | _C.TRAIN.SHUFFLE = True 102 | # only using some training samples 103 | _C.TRAIN.NUM_SAMPLES = 0 104 | 105 | # testing 106 | _C.TEST = CN() 107 | 108 | _C.TEST.IMAGE_SIZE = [2048, 1024] # width * height 109 | _C.TEST.BASE_SIZE = 2048 110 | 111 | _C.TEST.BATCH_SIZE_PER_GPU = 32 112 | # only testing some samples 113 | _C.TEST.NUM_SAMPLES = 0 114 | 115 | _C.TEST.MODEL_FILE = '' 116 | _C.TEST.FLIP_TEST = False 117 | _C.TEST.MULTI_SCALE = False 118 | _C.TEST.SCALE_LIST = [1] 119 | 120 | _C.TEST.OUTPUT_INDEX = -1 121 | 122 | # debug 123 | _C.DEBUG = CN() 124 | _C.DEBUG.DEBUG = False 125 | _C.DEBUG.SAVE_BATCH_IMAGES_GT = False 126 | _C.DEBUG.SAVE_BATCH_IMAGES_PRED = False 127 | _C.DEBUG.SAVE_HEATMAPS_GT = False 128 | _C.DEBUG.SAVE_HEATMAPS_PRED = False 129 | 130 | 131 | def update_config(cfg, args): 132 | cfg.defrost() 133 | 134 | cfg.merge_from_file(args.cfg) 135 | cfg.merge_from_list(args.opts) 136 | 137 | cfg.freeze() 138 | 139 | 140 | if __name__ == '__main__': 141 | import sys 142 | with open(sys.argv[1], 'w') as f: 143 | print(_C, file=f) 144 | 145 | -------------------------------------------------------------------------------- /hrnet_code/lib/config/models.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------ 2 | # Copyright (c) Microsoft 3 | # Licensed under the MIT License. 4 | # Written by Ke Sun (sunk@mail.ustc.edu.cn) 5 | # ------------------------------------------------------------------------------ 6 | 7 | from __future__ import absolute_import 8 | from __future__ import division 9 | from __future__ import print_function 10 | 11 | from yacs.config import CfgNode as CN 12 | 13 | # high_resoluton_net related params for segmentation 14 | HIGH_RESOLUTION_NET = CN() 15 | HIGH_RESOLUTION_NET.PRETRAINED_LAYERS = ['*'] 16 | HIGH_RESOLUTION_NET.STEM_INPLANES = 64 17 | HIGH_RESOLUTION_NET.FINAL_CONV_KERNEL = 1 18 | HIGH_RESOLUTION_NET.WITH_HEAD = True 19 | 20 | HIGH_RESOLUTION_NET.STAGE2 = CN() 21 | HIGH_RESOLUTION_NET.STAGE2.NUM_MODULES = 1 22 | HIGH_RESOLUTION_NET.STAGE2.NUM_BRANCHES = 2 23 | HIGH_RESOLUTION_NET.STAGE2.NUM_BLOCKS = [4, 4] 24 | HIGH_RESOLUTION_NET.STAGE2.NUM_CHANNELS = [32, 64] 25 | HIGH_RESOLUTION_NET.STAGE2.BLOCK = 'BASIC' 26 | HIGH_RESOLUTION_NET.STAGE2.FUSE_METHOD = 'SUM' 27 | 28 | HIGH_RESOLUTION_NET.STAGE3 = CN() 29 | HIGH_RESOLUTION_NET.STAGE3.NUM_MODULES = 1 30 | HIGH_RESOLUTION_NET.STAGE3.NUM_BRANCHES = 3 31 | HIGH_RESOLUTION_NET.STAGE3.NUM_BLOCKS = [4, 4, 4] 32 | HIGH_RESOLUTION_NET.STAGE3.NUM_CHANNELS = [32, 64, 128] 33 | HIGH_RESOLUTION_NET.STAGE3.BLOCK = 'BASIC' 34 | HIGH_RESOLUTION_NET.STAGE3.FUSE_METHOD = 'SUM' 35 | 36 | HIGH_RESOLUTION_NET.STAGE4 = CN() 37 | HIGH_RESOLUTION_NET.STAGE4.NUM_MODULES = 1 38 | HIGH_RESOLUTION_NET.STAGE4.NUM_BRANCHES = 4 39 | HIGH_RESOLUTION_NET.STAGE4.NUM_BLOCKS = [4, 4, 4, 4] 40 | HIGH_RESOLUTION_NET.STAGE4.NUM_CHANNELS = [32, 64, 128, 256] 41 | HIGH_RESOLUTION_NET.STAGE4.BLOCK = 'BASIC' 42 | HIGH_RESOLUTION_NET.STAGE4.FUSE_METHOD = 'SUM' 43 | 44 | MODEL_EXTRAS = { 45 | 'seg_hrnet': HIGH_RESOLUTION_NET, 46 | } 47 | -------------------------------------------------------------------------------- /hrnet_code/lib/core/criterion.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------ 2 | # Copyright (c) Microsoft 3 | # Licensed under the MIT License. 4 | # Written by Ke Sun (sunk@mail.ustc.edu.cn) 5 | # ------------------------------------------------------------------------------ 6 | 7 | import torch 8 | import torch.nn as nn 9 | from torch.nn import functional as F 10 | import logging 11 | from config import config 12 | 13 | 14 | class CrossEntropy(nn.Module): 15 | def __init__(self, ignore_label=-1, weight=None): 16 | super(CrossEntropy, self).__init__() 17 | self.ignore_label = ignore_label 18 | self.criterion = nn.CrossEntropyLoss( 19 | weight=weight, 20 | ignore_index=ignore_label 21 | ) 22 | 23 | def _forward(self, score, target): 24 | ph, pw = score.size(2), score.size(3) 25 | h, w = target.size(1), target.size(2) 26 | if ph != h or pw != w: 27 | score = F.interpolate(input=score, size=( 28 | h, w), mode='bilinear', align_corners=config.MODEL.ALIGN_CORNERS) 29 | 30 | loss = self.criterion(score, target) 31 | 32 | return loss 33 | 34 | def forward(self, score, target): 35 | 36 | if config.MODEL.NUM_OUTPUTS == 1: 37 | score = [score] 38 | 39 | weights = config.LOSS.BALANCE_WEIGHTS 40 | assert len(weights) == len(score) 41 | 42 | return sum([w * self._forward(x, target) for (w, x) in zip(weights, score)]) 43 | 44 | 45 | class OhemCrossEntropy(nn.Module): 46 | def __init__(self, ignore_label=-1, thres=0.7, 47 | min_kept=100000, weight=None): 48 | super(OhemCrossEntropy, self).__init__() 49 | self.thresh = thres 50 | self.min_kept = max(1, min_kept) 51 | self.ignore_label = ignore_label 52 | self.criterion = nn.CrossEntropyLoss( 53 | weight=weight, 54 | ignore_index=ignore_label, 55 | reduction='none' 56 | ) 57 | 58 | def _ce_forward(self, score, target): 59 | ph, pw = score.size(2), score.size(3) 60 | h, w = target.size(1), target.size(2) 61 | if ph != h or pw != w: 62 | score = F.interpolate(input=score, size=( 63 | h, w), mode='bilinear', align_corners=config.MODEL.ALIGN_CORNERS) 64 | 65 | loss = self.criterion(score, target) 66 | 67 | return loss 68 | 69 | def _ohem_forward(self, score, target, **kwargs): 70 | ph, pw = score.size(2), score.size(3) 71 | h, w = target.size(1), target.size(2) 72 | if ph != h or pw != w: 73 | score = F.interpolate(input=score, size=( 74 | h, w), mode='bilinear', align_corners=config.MODEL.ALIGN_CORNERS) 75 | pred = F.softmax(score, dim=1) 76 | pixel_losses = self.criterion(score, target).contiguous().view(-1) 77 | mask = target.contiguous().view(-1) != self.ignore_label 78 | 79 | tmp_target = target.clone() 80 | tmp_target[tmp_target == self.ignore_label] = 0 81 | pred = pred.gather(1, tmp_target.unsqueeze(1)) 82 | pred, ind = pred.contiguous().view(-1,)[mask].contiguous().sort() 83 | min_value = pred[min(self.min_kept, pred.numel() - 1)] 84 | threshold = max(min_value, self.thresh) 85 | 86 | pixel_losses = pixel_losses[mask][ind] 87 | pixel_losses = pixel_losses[pred < threshold] 88 | return pixel_losses.mean() 89 | 90 | def forward(self, score, target): 91 | 92 | if config.MODEL.NUM_OUTPUTS == 1: 93 | score = [score] 94 | 95 | weights = config.LOSS.BALANCE_WEIGHTS 96 | assert len(weights) == len(score) 97 | 98 | functions = [self._ce_forward] * \ 99 | (len(weights) - 1) + [self._ohem_forward] 100 | return sum([ 101 | w * func(x, target) 102 | for (w, x, func) in zip(weights, score, functions) 103 | ]) 104 | -------------------------------------------------------------------------------- /hrnet_code/lib/core/function.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------ 2 | # Copyright (c) Microsoft 3 | # Licensed under the MIT License. 4 | # Written by Ke Sun (sunk@mail.ustc.edu.cn) 5 | # ------------------------------------------------------------------------------ 6 | 7 | import logging 8 | import os 9 | import time 10 | 11 | import numpy as np 12 | import numpy.ma as ma 13 | from tqdm import tqdm 14 | 15 | import torch 16 | import torch.nn as nn 17 | from torch.nn import functional as F 18 | 19 | from utils.utils import AverageMeter 20 | from utils.utils import get_confusion_matrix 21 | from utils.utils import adjust_learning_rate 22 | 23 | import utils.distributed as dist 24 | 25 | 26 | def reduce_tensor(inp): 27 | """ 28 | Reduce the loss from all processes so that 29 | process with rank 0 has the averaged results. 30 | """ 31 | world_size = dist.get_world_size() 32 | if world_size < 2: 33 | return inp 34 | with torch.no_grad(): 35 | reduced_inp = inp 36 | torch.distributed.reduce(reduced_inp, dst=0) 37 | return reduced_inp / world_size 38 | 39 | 40 | def train(config, epoch, num_epoch, epoch_iters, base_lr, 41 | num_iters, trainloader, optimizer, model, writer_dict): 42 | # Training 43 | model.train() 44 | 45 | batch_time = AverageMeter() 46 | ave_loss = AverageMeter() 47 | tic = time.time() 48 | cur_iters = epoch*epoch_iters 49 | writer = writer_dict['writer'] 50 | global_steps = writer_dict['train_global_steps'] 51 | 52 | for i_iter, batch in enumerate(trainloader, 0): 53 | images, labels, _, _ = batch 54 | images = images.cuda() 55 | labels = labels.long().cuda() 56 | 57 | losses, _ = model(images, labels) 58 | loss = losses.mean() 59 | 60 | if dist.is_distributed(): 61 | reduced_loss = reduce_tensor(loss) 62 | else: 63 | reduced_loss = loss 64 | 65 | model.zero_grad() 66 | loss.backward() 67 | optimizer.step() 68 | 69 | # measure elapsed time 70 | batch_time.update(time.time() - tic) 71 | tic = time.time() 72 | 73 | # update average loss 74 | ave_loss.update(reduced_loss.item()) 75 | 76 | lr = adjust_learning_rate(optimizer, 77 | base_lr, 78 | num_iters, 79 | i_iter+cur_iters) 80 | 81 | if i_iter % config.PRINT_FREQ == 0 and dist.get_rank() == 0: 82 | msg = 'Epoch: [{}/{}] Iter:[{}/{}], Time: {:.2f}, ' \ 83 | 'lr: {}, Loss: {:.6f}' .format( 84 | epoch, num_epoch, i_iter, epoch_iters, 85 | batch_time.average(), [x['lr'] for x in optimizer.param_groups], ave_loss.average()) 86 | logging.info(msg) 87 | 88 | writer.add_scalar('train_loss', ave_loss.average(), global_steps) 89 | writer_dict['train_global_steps'] = global_steps + 1 90 | 91 | def validate(config, testloader, model, writer_dict): 92 | model.eval() 93 | ave_loss = AverageMeter() 94 | nums = config.MODEL.NUM_OUTPUTS 95 | confusion_matrix = np.zeros( 96 | (config.DATASET.NUM_CLASSES, config.DATASET.NUM_CLASSES, nums)) 97 | with torch.no_grad(): 98 | for idx, batch in enumerate(testloader): 99 | image, label, _, _ = batch 100 | size = label.size() 101 | image = image.cuda() 102 | label = label.long().cuda() 103 | 104 | losses, pred = model(image, label) 105 | if not isinstance(pred, (list, tuple)): 106 | pred = [pred] 107 | for i, x in enumerate(pred): 108 | x = F.interpolate( 109 | input=x, size=size[-2:], 110 | mode='bilinear', align_corners=config.MODEL.ALIGN_CORNERS 111 | ) 112 | 113 | confusion_matrix[..., i] += get_confusion_matrix( 114 | label, 115 | x, 116 | size, 117 | config.DATASET.NUM_CLASSES, 118 | config.TRAIN.IGNORE_LABEL 119 | ) 120 | 121 | if idx % 10 == 0: 122 | print(idx) 123 | 124 | loss = losses.mean() 125 | if dist.is_distributed(): 126 | reduced_loss = reduce_tensor(loss) 127 | else: 128 | reduced_loss = loss 129 | ave_loss.update(reduced_loss.item()) 130 | 131 | if dist.is_distributed(): 132 | confusion_matrix = torch.from_numpy(confusion_matrix).cuda() 133 | reduced_confusion_matrix = reduce_tensor(confusion_matrix) 134 | confusion_matrix = reduced_confusion_matrix.cpu().numpy() 135 | 136 | for i in range(nums): 137 | pos = confusion_matrix[..., i].sum(1) 138 | res = confusion_matrix[..., i].sum(0) 139 | tp = np.diag(confusion_matrix[..., i]) 140 | IoU_array = (tp / np.maximum(1.0, pos + res - tp)) 141 | mean_IoU = IoU_array.mean() 142 | if dist.get_rank() <= 0: 143 | logging.info('{} {} {}'.format(i, IoU_array, mean_IoU)) 144 | 145 | writer = writer_dict['writer'] 146 | global_steps = writer_dict['valid_global_steps'] 147 | writer.add_scalar('valid_loss', ave_loss.average(), global_steps) 148 | writer.add_scalar('valid_mIoU', mean_IoU, global_steps) 149 | writer_dict['valid_global_steps'] = global_steps + 1 150 | return ave_loss.average(), mean_IoU, IoU_array 151 | 152 | 153 | def testval(config, test_dataset, testloader, model, 154 | sv_dir='', sv_pred=False): 155 | model.eval() 156 | confusion_matrix = np.zeros( 157 | (config.DATASET.NUM_CLASSES, config.DATASET.NUM_CLASSES)) 158 | with torch.no_grad(): 159 | for index, batch in enumerate(tqdm(testloader)): 160 | image, label, _, name, *border_padding = batch 161 | size = label.size() 162 | pred = test_dataset.multi_scale_inference( 163 | config, 164 | model, 165 | image, 166 | scales=config.TEST.SCALE_LIST, 167 | flip=config.TEST.FLIP_TEST) 168 | 169 | if len(border_padding) > 0: 170 | border_padding = border_padding[0] 171 | pred = pred[:, :, 0:pred.size(2) - border_padding[0], 0:pred.size(3) - border_padding[1]] 172 | 173 | if pred.size()[-2] != size[-2] or pred.size()[-1] != size[-1]: 174 | pred = F.interpolate( 175 | pred, size[-2:], 176 | mode='bilinear', align_corners=config.MODEL.ALIGN_CORNERS 177 | ) 178 | 179 | confusion_matrix += get_confusion_matrix( 180 | label, 181 | pred, 182 | size, 183 | config.DATASET.NUM_CLASSES, 184 | config.TRAIN.IGNORE_LABEL) 185 | 186 | if sv_pred: 187 | sv_path = os.path.join(sv_dir, 'test_results') 188 | if not os.path.exists(sv_path): 189 | os.mkdir(sv_path) 190 | test_dataset.save_pred(pred, sv_path, name) 191 | 192 | if index % 100 == 0: 193 | logging.info('processing: %d images' % index) 194 | pos = confusion_matrix.sum(1) 195 | res = confusion_matrix.sum(0) 196 | tp = np.diag(confusion_matrix) 197 | IoU_array = (tp / np.maximum(1.0, pos + res - tp)) 198 | mean_IoU = IoU_array.mean() 199 | logging.info('mIoU: %.4f' % (mean_IoU)) 200 | 201 | pos = confusion_matrix.sum(1) 202 | res = confusion_matrix.sum(0) 203 | tp = np.diag(confusion_matrix) 204 | pixel_acc = tp.sum()/pos.sum() 205 | mean_acc = (tp/np.maximum(1.0, pos)).mean() 206 | IoU_array = (tp / np.maximum(1.0, pos + res - tp)) 207 | mean_IoU = IoU_array.mean() 208 | 209 | return mean_IoU, IoU_array, pixel_acc, mean_acc 210 | 211 | 212 | def test(config, test_dataset, testloader, model, 213 | sv_dir='', sv_pred=True): 214 | model.eval() 215 | with torch.no_grad(): 216 | for _, batch in enumerate(tqdm(testloader)): 217 | image, size, name = batch 218 | size = size[0] 219 | pred = test_dataset.multi_scale_inference( 220 | config, 221 | model, 222 | image, 223 | scales=config.TEST.SCALE_LIST, 224 | flip=config.TEST.FLIP_TEST) 225 | 226 | if pred.size()[-2] != size[0] or pred.size()[-1] != size[1]: 227 | pred = F.interpolate( 228 | pred, size[-2:], 229 | mode='bilinear', align_corners=config.MODEL.ALIGN_CORNERS 230 | ) 231 | 232 | if sv_pred: 233 | sv_path = os.path.join(sv_dir, 'test_results') 234 | if not os.path.exists(sv_path): 235 | os.mkdir(sv_path) 236 | test_dataset.save_pred(pred, sv_path, name) 237 | -------------------------------------------------------------------------------- /hrnet_code/lib/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------ 2 | # Copyright (c) Microsoft 3 | # Licensed under the MIT License. 4 | # Written by Ke Sun (sunk@mail.ustc.edu.cn) 5 | # ------------------------------------------------------------------------------ 6 | 7 | from __future__ import absolute_import 8 | from __future__ import division 9 | from __future__ import print_function 10 | 11 | from .cityscapes import Cityscapes as cityscapes 12 | from .lip import LIP as lip 13 | from .pascal_ctx import PASCALContext as pascal_ctx 14 | from .ade20k import ADE20K as ade20k 15 | from .cocostuff import COCOStuff as cocostuff -------------------------------------------------------------------------------- /hrnet_code/lib/datasets/ade20k.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------ 2 | # Copyright (c) Microsoft 3 | # Licensed under the MIT License. 4 | # Written by Ke Sun (sunk@mail.ustc.edu.cn) 5 | # ------------------------------------------------------------------------------ 6 | 7 | import os 8 | 9 | import cv2 10 | import numpy as np 11 | 12 | import torch 13 | from torch.nn import functional as F 14 | from PIL import Image 15 | 16 | from .base_dataset import BaseDataset 17 | 18 | 19 | class ADE20K(BaseDataset): 20 | def __init__(self, 21 | root, 22 | list_path, 23 | num_samples=None, 24 | num_classes=150, 25 | multi_scale=True, 26 | flip=True, 27 | ignore_label=-1, 28 | base_size=520, 29 | crop_size=(520, 520), 30 | downsample_rate=1, 31 | scale_factor=11, 32 | mean=[0.485, 0.456, 0.406], 33 | std=[0.229, 0.224, 0.225]): 34 | 35 | super(ADE20K, self).__init__(ignore_label, base_size, 36 | crop_size, downsample_rate, scale_factor, mean, std) 37 | 38 | self.root = root 39 | self.num_classes = num_classes 40 | self.list_path = list_path 41 | self.class_weights = None 42 | 43 | self.multi_scale = multi_scale 44 | self.flip = flip 45 | self.img_list = [line.strip().split() for line in open(root+list_path)] 46 | 47 | self.files = self.read_files() 48 | if num_samples: 49 | self.files = self.files[:num_samples] 50 | 51 | def read_files(self): 52 | files = [] 53 | for item in self.img_list: 54 | image_path, label_path = item 55 | name = os.path.splitext(os.path.basename(label_path))[0] 56 | sample = { 57 | 'img': image_path, 58 | 'label': label_path, 59 | 'name': name 60 | } 61 | files.append(sample) 62 | return files 63 | 64 | def resize_image(self, image, label, size): 65 | image = cv2.resize(image, size, interpolation=cv2.INTER_LINEAR) 66 | label = cv2.resize(label, size, interpolation=cv2.INTER_NEAREST) 67 | return image, label 68 | 69 | def __getitem__(self, index): 70 | item = self.files[index] 71 | name = item["name"] 72 | image_path = os.path.join(self.root, 'ade20k', item['img']) 73 | label_path = os.path.join(self.root, 'ade20k', item['label']) 74 | image = cv2.imread( 75 | image_path, 76 | cv2.IMREAD_COLOR 77 | ) 78 | label = np.array( 79 | Image.open(label_path).convert('P') 80 | ) 81 | label = self.reduce_zero_label(label) 82 | size = label.shape 83 | 84 | if 'testval' in self.list_path: 85 | image = self.resize_short_length( 86 | image, 87 | short_length=self.base_size, 88 | fit_stride=8 89 | ) 90 | image = self.input_transform(image) 91 | image = image.transpose((2, 0, 1)) 92 | 93 | return image.copy(), label.copy(), np.array(size), name 94 | 95 | if 'val' in self.list_path: 96 | image, label = self.resize_short_length( 97 | image, 98 | label=label, 99 | short_length=self.base_size, 100 | fit_stride=8 101 | ) 102 | image, label = self.rand_crop(image, label) 103 | image = self.input_transform(image) 104 | image = image.transpose((2, 0, 1)) 105 | 106 | return image.copy(), label.copy(), np.array(size), name 107 | 108 | image, label = self.resize_short_length(image, label, short_length=self.base_size) 109 | image, label = self.gen_sample(image, label, self.multi_scale, self.flip) 110 | 111 | return image.copy(), label.copy(), np.array(size), name -------------------------------------------------------------------------------- /hrnet_code/lib/datasets/cityscapes.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------ 2 | # Copyright (c) Microsoft 3 | # Licensed under the MIT License. 4 | # Written by Ke Sun (sunk@mail.ustc.edu.cn) 5 | # ------------------------------------------------------------------------------ 6 | 7 | import os 8 | 9 | import cv2 10 | import numpy as np 11 | from PIL import Image 12 | 13 | import torch 14 | from torch.nn import functional as F 15 | 16 | from .base_dataset import BaseDataset 17 | 18 | class Cityscapes(BaseDataset): 19 | def __init__(self, 20 | root, 21 | list_path, 22 | num_samples=None, 23 | num_classes=19, 24 | multi_scale=True, 25 | flip=True, 26 | ignore_label=-1, 27 | base_size=2048, 28 | crop_size=(512, 1024), 29 | downsample_rate=1, 30 | scale_factor=16, 31 | mean=[0.485, 0.456, 0.406], 32 | std=[0.229, 0.224, 0.225]): 33 | 34 | super(Cityscapes, self).__init__(ignore_label, base_size, 35 | crop_size, downsample_rate, scale_factor, mean, std,) 36 | 37 | self.root = root 38 | self.list_path = list_path 39 | self.num_classes = num_classes 40 | 41 | self.multi_scale = multi_scale 42 | self.flip = flip 43 | 44 | self.img_list = [line.strip().split() for line in open(root+list_path)] 45 | 46 | self.files = self.read_files() 47 | if num_samples: 48 | self.files = self.files[:num_samples] 49 | 50 | self.label_mapping = {-1: ignore_label, 0: ignore_label, 51 | 1: ignore_label, 2: ignore_label, 52 | 3: ignore_label, 4: ignore_label, 53 | 5: ignore_label, 6: ignore_label, 54 | 7: 0, 8: 1, 9: ignore_label, 55 | 10: ignore_label, 11: 2, 12: 3, 56 | 13: 4, 14: ignore_label, 15: ignore_label, 57 | 16: ignore_label, 17: 5, 18: ignore_label, 58 | 19: 6, 20: 7, 21: 8, 22: 9, 23: 10, 24: 11, 59 | 25: 12, 26: 13, 27: 14, 28: 15, 60 | 29: ignore_label, 30: ignore_label, 61 | 31: 16, 32: 17, 33: 18} 62 | self.class_weights = torch.FloatTensor([0.8373, 0.918, 0.866, 1.0345, 63 | 1.0166, 0.9969, 0.9754, 1.0489, 64 | 0.8786, 1.0023, 0.9539, 0.9843, 65 | 1.1116, 0.9037, 1.0865, 1.0955, 66 | 1.0865, 1.1529, 1.0507]).cuda() 67 | 68 | def read_files(self): 69 | files = [] 70 | if 'test' in self.list_path: 71 | for item in self.img_list: 72 | image_path = item 73 | name = os.path.splitext(os.path.basename(image_path[0]))[0] 74 | files.append({ 75 | "img": image_path[0], 76 | "name": name, 77 | }) 78 | else: 79 | for item in self.img_list: 80 | image_path, label_path = item 81 | name = os.path.splitext(os.path.basename(label_path))[0] 82 | files.append({ 83 | "img": image_path, 84 | "label": label_path, 85 | "name": name, 86 | "weight": 1 87 | }) 88 | return files 89 | 90 | def convert_label(self, label, inverse=False): 91 | temp = label.copy() 92 | if inverse: 93 | for v, k in self.label_mapping.items(): 94 | label[temp == k] = v 95 | else: 96 | for k, v in self.label_mapping.items(): 97 | label[temp == k] = v 98 | return label 99 | 100 | def __getitem__(self, index): 101 | item = self.files[index] 102 | name = item["name"] 103 | image = cv2.imread(os.path.join(self.root,'cityscapes',item["img"]), 104 | cv2.IMREAD_COLOR) 105 | size = image.shape 106 | 107 | if 'test' in self.list_path: 108 | image = self.input_transform(image) 109 | image = image.transpose((2, 0, 1)) 110 | 111 | return image.copy(), np.array(size), name 112 | 113 | label = cv2.imread(os.path.join(self.root,'cityscapes',item["label"]), 114 | cv2.IMREAD_GRAYSCALE) 115 | label = self.convert_label(label) 116 | 117 | image, label = self.gen_sample(image, label, 118 | self.multi_scale, self.flip) 119 | 120 | return image.copy(), label.copy(), np.array(size), name 121 | 122 | def multi_scale_inference(self, config, model, image, scales=[1], flip=False): 123 | batch, _, ori_height, ori_width = image.size() 124 | assert batch == 1, "only supporting batchsize 1." 125 | image = image.numpy()[0].transpose((1,2,0)).copy() 126 | stride_h = np.int(self.crop_size[0] * 1.0) 127 | stride_w = np.int(self.crop_size[1] * 1.0) 128 | final_pred = torch.zeros([1, self.num_classes, 129 | ori_height,ori_width]).cuda() 130 | for scale in scales: 131 | new_img = self.multi_scale_aug(image=image, 132 | rand_scale=scale, 133 | rand_crop=False) 134 | height, width = new_img.shape[:-1] 135 | 136 | if scale <= 1.0: 137 | new_img = new_img.transpose((2, 0, 1)) 138 | new_img = np.expand_dims(new_img, axis=0) 139 | new_img = torch.from_numpy(new_img) 140 | preds = self.inference(config, model, new_img, flip) 141 | preds = preds[:, :, 0:height, 0:width] 142 | else: 143 | new_h, new_w = new_img.shape[:-1] 144 | rows = np.int(np.ceil(1.0 * (new_h - 145 | self.crop_size[0]) / stride_h)) + 1 146 | cols = np.int(np.ceil(1.0 * (new_w - 147 | self.crop_size[1]) / stride_w)) + 1 148 | preds = torch.zeros([1, self.num_classes, 149 | new_h,new_w]).cuda() 150 | count = torch.zeros([1,1, new_h, new_w]).cuda() 151 | 152 | for r in range(rows): 153 | for c in range(cols): 154 | h0 = r * stride_h 155 | w0 = c * stride_w 156 | h1 = min(h0 + self.crop_size[0], new_h) 157 | w1 = min(w0 + self.crop_size[1], new_w) 158 | h0 = max(int(h1 - self.crop_size[0]), 0) 159 | w0 = max(int(w1 - self.crop_size[1]), 0) 160 | crop_img = new_img[h0:h1, w0:w1, :] 161 | crop_img = crop_img.transpose((2, 0, 1)) 162 | crop_img = np.expand_dims(crop_img, axis=0) 163 | crop_img = torch.from_numpy(crop_img) 164 | pred = self.inference(config, model, crop_img, flip) 165 | preds[:,:,h0:h1,w0:w1] += pred[:,:, 0:h1-h0, 0:w1-w0] 166 | count[:,:,h0:h1,w0:w1] += 1 167 | preds = preds / count 168 | preds = preds[:,:,:height,:width] 169 | 170 | preds = F.interpolate( 171 | preds, (ori_height, ori_width), 172 | mode='bilinear', align_corners=config.MODEL.ALIGN_CORNERS 173 | ) 174 | final_pred += preds 175 | return final_pred 176 | 177 | def get_palette(self, n): 178 | palette = [0] * (n * 3) 179 | for j in range(0, n): 180 | lab = j 181 | palette[j * 3 + 0] = 0 182 | palette[j * 3 + 1] = 0 183 | palette[j * 3 + 2] = 0 184 | i = 0 185 | while lab: 186 | palette[j * 3 + 0] |= (((lab >> 0) & 1) << (7 - i)) 187 | palette[j * 3 + 1] |= (((lab >> 1) & 1) << (7 - i)) 188 | palette[j * 3 + 2] |= (((lab >> 2) & 1) << (7 - i)) 189 | i += 1 190 | lab >>= 3 191 | return palette 192 | 193 | def save_pred(self, preds, sv_path, name): 194 | palette = self.get_palette(256) 195 | preds = np.asarray(np.argmax(preds.cpu(), axis=1), dtype=np.uint8) 196 | for i in range(preds.shape[0]): 197 | pred = self.convert_label(preds[i], inverse=True) 198 | save_img = Image.fromarray(pred) 199 | save_img.putpalette(palette) 200 | save_img.save(os.path.join(sv_path, name[i]+'.png')) 201 | 202 | 203 | 204 | -------------------------------------------------------------------------------- /hrnet_code/lib/datasets/cocostuff.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------ 2 | # Copyright (c) Microsoft 3 | # Licensed under the MIT License. 4 | # Written by Ke Sun (sunk@mail.ustc.edu.cn) 5 | # ------------------------------------------------------------------------------ 6 | 7 | import os 8 | 9 | import cv2 10 | import numpy as np 11 | 12 | import torch 13 | from torch.nn import functional as F 14 | from PIL import Image 15 | 16 | from .base_dataset import BaseDataset 17 | 18 | 19 | class COCOStuff(BaseDataset): 20 | def __init__(self, 21 | root, 22 | list_path, 23 | num_samples=None, 24 | num_classes=171, 25 | multi_scale=True, 26 | flip=True, 27 | ignore_label=-1, 28 | base_size=520, 29 | crop_size=(520, 520), 30 | downsample_rate=1, 31 | scale_factor=11, 32 | mean=[0.485, 0.456, 0.406], 33 | std=[0.229, 0.224, 0.225]): 34 | 35 | super(COCOStuff, self).__init__(ignore_label, base_size, 36 | crop_size, downsample_rate, scale_factor, mean, std) 37 | 38 | self.root = root 39 | self.num_classes = num_classes 40 | self.list_path = list_path 41 | self.class_weights = None 42 | 43 | self.multi_scale = multi_scale 44 | self.flip = flip 45 | self.img_list = [line.strip().split() for line in open(root+list_path)] 46 | 47 | self.files = self.read_files() 48 | if num_samples: 49 | self.files = self.files[:num_samples] 50 | self.mapping = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 13, 14, 15, 16, 17, 18, 19, 20, 51 | 21, 22, 23, 24, 25, 27, 28, 31, 32, 33, 34, 35, 36, 37, 38, 39, 52 | 40, 41, 42, 43, 44, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 53 | 59, 60, 61, 62, 63, 64, 65, 67, 70, 72, 73, 74, 75, 76, 77, 54 | 78, 79, 80, 81, 82, 84, 85, 86, 87, 88, 89, 90, 92, 93, 94, 95, 96, 55 | 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111, 112, 56 | 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127, 128, 57 | 129, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141, 142, 143, 144, 58 | 145, 146, 147, 148, 149, 150, 151, 152, 153, 154, 155, 156, 157, 158, 159, 160, 59 | 161, 162, 163, 164, 165, 166, 167, 168, 169, 170, 171, 172, 173, 174, 175, 176, 60 | 177, 178, 179, 180, 181, 182] 61 | 62 | def read_files(self): 63 | files = [] 64 | for item in self.img_list: 65 | image_path, label_path = item 66 | name = os.path.splitext(os.path.basename(label_path))[0] 67 | sample = { 68 | 'img': image_path, 69 | 'label': label_path, 70 | 'name': name 71 | } 72 | files.append(sample) 73 | return files 74 | 75 | def encode_label(self, labelmap): 76 | ret = np.ones_like(labelmap) * 255 77 | for idx, label in enumerate(self.mapping): 78 | ret[labelmap == label] = idx 79 | 80 | return ret 81 | 82 | def resize_image(self, image, label, size): 83 | image = cv2.resize(image, size, interpolation=cv2.INTER_LINEAR) 84 | label = cv2.resize(label, size, interpolation=cv2.INTER_NEAREST) 85 | return image, label 86 | 87 | def __getitem__(self, index): 88 | item = self.files[index] 89 | name = item["name"] 90 | image_path = os.path.join(self.root, 'cocostuff', item['img']) 91 | label_path = os.path.join(self.root, 'cocostuff', item['label']) 92 | image = cv2.imread( 93 | image_path, 94 | cv2.IMREAD_COLOR 95 | ) 96 | label = np.array( 97 | Image.open(label_path).convert('P') 98 | ) 99 | label = self.encode_label(label) 100 | label = self.reduce_zero_label(label) 101 | size = label.shape 102 | 103 | if 'testval' in self.list_path: 104 | image, border_padding = self.resize_short_length( 105 | image, 106 | short_length=self.base_size, 107 | fit_stride=8, 108 | return_padding=True 109 | ) 110 | image = self.input_transform(image) 111 | image = image.transpose((2, 0, 1)) 112 | 113 | return image.copy(), label.copy(), np.array(size), name, border_padding 114 | 115 | if 'val' in self.list_path: 116 | image, label = self.resize_short_length( 117 | image, 118 | label=label, 119 | short_length=self.base_size, 120 | fit_stride=8 121 | ) 122 | image, label = self.rand_crop(image, label) 123 | image = self.input_transform(image) 124 | image = image.transpose((2, 0, 1)) 125 | 126 | return image.copy(), label.copy(), np.array(size), name 127 | 128 | image, label = self.resize_short_length(image, label, short_length=self.base_size) 129 | image, label = self.gen_sample(image, label, self.multi_scale, self.flip) 130 | 131 | return image.copy(), label.copy(), np.array(size), name -------------------------------------------------------------------------------- /hrnet_code/lib/datasets/lip.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------ 2 | # Copyright (c) Microsoft 3 | # Licensed under the MIT License. 4 | # Written by Ke Sun (sunk@mail.ustc.edu.cn) 5 | # ------------------------------------------------------------------------------ 6 | 7 | import os 8 | 9 | import cv2 10 | import numpy as np 11 | 12 | import torch 13 | from torch.nn import functional as F 14 | 15 | from .base_dataset import BaseDataset 16 | 17 | 18 | class LIP(BaseDataset): 19 | def __init__(self, 20 | root, 21 | list_path, 22 | num_samples=None, 23 | num_classes=20, 24 | multi_scale=True, 25 | flip=True, 26 | ignore_label=-1, 27 | base_size=473, 28 | crop_size=(473, 473), 29 | downsample_rate=1, 30 | scale_factor=11, 31 | mean=[0.485, 0.456, 0.406], 32 | std=[0.229, 0.224, 0.225]): 33 | 34 | super(LIP, self).__init__(ignore_label, base_size, 35 | crop_size, downsample_rate, scale_factor, mean, std) 36 | 37 | self.root = root 38 | self.num_classes = num_classes 39 | self.list_path = list_path 40 | self.class_weights = None 41 | 42 | self.multi_scale = multi_scale 43 | self.flip = flip 44 | self.img_list = [line.strip().split() for line in open(root+list_path)] 45 | 46 | self.files = self.read_files() 47 | if num_samples: 48 | self.files = self.files[:num_samples] 49 | 50 | def read_files(self): 51 | files = [] 52 | for item in self.img_list: 53 | if 'train' in self.list_path: 54 | image_path, label_path, label_rev_path, _ = item 55 | name = os.path.splitext(os.path.basename(label_path))[0] 56 | sample = {"img": image_path, 57 | "label": label_path, 58 | "label_rev": label_rev_path, 59 | "name": name, } 60 | elif 'val' in self.list_path: 61 | image_path, label_path = item 62 | name = os.path.splitext(os.path.basename(label_path))[0] 63 | sample = {"img": image_path, 64 | "label": label_path, 65 | "name": name, } 66 | else: 67 | raise NotImplementedError('Unknown subset.') 68 | files.append(sample) 69 | return files 70 | 71 | def resize_image(self, image, label, size): 72 | image = cv2.resize(image, size, interpolation=cv2.INTER_LINEAR) 73 | label = cv2.resize(label, size, interpolation=cv2.INTER_NEAREST) 74 | return image, label 75 | 76 | def __getitem__(self, index): 77 | item = self.files[index] 78 | name = item["name"] 79 | item["img"] = item["img"].replace( 80 | "train_images", "LIP_Train").replace("val_images", "LIP_Val") 81 | item["label"] = item["label"].replace( 82 | "train_segmentations", "LIP_Train").replace("val_segmentations", "LIP_Val") 83 | image = cv2.imread(os.path.join( 84 | self.root, 'lip/TrainVal_images/', item["img"]), 85 | cv2.IMREAD_COLOR) 86 | label = cv2.imread(os.path.join( 87 | self.root, 'lip/TrainVal_parsing_annotations/', 88 | item["label"]), 89 | cv2.IMREAD_GRAYSCALE) 90 | size = label.shape 91 | 92 | if 'testval' in self.list_path: 93 | image = cv2.resize(image, self.crop_size, 94 | interpolation=cv2.INTER_LINEAR) 95 | image = self.input_transform(image) 96 | image = image.transpose((2, 0, 1)) 97 | 98 | return image.copy(), label.copy(), np.array(size), name 99 | 100 | if self.flip: 101 | flip = np.random.choice(2) * 2 - 1 102 | image = image[:, ::flip, :] 103 | label = label[:, ::flip] 104 | 105 | if flip == -1: 106 | right_idx = [15, 17, 19] 107 | left_idx = [14, 16, 18] 108 | for i in range(0, 3): 109 | right_pos = np.where(label == right_idx[i]) 110 | left_pos = np.where(label == left_idx[i]) 111 | label[right_pos[0], right_pos[1]] = left_idx[i] 112 | label[left_pos[0], left_pos[1]] = right_idx[i] 113 | 114 | image, label = self.resize_image(image, label, self.crop_size) 115 | image, label = self.gen_sample(image, label, 116 | self.multi_scale, False) 117 | 118 | return image.copy(), label.copy(), np.array(size), name 119 | 120 | def inference(self, config, model, image, flip): 121 | size = image.size() 122 | pred = model(image) 123 | if config.MODEL.NUM_OUTPUTS > 1: 124 | pred = pred[config.TEST.OUTPUT_INDEX] 125 | 126 | pred = F.interpolate( 127 | input=pred, size=size[-2:], 128 | mode='bilinear', align_corners=config.MODEL.ALIGN_CORNERS 129 | ) 130 | 131 | if flip: 132 | flip_img = image.numpy()[:, :, :, ::-1] 133 | flip_output = model(torch.from_numpy(flip_img.copy())) 134 | 135 | if config.MODEL.NUM_OUTPUTS > 1: 136 | flip_output = flip_output[config.TEST.OUTPUT_INDEX] 137 | 138 | flip_output = F.interpolate( 139 | input=flip_output, size=size[-2:], 140 | mode='bilinear', align_corners=config.MODEL.ALIGN_CORNERS 141 | ) 142 | 143 | flip_output = flip_output.cpu() 144 | flip_pred = flip_output.cpu().numpy().copy() 145 | flip_pred[:, 14, :, :] = flip_output[:, 15, :, :] 146 | flip_pred[:, 15, :, :] = flip_output[:, 14, :, :] 147 | flip_pred[:, 16, :, :] = flip_output[:, 17, :, :] 148 | flip_pred[:, 17, :, :] = flip_output[:, 16, :, :] 149 | flip_pred[:, 18, :, :] = flip_output[:, 19, :, :] 150 | flip_pred[:, 19, :, :] = flip_output[:, 18, :, :] 151 | flip_pred = torch.from_numpy( 152 | flip_pred[:, :, :, ::-1].copy()).cuda() 153 | pred += flip_pred 154 | pred = pred * 0.5 155 | return pred.exp() 156 | -------------------------------------------------------------------------------- /hrnet_code/lib/datasets/pascal_ctx.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------ 2 | # Copyright (c) Microsoft 3 | # Licensed under the MIT License. 4 | # Written by Ke Sun (sunk@mail.ustc.edu.cn) 5 | # Referring to the implementation in 6 | # https://github.com/zhanghang1989/PyTorch-Encoding 7 | # ------------------------------------------------------------------------------ 8 | 9 | import os 10 | 11 | import cv2 12 | import numpy as np 13 | from PIL import Image 14 | 15 | import torch 16 | 17 | from .base_dataset import BaseDataset 18 | 19 | class PASCALContext(BaseDataset): 20 | def __init__(self, 21 | root, 22 | list_path, 23 | num_samples=None, 24 | num_classes=59, 25 | multi_scale=True, 26 | flip=True, 27 | ignore_label=-1, 28 | base_size=520, 29 | crop_size=(480, 480), 30 | downsample_rate=1, 31 | scale_factor=16, 32 | mean=[0.485, 0.456, 0.406], 33 | std=[0.229, 0.224, 0.225],): 34 | 35 | super(PASCALContext, self).__init__(ignore_label, base_size, 36 | crop_size, downsample_rate, scale_factor, mean, std) 37 | 38 | self.root = os.path.join(root, 'pascal_ctx/VOCdevkit/VOC2010') 39 | self.split = list_path 40 | 41 | self.num_classes = num_classes 42 | self.class_weights = None 43 | 44 | self.multi_scale = multi_scale 45 | self.flip = flip 46 | self.crop_size = crop_size 47 | 48 | # prepare data 49 | annots = os.path.join(self.root, 'trainval_merged.json') 50 | img_path = os.path.join(self.root, 'JPEGImages') 51 | from detail import Detail 52 | if 'val' in self.split: 53 | self.detail = Detail(annots, img_path, 'val') 54 | mask_file = os.path.join(self.root, 'val.pth') 55 | elif 'train' in self.split: 56 | self.mode = 'train' 57 | self.detail = Detail(annots, img_path, 'train') 58 | mask_file = os.path.join(self.root, 'train.pth') 59 | else: 60 | raise NotImplementedError('only supporting train and val set.') 61 | self.files = self.detail.getImgs() 62 | 63 | # generate masks 64 | self._mapping = np.sort(np.array([ 65 | 0, 2, 259, 260, 415, 324, 9, 258, 144, 18, 19, 22, 66 | 23, 397, 25, 284, 158, 159, 416, 33, 162, 420, 454, 295, 296, 67 | 427, 44, 45, 46, 308, 59, 440, 445, 31, 232, 65, 354, 424, 68 | 68, 326, 72, 458, 34, 207, 80, 355, 85, 347, 220, 349, 360, 69 | 98, 187, 104, 105, 366, 189, 368, 113, 115])) 70 | 71 | self._key = np.array(range(len(self._mapping))).astype('uint8') 72 | 73 | print('mask_file:', mask_file) 74 | if os.path.exists(mask_file): 75 | self.masks = torch.load(mask_file) 76 | else: 77 | self.masks = self._preprocess(mask_file) 78 | 79 | def _class_to_index(self, mask): 80 | # assert the values 81 | values = np.unique(mask) 82 | for i in range(len(values)): 83 | assert(values[i] in self._mapping) 84 | index = np.digitize(mask.ravel(), self._mapping, right=True) 85 | return self._key[index].reshape(mask.shape) 86 | 87 | def _preprocess(self, mask_file): 88 | masks = {} 89 | print("Preprocessing mask, this will take a while." + \ 90 | "But don't worry, it only run once for each split.") 91 | for i in range(len(self.files)): 92 | img_id = self.files[i] 93 | mask = Image.fromarray(self._class_to_index( 94 | self.detail.getMask(img_id))) 95 | masks[img_id['image_id']] = mask 96 | torch.save(masks, mask_file) 97 | return masks 98 | 99 | def __getitem__(self, index): 100 | item = self.files[index] 101 | name = item['file_name'] 102 | img_id = item['image_id'] 103 | 104 | image = cv2.imread(os.path.join(self.detail.img_folder,name), 105 | cv2.IMREAD_COLOR) 106 | label = np.asarray(self.masks[img_id],dtype=np.int) 107 | size = image.shape 108 | 109 | if self.split == 'val': 110 | image = cv2.resize(image, self.crop_size, 111 | interpolation = cv2.INTER_LINEAR) 112 | image = self.input_transform(image) 113 | image = image.transpose((2, 0, 1)) 114 | 115 | label = cv2.resize(label, self.crop_size, 116 | interpolation=cv2.INTER_NEAREST) 117 | label = self.label_transform(label) 118 | elif self.split == 'testval': 119 | # evaluate model on val dataset 120 | image = self.input_transform(image) 121 | image = image.transpose((2, 0, 1)) 122 | label = self.label_transform(label) 123 | else: 124 | image, label = self.gen_sample(image, label, 125 | self.multi_scale, self.flip) 126 | 127 | return image.copy(), label.copy(), np.array(size), name 128 | 129 | def label_transform(self, label): 130 | if self.num_classes == 59: 131 | # background is ignored 132 | label = np.array(label).astype('int32') - 1 133 | label[label==-2] = -1 134 | else: 135 | label = np.array(label).astype('int32') 136 | return label 137 | -------------------------------------------------------------------------------- /hrnet_code/lib/models/__init__.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------ 2 | # Copyright (c) Microsoft 3 | # Licensed under the MIT License. 4 | # Written by Ke Sun (sunk@mail.ustc.edu.cn) 5 | # ------------------------------------------------------------------------------ 6 | 7 | from __future__ import absolute_import 8 | from __future__ import division 9 | from __future__ import print_function 10 | 11 | import models.seg_hrnet 12 | import models.seg_hrnet_ocr -------------------------------------------------------------------------------- /hrnet_code/lib/models/bn_helper.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import functools 3 | 4 | if torch.__version__.startswith('0'): 5 | from .sync_bn.inplace_abn.bn import InPlaceABNSync 6 | BatchNorm2d = functools.partial(InPlaceABNSync, activation='none') 7 | BatchNorm2d_class = InPlaceABNSync 8 | relu_inplace = False 9 | else: 10 | # BatchNorm2d_class = BatchNorm2d = torch.nn.SyncBatchNorm 11 | BatchNorm2d_class = BatchNorm2d = torch.nn.BatchNorm2d 12 | relu_inplace = True -------------------------------------------------------------------------------- /hrnet_code/lib/models/sync_bn/LICENSE: -------------------------------------------------------------------------------- 1 | 2 | BSD 3-Clause License 3 | 4 | Copyright (c) 2017, mapillary 5 | All rights reserved. 6 | 7 | Redistribution and use in source and binary forms, with or without 8 | modification, are permitted provided that the following conditions are met: 9 | 10 | * Redistributions of source code must retain the above copyright notice, this 11 | list of conditions and the following disclaimer. 12 | 13 | * Redistributions in binary form must reproduce the above copyright notice, 14 | this list of conditions and the following disclaimer in the documentation 15 | and/or other materials provided with the distribution. 16 | 17 | * Neither the name of the copyright holder nor the names of its 18 | contributors may be used to endorse or promote products derived from 19 | this software without specific prior written permission. 20 | 21 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 22 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 23 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 24 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 25 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 26 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 27 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 28 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 29 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 30 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 31 | -------------------------------------------------------------------------------- /hrnet_code/lib/models/sync_bn/__init__.py: -------------------------------------------------------------------------------- 1 | from .inplace_abn import bn -------------------------------------------------------------------------------- /hrnet_code/lib/models/sync_bn/inplace_abn/__init__.py: -------------------------------------------------------------------------------- 1 | from .bn import ABN, InPlaceABN, InPlaceABNSync 2 | from .functions import ACT_RELU, ACT_LEAKY_RELU, ACT_ELU, ACT_NONE 3 | -------------------------------------------------------------------------------- /hrnet_code/lib/models/sync_bn/inplace_abn/bn.py: -------------------------------------------------------------------------------- 1 | import os, sys 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as functional 5 | 6 | try: 7 | from queue import Queue 8 | except ImportError: 9 | from Queue import Queue 10 | 11 | BASE_DIR = os.path.dirname(os.path.abspath(__file__)) 12 | sys.path.append(BASE_DIR) 13 | sys.path.append(os.path.join(BASE_DIR, '../src')) 14 | from functions import * 15 | 16 | 17 | class ABN(nn.Module): 18 | """Activated Batch Normalization 19 | 20 | This gathers a `BatchNorm2d` and an activation function in a single module 21 | """ 22 | 23 | def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True, activation="leaky_relu", slope=0.01): 24 | """Creates an Activated Batch Normalization module 25 | 26 | Parameters 27 | ---------- 28 | num_features : int 29 | Number of feature channels in the input and output. 30 | eps : float 31 | Small constant to prevent numerical issues. 32 | momentum : float 33 | Momentum factor applied to compute running statistics as. 34 | affine : bool 35 | If `True` apply learned scale and shift transformation after normalization. 36 | activation : str 37 | Name of the activation functions, one of: `leaky_relu`, `elu` or `none`. 38 | slope : float 39 | Negative slope for the `leaky_relu` activation. 40 | """ 41 | super(ABN, self).__init__() 42 | self.num_features = num_features 43 | self.affine = affine 44 | self.eps = eps 45 | self.momentum = momentum 46 | self.activation = activation 47 | self.slope = slope 48 | if self.affine: 49 | self.weight = nn.Parameter(torch.ones(num_features)) 50 | self.bias = nn.Parameter(torch.zeros(num_features)) 51 | else: 52 | self.register_parameter('weight', None) 53 | self.register_parameter('bias', None) 54 | self.register_buffer('running_mean', torch.zeros(num_features)) 55 | self.register_buffer('running_var', torch.ones(num_features)) 56 | self.reset_parameters() 57 | 58 | def reset_parameters(self): 59 | nn.init.constant_(self.running_mean, 0) 60 | nn.init.constant_(self.running_var, 1) 61 | if self.affine: 62 | nn.init.constant_(self.weight, 1) 63 | nn.init.constant_(self.bias, 0) 64 | 65 | def forward(self, x): 66 | x = functional.batch_norm(x, self.running_mean, self.running_var, self.weight, self.bias, 67 | self.training, self.momentum, self.eps) 68 | 69 | if self.activation == ACT_RELU: 70 | return functional.relu(x, inplace=True) 71 | elif self.activation == ACT_LEAKY_RELU: 72 | return functional.leaky_relu(x, negative_slope=self.slope, inplace=True) 73 | elif self.activation == ACT_ELU: 74 | return functional.elu(x, inplace=True) 75 | else: 76 | return x 77 | 78 | def __repr__(self): 79 | rep = '{name}({num_features}, eps={eps}, momentum={momentum},' \ 80 | ' affine={affine}, activation={activation}' 81 | if self.activation == "leaky_relu": 82 | rep += ', slope={slope})' 83 | else: 84 | rep += ')' 85 | return rep.format(name=self.__class__.__name__, **self.__dict__) 86 | 87 | 88 | class InPlaceABN(ABN): 89 | """InPlace Activated Batch Normalization""" 90 | 91 | def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True, activation="leaky_relu", slope=0.01): 92 | """Creates an InPlace Activated Batch Normalization module 93 | 94 | Parameters 95 | ---------- 96 | num_features : int 97 | Number of feature channels in the input and output. 98 | eps : float 99 | Small constant to prevent numerical issues. 100 | momentum : float 101 | Momentum factor applied to compute running statistics as. 102 | affine : bool 103 | If `True` apply learned scale and shift transformation after normalization. 104 | activation : str 105 | Name of the activation functions, one of: `leaky_relu`, `elu` or `none`. 106 | slope : float 107 | Negative slope for the `leaky_relu` activation. 108 | """ 109 | super(InPlaceABN, self).__init__(num_features, eps, momentum, affine, activation, slope) 110 | 111 | def forward(self, x): 112 | return inplace_abn(x, self.weight, self.bias, self.running_mean, self.running_var, 113 | self.training, self.momentum, self.eps, self.activation, self.slope) 114 | 115 | 116 | class InPlaceABNSync(ABN): 117 | """InPlace Activated Batch Normalization with cross-GPU synchronization 118 | 119 | This assumes that it will be replicated across GPUs using the same mechanism as in `nn.DataParallel`. 120 | """ 121 | 122 | def __init__(self, num_features, devices=None, eps=1e-5, momentum=0.1, affine=True, activation="leaky_relu", 123 | slope=0.01): 124 | """Creates a synchronized, InPlace Activated Batch Normalization module 125 | 126 | Parameters 127 | ---------- 128 | num_features : int 129 | Number of feature channels in the input and output. 130 | devices : list of int or None 131 | IDs of the GPUs that will run the replicas of this module. 132 | eps : float 133 | Small constant to prevent numerical issues. 134 | momentum : float 135 | Momentum factor applied to compute running statistics as. 136 | affine : bool 137 | If `True` apply learned scale and shift transformation after normalization. 138 | activation : str 139 | Name of the activation functions, one of: `leaky_relu`, `elu` or `none`. 140 | slope : float 141 | Negative slope for the `leaky_relu` activation. 142 | """ 143 | super(InPlaceABNSync, self).__init__(num_features, eps, momentum, affine, activation, slope) 144 | self.devices = devices if devices else list(range(torch.cuda.device_count())) 145 | 146 | # Initialize queues 147 | self.worker_ids = self.devices[1:] 148 | self.master_queue = Queue(len(self.worker_ids)) 149 | self.worker_queues = [Queue(1) for _ in self.worker_ids] 150 | 151 | def forward(self, x): 152 | if x.get_device() == self.devices[0]: 153 | # Master mode 154 | extra = { 155 | "is_master": True, 156 | "master_queue": self.master_queue, 157 | "worker_queues": self.worker_queues, 158 | "worker_ids": self.worker_ids 159 | } 160 | else: 161 | # Worker mode 162 | extra = { 163 | "is_master": False, 164 | "master_queue": self.master_queue, 165 | "worker_queue": self.worker_queues[self.worker_ids.index(x.get_device())] 166 | } 167 | 168 | return inplace_abn_sync(x, self.weight, self.bias, self.running_mean, self.running_var, 169 | extra, self.training, self.momentum, self.eps, self.activation, self.slope) 170 | 171 | def __repr__(self): 172 | rep = '{name}({num_features}, eps={eps}, momentum={momentum},' \ 173 | ' affine={affine}, devices={devices}, activation={activation}' 174 | if self.activation == "leaky_relu": 175 | rep += ', slope={slope})' 176 | else: 177 | rep += ')' 178 | return rep.format(name=self.__class__.__name__, **self.__dict__) 179 | -------------------------------------------------------------------------------- /hrnet_code/lib/models/sync_bn/inplace_abn/functions.py: -------------------------------------------------------------------------------- 1 | from os import path 2 | 3 | import torch.autograd as autograd 4 | import torch.cuda.comm as comm 5 | from torch.autograd.function import once_differentiable 6 | from torch.utils.cpp_extension import load 7 | 8 | _src_path = path.join(path.dirname(path.abspath(__file__)), "src") 9 | _backend = load(name="inplace_abn", 10 | extra_cflags=["-O3"], 11 | sources=[path.join(_src_path, f) for f in [ 12 | "inplace_abn.cpp", 13 | "inplace_abn_cpu.cpp", 14 | "inplace_abn_cuda.cu" 15 | ]], 16 | extra_cuda_cflags=["--expt-extended-lambda"]) 17 | 18 | # Activation names 19 | ACT_RELU = "relu" 20 | ACT_LEAKY_RELU = "leaky_relu" 21 | ACT_ELU = "elu" 22 | ACT_NONE = "none" 23 | 24 | 25 | def _check(fn, *args, **kwargs): 26 | success = fn(*args, **kwargs) 27 | if not success: 28 | raise RuntimeError("CUDA Error encountered in {}".format(fn)) 29 | 30 | 31 | def _broadcast_shape(x): 32 | out_size = [] 33 | for i, s in enumerate(x.size()): 34 | if i != 1: 35 | out_size.append(1) 36 | else: 37 | out_size.append(s) 38 | return out_size 39 | 40 | 41 | def _reduce(x): 42 | if len(x.size()) == 2: 43 | return x.sum(dim=0) 44 | else: 45 | n, c = x.size()[0:2] 46 | return x.contiguous().view((n, c, -1)).sum(2).sum(0) 47 | 48 | 49 | def _count_samples(x): 50 | count = 1 51 | for i, s in enumerate(x.size()): 52 | if i != 1: 53 | count *= s 54 | return count 55 | 56 | 57 | def _act_forward(ctx, x): 58 | if ctx.activation == ACT_LEAKY_RELU: 59 | _backend.leaky_relu_forward(x, ctx.slope) 60 | elif ctx.activation == ACT_ELU: 61 | _backend.elu_forward(x) 62 | elif ctx.activation == ACT_NONE: 63 | pass 64 | 65 | 66 | def _act_backward(ctx, x, dx): 67 | if ctx.activation == ACT_LEAKY_RELU: 68 | _backend.leaky_relu_backward(x, dx, ctx.slope) 69 | elif ctx.activation == ACT_ELU: 70 | _backend.elu_backward(x, dx) 71 | elif ctx.activation == ACT_NONE: 72 | pass 73 | 74 | 75 | class InPlaceABN(autograd.Function): 76 | @staticmethod 77 | def forward(ctx, x, weight, bias, running_mean, running_var, 78 | training=True, momentum=0.1, eps=1e-05, activation=ACT_LEAKY_RELU, slope=0.01): 79 | # Save context 80 | ctx.training = training 81 | ctx.momentum = momentum 82 | ctx.eps = eps 83 | ctx.activation = activation 84 | ctx.slope = slope 85 | ctx.affine = weight is not None and bias is not None 86 | 87 | # Prepare inputs 88 | count = _count_samples(x) 89 | x = x.contiguous() 90 | weight = weight.contiguous() if ctx.affine else x.new_empty(0) 91 | bias = bias.contiguous() if ctx.affine else x.new_empty(0) 92 | 93 | if ctx.training: 94 | mean, var = _backend.mean_var(x) 95 | 96 | # Update running stats 97 | running_mean.mul_((1 - ctx.momentum)).add_(ctx.momentum * mean) 98 | running_var.mul_((1 - ctx.momentum)).add_(ctx.momentum * var * count / (count - 1)) 99 | 100 | # Mark in-place modified tensors 101 | ctx.mark_dirty(x, running_mean, running_var) 102 | else: 103 | mean, var = running_mean.contiguous(), running_var.contiguous() 104 | ctx.mark_dirty(x) 105 | 106 | # BN forward + activation 107 | _backend.forward(x, mean, var, weight, bias, ctx.affine, ctx.eps) 108 | _act_forward(ctx, x) 109 | 110 | # Output 111 | ctx.var = var 112 | ctx.save_for_backward(x, var, weight, bias) 113 | return x 114 | 115 | @staticmethod 116 | @once_differentiable 117 | def backward(ctx, dz): 118 | z, var, weight, bias = ctx.saved_tensors 119 | dz = dz.contiguous() 120 | 121 | # Undo activation 122 | _act_backward(ctx, z, dz) 123 | 124 | if ctx.training: 125 | edz, eydz = _backend.edz_eydz(z, dz, weight, bias, ctx.affine, ctx.eps) 126 | else: 127 | # TODO: implement simplified CUDA backward for inference mode 128 | edz = dz.new_zeros(dz.size(1)) 129 | eydz = dz.new_zeros(dz.size(1)) 130 | 131 | dx, dweight, dbias = _backend.backward(z, dz, var, weight, bias, edz, eydz, ctx.affine, ctx.eps) 132 | dweight = dweight if ctx.affine else None 133 | dbias = dbias if ctx.affine else None 134 | 135 | return dx, dweight, dbias, None, None, None, None, None, None, None 136 | 137 | 138 | class InPlaceABNSync(autograd.Function): 139 | @classmethod 140 | def forward(cls, ctx, x, weight, bias, running_mean, running_var, 141 | extra, training=True, momentum=0.1, eps=1e-05, activation=ACT_LEAKY_RELU, slope=0.01): 142 | # Save context 143 | cls._parse_extra(ctx, extra) 144 | ctx.training = training 145 | ctx.momentum = momentum 146 | ctx.eps = eps 147 | ctx.activation = activation 148 | ctx.slope = slope 149 | ctx.affine = weight is not None and bias is not None 150 | 151 | # Prepare inputs 152 | count = _count_samples(x) * (ctx.master_queue.maxsize + 1) 153 | x = x.contiguous() 154 | weight = weight.contiguous() if ctx.affine else x.new_empty(0) 155 | bias = bias.contiguous() if ctx.affine else x.new_empty(0) 156 | 157 | if ctx.training: 158 | mean, var = _backend.mean_var(x) 159 | 160 | if ctx.is_master: 161 | means, vars = [mean.unsqueeze(0)], [var.unsqueeze(0)] 162 | for _ in range(ctx.master_queue.maxsize): 163 | mean_w, var_w = ctx.master_queue.get() 164 | ctx.master_queue.task_done() 165 | means.append(mean_w.unsqueeze(0)) 166 | vars.append(var_w.unsqueeze(0)) 167 | 168 | means = comm.gather(means) 169 | vars = comm.gather(vars) 170 | 171 | mean = means.mean(0) 172 | var = (vars + (mean - means) ** 2).mean(0) 173 | 174 | tensors = comm.broadcast_coalesced((mean, var), [mean.get_device()] + ctx.worker_ids) 175 | for ts, queue in zip(tensors[1:], ctx.worker_queues): 176 | queue.put(ts) 177 | else: 178 | ctx.master_queue.put((mean, var)) 179 | mean, var = ctx.worker_queue.get() 180 | ctx.worker_queue.task_done() 181 | 182 | # Update running stats 183 | running_mean.mul_((1 - ctx.momentum)).add_(ctx.momentum * mean) 184 | running_var.mul_((1 - ctx.momentum)).add_(ctx.momentum * var * count / (count - 1)) 185 | 186 | # Mark in-place modified tensors 187 | ctx.mark_dirty(x, running_mean, running_var) 188 | else: 189 | mean, var = running_mean.contiguous(), running_var.contiguous() 190 | ctx.mark_dirty(x) 191 | 192 | # BN forward + activation 193 | _backend.forward(x, mean, var, weight, bias, ctx.affine, ctx.eps) 194 | _act_forward(ctx, x) 195 | 196 | # Output 197 | ctx.var = var 198 | ctx.save_for_backward(x, var, weight, bias) 199 | return x 200 | 201 | @staticmethod 202 | @once_differentiable 203 | def backward(ctx, dz): 204 | z, var, weight, bias = ctx.saved_tensors 205 | dz = dz.contiguous() 206 | 207 | # Undo activation 208 | _act_backward(ctx, z, dz) 209 | 210 | if ctx.training: 211 | edz, eydz = _backend.edz_eydz(z, dz, weight, bias, ctx.affine, ctx.eps) 212 | 213 | if ctx.is_master: 214 | edzs, eydzs = [edz], [eydz] 215 | for _ in range(len(ctx.worker_queues)): 216 | edz_w, eydz_w = ctx.master_queue.get() 217 | ctx.master_queue.task_done() 218 | edzs.append(edz_w) 219 | eydzs.append(eydz_w) 220 | 221 | edz = comm.reduce_add(edzs) / (ctx.master_queue.maxsize + 1) 222 | eydz = comm.reduce_add(eydzs) / (ctx.master_queue.maxsize + 1) 223 | 224 | tensors = comm.broadcast_coalesced((edz, eydz), [edz.get_device()] + ctx.worker_ids) 225 | for ts, queue in zip(tensors[1:], ctx.worker_queues): 226 | queue.put(ts) 227 | else: 228 | ctx.master_queue.put((edz, eydz)) 229 | edz, eydz = ctx.worker_queue.get() 230 | ctx.worker_queue.task_done() 231 | else: 232 | edz = dz.new_zeros(dz.size(1)) 233 | eydz = dz.new_zeros(dz.size(1)) 234 | 235 | dx, dweight, dbias = _backend.backward(z, dz, var, weight, bias, edz, eydz, ctx.affine, ctx.eps) 236 | dweight = dweight if ctx.affine else None 237 | dbias = dbias if ctx.affine else None 238 | 239 | return dx, dweight, dbias, None, None, None, None, None, None, None, None 240 | 241 | @staticmethod 242 | def _parse_extra(ctx, extra): 243 | ctx.is_master = extra["is_master"] 244 | if ctx.is_master: 245 | ctx.master_queue = extra["master_queue"] 246 | ctx.worker_queues = extra["worker_queues"] 247 | ctx.worker_ids = extra["worker_ids"] 248 | else: 249 | ctx.master_queue = extra["master_queue"] 250 | ctx.worker_queue = extra["worker_queue"] 251 | 252 | 253 | inplace_abn = InPlaceABN.apply 254 | inplace_abn_sync = InPlaceABNSync.apply 255 | 256 | __all__ = ["inplace_abn", "inplace_abn_sync", "ACT_RELU", "ACT_LEAKY_RELU", "ACT_ELU", "ACT_NONE"] 257 | -------------------------------------------------------------------------------- /hrnet_code/lib/models/sync_bn/inplace_abn/src/common.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | 5 | /* 6 | * General settings 7 | */ 8 | const int WARP_SIZE = 32; 9 | const int MAX_BLOCK_SIZE = 512; 10 | 11 | template 12 | struct Pair { 13 | T v1, v2; 14 | __device__ Pair() {} 15 | __device__ Pair(T _v1, T _v2) : v1(_v1), v2(_v2) {} 16 | __device__ Pair(T v) : v1(v), v2(v) {} 17 | __device__ Pair(int v) : v1(v), v2(v) {} 18 | __device__ Pair &operator+=(const Pair &a) { 19 | v1 += a.v1; 20 | v2 += a.v2; 21 | return *this; 22 | } 23 | }; 24 | 25 | /* 26 | * Utility functions 27 | */ 28 | template 29 | __device__ __forceinline__ T WARP_SHFL_XOR(T value, int laneMask, int width = warpSize, 30 | unsigned int mask = 0xffffffff) { 31 | #if CUDART_VERSION >= 9000 32 | return __shfl_xor_sync(mask, value, laneMask, width); 33 | #else 34 | return __shfl_xor(value, laneMask, width); 35 | #endif 36 | } 37 | 38 | __device__ __forceinline__ int getMSB(int val) { return 31 - __clz(val); } 39 | 40 | static int getNumThreads(int nElem) { 41 | int threadSizes[5] = {32, 64, 128, 256, MAX_BLOCK_SIZE}; 42 | for (int i = 0; i != 5; ++i) { 43 | if (nElem <= threadSizes[i]) { 44 | return threadSizes[i]; 45 | } 46 | } 47 | return MAX_BLOCK_SIZE; 48 | } 49 | 50 | template 51 | static __device__ __forceinline__ T warpSum(T val) { 52 | #if __CUDA_ARCH__ >= 300 53 | for (int i = 0; i < getMSB(WARP_SIZE); ++i) { 54 | val += WARP_SHFL_XOR(val, 1 << i, WARP_SIZE); 55 | } 56 | #else 57 | __shared__ T values[MAX_BLOCK_SIZE]; 58 | values[threadIdx.x] = val; 59 | __threadfence_block(); 60 | const int base = (threadIdx.x / WARP_SIZE) * WARP_SIZE; 61 | for (int i = 1; i < WARP_SIZE; i++) { 62 | val += values[base + ((i + threadIdx.x) % WARP_SIZE)]; 63 | } 64 | #endif 65 | return val; 66 | } 67 | 68 | template 69 | static __device__ __forceinline__ Pair warpSum(Pair value) { 70 | value.v1 = warpSum(value.v1); 71 | value.v2 = warpSum(value.v2); 72 | return value; 73 | } 74 | 75 | template 76 | __device__ T reduce(Op op, int plane, int N, int C, int S) { 77 | T sum = (T)0; 78 | for (int batch = 0; batch < N; ++batch) { 79 | for (int x = threadIdx.x; x < S; x += blockDim.x) { 80 | sum += op(batch, plane, x); 81 | } 82 | } 83 | 84 | // sum over NumThreads within a warp 85 | sum = warpSum(sum); 86 | 87 | // 'transpose', and reduce within warp again 88 | __shared__ T shared[32]; 89 | __syncthreads(); 90 | if (threadIdx.x % WARP_SIZE == 0) { 91 | shared[threadIdx.x / WARP_SIZE] = sum; 92 | } 93 | if (threadIdx.x >= blockDim.x / WARP_SIZE && threadIdx.x < WARP_SIZE) { 94 | // zero out the other entries in shared 95 | shared[threadIdx.x] = (T)0; 96 | } 97 | __syncthreads(); 98 | if (threadIdx.x / WARP_SIZE == 0) { 99 | sum = warpSum(shared[threadIdx.x]); 100 | if (threadIdx.x == 0) { 101 | shared[0] = sum; 102 | } 103 | } 104 | __syncthreads(); 105 | 106 | // Everyone picks it up, should be broadcast into the whole gradInput 107 | return shared[0]; 108 | } -------------------------------------------------------------------------------- /hrnet_code/lib/models/sync_bn/inplace_abn/src/inplace_abn.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | #include 4 | 5 | #include "inplace_abn.h" 6 | 7 | std::vector mean_var(at::Tensor x) { 8 | if (x.is_cuda()) { 9 | return mean_var_cuda(x); 10 | } else { 11 | return mean_var_cpu(x); 12 | } 13 | } 14 | 15 | at::Tensor forward(at::Tensor x, at::Tensor mean, at::Tensor var, at::Tensor weight, at::Tensor bias, 16 | bool affine, float eps) { 17 | if (x.is_cuda()) { 18 | return forward_cuda(x, mean, var, weight, bias, affine, eps); 19 | } else { 20 | return forward_cpu(x, mean, var, weight, bias, affine, eps); 21 | } 22 | } 23 | 24 | std::vector edz_eydz(at::Tensor z, at::Tensor dz, at::Tensor weight, at::Tensor bias, 25 | bool affine, float eps) { 26 | if (z.is_cuda()) { 27 | return edz_eydz_cuda(z, dz, weight, bias, affine, eps); 28 | } else { 29 | return edz_eydz_cpu(z, dz, weight, bias, affine, eps); 30 | } 31 | } 32 | 33 | std::vector backward(at::Tensor z, at::Tensor dz, at::Tensor var, at::Tensor weight, at::Tensor bias, 34 | at::Tensor edz, at::Tensor eydz, bool affine, float eps) { 35 | if (z.is_cuda()) { 36 | return backward_cuda(z, dz, var, weight, bias, edz, eydz, affine, eps); 37 | } else { 38 | return backward_cpu(z, dz, var, weight, bias, edz, eydz, affine, eps); 39 | } 40 | } 41 | 42 | void leaky_relu_forward(at::Tensor z, float slope) { 43 | at::leaky_relu_(z, slope); 44 | } 45 | 46 | void leaky_relu_backward(at::Tensor z, at::Tensor dz, float slope) { 47 | if (z.is_cuda()) { 48 | return leaky_relu_backward_cuda(z, dz, slope); 49 | } else { 50 | return leaky_relu_backward_cpu(z, dz, slope); 51 | } 52 | } 53 | 54 | void elu_forward(at::Tensor z) { 55 | at::elu_(z); 56 | } 57 | 58 | void elu_backward(at::Tensor z, at::Tensor dz) { 59 | if (z.is_cuda()) { 60 | return elu_backward_cuda(z, dz); 61 | } else { 62 | return elu_backward_cpu(z, dz); 63 | } 64 | } 65 | 66 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 67 | m.def("mean_var", &mean_var, "Mean and variance computation"); 68 | m.def("forward", &forward, "In-place forward computation"); 69 | m.def("edz_eydz", &edz_eydz, "First part of backward computation"); 70 | m.def("backward", &backward, "Second part of backward computation"); 71 | m.def("leaky_relu_forward", &leaky_relu_forward, "Leaky relu forward computation"); 72 | m.def("leaky_relu_backward", &leaky_relu_backward, "Leaky relu backward computation and inversion"); 73 | m.def("elu_forward", &elu_forward, "Elu forward computation"); 74 | m.def("elu_backward", &elu_backward, "Elu backward computation and inversion"); 75 | } -------------------------------------------------------------------------------- /hrnet_code/lib/models/sync_bn/inplace_abn/src/inplace_abn.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | 5 | #include 6 | 7 | std::vector mean_var_cpu(at::Tensor x); 8 | std::vector mean_var_cuda(at::Tensor x); 9 | 10 | at::Tensor forward_cpu(at::Tensor x, at::Tensor mean, at::Tensor var, at::Tensor weight, at::Tensor bias, 11 | bool affine, float eps); 12 | at::Tensor forward_cuda(at::Tensor x, at::Tensor mean, at::Tensor var, at::Tensor weight, at::Tensor bias, 13 | bool affine, float eps); 14 | 15 | std::vector edz_eydz_cpu(at::Tensor z, at::Tensor dz, at::Tensor weight, at::Tensor bias, 16 | bool affine, float eps); 17 | std::vector edz_eydz_cuda(at::Tensor z, at::Tensor dz, at::Tensor weight, at::Tensor bias, 18 | bool affine, float eps); 19 | 20 | std::vector backward_cpu(at::Tensor z, at::Tensor dz, at::Tensor var, at::Tensor weight, at::Tensor bias, 21 | at::Tensor edz, at::Tensor eydz, bool affine, float eps); 22 | std::vector backward_cuda(at::Tensor z, at::Tensor dz, at::Tensor var, at::Tensor weight, at::Tensor bias, 23 | at::Tensor edz, at::Tensor eydz, bool affine, float eps); 24 | 25 | void leaky_relu_backward_cpu(at::Tensor z, at::Tensor dz, float slope); 26 | void leaky_relu_backward_cuda(at::Tensor z, at::Tensor dz, float slope); 27 | 28 | void elu_backward_cpu(at::Tensor z, at::Tensor dz); 29 | void elu_backward_cuda(at::Tensor z, at::Tensor dz); -------------------------------------------------------------------------------- /hrnet_code/lib/models/sync_bn/inplace_abn/src/inplace_abn_cpu.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | #include 4 | 5 | #include "inplace_abn.h" 6 | 7 | at::Tensor reduce_sum(at::Tensor x) { 8 | if (x.ndimension() == 2) { 9 | return x.sum(0); 10 | } else { 11 | auto x_view = x.view({x.size(0), x.size(1), -1}); 12 | return x_view.sum(-1).sum(0); 13 | } 14 | } 15 | 16 | at::Tensor broadcast_to(at::Tensor v, at::Tensor x) { 17 | if (x.ndimension() == 2) { 18 | return v; 19 | } else { 20 | std::vector broadcast_size = {1, -1}; 21 | for (int64_t i = 2; i < x.ndimension(); ++i) 22 | broadcast_size.push_back(1); 23 | 24 | return v.view(broadcast_size); 25 | } 26 | } 27 | 28 | int64_t count(at::Tensor x) { 29 | int64_t count = x.size(0); 30 | for (int64_t i = 2; i < x.ndimension(); ++i) 31 | count *= x.size(i); 32 | 33 | return count; 34 | } 35 | 36 | at::Tensor invert_affine(at::Tensor z, at::Tensor weight, at::Tensor bias, bool affine, float eps) { 37 | if (affine) { 38 | return (z - broadcast_to(bias, z)) / broadcast_to(at::abs(weight) + eps, z); 39 | } else { 40 | return z; 41 | } 42 | } 43 | 44 | std::vector mean_var_cpu(at::Tensor x) { 45 | auto num = count(x); 46 | auto mean = reduce_sum(x) / num; 47 | auto diff = x - broadcast_to(mean, x); 48 | auto var = reduce_sum(diff.pow(2)) / num; 49 | 50 | return {mean, var}; 51 | } 52 | 53 | at::Tensor forward_cpu(at::Tensor x, at::Tensor mean, at::Tensor var, at::Tensor weight, at::Tensor bias, 54 | bool affine, float eps) { 55 | auto gamma = affine ? at::abs(weight) + eps : at::ones_like(var); 56 | auto mul = at::rsqrt(var + eps) * gamma; 57 | 58 | x.sub_(broadcast_to(mean, x)); 59 | x.mul_(broadcast_to(mul, x)); 60 | if (affine) x.add_(broadcast_to(bias, x)); 61 | 62 | return x; 63 | } 64 | 65 | std::vector edz_eydz_cpu(at::Tensor z, at::Tensor dz, at::Tensor weight, at::Tensor bias, 66 | bool affine, float eps) { 67 | auto edz = reduce_sum(dz); 68 | auto y = invert_affine(z, weight, bias, affine, eps); 69 | auto eydz = reduce_sum(y * dz); 70 | 71 | return {edz, eydz}; 72 | } 73 | 74 | std::vector backward_cpu(at::Tensor z, at::Tensor dz, at::Tensor var, at::Tensor weight, at::Tensor bias, 75 | at::Tensor edz, at::Tensor eydz, bool affine, float eps) { 76 | auto y = invert_affine(z, weight, bias, affine, eps); 77 | auto mul = affine ? at::rsqrt(var + eps) * (at::abs(weight) + eps) : at::rsqrt(var + eps); 78 | 79 | auto num = count(z); 80 | auto dx = (dz - broadcast_to(edz / num, dz) - y * broadcast_to(eydz / num, dz)) * broadcast_to(mul, dz); 81 | 82 | auto dweight = at::empty(z.type(), {0}); 83 | auto dbias = at::empty(z.type(), {0}); 84 | if (affine) { 85 | dweight = eydz * at::sign(weight); 86 | dbias = edz; 87 | } 88 | 89 | return {dx, dweight, dbias}; 90 | } 91 | 92 | void leaky_relu_backward_cpu(at::Tensor z, at::Tensor dz, float slope) { 93 | AT_DISPATCH_FLOATING_TYPES(z.type(), "leaky_relu_backward_cpu", ([&] { 94 | int64_t count = z.numel(); 95 | auto *_z = z.data(); 96 | auto *_dz = dz.data(); 97 | 98 | for (int64_t i = 0; i < count; ++i) { 99 | if (_z[i] < 0) { 100 | _z[i] *= 1 / slope; 101 | _dz[i] *= slope; 102 | } 103 | } 104 | })); 105 | } 106 | 107 | void elu_backward_cpu(at::Tensor z, at::Tensor dz) { 108 | AT_DISPATCH_FLOATING_TYPES(z.type(), "elu_backward_cpu", ([&] { 109 | int64_t count = z.numel(); 110 | auto *_z = z.data(); 111 | auto *_dz = dz.data(); 112 | 113 | for (int64_t i = 0; i < count; ++i) { 114 | if (_z[i] < 0) { 115 | _z[i] = log1p(_z[i]); 116 | _dz[i] *= (_z[i] + 1.f); 117 | } 118 | } 119 | })); 120 | } -------------------------------------------------------------------------------- /hrnet_code/lib/models/sync_bn/inplace_abn/src/inplace_abn_cuda.cu: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | #include 4 | #include 5 | 6 | #include 7 | 8 | #include "common.h" 9 | #include "inplace_abn.h" 10 | 11 | // Checks 12 | #ifndef AT_CHECK 13 | #define AT_CHECK AT_ASSERT 14 | #endif 15 | #define CHECK_CUDA(x) AT_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor") 16 | #define CHECK_CONTIGUOUS(x) AT_CHECK(x.is_contiguous(), #x " must be contiguous") 17 | #define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x) 18 | 19 | // Utilities 20 | void get_dims(at::Tensor x, int64_t& num, int64_t& chn, int64_t& sp) { 21 | num = x.size(0); 22 | chn = x.size(1); 23 | sp = 1; 24 | for (int64_t i = 2; i < x.ndimension(); ++i) 25 | sp *= x.size(i); 26 | } 27 | 28 | // Operations for reduce 29 | template 30 | struct SumOp { 31 | __device__ SumOp(const T *t, int c, int s) 32 | : tensor(t), chn(c), sp(s) {} 33 | __device__ __forceinline__ T operator()(int batch, int plane, int n) { 34 | return tensor[(batch * chn + plane) * sp + n]; 35 | } 36 | const T *tensor; 37 | const int chn; 38 | const int sp; 39 | }; 40 | 41 | template 42 | struct VarOp { 43 | __device__ VarOp(T m, const T *t, int c, int s) 44 | : mean(m), tensor(t), chn(c), sp(s) {} 45 | __device__ __forceinline__ T operator()(int batch, int plane, int n) { 46 | T val = tensor[(batch * chn + plane) * sp + n]; 47 | return (val - mean) * (val - mean); 48 | } 49 | const T mean; 50 | const T *tensor; 51 | const int chn; 52 | const int sp; 53 | }; 54 | 55 | template 56 | struct GradOp { 57 | __device__ GradOp(T _weight, T _bias, const T *_z, const T *_dz, int c, int s) 58 | : weight(_weight), bias(_bias), z(_z), dz(_dz), chn(c), sp(s) {} 59 | __device__ __forceinline__ Pair operator()(int batch, int plane, int n) { 60 | T _y = (z[(batch * chn + plane) * sp + n] - bias) / weight; 61 | T _dz = dz[(batch * chn + plane) * sp + n]; 62 | return Pair(_dz, _y * _dz); 63 | } 64 | const T weight; 65 | const T bias; 66 | const T *z; 67 | const T *dz; 68 | const int chn; 69 | const int sp; 70 | }; 71 | 72 | /*********** 73 | * mean_var 74 | ***********/ 75 | 76 | template 77 | __global__ void mean_var_kernel(const T *x, T *mean, T *var, int num, int chn, int sp) { 78 | int plane = blockIdx.x; 79 | T norm = T(1) / T(num * sp); 80 | 81 | T _mean = reduce>(SumOp(x, chn, sp), plane, num, chn, sp) * norm; 82 | __syncthreads(); 83 | T _var = reduce>(VarOp(_mean, x, chn, sp), plane, num, chn, sp) * norm; 84 | 85 | if (threadIdx.x == 0) { 86 | mean[plane] = _mean; 87 | var[plane] = _var; 88 | } 89 | } 90 | 91 | std::vector mean_var_cuda(at::Tensor x) { 92 | CHECK_INPUT(x); 93 | 94 | // Extract dimensions 95 | int64_t num, chn, sp; 96 | get_dims(x, num, chn, sp); 97 | 98 | // Prepare output tensors 99 | auto mean = at::empty(x.type(), {chn}); 100 | auto var = at::empty(x.type(), {chn}); 101 | 102 | // Run kernel 103 | dim3 blocks(chn); 104 | dim3 threads(getNumThreads(sp)); 105 | AT_DISPATCH_FLOATING_TYPES(x.type(), "mean_var_cuda", ([&] { 106 | mean_var_kernel<<>>( 107 | x.data(), 108 | mean.data(), 109 | var.data(), 110 | num, chn, sp); 111 | })); 112 | 113 | return {mean, var}; 114 | } 115 | 116 | /********** 117 | * forward 118 | **********/ 119 | 120 | template 121 | __global__ void forward_kernel(T *x, const T *mean, const T *var, const T *weight, const T *bias, 122 | bool affine, float eps, int num, int chn, int sp) { 123 | int plane = blockIdx.x; 124 | 125 | T _mean = mean[plane]; 126 | T _var = var[plane]; 127 | T _weight = affine ? abs(weight[plane]) + eps : T(1); 128 | T _bias = affine ? bias[plane] : T(0); 129 | 130 | T mul = rsqrt(_var + eps) * _weight; 131 | 132 | for (int batch = 0; batch < num; ++batch) { 133 | for (int n = threadIdx.x; n < sp; n += blockDim.x) { 134 | T _x = x[(batch * chn + plane) * sp + n]; 135 | T _y = (_x - _mean) * mul + _bias; 136 | 137 | x[(batch * chn + plane) * sp + n] = _y; 138 | } 139 | } 140 | } 141 | 142 | at::Tensor forward_cuda(at::Tensor x, at::Tensor mean, at::Tensor var, at::Tensor weight, at::Tensor bias, 143 | bool affine, float eps) { 144 | CHECK_INPUT(x); 145 | CHECK_INPUT(mean); 146 | CHECK_INPUT(var); 147 | CHECK_INPUT(weight); 148 | CHECK_INPUT(bias); 149 | 150 | // Extract dimensions 151 | int64_t num, chn, sp; 152 | get_dims(x, num, chn, sp); 153 | 154 | // Run kernel 155 | dim3 blocks(chn); 156 | dim3 threads(getNumThreads(sp)); 157 | AT_DISPATCH_FLOATING_TYPES(x.type(), "forward_cuda", ([&] { 158 | forward_kernel<<>>( 159 | x.data(), 160 | mean.data(), 161 | var.data(), 162 | weight.data(), 163 | bias.data(), 164 | affine, eps, num, chn, sp); 165 | })); 166 | 167 | return x; 168 | } 169 | 170 | /*********** 171 | * edz_eydz 172 | ***********/ 173 | 174 | template 175 | __global__ void edz_eydz_kernel(const T *z, const T *dz, const T *weight, const T *bias, 176 | T *edz, T *eydz, bool affine, float eps, int num, int chn, int sp) { 177 | int plane = blockIdx.x; 178 | 179 | T _weight = affine ? abs(weight[plane]) + eps : 1.f; 180 | T _bias = affine ? bias[plane] : 0.f; 181 | 182 | Pair res = reduce, GradOp>(GradOp(_weight, _bias, z, dz, chn, sp), plane, num, chn, sp); 183 | __syncthreads(); 184 | 185 | if (threadIdx.x == 0) { 186 | edz[plane] = res.v1; 187 | eydz[plane] = res.v2; 188 | } 189 | } 190 | 191 | std::vector edz_eydz_cuda(at::Tensor z, at::Tensor dz, at::Tensor weight, at::Tensor bias, 192 | bool affine, float eps) { 193 | CHECK_INPUT(z); 194 | CHECK_INPUT(dz); 195 | CHECK_INPUT(weight); 196 | CHECK_INPUT(bias); 197 | 198 | // Extract dimensions 199 | int64_t num, chn, sp; 200 | get_dims(z, num, chn, sp); 201 | 202 | auto edz = at::empty(z.type(), {chn}); 203 | auto eydz = at::empty(z.type(), {chn}); 204 | 205 | // Run kernel 206 | dim3 blocks(chn); 207 | dim3 threads(getNumThreads(sp)); 208 | AT_DISPATCH_FLOATING_TYPES(z.type(), "edz_eydz_cuda", ([&] { 209 | edz_eydz_kernel<<>>( 210 | z.data(), 211 | dz.data(), 212 | weight.data(), 213 | bias.data(), 214 | edz.data(), 215 | eydz.data(), 216 | affine, eps, num, chn, sp); 217 | })); 218 | 219 | return {edz, eydz}; 220 | } 221 | 222 | /*********** 223 | * backward 224 | ***********/ 225 | 226 | template 227 | __global__ void backward_kernel(const T *z, const T *dz, const T *var, const T *weight, const T *bias, const T *edz, 228 | const T *eydz, T *dx, T *dweight, T *dbias, 229 | bool affine, float eps, int num, int chn, int sp) { 230 | int plane = blockIdx.x; 231 | 232 | T _weight = affine ? abs(weight[plane]) + eps : 1.f; 233 | T _bias = affine ? bias[plane] : 0.f; 234 | T _var = var[plane]; 235 | T _edz = edz[plane]; 236 | T _eydz = eydz[plane]; 237 | 238 | T _mul = _weight * rsqrt(_var + eps); 239 | T count = T(num * sp); 240 | 241 | for (int batch = 0; batch < num; ++batch) { 242 | for (int n = threadIdx.x; n < sp; n += blockDim.x) { 243 | T _dz = dz[(batch * chn + plane) * sp + n]; 244 | T _y = (z[(batch * chn + plane) * sp + n] - _bias) / _weight; 245 | 246 | dx[(batch * chn + plane) * sp + n] = (_dz - _edz / count - _y * _eydz / count) * _mul; 247 | } 248 | } 249 | 250 | if (threadIdx.x == 0) { 251 | if (affine) { 252 | dweight[plane] = weight[plane] > 0 ? _eydz : -_eydz; 253 | dbias[plane] = _edz; 254 | } 255 | } 256 | } 257 | 258 | std::vector backward_cuda(at::Tensor z, at::Tensor dz, at::Tensor var, at::Tensor weight, at::Tensor bias, 259 | at::Tensor edz, at::Tensor eydz, bool affine, float eps) { 260 | CHECK_INPUT(z); 261 | CHECK_INPUT(dz); 262 | CHECK_INPUT(var); 263 | CHECK_INPUT(weight); 264 | CHECK_INPUT(bias); 265 | CHECK_INPUT(edz); 266 | CHECK_INPUT(eydz); 267 | 268 | // Extract dimensions 269 | int64_t num, chn, sp; 270 | get_dims(z, num, chn, sp); 271 | 272 | auto dx = at::zeros_like(z); 273 | auto dweight = at::zeros_like(weight); 274 | auto dbias = at::zeros_like(bias); 275 | 276 | // Run kernel 277 | dim3 blocks(chn); 278 | dim3 threads(getNumThreads(sp)); 279 | AT_DISPATCH_FLOATING_TYPES(z.type(), "backward_cuda", ([&] { 280 | backward_kernel<<>>( 281 | z.data(), 282 | dz.data(), 283 | var.data(), 284 | weight.data(), 285 | bias.data(), 286 | edz.data(), 287 | eydz.data(), 288 | dx.data(), 289 | dweight.data(), 290 | dbias.data(), 291 | affine, eps, num, chn, sp); 292 | })); 293 | 294 | return {dx, dweight, dbias}; 295 | } 296 | 297 | /************** 298 | * activations 299 | **************/ 300 | 301 | template 302 | inline void leaky_relu_backward_impl(T *z, T *dz, float slope, int64_t count) { 303 | // Create thrust pointers 304 | thrust::device_ptr th_z = thrust::device_pointer_cast(z); 305 | thrust::device_ptr th_dz = thrust::device_pointer_cast(dz); 306 | 307 | thrust::transform_if(th_dz, th_dz + count, th_z, th_dz, 308 | [slope] __device__ (const T& dz) { return dz * slope; }, 309 | [] __device__ (const T& z) { return z < 0; }); 310 | thrust::transform_if(th_z, th_z + count, th_z, 311 | [slope] __device__ (const T& z) { return z / slope; }, 312 | [] __device__ (const T& z) { return z < 0; }); 313 | } 314 | 315 | void leaky_relu_backward_cuda(at::Tensor z, at::Tensor dz, float slope) { 316 | CHECK_INPUT(z); 317 | CHECK_INPUT(dz); 318 | 319 | int64_t count = z.numel(); 320 | 321 | AT_DISPATCH_FLOATING_TYPES(z.type(), "leaky_relu_backward_cuda", ([&] { 322 | leaky_relu_backward_impl(z.data(), dz.data(), slope, count); 323 | })); 324 | } 325 | 326 | template 327 | inline void elu_backward_impl(T *z, T *dz, int64_t count) { 328 | // Create thrust pointers 329 | thrust::device_ptr th_z = thrust::device_pointer_cast(z); 330 | thrust::device_ptr th_dz = thrust::device_pointer_cast(dz); 331 | 332 | thrust::transform_if(th_dz, th_dz + count, th_z, th_z, th_dz, 333 | [] __device__ (const T& dz, const T& z) { return dz * (z + 1.); }, 334 | [] __device__ (const T& z) { return z < 0; }); 335 | thrust::transform_if(th_z, th_z + count, th_z, 336 | [] __device__ (const T& z) { return log1p(z); }, 337 | [] __device__ (const T& z) { return z < 0; }); 338 | } 339 | 340 | void elu_backward_cuda(at::Tensor z, at::Tensor dz) { 341 | CHECK_INPUT(z); 342 | CHECK_INPUT(dz); 343 | 344 | int64_t count = z.numel(); 345 | 346 | AT_DISPATCH_FLOATING_TYPES(z.type(), "leaky_relu_backward_cuda", ([&] { 347 | elu_backward_impl(z.data(), dz.data(), count); 348 | })); 349 | } 350 | -------------------------------------------------------------------------------- /hrnet_code/lib/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SHI-Labs/Rethinking-Text-Segmentation/223f7ffc822c345ce1a7c0eb3d4fac58a43d6a3a/hrnet_code/lib/utils/__init__.py -------------------------------------------------------------------------------- /hrnet_code/lib/utils/distributed.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------ 2 | # Copyright (c) Microsoft 3 | # Licensed under the MIT License. 4 | # Written by Jingyi Xie (hsfzxjy@gmail.com) 5 | # ------------------------------------------------------------------------------ 6 | 7 | import torch 8 | import torch.distributed as torch_dist 9 | 10 | def is_distributed(): 11 | return torch_dist.is_initialized() 12 | 13 | def get_world_size(): 14 | if not torch_dist.is_initialized(): 15 | return 1 16 | return torch_dist.get_world_size() 17 | 18 | def get_rank(): 19 | if not torch_dist.is_initialized(): 20 | return 0 21 | return torch_dist.get_rank() -------------------------------------------------------------------------------- /hrnet_code/lib/utils/modelsummary.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------ 2 | # Copyright (c) Microsoft 3 | # Licensed under the MIT License. 4 | # Written by Bin Xiao (Bin.Xiao@microsoft.com) 5 | # Modified by Ke Sun (sunk@mail.ustc.edu.cn) 6 | # ------------------------------------------------------------------------------ 7 | 8 | from __future__ import absolute_import 9 | from __future__ import division 10 | from __future__ import print_function 11 | 12 | import os 13 | import logging 14 | from collections import namedtuple 15 | 16 | import torch 17 | import torch.nn as nn 18 | 19 | def get_model_summary(model, *input_tensors, item_length=26, verbose=False): 20 | """ 21 | :param model: 22 | :param input_tensors: 23 | :param item_length: 24 | :return: 25 | """ 26 | 27 | summary = [] 28 | 29 | ModuleDetails = namedtuple( 30 | "Layer", ["name", "input_size", "output_size", "num_parameters", "multiply_adds"]) 31 | hooks = [] 32 | layer_instances = {} 33 | 34 | def add_hooks(module): 35 | 36 | def hook(module, input, output): 37 | class_name = str(module.__class__.__name__) 38 | 39 | instance_index = 1 40 | if class_name not in layer_instances: 41 | layer_instances[class_name] = instance_index 42 | else: 43 | instance_index = layer_instances[class_name] + 1 44 | layer_instances[class_name] = instance_index 45 | 46 | layer_name = class_name + "_" + str(instance_index) 47 | 48 | params = 0 49 | 50 | if class_name.find("Conv") != -1 or class_name.find("BatchNorm") != -1 or \ 51 | class_name.find("Linear") != -1: 52 | for param_ in module.parameters(): 53 | params += param_.view(-1).size(0) 54 | 55 | flops = "Not Available" 56 | if class_name.find("Conv") != -1 and hasattr(module, "weight"): 57 | flops = ( 58 | torch.prod( 59 | torch.LongTensor(list(module.weight.data.size()))) * 60 | torch.prod( 61 | torch.LongTensor(list(output.size())[2:]))).item() 62 | elif isinstance(module, nn.Linear): 63 | flops = (torch.prod(torch.LongTensor(list(output.size()))) \ 64 | * input[0].size(1)).item() 65 | 66 | if isinstance(input[0], list): 67 | input = input[0] 68 | if isinstance(output, list): 69 | output = output[0] 70 | 71 | summary.append( 72 | ModuleDetails( 73 | name=layer_name, 74 | input_size=list(input[0].size()), 75 | output_size=list(output.size()), 76 | num_parameters=params, 77 | multiply_adds=flops) 78 | ) 79 | 80 | if not isinstance(module, nn.ModuleList) \ 81 | and not isinstance(module, nn.Sequential) \ 82 | and module != model: 83 | hooks.append(module.register_forward_hook(hook)) 84 | 85 | model.eval() 86 | model.apply(add_hooks) 87 | 88 | space_len = item_length 89 | 90 | model(*input_tensors) 91 | for hook in hooks: 92 | hook.remove() 93 | 94 | details = '' 95 | if verbose: 96 | details = "Model Summary" + \ 97 | os.linesep + \ 98 | "Name{}Input Size{}Output Size{}Parameters{}Multiply Adds (Flops){}".format( 99 | ' ' * (space_len - len("Name")), 100 | ' ' * (space_len - len("Input Size")), 101 | ' ' * (space_len - len("Output Size")), 102 | ' ' * (space_len - len("Parameters")), 103 | ' ' * (space_len - len("Multiply Adds (Flops)"))) \ 104 | + os.linesep + '-' * space_len * 5 + os.linesep 105 | 106 | params_sum = 0 107 | flops_sum = 0 108 | for layer in summary: 109 | params_sum += layer.num_parameters 110 | if layer.multiply_adds != "Not Available": 111 | flops_sum += layer.multiply_adds 112 | if verbose: 113 | details += "{}{}{}{}{}{}{}{}{}{}".format( 114 | layer.name, 115 | ' ' * (space_len - len(layer.name)), 116 | layer.input_size, 117 | ' ' * (space_len - len(str(layer.input_size))), 118 | layer.output_size, 119 | ' ' * (space_len - len(str(layer.output_size))), 120 | layer.num_parameters, 121 | ' ' * (space_len - len(str(layer.num_parameters))), 122 | layer.multiply_adds, 123 | ' ' * (space_len - len(str(layer.multiply_adds)))) \ 124 | + os.linesep + '-' * space_len * 5 + os.linesep 125 | 126 | details += os.linesep \ 127 | + "Total Parameters: {:,}".format(params_sum) \ 128 | + os.linesep + '-' * space_len * 5 + os.linesep 129 | details += "Total Multiply Adds (For Convolution and Linear Layers only): {:,} GFLOPs".format(flops_sum/(1024**3)) \ 130 | + os.linesep + '-' * space_len * 5 + os.linesep 131 | details += "Number of Layers" + os.linesep 132 | for layer in layer_instances: 133 | details += "{} : {} layers ".format(layer, layer_instances[layer]) 134 | 135 | return details -------------------------------------------------------------------------------- /hrnet_code/lib/utils/utils.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------ 2 | # Copyright (c) Microsoft 3 | # Licensed under the MIT License. 4 | # Written by Ke Sun (sunk@mail.ustc.edu.cn) 5 | # ------------------------------------------------------------------------------ 6 | 7 | from __future__ import absolute_import 8 | from __future__ import division 9 | from __future__ import print_function 10 | 11 | import os 12 | import logging 13 | import time 14 | from pathlib import Path 15 | 16 | import numpy as np 17 | 18 | import torch 19 | import torch.nn as nn 20 | 21 | class FullModel(nn.Module): 22 | """ 23 | Distribute the loss on multi-gpu to reduce 24 | the memory cost in the main gpu. 25 | You can check the following discussion. 26 | https://discuss.pytorch.org/t/dataparallel-imbalanced-memory-usage/22551/21 27 | """ 28 | def __init__(self, model, loss): 29 | super(FullModel, self).__init__() 30 | self.model = model 31 | self.loss = loss 32 | 33 | def forward(self, inputs, labels, *args, **kwargs): 34 | outputs = self.model(inputs, *args, **kwargs) 35 | loss = self.loss(outputs, labels) 36 | return torch.unsqueeze(loss,0), outputs 37 | 38 | class AverageMeter(object): 39 | """Computes and stores the average and current value""" 40 | 41 | def __init__(self): 42 | self.initialized = False 43 | self.val = None 44 | self.avg = None 45 | self.sum = None 46 | self.count = None 47 | 48 | def initialize(self, val, weight): 49 | self.val = val 50 | self.avg = val 51 | self.sum = val * weight 52 | self.count = weight 53 | self.initialized = True 54 | 55 | def update(self, val, weight=1): 56 | if not self.initialized: 57 | self.initialize(val, weight) 58 | else: 59 | self.add(val, weight) 60 | 61 | def add(self, val, weight): 62 | self.val = val 63 | self.sum += val * weight 64 | self.count += weight 65 | self.avg = self.sum / self.count 66 | 67 | def value(self): 68 | return self.val 69 | 70 | def average(self): 71 | return self.avg 72 | 73 | def create_logger(cfg, cfg_name, phase='train'): 74 | root_output_dir = Path(cfg.OUTPUT_DIR) 75 | # set up logger 76 | if not root_output_dir.exists(): 77 | print('=> creating {}'.format(root_output_dir)) 78 | root_output_dir.mkdir() 79 | 80 | dataset = cfg.DATASET.DATASET 81 | model = cfg.MODEL.NAME 82 | cfg_name = os.path.basename(cfg_name).split('.')[0] 83 | 84 | final_output_dir = root_output_dir / dataset / cfg_name 85 | 86 | print('=> creating {}'.format(final_output_dir)) 87 | final_output_dir.mkdir(parents=True, exist_ok=True) 88 | 89 | time_str = time.strftime('%Y-%m-%d-%H-%M') 90 | log_file = '{}_{}_{}.log'.format(cfg_name, time_str, phase) 91 | final_log_file = final_output_dir / log_file 92 | head = '%(asctime)-15s %(message)s' 93 | logging.basicConfig(filename=str(final_log_file), 94 | format=head) 95 | logger = logging.getLogger() 96 | logger.setLevel(logging.INFO) 97 | console = logging.StreamHandler() 98 | logging.getLogger('').addHandler(console) 99 | 100 | tensorboard_log_dir = Path(cfg.LOG_DIR) / dataset / model / \ 101 | (cfg_name + '_' + time_str) 102 | print('=> creating {}'.format(tensorboard_log_dir)) 103 | tensorboard_log_dir.mkdir(parents=True, exist_ok=True) 104 | 105 | return logger, str(final_output_dir), str(tensorboard_log_dir) 106 | 107 | def get_confusion_matrix(label, pred, size, num_class, ignore=-1): 108 | """ 109 | Calcute the confusion matrix by given label and pred 110 | """ 111 | output = pred.cpu().numpy().transpose(0, 2, 3, 1) 112 | seg_pred = np.asarray(np.argmax(output, axis=3), dtype=np.uint8) 113 | seg_gt = np.asarray( 114 | label.cpu().numpy()[:, :size[-2], :size[-1]], dtype=np.int) 115 | 116 | ignore_index = seg_gt != ignore 117 | seg_gt = seg_gt[ignore_index] 118 | seg_pred = seg_pred[ignore_index] 119 | 120 | index = (seg_gt * num_class + seg_pred).astype('int32') 121 | label_count = np.bincount(index) 122 | confusion_matrix = np.zeros((num_class, num_class)) 123 | 124 | for i_label in range(num_class): 125 | for i_pred in range(num_class): 126 | cur_index = i_label * num_class + i_pred 127 | if cur_index < len(label_count): 128 | confusion_matrix[i_label, 129 | i_pred] = label_count[cur_index] 130 | return confusion_matrix 131 | 132 | def adjust_learning_rate(optimizer, base_lr, max_iters, 133 | cur_iters, power=0.9, nbb_mult=10): 134 | lr = base_lr*((1-float(cur_iters)/max_iters)**(power)) 135 | optimizer.param_groups[0]['lr'] = lr 136 | if len(optimizer.param_groups) == 2: 137 | optimizer.param_groups[1]['lr'] = lr * nbb_mult 138 | return lr -------------------------------------------------------------------------------- /hrnet_code/requirements.txt: -------------------------------------------------------------------------------- 1 | EasyDict==1.7 2 | opencv-python==3.4.2.17 3 | shapely==1.6.4 4 | Cython 5 | scipy 6 | pandas 7 | pyyaml 8 | json_tricks 9 | scikit-image 10 | yacs>=0.1.5 11 | tensorboardX>=1.6 12 | tqdm 13 | ninja 14 | -------------------------------------------------------------------------------- /hrnet_code/tools/_init_paths.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------ 2 | # Copyright (c) Microsoft 3 | # Licensed under the MIT License. 4 | # Written by Ke Sun (sunk@mail.ustc.edu.cn) 5 | # ------------------------------------------------------------------------------ 6 | 7 | from __future__ import absolute_import 8 | from __future__ import division 9 | from __future__ import print_function 10 | 11 | import os.path as osp 12 | import sys 13 | 14 | 15 | def add_path(path): 16 | if path not in sys.path: 17 | sys.path.insert(0, path) 18 | 19 | this_dir = osp.dirname(__file__) 20 | 21 | lib_path = osp.join(this_dir, '..', 'lib') 22 | add_path(lib_path) 23 | -------------------------------------------------------------------------------- /hrnet_code/tools/test.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------ 2 | # Copyright (c) Microsoft 3 | # Licensed under the MIT License. 4 | # Written by Ke Sun (sunk@mail.ustc.edu.cn) 5 | # ------------------------------------------------------------------------------ 6 | 7 | import argparse 8 | import os 9 | import pprint 10 | import shutil 11 | import sys 12 | 13 | import logging 14 | import time 15 | import timeit 16 | from pathlib import Path 17 | 18 | import numpy as np 19 | 20 | import torch 21 | import torch.nn as nn 22 | import torch.backends.cudnn as cudnn 23 | 24 | import _init_paths 25 | import models 26 | import datasets 27 | from config import config 28 | from config import update_config 29 | from core.function import testval, test 30 | from utils.modelsummary import get_model_summary 31 | from utils.utils import create_logger, FullModel 32 | 33 | def parse_args(): 34 | parser = argparse.ArgumentParser(description='Train segmentation network') 35 | 36 | parser.add_argument('--cfg', 37 | help='experiment configure file name', 38 | required=True, 39 | type=str) 40 | parser.add_argument('opts', 41 | help="Modify config options using the command-line", 42 | default=None, 43 | nargs=argparse.REMAINDER) 44 | 45 | args = parser.parse_args() 46 | update_config(config, args) 47 | 48 | return args 49 | 50 | def main(): 51 | args = parse_args() 52 | 53 | logger, final_output_dir, _ = create_logger( 54 | config, args.cfg, 'test') 55 | 56 | logger.info(pprint.pformat(args)) 57 | logger.info(pprint.pformat(config)) 58 | 59 | # cudnn related setting 60 | cudnn.benchmark = config.CUDNN.BENCHMARK 61 | cudnn.deterministic = config.CUDNN.DETERMINISTIC 62 | cudnn.enabled = config.CUDNN.ENABLED 63 | 64 | # build model 65 | if torch.__version__.startswith('1'): 66 | module = eval('models.'+config.MODEL.NAME) 67 | module.BatchNorm2d_class = module.BatchNorm2d = torch.nn.BatchNorm2d 68 | model = eval('models.'+config.MODEL.NAME + 69 | '.get_seg_model')(config) 70 | 71 | dump_input = torch.rand( 72 | (1, 3, config.TRAIN.IMAGE_SIZE[1], config.TRAIN.IMAGE_SIZE[0]) 73 | ) 74 | logger.info(get_model_summary(model.cuda(), dump_input.cuda())) 75 | 76 | if config.TEST.MODEL_FILE: 77 | model_state_file = config.TEST.MODEL_FILE 78 | else: 79 | model_state_file = os.path.join(final_output_dir, 'final_state.pth') 80 | logger.info('=> loading model from {}'.format(model_state_file)) 81 | 82 | pretrained_dict = torch.load(model_state_file) 83 | if 'state_dict' in pretrained_dict: 84 | pretrained_dict = pretrained_dict['state_dict'] 85 | model_dict = model.state_dict() 86 | pretrained_dict = {k[6:]: v for k, v in pretrained_dict.items() 87 | if k[6:] in model_dict.keys()} 88 | for k, _ in pretrained_dict.items(): 89 | logger.info( 90 | '=> loading {} from pretrained model'.format(k)) 91 | model_dict.update(pretrained_dict) 92 | model.load_state_dict(model_dict) 93 | 94 | gpus = list(config.GPUS) 95 | model = nn.DataParallel(model, device_ids=gpus).cuda() 96 | 97 | # prepare data 98 | test_size = (config.TEST.IMAGE_SIZE[1], config.TEST.IMAGE_SIZE[0]) 99 | test_dataset = eval('datasets.'+config.DATASET.DATASET)( 100 | root=config.DATASET.ROOT, 101 | list_path=config.DATASET.TEST_SET, 102 | num_samples=None, 103 | num_classes=config.DATASET.NUM_CLASSES, 104 | multi_scale=False, 105 | flip=False, 106 | ignore_label=config.TRAIN.IGNORE_LABEL, 107 | base_size=config.TEST.BASE_SIZE, 108 | crop_size=test_size, 109 | downsample_rate=1) 110 | 111 | testloader = torch.utils.data.DataLoader( 112 | test_dataset, 113 | batch_size=1, 114 | shuffle=False, 115 | num_workers=config.WORKERS, 116 | pin_memory=True) 117 | 118 | start = timeit.default_timer() 119 | if 'val' in config.DATASET.TEST_SET: 120 | mean_IoU, IoU_array, pixel_acc, mean_acc = testval(config, 121 | test_dataset, 122 | testloader, 123 | model) 124 | 125 | msg = 'MeanIU: {: 4.4f}, Pixel_Acc: {: 4.4f}, \ 126 | Mean_Acc: {: 4.4f}, Class IoU: '.format(mean_IoU, 127 | pixel_acc, mean_acc) 128 | logging.info(msg) 129 | logging.info(IoU_array) 130 | elif 'test' in config.DATASET.TEST_SET: 131 | test(config, 132 | test_dataset, 133 | testloader, 134 | model, 135 | sv_dir=final_output_dir) 136 | 137 | end = timeit.default_timer() 138 | logger.info('Mins: %d' % np.int((end-start)/60)) 139 | logger.info('Done') 140 | 141 | 142 | if __name__ == '__main__': 143 | main() 144 | -------------------------------------------------------------------------------- /inference.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | import os 4 | import os.path as osp 5 | import numpy as np 6 | from PIL import Image 7 | 8 | import argparse 9 | 10 | from lib.model_zoo.texrnet import TexRNet 11 | from lib.model_zoo.hrnet import HRNet_Base 12 | from lib.model_zoo.deeplab import DeepLabv3p_Base 13 | from lib.model_zoo.resnet import ResNet_Dilated_Base 14 | from lib import torchutils 15 | 16 | import tqdm 17 | 18 | class TextRNet_HRNet_Wrapper(object): 19 | """ 20 | This is the UltraSRWrapper with render-level batchification. 21 | """ 22 | def __init__(self, 23 | device, 24 | pth=None,): 25 | """ 26 | Create uspcale instance 27 | :param device: device on run the upscale pipeline (if GPU is accessible should be 'cuda') 28 | :param pth: path to model 29 | """ 30 | self.model = self.make_model(pth) 31 | self.model.eval() 32 | self.model = self.model.to(device) 33 | self.device = device 34 | 35 | @staticmethod 36 | def make_model(pth=None): 37 | backbone = HRNet_Base( 38 | oc_n=720, 39 | align_corners=True, 40 | ignore_label=999, 41 | stage1_para={ 42 | 'BLOCK' : 'BOTTLENECK', 43 | 'FUSE_METHOD' : 'SUM', 44 | 'NUM_BLOCKS' : [4], 45 | 'NUM_BRANCHES': 1, 46 | 'NUM_CHANNELS': [64], 47 | 'NUM_MODULES' : 1 }, 48 | stage2_para={ 49 | 'BLOCK' : 'BASIC', 50 | 'FUSE_METHOD' : 'SUM', 51 | 'NUM_BLOCKS' : [4, 4], 52 | 'NUM_BRANCHES': 2, 53 | 'NUM_CHANNELS': [48, 96], 54 | 'NUM_MODULES' : 1 }, 55 | stage3_para={ 56 | 'BLOCK' : 'BASIC', 57 | 'FUSE_METHOD' : 'SUM', 58 | 'NUM_BLOCKS' : [4, 4, 4], 59 | 'NUM_BRANCHES': 3, 60 | 'NUM_CHANNELS': [48, 96, 192], 61 | 'NUM_MODULES' : 4 }, 62 | stage4_para={ 63 | 'BLOCK' : 'BASIC', 64 | 'FUSE_METHOD' : 'SUM', 65 | 'NUM_BLOCKS' : [4, 4, 4, 4], 66 | 'NUM_BRANCHES': 4, 67 | 'NUM_CHANNELS': [48, 96, 192, 384], 68 | 'NUM_MODULES' : 3 }, 69 | final_conv_kernel = 1, 70 | ) 71 | 72 | model = TexRNet( 73 | bbn_name='hrnet', 74 | bbn=backbone, 75 | ic_n=720, 76 | rfn_c_n=[725, 64, 64], 77 | sem_n=2, 78 | conv_type='conv', 79 | bn_type='bn', 80 | relu_type='relu', 81 | align_corners=True, 82 | ignore_label=None, 83 | bias_att_type='cossim', 84 | ineval_output_argmax=False, 85 | ) 86 | if pth is not None: 87 | paras = torch.load(pth, map_location=torch.device('cpu')) 88 | new_paras = model.state_dict() 89 | new_paras.update(paras) 90 | model.load_state_dict(new_paras) 91 | return model 92 | 93 | def process(self, pil_image): 94 | im = np.array(pil_image.convert("RGB")) 95 | im = im/255 96 | im = im - np.array([0.485, 0.456, 0.406]) 97 | im = im / np.array([0.229, 0.224, 0.225]) 98 | im = np.transpose(im, (2, 0, 1))[None] 99 | im = torch.FloatTensor(im).to(self.device) 100 | 101 | # This step will auto-adjust model if it is torch-DDP 102 | netm = getattr(self.model, 'module', self.model) 103 | _, _, oh, ow = im.shape 104 | ac = True 105 | 106 | prfnc_ms, pcount_ms = {}, {} 107 | 108 | for mstag, mssize in [ 109 | ['0.75x', 385], 110 | ['1.00x', 513], 111 | ['1.25x', 641], 112 | ['1.50x', 769], 113 | ['1.75x', 897], 114 | ['2.00x', 1025], 115 | ['2.25x', 1153], 116 | ['2.50x', 1281], ]: 117 | # by area 118 | ratio = np.sqrt(mssize**2 / (oh*ow)) 119 | th, tw = int(oh*ratio), int(ow*ratio) 120 | tw = tw//32*32+1 121 | th = th//32*32+1 122 | 123 | imi = { 124 | 'nofp' : torchutils.interpolate_2d( 125 | size=(th, tw), mode='bilinear', 126 | align_corners=ac)(im)} 127 | imi['flip'] = torch.flip(imi['nofp'], dims=[-1]) 128 | 129 | for fliptag, imii in imi.items(): 130 | with torch.no_grad(): 131 | pred = netm(imii) 132 | psem = torchutils.interpolate_2d( 133 | size=(oh, ow), 134 | mode='bilinear', align_corners=ac)(pred['predsem']) 135 | prfn = torchutils.interpolate_2d( 136 | size=(oh, ow), 137 | mode='bilinear', align_corners=ac)(pred['predrfn']) 138 | 139 | if fliptag == 'flip': 140 | psem = torch.flip(psem, dims=[-1]) 141 | prfn = torch.flip(prfn, dims=[-1]) 142 | elif fliptag == 'nofp': 143 | pass 144 | else: 145 | raise ValueError 146 | 147 | try: 148 | prfnc_ms[mstag] += prfn 149 | pcount_ms[mstag] += 1 150 | except: 151 | prfnc_ms[mstag] = prfn 152 | pcount_ms[mstag] = 1 153 | 154 | pred = sum([pi for pi in prfnc_ms.values()]) 155 | pred /= sum([ni for ni in pcount_ms.values()]) 156 | pred = torch.argmax(psem, dim=1) 157 | pred = pred[0].cpu().detach().numpy() 158 | pred = (pred * 255).astype(np.uint8) 159 | return Image.fromarray(pred) 160 | 161 | class TextRNet_Deeplab_Wrapper(TextRNet_HRNet_Wrapper): 162 | @staticmethod 163 | def make_model(pth=None): 164 | raise NotImplementedError 165 | # resnet = ResNet_Dilated_Base( 166 | # block = 167 | # layer_n = 168 | # ) 169 | 170 | # model = TexRNet( 171 | # bbn_name='hrnet', 172 | # bbn=backbone, 173 | # ic_n=720, 174 | # rfn_c_n=[725, 64, 64], 175 | # sem_n=2, 176 | # conv_type='conv', 177 | # bn_type='bn', 178 | # relu_type='relu', 179 | # align_corners=True, 180 | # ignore_label=None, 181 | # bias_att_type='cossim', 182 | # ineval_output_argmax=False, 183 | # ) 184 | # if pth is not None: 185 | # paras = torch.load(pth, map_location=torch.device('cpu')) 186 | # new_paras = model.state_dict() 187 | # new_paras.update(paras) 188 | # model.load_state_dict(new_paras) 189 | # return model 190 | 191 | if __name__ == "__main__": 192 | parser = argparse.ArgumentParser() 193 | parser.add_argument("--input", type=str, required=True, help="input folder or a single input file") 194 | parser.add_argument("--output", type=str, required=True, help="output folder or a single output file") 195 | parser.add_argument("--method", '-m', type=str, default='textrnet_hrnet') 196 | args = parser.parse_args() 197 | 198 | if osp.isdir(args.input): 199 | if not osp.exists(args.output): 200 | os.makedirs(args.output) 201 | assert osp.isdir(args.output), \ 202 | "When --input is a directory, --output must be a directory!" 203 | elif osp.isfile(args.input): 204 | assert not osp.isdir(args.output), \ 205 | "When --input is a file, --output must be a file!" 206 | else: 207 | assert False, "No such input!" 208 | 209 | assert args.input != args.output, \ 210 | "--input and --output points to the same location, "\ 211 | "this is not allowed because it will override the input files." 212 | 213 | if args.method == 'textrnet_hrnet': 214 | wrapper = TextRNet_HRNet_Wrapper 215 | model_path = 'pretrained/texrnet_hrnet.pth' 216 | elif args.method == 'textrnet_deeplab': 217 | wrapper = TextRNet_Deeplab_Wrapper 218 | model_path = 'pretrained/texrnet_deeplab.pth' 219 | else: 220 | assert False, 'No such model.' 221 | 222 | enl = wrapper(torch.device("cuda:0"), model_path) 223 | 224 | if osp.isfile(args.input): 225 | imgs = [args.input] 226 | outs = [args.output] 227 | else: 228 | imgs = sorted(os.listdir(args.input)) 229 | outs = [ 230 | osp.join(args.output, '{}.png'.format(osp.splitext(fi)[0])) 231 | for fi in imgs 232 | ] 233 | imgs = [osp.join(args.input, fi) for fi in imgs] 234 | 235 | for fin, fout in tqdm(zip(imgs, outs), total=len(imgs)): 236 | x = Image.open(fin).convert('RGB') 237 | y = enl.process(x) 238 | y.save(fout) 239 | -------------------------------------------------------------------------------- /lib/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SHI-Labs/Rethinking-Text-Segmentation/223f7ffc822c345ce1a7c0eb3d4fac58a43d6a3a/lib/__init__.py -------------------------------------------------------------------------------- /lib/cfg_helper.py: -------------------------------------------------------------------------------- 1 | import os 2 | import os.path as osp 3 | import shutil 4 | import copy 5 | import time 6 | from easydict import EasyDict as edict 7 | import pprint 8 | import numpy as np 9 | import torch 10 | import matplotlib 11 | import argparse 12 | import torch 13 | from configs.cfg_base import cfg_train, cfg_test 14 | import json 15 | 16 | def singleton(class_): 17 | instances = {} 18 | def getinstance(*args, **kwargs): 19 | if class_ not in instances: 20 | instances[class_] = class_(*args, **kwargs) 21 | return instances[class_] 22 | return getinstance 23 | 24 | @singleton 25 | class cfg_unique_holder(object): 26 | def __init__(self): 27 | self.cfg = None 28 | # this is use to track the main codes. 29 | self.code = set() 30 | def save_cfg(self, cfg): 31 | self.cfg = copy.deepcopy(cfg) 32 | def add_code(self, code): 33 | """ 34 | A new main code is reached and 35 | its name is added. 36 | """ 37 | self.code.add(code) 38 | 39 | def get_experiment_id(): 40 | time.sleep(0.5) 41 | return int(time.time()*100) 42 | 43 | def set_debug_cfg(cfg, istrain=True): 44 | if istrain: 45 | cfg.EXPERIMENT_ID = 999999999999 46 | cfg.TRAIN.SAVE_INIT_MODEL = False 47 | cfg.TRAIN.COMMENT = "Debug" 48 | cfg.LOG_DIR = osp.join( 49 | cfg.MISC_DIR, 50 | '{}_{}'.format(cfg.MODEL.MODEL_NAME, cfg.DATA.DATASET_NAME), 51 | str(cfg.EXPERIMENT_ID)) 52 | cfg.LOG_FILE = osp.join(cfg.LOG_DIR, 'train.log') 53 | cfg.TRAIN.BATCH_SIZE = None 54 | cfg.TRAIN.BATCH_SIZE_PER_GPU = 2 55 | else: 56 | cfg.LOG_DIR = cfg.LOG_DIR.replace(cfg.TEST.SUB_DIR, 'debug') 57 | cfg.TEST.SUB_DIR = 'debug' 58 | cfg.LOG_FILE = osp.join( 59 | cfg.LOG_DIR, 'eval.log') 60 | cfg.TEST.BATCH_SIZE = None 61 | cfg.TEST.BATCH_SIZE_PER_GPU = 1 62 | 63 | cfg.DATA.NUM_WORKERS = None 64 | cfg.DATA.NUM_WORKERS_PER_GPU = 0 65 | cfg.MATPLOTLIB_MODE = 'TKAgg' 66 | return cfg 67 | 68 | def experiment_folder(cfg, 69 | isnew=False, 70 | sig=['nosig'], 71 | mdds_override=None, 72 | **kwargs): 73 | """ 74 | Args: 75 | cfg: easydict, 76 | the config easydict 77 | isnew: bool, 78 | whether this is a new folder or not 79 | True, create a path using exid and sig 80 | False, find a path based on exid and refpath 81 | sig: [sig1, ...] array of str 82 | when isnew == True, these are the signature 83 | put as [exid]_[sig1]_.._[sign] 84 | signatures after (and include) 'hided' will be 85 | hided from the name. 86 | mdds_override: str or None, 87 | the override folder for [modelname]_[dataset], 88 | None, no override 89 | Returns: 90 | workdir: str, 91 | the absolute path to the folder. 92 | """ 93 | if mdds_override is None: 94 | refdir = osp.join( 95 | cfg.MISC_DIR, 96 | '{}_{}'.format( 97 | cfg.MODEL.MODEL_NAME, cfg.DATA.DATASET_NAME),) 98 | else: 99 | refdir = osp.abspath(osp.join( 100 | cfg.MISC_DIR, mdds_override)) 101 | 102 | if isnew: 103 | try: 104 | hided_after = sig.index('hided') 105 | except: 106 | hided_after = len(sig) 107 | sig = sig[0:hided_after] 108 | workdir = '_'.join([str(cfg.EXPERIMENT_ID)] + sig) 109 | return osp.join(refdir, workdir) 110 | else: 111 | for d in os.listdir(refdir): 112 | if not d.find(str(cfg.EXPERIMENT_ID))==0: 113 | continue 114 | if not osp.isdir(osp.join(refdir, d)): 115 | continue 116 | try: 117 | workdir = osp.join( 118 | refdir, d, cfg.TEST.SUB_DIR) 119 | except: 120 | workdir = osp.join( 121 | refdir, d) 122 | return workdir 123 | raise ValueError 124 | 125 | def get_experiment_folder(exid, 126 | path, 127 | full_path=False, 128 | **kwargs): 129 | """ 130 | Args: 131 | exid: int, 132 | experiment ID 133 | path: path, 134 | the base folder to search... 135 | folder should be like _.... 136 | full_path: bool, 137 | whether return the full path or not. 138 | """ 139 | for d in os.listdir(path): 140 | if d.find(str(exid))==0: 141 | if osp.isdir(osp.join(path, d)): 142 | if not full_path: 143 | return d 144 | else: 145 | return osp.abspath(osp.join(path, d)) 146 | raise ValueError 147 | 148 | def set_experiment_folder(exid, 149 | signature, 150 | **kwargs): 151 | """ 152 | Args: 153 | exid: experiment ID 154 | signature: string or array of strings tells the tags that append after exid 155 | as a experiment folder... 156 | """ 157 | if isinstance(signature, str): 158 | signature = [signature] 159 | return '_'.join([str(exid)] + signature) 160 | 161 | def hided_sig_to_str(sig): 162 | """ 163 | Args: 164 | sig: [] of str, 165 | Returns: 166 | out: str 167 | If sig is [..., 'hided', 'sig1', 'sig2'] 168 | out = hided: sig1_sig2_... 169 | If sig do not have 'hided' 170 | out = None 171 | """ 172 | try: 173 | hided_after = sig.index('hided') 174 | except: 175 | return None 176 | 177 | return 'hided: '+'_'.join(sig[hided_after+1:]) 178 | 179 | def common_initiates(cfg): 180 | if cfg.GPU_DEVICE != 'all': 181 | os.environ["CUDA_VISIBLE_DEVICES"] = ','.join( 182 | [str(gid) for gid in cfg.GPU_DEVICE]) 183 | cfg.GPU_COUNT = len(cfg.GPU_DEVICE) 184 | else: 185 | cfg.GPU_COUNT = torch.cuda.device_count() 186 | 187 | if 'TRAIN' in cfg: 188 | if (cfg.TRAIN.BATCH_SIZE is None) and \ 189 | (cfg.TRAIN.BATCH_SIZE_PER_GPU is None): 190 | raise ValueError 191 | elif (cfg.TRAIN.BATCH_SIZE is not None) and \ 192 | (cfg.TRAIN.BATCH_SIZE_PER_GPU is not None): 193 | if cfg.TRAIN.BATCH_SIZE != \ 194 | cfg.TRAIN.BATCH_SIZE_PER_GPU * cfg.GPU_COUNT: 195 | raise ValueError 196 | elif cfg.TRAIN.BATCH_SIZE is None: 197 | cfg.TRAIN.BATCH_SIZE = \ 198 | cfg.TRAIN.BATCH_SIZE_PER_GPU * cfg.GPU_COUNT 199 | else: 200 | cfg.TRAIN.BATCH_SIZE_PER_GPU = \ 201 | cfg.TRAIN.BATCH_SIZE // cfg.GPU_COUNT 202 | if 'TEST' in cfg: 203 | if (cfg.TEST.BATCH_SIZE is None) and \ 204 | (cfg.TEST.BATCH_SIZE_PER_GPU is None): 205 | raise ValueError 206 | elif (cfg.TEST.BATCH_SIZE is not None) and \ 207 | (cfg.TEST.BATCH_SIZE_PER_GPU is not None): 208 | if cfg.TEST.BATCH_SIZE != \ 209 | cfg.TEST.BATCH_SIZE_PER_GPU * cfg.GPU_COUNT: 210 | raise ValueError 211 | elif cfg.TEST.BATCH_SIZE is None: 212 | cfg.TEST.BATCH_SIZE = \ 213 | cfg.TEST.BATCH_SIZE_PER_GPU * cfg.GPU_COUNT 214 | else: 215 | cfg.TEST.BATCH_SIZE_PER_GPU = \ 216 | cfg.TEST.BATCH_SIZE // cfg.GPU_COUNT 217 | 218 | if (cfg.DATA.NUM_WORKERS is None) and \ 219 | (cfg.DATA.NUM_WORKERS_PER_GPU is None): 220 | raise ValueError 221 | elif (cfg.DATA.NUM_WORKERS is not None) and \ 222 | (cfg.DATA.NUM_WORKERS_PER_GPU is not None): 223 | if cfg.DATA.NUM_WORKERS != \ 224 | cfg.DATA.NUM_WORKERS_PER_GPU * cfg.GPU_COUNT: 225 | raise ValueError 226 | elif cfg.DATA.NUM_WORKERS is None: 227 | cfg.DATA.NUM_WORKERS = \ 228 | cfg.DATA.NUM_WORKERS_PER_GPU * cfg.GPU_COUNT 229 | else: 230 | cfg.DATA.NUM_WORKERS_PER_GPU = \ 231 | cfg.DATA.NUM_WORKERS // cfg.GPU_COUNT 232 | 233 | cfg.MAIN_CODE_PATH = osp.abspath(osp.join( 234 | osp.dirname(__file__), '..')) 235 | cfg.MAIN_CODE = list(cfg_unique_holder().code) 236 | 237 | cfg.TORCH_VERSION = torch.__version__ 238 | 239 | pprint.pprint(cfg) 240 | if cfg.LOG_FILE is not None: 241 | if not osp.exists(osp.dirname(cfg.LOG_FILE)): 242 | os.makedirs(osp.dirname(cfg.LOG_FILE)) 243 | with open(cfg.LOG_FILE, 'w') as f: 244 | pprint.pprint(cfg, f) 245 | with open(osp.join(cfg.LOG_DIR, 'config.json'), 'w') as f: 246 | json.dump(cfg, f, indent=4) 247 | 248 | # step3.1 code saving 249 | if cfg.SAVE_CODE: 250 | codedir = osp.join(cfg.LOG_DIR, 'code') 251 | if osp.exists(codedir): 252 | shutil.rmtree(codedir) 253 | for d in ['configs', 'lib']: 254 | fromcodedir = osp.abspath( 255 | osp.join(cfg.MAIN_CODE_PATH, d)) 256 | tocodedir = osp.join(codedir, d) 257 | shutil.copytree( 258 | fromcodedir, tocodedir, 259 | ignore=shutil.ignore_patterns('*__pycache__*', '*build*')) 260 | for codei in cfg.MAIN_CODE: 261 | shutil.copy(osp.join(cfg.MAIN_CODE_PATH, codei), codedir) 262 | 263 | # step3.2 264 | if cfg.RND_SEED is None: 265 | pass 266 | elif isinstance(cfg.RND_SEED, int): 267 | np.random.seed(cfg.RND_SEED) 268 | torch.manual_seed(cfg.RND_SEED) 269 | else: 270 | raise ValueError 271 | 272 | # step3.3 273 | if cfg.RND_RECORDING: 274 | rnduh().reset(osp.join(cfg.LOG_DIR, 'rcache'), None) 275 | if not isinstance(cfg.RND_SEED, str): 276 | pass 277 | elif osp.isfile(cfg.RND_SEED): 278 | print('[Warning]: RND_SEED is a file but RND_RECORDING is on and disables the file.') 279 | # raise ValueError 280 | 281 | # step3.4 282 | try: 283 | if cfg.MATPLOTLIB_MODE is not None: 284 | matplotlib.use(cfg.MATPLOTLIB_MODE) 285 | except: 286 | pass 287 | 288 | return cfg 289 | 290 | def common_argparse(extra_parsing_f=None): 291 | """ 292 | Outputs: 293 | cfg: edict, 294 | 'DEBUG' 295 | 'GPU_DEVICE' 296 | 'ISTRAIN' 297 | exid: [] of int -or- None 298 | experiment id followed by --eval 299 | None so do the regular training 300 | """ 301 | 302 | parser = argparse.ArgumentParser() 303 | parser.add_argument('--debug', action='store_true') 304 | parser.add_argument('--eval', nargs='+', type=int) 305 | parser.add_argument('--gpu', nargs='+', type=int) 306 | parser.add_argument('--port', type=int) 307 | 308 | if extra_parsing_f is not None: 309 | args, cfg = extra_parsing_f(parser) 310 | else: 311 | args = parser.parse_args() 312 | cfg = edict() 313 | 314 | cfg.DEBUG = args.debug 315 | try: 316 | cfg.GPU_DEVICE = list(args.gpu) 317 | except: 318 | pass 319 | 320 | try: 321 | eval_exid = list(args.eval) 322 | except: 323 | eval_exid = None 324 | 325 | try: 326 | port = int(args.port) 327 | cfg.DIST_URL = 'tcp://127.0.0.1:{}'.format(port) 328 | except: 329 | pass 330 | 331 | return cfg, eval_exid 332 | -------------------------------------------------------------------------------- /lib/data_factory/__init__.py: -------------------------------------------------------------------------------- 1 | from .ds_base import get_dataset, collate 2 | from .ds_loader import get_loader 3 | from .ds_transform import get_transform 4 | from .ds_formatter import get_formatter 5 | from .ds_sampler import DistributedSampler 6 | -------------------------------------------------------------------------------- /lib/data_factory/ds_base.py: -------------------------------------------------------------------------------- 1 | import os 2 | import os.path as osp 3 | import numpy as np 4 | import numpy.random as npr 5 | import torch 6 | import torchvision 7 | import copy 8 | import itertools 9 | 10 | import sys 11 | code_dir = osp.abspath(osp.join(osp.dirname(__file__), '..', '..')) 12 | sys.path.append(code_dir) 13 | 14 | from .. import nputils 15 | from ..cfg_helper import cfg_unique_holder as cfguh 16 | from ..log_service import print_log 17 | 18 | class ds_base(torch.utils.data.Dataset): 19 | def __init__(self, 20 | mode, 21 | loader = None, 22 | estimator = None, 23 | transforms = None, 24 | formatter = None, 25 | **kwargs): 26 | self.init_load_info(mode) 27 | self.loader = loader 28 | self.transforms = transforms 29 | self.formatter = formatter 30 | 31 | console_info = '{}: '.format(self.__class__.__name__) 32 | console_info += 'total {} unique images, '.format(len(self.load_info)) 33 | 34 | self.load_info = sorted(self.load_info, key=lambda x:x['unique_id']) 35 | if estimator is not None: 36 | self.load_info = estimator(self.load_info) 37 | 38 | console_info += 'total {0} unique images after estimation.'.format( 39 | len(self.load_info)) 40 | print_log(console_info) 41 | 42 | for idx, info in enumerate(self.load_info): 43 | info['idx'] = idx 44 | 45 | try: 46 | trysome = cfguh().cfg.DATA.TRY_SOME_SAMPLE 47 | except: 48 | trysome = None 49 | 50 | if trysome is not None: 51 | if isinstance(trysome, str): 52 | trysome = [trysome] 53 | elif isinstance(trysome, (list, tuple)): 54 | trysome = list(trysome) 55 | else: 56 | raise ValueError 57 | 58 | self.load_info = [ 59 | infoi for infoi in self.load_info \ 60 | if osp.splitext( 61 | osp.basename(infoi['image_path']) 62 | )[0] in trysome 63 | ] 64 | print_log('try {} samples.'.format(len(self.load_info))) 65 | 66 | def init_load_info(self, mode): 67 | # implement by sub class 68 | raise ValueError 69 | 70 | def __len__(self): 71 | try: 72 | try_sample = cfguh().cfg.DATA.TRY_SAMPLE 73 | except: 74 | try_sample = None 75 | if try_sample is not None: 76 | return try_sample 77 | return len(self.load_info) 78 | 79 | def __getitem__(self, idx): 80 | element = copy.deepcopy(self.load_info[idx]) 81 | element = self.loader(element) 82 | if self.transforms is not None: 83 | element = self.transforms(element) 84 | if self.formatter is not None: 85 | return self.formatter(element) 86 | else: 87 | return element 88 | 89 | def singleton(class_): 90 | instances = {} 91 | def getinstance(*args, **kwargs): 92 | if class_ not in instances: 93 | instances[class_] = class_(*args, **kwargs) 94 | return instances[class_] 95 | return getinstance 96 | 97 | @singleton 98 | class get_dataset(object): 99 | def __init__(self): 100 | self.dataset = {} 101 | 102 | def register(self, dsf): 103 | self.dataset[dsf.__name__] = dsf 104 | 105 | def __call__(self, dsname=None): 106 | if dsname is None: 107 | dsname = cfguh().cfg.DATA.DATASET_NAME 108 | 109 | # the register is in each file 110 | if dsname == 'textseg': 111 | from . import ds_textseg 112 | elif dsname == 'cocots': 113 | from . import ds_cocotext 114 | elif dsname == 'mlt': 115 | from . import ds_mlt 116 | elif dsname == 'icdar13': 117 | from . import ds_icdar13 118 | elif dsname == 'totaltext': 119 | from . import ds_totaltext 120 | elif dsname == 'textssc': 121 | from . import ds_textssc 122 | 123 | return self.dataset[dsname] 124 | 125 | def register(): 126 | def wrapper(class_): 127 | get_dataset().register(class_) 128 | return class_ 129 | return wrapper 130 | 131 | # some other helpers 132 | 133 | class collate(object): 134 | def __init__(self): 135 | self.default_collate = torch.utils.data._utils.collate.default_collate 136 | 137 | def __call__(self, batch): 138 | elem = batch[0] 139 | if not isinstance(elem, (tuple, list)): 140 | return self.default_collate(batch) 141 | 142 | rv = [] 143 | # transposed 144 | for i in zip(*batch): 145 | if isinstance(i[0], list): 146 | if len(i[0]) != 1: 147 | raise ValueError 148 | try: 149 | i = [[self.default_collate(ii).squeeze(0)] for ii in i] 150 | except: 151 | pass 152 | rvi = list(itertools.chain.from_iterable(i)) 153 | rv.append(rvi) # list concat 154 | elif i[0] is None: 155 | rv.append(None) 156 | else: 157 | rv.append(self.default_collate(i)) 158 | return rv 159 | -------------------------------------------------------------------------------- /lib/data_factory/ds_cocotext.py: -------------------------------------------------------------------------------- 1 | import os 2 | import os.path as osp 3 | import numpy as np 4 | import numpy.random as npr 5 | import torch 6 | import torchvision 7 | import PIL 8 | import json 9 | import cv2 10 | import copy 11 | PIL.Image.MAX_IMAGE_PIXELS = None 12 | 13 | from .ds_base import ds_base, register as regdataset 14 | from .ds_loader import pre_loader_checkings, register as regloader 15 | from .ds_transform import TBase, have, register as regtrans 16 | from .ds_formatter import register as regformat 17 | 18 | from .. import nputils 19 | from ..cfg_helper import cfg_unique_holder as cfguh 20 | from ..log_service import print_log 21 | 22 | class common(ds_base): 23 | def init_load_info(self, mode): 24 | cfgd = cfguh().cfg.DATA 25 | annofile = osp.join(cfgd.ROOT_DIR, 'coco_text', 'cocotext.v2.json') 26 | with open(annofile, 'r') as f: 27 | annoinfo = json.load(f) 28 | 29 | im_list = [i for _, i in annoinfo['imgs'].items()] 30 | im_train = list(filter(lambda x:x['set'] == 'train', im_list)) 31 | im_val = list(filter(lambda x:x['set'] == 'val' , im_list)) 32 | im_list = [] 33 | 34 | for m in mode.split('+'): 35 | if m == 'train': 36 | im_list += im_train 37 | elif m == 'val': 38 | im_list += im_val 39 | else: 40 | raise ValueError 41 | 42 | self.load_info = [] 43 | 44 | for im in im_list: 45 | filename = im['file_name'] 46 | path = filename.split('_')[1] 47 | imagepath = osp.join(cfgd.ROOT_DIR, path, filename) 48 | annids = annoinfo['imgToAnns'][str(im['id'])] 49 | info = { 50 | 'unique_id' : filename.split('.')[0], 51 | 'filename' : filename, 52 | 'image_path': imagepath, 53 | 'coco_text_anno' : [annoinfo['anns'][str(i)] for i in annids], 54 | } 55 | self.load_info.append(info) 56 | 57 | @regdataset() 58 | class cocotext(common): 59 | def init_load_info(self, mode): 60 | super().init_load_info(mode) 61 | 62 | @regdataset() 63 | class cocots(common): 64 | def init_load_info(self, mode): 65 | cfgd = cfguh().cfg.DATA 66 | super().init_load_info(mode) 67 | 68 | cocots_annopath = osp.join(cfgd.ROOT_DIR, 'coco_ts_labels') 69 | 70 | info_new = [] 71 | for info in self.load_info: 72 | annf = info['filename'].split('.')[0]+'.png' 73 | segpath = osp.join(cocots_annopath, annf) 74 | if osp.exists(segpath): 75 | info = copy.deepcopy(info) 76 | info['seglabel_path'] = segpath 77 | info['bbox_path'] = info['coco_text_anno'] # a hack 78 | info_new.append(info) 79 | self.load_info = info_new 80 | 81 | def get_semantic_classname(self,): 82 | map = { 83 | 0 : 'background' , 84 | 1 : 'text' , 85 | } 86 | return map 87 | -------------------------------------------------------------------------------- /lib/data_factory/ds_formatter.py: -------------------------------------------------------------------------------- 1 | import os 2 | import os.path as osp 3 | import numpy as np 4 | import numpy.random as npr 5 | import torch 6 | import cv2 7 | # import scipy.ndimage 8 | from PIL import Image 9 | import copy 10 | import gc 11 | import itertools 12 | 13 | from ..cfg_helper import cfg_unique_holder as cfguh 14 | from .. import nputils 15 | 16 | def singleton(class_): 17 | instances = {} 18 | def getinstance(*args, **kwargs): 19 | if class_ not in instances: 20 | instances[class_] = class_(*args, **kwargs) 21 | return instances[class_] 22 | return getinstance 23 | 24 | @singleton 25 | class get_formatter(object): 26 | def __init__(self): 27 | self.formatter = {} 28 | 29 | def register(self, formatf, kwmap, kwfix): 30 | self.formatter[formatf.__name__] = [formatf, kwmap, kwfix] 31 | 32 | def __call__(self, format_name=None): 33 | cfgd = cfguh().cfg.DATA 34 | if format_name is None: 35 | format_name = cfgd.FORMATTER 36 | 37 | formatf, kwmap, kwfix = self.formatter[format_name] 38 | kw = {k1:cfgd[k2] for k1, k2 in kwmap.items()} 39 | kw.update(kwfix) 40 | return formatf(**kw) 41 | 42 | def register(kwmap={}, kwfix={}): 43 | def wrapper(class_): 44 | get_formatter().register(class_, kwmap, kwfix) 45 | return class_ 46 | return wrapper 47 | 48 | @register() 49 | class SemanticFormatter(object): 50 | def __init__(self, 51 | **kwargs): 52 | pass 53 | 54 | def __call__(self, element): 55 | im = element['image'] 56 | semlabel = element['seglabel'] 57 | return im.astype(np.float32), semlabel.astype(int), element['unique_id'] 58 | -------------------------------------------------------------------------------- /lib/data_factory/ds_loader.py: -------------------------------------------------------------------------------- 1 | import os.path as osp 2 | import numpy as np 3 | import numpy.random as npr 4 | import PIL 5 | import cv2 6 | 7 | import torch 8 | import torchvision 9 | import xml.etree.ElementTree as ET 10 | import json 11 | import copy 12 | 13 | from ..cfg_helper import cfg_unique_holder as cfguh 14 | from .. import nputils 15 | 16 | def singleton(class_): 17 | instances = {} 18 | def getinstance(*args, **kwargs): 19 | if class_ not in instances: 20 | instances[class_] = class_(*args, **kwargs) 21 | return instances[class_] 22 | return getinstance 23 | 24 | @singleton 25 | class get_loader(object): 26 | def __init__(self): 27 | self.loader = {} 28 | 29 | def register(self, loadf, kwmap, kwfix): 30 | self.loader[loadf.__name__] = [loadf, kwmap, kwfix] 31 | 32 | def __call__(self, pipeline=None): 33 | cfgd = cfguh().cfg.DATA 34 | if pipeline is None: 35 | pipeline = cfgd.LOADER_PIPELINE 36 | 37 | loader = [] 38 | for tag in pipeline: 39 | loadf, kwmap, kwfix = self.loader[tag] 40 | kw = {k1:cfgd[k2] for k1, k2 in kwmap.items()} 41 | kw.update(kwfix) 42 | loader.append(loadf(**kw)) 43 | if len(loader) == 0: 44 | return None 45 | else: 46 | return compose(loader) 47 | 48 | class compose(object): 49 | def __init__(self, loaders): 50 | self.loaders = loaders 51 | 52 | def __call__(self, element): 53 | for l in self.loaders: 54 | element = l(element) 55 | return element 56 | 57 | def register(kwmap={}, kwfix={}): 58 | def wrapper(class_): 59 | get_loader().register(class_, kwmap, kwfix) 60 | return class_ 61 | return wrapper 62 | 63 | def pre_loader_checkings(ltype): 64 | lpath = ltype+'_path' 65 | # cache feature added on 20201021 66 | lcache = ltype+'_cache' 67 | def wrapper(func): 68 | def inner(self, element): 69 | if lcache in element: 70 | # cache feature added on 20201021 71 | data = element[lcache] 72 | else: 73 | if ltype in element: 74 | raise ValueError 75 | if lpath not in element: 76 | raise ValueError 77 | 78 | if element[lpath] is None: 79 | data = None 80 | else: 81 | data = func(self, element[lpath], element) 82 | element[ltype] = data 83 | 84 | if ltype == 'image': 85 | if isinstance(data, np.ndarray): 86 | imsize = data.shape[-2:] 87 | elif isinstance(data, PIL.Image.Image): 88 | imsize = data.size[::-1] 89 | elif data is None: 90 | imsize = None 91 | else: 92 | raise ValueError 93 | element['imsize'] = imsize 94 | element['imsize_current'] = copy.deepcopy(imsize) 95 | return element 96 | return inner 97 | return wrapper 98 | 99 | ########### 100 | # general # 101 | ########### 102 | 103 | @register( 104 | {'backend':'LOAD_BACKEND_IMAGE', 105 | 'is_mc' :'LOAD_IS_MC_IMAGE' ,}) 106 | class NumpyImageLoader(object): 107 | def __init__(self, 108 | backend='pil', 109 | is_mc=False,): 110 | self.backend = backend 111 | self.is_mc = is_mc 112 | 113 | @pre_loader_checkings('image') 114 | def __call__(self, path, element): 115 | return self.load(path) 116 | 117 | def load(self, path): 118 | if not self.is_mc: 119 | if self.backend == 'cv2': 120 | data = cv2.imread( 121 | path, cv2.IMREAD_COLOR)[:, :, ::-1] 122 | elif self.backend == 'pil': 123 | data = np.array(PIL.Image.open( 124 | path).convert('RGB')) 125 | else: 126 | raise ValueError 127 | else: 128 | # multichannel should not assume image is 129 | # defaultly RGB 130 | datai = [] 131 | for p in path: 132 | if self.backend == 'cv2': 133 | i = cv2.imread(p)[:, :, ::-1] 134 | elif self.backend == 'pil': 135 | i = np.array(PIL.Image.open(p)) 136 | else: 137 | raise ValueError 138 | if len(i.shape) == 2: 139 | i = i[:, :, np.newaxis] 140 | datai.append(i) 141 | data = np.concatenate(datai, axis=2) 142 | return np.transpose(data, (2, 0, 1)).astype(np.uint8) 143 | 144 | @register( 145 | {'backend':'LOAD_BACKEND_IMAGE', 146 | 'is_mc' :'LOAD_IS_MC_IMAGE' ,}) 147 | class NumpyImageLoaderWithCache(NumpyImageLoader): 148 | def __init__(self, 149 | backend='pil', 150 | is_mc=False): 151 | super().__init__(backend, is_mc) 152 | self.cache_data = None 153 | self.cache_path = None 154 | 155 | @pre_loader_checkings('image') 156 | def __call__(self, path, element): 157 | if path == self.cache_path: 158 | return self.cache_data 159 | else: 160 | self.cache_data = super().load(path) 161 | return self.cache_data 162 | 163 | @register( 164 | {'backend':'LOAD_BACKEND_SEGLABEL', 165 | 'is_mc' :'LOAD_IS_MC_SEGLABEL' ,}) 166 | class NumpySeglabelLoader(object): 167 | def __init__(self, 168 | backend='pil', 169 | is_mc=False): 170 | self.backend = backend 171 | self.is_mc = is_mc 172 | 173 | @pre_loader_checkings('seglabel') 174 | def __call__(self, path, element): 175 | return self.load(path) 176 | 177 | def load(self, path): 178 | if not self.is_mc: 179 | if self.backend == 'cv2': 180 | data = cv2.imread( 181 | path, cv2.IMREAD_GRAYSCALE) 182 | elif self.backend == 'pil': 183 | data = np.array(PIL.Image.open(path)) 184 | else: 185 | raise ValueError 186 | else: 187 | # seglabel doesn't convert to rgb. 188 | datai = [] 189 | for p in path: 190 | if self.backend == 'cv2': 191 | i = cv2.imread( 192 | p, cv2.IMREAD_GRAYSCALE) 193 | elif self.backend == 'pil': 194 | i = np.array(PIL.Image.open(p)) 195 | else: 196 | raise ValueError 197 | if len(i.shape) == 2: 198 | i = i[:, :, np.newaxis] 199 | datai.append(i) 200 | data = np.concatenate(datai, axis=2) 201 | data = np.transpose(data, (2, 0, 1)) 202 | if data.shape[0] == 1: 203 | data = data[0] 204 | return data.astype(np.int32) 205 | 206 | @register( 207 | {'backend':'LOAD_BACKEND_MASK', 208 | 'is_mc' :'LOAD_IS_MC_MASK' ,}) 209 | class NumpyMaskLoader(NumpySeglabelLoader): 210 | def __init__(self, 211 | backend='pil', 212 | is_mc=False): 213 | super().__init__( 214 | backend, is_mc) 215 | 216 | @pre_loader_checkings('mask') 217 | def __call__(self, path, element): 218 | data = super().load(path) 219 | return (data!=0).astype(np.uint8) 220 | 221 | -------------------------------------------------------------------------------- /lib/data_factory/ds_mlt.py: -------------------------------------------------------------------------------- 1 | import os 2 | import os.path as osp 3 | import numpy as np 4 | import numpy.random as npr 5 | import torch 6 | import torchvision 7 | import PIL 8 | import json 9 | import cv2 10 | import copy 11 | PIL.Image.MAX_IMAGE_PIXELS = None 12 | 13 | from .ds_base import ds_base, register as regdataset 14 | from .ds_loader import pre_loader_checkings, register as regloader 15 | from .ds_transform import TBase, have, register as regtrans 16 | from .ds_formatter import register as regformat 17 | 18 | from .. import nputils 19 | from ..cfg_helper import cfg_unique_holder as cfguh 20 | from ..log_service import print_log 21 | 22 | @regdataset() 23 | class mlt(ds_base): 24 | def init_load_info(self, mode): 25 | cfgd = cfguh().cfg.DATA 26 | self.root_dir = cfgd.ROOT_DIR 27 | 28 | trainlist = self.get_trainlist() 29 | vallist = self.get_vallist() 30 | 31 | trainlist = [i for _, i in trainlist.items()] 32 | vallist = [i for _, i in vallist.items() ] 33 | 34 | self.load_info = [] 35 | for mi in mode.split('+'): 36 | if mi == 'train': 37 | self.load_info += trainlist 38 | elif mi == 'trainseg': 39 | self.load_info += [ 40 | i for i in trainlist if i['seglabel_path'] is not None] 41 | elif mi == 'val': 42 | self.load_info += vallist 43 | elif mi == 'valseg': 44 | self.load_info += [ 45 | i for i in vallist if i['seglabel_path'] is not None] 46 | else: 47 | raise ValueError 48 | 49 | def get_trainlist(self): 50 | # get train data 51 | imdir_list = [ 52 | 'ch8_training_images_1', 'ch8_training_images_2', 53 | 'ch8_training_images_3', 'ch8_training_images_4', 54 | 'ch8_training_images_5', 'ch8_training_images_6', 55 | 'ch8_training_images_7', 'ch8_training_images_8', ] 56 | 57 | # get image info 58 | trainlist = {} 59 | for di in imdir_list: 60 | for fi in os.listdir(osp.join(self.root_dir, di)): 61 | fname = fi.split('.')[0] 62 | fpath = osp.join(self.root_dir, di, fi) 63 | trainlist[fname] = { 64 | 'unique_id' : '00_train_' + fname, 65 | 'filename' : fi, 66 | 'image_path' : fpath} 67 | 68 | # get bpoly label info 69 | labeldir = 'ch8_training_localization_transcription_gt_v2' 70 | for fi in trainlist.keys(): 71 | bpoly_path = osp.join(self.root_dir, labeldir, 'gt_'+fi+'.txt') 72 | if osp.exists(bpoly_path): 73 | trainlist[fi]['bpoly_path'] = bpoly_path 74 | else: 75 | raise ValueError 76 | 77 | # get seglabel info 78 | labeldir = osp.join('MLT_S_labels', 'training_labels') 79 | for fi in trainlist.keys(): 80 | seglabel_path = osp.join(self.root_dir, labeldir, fi+'.png') 81 | if osp.exists(seglabel_path): 82 | trainlist[fi]['seglabel_path'] = seglabel_path 83 | else: 84 | trainlist[fi]['seglabel_path'] = None 85 | return trainlist 86 | 87 | def get_vallist(self): 88 | # get train data 89 | imdir_list = [ 90 | 'ch8_validation_images', ] 91 | 92 | # get image info 93 | vallist = {} 94 | for di in imdir_list: 95 | for fi in os.listdir(osp.join(self.root_dir, di)): 96 | fname = fi.split('.')[0] 97 | fpath = osp.join(self.root_dir, di, fi) 98 | vallist[fname] = { 99 | 'unique_id' : '01_val_' + fname, 100 | 'filename' : fi, 101 | 'image_path' : fpath} 102 | 103 | # get bpoly label info 104 | labeldir = 'ch8_validation_localization_transcription_gt_v2' 105 | for fi in vallist.keys(): 106 | bpoly_path = osp.join(self.root_dir, labeldir, 'gt_'+fi+'.txt') 107 | if osp.exists(bpoly_path): 108 | vallist[fi]['bpoly_path'] = bpoly_path 109 | else: 110 | raise ValueError 111 | 112 | # get seglabel info 113 | labeldir = osp.join('MLT_S_labels', 'validation_labels') 114 | for fi in vallist.keys(): 115 | seglabel_path = osp.join(self.root_dir, labeldir, fi+'.png') 116 | if osp.exists(seglabel_path): 117 | vallist[fi]['seglabel_path'] = seglabel_path 118 | else: 119 | vallist[fi]['seglabel_path'] = None 120 | return vallist 121 | 122 | def get_semantic_classname(self,): 123 | map = { 124 | 0 : 'background' , 125 | 1 : 'text' , 126 | } 127 | return map 128 | 129 | # ---- loader ---- 130 | 131 | @regloader() 132 | class Mlt_SeglabelLoader(object): 133 | def __init__(self): 134 | pass 135 | 136 | @pre_loader_checkings('seglabel') 137 | def __call__(self, path, element): 138 | sem = np.array(PIL.Image.open(path)).astype(int) #.convert('RGB')) 139 | return sem 140 | -------------------------------------------------------------------------------- /lib/data_factory/ds_sampler.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import numpy.random as npr 4 | import torch.distributed as dist 5 | import math 6 | 7 | from ..log_service import print_log 8 | 9 | class DistributedSampler(torch.utils.data.Sampler): 10 | def __init__(self, 11 | dataset, 12 | num_replicas=None, 13 | rank=None, 14 | shuffle=True, 15 | extend=False,): 16 | if num_replicas is None: 17 | if not dist.is_available(): 18 | raise ValueError 19 | num_replicas = dist.get_world_size() 20 | if rank is None: 21 | if not dist.is_available(): 22 | raise ValueError 23 | rank = dist.get_rank() 24 | 25 | self.dataset = dataset 26 | self.num_replicas = num_replicas 27 | self.rank = rank 28 | 29 | num_samples = len(dataset) // num_replicas 30 | if extend: 31 | if len(dataset) != num_samples*num_replicas: 32 | num_samples+=1 33 | 34 | self.num_samples = num_samples 35 | self.total_size = num_samples * num_replicas 36 | self.shuffle = shuffle 37 | self.extend = extend 38 | 39 | def __iter__(self): 40 | indices = self.get_sync_order() 41 | if self.extend: 42 | # extend using the front indices 43 | indices = (indices+indices)[0:self.total_size] 44 | else: 45 | # truncate 46 | indices = indices[0:self.total_size] 47 | # subsample 48 | indices = indices[self.rank : len(indices) : self.num_replicas] 49 | return iter(indices) 50 | 51 | def __len__(self): 52 | return self.num_samples 53 | 54 | def set_epoch(self, epoch): 55 | # legacy 56 | pass 57 | 58 | def get_sync_order(self): 59 | # g = torch.Generator() 60 | # g.manual_seed(self.epoch) 61 | if self.shuffle: 62 | indices = torch.randperm(len(self.dataset)).to(self.rank) 63 | dist.broadcast(indices, src=0) 64 | indices = indices.to('cpu').tolist() 65 | else: 66 | indices = list(range(len(self.dataset))) 67 | print_log(str(indices[0:5])) 68 | return indices 69 | -------------------------------------------------------------------------------- /lib/data_factory/ds_textssc.py: -------------------------------------------------------------------------------- 1 | import os 2 | import os.path as osp 3 | import numpy as np 4 | import numpy.random as npr 5 | import torch 6 | import torchvision 7 | import PIL 8 | import json 9 | import cv2 10 | import copy 11 | import scipy 12 | import pandas 13 | PIL.Image.MAX_IMAGE_PIXELS = None 14 | 15 | from .ds_base import ds_base, register as regdataset 16 | from .ds_loader import pre_loader_checkings, register as regloader 17 | from .ds_transform import TBase, have, register as regtrans 18 | from .ds_formatter import register as regformat 19 | 20 | from .. import nputils 21 | from ..cfg_helper import cfg_unique_holder as cfguh 22 | from ..log_service import print_log 23 | 24 | @regdataset() 25 | class textssc(ds_base): 26 | def init_load_info(self, mode): 27 | cfgd = cfguh().cfg.DATA 28 | self.root_dir = cfgd.ROOT_DIR 29 | 30 | imdir = [] 31 | segdir = [] 32 | for modei in mode.split('+'): 33 | dsi, seti = modei.split('_') 34 | imdir += [osp.join(self.root_dir, dsi, 'image_'+seti)] 35 | segdir += [osp.join(self.root_dir, dsi, 'seglabel_'+seti)] 36 | 37 | self.load_info = [] 38 | 39 | for imdiri, segdiri in zip(imdir, segdir): 40 | for fi in os.listdir(imdiri): 41 | ftag = fi.split('.')[0] 42 | info = { 43 | 'unique_id' : ftag, 44 | 'filename' : fi, 45 | 'image_path' : osp.join(imdiri, fi), 46 | 'seglabel_path' : osp.join(segdiri, ftag+'.png'), 47 | } 48 | self.load_info.append(info) 49 | 50 | def get_semantic_classname(self,): 51 | map = { 52 | 0 : 'background' , 53 | 1 : 'text' , 54 | } 55 | return map 56 | -------------------------------------------------------------------------------- /lib/data_factory/ds_totaltext.py: -------------------------------------------------------------------------------- 1 | import os 2 | import os.path as osp 3 | import numpy as np 4 | import numpy.random as npr 5 | import torch 6 | import torchvision 7 | import PIL 8 | import json 9 | import cv2 10 | import copy 11 | PIL.Image.MAX_IMAGE_PIXELS = None 12 | 13 | from .ds_base import ds_base, register as regdataset 14 | from .ds_loader import pre_loader_checkings, register as regloader 15 | from .ds_transform import TBase, have, register as regtrans 16 | from .ds_formatter import register as regformat 17 | 18 | from .. import nputils 19 | from ..cfg_helper import cfg_unique_holder as cfguh 20 | from ..log_service import print_log 21 | 22 | @regdataset() 23 | class totaltext(ds_base): 24 | def init_load_info(self, mode): 25 | cfgd = cfguh().cfg.DATA 26 | self.root_dir = cfgd.ROOT_DIR 27 | 28 | im_path = [] 29 | seg_path = [] 30 | 31 | for mi in mode.split('+'): 32 | if mi == 'train': 33 | im_path += [osp.join(self.root_dir, 'Images', 'Train')] 34 | seg_path += [osp.join(self.root_dir, 'groundtruth_pixel', 'Train')] 35 | elif mi == 'test': 36 | im_path += [osp.join(self.root_dir, 'Images', 'Test')] 37 | seg_path += [osp.join(self.root_dir, 'groundtruth_pixel', 'Test')] 38 | else: 39 | raise ValueError 40 | 41 | self.load_info = [] 42 | for impi, segpi in zip(im_path, seg_path): 43 | for fi in os.listdir(impi): 44 | uid = fi.split('.')[0] 45 | self.load_info.append({ 46 | 'unique_id' : uid, 47 | 'filename' : fi, 48 | 'image_path' : osp.join(impi, fi), 49 | 'seglabel_path' : osp.join(segpi, uid+'.jpg'), 50 | }) 51 | 52 | def get_semantic_classname(self,): 53 | map = { 54 | 0 : 'background' , 55 | 1 : 'text' , 56 | } 57 | return map 58 | 59 | # ---- loader ---- 60 | 61 | @regloader() 62 | class TotalText_SeglabelLoader(object): 63 | def __init__(self): 64 | pass 65 | 66 | @pre_loader_checkings('seglabel') 67 | def __call__(self, path, element): 68 | sem = np.array(PIL.Image.open(path)).astype(int) 69 | sem = (sem>127).astype(int) 70 | return sem 71 | -------------------------------------------------------------------------------- /lib/log_service.py: -------------------------------------------------------------------------------- 1 | import timeit 2 | import numpy as np 3 | import os.path as osp 4 | import torch 5 | import torch.nn as nn 6 | from .cfg_helper import cfg_unique_holder as cfguh 7 | 8 | def print_log(console_info): 9 | print(console_info) 10 | log_file = cfguh().cfg.LOG_FILE 11 | if log_file is not None: 12 | with open(log_file, 'a') as f: 13 | f.write(console_info + '\n') 14 | 15 | class log_manager(object): 16 | """ 17 | The helper to print logs. 18 | """ 19 | def __init__(self, 20 | **kwargs): 21 | self.data = {} 22 | self.cnt = {} 23 | self.time_check = timeit.default_timer() 24 | 25 | def accumulate(self, 26 | n, 27 | data, 28 | **kwargs): 29 | """ 30 | Args: 31 | n: number of items (i.e. the batchsize) 32 | data: {itemname : float} data (i.e. the loss values) 33 | which are going to be accumulated. 34 | """ 35 | if n < 0: 36 | raise ValueError 37 | 38 | for itemn, di in data.items(): 39 | try: 40 | self.data[itemn] += di * n 41 | except: 42 | self.data[itemn] = di * n 43 | 44 | try: 45 | self.cnt[itemn] += n 46 | except: 47 | self.cnt[itemn] = n 48 | 49 | def print(self, rank, itern, epochn, samplen, lr): 50 | console_info = [ 51 | 'Rank:{}'.format(rank), 52 | 'Iter:{}'.format(itern), 53 | 'Epoch:{}'.format(epochn), 54 | 'Sample:{}'.format(samplen), 55 | 'LR:{:.4f}'.format(lr)] 56 | 57 | cntgroups = {} 58 | for itemn, ci in self.cnt.items(): 59 | try: 60 | cntgroups[ci].append(itemn) 61 | except: 62 | cntgroups[ci] = [itemn] 63 | 64 | for ci, itemng in cntgroups.items(): 65 | console_info.append('cnt:{}'.format(ci)) 66 | for itemn in sorted(itemng): 67 | console_info.append('{}:{:.4f}'.format( 68 | itemn, self.data[itemn]/ci)) 69 | 70 | console_info.append('Time:{:.2f}s'.format( 71 | timeit.default_timer() - self.time_check)) 72 | return ' , '.join(console_info) 73 | 74 | def clear(self): 75 | self.data = {} 76 | self.cnt = {} 77 | self.time_check = timeit.default_timer() 78 | 79 | def pop(self, rank, itern, epochn, samplen, lr): 80 | console_info = self.print( 81 | rank, itern, epochn, samplen, lr) 82 | self.clear() 83 | return console_info 84 | 85 | # ----- also include some small utils ----- 86 | 87 | def torch_to_numpy(*argv): 88 | if len(argv) > 1: 89 | data = list(argv) 90 | else: 91 | data = argv[0] 92 | 93 | if isinstance(data, torch.Tensor): 94 | return data.to('cpu').detach().numpy() 95 | 96 | elif isinstance(data, (list, tuple)): 97 | out = [] 98 | for di in data: 99 | out.append(torch_to_numpy(di)) 100 | return out 101 | 102 | elif isinstance(data, dict): 103 | out = {} 104 | for ni, di in data.items(): 105 | out[ni] = torch_to_numpy(di) 106 | return out 107 | 108 | else: 109 | return data 110 | -------------------------------------------------------------------------------- /lib/loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import numpy as np 5 | import copy 6 | from . import torchutils 7 | 8 | class finalize_loss(object): 9 | def __init__(self, 10 | weight = None, 11 | normalize_weight = True, 12 | **kwargs): 13 | if weight is None: 14 | self.weight = None 15 | else: 16 | for _, wi in weight.items(): 17 | if wi < 0: 18 | raise ValueError 19 | 20 | if not normalize_weight: 21 | self.weight = weight 22 | else: 23 | sum_weight = 0 24 | for _, wi in weight.items(): 25 | sum_weight += wi 26 | if sum_weight == 0: 27 | raise ValueError 28 | self.weight = { 29 | itemn:wi/sum_weight for itemn, wi in weight.items()} 30 | 31 | self.normalize_weight = normalize_weight 32 | 33 | def __call__(self, 34 | loss_input,): 35 | item = {n : v.item() for n, v in loss_input.items()} 36 | lossname = [n for n in loss_input.keys() if n[0:4]=='loss'] 37 | 38 | if self.weight is not None: 39 | if sorted(lossname) \ 40 | != sorted(list(self.weight.keys())): 41 | raise ValueError 42 | 43 | loss_num = len(lossname) 44 | loss = None 45 | 46 | for n in lossname: 47 | v = loss_input[n] 48 | if loss is not None: 49 | if self.weight is not None: 50 | loss += v * self.weight[n] 51 | else: 52 | loss += v 53 | else: 54 | if self.weight is not None: 55 | loss = v * self.weight[n] 56 | else: 57 | loss = v 58 | 59 | if (self.weight is None) and (self.normalize_weight): 60 | loss /= loss_num 61 | 62 | item['Loss'] = loss.item() 63 | return loss, item 64 | -------------------------------------------------------------------------------- /lib/model_zoo/__init__.py: -------------------------------------------------------------------------------- 1 | from .get_model import get_model, save_state_dict 2 | -------------------------------------------------------------------------------- /lib/model_zoo/deeplab.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from .. import nputils 6 | from .. import torchutils 7 | from .. import loss 8 | from .get_model import get_model, register 9 | from .optim_manager import optim_manager 10 | 11 | from . import utils 12 | 13 | version = 'v33' 14 | 15 | class ASPP(nn.Module): 16 | def __init__(self, 17 | ic_n, 18 | c_n, 19 | oc_n, 20 | dilation_n, 21 | conv_type='conv', 22 | bn_type='bn', 23 | relu_type='relu', 24 | dropout_type='dropout|0.5', 25 | with_gap=True, 26 | **kwargs): 27 | super().__init__() 28 | 29 | conv, bn, relu = utils.conv_bn_relu(conv_type, bn_type, relu_type) 30 | dropout = utils.nn_component(dropout_type) 31 | 32 | d1, d2, d3 = dilation_n 33 | self.conv1 = nn.Sequential( 34 | conv(ic_n, c_n, 1, 1, padding=0, dilation=1), 35 | bn(c_n), 36 | relu(inplace=True)) 37 | self.conv2 = nn.Sequential( 38 | conv(ic_n, c_n, 3, 1, padding=d1, dilation=d1), 39 | bn(c_n), 40 | relu(inplace=True)) 41 | self.conv3 = nn.Sequential( 42 | conv(ic_n, c_n, 3, 1, padding=d2, dilation=d2), 43 | bn(c_n), 44 | relu(inplace=True)) 45 | self.conv4 = nn.Sequential( 46 | conv(ic_n, c_n, 3, 1, padding=d3, dilation=d3), 47 | bn(c_n), 48 | relu(inplace=True)) 49 | if with_gap: 50 | self.conv5 = nn.Sequential( 51 | nn.AdaptiveAvgPool2d((1, 1)), 52 | conv(ic_n, c_n, 1, 1, padding=0, dilation=1), 53 | bn(c_n), 54 | relu(inplace=True)) 55 | total_layers=5 56 | else: 57 | total_layers=4 58 | 59 | self.bottleneck = nn.Sequential( 60 | conv(c_n*total_layers, oc_n, 1, 1, 0), 61 | bn(oc_n), 62 | relu(inplace=True),) 63 | self.dropout = dropout() 64 | self.with_gap = with_gap 65 | 66 | def forward(self, x): 67 | _, _, h, w = x.size() 68 | feat1 = self.conv1(x) 69 | feat2 = self.conv2(x) 70 | feat3 = self.conv3(x) 71 | feat4 = self.conv4(x) 72 | if self.with_gap: 73 | feat5 = F.interpolate( 74 | self.conv5(x), size=(h, w), mode='nearest') 75 | out = torch.cat((feat1, feat2, feat3, feat4, feat5), dim=1) 76 | else: 77 | out = torch.cat((feat1, feat2, feat3, feat4), dim=1) 78 | out = self.bottleneck(out) 79 | out = self.dropout(out) 80 | return out 81 | 82 | class Decoder(nn.Module): 83 | def __init__(self, 84 | bic_n, 85 | xic_n, 86 | oc_n, 87 | align_corners=False, 88 | conv_type='conv', 89 | bn_type='bn', 90 | relu_type='relu', 91 | dropout2_type='dropout|0.5', 92 | dropout3_type='dropout|0.1', 93 | **kwargs): 94 | super().__init__() 95 | 96 | conv, bn, relu = utils.conv_bn_relu(conv_type, bn_type, relu_type) 97 | dropout2 = utils.nn_component(dropout2_type) 98 | dropout3 = utils.nn_component(dropout3_type) 99 | 100 | self.conv1 = conv(bic_n, 48, 1, 1, 0) 101 | self.bn1 = bn(48) 102 | self.relu = relu(inplace=True) 103 | 104 | self.conv2 = conv(xic_n+48, 256, 3, 1, 1) 105 | self.bn2 = bn(256) 106 | self.dropout2 = dropout2() 107 | 108 | self.conv3 = conv(256, oc_n, 3, 1, 1) 109 | self.bn3 = bn(oc_n) 110 | self.dropout3 = dropout3() 111 | 112 | self.align_corners = align_corners 113 | 114 | def forward(self, 115 | b, 116 | x): 117 | b = self.relu(self.bn1(self.conv1(b))) 118 | x = torchutils.interpolate_2d( 119 | size = b.shape[2:], mode='bilinear', 120 | align_corners=self.align_corners)(x) 121 | x = torch.cat([x, b], dim=1) 122 | x = self.relu(self.bn2(self.conv2(x))) 123 | x = self.dropout2(x) 124 | x = self.relu(self.bn3(self.conv3(x))) 125 | x = self.dropout3(x) 126 | return x 127 | 128 | class DeepLabv3p_Base(nn.Module): 129 | def __init__(self, 130 | bbn_name, 131 | bbn, 132 | oc_n, 133 | aspp_ic_n, 134 | aspp_dilation_n, 135 | decoder_bic_n, 136 | aspp_dropout_type='dropout|0.5', 137 | aspp_with_gap=True, 138 | decoder_dropout2_type='dropout|0.5', 139 | decoder_dropout3_type='dropout|0.1', 140 | conv_type='conv', 141 | bn_type='bn', 142 | relu_type='relu', 143 | align_corners=False, 144 | **kwargs): 145 | super().__init__() 146 | 147 | setattr(self, bbn_name, bbn) 148 | self.aspp = ASPP( 149 | aspp_ic_n, 256, 256, aspp_dilation_n, 150 | conv_type=conv_type, 151 | bn_type=bn_type, 152 | relu_type=relu_type, 153 | dropout_type=aspp_dropout_type, 154 | with_gap=aspp_with_gap) 155 | self.decoder = Decoder( 156 | decoder_bic_n, 256, oc_n, 157 | align_corners=align_corners, 158 | conv_type=conv_type, 159 | bn_type=bn_type, 160 | relu_type=relu_type, 161 | dropout2_type=decoder_dropout2_type, 162 | dropout3_type=decoder_dropout3_type,) 163 | 164 | self.bbn_name = bbn_name 165 | 166 | # initialize the weight 167 | utils.init_module([self.aspp, self.decoder]) 168 | 169 | # prepare opmgr 170 | self.opmgr = getattr(self, self.bbn_name).opmgr 171 | self.opmgr.inheritant(self.bbn_name) 172 | self.opmgr.pushback( 173 | 'deeplab', ['aspp', 'decoder'] 174 | ) 175 | 176 | def forward(self, x): 177 | xs = getattr(self, self.bbn_name)(x) 178 | b, x = xs[0], xs[-1] 179 | x = self.aspp(x) 180 | x = self.decoder(b, x) 181 | return x 182 | 183 | class DeepLabv3p(DeepLabv3p_Base): 184 | def __init__(self, 185 | bbn_name, 186 | bbn, 187 | class_n, 188 | aspp_ic_n, 189 | aspp_dilation_n, 190 | decoder_bic_n, 191 | aspp_dropout_type='dropout|0.5', 192 | aspp_with_gap=True, 193 | decoder_dropout2_type='dropout|0.5', 194 | decoder_dropout3_type='dropout|0.1', 195 | conv_type='conv', 196 | bn_type='bn', 197 | relu_type='relu', 198 | align_corners=False, 199 | ignore_label=None, 200 | loss_type='ce', 201 | intrain_getpred=False, 202 | ineval_output_argmax=True, 203 | **kwargs): 204 | 205 | super().__init__( 206 | bbn_name, 207 | bbn, 208 | 256, 209 | aspp_ic_n, 210 | aspp_dilation_n, 211 | decoder_bic_n, 212 | aspp_dropout_type, 213 | aspp_with_gap, 214 | decoder_dropout2_type, 215 | decoder_dropout3_type, 216 | conv_type, 217 | bn_type, 218 | relu_type, 219 | align_corners, 220 | ) 221 | 222 | self.semhead = utils.semantic_head( 223 | 256, class_n, 224 | align_corners=align_corners, 225 | ignore_label=ignore_label, 226 | loss_type=loss_type, 227 | ineval_output_argmax=ineval_output_argmax, 228 | intrain_getpred = intrain_getpred, 229 | ) 230 | 231 | # initialize the weight 232 | # (aspp and decoder initalized in base class) 233 | utils.init_module([self.semhead]) 234 | 235 | # prepare opmgr 236 | module_name = self.opmgr.popback() 237 | module_name.append('semhead') 238 | self.opmgr.pushback('deeplab', module_name) 239 | 240 | def forward(self, 241 | x, 242 | gtsem=None, 243 | ): 244 | x = super().forward(x) 245 | o = self.semhead(x, gtsem) 246 | return o 247 | 248 | @register( 249 | 'DEEPLAB', 250 | { 251 | # base 252 | 'freeze_backbone_bn' : 'FREEZE_BACKBONE_BN', 253 | 'oc_n' : 'OUTPUT_CHANNEL_NUM', 254 | 'conv_type' : 'CONV_TYPE', 255 | 'bn_type' : 'BN_TYPE', 256 | 'relu_type' : 'RELU_TYPE', 257 | 'aspp_dropout_type' : 'ASPP_DROPOUT_TYPE', 258 | 'aspp_with_gap' : 'ASPP_WITH_GAP', 259 | 'decoder_dropout2_type' : 'DECODER_DROPOUT2_TYPE', 260 | 'decoder_dropout3_type' : 'DECODER_DROPOUT3_TYPE', 261 | 'align_corners' : 'INTERPOLATE_ALIGN_CORNERS', 262 | # non_base 263 | 'ignore_label' : 'SEMANTIC_IGNORE_LABEL', 264 | 'class_n' : 'SEMANTIC_CLASS_NUM', 265 | 'loss_type' : 'LOSS_TYPE', 266 | 'intrain_getpred' : 'INTRAIN_GETPRED', 267 | 'ineval_output_argmax' : 'INEVAL_OUTPUT_ARGMAX', 268 | }) 269 | def deeplab(tags, **para): 270 | if 'resnet' in tags: 271 | bbn = get_model()('resnet') 272 | para['bbn_name'] = 'resnet' 273 | para['bbn'] = bbn 274 | para['aspp_ic_n'] = 512*bbn.block_expansion 275 | para['decoder_bic_n'] = 64*bbn.block_expansion 276 | 277 | if 'os16' in tags: 278 | para['aspp_dilation_n'] = [6, 12, 18] 279 | elif 'os8' in tags: 280 | para['aspp_dilation_n'] = [12, 24, 36] 281 | else: 282 | raise ValueError 283 | 284 | try: 285 | freezebn = para.pop('freeze_backbone_bn') 286 | except: 287 | freezebn = False 288 | if freezebn: 289 | for m in bbn.modules(): 290 | if isinstance(m, nn.BatchNorm2d): 291 | for i in m.parameters(): 292 | i.requires_grad = False 293 | 294 | if 'v3+' in tags: 295 | if 'base' in tags: 296 | net = DeepLabv3p_Base(**para) 297 | else: 298 | net = DeepLabv3p(**para) 299 | 300 | return net 301 | -------------------------------------------------------------------------------- /lib/model_zoo/get_model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torchvision.models 3 | import os.path as osp 4 | from ..cfg_helper import cfg_unique_holder as cfguh 5 | from ..log_service import print_log 6 | from .utils import get_total_param, get_total_param_sum, freeze 7 | 8 | def load_state_dict(net, model_path): 9 | paras = torch.load(model_path, map_location=torch.device('cpu')) 10 | new_paras = net.state_dict() 11 | new_paras.update(paras) 12 | net.load_state_dict(new_paras) 13 | return 14 | 15 | def save_state_dict(net, path): 16 | if isinstance(net, (torch.nn.DataParallel, 17 | torch.nn.parallel.DistributedDataParallel)): 18 | torch.save(net.module.state_dict(), path) 19 | else: 20 | torch.save(net.state_dict(), path) 21 | 22 | def singleton(class_): 23 | instances = {} 24 | def getinstance(*args, **kwargs): 25 | if class_ not in instances: 26 | instances[class_] = class_(*args, **kwargs) 27 | return instances[class_] 28 | return getinstance 29 | 30 | @singleton 31 | class get_model(object): 32 | def __init__(self): 33 | self.model = {} 34 | 35 | def register(self, modelf, cfgname, kwmap, kwfix): 36 | self.model[modelf.__name__] = [modelf, cfgname, kwmap, kwfix] 37 | 38 | def __call__(self, name=None, cfgm=None): 39 | if cfgm is None: 40 | cfgm = cfguh().cfg.MODEL 41 | if name is None: 42 | name = cfgm.MODEL_NAME 43 | 44 | # the register is in each file 45 | if name == 'resnet': 46 | from . import resnet 47 | elif name == 'deeplab': 48 | from . import deeplab 49 | elif name == 'hrnet': 50 | from . import hrnet 51 | elif name == 'texrnet': 52 | from . import texrnet 53 | 54 | modelf, cfgname, kwmap, kwfix = self.model[name] 55 | cfgm = cfgm.__getitem__(cfgname) 56 | 57 | # MODEL_TAGS and PRETRAINED_PTH are two special args 58 | # FREEZE_BACKBONE_BN not frequently used. 59 | kw = {'tags' : cfgm.MODEL_TAGS} 60 | for k1, k2 in kwmap.items(): 61 | if k2 in cfgm.keys(): 62 | kw[k1] = cfgm[k2] 63 | kw.update(kwfix) 64 | net = modelf(**kw) 65 | 66 | # load init model 67 | if cfgm.PRETRAINED_PTH is not None: 68 | print_log('Load model from {0}.'.format( 69 | cfgm.PRETRAINED_PTH)) 70 | load_state_dict( 71 | net, cfgm.PRETRAINED_PTH) 72 | 73 | # display param_num & param_sum 74 | print_log('Load {} with total {} parameters, {:3f} parameter sum.'.format( 75 | name, get_total_param(net), get_total_param_sum(net))) 76 | 77 | return net 78 | 79 | def register(cfgname, kwmap={}, kwfix={}): 80 | def wrapper(class_): 81 | get_model().register(class_, cfgname, kwmap, kwfix) 82 | return class_ 83 | return wrapper 84 | -------------------------------------------------------------------------------- /lib/model_zoo/hrnet.py: -------------------------------------------------------------------------------- 1 | import sys 2 | # import os 3 | import os.path as osp 4 | sys.path.append(osp.join(osp.dirname(__file__), '..', '..', 'hrnet_code', 'lib')) 5 | # sys.path.append('/home/james/Spy/hrnet/HRNet-Semantic-Segmentation-HRNet-OCR/lib') 6 | from models import seg_hrnet 7 | from models import seg_hrnet_ocr 8 | 9 | import torch 10 | import torch.nn as nn 11 | import torch.nn.functional as F 12 | 13 | from .. import nputils 14 | from .. import torchutils 15 | from .. import loss 16 | from .get_model import get_model, register 17 | from .optim_manager import optim_manager 18 | 19 | from . import utils 20 | 21 | version = 'v0' 22 | """ 23 | v0: the original code from github made by paper author 24 | """ 25 | 26 | class HRNet_Base(seg_hrnet.HighResolutionNet): 27 | def __init__(self, 28 | oc_n, 29 | align_corners, 30 | ignore_label, 31 | stage1_para, 32 | stage2_para, 33 | stage3_para, 34 | stage4_para, 35 | final_conv_kernel, 36 | **kwargs): 37 | from easydict import EasyDict as edict 38 | config = edict() 39 | config.MODEL = edict() 40 | config.MODEL.ALIGN_CORNERS = align_corners 41 | config.MODEL.EXTRA = {} 42 | config.MODEL.EXTRA['STAGE1'] = stage1_para 43 | config.MODEL.EXTRA['STAGE2'] = stage2_para 44 | config.MODEL.EXTRA['STAGE3'] = stage3_para 45 | config.MODEL.EXTRA['STAGE4'] = stage4_para 46 | config.MODEL.EXTRA['FINAL_CONV_KERNEL'] = final_conv_kernel 47 | config.DATASET = edict() 48 | config.DATASET.NUM_CLASSES = 1 # dummy 49 | super().__init__(config) 50 | 51 | self.opmgr = optim_manager( 52 | group = {'hrnet': 'self'}, 53 | order = ['hrnet'], 54 | ) 55 | 56 | last_inp_channels = self.last_layer[0].in_channels 57 | # BatchNorm2d = nn.SyncBatchNorm 58 | BatchNorm2d = nn.BatchNorm2d 59 | relu_inplace = True 60 | self.last_layer = nn.Sequential( 61 | nn.Conv2d( 62 | in_channels=last_inp_channels, 63 | out_channels=oc_n, 64 | kernel_size=1, 65 | stride=1, 66 | padding=0), 67 | BatchNorm2d(oc_n, momentum=0.1), 68 | nn.ReLU(inplace=relu_inplace), 69 | ) 70 | 71 | def forward(self, x): 72 | x = super().forward(x) 73 | return x 74 | 75 | class HRNet(seg_hrnet.HighResolutionNet): 76 | def __init__(self, 77 | cls_n, 78 | align_corners, 79 | ignore_label, 80 | stage1_para, 81 | stage2_para, 82 | stage3_para, 83 | stage4_para, 84 | final_conv_kernel, 85 | loss_type='ce', 86 | intrain_getpred=False, 87 | ineval_output_argmax=False, 88 | **kwargs): 89 | from easydict import EasyDict as edict 90 | config = edict() 91 | config.MODEL = edict() 92 | config.MODEL.ALIGN_CORNERS = align_corners 93 | config.MODEL.EXTRA = {} 94 | config.MODEL.EXTRA['STAGE1'] = stage1_para 95 | config.MODEL.EXTRA['STAGE2'] = stage2_para 96 | config.MODEL.EXTRA['STAGE3'] = stage3_para 97 | config.MODEL.EXTRA['STAGE4'] = stage4_para 98 | config.MODEL.EXTRA['FINAL_CONV_KERNEL'] = final_conv_kernel 99 | config.DATASET = edict() 100 | config.DATASET.NUM_CLASSES = cls_n 101 | super().__init__(config) 102 | 103 | self.semhead = utils.semantic_head_noconv( 104 | align_corners = align_corners, 105 | ignore_label = ignore_label, 106 | loss_type = loss_type, 107 | intrain_getpred = intrain_getpred, 108 | ineval_output_argmax = ineval_output_argmax, 109 | ) 110 | 111 | self.opmgr = optim_manager( 112 | group = {'hrnet': 'self'}, 113 | order = ['hrnet'], 114 | ) 115 | 116 | def forward(self, x, gtsem=None): 117 | x = super().forward(x) 118 | o = self.semhead(x, gtsem) 119 | return o 120 | 121 | class HRNet_Ocr(seg_hrnet_ocr.HighResolutionNet): 122 | def __init__(self, 123 | cls_n, 124 | align_corners, 125 | ignore_label, 126 | stage1_para, 127 | stage2_para, 128 | stage3_para, 129 | stage4_para, 130 | final_conv_kernel, 131 | loss_type='ce', 132 | intrain_getpred=False, 133 | ineval_output_argmax=False, 134 | ocr_mc_n = 512, 135 | ocr_keyc_n = 256, 136 | ocr_dropout_rate = 0.05, 137 | ocr_scale = 1, 138 | **kwargs): 139 | from easydict import EasyDict as edict 140 | config = edict() 141 | config.MODEL = edict() 142 | config.MODEL.ALIGN_CORNERS = align_corners 143 | config.MODEL.EXTRA = {} 144 | config.MODEL.EXTRA['STAGE1'] = stage1_para 145 | config.MODEL.EXTRA['STAGE2'] = stage2_para 146 | config.MODEL.EXTRA['STAGE3'] = stage3_para 147 | config.MODEL.EXTRA['STAGE4'] = stage4_para 148 | config.MODEL.EXTRA['FINAL_CONV_KERNEL'] = final_conv_kernel 149 | config.DATASET = edict() 150 | config.DATASET.NUM_CLASSES = cls_n 151 | 152 | # OCR special 153 | config.MODEL.OCR = edict() 154 | config.MODEL.OCR.MID_CHANNELS = ocr_mc_n 155 | config.MODEL.OCR.KEY_CHANNELS = ocr_keyc_n 156 | config.MODEL.OCR.DROPOUT = ocr_dropout_rate 157 | config.MODEL.OCR.SCALE = ocr_scale 158 | 159 | super().__init__(config) 160 | 161 | self.auxhead = utils.semantic_head_noconv( 162 | align_corners = align_corners, 163 | ignore_label = ignore_label, 164 | loss_type = loss_type, 165 | intrain_getpred = False, 166 | ineval_output_argmax = False, 167 | ) # the aux head (not main) 168 | 169 | self.semhead = utils.semantic_head_noconv( 170 | align_corners = align_corners, 171 | ignore_label = ignore_label, 172 | loss_type = loss_type, 173 | intrain_getpred = intrain_getpred, 174 | ineval_output_argmax = ineval_output_argmax, 175 | ) 176 | 177 | self.opmgr = optim_manager( 178 | group = {'hrnet': 'self'}, 179 | order = ['hrnet'], 180 | ) 181 | 182 | def forward(self, x, gtsem=None): 183 | x = super().forward(x) 184 | o = self.semhead(x[1], gtsem) 185 | if self.training: 186 | o['lossaux'] = self.auxhead(x[0], gtsem)['losssem'] 187 | return o 188 | 189 | 190 | @register( 191 | 'HRNET', 192 | { 193 | 'oc_n' : 'OUTPUT_CHANNEL_NUM', 194 | 'cls_n' : 'CLASS_NUM', 195 | 'align_corners' : 'ALIGN_CORNERS', 196 | 'ignore_label' : 'IGNORE_LABEL', 197 | 'stage1_para' : 'STAGE1_PARA', 198 | 'stage2_para' : 'STAGE2_PARA', 199 | 'stage3_para' : 'STAGE3_PARA', 200 | 'stage4_para' : 'STAGE4_PARA', 201 | 'final_conv_kernel' : 'FINAL_CONV_KERNEL', 202 | 'loss_type' : 'LOSS_TYPE', 203 | 'intrain_getpred' : 'INTRAIN_GETPRED', 204 | 'ineval_output_argmax' : 'INEVAL_OUTPUT_ARGMAX', 205 | }) 206 | def hrnet(tags, **para): 207 | if 'base' in tags: 208 | net = HRNet_Base(**para) 209 | elif 'ocr' in tags: 210 | net = HRNet_Ocr(**para) 211 | else: 212 | net = HRNet(**para) 213 | return net 214 | -------------------------------------------------------------------------------- /lib/model_zoo/optim_manager.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import itertools 5 | 6 | class optim_manager(object): 7 | def __init__(self, 8 | group, 9 | order=None, 10 | group_lrscale=None,): 11 | self.group = {} 12 | self.order = [] 13 | for gn in order: 14 | self.pushback(gn, group[gn]) 15 | self.group_lrscale = group_lrscale 16 | 17 | def pushback(self, 18 | group_name, 19 | module_name): 20 | if group_name in self.group.keys(): 21 | raise ValueError 22 | if isinstance(module_name, (list, tuple)): 23 | self.group[group_name] = list(module_name) 24 | else: 25 | self.group[group_name] = [module_name] 26 | self.order += [group_name] 27 | self.group_lrscale = None 28 | 29 | def popback(self): 30 | group = self.group.pop(self.order[-1]) 31 | self.order = self.order[:-1] 32 | self.group_lrscale = None 33 | return group 34 | 35 | def replace(self, 36 | group_name, 37 | module_name,): 38 | if group_name not in self.group.keys(): 39 | raise ValueError 40 | if isinstance(module_name, (list, tuple)): 41 | module_name = list(module_name) 42 | else: 43 | module_name = [module_name] 44 | self.group[group_name] = module_name 45 | self.group_lrscale = None 46 | 47 | def inheritant(self, 48 | supername): 49 | for gn in self.group.keys(): 50 | module_name = self.group[gn] 51 | module_name_new = [] 52 | for mn in module_name: 53 | if mn == 'self': 54 | module_name_new.append(supername) 55 | else: 56 | module_name_new.append(supername+'.'+mn) 57 | self.group[gn] = module_name_new 58 | 59 | def set_lrscale(self, 60 | group_lrscale): 61 | if sorted(self.order) != \ 62 | sorted(list(group_lrscale.keys())): 63 | raise ValueError 64 | self.group_lrscale = group_lrscale 65 | 66 | def pg_generator(self, 67 | netref, 68 | module_name): 69 | if not isinstance(module_name, list): 70 | raise ValueError 71 | # the "self" special case 72 | if (len(module_name)==1) \ 73 | and (module_name[0] == 'self'): 74 | return netref.parameters() 75 | pg = [] 76 | for mn in module_name: 77 | if mn == 'self': 78 | raise ValueError 79 | mn = mn.split('.') 80 | module = netref 81 | for mni in mn: 82 | module = getattr(module, mni) 83 | 84 | pg.append(module.parameters()) 85 | 86 | pg = itertools.chain(*pg) 87 | return pg 88 | 89 | def get_pg(self, net): 90 | try: 91 | netref = net.module 92 | except: 93 | netref = net 94 | 95 | return [ 96 | {'params': self.pg_generator(netref, self.group[gn])} \ 97 | for gn in self.order] 98 | 99 | def get_pglr(self, idx, base_lr): 100 | return base_lr * self.group_lrscale[self.order[idx]] 101 | -------------------------------------------------------------------------------- /lib/optimizer/__init__.py: -------------------------------------------------------------------------------- 1 | from .get_optimizer import get_optimizer, adjust_lr, lr_scheduler -------------------------------------------------------------------------------- /lib/optimizer/get_optimizer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.optim as optim 3 | import numpy as np 4 | 5 | from ..cfg_helper import cfg_unique_holder as cfguh 6 | 7 | def get_optimizer(net, optimizer_name = None, opmgr = None): 8 | cfg = cfguh().cfg 9 | if optimizer_name is None: 10 | optimizer_name = cfg.TRAIN.OPTIMIZER 11 | 12 | # all lr are initialized as 0, 13 | # because it will be set outside this function 14 | if opmgr is None: 15 | parameter_groups = net.parameters() 16 | else: 17 | para_num = len([i for i in net.parameters()]) 18 | para_ckt = sum([ 19 | len([i for i in pg['params']]) for pg in opmgr.get_pg(net)]) 20 | if para_num != para_ckt: 21 | # this check whether the opmgr paragroup include all parameters. 22 | # TODO: may put a warning here 23 | raise ValueError 24 | parameter_groups = opmgr.get_pg(net) 25 | 26 | if optimizer_name == "sgd": 27 | optimizer = optim.SGD( 28 | parameter_groups, 29 | lr = 0, 30 | momentum = cfg.TRAIN.SGD_MOMENTUM, 31 | weight_decay = cfg.TRAIN.SGD_WEIGHT_DECAY) 32 | 33 | elif optimizer_name == "adam": 34 | optimizer = optim.Adam( 35 | parameter_groups, 36 | lr = 0, 37 | betas = cfg.TRAIN.ADAM_BETAS, 38 | eps = cfg.TRAIN.ADAM_EPS, 39 | weight_decay = cfg.TRAIN.ADAM_WEIGHT_DECAY) 40 | 41 | else: 42 | raise ValueError 43 | 44 | return optimizer 45 | 46 | def adjust_lr(op, new_lr, opmgr = None): 47 | for idx, param_group in enumerate(op.param_groups): 48 | if opmgr is None: 49 | param_group['lr'] = new_lr 50 | else: 51 | param_group['lr'] = opmgr.get_pglr(idx, new_lr) 52 | 53 | class lr_scheduler(object): 54 | def __init__(self, 55 | types): 56 | self.lr = [] 57 | for type in types: 58 | if type[0] == 'constant': 59 | _, v, n = type 60 | lr = [v for i in range(n)] 61 | elif type[0] == 'ploy': 62 | _, va, vb, n, pw = type 63 | lr = [ vb + (va-vb) * ((1-i/n)**pw) for i in range(n) ] 64 | elif type[0] == 'linear': 65 | _, va, vb, n = type 66 | lr = [ vb + (va-vb) * (1-i/n) for i in range(n) ] 67 | else: 68 | raise ValueError 69 | self.lr += lr 70 | self.lr = np.array(self.lr) 71 | 72 | def __call__(self, i): 73 | if i < len(self.lr): 74 | return self.lr[i] 75 | else: 76 | return self.lr[-1] 77 | 78 | 79 | 80 | 81 | 82 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import torch.distributed as dist 2 | import torch.multiprocessing as mp 3 | 4 | import os 5 | import os.path as osp 6 | import sys 7 | import numpy as np 8 | import copy 9 | import gc 10 | import time 11 | 12 | import argparse 13 | from easydict import EasyDict as edict 14 | 15 | from lib.model_zoo.texrnet import version as VERSION 16 | from lib.cfg_helper import cfg_unique_holder as cfguh, \ 17 | get_experiment_id, \ 18 | experiment_folder, \ 19 | common_initiates 20 | 21 | from configs.cfg_dataset import cfg_textseg, cfg_cocots, cfg_mlt, cfg_icdar13, cfg_totaltext 22 | from configs.cfg_model import cfg_texrnet as cfg_mdel 23 | from configs.cfg_base import cfg_train, cfg_test 24 | 25 | from train_utils import \ 26 | set_cfg as set_cfg_train, \ 27 | set_cfg_hrnetw48 as set_cfg_hrnetw48_train, \ 28 | ts, ts_with_classifier, train 29 | 30 | from eval_utils import \ 31 | set_cfg as set_cfg_eval, \ 32 | set_cfg_hrnetw48 as set_cfg_hrnetw48_eval, \ 33 | es, eval 34 | 35 | cfguh().add_code(osp.basename(__file__)) 36 | 37 | def common_argparse(): 38 | 39 | 40 | cfg = edict() 41 | cfg.DEBUG = args.debug 42 | cfg.DIST_URL = 'tcp://127.0.0.1:{}'.format(args.port) 43 | is_eval = args.eval 44 | pth = args.pth 45 | return cfg, is_eval, pth 46 | 47 | if __name__ == '__main__': 48 | parser = argparse.ArgumentParser() 49 | parser.add_argument('--debug' , action='store_true', default=False) 50 | parser.add_argument('--hrnet' , action='store_true', default=False) 51 | parser.add_argument('--eval' , action='store_true', default=False) 52 | parser.add_argument('--pth' , type=str) 53 | parser.add_argument('--gpu' , nargs='+', type=int) 54 | parser.add_argument('--port' , type=int, default=11233) 55 | parser.add_argument('--dsname', type=str, default='textseg') 56 | parser.add_argument('--trainwithcls', action='store_true', default=False) 57 | args = parser.parse_args() 58 | 59 | istrain = not args.eval 60 | 61 | if istrain: 62 | cfg = copy.deepcopy(cfg_train) 63 | else: 64 | cfg = copy.deepcopy(cfg_test) 65 | 66 | if istrain: 67 | cfg.EXPERIMENT_ID = get_experiment_id() 68 | else: 69 | cfg.EXPERIMENT_ID = None 70 | 71 | if args.dsname == "textseg": 72 | cfg_data = cfg_textseg 73 | elif args.dsname == "cocots": 74 | cfg_data = cfg_cocots 75 | elif args.dsname == "mlt": 76 | cfg_data = cfg_mlt 77 | elif args.dsname == "icdar13": 78 | cfg_data = cfg_icdar13 79 | elif args.dsname == "totaltext": 80 | cfg_data = cfg_totaltext 81 | else: 82 | raise ValueError 83 | 84 | cfg.DEBUG = args.debug 85 | cfg.DIST_URL = 'tcp://127.0.0.1:{}'.format(args.port) 86 | if args.gpu is None: 87 | cfg.GPU_DEVICE = 'all' 88 | else: 89 | cfg.GPU_DEVICE = args.gpu 90 | 91 | cfg.MODEL = copy.deepcopy(cfg_mdel) 92 | cfg.DATA = copy.deepcopy(cfg_data) 93 | 94 | if istrain: 95 | cfg = set_cfg_train(cfg, dsname=args.dsname) 96 | if args.hrnet: 97 | cfg = set_cfg_hrnetw48_train(cfg) 98 | else: 99 | cfg = set_cfg_eval(cfg, dsname=args.dsname) 100 | if args.hrnet: 101 | cfg = set_cfg_hrnetw48_eval(cfg) 102 | cfg.MODEL.TEXRNET.PRETRAINED_PTH = args.pth 103 | 104 | if istrain: 105 | if args.dsname == "textseg": 106 | cfg.DATA.DATASET_MODE = 'train+val' 107 | elif args.dsname == "cocots": 108 | cfg.DATA.DATASET_MODE = 'train' 109 | elif args.dsname == "mlt": 110 | cfg.DATA.DATASET_MODE = 'trainseg' 111 | elif args.dsname == "icdar13": 112 | cfg.DATA.DATASET_MODE = 'train_fst' 113 | elif args.dsname == "totaltext": 114 | cfg.DATA.DATASET_MODE = 'train' 115 | else: 116 | raise ValueError 117 | else: 118 | if args.dsname == "textseg": 119 | cfg.DATA.DATASET_MODE = 'test' 120 | elif args.dsname == "cocots": 121 | cfg.DATA.DATASET_MODE = 'val' 122 | elif args.dsname == "mlt": 123 | cfg.DATA.DATASET_MODE = 'valseg' 124 | elif args.dsname == "icdar13": 125 | cfg.DATA.DATASET_MODE = 'test_fst' 126 | elif args.dsname == "totaltext": 127 | cfg.DATA.DATASET_MODE = 'test' 128 | else: 129 | raise ValueError 130 | 131 | if istrain: 132 | if args.trainwithcls: 133 | if args.dsname == 'textseg': 134 | cfg.DATA.LOADER_PIPELINE = [ 135 | 'NumpyImageLoader', 136 | 'TextSeg_SeglabelLoader', 137 | 'CharBboxSpLoader',] 138 | cfg.DATA.RANDOM_RESIZE_CROP_SIZE = [32, 32] 139 | cfg.DATA.RANDOM_RESIZE_CROP_SCALE = [0.8, 1.2] 140 | cfg.DATA.RANDOM_RESIZE_CROP_RATIO = [3/4, 4/3] 141 | cfg.DATA.TRANS_PIPELINE = [ 142 | 'UniformNumpyType', 143 | 'TextSeg_RandomResizeCropCharBbox', 144 | 'NormalizeUint8ToZeroOne', 145 | 'Normalize', 146 | 'RandomScaleOneSide', 147 | 'RandomCrop', 148 | ] 149 | elif args.dsname == 'icdar13': 150 | cfg.DATA.LOADER_PIPELINE = [ 151 | 'NumpyImageLoader', 152 | 'SeglabelLoader', 153 | 'CharBboxSpLoader',] 154 | cfg.DATA.TRANS_PIPELINE = [ 155 | 'UniformNumpyType', 156 | 'NormalizeUint8ToZeroOne', 157 | 'Normalize', 158 | 'RandomScaleOneSide', 159 | 'RandomCrop', 160 | ] 161 | else: 162 | raise ValueError 163 | cfg.DATA.FORMATTER = 'SemChinsChbbxFormatter' 164 | cfg.DATA.LOADER_SQUARE_BBOX = True 165 | cfg.DATA.RANDOM_RESIZE_CROP_FROM = 'sem' 166 | cfg.MODEL.TEXRNET.INTRAIN_GETPRED_FROM = 'sem' 167 | # the one with 93.98% and trained on semantic crops 168 | cfg.TRAIN.CLASSIFIER_PATH = osp.join( 169 | 'pretrained', 'init', 'resnet50_textcls.pth', 170 | ) 171 | cfg.TRAIN.ROI_BBOX_PADDING_TYPE = 'semcrop' 172 | cfg.TRAIN.ROI_ALIGN_SIZE = [32, 32] 173 | cfg.TRAIN.UPDATE_CLASSIFIER = False 174 | cfg.TRAIN.ACTIVATE_CLASSIFIER_FOR_SEGMODEL_AFTER = 0 175 | cfg.TRAIN.LOSS_WEIGHT = { 176 | 'losssem' : 1, 177 | 'lossrfn' : 0.5, 178 | 'lossrfntri': 0.5, 179 | 'losscls' : 0.1, 180 | } 181 | 182 | if istrain: 183 | if args.hrnet: 184 | cfg.TRAIN.SIGNATURE = ['texrnet', 'hrnet'] 185 | else: 186 | cfg.TRAIN.SIGNATURE = ['texrnet', 'deeplab'] 187 | cfg.LOG_DIR = experiment_folder(cfg, isnew=True, sig=cfg.TRAIN.SIGNATURE) 188 | cfg.LOG_FILE = osp.join(cfg.LOG_DIR, 'train.log') 189 | else: 190 | cfg.LOG_DIR = osp.join(cfg.MISC_DIR, 'eval') 191 | cfg.LOG_FILE = osp.join(cfg.LOG_DIR, 'eval.log') 192 | cfg.TEST.SUB_DIR = None 193 | 194 | if cfg.DEBUG: 195 | cfg.EXPERIMENT_ID = 999999999999 196 | cfg.DATA.NUM_WORKERS_PER_GPU = 0 197 | cfg.TRAIN.BATCH_SIZE_PER_GPU = 2 198 | 199 | cfg = common_initiates(cfg) 200 | 201 | if istrain: 202 | if args.trainwithcls: 203 | exec_ts = ts_with_classifier() 204 | else: 205 | exec_ts = ts() 206 | trainer = train(cfg) 207 | trainer.register_stage(exec_ts) 208 | 209 | # trainer(0) 210 | mp.spawn(trainer, 211 | args=(), 212 | nprocs=cfg.GPU_COUNT, 213 | join=True) 214 | else: 215 | exec_es = es() 216 | tester = eval(cfg) 217 | tester.register_stage(exec_es) 218 | 219 | # tester(0) 220 | mp.spawn(tester, 221 | args=(), 222 | nprocs=cfg.GPU_COUNT, 223 | join=True) 224 | -------------------------------------------------------------------------------- /requirement.txt: -------------------------------------------------------------------------------- 1 | torch==1.6 2 | torchvision==0.7 3 | matplotlib==3.3.2 4 | opencv-python==4.5.1.48 5 | easydict==1.9 6 | --------------------------------------------------------------------------------