├── .figure
├── image_anno.jpg
├── image_only.jpg
└── network.jpg
├── .gitignore
├── README.md
├── configs
├── __init__.py
├── cfg_base.py
├── cfg_dataset.py
└── cfg_model.py
├── eval_utils.py
├── hrnet_code
├── .gitignore
├── LICENSE
├── README.md
├── lib
│ ├── config
│ │ ├── __init__.py
│ │ ├── default.py
│ │ └── models.py
│ ├── core
│ │ ├── criterion.py
│ │ └── function.py
│ ├── datasets
│ │ ├── __init__.py
│ │ ├── ade20k.py
│ │ ├── base_dataset.py
│ │ ├── cityscapes.py
│ │ ├── cocostuff.py
│ │ ├── lip.py
│ │ └── pascal_ctx.py
│ ├── models
│ │ ├── __init__.py
│ │ ├── bn_helper.py
│ │ ├── seg_hrnet.py
│ │ ├── seg_hrnet_ocr.py
│ │ └── sync_bn
│ │ │ ├── LICENSE
│ │ │ ├── __init__.py
│ │ │ └── inplace_abn
│ │ │ ├── __init__.py
│ │ │ ├── bn.py
│ │ │ ├── functions.py
│ │ │ └── src
│ │ │ ├── common.h
│ │ │ ├── inplace_abn.cpp
│ │ │ ├── inplace_abn.h
│ │ │ ├── inplace_abn_cpu.cpp
│ │ │ └── inplace_abn_cuda.cu
│ └── utils
│ │ ├── __init__.py
│ │ ├── distributed.py
│ │ ├── modelsummary.py
│ │ └── utils.py
├── requirements.txt
└── tools
│ ├── _init_paths.py
│ ├── test.py
│ └── train.py
├── inference.py
├── lib
├── __init__.py
├── cfg_helper.py
├── data_factory
│ ├── __init__.py
│ ├── ds_base.py
│ ├── ds_cocotext.py
│ ├── ds_formatter.py
│ ├── ds_icdar13.py
│ ├── ds_loader.py
│ ├── ds_mlt.py
│ ├── ds_sampler.py
│ ├── ds_textseg.py
│ ├── ds_textssc.py
│ ├── ds_totaltext.py
│ └── ds_transform.py
├── evaluate_service.py
├── log_service.py
├── loss.py
├── model_zoo
│ ├── __init__.py
│ ├── deeplab.py
│ ├── get_model.py
│ ├── hrnet.py
│ ├── optim_manager.py
│ ├── resnet.py
│ ├── texrnet.py
│ └── utils.py
├── nputils.py
├── optimizer
│ ├── __init__.py
│ └── get_optimizer.py
└── torchutils.py
├── main.py
├── requirement.txt
└── train_utils.py
/.figure/image_anno.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SHI-Labs/Rethinking-Text-Segmentation/223f7ffc822c345ce1a7c0eb3d4fac58a43d6a3a/.figure/image_anno.jpg
--------------------------------------------------------------------------------
/.figure/image_only.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SHI-Labs/Rethinking-Text-Segmentation/223f7ffc822c345ce1a7c0eb3d4fac58a43d6a3a/.figure/image_only.jpg
--------------------------------------------------------------------------------
/.figure/network.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SHI-Labs/Rethinking-Text-Segmentation/223f7ffc822c345ce1a7c0eb3d4fac58a43d6a3a/.figure/network.jpg
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | /data
2 | /models
3 | /log
4 | /scripts
5 | /pretrained
6 | **/__pycache__
7 | **/.idea/*
8 | **/.vscode/*
9 | **/.ipynb_checkpoints/*
10 | **/old/*
11 | **/veryold/*
12 | *.ipynb
13 | **/build
14 | **/*.out
15 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Rethinking Text Segmentation: A Novel Dataset and A Text-Specific Refinement Approach
2 |
3 | This is the repo to host the dataset TextSeg and code for TexRNet from the following paper:
4 |
5 | [Xingqian Xu](https://ifp-uiuc.github.io/), [Zhifei Zhang](http://web.eecs.utk.edu/~zzhang61/), [Zhaowen Wang](https://research.adobe.com/person/zhaowen-wang/), [Brian Price](https://research.adobe.com/person/brian-price/), [Zhonghao Wang](https://ifp-uiuc.github.io/) and [Humphrey Shi](https://www.humphreyshi.com), **Rethinking Text Segmentation: A Novel Dataset and A Text-Specific Refinement Approach**, [ArXiv Link](arxiv.org/abs/2011.14021)
6 |
7 | **Note:**
8 |
11 | \[2021.04.21\] So far, our dataset is partially released with images and semantic labels. Since many people may request the dataset for OCR or non-segmentation tasks, please stay tuned, and we will release the dataset in full ASAP.
12 |
13 | \[2021.06.18\] **Our dataset is now fully released.** To download the data, please send a request email to *textseg.dataset@gmail.com* and tell us which school you are affiliated with. Please be aware the released dataset is **version 2**, and the annotations are slightly different from the one in the paper. In order to provide the most accurate dataset, we went through a second round of quality assurance, in which we fixed some faulty annotations and made them more consistent across the dataset. Since our TexRNet in the paper doesn't use OCR and character instance labels (*i.e.* word- and character-level bounding polygons; character-level masks;), we will not release the older version of these labels. However, we release the retroactive ```semantic_label_v1.tar.gz``` for researchers to reproduce the results in the paper. For more details about the dataset, please see below.
14 |
15 | ## Introduction
16 | Text in the real world is extremely diverse, yet current text dataset does not reflect such diversity very well. To bridge this gap, we proposed TextSeg, a large-scale fine-annotated and multi-purpose text dataset, collecting scene and design text with six types of annotations: word- and character-wise bounding polygons, masks and transcriptions. We also introduce Text Refinement Network (TexRNet), a novel text segmentation approach that adapts to the unique properties of text, e.g. non-convex boundary, diverse texture, etc., which often impose burdens on traditional segmentation models. TexRNet refines results from common segmentation approach via key features pooling and attention, so that wrong-activated text regions can be adjusted. We also introduce trimap and discriminator losses that show significant improvement on text segmentation.
17 |
18 | ## TextSeg Dataset
19 |
20 | ### Image Collection
21 |
22 |
23 |
24 |
25 |
26 | ### Annotation
27 |
28 |
29 |
30 |
31 |
32 | ### Download
33 |
34 | Our dataset (TextSeg) is academia-only and cannot be used on any commercial project and research. To download the data, please send a request email to *textseg.dataset@gmail.com* and tell us which school you are affiliated with.
35 |
36 | A full download should contain these files:
37 |
38 | * ```image.tar.gz``` contains 4024 images.
39 | * ```annotation.tar.gz``` labels corresponding to the images. These three types of files are included:
40 | * ```[dataID]_anno.json``` contains all word- and character-level translations and bounding polygons.
41 | * ```[dataID]_mask.png``` contains all character masks. Character mask label value will be ordered from 1 to n. Label value 0 means background, 255 means ignore.
42 | * ```[dataID]_maskeff.png``` contains all character masks **with effect**.
43 | * ```Adobe_Research_License_TextSeg.txt``` license file.
44 | * ```semantic_label.tar.gz``` contains all word-level (semantic-level) masks. It contains:
45 | * ```[dataID]_maskfg.png``` 0 means background, 100 means word, 200 means word-effect, 255 means ignore. (The ```[dataID]_maskfg.png``` can also be generated using ```[dataID]_mask.png``` and ```[dataID]_maskeff.png```)
46 | * ```split.json``` the official split of train, val and test.
47 | * [Optional] ```semantic_label_v1.tar.gz``` the old version of label that was used in our paper. One can download it to reproduce our paper results.
48 |
49 | ## TexRNet Structure and Results
50 |
51 |
52 |
53 |
54 |
55 | In this table, we report the performance of our TexRNet on 5 text segmentation dataset including ours.
56 |
57 |
58 |
59 |
60 | |
61 | TextSeg(Ours) |
62 | ICDAR13 FST |
63 | COCO_TS |
64 | MLT_S |
65 | Total-Text |
66 |
67 |
68 | Method |
69 | fgIoU | F-score |
70 | fgIoU | F-score |
71 | fgIoU | F-score |
72 | fgIoU | F-score |
73 | fgIoU | F-score |
74 |
75 |
76 | DeeplabV3+ |
77 | 84.07 | 0.914 |
78 | 69.27 | 0.802 |
79 | 72.07 | 0.641 |
80 | 84.63 | 0.837 |
81 | 74.44 | 0.824 |
82 |
83 |
84 | HRNetV2-W48 |
85 | 85.03 | 0.914 |
86 | 70.98 | 0.822 |
87 | 68.93 | 0.629 |
88 | 83.26 | 0.836 |
89 | 75.29 | 0.825 |
90 |
91 |
92 | HRNetV2-W48 + OCR |
93 | 85.98 | 0.918 |
94 | 72.45 | 0.830 |
95 | 69.54 | 0.627 |
96 | 83.49 | 0.838 |
97 | 76.23 | 0.832 |
98 |
99 |
100 | Ours: TexRNet + DeeplabV3+ |
101 | 86.06 | 0.921 |
102 | 72.16 | 0.835 |
103 | 73.98 | 0.722 |
104 | 86.31 | 0.830 |
105 | 76.53 | 0.844 |
106 |
107 |
108 | Ours: TexRNet + HRNetV2-W48 |
109 | 86.84 | 0.924 |
110 | 73.38 | 0.850 |
111 | 72.39 | 0.720 |
112 | 86.09 | 0.865 |
113 | 78.47 | 0.848 |
114 |
115 |
116 |
117 |
118 | ## To run the code
119 |
120 | ### Set up the environment
121 | ```
122 | conda create -n texrnet python=3.7
123 | conda activate texrnet
124 | pip install -r requirement.txt
125 | ```
126 | ### To eval
127 |
128 | First, make the following directories to hold pre-trained models, dataset, and running logs:
129 | ```
130 | mkdir ./pretrained
131 | mkdir ./data
132 | mkdir ./log
133 | ```
134 |
135 | Second, download the models from [this link](https://drive.google.com/drive/folders/1EvGNvI5R6NKsW0YTy_0YHD9dpvtM0HDi?usp=sharing). Move those downloaded models to `./pretrained`.
136 |
137 | Thrid, make sure that `./data` contains the data. A sample root directory for **TextSeg** would be `./data/TextSeg`.
138 |
139 | Lastly, evaluate the model and compute fgIoU/F-score with the following command:
140 | ```
141 | python main.py --eval --pth [model path] [--hrnet] [--gpu 0 1 ...] --dsname [dataset name]
142 | ```
143 |
144 | Here is the sample command to eval a TexRNet_HRNet on TextSeg with 4 GPUs:
145 | ```
146 | python main.py --eval --pth pretrained/texrnet_hrnet.pth --hrnet --gpu 0 1 2 3 --dsname textseg
147 | ```
148 |
149 | The program will store results and execution log in `./log/eval`.
150 |
151 | ### To train
152 |
153 | Similarly, these directories need to be created:
154 | ```
155 | mkdir ./pretrained
156 | mkdir ./pretrained/init
157 | mkdir ./data
158 | mkdir ./log
159 | ```
160 |
161 | Second, we use multiple pre-trained models for training. Download these initial models from [this link](https://drive.google.com/drive/folders/1EvGNvI5R6NKsW0YTy_0YHD9dpvtM0HDi?usp=sharing). Move those models to `./pretrained/init`. Also, make sure that `./data` contains the data.
162 |
163 | Lastly, execute the training code with the following command:
164 | ```
165 | python main.py [--hrnet] [--gpu 0 1 ...] --dsname [dataset name] [--trainwithcls]
166 | ```
167 |
168 | Here is the sample command to train a TexRNet_HRNet on TextSeg with classifier and discriminate loss using 4 GPUs:
169 | ```
170 | python main.py --hrnet --gpu 0 1 2 3 --dsname textseg --trainwithcls
171 | ```
172 |
173 | The training configs, logs, and models will be stored in `./log/texrnet_[dsname]/[exid]_[signature]`.
174 |
175 | ## Bibtex
176 | ```
177 | @article{xu2020rethinking,
178 | title={Rethinking Text Segmentation: A Novel Dataset and A Text-Specific Refinement Approach},
179 | author={Xu, Xingqian and Zhang, Zhifei and Wang, Zhaowen and Price, Brian and Wang, Zhonghao and Shi, Humphrey},
180 | journal={arXiv preprint arXiv:2011.14021},
181 | year={2020}
182 | }
183 | ```
184 |
185 | ## Acknowledgements
186 |
187 | The directory `.\hrnet_code` is directly copied from the HRNet official github website [(link)](https://github.com/HRNet/HRNet-Semantic-Segmentation). HRNet code ownership should be credited to HRNet authors, and users should follow their terms of usage.
188 |
--------------------------------------------------------------------------------
/configs/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SHI-Labs/Rethinking-Text-Segmentation/223f7ffc822c345ce1a7c0eb3d4fac58a43d6a3a/configs/__init__.py
--------------------------------------------------------------------------------
/configs/cfg_base.py:
--------------------------------------------------------------------------------
1 | import os
2 | import os.path as osp
3 | import numpy as np
4 | import copy
5 | import socket
6 |
7 | from easydict import EasyDict as edict
8 |
9 | cfg = edict()
10 |
11 | # -----------------------------BASE-----------------------------
12 |
13 | cfg.DEBUG = False
14 | cfg.EXPERIMENT_ID = 0
15 | cfg.GPU_DEVICE = 'all'
16 | cfg.CUDA = False
17 | cfg.MISC_DIR = osp.abspath(osp.join(osp.dirname(__file__), '..', 'log'))
18 | cfg.LOG_FILE = None
19 | cfg.RND_SEED = None
20 | cfg.RND_RECORDING = False
21 | # cfg.USE_FLOAT16 = False
22 | cfg.MATPLOTLIB_MODE = 'Agg'
23 | cfg.MAINLOOP_EXECUTE = True
24 | cfg.MAIN_CODE_PATH = None
25 | cfg.MAIN_CODE = []
26 | cfg.SAVE_CODE = True
27 | cfg.COMPUTER_NAME = socket.gethostname()
28 | cfg.TORCH_VERSION = 'unknown'
29 |
30 | cfg.DIST_URL = 'tcp://127.0.0.1:11233'
31 | cfg.DIST_BACKEND = 'nccl'
32 |
33 | cfg_train = copy.deepcopy(cfg)
34 | cfg_test = copy.deepcopy(cfg)
35 |
36 | # -----------------------------TRAIN-----------------------------
37 |
38 | cfg_train.TRAIN = edict()
39 | cfg_train.TRAIN.BATCH_SIZE = None
40 | cfg_train.TRAIN.BATCH_SIZE_PER_GPU = None
41 | cfg_train.TRAIN.MAX_STEP = 0
42 | cfg_train.TRAIN.MAX_STEP_TYPE = None
43 | cfg_train.TRAIN.SKIP_PARTIAL = True
44 | # cfg_train.TRAIN.LR_ADJUST_MODE = None
45 | cfg_train.TRAIN.LR_ITER_BY = None
46 | cfg_train.TRAIN.OPTIMIZER = None
47 | cfg_train.TRAIN.DISPLAY = 0
48 | cfg_train.TRAIN.VISUAL = None
49 | cfg_train.TRAIN.SAVE_INIT_MODEL = True
50 | cfg_train.TRAIN.SAVE_CODE = True
51 |
52 | # -----------------------------TEST-----------------------------
53 |
54 | cfg_test.TEST = edict()
55 | cfg_test.TEST.BATCH_SIZE = None
56 | cfg_test.TEST.BATCH_SIZE_PER_GPU = None
57 | cfg_test.TEST.VISUAL = None
58 |
59 | # -----------------------------COMBINED-----------------------------
60 |
61 | cfg.update(cfg_train)
62 | cfg.update(cfg_test)
63 |
--------------------------------------------------------------------------------
/configs/cfg_dataset.py:
--------------------------------------------------------------------------------
1 | import os
2 | import os.path as osp
3 | import numpy as np
4 | import copy
5 |
6 | from easydict import EasyDict as edict
7 |
8 | cfg = edict()
9 | cfg.DATASET_MODE = None
10 | cfg.LOADER_PIPELINE = []
11 | cfg.LOAD_BACKEND_IMAGE = 'pil'
12 | cfg.LOAD_IS_MC_IMAGE = False
13 | cfg.TRANS_PIPELINE = []
14 | cfg.NUM_WORKERS_PER_GPU = None
15 | cfg.NUM_WORKERS = None
16 | cfg.TRY_SAMPLE = None
17 |
18 | ##############################
19 | ##### imagenet #####
20 | ##############################
21 |
22 | cfg_imagenet = copy.deepcopy(cfg)
23 | cfg_imagenet.DATASET_NAME = 'imagenet'
24 | cfg_imagenet.ROOT_DIR = osp.abspath(osp.join(
25 | osp.dirname(__file__), '..', 'data',
26 | 'ImageNet', 'ILSVRC2012'))
27 | cfg_imagenet.CLASS_INFO_JSON = osp.abspath(osp.join(
28 | osp.dirname(__file__), '..', 'data',
29 | 'ImageNet', 'addon', 'ILSVRC2012', '1000nids.json'))
30 | cfg_imagenet.IM_MEAN = [0.485, 0.456, 0.406]
31 | cfg_imagenet.IM_STD = [0.229, 0.224, 0.225]
32 | cfg_imagenet.CLASS_NUM = 1000
33 |
34 | #############################
35 | ##### textseg #####
36 | #############################
37 |
38 | cfg_textseg = copy.deepcopy(cfg)
39 | cfg_textseg.DATASET_NAME = 'textseg'
40 | cfg_textseg.ROOT_DIR = osp.abspath(osp.join(
41 | osp.dirname(__file__), '..', 'data', 'TextSeg'))
42 | cfg_textseg.CLASS_NUM = 2
43 | cfg_textseg.CLASS_NAME = [
44 | 'background',
45 | 'text']
46 | cfg_textseg.SEGLABEL_IGNORE_LABEL = 999
47 | cfg_textseg.SEMANTIC_PICK_CLASS = 'all'
48 | cfg_textseg.IM_MEAN = [0.485, 0.456, 0.406]
49 | cfg_textseg.IM_STD = [0.229, 0.224, 0.225]
50 | cfg_textseg.LOAD_IS_MC_SEGLABEL = True
51 |
52 | ##########################
53 | ##### cocotext #####
54 | ##########################
55 |
56 | cfg_cocotext = copy.deepcopy(cfg)
57 | cfg_cocotext.DATASET_NAME = 'cocotext'
58 | cfg_cocotext.ROOT_DIR = osp.abspath(osp.join(
59 | osp.dirname(__file__), '..', 'data', 'COCO'))
60 | cfg_cocotext.IM_MEAN = [0.485, 0.456, 0.406]
61 | cfg_cocotext.IM_STD = [0.229, 0.224, 0.225]
62 |
63 | ########################
64 | ##### cocots #####
65 | ########################
66 |
67 | cfg_cocots = copy.deepcopy(cfg)
68 | cfg_cocots.DATASET_NAME = 'cocots'
69 | cfg_cocots.ROOT_DIR = osp.abspath(osp.join(
70 | osp.dirname(__file__), '..', 'data', 'COCO'))
71 | cfg_cocots.IM_MEAN = [0.485, 0.456, 0.406]
72 | cfg_cocots.IM_STD = [0.229, 0.224, 0.225]
73 | cfg_cocots.CLASS_NUM = 2
74 | cfg_cocots.SEGLABEL_IGNORE_LABEL = 255
75 | cfg_cocots.LOAD_BACKEND_SEGLABEL = 'pil'
76 | cfg_cocots.LOAD_IS_MC_SEGLABEL = False
77 |
78 | #####################
79 | ##### mlt #####
80 | #####################
81 |
82 | cfg_mlt = copy.deepcopy(cfg)
83 | cfg_mlt.DATASET_NAME = 'mlt'
84 | cfg_mlt.ROOT_DIR = osp.abspath(osp.join(
85 | osp.dirname(__file__), '..', 'data', 'ICDAR17', 'challenge8'))
86 | cfg_mlt.IM_MEAN = [0.485, 0.456, 0.406]
87 | cfg_mlt.IM_STD = [0.229, 0.224, 0.225]
88 | cfg_mlt.CLASS_NUM = 2
89 | cfg_mlt.SEGLABEL_IGNORE_LABEL = 255
90 |
91 | #######################
92 | ##### icdar13 #####
93 | #######################
94 |
95 | cfg_icdar13 = copy.deepcopy(cfg)
96 | cfg_icdar13.DATASET_NAME = 'icdar13'
97 | cfg_icdar13.ROOT_DIR = osp.abspath(osp.join(
98 | osp.dirname(__file__), '..', 'data', 'ICDAR13'))
99 | cfg_icdar13.CLASS_NUM = 2
100 | cfg_icdar13.CLASS_NAME = [
101 | 'background',
102 | 'text']
103 | cfg_icdar13.SEGLABEL_IGNORE_LABEL = 999
104 | cfg_icdar13.SEMANTIC_PICK_CLASS = 'all'
105 | cfg_icdar13.IM_MEAN = [0.485, 0.456, 0.406]
106 | cfg_icdar13.IM_STD = [0.229, 0.224, 0.225]
107 | cfg_icdar13.LOAD_BACKEND_SEGLABEL = 'pil'
108 | cfg_icdar13.LOAD_IS_MC_SEGLABEL = False
109 | cfg_icdar13.FROM_SOURCE = 'addon'
110 | cfg_icdar13.USE_CACHE = False
111 |
112 | #########################
113 | ##### totaltext #####
114 | #########################
115 |
116 | cfg_totaltext = copy.deepcopy(cfg)
117 | cfg_totaltext.DATASET_NAME = 'totaltext'
118 | cfg_totaltext.ROOT_DIR = osp.abspath(osp.join(
119 | osp.dirname(__file__), '..', 'data', 'TotalText'))
120 | cfg_totaltext.CLASS_NUM = 2
121 | cfg_totaltext.CLASS_NAME = [
122 | 'background',
123 | 'text']
124 | cfg_totaltext.SEGLABEL_IGNORE_LABEL = 999
125 | # dummy, totaltext pixel level anno has no ignore label
126 | cfg_totaltext.IM_MEAN = [0.485, 0.456, 0.406]
127 | cfg_totaltext.IM_STD = [0.229, 0.224, 0.225]
128 |
129 | #######################
130 | ##### textssc #####
131 | #######################
132 | # text semantic segmentation composed
133 |
134 | cfg_textssc = copy.deepcopy(cfg)
135 | cfg_textssc.DATASET_NAME = 'textssc'
136 | cfg_textssc.ROOT_DIR = osp.abspath(osp.join(
137 | osp.dirname(__file__), '..', 'data', 'TextSSC'))
138 | cfg_textssc.CLASS_NUM = 2
139 | cfg_textssc.CLASS_NAME = [
140 | 'background',
141 | 'text']
142 | cfg_textssc.SEGLABEL_IGNORE_LABEL = 999
143 | cfg_textssc.IM_MEAN = [0.485, 0.456, 0.406]
144 | cfg_textssc.IM_STD = [0.229, 0.224, 0.225]
145 | cfg_textssc.LOAD_BACKEND_SEGLABEL = 'pil'
146 | cfg_textssc.LOAD_IS_MC_SEGLABEL = False
147 |
--------------------------------------------------------------------------------
/configs/cfg_model.py:
--------------------------------------------------------------------------------
1 | import os
2 | import os.path as osp
3 | import numpy as np
4 | import copy
5 |
6 | from easydict import EasyDict as edict
7 |
8 | cfg = edict()
9 | cfg.MODEL_NAME = None
10 | # cfg.CONV_TYPE = 'conv'
11 | # cfg.BN_TYPE = 'bn'
12 | # cfg.RELU_TYPE = 'relu'
13 |
14 | # resnet
15 | cfg_resnet = copy.deepcopy(cfg)
16 | cfg_resnet.MODEL_NAME = 'resnet'
17 | cfg_resnet.RESNET = edict()
18 | cfg_resnet.RESNET.MODEL_TAGS = None
19 | cfg_resnet.RESNET.PRETRAINED_PTH = None
20 | cfg_resnet.RESNET.BN_TYPE = 'bn'
21 | cfg_resnet.RESNET.RELU_TYPE = 'relu'
22 |
23 | # deeplab
24 | cfg_deeplab = copy.deepcopy(cfg)
25 | cfg_deeplab.MODEL_NAME = 'deeplab'
26 | cfg_deeplab.DEEPLAB = edict()
27 | cfg_deeplab.DEEPLAB.MODEL_TAGS = None
28 | cfg_deeplab.DEEPLAB.PRETRAINED_PTH = None
29 | cfg_deeplab.DEEPLAB.FREEZE_BACKBONE_BN = False
30 | cfg_deeplab.DEEPLAB.BN_TYPE = 'bn'
31 | cfg_deeplab.DEEPLAB.RELU_TYPE = 'relu'
32 | # cfg_deeplab.DEEPLAB.ASPP_DROPOUT_TYPE = 'dropout|0.5'
33 | cfg_deeplab.DEEPLAB.ASPP_WITH_GAP = True
34 | # cfg_deeplab.DEEPLAB.DECODER_DROPOUT2_TYPE = 'dropout|0.5'
35 | # cfg_deeplab.DEEPLAB.DECODER_DROPOUT3_TYPE = 'dropout|0.1'
36 | cfg_deeplab.RESNET = cfg_resnet.RESNET
37 |
38 | # hrnet
39 | cfg_hrnet = copy.deepcopy(cfg)
40 | cfg_hrnet.MODEL_NAME = 'hrnet'
41 | cfg_hrnet.HRNET = edict()
42 | cfg_hrnet.HRNET.MODEL_TAGS = None
43 | cfg_hrnet.HRNET.PRETRAINED_PTH = None
44 | cfg_hrnet.HRNET.BN_TYPE = 'bn'
45 | cfg_hrnet.HRNET.RELU_TYPE = 'relu'
46 |
47 | # texrnet
48 | cfg_texrnet = copy.deepcopy(cfg)
49 | cfg_texrnet.MODEL_NAME = 'texrnet'
50 | cfg_texrnet.TEXRNET = edict()
51 | cfg_texrnet.TEXRNET.MODEL_TAGS = None
52 | cfg_texrnet.TEXRNET.PRETRAINED_PTH = None
53 | cfg_texrnet.RESNET = cfg_resnet.RESNET
54 | cfg_texrnet.DEEPLAB = cfg_deeplab.DEEPLAB
55 |
--------------------------------------------------------------------------------
/hrnet_code/.gitignore:
--------------------------------------------------------------------------------
1 | .vscode
2 | __pycache__/
3 | *.py[co]
4 | data/
5 | log/
6 | output/
7 | pretrained_models
8 | scripts/
9 | detail-api/
10 | data/list
--------------------------------------------------------------------------------
/hrnet_code/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) [2019] [Microsoft]
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
23 | =======================================================================================
24 | 3-clause BSD licenses
25 | =======================================================================================
26 | 1. syncbn - For details, see lib/models/syncbn/LICENSE
27 | Copyright (c) 2017 mapillary
28 |
--------------------------------------------------------------------------------
/hrnet_code/lib/config/__init__.py:
--------------------------------------------------------------------------------
1 | # ------------------------------------------------------------------------------
2 | # Copyright (c) Microsoft
3 | # Licensed under the MIT License.
4 | # Written by Ke Sun (sunk@mail.ustc.edu.cn)
5 | # ------------------------------------------------------------------------------
6 | from __future__ import absolute_import
7 | from __future__ import division
8 | from __future__ import print_function
9 |
10 | from .default import _C as config
11 | from .default import update_config
12 | from .models import MODEL_EXTRAS
13 |
--------------------------------------------------------------------------------
/hrnet_code/lib/config/default.py:
--------------------------------------------------------------------------------
1 |
2 | # ------------------------------------------------------------------------------
3 | # Copyright (c) Microsoft
4 | # Licensed under the MIT License.
5 | # Written by Ke Sun (sunk@mail.ustc.edu.cn)
6 | # ------------------------------------------------------------------------------
7 |
8 | from __future__ import absolute_import
9 | from __future__ import division
10 | from __future__ import print_function
11 |
12 | import os
13 |
14 | from yacs.config import CfgNode as CN
15 |
16 |
17 | _C = CN()
18 |
19 | _C.OUTPUT_DIR = ''
20 | _C.LOG_DIR = ''
21 | _C.GPUS = (0,)
22 | _C.WORKERS = 4
23 | _C.PRINT_FREQ = 20
24 | _C.AUTO_RESUME = False
25 | _C.PIN_MEMORY = True
26 | _C.RANK = 0
27 |
28 | # Cudnn related params
29 | _C.CUDNN = CN()
30 | _C.CUDNN.BENCHMARK = True
31 | _C.CUDNN.DETERMINISTIC = False
32 | _C.CUDNN.ENABLED = True
33 |
34 | # common params for NETWORK
35 | _C.MODEL = CN()
36 | _C.MODEL.NAME = 'seg_hrnet'
37 | _C.MODEL.PRETRAINED = ''
38 | _C.MODEL.ALIGN_CORNERS = True
39 | _C.MODEL.NUM_OUTPUTS = 1
40 | _C.MODEL.EXTRA = CN(new_allowed=True)
41 |
42 |
43 | _C.MODEL.OCR = CN()
44 | _C.MODEL.OCR.MID_CHANNELS = 512
45 | _C.MODEL.OCR.KEY_CHANNELS = 256
46 | _C.MODEL.OCR.DROPOUT = 0.05
47 | _C.MODEL.OCR.SCALE = 1
48 |
49 | _C.LOSS = CN()
50 | _C.LOSS.USE_OHEM = False
51 | _C.LOSS.OHEMTHRES = 0.9
52 | _C.LOSS.OHEMKEEP = 100000
53 | _C.LOSS.CLASS_BALANCE = False
54 | _C.LOSS.BALANCE_WEIGHTS = [1]
55 |
56 | # DATASET related params
57 | _C.DATASET = CN()
58 | _C.DATASET.ROOT = ''
59 | _C.DATASET.DATASET = 'cityscapes'
60 | _C.DATASET.NUM_CLASSES = 19
61 | _C.DATASET.TRAIN_SET = 'list/cityscapes/train.lst'
62 | _C.DATASET.EXTRA_TRAIN_SET = ''
63 | _C.DATASET.TEST_SET = 'list/cityscapes/val.lst'
64 |
65 | # training
66 | _C.TRAIN = CN()
67 |
68 | _C.TRAIN.FREEZE_LAYERS = ''
69 | _C.TRAIN.FREEZE_EPOCHS = -1
70 | _C.TRAIN.NONBACKBONE_KEYWORDS = []
71 | _C.TRAIN.NONBACKBONE_MULT = 10
72 |
73 | _C.TRAIN.IMAGE_SIZE = [1024, 512] # width * height
74 | _C.TRAIN.BASE_SIZE = 2048
75 | _C.TRAIN.DOWNSAMPLERATE = 1
76 | _C.TRAIN.FLIP = True
77 | _C.TRAIN.MULTI_SCALE = True
78 | _C.TRAIN.SCALE_FACTOR = 16
79 |
80 | _C.TRAIN.RANDOM_BRIGHTNESS = False
81 | _C.TRAIN.RANDOM_BRIGHTNESS_SHIFT_VALUE = 10
82 |
83 | _C.TRAIN.LR_FACTOR = 0.1
84 | _C.TRAIN.LR_STEP = [90, 110]
85 | _C.TRAIN.LR = 0.01
86 | _C.TRAIN.EXTRA_LR = 0.001
87 |
88 | _C.TRAIN.OPTIMIZER = 'sgd'
89 | _C.TRAIN.MOMENTUM = 0.9
90 | _C.TRAIN.WD = 0.0001
91 | _C.TRAIN.NESTEROV = False
92 | _C.TRAIN.IGNORE_LABEL = -1
93 |
94 | _C.TRAIN.BEGIN_EPOCH = 0
95 | _C.TRAIN.END_EPOCH = 484
96 | _C.TRAIN.EXTRA_EPOCH = 0
97 |
98 | _C.TRAIN.RESUME = False
99 |
100 | _C.TRAIN.BATCH_SIZE_PER_GPU = 32
101 | _C.TRAIN.SHUFFLE = True
102 | # only using some training samples
103 | _C.TRAIN.NUM_SAMPLES = 0
104 |
105 | # testing
106 | _C.TEST = CN()
107 |
108 | _C.TEST.IMAGE_SIZE = [2048, 1024] # width * height
109 | _C.TEST.BASE_SIZE = 2048
110 |
111 | _C.TEST.BATCH_SIZE_PER_GPU = 32
112 | # only testing some samples
113 | _C.TEST.NUM_SAMPLES = 0
114 |
115 | _C.TEST.MODEL_FILE = ''
116 | _C.TEST.FLIP_TEST = False
117 | _C.TEST.MULTI_SCALE = False
118 | _C.TEST.SCALE_LIST = [1]
119 |
120 | _C.TEST.OUTPUT_INDEX = -1
121 |
122 | # debug
123 | _C.DEBUG = CN()
124 | _C.DEBUG.DEBUG = False
125 | _C.DEBUG.SAVE_BATCH_IMAGES_GT = False
126 | _C.DEBUG.SAVE_BATCH_IMAGES_PRED = False
127 | _C.DEBUG.SAVE_HEATMAPS_GT = False
128 | _C.DEBUG.SAVE_HEATMAPS_PRED = False
129 |
130 |
131 | def update_config(cfg, args):
132 | cfg.defrost()
133 |
134 | cfg.merge_from_file(args.cfg)
135 | cfg.merge_from_list(args.opts)
136 |
137 | cfg.freeze()
138 |
139 |
140 | if __name__ == '__main__':
141 | import sys
142 | with open(sys.argv[1], 'w') as f:
143 | print(_C, file=f)
144 |
145 |
--------------------------------------------------------------------------------
/hrnet_code/lib/config/models.py:
--------------------------------------------------------------------------------
1 | # ------------------------------------------------------------------------------
2 | # Copyright (c) Microsoft
3 | # Licensed under the MIT License.
4 | # Written by Ke Sun (sunk@mail.ustc.edu.cn)
5 | # ------------------------------------------------------------------------------
6 |
7 | from __future__ import absolute_import
8 | from __future__ import division
9 | from __future__ import print_function
10 |
11 | from yacs.config import CfgNode as CN
12 |
13 | # high_resoluton_net related params for segmentation
14 | HIGH_RESOLUTION_NET = CN()
15 | HIGH_RESOLUTION_NET.PRETRAINED_LAYERS = ['*']
16 | HIGH_RESOLUTION_NET.STEM_INPLANES = 64
17 | HIGH_RESOLUTION_NET.FINAL_CONV_KERNEL = 1
18 | HIGH_RESOLUTION_NET.WITH_HEAD = True
19 |
20 | HIGH_RESOLUTION_NET.STAGE2 = CN()
21 | HIGH_RESOLUTION_NET.STAGE2.NUM_MODULES = 1
22 | HIGH_RESOLUTION_NET.STAGE2.NUM_BRANCHES = 2
23 | HIGH_RESOLUTION_NET.STAGE2.NUM_BLOCKS = [4, 4]
24 | HIGH_RESOLUTION_NET.STAGE2.NUM_CHANNELS = [32, 64]
25 | HIGH_RESOLUTION_NET.STAGE2.BLOCK = 'BASIC'
26 | HIGH_RESOLUTION_NET.STAGE2.FUSE_METHOD = 'SUM'
27 |
28 | HIGH_RESOLUTION_NET.STAGE3 = CN()
29 | HIGH_RESOLUTION_NET.STAGE3.NUM_MODULES = 1
30 | HIGH_RESOLUTION_NET.STAGE3.NUM_BRANCHES = 3
31 | HIGH_RESOLUTION_NET.STAGE3.NUM_BLOCKS = [4, 4, 4]
32 | HIGH_RESOLUTION_NET.STAGE3.NUM_CHANNELS = [32, 64, 128]
33 | HIGH_RESOLUTION_NET.STAGE3.BLOCK = 'BASIC'
34 | HIGH_RESOLUTION_NET.STAGE3.FUSE_METHOD = 'SUM'
35 |
36 | HIGH_RESOLUTION_NET.STAGE4 = CN()
37 | HIGH_RESOLUTION_NET.STAGE4.NUM_MODULES = 1
38 | HIGH_RESOLUTION_NET.STAGE4.NUM_BRANCHES = 4
39 | HIGH_RESOLUTION_NET.STAGE4.NUM_BLOCKS = [4, 4, 4, 4]
40 | HIGH_RESOLUTION_NET.STAGE4.NUM_CHANNELS = [32, 64, 128, 256]
41 | HIGH_RESOLUTION_NET.STAGE4.BLOCK = 'BASIC'
42 | HIGH_RESOLUTION_NET.STAGE4.FUSE_METHOD = 'SUM'
43 |
44 | MODEL_EXTRAS = {
45 | 'seg_hrnet': HIGH_RESOLUTION_NET,
46 | }
47 |
--------------------------------------------------------------------------------
/hrnet_code/lib/core/criterion.py:
--------------------------------------------------------------------------------
1 | # ------------------------------------------------------------------------------
2 | # Copyright (c) Microsoft
3 | # Licensed under the MIT License.
4 | # Written by Ke Sun (sunk@mail.ustc.edu.cn)
5 | # ------------------------------------------------------------------------------
6 |
7 | import torch
8 | import torch.nn as nn
9 | from torch.nn import functional as F
10 | import logging
11 | from config import config
12 |
13 |
14 | class CrossEntropy(nn.Module):
15 | def __init__(self, ignore_label=-1, weight=None):
16 | super(CrossEntropy, self).__init__()
17 | self.ignore_label = ignore_label
18 | self.criterion = nn.CrossEntropyLoss(
19 | weight=weight,
20 | ignore_index=ignore_label
21 | )
22 |
23 | def _forward(self, score, target):
24 | ph, pw = score.size(2), score.size(3)
25 | h, w = target.size(1), target.size(2)
26 | if ph != h or pw != w:
27 | score = F.interpolate(input=score, size=(
28 | h, w), mode='bilinear', align_corners=config.MODEL.ALIGN_CORNERS)
29 |
30 | loss = self.criterion(score, target)
31 |
32 | return loss
33 |
34 | def forward(self, score, target):
35 |
36 | if config.MODEL.NUM_OUTPUTS == 1:
37 | score = [score]
38 |
39 | weights = config.LOSS.BALANCE_WEIGHTS
40 | assert len(weights) == len(score)
41 |
42 | return sum([w * self._forward(x, target) for (w, x) in zip(weights, score)])
43 |
44 |
45 | class OhemCrossEntropy(nn.Module):
46 | def __init__(self, ignore_label=-1, thres=0.7,
47 | min_kept=100000, weight=None):
48 | super(OhemCrossEntropy, self).__init__()
49 | self.thresh = thres
50 | self.min_kept = max(1, min_kept)
51 | self.ignore_label = ignore_label
52 | self.criterion = nn.CrossEntropyLoss(
53 | weight=weight,
54 | ignore_index=ignore_label,
55 | reduction='none'
56 | )
57 |
58 | def _ce_forward(self, score, target):
59 | ph, pw = score.size(2), score.size(3)
60 | h, w = target.size(1), target.size(2)
61 | if ph != h or pw != w:
62 | score = F.interpolate(input=score, size=(
63 | h, w), mode='bilinear', align_corners=config.MODEL.ALIGN_CORNERS)
64 |
65 | loss = self.criterion(score, target)
66 |
67 | return loss
68 |
69 | def _ohem_forward(self, score, target, **kwargs):
70 | ph, pw = score.size(2), score.size(3)
71 | h, w = target.size(1), target.size(2)
72 | if ph != h or pw != w:
73 | score = F.interpolate(input=score, size=(
74 | h, w), mode='bilinear', align_corners=config.MODEL.ALIGN_CORNERS)
75 | pred = F.softmax(score, dim=1)
76 | pixel_losses = self.criterion(score, target).contiguous().view(-1)
77 | mask = target.contiguous().view(-1) != self.ignore_label
78 |
79 | tmp_target = target.clone()
80 | tmp_target[tmp_target == self.ignore_label] = 0
81 | pred = pred.gather(1, tmp_target.unsqueeze(1))
82 | pred, ind = pred.contiguous().view(-1,)[mask].contiguous().sort()
83 | min_value = pred[min(self.min_kept, pred.numel() - 1)]
84 | threshold = max(min_value, self.thresh)
85 |
86 | pixel_losses = pixel_losses[mask][ind]
87 | pixel_losses = pixel_losses[pred < threshold]
88 | return pixel_losses.mean()
89 |
90 | def forward(self, score, target):
91 |
92 | if config.MODEL.NUM_OUTPUTS == 1:
93 | score = [score]
94 |
95 | weights = config.LOSS.BALANCE_WEIGHTS
96 | assert len(weights) == len(score)
97 |
98 | functions = [self._ce_forward] * \
99 | (len(weights) - 1) + [self._ohem_forward]
100 | return sum([
101 | w * func(x, target)
102 | for (w, x, func) in zip(weights, score, functions)
103 | ])
104 |
--------------------------------------------------------------------------------
/hrnet_code/lib/core/function.py:
--------------------------------------------------------------------------------
1 | # ------------------------------------------------------------------------------
2 | # Copyright (c) Microsoft
3 | # Licensed under the MIT License.
4 | # Written by Ke Sun (sunk@mail.ustc.edu.cn)
5 | # ------------------------------------------------------------------------------
6 |
7 | import logging
8 | import os
9 | import time
10 |
11 | import numpy as np
12 | import numpy.ma as ma
13 | from tqdm import tqdm
14 |
15 | import torch
16 | import torch.nn as nn
17 | from torch.nn import functional as F
18 |
19 | from utils.utils import AverageMeter
20 | from utils.utils import get_confusion_matrix
21 | from utils.utils import adjust_learning_rate
22 |
23 | import utils.distributed as dist
24 |
25 |
26 | def reduce_tensor(inp):
27 | """
28 | Reduce the loss from all processes so that
29 | process with rank 0 has the averaged results.
30 | """
31 | world_size = dist.get_world_size()
32 | if world_size < 2:
33 | return inp
34 | with torch.no_grad():
35 | reduced_inp = inp
36 | torch.distributed.reduce(reduced_inp, dst=0)
37 | return reduced_inp / world_size
38 |
39 |
40 | def train(config, epoch, num_epoch, epoch_iters, base_lr,
41 | num_iters, trainloader, optimizer, model, writer_dict):
42 | # Training
43 | model.train()
44 |
45 | batch_time = AverageMeter()
46 | ave_loss = AverageMeter()
47 | tic = time.time()
48 | cur_iters = epoch*epoch_iters
49 | writer = writer_dict['writer']
50 | global_steps = writer_dict['train_global_steps']
51 |
52 | for i_iter, batch in enumerate(trainloader, 0):
53 | images, labels, _, _ = batch
54 | images = images.cuda()
55 | labels = labels.long().cuda()
56 |
57 | losses, _ = model(images, labels)
58 | loss = losses.mean()
59 |
60 | if dist.is_distributed():
61 | reduced_loss = reduce_tensor(loss)
62 | else:
63 | reduced_loss = loss
64 |
65 | model.zero_grad()
66 | loss.backward()
67 | optimizer.step()
68 |
69 | # measure elapsed time
70 | batch_time.update(time.time() - tic)
71 | tic = time.time()
72 |
73 | # update average loss
74 | ave_loss.update(reduced_loss.item())
75 |
76 | lr = adjust_learning_rate(optimizer,
77 | base_lr,
78 | num_iters,
79 | i_iter+cur_iters)
80 |
81 | if i_iter % config.PRINT_FREQ == 0 and dist.get_rank() == 0:
82 | msg = 'Epoch: [{}/{}] Iter:[{}/{}], Time: {:.2f}, ' \
83 | 'lr: {}, Loss: {:.6f}' .format(
84 | epoch, num_epoch, i_iter, epoch_iters,
85 | batch_time.average(), [x['lr'] for x in optimizer.param_groups], ave_loss.average())
86 | logging.info(msg)
87 |
88 | writer.add_scalar('train_loss', ave_loss.average(), global_steps)
89 | writer_dict['train_global_steps'] = global_steps + 1
90 |
91 | def validate(config, testloader, model, writer_dict):
92 | model.eval()
93 | ave_loss = AverageMeter()
94 | nums = config.MODEL.NUM_OUTPUTS
95 | confusion_matrix = np.zeros(
96 | (config.DATASET.NUM_CLASSES, config.DATASET.NUM_CLASSES, nums))
97 | with torch.no_grad():
98 | for idx, batch in enumerate(testloader):
99 | image, label, _, _ = batch
100 | size = label.size()
101 | image = image.cuda()
102 | label = label.long().cuda()
103 |
104 | losses, pred = model(image, label)
105 | if not isinstance(pred, (list, tuple)):
106 | pred = [pred]
107 | for i, x in enumerate(pred):
108 | x = F.interpolate(
109 | input=x, size=size[-2:],
110 | mode='bilinear', align_corners=config.MODEL.ALIGN_CORNERS
111 | )
112 |
113 | confusion_matrix[..., i] += get_confusion_matrix(
114 | label,
115 | x,
116 | size,
117 | config.DATASET.NUM_CLASSES,
118 | config.TRAIN.IGNORE_LABEL
119 | )
120 |
121 | if idx % 10 == 0:
122 | print(idx)
123 |
124 | loss = losses.mean()
125 | if dist.is_distributed():
126 | reduced_loss = reduce_tensor(loss)
127 | else:
128 | reduced_loss = loss
129 | ave_loss.update(reduced_loss.item())
130 |
131 | if dist.is_distributed():
132 | confusion_matrix = torch.from_numpy(confusion_matrix).cuda()
133 | reduced_confusion_matrix = reduce_tensor(confusion_matrix)
134 | confusion_matrix = reduced_confusion_matrix.cpu().numpy()
135 |
136 | for i in range(nums):
137 | pos = confusion_matrix[..., i].sum(1)
138 | res = confusion_matrix[..., i].sum(0)
139 | tp = np.diag(confusion_matrix[..., i])
140 | IoU_array = (tp / np.maximum(1.0, pos + res - tp))
141 | mean_IoU = IoU_array.mean()
142 | if dist.get_rank() <= 0:
143 | logging.info('{} {} {}'.format(i, IoU_array, mean_IoU))
144 |
145 | writer = writer_dict['writer']
146 | global_steps = writer_dict['valid_global_steps']
147 | writer.add_scalar('valid_loss', ave_loss.average(), global_steps)
148 | writer.add_scalar('valid_mIoU', mean_IoU, global_steps)
149 | writer_dict['valid_global_steps'] = global_steps + 1
150 | return ave_loss.average(), mean_IoU, IoU_array
151 |
152 |
153 | def testval(config, test_dataset, testloader, model,
154 | sv_dir='', sv_pred=False):
155 | model.eval()
156 | confusion_matrix = np.zeros(
157 | (config.DATASET.NUM_CLASSES, config.DATASET.NUM_CLASSES))
158 | with torch.no_grad():
159 | for index, batch in enumerate(tqdm(testloader)):
160 | image, label, _, name, *border_padding = batch
161 | size = label.size()
162 | pred = test_dataset.multi_scale_inference(
163 | config,
164 | model,
165 | image,
166 | scales=config.TEST.SCALE_LIST,
167 | flip=config.TEST.FLIP_TEST)
168 |
169 | if len(border_padding) > 0:
170 | border_padding = border_padding[0]
171 | pred = pred[:, :, 0:pred.size(2) - border_padding[0], 0:pred.size(3) - border_padding[1]]
172 |
173 | if pred.size()[-2] != size[-2] or pred.size()[-1] != size[-1]:
174 | pred = F.interpolate(
175 | pred, size[-2:],
176 | mode='bilinear', align_corners=config.MODEL.ALIGN_CORNERS
177 | )
178 |
179 | confusion_matrix += get_confusion_matrix(
180 | label,
181 | pred,
182 | size,
183 | config.DATASET.NUM_CLASSES,
184 | config.TRAIN.IGNORE_LABEL)
185 |
186 | if sv_pred:
187 | sv_path = os.path.join(sv_dir, 'test_results')
188 | if not os.path.exists(sv_path):
189 | os.mkdir(sv_path)
190 | test_dataset.save_pred(pred, sv_path, name)
191 |
192 | if index % 100 == 0:
193 | logging.info('processing: %d images' % index)
194 | pos = confusion_matrix.sum(1)
195 | res = confusion_matrix.sum(0)
196 | tp = np.diag(confusion_matrix)
197 | IoU_array = (tp / np.maximum(1.0, pos + res - tp))
198 | mean_IoU = IoU_array.mean()
199 | logging.info('mIoU: %.4f' % (mean_IoU))
200 |
201 | pos = confusion_matrix.sum(1)
202 | res = confusion_matrix.sum(0)
203 | tp = np.diag(confusion_matrix)
204 | pixel_acc = tp.sum()/pos.sum()
205 | mean_acc = (tp/np.maximum(1.0, pos)).mean()
206 | IoU_array = (tp / np.maximum(1.0, pos + res - tp))
207 | mean_IoU = IoU_array.mean()
208 |
209 | return mean_IoU, IoU_array, pixel_acc, mean_acc
210 |
211 |
212 | def test(config, test_dataset, testloader, model,
213 | sv_dir='', sv_pred=True):
214 | model.eval()
215 | with torch.no_grad():
216 | for _, batch in enumerate(tqdm(testloader)):
217 | image, size, name = batch
218 | size = size[0]
219 | pred = test_dataset.multi_scale_inference(
220 | config,
221 | model,
222 | image,
223 | scales=config.TEST.SCALE_LIST,
224 | flip=config.TEST.FLIP_TEST)
225 |
226 | if pred.size()[-2] != size[0] or pred.size()[-1] != size[1]:
227 | pred = F.interpolate(
228 | pred, size[-2:],
229 | mode='bilinear', align_corners=config.MODEL.ALIGN_CORNERS
230 | )
231 |
232 | if sv_pred:
233 | sv_path = os.path.join(sv_dir, 'test_results')
234 | if not os.path.exists(sv_path):
235 | os.mkdir(sv_path)
236 | test_dataset.save_pred(pred, sv_path, name)
237 |
--------------------------------------------------------------------------------
/hrnet_code/lib/datasets/__init__.py:
--------------------------------------------------------------------------------
1 | # ------------------------------------------------------------------------------
2 | # Copyright (c) Microsoft
3 | # Licensed under the MIT License.
4 | # Written by Ke Sun (sunk@mail.ustc.edu.cn)
5 | # ------------------------------------------------------------------------------
6 |
7 | from __future__ import absolute_import
8 | from __future__ import division
9 | from __future__ import print_function
10 |
11 | from .cityscapes import Cityscapes as cityscapes
12 | from .lip import LIP as lip
13 | from .pascal_ctx import PASCALContext as pascal_ctx
14 | from .ade20k import ADE20K as ade20k
15 | from .cocostuff import COCOStuff as cocostuff
--------------------------------------------------------------------------------
/hrnet_code/lib/datasets/ade20k.py:
--------------------------------------------------------------------------------
1 | # ------------------------------------------------------------------------------
2 | # Copyright (c) Microsoft
3 | # Licensed under the MIT License.
4 | # Written by Ke Sun (sunk@mail.ustc.edu.cn)
5 | # ------------------------------------------------------------------------------
6 |
7 | import os
8 |
9 | import cv2
10 | import numpy as np
11 |
12 | import torch
13 | from torch.nn import functional as F
14 | from PIL import Image
15 |
16 | from .base_dataset import BaseDataset
17 |
18 |
19 | class ADE20K(BaseDataset):
20 | def __init__(self,
21 | root,
22 | list_path,
23 | num_samples=None,
24 | num_classes=150,
25 | multi_scale=True,
26 | flip=True,
27 | ignore_label=-1,
28 | base_size=520,
29 | crop_size=(520, 520),
30 | downsample_rate=1,
31 | scale_factor=11,
32 | mean=[0.485, 0.456, 0.406],
33 | std=[0.229, 0.224, 0.225]):
34 |
35 | super(ADE20K, self).__init__(ignore_label, base_size,
36 | crop_size, downsample_rate, scale_factor, mean, std)
37 |
38 | self.root = root
39 | self.num_classes = num_classes
40 | self.list_path = list_path
41 | self.class_weights = None
42 |
43 | self.multi_scale = multi_scale
44 | self.flip = flip
45 | self.img_list = [line.strip().split() for line in open(root+list_path)]
46 |
47 | self.files = self.read_files()
48 | if num_samples:
49 | self.files = self.files[:num_samples]
50 |
51 | def read_files(self):
52 | files = []
53 | for item in self.img_list:
54 | image_path, label_path = item
55 | name = os.path.splitext(os.path.basename(label_path))[0]
56 | sample = {
57 | 'img': image_path,
58 | 'label': label_path,
59 | 'name': name
60 | }
61 | files.append(sample)
62 | return files
63 |
64 | def resize_image(self, image, label, size):
65 | image = cv2.resize(image, size, interpolation=cv2.INTER_LINEAR)
66 | label = cv2.resize(label, size, interpolation=cv2.INTER_NEAREST)
67 | return image, label
68 |
69 | def __getitem__(self, index):
70 | item = self.files[index]
71 | name = item["name"]
72 | image_path = os.path.join(self.root, 'ade20k', item['img'])
73 | label_path = os.path.join(self.root, 'ade20k', item['label'])
74 | image = cv2.imread(
75 | image_path,
76 | cv2.IMREAD_COLOR
77 | )
78 | label = np.array(
79 | Image.open(label_path).convert('P')
80 | )
81 | label = self.reduce_zero_label(label)
82 | size = label.shape
83 |
84 | if 'testval' in self.list_path:
85 | image = self.resize_short_length(
86 | image,
87 | short_length=self.base_size,
88 | fit_stride=8
89 | )
90 | image = self.input_transform(image)
91 | image = image.transpose((2, 0, 1))
92 |
93 | return image.copy(), label.copy(), np.array(size), name
94 |
95 | if 'val' in self.list_path:
96 | image, label = self.resize_short_length(
97 | image,
98 | label=label,
99 | short_length=self.base_size,
100 | fit_stride=8
101 | )
102 | image, label = self.rand_crop(image, label)
103 | image = self.input_transform(image)
104 | image = image.transpose((2, 0, 1))
105 |
106 | return image.copy(), label.copy(), np.array(size), name
107 |
108 | image, label = self.resize_short_length(image, label, short_length=self.base_size)
109 | image, label = self.gen_sample(image, label, self.multi_scale, self.flip)
110 |
111 | return image.copy(), label.copy(), np.array(size), name
--------------------------------------------------------------------------------
/hrnet_code/lib/datasets/cityscapes.py:
--------------------------------------------------------------------------------
1 | # ------------------------------------------------------------------------------
2 | # Copyright (c) Microsoft
3 | # Licensed under the MIT License.
4 | # Written by Ke Sun (sunk@mail.ustc.edu.cn)
5 | # ------------------------------------------------------------------------------
6 |
7 | import os
8 |
9 | import cv2
10 | import numpy as np
11 | from PIL import Image
12 |
13 | import torch
14 | from torch.nn import functional as F
15 |
16 | from .base_dataset import BaseDataset
17 |
18 | class Cityscapes(BaseDataset):
19 | def __init__(self,
20 | root,
21 | list_path,
22 | num_samples=None,
23 | num_classes=19,
24 | multi_scale=True,
25 | flip=True,
26 | ignore_label=-1,
27 | base_size=2048,
28 | crop_size=(512, 1024),
29 | downsample_rate=1,
30 | scale_factor=16,
31 | mean=[0.485, 0.456, 0.406],
32 | std=[0.229, 0.224, 0.225]):
33 |
34 | super(Cityscapes, self).__init__(ignore_label, base_size,
35 | crop_size, downsample_rate, scale_factor, mean, std,)
36 |
37 | self.root = root
38 | self.list_path = list_path
39 | self.num_classes = num_classes
40 |
41 | self.multi_scale = multi_scale
42 | self.flip = flip
43 |
44 | self.img_list = [line.strip().split() for line in open(root+list_path)]
45 |
46 | self.files = self.read_files()
47 | if num_samples:
48 | self.files = self.files[:num_samples]
49 |
50 | self.label_mapping = {-1: ignore_label, 0: ignore_label,
51 | 1: ignore_label, 2: ignore_label,
52 | 3: ignore_label, 4: ignore_label,
53 | 5: ignore_label, 6: ignore_label,
54 | 7: 0, 8: 1, 9: ignore_label,
55 | 10: ignore_label, 11: 2, 12: 3,
56 | 13: 4, 14: ignore_label, 15: ignore_label,
57 | 16: ignore_label, 17: 5, 18: ignore_label,
58 | 19: 6, 20: 7, 21: 8, 22: 9, 23: 10, 24: 11,
59 | 25: 12, 26: 13, 27: 14, 28: 15,
60 | 29: ignore_label, 30: ignore_label,
61 | 31: 16, 32: 17, 33: 18}
62 | self.class_weights = torch.FloatTensor([0.8373, 0.918, 0.866, 1.0345,
63 | 1.0166, 0.9969, 0.9754, 1.0489,
64 | 0.8786, 1.0023, 0.9539, 0.9843,
65 | 1.1116, 0.9037, 1.0865, 1.0955,
66 | 1.0865, 1.1529, 1.0507]).cuda()
67 |
68 | def read_files(self):
69 | files = []
70 | if 'test' in self.list_path:
71 | for item in self.img_list:
72 | image_path = item
73 | name = os.path.splitext(os.path.basename(image_path[0]))[0]
74 | files.append({
75 | "img": image_path[0],
76 | "name": name,
77 | })
78 | else:
79 | for item in self.img_list:
80 | image_path, label_path = item
81 | name = os.path.splitext(os.path.basename(label_path))[0]
82 | files.append({
83 | "img": image_path,
84 | "label": label_path,
85 | "name": name,
86 | "weight": 1
87 | })
88 | return files
89 |
90 | def convert_label(self, label, inverse=False):
91 | temp = label.copy()
92 | if inverse:
93 | for v, k in self.label_mapping.items():
94 | label[temp == k] = v
95 | else:
96 | for k, v in self.label_mapping.items():
97 | label[temp == k] = v
98 | return label
99 |
100 | def __getitem__(self, index):
101 | item = self.files[index]
102 | name = item["name"]
103 | image = cv2.imread(os.path.join(self.root,'cityscapes',item["img"]),
104 | cv2.IMREAD_COLOR)
105 | size = image.shape
106 |
107 | if 'test' in self.list_path:
108 | image = self.input_transform(image)
109 | image = image.transpose((2, 0, 1))
110 |
111 | return image.copy(), np.array(size), name
112 |
113 | label = cv2.imread(os.path.join(self.root,'cityscapes',item["label"]),
114 | cv2.IMREAD_GRAYSCALE)
115 | label = self.convert_label(label)
116 |
117 | image, label = self.gen_sample(image, label,
118 | self.multi_scale, self.flip)
119 |
120 | return image.copy(), label.copy(), np.array(size), name
121 |
122 | def multi_scale_inference(self, config, model, image, scales=[1], flip=False):
123 | batch, _, ori_height, ori_width = image.size()
124 | assert batch == 1, "only supporting batchsize 1."
125 | image = image.numpy()[0].transpose((1,2,0)).copy()
126 | stride_h = np.int(self.crop_size[0] * 1.0)
127 | stride_w = np.int(self.crop_size[1] * 1.0)
128 | final_pred = torch.zeros([1, self.num_classes,
129 | ori_height,ori_width]).cuda()
130 | for scale in scales:
131 | new_img = self.multi_scale_aug(image=image,
132 | rand_scale=scale,
133 | rand_crop=False)
134 | height, width = new_img.shape[:-1]
135 |
136 | if scale <= 1.0:
137 | new_img = new_img.transpose((2, 0, 1))
138 | new_img = np.expand_dims(new_img, axis=0)
139 | new_img = torch.from_numpy(new_img)
140 | preds = self.inference(config, model, new_img, flip)
141 | preds = preds[:, :, 0:height, 0:width]
142 | else:
143 | new_h, new_w = new_img.shape[:-1]
144 | rows = np.int(np.ceil(1.0 * (new_h -
145 | self.crop_size[0]) / stride_h)) + 1
146 | cols = np.int(np.ceil(1.0 * (new_w -
147 | self.crop_size[1]) / stride_w)) + 1
148 | preds = torch.zeros([1, self.num_classes,
149 | new_h,new_w]).cuda()
150 | count = torch.zeros([1,1, new_h, new_w]).cuda()
151 |
152 | for r in range(rows):
153 | for c in range(cols):
154 | h0 = r * stride_h
155 | w0 = c * stride_w
156 | h1 = min(h0 + self.crop_size[0], new_h)
157 | w1 = min(w0 + self.crop_size[1], new_w)
158 | h0 = max(int(h1 - self.crop_size[0]), 0)
159 | w0 = max(int(w1 - self.crop_size[1]), 0)
160 | crop_img = new_img[h0:h1, w0:w1, :]
161 | crop_img = crop_img.transpose((2, 0, 1))
162 | crop_img = np.expand_dims(crop_img, axis=0)
163 | crop_img = torch.from_numpy(crop_img)
164 | pred = self.inference(config, model, crop_img, flip)
165 | preds[:,:,h0:h1,w0:w1] += pred[:,:, 0:h1-h0, 0:w1-w0]
166 | count[:,:,h0:h1,w0:w1] += 1
167 | preds = preds / count
168 | preds = preds[:,:,:height,:width]
169 |
170 | preds = F.interpolate(
171 | preds, (ori_height, ori_width),
172 | mode='bilinear', align_corners=config.MODEL.ALIGN_CORNERS
173 | )
174 | final_pred += preds
175 | return final_pred
176 |
177 | def get_palette(self, n):
178 | palette = [0] * (n * 3)
179 | for j in range(0, n):
180 | lab = j
181 | palette[j * 3 + 0] = 0
182 | palette[j * 3 + 1] = 0
183 | palette[j * 3 + 2] = 0
184 | i = 0
185 | while lab:
186 | palette[j * 3 + 0] |= (((lab >> 0) & 1) << (7 - i))
187 | palette[j * 3 + 1] |= (((lab >> 1) & 1) << (7 - i))
188 | palette[j * 3 + 2] |= (((lab >> 2) & 1) << (7 - i))
189 | i += 1
190 | lab >>= 3
191 | return palette
192 |
193 | def save_pred(self, preds, sv_path, name):
194 | palette = self.get_palette(256)
195 | preds = np.asarray(np.argmax(preds.cpu(), axis=1), dtype=np.uint8)
196 | for i in range(preds.shape[0]):
197 | pred = self.convert_label(preds[i], inverse=True)
198 | save_img = Image.fromarray(pred)
199 | save_img.putpalette(palette)
200 | save_img.save(os.path.join(sv_path, name[i]+'.png'))
201 |
202 |
203 |
204 |
--------------------------------------------------------------------------------
/hrnet_code/lib/datasets/cocostuff.py:
--------------------------------------------------------------------------------
1 | # ------------------------------------------------------------------------------
2 | # Copyright (c) Microsoft
3 | # Licensed under the MIT License.
4 | # Written by Ke Sun (sunk@mail.ustc.edu.cn)
5 | # ------------------------------------------------------------------------------
6 |
7 | import os
8 |
9 | import cv2
10 | import numpy as np
11 |
12 | import torch
13 | from torch.nn import functional as F
14 | from PIL import Image
15 |
16 | from .base_dataset import BaseDataset
17 |
18 |
19 | class COCOStuff(BaseDataset):
20 | def __init__(self,
21 | root,
22 | list_path,
23 | num_samples=None,
24 | num_classes=171,
25 | multi_scale=True,
26 | flip=True,
27 | ignore_label=-1,
28 | base_size=520,
29 | crop_size=(520, 520),
30 | downsample_rate=1,
31 | scale_factor=11,
32 | mean=[0.485, 0.456, 0.406],
33 | std=[0.229, 0.224, 0.225]):
34 |
35 | super(COCOStuff, self).__init__(ignore_label, base_size,
36 | crop_size, downsample_rate, scale_factor, mean, std)
37 |
38 | self.root = root
39 | self.num_classes = num_classes
40 | self.list_path = list_path
41 | self.class_weights = None
42 |
43 | self.multi_scale = multi_scale
44 | self.flip = flip
45 | self.img_list = [line.strip().split() for line in open(root+list_path)]
46 |
47 | self.files = self.read_files()
48 | if num_samples:
49 | self.files = self.files[:num_samples]
50 | self.mapping = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 13, 14, 15, 16, 17, 18, 19, 20,
51 | 21, 22, 23, 24, 25, 27, 28, 31, 32, 33, 34, 35, 36, 37, 38, 39,
52 | 40, 41, 42, 43, 44, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58,
53 | 59, 60, 61, 62, 63, 64, 65, 67, 70, 72, 73, 74, 75, 76, 77,
54 | 78, 79, 80, 81, 82, 84, 85, 86, 87, 88, 89, 90, 92, 93, 94, 95, 96,
55 | 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111, 112,
56 | 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127, 128,
57 | 129, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141, 142, 143, 144,
58 | 145, 146, 147, 148, 149, 150, 151, 152, 153, 154, 155, 156, 157, 158, 159, 160,
59 | 161, 162, 163, 164, 165, 166, 167, 168, 169, 170, 171, 172, 173, 174, 175, 176,
60 | 177, 178, 179, 180, 181, 182]
61 |
62 | def read_files(self):
63 | files = []
64 | for item in self.img_list:
65 | image_path, label_path = item
66 | name = os.path.splitext(os.path.basename(label_path))[0]
67 | sample = {
68 | 'img': image_path,
69 | 'label': label_path,
70 | 'name': name
71 | }
72 | files.append(sample)
73 | return files
74 |
75 | def encode_label(self, labelmap):
76 | ret = np.ones_like(labelmap) * 255
77 | for idx, label in enumerate(self.mapping):
78 | ret[labelmap == label] = idx
79 |
80 | return ret
81 |
82 | def resize_image(self, image, label, size):
83 | image = cv2.resize(image, size, interpolation=cv2.INTER_LINEAR)
84 | label = cv2.resize(label, size, interpolation=cv2.INTER_NEAREST)
85 | return image, label
86 |
87 | def __getitem__(self, index):
88 | item = self.files[index]
89 | name = item["name"]
90 | image_path = os.path.join(self.root, 'cocostuff', item['img'])
91 | label_path = os.path.join(self.root, 'cocostuff', item['label'])
92 | image = cv2.imread(
93 | image_path,
94 | cv2.IMREAD_COLOR
95 | )
96 | label = np.array(
97 | Image.open(label_path).convert('P')
98 | )
99 | label = self.encode_label(label)
100 | label = self.reduce_zero_label(label)
101 | size = label.shape
102 |
103 | if 'testval' in self.list_path:
104 | image, border_padding = self.resize_short_length(
105 | image,
106 | short_length=self.base_size,
107 | fit_stride=8,
108 | return_padding=True
109 | )
110 | image = self.input_transform(image)
111 | image = image.transpose((2, 0, 1))
112 |
113 | return image.copy(), label.copy(), np.array(size), name, border_padding
114 |
115 | if 'val' in self.list_path:
116 | image, label = self.resize_short_length(
117 | image,
118 | label=label,
119 | short_length=self.base_size,
120 | fit_stride=8
121 | )
122 | image, label = self.rand_crop(image, label)
123 | image = self.input_transform(image)
124 | image = image.transpose((2, 0, 1))
125 |
126 | return image.copy(), label.copy(), np.array(size), name
127 |
128 | image, label = self.resize_short_length(image, label, short_length=self.base_size)
129 | image, label = self.gen_sample(image, label, self.multi_scale, self.flip)
130 |
131 | return image.copy(), label.copy(), np.array(size), name
--------------------------------------------------------------------------------
/hrnet_code/lib/datasets/lip.py:
--------------------------------------------------------------------------------
1 | # ------------------------------------------------------------------------------
2 | # Copyright (c) Microsoft
3 | # Licensed under the MIT License.
4 | # Written by Ke Sun (sunk@mail.ustc.edu.cn)
5 | # ------------------------------------------------------------------------------
6 |
7 | import os
8 |
9 | import cv2
10 | import numpy as np
11 |
12 | import torch
13 | from torch.nn import functional as F
14 |
15 | from .base_dataset import BaseDataset
16 |
17 |
18 | class LIP(BaseDataset):
19 | def __init__(self,
20 | root,
21 | list_path,
22 | num_samples=None,
23 | num_classes=20,
24 | multi_scale=True,
25 | flip=True,
26 | ignore_label=-1,
27 | base_size=473,
28 | crop_size=(473, 473),
29 | downsample_rate=1,
30 | scale_factor=11,
31 | mean=[0.485, 0.456, 0.406],
32 | std=[0.229, 0.224, 0.225]):
33 |
34 | super(LIP, self).__init__(ignore_label, base_size,
35 | crop_size, downsample_rate, scale_factor, mean, std)
36 |
37 | self.root = root
38 | self.num_classes = num_classes
39 | self.list_path = list_path
40 | self.class_weights = None
41 |
42 | self.multi_scale = multi_scale
43 | self.flip = flip
44 | self.img_list = [line.strip().split() for line in open(root+list_path)]
45 |
46 | self.files = self.read_files()
47 | if num_samples:
48 | self.files = self.files[:num_samples]
49 |
50 | def read_files(self):
51 | files = []
52 | for item in self.img_list:
53 | if 'train' in self.list_path:
54 | image_path, label_path, label_rev_path, _ = item
55 | name = os.path.splitext(os.path.basename(label_path))[0]
56 | sample = {"img": image_path,
57 | "label": label_path,
58 | "label_rev": label_rev_path,
59 | "name": name, }
60 | elif 'val' in self.list_path:
61 | image_path, label_path = item
62 | name = os.path.splitext(os.path.basename(label_path))[0]
63 | sample = {"img": image_path,
64 | "label": label_path,
65 | "name": name, }
66 | else:
67 | raise NotImplementedError('Unknown subset.')
68 | files.append(sample)
69 | return files
70 |
71 | def resize_image(self, image, label, size):
72 | image = cv2.resize(image, size, interpolation=cv2.INTER_LINEAR)
73 | label = cv2.resize(label, size, interpolation=cv2.INTER_NEAREST)
74 | return image, label
75 |
76 | def __getitem__(self, index):
77 | item = self.files[index]
78 | name = item["name"]
79 | item["img"] = item["img"].replace(
80 | "train_images", "LIP_Train").replace("val_images", "LIP_Val")
81 | item["label"] = item["label"].replace(
82 | "train_segmentations", "LIP_Train").replace("val_segmentations", "LIP_Val")
83 | image = cv2.imread(os.path.join(
84 | self.root, 'lip/TrainVal_images/', item["img"]),
85 | cv2.IMREAD_COLOR)
86 | label = cv2.imread(os.path.join(
87 | self.root, 'lip/TrainVal_parsing_annotations/',
88 | item["label"]),
89 | cv2.IMREAD_GRAYSCALE)
90 | size = label.shape
91 |
92 | if 'testval' in self.list_path:
93 | image = cv2.resize(image, self.crop_size,
94 | interpolation=cv2.INTER_LINEAR)
95 | image = self.input_transform(image)
96 | image = image.transpose((2, 0, 1))
97 |
98 | return image.copy(), label.copy(), np.array(size), name
99 |
100 | if self.flip:
101 | flip = np.random.choice(2) * 2 - 1
102 | image = image[:, ::flip, :]
103 | label = label[:, ::flip]
104 |
105 | if flip == -1:
106 | right_idx = [15, 17, 19]
107 | left_idx = [14, 16, 18]
108 | for i in range(0, 3):
109 | right_pos = np.where(label == right_idx[i])
110 | left_pos = np.where(label == left_idx[i])
111 | label[right_pos[0], right_pos[1]] = left_idx[i]
112 | label[left_pos[0], left_pos[1]] = right_idx[i]
113 |
114 | image, label = self.resize_image(image, label, self.crop_size)
115 | image, label = self.gen_sample(image, label,
116 | self.multi_scale, False)
117 |
118 | return image.copy(), label.copy(), np.array(size), name
119 |
120 | def inference(self, config, model, image, flip):
121 | size = image.size()
122 | pred = model(image)
123 | if config.MODEL.NUM_OUTPUTS > 1:
124 | pred = pred[config.TEST.OUTPUT_INDEX]
125 |
126 | pred = F.interpolate(
127 | input=pred, size=size[-2:],
128 | mode='bilinear', align_corners=config.MODEL.ALIGN_CORNERS
129 | )
130 |
131 | if flip:
132 | flip_img = image.numpy()[:, :, :, ::-1]
133 | flip_output = model(torch.from_numpy(flip_img.copy()))
134 |
135 | if config.MODEL.NUM_OUTPUTS > 1:
136 | flip_output = flip_output[config.TEST.OUTPUT_INDEX]
137 |
138 | flip_output = F.interpolate(
139 | input=flip_output, size=size[-2:],
140 | mode='bilinear', align_corners=config.MODEL.ALIGN_CORNERS
141 | )
142 |
143 | flip_output = flip_output.cpu()
144 | flip_pred = flip_output.cpu().numpy().copy()
145 | flip_pred[:, 14, :, :] = flip_output[:, 15, :, :]
146 | flip_pred[:, 15, :, :] = flip_output[:, 14, :, :]
147 | flip_pred[:, 16, :, :] = flip_output[:, 17, :, :]
148 | flip_pred[:, 17, :, :] = flip_output[:, 16, :, :]
149 | flip_pred[:, 18, :, :] = flip_output[:, 19, :, :]
150 | flip_pred[:, 19, :, :] = flip_output[:, 18, :, :]
151 | flip_pred = torch.from_numpy(
152 | flip_pred[:, :, :, ::-1].copy()).cuda()
153 | pred += flip_pred
154 | pred = pred * 0.5
155 | return pred.exp()
156 |
--------------------------------------------------------------------------------
/hrnet_code/lib/datasets/pascal_ctx.py:
--------------------------------------------------------------------------------
1 | # ------------------------------------------------------------------------------
2 | # Copyright (c) Microsoft
3 | # Licensed under the MIT License.
4 | # Written by Ke Sun (sunk@mail.ustc.edu.cn)
5 | # Referring to the implementation in
6 | # https://github.com/zhanghang1989/PyTorch-Encoding
7 | # ------------------------------------------------------------------------------
8 |
9 | import os
10 |
11 | import cv2
12 | import numpy as np
13 | from PIL import Image
14 |
15 | import torch
16 |
17 | from .base_dataset import BaseDataset
18 |
19 | class PASCALContext(BaseDataset):
20 | def __init__(self,
21 | root,
22 | list_path,
23 | num_samples=None,
24 | num_classes=59,
25 | multi_scale=True,
26 | flip=True,
27 | ignore_label=-1,
28 | base_size=520,
29 | crop_size=(480, 480),
30 | downsample_rate=1,
31 | scale_factor=16,
32 | mean=[0.485, 0.456, 0.406],
33 | std=[0.229, 0.224, 0.225],):
34 |
35 | super(PASCALContext, self).__init__(ignore_label, base_size,
36 | crop_size, downsample_rate, scale_factor, mean, std)
37 |
38 | self.root = os.path.join(root, 'pascal_ctx/VOCdevkit/VOC2010')
39 | self.split = list_path
40 |
41 | self.num_classes = num_classes
42 | self.class_weights = None
43 |
44 | self.multi_scale = multi_scale
45 | self.flip = flip
46 | self.crop_size = crop_size
47 |
48 | # prepare data
49 | annots = os.path.join(self.root, 'trainval_merged.json')
50 | img_path = os.path.join(self.root, 'JPEGImages')
51 | from detail import Detail
52 | if 'val' in self.split:
53 | self.detail = Detail(annots, img_path, 'val')
54 | mask_file = os.path.join(self.root, 'val.pth')
55 | elif 'train' in self.split:
56 | self.mode = 'train'
57 | self.detail = Detail(annots, img_path, 'train')
58 | mask_file = os.path.join(self.root, 'train.pth')
59 | else:
60 | raise NotImplementedError('only supporting train and val set.')
61 | self.files = self.detail.getImgs()
62 |
63 | # generate masks
64 | self._mapping = np.sort(np.array([
65 | 0, 2, 259, 260, 415, 324, 9, 258, 144, 18, 19, 22,
66 | 23, 397, 25, 284, 158, 159, 416, 33, 162, 420, 454, 295, 296,
67 | 427, 44, 45, 46, 308, 59, 440, 445, 31, 232, 65, 354, 424,
68 | 68, 326, 72, 458, 34, 207, 80, 355, 85, 347, 220, 349, 360,
69 | 98, 187, 104, 105, 366, 189, 368, 113, 115]))
70 |
71 | self._key = np.array(range(len(self._mapping))).astype('uint8')
72 |
73 | print('mask_file:', mask_file)
74 | if os.path.exists(mask_file):
75 | self.masks = torch.load(mask_file)
76 | else:
77 | self.masks = self._preprocess(mask_file)
78 |
79 | def _class_to_index(self, mask):
80 | # assert the values
81 | values = np.unique(mask)
82 | for i in range(len(values)):
83 | assert(values[i] in self._mapping)
84 | index = np.digitize(mask.ravel(), self._mapping, right=True)
85 | return self._key[index].reshape(mask.shape)
86 |
87 | def _preprocess(self, mask_file):
88 | masks = {}
89 | print("Preprocessing mask, this will take a while." + \
90 | "But don't worry, it only run once for each split.")
91 | for i in range(len(self.files)):
92 | img_id = self.files[i]
93 | mask = Image.fromarray(self._class_to_index(
94 | self.detail.getMask(img_id)))
95 | masks[img_id['image_id']] = mask
96 | torch.save(masks, mask_file)
97 | return masks
98 |
99 | def __getitem__(self, index):
100 | item = self.files[index]
101 | name = item['file_name']
102 | img_id = item['image_id']
103 |
104 | image = cv2.imread(os.path.join(self.detail.img_folder,name),
105 | cv2.IMREAD_COLOR)
106 | label = np.asarray(self.masks[img_id],dtype=np.int)
107 | size = image.shape
108 |
109 | if self.split == 'val':
110 | image = cv2.resize(image, self.crop_size,
111 | interpolation = cv2.INTER_LINEAR)
112 | image = self.input_transform(image)
113 | image = image.transpose((2, 0, 1))
114 |
115 | label = cv2.resize(label, self.crop_size,
116 | interpolation=cv2.INTER_NEAREST)
117 | label = self.label_transform(label)
118 | elif self.split == 'testval':
119 | # evaluate model on val dataset
120 | image = self.input_transform(image)
121 | image = image.transpose((2, 0, 1))
122 | label = self.label_transform(label)
123 | else:
124 | image, label = self.gen_sample(image, label,
125 | self.multi_scale, self.flip)
126 |
127 | return image.copy(), label.copy(), np.array(size), name
128 |
129 | def label_transform(self, label):
130 | if self.num_classes == 59:
131 | # background is ignored
132 | label = np.array(label).astype('int32') - 1
133 | label[label==-2] = -1
134 | else:
135 | label = np.array(label).astype('int32')
136 | return label
137 |
--------------------------------------------------------------------------------
/hrnet_code/lib/models/__init__.py:
--------------------------------------------------------------------------------
1 | # ------------------------------------------------------------------------------
2 | # Copyright (c) Microsoft
3 | # Licensed under the MIT License.
4 | # Written by Ke Sun (sunk@mail.ustc.edu.cn)
5 | # ------------------------------------------------------------------------------
6 |
7 | from __future__ import absolute_import
8 | from __future__ import division
9 | from __future__ import print_function
10 |
11 | import models.seg_hrnet
12 | import models.seg_hrnet_ocr
--------------------------------------------------------------------------------
/hrnet_code/lib/models/bn_helper.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import functools
3 |
4 | if torch.__version__.startswith('0'):
5 | from .sync_bn.inplace_abn.bn import InPlaceABNSync
6 | BatchNorm2d = functools.partial(InPlaceABNSync, activation='none')
7 | BatchNorm2d_class = InPlaceABNSync
8 | relu_inplace = False
9 | else:
10 | # BatchNorm2d_class = BatchNorm2d = torch.nn.SyncBatchNorm
11 | BatchNorm2d_class = BatchNorm2d = torch.nn.BatchNorm2d
12 | relu_inplace = True
--------------------------------------------------------------------------------
/hrnet_code/lib/models/sync_bn/LICENSE:
--------------------------------------------------------------------------------
1 |
2 | BSD 3-Clause License
3 |
4 | Copyright (c) 2017, mapillary
5 | All rights reserved.
6 |
7 | Redistribution and use in source and binary forms, with or without
8 | modification, are permitted provided that the following conditions are met:
9 |
10 | * Redistributions of source code must retain the above copyright notice, this
11 | list of conditions and the following disclaimer.
12 |
13 | * Redistributions in binary form must reproduce the above copyright notice,
14 | this list of conditions and the following disclaimer in the documentation
15 | and/or other materials provided with the distribution.
16 |
17 | * Neither the name of the copyright holder nor the names of its
18 | contributors may be used to endorse or promote products derived from
19 | this software without specific prior written permission.
20 |
21 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
22 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
23 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
24 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
25 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
26 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
27 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
28 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
29 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
30 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
31 |
--------------------------------------------------------------------------------
/hrnet_code/lib/models/sync_bn/__init__.py:
--------------------------------------------------------------------------------
1 | from .inplace_abn import bn
--------------------------------------------------------------------------------
/hrnet_code/lib/models/sync_bn/inplace_abn/__init__.py:
--------------------------------------------------------------------------------
1 | from .bn import ABN, InPlaceABN, InPlaceABNSync
2 | from .functions import ACT_RELU, ACT_LEAKY_RELU, ACT_ELU, ACT_NONE
3 |
--------------------------------------------------------------------------------
/hrnet_code/lib/models/sync_bn/inplace_abn/bn.py:
--------------------------------------------------------------------------------
1 | import os, sys
2 | import torch
3 | import torch.nn as nn
4 | import torch.nn.functional as functional
5 |
6 | try:
7 | from queue import Queue
8 | except ImportError:
9 | from Queue import Queue
10 |
11 | BASE_DIR = os.path.dirname(os.path.abspath(__file__))
12 | sys.path.append(BASE_DIR)
13 | sys.path.append(os.path.join(BASE_DIR, '../src'))
14 | from functions import *
15 |
16 |
17 | class ABN(nn.Module):
18 | """Activated Batch Normalization
19 |
20 | This gathers a `BatchNorm2d` and an activation function in a single module
21 | """
22 |
23 | def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True, activation="leaky_relu", slope=0.01):
24 | """Creates an Activated Batch Normalization module
25 |
26 | Parameters
27 | ----------
28 | num_features : int
29 | Number of feature channels in the input and output.
30 | eps : float
31 | Small constant to prevent numerical issues.
32 | momentum : float
33 | Momentum factor applied to compute running statistics as.
34 | affine : bool
35 | If `True` apply learned scale and shift transformation after normalization.
36 | activation : str
37 | Name of the activation functions, one of: `leaky_relu`, `elu` or `none`.
38 | slope : float
39 | Negative slope for the `leaky_relu` activation.
40 | """
41 | super(ABN, self).__init__()
42 | self.num_features = num_features
43 | self.affine = affine
44 | self.eps = eps
45 | self.momentum = momentum
46 | self.activation = activation
47 | self.slope = slope
48 | if self.affine:
49 | self.weight = nn.Parameter(torch.ones(num_features))
50 | self.bias = nn.Parameter(torch.zeros(num_features))
51 | else:
52 | self.register_parameter('weight', None)
53 | self.register_parameter('bias', None)
54 | self.register_buffer('running_mean', torch.zeros(num_features))
55 | self.register_buffer('running_var', torch.ones(num_features))
56 | self.reset_parameters()
57 |
58 | def reset_parameters(self):
59 | nn.init.constant_(self.running_mean, 0)
60 | nn.init.constant_(self.running_var, 1)
61 | if self.affine:
62 | nn.init.constant_(self.weight, 1)
63 | nn.init.constant_(self.bias, 0)
64 |
65 | def forward(self, x):
66 | x = functional.batch_norm(x, self.running_mean, self.running_var, self.weight, self.bias,
67 | self.training, self.momentum, self.eps)
68 |
69 | if self.activation == ACT_RELU:
70 | return functional.relu(x, inplace=True)
71 | elif self.activation == ACT_LEAKY_RELU:
72 | return functional.leaky_relu(x, negative_slope=self.slope, inplace=True)
73 | elif self.activation == ACT_ELU:
74 | return functional.elu(x, inplace=True)
75 | else:
76 | return x
77 |
78 | def __repr__(self):
79 | rep = '{name}({num_features}, eps={eps}, momentum={momentum},' \
80 | ' affine={affine}, activation={activation}'
81 | if self.activation == "leaky_relu":
82 | rep += ', slope={slope})'
83 | else:
84 | rep += ')'
85 | return rep.format(name=self.__class__.__name__, **self.__dict__)
86 |
87 |
88 | class InPlaceABN(ABN):
89 | """InPlace Activated Batch Normalization"""
90 |
91 | def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True, activation="leaky_relu", slope=0.01):
92 | """Creates an InPlace Activated Batch Normalization module
93 |
94 | Parameters
95 | ----------
96 | num_features : int
97 | Number of feature channels in the input and output.
98 | eps : float
99 | Small constant to prevent numerical issues.
100 | momentum : float
101 | Momentum factor applied to compute running statistics as.
102 | affine : bool
103 | If `True` apply learned scale and shift transformation after normalization.
104 | activation : str
105 | Name of the activation functions, one of: `leaky_relu`, `elu` or `none`.
106 | slope : float
107 | Negative slope for the `leaky_relu` activation.
108 | """
109 | super(InPlaceABN, self).__init__(num_features, eps, momentum, affine, activation, slope)
110 |
111 | def forward(self, x):
112 | return inplace_abn(x, self.weight, self.bias, self.running_mean, self.running_var,
113 | self.training, self.momentum, self.eps, self.activation, self.slope)
114 |
115 |
116 | class InPlaceABNSync(ABN):
117 | """InPlace Activated Batch Normalization with cross-GPU synchronization
118 |
119 | This assumes that it will be replicated across GPUs using the same mechanism as in `nn.DataParallel`.
120 | """
121 |
122 | def __init__(self, num_features, devices=None, eps=1e-5, momentum=0.1, affine=True, activation="leaky_relu",
123 | slope=0.01):
124 | """Creates a synchronized, InPlace Activated Batch Normalization module
125 |
126 | Parameters
127 | ----------
128 | num_features : int
129 | Number of feature channels in the input and output.
130 | devices : list of int or None
131 | IDs of the GPUs that will run the replicas of this module.
132 | eps : float
133 | Small constant to prevent numerical issues.
134 | momentum : float
135 | Momentum factor applied to compute running statistics as.
136 | affine : bool
137 | If `True` apply learned scale and shift transformation after normalization.
138 | activation : str
139 | Name of the activation functions, one of: `leaky_relu`, `elu` or `none`.
140 | slope : float
141 | Negative slope for the `leaky_relu` activation.
142 | """
143 | super(InPlaceABNSync, self).__init__(num_features, eps, momentum, affine, activation, slope)
144 | self.devices = devices if devices else list(range(torch.cuda.device_count()))
145 |
146 | # Initialize queues
147 | self.worker_ids = self.devices[1:]
148 | self.master_queue = Queue(len(self.worker_ids))
149 | self.worker_queues = [Queue(1) for _ in self.worker_ids]
150 |
151 | def forward(self, x):
152 | if x.get_device() == self.devices[0]:
153 | # Master mode
154 | extra = {
155 | "is_master": True,
156 | "master_queue": self.master_queue,
157 | "worker_queues": self.worker_queues,
158 | "worker_ids": self.worker_ids
159 | }
160 | else:
161 | # Worker mode
162 | extra = {
163 | "is_master": False,
164 | "master_queue": self.master_queue,
165 | "worker_queue": self.worker_queues[self.worker_ids.index(x.get_device())]
166 | }
167 |
168 | return inplace_abn_sync(x, self.weight, self.bias, self.running_mean, self.running_var,
169 | extra, self.training, self.momentum, self.eps, self.activation, self.slope)
170 |
171 | def __repr__(self):
172 | rep = '{name}({num_features}, eps={eps}, momentum={momentum},' \
173 | ' affine={affine}, devices={devices}, activation={activation}'
174 | if self.activation == "leaky_relu":
175 | rep += ', slope={slope})'
176 | else:
177 | rep += ')'
178 | return rep.format(name=self.__class__.__name__, **self.__dict__)
179 |
--------------------------------------------------------------------------------
/hrnet_code/lib/models/sync_bn/inplace_abn/functions.py:
--------------------------------------------------------------------------------
1 | from os import path
2 |
3 | import torch.autograd as autograd
4 | import torch.cuda.comm as comm
5 | from torch.autograd.function import once_differentiable
6 | from torch.utils.cpp_extension import load
7 |
8 | _src_path = path.join(path.dirname(path.abspath(__file__)), "src")
9 | _backend = load(name="inplace_abn",
10 | extra_cflags=["-O3"],
11 | sources=[path.join(_src_path, f) for f in [
12 | "inplace_abn.cpp",
13 | "inplace_abn_cpu.cpp",
14 | "inplace_abn_cuda.cu"
15 | ]],
16 | extra_cuda_cflags=["--expt-extended-lambda"])
17 |
18 | # Activation names
19 | ACT_RELU = "relu"
20 | ACT_LEAKY_RELU = "leaky_relu"
21 | ACT_ELU = "elu"
22 | ACT_NONE = "none"
23 |
24 |
25 | def _check(fn, *args, **kwargs):
26 | success = fn(*args, **kwargs)
27 | if not success:
28 | raise RuntimeError("CUDA Error encountered in {}".format(fn))
29 |
30 |
31 | def _broadcast_shape(x):
32 | out_size = []
33 | for i, s in enumerate(x.size()):
34 | if i != 1:
35 | out_size.append(1)
36 | else:
37 | out_size.append(s)
38 | return out_size
39 |
40 |
41 | def _reduce(x):
42 | if len(x.size()) == 2:
43 | return x.sum(dim=0)
44 | else:
45 | n, c = x.size()[0:2]
46 | return x.contiguous().view((n, c, -1)).sum(2).sum(0)
47 |
48 |
49 | def _count_samples(x):
50 | count = 1
51 | for i, s in enumerate(x.size()):
52 | if i != 1:
53 | count *= s
54 | return count
55 |
56 |
57 | def _act_forward(ctx, x):
58 | if ctx.activation == ACT_LEAKY_RELU:
59 | _backend.leaky_relu_forward(x, ctx.slope)
60 | elif ctx.activation == ACT_ELU:
61 | _backend.elu_forward(x)
62 | elif ctx.activation == ACT_NONE:
63 | pass
64 |
65 |
66 | def _act_backward(ctx, x, dx):
67 | if ctx.activation == ACT_LEAKY_RELU:
68 | _backend.leaky_relu_backward(x, dx, ctx.slope)
69 | elif ctx.activation == ACT_ELU:
70 | _backend.elu_backward(x, dx)
71 | elif ctx.activation == ACT_NONE:
72 | pass
73 |
74 |
75 | class InPlaceABN(autograd.Function):
76 | @staticmethod
77 | def forward(ctx, x, weight, bias, running_mean, running_var,
78 | training=True, momentum=0.1, eps=1e-05, activation=ACT_LEAKY_RELU, slope=0.01):
79 | # Save context
80 | ctx.training = training
81 | ctx.momentum = momentum
82 | ctx.eps = eps
83 | ctx.activation = activation
84 | ctx.slope = slope
85 | ctx.affine = weight is not None and bias is not None
86 |
87 | # Prepare inputs
88 | count = _count_samples(x)
89 | x = x.contiguous()
90 | weight = weight.contiguous() if ctx.affine else x.new_empty(0)
91 | bias = bias.contiguous() if ctx.affine else x.new_empty(0)
92 |
93 | if ctx.training:
94 | mean, var = _backend.mean_var(x)
95 |
96 | # Update running stats
97 | running_mean.mul_((1 - ctx.momentum)).add_(ctx.momentum * mean)
98 | running_var.mul_((1 - ctx.momentum)).add_(ctx.momentum * var * count / (count - 1))
99 |
100 | # Mark in-place modified tensors
101 | ctx.mark_dirty(x, running_mean, running_var)
102 | else:
103 | mean, var = running_mean.contiguous(), running_var.contiguous()
104 | ctx.mark_dirty(x)
105 |
106 | # BN forward + activation
107 | _backend.forward(x, mean, var, weight, bias, ctx.affine, ctx.eps)
108 | _act_forward(ctx, x)
109 |
110 | # Output
111 | ctx.var = var
112 | ctx.save_for_backward(x, var, weight, bias)
113 | return x
114 |
115 | @staticmethod
116 | @once_differentiable
117 | def backward(ctx, dz):
118 | z, var, weight, bias = ctx.saved_tensors
119 | dz = dz.contiguous()
120 |
121 | # Undo activation
122 | _act_backward(ctx, z, dz)
123 |
124 | if ctx.training:
125 | edz, eydz = _backend.edz_eydz(z, dz, weight, bias, ctx.affine, ctx.eps)
126 | else:
127 | # TODO: implement simplified CUDA backward for inference mode
128 | edz = dz.new_zeros(dz.size(1))
129 | eydz = dz.new_zeros(dz.size(1))
130 |
131 | dx, dweight, dbias = _backend.backward(z, dz, var, weight, bias, edz, eydz, ctx.affine, ctx.eps)
132 | dweight = dweight if ctx.affine else None
133 | dbias = dbias if ctx.affine else None
134 |
135 | return dx, dweight, dbias, None, None, None, None, None, None, None
136 |
137 |
138 | class InPlaceABNSync(autograd.Function):
139 | @classmethod
140 | def forward(cls, ctx, x, weight, bias, running_mean, running_var,
141 | extra, training=True, momentum=0.1, eps=1e-05, activation=ACT_LEAKY_RELU, slope=0.01):
142 | # Save context
143 | cls._parse_extra(ctx, extra)
144 | ctx.training = training
145 | ctx.momentum = momentum
146 | ctx.eps = eps
147 | ctx.activation = activation
148 | ctx.slope = slope
149 | ctx.affine = weight is not None and bias is not None
150 |
151 | # Prepare inputs
152 | count = _count_samples(x) * (ctx.master_queue.maxsize + 1)
153 | x = x.contiguous()
154 | weight = weight.contiguous() if ctx.affine else x.new_empty(0)
155 | bias = bias.contiguous() if ctx.affine else x.new_empty(0)
156 |
157 | if ctx.training:
158 | mean, var = _backend.mean_var(x)
159 |
160 | if ctx.is_master:
161 | means, vars = [mean.unsqueeze(0)], [var.unsqueeze(0)]
162 | for _ in range(ctx.master_queue.maxsize):
163 | mean_w, var_w = ctx.master_queue.get()
164 | ctx.master_queue.task_done()
165 | means.append(mean_w.unsqueeze(0))
166 | vars.append(var_w.unsqueeze(0))
167 |
168 | means = comm.gather(means)
169 | vars = comm.gather(vars)
170 |
171 | mean = means.mean(0)
172 | var = (vars + (mean - means) ** 2).mean(0)
173 |
174 | tensors = comm.broadcast_coalesced((mean, var), [mean.get_device()] + ctx.worker_ids)
175 | for ts, queue in zip(tensors[1:], ctx.worker_queues):
176 | queue.put(ts)
177 | else:
178 | ctx.master_queue.put((mean, var))
179 | mean, var = ctx.worker_queue.get()
180 | ctx.worker_queue.task_done()
181 |
182 | # Update running stats
183 | running_mean.mul_((1 - ctx.momentum)).add_(ctx.momentum * mean)
184 | running_var.mul_((1 - ctx.momentum)).add_(ctx.momentum * var * count / (count - 1))
185 |
186 | # Mark in-place modified tensors
187 | ctx.mark_dirty(x, running_mean, running_var)
188 | else:
189 | mean, var = running_mean.contiguous(), running_var.contiguous()
190 | ctx.mark_dirty(x)
191 |
192 | # BN forward + activation
193 | _backend.forward(x, mean, var, weight, bias, ctx.affine, ctx.eps)
194 | _act_forward(ctx, x)
195 |
196 | # Output
197 | ctx.var = var
198 | ctx.save_for_backward(x, var, weight, bias)
199 | return x
200 |
201 | @staticmethod
202 | @once_differentiable
203 | def backward(ctx, dz):
204 | z, var, weight, bias = ctx.saved_tensors
205 | dz = dz.contiguous()
206 |
207 | # Undo activation
208 | _act_backward(ctx, z, dz)
209 |
210 | if ctx.training:
211 | edz, eydz = _backend.edz_eydz(z, dz, weight, bias, ctx.affine, ctx.eps)
212 |
213 | if ctx.is_master:
214 | edzs, eydzs = [edz], [eydz]
215 | for _ in range(len(ctx.worker_queues)):
216 | edz_w, eydz_w = ctx.master_queue.get()
217 | ctx.master_queue.task_done()
218 | edzs.append(edz_w)
219 | eydzs.append(eydz_w)
220 |
221 | edz = comm.reduce_add(edzs) / (ctx.master_queue.maxsize + 1)
222 | eydz = comm.reduce_add(eydzs) / (ctx.master_queue.maxsize + 1)
223 |
224 | tensors = comm.broadcast_coalesced((edz, eydz), [edz.get_device()] + ctx.worker_ids)
225 | for ts, queue in zip(tensors[1:], ctx.worker_queues):
226 | queue.put(ts)
227 | else:
228 | ctx.master_queue.put((edz, eydz))
229 | edz, eydz = ctx.worker_queue.get()
230 | ctx.worker_queue.task_done()
231 | else:
232 | edz = dz.new_zeros(dz.size(1))
233 | eydz = dz.new_zeros(dz.size(1))
234 |
235 | dx, dweight, dbias = _backend.backward(z, dz, var, weight, bias, edz, eydz, ctx.affine, ctx.eps)
236 | dweight = dweight if ctx.affine else None
237 | dbias = dbias if ctx.affine else None
238 |
239 | return dx, dweight, dbias, None, None, None, None, None, None, None, None
240 |
241 | @staticmethod
242 | def _parse_extra(ctx, extra):
243 | ctx.is_master = extra["is_master"]
244 | if ctx.is_master:
245 | ctx.master_queue = extra["master_queue"]
246 | ctx.worker_queues = extra["worker_queues"]
247 | ctx.worker_ids = extra["worker_ids"]
248 | else:
249 | ctx.master_queue = extra["master_queue"]
250 | ctx.worker_queue = extra["worker_queue"]
251 |
252 |
253 | inplace_abn = InPlaceABN.apply
254 | inplace_abn_sync = InPlaceABNSync.apply
255 |
256 | __all__ = ["inplace_abn", "inplace_abn_sync", "ACT_RELU", "ACT_LEAKY_RELU", "ACT_ELU", "ACT_NONE"]
257 |
--------------------------------------------------------------------------------
/hrnet_code/lib/models/sync_bn/inplace_abn/src/common.h:
--------------------------------------------------------------------------------
1 | #pragma once
2 |
3 | #include
4 |
5 | /*
6 | * General settings
7 | */
8 | const int WARP_SIZE = 32;
9 | const int MAX_BLOCK_SIZE = 512;
10 |
11 | template
12 | struct Pair {
13 | T v1, v2;
14 | __device__ Pair() {}
15 | __device__ Pair(T _v1, T _v2) : v1(_v1), v2(_v2) {}
16 | __device__ Pair(T v) : v1(v), v2(v) {}
17 | __device__ Pair(int v) : v1(v), v2(v) {}
18 | __device__ Pair &operator+=(const Pair &a) {
19 | v1 += a.v1;
20 | v2 += a.v2;
21 | return *this;
22 | }
23 | };
24 |
25 | /*
26 | * Utility functions
27 | */
28 | template
29 | __device__ __forceinline__ T WARP_SHFL_XOR(T value, int laneMask, int width = warpSize,
30 | unsigned int mask = 0xffffffff) {
31 | #if CUDART_VERSION >= 9000
32 | return __shfl_xor_sync(mask, value, laneMask, width);
33 | #else
34 | return __shfl_xor(value, laneMask, width);
35 | #endif
36 | }
37 |
38 | __device__ __forceinline__ int getMSB(int val) { return 31 - __clz(val); }
39 |
40 | static int getNumThreads(int nElem) {
41 | int threadSizes[5] = {32, 64, 128, 256, MAX_BLOCK_SIZE};
42 | for (int i = 0; i != 5; ++i) {
43 | if (nElem <= threadSizes[i]) {
44 | return threadSizes[i];
45 | }
46 | }
47 | return MAX_BLOCK_SIZE;
48 | }
49 |
50 | template
51 | static __device__ __forceinline__ T warpSum(T val) {
52 | #if __CUDA_ARCH__ >= 300
53 | for (int i = 0; i < getMSB(WARP_SIZE); ++i) {
54 | val += WARP_SHFL_XOR(val, 1 << i, WARP_SIZE);
55 | }
56 | #else
57 | __shared__ T values[MAX_BLOCK_SIZE];
58 | values[threadIdx.x] = val;
59 | __threadfence_block();
60 | const int base = (threadIdx.x / WARP_SIZE) * WARP_SIZE;
61 | for (int i = 1; i < WARP_SIZE; i++) {
62 | val += values[base + ((i + threadIdx.x) % WARP_SIZE)];
63 | }
64 | #endif
65 | return val;
66 | }
67 |
68 | template
69 | static __device__ __forceinline__ Pair warpSum(Pair value) {
70 | value.v1 = warpSum(value.v1);
71 | value.v2 = warpSum(value.v2);
72 | return value;
73 | }
74 |
75 | template
76 | __device__ T reduce(Op op, int plane, int N, int C, int S) {
77 | T sum = (T)0;
78 | for (int batch = 0; batch < N; ++batch) {
79 | for (int x = threadIdx.x; x < S; x += blockDim.x) {
80 | sum += op(batch, plane, x);
81 | }
82 | }
83 |
84 | // sum over NumThreads within a warp
85 | sum = warpSum(sum);
86 |
87 | // 'transpose', and reduce within warp again
88 | __shared__ T shared[32];
89 | __syncthreads();
90 | if (threadIdx.x % WARP_SIZE == 0) {
91 | shared[threadIdx.x / WARP_SIZE] = sum;
92 | }
93 | if (threadIdx.x >= blockDim.x / WARP_SIZE && threadIdx.x < WARP_SIZE) {
94 | // zero out the other entries in shared
95 | shared[threadIdx.x] = (T)0;
96 | }
97 | __syncthreads();
98 | if (threadIdx.x / WARP_SIZE == 0) {
99 | sum = warpSum(shared[threadIdx.x]);
100 | if (threadIdx.x == 0) {
101 | shared[0] = sum;
102 | }
103 | }
104 | __syncthreads();
105 |
106 | // Everyone picks it up, should be broadcast into the whole gradInput
107 | return shared[0];
108 | }
--------------------------------------------------------------------------------
/hrnet_code/lib/models/sync_bn/inplace_abn/src/inplace_abn.cpp:
--------------------------------------------------------------------------------
1 | #include
2 |
3 | #include
4 |
5 | #include "inplace_abn.h"
6 |
7 | std::vector mean_var(at::Tensor x) {
8 | if (x.is_cuda()) {
9 | return mean_var_cuda(x);
10 | } else {
11 | return mean_var_cpu(x);
12 | }
13 | }
14 |
15 | at::Tensor forward(at::Tensor x, at::Tensor mean, at::Tensor var, at::Tensor weight, at::Tensor bias,
16 | bool affine, float eps) {
17 | if (x.is_cuda()) {
18 | return forward_cuda(x, mean, var, weight, bias, affine, eps);
19 | } else {
20 | return forward_cpu(x, mean, var, weight, bias, affine, eps);
21 | }
22 | }
23 |
24 | std::vector edz_eydz(at::Tensor z, at::Tensor dz, at::Tensor weight, at::Tensor bias,
25 | bool affine, float eps) {
26 | if (z.is_cuda()) {
27 | return edz_eydz_cuda(z, dz, weight, bias, affine, eps);
28 | } else {
29 | return edz_eydz_cpu(z, dz, weight, bias, affine, eps);
30 | }
31 | }
32 |
33 | std::vector backward(at::Tensor z, at::Tensor dz, at::Tensor var, at::Tensor weight, at::Tensor bias,
34 | at::Tensor edz, at::Tensor eydz, bool affine, float eps) {
35 | if (z.is_cuda()) {
36 | return backward_cuda(z, dz, var, weight, bias, edz, eydz, affine, eps);
37 | } else {
38 | return backward_cpu(z, dz, var, weight, bias, edz, eydz, affine, eps);
39 | }
40 | }
41 |
42 | void leaky_relu_forward(at::Tensor z, float slope) {
43 | at::leaky_relu_(z, slope);
44 | }
45 |
46 | void leaky_relu_backward(at::Tensor z, at::Tensor dz, float slope) {
47 | if (z.is_cuda()) {
48 | return leaky_relu_backward_cuda(z, dz, slope);
49 | } else {
50 | return leaky_relu_backward_cpu(z, dz, slope);
51 | }
52 | }
53 |
54 | void elu_forward(at::Tensor z) {
55 | at::elu_(z);
56 | }
57 |
58 | void elu_backward(at::Tensor z, at::Tensor dz) {
59 | if (z.is_cuda()) {
60 | return elu_backward_cuda(z, dz);
61 | } else {
62 | return elu_backward_cpu(z, dz);
63 | }
64 | }
65 |
66 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
67 | m.def("mean_var", &mean_var, "Mean and variance computation");
68 | m.def("forward", &forward, "In-place forward computation");
69 | m.def("edz_eydz", &edz_eydz, "First part of backward computation");
70 | m.def("backward", &backward, "Second part of backward computation");
71 | m.def("leaky_relu_forward", &leaky_relu_forward, "Leaky relu forward computation");
72 | m.def("leaky_relu_backward", &leaky_relu_backward, "Leaky relu backward computation and inversion");
73 | m.def("elu_forward", &elu_forward, "Elu forward computation");
74 | m.def("elu_backward", &elu_backward, "Elu backward computation and inversion");
75 | }
--------------------------------------------------------------------------------
/hrnet_code/lib/models/sync_bn/inplace_abn/src/inplace_abn.h:
--------------------------------------------------------------------------------
1 | #pragma once
2 |
3 | #include
4 |
5 | #include
6 |
7 | std::vector mean_var_cpu(at::Tensor x);
8 | std::vector mean_var_cuda(at::Tensor x);
9 |
10 | at::Tensor forward_cpu(at::Tensor x, at::Tensor mean, at::Tensor var, at::Tensor weight, at::Tensor bias,
11 | bool affine, float eps);
12 | at::Tensor forward_cuda(at::Tensor x, at::Tensor mean, at::Tensor var, at::Tensor weight, at::Tensor bias,
13 | bool affine, float eps);
14 |
15 | std::vector edz_eydz_cpu(at::Tensor z, at::Tensor dz, at::Tensor weight, at::Tensor bias,
16 | bool affine, float eps);
17 | std::vector edz_eydz_cuda(at::Tensor z, at::Tensor dz, at::Tensor weight, at::Tensor bias,
18 | bool affine, float eps);
19 |
20 | std::vector backward_cpu(at::Tensor z, at::Tensor dz, at::Tensor var, at::Tensor weight, at::Tensor bias,
21 | at::Tensor edz, at::Tensor eydz, bool affine, float eps);
22 | std::vector backward_cuda(at::Tensor z, at::Tensor dz, at::Tensor var, at::Tensor weight, at::Tensor bias,
23 | at::Tensor edz, at::Tensor eydz, bool affine, float eps);
24 |
25 | void leaky_relu_backward_cpu(at::Tensor z, at::Tensor dz, float slope);
26 | void leaky_relu_backward_cuda(at::Tensor z, at::Tensor dz, float slope);
27 |
28 | void elu_backward_cpu(at::Tensor z, at::Tensor dz);
29 | void elu_backward_cuda(at::Tensor z, at::Tensor dz);
--------------------------------------------------------------------------------
/hrnet_code/lib/models/sync_bn/inplace_abn/src/inplace_abn_cpu.cpp:
--------------------------------------------------------------------------------
1 | #include
2 |
3 | #include
4 |
5 | #include "inplace_abn.h"
6 |
7 | at::Tensor reduce_sum(at::Tensor x) {
8 | if (x.ndimension() == 2) {
9 | return x.sum(0);
10 | } else {
11 | auto x_view = x.view({x.size(0), x.size(1), -1});
12 | return x_view.sum(-1).sum(0);
13 | }
14 | }
15 |
16 | at::Tensor broadcast_to(at::Tensor v, at::Tensor x) {
17 | if (x.ndimension() == 2) {
18 | return v;
19 | } else {
20 | std::vector broadcast_size = {1, -1};
21 | for (int64_t i = 2; i < x.ndimension(); ++i)
22 | broadcast_size.push_back(1);
23 |
24 | return v.view(broadcast_size);
25 | }
26 | }
27 |
28 | int64_t count(at::Tensor x) {
29 | int64_t count = x.size(0);
30 | for (int64_t i = 2; i < x.ndimension(); ++i)
31 | count *= x.size(i);
32 |
33 | return count;
34 | }
35 |
36 | at::Tensor invert_affine(at::Tensor z, at::Tensor weight, at::Tensor bias, bool affine, float eps) {
37 | if (affine) {
38 | return (z - broadcast_to(bias, z)) / broadcast_to(at::abs(weight) + eps, z);
39 | } else {
40 | return z;
41 | }
42 | }
43 |
44 | std::vector mean_var_cpu(at::Tensor x) {
45 | auto num = count(x);
46 | auto mean = reduce_sum(x) / num;
47 | auto diff = x - broadcast_to(mean, x);
48 | auto var = reduce_sum(diff.pow(2)) / num;
49 |
50 | return {mean, var};
51 | }
52 |
53 | at::Tensor forward_cpu(at::Tensor x, at::Tensor mean, at::Tensor var, at::Tensor weight, at::Tensor bias,
54 | bool affine, float eps) {
55 | auto gamma = affine ? at::abs(weight) + eps : at::ones_like(var);
56 | auto mul = at::rsqrt(var + eps) * gamma;
57 |
58 | x.sub_(broadcast_to(mean, x));
59 | x.mul_(broadcast_to(mul, x));
60 | if (affine) x.add_(broadcast_to(bias, x));
61 |
62 | return x;
63 | }
64 |
65 | std::vector edz_eydz_cpu(at::Tensor z, at::Tensor dz, at::Tensor weight, at::Tensor bias,
66 | bool affine, float eps) {
67 | auto edz = reduce_sum(dz);
68 | auto y = invert_affine(z, weight, bias, affine, eps);
69 | auto eydz = reduce_sum(y * dz);
70 |
71 | return {edz, eydz};
72 | }
73 |
74 | std::vector backward_cpu(at::Tensor z, at::Tensor dz, at::Tensor var, at::Tensor weight, at::Tensor bias,
75 | at::Tensor edz, at::Tensor eydz, bool affine, float eps) {
76 | auto y = invert_affine(z, weight, bias, affine, eps);
77 | auto mul = affine ? at::rsqrt(var + eps) * (at::abs(weight) + eps) : at::rsqrt(var + eps);
78 |
79 | auto num = count(z);
80 | auto dx = (dz - broadcast_to(edz / num, dz) - y * broadcast_to(eydz / num, dz)) * broadcast_to(mul, dz);
81 |
82 | auto dweight = at::empty(z.type(), {0});
83 | auto dbias = at::empty(z.type(), {0});
84 | if (affine) {
85 | dweight = eydz * at::sign(weight);
86 | dbias = edz;
87 | }
88 |
89 | return {dx, dweight, dbias};
90 | }
91 |
92 | void leaky_relu_backward_cpu(at::Tensor z, at::Tensor dz, float slope) {
93 | AT_DISPATCH_FLOATING_TYPES(z.type(), "leaky_relu_backward_cpu", ([&] {
94 | int64_t count = z.numel();
95 | auto *_z = z.data();
96 | auto *_dz = dz.data();
97 |
98 | for (int64_t i = 0; i < count; ++i) {
99 | if (_z[i] < 0) {
100 | _z[i] *= 1 / slope;
101 | _dz[i] *= slope;
102 | }
103 | }
104 | }));
105 | }
106 |
107 | void elu_backward_cpu(at::Tensor z, at::Tensor dz) {
108 | AT_DISPATCH_FLOATING_TYPES(z.type(), "elu_backward_cpu", ([&] {
109 | int64_t count = z.numel();
110 | auto *_z = z.data();
111 | auto *_dz = dz.data();
112 |
113 | for (int64_t i = 0; i < count; ++i) {
114 | if (_z[i] < 0) {
115 | _z[i] = log1p(_z[i]);
116 | _dz[i] *= (_z[i] + 1.f);
117 | }
118 | }
119 | }));
120 | }
--------------------------------------------------------------------------------
/hrnet_code/lib/models/sync_bn/inplace_abn/src/inplace_abn_cuda.cu:
--------------------------------------------------------------------------------
1 | #include
2 |
3 | #include
4 | #include
5 |
6 | #include
7 |
8 | #include "common.h"
9 | #include "inplace_abn.h"
10 |
11 | // Checks
12 | #ifndef AT_CHECK
13 | #define AT_CHECK AT_ASSERT
14 | #endif
15 | #define CHECK_CUDA(x) AT_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor")
16 | #define CHECK_CONTIGUOUS(x) AT_CHECK(x.is_contiguous(), #x " must be contiguous")
17 | #define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)
18 |
19 | // Utilities
20 | void get_dims(at::Tensor x, int64_t& num, int64_t& chn, int64_t& sp) {
21 | num = x.size(0);
22 | chn = x.size(1);
23 | sp = 1;
24 | for (int64_t i = 2; i < x.ndimension(); ++i)
25 | sp *= x.size(i);
26 | }
27 |
28 | // Operations for reduce
29 | template
30 | struct SumOp {
31 | __device__ SumOp(const T *t, int c, int s)
32 | : tensor(t), chn(c), sp(s) {}
33 | __device__ __forceinline__ T operator()(int batch, int plane, int n) {
34 | return tensor[(batch * chn + plane) * sp + n];
35 | }
36 | const T *tensor;
37 | const int chn;
38 | const int sp;
39 | };
40 |
41 | template
42 | struct VarOp {
43 | __device__ VarOp(T m, const T *t, int c, int s)
44 | : mean(m), tensor(t), chn(c), sp(s) {}
45 | __device__ __forceinline__ T operator()(int batch, int plane, int n) {
46 | T val = tensor[(batch * chn + plane) * sp + n];
47 | return (val - mean) * (val - mean);
48 | }
49 | const T mean;
50 | const T *tensor;
51 | const int chn;
52 | const int sp;
53 | };
54 |
55 | template
56 | struct GradOp {
57 | __device__ GradOp(T _weight, T _bias, const T *_z, const T *_dz, int c, int s)
58 | : weight(_weight), bias(_bias), z(_z), dz(_dz), chn(c), sp(s) {}
59 | __device__ __forceinline__ Pair operator()(int batch, int plane, int n) {
60 | T _y = (z[(batch * chn + plane) * sp + n] - bias) / weight;
61 | T _dz = dz[(batch * chn + plane) * sp + n];
62 | return Pair(_dz, _y * _dz);
63 | }
64 | const T weight;
65 | const T bias;
66 | const T *z;
67 | const T *dz;
68 | const int chn;
69 | const int sp;
70 | };
71 |
72 | /***********
73 | * mean_var
74 | ***********/
75 |
76 | template
77 | __global__ void mean_var_kernel(const T *x, T *mean, T *var, int num, int chn, int sp) {
78 | int plane = blockIdx.x;
79 | T norm = T(1) / T(num * sp);
80 |
81 | T _mean = reduce>(SumOp(x, chn, sp), plane, num, chn, sp) * norm;
82 | __syncthreads();
83 | T _var = reduce>(VarOp(_mean, x, chn, sp), plane, num, chn, sp) * norm;
84 |
85 | if (threadIdx.x == 0) {
86 | mean[plane] = _mean;
87 | var[plane] = _var;
88 | }
89 | }
90 |
91 | std::vector mean_var_cuda(at::Tensor x) {
92 | CHECK_INPUT(x);
93 |
94 | // Extract dimensions
95 | int64_t num, chn, sp;
96 | get_dims(x, num, chn, sp);
97 |
98 | // Prepare output tensors
99 | auto mean = at::empty(x.type(), {chn});
100 | auto var = at::empty(x.type(), {chn});
101 |
102 | // Run kernel
103 | dim3 blocks(chn);
104 | dim3 threads(getNumThreads(sp));
105 | AT_DISPATCH_FLOATING_TYPES(x.type(), "mean_var_cuda", ([&] {
106 | mean_var_kernel<<>>(
107 | x.data(),
108 | mean.data(),
109 | var.data(),
110 | num, chn, sp);
111 | }));
112 |
113 | return {mean, var};
114 | }
115 |
116 | /**********
117 | * forward
118 | **********/
119 |
120 | template
121 | __global__ void forward_kernel(T *x, const T *mean, const T *var, const T *weight, const T *bias,
122 | bool affine, float eps, int num, int chn, int sp) {
123 | int plane = blockIdx.x;
124 |
125 | T _mean = mean[plane];
126 | T _var = var[plane];
127 | T _weight = affine ? abs(weight[plane]) + eps : T(1);
128 | T _bias = affine ? bias[plane] : T(0);
129 |
130 | T mul = rsqrt(_var + eps) * _weight;
131 |
132 | for (int batch = 0; batch < num; ++batch) {
133 | for (int n = threadIdx.x; n < sp; n += blockDim.x) {
134 | T _x = x[(batch * chn + plane) * sp + n];
135 | T _y = (_x - _mean) * mul + _bias;
136 |
137 | x[(batch * chn + plane) * sp + n] = _y;
138 | }
139 | }
140 | }
141 |
142 | at::Tensor forward_cuda(at::Tensor x, at::Tensor mean, at::Tensor var, at::Tensor weight, at::Tensor bias,
143 | bool affine, float eps) {
144 | CHECK_INPUT(x);
145 | CHECK_INPUT(mean);
146 | CHECK_INPUT(var);
147 | CHECK_INPUT(weight);
148 | CHECK_INPUT(bias);
149 |
150 | // Extract dimensions
151 | int64_t num, chn, sp;
152 | get_dims(x, num, chn, sp);
153 |
154 | // Run kernel
155 | dim3 blocks(chn);
156 | dim3 threads(getNumThreads(sp));
157 | AT_DISPATCH_FLOATING_TYPES(x.type(), "forward_cuda", ([&] {
158 | forward_kernel<<>>(
159 | x.data(),
160 | mean.data(),
161 | var.data(),
162 | weight.data(),
163 | bias.data(),
164 | affine, eps, num, chn, sp);
165 | }));
166 |
167 | return x;
168 | }
169 |
170 | /***********
171 | * edz_eydz
172 | ***********/
173 |
174 | template
175 | __global__ void edz_eydz_kernel(const T *z, const T *dz, const T *weight, const T *bias,
176 | T *edz, T *eydz, bool affine, float eps, int num, int chn, int sp) {
177 | int plane = blockIdx.x;
178 |
179 | T _weight = affine ? abs(weight[plane]) + eps : 1.f;
180 | T _bias = affine ? bias[plane] : 0.f;
181 |
182 | Pair res = reduce, GradOp>(GradOp(_weight, _bias, z, dz, chn, sp), plane, num, chn, sp);
183 | __syncthreads();
184 |
185 | if (threadIdx.x == 0) {
186 | edz[plane] = res.v1;
187 | eydz[plane] = res.v2;
188 | }
189 | }
190 |
191 | std::vector edz_eydz_cuda(at::Tensor z, at::Tensor dz, at::Tensor weight, at::Tensor bias,
192 | bool affine, float eps) {
193 | CHECK_INPUT(z);
194 | CHECK_INPUT(dz);
195 | CHECK_INPUT(weight);
196 | CHECK_INPUT(bias);
197 |
198 | // Extract dimensions
199 | int64_t num, chn, sp;
200 | get_dims(z, num, chn, sp);
201 |
202 | auto edz = at::empty(z.type(), {chn});
203 | auto eydz = at::empty(z.type(), {chn});
204 |
205 | // Run kernel
206 | dim3 blocks(chn);
207 | dim3 threads(getNumThreads(sp));
208 | AT_DISPATCH_FLOATING_TYPES(z.type(), "edz_eydz_cuda", ([&] {
209 | edz_eydz_kernel<<>>(
210 | z.data(),
211 | dz.data(),
212 | weight.data(),
213 | bias.data(),
214 | edz.data(),
215 | eydz.data(),
216 | affine, eps, num, chn, sp);
217 | }));
218 |
219 | return {edz, eydz};
220 | }
221 |
222 | /***********
223 | * backward
224 | ***********/
225 |
226 | template
227 | __global__ void backward_kernel(const T *z, const T *dz, const T *var, const T *weight, const T *bias, const T *edz,
228 | const T *eydz, T *dx, T *dweight, T *dbias,
229 | bool affine, float eps, int num, int chn, int sp) {
230 | int plane = blockIdx.x;
231 |
232 | T _weight = affine ? abs(weight[plane]) + eps : 1.f;
233 | T _bias = affine ? bias[plane] : 0.f;
234 | T _var = var[plane];
235 | T _edz = edz[plane];
236 | T _eydz = eydz[plane];
237 |
238 | T _mul = _weight * rsqrt(_var + eps);
239 | T count = T(num * sp);
240 |
241 | for (int batch = 0; batch < num; ++batch) {
242 | for (int n = threadIdx.x; n < sp; n += blockDim.x) {
243 | T _dz = dz[(batch * chn + plane) * sp + n];
244 | T _y = (z[(batch * chn + plane) * sp + n] - _bias) / _weight;
245 |
246 | dx[(batch * chn + plane) * sp + n] = (_dz - _edz / count - _y * _eydz / count) * _mul;
247 | }
248 | }
249 |
250 | if (threadIdx.x == 0) {
251 | if (affine) {
252 | dweight[plane] = weight[plane] > 0 ? _eydz : -_eydz;
253 | dbias[plane] = _edz;
254 | }
255 | }
256 | }
257 |
258 | std::vector backward_cuda(at::Tensor z, at::Tensor dz, at::Tensor var, at::Tensor weight, at::Tensor bias,
259 | at::Tensor edz, at::Tensor eydz, bool affine, float eps) {
260 | CHECK_INPUT(z);
261 | CHECK_INPUT(dz);
262 | CHECK_INPUT(var);
263 | CHECK_INPUT(weight);
264 | CHECK_INPUT(bias);
265 | CHECK_INPUT(edz);
266 | CHECK_INPUT(eydz);
267 |
268 | // Extract dimensions
269 | int64_t num, chn, sp;
270 | get_dims(z, num, chn, sp);
271 |
272 | auto dx = at::zeros_like(z);
273 | auto dweight = at::zeros_like(weight);
274 | auto dbias = at::zeros_like(bias);
275 |
276 | // Run kernel
277 | dim3 blocks(chn);
278 | dim3 threads(getNumThreads(sp));
279 | AT_DISPATCH_FLOATING_TYPES(z.type(), "backward_cuda", ([&] {
280 | backward_kernel<<>>(
281 | z.data(),
282 | dz.data(),
283 | var.data(),
284 | weight.data(),
285 | bias.data(),
286 | edz.data(),
287 | eydz.data(),
288 | dx.data(),
289 | dweight.data(),
290 | dbias.data(),
291 | affine, eps, num, chn, sp);
292 | }));
293 |
294 | return {dx, dweight, dbias};
295 | }
296 |
297 | /**************
298 | * activations
299 | **************/
300 |
301 | template
302 | inline void leaky_relu_backward_impl(T *z, T *dz, float slope, int64_t count) {
303 | // Create thrust pointers
304 | thrust::device_ptr th_z = thrust::device_pointer_cast(z);
305 | thrust::device_ptr th_dz = thrust::device_pointer_cast(dz);
306 |
307 | thrust::transform_if(th_dz, th_dz + count, th_z, th_dz,
308 | [slope] __device__ (const T& dz) { return dz * slope; },
309 | [] __device__ (const T& z) { return z < 0; });
310 | thrust::transform_if(th_z, th_z + count, th_z,
311 | [slope] __device__ (const T& z) { return z / slope; },
312 | [] __device__ (const T& z) { return z < 0; });
313 | }
314 |
315 | void leaky_relu_backward_cuda(at::Tensor z, at::Tensor dz, float slope) {
316 | CHECK_INPUT(z);
317 | CHECK_INPUT(dz);
318 |
319 | int64_t count = z.numel();
320 |
321 | AT_DISPATCH_FLOATING_TYPES(z.type(), "leaky_relu_backward_cuda", ([&] {
322 | leaky_relu_backward_impl(z.data(), dz.data(), slope, count);
323 | }));
324 | }
325 |
326 | template
327 | inline void elu_backward_impl(T *z, T *dz, int64_t count) {
328 | // Create thrust pointers
329 | thrust::device_ptr th_z = thrust::device_pointer_cast(z);
330 | thrust::device_ptr th_dz = thrust::device_pointer_cast(dz);
331 |
332 | thrust::transform_if(th_dz, th_dz + count, th_z, th_z, th_dz,
333 | [] __device__ (const T& dz, const T& z) { return dz * (z + 1.); },
334 | [] __device__ (const T& z) { return z < 0; });
335 | thrust::transform_if(th_z, th_z + count, th_z,
336 | [] __device__ (const T& z) { return log1p(z); },
337 | [] __device__ (const T& z) { return z < 0; });
338 | }
339 |
340 | void elu_backward_cuda(at::Tensor z, at::Tensor dz) {
341 | CHECK_INPUT(z);
342 | CHECK_INPUT(dz);
343 |
344 | int64_t count = z.numel();
345 |
346 | AT_DISPATCH_FLOATING_TYPES(z.type(), "leaky_relu_backward_cuda", ([&] {
347 | elu_backward_impl(z.data(), dz.data(), count);
348 | }));
349 | }
350 |
--------------------------------------------------------------------------------
/hrnet_code/lib/utils/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SHI-Labs/Rethinking-Text-Segmentation/223f7ffc822c345ce1a7c0eb3d4fac58a43d6a3a/hrnet_code/lib/utils/__init__.py
--------------------------------------------------------------------------------
/hrnet_code/lib/utils/distributed.py:
--------------------------------------------------------------------------------
1 | # ------------------------------------------------------------------------------
2 | # Copyright (c) Microsoft
3 | # Licensed under the MIT License.
4 | # Written by Jingyi Xie (hsfzxjy@gmail.com)
5 | # ------------------------------------------------------------------------------
6 |
7 | import torch
8 | import torch.distributed as torch_dist
9 |
10 | def is_distributed():
11 | return torch_dist.is_initialized()
12 |
13 | def get_world_size():
14 | if not torch_dist.is_initialized():
15 | return 1
16 | return torch_dist.get_world_size()
17 |
18 | def get_rank():
19 | if not torch_dist.is_initialized():
20 | return 0
21 | return torch_dist.get_rank()
--------------------------------------------------------------------------------
/hrnet_code/lib/utils/modelsummary.py:
--------------------------------------------------------------------------------
1 | # ------------------------------------------------------------------------------
2 | # Copyright (c) Microsoft
3 | # Licensed under the MIT License.
4 | # Written by Bin Xiao (Bin.Xiao@microsoft.com)
5 | # Modified by Ke Sun (sunk@mail.ustc.edu.cn)
6 | # ------------------------------------------------------------------------------
7 |
8 | from __future__ import absolute_import
9 | from __future__ import division
10 | from __future__ import print_function
11 |
12 | import os
13 | import logging
14 | from collections import namedtuple
15 |
16 | import torch
17 | import torch.nn as nn
18 |
19 | def get_model_summary(model, *input_tensors, item_length=26, verbose=False):
20 | """
21 | :param model:
22 | :param input_tensors:
23 | :param item_length:
24 | :return:
25 | """
26 |
27 | summary = []
28 |
29 | ModuleDetails = namedtuple(
30 | "Layer", ["name", "input_size", "output_size", "num_parameters", "multiply_adds"])
31 | hooks = []
32 | layer_instances = {}
33 |
34 | def add_hooks(module):
35 |
36 | def hook(module, input, output):
37 | class_name = str(module.__class__.__name__)
38 |
39 | instance_index = 1
40 | if class_name not in layer_instances:
41 | layer_instances[class_name] = instance_index
42 | else:
43 | instance_index = layer_instances[class_name] + 1
44 | layer_instances[class_name] = instance_index
45 |
46 | layer_name = class_name + "_" + str(instance_index)
47 |
48 | params = 0
49 |
50 | if class_name.find("Conv") != -1 or class_name.find("BatchNorm") != -1 or \
51 | class_name.find("Linear") != -1:
52 | for param_ in module.parameters():
53 | params += param_.view(-1).size(0)
54 |
55 | flops = "Not Available"
56 | if class_name.find("Conv") != -1 and hasattr(module, "weight"):
57 | flops = (
58 | torch.prod(
59 | torch.LongTensor(list(module.weight.data.size()))) *
60 | torch.prod(
61 | torch.LongTensor(list(output.size())[2:]))).item()
62 | elif isinstance(module, nn.Linear):
63 | flops = (torch.prod(torch.LongTensor(list(output.size()))) \
64 | * input[0].size(1)).item()
65 |
66 | if isinstance(input[0], list):
67 | input = input[0]
68 | if isinstance(output, list):
69 | output = output[0]
70 |
71 | summary.append(
72 | ModuleDetails(
73 | name=layer_name,
74 | input_size=list(input[0].size()),
75 | output_size=list(output.size()),
76 | num_parameters=params,
77 | multiply_adds=flops)
78 | )
79 |
80 | if not isinstance(module, nn.ModuleList) \
81 | and not isinstance(module, nn.Sequential) \
82 | and module != model:
83 | hooks.append(module.register_forward_hook(hook))
84 |
85 | model.eval()
86 | model.apply(add_hooks)
87 |
88 | space_len = item_length
89 |
90 | model(*input_tensors)
91 | for hook in hooks:
92 | hook.remove()
93 |
94 | details = ''
95 | if verbose:
96 | details = "Model Summary" + \
97 | os.linesep + \
98 | "Name{}Input Size{}Output Size{}Parameters{}Multiply Adds (Flops){}".format(
99 | ' ' * (space_len - len("Name")),
100 | ' ' * (space_len - len("Input Size")),
101 | ' ' * (space_len - len("Output Size")),
102 | ' ' * (space_len - len("Parameters")),
103 | ' ' * (space_len - len("Multiply Adds (Flops)"))) \
104 | + os.linesep + '-' * space_len * 5 + os.linesep
105 |
106 | params_sum = 0
107 | flops_sum = 0
108 | for layer in summary:
109 | params_sum += layer.num_parameters
110 | if layer.multiply_adds != "Not Available":
111 | flops_sum += layer.multiply_adds
112 | if verbose:
113 | details += "{}{}{}{}{}{}{}{}{}{}".format(
114 | layer.name,
115 | ' ' * (space_len - len(layer.name)),
116 | layer.input_size,
117 | ' ' * (space_len - len(str(layer.input_size))),
118 | layer.output_size,
119 | ' ' * (space_len - len(str(layer.output_size))),
120 | layer.num_parameters,
121 | ' ' * (space_len - len(str(layer.num_parameters))),
122 | layer.multiply_adds,
123 | ' ' * (space_len - len(str(layer.multiply_adds)))) \
124 | + os.linesep + '-' * space_len * 5 + os.linesep
125 |
126 | details += os.linesep \
127 | + "Total Parameters: {:,}".format(params_sum) \
128 | + os.linesep + '-' * space_len * 5 + os.linesep
129 | details += "Total Multiply Adds (For Convolution and Linear Layers only): {:,} GFLOPs".format(flops_sum/(1024**3)) \
130 | + os.linesep + '-' * space_len * 5 + os.linesep
131 | details += "Number of Layers" + os.linesep
132 | for layer in layer_instances:
133 | details += "{} : {} layers ".format(layer, layer_instances[layer])
134 |
135 | return details
--------------------------------------------------------------------------------
/hrnet_code/lib/utils/utils.py:
--------------------------------------------------------------------------------
1 | # ------------------------------------------------------------------------------
2 | # Copyright (c) Microsoft
3 | # Licensed under the MIT License.
4 | # Written by Ke Sun (sunk@mail.ustc.edu.cn)
5 | # ------------------------------------------------------------------------------
6 |
7 | from __future__ import absolute_import
8 | from __future__ import division
9 | from __future__ import print_function
10 |
11 | import os
12 | import logging
13 | import time
14 | from pathlib import Path
15 |
16 | import numpy as np
17 |
18 | import torch
19 | import torch.nn as nn
20 |
21 | class FullModel(nn.Module):
22 | """
23 | Distribute the loss on multi-gpu to reduce
24 | the memory cost in the main gpu.
25 | You can check the following discussion.
26 | https://discuss.pytorch.org/t/dataparallel-imbalanced-memory-usage/22551/21
27 | """
28 | def __init__(self, model, loss):
29 | super(FullModel, self).__init__()
30 | self.model = model
31 | self.loss = loss
32 |
33 | def forward(self, inputs, labels, *args, **kwargs):
34 | outputs = self.model(inputs, *args, **kwargs)
35 | loss = self.loss(outputs, labels)
36 | return torch.unsqueeze(loss,0), outputs
37 |
38 | class AverageMeter(object):
39 | """Computes and stores the average and current value"""
40 |
41 | def __init__(self):
42 | self.initialized = False
43 | self.val = None
44 | self.avg = None
45 | self.sum = None
46 | self.count = None
47 |
48 | def initialize(self, val, weight):
49 | self.val = val
50 | self.avg = val
51 | self.sum = val * weight
52 | self.count = weight
53 | self.initialized = True
54 |
55 | def update(self, val, weight=1):
56 | if not self.initialized:
57 | self.initialize(val, weight)
58 | else:
59 | self.add(val, weight)
60 |
61 | def add(self, val, weight):
62 | self.val = val
63 | self.sum += val * weight
64 | self.count += weight
65 | self.avg = self.sum / self.count
66 |
67 | def value(self):
68 | return self.val
69 |
70 | def average(self):
71 | return self.avg
72 |
73 | def create_logger(cfg, cfg_name, phase='train'):
74 | root_output_dir = Path(cfg.OUTPUT_DIR)
75 | # set up logger
76 | if not root_output_dir.exists():
77 | print('=> creating {}'.format(root_output_dir))
78 | root_output_dir.mkdir()
79 |
80 | dataset = cfg.DATASET.DATASET
81 | model = cfg.MODEL.NAME
82 | cfg_name = os.path.basename(cfg_name).split('.')[0]
83 |
84 | final_output_dir = root_output_dir / dataset / cfg_name
85 |
86 | print('=> creating {}'.format(final_output_dir))
87 | final_output_dir.mkdir(parents=True, exist_ok=True)
88 |
89 | time_str = time.strftime('%Y-%m-%d-%H-%M')
90 | log_file = '{}_{}_{}.log'.format(cfg_name, time_str, phase)
91 | final_log_file = final_output_dir / log_file
92 | head = '%(asctime)-15s %(message)s'
93 | logging.basicConfig(filename=str(final_log_file),
94 | format=head)
95 | logger = logging.getLogger()
96 | logger.setLevel(logging.INFO)
97 | console = logging.StreamHandler()
98 | logging.getLogger('').addHandler(console)
99 |
100 | tensorboard_log_dir = Path(cfg.LOG_DIR) / dataset / model / \
101 | (cfg_name + '_' + time_str)
102 | print('=> creating {}'.format(tensorboard_log_dir))
103 | tensorboard_log_dir.mkdir(parents=True, exist_ok=True)
104 |
105 | return logger, str(final_output_dir), str(tensorboard_log_dir)
106 |
107 | def get_confusion_matrix(label, pred, size, num_class, ignore=-1):
108 | """
109 | Calcute the confusion matrix by given label and pred
110 | """
111 | output = pred.cpu().numpy().transpose(0, 2, 3, 1)
112 | seg_pred = np.asarray(np.argmax(output, axis=3), dtype=np.uint8)
113 | seg_gt = np.asarray(
114 | label.cpu().numpy()[:, :size[-2], :size[-1]], dtype=np.int)
115 |
116 | ignore_index = seg_gt != ignore
117 | seg_gt = seg_gt[ignore_index]
118 | seg_pred = seg_pred[ignore_index]
119 |
120 | index = (seg_gt * num_class + seg_pred).astype('int32')
121 | label_count = np.bincount(index)
122 | confusion_matrix = np.zeros((num_class, num_class))
123 |
124 | for i_label in range(num_class):
125 | for i_pred in range(num_class):
126 | cur_index = i_label * num_class + i_pred
127 | if cur_index < len(label_count):
128 | confusion_matrix[i_label,
129 | i_pred] = label_count[cur_index]
130 | return confusion_matrix
131 |
132 | def adjust_learning_rate(optimizer, base_lr, max_iters,
133 | cur_iters, power=0.9, nbb_mult=10):
134 | lr = base_lr*((1-float(cur_iters)/max_iters)**(power))
135 | optimizer.param_groups[0]['lr'] = lr
136 | if len(optimizer.param_groups) == 2:
137 | optimizer.param_groups[1]['lr'] = lr * nbb_mult
138 | return lr
--------------------------------------------------------------------------------
/hrnet_code/requirements.txt:
--------------------------------------------------------------------------------
1 | EasyDict==1.7
2 | opencv-python==3.4.2.17
3 | shapely==1.6.4
4 | Cython
5 | scipy
6 | pandas
7 | pyyaml
8 | json_tricks
9 | scikit-image
10 | yacs>=0.1.5
11 | tensorboardX>=1.6
12 | tqdm
13 | ninja
14 |
--------------------------------------------------------------------------------
/hrnet_code/tools/_init_paths.py:
--------------------------------------------------------------------------------
1 | # ------------------------------------------------------------------------------
2 | # Copyright (c) Microsoft
3 | # Licensed under the MIT License.
4 | # Written by Ke Sun (sunk@mail.ustc.edu.cn)
5 | # ------------------------------------------------------------------------------
6 |
7 | from __future__ import absolute_import
8 | from __future__ import division
9 | from __future__ import print_function
10 |
11 | import os.path as osp
12 | import sys
13 |
14 |
15 | def add_path(path):
16 | if path not in sys.path:
17 | sys.path.insert(0, path)
18 |
19 | this_dir = osp.dirname(__file__)
20 |
21 | lib_path = osp.join(this_dir, '..', 'lib')
22 | add_path(lib_path)
23 |
--------------------------------------------------------------------------------
/hrnet_code/tools/test.py:
--------------------------------------------------------------------------------
1 | # ------------------------------------------------------------------------------
2 | # Copyright (c) Microsoft
3 | # Licensed under the MIT License.
4 | # Written by Ke Sun (sunk@mail.ustc.edu.cn)
5 | # ------------------------------------------------------------------------------
6 |
7 | import argparse
8 | import os
9 | import pprint
10 | import shutil
11 | import sys
12 |
13 | import logging
14 | import time
15 | import timeit
16 | from pathlib import Path
17 |
18 | import numpy as np
19 |
20 | import torch
21 | import torch.nn as nn
22 | import torch.backends.cudnn as cudnn
23 |
24 | import _init_paths
25 | import models
26 | import datasets
27 | from config import config
28 | from config import update_config
29 | from core.function import testval, test
30 | from utils.modelsummary import get_model_summary
31 | from utils.utils import create_logger, FullModel
32 |
33 | def parse_args():
34 | parser = argparse.ArgumentParser(description='Train segmentation network')
35 |
36 | parser.add_argument('--cfg',
37 | help='experiment configure file name',
38 | required=True,
39 | type=str)
40 | parser.add_argument('opts',
41 | help="Modify config options using the command-line",
42 | default=None,
43 | nargs=argparse.REMAINDER)
44 |
45 | args = parser.parse_args()
46 | update_config(config, args)
47 |
48 | return args
49 |
50 | def main():
51 | args = parse_args()
52 |
53 | logger, final_output_dir, _ = create_logger(
54 | config, args.cfg, 'test')
55 |
56 | logger.info(pprint.pformat(args))
57 | logger.info(pprint.pformat(config))
58 |
59 | # cudnn related setting
60 | cudnn.benchmark = config.CUDNN.BENCHMARK
61 | cudnn.deterministic = config.CUDNN.DETERMINISTIC
62 | cudnn.enabled = config.CUDNN.ENABLED
63 |
64 | # build model
65 | if torch.__version__.startswith('1'):
66 | module = eval('models.'+config.MODEL.NAME)
67 | module.BatchNorm2d_class = module.BatchNorm2d = torch.nn.BatchNorm2d
68 | model = eval('models.'+config.MODEL.NAME +
69 | '.get_seg_model')(config)
70 |
71 | dump_input = torch.rand(
72 | (1, 3, config.TRAIN.IMAGE_SIZE[1], config.TRAIN.IMAGE_SIZE[0])
73 | )
74 | logger.info(get_model_summary(model.cuda(), dump_input.cuda()))
75 |
76 | if config.TEST.MODEL_FILE:
77 | model_state_file = config.TEST.MODEL_FILE
78 | else:
79 | model_state_file = os.path.join(final_output_dir, 'final_state.pth')
80 | logger.info('=> loading model from {}'.format(model_state_file))
81 |
82 | pretrained_dict = torch.load(model_state_file)
83 | if 'state_dict' in pretrained_dict:
84 | pretrained_dict = pretrained_dict['state_dict']
85 | model_dict = model.state_dict()
86 | pretrained_dict = {k[6:]: v for k, v in pretrained_dict.items()
87 | if k[6:] in model_dict.keys()}
88 | for k, _ in pretrained_dict.items():
89 | logger.info(
90 | '=> loading {} from pretrained model'.format(k))
91 | model_dict.update(pretrained_dict)
92 | model.load_state_dict(model_dict)
93 |
94 | gpus = list(config.GPUS)
95 | model = nn.DataParallel(model, device_ids=gpus).cuda()
96 |
97 | # prepare data
98 | test_size = (config.TEST.IMAGE_SIZE[1], config.TEST.IMAGE_SIZE[0])
99 | test_dataset = eval('datasets.'+config.DATASET.DATASET)(
100 | root=config.DATASET.ROOT,
101 | list_path=config.DATASET.TEST_SET,
102 | num_samples=None,
103 | num_classes=config.DATASET.NUM_CLASSES,
104 | multi_scale=False,
105 | flip=False,
106 | ignore_label=config.TRAIN.IGNORE_LABEL,
107 | base_size=config.TEST.BASE_SIZE,
108 | crop_size=test_size,
109 | downsample_rate=1)
110 |
111 | testloader = torch.utils.data.DataLoader(
112 | test_dataset,
113 | batch_size=1,
114 | shuffle=False,
115 | num_workers=config.WORKERS,
116 | pin_memory=True)
117 |
118 | start = timeit.default_timer()
119 | if 'val' in config.DATASET.TEST_SET:
120 | mean_IoU, IoU_array, pixel_acc, mean_acc = testval(config,
121 | test_dataset,
122 | testloader,
123 | model)
124 |
125 | msg = 'MeanIU: {: 4.4f}, Pixel_Acc: {: 4.4f}, \
126 | Mean_Acc: {: 4.4f}, Class IoU: '.format(mean_IoU,
127 | pixel_acc, mean_acc)
128 | logging.info(msg)
129 | logging.info(IoU_array)
130 | elif 'test' in config.DATASET.TEST_SET:
131 | test(config,
132 | test_dataset,
133 | testloader,
134 | model,
135 | sv_dir=final_output_dir)
136 |
137 | end = timeit.default_timer()
138 | logger.info('Mins: %d' % np.int((end-start)/60))
139 | logger.info('Done')
140 |
141 |
142 | if __name__ == '__main__':
143 | main()
144 |
--------------------------------------------------------------------------------
/inference.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | import os
4 | import os.path as osp
5 | import numpy as np
6 | from PIL import Image
7 |
8 | import argparse
9 |
10 | from lib.model_zoo.texrnet import TexRNet
11 | from lib.model_zoo.hrnet import HRNet_Base
12 | from lib.model_zoo.deeplab import DeepLabv3p_Base
13 | from lib.model_zoo.resnet import ResNet_Dilated_Base
14 | from lib import torchutils
15 |
16 | import tqdm
17 |
18 | class TextRNet_HRNet_Wrapper(object):
19 | """
20 | This is the UltraSRWrapper with render-level batchification.
21 | """
22 | def __init__(self,
23 | device,
24 | pth=None,):
25 | """
26 | Create uspcale instance
27 | :param device: device on run the upscale pipeline (if GPU is accessible should be 'cuda')
28 | :param pth: path to model
29 | """
30 | self.model = self.make_model(pth)
31 | self.model.eval()
32 | self.model = self.model.to(device)
33 | self.device = device
34 |
35 | @staticmethod
36 | def make_model(pth=None):
37 | backbone = HRNet_Base(
38 | oc_n=720,
39 | align_corners=True,
40 | ignore_label=999,
41 | stage1_para={
42 | 'BLOCK' : 'BOTTLENECK',
43 | 'FUSE_METHOD' : 'SUM',
44 | 'NUM_BLOCKS' : [4],
45 | 'NUM_BRANCHES': 1,
46 | 'NUM_CHANNELS': [64],
47 | 'NUM_MODULES' : 1 },
48 | stage2_para={
49 | 'BLOCK' : 'BASIC',
50 | 'FUSE_METHOD' : 'SUM',
51 | 'NUM_BLOCKS' : [4, 4],
52 | 'NUM_BRANCHES': 2,
53 | 'NUM_CHANNELS': [48, 96],
54 | 'NUM_MODULES' : 1 },
55 | stage3_para={
56 | 'BLOCK' : 'BASIC',
57 | 'FUSE_METHOD' : 'SUM',
58 | 'NUM_BLOCKS' : [4, 4, 4],
59 | 'NUM_BRANCHES': 3,
60 | 'NUM_CHANNELS': [48, 96, 192],
61 | 'NUM_MODULES' : 4 },
62 | stage4_para={
63 | 'BLOCK' : 'BASIC',
64 | 'FUSE_METHOD' : 'SUM',
65 | 'NUM_BLOCKS' : [4, 4, 4, 4],
66 | 'NUM_BRANCHES': 4,
67 | 'NUM_CHANNELS': [48, 96, 192, 384],
68 | 'NUM_MODULES' : 3 },
69 | final_conv_kernel = 1,
70 | )
71 |
72 | model = TexRNet(
73 | bbn_name='hrnet',
74 | bbn=backbone,
75 | ic_n=720,
76 | rfn_c_n=[725, 64, 64],
77 | sem_n=2,
78 | conv_type='conv',
79 | bn_type='bn',
80 | relu_type='relu',
81 | align_corners=True,
82 | ignore_label=None,
83 | bias_att_type='cossim',
84 | ineval_output_argmax=False,
85 | )
86 | if pth is not None:
87 | paras = torch.load(pth, map_location=torch.device('cpu'))
88 | new_paras = model.state_dict()
89 | new_paras.update(paras)
90 | model.load_state_dict(new_paras)
91 | return model
92 |
93 | def process(self, pil_image):
94 | im = np.array(pil_image.convert("RGB"))
95 | im = im/255
96 | im = im - np.array([0.485, 0.456, 0.406])
97 | im = im / np.array([0.229, 0.224, 0.225])
98 | im = np.transpose(im, (2, 0, 1))[None]
99 | im = torch.FloatTensor(im).to(self.device)
100 |
101 | # This step will auto-adjust model if it is torch-DDP
102 | netm = getattr(self.model, 'module', self.model)
103 | _, _, oh, ow = im.shape
104 | ac = True
105 |
106 | prfnc_ms, pcount_ms = {}, {}
107 |
108 | for mstag, mssize in [
109 | ['0.75x', 385],
110 | ['1.00x', 513],
111 | ['1.25x', 641],
112 | ['1.50x', 769],
113 | ['1.75x', 897],
114 | ['2.00x', 1025],
115 | ['2.25x', 1153],
116 | ['2.50x', 1281], ]:
117 | # by area
118 | ratio = np.sqrt(mssize**2 / (oh*ow))
119 | th, tw = int(oh*ratio), int(ow*ratio)
120 | tw = tw//32*32+1
121 | th = th//32*32+1
122 |
123 | imi = {
124 | 'nofp' : torchutils.interpolate_2d(
125 | size=(th, tw), mode='bilinear',
126 | align_corners=ac)(im)}
127 | imi['flip'] = torch.flip(imi['nofp'], dims=[-1])
128 |
129 | for fliptag, imii in imi.items():
130 | with torch.no_grad():
131 | pred = netm(imii)
132 | psem = torchutils.interpolate_2d(
133 | size=(oh, ow),
134 | mode='bilinear', align_corners=ac)(pred['predsem'])
135 | prfn = torchutils.interpolate_2d(
136 | size=(oh, ow),
137 | mode='bilinear', align_corners=ac)(pred['predrfn'])
138 |
139 | if fliptag == 'flip':
140 | psem = torch.flip(psem, dims=[-1])
141 | prfn = torch.flip(prfn, dims=[-1])
142 | elif fliptag == 'nofp':
143 | pass
144 | else:
145 | raise ValueError
146 |
147 | try:
148 | prfnc_ms[mstag] += prfn
149 | pcount_ms[mstag] += 1
150 | except:
151 | prfnc_ms[mstag] = prfn
152 | pcount_ms[mstag] = 1
153 |
154 | pred = sum([pi for pi in prfnc_ms.values()])
155 | pred /= sum([ni for ni in pcount_ms.values()])
156 | pred = torch.argmax(psem, dim=1)
157 | pred = pred[0].cpu().detach().numpy()
158 | pred = (pred * 255).astype(np.uint8)
159 | return Image.fromarray(pred)
160 |
161 | class TextRNet_Deeplab_Wrapper(TextRNet_HRNet_Wrapper):
162 | @staticmethod
163 | def make_model(pth=None):
164 | raise NotImplementedError
165 | # resnet = ResNet_Dilated_Base(
166 | # block =
167 | # layer_n =
168 | # )
169 |
170 | # model = TexRNet(
171 | # bbn_name='hrnet',
172 | # bbn=backbone,
173 | # ic_n=720,
174 | # rfn_c_n=[725, 64, 64],
175 | # sem_n=2,
176 | # conv_type='conv',
177 | # bn_type='bn',
178 | # relu_type='relu',
179 | # align_corners=True,
180 | # ignore_label=None,
181 | # bias_att_type='cossim',
182 | # ineval_output_argmax=False,
183 | # )
184 | # if pth is not None:
185 | # paras = torch.load(pth, map_location=torch.device('cpu'))
186 | # new_paras = model.state_dict()
187 | # new_paras.update(paras)
188 | # model.load_state_dict(new_paras)
189 | # return model
190 |
191 | if __name__ == "__main__":
192 | parser = argparse.ArgumentParser()
193 | parser.add_argument("--input", type=str, required=True, help="input folder or a single input file")
194 | parser.add_argument("--output", type=str, required=True, help="output folder or a single output file")
195 | parser.add_argument("--method", '-m', type=str, default='textrnet_hrnet')
196 | args = parser.parse_args()
197 |
198 | if osp.isdir(args.input):
199 | if not osp.exists(args.output):
200 | os.makedirs(args.output)
201 | assert osp.isdir(args.output), \
202 | "When --input is a directory, --output must be a directory!"
203 | elif osp.isfile(args.input):
204 | assert not osp.isdir(args.output), \
205 | "When --input is a file, --output must be a file!"
206 | else:
207 | assert False, "No such input!"
208 |
209 | assert args.input != args.output, \
210 | "--input and --output points to the same location, "\
211 | "this is not allowed because it will override the input files."
212 |
213 | if args.method == 'textrnet_hrnet':
214 | wrapper = TextRNet_HRNet_Wrapper
215 | model_path = 'pretrained/texrnet_hrnet.pth'
216 | elif args.method == 'textrnet_deeplab':
217 | wrapper = TextRNet_Deeplab_Wrapper
218 | model_path = 'pretrained/texrnet_deeplab.pth'
219 | else:
220 | assert False, 'No such model.'
221 |
222 | enl = wrapper(torch.device("cuda:0"), model_path)
223 |
224 | if osp.isfile(args.input):
225 | imgs = [args.input]
226 | outs = [args.output]
227 | else:
228 | imgs = sorted(os.listdir(args.input))
229 | outs = [
230 | osp.join(args.output, '{}.png'.format(osp.splitext(fi)[0]))
231 | for fi in imgs
232 | ]
233 | imgs = [osp.join(args.input, fi) for fi in imgs]
234 |
235 | for fin, fout in tqdm(zip(imgs, outs), total=len(imgs)):
236 | x = Image.open(fin).convert('RGB')
237 | y = enl.process(x)
238 | y.save(fout)
239 |
--------------------------------------------------------------------------------
/lib/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SHI-Labs/Rethinking-Text-Segmentation/223f7ffc822c345ce1a7c0eb3d4fac58a43d6a3a/lib/__init__.py
--------------------------------------------------------------------------------
/lib/cfg_helper.py:
--------------------------------------------------------------------------------
1 | import os
2 | import os.path as osp
3 | import shutil
4 | import copy
5 | import time
6 | from easydict import EasyDict as edict
7 | import pprint
8 | import numpy as np
9 | import torch
10 | import matplotlib
11 | import argparse
12 | import torch
13 | from configs.cfg_base import cfg_train, cfg_test
14 | import json
15 |
16 | def singleton(class_):
17 | instances = {}
18 | def getinstance(*args, **kwargs):
19 | if class_ not in instances:
20 | instances[class_] = class_(*args, **kwargs)
21 | return instances[class_]
22 | return getinstance
23 |
24 | @singleton
25 | class cfg_unique_holder(object):
26 | def __init__(self):
27 | self.cfg = None
28 | # this is use to track the main codes.
29 | self.code = set()
30 | def save_cfg(self, cfg):
31 | self.cfg = copy.deepcopy(cfg)
32 | def add_code(self, code):
33 | """
34 | A new main code is reached and
35 | its name is added.
36 | """
37 | self.code.add(code)
38 |
39 | def get_experiment_id():
40 | time.sleep(0.5)
41 | return int(time.time()*100)
42 |
43 | def set_debug_cfg(cfg, istrain=True):
44 | if istrain:
45 | cfg.EXPERIMENT_ID = 999999999999
46 | cfg.TRAIN.SAVE_INIT_MODEL = False
47 | cfg.TRAIN.COMMENT = "Debug"
48 | cfg.LOG_DIR = osp.join(
49 | cfg.MISC_DIR,
50 | '{}_{}'.format(cfg.MODEL.MODEL_NAME, cfg.DATA.DATASET_NAME),
51 | str(cfg.EXPERIMENT_ID))
52 | cfg.LOG_FILE = osp.join(cfg.LOG_DIR, 'train.log')
53 | cfg.TRAIN.BATCH_SIZE = None
54 | cfg.TRAIN.BATCH_SIZE_PER_GPU = 2
55 | else:
56 | cfg.LOG_DIR = cfg.LOG_DIR.replace(cfg.TEST.SUB_DIR, 'debug')
57 | cfg.TEST.SUB_DIR = 'debug'
58 | cfg.LOG_FILE = osp.join(
59 | cfg.LOG_DIR, 'eval.log')
60 | cfg.TEST.BATCH_SIZE = None
61 | cfg.TEST.BATCH_SIZE_PER_GPU = 1
62 |
63 | cfg.DATA.NUM_WORKERS = None
64 | cfg.DATA.NUM_WORKERS_PER_GPU = 0
65 | cfg.MATPLOTLIB_MODE = 'TKAgg'
66 | return cfg
67 |
68 | def experiment_folder(cfg,
69 | isnew=False,
70 | sig=['nosig'],
71 | mdds_override=None,
72 | **kwargs):
73 | """
74 | Args:
75 | cfg: easydict,
76 | the config easydict
77 | isnew: bool,
78 | whether this is a new folder or not
79 | True, create a path using exid and sig
80 | False, find a path based on exid and refpath
81 | sig: [sig1, ...] array of str
82 | when isnew == True, these are the signature
83 | put as [exid]_[sig1]_.._[sign]
84 | signatures after (and include) 'hided' will be
85 | hided from the name.
86 | mdds_override: str or None,
87 | the override folder for [modelname]_[dataset],
88 | None, no override
89 | Returns:
90 | workdir: str,
91 | the absolute path to the folder.
92 | """
93 | if mdds_override is None:
94 | refdir = osp.join(
95 | cfg.MISC_DIR,
96 | '{}_{}'.format(
97 | cfg.MODEL.MODEL_NAME, cfg.DATA.DATASET_NAME),)
98 | else:
99 | refdir = osp.abspath(osp.join(
100 | cfg.MISC_DIR, mdds_override))
101 |
102 | if isnew:
103 | try:
104 | hided_after = sig.index('hided')
105 | except:
106 | hided_after = len(sig)
107 | sig = sig[0:hided_after]
108 | workdir = '_'.join([str(cfg.EXPERIMENT_ID)] + sig)
109 | return osp.join(refdir, workdir)
110 | else:
111 | for d in os.listdir(refdir):
112 | if not d.find(str(cfg.EXPERIMENT_ID))==0:
113 | continue
114 | if not osp.isdir(osp.join(refdir, d)):
115 | continue
116 | try:
117 | workdir = osp.join(
118 | refdir, d, cfg.TEST.SUB_DIR)
119 | except:
120 | workdir = osp.join(
121 | refdir, d)
122 | return workdir
123 | raise ValueError
124 |
125 | def get_experiment_folder(exid,
126 | path,
127 | full_path=False,
128 | **kwargs):
129 | """
130 | Args:
131 | exid: int,
132 | experiment ID
133 | path: path,
134 | the base folder to search...
135 | folder should be like _....
136 | full_path: bool,
137 | whether return the full path or not.
138 | """
139 | for d in os.listdir(path):
140 | if d.find(str(exid))==0:
141 | if osp.isdir(osp.join(path, d)):
142 | if not full_path:
143 | return d
144 | else:
145 | return osp.abspath(osp.join(path, d))
146 | raise ValueError
147 |
148 | def set_experiment_folder(exid,
149 | signature,
150 | **kwargs):
151 | """
152 | Args:
153 | exid: experiment ID
154 | signature: string or array of strings tells the tags that append after exid
155 | as a experiment folder...
156 | """
157 | if isinstance(signature, str):
158 | signature = [signature]
159 | return '_'.join([str(exid)] + signature)
160 |
161 | def hided_sig_to_str(sig):
162 | """
163 | Args:
164 | sig: [] of str,
165 | Returns:
166 | out: str
167 | If sig is [..., 'hided', 'sig1', 'sig2']
168 | out = hided: sig1_sig2_...
169 | If sig do not have 'hided'
170 | out = None
171 | """
172 | try:
173 | hided_after = sig.index('hided')
174 | except:
175 | return None
176 |
177 | return 'hided: '+'_'.join(sig[hided_after+1:])
178 |
179 | def common_initiates(cfg):
180 | if cfg.GPU_DEVICE != 'all':
181 | os.environ["CUDA_VISIBLE_DEVICES"] = ','.join(
182 | [str(gid) for gid in cfg.GPU_DEVICE])
183 | cfg.GPU_COUNT = len(cfg.GPU_DEVICE)
184 | else:
185 | cfg.GPU_COUNT = torch.cuda.device_count()
186 |
187 | if 'TRAIN' in cfg:
188 | if (cfg.TRAIN.BATCH_SIZE is None) and \
189 | (cfg.TRAIN.BATCH_SIZE_PER_GPU is None):
190 | raise ValueError
191 | elif (cfg.TRAIN.BATCH_SIZE is not None) and \
192 | (cfg.TRAIN.BATCH_SIZE_PER_GPU is not None):
193 | if cfg.TRAIN.BATCH_SIZE != \
194 | cfg.TRAIN.BATCH_SIZE_PER_GPU * cfg.GPU_COUNT:
195 | raise ValueError
196 | elif cfg.TRAIN.BATCH_SIZE is None:
197 | cfg.TRAIN.BATCH_SIZE = \
198 | cfg.TRAIN.BATCH_SIZE_PER_GPU * cfg.GPU_COUNT
199 | else:
200 | cfg.TRAIN.BATCH_SIZE_PER_GPU = \
201 | cfg.TRAIN.BATCH_SIZE // cfg.GPU_COUNT
202 | if 'TEST' in cfg:
203 | if (cfg.TEST.BATCH_SIZE is None) and \
204 | (cfg.TEST.BATCH_SIZE_PER_GPU is None):
205 | raise ValueError
206 | elif (cfg.TEST.BATCH_SIZE is not None) and \
207 | (cfg.TEST.BATCH_SIZE_PER_GPU is not None):
208 | if cfg.TEST.BATCH_SIZE != \
209 | cfg.TEST.BATCH_SIZE_PER_GPU * cfg.GPU_COUNT:
210 | raise ValueError
211 | elif cfg.TEST.BATCH_SIZE is None:
212 | cfg.TEST.BATCH_SIZE = \
213 | cfg.TEST.BATCH_SIZE_PER_GPU * cfg.GPU_COUNT
214 | else:
215 | cfg.TEST.BATCH_SIZE_PER_GPU = \
216 | cfg.TEST.BATCH_SIZE // cfg.GPU_COUNT
217 |
218 | if (cfg.DATA.NUM_WORKERS is None) and \
219 | (cfg.DATA.NUM_WORKERS_PER_GPU is None):
220 | raise ValueError
221 | elif (cfg.DATA.NUM_WORKERS is not None) and \
222 | (cfg.DATA.NUM_WORKERS_PER_GPU is not None):
223 | if cfg.DATA.NUM_WORKERS != \
224 | cfg.DATA.NUM_WORKERS_PER_GPU * cfg.GPU_COUNT:
225 | raise ValueError
226 | elif cfg.DATA.NUM_WORKERS is None:
227 | cfg.DATA.NUM_WORKERS = \
228 | cfg.DATA.NUM_WORKERS_PER_GPU * cfg.GPU_COUNT
229 | else:
230 | cfg.DATA.NUM_WORKERS_PER_GPU = \
231 | cfg.DATA.NUM_WORKERS // cfg.GPU_COUNT
232 |
233 | cfg.MAIN_CODE_PATH = osp.abspath(osp.join(
234 | osp.dirname(__file__), '..'))
235 | cfg.MAIN_CODE = list(cfg_unique_holder().code)
236 |
237 | cfg.TORCH_VERSION = torch.__version__
238 |
239 | pprint.pprint(cfg)
240 | if cfg.LOG_FILE is not None:
241 | if not osp.exists(osp.dirname(cfg.LOG_FILE)):
242 | os.makedirs(osp.dirname(cfg.LOG_FILE))
243 | with open(cfg.LOG_FILE, 'w') as f:
244 | pprint.pprint(cfg, f)
245 | with open(osp.join(cfg.LOG_DIR, 'config.json'), 'w') as f:
246 | json.dump(cfg, f, indent=4)
247 |
248 | # step3.1 code saving
249 | if cfg.SAVE_CODE:
250 | codedir = osp.join(cfg.LOG_DIR, 'code')
251 | if osp.exists(codedir):
252 | shutil.rmtree(codedir)
253 | for d in ['configs', 'lib']:
254 | fromcodedir = osp.abspath(
255 | osp.join(cfg.MAIN_CODE_PATH, d))
256 | tocodedir = osp.join(codedir, d)
257 | shutil.copytree(
258 | fromcodedir, tocodedir,
259 | ignore=shutil.ignore_patterns('*__pycache__*', '*build*'))
260 | for codei in cfg.MAIN_CODE:
261 | shutil.copy(osp.join(cfg.MAIN_CODE_PATH, codei), codedir)
262 |
263 | # step3.2
264 | if cfg.RND_SEED is None:
265 | pass
266 | elif isinstance(cfg.RND_SEED, int):
267 | np.random.seed(cfg.RND_SEED)
268 | torch.manual_seed(cfg.RND_SEED)
269 | else:
270 | raise ValueError
271 |
272 | # step3.3
273 | if cfg.RND_RECORDING:
274 | rnduh().reset(osp.join(cfg.LOG_DIR, 'rcache'), None)
275 | if not isinstance(cfg.RND_SEED, str):
276 | pass
277 | elif osp.isfile(cfg.RND_SEED):
278 | print('[Warning]: RND_SEED is a file but RND_RECORDING is on and disables the file.')
279 | # raise ValueError
280 |
281 | # step3.4
282 | try:
283 | if cfg.MATPLOTLIB_MODE is not None:
284 | matplotlib.use(cfg.MATPLOTLIB_MODE)
285 | except:
286 | pass
287 |
288 | return cfg
289 |
290 | def common_argparse(extra_parsing_f=None):
291 | """
292 | Outputs:
293 | cfg: edict,
294 | 'DEBUG'
295 | 'GPU_DEVICE'
296 | 'ISTRAIN'
297 | exid: [] of int -or- None
298 | experiment id followed by --eval
299 | None so do the regular training
300 | """
301 |
302 | parser = argparse.ArgumentParser()
303 | parser.add_argument('--debug', action='store_true')
304 | parser.add_argument('--eval', nargs='+', type=int)
305 | parser.add_argument('--gpu', nargs='+', type=int)
306 | parser.add_argument('--port', type=int)
307 |
308 | if extra_parsing_f is not None:
309 | args, cfg = extra_parsing_f(parser)
310 | else:
311 | args = parser.parse_args()
312 | cfg = edict()
313 |
314 | cfg.DEBUG = args.debug
315 | try:
316 | cfg.GPU_DEVICE = list(args.gpu)
317 | except:
318 | pass
319 |
320 | try:
321 | eval_exid = list(args.eval)
322 | except:
323 | eval_exid = None
324 |
325 | try:
326 | port = int(args.port)
327 | cfg.DIST_URL = 'tcp://127.0.0.1:{}'.format(port)
328 | except:
329 | pass
330 |
331 | return cfg, eval_exid
332 |
--------------------------------------------------------------------------------
/lib/data_factory/__init__.py:
--------------------------------------------------------------------------------
1 | from .ds_base import get_dataset, collate
2 | from .ds_loader import get_loader
3 | from .ds_transform import get_transform
4 | from .ds_formatter import get_formatter
5 | from .ds_sampler import DistributedSampler
6 |
--------------------------------------------------------------------------------
/lib/data_factory/ds_base.py:
--------------------------------------------------------------------------------
1 | import os
2 | import os.path as osp
3 | import numpy as np
4 | import numpy.random as npr
5 | import torch
6 | import torchvision
7 | import copy
8 | import itertools
9 |
10 | import sys
11 | code_dir = osp.abspath(osp.join(osp.dirname(__file__), '..', '..'))
12 | sys.path.append(code_dir)
13 |
14 | from .. import nputils
15 | from ..cfg_helper import cfg_unique_holder as cfguh
16 | from ..log_service import print_log
17 |
18 | class ds_base(torch.utils.data.Dataset):
19 | def __init__(self,
20 | mode,
21 | loader = None,
22 | estimator = None,
23 | transforms = None,
24 | formatter = None,
25 | **kwargs):
26 | self.init_load_info(mode)
27 | self.loader = loader
28 | self.transforms = transforms
29 | self.formatter = formatter
30 |
31 | console_info = '{}: '.format(self.__class__.__name__)
32 | console_info += 'total {} unique images, '.format(len(self.load_info))
33 |
34 | self.load_info = sorted(self.load_info, key=lambda x:x['unique_id'])
35 | if estimator is not None:
36 | self.load_info = estimator(self.load_info)
37 |
38 | console_info += 'total {0} unique images after estimation.'.format(
39 | len(self.load_info))
40 | print_log(console_info)
41 |
42 | for idx, info in enumerate(self.load_info):
43 | info['idx'] = idx
44 |
45 | try:
46 | trysome = cfguh().cfg.DATA.TRY_SOME_SAMPLE
47 | except:
48 | trysome = None
49 |
50 | if trysome is not None:
51 | if isinstance(trysome, str):
52 | trysome = [trysome]
53 | elif isinstance(trysome, (list, tuple)):
54 | trysome = list(trysome)
55 | else:
56 | raise ValueError
57 |
58 | self.load_info = [
59 | infoi for infoi in self.load_info \
60 | if osp.splitext(
61 | osp.basename(infoi['image_path'])
62 | )[0] in trysome
63 | ]
64 | print_log('try {} samples.'.format(len(self.load_info)))
65 |
66 | def init_load_info(self, mode):
67 | # implement by sub class
68 | raise ValueError
69 |
70 | def __len__(self):
71 | try:
72 | try_sample = cfguh().cfg.DATA.TRY_SAMPLE
73 | except:
74 | try_sample = None
75 | if try_sample is not None:
76 | return try_sample
77 | return len(self.load_info)
78 |
79 | def __getitem__(self, idx):
80 | element = copy.deepcopy(self.load_info[idx])
81 | element = self.loader(element)
82 | if self.transforms is not None:
83 | element = self.transforms(element)
84 | if self.formatter is not None:
85 | return self.formatter(element)
86 | else:
87 | return element
88 |
89 | def singleton(class_):
90 | instances = {}
91 | def getinstance(*args, **kwargs):
92 | if class_ not in instances:
93 | instances[class_] = class_(*args, **kwargs)
94 | return instances[class_]
95 | return getinstance
96 |
97 | @singleton
98 | class get_dataset(object):
99 | def __init__(self):
100 | self.dataset = {}
101 |
102 | def register(self, dsf):
103 | self.dataset[dsf.__name__] = dsf
104 |
105 | def __call__(self, dsname=None):
106 | if dsname is None:
107 | dsname = cfguh().cfg.DATA.DATASET_NAME
108 |
109 | # the register is in each file
110 | if dsname == 'textseg':
111 | from . import ds_textseg
112 | elif dsname == 'cocots':
113 | from . import ds_cocotext
114 | elif dsname == 'mlt':
115 | from . import ds_mlt
116 | elif dsname == 'icdar13':
117 | from . import ds_icdar13
118 | elif dsname == 'totaltext':
119 | from . import ds_totaltext
120 | elif dsname == 'textssc':
121 | from . import ds_textssc
122 |
123 | return self.dataset[dsname]
124 |
125 | def register():
126 | def wrapper(class_):
127 | get_dataset().register(class_)
128 | return class_
129 | return wrapper
130 |
131 | # some other helpers
132 |
133 | class collate(object):
134 | def __init__(self):
135 | self.default_collate = torch.utils.data._utils.collate.default_collate
136 |
137 | def __call__(self, batch):
138 | elem = batch[0]
139 | if not isinstance(elem, (tuple, list)):
140 | return self.default_collate(batch)
141 |
142 | rv = []
143 | # transposed
144 | for i in zip(*batch):
145 | if isinstance(i[0], list):
146 | if len(i[0]) != 1:
147 | raise ValueError
148 | try:
149 | i = [[self.default_collate(ii).squeeze(0)] for ii in i]
150 | except:
151 | pass
152 | rvi = list(itertools.chain.from_iterable(i))
153 | rv.append(rvi) # list concat
154 | elif i[0] is None:
155 | rv.append(None)
156 | else:
157 | rv.append(self.default_collate(i))
158 | return rv
159 |
--------------------------------------------------------------------------------
/lib/data_factory/ds_cocotext.py:
--------------------------------------------------------------------------------
1 | import os
2 | import os.path as osp
3 | import numpy as np
4 | import numpy.random as npr
5 | import torch
6 | import torchvision
7 | import PIL
8 | import json
9 | import cv2
10 | import copy
11 | PIL.Image.MAX_IMAGE_PIXELS = None
12 |
13 | from .ds_base import ds_base, register as regdataset
14 | from .ds_loader import pre_loader_checkings, register as regloader
15 | from .ds_transform import TBase, have, register as regtrans
16 | from .ds_formatter import register as regformat
17 |
18 | from .. import nputils
19 | from ..cfg_helper import cfg_unique_holder as cfguh
20 | from ..log_service import print_log
21 |
22 | class common(ds_base):
23 | def init_load_info(self, mode):
24 | cfgd = cfguh().cfg.DATA
25 | annofile = osp.join(cfgd.ROOT_DIR, 'coco_text', 'cocotext.v2.json')
26 | with open(annofile, 'r') as f:
27 | annoinfo = json.load(f)
28 |
29 | im_list = [i for _, i in annoinfo['imgs'].items()]
30 | im_train = list(filter(lambda x:x['set'] == 'train', im_list))
31 | im_val = list(filter(lambda x:x['set'] == 'val' , im_list))
32 | im_list = []
33 |
34 | for m in mode.split('+'):
35 | if m == 'train':
36 | im_list += im_train
37 | elif m == 'val':
38 | im_list += im_val
39 | else:
40 | raise ValueError
41 |
42 | self.load_info = []
43 |
44 | for im in im_list:
45 | filename = im['file_name']
46 | path = filename.split('_')[1]
47 | imagepath = osp.join(cfgd.ROOT_DIR, path, filename)
48 | annids = annoinfo['imgToAnns'][str(im['id'])]
49 | info = {
50 | 'unique_id' : filename.split('.')[0],
51 | 'filename' : filename,
52 | 'image_path': imagepath,
53 | 'coco_text_anno' : [annoinfo['anns'][str(i)] for i in annids],
54 | }
55 | self.load_info.append(info)
56 |
57 | @regdataset()
58 | class cocotext(common):
59 | def init_load_info(self, mode):
60 | super().init_load_info(mode)
61 |
62 | @regdataset()
63 | class cocots(common):
64 | def init_load_info(self, mode):
65 | cfgd = cfguh().cfg.DATA
66 | super().init_load_info(mode)
67 |
68 | cocots_annopath = osp.join(cfgd.ROOT_DIR, 'coco_ts_labels')
69 |
70 | info_new = []
71 | for info in self.load_info:
72 | annf = info['filename'].split('.')[0]+'.png'
73 | segpath = osp.join(cocots_annopath, annf)
74 | if osp.exists(segpath):
75 | info = copy.deepcopy(info)
76 | info['seglabel_path'] = segpath
77 | info['bbox_path'] = info['coco_text_anno'] # a hack
78 | info_new.append(info)
79 | self.load_info = info_new
80 |
81 | def get_semantic_classname(self,):
82 | map = {
83 | 0 : 'background' ,
84 | 1 : 'text' ,
85 | }
86 | return map
87 |
--------------------------------------------------------------------------------
/lib/data_factory/ds_formatter.py:
--------------------------------------------------------------------------------
1 | import os
2 | import os.path as osp
3 | import numpy as np
4 | import numpy.random as npr
5 | import torch
6 | import cv2
7 | # import scipy.ndimage
8 | from PIL import Image
9 | import copy
10 | import gc
11 | import itertools
12 |
13 | from ..cfg_helper import cfg_unique_holder as cfguh
14 | from .. import nputils
15 |
16 | def singleton(class_):
17 | instances = {}
18 | def getinstance(*args, **kwargs):
19 | if class_ not in instances:
20 | instances[class_] = class_(*args, **kwargs)
21 | return instances[class_]
22 | return getinstance
23 |
24 | @singleton
25 | class get_formatter(object):
26 | def __init__(self):
27 | self.formatter = {}
28 |
29 | def register(self, formatf, kwmap, kwfix):
30 | self.formatter[formatf.__name__] = [formatf, kwmap, kwfix]
31 |
32 | def __call__(self, format_name=None):
33 | cfgd = cfguh().cfg.DATA
34 | if format_name is None:
35 | format_name = cfgd.FORMATTER
36 |
37 | formatf, kwmap, kwfix = self.formatter[format_name]
38 | kw = {k1:cfgd[k2] for k1, k2 in kwmap.items()}
39 | kw.update(kwfix)
40 | return formatf(**kw)
41 |
42 | def register(kwmap={}, kwfix={}):
43 | def wrapper(class_):
44 | get_formatter().register(class_, kwmap, kwfix)
45 | return class_
46 | return wrapper
47 |
48 | @register()
49 | class SemanticFormatter(object):
50 | def __init__(self,
51 | **kwargs):
52 | pass
53 |
54 | def __call__(self, element):
55 | im = element['image']
56 | semlabel = element['seglabel']
57 | return im.astype(np.float32), semlabel.astype(int), element['unique_id']
58 |
--------------------------------------------------------------------------------
/lib/data_factory/ds_loader.py:
--------------------------------------------------------------------------------
1 | import os.path as osp
2 | import numpy as np
3 | import numpy.random as npr
4 | import PIL
5 | import cv2
6 |
7 | import torch
8 | import torchvision
9 | import xml.etree.ElementTree as ET
10 | import json
11 | import copy
12 |
13 | from ..cfg_helper import cfg_unique_holder as cfguh
14 | from .. import nputils
15 |
16 | def singleton(class_):
17 | instances = {}
18 | def getinstance(*args, **kwargs):
19 | if class_ not in instances:
20 | instances[class_] = class_(*args, **kwargs)
21 | return instances[class_]
22 | return getinstance
23 |
24 | @singleton
25 | class get_loader(object):
26 | def __init__(self):
27 | self.loader = {}
28 |
29 | def register(self, loadf, kwmap, kwfix):
30 | self.loader[loadf.__name__] = [loadf, kwmap, kwfix]
31 |
32 | def __call__(self, pipeline=None):
33 | cfgd = cfguh().cfg.DATA
34 | if pipeline is None:
35 | pipeline = cfgd.LOADER_PIPELINE
36 |
37 | loader = []
38 | for tag in pipeline:
39 | loadf, kwmap, kwfix = self.loader[tag]
40 | kw = {k1:cfgd[k2] for k1, k2 in kwmap.items()}
41 | kw.update(kwfix)
42 | loader.append(loadf(**kw))
43 | if len(loader) == 0:
44 | return None
45 | else:
46 | return compose(loader)
47 |
48 | class compose(object):
49 | def __init__(self, loaders):
50 | self.loaders = loaders
51 |
52 | def __call__(self, element):
53 | for l in self.loaders:
54 | element = l(element)
55 | return element
56 |
57 | def register(kwmap={}, kwfix={}):
58 | def wrapper(class_):
59 | get_loader().register(class_, kwmap, kwfix)
60 | return class_
61 | return wrapper
62 |
63 | def pre_loader_checkings(ltype):
64 | lpath = ltype+'_path'
65 | # cache feature added on 20201021
66 | lcache = ltype+'_cache'
67 | def wrapper(func):
68 | def inner(self, element):
69 | if lcache in element:
70 | # cache feature added on 20201021
71 | data = element[lcache]
72 | else:
73 | if ltype in element:
74 | raise ValueError
75 | if lpath not in element:
76 | raise ValueError
77 |
78 | if element[lpath] is None:
79 | data = None
80 | else:
81 | data = func(self, element[lpath], element)
82 | element[ltype] = data
83 |
84 | if ltype == 'image':
85 | if isinstance(data, np.ndarray):
86 | imsize = data.shape[-2:]
87 | elif isinstance(data, PIL.Image.Image):
88 | imsize = data.size[::-1]
89 | elif data is None:
90 | imsize = None
91 | else:
92 | raise ValueError
93 | element['imsize'] = imsize
94 | element['imsize_current'] = copy.deepcopy(imsize)
95 | return element
96 | return inner
97 | return wrapper
98 |
99 | ###########
100 | # general #
101 | ###########
102 |
103 | @register(
104 | {'backend':'LOAD_BACKEND_IMAGE',
105 | 'is_mc' :'LOAD_IS_MC_IMAGE' ,})
106 | class NumpyImageLoader(object):
107 | def __init__(self,
108 | backend='pil',
109 | is_mc=False,):
110 | self.backend = backend
111 | self.is_mc = is_mc
112 |
113 | @pre_loader_checkings('image')
114 | def __call__(self, path, element):
115 | return self.load(path)
116 |
117 | def load(self, path):
118 | if not self.is_mc:
119 | if self.backend == 'cv2':
120 | data = cv2.imread(
121 | path, cv2.IMREAD_COLOR)[:, :, ::-1]
122 | elif self.backend == 'pil':
123 | data = np.array(PIL.Image.open(
124 | path).convert('RGB'))
125 | else:
126 | raise ValueError
127 | else:
128 | # multichannel should not assume image is
129 | # defaultly RGB
130 | datai = []
131 | for p in path:
132 | if self.backend == 'cv2':
133 | i = cv2.imread(p)[:, :, ::-1]
134 | elif self.backend == 'pil':
135 | i = np.array(PIL.Image.open(p))
136 | else:
137 | raise ValueError
138 | if len(i.shape) == 2:
139 | i = i[:, :, np.newaxis]
140 | datai.append(i)
141 | data = np.concatenate(datai, axis=2)
142 | return np.transpose(data, (2, 0, 1)).astype(np.uint8)
143 |
144 | @register(
145 | {'backend':'LOAD_BACKEND_IMAGE',
146 | 'is_mc' :'LOAD_IS_MC_IMAGE' ,})
147 | class NumpyImageLoaderWithCache(NumpyImageLoader):
148 | def __init__(self,
149 | backend='pil',
150 | is_mc=False):
151 | super().__init__(backend, is_mc)
152 | self.cache_data = None
153 | self.cache_path = None
154 |
155 | @pre_loader_checkings('image')
156 | def __call__(self, path, element):
157 | if path == self.cache_path:
158 | return self.cache_data
159 | else:
160 | self.cache_data = super().load(path)
161 | return self.cache_data
162 |
163 | @register(
164 | {'backend':'LOAD_BACKEND_SEGLABEL',
165 | 'is_mc' :'LOAD_IS_MC_SEGLABEL' ,})
166 | class NumpySeglabelLoader(object):
167 | def __init__(self,
168 | backend='pil',
169 | is_mc=False):
170 | self.backend = backend
171 | self.is_mc = is_mc
172 |
173 | @pre_loader_checkings('seglabel')
174 | def __call__(self, path, element):
175 | return self.load(path)
176 |
177 | def load(self, path):
178 | if not self.is_mc:
179 | if self.backend == 'cv2':
180 | data = cv2.imread(
181 | path, cv2.IMREAD_GRAYSCALE)
182 | elif self.backend == 'pil':
183 | data = np.array(PIL.Image.open(path))
184 | else:
185 | raise ValueError
186 | else:
187 | # seglabel doesn't convert to rgb.
188 | datai = []
189 | for p in path:
190 | if self.backend == 'cv2':
191 | i = cv2.imread(
192 | p, cv2.IMREAD_GRAYSCALE)
193 | elif self.backend == 'pil':
194 | i = np.array(PIL.Image.open(p))
195 | else:
196 | raise ValueError
197 | if len(i.shape) == 2:
198 | i = i[:, :, np.newaxis]
199 | datai.append(i)
200 | data = np.concatenate(datai, axis=2)
201 | data = np.transpose(data, (2, 0, 1))
202 | if data.shape[0] == 1:
203 | data = data[0]
204 | return data.astype(np.int32)
205 |
206 | @register(
207 | {'backend':'LOAD_BACKEND_MASK',
208 | 'is_mc' :'LOAD_IS_MC_MASK' ,})
209 | class NumpyMaskLoader(NumpySeglabelLoader):
210 | def __init__(self,
211 | backend='pil',
212 | is_mc=False):
213 | super().__init__(
214 | backend, is_mc)
215 |
216 | @pre_loader_checkings('mask')
217 | def __call__(self, path, element):
218 | data = super().load(path)
219 | return (data!=0).astype(np.uint8)
220 |
221 |
--------------------------------------------------------------------------------
/lib/data_factory/ds_mlt.py:
--------------------------------------------------------------------------------
1 | import os
2 | import os.path as osp
3 | import numpy as np
4 | import numpy.random as npr
5 | import torch
6 | import torchvision
7 | import PIL
8 | import json
9 | import cv2
10 | import copy
11 | PIL.Image.MAX_IMAGE_PIXELS = None
12 |
13 | from .ds_base import ds_base, register as regdataset
14 | from .ds_loader import pre_loader_checkings, register as regloader
15 | from .ds_transform import TBase, have, register as regtrans
16 | from .ds_formatter import register as regformat
17 |
18 | from .. import nputils
19 | from ..cfg_helper import cfg_unique_holder as cfguh
20 | from ..log_service import print_log
21 |
22 | @regdataset()
23 | class mlt(ds_base):
24 | def init_load_info(self, mode):
25 | cfgd = cfguh().cfg.DATA
26 | self.root_dir = cfgd.ROOT_DIR
27 |
28 | trainlist = self.get_trainlist()
29 | vallist = self.get_vallist()
30 |
31 | trainlist = [i for _, i in trainlist.items()]
32 | vallist = [i for _, i in vallist.items() ]
33 |
34 | self.load_info = []
35 | for mi in mode.split('+'):
36 | if mi == 'train':
37 | self.load_info += trainlist
38 | elif mi == 'trainseg':
39 | self.load_info += [
40 | i for i in trainlist if i['seglabel_path'] is not None]
41 | elif mi == 'val':
42 | self.load_info += vallist
43 | elif mi == 'valseg':
44 | self.load_info += [
45 | i for i in vallist if i['seglabel_path'] is not None]
46 | else:
47 | raise ValueError
48 |
49 | def get_trainlist(self):
50 | # get train data
51 | imdir_list = [
52 | 'ch8_training_images_1', 'ch8_training_images_2',
53 | 'ch8_training_images_3', 'ch8_training_images_4',
54 | 'ch8_training_images_5', 'ch8_training_images_6',
55 | 'ch8_training_images_7', 'ch8_training_images_8', ]
56 |
57 | # get image info
58 | trainlist = {}
59 | for di in imdir_list:
60 | for fi in os.listdir(osp.join(self.root_dir, di)):
61 | fname = fi.split('.')[0]
62 | fpath = osp.join(self.root_dir, di, fi)
63 | trainlist[fname] = {
64 | 'unique_id' : '00_train_' + fname,
65 | 'filename' : fi,
66 | 'image_path' : fpath}
67 |
68 | # get bpoly label info
69 | labeldir = 'ch8_training_localization_transcription_gt_v2'
70 | for fi in trainlist.keys():
71 | bpoly_path = osp.join(self.root_dir, labeldir, 'gt_'+fi+'.txt')
72 | if osp.exists(bpoly_path):
73 | trainlist[fi]['bpoly_path'] = bpoly_path
74 | else:
75 | raise ValueError
76 |
77 | # get seglabel info
78 | labeldir = osp.join('MLT_S_labels', 'training_labels')
79 | for fi in trainlist.keys():
80 | seglabel_path = osp.join(self.root_dir, labeldir, fi+'.png')
81 | if osp.exists(seglabel_path):
82 | trainlist[fi]['seglabel_path'] = seglabel_path
83 | else:
84 | trainlist[fi]['seglabel_path'] = None
85 | return trainlist
86 |
87 | def get_vallist(self):
88 | # get train data
89 | imdir_list = [
90 | 'ch8_validation_images', ]
91 |
92 | # get image info
93 | vallist = {}
94 | for di in imdir_list:
95 | for fi in os.listdir(osp.join(self.root_dir, di)):
96 | fname = fi.split('.')[0]
97 | fpath = osp.join(self.root_dir, di, fi)
98 | vallist[fname] = {
99 | 'unique_id' : '01_val_' + fname,
100 | 'filename' : fi,
101 | 'image_path' : fpath}
102 |
103 | # get bpoly label info
104 | labeldir = 'ch8_validation_localization_transcription_gt_v2'
105 | for fi in vallist.keys():
106 | bpoly_path = osp.join(self.root_dir, labeldir, 'gt_'+fi+'.txt')
107 | if osp.exists(bpoly_path):
108 | vallist[fi]['bpoly_path'] = bpoly_path
109 | else:
110 | raise ValueError
111 |
112 | # get seglabel info
113 | labeldir = osp.join('MLT_S_labels', 'validation_labels')
114 | for fi in vallist.keys():
115 | seglabel_path = osp.join(self.root_dir, labeldir, fi+'.png')
116 | if osp.exists(seglabel_path):
117 | vallist[fi]['seglabel_path'] = seglabel_path
118 | else:
119 | vallist[fi]['seglabel_path'] = None
120 | return vallist
121 |
122 | def get_semantic_classname(self,):
123 | map = {
124 | 0 : 'background' ,
125 | 1 : 'text' ,
126 | }
127 | return map
128 |
129 | # ---- loader ----
130 |
131 | @regloader()
132 | class Mlt_SeglabelLoader(object):
133 | def __init__(self):
134 | pass
135 |
136 | @pre_loader_checkings('seglabel')
137 | def __call__(self, path, element):
138 | sem = np.array(PIL.Image.open(path)).astype(int) #.convert('RGB'))
139 | return sem
140 |
--------------------------------------------------------------------------------
/lib/data_factory/ds_sampler.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 | import numpy.random as npr
4 | import torch.distributed as dist
5 | import math
6 |
7 | from ..log_service import print_log
8 |
9 | class DistributedSampler(torch.utils.data.Sampler):
10 | def __init__(self,
11 | dataset,
12 | num_replicas=None,
13 | rank=None,
14 | shuffle=True,
15 | extend=False,):
16 | if num_replicas is None:
17 | if not dist.is_available():
18 | raise ValueError
19 | num_replicas = dist.get_world_size()
20 | if rank is None:
21 | if not dist.is_available():
22 | raise ValueError
23 | rank = dist.get_rank()
24 |
25 | self.dataset = dataset
26 | self.num_replicas = num_replicas
27 | self.rank = rank
28 |
29 | num_samples = len(dataset) // num_replicas
30 | if extend:
31 | if len(dataset) != num_samples*num_replicas:
32 | num_samples+=1
33 |
34 | self.num_samples = num_samples
35 | self.total_size = num_samples * num_replicas
36 | self.shuffle = shuffle
37 | self.extend = extend
38 |
39 | def __iter__(self):
40 | indices = self.get_sync_order()
41 | if self.extend:
42 | # extend using the front indices
43 | indices = (indices+indices)[0:self.total_size]
44 | else:
45 | # truncate
46 | indices = indices[0:self.total_size]
47 | # subsample
48 | indices = indices[self.rank : len(indices) : self.num_replicas]
49 | return iter(indices)
50 |
51 | def __len__(self):
52 | return self.num_samples
53 |
54 | def set_epoch(self, epoch):
55 | # legacy
56 | pass
57 |
58 | def get_sync_order(self):
59 | # g = torch.Generator()
60 | # g.manual_seed(self.epoch)
61 | if self.shuffle:
62 | indices = torch.randperm(len(self.dataset)).to(self.rank)
63 | dist.broadcast(indices, src=0)
64 | indices = indices.to('cpu').tolist()
65 | else:
66 | indices = list(range(len(self.dataset)))
67 | print_log(str(indices[0:5]))
68 | return indices
69 |
--------------------------------------------------------------------------------
/lib/data_factory/ds_textssc.py:
--------------------------------------------------------------------------------
1 | import os
2 | import os.path as osp
3 | import numpy as np
4 | import numpy.random as npr
5 | import torch
6 | import torchvision
7 | import PIL
8 | import json
9 | import cv2
10 | import copy
11 | import scipy
12 | import pandas
13 | PIL.Image.MAX_IMAGE_PIXELS = None
14 |
15 | from .ds_base import ds_base, register as regdataset
16 | from .ds_loader import pre_loader_checkings, register as regloader
17 | from .ds_transform import TBase, have, register as regtrans
18 | from .ds_formatter import register as regformat
19 |
20 | from .. import nputils
21 | from ..cfg_helper import cfg_unique_holder as cfguh
22 | from ..log_service import print_log
23 |
24 | @regdataset()
25 | class textssc(ds_base):
26 | def init_load_info(self, mode):
27 | cfgd = cfguh().cfg.DATA
28 | self.root_dir = cfgd.ROOT_DIR
29 |
30 | imdir = []
31 | segdir = []
32 | for modei in mode.split('+'):
33 | dsi, seti = modei.split('_')
34 | imdir += [osp.join(self.root_dir, dsi, 'image_'+seti)]
35 | segdir += [osp.join(self.root_dir, dsi, 'seglabel_'+seti)]
36 |
37 | self.load_info = []
38 |
39 | for imdiri, segdiri in zip(imdir, segdir):
40 | for fi in os.listdir(imdiri):
41 | ftag = fi.split('.')[0]
42 | info = {
43 | 'unique_id' : ftag,
44 | 'filename' : fi,
45 | 'image_path' : osp.join(imdiri, fi),
46 | 'seglabel_path' : osp.join(segdiri, ftag+'.png'),
47 | }
48 | self.load_info.append(info)
49 |
50 | def get_semantic_classname(self,):
51 | map = {
52 | 0 : 'background' ,
53 | 1 : 'text' ,
54 | }
55 | return map
56 |
--------------------------------------------------------------------------------
/lib/data_factory/ds_totaltext.py:
--------------------------------------------------------------------------------
1 | import os
2 | import os.path as osp
3 | import numpy as np
4 | import numpy.random as npr
5 | import torch
6 | import torchvision
7 | import PIL
8 | import json
9 | import cv2
10 | import copy
11 | PIL.Image.MAX_IMAGE_PIXELS = None
12 |
13 | from .ds_base import ds_base, register as regdataset
14 | from .ds_loader import pre_loader_checkings, register as regloader
15 | from .ds_transform import TBase, have, register as regtrans
16 | from .ds_formatter import register as regformat
17 |
18 | from .. import nputils
19 | from ..cfg_helper import cfg_unique_holder as cfguh
20 | from ..log_service import print_log
21 |
22 | @regdataset()
23 | class totaltext(ds_base):
24 | def init_load_info(self, mode):
25 | cfgd = cfguh().cfg.DATA
26 | self.root_dir = cfgd.ROOT_DIR
27 |
28 | im_path = []
29 | seg_path = []
30 |
31 | for mi in mode.split('+'):
32 | if mi == 'train':
33 | im_path += [osp.join(self.root_dir, 'Images', 'Train')]
34 | seg_path += [osp.join(self.root_dir, 'groundtruth_pixel', 'Train')]
35 | elif mi == 'test':
36 | im_path += [osp.join(self.root_dir, 'Images', 'Test')]
37 | seg_path += [osp.join(self.root_dir, 'groundtruth_pixel', 'Test')]
38 | else:
39 | raise ValueError
40 |
41 | self.load_info = []
42 | for impi, segpi in zip(im_path, seg_path):
43 | for fi in os.listdir(impi):
44 | uid = fi.split('.')[0]
45 | self.load_info.append({
46 | 'unique_id' : uid,
47 | 'filename' : fi,
48 | 'image_path' : osp.join(impi, fi),
49 | 'seglabel_path' : osp.join(segpi, uid+'.jpg'),
50 | })
51 |
52 | def get_semantic_classname(self,):
53 | map = {
54 | 0 : 'background' ,
55 | 1 : 'text' ,
56 | }
57 | return map
58 |
59 | # ---- loader ----
60 |
61 | @regloader()
62 | class TotalText_SeglabelLoader(object):
63 | def __init__(self):
64 | pass
65 |
66 | @pre_loader_checkings('seglabel')
67 | def __call__(self, path, element):
68 | sem = np.array(PIL.Image.open(path)).astype(int)
69 | sem = (sem>127).astype(int)
70 | return sem
71 |
--------------------------------------------------------------------------------
/lib/log_service.py:
--------------------------------------------------------------------------------
1 | import timeit
2 | import numpy as np
3 | import os.path as osp
4 | import torch
5 | import torch.nn as nn
6 | from .cfg_helper import cfg_unique_holder as cfguh
7 |
8 | def print_log(console_info):
9 | print(console_info)
10 | log_file = cfguh().cfg.LOG_FILE
11 | if log_file is not None:
12 | with open(log_file, 'a') as f:
13 | f.write(console_info + '\n')
14 |
15 | class log_manager(object):
16 | """
17 | The helper to print logs.
18 | """
19 | def __init__(self,
20 | **kwargs):
21 | self.data = {}
22 | self.cnt = {}
23 | self.time_check = timeit.default_timer()
24 |
25 | def accumulate(self,
26 | n,
27 | data,
28 | **kwargs):
29 | """
30 | Args:
31 | n: number of items (i.e. the batchsize)
32 | data: {itemname : float} data (i.e. the loss values)
33 | which are going to be accumulated.
34 | """
35 | if n < 0:
36 | raise ValueError
37 |
38 | for itemn, di in data.items():
39 | try:
40 | self.data[itemn] += di * n
41 | except:
42 | self.data[itemn] = di * n
43 |
44 | try:
45 | self.cnt[itemn] += n
46 | except:
47 | self.cnt[itemn] = n
48 |
49 | def print(self, rank, itern, epochn, samplen, lr):
50 | console_info = [
51 | 'Rank:{}'.format(rank),
52 | 'Iter:{}'.format(itern),
53 | 'Epoch:{}'.format(epochn),
54 | 'Sample:{}'.format(samplen),
55 | 'LR:{:.4f}'.format(lr)]
56 |
57 | cntgroups = {}
58 | for itemn, ci in self.cnt.items():
59 | try:
60 | cntgroups[ci].append(itemn)
61 | except:
62 | cntgroups[ci] = [itemn]
63 |
64 | for ci, itemng in cntgroups.items():
65 | console_info.append('cnt:{}'.format(ci))
66 | for itemn in sorted(itemng):
67 | console_info.append('{}:{:.4f}'.format(
68 | itemn, self.data[itemn]/ci))
69 |
70 | console_info.append('Time:{:.2f}s'.format(
71 | timeit.default_timer() - self.time_check))
72 | return ' , '.join(console_info)
73 |
74 | def clear(self):
75 | self.data = {}
76 | self.cnt = {}
77 | self.time_check = timeit.default_timer()
78 |
79 | def pop(self, rank, itern, epochn, samplen, lr):
80 | console_info = self.print(
81 | rank, itern, epochn, samplen, lr)
82 | self.clear()
83 | return console_info
84 |
85 | # ----- also include some small utils -----
86 |
87 | def torch_to_numpy(*argv):
88 | if len(argv) > 1:
89 | data = list(argv)
90 | else:
91 | data = argv[0]
92 |
93 | if isinstance(data, torch.Tensor):
94 | return data.to('cpu').detach().numpy()
95 |
96 | elif isinstance(data, (list, tuple)):
97 | out = []
98 | for di in data:
99 | out.append(torch_to_numpy(di))
100 | return out
101 |
102 | elif isinstance(data, dict):
103 | out = {}
104 | for ni, di in data.items():
105 | out[ni] = torch_to_numpy(di)
106 | return out
107 |
108 | else:
109 | return data
110 |
--------------------------------------------------------------------------------
/lib/loss.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | import numpy as np
5 | import copy
6 | from . import torchutils
7 |
8 | class finalize_loss(object):
9 | def __init__(self,
10 | weight = None,
11 | normalize_weight = True,
12 | **kwargs):
13 | if weight is None:
14 | self.weight = None
15 | else:
16 | for _, wi in weight.items():
17 | if wi < 0:
18 | raise ValueError
19 |
20 | if not normalize_weight:
21 | self.weight = weight
22 | else:
23 | sum_weight = 0
24 | for _, wi in weight.items():
25 | sum_weight += wi
26 | if sum_weight == 0:
27 | raise ValueError
28 | self.weight = {
29 | itemn:wi/sum_weight for itemn, wi in weight.items()}
30 |
31 | self.normalize_weight = normalize_weight
32 |
33 | def __call__(self,
34 | loss_input,):
35 | item = {n : v.item() for n, v in loss_input.items()}
36 | lossname = [n for n in loss_input.keys() if n[0:4]=='loss']
37 |
38 | if self.weight is not None:
39 | if sorted(lossname) \
40 | != sorted(list(self.weight.keys())):
41 | raise ValueError
42 |
43 | loss_num = len(lossname)
44 | loss = None
45 |
46 | for n in lossname:
47 | v = loss_input[n]
48 | if loss is not None:
49 | if self.weight is not None:
50 | loss += v * self.weight[n]
51 | else:
52 | loss += v
53 | else:
54 | if self.weight is not None:
55 | loss = v * self.weight[n]
56 | else:
57 | loss = v
58 |
59 | if (self.weight is None) and (self.normalize_weight):
60 | loss /= loss_num
61 |
62 | item['Loss'] = loss.item()
63 | return loss, item
64 |
--------------------------------------------------------------------------------
/lib/model_zoo/__init__.py:
--------------------------------------------------------------------------------
1 | from .get_model import get_model, save_state_dict
2 |
--------------------------------------------------------------------------------
/lib/model_zoo/deeplab.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 | from .. import nputils
6 | from .. import torchutils
7 | from .. import loss
8 | from .get_model import get_model, register
9 | from .optim_manager import optim_manager
10 |
11 | from . import utils
12 |
13 | version = 'v33'
14 |
15 | class ASPP(nn.Module):
16 | def __init__(self,
17 | ic_n,
18 | c_n,
19 | oc_n,
20 | dilation_n,
21 | conv_type='conv',
22 | bn_type='bn',
23 | relu_type='relu',
24 | dropout_type='dropout|0.5',
25 | with_gap=True,
26 | **kwargs):
27 | super().__init__()
28 |
29 | conv, bn, relu = utils.conv_bn_relu(conv_type, bn_type, relu_type)
30 | dropout = utils.nn_component(dropout_type)
31 |
32 | d1, d2, d3 = dilation_n
33 | self.conv1 = nn.Sequential(
34 | conv(ic_n, c_n, 1, 1, padding=0, dilation=1),
35 | bn(c_n),
36 | relu(inplace=True))
37 | self.conv2 = nn.Sequential(
38 | conv(ic_n, c_n, 3, 1, padding=d1, dilation=d1),
39 | bn(c_n),
40 | relu(inplace=True))
41 | self.conv3 = nn.Sequential(
42 | conv(ic_n, c_n, 3, 1, padding=d2, dilation=d2),
43 | bn(c_n),
44 | relu(inplace=True))
45 | self.conv4 = nn.Sequential(
46 | conv(ic_n, c_n, 3, 1, padding=d3, dilation=d3),
47 | bn(c_n),
48 | relu(inplace=True))
49 | if with_gap:
50 | self.conv5 = nn.Sequential(
51 | nn.AdaptiveAvgPool2d((1, 1)),
52 | conv(ic_n, c_n, 1, 1, padding=0, dilation=1),
53 | bn(c_n),
54 | relu(inplace=True))
55 | total_layers=5
56 | else:
57 | total_layers=4
58 |
59 | self.bottleneck = nn.Sequential(
60 | conv(c_n*total_layers, oc_n, 1, 1, 0),
61 | bn(oc_n),
62 | relu(inplace=True),)
63 | self.dropout = dropout()
64 | self.with_gap = with_gap
65 |
66 | def forward(self, x):
67 | _, _, h, w = x.size()
68 | feat1 = self.conv1(x)
69 | feat2 = self.conv2(x)
70 | feat3 = self.conv3(x)
71 | feat4 = self.conv4(x)
72 | if self.with_gap:
73 | feat5 = F.interpolate(
74 | self.conv5(x), size=(h, w), mode='nearest')
75 | out = torch.cat((feat1, feat2, feat3, feat4, feat5), dim=1)
76 | else:
77 | out = torch.cat((feat1, feat2, feat3, feat4), dim=1)
78 | out = self.bottleneck(out)
79 | out = self.dropout(out)
80 | return out
81 |
82 | class Decoder(nn.Module):
83 | def __init__(self,
84 | bic_n,
85 | xic_n,
86 | oc_n,
87 | align_corners=False,
88 | conv_type='conv',
89 | bn_type='bn',
90 | relu_type='relu',
91 | dropout2_type='dropout|0.5',
92 | dropout3_type='dropout|0.1',
93 | **kwargs):
94 | super().__init__()
95 |
96 | conv, bn, relu = utils.conv_bn_relu(conv_type, bn_type, relu_type)
97 | dropout2 = utils.nn_component(dropout2_type)
98 | dropout3 = utils.nn_component(dropout3_type)
99 |
100 | self.conv1 = conv(bic_n, 48, 1, 1, 0)
101 | self.bn1 = bn(48)
102 | self.relu = relu(inplace=True)
103 |
104 | self.conv2 = conv(xic_n+48, 256, 3, 1, 1)
105 | self.bn2 = bn(256)
106 | self.dropout2 = dropout2()
107 |
108 | self.conv3 = conv(256, oc_n, 3, 1, 1)
109 | self.bn3 = bn(oc_n)
110 | self.dropout3 = dropout3()
111 |
112 | self.align_corners = align_corners
113 |
114 | def forward(self,
115 | b,
116 | x):
117 | b = self.relu(self.bn1(self.conv1(b)))
118 | x = torchutils.interpolate_2d(
119 | size = b.shape[2:], mode='bilinear',
120 | align_corners=self.align_corners)(x)
121 | x = torch.cat([x, b], dim=1)
122 | x = self.relu(self.bn2(self.conv2(x)))
123 | x = self.dropout2(x)
124 | x = self.relu(self.bn3(self.conv3(x)))
125 | x = self.dropout3(x)
126 | return x
127 |
128 | class DeepLabv3p_Base(nn.Module):
129 | def __init__(self,
130 | bbn_name,
131 | bbn,
132 | oc_n,
133 | aspp_ic_n,
134 | aspp_dilation_n,
135 | decoder_bic_n,
136 | aspp_dropout_type='dropout|0.5',
137 | aspp_with_gap=True,
138 | decoder_dropout2_type='dropout|0.5',
139 | decoder_dropout3_type='dropout|0.1',
140 | conv_type='conv',
141 | bn_type='bn',
142 | relu_type='relu',
143 | align_corners=False,
144 | **kwargs):
145 | super().__init__()
146 |
147 | setattr(self, bbn_name, bbn)
148 | self.aspp = ASPP(
149 | aspp_ic_n, 256, 256, aspp_dilation_n,
150 | conv_type=conv_type,
151 | bn_type=bn_type,
152 | relu_type=relu_type,
153 | dropout_type=aspp_dropout_type,
154 | with_gap=aspp_with_gap)
155 | self.decoder = Decoder(
156 | decoder_bic_n, 256, oc_n,
157 | align_corners=align_corners,
158 | conv_type=conv_type,
159 | bn_type=bn_type,
160 | relu_type=relu_type,
161 | dropout2_type=decoder_dropout2_type,
162 | dropout3_type=decoder_dropout3_type,)
163 |
164 | self.bbn_name = bbn_name
165 |
166 | # initialize the weight
167 | utils.init_module([self.aspp, self.decoder])
168 |
169 | # prepare opmgr
170 | self.opmgr = getattr(self, self.bbn_name).opmgr
171 | self.opmgr.inheritant(self.bbn_name)
172 | self.opmgr.pushback(
173 | 'deeplab', ['aspp', 'decoder']
174 | )
175 |
176 | def forward(self, x):
177 | xs = getattr(self, self.bbn_name)(x)
178 | b, x = xs[0], xs[-1]
179 | x = self.aspp(x)
180 | x = self.decoder(b, x)
181 | return x
182 |
183 | class DeepLabv3p(DeepLabv3p_Base):
184 | def __init__(self,
185 | bbn_name,
186 | bbn,
187 | class_n,
188 | aspp_ic_n,
189 | aspp_dilation_n,
190 | decoder_bic_n,
191 | aspp_dropout_type='dropout|0.5',
192 | aspp_with_gap=True,
193 | decoder_dropout2_type='dropout|0.5',
194 | decoder_dropout3_type='dropout|0.1',
195 | conv_type='conv',
196 | bn_type='bn',
197 | relu_type='relu',
198 | align_corners=False,
199 | ignore_label=None,
200 | loss_type='ce',
201 | intrain_getpred=False,
202 | ineval_output_argmax=True,
203 | **kwargs):
204 |
205 | super().__init__(
206 | bbn_name,
207 | bbn,
208 | 256,
209 | aspp_ic_n,
210 | aspp_dilation_n,
211 | decoder_bic_n,
212 | aspp_dropout_type,
213 | aspp_with_gap,
214 | decoder_dropout2_type,
215 | decoder_dropout3_type,
216 | conv_type,
217 | bn_type,
218 | relu_type,
219 | align_corners,
220 | )
221 |
222 | self.semhead = utils.semantic_head(
223 | 256, class_n,
224 | align_corners=align_corners,
225 | ignore_label=ignore_label,
226 | loss_type=loss_type,
227 | ineval_output_argmax=ineval_output_argmax,
228 | intrain_getpred = intrain_getpred,
229 | )
230 |
231 | # initialize the weight
232 | # (aspp and decoder initalized in base class)
233 | utils.init_module([self.semhead])
234 |
235 | # prepare opmgr
236 | module_name = self.opmgr.popback()
237 | module_name.append('semhead')
238 | self.opmgr.pushback('deeplab', module_name)
239 |
240 | def forward(self,
241 | x,
242 | gtsem=None,
243 | ):
244 | x = super().forward(x)
245 | o = self.semhead(x, gtsem)
246 | return o
247 |
248 | @register(
249 | 'DEEPLAB',
250 | {
251 | # base
252 | 'freeze_backbone_bn' : 'FREEZE_BACKBONE_BN',
253 | 'oc_n' : 'OUTPUT_CHANNEL_NUM',
254 | 'conv_type' : 'CONV_TYPE',
255 | 'bn_type' : 'BN_TYPE',
256 | 'relu_type' : 'RELU_TYPE',
257 | 'aspp_dropout_type' : 'ASPP_DROPOUT_TYPE',
258 | 'aspp_with_gap' : 'ASPP_WITH_GAP',
259 | 'decoder_dropout2_type' : 'DECODER_DROPOUT2_TYPE',
260 | 'decoder_dropout3_type' : 'DECODER_DROPOUT3_TYPE',
261 | 'align_corners' : 'INTERPOLATE_ALIGN_CORNERS',
262 | # non_base
263 | 'ignore_label' : 'SEMANTIC_IGNORE_LABEL',
264 | 'class_n' : 'SEMANTIC_CLASS_NUM',
265 | 'loss_type' : 'LOSS_TYPE',
266 | 'intrain_getpred' : 'INTRAIN_GETPRED',
267 | 'ineval_output_argmax' : 'INEVAL_OUTPUT_ARGMAX',
268 | })
269 | def deeplab(tags, **para):
270 | if 'resnet' in tags:
271 | bbn = get_model()('resnet')
272 | para['bbn_name'] = 'resnet'
273 | para['bbn'] = bbn
274 | para['aspp_ic_n'] = 512*bbn.block_expansion
275 | para['decoder_bic_n'] = 64*bbn.block_expansion
276 |
277 | if 'os16' in tags:
278 | para['aspp_dilation_n'] = [6, 12, 18]
279 | elif 'os8' in tags:
280 | para['aspp_dilation_n'] = [12, 24, 36]
281 | else:
282 | raise ValueError
283 |
284 | try:
285 | freezebn = para.pop('freeze_backbone_bn')
286 | except:
287 | freezebn = False
288 | if freezebn:
289 | for m in bbn.modules():
290 | if isinstance(m, nn.BatchNorm2d):
291 | for i in m.parameters():
292 | i.requires_grad = False
293 |
294 | if 'v3+' in tags:
295 | if 'base' in tags:
296 | net = DeepLabv3p_Base(**para)
297 | else:
298 | net = DeepLabv3p(**para)
299 |
300 | return net
301 |
--------------------------------------------------------------------------------
/lib/model_zoo/get_model.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torchvision.models
3 | import os.path as osp
4 | from ..cfg_helper import cfg_unique_holder as cfguh
5 | from ..log_service import print_log
6 | from .utils import get_total_param, get_total_param_sum, freeze
7 |
8 | def load_state_dict(net, model_path):
9 | paras = torch.load(model_path, map_location=torch.device('cpu'))
10 | new_paras = net.state_dict()
11 | new_paras.update(paras)
12 | net.load_state_dict(new_paras)
13 | return
14 |
15 | def save_state_dict(net, path):
16 | if isinstance(net, (torch.nn.DataParallel,
17 | torch.nn.parallel.DistributedDataParallel)):
18 | torch.save(net.module.state_dict(), path)
19 | else:
20 | torch.save(net.state_dict(), path)
21 |
22 | def singleton(class_):
23 | instances = {}
24 | def getinstance(*args, **kwargs):
25 | if class_ not in instances:
26 | instances[class_] = class_(*args, **kwargs)
27 | return instances[class_]
28 | return getinstance
29 |
30 | @singleton
31 | class get_model(object):
32 | def __init__(self):
33 | self.model = {}
34 |
35 | def register(self, modelf, cfgname, kwmap, kwfix):
36 | self.model[modelf.__name__] = [modelf, cfgname, kwmap, kwfix]
37 |
38 | def __call__(self, name=None, cfgm=None):
39 | if cfgm is None:
40 | cfgm = cfguh().cfg.MODEL
41 | if name is None:
42 | name = cfgm.MODEL_NAME
43 |
44 | # the register is in each file
45 | if name == 'resnet':
46 | from . import resnet
47 | elif name == 'deeplab':
48 | from . import deeplab
49 | elif name == 'hrnet':
50 | from . import hrnet
51 | elif name == 'texrnet':
52 | from . import texrnet
53 |
54 | modelf, cfgname, kwmap, kwfix = self.model[name]
55 | cfgm = cfgm.__getitem__(cfgname)
56 |
57 | # MODEL_TAGS and PRETRAINED_PTH are two special args
58 | # FREEZE_BACKBONE_BN not frequently used.
59 | kw = {'tags' : cfgm.MODEL_TAGS}
60 | for k1, k2 in kwmap.items():
61 | if k2 in cfgm.keys():
62 | kw[k1] = cfgm[k2]
63 | kw.update(kwfix)
64 | net = modelf(**kw)
65 |
66 | # load init model
67 | if cfgm.PRETRAINED_PTH is not None:
68 | print_log('Load model from {0}.'.format(
69 | cfgm.PRETRAINED_PTH))
70 | load_state_dict(
71 | net, cfgm.PRETRAINED_PTH)
72 |
73 | # display param_num & param_sum
74 | print_log('Load {} with total {} parameters, {:3f} parameter sum.'.format(
75 | name, get_total_param(net), get_total_param_sum(net)))
76 |
77 | return net
78 |
79 | def register(cfgname, kwmap={}, kwfix={}):
80 | def wrapper(class_):
81 | get_model().register(class_, cfgname, kwmap, kwfix)
82 | return class_
83 | return wrapper
84 |
--------------------------------------------------------------------------------
/lib/model_zoo/hrnet.py:
--------------------------------------------------------------------------------
1 | import sys
2 | # import os
3 | import os.path as osp
4 | sys.path.append(osp.join(osp.dirname(__file__), '..', '..', 'hrnet_code', 'lib'))
5 | # sys.path.append('/home/james/Spy/hrnet/HRNet-Semantic-Segmentation-HRNet-OCR/lib')
6 | from models import seg_hrnet
7 | from models import seg_hrnet_ocr
8 |
9 | import torch
10 | import torch.nn as nn
11 | import torch.nn.functional as F
12 |
13 | from .. import nputils
14 | from .. import torchutils
15 | from .. import loss
16 | from .get_model import get_model, register
17 | from .optim_manager import optim_manager
18 |
19 | from . import utils
20 |
21 | version = 'v0'
22 | """
23 | v0: the original code from github made by paper author
24 | """
25 |
26 | class HRNet_Base(seg_hrnet.HighResolutionNet):
27 | def __init__(self,
28 | oc_n,
29 | align_corners,
30 | ignore_label,
31 | stage1_para,
32 | stage2_para,
33 | stage3_para,
34 | stage4_para,
35 | final_conv_kernel,
36 | **kwargs):
37 | from easydict import EasyDict as edict
38 | config = edict()
39 | config.MODEL = edict()
40 | config.MODEL.ALIGN_CORNERS = align_corners
41 | config.MODEL.EXTRA = {}
42 | config.MODEL.EXTRA['STAGE1'] = stage1_para
43 | config.MODEL.EXTRA['STAGE2'] = stage2_para
44 | config.MODEL.EXTRA['STAGE3'] = stage3_para
45 | config.MODEL.EXTRA['STAGE4'] = stage4_para
46 | config.MODEL.EXTRA['FINAL_CONV_KERNEL'] = final_conv_kernel
47 | config.DATASET = edict()
48 | config.DATASET.NUM_CLASSES = 1 # dummy
49 | super().__init__(config)
50 |
51 | self.opmgr = optim_manager(
52 | group = {'hrnet': 'self'},
53 | order = ['hrnet'],
54 | )
55 |
56 | last_inp_channels = self.last_layer[0].in_channels
57 | # BatchNorm2d = nn.SyncBatchNorm
58 | BatchNorm2d = nn.BatchNorm2d
59 | relu_inplace = True
60 | self.last_layer = nn.Sequential(
61 | nn.Conv2d(
62 | in_channels=last_inp_channels,
63 | out_channels=oc_n,
64 | kernel_size=1,
65 | stride=1,
66 | padding=0),
67 | BatchNorm2d(oc_n, momentum=0.1),
68 | nn.ReLU(inplace=relu_inplace),
69 | )
70 |
71 | def forward(self, x):
72 | x = super().forward(x)
73 | return x
74 |
75 | class HRNet(seg_hrnet.HighResolutionNet):
76 | def __init__(self,
77 | cls_n,
78 | align_corners,
79 | ignore_label,
80 | stage1_para,
81 | stage2_para,
82 | stage3_para,
83 | stage4_para,
84 | final_conv_kernel,
85 | loss_type='ce',
86 | intrain_getpred=False,
87 | ineval_output_argmax=False,
88 | **kwargs):
89 | from easydict import EasyDict as edict
90 | config = edict()
91 | config.MODEL = edict()
92 | config.MODEL.ALIGN_CORNERS = align_corners
93 | config.MODEL.EXTRA = {}
94 | config.MODEL.EXTRA['STAGE1'] = stage1_para
95 | config.MODEL.EXTRA['STAGE2'] = stage2_para
96 | config.MODEL.EXTRA['STAGE3'] = stage3_para
97 | config.MODEL.EXTRA['STAGE4'] = stage4_para
98 | config.MODEL.EXTRA['FINAL_CONV_KERNEL'] = final_conv_kernel
99 | config.DATASET = edict()
100 | config.DATASET.NUM_CLASSES = cls_n
101 | super().__init__(config)
102 |
103 | self.semhead = utils.semantic_head_noconv(
104 | align_corners = align_corners,
105 | ignore_label = ignore_label,
106 | loss_type = loss_type,
107 | intrain_getpred = intrain_getpred,
108 | ineval_output_argmax = ineval_output_argmax,
109 | )
110 |
111 | self.opmgr = optim_manager(
112 | group = {'hrnet': 'self'},
113 | order = ['hrnet'],
114 | )
115 |
116 | def forward(self, x, gtsem=None):
117 | x = super().forward(x)
118 | o = self.semhead(x, gtsem)
119 | return o
120 |
121 | class HRNet_Ocr(seg_hrnet_ocr.HighResolutionNet):
122 | def __init__(self,
123 | cls_n,
124 | align_corners,
125 | ignore_label,
126 | stage1_para,
127 | stage2_para,
128 | stage3_para,
129 | stage4_para,
130 | final_conv_kernel,
131 | loss_type='ce',
132 | intrain_getpred=False,
133 | ineval_output_argmax=False,
134 | ocr_mc_n = 512,
135 | ocr_keyc_n = 256,
136 | ocr_dropout_rate = 0.05,
137 | ocr_scale = 1,
138 | **kwargs):
139 | from easydict import EasyDict as edict
140 | config = edict()
141 | config.MODEL = edict()
142 | config.MODEL.ALIGN_CORNERS = align_corners
143 | config.MODEL.EXTRA = {}
144 | config.MODEL.EXTRA['STAGE1'] = stage1_para
145 | config.MODEL.EXTRA['STAGE2'] = stage2_para
146 | config.MODEL.EXTRA['STAGE3'] = stage3_para
147 | config.MODEL.EXTRA['STAGE4'] = stage4_para
148 | config.MODEL.EXTRA['FINAL_CONV_KERNEL'] = final_conv_kernel
149 | config.DATASET = edict()
150 | config.DATASET.NUM_CLASSES = cls_n
151 |
152 | # OCR special
153 | config.MODEL.OCR = edict()
154 | config.MODEL.OCR.MID_CHANNELS = ocr_mc_n
155 | config.MODEL.OCR.KEY_CHANNELS = ocr_keyc_n
156 | config.MODEL.OCR.DROPOUT = ocr_dropout_rate
157 | config.MODEL.OCR.SCALE = ocr_scale
158 |
159 | super().__init__(config)
160 |
161 | self.auxhead = utils.semantic_head_noconv(
162 | align_corners = align_corners,
163 | ignore_label = ignore_label,
164 | loss_type = loss_type,
165 | intrain_getpred = False,
166 | ineval_output_argmax = False,
167 | ) # the aux head (not main)
168 |
169 | self.semhead = utils.semantic_head_noconv(
170 | align_corners = align_corners,
171 | ignore_label = ignore_label,
172 | loss_type = loss_type,
173 | intrain_getpred = intrain_getpred,
174 | ineval_output_argmax = ineval_output_argmax,
175 | )
176 |
177 | self.opmgr = optim_manager(
178 | group = {'hrnet': 'self'},
179 | order = ['hrnet'],
180 | )
181 |
182 | def forward(self, x, gtsem=None):
183 | x = super().forward(x)
184 | o = self.semhead(x[1], gtsem)
185 | if self.training:
186 | o['lossaux'] = self.auxhead(x[0], gtsem)['losssem']
187 | return o
188 |
189 |
190 | @register(
191 | 'HRNET',
192 | {
193 | 'oc_n' : 'OUTPUT_CHANNEL_NUM',
194 | 'cls_n' : 'CLASS_NUM',
195 | 'align_corners' : 'ALIGN_CORNERS',
196 | 'ignore_label' : 'IGNORE_LABEL',
197 | 'stage1_para' : 'STAGE1_PARA',
198 | 'stage2_para' : 'STAGE2_PARA',
199 | 'stage3_para' : 'STAGE3_PARA',
200 | 'stage4_para' : 'STAGE4_PARA',
201 | 'final_conv_kernel' : 'FINAL_CONV_KERNEL',
202 | 'loss_type' : 'LOSS_TYPE',
203 | 'intrain_getpred' : 'INTRAIN_GETPRED',
204 | 'ineval_output_argmax' : 'INEVAL_OUTPUT_ARGMAX',
205 | })
206 | def hrnet(tags, **para):
207 | if 'base' in tags:
208 | net = HRNet_Base(**para)
209 | elif 'ocr' in tags:
210 | net = HRNet_Ocr(**para)
211 | else:
212 | net = HRNet(**para)
213 | return net
214 |
--------------------------------------------------------------------------------
/lib/model_zoo/optim_manager.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | import itertools
5 |
6 | class optim_manager(object):
7 | def __init__(self,
8 | group,
9 | order=None,
10 | group_lrscale=None,):
11 | self.group = {}
12 | self.order = []
13 | for gn in order:
14 | self.pushback(gn, group[gn])
15 | self.group_lrscale = group_lrscale
16 |
17 | def pushback(self,
18 | group_name,
19 | module_name):
20 | if group_name in self.group.keys():
21 | raise ValueError
22 | if isinstance(module_name, (list, tuple)):
23 | self.group[group_name] = list(module_name)
24 | else:
25 | self.group[group_name] = [module_name]
26 | self.order += [group_name]
27 | self.group_lrscale = None
28 |
29 | def popback(self):
30 | group = self.group.pop(self.order[-1])
31 | self.order = self.order[:-1]
32 | self.group_lrscale = None
33 | return group
34 |
35 | def replace(self,
36 | group_name,
37 | module_name,):
38 | if group_name not in self.group.keys():
39 | raise ValueError
40 | if isinstance(module_name, (list, tuple)):
41 | module_name = list(module_name)
42 | else:
43 | module_name = [module_name]
44 | self.group[group_name] = module_name
45 | self.group_lrscale = None
46 |
47 | def inheritant(self,
48 | supername):
49 | for gn in self.group.keys():
50 | module_name = self.group[gn]
51 | module_name_new = []
52 | for mn in module_name:
53 | if mn == 'self':
54 | module_name_new.append(supername)
55 | else:
56 | module_name_new.append(supername+'.'+mn)
57 | self.group[gn] = module_name_new
58 |
59 | def set_lrscale(self,
60 | group_lrscale):
61 | if sorted(self.order) != \
62 | sorted(list(group_lrscale.keys())):
63 | raise ValueError
64 | self.group_lrscale = group_lrscale
65 |
66 | def pg_generator(self,
67 | netref,
68 | module_name):
69 | if not isinstance(module_name, list):
70 | raise ValueError
71 | # the "self" special case
72 | if (len(module_name)==1) \
73 | and (module_name[0] == 'self'):
74 | return netref.parameters()
75 | pg = []
76 | for mn in module_name:
77 | if mn == 'self':
78 | raise ValueError
79 | mn = mn.split('.')
80 | module = netref
81 | for mni in mn:
82 | module = getattr(module, mni)
83 |
84 | pg.append(module.parameters())
85 |
86 | pg = itertools.chain(*pg)
87 | return pg
88 |
89 | def get_pg(self, net):
90 | try:
91 | netref = net.module
92 | except:
93 | netref = net
94 |
95 | return [
96 | {'params': self.pg_generator(netref, self.group[gn])} \
97 | for gn in self.order]
98 |
99 | def get_pglr(self, idx, base_lr):
100 | return base_lr * self.group_lrscale[self.order[idx]]
101 |
--------------------------------------------------------------------------------
/lib/optimizer/__init__.py:
--------------------------------------------------------------------------------
1 | from .get_optimizer import get_optimizer, adjust_lr, lr_scheduler
--------------------------------------------------------------------------------
/lib/optimizer/get_optimizer.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.optim as optim
3 | import numpy as np
4 |
5 | from ..cfg_helper import cfg_unique_holder as cfguh
6 |
7 | def get_optimizer(net, optimizer_name = None, opmgr = None):
8 | cfg = cfguh().cfg
9 | if optimizer_name is None:
10 | optimizer_name = cfg.TRAIN.OPTIMIZER
11 |
12 | # all lr are initialized as 0,
13 | # because it will be set outside this function
14 | if opmgr is None:
15 | parameter_groups = net.parameters()
16 | else:
17 | para_num = len([i for i in net.parameters()])
18 | para_ckt = sum([
19 | len([i for i in pg['params']]) for pg in opmgr.get_pg(net)])
20 | if para_num != para_ckt:
21 | # this check whether the opmgr paragroup include all parameters.
22 | # TODO: may put a warning here
23 | raise ValueError
24 | parameter_groups = opmgr.get_pg(net)
25 |
26 | if optimizer_name == "sgd":
27 | optimizer = optim.SGD(
28 | parameter_groups,
29 | lr = 0,
30 | momentum = cfg.TRAIN.SGD_MOMENTUM,
31 | weight_decay = cfg.TRAIN.SGD_WEIGHT_DECAY)
32 |
33 | elif optimizer_name == "adam":
34 | optimizer = optim.Adam(
35 | parameter_groups,
36 | lr = 0,
37 | betas = cfg.TRAIN.ADAM_BETAS,
38 | eps = cfg.TRAIN.ADAM_EPS,
39 | weight_decay = cfg.TRAIN.ADAM_WEIGHT_DECAY)
40 |
41 | else:
42 | raise ValueError
43 |
44 | return optimizer
45 |
46 | def adjust_lr(op, new_lr, opmgr = None):
47 | for idx, param_group in enumerate(op.param_groups):
48 | if opmgr is None:
49 | param_group['lr'] = new_lr
50 | else:
51 | param_group['lr'] = opmgr.get_pglr(idx, new_lr)
52 |
53 | class lr_scheduler(object):
54 | def __init__(self,
55 | types):
56 | self.lr = []
57 | for type in types:
58 | if type[0] == 'constant':
59 | _, v, n = type
60 | lr = [v for i in range(n)]
61 | elif type[0] == 'ploy':
62 | _, va, vb, n, pw = type
63 | lr = [ vb + (va-vb) * ((1-i/n)**pw) for i in range(n) ]
64 | elif type[0] == 'linear':
65 | _, va, vb, n = type
66 | lr = [ vb + (va-vb) * (1-i/n) for i in range(n) ]
67 | else:
68 | raise ValueError
69 | self.lr += lr
70 | self.lr = np.array(self.lr)
71 |
72 | def __call__(self, i):
73 | if i < len(self.lr):
74 | return self.lr[i]
75 | else:
76 | return self.lr[-1]
77 |
78 |
79 |
80 |
81 |
82 |
--------------------------------------------------------------------------------
/main.py:
--------------------------------------------------------------------------------
1 | import torch.distributed as dist
2 | import torch.multiprocessing as mp
3 |
4 | import os
5 | import os.path as osp
6 | import sys
7 | import numpy as np
8 | import copy
9 | import gc
10 | import time
11 |
12 | import argparse
13 | from easydict import EasyDict as edict
14 |
15 | from lib.model_zoo.texrnet import version as VERSION
16 | from lib.cfg_helper import cfg_unique_holder as cfguh, \
17 | get_experiment_id, \
18 | experiment_folder, \
19 | common_initiates
20 |
21 | from configs.cfg_dataset import cfg_textseg, cfg_cocots, cfg_mlt, cfg_icdar13, cfg_totaltext
22 | from configs.cfg_model import cfg_texrnet as cfg_mdel
23 | from configs.cfg_base import cfg_train, cfg_test
24 |
25 | from train_utils import \
26 | set_cfg as set_cfg_train, \
27 | set_cfg_hrnetw48 as set_cfg_hrnetw48_train, \
28 | ts, ts_with_classifier, train
29 |
30 | from eval_utils import \
31 | set_cfg as set_cfg_eval, \
32 | set_cfg_hrnetw48 as set_cfg_hrnetw48_eval, \
33 | es, eval
34 |
35 | cfguh().add_code(osp.basename(__file__))
36 |
37 | def common_argparse():
38 |
39 |
40 | cfg = edict()
41 | cfg.DEBUG = args.debug
42 | cfg.DIST_URL = 'tcp://127.0.0.1:{}'.format(args.port)
43 | is_eval = args.eval
44 | pth = args.pth
45 | return cfg, is_eval, pth
46 |
47 | if __name__ == '__main__':
48 | parser = argparse.ArgumentParser()
49 | parser.add_argument('--debug' , action='store_true', default=False)
50 | parser.add_argument('--hrnet' , action='store_true', default=False)
51 | parser.add_argument('--eval' , action='store_true', default=False)
52 | parser.add_argument('--pth' , type=str)
53 | parser.add_argument('--gpu' , nargs='+', type=int)
54 | parser.add_argument('--port' , type=int, default=11233)
55 | parser.add_argument('--dsname', type=str, default='textseg')
56 | parser.add_argument('--trainwithcls', action='store_true', default=False)
57 | args = parser.parse_args()
58 |
59 | istrain = not args.eval
60 |
61 | if istrain:
62 | cfg = copy.deepcopy(cfg_train)
63 | else:
64 | cfg = copy.deepcopy(cfg_test)
65 |
66 | if istrain:
67 | cfg.EXPERIMENT_ID = get_experiment_id()
68 | else:
69 | cfg.EXPERIMENT_ID = None
70 |
71 | if args.dsname == "textseg":
72 | cfg_data = cfg_textseg
73 | elif args.dsname == "cocots":
74 | cfg_data = cfg_cocots
75 | elif args.dsname == "mlt":
76 | cfg_data = cfg_mlt
77 | elif args.dsname == "icdar13":
78 | cfg_data = cfg_icdar13
79 | elif args.dsname == "totaltext":
80 | cfg_data = cfg_totaltext
81 | else:
82 | raise ValueError
83 |
84 | cfg.DEBUG = args.debug
85 | cfg.DIST_URL = 'tcp://127.0.0.1:{}'.format(args.port)
86 | if args.gpu is None:
87 | cfg.GPU_DEVICE = 'all'
88 | else:
89 | cfg.GPU_DEVICE = args.gpu
90 |
91 | cfg.MODEL = copy.deepcopy(cfg_mdel)
92 | cfg.DATA = copy.deepcopy(cfg_data)
93 |
94 | if istrain:
95 | cfg = set_cfg_train(cfg, dsname=args.dsname)
96 | if args.hrnet:
97 | cfg = set_cfg_hrnetw48_train(cfg)
98 | else:
99 | cfg = set_cfg_eval(cfg, dsname=args.dsname)
100 | if args.hrnet:
101 | cfg = set_cfg_hrnetw48_eval(cfg)
102 | cfg.MODEL.TEXRNET.PRETRAINED_PTH = args.pth
103 |
104 | if istrain:
105 | if args.dsname == "textseg":
106 | cfg.DATA.DATASET_MODE = 'train+val'
107 | elif args.dsname == "cocots":
108 | cfg.DATA.DATASET_MODE = 'train'
109 | elif args.dsname == "mlt":
110 | cfg.DATA.DATASET_MODE = 'trainseg'
111 | elif args.dsname == "icdar13":
112 | cfg.DATA.DATASET_MODE = 'train_fst'
113 | elif args.dsname == "totaltext":
114 | cfg.DATA.DATASET_MODE = 'train'
115 | else:
116 | raise ValueError
117 | else:
118 | if args.dsname == "textseg":
119 | cfg.DATA.DATASET_MODE = 'test'
120 | elif args.dsname == "cocots":
121 | cfg.DATA.DATASET_MODE = 'val'
122 | elif args.dsname == "mlt":
123 | cfg.DATA.DATASET_MODE = 'valseg'
124 | elif args.dsname == "icdar13":
125 | cfg.DATA.DATASET_MODE = 'test_fst'
126 | elif args.dsname == "totaltext":
127 | cfg.DATA.DATASET_MODE = 'test'
128 | else:
129 | raise ValueError
130 |
131 | if istrain:
132 | if args.trainwithcls:
133 | if args.dsname == 'textseg':
134 | cfg.DATA.LOADER_PIPELINE = [
135 | 'NumpyImageLoader',
136 | 'TextSeg_SeglabelLoader',
137 | 'CharBboxSpLoader',]
138 | cfg.DATA.RANDOM_RESIZE_CROP_SIZE = [32, 32]
139 | cfg.DATA.RANDOM_RESIZE_CROP_SCALE = [0.8, 1.2]
140 | cfg.DATA.RANDOM_RESIZE_CROP_RATIO = [3/4, 4/3]
141 | cfg.DATA.TRANS_PIPELINE = [
142 | 'UniformNumpyType',
143 | 'TextSeg_RandomResizeCropCharBbox',
144 | 'NormalizeUint8ToZeroOne',
145 | 'Normalize',
146 | 'RandomScaleOneSide',
147 | 'RandomCrop',
148 | ]
149 | elif args.dsname == 'icdar13':
150 | cfg.DATA.LOADER_PIPELINE = [
151 | 'NumpyImageLoader',
152 | 'SeglabelLoader',
153 | 'CharBboxSpLoader',]
154 | cfg.DATA.TRANS_PIPELINE = [
155 | 'UniformNumpyType',
156 | 'NormalizeUint8ToZeroOne',
157 | 'Normalize',
158 | 'RandomScaleOneSide',
159 | 'RandomCrop',
160 | ]
161 | else:
162 | raise ValueError
163 | cfg.DATA.FORMATTER = 'SemChinsChbbxFormatter'
164 | cfg.DATA.LOADER_SQUARE_BBOX = True
165 | cfg.DATA.RANDOM_RESIZE_CROP_FROM = 'sem'
166 | cfg.MODEL.TEXRNET.INTRAIN_GETPRED_FROM = 'sem'
167 | # the one with 93.98% and trained on semantic crops
168 | cfg.TRAIN.CLASSIFIER_PATH = osp.join(
169 | 'pretrained', 'init', 'resnet50_textcls.pth',
170 | )
171 | cfg.TRAIN.ROI_BBOX_PADDING_TYPE = 'semcrop'
172 | cfg.TRAIN.ROI_ALIGN_SIZE = [32, 32]
173 | cfg.TRAIN.UPDATE_CLASSIFIER = False
174 | cfg.TRAIN.ACTIVATE_CLASSIFIER_FOR_SEGMODEL_AFTER = 0
175 | cfg.TRAIN.LOSS_WEIGHT = {
176 | 'losssem' : 1,
177 | 'lossrfn' : 0.5,
178 | 'lossrfntri': 0.5,
179 | 'losscls' : 0.1,
180 | }
181 |
182 | if istrain:
183 | if args.hrnet:
184 | cfg.TRAIN.SIGNATURE = ['texrnet', 'hrnet']
185 | else:
186 | cfg.TRAIN.SIGNATURE = ['texrnet', 'deeplab']
187 | cfg.LOG_DIR = experiment_folder(cfg, isnew=True, sig=cfg.TRAIN.SIGNATURE)
188 | cfg.LOG_FILE = osp.join(cfg.LOG_DIR, 'train.log')
189 | else:
190 | cfg.LOG_DIR = osp.join(cfg.MISC_DIR, 'eval')
191 | cfg.LOG_FILE = osp.join(cfg.LOG_DIR, 'eval.log')
192 | cfg.TEST.SUB_DIR = None
193 |
194 | if cfg.DEBUG:
195 | cfg.EXPERIMENT_ID = 999999999999
196 | cfg.DATA.NUM_WORKERS_PER_GPU = 0
197 | cfg.TRAIN.BATCH_SIZE_PER_GPU = 2
198 |
199 | cfg = common_initiates(cfg)
200 |
201 | if istrain:
202 | if args.trainwithcls:
203 | exec_ts = ts_with_classifier()
204 | else:
205 | exec_ts = ts()
206 | trainer = train(cfg)
207 | trainer.register_stage(exec_ts)
208 |
209 | # trainer(0)
210 | mp.spawn(trainer,
211 | args=(),
212 | nprocs=cfg.GPU_COUNT,
213 | join=True)
214 | else:
215 | exec_es = es()
216 | tester = eval(cfg)
217 | tester.register_stage(exec_es)
218 |
219 | # tester(0)
220 | mp.spawn(tester,
221 | args=(),
222 | nprocs=cfg.GPU_COUNT,
223 | join=True)
224 |
--------------------------------------------------------------------------------
/requirement.txt:
--------------------------------------------------------------------------------
1 | torch==1.6
2 | torchvision==0.7
3 | matplotlib==3.3.2
4 | opencv-python==4.5.1.48
5 | easydict==1.9
6 |
--------------------------------------------------------------------------------