├── .gitignore ├── .gitmodules ├── LICENSE ├── README.md ├── configs ├── test │ └── test_full_image.cfg └── train │ └── train_cityscapes.cfg ├── data_prep └── get_cityscapes_list.py ├── setup.py ├── test └── predict_full_image.py ├── train ├── __init__.py ├── solver.py └── train_model.py └── tusimple_duc ├── __init__.py ├── core ├── __init__.py ├── cityscapes_labels.py ├── cityscapes_loader.py ├── lr_scheduler.py ├── metrics.py └── utils.py ├── networks ├── __init__.py ├── network_duc_hdc.py └── resnet.py └── test ├── __init__.py ├── predictor.py └── tester.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # Swap file 7 | *.swp 8 | 9 | # C extensions 10 | *.so 11 | 12 | # Distribution / packaging 13 | .Python 14 | env/ 15 | build/ 16 | develop-eggs/ 17 | dist/ 18 | downloads/ 19 | eggs/ 20 | .eggs/ 21 | lib/ 22 | lib64/ 23 | parts/ 24 | sdist/ 25 | var/ 26 | *.egg-info/ 27 | .installed.cfg 28 | *.egg 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *,cover 49 | .hypothesis/ 50 | 51 | # Translations 52 | *.mo 53 | *.pot 54 | 55 | # Django stuff: 56 | *.log 57 | local_settings.py 58 | 59 | # Flask instance folder 60 | instance/ 61 | 62 | # Scrapy stuff: 63 | .scrapy 64 | 65 | # Sphinx documentation 66 | docs/_build/ 67 | 68 | # PyBuilder 69 | target/ 70 | 71 | # IPython Notebook 72 | .ipynb_checkpoints 73 | 74 | # pyenv 75 | .python-version 76 | 77 | # celery beat schedule file 78 | celerybeat-schedule 79 | 80 | # dotenv 81 | .env 82 | 83 | # virtualenv 84 | venv/ 85 | ENV/ 86 | 87 | # Spyder project settings 88 | .spyderproject 89 | 90 | # Rope project settings 91 | .ropeproject 92 | 93 | # IDEA 94 | .idea/ 95 | 96 | # Model and results 97 | models 98 | models/* 99 | results 100 | results/* 101 | 102 | 103 | # Data 104 | data/* 105 | data 106 | 107 | # linux temp 108 | *~ 109 | 110 | # image and pdf 111 | *.pdf 112 | *.gv 113 | *.png 114 | 115 | # text results 116 | *.lst 117 | *.txt 118 | 119 | # big models 120 | */*.params 121 | */*.caffemodel 122 | 123 | *.params 124 | *.json 125 | -------------------------------------------------------------------------------- /.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "mxnet"] 2 | path = mxnet 3 | url = git@github.com:TuSimple/mxnet.git 4 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "{}" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright {TuSimple} {name of copyright owner} 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # TuSimple-DUC 2 | 3 | by Panqu Wang, Pengfei Chen, Ye Yuan, Ding Liu, Zehua Huang, Xiaodi Hou, and Garrison Cottrell. 4 | 5 | ## Introduction 6 | 7 | This repository is for [Understanding Convolution for Semantic Segmentation](https://arxiv.org/abs/1702.08502) (WACV 2018), which achieved state-of-the-art result on the CityScapes, PASCAL VOC 2012, and Kitti Road benchmark. 8 | 9 | ## Requirement 10 | 11 | We tested our code on: 12 | 13 | Ubuntu 16.04, Python 2.7 with 14 | 15 | [MXNet (0.11.0)](https://github.com/TuSimple/mxnet), numpy(1.13.1), cv2(3.2.0), PIL(4.2.1), and cython(0.25.2) 16 | 17 | ## Usage 18 | 19 | 1. Clone the repository: 20 | 21 | ```shell 22 | git clone git@github.com:TuSimple/TuSimple-DUC.git 23 | python setup.py develop --user 24 | ``` 25 | 26 | 2. Download the pretrained model from [Google Drive](https://drive.google.com/drive/folders/0B72xLTlRb0SoREhISlhibFZTRmM?resourcekey=0-g2Fr58Nn591bs5ZvZ0Vlwg&usp=sharing). 27 | 28 | 3. Build MXNet (only tested on the TuSimple version): 29 | 30 | ```shell 31 | git clone --recursive git@github.com:TuSimple/mxnet.git 32 | vim make/config.mk (we should have USE_CUDA = 1, modify USE_CUDA_PATH, and have USE_CUDNN = 1 to enable GPU usage.) 33 | make -j 34 | cd python 35 | python setup.py develop --user 36 | ``` 37 | 38 | For more MXNet tutorials, please refer to the [official documentation](https://mxnet.incubator.apache.org/install/index.html). 39 | 40 | 3. Training: 41 | 42 | ```shell 43 | cd train 44 | python train_model.py ../configs/train/train_cityscapes.cfg 45 | ``` 46 | 47 | The paths/dirs in the ``.cfg`` file need to be specified by the user. 48 | 49 | 4. Testing 50 | 51 | ``` 52 | cd test 53 | python predict_full_image.py ../configs/test/test_full_image.cfg 54 | ``` 55 | 56 | The paths/dirs in the ``.cfg`` file need to be specified by the user. 57 | 58 | 5. Results: 59 | 60 | Modify the ``result_dir`` path in the config file to save the label map and visualizations. The expected scores are: 61 | 62 | (single scale testing denotes as 'ss' and multiple scale testing denotes as 'ms') 63 | 64 | - ResNet101-DUC-HDC on CityScapes testset (mIoU): 79.1(ss) / 80.1(ms) 65 | - ResNet152-DUC on VOC2012 (mIoU): 83.1(ss) 66 | 67 | ## Citation 68 | 69 | If you find the repository is useful for your research, please consider citing: 70 | 71 | @article{wang2017understanding, 72 | title={Understanding convolution for semantic segmentation}, 73 | author={Wang, Panqu and Chen, Pengfei and Yuan, Ye and Liu, Ding and Huang, Zehua and Hou, Xiaodi and Cottrell, Garrison}, 74 | journal={arXiv preprint arXiv:1702.08502}, 75 | year={2017} 76 | } 77 | 78 | ## Questions 79 | 80 | Please contact panqu.wang@tusimple.ai or pengfei.chen@tusimple.ai . 81 | -------------------------------------------------------------------------------- /configs/test/test_full_image.cfg: -------------------------------------------------------------------------------- 1 | [model] 2 | model_dir=../models 3 | model_prefix=ResNet_DUC_HDC_CityScapes 4 | model_epoch=20 5 | gpu=0 6 | result_dir=../results/ 7 | label_num=19 8 | multi_scales=False 9 | 10 | [data] 11 | image_list=../data/cityscapes/imagesets/cityscapes_fine/val.lst 12 | test_img_dir=../data/cityscapes 13 | gt_dir=../data/cityscapes 14 | ds_rate=8 15 | cell_width=2 16 | test_shape=1024,2048 17 | result_shape=1024,2048 18 | rgb_mean=122.675, 116.669, 104.008 19 | test_scales=1 20 | -------------------------------------------------------------------------------- /configs/train/train_cityscapes.cfg: -------------------------------------------------------------------------------- 1 | [env] 2 | use_cpu=False 3 | gpus=0 4 | kv_store=local 5 | multi_thread=False 6 | 7 | [network] 8 | label_num=19 9 | aspp=4 10 | aspp_stride=6 11 | cell_width=2 12 | ignore_label=255 13 | bn_use_global_stats=True 14 | 15 | [model] 16 | num_epochs=50 17 | model_dir=../models 18 | save_model_prefix=ResNet_DUC_HDC_CityScapes 19 | checkpoint_interval=1 20 | lr=1e-4 21 | lr_policy=step 22 | lr_factor=0.5 23 | lr_factor_epoch=5 24 | momentum=0.9 25 | weight_decay=0.00005 26 | load_model_dir=../models 27 | load_model_prefix=ResNet_DUC_HDC_CityScapes 28 | load_epoch=20 29 | eval_metric=acc_ignore, IoU 30 | 31 | [data] 32 | data_dir=../data/cityscapes/ 33 | label_dir=../data/cityscapes/ 34 | train_list=../data/cityscapes/imagesets/cityscapes_fine/train_bigger_patch.lst 35 | use_val=False 36 | val_list=None 37 | rgb_mean=122.675, 116.669, 104.008 38 | batch_size=1 39 | ds_rate=8 40 | convert_label=True 41 | scale_factors=0.5, 0.75, 1.0, 1.25, 1.5 42 | crop_shape=800,800 43 | use_mirror=True 44 | use_random_crop=True 45 | random_bound=120,120 46 | 47 | [misc] 48 | draw_network=False 49 | -------------------------------------------------------------------------------- /data_prep/get_cityscapes_list.py: -------------------------------------------------------------------------------- 1 | import os 2 | import glob 3 | 4 | 5 | def get_cityscapes_list_augmented(root, image_path, label_path, lst_path, is_fine=True, sample_rate=1): 6 | index = 0 7 | train_lst = [] 8 | label_prefix = 'gtFine_labelIds' if is_fine else 'gtCoarse_labelIds' 9 | 10 | # images 11 | all_images = glob.glob(os.path.join(root, image_path, '*/*.png')) 12 | all_images.sort() 13 | for p in all_images: 14 | l = p.replace(image_path, label_path).replace('leftImg8bit', label_prefix) 15 | if os.path.isfile(l): 16 | index += 1 17 | if index % 100 == 0: 18 | print "%d out of %d done." % (index, len(all_images)) 19 | if index % sample_rate != 0: 20 | continue 21 | for i in range(1, 8): 22 | train_lst.append([str(index), p, l, "512", str(256 * i)]) 23 | else: 24 | print "dismiss %s" % (p) 25 | 26 | train_out = open(lst_path, "w") 27 | for line in train_lst: 28 | print >> train_out, '\t'.join(line) 29 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, find_packages 2 | 3 | 4 | setup( 5 | name='tusimple_duc', 6 | version='1.0.0', 7 | author='Pengfei Chen & Panqu Wang', 8 | description='semantic segmentation module on the Cityscapes dataset', 9 | install_requires=['configparser', 'numpy', 'Pillow'], 10 | url='https://github.com/TuSimple/TuSimple-DUC', 11 | packages=find_packages(), 12 | ) 13 | -------------------------------------------------------------------------------- /test/predict_full_image.py: -------------------------------------------------------------------------------- 1 | import ConfigParser 2 | import os 3 | import sys 4 | import time 5 | from PIL import Image 6 | 7 | import cv2 as cv 8 | import numpy as np 9 | 10 | from tusimple_duc.test.tester import Tester 11 | 12 | 13 | class ImageListTester: 14 | def __init__(self, config): 15 | self.config = config 16 | # # model 17 | self.model_dir = config.get('model', 'model_dir') 18 | self.model_prefix = config.get('model', 'model_prefix') 19 | self.model_epoch = config.getint('model', 'model_epoch') 20 | self.result_dir = config.get('model', 'result_dir') 21 | if not os.path.isdir(self.result_dir): 22 | os.mkdir(self.result_dir) 23 | if not os.path.isdir(os.path.join(self.result_dir, 'visualization')): 24 | os.mkdir(os.path.join(self.result_dir, 'visualization')) 25 | if not os.path.isdir(os.path.join(self.result_dir, 'score')): 26 | os.mkdir(os.path.join(self.result_dir, 'score')) 27 | 28 | # data 29 | self.image_list = config.get('data', 'image_list') 30 | self.test_img_dir = config.get('data', 'test_img_dir') 31 | self.result_shape = [int(f) for f in config.get('data', 'result_shape').split(',')] 32 | # initialize tester 33 | self.tester = Tester(self.config) 34 | 35 | def predict_single(self, item): 36 | # img_name = item.strip().replace('/', '_') 37 | img_name = item.strip().split('/')[-1] 38 | img_path = os.path.join(self.test_img_dir, item.strip().split('\t')[1]) 39 | 40 | # read image as rgb 41 | im = cv.imread(img_path)[:, :, ::-1] 42 | result_width = self.result_shape[1] 43 | result_height = self.result_shape[0] 44 | 45 | concat_img = Image.new('RGB', (result_width * 2, result_height * 2)) 46 | 47 | results = self.tester.predict_single( 48 | img=im, 49 | ret_converted=True, 50 | ret_heat_map=True, 51 | ret_softmax=True) 52 | 53 | # label 54 | heat_map = results['heat_map'] 55 | cvt_labels = results['converted'] 56 | raw_labels = results['raw'] 57 | softmax = results['softmax'] 58 | 59 | confidence = float(np.max(softmax, axis=0).mean()) 60 | 61 | result_img = Image.fromarray(self.tester.colorize(raw_labels)).resize(self.result_shape[::-1]) 62 | 63 | # paste raw image 64 | concat_img.paste(Image.fromarray(im).convert('RGB'), (0, 0)) 65 | # paste color result 66 | concat_img.paste(result_img, (0, result_height)) 67 | # paste blended result 68 | concat_img.paste(Image.fromarray(cv.addWeighted(im[:, :, ::-1], 0.5, np.array(result_img), 0.5, 0)), 69 | (result_width, 0)) 70 | # paste heat map 71 | concat_img.paste(Image.fromarray(heat_map[:, :, [2, 1, 0]]).resize(self.result_shape[::-1]), 72 | (result_width, result_height)) 73 | concat_img.save(os.path.join(self.result_dir, 'visualization', img_name.replace('jpg', 'png'))) 74 | 75 | # save results for score 76 | cv.imwrite(os.path.join(self.result_dir, 'score', img_name.replace('jpg', 'png')), cvt_labels) 77 | return confidence, img_path 78 | 79 | def predict_all(self): 80 | img_list = [line for line in open(self.image_list, 'r')] 81 | idx = 0 82 | conf_lst = [] 83 | for item in img_list[:]: 84 | idx += 1 85 | start_time = time.time() 86 | conf_lst.append(self.predict_single(item)) 87 | print 'Process %d out of %d image ... %s, time cost:%.3f, confidence:%.3f' % \ 88 | (idx, len(img_list), item.strip().split('/')[-1], time.time() - start_time, conf_lst[-1][0]) 89 | conf_file = open(os.path.join(self.result_dir, self.model_prefix + str(self.model_epoch) + '.txt'), 'w') 90 | conf_lst.sort() 91 | for item in conf_lst: 92 | 93 | print >> conf_file, "{}\t{}".format(item[1], item[0]) 94 | 95 | 96 | if __name__ == '__main__': 97 | config_path = sys.argv[1] 98 | config = ConfigParser.RawConfigParser() 99 | config.read(config_path) 100 | tester = ImageListTester(config) 101 | tester.predict_all() 102 | -------------------------------------------------------------------------------- /train/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TuSimple/TuSimple-DUC/f3b05b21cf3252baf89c7b88667f55cc74cd5418/train/__init__.py -------------------------------------------------------------------------------- /train/solver.py: -------------------------------------------------------------------------------- 1 | import os 2 | import logging 3 | from datetime import datetime 4 | 5 | import mxnet as mx 6 | from mxnet import metric 7 | from tusimple_duc.core import utils 8 | from tusimple_duc.core import lr_scheduler 9 | from tusimple_duc.core.cityscapes_loader import CityLoader 10 | from tusimple_duc.core.metrics import CompositeEvalMetric, AccWithIgnoreMetric, IoUMetric, SoftmaxLoss 11 | from tusimple_duc.networks.network_duc_hdc import get_symbol_duc_hdc 12 | 13 | 14 | class Solver(object): 15 | def __init__(self, config): 16 | try: 17 | self.config = config 18 | # environment 19 | self.use_cpu = config.getboolean('env', 'use_cpu') 20 | self.gpus = config.get('env', 'gpus') 21 | self.kv_store = config.get('env', 'kv_store') 22 | self.ctx = mx.cpu() if self.use_cpu is True else [ 23 | mx.gpu(int(i)) for i in self.gpus.split(',')] 24 | self.multi_thread = config.getboolean('env', 'multi_thread') 25 | 26 | # network parameter 27 | self.label_num = config.getint('network', 'label_num') 28 | self.aspp = config.getint('network', 'aspp') 29 | self.aspp_stride = config.getint('network', 'aspp_stride') 30 | self.cell_width = config.getint('network', 'cell_width') 31 | self.ignore_label = config.getint('network', 'ignore_label') 32 | self.bn_use_global_stats = config.getboolean('network', 'bn_use_global_stats') 33 | 34 | # model 35 | self.num_epochs = config.getint('model', 'num_epochs') 36 | self.model_dir = config.get('model', 'model_dir') 37 | self.save_model_prefix = config.get('model', 'save_model_prefix') 38 | self.checkpoint_interval = config.getint('model', 'checkpoint_interval') 39 | # SGD parameters 40 | self.optimizer = 'sgd' 41 | self.lr = config.getfloat('model', 'lr') 42 | self.lr_policy = config.get('model', 'lr_policy') 43 | self.lr_factor = config.getfloat('model', 'lr_factor') 44 | self.lr_factor_epoch = config.getfloat('model', 'lr_factor_epoch') 45 | self.momentum = config.getfloat('model', 'momentum') 46 | self.weight_decay = config.getfloat('model', 'weight_decay') 47 | # fine tuning 48 | self.load_model_dir = config.get('model', 'load_model_dir') 49 | self.load_model_prefix = config.get('model', 'load_model_prefix') 50 | self.load_epoch = config.getint('model', 'load_epoch') 51 | # evaluation metric 52 | self.eval_metric = [m.strip() for m in config.get('model', 'eval_metric').split(',')] 53 | 54 | # data 55 | self.data_dir = config.get('data', 'data_dir') 56 | self.label_dir = config.get('data', 'label_dir') 57 | self.train_list = config.get('data', 'train_list') 58 | self.use_val = config.getboolean('data', 'use_val') 59 | if self.use_val: 60 | self.val_list = config.get('data', 'val_list') 61 | self.rgb_mean = tuple([float(color.strip()) for color in config.get('data', 'rgb_mean').split(',')]) 62 | self.batch_size = config.getint('data', 'batch_size') 63 | self.ds_rate = config.getint('data', 'ds_rate') 64 | self.convert_label = config.getboolean('data', 'convert_label') 65 | self.scale_factors = [float(scale.strip()) for scale in config.get('data', 'scale_factors').split(',')] 66 | self.crop_shape = tuple([int(l.strip()) for l in config.get('data', 'crop_shape').split(',')]) 67 | self.use_mirror = config.getboolean('data', 'use_mirror') 68 | self.use_random_crop = config.getboolean('data', 'use_random_crop') 69 | self.random_bound = tuple([int(l.strip()) for l in config.get('data', 'random_bound').split(',')]) 70 | 71 | # miscellaneous 72 | self.draw_network = config.getboolean('misc', 'draw_network') 73 | 74 | # inference 75 | self.train_size = 0 76 | with open(self.train_list, 'r') as f: 77 | for _ in f: 78 | self.train_size += 1 79 | self.epoch_size = self.train_size / self.batch_size 80 | self.data_shape = [tuple(list([self.batch_size, 3, self.crop_shape[0], self.crop_shape[1]]))] 81 | self.label_shape = [tuple([self.batch_size, (self.crop_shape[1]*self.crop_shape[0]/self.cell_width**2)])] 82 | self.data_name = ['data'] 83 | self.label_name = ['seg_loss_label'] 84 | self.symbol = None 85 | self.arg_params = None 86 | self.aux_params = None 87 | 88 | except ValueError: 89 | logging.error('Config parameter error') 90 | 91 | def get_data_iterator(self): 92 | loader = CityLoader 93 | train_args = { 94 | 'data_path' : self.data_dir, 95 | 'label_path' : self.label_dir, 96 | 'rgb_mean' : self.rgb_mean, 97 | 'batch_size' : self.batch_size, 98 | 'scale_factors' : self.scale_factors, 99 | 'data_name' : self.data_name, 100 | 'label_name' : self.label_name, 101 | 'data_shape' : self.data_shape, 102 | 'label_shape' : self.label_shape, 103 | 'use_random_crop' : self.use_random_crop, 104 | 'use_mirror' : self.use_mirror, 105 | 'ds_rate' : self.ds_rate, 106 | 'convert_label' : self.convert_label, 107 | 'multi_thread' : self.multi_thread, 108 | 'cell_width' : self.cell_width, 109 | 'random_bound' : self.random_bound, 110 | } 111 | val_args = train_args.copy() 112 | val_args['scale_factors'] = [1] 113 | val_args['use_random_crop'] = False 114 | val_args['use_mirror'] = False 115 | train_dataloader = loader(self.train_list, train_args) 116 | if self.use_val: 117 | val_dataloader = loader(self.val_list, val_args) 118 | else: 119 | val_dataloader = None 120 | return train_dataloader, val_dataloader 121 | 122 | def get_symbol(self): 123 | self.symbol = get_symbol_duc_hdc( 124 | cell_cap=(self.ds_rate / self.cell_width) ** 2, 125 | label_num=self.label_num, 126 | ignore_label=self.ignore_label, 127 | bn_use_global_stats=self.bn_use_global_stats, 128 | aspp_num=self.aspp, 129 | aspp_stride=self.aspp_stride, 130 | ) 131 | 132 | # build up symbol, parameters and auxiliary parameters 133 | def get_model(self): 134 | self.get_symbol() 135 | 136 | # load model 137 | if self.load_model_prefix is not None and self.load_epoch > 0: 138 | self.symbol, self.arg_params, self.aux_params = \ 139 | mx.model.load_checkpoint(os.path.join(self.load_model_dir, self.load_model_prefix), self.load_epoch) 140 | 141 | def fit(self): 142 | # kvstore 143 | if self.kv_store is 'local' and ( 144 | self.gpus is None or len(self.gpus.split(',')) is 1): 145 | kv = None 146 | else: 147 | kv = mx.kvstore.create(self.kv_store) 148 | 149 | # setup module, including symbol, params and aux 150 | # get_model should always be called before get_data_iterator to ensure correct data loader 151 | self.get_model() 152 | 153 | # get dataloader 154 | train_data, eval_data = self.get_data_iterator() 155 | 156 | # evaluate metrics 157 | eval_metric_lst = [] 158 | if "acc" in self.eval_metric: 159 | eval_metric_lst.append(metric.create(self.eval_metric)) 160 | if "acc_ignore" in self.eval_metric and self.ignore_label is not None: 161 | eval_metric_lst.append(AccWithIgnoreMetric(self.ignore_label, name="acc_ignore")) 162 | if "IoU" in self.eval_metric and self.ignore_label is not None: 163 | eval_metric_lst.append(IoUMetric(self.ignore_label, label_num=self.label_num, name="IoU")) 164 | eval_metric_lst.append(SoftmaxLoss(self.ignore_label, label_num=self.label_num, name="SoftmaxLoss")) 165 | eval_metrics = CompositeEvalMetric(metrics=eval_metric_lst) 166 | 167 | optimizer_params = {} 168 | # optimizer 169 | # lr policy 170 | if self.lr_policy == 'step' and self.lr_factor < 1 and self.lr_factor_epoch > 0: 171 | optimizer_params['lr_scheduler'] = mx.lr_scheduler.FactorScheduler( 172 | step=max(int(self.epoch_size * self.lr_factor_epoch), 1), 173 | factor=self.lr_factor) 174 | elif self.lr_policy == 'poly': 175 | optimizer_params['lr_scheduler'] = lr_scheduler.PolyScheduler( 176 | origin_lr=self.lr, 177 | max_samples=max(int(self.epoch_size * self.num_epochs), 1), 178 | factor=self.lr_factor) 179 | else: 180 | logging.error('Unknown lr policy: %s' % self.lr_policy) 181 | optimizer_params['learning_rate'] = self.lr 182 | optimizer_params['momentum'] = self.momentum 183 | optimizer_params['wd'] = self.weight_decay 184 | optimizer_params['rescale_grad'] = 1.0 / self.batch_size 185 | optimizer_params['clip_gradient'] = 5 186 | 187 | # directory for saving models 188 | model_path = os.path.join(self.model_dir, self.save_model_prefix) 189 | if not os.path.isdir(model_path): 190 | os.mkdir(model_path) 191 | model_full_path = os.path.join(model_path, datetime.now().strftime('%Y_%m_%d_%H:%M:%S')) 192 | if not os.path.isdir(model_full_path): 193 | os.mkdir(model_full_path) 194 | checkpoint = utils.do_checkpoint(os.path.join(model_full_path, self.save_model_prefix), self.checkpoint_interval) 195 | with open(os.path.join(model_full_path, 196 | 'train_' + datetime.now().strftime('%Y_%m_%d_%H:%M:%S') + '.cfg'), 'w') as f: 197 | self.config.write(f) 198 | utils.save_symbol(self.symbol, os.path.join(model_full_path, self.save_model_prefix)) 199 | utils.save_log(self.save_model_prefix, model_full_path) 200 | 201 | # draw network 202 | if self.draw_network is True: 203 | utils.draw_network(self.symbol, os.path.join(model_full_path, self.save_model_prefix), self.data_shape[0]) 204 | 205 | # batch_end_callback 206 | batch_end_callback = list() 207 | batch_end_callback.append(utils.Speedometer(self.batch_size, 10)) 208 | 209 | module = mx.module.Module(self.symbol, context=self.ctx, data_names=self.data_name, label_names=self.label_name) 210 | 211 | # initialize (base_module now no more do this initialization) 212 | train_data.reset() 213 | module.fit( 214 | train_data=train_data, 215 | eval_data=eval_data, 216 | eval_metric=eval_metrics, 217 | epoch_end_callback=checkpoint, 218 | batch_end_callback=batch_end_callback, 219 | kvstore=kv, 220 | optimizer=self.optimizer, 221 | optimizer_params=optimizer_params, 222 | initializer=mx.init.Xavier(factor_type="in", magnitude=2.34), 223 | arg_params=self.arg_params, 224 | aux_params=self.aux_params, 225 | allow_missing=True, 226 | begin_epoch=self.load_epoch, 227 | num_epoch=self.num_epochs, 228 | ) 229 | -------------------------------------------------------------------------------- /train/train_model.py: -------------------------------------------------------------------------------- 1 | from solver import Solver 2 | import ConfigParser 3 | import sys 4 | 5 | 6 | def train_end2end(): 7 | config = ConfigParser.RawConfigParser() 8 | config_path = sys.argv[1] 9 | config.read(config_path) 10 | 11 | model = Solver(config) 12 | model.fit() 13 | 14 | if __name__ == '__main__': 15 | train_end2end() 16 | -------------------------------------------------------------------------------- /tusimple_duc/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TuSimple/TuSimple-DUC/f3b05b21cf3252baf89c7b88667f55cc74cd5418/tusimple_duc/__init__.py -------------------------------------------------------------------------------- /tusimple_duc/core/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TuSimple/TuSimple-DUC/f3b05b21cf3252baf89c7b88667f55cc74cd5418/tusimple_duc/core/__init__.py -------------------------------------------------------------------------------- /tusimple_duc/core/cityscapes_labels.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | # 3 | # Cityscapes labels 4 | # 5 | 6 | from collections import namedtuple 7 | 8 | 9 | #-------------------------------------------------------------------------------- 10 | # Definitions 11 | #-------------------------------------------------------------------------------- 12 | 13 | # a label and all meta information 14 | Label = namedtuple( 'Label' , [ 15 | 16 | 'name' , # The identifier of this label, e.g. 'car', 'person', ... . 17 | # We use them to uniquely name a class 18 | 19 | 'id' , # An integer ID that is associated with this label. 20 | # The IDs are used to represent the label in ground truth images 21 | # An ID of -1 means that this label does not have an ID and thus 22 | # is ignored when creating ground truth images (e.g. license plate). 23 | # Do not modify these IDs, since exactly these IDs are expected by the 24 | # evaluation server. 25 | 26 | 'trainId' , # Feel free to modify these IDs as suitable for your method. Then create 27 | # ground truth images with train IDs, using the tools provided in the 28 | # 'preparation' folder. However, make sure to validate or submit results 29 | # to our evaluation server using the regular IDs above! 30 | # For trainIds, multiple labels might have the same ID. Then, these labels 31 | # are mapped to the same class in the ground truth images. For the inverse 32 | # mapping, we use the label that is defined first in the list below. 33 | # For example, mapping all void-type classes to the same ID in training, 34 | # might make sense for some approaches. 35 | # Max value is 255! 36 | 37 | 'category' , # The name of the category that this label belongs to 38 | 39 | 'categoryId' , # The ID of this category. Used to create ground truth images 40 | # on category level. 41 | 42 | 'hasInstances', # Whether this label distinguishes between single instances or not 43 | 44 | 'ignoreInEval', # Whether pixels having this class as ground truth label are ignored 45 | # during evaluations or not 46 | 47 | 'color' , # The color of this label 48 | ] ) 49 | 50 | 51 | #-------------------------------------------------------------------------------- 52 | # A list of all labels 53 | #-------------------------------------------------------------------------------- 54 | 55 | # Please adapt the train IDs as appropriate for you approach. 56 | # Note that you might want to ignore labels with ID 255 during training. 57 | # Further note that the current train IDs are only a suggestion. You can use whatever you like. 58 | # Make sure to provide your results using the original IDs and not the training IDs. 59 | # Note that many IDs are ignored in evaluation and thus you never need to predict these! 60 | 61 | labels = [ 62 | # name id trainId category catId hasInstances ignoreInEval color 63 | Label( 'unlabeled' , 0 , 255 , 'void' , 0 , False , True , ( 0, 0, 0) ), 64 | Label( 'ego vehicle' , 1 , 255 , 'void' , 0 , False , True , ( 0, 0, 0) ), 65 | Label( 'rectification border' , 2 , 255 , 'void' , 0 , False , True , ( 0, 0, 0) ), 66 | Label( 'out of roi' , 3 , 255 , 'void' , 0 , False , True , ( 0, 0, 0) ), 67 | Label( 'static' , 4 , 255 , 'void' , 0 , False , True , ( 0, 0, 0) ), 68 | Label( 'dynamic' , 5 , 255 , 'void' , 0 , False , True , (111, 74, 0) ), 69 | Label( 'ground' , 6 , 255 , 'void' , 0 , False , True , ( 81, 0, 81) ), 70 | Label( 'road' , 7 , 0 , 'flat' , 1 , False , False , (128, 64,128) ), 71 | Label( 'sidewalk' , 8 , 1 , 'flat' , 1 , False , False , (244, 35,232) ), 72 | Label( 'parking' , 9 , 255 , 'flat' , 1 , False , True , (250,170,160) ), 73 | Label( 'rail track' , 10 , 255 , 'flat' , 1 , False , True , (230,150,140) ), 74 | Label( 'building' , 11 , 2 , 'construction' , 2 , False , False , ( 70, 70, 70) ), 75 | Label( 'wall' , 12 , 3 , 'construction' , 2 , False , False , (102,102,156) ), 76 | Label( 'fence' , 13 , 4 , 'construction' , 2 , False , False , (190,153,153) ), 77 | Label( 'guard rail' , 14 , 255 , 'construction' , 2 , False , True , (180,165,180) ), 78 | Label( 'bridge' , 15 , 255 , 'construction' , 2 , False , True , (150,100,100) ), 79 | Label( 'tunnel' , 16 , 255 , 'construction' , 2 , False , True , (150,120, 90) ), 80 | Label( 'pole' , 17 , 5 , 'object' , 3 , False , False , (153,153,153) ), 81 | Label( 'polegroup' , 18 , 255 , 'object' , 3 , False , True , (153,153,153) ), 82 | Label( 'traffic light' , 19 , 6 , 'object' , 3 , False , False , (250,170, 30) ), 83 | Label( 'traffic sign' , 20 , 7 , 'object' , 3 , False , False , (220,220, 0) ), 84 | Label( 'vegetation' , 21 , 8 , 'nature' , 4 , False , False , (107,142, 35) ), 85 | Label( 'terrain' , 22 , 9 , 'nature' , 4 , False , False , (152,251,152) ), 86 | Label( 'sky' , 23 , 10 , 'sky' , 5 , False , False , ( 70,130,180) ), 87 | Label( 'person' , 24 , 11 , 'human' , 6 , True , False , (220, 20, 60) ), 88 | Label( 'rider' , 25 , 12 , 'human' , 6 , True , False , (255, 0, 0) ), 89 | Label( 'car' , 26 , 13 , 'vehicle' , 7 , True , False , ( 0, 0,142) ), 90 | Label( 'truck' , 27 , 14 , 'vehicle' , 7 , True , False , ( 0, 0, 70) ), 91 | Label( 'bus' , 28 , 15 , 'vehicle' , 7 , True , False , ( 0, 60,100) ), 92 | Label( 'caravan' , 29 , 255 , 'vehicle' , 7 , True , True , ( 0, 0, 90) ), 93 | Label( 'trailer' , 30 , 255 , 'vehicle' , 7 , True , True , ( 0, 0,110) ), 94 | Label( 'train' , 31 , 16 , 'vehicle' , 7 , True , False , ( 0, 80,100) ), 95 | Label( 'motorcycle' , 32 , 17 , 'vehicle' , 7 , True , False , ( 0, 0,230) ), 96 | Label( 'bicycle' , 33 , 18 , 'vehicle' , 7 , True , False , (119, 11, 32) ), 97 | Label( 'license plate' , -1 , -1 , 'vehicle' , 7 , False , True , ( 0, 0,142) ), 98 | ] 99 | 100 | 101 | #-------------------------------------------------------------------------------- 102 | # Create dictionaries for a fast lookup 103 | #-------------------------------------------------------------------------------- 104 | 105 | # Please refer to the main method below for example usages! 106 | 107 | # name to label object 108 | name2label = { label.name : label for label in labels } 109 | # id to label object 110 | id2label = { label.id : label for label in labels } 111 | # trainId to label object 112 | trainId2label = { label.trainId : label for label in reversed(labels) } 113 | # category to list of label objects 114 | category2labels = {} 115 | for label in labels: 116 | category = label.category 117 | if category in category2labels: 118 | category2labels[category].append(label) 119 | else: 120 | category2labels[category] = [label] 121 | 122 | #-------------------------------------------------------------------------------- 123 | # Assure single instance name 124 | #-------------------------------------------------------------------------------- 125 | 126 | # returns the label name that describes a single instance (if possible) 127 | # e.g. input | output 128 | # ---------------------- 129 | # car | car 130 | # cargroup | car 131 | # foo | None 132 | # foogroup | None 133 | # skygroup | None 134 | def assureSingleInstanceName( name ): 135 | # if the name is known, it is not a group 136 | if name in name2label: 137 | return name 138 | # test if the name actually denotes a group 139 | if not name.endswith("group"): 140 | return None 141 | # remove group 142 | name = name[:-len("group")] 143 | # test if the new name exists 144 | if not name in name2label: 145 | return None 146 | # test if the new name denotes a label that actually has instances 147 | if not name2label[name].hasInstances: 148 | return None 149 | # all good then 150 | return name 151 | 152 | #-------------------------------------------------------------------------------- 153 | # Main for testing 154 | #-------------------------------------------------------------------------------- 155 | 156 | # just a dummy main 157 | if __name__ == "__main__": 158 | # Print all the labels 159 | print("List of cityscapes labels:") 160 | print("") 161 | print(" {:>21} | {:>3} | {:>7} | {:>14} | {:>10} | {:>12} | {:>12}".format( 'name', 'id', 'trainId', 'category', 'categoryId', 'hasInstances', 'ignoreInEval' )) 162 | print(" " + ('-' * 98)) 163 | for label in labels: 164 | print(" {:>21} | {:>3} | {:>7} | {:>14} | {:>10} | {:>12} | {:>12}".format( label.name, label.id, label.trainId, label.category, label.categoryId, label.hasInstances, label.ignoreInEval )) 165 | print("") 166 | 167 | print("Example usages:") 168 | 169 | # Map from name to label 170 | name = 'car' 171 | id = name2label[name].id 172 | print("ID of label '{name}': {id}".format( name=name, id=id )) 173 | 174 | # Map from ID to label 175 | category = id2label[id].category 176 | print("Category of label with ID '{id}': {category}".format( id=id, category=category )) 177 | 178 | # Map from trainID to label 179 | trainId = 0 180 | name = trainId2label[trainId].name 181 | print("Name of label with trainID '{id}': {name}".format( id=trainId, name=name )) -------------------------------------------------------------------------------- /tusimple_duc/core/cityscapes_loader.py: -------------------------------------------------------------------------------- 1 | import Queue 2 | import atexit 3 | import logging 4 | import multiprocessing as mp 5 | import random 6 | 7 | import mxnet as mx 8 | import numpy as np 9 | 10 | import utils 11 | 12 | 13 | class CityLoader(mx.io.DataIter): 14 | """ 15 | Data Loader for Cityscapes Dataset 16 | """ 17 | def __init__(self, data_list, input_args): 18 | super(CityLoader, self).__init__() 19 | self.input_args = input_args 20 | self.data_list = data_list 21 | self.data = CityLoader.read_data(self.data_list) 22 | self.data_path = input_args.get('data_path', '') 23 | self.data_shape = input_args.get('data_shape') 24 | self.label_shape = input_args.get('label_shape') 25 | self.multi_thread = input_args.get('multi_thread', False) 26 | self.n_thread = input_args.get('n_thread', 7) 27 | self.data_name = input_args.get('data_name', ['data']) 28 | self.label_name = input_args.get('label_name', ['seg_loss_label']) 29 | self.data_loader = input_args.get('data_loader') 30 | self.stop_word = input_args.get('stop_word', '==STOP--') 31 | self.batch_size = input_args.pop('batch_size', 4) 32 | self.current_batch = None 33 | self.data_num = None 34 | self.current = None 35 | self.worker_proc = None 36 | 37 | if self.multi_thread: 38 | self.stop_flag = mp.Value('b', False) 39 | self.result_queue = mp.Queue(maxsize=self.batch_size*3) 40 | self.data_queue = mp.Queue() 41 | 42 | @staticmethod 43 | def read_data(data_list): 44 | data = [] 45 | with open(data_list, 'r') as f: 46 | for line in f: 47 | frags = line.strip().split('\t') 48 | item = list() 49 | item.append(frags[1]) # item[0] is image path 50 | item.append(frags[2]) # item[1] is label path 51 | if len(frags) > 3: 52 | item.append(frags[3:]) # item[2] is parameters for cropping 53 | data.append(item) 54 | return data 55 | 56 | def _insert_queue(self): 57 | for item in self.data: 58 | self.data_queue.put(item) 59 | [self.data_queue.put(self.stop_word) for pid in range(self.n_thread)] 60 | 61 | def _thread_start(self): 62 | self.stop_flag = False 63 | self.worker_proc = [mp.Process(target=CityLoader._worker, 64 | args=[pid, 65 | self.data_queue, 66 | self.result_queue, 67 | self.input_args, 68 | self.stop_word, 69 | self.stop_flag]) 70 | for pid in range(self.n_thread)] 71 | [item.start() for item in self.worker_proc] 72 | 73 | def cleanup(): 74 | self.shutdown() 75 | atexit.register(cleanup) 76 | 77 | @staticmethod 78 | def _worker(worker_id, data_queue, result_queue, input_args, stop_word, stop_flag): 79 | count = 0 80 | for item in iter(data_queue.get, stop_word): 81 | if stop_flag == 1: 82 | break 83 | image, label = CityLoader._get_single(item, input_args) 84 | result_queue.put((image, label)) 85 | count += 1 86 | 87 | @property 88 | def provide_label(self): 89 | return [(self.label_name[i], self.label_shape[i]) for i in range(len(self.label_name))] 90 | 91 | @property 92 | def provide_data(self): 93 | return [(self.data_name[i], self.data_shape[i]) for i in range(len(self.data_name))] 94 | 95 | def reset(self): 96 | self.data_num = len(self.data) 97 | self.current = 0 98 | self.shuffle() 99 | if self.multi_thread: 100 | self.shutdown() 101 | self._insert_queue() 102 | self._thread_start() 103 | 104 | def get_batch_size(self): 105 | return self.batch_size 106 | 107 | def shutdown(self): 108 | if self.multi_thread: 109 | # clean queue 110 | while True: 111 | try: 112 | self.result_queue.get(timeout=1) 113 | except Queue.Empty: 114 | break 115 | while True: 116 | try: 117 | self.data_queue.get(timeout=1) 118 | except Queue.Empty: 119 | break 120 | # stop worker 121 | self.stop_flag = True 122 | if self.worker_proc: 123 | for i, worker in enumerate(self.worker_proc): 124 | worker.join(timeout=1) 125 | if worker.is_alive(): 126 | logging.error('worker {} is join fail'.format(i)) 127 | worker.terminate() 128 | 129 | def shuffle(self): 130 | random.shuffle(self.data) 131 | 132 | def next(self): 133 | if self._get_next(): 134 | return self.current_batch 135 | else: 136 | raise StopIteration 137 | 138 | def _get_next(self): 139 | batch_size = self.batch_size 140 | if self.current + batch_size > self.data_num: 141 | return False 142 | xs = [np.zeros(ds) for ds in self.data_shape] 143 | ys = [np.zeros(ls) for ls in self.label_shape] 144 | cnt = 0 145 | for i in range(self.current, self.current + batch_size): 146 | if self.multi_thread: 147 | image, label = self.result_queue.get() 148 | else: 149 | image, label = CityLoader._get_single(self.data[i], self.input_args) 150 | for j in range(len(image)): 151 | xs[j][cnt, :, :, :] = image[j] 152 | for j in range(len(label)): 153 | ys[j][cnt, :] = label[j] 154 | cnt += 1 155 | xs = [mx.ndarray.array(x) for x in xs] 156 | ys = [mx.ndarray.array(y) for y in ys] 157 | self.current_batch = mx.io.DataBatch(data=xs, label=ys, pad=0, index=None) 158 | self.current += batch_size 159 | return True 160 | 161 | @staticmethod 162 | def _get_single(item, input_args): 163 | return utils.get_single_image_duc(item, input_args) 164 | -------------------------------------------------------------------------------- /tusimple_duc/core/lr_scheduler.py: -------------------------------------------------------------------------------- 1 | """ 2 | poly learning rate scheduler, re-implement the 'poly' learning policy from caffe. 3 | learning rate decays as epochs grows 4 | """ 5 | from mxnet.lr_scheduler import LRScheduler 6 | import logging 7 | 8 | 9 | class PolyScheduler(LRScheduler): 10 | """Reduce learning rate in a power way 11 | Assume the weight has been updated by n times, then the learning rate will 12 | be 13 | base_lr * (floor(1-n/max_time))^factor 14 | Parameters 15 | ---------- 16 | origin_lr: int 17 | original learning rate 18 | max_samples: int 19 | schedule learning rate after n updates 20 | show_num: int 21 | show current learning rate after n updates 22 | factor: float 23 | the factor for reducing the learning rate 24 | """ 25 | def __init__(self, origin_lr, max_samples, show_num=10,factor=1, stop_factor_lr=1e-8): 26 | super(PolyScheduler, self).__init__() 27 | if max_samples < 1: 28 | raise ValueError("Schedule max time must be greater or equal than 1 round") 29 | if factor > 1.0: 30 | raise ValueError("Factor must be no more than 1 to make lr reduce") 31 | self.max_samples = max_samples 32 | self.factor = factor 33 | self.stop_factor_lr = stop_factor_lr 34 | self.count = 0 35 | self.origin_lr = origin_lr 36 | self.base_lr = origin_lr 37 | self.show_num = show_num 38 | 39 | def __call__(self, num_update): 40 | """ 41 | Call to schedule current learning rate 42 | Parameters 43 | ---------- 44 | num_update: int 45 | the maximal number of updates applied to a weight. 46 | """ 47 | if num_update > self.count: 48 | self.base_lr = self.origin_lr * pow((1 - 1.0*num_update/self.max_samples), self.factor) 49 | if self.base_lr < self.stop_factor_lr: 50 | self.base_lr = self.stop_factor_lr 51 | logging.info("Update[%d]: now learning rate arrived at %0.5e, will not " 52 | "change in the future", num_update, self.base_lr) 53 | elif num_update % self.show_num == 0: 54 | logging.info("Update[%d]: Change learning rate to %0.8e", 55 | num_update, self.base_lr) 56 | self.count = num_update 57 | return self.base_lr 58 | -------------------------------------------------------------------------------- /tusimple_duc/core/metrics.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import numpy as np 3 | import mxnet as mx 4 | from mxnet.metric import EvalMetric 5 | 6 | 7 | class CompositeEvalMetric(EvalMetric): 8 | """Manage multiple evaluation metrics.""" 9 | 10 | def __init__(self, **kwargs): 11 | super(CompositeEvalMetric, self).__init__('composite') 12 | try: 13 | self.metrics = kwargs['metrics'] 14 | except KeyError: 15 | self.metrics = [] 16 | 17 | def add(self, metric): 18 | self.metrics.append(metric) 19 | 20 | def get_metric(self, index): 21 | try: 22 | return self.metrics[index] 23 | except IndexError: 24 | return ValueError("Metric index {} is out of range 0 and {}".format( 25 | index, len(self.metrics))) 26 | 27 | def update(self, labels, preds): 28 | for metric in self.metrics: 29 | metric.update(labels, preds) 30 | 31 | def reset(self): 32 | try: 33 | for metric in self.metrics: 34 | metric.reset() 35 | except AttributeError: 36 | pass 37 | 38 | def get(self): 39 | names = [] 40 | results = [] 41 | for metric in self.metrics: 42 | result = metric.get() 43 | names.append(result[0]) 44 | results.append(result[1]) 45 | return names, results 46 | 47 | def print_log(self): 48 | names, results = self.get() 49 | logging.info('; '.join(['{}: {}'.format(name, val) for name, val in zip(names, results)])) 50 | 51 | 52 | def check_label_shapes(labels, preds, shape=0): 53 | if shape == 0: 54 | label_shape, pred_shape = len(labels), len(preds) 55 | else: 56 | label_shape, pred_shape = labels.shape, preds.shape 57 | 58 | if label_shape != pred_shape: 59 | raise ValueError("Shape of labels {} does not match shape of " 60 | "predictions {}".format(label_shape, pred_shape)) 61 | 62 | 63 | class AccWithIgnoreMetric(EvalMetric): 64 | def __init__(self, ignore_label, name='AccWithIgnore'): 65 | super(AccWithIgnoreMetric, self).__init__(name=name) 66 | self._ignore_label = ignore_label 67 | self._iter_size = 200 68 | self._nomin_buffer = [] 69 | self._denom_buffer = [] 70 | 71 | def update(self, labels, preds): 72 | check_label_shapes(labels, preds) 73 | for i in range(len(labels)): 74 | pred_label = mx.ndarray.argmax_channel(preds[i]).asnumpy().astype('int32') 75 | label = labels[i].asnumpy().astype('int32') 76 | 77 | check_label_shapes(label, pred_label) 78 | 79 | self.sum_metric += (pred_label.flat == label.flat).sum() 80 | self.num_inst += len(pred_label.flat) - (label.flat == self._ignore_label).sum() 81 | 82 | 83 | class IoUMetric(EvalMetric): 84 | def __init__(self, ignore_label, label_num, name='IoU'): 85 | self._ignore_label = ignore_label 86 | self._label_num = label_num 87 | super(IoUMetric, self).__init__(name=name) 88 | 89 | def reset(self): 90 | self._tp = [0.0] * self._label_num 91 | self._denom = [0.0] * self._label_num 92 | 93 | def update(self, labels, preds): 94 | check_label_shapes(labels, preds) 95 | for i in range(len(labels)): 96 | pred_label = mx.ndarray.argmax_channel(preds[i]).asnumpy().astype('int32') 97 | label = labels[i].asnumpy().astype('int32') 98 | 99 | check_label_shapes(label, pred_label) 100 | 101 | iou = 0 102 | eps = 1e-6 103 | # skip_label_num = 0 104 | for j in range(self._label_num): 105 | pred_cur = (pred_label.flat == j) 106 | gt_cur = (label.flat == j) 107 | tp = np.logical_and(pred_cur, gt_cur).sum() 108 | denom = np.logical_or(pred_cur, gt_cur).sum() - np.logical_and(pred_cur, label.flat == self._ignore_label).sum() 109 | assert tp <= denom 110 | self._tp[j] += tp 111 | self._denom[j] += denom 112 | iou += self._tp[j] / (self._denom[j] + eps) 113 | iou /= self._label_num 114 | self.sum_metric = iou 115 | self.num_inst = 1 116 | 117 | 118 | class SoftmaxLoss(EvalMetric): 119 | def __init__(self, ignore_label, label_num, name='OverallSoftmaxLoss'): 120 | super(SoftmaxLoss, self).__init__(name=name) 121 | self._ignore_label = ignore_label 122 | self._label_num = label_num 123 | 124 | def update(self, labels, preds): 125 | check_label_shapes(labels, preds) 126 | 127 | loss = 0.0 128 | cnt = 0.0 129 | eps = 1e-6 130 | for i in range(len(labels)): 131 | prediction = preds[i].asnumpy()[:] 132 | shape = prediction.shape 133 | if len(shape) == 4: 134 | shape = (shape[0], shape[1], shape[2]*shape[3]) 135 | prediction = prediction.reshape(shape) 136 | label = labels[i].asnumpy() 137 | soft_label = np.zeros(prediction.shape) 138 | for b in range(soft_label.shape[0]): 139 | for c in range(self._label_num): 140 | soft_label[b][c][label[b] == c] = 1.0 141 | 142 | loss += (-np.log(prediction[soft_label == 1] + eps)).sum() 143 | cnt += prediction[soft_label == 1].size 144 | self.sum_metric += loss 145 | self.num_inst += cnt 146 | 147 | -------------------------------------------------------------------------------- /tusimple_duc/core/utils.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import math 3 | import os 4 | import random 5 | import time 6 | from PIL import Image 7 | from datetime import datetime 8 | 9 | import cv2 as cv 10 | import mxnet as mx 11 | import numpy as np 12 | 13 | import cityscapes_labels 14 | 15 | 16 | # save symbol 17 | def save_symbol(net, net_prefix): 18 | net.save('%s-symbol.json' % net_prefix) 19 | 20 | 21 | # save parameters 22 | def save_parameter(net, net_prefix, data_shape): 23 | executor = net.simple_bind(mx.gpu(0), data=data_shape) 24 | arg_params = executor.arg_dict 25 | aux_params = executor.aux_dict 26 | 27 | save_dict = {('arg:%s' % k): v for k, v in arg_params.items()} 28 | save_dict.update({('aux:%s' % k): v for k, v in aux_params.items()}) 29 | param_name = '%s.params' % net_prefix 30 | mx.ndarray.save(param_name, save_dict) 31 | 32 | 33 | # save log 34 | def save_log(prefix, output_dir): 35 | fmt = '%(asctime)s %(message)s' 36 | date_fmt = '%m-%d %H:%M:%S' 37 | logging.basicConfig(level=logging.INFO, 38 | format=fmt, 39 | datefmt=date_fmt, 40 | filename=os.path.join(output_dir, 41 | prefix + '_' + datetime.now().strftime('%Y_%m_%d_%H:%M:%S') + '.log'), 42 | filemode='w') 43 | console = logging.StreamHandler() 44 | console.setLevel(logging.INFO) 45 | formatter = logging.Formatter(fmt=fmt, datefmt=date_fmt) 46 | console.setFormatter(formatter) 47 | logging.getLogger('').addHandler(console) 48 | 49 | 50 | # replace ids with train_ids 51 | def replace_city_labels(label_data): 52 | labels = cityscapes_labels.labels 53 | converted = np.ones(label_data.shape, dtype=np.float) * 255 54 | # id to trainId 55 | id2trainId = {label.id: label.trainId for label in labels} 56 | for id in id2trainId: 57 | trainId = id2trainId[id] 58 | converted[label_data == id] = trainId 59 | return converted 60 | 61 | 62 | # get the data of image and label for networks including a ye layer 63 | def get_single_image_duc(item, input_args): 64 | # parse options 65 | data_path = input_args.get('data_path') 66 | label_path = input_args.get('label_path', '') 67 | data_shape = input_args.get('data_shape') 68 | convert_label = input_args.get('convert_label', False) 69 | 70 | crop_sz = (data_shape[0][-1], data_shape[0][-2]) 71 | use_random_crop = input_args.get('use_random_crop', False) 72 | use_mirror = input_args.get('use_mirror', False) 73 | scale_factors = input_args.get('scale_factors', [1]) 74 | rgb_mean = input_args.get('rgb_mean', [128, 128, 128]) 75 | ignore_label = input_args.get('ignore_label', 255) 76 | stride = input_args.get('ds_rate', 8) 77 | cell_width = input_args.get('cell_width', 1) 78 | random_bound = input_args.get('random_bound') 79 | 80 | # read data, scale, and random crop 81 | im = cv.imread(os.path.join(data_path, item[0])) 82 | # change bgr to rgb 83 | im = im[:, :, [2, 1, 0]] 84 | 85 | im_size = (im.shape[0], im.shape[1]) 86 | scale_factor = random.choice(scale_factors) 87 | scaled_shape = (int(im_size[0]*scale_factor), int(im_size[1]*scale_factor)) 88 | random_bound = (int(random_bound[0]*scale_factor), int(random_bound[1]*scale_factor)) 89 | crop_coor = [int(int(c) * scale_factor) for c in item[-1]] 90 | 91 | if use_random_crop: 92 | x0 = crop_coor[0] + random.randint(-random_bound[0], random_bound[0]) - crop_sz[0] / 2 93 | y0 = crop_coor[1] + random.randint(-random_bound[1], random_bound[1]) - crop_sz[1] / 2 94 | else: 95 | # center crop 96 | x0 = crop_coor[0] - crop_sz[0] / 2 97 | y0 = crop_coor[1] - crop_sz[1] / 2 98 | x1 = x0 + crop_sz[0] 99 | y1 = y0 + crop_sz[1] 100 | 101 | # resize 102 | scaled_img = cv.resize(im, (scaled_shape[1], scaled_shape[0]), interpolation=cv.INTER_LINEAR) 103 | 104 | # crop and make boarder 105 | pad_w_left = max(0 - y0, 0) 106 | pad_w_right = max(y1 - scaled_shape[1], 0) 107 | pad_h_up = max(0 - x0, 0) 108 | pad_h_bottom = max(x1 - scaled_shape[0], 0) 109 | 110 | x0 += pad_h_up 111 | x1 += pad_h_up 112 | y0 += pad_w_left 113 | y1 += pad_w_left 114 | 115 | img_data = np.array(scaled_img, dtype=np.float) 116 | img_data = cv.copyMakeBorder(img_data, pad_h_up, pad_h_bottom, pad_w_left, pad_w_right, cv.BORDER_CONSTANT, 117 | value=list(rgb_mean)) 118 | img_data = img_data[x0:x1, y0:y1, :] 119 | img_data = np.transpose(img_data, (2, 0, 1)) 120 | 121 | # subtract rgb mean 122 | for i in range(3): 123 | img_data[i] -= rgb_mean[i] 124 | 125 | # read label 126 | img_label = np.array(Image.open(os.path.join(label_path, item[1]))) 127 | img_label = cv.resize(img_label, (scaled_shape[1], scaled_shape[0]), interpolation=cv.INTER_NEAREST) 128 | img_label = np.array(img_label, dtype=np.float) 129 | img_label = cv.copyMakeBorder(img_label, pad_h_up, pad_h_bottom, pad_w_left, pad_w_right, cv.BORDER_CONSTANT, 130 | value=ignore_label) 131 | img_label = img_label[x0:x1, y0:y1] 132 | 133 | # resize label according to down sample rate 134 | if cell_width > 1: 135 | img_label = cv.resize(img_label, (crop_sz[1] / cell_width, crop_sz[0] / cell_width), 136 | interpolation=cv.INTER_NEAREST) 137 | 138 | # use mirror 139 | if use_mirror and random.randint(0, 1) == 1: 140 | img_data = img_data[:, :, ::-1] 141 | img_label = img_label[:, ::-1] 142 | 143 | # convert label from label id to train id 144 | if convert_label: 145 | img_label = replace_city_labels(img_label) 146 | 147 | feat_height = int(math.ceil(float(crop_sz[0]) / stride)) 148 | feat_width = int(math.ceil(float(crop_sz[1]) / stride)) 149 | 150 | img_label = img_label.reshape((feat_height, stride / cell_width, feat_width, stride / cell_width)) 151 | img_label = np.transpose(img_label, (1, 3, 0, 2)) 152 | img_label = img_label.reshape((-1, feat_height, feat_width)) 153 | img_label = img_label.reshape(-1) 154 | return [img_data], [img_label] 155 | 156 | 157 | # get palette for coloring 158 | def get_palette(): 159 | # get palette 160 | trainId2colors = {label.trainId: label.color for label in cityscapes_labels.labels} 161 | palette = [0] * 256 * 3 162 | for trainId in trainId2colors: 163 | colors = trainId2colors[trainId] 164 | if trainId == 255: 165 | colors = (0, 0, 0) 166 | for i in range(3): 167 | palette[trainId * 3 + i] = colors[i] 168 | return palette 169 | 170 | 171 | # check point 172 | def do_checkpoint(prefix, interval): 173 | def _callback(iter_no, sym, arg, aux): 174 | if (iter_no + 1) % interval == 0: 175 | mx.model.save_checkpoint(prefix, iter_no + 1, sym, arg, aux) 176 | return _callback 177 | 178 | 179 | # speed calculator 180 | class Speedometer(object): 181 | def __init__(self, batch_size, frequent=50): 182 | self.batch_size = batch_size 183 | self.frequent = frequent 184 | self.tic = time.time() 185 | self.last_count = 0 186 | 187 | def __call__(self, param): 188 | if param.nbatch % self.frequent == 0: 189 | speed = self.frequent * self.batch_size / (time.time() - self.tic) 190 | logging.info('Epoch[%d] Batch [%d]\tSpeed: %.2f samples/sec', 191 | param.epoch, param.nbatch, speed) 192 | param.eval_metric.print_log() 193 | self.tic = time.time() 194 | 195 | 196 | # draw network 197 | def draw_network(net, title, data_shape=(8, 3, 224, 224)): 198 | t = mx.viz.plot_network(net, title=title, shape={'data': data_shape}) 199 | t.render() 200 | -------------------------------------------------------------------------------- /tusimple_duc/networks/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TuSimple/TuSimple-DUC/f3b05b21cf3252baf89c7b88667f55cc74cd5418/tusimple_duc/networks/__init__.py -------------------------------------------------------------------------------- /tusimple_duc/networks/network_duc_hdc.py: -------------------------------------------------------------------------------- 1 | import mxnet as mx 2 | from resnet import get_resnet_hdc 3 | 4 | 5 | def get_symbol_duc_hdc(label_num=19, ignore_label=255, bn_use_global_stats=True, 6 | aspp_num=4, aspp_stride=6, cell_cap=64, exp="cityscapes"): 7 | """ 8 | Get 9 | Parameters 10 | ---------- 11 | label_num: the number of labels 12 | ignore_label: id for ignore label 13 | bn_use_global_stats: whether batch normalizations should use global_stats 14 | aspp_num: number of ASPPs 15 | aspp_stride: stride of ASPPs 16 | cell_cap: capacity of a cell in dense upsampling convolutions 17 | exp: expression 18 | 19 | Returns 20 | ------- 21 | 22 | """ 23 | # Base Network 24 | res = get_resnet_hdc(bn_use_global_stats=bn_use_global_stats) 25 | 26 | # ASPP 27 | aspp_list = list() 28 | for i in range(aspp_num): 29 | pad = ((i + 1) * aspp_stride, (i + 1) * aspp_stride) 30 | dilate = pad 31 | conv_aspp=mx.symbol.Convolution(data=res, num_filter=cell_cap * label_num, kernel=(3, 3), pad=pad, 32 | dilate=dilate, name=('fc1_%s_c%d' % (exp, i)), workspace=8192) 33 | aspp_list.append(conv_aspp) 34 | 35 | summ = mx.symbol.ElementWiseSum(*aspp_list, name=('fc1_%s' % exp)) 36 | 37 | cls_score_reshape = mx.symbol.Reshape(data=summ, shape=(0, label_num, -1), name='cls_score_reshape') 38 | cls = mx.symbol.SoftmaxOutput(data=cls_score_reshape, multi_output=True, 39 | normalization='valid', use_ignore=True, ignore_label=ignore_label, name='seg_loss') 40 | return cls 41 | 42 | if __name__ == '__main__': 43 | symbol = get_symbol_duc_hdc(label_num=19, cell_cap=16) 44 | 45 | t = mx.viz.plot_network(symbol, shape={'data': (3, 3, 480, 480)}) 46 | t.render() 47 | -------------------------------------------------------------------------------- /tusimple_duc/networks/resnet.py: -------------------------------------------------------------------------------- 1 | import mxnet as mx 2 | no_bias = True 3 | use_global_stats = True 4 | fix_gamma = False 5 | bn_momentum = 0.9995 6 | eps = 1e-6 7 | 8 | 9 | def Conv(data, num_filter, kernel, stride=(1, 1), pad=(0, 0), dilate=(1, 1), name=None): 10 | conv = mx.symbol.Convolution(data=data, num_filter=num_filter, kernel=kernel, stride=stride, pad=pad, 11 | dilate=dilate, no_bias=no_bias, name=('%s' % name), workspace=4096) 12 | return conv 13 | 14 | 15 | def ReLU(data, name): 16 | return mx.symbol.Activation(data=data, act_type='relu', name=name) 17 | 18 | 19 | def Conv_AC(data, num_filter, kernel, stride=(1, 1), pad=(0, 0), dilate=(1, 1), name=None): 20 | conv = Conv(data=data, num_filter=num_filter, kernel=kernel, stride=stride, pad=pad, dilate=dilate, name=name) 21 | act = ReLU(data=conv, name=('%s_relu' % name)) 22 | return act 23 | 24 | 25 | def Conv_BN(data, num_filter, kernel, stride=(1, 1), pad=(0, 0), dilate=(1, 1), name=None, suffix=''): 26 | conv = Conv(data=data, num_filter=num_filter, kernel=kernel, stride=stride, pad=pad, dilate=dilate, name=name) 27 | bn = mx.symbol.BatchNorm(data=conv, name=('%s/bn' % suffix), eps=eps, use_global_stats=use_global_stats, 28 | momentum=bn_momentum, fix_gamma=fix_gamma) 29 | return bn 30 | 31 | 32 | def Conv_BN_AC(data, num_filter, kernel, stride=(1, 1), pad=(0, 0), dilate=(1, 1), name=None, suffix=''): 33 | conv = Conv_BN(data=data, num_filter=num_filter, kernel=kernel, stride=stride, pad=pad, dilate=dilate, 34 | name=name, suffix=suffix) 35 | act = ReLU(data=conv, name=('%s/relu' % name)) 36 | return act 37 | 38 | 39 | def ResidualFactory_o(data, num_1x1_a, num_3x3_b, num_1x1_c, dilate, suffix): 40 | branch1 = Conv_BN(data=data, num_filter=num_1x1_c, kernel=(1, 1), name=('conv%s_1x1_proj' % suffix), 41 | suffix=('conv%s_1x1_proj' % suffix), pad=(0, 0)) 42 | branch2a = Conv_BN_AC(data=data, num_filter=num_1x1_a, kernel=(1, 1), name=('conv%s_1x1_reduce' % suffix), 43 | suffix=('conv%s_1x1_reduce' % suffix), pad=(0, 0)) 44 | branch2b = Conv_BN_AC(data=branch2a, num_filter=num_3x3_b, kernel=(3, 3), name=('conv%s_3x3' % suffix), 45 | suffix=('conv%s_3x3' % suffix), pad=dilate, dilate=dilate) 46 | branch2c = Conv_BN(data=branch2b, num_filter=num_1x1_c, kernel=(1, 1), name=('conv%s_1x1_increase' % suffix), 47 | suffix=('conv%s_1x1_increase' % suffix), pad=(0, 0)) 48 | summ = mx.symbol.ElementWiseSum(*[branch2c, branch1], name=('conv%s' % suffix)) 49 | summ_ac = ReLU(data=summ, name=('res%s_relu' % suffix)) 50 | return summ_ac 51 | 52 | 53 | def ResidualFactory_x(data, num_1x1_a, num_3x3_b, num_1x1_c, dilate, suffix): 54 | branch2a = Conv_BN_AC(data=data, num_filter=num_1x1_a, kernel=(1, 1), name=('conv%s_1x1_reduce' % suffix), 55 | suffix=('conv%s_1x1_reduce' % suffix), pad=(0, 0)) 56 | branch2b = Conv_BN_AC(data=branch2a, num_filter=num_3x3_b, kernel=(3, 3), name=('conv%s_3x3' % suffix), 57 | suffix=('conv%s_3x3' % suffix), pad=dilate, dilate=dilate) 58 | branch2c = Conv_BN(data=branch2b, num_filter=num_1x1_c, kernel=(1, 1), name=('conv%s_1x1_increase' % suffix), 59 | suffix=('conv%s_1x1_increase' % suffix), pad=(0, 0)) 60 | summ = mx.symbol.ElementWiseSum(*[data, branch2c], name=('res%s' % suffix)) 61 | summ_ac = ReLU(data=summ, name=('res%s_relu' % suffix)) 62 | return summ_ac 63 | 64 | 65 | def ResidualFactory_d(data, num_1x1_a, num_3x3_b, num_1x1_c, suffix): 66 | branch1 = Conv_BN(data=data, num_filter=num_1x1_c, kernel=(1, 1), name=('conv%s_1x1_proj' % suffix), 67 | suffix=('conv%s_1x1_proj' % suffix), pad=(0, 0), stride=(2, 2)) 68 | branch2a = Conv_BN_AC(data=data, num_filter=num_1x1_a, kernel=(1, 1), name=('conv%s_1x1_reduce' % suffix), 69 | suffix=('conv%s_1x1_reduce' % suffix), pad=(0, 0), stride=(2, 2)) 70 | branch2b = Conv_BN_AC(data=branch2a, num_filter=num_3x3_b, kernel=(3, 3), name=('conv%s_3x3' % suffix), 71 | suffix=('conv%s_3x3' % suffix), pad=(1, 1)) 72 | branch2c = Conv_BN(data=branch2b, num_filter=num_1x1_c, kernel=(1, 1), name=('conv%s_1x1_increase' % suffix), 73 | suffix=('conv%s_1x1_increase' % suffix), pad=(0, 0)) 74 | summ = mx.symbol.ElementWiseSum(*[branch2c, branch1], name=('res%s' % suffix)) 75 | summ_ac = ReLU(data=summ, name=('res%s_relu' % suffix)) 76 | return summ_ac 77 | 78 | 79 | def get_resnet_hdc(bn_use_global_stats=True): 80 | """ 81 | Get resnet with hybrid dilated convolutions 82 | Parameters 83 | ---------- 84 | bn_use_global_stats: whether the batch normalization layers should use global stats 85 | 86 | Returns the symbol generated 87 | ------- 88 | 89 | """ 90 | global use_global_stats 91 | use_global_stats = bn_use_global_stats 92 | 93 | data = mx.symbol.Variable(name="data") 94 | 95 | # group 1 96 | res1_1 = Conv_BN_AC(data=data, num_filter=64, kernel=(3, 3), name='conv1_1_3x3_s2', suffix='conv1_1_3x3_s2', pad=(1, 1), stride=(2, 2)) 97 | res1_2 = Conv_BN_AC(data=res1_1, num_filter=64, kernel=(3, 3), name='conv1_2_3x3', suffix='conv1_2_3x3', pad=(1, 1), stride=(1, 1)) 98 | res1_3 = Conv_BN_AC(data=res1_2, num_filter=128, kernel=(3, 3), name='conv1_3_3x3', suffix='conv1_3_3x3', pad=(1, 1), stride=(1, 1)) 99 | pool1 = mx.symbol.Pooling(data=res1_3, pool_type="max", kernel=(3, 3), stride=(2, 2), name="pool1_3x3_s2") 100 | 101 | # group 2 102 | res2a = ResidualFactory_o(pool1, 64, 64, 256, (1, 1), '2_1') 103 | res2b = ResidualFactory_x(res2a, 64, 64, 256, (1, 1), '2_2') 104 | res2c = ResidualFactory_x(res2b, 64, 64, 256, (1, 1), '2_3') 105 | 106 | # group 3 107 | res3a = ResidualFactory_d(res2c, 128, 128, 512, '3_1') 108 | res3b1 = ResidualFactory_x(res3a, 128, 128, 512, (1, 1), '3_2') 109 | res3b2 = ResidualFactory_x(res3b1, 128, 128, 512, (1, 1), '3_3') 110 | res3b3 = ResidualFactory_x(res3b2, 128, 128, 512, (1, 1), '3_4') 111 | 112 | # group 4 113 | res4a = ResidualFactory_o(res3b3, 256, 256, 1024, (2, 2), '4_1') 114 | res4b1 = ResidualFactory_x(res4a, 256, 256, 1024, (2, 2), '4_2') 115 | res4b2 = ResidualFactory_x(res4b1, 256, 256, 1024, (5, 5), '4_3') 116 | res4b3 = ResidualFactory_x(res4b2, 256, 256, 1024, (9, 9), '4_4') 117 | res4b4 = ResidualFactory_x(res4b3, 256, 256, 1024, (1, 1), '4_5') 118 | res4b5 = ResidualFactory_x(res4b4, 256, 256, 1024, (2, 2), '4_6') 119 | res4b6 = ResidualFactory_x(res4b5, 256, 256, 1024, (5, 5), '4_7') 120 | res4b7 = ResidualFactory_x(res4b6, 256, 256, 1024, (9, 9), '4_8') 121 | res4b8 = ResidualFactory_x(res4b7, 256, 256, 1024, (1, 1), '4_9') 122 | res4b9 = ResidualFactory_x(res4b8, 256, 256, 1024, (2, 2), '4_10') 123 | res4b10 = ResidualFactory_x(res4b9, 256, 256, 1024, (5, 5), '4_11') 124 | res4b11 = ResidualFactory_x(res4b10, 256, 256, 1024, (9, 9), '4_12') 125 | res4b12 = ResidualFactory_x(res4b11, 256, 256, 1024, (1, 1), '4_13') 126 | res4b13 = ResidualFactory_x(res4b12, 256, 256, 1024, (2, 2), '4_14') 127 | res4b14 = ResidualFactory_x(res4b13, 256, 256, 1024, (5, 5), '4_15') 128 | res4b15 = ResidualFactory_x(res4b14, 256, 256, 1024, (9, 9), '4_16') 129 | res4b16 = ResidualFactory_x(res4b15, 256, 256, 1024, (1, 1), '4_17') 130 | res4b17 = ResidualFactory_x(res4b16, 256, 256, 1024, (2, 2), '4_18') 131 | res4b18 = ResidualFactory_x(res4b17, 256, 256, 1024, (5, 5), '4_19') 132 | res4b19 = ResidualFactory_x(res4b18, 256, 256, 1024, (9, 9), '4_20') 133 | res4b20 = ResidualFactory_x(res4b19, 256, 256, 1024, (1, 1), '4_21') 134 | res4b21 = ResidualFactory_x(res4b20, 256, 256, 1024, (2, 2), '4_22') 135 | res4b22 = ResidualFactory_x(res4b21, 256, 256, 1024, (5, 5), '4_23') 136 | # group 5 137 | res5a = ResidualFactory_o(res4b22, 512, 512, 2048, (5, 5), '5_1') 138 | res5b = ResidualFactory_x(res5a, 512, 512, 2048, (9, 9), '5_2') 139 | res5c = ResidualFactory_x(res5b, 512, 512, 2048, (17, 17), '5_3') 140 | return res5c 141 | -------------------------------------------------------------------------------- /tusimple_duc/test/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TuSimple/TuSimple-DUC/f3b05b21cf3252baf89c7b88667f55cc74cd5418/tusimple_duc/test/__init__.py -------------------------------------------------------------------------------- /tusimple_duc/test/predictor.py: -------------------------------------------------------------------------------- 1 | import mxnet as mx 2 | import cv2 as cv 3 | import numpy as np 4 | 5 | 6 | class Predictor: 7 | def __init__(self, modules, label_num, ds_rate=8, cell_width=2, 8 | result_shape=(1024, 2048), test_scales=[1]): 9 | self._modules = modules 10 | self._label_num = label_num 11 | 12 | self._ds_rate = ds_rate 13 | self._cell_width = cell_width 14 | self._rpn_width = self._ds_rate / self._cell_width 15 | self._result_shape = result_shape 16 | self._test_scales = test_scales 17 | self._im_shape = None 18 | 19 | def predict(self, imgs): 20 | result_height, result_width = self._result_shape 21 | label_list = [] 22 | 23 | # multi scale test 24 | for index, test_scale in enumerate(self._test_scales): 25 | _, _, img_height, img_width = imgs[index].shape 26 | 27 | class CustomNDArrayIter(mx.io.NDArrayIter): 28 | @property 29 | def provide_data(self): 30 | return [('data', self.data.shape)] 31 | data_iter = CustomNDArrayIter(imgs[index], np.zeros(1), 1, shuffle=False) 32 | labels = self._modules[index].predict(data_iter).asnumpy().squeeze() 33 | test_width = (int(img_width) / self._ds_rate) * self._ds_rate 34 | test_height = (int(img_height) / self._ds_rate) * self._ds_rate 35 | feat_width = test_width / self._ds_rate 36 | feat_height = test_height / self._ds_rate 37 | # re-arrange duc results 38 | labels = labels.reshape((self._label_num, self._ds_rate/self._cell_width, self._ds_rate/self._cell_width, 39 | feat_height, feat_width)) 40 | labels = np.transpose(labels, (0, 3, 1, 4, 2)) 41 | labels = labels.reshape((self._label_num, test_height / self._cell_width, test_width / self._cell_width)) 42 | 43 | labels = labels[:, :int(img_height / self._cell_width), 44 | :int(img_width / self._cell_width)] 45 | labels = np.transpose(labels, [1, 2, 0]) 46 | labels = cv.resize(labels, (result_width, result_height), interpolation=cv.INTER_LINEAR) 47 | labels = np.transpose(labels, [2, 0, 1]) 48 | label_list.append(labels) 49 | labels = np.array(label_list).sum(axis=0) / len(label_list) 50 | 51 | return labels 52 | -------------------------------------------------------------------------------- /tusimple_duc/test/tester.py: -------------------------------------------------------------------------------- 1 | import math 2 | import os 3 | 4 | import cv2 as cv 5 | import mxnet as mx 6 | import numpy as np 7 | from PIL import Image 8 | 9 | from predictor import Predictor 10 | 11 | from tusimple_duc.core import utils 12 | from tusimple_duc.core import cityscapes_labels 13 | 14 | 15 | class Tester: 16 | def __init__(self, config): 17 | self.config = config 18 | # model 19 | self.model_dir = config.get('model', 'model_dir') 20 | self.model_prefix = config.get('model', 'model_prefix') 21 | self.model_epoch = config.getint('model', 'model_epoch') 22 | self.label_num = config.getint('model', 'label_num') 23 | self.ctx = mx.gpu(config.getint('model', 'gpu')) 24 | 25 | # data 26 | self.ds_rate = int(config.get('data', 'ds_rate')) 27 | self.cell_width = int(config.get('data', 'cell_width')) 28 | self.test_shape = [int(f) for f in config.get('data', 'test_shape').split(',')] 29 | self.result_shape = [int(f) for f in config.get('data', 'result_shape').split(',')] 30 | self.rgb_mean = [float(f) for f in config.get('data', 'rgb_mean').split(',')] 31 | # rescale for test 32 | self.test_scales = [float(f) for f in config.get('data', 'test_scales').split(',')] 33 | self.cell_shapes = [[math.ceil(l * s / self.ds_rate)*self.ds_rate for l in self.test_shape] 34 | for s in self.test_scales] 35 | self.modules = [] 36 | for i, test_scale in enumerate(self.test_scales): 37 | predictor = mx.module.Module.load( 38 | prefix=os.path.join(self.model_dir, self.model_prefix), 39 | epoch=self.model_epoch, 40 | context=self.ctx) 41 | data_shape = (1, 3, int(self.cell_shapes[i][0]), int(self.cell_shapes[i][1])) 42 | predictor.bind(data_shapes=[('data', data_shape)], for_training=False) 43 | self.modules.append(predictor) 44 | self.predictor = Predictor( 45 | modules=self.modules, 46 | label_num=self.label_num, 47 | ds_rate=self.ds_rate, 48 | cell_width=self.cell_width, 49 | result_shape=self.result_shape, 50 | test_scales=self.test_scales 51 | ) 52 | 53 | def preprocess(self, im): 54 | imgs = [] 55 | for index, test_scale in enumerate(self.test_scales): 56 | # resize to test scale 57 | test_img = cv.resize(im, (int(im.shape[1] * test_scale), int(im.shape[0] * test_scale)), 58 | interpolation=cv.INTER_LINEAR) 59 | test_img = test_img.astype(np.float32)[:int(self.test_shape[0] * test_scale), 60 | :int(self.test_shape[1] * test_scale)] 61 | test_img = cv.copyMakeBorder(test_img, 0, max(0, int(self.cell_shapes[index][0] * test_scale) - im.shape[0]), 62 | 0, max(0, int(self.cell_shapes[index][1] * test_scale) - im.shape[1]), 63 | cv.BORDER_CONSTANT, value=self.rgb_mean) 64 | 65 | test_img = np.transpose(test_img, (2, 0, 1)) 66 | # subtract rbg mean 67 | for i in range(3): 68 | test_img[i] -= self.rgb_mean[i] 69 | test_img = np.expand_dims(test_img, axis=0) 70 | mx.ndarray.array(test_img) 71 | imgs.append(test_img) 72 | return imgs 73 | 74 | 75 | @staticmethod 76 | def convert_label(label): 77 | cvt_label = np.zeros(label.shape) 78 | for l in cityscapes_labels.labels: 79 | cvt_label[label == l.trainId] = cityscapes_labels.trainId2label[l.trainId].id 80 | return cvt_label 81 | 82 | @staticmethod 83 | def colorize(labels): 84 | """ 85 | colorize the labels with predefined palette 86 | :param labels: labels organized in their train ids 87 | :return: a segmented result of colorful image as numpy array in RGB order 88 | """ 89 | # label 90 | result_img = Image.fromarray(labels).convert('P') 91 | result_img.putpalette(utils.get_palette()) 92 | return np.array(result_img.convert('RGB')) 93 | 94 | def predict_single(self, img, ret_converted=False, ret_softmax=False, ret_heat_map=False): 95 | """ 96 | predict single image by predefined models and configuration 97 | :param img: image array 98 | :param ret_converted: whether return labels with their original ids or train ids 99 | :param ret_softmax: whether return softmax results 100 | :param ret_heat_map: whether return heat map results 101 | """ 102 | 103 | rets = {} 104 | imgs = self.preprocess(img) 105 | labels = self.predictor.predict(imgs) 106 | 107 | # return softmax results 108 | if ret_softmax: 109 | rets['softmax'] = labels 110 | # feature_symbol = self.checkpoint.symbol.get_internals()['conv5_3_relu_output'] 111 | # feature_model = mx.model.FeedForward(symbol=feature_symbol, arg_params = self.checkpoint.arg_params, 112 | # aux_params = self.checkpoint.aux_params, ctx = self.ctx, 113 | # allow_extra_params = True) 114 | # features = feature_model.predict(self.prepocess(im, 1)).squeeze() 115 | # np.savez(os.path.join(self.save_softmax_dir, img_name.replace('.jpg', '')), softmax=features) 116 | # return heat map 117 | if ret_heat_map: 118 | heat = np.max(labels, axis=0) 119 | heat = heat * 256 - 1 120 | heat_map = cv.applyColorMap(heat.astype(np.uint8), cv.COLORMAP_JET) 121 | rets['heat_map'] = heat_map 122 | 123 | results = np.argmax(labels, axis=0).astype(np.uint8) 124 | rets['raw'] = results 125 | 126 | # return converted labels in their original ids rather than train ids 127 | if ret_converted: 128 | rets['converted'] = self.convert_label(results) 129 | return rets 130 | --------------------------------------------------------------------------------