├── .gitignore ├── LICENSE ├── README.md ├── configs └── trans10K │ ├── translab.yaml │ └── translab_bs4.yaml ├── demo ├── imgs │ ├── 1.png │ ├── 2.png │ ├── 3.png │ └── 4.png └── result │ ├── 1_glass.png │ ├── 2_glass.png │ ├── 3_glass.png │ └── 4_glass.png ├── segmentron ├── __init__.py ├── config │ ├── __init__.py │ ├── config.py │ └── settings.py ├── data │ ├── __init__.py │ └── dataloader │ │ ├── __init__.py │ │ ├── seg_data_base.py │ │ ├── trans10k.py │ │ ├── trans10k_boundary.py │ │ ├── trans10k_extra.py │ │ └── utils.py ├── models │ ├── __init__.py │ ├── backbones │ │ ├── __init__.py │ │ ├── build.py │ │ ├── eespnet.py │ │ ├── hrnet.py │ │ ├── mobilenet.py │ │ ├── resnet.py │ │ └── xception.py │ ├── model_zoo.py │ ├── segbase.py │ └── translab.py ├── modules │ ├── __init__.py │ ├── basic.py │ ├── batch_norm.py │ ├── cc_attention.py │ ├── csrc │ │ ├── criss_cross_attention │ │ │ ├── ca.h │ │ │ └── ca_cuda.cu │ │ └── vision.cpp │ ├── module.py │ └── sync_bn │ │ └── syncbn.py ├── solver │ ├── __init__.py │ ├── loss.py │ ├── lovasz_losses.py │ ├── lr_scheduler.py │ └── optimizer.py └── utils │ ├── __init__.py │ ├── default_setup.py │ ├── distributed.py │ ├── download.py │ ├── env.py │ ├── filesystem.py │ ├── logger.py │ ├── options.py │ ├── parallel.py │ ├── registry.py │ ├── score.py │ └── visualize.py ├── setup.py └── tools ├── dist_test.sh ├── dist_train.sh ├── test_demo.py ├── test_translab.py └── train.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | *.egg-info/ 24 | .installed.cfg 25 | *.egg 26 | MANIFEST 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *.cover 47 | .hypothesis/ 48 | .pytest_cache/ 49 | 50 | # Translations 51 | *.mo 52 | *.pot 53 | 54 | # Django stuff: 55 | *.log 56 | local_settings.py 57 | db.sqlite3 58 | 59 | # Flask stuff: 60 | instance/ 61 | .webassets-cache 62 | 63 | # Scrapy stuff: 64 | .scrapy 65 | 66 | # Sphinx documentation 67 | docs/_build/ 68 | 69 | # PyBuilder 70 | target/ 71 | 72 | # Jupyter Notebook 73 | .ipynb_checkpoints 74 | 75 | # pyenv 76 | .python-version 77 | 78 | # celery beat schedule file 79 | celerybeat-schedule 80 | 81 | # SageMath parsed files 82 | *.sage.py 83 | 84 | # Environments 85 | .env 86 | .venv 87 | env/ 88 | venv/ 89 | ENV/ 90 | env.bak/ 91 | venv.bak/ 92 | 93 | # Spyder project settings 94 | .spyderproject 95 | .spyproject 96 | 97 | # Rope project settings 98 | .ropeproject 99 | 100 | # mkdocs documentation 101 | /site 102 | 103 | # mypy 104 | .mypy_cache/ 105 | 106 | # mac os 107 | __MACOSX/ 108 | 109 | #model 110 | trash 111 | workdirs 112 | *.bak 113 | datasets 114 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Segment_Transparent_Objects 2 | ## Introduce 3 | This repository contains the data and code for ECCV2020 paper [Segmenting Transparent Objects in the Wild](https://arxiv.org/abs/2003.13948). 4 | 5 | For downloading the data, you can refer to [Trans10K Website](https://xieenze.github.io/projects/TransLAB/TransLAB.html). 6 | 7 | 8 | ## Environments 9 | 10 | - python 3 11 | - torch = 1.1.0 (>1.1.0 with cause performance drop, we can't find the reason) 12 | - torchvision 13 | - pyyaml 14 | - Pillow 15 | - numpy 16 | 17 | ## INSTALL 18 | 19 | ``` 20 | python setup.py develop 21 | ``` 22 | ## Pretrained Models and Logs 23 | We provide the trained models and logs for TransLab. 24 | [Google Drive](https://drive.google.com/drive/folders/1yJMEB4rNKIZt5IWL13Nn-YwckrvAPNuz?usp=sharing) 25 | 26 | ## Demo 27 | 1. put the images in './demo/imgs' 28 | 2. download the trained model from [Google Drive](https://drive.google.com/drive/folders/1yJMEB4rNKIZt5IWL13Nn-YwckrvAPNuz?usp=sharing) 29 | , and put it in './demo/16.pth' 30 | 3. run this script 31 | ``` 32 | CUDA_VISIBLE_DEVICES=0 python -u ./tools/test_demo.py --config-file configs/trans10K/translab.yaml TEST.TEST_MODEL_PATH ./demo/16.pth DEMO_DIR ./demo/imgs 33 | ``` 34 | 4. the results are generated in './demo/results' 35 | 36 | 37 | ## Data Preparation 38 | 1. create dirs './datasets/Trans10K' 39 | 2. download the data from [Trans10K Website](https://xieenze.github.io/projects/TransLAB/TransLAB.html). 40 | 3. put the train/validation/test data under './datasets/Trans10K'. Data Structure is shown below. 41 | ``` 42 | Trans10K/ 43 | ├── test 44 | │   ├── easy 45 | │   └── hard 46 | ├── train 47 | │   ├── images 48 | │   └── masks 49 | └── validation 50 | ├── easy 51 | └── hard 52 | ``` 53 | ## Pretrained backbone models 54 | 55 | pretrained backbone models will be download automatically in pytorch default directory(```~/.cache/torch/checkpoints/```). 56 | 57 | ## Train 58 | Our experiments are based on one machine with 8 V100 GPUs(32g memory), if you face memory error, you can try the 'batchsize=4' version. 59 | ### Train with batchsize=8(cost 15G memory) 60 | ``` 61 | bash tools/dist_train.sh configs/trans10K/translab.yaml 8 TRAIN.MODEL_SAVE_DIR workdirs/translab_bs8 62 | ``` 63 | ### Train with batchsize=4(cost 8G memory) 64 | ``` 65 | bash tools/dist_train.sh configs/trans10K/translab_bs4.yaml 8 TRAIN.MODEL_SAVE_DIR workdirs/translab_bs4 66 | ``` 67 | 68 | ## Eval 69 | for example (batchsize=8) 70 | ``` 71 | CUDA_VISIBLE_DEVICES=0 python -u ./tools/test_translab.py --config-file configs/trans10K/translab.yaml TEST.TEST_MODEL_PATH workdirs/translab_bs8/16.pth 72 | ``` 73 | 74 | ## License 75 | 76 | For academic use, this project is licensed under the Apache License - see the LICENSE file for details. For commercial use, please contact the authors. 77 | 78 | ## Citations 79 | Please consider citing our paper in your publications if the project helps your research. BibTeX reference is as follows. 80 | 81 | ``` 82 | @article{xie2020segmenting, 83 | title={Segmenting Transparent Objects in the Wild}, 84 | author={Xie, Enze and Wang, Wenjia and Wang, Wenhai and Ding, Mingyu and Shen, Chunhua and Luo, Ping}, 85 | journal={arXiv preprint arXiv:2003.13948}, 86 | year={2020} 87 | } 88 | ``` 89 | -------------------------------------------------------------------------------- /configs/trans10K/translab.yaml: -------------------------------------------------------------------------------- 1 | DATASET: 2 | NAME: "trans10k_boundary" 3 | MEAN: [0.485, 0.456, 0.406] 4 | STD: [0.229, 0.224, 0.225] 5 | TRAIN: 6 | EPOCHS: 16 7 | BATCH_SIZE: 8 8 | CROP_SIZE: 769 9 | MODEL_SAVE_DIR: 'workdirs/debug' 10 | TEST: 11 | BATCH_SIZE: 1 12 | 13 | SOLVER: 14 | LR: 0.02 15 | 16 | MODEL: 17 | MODEL_NAME: "TransLab" 18 | BACKBONE: "resnet50" 19 | 20 | -------------------------------------------------------------------------------- /configs/trans10K/translab_bs4.yaml: -------------------------------------------------------------------------------- 1 | DATASET: 2 | NAME: "trans10k_boundary" 3 | MEAN: [0.485, 0.456, 0.406] 4 | STD: [0.229, 0.224, 0.225] 5 | TRAIN: 6 | EPOCHS: 16 7 | BATCH_SIZE: 4 8 | CROP_SIZE: 769 9 | MODEL_SAVE_DIR: 'workdirs/debug' 10 | TEST: 11 | BATCH_SIZE: 1 12 | 13 | SOLVER: 14 | LR: 0.01 15 | 16 | MODEL: 17 | MODEL_NAME: "TransLab" 18 | BACKBONE: "resnet50" 19 | 20 | -------------------------------------------------------------------------------- /demo/imgs/1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xieenze/Segment_Transparent_Objects/06a9c806d32fec178e37700095f0c5443a4f109a/demo/imgs/1.png -------------------------------------------------------------------------------- /demo/imgs/2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xieenze/Segment_Transparent_Objects/06a9c806d32fec178e37700095f0c5443a4f109a/demo/imgs/2.png -------------------------------------------------------------------------------- /demo/imgs/3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xieenze/Segment_Transparent_Objects/06a9c806d32fec178e37700095f0c5443a4f109a/demo/imgs/3.png -------------------------------------------------------------------------------- /demo/imgs/4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xieenze/Segment_Transparent_Objects/06a9c806d32fec178e37700095f0c5443a4f109a/demo/imgs/4.png -------------------------------------------------------------------------------- /demo/result/1_glass.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xieenze/Segment_Transparent_Objects/06a9c806d32fec178e37700095f0c5443a4f109a/demo/result/1_glass.png -------------------------------------------------------------------------------- /demo/result/2_glass.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xieenze/Segment_Transparent_Objects/06a9c806d32fec178e37700095f0c5443a4f109a/demo/result/2_glass.png -------------------------------------------------------------------------------- /demo/result/3_glass.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xieenze/Segment_Transparent_Objects/06a9c806d32fec178e37700095f0c5443a4f109a/demo/result/3_glass.png -------------------------------------------------------------------------------- /demo/result/4_glass.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xieenze/Segment_Transparent_Objects/06a9c806d32fec178e37700095f0c5443a4f109a/demo/result/4_glass.png -------------------------------------------------------------------------------- /segmentron/__init__.py: -------------------------------------------------------------------------------- 1 | from . import modules, models, utils, data -------------------------------------------------------------------------------- /segmentron/config/__init__.py: -------------------------------------------------------------------------------- 1 | from .settings import cfg -------------------------------------------------------------------------------- /segmentron/config/config.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | from __future__ import absolute_import 3 | from __future__ import division 4 | from __future__ import unicode_literals 5 | 6 | import codecs 7 | import yaml 8 | import six 9 | import time 10 | 11 | from ast import literal_eval 12 | 13 | class SegmentronConfig(dict): 14 | def __init__(self, *args, **kwargs): 15 | super(SegmentronConfig, self).__init__(*args, **kwargs) 16 | self.immutable = False 17 | 18 | def __setattr__(self, key, value, create_if_not_exist=True): 19 | if key in ["immutable"]: 20 | self.__dict__[key] = value 21 | return 22 | 23 | t = self 24 | keylist = key.split(".") 25 | for k in keylist[:-1]: 26 | t = t.__getattr__(k, create_if_not_exist) 27 | 28 | t.__getattr__(keylist[-1], create_if_not_exist) 29 | t[keylist[-1]] = value 30 | 31 | def __getattr__(self, key, create_if_not_exist=True): 32 | if key in ["immutable"]: 33 | if key not in self.__dict__: 34 | self.__dict__[key] = False 35 | return self.__dict__[key] 36 | 37 | if not key in self: 38 | if not create_if_not_exist: 39 | raise KeyError 40 | self[key] = SegmentronConfig() 41 | return self[key] 42 | 43 | def __setitem__(self, key, value): 44 | # 45 | if self.immutable: 46 | raise AttributeError( 47 | 'Attempted to set "{}" to "{}", but SegConfig is immutable'. 48 | format(key, value)) 49 | # 50 | if isinstance(value, six.string_types): 51 | try: 52 | value = literal_eval(value) 53 | except ValueError: 54 | pass 55 | except SyntaxError: 56 | pass 57 | super(SegmentronConfig, self).__setitem__(key, value) 58 | 59 | def update_from_other_cfg(self, other): 60 | if isinstance(other, dict): 61 | other = SegmentronConfig(other) 62 | assert isinstance(other, SegmentronConfig) 63 | cfg_list = [("", other)] 64 | while len(cfg_list): 65 | prefix, tdic = cfg_list[0] 66 | cfg_list = cfg_list[1:] 67 | for key, value in tdic.items(): 68 | key = "{}.{}".format(prefix, key) if prefix else key 69 | if isinstance(value, dict): 70 | cfg_list.append((key, value)) 71 | continue 72 | try: 73 | self.__setattr__(key, value, create_if_not_exist=False) 74 | except KeyError: 75 | raise KeyError('Non-existent config key: {}'.format(key)) 76 | 77 | def remove_irrelevant_cfg(self): 78 | model_name = self.MODEL.MODEL_NAME 79 | 80 | from ..models.model_zoo import MODEL_REGISTRY 81 | model_list = MODEL_REGISTRY.get_list() 82 | model_list_lower = [x.lower() for x in model_list] 83 | # print('model_list:', model_list) 84 | assert model_name.lower() in model_list_lower, "Expected model name in {}, but received {}"\ 85 | .format(model_list, model_name) 86 | pop_keys = [] 87 | for key in self.MODEL.keys(): 88 | if key.lower() in model_list_lower and key.lower() != model_name.lower(): 89 | pop_keys.append(key) 90 | for key in pop_keys: 91 | self.MODEL.pop(key) 92 | 93 | 94 | 95 | def check_and_freeze(self): 96 | self.TIME_STAMP = time.strftime('%Y-%m-%d-%H-%M', time.localtime()) 97 | # TODO: remove irrelevant config and then freeze 98 | self.remove_irrelevant_cfg() 99 | self.immutable = True 100 | 101 | def update_from_list(self, config_list): 102 | if len(config_list) % 2 != 0: 103 | raise ValueError( 104 | "Command line options config format error! Please check it: {}". 105 | format(config_list)) 106 | for key, value in zip(config_list[0::2], config_list[1::2]): 107 | try: 108 | self.__setattr__(key, value, create_if_not_exist=False) 109 | except KeyError: 110 | raise KeyError('Non-existent config key: {}'.format(key)) 111 | 112 | def update_from_file(self, config_file): 113 | with codecs.open(config_file, 'r', 'utf-8') as file: 114 | loaded_cfg = yaml.load(file, Loader=yaml.FullLoader) 115 | self.update_from_other_cfg(loaded_cfg) 116 | 117 | def set_immutable(self, immutable): 118 | self.immutable = immutable 119 | for value in self.values(): 120 | if isinstance(value, SegmentronConfig): 121 | value.set_immutable(immutable) 122 | 123 | def is_immutable(self): 124 | return self.immutable -------------------------------------------------------------------------------- /segmentron/config/settings.py: -------------------------------------------------------------------------------- 1 | from .config import SegmentronConfig 2 | 3 | cfg = SegmentronConfig() 4 | 5 | ########################## basic set ########################################### 6 | # random seed 7 | cfg.SEED = 1024 8 | # train time stamp, auto generate, do not need to set 9 | cfg.TIME_STAMP = '' 10 | # root path 11 | cfg.ROOT_PATH = '' 12 | # model phase ['train', 'test'] 13 | cfg.PHASE = 'train' 14 | 15 | ########################## dataset config ######################################### 16 | # dataset name 17 | cfg.DATASET.NAME = '' 18 | # pixel mean 19 | cfg.DATASET.MEAN = [0.5, 0.5, 0.5] 20 | # pixel std 21 | cfg.DATASET.STD = [0.5, 0.5, 0.5] 22 | # dataset ignore index 23 | cfg.DATASET.IGNORE_INDEX = -1 24 | # workers 25 | cfg.DATASET.WORKERS = 8 26 | # val dataset mode 27 | cfg.DATASET.MODE = 'testval' 28 | ########################### data augment ###################################### 29 | # data augment image mirror 30 | cfg.AUG.MIRROR = True 31 | # blur probability 32 | cfg.AUG.BLUR_PROB = 0.0 33 | # blur radius 34 | cfg.AUG.BLUR_RADIUS = 0.0 35 | # color jitter, float or tuple: (0.1, 0.2, 0.3, 0.4) 36 | cfg.AUG.COLOR_JITTER = None 37 | ########################### train config ########################################## 38 | # epochs 39 | cfg.TRAIN.EPOCHS = 30 40 | # batch size 41 | cfg.TRAIN.BATCH_SIZE = 1 42 | # train crop size 43 | cfg.TRAIN.CROP_SIZE = 769 44 | # train base size 45 | cfg.TRAIN.BASE_SIZE = 512 46 | # model output dir 47 | cfg.TRAIN.MODEL_SAVE_DIR = 'workdirs/' 48 | # log dir 49 | cfg.TRAIN.LOG_SAVE_DIR = cfg.TRAIN.MODEL_SAVE_DIR 50 | # pretrained model for eval or finetune 51 | cfg.TRAIN.PRETRAINED_MODEL_PATH = '' 52 | # use pretrained backbone model over imagenet 53 | cfg.TRAIN.BACKBONE_PRETRAINED = True 54 | # backbone pretrained model path, if not specific, will load from url when backbone pretrained enabled 55 | cfg.TRAIN.BACKBONE_PRETRAINED_PATH = '' 56 | # resume model path 57 | cfg.TRAIN.RESUME_MODEL_PATH = '' 58 | # whether to use synchronize bn 59 | cfg.TRAIN.SYNC_BATCH_NORM = True 60 | # save model every checkpoint-epoch 61 | cfg.TRAIN.SNAPSHOT_EPOCH = 1 62 | 63 | ########################### optimizer config ################################## 64 | # base learning rate 65 | cfg.SOLVER.LR = 1e-4 66 | # optimizer method 67 | cfg.SOLVER.OPTIMIZER = "sgd" 68 | # optimizer epsilon 69 | cfg.SOLVER.EPSILON = 1e-8 70 | # optimizer momentum 71 | cfg.SOLVER.MOMENTUM = 0.9 72 | # weight decay 73 | cfg.SOLVER.WEIGHT_DECAY = 1e-4 #0.00004 74 | # decoder lr x10 75 | cfg.SOLVER.DECODER_LR_FACTOR = 10.0 76 | # lr scheduler mode 77 | cfg.SOLVER.LR_SCHEDULER = "poly" 78 | # poly power 79 | cfg.SOLVER.POLY.POWER = 0.9 80 | # step gamma 81 | cfg.SOLVER.STEP.GAMMA = 0.1 82 | # milestone of step lr scheduler 83 | cfg.SOLVER.STEP.DECAY_EPOCH = [10, 20] 84 | # warm up epochs can be float 85 | cfg.SOLVER.WARMUP.EPOCHS = 0. 86 | # warm up factor 87 | cfg.SOLVER.WARMUP.FACTOR = 1.0 / 3 88 | # warm up method 89 | cfg.SOLVER.WARMUP.METHOD = 'linear' 90 | # whether to use ohem 91 | cfg.SOLVER.OHEM = False 92 | # whether to use aux loss 93 | cfg.SOLVER.AUX = False 94 | # aux loss weight 95 | cfg.SOLVER.AUX_WEIGHT = 0.4 96 | # loss name 97 | cfg.SOLVER.LOSS_NAME = '' 98 | ########################## test config ########################################### 99 | # val/test model path 100 | cfg.TEST.TEST_MODEL_PATH = '' 101 | # test batch size 102 | cfg.TEST.BATCH_SIZE = 1 103 | # eval crop size 104 | cfg.TEST.CROP_SIZE = None 105 | # multiscale eval 106 | cfg.TEST.SCALES = [1.0] 107 | # flip 108 | cfg.TEST.FLIP = False 109 | 110 | ########################## visual config ########################################### 111 | # visual result output dir 112 | cfg.VISUAL.OUTPUT_DIR = '../runs/visual/' 113 | 114 | ########################## model ####################################### 115 | # model name 116 | cfg.MODEL.MODEL_NAME = '' 117 | # model backbone 118 | cfg.MODEL.BACKBONE = '' 119 | # model backbone channel scale 120 | cfg.MODEL.BACKBONE_SCALE = 1.0 121 | # support resnet b, c. b is standard resnet in pytorch official repo 122 | # cfg.MODEL.RESNET_VARIANT = 'b' 123 | # multi branch loss weight 124 | cfg.MODEL.MULTI_LOSS_WEIGHT = [1.0] 125 | # gn groups 126 | cfg.MODEL.DEFAULT_GROUP_NUMBER = 32 127 | # whole model default epsilon 128 | cfg.MODEL.DEFAULT_EPSILON = 1e-5 129 | # batch norm, support ['BN', 'SyncBN', 'FrozenBN', 'GN', 'nnSyncBN'] 130 | cfg.MODEL.BN_TYPE = 'BN' 131 | # batch norm epsilon for encoder, if set None will use api default value. 132 | cfg.MODEL.BN_EPS_FOR_ENCODER = None 133 | # batch norm epsilon for encoder, if set None will use api default value. 134 | cfg.MODEL.BN_EPS_FOR_DECODER = None 135 | # backbone output stride 136 | cfg.MODEL.OUTPUT_STRIDE = 16 137 | # BatchNorm momentum, if set None will use api default value. 138 | cfg.MODEL.BN_MOMENTUM = None 139 | 140 | 141 | ########################## DeepLab config #################################### 142 | # whether to use aspp 143 | cfg.MODEL.DEEPLABV3_PLUS.USE_ASPP = True 144 | # whether to use decoder 145 | cfg.MODEL.DEEPLABV3_PLUS.ENABLE_DECODER = True 146 | # whether aspp use sep conv 147 | cfg.MODEL.DEEPLABV3_PLUS.ASPP_WITH_SEP_CONV = True 148 | # whether decoder use sep conv 149 | cfg.MODEL.DEEPLABV3_PLUS.DECODER_USE_SEP_CONV = True 150 | ########################## Demo #################################### 151 | 152 | cfg.DEMO_DIR = 'demo/imgs' 153 | -------------------------------------------------------------------------------- /segmentron/data/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xieenze/Segment_Transparent_Objects/06a9c806d32fec178e37700095f0c5443a4f109a/segmentron/data/__init__.py -------------------------------------------------------------------------------- /segmentron/data/dataloader/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | This module provides data loaders and transformers for popular vision datasets. 3 | """ 4 | from .trans10k import TransSegmentation 5 | from .trans10k_boundary import TransSegmentationBoundary 6 | from .trans10k_extra import TransExtraSegmentation 7 | 8 | datasets = { 9 | 'trans10k': TransSegmentation, 10 | 'trans10k_boundary': TransSegmentationBoundary, 11 | 'trans10k_extra': TransExtraSegmentation 12 | } 13 | 14 | 15 | def get_segmentation_dataset(name, **kwargs): 16 | """Segmentation Datasets""" 17 | return datasets[name.lower()](**kwargs) 18 | -------------------------------------------------------------------------------- /segmentron/data/dataloader/seg_data_base.py: -------------------------------------------------------------------------------- 1 | """Base segmentation dataset""" 2 | import os 3 | import random 4 | import numpy as np 5 | import torchvision 6 | 7 | from PIL import Image, ImageOps, ImageFilter 8 | from ...config import cfg 9 | 10 | __all__ = ['SegmentationDataset'] 11 | 12 | 13 | class SegmentationDataset(object): 14 | """Segmentation Base Dataset""" 15 | 16 | def __init__(self, root, split, mode, transform, base_size=520, crop_size=480): 17 | super(SegmentationDataset, self).__init__() 18 | self.root = os.path.join(cfg.ROOT_PATH, root) 19 | self.transform = transform 20 | self.split = split 21 | self.mode = mode if mode is not None else split 22 | self.base_size = base_size 23 | self.crop_size = self.to_tuple(crop_size) 24 | self.color_jitter = self._get_color_jitter() 25 | 26 | def to_tuple(self, size): 27 | if isinstance(size, (list, tuple)): 28 | return tuple(size) 29 | elif isinstance(size, (int, float)): 30 | return tuple((size, size)) 31 | else: 32 | raise ValueError('Unsupport datatype: {}'.format(type(size))) 33 | 34 | def _get_color_jitter(self): 35 | color_jitter = cfg.AUG.COLOR_JITTER 36 | if color_jitter is None: 37 | return None 38 | if isinstance(color_jitter, (list, tuple)): 39 | # color jitter should be a 3-tuple/list if spec brightness/contrast/saturation 40 | # or 4 if also augmenting hue 41 | assert len(color_jitter) in (3, 4) 42 | else: 43 | # if it's a scalar, duplicate for brightness, contrast, and saturation, no hue 44 | color_jitter = (float(color_jitter),) * 3 45 | return torchvision.transforms.ColorJitter(*color_jitter) 46 | 47 | def _val_sync_transform(self, img, mask): 48 | short_size = self.base_size 49 | img = img.resize((short_size, short_size), Image.BILINEAR) 50 | mask = mask.resize((short_size, short_size), Image.NEAREST) 51 | 52 | # final transform 53 | img, mask = self._img_transform(img), self._mask_transform(mask) 54 | return img, mask 55 | 56 | def _sync_transform(self, img, mask): 57 | short_size = self.base_size 58 | img = img.resize((short_size, short_size), Image.BILINEAR) 59 | mask = mask.resize((short_size, short_size), Image.NEAREST) 60 | 61 | # final transform 62 | img, mask = self._img_transform(img), self._mask_transform(mask) 63 | return img, mask 64 | 65 | def _img_transform(self, img): 66 | return np.array(img) 67 | 68 | def _mask_transform(self, mask): 69 | return np.array(mask).astype('int32') 70 | 71 | @property 72 | def num_class(self): 73 | """Number of categories.""" 74 | return self.NUM_CLASS 75 | 76 | @property 77 | def pred_offset(self): 78 | return 0 79 | -------------------------------------------------------------------------------- /segmentron/data/dataloader/trans10k.py: -------------------------------------------------------------------------------- 1 | """Prepare Trans10K dataset""" 2 | import os 3 | import torch 4 | import numpy as np 5 | import logging 6 | 7 | from PIL import Image 8 | from .seg_data_base import SegmentationDataset 9 | from IPython import embed 10 | 11 | class TransSegmentation(SegmentationDataset): 12 | """Trans10K Semantic Segmentation Dataset. 13 | 14 | Parameters 15 | ---------- 16 | root : string 17 | Path to Trans10K folder. Default is './datasets/Trans10K' 18 | split: string 19 | 'train', 'validation', 'test' 20 | transform : callable, optional 21 | A function that transforms the image 22 | """ 23 | BASE_DIR = 'Trans10K' 24 | NUM_CLASS = 3 25 | 26 | def __init__(self, root='datasets/Trans10K', split='train', mode=None, transform=None, **kwargs): 27 | super(TransSegmentation, self).__init__(root, split, mode, transform, **kwargs) 28 | # self.root = os.path.join(root, self.BASE_DIR) 29 | assert os.path.exists(self.root), "Please put dataset in {SEG_ROOT}/datasets/Trans10K" 30 | self.images, self.mask_paths = _get_trans10k_pairs(self.root, self.split) 31 | assert (len(self.images) == len(self.mask_paths)) 32 | if len(self.images) == 0: 33 | raise RuntimeError("Found 0 images in subfolders of:" + root + "\n") 34 | self.valid_classes = [0,1,2] 35 | self._key = np.array([0,1,2]) 36 | # self._mapping = np.array(range(-1, len(self._key) - 1)).astype('int32') 37 | self._mapping = np.array(range(-1, len(self._key) - 1)).astype('int32') + 1 38 | 39 | def _class_to_index(self, mask): 40 | # assert the value 41 | values = np.unique(mask) 42 | 43 | for value in values: 44 | assert (value in self._mapping) 45 | index = np.digitize(mask.ravel(), self._mapping, right=True) 46 | return self._key[index].reshape(mask.shape) 47 | 48 | def __getitem__(self, index): 49 | img = Image.open(self.images[index]).convert('RGB') 50 | if self.mode == 'test': 51 | if self.transform is not None: 52 | img = self.transform(img) 53 | return img, os.path.basename(self.images[index]) 54 | mask = Image.open(self.mask_paths[index]) 55 | # 转换mask 56 | mask = np.array(mask)[:,:,:3].mean(-1) 57 | mask[mask==85.0] = 1 58 | mask[mask==255.0] = 2 59 | assert mask.max()<=2, mask.max() 60 | mask = Image.fromarray(mask) 61 | 62 | # synchrosized transform 63 | if self.mode == 'train': 64 | img, mask = self._sync_transform(img, mask) 65 | elif self.mode == 'val': 66 | img, mask = self._val_sync_transform(img, mask) 67 | else: 68 | assert self.mode == 'testval' 69 | img, mask = self._img_transform(img), self._mask_transform(mask) 70 | 71 | # general resize, normalize and toTensor 72 | if self.transform is not None: 73 | img = self.transform(img) 74 | return img, mask, os.path.basename(self.images[index]) 75 | 76 | def _mask_transform(self, mask): 77 | target = self._class_to_index(np.array(mask).astype('int32')) 78 | return torch.LongTensor(np.array(target).astype('int32')) 79 | 80 | def __len__(self): 81 | return len(self.images) 82 | 83 | @property 84 | def pred_offset(self): 85 | return 0 86 | 87 | @property 88 | def classes(self): 89 | """Category names.""" 90 | return ('background', 'things', 'stuff') 91 | 92 | 93 | def _get_trans10k_pairs(folder, split='train'): 94 | 95 | def get_path_pairs(img_folder, mask_folder): 96 | img_paths = [] 97 | mask_paths = [] 98 | imgs = os.listdir(img_folder) 99 | 100 | for imgname in imgs: 101 | imgpath = os.path.join(img_folder, imgname) 102 | maskname = imgname.replace('.jpg', '_mask.png') 103 | maskpath = os.path.join(mask_folder, maskname) 104 | if os.path.isfile(imgpath) and os.path.isfile(maskpath): 105 | img_paths.append(imgpath) 106 | mask_paths.append(maskpath) 107 | else: 108 | logging.info('cannot find the mask or image:', imgpath, maskpath) 109 | 110 | logging.info('Found {} images in the folder {}'.format(len(img_paths), img_folder)) 111 | return img_paths, mask_paths 112 | 113 | 114 | if split == 'train': 115 | img_folder = os.path.join(folder, split, 'images') 116 | mask_folder = os.path.join(folder, split, 'masks') 117 | img_paths, mask_paths = get_path_pairs(img_folder, mask_folder) 118 | else: 119 | assert split == 'validation' or split == 'test' 120 | easy_img_folder = os.path.join(folder, split, 'easy', 'images') 121 | easy_mask_folder = os.path.join(folder, split, 'easy', 'masks') 122 | hard_img_folder = os.path.join(folder, split, 'hard', 'images') 123 | hard_mask_folder = os.path.join(folder, split, 'hard', 'masks') 124 | easy_img_paths, easy_mask_paths = get_path_pairs(easy_img_folder, easy_mask_folder) 125 | hard_img_paths, hard_mask_paths = get_path_pairs(hard_img_folder, hard_mask_folder) 126 | easy_img_paths.extend(hard_img_paths) 127 | easy_mask_paths.extend(hard_mask_paths) 128 | img_paths = easy_img_paths 129 | mask_paths = easy_mask_paths 130 | return img_paths, mask_paths 131 | 132 | if __name__ == '__main__': 133 | dataset = TransSegmentation() 134 | -------------------------------------------------------------------------------- /segmentron/data/dataloader/trans10k_boundary.py: -------------------------------------------------------------------------------- 1 | """Prepare Trans10K dataset""" 2 | import os 3 | import torch 4 | import numpy as np 5 | import logging 6 | 7 | from PIL import Image 8 | from .seg_data_base import SegmentationDataset 9 | from IPython import embed 10 | import cv2 11 | 12 | class TransSegmentationBoundary(SegmentationDataset): 13 | """Trans10K Semantic Segmentation Dataset. 14 | 15 | Parameters 16 | ---------- 17 | root : string 18 | Path to Trans10K folder. Default is './datasets/Trans10K' 19 | split: string 20 | 'train', 'validation', 'test' 21 | transform : callable, optional 22 | A function that transforms the image 23 | """ 24 | BASE_DIR = 'Trans10K' 25 | NUM_CLASS = 3 26 | 27 | def __init__(self, root='datasets/Trans10K', split='train', mode=None, transform=None, **kwargs): 28 | super(TransSegmentationBoundary, self).__init__(root, split, mode, transform, **kwargs) 29 | # self.root = os.path.join(root, self.BASE_DIR) 30 | assert os.path.exists(self.root), "Please put dataset in {SEG_ROOT}/datasets/Trans10K" 31 | self.images, self.mask_paths = _get_trans10k_pairs(self.root, self.split) 32 | assert (len(self.images) == len(self.mask_paths)) 33 | if len(self.images) == 0: 34 | raise RuntimeError("Found 0 images in subfolders of:" + root + "\n") 35 | self.valid_classes = [0,1,2] 36 | self._key = np.array([0,1,2]) 37 | # self._mapping = np.array(range(-1, len(self._key) - 1)).astype('int32') 38 | self._mapping = np.array(range(-1, len(self._key) - 1)).astype('int32') + 1 39 | 40 | def _class_to_index(self, mask): 41 | # assert the value 42 | values = np.unique(mask) 43 | 44 | for value in values: 45 | assert (value in self._mapping) 46 | index = np.digitize(mask.ravel(), self._mapping, right=True) 47 | return self._key[index].reshape(mask.shape) 48 | 49 | def __getitem__(self, index): 50 | img = Image.open(self.images[index]).convert('RGB') 51 | if self.mode == 'test': 52 | if self.transform is not None: 53 | img = self.transform(img) 54 | return img, os.path.basename(self.images[index]) 55 | mask = Image.open(self.mask_paths[index]) 56 | # 转换mask 57 | mask = np.array(mask)[:,:,:3].mean(-1) 58 | mask[mask==85.0] = 1 59 | mask[mask==255.0] = 2 60 | assert mask.max()<=2, mask.max() 61 | mask = Image.fromarray(mask) 62 | 63 | # synchrosized transform 64 | if self.mode == 'train': 65 | img, mask = self._sync_transform(img, mask) 66 | elif self.mode == 'val': 67 | img, mask = self._val_sync_transform(img, mask) 68 | else: 69 | assert self.mode == 'testval' 70 | img, mask = self._img_transform(img), self._mask_transform(mask) 71 | 72 | 73 | boundary = self.get_boundary(mask) 74 | boundary = torch.LongTensor(np.array(boundary).astype('int32')) 75 | 76 | # general resize, normalize and toTensor 77 | if self.transform is not None: 78 | img = self.transform(img) 79 | return img, mask, boundary, self.images[index] 80 | 81 | def _mask_transform(self, mask): 82 | target = self._class_to_index(np.array(mask).astype('int32')) 83 | return torch.LongTensor(np.array(target).astype('int32')) 84 | 85 | def __len__(self): 86 | return len(self.images) 87 | 88 | def get_boundary(self, mask, thicky=8): 89 | tmp = mask.data.numpy().astype('uint8') 90 | contour, _ = cv2.findContours(tmp, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE) 91 | boundary = np.zeros_like(tmp) 92 | boundary = cv2.drawContours(boundary, contour, -1, 1, thicky) 93 | return boundary 94 | 95 | @property 96 | def pred_offset(self): 97 | return 0 98 | 99 | @property 100 | def classes(self): 101 | """Category names.""" 102 | return ('background', 'things', 'stuff') 103 | 104 | 105 | def _get_trans10k_pairs(folder, split='train'): 106 | 107 | def get_path_pairs(img_folder, mask_folder): 108 | img_paths = [] 109 | mask_paths = [] 110 | imgs = os.listdir(img_folder) 111 | 112 | for imgname in imgs: 113 | imgpath = os.path.join(img_folder, imgname) 114 | maskname = imgname.replace('.jpg', '_mask.png') 115 | maskpath = os.path.join(mask_folder, maskname) 116 | if os.path.isfile(imgpath) and os.path.isfile(maskpath): 117 | img_paths.append(imgpath) 118 | mask_paths.append(maskpath) 119 | else: 120 | logging.info('cannot find the mask or image:', imgpath, maskpath) 121 | 122 | logging.info('Found {} images in the folder {}'.format(len(img_paths), img_folder)) 123 | return img_paths, mask_paths 124 | 125 | 126 | if split == 'train': 127 | img_folder = os.path.join(folder, split, 'images') 128 | mask_folder = os.path.join(folder, split, 'masks') 129 | img_paths, mask_paths = get_path_pairs(img_folder, mask_folder) 130 | else: 131 | assert split == 'validation' or split == 'test' 132 | easy_img_folder = os.path.join(folder, split, 'easy', 'images') 133 | easy_mask_folder = os.path.join(folder, split, 'easy', 'masks') 134 | hard_img_folder = os.path.join(folder, split, 'hard', 'images') 135 | hard_mask_folder = os.path.join(folder, split, 'hard', 'masks') 136 | easy_img_paths, easy_mask_paths = get_path_pairs(easy_img_folder, easy_mask_folder) 137 | hard_img_paths, hard_mask_paths = get_path_pairs(hard_img_folder, hard_mask_folder) 138 | easy_img_paths.extend(hard_img_paths) 139 | easy_mask_paths.extend(hard_mask_paths) 140 | img_paths = easy_img_paths 141 | mask_paths = easy_mask_paths 142 | return img_paths, mask_paths 143 | 144 | 145 | 146 | if __name__ == '__main__': 147 | dataset = TransSegmentationBoundary() 148 | -------------------------------------------------------------------------------- /segmentron/data/dataloader/trans10k_extra.py: -------------------------------------------------------------------------------- 1 | """Prepare Trans10K dataset""" 2 | import os 3 | import torch 4 | import numpy as np 5 | import logging 6 | 7 | from PIL import Image 8 | from .seg_data_base import SegmentationDataset 9 | from IPython import embed 10 | 11 | class TransExtraSegmentation(SegmentationDataset): 12 | """Trans10K Semantic Segmentation Dataset. 13 | 14 | Parameters 15 | ---------- 16 | root : string 17 | Path to Trans10K folder. Default is './datasets/Trans10K' 18 | split: string 19 | 'train', 'validation', 'test' 20 | transform : callable, optional 21 | A function that transforms the image 22 | """ 23 | BASE_DIR = 'Trans10K' 24 | NUM_CLASS = 3 25 | 26 | def __init__(self, root='demo/imgs', split='train', mode=None, transform=None, **kwargs): 27 | super(TransExtraSegmentation, self).__init__(root, split, mode, transform, **kwargs) 28 | # self.root = os.path.join(root, self.BASE_DIR) 29 | assert os.path.exists(self.root), "Please put dataset in {SEG_ROOT}/datasets/Extra" 30 | self.images = _get_demo_pairs(self.root) 31 | if len(self.images) == 0: 32 | raise RuntimeError("Found 0 images in subfolders of:" + root + "\n") 33 | 34 | def __getitem__(self, index): 35 | img = Image.open(self.images[index]).convert('RGB') 36 | mask = np.zeros_like(np.array(img))[:,:,0] 37 | assert mask.max()<=2, mask.max() 38 | mask = Image.fromarray(mask) 39 | 40 | # synchrosized transform 41 | img, mask = self._val_sync_transform(img, mask) 42 | # general resize, normalize and toTensor 43 | if self.transform is not None: 44 | img = self.transform(img) 45 | return img, mask, self.images[index] 46 | 47 | def __len__(self): 48 | return len(self.images) 49 | 50 | @property 51 | def pred_offset(self): 52 | return 0 53 | 54 | @property 55 | def classes(self): 56 | """Category names.""" 57 | return ('background', 'things', 'stuff') 58 | 59 | 60 | def _get_demo_pairs(folder): 61 | 62 | def get_path_pairs(img_folder): 63 | img_paths = [] 64 | imgs = os.listdir(img_folder) 65 | for imgname in imgs: 66 | imgpath = os.path.join(img_folder, imgname) 67 | if os.path.isfile(imgpath): 68 | img_paths.append(imgpath) 69 | else: 70 | logging.info('cannot find the image:', imgpath) 71 | 72 | logging.info('Found {} images in the folder {}'.format(len(img_paths), img_folder)) 73 | return img_paths 74 | 75 | img_folder = folder 76 | img_paths = get_path_pairs(img_folder) 77 | 78 | return img_paths 79 | 80 | if __name__ == '__main__': 81 | dataset = TransSegmentation() 82 | -------------------------------------------------------------------------------- /segmentron/data/dataloader/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import hashlib 3 | import errno 4 | import tarfile 5 | from six.moves import urllib 6 | from torch.utils.model_zoo import tqdm 7 | 8 | def gen_bar_updater(): 9 | pbar = tqdm(total=None) 10 | 11 | def bar_update(count, block_size, total_size): 12 | if pbar.total is None and total_size: 13 | pbar.total = total_size 14 | progress_bytes = count * block_size 15 | pbar.update(progress_bytes - pbar.n) 16 | 17 | return bar_update 18 | 19 | def check_integrity(fpath, md5=None): 20 | if md5 is None: 21 | return True 22 | if not os.path.isfile(fpath): 23 | return False 24 | md5o = hashlib.md5() 25 | with open(fpath, 'rb') as f: 26 | # read in 1MB chunks 27 | for chunk in iter(lambda: f.read(1024 * 1024), b''): 28 | md5o.update(chunk) 29 | md5c = md5o.hexdigest() 30 | if md5c != md5: 31 | return False 32 | return True 33 | 34 | def makedir_exist_ok(dirpath): 35 | try: 36 | os.makedirs(dirpath) 37 | except OSError as e: 38 | if e.errno == errno.EEXIST: 39 | pass 40 | else: 41 | pass 42 | 43 | def download_url(url, root, filename=None, md5=None): 44 | """Download a file from a url and place it in root.""" 45 | root = os.path.expanduser(root) 46 | if not filename: 47 | filename = os.path.basename(url) 48 | fpath = os.path.join(root, filename) 49 | 50 | makedir_exist_ok(root) 51 | 52 | # downloads file 53 | if os.path.isfile(fpath) and check_integrity(fpath, md5): 54 | print('Using downloaded and verified file: ' + fpath) 55 | else: 56 | try: 57 | print('Downloading ' + url + ' to ' + fpath) 58 | urllib.request.urlretrieve(url, fpath, reporthook=gen_bar_updater()) 59 | except OSError: 60 | if url[:5] == 'https': 61 | url = url.replace('https:', 'http:') 62 | print('Failed download. Trying https -> http instead.' 63 | ' Downloading ' + url + ' to ' + fpath) 64 | urllib.request.urlretrieve(url, fpath, reporthook=gen_bar_updater()) 65 | 66 | def download_extract(url, root, filename, md5): 67 | download_url(url, root, filename, md5) 68 | with tarfile.open(os.path.join(root, filename), "r") as tar: 69 | tar.extractall(path=root) -------------------------------------------------------------------------------- /segmentron/models/__init__.py: -------------------------------------------------------------------------------- 1 | """Model Zoo""" 2 | from .model_zoo import MODEL_REGISTRY 3 | from .translab import TransLab 4 | -------------------------------------------------------------------------------- /segmentron/models/backbones/__init__.py: -------------------------------------------------------------------------------- 1 | from .build import BACKBONE_REGISTRY, get_segmentation_backbone 2 | from .xception import * 3 | from .mobilenet import * 4 | from .resnet import * 5 | from .hrnet import * 6 | from .eespnet import * 7 | -------------------------------------------------------------------------------- /segmentron/models/backbones/build.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import logging 4 | import torch.utils.model_zoo as model_zoo 5 | 6 | from ...utils.registry import Registry 7 | from ...config import cfg 8 | 9 | BACKBONE_REGISTRY = Registry("BACKBONE") 10 | BACKBONE_REGISTRY.__doc__ = """ 11 | Registry for backbone, i.e. resnet. 12 | 13 | The registered object will be called with `obj()` 14 | and expected to return a `nn.Module` object. 15 | """ 16 | 17 | model_urls = { 18 | 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth', 19 | 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth', 20 | 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth', 21 | 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth', 22 | 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth', 23 | 'resnet50c': 'https://github.com/LikeLy-Journey/SegmenTron/releases/download/v0.1.0/resnet50-25c4b509.pth', 24 | 'resnet101c': 'https://github.com/LikeLy-Journey/SegmenTron/releases/download/v0.1.0/resnet101-2a57e44d.pth', 25 | 'resnet152c': 'https://github.com/LikeLy-Journey/SegmenTron/releases/download/v0.1.0/resnet152-0d43d698.pth', 26 | 'xception65': 'https://github.com/LikeLy-Journey/SegmenTron/releases/download/v0.1.0/tf-xception65-270e81cf.pth', 27 | 'hrnet_w18_small_v1': 'https://github.com/LikeLy-Journey/SegmenTron/releases/download/v0.1.0/hrnet-w18-small-v1-08f8ae64.pth', 28 | 'mobilenet_v2': 'https://github.com/LikeLy-Journey/SegmenTron/releases/download/v0.1.0/mobilenetV2-15498621.pth', 29 | } 30 | 31 | 32 | def load_backbone_pretrained(model, backbone): 33 | if cfg.PHASE == 'train' and cfg.TRAIN.BACKBONE_PRETRAINED and (not cfg.TRAIN.PRETRAINED_MODEL_PATH): 34 | if os.path.isfile(cfg.TRAIN.BACKBONE_PRETRAINED_PATH): 35 | logging.info('Load backbone pretrained model from {}'.format( 36 | cfg.TRAIN.BACKBONE_PRETRAINED_PATH 37 | )) 38 | msg = model.load_state_dict(torch.load(cfg.TRAIN.BACKBONE_PRETRAINED_PATH), strict=False) 39 | logging.info(msg) 40 | elif backbone not in model_urls: 41 | logging.info('{} has no pretrained model'.format(backbone)) 42 | return 43 | else: 44 | logging.info('load backbone pretrained model from url..') 45 | msg = model.load_state_dict(model_zoo.load_url(model_urls[backbone]), strict=False) 46 | logging.info(msg) 47 | 48 | 49 | def get_segmentation_backbone(backbone, norm_layer=torch.nn.BatchNorm2d): 50 | """ 51 | Built the backbone model, defined by `cfg.MODEL.BACKBONE`. 52 | """ 53 | model = BACKBONE_REGISTRY.get(backbone)(norm_layer) 54 | load_backbone_pretrained(model, backbone) 55 | return model 56 | 57 | -------------------------------------------------------------------------------- /segmentron/models/backbones/eespnet.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | from ...modules import _ConvBNPReLU, _ConvBN, _BNPReLU, EESP 7 | from .build import BACKBONE_REGISTRY 8 | from ...config import cfg 9 | 10 | __all__ = ['EESPNet', 'eespnet'] 11 | 12 | 13 | class DownSampler(nn.Module): 14 | 15 | def __init__(self, in_channels, out_channels, k=4, r_lim=9, reinf=True, inp_reinf=3, norm_layer=None): 16 | super(DownSampler, self).__init__() 17 | channels_diff = out_channels - in_channels 18 | self.eesp = EESP(in_channels, channels_diff, stride=2, k=k, 19 | r_lim=r_lim, down_method='avg', norm_layer=norm_layer) 20 | self.avg = nn.AvgPool2d(kernel_size=3, padding=1, stride=2) 21 | if reinf: 22 | self.inp_reinf = nn.Sequential( 23 | _ConvBNPReLU(inp_reinf, inp_reinf, 3, 1, 1), 24 | _ConvBN(inp_reinf, out_channels, 1, 1)) 25 | self.act = nn.PReLU(out_channels) 26 | 27 | def forward(self, x, x2=None): 28 | avg_out = self.avg(x) 29 | eesp_out = self.eesp(x) 30 | output = torch.cat([avg_out, eesp_out], 1) 31 | if x2 is not None: 32 | w1 = avg_out.size(2) 33 | while True: 34 | x2 = F.avg_pool2d(x2, kernel_size=3, padding=1, stride=2) 35 | w2 = x2.size(2) 36 | if w2 == w1: 37 | break 38 | output = output + self.inp_reinf(x2) 39 | 40 | return self.act(output) 41 | 42 | 43 | class EESPNet(nn.Module): 44 | def __init__(self, num_classes=1000, scale=1, reinf=True, norm_layer=nn.BatchNorm2d): 45 | super(EESPNet, self).__init__() 46 | inp_reinf = 3 if reinf else None 47 | reps = [0, 3, 7, 3] 48 | r_lim = [13, 11, 9, 7, 5] 49 | K = [4] * len(r_lim) 50 | 51 | # set out_channels 52 | base, levels, base_s = 32, 5, 0 53 | out_channels = [base] * levels 54 | for i in range(levels): 55 | if i == 0: 56 | base_s = int(base * scale) 57 | base_s = math.ceil(base_s / K[0]) * K[0] 58 | out_channels[i] = base if base_s > base else base_s 59 | else: 60 | out_channels[i] = base_s * pow(2, i) 61 | if scale <= 1.5: 62 | out_channels.append(1024) 63 | elif scale in [1.5, 2]: 64 | out_channels.append(1280) 65 | else: 66 | raise ValueError("Unknown scale value.") 67 | 68 | self.level1 = _ConvBNPReLU(3, out_channels[0], 3, 2, 1, norm_layer=norm_layer) 69 | 70 | self.level2_0 = DownSampler(out_channels[0], out_channels[1], k=K[0], r_lim=r_lim[0], 71 | reinf=reinf, inp_reinf=inp_reinf, norm_layer=norm_layer) 72 | 73 | self.level3_0 = DownSampler(out_channels[1], out_channels[2], k=K[1], r_lim=r_lim[1], 74 | reinf=reinf, inp_reinf=inp_reinf, norm_layer=norm_layer) 75 | self.level3 = nn.ModuleList() 76 | for i in range(reps[1]): 77 | self.level3.append(EESP(out_channels[2], out_channels[2], k=K[2], r_lim=r_lim[2], 78 | norm_layer=norm_layer)) 79 | 80 | self.level4_0 = DownSampler(out_channels[2], out_channels[3], k=K[2], r_lim=r_lim[2], 81 | reinf=reinf, inp_reinf=inp_reinf, norm_layer=norm_layer) 82 | self.level4 = nn.ModuleList() 83 | for i in range(reps[2]): 84 | self.level4.append(EESP(out_channels[3], out_channels[3], k=K[3], r_lim=r_lim[3], 85 | norm_layer=norm_layer)) 86 | 87 | self.level5_0 = DownSampler(out_channels[3], out_channels[4], k=K[3], r_lim=r_lim[3], 88 | reinf=reinf, inp_reinf=inp_reinf, norm_layer=norm_layer) 89 | self.level5 = nn.ModuleList() 90 | for i in range(reps[2]): 91 | self.level5.append(EESP(out_channels[4], out_channels[4], k=K[4], r_lim=r_lim[4], 92 | norm_layer=norm_layer)) 93 | 94 | self.level5.append(_ConvBNPReLU(out_channels[4], out_channels[4], 3, 1, 1, 95 | groups=out_channels[4], norm_layer=norm_layer)) 96 | self.level5.append(_ConvBNPReLU(out_channels[4], out_channels[5], 1, 1, 0, 97 | groups=K[4], norm_layer=norm_layer)) 98 | 99 | self.fc = nn.Linear(out_channels[5], num_classes) 100 | 101 | for m in self.modules(): 102 | if isinstance(m, nn.Conv2d): 103 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 104 | if m.bias is not None: 105 | nn.init.constant_(m.bias, 0) 106 | elif isinstance(m, nn.BatchNorm2d): 107 | nn.init.constant_(m.weight, 1) 108 | nn.init.constant_(m.bias, 0) 109 | elif isinstance(m, nn.Linear): 110 | nn.init.normal_(m.weight, std=0.001) 111 | if m.bias is not None: 112 | nn.init.constant_(m.bias, 0) 113 | 114 | def forward(self, x, seg=True): 115 | out_l1 = self.level1(x) 116 | 117 | out_l2 = self.level2_0(out_l1, x) 118 | 119 | out_l3_0 = self.level3_0(out_l2, x) 120 | for i, layer in enumerate(self.level3): 121 | if i == 0: 122 | out_l3 = layer(out_l3_0) 123 | else: 124 | out_l3 = layer(out_l3) 125 | 126 | out_l4_0 = self.level4_0(out_l3, x) 127 | for i, layer in enumerate(self.level4): 128 | if i == 0: 129 | out_l4 = layer(out_l4_0) 130 | else: 131 | out_l4 = layer(out_l4) 132 | 133 | if not seg: 134 | out_l5_0 = self.level5_0(out_l4) # down-sampled 135 | for i, layer in enumerate(self.level5): 136 | if i == 0: 137 | out_l5 = layer(out_l5_0) 138 | else: 139 | out_l5 = layer(out_l5) 140 | 141 | output_g = F.adaptive_avg_pool2d(out_l5, output_size=1) 142 | output_g = F.dropout(output_g, p=0.2, training=self.training) 143 | output_1x1 = output_g.view(output_g.size(0), -1) 144 | 145 | return self.fc(output_1x1) 146 | return out_l1, out_l2, out_l3, out_l4 147 | 148 | 149 | @BACKBONE_REGISTRY.register() 150 | def eespnet(norm_layer=nn.BatchNorm2d): 151 | return EESPNet(norm_layer=norm_layer) 152 | 153 | # def eespnet(pretrained=False, **kwargs): 154 | # model = EESPNet(**kwargs) 155 | # if pretrained: 156 | # raise ValueError("Don't support pretrained") 157 | # return model 158 | 159 | 160 | if __name__ == '__main__': 161 | img = torch.randn(1, 3, 224, 224) 162 | model = eespnet() 163 | out = model(img) 164 | -------------------------------------------------------------------------------- /segmentron/models/backbones/mobilenet.py: -------------------------------------------------------------------------------- 1 | """MobileNet and MobileNetV2.""" 2 | import torch.nn as nn 3 | 4 | from .build import BACKBONE_REGISTRY 5 | from ...modules import _ConvBNReLU, _DepthwiseConv, InvertedResidual 6 | from ...config import cfg 7 | 8 | __all__ = ['MobileNet', 'MobileNetV2'] 9 | 10 | 11 | class MobileNet(nn.Module): 12 | def __init__(self, num_classes=1000, norm_layer=nn.BatchNorm2d): 13 | super(MobileNet, self).__init__() 14 | multiplier = cfg.MODEL.BACKBONE_SCALE 15 | conv_dw_setting = [ 16 | [64, 1, 1], 17 | [128, 2, 2], 18 | [256, 2, 2], 19 | [512, 6, 2], 20 | [1024, 2, 2]] 21 | input_channels = int(32 * multiplier) if multiplier > 1.0 else 32 22 | features = [_ConvBNReLU(3, input_channels, 3, 2, 1, norm_layer=norm_layer)] 23 | 24 | for c, n, s in conv_dw_setting: 25 | out_channels = int(c * multiplier) 26 | for i in range(n): 27 | stride = s if i == 0 else 1 28 | features.append(_DepthwiseConv(input_channels, out_channels, stride, norm_layer)) 29 | input_channels = out_channels 30 | self.last_inp_channels = int(1024 * multiplier) 31 | features.append(nn.AdaptiveAvgPool2d(1)) 32 | self.features = nn.Sequential(*features) 33 | 34 | self.classifier = nn.Linear(int(1024 * multiplier), num_classes) 35 | 36 | # weight initialization 37 | for m in self.modules(): 38 | if isinstance(m, nn.Conv2d): 39 | nn.init.kaiming_normal_(m.weight, mode='fan_out') 40 | if m.bias is not None: 41 | nn.init.zeros_(m.bias) 42 | elif isinstance(m, nn.BatchNorm2d): 43 | nn.init.ones_(m.weight) 44 | nn.init.zeros_(m.bias) 45 | elif isinstance(m, nn.Linear): 46 | nn.init.normal_(m.weight, 0, 0.01) 47 | nn.init.zeros_(m.bias) 48 | 49 | def forward(self, x): 50 | x = self.features(x) 51 | x = self.classifier(x.view(x.size(0), x.size(1))) 52 | return x 53 | 54 | 55 | class MobileNetV2(nn.Module): 56 | def __init__(self, num_classes=1000, norm_layer=nn.BatchNorm2d): 57 | super(MobileNetV2, self).__init__() 58 | output_stride = cfg.MODEL.OUTPUT_STRIDE 59 | self.multiplier = cfg.MODEL.BACKBONE_SCALE 60 | if output_stride == 32: 61 | dilations = [1, 1] 62 | elif output_stride == 16: 63 | dilations = [1, 2] 64 | elif output_stride == 8: 65 | dilations = [2, 4] 66 | else: 67 | raise NotImplementedError 68 | inverted_residual_setting = [ 69 | # t, c, n, s 70 | [1, 16, 1, 1], 71 | [6, 24, 2, 2], 72 | [6, 32, 3, 2], 73 | [6, 64, 4, 2], 74 | [6, 96, 3, 1], 75 | [6, 160, 3, 2], 76 | [6, 320, 1, 1]] 77 | # building first layer 78 | input_channels = int(32 * self.multiplier) if self.multiplier > 1.0 else 32 79 | # last_channels = int(1280 * multiplier) if multiplier > 1.0 else 1280 80 | self.conv1 = _ConvBNReLU(3, input_channels, 3, 2, 1, relu6=True, norm_layer=norm_layer) 81 | 82 | # building inverted residual blocks 83 | self.planes = input_channels 84 | self.block1 = self._make_layer(InvertedResidual, self.planes, inverted_residual_setting[0:1], 85 | norm_layer=norm_layer) 86 | self.block2 = self._make_layer(InvertedResidual, self.planes, inverted_residual_setting[1:2], 87 | norm_layer=norm_layer) 88 | self.block3 = self._make_layer(InvertedResidual, self.planes, inverted_residual_setting[2:3], 89 | norm_layer=norm_layer) 90 | self.block4 = self._make_layer(InvertedResidual, self.planes, inverted_residual_setting[3:5], 91 | dilations[0], norm_layer=norm_layer) 92 | self.block5 = self._make_layer(InvertedResidual, self.planes, inverted_residual_setting[5:], 93 | dilations[1], norm_layer=norm_layer) 94 | self.last_inp_channels = self.planes 95 | 96 | # building last several layers 97 | # features = list() 98 | # features.append(_ConvBNReLU(input_channels, last_channels, 1, relu6=True, norm_layer=norm_layer)) 99 | # features.append(nn.AdaptiveAvgPool2d(1)) 100 | # self.features = nn.Sequential(*features) 101 | # 102 | # self.classifier = nn.Sequential( 103 | # nn.Dropout2d(0.2), 104 | # nn.Linear(last_channels, num_classes)) 105 | 106 | # weight initialization 107 | for m in self.modules(): 108 | if isinstance(m, nn.Conv2d): 109 | nn.init.kaiming_normal_(m.weight, mode='fan_out') 110 | if m.bias is not None: 111 | nn.init.zeros_(m.bias) 112 | elif isinstance(m, nn.BatchNorm2d): 113 | nn.init.ones_(m.weight) 114 | nn.init.zeros_(m.bias) 115 | elif isinstance(m, nn.Linear): 116 | nn.init.normal_(m.weight, 0, 0.01) 117 | if m.bias is not None: 118 | nn.init.zeros_(m.bias) 119 | 120 | def _make_layer(self, block, planes, inverted_residual_setting, dilation=1, norm_layer=nn.BatchNorm2d): 121 | features = list() 122 | for t, c, n, s in inverted_residual_setting: 123 | out_channels = int(c * self.multiplier) 124 | stride = s if dilation == 1 else 1 125 | features.append(block(planes, out_channels, stride, t, dilation, norm_layer)) 126 | planes = out_channels 127 | for i in range(n - 1): 128 | features.append(block(planes, out_channels, 1, t, norm_layer=norm_layer)) 129 | planes = out_channels 130 | self.planes = planes 131 | return nn.Sequential(*features) 132 | 133 | def forward(self, x): 134 | x = self.conv1(x) 135 | x = self.block1(x) 136 | c1 = self.block2(x) 137 | c2 = self.block3(c1) 138 | c3 = self.block4(c2) 139 | c4 = self.block5(c3) 140 | 141 | # x = self.features(x) 142 | # x = self.classifier(x.view(x.size(0), x.size(1))) 143 | return c1, c2, c3, c4 144 | 145 | 146 | @BACKBONE_REGISTRY.register() 147 | def mobilenet_v1(norm_layer=nn.BatchNorm2d): 148 | return MobileNet(norm_layer=norm_layer) 149 | 150 | 151 | @BACKBONE_REGISTRY.register() 152 | def mobilenet_v2(norm_layer=nn.BatchNorm2d): 153 | return MobileNetV2(norm_layer=norm_layer) 154 | 155 | -------------------------------------------------------------------------------- /segmentron/models/backbones/resnet.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | from .build import BACKBONE_REGISTRY 4 | from ...config import cfg 5 | 6 | __all__ = ['ResNetV1'] 7 | 8 | 9 | class BasicBlockV1b(nn.Module): 10 | expansion = 1 11 | 12 | def __init__(self, inplanes, planes, stride=1, dilation=1, downsample=None, 13 | previous_dilation=1, norm_layer=nn.BatchNorm2d): 14 | super(BasicBlockV1b, self).__init__() 15 | self.conv1 = nn.Conv2d(inplanes, planes, 3, stride, 16 | dilation, dilation, bias=False) 17 | self.bn1 = norm_layer(planes) 18 | self.relu = nn.ReLU(True) 19 | self.conv2 = nn.Conv2d(planes, planes, 3, 1, previous_dilation, 20 | dilation=previous_dilation, bias=False) 21 | self.bn2 = norm_layer(planes) 22 | self.downsample = downsample 23 | self.stride = stride 24 | 25 | def forward(self, x): 26 | identity = x 27 | 28 | out = self.conv1(x) 29 | out = self.bn1(out) 30 | out = self.relu(out) 31 | 32 | out = self.conv2(out) 33 | out = self.bn2(out) 34 | 35 | if self.downsample is not None: 36 | identity = self.downsample(x) 37 | 38 | out += identity 39 | out = self.relu(out) 40 | 41 | return out 42 | 43 | 44 | class BottleneckV1b(nn.Module): 45 | expansion = 4 46 | 47 | def __init__(self, inplanes, planes, stride=1, dilation=1, downsample=None, 48 | previous_dilation=1, norm_layer=nn.BatchNorm2d): 49 | super(BottleneckV1b, self).__init__() 50 | self.conv1 = nn.Conv2d(inplanes, planes, 1, bias=False) 51 | self.bn1 = norm_layer(planes) 52 | self.conv2 = nn.Conv2d(planes, planes, 3, stride, 53 | dilation, dilation, bias=False) 54 | self.bn2 = norm_layer(planes) 55 | self.conv3 = nn.Conv2d(planes, planes * self.expansion, 1, bias=False) 56 | self.bn3 = norm_layer(planes * self.expansion) 57 | self.relu = nn.ReLU(True) 58 | self.downsample = downsample 59 | self.stride = stride 60 | 61 | def forward(self, x): 62 | identity = x 63 | 64 | out = self.conv1(x) 65 | out = self.bn1(out) 66 | out = self.relu(out) 67 | 68 | out = self.conv2(out) 69 | out = self.bn2(out) 70 | out = self.relu(out) 71 | 72 | out = self.conv3(out) 73 | out = self.bn3(out) 74 | 75 | if self.downsample is not None: 76 | identity = self.downsample(x) 77 | 78 | out += identity 79 | out = self.relu(out) 80 | 81 | return out 82 | 83 | 84 | class ResNetV1(nn.Module): 85 | 86 | def __init__(self, block, layers, num_classes=1000, deep_stem=False, 87 | zero_init_residual=False, norm_layer=nn.BatchNorm2d): 88 | output_stride = cfg.MODEL.OUTPUT_STRIDE 89 | scale = cfg.MODEL.BACKBONE_SCALE 90 | if output_stride == 32: 91 | dilations = [1, 1] 92 | strides = [2, 2] 93 | elif output_stride == 16: 94 | dilations = [1, 2] 95 | strides = [2, 1] 96 | elif output_stride == 8: 97 | dilations = [2, 4] 98 | strides = [1, 1] 99 | else: 100 | raise NotImplementedError 101 | self.inplanes = int((128 if deep_stem else 64) * scale) 102 | super(ResNetV1, self).__init__() 103 | if deep_stem: 104 | # resnet vc 105 | mid_channel = int(64 * scale) 106 | self.conv1 = nn.Sequential( 107 | nn.Conv2d(3, mid_channel, 3, 2, 1, bias=False), 108 | norm_layer(mid_channel), 109 | nn.ReLU(True), 110 | nn.Conv2d(mid_channel, mid_channel, 3, 1, 1, bias=False), 111 | norm_layer(mid_channel), 112 | nn.ReLU(True), 113 | nn.Conv2d(mid_channel, self.inplanes, 3, 1, 1, bias=False) 114 | ) 115 | else: 116 | self.conv1 = nn.Conv2d(3, self.inplanes, 7, 2, 3, bias=False) 117 | self.bn1 = norm_layer(self.inplanes) 118 | self.relu = nn.ReLU(True) 119 | self.maxpool = nn.MaxPool2d(3, 2, 1) 120 | self.layer1 = self._make_layer(block, int(64 * scale), layers[0], norm_layer=norm_layer) 121 | self.layer2 = self._make_layer(block, int(128 * scale), layers[1], stride=2, norm_layer=norm_layer) 122 | 123 | self.layer3 = self._make_layer(block, int(256 * scale), layers[2], stride=strides[0], dilation=dilations[0], 124 | norm_layer=norm_layer) 125 | self.layer4 = self._make_layer(block, int(512 * scale), layers[3], stride=strides[1], dilation=dilations[1], 126 | norm_layer=norm_layer, multi_grid=cfg.MODEL.DANET.MULTI_GRID, 127 | multi_dilation=cfg.MODEL.DANET.MULTI_DILATION) 128 | 129 | self.last_inp_channels = int(512 * block.expansion * scale) 130 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) 131 | self.fc = nn.Linear(int(512 * block.expansion * scale), num_classes) 132 | 133 | for m in self.modules(): 134 | if isinstance(m, nn.Conv2d): 135 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 136 | elif isinstance(m, nn.BatchNorm2d): 137 | nn.init.constant_(m.weight, 1) 138 | nn.init.constant_(m.bias, 0) 139 | 140 | if zero_init_residual: 141 | for m in self.modules(): 142 | if isinstance(m, BottleneckV1b): 143 | nn.init.constant_(m.bn3.weight, 0) 144 | elif isinstance(m, BasicBlockV1b): 145 | nn.init.constant_(m.bn2.weight, 0) 146 | 147 | def _make_layer(self, block, planes, blocks, stride=1, dilation=1, norm_layer=nn.BatchNorm2d, 148 | multi_grid=False, multi_dilation=None): 149 | downsample = None 150 | if stride != 1 or self.inplanes != planes * block.expansion: 151 | downsample = nn.Sequential( 152 | nn.Conv2d(self.inplanes, planes * block.expansion, 1, stride, bias=False), 153 | norm_layer(planes * block.expansion), 154 | ) 155 | 156 | layers = [] 157 | if not multi_grid: 158 | if dilation in (1, 2): 159 | layers.append(block(self.inplanes, planes, stride, dilation=1, downsample=downsample, 160 | previous_dilation=dilation, norm_layer=norm_layer)) 161 | elif dilation == 4: 162 | layers.append(block(self.inplanes, planes, stride, dilation=2, downsample=downsample, 163 | previous_dilation=dilation, norm_layer=norm_layer)) 164 | else: 165 | raise RuntimeError("=> unknown dilation size: {}".format(dilation)) 166 | else: 167 | layers.append(block(self.inplanes, planes, stride, dilation=multi_dilation[0], 168 | downsample=downsample, previous_dilation=dilation, norm_layer=norm_layer)) 169 | self.inplanes = planes * block.expansion 170 | 171 | if multi_grid: 172 | div = len(multi_dilation) 173 | for i in range(1, blocks): 174 | layers.append(block(self.inplanes, planes, dilation=multi_dilation[i % div], 175 | previous_dilation=dilation, norm_layer=norm_layer)) 176 | else: 177 | for _ in range(1, blocks): 178 | layers.append(block(self.inplanes, planes, dilation=dilation, 179 | previous_dilation=dilation, norm_layer=norm_layer)) 180 | 181 | return nn.Sequential(*layers) 182 | 183 | def forward(self, x): 184 | x = self.conv1(x) 185 | x = self.bn1(x) 186 | x = self.relu(x) 187 | x = self.maxpool(x) 188 | 189 | c1 = self.layer1(x) 190 | c2 = self.layer2(c1) 191 | c3 = self.layer3(c2) 192 | c4 = self.layer4(c3) 193 | 194 | # for classification 195 | # x = self.avgpool(c4) 196 | # x = x.view(x.size(0), -1) 197 | # x = self.fc(x) 198 | 199 | return c1, c2, c3, c4 200 | 201 | 202 | @BACKBONE_REGISTRY.register() 203 | def resnet18(norm_layer=nn.BatchNorm2d): 204 | num_block = [2, 2, 2, 2] 205 | return ResNetV1(BasicBlockV1b, num_block, norm_layer=norm_layer) 206 | 207 | 208 | @BACKBONE_REGISTRY.register() 209 | def resnet34(norm_layer=nn.BatchNorm2d): 210 | num_block = [3, 4, 6, 3] 211 | return ResNetV1(BasicBlockV1b, num_block, norm_layer=norm_layer) 212 | 213 | 214 | @BACKBONE_REGISTRY.register() 215 | def resnet50(norm_layer=nn.BatchNorm2d): 216 | num_block = [3, 4, 6, 3] 217 | return ResNetV1(BottleneckV1b, num_block, norm_layer=norm_layer) 218 | 219 | 220 | @BACKBONE_REGISTRY.register() 221 | def resnet101(norm_layer=nn.BatchNorm2d): 222 | num_block = [3, 4, 23, 3] 223 | return ResNetV1(BottleneckV1b, num_block, norm_layer=norm_layer) 224 | 225 | 226 | @BACKBONE_REGISTRY.register() 227 | def resnet152(norm_layer=nn.BatchNorm2d): 228 | num_block = [3, 8, 36, 3] 229 | return ResNetV1(BottleneckV1b, num_block, norm_layer=norm_layer) 230 | 231 | 232 | @BACKBONE_REGISTRY.register() 233 | def resnet50c(norm_layer=nn.BatchNorm2d): 234 | num_block = [3, 4, 6, 3] 235 | return ResNetV1(BottleneckV1b, num_block, norm_layer=norm_layer, deep_stem=True) 236 | 237 | 238 | @BACKBONE_REGISTRY.register() 239 | def resnet101c(norm_layer=nn.BatchNorm2d): 240 | num_block = [3, 4, 23, 3] 241 | return ResNetV1(BottleneckV1b, num_block, norm_layer=norm_layer, deep_stem=True) 242 | 243 | 244 | @BACKBONE_REGISTRY.register() 245 | def resnet152c(norm_layer=nn.BatchNorm2d): 246 | num_block = [3, 8, 36, 3] 247 | return ResNetV1(BottleneckV1b, num_block, norm_layer=norm_layer, deep_stem=True) 248 | 249 | -------------------------------------------------------------------------------- /segmentron/models/model_zoo.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import torch 3 | 4 | from collections import OrderedDict 5 | from segmentron.utils.registry import Registry 6 | from ..config import cfg 7 | 8 | MODEL_REGISTRY = Registry("MODEL") 9 | MODEL_REGISTRY.__doc__ = """ 10 | Registry for segment model, i.e. the whole model. 11 | 12 | The registered object will be called with `obj()` 13 | and expected to return a `nn.Module` object. 14 | """ 15 | 16 | 17 | def get_segmentation_model(): 18 | """ 19 | Built the whole model, defined by `cfg.MODEL.META_ARCHITECTURE`. 20 | """ 21 | model_name = cfg.MODEL.MODEL_NAME 22 | model = MODEL_REGISTRY.get(model_name)() 23 | load_model_pretrain(model) 24 | return model 25 | 26 | 27 | def load_model_pretrain(model): 28 | if cfg.PHASE == 'train': 29 | if cfg.TRAIN.PRETRAINED_MODEL_PATH: 30 | logging.info('load pretrained model from {}'.format(cfg.TRAIN.PRETRAINED_MODEL_PATH)) 31 | state_dict_to_load = torch.load(cfg.TRAIN.PRETRAINED_MODEL_PATH) 32 | keys_wrong_shape = [] 33 | state_dict_suitable = OrderedDict() 34 | state_dict = model.state_dict() 35 | for k, v in state_dict_to_load.items(): 36 | if v.shape == state_dict[k].shape: 37 | state_dict_suitable[k] = v 38 | else: 39 | keys_wrong_shape.append(k) 40 | logging.info('Shape unmatched weights: {}'.format(keys_wrong_shape)) 41 | msg = model.load_state_dict(state_dict_suitable, strict=False) 42 | logging.info(msg) 43 | else: 44 | if cfg.TEST.TEST_MODEL_PATH: 45 | logging.info('load test model from {}'.format(cfg.TEST.TEST_MODEL_PATH)) 46 | msg = model.load_state_dict(torch.load(cfg.TEST.TEST_MODEL_PATH), strict=False) 47 | logging.info(msg) -------------------------------------------------------------------------------- /segmentron/models/segbase.py: -------------------------------------------------------------------------------- 1 | """Base Model for Semantic Segmentation""" 2 | import math 3 | import numbers 4 | import numpy as np 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | 9 | from .backbones import get_segmentation_backbone 10 | from ..data.dataloader import datasets 11 | from ..modules import get_norm 12 | from ..config import cfg 13 | __all__ = ['SegBaseModel'] 14 | 15 | 16 | class SegBaseModel(nn.Module): 17 | r"""Base Model for Semantic Segmentation 18 | """ 19 | def __init__(self, need_backbone=True): 20 | super(SegBaseModel, self).__init__() 21 | self.nclass = datasets[cfg.DATASET.NAME].NUM_CLASS 22 | self.aux = cfg.SOLVER.AUX 23 | self.norm_layer = get_norm(cfg.MODEL.BN_TYPE) 24 | self.backbone = None 25 | self.encoder = None 26 | if need_backbone: 27 | self.get_backbone() 28 | 29 | def get_backbone(self): 30 | self.backbone = cfg.MODEL.BACKBONE.lower() 31 | self.encoder = get_segmentation_backbone(self.backbone, self.norm_layer) 32 | 33 | def base_forward(self, x): 34 | """forwarding backbone network""" 35 | c1, c2, c3, c4 = self.encoder(x) 36 | return c1, c2, c3, c4 37 | 38 | def demo(self, x): 39 | pred = self.forward(x) 40 | if self.aux: 41 | pred = pred[0] 42 | return pred 43 | 44 | def evaluate(self, image): 45 | """evaluating network with inputs and targets""" 46 | scales = cfg.TEST.SCALES 47 | batch, _, h, w = image.shape 48 | base_size = max(h, w) 49 | # scores = torch.zeros((batch, self.nclass, h, w)).to(image.device) 50 | scores = None 51 | for scale in scales: 52 | long_size = int(math.ceil(base_size * scale)) 53 | if h > w: 54 | height = long_size 55 | width = int(1.0 * w * long_size / h + 0.5) 56 | else: 57 | width = long_size 58 | height = int(1.0 * h * long_size / w + 0.5) 59 | 60 | # resize image to current size 61 | cur_img = _resize_image(image, height, width) 62 | outputs = self.forward(cur_img)[0][..., :height, :width] 63 | 64 | score = _resize_image(outputs, h, w) 65 | 66 | if scores is None: 67 | scores = score 68 | else: 69 | scores += score 70 | return scores 71 | 72 | 73 | def _resize_image(img, h, w): 74 | return F.interpolate(img, size=[h, w], mode='bilinear', align_corners=True) 75 | 76 | 77 | def _pad_image(img, crop_size): 78 | b, c, h, w = img.shape 79 | assert(c == 3) 80 | padh = crop_size[0] - h if h < crop_size[0] else 0 81 | padw = crop_size[1] - w if w < crop_size[1] else 0 82 | if padh == 0 and padw == 0: 83 | return img 84 | img_pad = F.pad(img, (0, padh, 0, padw)) 85 | 86 | # TODO clean this code 87 | # mean = cfg.DATASET.MEAN 88 | # std = cfg.DATASET.STD 89 | # pad_values = -np.array(mean) / np.array(std) 90 | # img_pad = torch.zeros((b, c, h + padh, w + padw)).to(img.device) 91 | # for i in range(c): 92 | # # print(img[:, i, :, :].unsqueeze(1).shape) 93 | # img_pad[:, i, :, :] = torch.squeeze( 94 | # F.pad(img[:, i, :, :].unsqueeze(1), (0, padh, 0, padw), 95 | # 'constant', value=pad_values[i]), 1) 96 | # assert(img_pad.shape[2] >= crop_size[0] and img_pad.shape[3] >= crop_size[1]) 97 | 98 | return img_pad 99 | 100 | 101 | def _crop_image(img, h0, h1, w0, w1): 102 | return img[:, :, h0:h1, w0:w1] 103 | 104 | 105 | def _flip_image(img): 106 | assert(img.ndim == 4) 107 | return img.flip((3)) 108 | 109 | 110 | def _to_tuple(size): 111 | if isinstance(size, (list, tuple)): 112 | assert len(size), 'Expect eval crop size contains two element, ' \ 113 | 'but received {}'.format(len(size)) 114 | return tuple(size) 115 | elif isinstance(size, numbers.Number): 116 | return tuple((size, size)) 117 | else: 118 | raise ValueError('Unsupport datatype: {}'.format(type(size))) 119 | -------------------------------------------------------------------------------- /segmentron/models/translab.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from .segbase import SegBaseModel 6 | from .model_zoo import MODEL_REGISTRY 7 | from ..modules import _ConvBNReLU, SeparableConv2d, _ASPP, _FCNHead 8 | from ..config import cfg 9 | from IPython import embed 10 | import math 11 | 12 | __all__ = ['TransLab'] 13 | 14 | def _resize_image(img, h, w): 15 | return F.interpolate(img, size=[h, w], mode='bilinear', align_corners=True) 16 | 17 | @MODEL_REGISTRY.register(name='TransLab') 18 | class TransLab(SegBaseModel): 19 | def __init__(self): 20 | super(TransLab, self).__init__() 21 | if self.backbone.startswith('mobilenet'): 22 | c1_channels = 24 23 | c4_channels = 320 24 | else: 25 | c1_channels = 256 26 | c4_channels = 2048 27 | c2_channel = 512 28 | 29 | self.head = _DeepLabHead_attention(self.nclass, c1_channels=c1_channels, c4_channels=c4_channels, c2_channel=c2_channel) 30 | self.head_b = _DeepLabHead(1, c1_channels=c1_channels, c4_channels=c4_channels) 31 | 32 | self.fus_head1 = FusHead() 33 | self.fus_head2 = FusHead(inplane=2048) 34 | self.fus_head3 = FusHead(inplane=512) 35 | 36 | if self.aux: 37 | self.auxlayer = _FCNHead(728, self.nclass) 38 | self.__setattr__('decoder', ['head', 'auxlayer'] if self.aux else ['head']) 39 | 40 | def forward(self, x): 41 | size = x.size()[2:] 42 | c1, c2, c3, c4 = self.encoder(x) 43 | outputs = list() 44 | outputs_b = list() 45 | 46 | x_b = self.head_b(c4, c1) 47 | 48 | #attention c1 c4 49 | attention_map = x_b.sigmoid() 50 | 51 | c1 = self.fus_head1(c1, attention_map) 52 | c4 = self.fus_head2(c4, attention_map) 53 | c2 = self.fus_head3(c2, attention_map) 54 | 55 | x = self.head(c4, c2, c1, attention_map) 56 | 57 | x = F.interpolate(x, size, mode='bilinear', align_corners=True) 58 | x_b = F.interpolate(x_b, size, mode='bilinear', align_corners=True) 59 | 60 | outputs.append(x) 61 | outputs_b.append(x_b)#.sigmoid()) 62 | 63 | return tuple(outputs), tuple(outputs_b) 64 | 65 | def evaluate(self, image): 66 | """evaluating network with inputs and targets""" 67 | scales = cfg.TEST.SCALES 68 | batch, _, h, w = image.shape 69 | base_size = max(h, w) 70 | # scores = torch.zeros((batch, self.nclass, h, w)).to(image.device) 71 | scores = None 72 | scores_boundary = None 73 | for scale in scales: 74 | long_size = int(math.ceil(base_size * scale)) 75 | if h > w: 76 | height = long_size 77 | width = int(1.0 * w * long_size / h + 0.5) 78 | else: 79 | width = long_size 80 | height = int(1.0 * h * long_size / w + 0.5) 81 | 82 | # resize image to current size 83 | cur_img = _resize_image(image, height, width) 84 | outputs, outputs_boundary = self.forward(cur_img) 85 | outputs = outputs[0][..., :height, :width] 86 | outputs_boundary = outputs_boundary[0][..., :height, :width] 87 | 88 | score = _resize_image(outputs, h, w) 89 | score_boundary = _resize_image(outputs_boundary, h, w) 90 | 91 | if scores is None: 92 | scores = score 93 | scores_boundary = score_boundary 94 | else: 95 | scores += score 96 | scores_boundary += score_boundary 97 | return scores, scores_boundary 98 | 99 | 100 | class _DeepLabHead(nn.Module): 101 | def __init__(self, nclass, c1_channels=256, c4_channels=2048, norm_layer=nn.BatchNorm2d): 102 | super(_DeepLabHead, self).__init__() 103 | # self.use_aspp = cfg.MODEL.DEEPLABV3_PLUS.USE_ASPP 104 | # self.use_decoder = cfg.MODEL.DEEPLABV3_PLUS.ENABLE_DECODER 105 | self.use_aspp = True 106 | self.use_decoder = True 107 | last_channels = c4_channels 108 | if self.use_aspp: 109 | self.aspp = _ASPP(c4_channels, 256) 110 | last_channels = 256 111 | if self.use_decoder: 112 | self.c1_block = _ConvBNReLU(c1_channels, 48, 1, norm_layer=norm_layer) 113 | last_channels += 48 114 | self.block = nn.Sequential( 115 | SeparableConv2d(last_channels, 256, 3, norm_layer=norm_layer, relu_first=False), 116 | SeparableConv2d(256, 256, 3, norm_layer=norm_layer, relu_first=False), 117 | nn.Conv2d(256, nclass, 1)) 118 | 119 | def forward(self, x, c1): 120 | size = c1.size()[2:] 121 | if self.use_aspp: 122 | x = self.aspp(x) 123 | if self.use_decoder: 124 | x = F.interpolate(x, size, mode='bilinear', align_corners=True) 125 | c1 = self.c1_block(c1) 126 | cat_fmap = torch.cat([x, c1], dim=1) 127 | return self.block(cat_fmap) 128 | 129 | return self.block(x) 130 | 131 | 132 | class _DeepLabHead_attention(nn.Module): 133 | def __init__(self, nclass, c1_channels=256, c4_channels=2048, c2_channel=512, norm_layer=nn.BatchNorm2d): 134 | super(_DeepLabHead_attention, self).__init__() 135 | # self.use_aspp = cfg.MODEL.DEEPLABV3_PLUS.USE_ASPP 136 | # self.use_decoder = cfg.MODEL.DEEPLABV3_PLUS.ENABLE_DECODER 137 | self.use_aspp = True 138 | self.use_decoder = True 139 | last_channels = c4_channels 140 | if self.use_aspp: 141 | self.aspp = _ASPP(c4_channels, 256) 142 | last_channels = 256 143 | if self.use_decoder: 144 | self.c1_block = _ConvBNReLU(c1_channels, 48, 1, norm_layer=norm_layer) 145 | last_channels += 48 146 | 147 | self.c2_block = _ConvBNReLU(c2_channel, 24, 1, norm_layer=norm_layer) 148 | last_channels += 24 149 | 150 | self.block = nn.Sequential( 151 | SeparableConv2d(256+24+48, 256, 3, norm_layer=norm_layer, relu_first=False), 152 | SeparableConv2d(256, 256, 3, norm_layer=norm_layer, relu_first=False), 153 | nn.Conv2d(256, nclass, 1)) 154 | 155 | self.block_c2 = nn.Sequential( 156 | SeparableConv2d(256+24, 256+24, 3, norm_layer=norm_layer, relu_first=False), 157 | SeparableConv2d(256+24, 256+24, 3, norm_layer=norm_layer, relu_first=False)) 158 | 159 | 160 | self.fus_head_c2 = FusHead(inplane=256+24) 161 | self.fus_head_c1 = FusHead(inplane=256+24+48) 162 | 163 | 164 | def forward(self, x, c2, c1, attention_map): 165 | c1_size = c1.size()[2:] 166 | c2_size = c2.size()[2:] 167 | if self.use_aspp: 168 | x = self.aspp(x) 169 | 170 | 171 | if self.use_decoder: 172 | x = F.interpolate(x, c2_size, mode='bilinear', align_corners=True) 173 | c2 = self.c2_block(c2) 174 | x = torch.cat([x, c2], dim=1) 175 | x = self.fus_head_c2(x, attention_map) 176 | x = self.block_c2(x) 177 | 178 | x = F.interpolate(x, c1_size, mode='bilinear', align_corners=True) 179 | c1 = self.c1_block(c1) 180 | x = torch.cat([x, c1], dim=1) 181 | x = self.fus_head_c1(x, attention_map) 182 | return self.block(x) 183 | 184 | return self.block(x) 185 | 186 | 187 | class FusHead(nn.Module): 188 | def __init__(self, norm_layer=nn.BatchNorm2d, inplane=256): 189 | super(FusHead, self).__init__() 190 | self.conv1 = SeparableConv2d(inplane*2, inplane, 3, norm_layer=norm_layer, relu_first=False) 191 | self.fc1 = nn.Conv2d(inplane, inplane // 16, kernel_size=1) 192 | self.fc2 = nn.Conv2d(inplane // 16, inplane, kernel_size=1) 193 | 194 | def forward(self, c, att_map): 195 | if c.size() != att_map.size(): 196 | att_map = F.interpolate(att_map, c.size()[2:], mode='bilinear', align_corners=True) 197 | 198 | atted_c = c * att_map 199 | x = torch.cat([c, atted_c], 1)#512 200 | x = self.conv1(x) #256 201 | 202 | weight = F.avg_pool2d(x, x.size(2)) 203 | weight = F.relu(self.fc1(weight)) 204 | weight = torch.sigmoid(self.fc2(weight)) 205 | x = x * weight 206 | return x 207 | -------------------------------------------------------------------------------- /segmentron/modules/__init__.py: -------------------------------------------------------------------------------- 1 | """Seg NN Modules""" 2 | 3 | from .basic import * 4 | from .module import * 5 | from .batch_norm import get_norm -------------------------------------------------------------------------------- /segmentron/modules/basic.py: -------------------------------------------------------------------------------- 1 | """Basic Module for Semantic Segmentation""" 2 | import torch 3 | import torch.nn as nn 4 | 5 | from collections import OrderedDict 6 | 7 | __all__ = ['_ConvBNPReLU', '_ConvBN', '_BNPReLU', '_ConvBNReLU', '_DepthwiseConv', 'InvertedResidual', 8 | 'SeparableConv2d'] 9 | 10 | _USE_FIXED_PAD = False 11 | 12 | 13 | def _pytorch_padding(kernel_size, stride=1, dilation=1, **_): 14 | if _USE_FIXED_PAD: 15 | return 0 # FIXME remove once verified 16 | else: 17 | padding = ((stride - 1) + dilation * (kernel_size - 1)) // 2 18 | 19 | # FIXME remove once verified 20 | fp = _fixed_padding(kernel_size, dilation) 21 | assert all(padding == p for p in fp) 22 | 23 | return padding 24 | 25 | 26 | def _fixed_padding(kernel_size, dilation): 27 | kernel_size_effective = kernel_size + (kernel_size - 1) * (dilation - 1) 28 | pad_total = kernel_size_effective - 1 29 | pad_beg = pad_total // 2 30 | pad_end = pad_total - pad_beg 31 | return [pad_beg, pad_end, pad_beg, pad_end] 32 | 33 | 34 | class SeparableConv2d(nn.Module): 35 | def __init__(self, inplanes, planes, kernel_size=3, stride=1, dilation=1, relu_first=True, 36 | bias=False, norm_layer=nn.BatchNorm2d): 37 | super().__init__() 38 | depthwise = nn.Conv2d(inplanes, inplanes, kernel_size, 39 | stride=stride, padding=dilation, 40 | dilation=dilation, groups=inplanes, bias=bias) 41 | bn_depth = norm_layer(inplanes) 42 | pointwise = nn.Conv2d(inplanes, planes, 1, bias=bias) 43 | bn_point = norm_layer(planes) 44 | 45 | if relu_first: 46 | self.block = nn.Sequential(OrderedDict([('relu', nn.ReLU()), 47 | ('depthwise', depthwise), 48 | ('bn_depth', bn_depth), 49 | ('pointwise', pointwise), 50 | ('bn_point', bn_point) 51 | ])) 52 | else: 53 | self.block = nn.Sequential(OrderedDict([('depthwise', depthwise), 54 | ('bn_depth', bn_depth), 55 | ('relu1', nn.ReLU(inplace=True)), 56 | ('pointwise', pointwise), 57 | ('bn_point', bn_point), 58 | ('relu2', nn.ReLU(inplace=True)) 59 | ])) 60 | 61 | def forward(self, x): 62 | return self.block(x) 63 | 64 | 65 | class _ConvBNReLU(nn.Module): 66 | def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, 67 | dilation=1, groups=1, relu6=False, norm_layer=nn.BatchNorm2d): 68 | super(_ConvBNReLU, self).__init__() 69 | self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, dilation, groups, bias=False) 70 | self.bn = norm_layer(out_channels) 71 | self.relu = nn.ReLU6(True) if relu6 else nn.ReLU(True) 72 | 73 | def forward(self, x): 74 | x = self.conv(x) 75 | x = self.bn(x) 76 | x = self.relu(x) 77 | return x 78 | 79 | 80 | class _ConvBNPReLU(nn.Module): 81 | def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, 82 | dilation=1, groups=1, norm_layer=nn.BatchNorm2d): 83 | super(_ConvBNPReLU, self).__init__() 84 | self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, dilation, groups, bias=False) 85 | self.bn = norm_layer(out_channels) 86 | self.prelu = nn.PReLU(out_channels) 87 | 88 | def forward(self, x): 89 | x = self.conv(x) 90 | x = self.bn(x) 91 | x = self.prelu(x) 92 | return x 93 | 94 | 95 | class _ConvBN(nn.Module): 96 | def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, 97 | dilation=1, groups=1, norm_layer=nn.BatchNorm2d, **kwargs): 98 | super(_ConvBN, self).__init__() 99 | self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, dilation, groups, bias=False) 100 | self.bn = norm_layer(out_channels) 101 | 102 | def forward(self, x): 103 | x = self.conv(x) 104 | x = self.bn(x) 105 | return x 106 | 107 | 108 | class _BNPReLU(nn.Module): 109 | def __init__(self, out_channels, norm_layer=nn.BatchNorm2d): 110 | super(_BNPReLU, self).__init__() 111 | self.bn = norm_layer(out_channels) 112 | self.prelu = nn.PReLU(out_channels) 113 | 114 | def forward(self, x): 115 | x = self.bn(x) 116 | x = self.prelu(x) 117 | return x 118 | 119 | 120 | # ----------------------------------------------------------------- 121 | # For MobileNet 122 | # ----------------------------------------------------------------- 123 | class _DepthwiseConv(nn.Module): 124 | """conv_dw in MobileNet""" 125 | 126 | def __init__(self, in_channels, out_channels, stride, norm_layer=nn.BatchNorm2d, **kwargs): 127 | super(_DepthwiseConv, self).__init__() 128 | self.conv = nn.Sequential( 129 | _ConvBNReLU(in_channels, in_channels, 3, stride, 1, groups=in_channels, norm_layer=norm_layer), 130 | _ConvBNReLU(in_channels, out_channels, 1, norm_layer=norm_layer)) 131 | 132 | def forward(self, x): 133 | return self.conv(x) 134 | 135 | 136 | # ----------------------------------------------------------------- 137 | # For MobileNetV2 138 | # ----------------------------------------------------------------- 139 | class InvertedResidual(nn.Module): 140 | def __init__(self, in_channels, out_channels, stride, expand_ratio, dilation=1, norm_layer=nn.BatchNorm2d): 141 | super(InvertedResidual, self).__init__() 142 | assert stride in [1, 2] 143 | self.use_res_connect = stride == 1 and in_channels == out_channels 144 | 145 | layers = list() 146 | inter_channels = int(round(in_channels * expand_ratio)) 147 | if expand_ratio != 1: 148 | # pw 149 | layers.append(_ConvBNReLU(in_channels, inter_channels, 1, relu6=True, norm_layer=norm_layer)) 150 | layers.extend([ 151 | # dw 152 | _ConvBNReLU(inter_channels, inter_channels, 3, stride, dilation, dilation, 153 | groups=inter_channels, relu6=True, norm_layer=norm_layer), 154 | # pw-linear 155 | nn.Conv2d(inter_channels, out_channels, 1, bias=False), 156 | norm_layer(out_channels)]) 157 | self.conv = nn.Sequential(*layers) 158 | 159 | def forward(self, x): 160 | if self.use_res_connect: 161 | return x + self.conv(x) 162 | else: 163 | return self.conv(x) 164 | 165 | 166 | if __name__ == '__main__': 167 | x = torch.randn(1, 32, 64, 64) 168 | model = InvertedResidual(32, 64, 2, 1) 169 | out = model(x) 170 | -------------------------------------------------------------------------------- /segmentron/modules/batch_norm.py: -------------------------------------------------------------------------------- 1 | # this code heavily based on detectron2 2 | import logging 3 | import torch 4 | import torch.distributed as dist 5 | from torch import nn 6 | from torch.autograd.function import Function 7 | from ..utils.distributed import get_world_size 8 | 9 | 10 | class FrozenBatchNorm2d(nn.Module): 11 | """ 12 | BatchNorm2d where the batch statistics and the affine parameters are fixed. 13 | 14 | It contains non-trainable buffers called 15 | "weight" and "bias", "running_mean", "running_var", 16 | initialized to perform identity transformation. 17 | 18 | The pre-trained backbone models from Caffe2 only contain "weight" and "bias", 19 | which are computed from the original four parameters of BN. 20 | The affine transform `x * weight + bias` will perform the equivalent 21 | computation of `(x - running_mean) / sqrt(running_var) * weight + bias`. 22 | When loading a backbone model from Caffe2, "running_mean" and "running_var" 23 | will be left unchanged as identity transformation. 24 | 25 | Other pre-trained backbone models may contain all 4 parameters. 26 | 27 | The forward is implemented by `F.batch_norm(..., training=False)`. 28 | """ 29 | 30 | _version = 3 31 | 32 | def __init__(self, num_features, eps=1e-5): 33 | super().__init__() 34 | self.num_features = num_features 35 | self.eps = eps 36 | self.register_buffer("weight", torch.ones(num_features)) 37 | self.register_buffer("bias", torch.zeros(num_features)) 38 | self.register_buffer("running_mean", torch.zeros(num_features)) 39 | self.register_buffer("running_var", torch.ones(num_features) - eps) 40 | 41 | def forward(self, x): 42 | scale = self.weight * (self.running_var + self.eps).rsqrt() 43 | bias = self.bias - self.running_mean * scale 44 | scale = scale.reshape(1, -1, 1, 1) 45 | bias = bias.reshape(1, -1, 1, 1) 46 | return x * scale + bias 47 | 48 | def _load_from_state_dict( 49 | self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs 50 | ): 51 | version = local_metadata.get("version", None) 52 | 53 | if version is None or version < 2: 54 | # No running_mean/var in early versions 55 | # This will silent the warnings 56 | if prefix + "running_mean" not in state_dict: 57 | state_dict[prefix + "running_mean"] = torch.zeros_like(self.running_mean) 58 | if prefix + "running_var" not in state_dict: 59 | state_dict[prefix + "running_var"] = torch.ones_like(self.running_var) 60 | 61 | if version is not None and version < 3: 62 | # logger = logging.getLogger(__name__) 63 | logging.info("FrozenBatchNorm {} is upgraded to version 3.".format(prefix.rstrip("."))) 64 | # In version < 3, running_var are used without +eps. 65 | state_dict[prefix + "running_var"] -= self.eps 66 | 67 | super()._load_from_state_dict( 68 | state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs 69 | ) 70 | 71 | def __repr__(self): 72 | return "FrozenBatchNorm2d(num_features={}, eps={})".format(self.num_features, self.eps) 73 | 74 | @classmethod 75 | def convert_frozen_batchnorm(cls, module): 76 | """ 77 | Convert BatchNorm/SyncBatchNorm in module into FrozenBatchNorm. 78 | 79 | Args: 80 | module (torch.nn.Module): 81 | 82 | Returns: 83 | If module is BatchNorm/SyncBatchNorm, returns a new module. 84 | Otherwise, in-place convert module and return it. 85 | 86 | Similar to convert_sync_batchnorm in 87 | https://github.com/pytorch/pytorch/blob/master/torch/nn/modules/batchnorm.py 88 | """ 89 | bn_module = nn.modules.batchnorm 90 | bn_module = (bn_module.BatchNorm2d, bn_module.SyncBatchNorm) 91 | res = module 92 | if isinstance(module, bn_module): 93 | res = cls(module.num_features) 94 | if module.affine: 95 | res.weight.data = module.weight.data.clone().detach() 96 | res.bias.data = module.bias.data.clone().detach() 97 | res.running_mean.data = module.running_mean.data 98 | res.running_var.data = module.running_var.data + module.eps 99 | else: 100 | for name, child in module.named_children(): 101 | new_child = cls.convert_frozen_batchnorm(child) 102 | if new_child is not child: 103 | res.add_module(name, new_child) 104 | return res 105 | 106 | 107 | def groupNorm(num_channels, eps=1e-5, momentum=0.1, affine=True): 108 | return nn.GroupNorm(min(32, num_channels), num_channels, eps=eps, affine=affine) 109 | 110 | 111 | def get_norm(norm): 112 | """ 113 | Args: 114 | norm (str or callable): 115 | 116 | Returns: 117 | nn.Module or None: the normalization layer 118 | """ 119 | support_norm_type = ['BN', 'SyncBN', 'FrozenBN', 'GN', 'nnSyncBN'] 120 | assert norm in support_norm_type, 'Unknown norm type {}, support norm types are {}'.format( 121 | norm, support_norm_type) 122 | if isinstance(norm, str): 123 | if len(norm) == 0: 124 | return None 125 | norm = { 126 | "BN": nn.BatchNorm2d, 127 | "SyncBN": NaiveSyncBatchNorm, 128 | "FrozenBN": FrozenBatchNorm2d, 129 | "GN": groupNorm, 130 | "nnSyncBN": nn.SyncBatchNorm, # keep for debugging 131 | }[norm] 132 | return norm 133 | 134 | 135 | class AllReduce(Function): 136 | @staticmethod 137 | def forward(ctx, input): 138 | input_list = [torch.zeros_like(input) for k in range(dist.get_world_size())] 139 | # Use allgather instead of allreduce since I don't trust in-place operations .. 140 | dist.all_gather(input_list, input, async_op=False) 141 | inputs = torch.stack(input_list, dim=0) 142 | return torch.sum(inputs, dim=0) 143 | 144 | @staticmethod 145 | def backward(ctx, grad_output): 146 | dist.all_reduce(grad_output, async_op=False) 147 | return grad_output 148 | 149 | 150 | class NaiveSyncBatchNorm(nn.BatchNorm2d): 151 | """ 152 | `torch.nn.SyncBatchNorm` has known unknown bugs. 153 | It produces significantly worse AP (and sometimes goes NaN) 154 | when the batch size on each worker is quite different 155 | (e.g., when scale augmentation is used, or when it is applied to mask head). 156 | 157 | Use this implementation before `nn.SyncBatchNorm` is fixed. 158 | It is slower than `nn.SyncBatchNorm`. 159 | """ 160 | 161 | def forward(self, input): 162 | if get_world_size() == 1 or not self.training: 163 | return super().forward(input) 164 | 165 | assert input.shape[0] > 0, "SyncBatchNorm does not support empty inputs" 166 | C = input.shape[1] 167 | mean = torch.mean(input, dim=[0, 2, 3]) 168 | meansqr = torch.mean(input * input, dim=[0, 2, 3]) 169 | 170 | vec = torch.cat([mean, meansqr], dim=0) 171 | vec = AllReduce.apply(vec) * (1.0 / dist.get_world_size()) 172 | 173 | mean, meansqr = torch.split(vec, C) 174 | var = meansqr - mean * mean 175 | self.running_mean += self.momentum * (mean.detach() - self.running_mean) 176 | self.running_var += self.momentum * (var.detach() - self.running_var) 177 | 178 | invstd = torch.rsqrt(var + self.eps) 179 | scale = self.weight * invstd 180 | bias = self.bias - mean * scale 181 | scale = scale.reshape(1, -1, 1, 1) 182 | bias = bias.reshape(1, -1, 1, 1) 183 | return input * scale + bias 184 | -------------------------------------------------------------------------------- /segmentron/modules/cc_attention.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from torch.autograd.function import once_differentiable 6 | from segmentron import _C 7 | 8 | __all__ = ['CrissCrossAttention', 'ca_weight', 'ca_map'] 9 | 10 | 11 | class _CAWeight(torch.autograd.Function): 12 | @staticmethod 13 | def forward(ctx, t, f): 14 | weight = _C.ca_forward(t, f) 15 | 16 | ctx.save_for_backward(t, f) 17 | 18 | return weight 19 | 20 | @staticmethod 21 | @once_differentiable 22 | def backward(ctx, dw): 23 | t, f = ctx.saved_tensors 24 | 25 | dt, df = _C.ca_backward(dw, t, f) 26 | return dt, df 27 | 28 | 29 | class _CAMap(torch.autograd.Function): 30 | @staticmethod 31 | def forward(ctx, weight, g): 32 | out = _C.ca_map_forward(weight, g) 33 | 34 | ctx.save_for_backward(weight, g) 35 | 36 | return out 37 | 38 | @staticmethod 39 | @once_differentiable 40 | def backward(ctx, dout): 41 | weight, g = ctx.saved_tensors 42 | 43 | dw, dg = _C.ca_map_backward(dout, weight, g) 44 | 45 | return dw, dg 46 | 47 | 48 | ca_weight = _CAWeight.apply 49 | ca_map = _CAMap.apply 50 | 51 | 52 | class CrissCrossAttention(nn.Module): 53 | """Criss-Cross Attention Module""" 54 | 55 | def __init__(self, in_channels): 56 | super(CrissCrossAttention, self).__init__() 57 | self.query_conv = nn.Conv2d(in_channels, in_channels // 8, 1) 58 | self.key_conv = nn.Conv2d(in_channels, in_channels // 8, 1) 59 | self.value_conv = nn.Conv2d(in_channels, in_channels, 1) 60 | self.gamma = nn.Parameter(torch.zeros(1)) 61 | 62 | def forward(self, x): 63 | proj_query = self.query_conv(x) 64 | proj_key = self.key_conv(x) 65 | proj_value = self.value_conv(x) 66 | 67 | energy = ca_weight(proj_query, proj_key) 68 | attention = F.softmax(energy, 1) 69 | out = ca_map(attention, proj_value) 70 | out = self.gamma * out + x 71 | 72 | return out 73 | -------------------------------------------------------------------------------- /segmentron/modules/csrc/criss_cross_attention/ca.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include 3 | #include 4 | 5 | namespace segmentron { 6 | at::Tensor ca_forward_cuda( 7 | const at::Tensor& t, 8 | const at::Tensor& f); 9 | 10 | std::tuple ca_backward_cuda( 11 | const at::Tensor& dw, 12 | const at::Tensor& t, 13 | const at::Tensor& f); 14 | 15 | at::Tensor ca_map_forward_cuda( 16 | const at::Tensor& weight, 17 | const at::Tensor& g); 18 | 19 | std::tuple ca_map_backward_cuda( 20 | const at::Tensor& dout, 21 | const at::Tensor& weight, 22 | const at::Tensor& g); 23 | 24 | 25 | at::Tensor ca_forward(const at::Tensor& t, 26 | const at::Tensor& f) { 27 | if (t.type().is_cuda()) { 28 | #ifdef WITH_CUDA 29 | return ca_forward_cuda(t, f); 30 | #else 31 | AT_ERROR("Not compiled with GPU support"); 32 | #endif 33 | } 34 | AT_ERROR("Not implemented on the CPU"); 35 | } 36 | 37 | std::tuple ca_backward(const at::Tensor& dw, 38 | const at::Tensor& t, 39 | const at::Tensor& f) { 40 | if (dw.type().is_cuda()) { 41 | #ifdef WITH_CUDA 42 | return ca_backward_cuda(dw, t, f); 43 | #else 44 | AT_ERROR("Not compiled with GPU support"); 45 | #endif 46 | } 47 | AT_ERROR("Not implemented on the CPU"); 48 | } 49 | 50 | at::Tensor ca_map_forward(const at::Tensor& weight, 51 | const at::Tensor& g) { 52 | if (weight.type().is_cuda()) { 53 | #ifdef WITH_CUDA 54 | return ca_map_forward_cuda(weight, g); 55 | #else 56 | AT_ERROR("Not compiled with GPU support"); 57 | #endif 58 | } 59 | AT_ERROR("Not implemented on the CPU"); 60 | } 61 | 62 | std::tuple ca_map_backward(const at::Tensor& dout, 63 | const at::Tensor& weight, 64 | const at::Tensor& g) { 65 | if (dout.type().is_cuda()) { 66 | #ifdef WITH_CUDA 67 | return ca_map_backward_cuda(dout, weight, g); 68 | #else 69 | AT_ERROR("Not compiled with GPU support"); 70 | #endif 71 | } 72 | AT_ERROR("Not implemented on the CPU"); 73 | } 74 | 75 | } // namespace segmentron 76 | -------------------------------------------------------------------------------- /segmentron/modules/csrc/vision.cpp: -------------------------------------------------------------------------------- 1 | 2 | #include "criss_cross_attention/ca.h" 3 | 4 | namespace segmentron { 5 | 6 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 7 | m.def("ca_forward", &ca_forward, "ca_forward"); 8 | m.def("ca_backward", &ca_backward, "ca_backward"); 9 | m.def("ca_map_forward", &ca_map_forward, "ca_map_forward"); 10 | m.def("ca_map_backward", &ca_map_backward, "ca_map_backward"); 11 | } 12 | 13 | } // namespace segmentron 14 | -------------------------------------------------------------------------------- /segmentron/modules/module.py: -------------------------------------------------------------------------------- 1 | """Basic Module for Semantic Segmentation""" 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | from collections import OrderedDict 7 | from .basic import _ConvBNReLU, SeparableConv2d, _ConvBN, _BNPReLU, _ConvBNPReLU 8 | from ..config import cfg 9 | from IPython import embed 10 | 11 | __all__ = ['_FCNHead', '_ASPP', 'PyramidPooling', 'PAM_Module', 'CAM_Module', 'EESP'] 12 | 13 | 14 | class _FCNHead(nn.Module): 15 | def __init__(self, in_channels, channels, norm_layer=nn.BatchNorm2d): 16 | super(_FCNHead, self).__init__() 17 | inter_channels = in_channels // 4 18 | self.block = nn.Sequential( 19 | nn.Conv2d(in_channels, inter_channels, 3, padding=1, bias=False), 20 | norm_layer(inter_channels), 21 | nn.ReLU(inplace=True), 22 | nn.Dropout(0.1), 23 | nn.Conv2d(inter_channels, channels, 1) 24 | ) 25 | 26 | def forward(self, x): 27 | return self.block(x) 28 | 29 | 30 | # ----------------------------------------------------------------- 31 | # For deeplab 32 | # ----------------------------------------------------------------- 33 | class _ASPP(nn.Module): 34 | def __init__(self, in_channels=2048, out_channels=256): 35 | super().__init__() 36 | output_stride = cfg.MODEL.OUTPUT_STRIDE 37 | if output_stride == 16: 38 | dilations = [6, 12, 18] 39 | elif output_stride == 8: 40 | dilations = [12, 24, 36] 41 | elif output_stride == 32: 42 | dilations = [6, 12, 18] 43 | else: 44 | raise NotImplementedError 45 | 46 | self.aspp0 = nn.Sequential(OrderedDict([('conv', nn.Conv2d(in_channels, out_channels, 1, bias=False)), 47 | ('bn', nn.BatchNorm2d(out_channels)), 48 | ('relu', nn.ReLU(inplace=True))])) 49 | self.aspp1 = SeparableConv2d(in_channels, out_channels, dilation=dilations[0], relu_first=False) 50 | self.aspp2 = SeparableConv2d(in_channels, out_channels, dilation=dilations[1], relu_first=False) 51 | self.aspp3 = SeparableConv2d(in_channels, out_channels, dilation=dilations[2], relu_first=False) 52 | 53 | self.image_pooling = nn.Sequential(OrderedDict([('gap', nn.AdaptiveAvgPool2d((1, 1))), 54 | ('conv', nn.Conv2d(in_channels, out_channels, 1, bias=False)), 55 | ('bn', nn.BatchNorm2d(out_channels)), 56 | ('relu', nn.ReLU(inplace=True))])) 57 | 58 | self.conv = nn.Conv2d(out_channels*5, out_channels, 1, bias=False) 59 | self.bn = nn.BatchNorm2d(out_channels) 60 | self.relu = nn.ReLU(inplace=True) 61 | self.dropout = nn.Dropout2d(p=0.1) 62 | 63 | def forward(self, x): 64 | pool = self.image_pooling(x) 65 | pool = F.interpolate(pool, size=x.shape[2:], mode='bilinear', align_corners=True) 66 | 67 | x0 = self.aspp0(x) 68 | x1 = self.aspp1(x) 69 | x2 = self.aspp2(x) 70 | x3 = self.aspp3(x) 71 | x = torch.cat((pool, x0, x1, x2, x3), dim=1) 72 | 73 | x = self.conv(x) 74 | x = self.bn(x) 75 | x = self.relu(x) 76 | x = self.dropout(x) 77 | 78 | return x 79 | 80 | # ----------------------------------------------------------------- 81 | # For PSPNet, fast_scnn 82 | # ----------------------------------------------------------------- 83 | class PyramidPooling(nn.Module): 84 | def __init__(self, in_channels, sizes=(1, 2, 3, 6), norm_layer=nn.BatchNorm2d, **kwargs): 85 | super(PyramidPooling, self).__init__() 86 | out_channels = int(in_channels / 4) 87 | self.avgpools = nn.ModuleList() 88 | self.convs = nn.ModuleList() 89 | for size in sizes: 90 | self.avgpools.append(nn.AdaptiveAvgPool2d(size)) 91 | self.convs.append(_ConvBNReLU(in_channels, out_channels, 1, norm_layer=norm_layer)) 92 | 93 | def forward(self, x): 94 | size = x.size()[2:] 95 | feats = [x] 96 | for (avgpool, conv) in zip(self.avgpools, self.convs): 97 | feats.append(F.interpolate(conv(avgpool(x)), size, mode='bilinear', align_corners=True)) 98 | return torch.cat(feats, dim=1) 99 | 100 | 101 | class PAM_Module(nn.Module): 102 | """ Position attention module""" 103 | def __init__(self, in_dim): 104 | super(PAM_Module, self).__init__() 105 | self.chanel_in = in_dim 106 | 107 | self.query_conv = nn.Conv2d(in_channels=in_dim, out_channels=in_dim//8, kernel_size=1) 108 | self.key_conv = nn.Conv2d(in_channels=in_dim, out_channels=in_dim//8, kernel_size=1) 109 | self.value_conv = nn.Conv2d(in_channels=in_dim, out_channels=in_dim, kernel_size=1) 110 | self.gamma = nn.Parameter(torch.zeros(1)) 111 | self.softmax = nn.Softmax(dim=-1) 112 | 113 | def forward(self, x): 114 | """ 115 | inputs : 116 | x : input feature maps( B X C X H X W) 117 | returns : 118 | out : attention value + input feature 119 | attention: B X (HxW) X (HxW) 120 | """ 121 | m_batchsize, C, height, width = x.size() 122 | proj_query = self.query_conv(x).view(m_batchsize, -1, width*height).permute(0, 2, 1) 123 | proj_key = self.key_conv(x).view(m_batchsize, -1, width*height) 124 | energy = torch.bmm(proj_query, proj_key) 125 | attention = self.softmax(energy) 126 | proj_value = self.value_conv(x).view(m_batchsize, -1, width*height) 127 | 128 | out = torch.bmm(proj_value, attention.permute(0, 2, 1)) 129 | out = out.view(m_batchsize, C, height, width) 130 | 131 | out = self.gamma*out + x 132 | return out 133 | 134 | 135 | class CAM_Module(nn.Module): 136 | """ Channel attention module""" 137 | def __init__(self, in_dim): 138 | super(CAM_Module, self).__init__() 139 | self.chanel_in = in_dim 140 | self.gamma = nn.Parameter(torch.zeros(1)) 141 | self.softmax = nn.Softmax(dim=-1) 142 | 143 | def forward(self,x): 144 | """ 145 | inputs : 146 | x : input feature maps( B X C X H X W) 147 | returns : 148 | out : attention value + input feature 149 | attention: B X C X C 150 | """ 151 | m_batchsize, C, height, width = x.size() 152 | proj_query = x.view(m_batchsize, C, -1) 153 | proj_key = x.view(m_batchsize, C, -1).permute(0, 2, 1) 154 | energy = torch.bmm(proj_query, proj_key) 155 | energy_new = torch.max(energy, -1, keepdim=True)[0].expand_as(energy)-energy 156 | attention = self.softmax(energy_new) 157 | proj_value = x.view(m_batchsize, C, -1) 158 | 159 | out = torch.bmm(attention, proj_value) 160 | out = out.view(m_batchsize, C, height, width) 161 | 162 | out = self.gamma*out + x 163 | return out 164 | 165 | 166 | class EESP(nn.Module): 167 | 168 | def __init__(self, in_channels, out_channels, stride=1, k=4, r_lim=7, down_method='esp', norm_layer=nn.BatchNorm2d): 169 | super(EESP, self).__init__() 170 | self.stride = stride 171 | n = int(out_channels / k) 172 | n1 = out_channels - (k - 1) * n 173 | assert down_method in ['avg', 'esp'], 'One of these is suppported (avg or esp)' 174 | assert n == n1, "n(={}) and n1(={}) should be equal for Depth-wise Convolution ".format(n, n1) 175 | self.proj_1x1 = _ConvBNPReLU(in_channels, n, 1, stride=1, groups=k, norm_layer=norm_layer) 176 | 177 | map_receptive_ksize = {3: 1, 5: 2, 7: 3, 9: 4, 11: 5, 13: 6, 15: 7, 17: 8} 178 | self.k_sizes = list() 179 | for i in range(k): 180 | ksize = int(3 + 2 * i) 181 | ksize = ksize if ksize <= r_lim else 3 182 | self.k_sizes.append(ksize) 183 | self.k_sizes.sort() 184 | self.spp_dw = nn.ModuleList() 185 | for i in range(k): 186 | dilation = map_receptive_ksize[self.k_sizes[i]] 187 | self.spp_dw.append(nn.Conv2d(n, n, 3, stride, dilation, dilation=dilation, groups=n, bias=False)) 188 | self.conv_1x1_exp = _ConvBN(out_channels, out_channels, 1, 1, groups=k, norm_layer=norm_layer) 189 | self.br_after_cat = _BNPReLU(out_channels, norm_layer) 190 | self.module_act = nn.PReLU(out_channels) 191 | self.downAvg = True if down_method == 'avg' else False 192 | 193 | def forward(self, x): 194 | output1 = self.proj_1x1(x) 195 | output = [self.spp_dw[0](output1)] 196 | for k in range(1, len(self.spp_dw)): 197 | out_k = self.spp_dw[k](output1) 198 | out_k = out_k + output[k - 1] 199 | output.append(out_k) 200 | expanded = self.conv_1x1_exp(self.br_after_cat(torch.cat(output, 1))) 201 | del output 202 | if self.stride == 2 and self.downAvg: 203 | return expanded 204 | 205 | if expanded.size() == x.size(): 206 | expanded = expanded + x 207 | 208 | return self.module_act(expanded) -------------------------------------------------------------------------------- /segmentron/modules/sync_bn/syncbn.py: -------------------------------------------------------------------------------- 1 | ##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 2 | ## Created by: Hang Zhang 3 | ## ECE Department, Rutgers University 4 | ## Email: zhang.hang@rutgers.edu 5 | ## Copyright (c) 2017 6 | ## 7 | ## This source code is licensed under the MIT-style license found in the 8 | ## LICENSE file in the root directory of this source tree 9 | ##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 10 | 11 | """Synchronized Cross-GPU Batch Normalization Module""" 12 | import warnings 13 | import torch 14 | 15 | from torch.nn.modules.batchnorm import _BatchNorm 16 | from queue import Queue 17 | from .functions import * 18 | 19 | __all__ = ['SyncBatchNorm', 'BatchNorm1d', 'BatchNorm2d', 'BatchNorm3d'] 20 | 21 | 22 | # Adopt from https://github.com/zhanghang1989/PyTorch-Encoding/blob/master/encoding/nn/syncbn.py 23 | class SyncBatchNorm(_BatchNorm): 24 | """Cross-GPU Synchronized Batch normalization (SyncBN) 25 | 26 | Parameters: 27 | num_features: num_features from an expected input of 28 | size batch_size x num_features x height x width 29 | eps: a value added to the denominator for numerical stability. 30 | Default: 1e-5 31 | momentum: the value used for the running_mean and running_var 32 | computation. Default: 0.1 33 | sync: a boolean value that when set to ``True``, synchronize across 34 | different gpus. Default: ``True`` 35 | activation : str 36 | Name of the activation functions, one of: `leaky_relu` or `none`. 37 | slope : float 38 | Negative slope for the `leaky_relu` activation. 39 | 40 | Shape: 41 | - Input: :math:`(N, C, H, W)` 42 | - Output: :math:`(N, C, H, W)` (same shape as input) 43 | Reference: 44 | .. [1] Ioffe, Sergey, and Christian Szegedy. "Batch normalization: Accelerating deep network training by reducing internal covariate shift." *ICML 2015* 45 | .. [2] Hang Zhang, Kristin Dana, Jianping Shi, Zhongyue Zhang, Xiaogang Wang, Ambrish Tyagi, and Amit Agrawal. "Context Encoding for Semantic Segmentation." *CVPR 2018* 46 | Examples: 47 | >>> m = SyncBatchNorm(100) 48 | >>> net = torch.nn.DataParallel(m) 49 | >>> output = net(input) 50 | """ 51 | 52 | def __init__(self, num_features, eps=1e-5, momentum=0.1, sync=True, activation='none', slope=0.01, inplace=True): 53 | super(SyncBatchNorm, self).__init__(num_features, eps=eps, momentum=momentum, affine=True) 54 | self.activation = activation 55 | self.inplace = False if activation == 'none' else inplace 56 | self.slope = slope 57 | self.devices = list(range(torch.cuda.device_count())) 58 | self.sync = sync if len(self.devices) > 1 else False 59 | # Initialize queues 60 | self.worker_ids = self.devices[1:] 61 | self.master_queue = Queue(len(self.worker_ids)) 62 | self.worker_queues = [Queue(1) for _ in self.worker_ids] 63 | 64 | def forward(self, x): 65 | # resize the input to (B, C, -1) 66 | input_shape = x.size() 67 | x = x.view(input_shape[0], self.num_features, -1) 68 | if x.get_device() == self.devices[0]: 69 | # Master mode 70 | extra = { 71 | "is_master": True, 72 | "master_queue": self.master_queue, 73 | "worker_queues": self.worker_queues, 74 | "worker_ids": self.worker_ids 75 | } 76 | else: 77 | # Worker mode 78 | extra = { 79 | "is_master": False, 80 | "master_queue": self.master_queue, 81 | "worker_queue": self.worker_queues[self.worker_ids.index(x.get_device())] 82 | } 83 | if self.inplace: 84 | return inp_syncbatchnorm(x, self.weight, self.bias, self.running_mean, self.running_var, 85 | extra, self.sync, self.training, self.momentum, self.eps, 86 | self.activation, self.slope).view(input_shape) 87 | else: 88 | return syncbatchnorm(x, self.weight, self.bias, self.running_mean, self.running_var, 89 | extra, self.sync, self.training, self.momentum, self.eps, 90 | self.activation, self.slope).view(input_shape) 91 | 92 | def extra_repr(self): 93 | if self.activation == 'none': 94 | return 'sync={}'.format(self.sync) 95 | else: 96 | return 'sync={}, act={}, slope={}, inplace={}'.format( 97 | self.sync, self.activation, self.slope, self.inplace) 98 | 99 | 100 | class BatchNorm1d(SyncBatchNorm): 101 | """BatchNorm1d is deprecated in favor of :class:`core.nn.sync_bn.SyncBatchNorm`.""" 102 | 103 | def __init__(self, *args, **kwargs): 104 | warnings.warn("core.nn.sync_bn.{} is now deprecated in favor of core.nn.sync_bn.{}." 105 | .format('BatchNorm1d', SyncBatchNorm.__name__), DeprecationWarning) 106 | super(BatchNorm1d, self).__init__(*args, **kwargs) 107 | 108 | 109 | class BatchNorm2d(SyncBatchNorm): 110 | """BatchNorm1d is deprecated in favor of :class:`core.nn.sync_bn.SyncBatchNorm`.""" 111 | 112 | def __init__(self, *args, **kwargs): 113 | warnings.warn("core.nn.sync_bn.{} is now deprecated in favor of core.nn.sync_bn.{}." 114 | .format('BatchNorm2d', SyncBatchNorm.__name__), DeprecationWarning) 115 | super(BatchNorm2d, self).__init__(*args, **kwargs) 116 | 117 | 118 | class BatchNorm3d(SyncBatchNorm): 119 | """BatchNorm1d is deprecated in favor of :class:`core.nn.sync_bn.SyncBatchNorm`.""" 120 | 121 | def __init__(self, *args, **kwargs): 122 | warnings.warn("core.nn.sync_bn.{} is now deprecated in favor of core.nn.sync_bn.{}." 123 | .format('BatchNorm3d', SyncBatchNorm.__name__), DeprecationWarning) 124 | super(BatchNorm3d, self).__init__(*args, **kwargs) 125 | -------------------------------------------------------------------------------- /segmentron/solver/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xieenze/Segment_Transparent_Objects/06a9c806d32fec178e37700095f0c5443a4f109a/segmentron/solver/__init__.py -------------------------------------------------------------------------------- /segmentron/solver/lovasz_losses.py: -------------------------------------------------------------------------------- 1 | """ 2 | Lovasz-Softmax and Jaccard hinge loss in PyTorch 3 | Maxim Berman 2018 ESAT-PSI KU Leuven (MIT License) 4 | https://github.com/bermanmaxim/LovaszSoftmax/blob/master/pytorch/lovasz_losses.py 5 | """ 6 | 7 | from __future__ import print_function, division 8 | 9 | import torch 10 | from torch.autograd import Variable 11 | import torch.nn.functional as F 12 | import numpy as np 13 | 14 | try: 15 | from itertools import ifilterfalse 16 | except ImportError: # py3k 17 | from itertools import filterfalse as ifilterfalse 18 | 19 | 20 | def lovasz_grad(gt_sorted): 21 | """ 22 | Computes gradient of the Lovasz extension w.r.t sorted errors 23 | See Alg. 1 in paper 24 | """ 25 | p = len(gt_sorted) 26 | gts = gt_sorted.sum() 27 | intersection = gts - gt_sorted.float().cumsum(0) 28 | union = gts + (1 - gt_sorted).float().cumsum(0) 29 | jaccard = 1. - intersection / union 30 | if p > 1: # cover 1-pixel case 31 | jaccard[1:p] = jaccard[1:p] - jaccard[0:-1] 32 | return jaccard 33 | 34 | 35 | def iou_binary(preds, labels, EMPTY=1., ignore=None, per_image=True): 36 | """ 37 | IoU for foreground class 38 | binary: 1 foreground, 0 background 39 | """ 40 | if not per_image: 41 | preds, labels = (preds,), (labels,) 42 | ious = [] 43 | for pred, label in zip(preds, labels): 44 | intersection = ((label == 1) & (pred == 1)).sum() 45 | union = ((label == 1) | ((pred == 1) & (label != ignore))).sum() 46 | if not union: 47 | iou = EMPTY 48 | else: 49 | iou = float(intersection) / float(union) 50 | ious.append(iou) 51 | iou = mean(ious) # mean accross images if per_image 52 | return 100 * iou 53 | 54 | 55 | def iou(preds, labels, C, EMPTY=1., ignore=None, per_image=False): 56 | """ 57 | Array of IoU for each (non ignored) class 58 | """ 59 | if not per_image: 60 | preds, labels = (preds,), (labels,) 61 | ious = [] 62 | for pred, label in zip(preds, labels): 63 | iou = [] 64 | for i in range(C): 65 | if i != ignore: # The ignored label is sometimes among predicted classes (ENet - CityScapes) 66 | intersection = ((label == i) & (pred == i)).sum() 67 | union = ((label == i) | ((pred == i) & (label != ignore))).sum() 68 | if not union: 69 | iou.append(EMPTY) 70 | else: 71 | iou.append(float(intersection) / float(union)) 72 | ious.append(iou) 73 | ious = [mean(iou) for iou in zip(*ious)] # mean accross images if per_image 74 | return 100 * np.array(ious) 75 | 76 | 77 | # --------------------------- BINARY LOSSES --------------------------- 78 | 79 | def lovasz_hinge(logits, labels, per_image=True, ignore=None): 80 | """ 81 | Binary Lovasz hinge loss 82 | logits: [B, H, W] Variable, logits at each pixel (between -\infty and +\infty) 83 | labels: [B, H, W] Tensor, binary ground truth masks (0 or 1) 84 | per_image: compute the loss per image instead of per batch 85 | ignore: void class id 86 | """ 87 | if per_image: 88 | loss = mean(lovasz_hinge_flat(*flatten_binary_scores(log.unsqueeze(0), lab.unsqueeze(0), ignore)) 89 | for log, lab in zip(logits, labels)) 90 | else: 91 | loss = lovasz_hinge_flat(*flatten_binary_scores(logits, labels, ignore)) 92 | return loss 93 | 94 | 95 | def lovasz_hinge_flat(logits, labels): 96 | """ 97 | Binary Lovasz hinge loss 98 | logits: [P] Variable, logits at each prediction (between -\infty and +\infty) 99 | labels: [P] Tensor, binary ground truth labels (0 or 1) 100 | ignore: label to ignore 101 | """ 102 | if len(labels) == 0: 103 | # only void pixels, the gradients should be 0 104 | return logits.sum() * 0. 105 | signs = 2. * labels.float() - 1. 106 | errors = (1. - logits * Variable(signs)) 107 | errors_sorted, perm = torch.sort(errors, dim=0, descending=True) 108 | perm = perm.data 109 | gt_sorted = labels[perm] 110 | grad = lovasz_grad(gt_sorted) 111 | loss = torch.dot(F.relu(errors_sorted), Variable(grad)) 112 | return loss 113 | 114 | 115 | def flatten_binary_scores(scores, labels, ignore=None): 116 | """ 117 | Flattens predictions in the batch (binary case) 118 | Remove labels equal to 'ignore' 119 | """ 120 | scores = scores.view(-1) 121 | labels = labels.view(-1) 122 | if ignore is None: 123 | return scores, labels 124 | valid = (labels != ignore) 125 | vscores = scores[valid] 126 | vlabels = labels[valid] 127 | return vscores, vlabels 128 | 129 | 130 | class StableBCELoss(torch.nn.modules.Module): 131 | def __init__(self): 132 | super(StableBCELoss, self).__init__() 133 | 134 | def forward(self, input, target): 135 | neg_abs = - input.abs() 136 | loss = input.clamp(min=0) - input * target + (1 + neg_abs.exp()).log() 137 | return loss.mean() 138 | 139 | 140 | def binary_xloss(logits, labels, ignore=None): 141 | """ 142 | Binary Cross entropy loss 143 | logits: [B, H, W] Variable, logits at each pixel (between -\infty and +\infty) 144 | labels: [B, H, W] Tensor, binary ground truth masks (0 or 1) 145 | ignore: void class id 146 | """ 147 | logits, labels = flatten_binary_scores(logits, labels, ignore) 148 | loss = StableBCELoss()(logits, Variable(labels.float())) 149 | return loss 150 | 151 | 152 | # --------------------------- MULTICLASS LOSSES --------------------------- 153 | 154 | 155 | def lovasz_softmax(probas, labels, classes='present', per_image=False, ignore=None): 156 | """ 157 | Multi-class Lovasz-Softmax loss 158 | probas: [B, C, H, W] Variable, class probabilities at each prediction (between 0 and 1). 159 | Interpreted as binary (sigmoid) output with outputs of size [B, H, W]. 160 | labels: [B, H, W] Tensor, ground truth labels (between 0 and C - 1) 161 | classes: 'all' for all, 'present' for classes present in labels, or a list of classes to average. 162 | per_image: compute the loss per image instead of per batch 163 | ignore: void class labels 164 | """ 165 | if per_image: 166 | loss = mean(lovasz_softmax_flat(*flatten_probas(prob.unsqueeze(0), lab.unsqueeze(0), ignore), classes=classes) 167 | for prob, lab in zip(probas, labels)) 168 | else: 169 | loss = lovasz_softmax_flat(*flatten_probas(probas, labels, ignore), classes=classes) 170 | return loss 171 | 172 | 173 | def lovasz_softmax_flat(probas, labels, classes='present'): 174 | """ 175 | Multi-class Lovasz-Softmax loss 176 | probas: [P, C] Variable, class probabilities at each prediction (between 0 and 1) 177 | labels: [P] Tensor, ground truth labels (between 0 and C - 1) 178 | classes: 'all' for all, 'present' for classes present in labels, or a list of classes to average. 179 | """ 180 | if probas.numel() == 0: 181 | # only void pixels, the gradients should be 0 182 | return probas * 0. 183 | C = probas.size(1) 184 | losses = [] 185 | class_to_sum = list(range(C)) if classes in ['all', 'present'] else classes 186 | for c in class_to_sum: 187 | fg = (labels == c).float() # foreground for class c 188 | if classes == 'present' and fg.sum() == 0: 189 | continue 190 | if C == 1: 191 | if len(classes) > 1: 192 | raise ValueError('Sigmoid output possible only with 1 class') 193 | class_pred = probas[:, 0] 194 | else: 195 | class_pred = probas[:, c] 196 | errors = (Variable(fg) - class_pred).abs() 197 | errors_sorted, perm = torch.sort(errors, 0, descending=True) 198 | perm = perm.data 199 | fg_sorted = fg[perm] 200 | losses.append(torch.dot(errors_sorted, Variable(lovasz_grad(fg_sorted)))) 201 | return mean(losses) 202 | 203 | 204 | def flatten_probas(probas, labels, ignore=None): 205 | """ 206 | Flattens predictions in the batch 207 | """ 208 | if probas.dim() == 3: 209 | # assumes output of a sigmoid layer 210 | B, H, W = probas.size() 211 | probas = probas.view(B, 1, H, W) 212 | B, C, H, W = probas.size() 213 | probas = probas.permute(0, 2, 3, 1).contiguous().view(-1, C) # B * H * W, C = P, C 214 | labels = labels.view(-1) 215 | if ignore is None: 216 | return probas, labels 217 | valid = (labels != ignore) 218 | vprobas = probas[valid.nonzero().squeeze()] 219 | vlabels = labels[valid] 220 | return vprobas, vlabels 221 | 222 | 223 | def xloss(logits, labels, ignore=None): 224 | """ 225 | Cross entropy loss 226 | """ 227 | return F.cross_entropy(logits, Variable(labels), ignore_index=255) 228 | 229 | 230 | # --------------------------- HELPER FUNCTIONS --------------------------- 231 | def isnan(x): 232 | return x != x 233 | 234 | 235 | def mean(l, ignore_nan=False, empty=0): 236 | """ 237 | nanmean compatible with generators. 238 | """ 239 | l = iter(l) 240 | if ignore_nan: 241 | l = ifilterfalse(isnan, l) 242 | try: 243 | n = 1 244 | acc = next(l) 245 | except StopIteration: 246 | if empty == 'raise': 247 | raise ValueError('Empty mean') 248 | return empty 249 | for n, v in enumerate(l, 2): 250 | acc += v 251 | if n == 1: 252 | return acc 253 | return acc / n 254 | -------------------------------------------------------------------------------- /segmentron/solver/lr_scheduler.py: -------------------------------------------------------------------------------- 1 | # this code heavily reference: detectron2 2 | from __future__ import division 3 | import math 4 | import torch 5 | 6 | from typing import List 7 | from bisect import bisect_right 8 | from segmentron.config import cfg 9 | 10 | __all__ = ['get_scheduler'] 11 | 12 | 13 | class WarmupPolyLR(torch.optim.lr_scheduler._LRScheduler): 14 | def __init__(self, optimizer, target_lr=0, max_iters=0, power=0.9, warmup_factor=1.0 / 3, 15 | warmup_iters=500, warmup_method='linear', last_epoch=-1): 16 | if warmup_method not in ("constant", "linear"): 17 | raise ValueError( 18 | "Only 'constant' or 'linear' warmup_method accepted " 19 | "got {}".format(warmup_method)) 20 | 21 | self.target_lr = target_lr 22 | self.max_iters = max_iters 23 | self.power = power 24 | self.warmup_factor = warmup_factor 25 | self.warmup_iters = warmup_iters 26 | self.warmup_method = warmup_method 27 | 28 | super(WarmupPolyLR, self).__init__(optimizer, last_epoch) 29 | 30 | def get_lr(self): 31 | N = self.max_iters - self.warmup_iters 32 | T = self.last_epoch - self.warmup_iters 33 | if self.last_epoch < self.warmup_iters: 34 | if self.warmup_method == 'constant': 35 | warmup_factor = self.warmup_factor 36 | elif self.warmup_method == 'linear': 37 | alpha = float(self.last_epoch) / self.warmup_iters 38 | warmup_factor = self.warmup_factor * (1 - alpha) + alpha 39 | else: 40 | raise ValueError("Unknown warmup type.") 41 | return [self.target_lr + (base_lr - self.target_lr) * warmup_factor for base_lr in self.base_lrs] 42 | factor = pow(1 - T / N, self.power) 43 | return [self.target_lr + (base_lr - self.target_lr) * factor for base_lr in self.base_lrs] 44 | 45 | 46 | class WarmupMultiStepLR(torch.optim.lr_scheduler._LRScheduler): 47 | def __init__( 48 | self, 49 | optimizer: torch.optim.Optimizer, 50 | milestones: List[int], 51 | gamma: float = 0.1, 52 | warmup_factor: float = 0.001, 53 | warmup_iters: int = 1000, 54 | warmup_method: str = "linear", 55 | last_epoch: int = -1, 56 | ): 57 | if not list(milestones) == sorted(milestones): 58 | raise ValueError( 59 | "Milestones should be a list of" " increasing integers. Got {}", milestones 60 | ) 61 | self.milestones = milestones 62 | self.gamma = gamma 63 | self.warmup_factor = warmup_factor 64 | self.warmup_iters = warmup_iters 65 | self.warmup_method = warmup_method 66 | super().__init__(optimizer, last_epoch) 67 | 68 | def get_lr(self) -> List[float]: 69 | warmup_factor = _get_warmup_factor_at_iter( 70 | self.warmup_method, self.last_epoch, self.warmup_iters, self.warmup_factor 71 | ) 72 | return [ 73 | base_lr * warmup_factor * self.gamma ** bisect_right(self.milestones, self.last_epoch) 74 | for base_lr in self.base_lrs 75 | ] 76 | 77 | def _compute_values(self) -> List[float]: 78 | # The new interface 79 | return self.get_lr() 80 | 81 | 82 | class WarmupCosineLR(torch.optim.lr_scheduler._LRScheduler): 83 | def __init__( 84 | self, 85 | optimizer: torch.optim.Optimizer, 86 | max_iters: int, 87 | warmup_factor: float = 0.001, 88 | warmup_iters: int = 1000, 89 | warmup_method: str = "linear", 90 | last_epoch: int = -1, 91 | ): 92 | self.max_iters = max_iters 93 | self.warmup_factor = warmup_factor 94 | self.warmup_iters = warmup_iters 95 | self.warmup_method = warmup_method 96 | super().__init__(optimizer, last_epoch) 97 | 98 | def get_lr(self) -> List[float]: 99 | warmup_factor = _get_warmup_factor_at_iter( 100 | self.warmup_method, self.last_epoch, self.warmup_iters, self.warmup_factor 101 | ) 102 | # Different definitions of half-cosine with warmup are possible. For 103 | # simplicity we multiply the standard half-cosine schedule by the warmup 104 | # factor. An alternative is to start the period of the cosine at warmup_iters 105 | # instead of at 0. In the case that warmup_iters << max_iters the two are 106 | # very close to each other. 107 | return [ 108 | base_lr 109 | * warmup_factor 110 | * 0.5 111 | * (1.0 + math.cos(math.pi * self.last_epoch / self.max_iters)) 112 | for base_lr in self.base_lrs 113 | ] 114 | 115 | def _compute_values(self) -> List[float]: 116 | # The new interface 117 | return self.get_lr() 118 | 119 | 120 | def _get_warmup_factor_at_iter( 121 | method: str, iter: int, warmup_iters: int, warmup_factor: float 122 | ) -> float: 123 | """ 124 | Return the learning rate warmup factor at a specific iteration. 125 | See https://arxiv.org/abs/1706.02677 for more details. 126 | 127 | Args: 128 | method (str): warmup method; either "constant" or "linear". 129 | iter (int): iteration at which to calculate the warmup factor. 130 | warmup_iters (int): the number of warmup iterations. 131 | warmup_factor (float): the base warmup factor (the meaning changes according 132 | to the method used). 133 | 134 | Returns: 135 | float: the effective warmup factor at the given iteration. 136 | """ 137 | if iter >= warmup_iters: 138 | return 1.0 139 | 140 | if method == "constant": 141 | return warmup_factor 142 | elif method == "linear": 143 | alpha = iter / warmup_iters 144 | return warmup_factor * (1 - alpha) + alpha 145 | else: 146 | raise ValueError("Unknown warmup method: {}".format(method)) 147 | 148 | 149 | def get_scheduler(optimizer, max_iters, iters_per_epoch): 150 | mode = cfg.SOLVER.LR_SCHEDULER.lower() 151 | warm_up_iters = iters_per_epoch * cfg.SOLVER.WARMUP.EPOCHS 152 | if mode == 'poly': 153 | return WarmupPolyLR(optimizer, max_iters=max_iters, power=cfg.SOLVER.POLY.POWER, 154 | warmup_factor=cfg.SOLVER.WARMUP.FACTOR, warmup_iters=warm_up_iters, 155 | warmup_method=cfg.SOLVER.WARMUP.METHOD) 156 | elif mode == 'cosine': 157 | return WarmupCosineLR(optimizer, max_iters=max_iters, warmup_factor=cfg.SOLVER.WARMUP.FACTOR, 158 | warmup_iters=warm_up_iters, warmup_method=cfg.SOLVER.WARMUP.METHOD) 159 | elif mode == 'step': 160 | milestones = [x * iters_per_epoch for x in cfg.SOLVER.STEP.DECAY_EPOCH] 161 | return WarmupMultiStepLR(optimizer, milestones=milestones, gamma=cfg.SOLVER.STEP.GAMMA, 162 | warmup_factor=cfg.SOLVER.WARMUP.FACTOR, warmup_iters=warm_up_iters, 163 | warmup_method=cfg.SOLVER.WARMUP.METHOD) 164 | else: 165 | raise ValueError("not support lr scheduler method!") 166 | 167 | -------------------------------------------------------------------------------- /segmentron/solver/optimizer.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import torch.nn as nn 3 | 4 | from torch import optim 5 | from segmentron.config import cfg 6 | 7 | 8 | def _set_batch_norm_attr(named_modules, attr, value): 9 | for m in named_modules: 10 | if isinstance(m[1], (nn.BatchNorm2d, nn.SyncBatchNorm)): 11 | setattr(m[1], attr, value) 12 | 13 | 14 | def _get_paramters(model): 15 | params_list = list() 16 | if hasattr(model, 'encoder') and model.encoder is not None and hasattr(model, 'decoder'): 17 | params_list.append({'params': model.encoder.parameters(), 'lr': cfg.SOLVER.LR}) 18 | if cfg.MODEL.BN_EPS_FOR_ENCODER: 19 | logging.info('Set bn custom eps for bn in encoder: {}'.format(cfg.MODEL.BN_EPS_FOR_ENCODER)) 20 | _set_batch_norm_attr(model.encoder.named_modules(), 'eps', cfg.MODEL.BN_EPS_FOR_ENCODER) 21 | 22 | for module in model.decoder: 23 | params_list.append({'params': getattr(model, module).parameters(), 24 | 'lr': cfg.SOLVER.LR * cfg.SOLVER.DECODER_LR_FACTOR}) 25 | 26 | if cfg.MODEL.BN_EPS_FOR_DECODER: 27 | logging.info('Set bn custom eps for bn in decoder: {}'.format(cfg.MODEL.BN_EPS_FOR_DECODER)) 28 | for module in model.decoder: 29 | _set_batch_norm_attr(getattr(model, module).named_modules(), 'eps', 30 | cfg.MODEL.BN_EPS_FOR_DECODER) 31 | else: 32 | logging.info('Model do not have encoder or decoder, params list was from model.parameters(), ' 33 | 'and arguments BN_EPS_FOR_ENCODER, BN_EPS_FOR_DECODER, DECODER_LR_FACTOR not used!') 34 | params_list = model.parameters() 35 | 36 | if cfg.MODEL.BN_MOMENTUM and cfg.MODEL.BN_TYPE in ['BN']: 37 | logging.info('Set bn custom momentum: {}'.format(cfg.MODEL.BN_MOMENTUM)) 38 | _set_batch_norm_attr(model.named_modules(), 'momentum', cfg.MODEL.BN_MOMENTUM) 39 | elif cfg.MODEL.BN_MOMENTUM and cfg.MODEL.BN_TYPE not in ['BN']: 40 | logging.info('Batch norm type is {}, custom bn momentum is not effective!'.format(cfg.MODEL.BN_TYPE)) 41 | 42 | return params_list 43 | 44 | 45 | def get_optimizer(model): 46 | parameters = _get_paramters(model) 47 | opt_lower = cfg.SOLVER.OPTIMIZER.lower() 48 | 49 | if opt_lower == 'sgd': 50 | optimizer = optim.SGD( 51 | parameters, lr=cfg.SOLVER.LR, momentum=cfg.SOLVER.MOMENTUM, weight_decay=cfg.SOLVER.WEIGHT_DECAY) 52 | elif opt_lower == 'adam': 53 | optimizer = optim.Adam( 54 | parameters, lr=cfg.SOLVER.LR, eps=cfg.SOLVER.EPSILON, weight_decay=cfg.SOLVER.WEIGHT_DECAY) 55 | elif opt_lower == 'adadelta': 56 | optimizer = optim.Adadelta( 57 | parameters, lr=cfg.SOLVER.LR, eps=cfg.SOLVER.EPSILON, weight_decay=cfg.SOLVER.WEIGHT_DECAY) 58 | elif opt_lower == 'rmsprop': 59 | optimizer = optim.RMSprop( 60 | parameters, lr=cfg.SOLVER.LR, alpha=0.9, eps=cfg.SOLVER.EPSILON, 61 | momentum=cfg.SOLVER.MOMENTUM, weight_decay=cfg.SOLVER.WEIGHT_DECAY) 62 | else: 63 | raise ValueError("Expected optimizer method in [sgd, adam, adadelta, rmsprop], but received " 64 | "{}".format(opt_lower)) 65 | 66 | return optimizer 67 | -------------------------------------------------------------------------------- /segmentron/utils/__init__.py: -------------------------------------------------------------------------------- 1 | """Utility functions.""" 2 | from __future__ import absolute_import 3 | 4 | from .download import download, check_sha1 5 | from .filesystem import makedirs 6 | -------------------------------------------------------------------------------- /segmentron/utils/default_setup.py: -------------------------------------------------------------------------------- 1 | import os 2 | import logging 3 | import json 4 | import torch 5 | 6 | from .distributed import get_rank, synchronize 7 | from .logger import setup_logger 8 | from .env import seed_all_rng 9 | from ..config import cfg 10 | 11 | def default_setup(args): 12 | num_gpus = int(os.environ["WORLD_SIZE"]) if "WORLD_SIZE" in os.environ else 1 13 | args.num_gpus = num_gpus 14 | args.distributed = num_gpus > 1 15 | 16 | if not args.no_cuda and torch.cuda.is_available(): 17 | # cudnn.deterministic = True 18 | torch.backends.cudnn.benchmark = True 19 | args.device = "cuda" 20 | else: 21 | args.distributed = False 22 | args.device = "cpu" 23 | if args.distributed: 24 | torch.cuda.set_device(args.local_rank) 25 | torch.distributed.init_process_group(backend="nccl", init_method="env://") 26 | synchronize() 27 | 28 | # TODO 29 | # if args.save_pred: 30 | # outdir = '../runs/pred_pic/{}_{}_{}'.format(args.model, args.backbone, args.dataset) 31 | # if not os.path.exists(outdir): 32 | # os.makedirs(outdir) 33 | 34 | save_dir = cfg.TRAIN.MODEL_SAVE_DIR if cfg.PHASE == 'train' else None 35 | setup_logger("Segmentron", save_dir, get_rank(), filename='{}_{}_{}_{}_log.txt'.format( 36 | cfg.MODEL.MODEL_NAME, cfg.MODEL.BACKBONE, cfg.DATASET.NAME, cfg.TIME_STAMP)) 37 | 38 | logging.info("Using {} GPUs".format(num_gpus)) 39 | logging.info(args) 40 | logging.info(json.dumps(cfg, indent=8)) 41 | 42 | seed_all_rng(None if cfg.SEED < 0 else cfg.SEED + get_rank()) -------------------------------------------------------------------------------- /segmentron/utils/distributed.py: -------------------------------------------------------------------------------- 1 | """ 2 | code is heavily based on https://github.com/facebookresearch/maskrcnn-benchmark 3 | """ 4 | import math 5 | import pickle 6 | import torch 7 | import torch.utils.data as data 8 | import torch.distributed as dist 9 | 10 | from torch.utils.data.sampler import Sampler, BatchSampler 11 | 12 | __all__ = ['get_world_size', 'get_rank', 'synchronize', 'is_main_process', 13 | 'all_gather', 'make_data_sampler', 'make_batch_data_sampler', 14 | 'reduce_dict', 'reduce_loss_dict'] 15 | 16 | 17 | def get_world_size(): 18 | if not dist.is_available(): 19 | return 1 20 | if not dist.is_initialized(): 21 | return 1 22 | return dist.get_world_size() 23 | 24 | 25 | def get_rank(): 26 | if not dist.is_available(): 27 | return 0 28 | if not dist.is_initialized(): 29 | return 0 30 | return dist.get_rank() 31 | 32 | 33 | def is_main_process(): 34 | return get_rank() == 0 35 | 36 | 37 | def synchronize(): 38 | """ 39 | Helper function to synchronize (barrier) among all processes when 40 | using distributed training 41 | """ 42 | if not dist.is_available(): 43 | return 44 | if not dist.is_initialized(): 45 | return 46 | world_size = dist.get_world_size() 47 | if world_size == 1: 48 | return 49 | dist.barrier() 50 | 51 | 52 | def all_gather(data): 53 | """ 54 | Run all_gather on arbitrary picklable data (not necessarily tensors) 55 | Args: 56 | data: any picklable object 57 | Returns: 58 | list[data]: list of data gathered from each rank 59 | """ 60 | world_size = get_world_size() 61 | if world_size == 1: 62 | return [data] 63 | 64 | # serialized to a Tensor 65 | buffer = pickle.dumps(data) 66 | storage = torch.ByteStorage.from_buffer(buffer) 67 | tensor = torch.ByteTensor(storage).to("cuda") 68 | 69 | # obtain Tensor size of each rank 70 | local_size = torch.IntTensor([tensor.numel()]).to("cuda") 71 | size_list = [torch.IntTensor([0]).to("cuda") for _ in range(world_size)] 72 | dist.all_gather(size_list, local_size) 73 | size_list = [int(size.item()) for size in size_list] 74 | max_size = max(size_list) 75 | 76 | # receiving Tensor from all ranks 77 | # we pad the tensor because torch all_gather does not support 78 | # gathering tensors of different shapes 79 | tensor_list = [] 80 | for _ in size_list: 81 | tensor_list.append(torch.ByteTensor(size=(max_size,)).to("cuda")) 82 | if local_size != max_size: 83 | padding = torch.ByteTensor(size=(max_size - local_size,)).to("cuda") 84 | tensor = torch.cat((tensor, padding), dim=0) 85 | dist.all_gather(tensor_list, tensor) 86 | 87 | data_list = [] 88 | for size, tensor in zip(size_list, tensor_list): 89 | buffer = tensor.cpu().numpy().tobytes()[:size] 90 | data_list.append(pickle.loads(buffer)) 91 | 92 | return data_list 93 | 94 | 95 | def reduce_dict(input_dict, average=True): 96 | """ 97 | Args: 98 | input_dict (dict): all the values will be reduced 99 | average (bool): whether to do average or sum 100 | Reduce the values in the dictionary from all processes so that process with rank 101 | 0 has the averaged results. Returns a dict with the same fields as 102 | input_dict, after reduction. 103 | """ 104 | world_size = get_world_size() 105 | if world_size < 2: 106 | return input_dict 107 | with torch.no_grad(): 108 | names = [] 109 | values = [] 110 | # sort the keys so that they are consistent across processes 111 | for k in sorted(input_dict.keys()): 112 | names.append(k) 113 | values.append(input_dict[k]) 114 | values = torch.stack(values, dim=0) 115 | dist.reduce(values, dst=0) 116 | if dist.get_rank() == 0 and average: 117 | # only main process gets accumulated, so only divide by 118 | # world_size in this case 119 | values /= world_size 120 | reduced_dict = {k: v for k, v in zip(names, values)} 121 | return reduced_dict 122 | 123 | 124 | def reduce_loss_dict(loss_dict): 125 | """ 126 | Reduce the loss dictionary from all processes so that process with rank 127 | 0 has the averaged results. Returns a dict with the same fields as 128 | loss_dict, after reduction. 129 | """ 130 | world_size = get_world_size() 131 | if world_size < 2: 132 | return loss_dict 133 | with torch.no_grad(): 134 | loss_names = [] 135 | all_losses = [] 136 | for k in sorted(loss_dict.keys()): 137 | loss_names.append(k) 138 | all_losses.append(loss_dict[k]) 139 | all_losses = torch.stack(all_losses, dim=0) 140 | dist.reduce(all_losses, dst=0) 141 | if dist.get_rank() == 0: 142 | # only main process gets accumulated, so only divide by 143 | # world_size in this case 144 | all_losses /= world_size 145 | reduced_losses = {k: v for k, v in zip(loss_names, all_losses)} 146 | return reduced_losses 147 | 148 | 149 | def make_data_sampler(dataset, shuffle, distributed): 150 | if distributed: 151 | return DistributedSampler(dataset, shuffle=shuffle) 152 | if shuffle: 153 | sampler = data.sampler.RandomSampler(dataset) 154 | else: 155 | sampler = data.sampler.SequentialSampler(dataset) 156 | return sampler 157 | 158 | 159 | def make_batch_data_sampler(sampler, images_per_batch, num_iters=None, start_iter=0, drop_last=True): 160 | batch_sampler = data.sampler.BatchSampler(sampler, images_per_batch, drop_last=drop_last) 161 | if num_iters is not None: 162 | batch_sampler = IterationBasedBatchSampler(batch_sampler, num_iters, start_iter) 163 | return batch_sampler 164 | 165 | 166 | class DistributedSampler(Sampler): 167 | """Sampler that restricts data loading to a subset of the dataset. 168 | It is especially useful in conjunction with 169 | :class:`torch.nn.parallel.DistributedDataParallel`. In such case, each 170 | process can pass a DistributedSampler instance as a DataLoader sampler, 171 | and load a subset of the original dataset that is exclusive to it. 172 | .. note:: 173 | Dataset is assumed to be of constant size. 174 | Arguments: 175 | dataset: Dataset used for sampling. 176 | num_replicas (optional): Number of processes participating in 177 | distributed training. 178 | rank (optional): Rank of the current process within num_replicas. 179 | """ 180 | 181 | def __init__(self, dataset, num_replicas=None, rank=None, shuffle=True): 182 | if num_replicas is None: 183 | if not dist.is_available(): 184 | raise RuntimeError("Requires distributed package to be available") 185 | num_replicas = dist.get_world_size() 186 | if rank is None: 187 | if not dist.is_available(): 188 | raise RuntimeError("Requires distributed package to be available") 189 | rank = dist.get_rank() 190 | self.dataset = dataset 191 | self.num_replicas = num_replicas 192 | self.rank = rank 193 | self.epoch = 0 194 | self.num_samples = int(math.ceil(len(self.dataset) * 1.0 / self.num_replicas)) 195 | self.total_size = self.num_samples * self.num_replicas 196 | self.shuffle = shuffle 197 | 198 | def __iter__(self): 199 | if self.shuffle: 200 | # deterministically shuffle based on epoch 201 | g = torch.Generator() 202 | g.manual_seed(self.epoch) 203 | indices = torch.randperm(len(self.dataset), generator=g).tolist() 204 | else: 205 | indices = torch.arange(len(self.dataset)).tolist() 206 | 207 | # add extra samples to make it evenly divisible 208 | indices += indices[: (self.total_size - len(indices))] 209 | assert len(indices) == self.total_size 210 | 211 | # subsample 212 | offset = self.num_samples * self.rank 213 | indices = indices[offset: offset + self.num_samples] 214 | assert len(indices) == self.num_samples 215 | 216 | return iter(indices) 217 | 218 | def __len__(self): 219 | return self.num_samples 220 | 221 | def set_epoch(self, epoch): 222 | self.epoch = epoch 223 | 224 | 225 | class IterationBasedBatchSampler(BatchSampler): 226 | """ 227 | Wraps a BatchSampler, resampling from it until 228 | a specified number of iterations have been sampled 229 | """ 230 | 231 | def __init__(self, batch_sampler, num_iterations, start_iter=0): 232 | self.batch_sampler = batch_sampler 233 | self.num_iterations = num_iterations 234 | self.start_iter = start_iter 235 | 236 | def __iter__(self): 237 | iteration = self.start_iter 238 | while iteration <= self.num_iterations: 239 | # if the underlying sampler has a set_epoch method, like 240 | # DistributedSampler, used for making each process see 241 | # a different split of the dataset, then set it 242 | if hasattr(self.batch_sampler.sampler, "set_epoch"): 243 | self.batch_sampler.sampler.set_epoch(iteration) 244 | for batch in self.batch_sampler: 245 | iteration += 1 246 | if iteration > self.num_iterations: 247 | break 248 | yield batch 249 | 250 | def __len__(self): 251 | return self.num_iterations 252 | -------------------------------------------------------------------------------- /segmentron/utils/download.py: -------------------------------------------------------------------------------- 1 | import os 2 | import hashlib 3 | import requests 4 | from tqdm import tqdm 5 | 6 | def check_sha1(filename, sha1_hash): 7 | """Check whether the sha1 hash of the file content matches the expected hash. 8 | Parameters 9 | ---------- 10 | filename : str 11 | Path to the file. 12 | sha1_hash : str 13 | Expected sha1 hash in hexadecimal digits. 14 | Returns 15 | ------- 16 | bool 17 | Whether the file content matches the expected hash. 18 | """ 19 | sha1 = hashlib.sha1() 20 | with open(filename, 'rb') as f: 21 | while True: 22 | data = f.read(1048576) 23 | if not data: 24 | break 25 | sha1.update(data) 26 | 27 | sha1_file = sha1.hexdigest() 28 | l = min(len(sha1_file), len(sha1_hash)) 29 | return sha1.hexdigest()[0:l] == sha1_hash[0:l] 30 | 31 | def download(url, path=None, overwrite=False, sha1_hash=None): 32 | """Download an given URL 33 | Parameters 34 | ---------- 35 | url : str 36 | URL to download 37 | path : str, optional 38 | Destination path to store downloaded file. By default stores to the 39 | current directory with same name as in url. 40 | overwrite : bool, optional 41 | Whether to overwrite destination file if already exists. 42 | sha1_hash : str, optional 43 | Expected sha1 hash in hexadecimal digits. Will ignore existing file when hash is specified 44 | but doesn't match. 45 | Returns 46 | ------- 47 | str 48 | The file path of the downloaded file. 49 | """ 50 | if path is None: 51 | fname = url.split('/')[-1] 52 | else: 53 | path = os.path.expanduser(path) 54 | if os.path.isdir(path): 55 | fname = os.path.join(path, url.split('/')[-1]) 56 | else: 57 | fname = path 58 | 59 | if overwrite or not os.path.exists(fname) or (sha1_hash and not check_sha1(fname, sha1_hash)): 60 | dirname = os.path.dirname(os.path.abspath(os.path.expanduser(fname))) 61 | if not os.path.exists(dirname): 62 | os.makedirs(dirname) 63 | 64 | print('Downloading %s from %s...'%(fname, url)) 65 | r = requests.get(url, stream=True) 66 | if r.status_code != 200: 67 | raise RuntimeError("Failed downloading url %s"%url) 68 | total_length = r.headers.get('content-length') 69 | with open(fname, 'wb') as f: 70 | if total_length is None: # no content length header 71 | for chunk in r.iter_content(chunk_size=1024): 72 | if chunk: # filter out keep-alive new chunks 73 | f.write(chunk) 74 | else: 75 | total_length = int(total_length) 76 | for chunk in tqdm(r.iter_content(chunk_size=1024), 77 | total=int(total_length / 1024. + 0.5), 78 | unit='KB', unit_scale=False, dynamic_ncols=True): 79 | f.write(chunk) 80 | 81 | if sha1_hash and not check_sha1(fname, sha1_hash): 82 | raise UserWarning('File {} is downloaded but the content hash does not match. ' \ 83 | 'The repo may be outdated or download may be incomplete. ' \ 84 | 'If the "repo_url" is overridden, consider switching to ' \ 85 | 'the default repo.'.format(fname)) 86 | 87 | return fname -------------------------------------------------------------------------------- /segmentron/utils/env.py: -------------------------------------------------------------------------------- 1 | # this code heavily based on detectron2 2 | 3 | import logging 4 | import numpy as np 5 | import os 6 | import random 7 | from datetime import datetime 8 | import torch 9 | 10 | __all__ = ["seed_all_rng"] 11 | 12 | 13 | def seed_all_rng(seed=None): 14 | """ 15 | Set the random seed for the RNG in torch, numpy and python. 16 | 17 | Args: 18 | seed (int): if None, will use a strong random seed. 19 | """ 20 | if seed is None: 21 | seed = ( 22 | os.getpid() 23 | + int(datetime.now().strftime("%S%f")) 24 | + int.from_bytes(os.urandom(2), "big") 25 | ) 26 | logger = logging.getLogger(__name__) 27 | logger.info("Using a generated random seed {}".format(seed)) 28 | np.random.seed(seed) 29 | torch.set_rng_state(torch.manual_seed(seed).get_state()) 30 | random.seed(seed) 31 | -------------------------------------------------------------------------------- /segmentron/utils/filesystem.py: -------------------------------------------------------------------------------- 1 | """Filesystem utility functions.""" 2 | from __future__ import absolute_import 3 | import os 4 | import errno 5 | import torch 6 | import logging 7 | 8 | from ..config import cfg 9 | 10 | def save_checkpoint(model, epoch, optimizer=None, lr_scheduler=None, is_best=False): 11 | """Save Checkpoint""" 12 | directory = os.path.expanduser(cfg.TRAIN.MODEL_SAVE_DIR) 13 | # directory = os.path.join(directory, '{}_{}_{}_{}'.format(cfg.MODEL.MODEL_NAME, cfg.MODEL.BACKBONE, 14 | # cfg.DATASET.NAME, cfg.TIME_STAMP)) 15 | if not os.path.exists(directory): 16 | os.makedirs(directory) 17 | filename = '{}.pth'.format(str(epoch)) 18 | filename = os.path.join(directory, filename) 19 | model_state_dict = model.module.state_dict() if hasattr(model, 'module') else model.state_dict() 20 | if is_best: 21 | best_filename = 'best_model.pth' 22 | best_filename = os.path.join(directory, best_filename) 23 | torch.save(model_state_dict, best_filename) 24 | else: 25 | if not os.path.exists(filename): 26 | torch.save(model_state_dict, filename) 27 | logging.info('Epoch {} model saved in: {}'.format(epoch, filename)) 28 | 29 | # remove last epoch 30 | pre_filename = '{}.pth'.format(str(epoch - 1)) 31 | pre_filename = os.path.join(directory, pre_filename) 32 | try: 33 | if os.path.exists(pre_filename): 34 | os.remove(pre_filename) 35 | except OSError as e: 36 | logging.info(e) 37 | 38 | def makedirs(path): 39 | """Create directory recursively if not exists. 40 | Similar to `makedir -p`, you can skip checking existence before this function. 41 | Parameters 42 | ---------- 43 | path : str 44 | Path of the desired dir 45 | """ 46 | try: 47 | os.makedirs(path) 48 | except OSError as exc: 49 | if exc.errno != errno.EEXIST: 50 | raise 51 | 52 | -------------------------------------------------------------------------------- /segmentron/utils/logger.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | import sys 4 | 5 | __all__ = ['setup_logger'] 6 | 7 | 8 | def setup_logger(name, save_dir, distributed_rank, filename="log.txt", mode='w'): 9 | if distributed_rank > 0: 10 | return 11 | 12 | logging.root.name = name 13 | logging.root.setLevel(logging.INFO) 14 | # don't log results for the non-master process 15 | ch = logging.StreamHandler(stream=sys.stdout) 16 | ch.setLevel(logging.DEBUG) 17 | formatter = logging.Formatter("%(asctime)s %(name)s %(levelname)s: %(message)s") 18 | ch.setFormatter(formatter) 19 | logging.root.addHandler(ch) 20 | 21 | if save_dir: 22 | if not os.path.exists(save_dir): 23 | os.makedirs(save_dir) 24 | fh = logging.FileHandler(os.path.join(save_dir, filename), mode=mode) # 'a+' for add, 'w' for overwrite 25 | fh.setLevel(logging.DEBUG) 26 | fh.setFormatter(formatter) 27 | logging.root.addHandler(fh) 28 | -------------------------------------------------------------------------------- /segmentron/utils/options.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | def parse_args(): 4 | parser = argparse.ArgumentParser(description='Segmentron') 5 | parser.add_argument('--config-file', metavar="FILE", 6 | help='config file path') 7 | # cuda setting 8 | parser.add_argument('--no-cuda', action='store_true', default=False, 9 | help='disables CUDA training') 10 | parser.add_argument('--local_rank', type=int, default=0) 11 | # checkpoint and log 12 | parser.add_argument('--resume', type=str, default=None, 13 | help='put the path to resuming file if needed') 14 | parser.add_argument('--log-iter', type=int, default=10, 15 | help='print log every log-iter') 16 | # for evaluation 17 | parser.add_argument('--val-epoch', type=int, default=1, 18 | help='run validation every val-epoch') 19 | parser.add_argument('--skip-val', action='store_true', default=False, 20 | help='skip validation during training') 21 | # for visual 22 | parser.add_argument('--input-img', type=str, default='tools/demo_vis.png', 23 | help='path to the input image or a directory of images') 24 | # config options 25 | parser.add_argument('opts', help='See config for all options', 26 | default=None, nargs=argparse.REMAINDER) 27 | args = parser.parse_args() 28 | 29 | return args -------------------------------------------------------------------------------- /segmentron/utils/parallel.py: -------------------------------------------------------------------------------- 1 | """Utils for Semantic Segmentation""" 2 | import threading 3 | import torch 4 | import torch.cuda.comm as comm 5 | from torch.nn.parallel.data_parallel import DataParallel 6 | from torch.nn.parallel._functions import Broadcast 7 | from torch.autograd import Function 8 | 9 | __all__ = ['DataParallelModel', 'DataParallelCriterion'] 10 | 11 | 12 | class Reduce(Function): 13 | @staticmethod 14 | def forward(ctx, *inputs): 15 | ctx.target_gpus = [inputs[i].get_device() for i in range(len(inputs))] 16 | inputs = sorted(inputs, key=lambda i: i.get_device()) 17 | return comm.reduce_add(inputs) 18 | 19 | @staticmethod 20 | def backward(ctx, gradOutputs): 21 | return Broadcast.apply(ctx.target_gpus, gradOutputs) 22 | 23 | 24 | class DataParallelModel(DataParallel): 25 | """Data parallelism 26 | 27 | Hide the difference of single/multiple GPUs to the user. 28 | In the forward pass, the module is replicated on each device, 29 | and each replica handles a portion of the input. During the backwards 30 | pass, gradients from each replica are summed into the original module. 31 | 32 | The batch size should be larger than the number of GPUs used. 33 | 34 | Parameters 35 | ---------- 36 | module : object 37 | Network to be parallelized. 38 | sync : bool 39 | enable synchronization (default: False). 40 | Inputs: 41 | - **inputs**: list of input 42 | Outputs: 43 | - **outputs**: list of output 44 | Example:: 45 | >>> net = DataParallelModel(model, device_ids=[0, 1, 2]) 46 | >>> output = net(input_var) # input_var can be on any device, including CPU 47 | """ 48 | 49 | def gather(self, outputs, output_device): 50 | return outputs 51 | 52 | def replicate(self, module, device_ids): 53 | modules = super(DataParallelModel, self).replicate(module, device_ids) 54 | return modules 55 | 56 | 57 | # Reference: https://github.com/zhanghang1989/PyTorch-Encoding/blob/master/encoding/parallel.py 58 | class DataParallelCriterion(DataParallel): 59 | """ 60 | Calculate loss in multiple-GPUs, which balance the memory usage for 61 | Semantic Segmentation. 62 | 63 | The targets are splitted across the specified devices by chunking in 64 | the batch dimension. Please use together with :class:`encoding.parallel.DataParallelModel`. 65 | 66 | Example:: 67 | >>> net = DataParallelModel(model, device_ids=[0, 1, 2]) 68 | >>> criterion = DataParallelCriterion(criterion, device_ids=[0, 1, 2]) 69 | >>> y = net(x) 70 | >>> loss = criterion(y, target) 71 | """ 72 | 73 | def forward(self, inputs, *targets, **kwargs): 74 | # the inputs should be the outputs of DataParallelModel 75 | if not self.device_ids: 76 | return self.module(inputs, *targets, **kwargs) 77 | targets, kwargs = self.scatter(targets, kwargs, self.device_ids) 78 | if len(self.device_ids) == 1: 79 | return self.module(inputs, *targets[0], **kwargs[0]) 80 | replicas = self.replicate(self.module, self.device_ids[:len(inputs)]) 81 | outputs = criterion_parallel_apply(replicas, inputs, targets, kwargs) 82 | return Reduce.apply(*outputs) / len(outputs) 83 | 84 | 85 | def get_a_var(obj): 86 | if isinstance(obj, torch.Tensor): 87 | return obj 88 | 89 | if isinstance(obj, list) or isinstance(obj, tuple): 90 | for result in map(get_a_var, obj): 91 | if isinstance(result, torch.Tensor): 92 | return result 93 | 94 | if isinstance(obj, dict): 95 | for result in map(get_a_var, obj.items()): 96 | if isinstance(result, torch.Tensor): 97 | return result 98 | return None 99 | 100 | 101 | def criterion_parallel_apply(modules, inputs, targets, kwargs_tup=None, devices=None): 102 | r"""Applies each `module` in :attr:`modules` in parallel on arguments 103 | contained in :attr:`inputs` (positional), attr:'targets' (positional) and :attr:`kwargs_tup` (keyword) 104 | on each of :attr:`devices`. 105 | 106 | Args: 107 | modules (Module): modules to be parallelized 108 | inputs (tensor): inputs to the modules 109 | targets (tensor): targets to the modules 110 | devices (list of int or torch.device): CUDA devices 111 | :attr:`modules`, :attr:`inputs`, :attr:'targets' :attr:`kwargs_tup` (if given), and 112 | :attr:`devices` (if given) should all have same length. Moreover, each 113 | element of :attr:`inputs` can either be a single object as the only argument 114 | to a module, or a collection of positional arguments. 115 | """ 116 | assert len(modules) == len(inputs) 117 | assert len(targets) == len(inputs) 118 | if kwargs_tup is not None: 119 | assert len(modules) == len(kwargs_tup) 120 | else: 121 | kwargs_tup = ({},) * len(modules) 122 | if devices is not None: 123 | assert len(modules) == len(devices) 124 | else: 125 | devices = [None] * len(modules) 126 | lock = threading.Lock() 127 | results = {} 128 | grad_enabled = torch.is_grad_enabled() 129 | 130 | def _worker(i, module, input, target, kwargs, device=None): 131 | torch.set_grad_enabled(grad_enabled) 132 | if device is None: 133 | device = get_a_var(input).get_device() 134 | try: 135 | with torch.cuda.device(device): 136 | output = module(*(list(input) + target), **kwargs) 137 | with lock: 138 | results[i] = output 139 | except Exception as e: 140 | with lock: 141 | results[i] = e 142 | 143 | if len(modules) > 1: 144 | threads = [threading.Thread(target=_worker, 145 | args=(i, module, input, target, kwargs, device)) 146 | for i, (module, input, target, kwargs, device) in 147 | enumerate(zip(modules, inputs, targets, kwargs_tup, devices))] 148 | 149 | for thread in threads: 150 | thread.start() 151 | for thread in threads: 152 | thread.join() 153 | else: 154 | _worker(0, modules[0], inputs[0], targets[0], kwargs_tup[0], devices[0]) 155 | 156 | outputs = [] 157 | for i in range(len(inputs)): 158 | output = results[i] 159 | if isinstance(output, Exception): 160 | raise output 161 | outputs.append(output) 162 | return outputs 163 | -------------------------------------------------------------------------------- /segmentron/utils/registry.py: -------------------------------------------------------------------------------- 1 | # this code heavily based on detectron2 2 | 3 | import logging 4 | import torch 5 | 6 | from ..config import cfg 7 | 8 | class Registry(object): 9 | """ 10 | The registry that provides name -> object mapping, to support third-party users' custom modules. 11 | 12 | To create a registry (inside segmentron): 13 | 14 | .. code-block:: python 15 | 16 | BACKBONE_REGISTRY = Registry('BACKBONE') 17 | 18 | To register an object: 19 | 20 | .. code-block:: python 21 | 22 | @BACKBONE_REGISTRY.register() 23 | class MyBackbone(): 24 | ... 25 | 26 | Or: 27 | 28 | .. code-block:: python 29 | 30 | BACKBONE_REGISTRY.register(MyBackbone) 31 | """ 32 | 33 | def __init__(self, name): 34 | """ 35 | Args: 36 | name (str): the name of this registry 37 | """ 38 | self._name = name 39 | 40 | self._obj_map = {} 41 | 42 | def _do_register(self, name, obj): 43 | assert ( 44 | name not in self._obj_map 45 | ), "An object named '{}' was already registered in '{}' registry!".format(name, self._name) 46 | self._obj_map[name] = obj 47 | 48 | def register(self, obj=None, name=None): 49 | """ 50 | Register the given object under the the name `obj.__name__`. 51 | Can be used as either a decorator or not. See docstring of this class for usage. 52 | """ 53 | if obj is None: 54 | # used as a decorator 55 | def deco(func_or_class, name=name): 56 | if name is None: 57 | name = func_or_class.__name__ 58 | self._do_register(name, func_or_class) 59 | return func_or_class 60 | 61 | return deco 62 | 63 | # used as a function call 64 | if name is None: 65 | name = obj.__name__ 66 | self._do_register(name, obj) 67 | 68 | 69 | 70 | def get(self, name): 71 | ret = self._obj_map.get(name) 72 | if ret is None: 73 | raise KeyError("No object named '{}' found in '{}' registry!".format(name, self._name)) 74 | 75 | return ret 76 | 77 | def get_list(self): 78 | return list(self._obj_map.keys()) 79 | -------------------------------------------------------------------------------- /segmentron/utils/score.py: -------------------------------------------------------------------------------- 1 | """Evaluation Metrics for Semantic Segmentation""" 2 | import torch 3 | import numpy as np 4 | from torch import distributed as dist 5 | import copy 6 | from IPython import embed 7 | 8 | __all__ = ['SegmentationMetric', 'batch_pix_accuracy', 'batch_intersection_union', 9 | 'pixelAccuracy', 'intersectionAndUnion', 'hist_info', 'compute_score'] 10 | 11 | 12 | class SegmentationMetric(object): 13 | """Computes pixAcc and mIoU metric scores 14 | """ 15 | 16 | def __init__(self, nclass, distributed, num_gpu): 17 | super(SegmentationMetric, self).__init__() 18 | self.nclass = nclass 19 | self.distributed = distributed 20 | self.num_gpu = num_gpu 21 | self.reset() 22 | 23 | def update(self, preds, labels): 24 | """Updates the internal evaluation result. 25 | 26 | Parameters 27 | ---------- 28 | labels : 'NumpyArray' or list of `NumpyArray` 29 | The labels of the data. 30 | preds : 'NumpyArray' or list of `NumpyArray` 31 | Predicted values. 32 | """ 33 | 34 | def reduce_tensor(tensor): 35 | if isinstance(tensor, torch.Tensor): 36 | rt = tensor.clone() 37 | else: 38 | rt = copy.deepcopy(tensor) 39 | dist.all_reduce(rt, op=dist.ReduceOp.SUM) 40 | return rt 41 | 42 | def evaluate_worker(self, pred, label): 43 | 44 | correct, labeled = batch_pix_accuracy(pred, label) 45 | inter, union = batch_intersection_union(pred, label, self.nclass) 46 | mae = batch_mae(pred, label) 47 | bers, bers_count = batch_ber(pred, label) 48 | 49 | if self.distributed: 50 | correct = reduce_tensor(correct) 51 | labeled = reduce_tensor(labeled) 52 | inter = reduce_tensor(inter.cuda()) 53 | union = reduce_tensor(union.cuda()) 54 | mae = reduce_tensor(mae.cuda()) 55 | bers = reduce_tensor(bers.cuda()) 56 | bers_count = reduce_tensor((bers_count.cuda())) 57 | 58 | torch.cuda.synchronize() 59 | self.total_correct += correct.item() 60 | self.total_label += labeled.item() 61 | 62 | if self.total_inter.device != inter.device: 63 | self.total_inter = self.total_inter.to(inter.device) 64 | self.total_union = self.total_union.to(union.device) 65 | self.total_inter += inter 66 | self.total_union += union 67 | 68 | self.total_mae.append(mae) 69 | 70 | if self.total_bers.device != bers.device: 71 | self.total_bers = self.total_bers.to(bers.device) 72 | self.total_bers_count = self.total_bers_count.to(bers_count.device) 73 | self.total_bers += bers 74 | self.total_bers_count += bers_count 75 | 76 | if isinstance(preds, torch.Tensor): 77 | evaluate_worker(self, preds, labels) 78 | elif isinstance(preds, (list, tuple)): 79 | for (pred, label) in zip(preds, labels): 80 | evaluate_worker(self, pred, label) 81 | 82 | def get(self, return_category_iou=False): 83 | """Gets the current evaluation result. 84 | 85 | Returns 86 | ------- 87 | metrics : tuple of float 88 | pixAcc and mIoU 89 | """ 90 | pixAcc = 1.0 * self.total_correct / (2.220446049250313e-16 + self.total_label) # remove np.spacing(1) 91 | IoU = 1.0 * self.total_inter / (2.220446049250313e-16 + self.total_union) 92 | # mIoU = IoU.mean().item() 93 | mIoU = IoU[1: ].mean().item() 94 | mae = 1.0 * torch.Tensor(self.total_mae).mean().item() / self.num_gpu 95 | 96 | Ber = 1.0 * self.total_bers / self.total_bers_count 97 | mBer = Ber[1: ].mean().item() 98 | 99 | if return_category_iou: 100 | return pixAcc, mIoU, IoU.cpu().numpy(), mae, mBer, Ber.cpu().numpy() 101 | return pixAcc, mIoU, mae, mBer 102 | 103 | def reset(self): 104 | """Resets the internal evaluation result to initial state.""" 105 | self.total_inter = torch.zeros(self.nclass) 106 | self.total_union = torch.zeros(self.nclass) 107 | self.total_correct = 0 108 | self.total_label = 0 109 | self.total_mae = [] 110 | 111 | self.total_bers = torch.zeros(3) 112 | self.total_bers_count = torch.zeros(3) 113 | 114 | 115 | def batch_pix_accuracy(output, target): 116 | """PixAcc""" 117 | # inputs are numpy array, output 4D, target 3D 118 | predict = torch.argmax(output.long(), 1) + 1 119 | target = target.long() + 1 120 | 121 | '''do not care background''' 122 | # pixel_labeled = torch.sum(target > 0) 123 | # pixel_correct = torch.sum((predict == target) * (target > 0)) 124 | 125 | pixel_labeled = torch.sum(target > 1) 126 | pixel_correct = torch.sum((predict == target) * (target > 1)) 127 | assert pixel_correct <= pixel_labeled, "Correct area should be smaller than Labeled" 128 | return pixel_correct, pixel_labeled 129 | 130 | def batch_mae(output, target): 131 | """Mean Average Error""" 132 | # inputs are numpy array, output 4D, target 3D 133 | predict = (torch.argmax(output, 1)).float() 134 | target = target.float() 135 | 136 | mae = (predict - target).abs().mean() 137 | return mae 138 | 139 | def batch_ber(output, target, class_ids=[1,2]): 140 | predict = torch.argmax(output.long(), 1) 141 | target = target.long() 142 | bers = torch.zeros(3) 143 | bers_count = torch.zeros(3) 144 | bers_count[0] = 1 145 | 146 | for class_id in class_ids: 147 | valid = target == class_id 148 | if valid.sum() == 0: 149 | continue 150 | N_p = torch.sum(target == class_id) 151 | N_n = torch.sum(target != class_id) 152 | TP = torch.sum((predict == target) * valid) 153 | TN = torch.sum((predict == target) * (1 - valid)) 154 | 155 | N_p = N_p.float(); N_n = N_n.float(); TP = TP.float(); TN = TN.float() 156 | ber = 1 - 1/2 * (TP / N_p + TN / N_n) 157 | ber = ber * 100 158 | 159 | bers[class_id] = ber 160 | bers_count[class_id] = 1.0 161 | 162 | return bers, bers_count 163 | 164 | def batch_intersection_union(output, target, nclass): 165 | """mIoU""" 166 | # inputs are numpy array, output 4D, target 3D 167 | mini = 1 168 | maxi = nclass 169 | nbins = nclass 170 | predict = torch.argmax(output, 1) + 1 171 | target = target.float() + 1 172 | 173 | predict = predict.float() * (target > 0).float() 174 | intersection = predict * (predict == target).float() 175 | # areas of intersection and union 176 | # element 0 in intersection occur the main difference from np.bincount. set boundary to -1 is necessary. 177 | area_inter = torch.histc(intersection.cpu(), bins=nbins, min=mini, max=maxi) 178 | area_pred = torch.histc(predict.cpu(), bins=nbins, min=mini, max=maxi) 179 | area_lab = torch.histc(target.cpu(), bins=nbins, min=mini, max=maxi) 180 | area_union = area_pred + area_lab - area_inter 181 | assert torch.sum(area_inter > area_union).item() == 0, "Intersection area should be smaller than Union area" 182 | return area_inter.float(), area_union.float() 183 | 184 | 185 | def pixelAccuracy(imPred, imLab): 186 | """ 187 | This function takes the prediction and label of a single image, returns pixel-wise accuracy 188 | To compute over many images do: 189 | for i = range(Nimages): 190 | (pixel_accuracy[i], pixel_correct[i], pixel_labeled[i]) = \ 191 | pixelAccuracy(imPred[i], imLab[i]) 192 | mean_pixel_accuracy = 1.0 * np.sum(pixel_correct) / (np.spacing(1) + np.sum(pixel_labeled)) 193 | """ 194 | # Remove classes from unlabeled pixels in gt image. 195 | # We should not penalize detections in unlabeled portions of the image. 196 | # pixel_labeled = np.sum(imLab >= 0) 197 | # pixel_correct = np.sum((imPred == imLab) * (imLab >= 0)) 198 | 199 | '''do not care background''' 200 | pixel_labeled = np.sum(imLab > 0) 201 | pixel_correct = np.sum((imPred == imLab) * (imLab > 0)) 202 | pixel_accuracy = 1.0 * pixel_correct / pixel_labeled 203 | return (pixel_accuracy, pixel_correct, pixel_labeled) 204 | 205 | 206 | def intersectionAndUnion(imPred, imLab, numClass): 207 | """ 208 | This function takes the prediction and label of a single image, 209 | returns intersection and union areas for each class 210 | To compute over many images do: 211 | for i in range(Nimages): 212 | (area_intersection[:,i], area_union[:,i]) = intersectionAndUnion(imPred[i], imLab[i]) 213 | IoU = 1.0 * np.sum(area_intersection, axis=1) / np.sum(np.spacing(1)+area_union, axis=1) 214 | """ 215 | # Remove classes from unlabeled pixels in gt image. 216 | # We should not penalize detections in unlabeled portions of the image. 217 | imPred = imPred * (imLab >= 0) 218 | 219 | # Compute area intersection: 220 | intersection = imPred * (imPred == imLab) 221 | (area_intersection, _) = np.histogram(intersection, bins=numClass, range=(1, numClass)) 222 | 223 | # Compute area union: 224 | (area_pred, _) = np.histogram(imPred, bins=numClass, range=(1, numClass)) 225 | (area_lab, _) = np.histogram(imLab, bins=numClass, range=(1, numClass)) 226 | area_union = area_pred + area_lab - area_intersection 227 | return (area_intersection, area_union) 228 | 229 | 230 | def hist_info(pred, label, num_cls): 231 | assert pred.shape == label.shape 232 | k = (label >= 0) & (label < num_cls) 233 | labeled = np.sum(k) 234 | correct = np.sum((pred[k] == label[k])) 235 | 236 | return np.bincount(num_cls * label[k].astype(int) + pred[k], minlength=num_cls ** 2).reshape(num_cls, 237 | num_cls), labeled, correct 238 | 239 | 240 | def compute_score(hist, correct, labeled): 241 | iu = np.diag(hist) / (hist.sum(1) + hist.sum(0) - np.diag(hist)) 242 | mean_IU = np.nanmean(iu) 243 | mean_IU_no_back = np.nanmean(iu[1:]) 244 | freq = hist.sum(1) / hist.sum() 245 | # freq_IU = (iu[freq > 0] * freq[freq > 0]).sum() 246 | mean_pixel_acc = correct / labeled 247 | 248 | return iu, mean_IU, mean_IU_no_back, mean_pixel_acc 249 | -------------------------------------------------------------------------------- /segmentron/utils/visualize.py: -------------------------------------------------------------------------------- 1 | import os 2 | import logging 3 | import numpy as np 4 | import torch 5 | 6 | from PIL import Image 7 | #from torchsummary import summary 8 | from thop import profile 9 | 10 | __all__ = ['get_color_pallete', 'print_iou', 'set_img_color', 11 | 'show_prediction', 'show_colorful_images', 'save_colorful_images'] 12 | 13 | 14 | def print_iou(iu, mean_pixel_acc, class_names=None, show_no_back=False): 15 | n = iu.size 16 | lines = [] 17 | for i in range(n): 18 | if class_names is None: 19 | cls = 'Class %d:' % (i + 1) 20 | else: 21 | cls = '%d %s' % (i + 1, class_names[i]) 22 | # lines.append('%-8s: %.3f%%' % (cls, iu[i] * 100)) 23 | mean_IU = np.nanmean(iu) 24 | mean_IU_no_back = np.nanmean(iu[1:]) 25 | if show_no_back: 26 | lines.append('mean_IU: %.3f%% || mean_IU_no_back: %.3f%% || mean_pixel_acc: %.3f%%' % ( 27 | mean_IU * 100, mean_IU_no_back * 100, mean_pixel_acc * 100)) 28 | else: 29 | lines.append('mean_IU: %.3f%% || mean_pixel_acc: %.3f%%' % (mean_IU * 100, mean_pixel_acc * 100)) 30 | lines.append('=================================================') 31 | line = "\n".join(lines) 32 | 33 | print(line) 34 | 35 | 36 | def show_flops_params(model, device, input_shape=[1, 3, 512, 512]): 37 | #summary(model, tuple(input_shape[1:]), device=device) 38 | input = torch.randn(*input_shape).to(torch.device(device)) 39 | flops, params = profile(model, inputs=(input,), verbose=False) 40 | 41 | logging.info('{} flops: {:.3f}G input shape is {}, params: {:.3f}M'.format( 42 | model.__class__.__name__, flops / 1000000000, input_shape[1:], params / 1000000)) 43 | 44 | 45 | def set_img_color(img, label, colors, background=0, show255=False): 46 | for i in range(len(colors)): 47 | if i != background: 48 | img[np.where(label == i)] = colors[i] 49 | if show255: 50 | img[np.where(label == 255)] = 255 51 | 52 | return img 53 | 54 | 55 | def show_prediction(img, pred, colors, background=0): 56 | im = np.array(img, np.uint8) 57 | set_img_color(im, pred, colors, background) 58 | out = np.array(im) 59 | 60 | return out 61 | 62 | 63 | def show_colorful_images(prediction, palettes): 64 | im = Image.fromarray(palettes[prediction.astype('uint8').squeeze()]) 65 | im.show() 66 | 67 | 68 | def save_colorful_images(prediction, filename, output_dir, palettes): 69 | ''' 70 | :param prediction: [B, H, W, C] 71 | ''' 72 | im = Image.fromarray(palettes[prediction.astype('uint8').squeeze()]) 73 | fn = os.path.join(output_dir, filename) 74 | out_dir = os.path.split(fn)[0] 75 | if not os.path.exists(out_dir): 76 | os.mkdir(out_dir) 77 | im.save(fn) 78 | 79 | 80 | def get_color_pallete(npimg, dataset='cityscape'): 81 | """Visualize image. 82 | 83 | Parameters 84 | ---------- 85 | npimg : numpy.ndarray 86 | Single channel image with shape `H, W, 1`. 87 | dataset : str, default: 'pascal_voc' 88 | The dataset that model pretrained on. ('pascal_voc', 'ade20k') 89 | Returns 90 | ------- 91 | out_img : PIL.Image 92 | Image with color pallete 93 | """ 94 | # recovery boundary 95 | if dataset in ('pascal_voc', 'pascal_aug'): 96 | npimg[npimg == -1] = 255 97 | # put colormap 98 | if dataset == 'ade20k': 99 | npimg = npimg + 1 100 | out_img = Image.fromarray(npimg.astype('uint8')) 101 | out_img.putpalette(adepallete) 102 | return out_img 103 | elif dataset == 'cityscape': 104 | out_img = Image.fromarray(npimg.astype('uint8')) 105 | out_img.putpalette(cityscapepallete) 106 | return out_img 107 | out_img = Image.fromarray(npimg.astype('uint8')) 108 | out_img.putpalette(vocpallete) 109 | return out_img 110 | 111 | 112 | def _getvocpallete(num_cls): 113 | n = num_cls 114 | pallete = [0] * (n * 3) 115 | for j in range(0, n): 116 | lab = j 117 | pallete[j * 3 + 0] = 0 118 | pallete[j * 3 + 1] = 0 119 | pallete[j * 3 + 2] = 0 120 | i = 0 121 | while (lab > 0): 122 | pallete[j * 3 + 0] |= (((lab >> 0) & 1) << (7 - i)) 123 | pallete[j * 3 + 1] |= (((lab >> 1) & 1) << (7 - i)) 124 | pallete[j * 3 + 2] |= (((lab >> 2) & 1) << (7 - i)) 125 | i = i + 1 126 | lab >>= 3 127 | return pallete 128 | 129 | 130 | vocpallete = _getvocpallete(256) 131 | 132 | adepallete = [ 133 | 0, 0, 0, 120, 120, 120, 180, 120, 120, 6, 230, 230, 80, 50, 50, 4, 200, 3, 120, 120, 80, 140, 140, 140, 204, 134 | 5, 255, 230, 230, 230, 4, 250, 7, 224, 5, 255, 235, 255, 7, 150, 5, 61, 120, 120, 70, 8, 255, 51, 255, 6, 82, 135 | 143, 255, 140, 204, 255, 4, 255, 51, 7, 204, 70, 3, 0, 102, 200, 61, 230, 250, 255, 6, 51, 11, 102, 255, 255, 136 | 7, 71, 255, 9, 224, 9, 7, 230, 220, 220, 220, 255, 9, 92, 112, 9, 255, 8, 255, 214, 7, 255, 224, 255, 184, 6, 137 | 10, 255, 71, 255, 41, 10, 7, 255, 255, 224, 255, 8, 102, 8, 255, 255, 61, 6, 255, 194, 7, 255, 122, 8, 0, 255, 138 | 20, 255, 8, 41, 255, 5, 153, 6, 51, 255, 235, 12, 255, 160, 150, 20, 0, 163, 255, 140, 140, 140, 250, 10, 15, 139 | 20, 255, 0, 31, 255, 0, 255, 31, 0, 255, 224, 0, 153, 255, 0, 0, 0, 255, 255, 71, 0, 0, 235, 255, 0, 173, 255, 140 | 31, 0, 255, 11, 200, 200, 255, 82, 0, 0, 255, 245, 0, 61, 255, 0, 255, 112, 0, 255, 133, 255, 0, 0, 255, 163, 141 | 0, 255, 102, 0, 194, 255, 0, 0, 143, 255, 51, 255, 0, 0, 82, 255, 0, 255, 41, 0, 255, 173, 10, 0, 255, 173, 255, 142 | 0, 0, 255, 153, 255, 92, 0, 255, 0, 255, 255, 0, 245, 255, 0, 102, 255, 173, 0, 255, 0, 20, 255, 184, 184, 0, 143 | 31, 255, 0, 255, 61, 0, 71, 255, 255, 0, 204, 0, 255, 194, 0, 255, 82, 0, 10, 255, 0, 112, 255, 51, 0, 255, 0, 144 | 194, 255, 0, 122, 255, 0, 255, 163, 255, 153, 0, 0, 255, 10, 255, 112, 0, 143, 255, 0, 82, 0, 255, 163, 255, 145 | 0, 255, 235, 0, 8, 184, 170, 133, 0, 255, 0, 255, 92, 184, 0, 255, 255, 0, 31, 0, 184, 255, 0, 214, 255, 255, 146 | 0, 112, 92, 255, 0, 0, 224, 255, 112, 224, 255, 70, 184, 160, 163, 0, 255, 153, 0, 255, 71, 255, 0, 255, 0, 147 | 163, 255, 204, 0, 255, 0, 143, 0, 255, 235, 133, 255, 0, 255, 0, 235, 245, 0, 255, 255, 0, 122, 255, 245, 0, 148 | 10, 190, 212, 214, 255, 0, 0, 204, 255, 20, 0, 255, 255, 255, 0, 0, 153, 255, 0, 41, 255, 0, 255, 204, 41, 0, 149 | 255, 41, 255, 0, 173, 0, 255, 0, 245, 255, 71, 0, 255, 122, 0, 255, 0, 255, 184, 0, 92, 255, 184, 255, 0, 0, 150 | 133, 255, 255, 214, 0, 25, 194, 194, 102, 255, 0, 92, 0, 255] 151 | 152 | cityscapepallete = [ 153 | 128, 64, 128, 154 | 244, 35, 232, 155 | 70, 70, 70, 156 | 102, 102, 156, 157 | 190, 153, 153, 158 | 153, 153, 153, 159 | 250, 170, 30, 160 | 220, 220, 0, 161 | 107, 142, 35, 162 | 152, 251, 152, 163 | 0, 130, 180, 164 | 220, 20, 60, 165 | 255, 0, 0, 166 | 0, 0, 142, 167 | 0, 0, 70, 168 | 0, 60, 100, 169 | 0, 80, 100, 170 | 0, 0, 230, 171 | 119, 11, 32, 172 | ] 173 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | import glob 4 | import os 5 | from setuptools import find_packages, setup 6 | import torch 7 | from torch.utils.cpp_extension import CUDA_HOME, CppExtension, CUDAExtension 8 | 9 | torch_ver = [int(x) for x in torch.__version__.split(".")[:2]] 10 | assert torch_ver >= [1, 1], "Requires PyTorch >= 1.1" 11 | 12 | 13 | def get_extensions(): 14 | this_dir = os.path.dirname(os.path.abspath(__file__)) 15 | extensions_dir = os.path.join(this_dir, "segmentron", "modules", "csrc") 16 | 17 | main_source = os.path.join(extensions_dir, "vision.cpp") 18 | sources = glob.glob(os.path.join(extensions_dir, "**", "*.cpp")) 19 | source_cuda = glob.glob(os.path.join(extensions_dir, "**", "*.cu")) + glob.glob( 20 | os.path.join(extensions_dir, "*.cu") 21 | ) 22 | 23 | sources = [main_source] + sources 24 | 25 | extension = CppExtension 26 | 27 | extra_compile_args = {"cxx": []} 28 | define_macros = [] 29 | 30 | if (torch.cuda.is_available() and CUDA_HOME is not None) or os.getenv("FORCE_CUDA", "0") == "1": 31 | extension = CUDAExtension 32 | sources += source_cuda 33 | define_macros += [("WITH_CUDA", None)] 34 | extra_compile_args["nvcc"] = [ 35 | "-DCUDA_HAS_FP16=1", 36 | "-D__CUDA_NO_HALF_OPERATORS__", 37 | "-D__CUDA_NO_HALF_CONVERSIONS__", 38 | "-D__CUDA_NO_HALF2_OPERATORS__", 39 | ] 40 | 41 | # It's better if pytorch can do this by default .. 42 | CC = os.environ.get("CC", None) 43 | if CC is not None: 44 | extra_compile_args["nvcc"].append("-ccbin={}".format(CC)) 45 | 46 | sources = [os.path.join(extensions_dir, s) for s in sources] 47 | 48 | include_dirs = [extensions_dir] 49 | 50 | ext_modules = [ 51 | extension( 52 | "segmentron._C", 53 | sources, 54 | include_dirs=include_dirs, 55 | define_macros=define_macros, 56 | extra_compile_args=extra_compile_args, 57 | ) 58 | ] 59 | 60 | return ext_modules 61 | 62 | 63 | setup( 64 | name="segmentron", 65 | version="0.1", 66 | author="LikeLy-Journey", 67 | url="https://github.com/LikeLy-Journey/SegmenTron", 68 | description="platform for semantic segmentation base on pytorch.", 69 | # packages=find_packages(exclude=("configs", "tests")), 70 | # python_requires=">=3.6", 71 | # install_requires=[ 72 | # "termcolor>=1.1", 73 | # "Pillow", 74 | # "yacs>=0.1.6", 75 | # "tabulate", 76 | # "cloudpickle", 77 | # "matplotlib", 78 | # "tqdm>4.29.0", 79 | # "tensorboard", 80 | # ], 81 | # extras_require={"all": ["shapely", "psutil"]}, 82 | ext_modules=get_extensions(), 83 | cmdclass={"build_ext": torch.utils.cpp_extension.BuildExtension}, 84 | ) 85 | -------------------------------------------------------------------------------- /tools/dist_test.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | PYTHON=${PYTHON:-"python"} 4 | 5 | CONFIG=$1 6 | GPUS=$2 7 | 8 | $PYTHON -m torch.distributed.launch --nproc_per_node=$GPUS \ 9 | $(dirname "$0")/eval.py --config-file $CONFIG ${@:3} 10 | -------------------------------------------------------------------------------- /tools/dist_train.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | PYTHON=${PYTHON:-"python"} 4 | 5 | CONFIG=$1 6 | GPUS=$2 7 | 8 | $PYTHON -m torch.distributed.launch --nproc_per_node=$GPUS \ 9 | $(dirname "$0")/train.py --config-file $CONFIG ${@:3} 10 | -------------------------------------------------------------------------------- /tools/test_demo.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | 3 | import os 4 | import sys 5 | 6 | cur_path = os.path.abspath(os.path.dirname(__file__)) 7 | root_path = os.path.split(cur_path)[0] 8 | sys.path.append(root_path) 9 | 10 | import logging 11 | import torch 12 | import torch.nn as nn 13 | import torch.utils.data as data 14 | import torch.nn.functional as F 15 | 16 | from tabulate import tabulate 17 | from torchvision import transforms 18 | from segmentron.data.dataloader import get_segmentation_dataset 19 | from segmentron.models.model_zoo import get_segmentation_model 20 | from segmentron.utils.distributed import synchronize, make_data_sampler, make_batch_data_sampler 21 | from segmentron.config import cfg 22 | from segmentron.utils.options import parse_args 23 | from segmentron.utils.default_setup import default_setup 24 | from IPython import embed 25 | from collections import OrderedDict 26 | from segmentron.utils.filesystem import makedirs 27 | import cv2 28 | import numpy as np 29 | 30 | class Evaluator(object): 31 | def __init__(self, args): 32 | self.args = args 33 | self.device = torch.device(args.device) 34 | 35 | # image transform 36 | input_transform = transforms.Compose([ 37 | transforms.ToTensor(), 38 | transforms.Normalize(cfg.DATASET.MEAN, cfg.DATASET.STD), 39 | ]) 40 | 41 | # dataset and dataloader 42 | val_dataset = get_segmentation_dataset(cfg.DATASET.NAME, 43 | root=cfg.DEMO_DIR, 44 | split='val', 45 | mode='val', 46 | transform=input_transform, 47 | base_size=cfg.TRAIN.BASE_SIZE) 48 | 49 | val_sampler = make_data_sampler(val_dataset, shuffle=False, distributed=args.distributed) 50 | val_batch_sampler = make_batch_data_sampler(val_sampler, images_per_batch=cfg.TEST.BATCH_SIZE, drop_last=False) 51 | 52 | self.val_loader = data.DataLoader(dataset=val_dataset, 53 | batch_sampler=val_batch_sampler, 54 | num_workers=cfg.DATASET.WORKERS, 55 | pin_memory=True) 56 | self.classes = val_dataset.classes 57 | # create network 58 | self.model = get_segmentation_model().to(self.device) 59 | 60 | if hasattr(self.model, 'encoder') and cfg.MODEL.BN_EPS_FOR_ENCODER: 61 | logging.info('set bn custom eps for bn in encoder: {}'.format(cfg.MODEL.BN_EPS_FOR_ENCODER)) 62 | self.set_batch_norm_attr(self.model.encoder.named_modules(), 'eps', cfg.MODEL.BN_EPS_FOR_ENCODER) 63 | 64 | if args.distributed: 65 | self.model = nn.parallel.DistributedDataParallel(self.model, 66 | device_ids=[args.local_rank], output_device=args.local_rank, find_unused_parameters=True) 67 | 68 | self.model.to(self.device) 69 | self.count_easy = 0 70 | self.count_hard = 0 71 | def set_batch_norm_attr(self, named_modules, attr, value): 72 | for m in named_modules: 73 | if isinstance(m[1], nn.BatchNorm2d) or isinstance(m[1], nn.SyncBatchNorm): 74 | setattr(m[1], attr, value) 75 | 76 | def eval(self): 77 | self.model.eval() 78 | if self.args.distributed: 79 | model = self.model.module 80 | else: 81 | model = self.model 82 | 83 | for i, (image, _, filename) in enumerate(self.val_loader): 84 | image = image.to(self.device) 85 | filename = filename[0] 86 | save_name = os.path.basename(filename).replace('.jpg', '').replace('.png', '') 87 | 88 | with torch.no_grad(): 89 | output, output_boundary = model.evaluate(image) 90 | ori_img = cv2.imread(filename) 91 | h, w, _ = ori_img.shape 92 | 93 | glass_res = output.argmax(1)[0].data.cpu().numpy().astype('uint8') * 127 94 | # boundary_res = output_boundary[0,0].data.cpu().numpy().astype('uint8') * 255 95 | glass_res = cv2.resize(glass_res, (w, h), interpolation=cv2.INTER_NEAREST) 96 | # boundary_res = cv2.resize(boundary_res, (w, h), interpolation=cv2.INTER_NEAREST) 97 | 98 | save_path = os.path.join('/'.join(cfg.DEMO_DIR.split('/')[:-2]), 'result') 99 | makedirs(save_path) 100 | cv2.imwrite(os.path.join(save_path, '{}_glass.png'.format(save_name)), glass_res) 101 | # cv2.imwrite(os.path.join(save_path, '{}_boundary.png'.format(save_name)), boundary_res) 102 | print('save {}'.format(save_name)) 103 | 104 | 105 | 106 | 107 | 108 | if __name__ == '__main__': 109 | args = parse_args() 110 | cfg.update_from_file(args.config_file) 111 | cfg.update_from_list(args.opts) 112 | cfg.PHASE = 'test' 113 | cfg.ROOT_PATH = root_path 114 | cfg.DATASET.NAME = 'trans10k_extra' 115 | cfg.check_and_freeze() 116 | 117 | default_setup(args) 118 | 119 | evaluator = Evaluator(args) 120 | evaluator.eval() 121 | -------------------------------------------------------------------------------- /tools/test_translab.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | 3 | import os 4 | import sys 5 | 6 | cur_path = os.path.abspath(os.path.dirname(__file__)) 7 | root_path = os.path.split(cur_path)[0] 8 | sys.path.append(root_path) 9 | 10 | import logging 11 | import torch 12 | import torch.nn as nn 13 | import torch.utils.data as data 14 | import torch.nn.functional as F 15 | 16 | from tabulate import tabulate 17 | from torchvision import transforms 18 | from segmentron.data.dataloader import get_segmentation_dataset 19 | from segmentron.models.model_zoo import get_segmentation_model 20 | from segmentron.utils.score import SegmentationMetric 21 | from segmentron.utils.distributed import synchronize, make_data_sampler, make_batch_data_sampler 22 | from segmentron.config import cfg 23 | from segmentron.utils.options import parse_args 24 | from segmentron.utils.default_setup import default_setup 25 | from IPython import embed 26 | from collections import OrderedDict 27 | from segmentron.utils.filesystem import makedirs 28 | from progressbar import * 29 | 30 | 31 | class Evaluator(object): 32 | def __init__(self, args): 33 | self.args = args 34 | self.device = torch.device(args.device) 35 | 36 | # image transform 37 | input_transform = transforms.Compose([ 38 | transforms.ToTensor(), 39 | transforms.Normalize(cfg.DATASET.MEAN, cfg.DATASET.STD), 40 | ]) 41 | 42 | # test dataloader 43 | val_dataset = get_segmentation_dataset(cfg.DATASET.NAME, 44 | split='test', 45 | mode='val', 46 | transform=input_transform, 47 | base_size=cfg.TRAIN.BASE_SIZE) 48 | 49 | # validation dataloader 50 | # val_dataset = get_segmentation_dataset(cfg.DATASET.NAME, 51 | # split='validation', 52 | # mode='val', 53 | # transform=input_transform, 54 | # base_size=cfg.TRAIN.BASE_SIZE) 55 | 56 | 57 | val_sampler = make_data_sampler(val_dataset, shuffle=False, distributed=args.distributed) 58 | val_batch_sampler = make_batch_data_sampler(val_sampler, images_per_batch=cfg.TEST.BATCH_SIZE, drop_last=False) 59 | 60 | self.val_loader = data.DataLoader(dataset=val_dataset, 61 | batch_sampler=val_batch_sampler, 62 | num_workers=cfg.DATASET.WORKERS, 63 | pin_memory=True) 64 | logging.info('**** number of images: {}. ****'.format(len(self.val_loader))) 65 | 66 | self.classes = val_dataset.classes 67 | # create network 68 | self.model = get_segmentation_model().to(self.device) 69 | 70 | if hasattr(self.model, 'encoder') and cfg.MODEL.BN_EPS_FOR_ENCODER: 71 | logging.info('set bn custom eps for bn in encoder: {}'.format(cfg.MODEL.BN_EPS_FOR_ENCODER)) 72 | self.set_batch_norm_attr(self.model.encoder.named_modules(), 'eps', cfg.MODEL.BN_EPS_FOR_ENCODER) 73 | 74 | if args.distributed: 75 | self.model = nn.parallel.DistributedDataParallel(self.model, 76 | device_ids=[args.local_rank], output_device=args.local_rank, find_unused_parameters=True) 77 | 78 | self.model.to(self.device) 79 | num_gpu = args.num_gpus 80 | 81 | # metric of easy and hard images 82 | self.metric = SegmentationMetric(val_dataset.num_class, args.distributed, num_gpu) 83 | self.metric_easy = SegmentationMetric(val_dataset.num_class, args.distributed, num_gpu) 84 | self.metric_hard = SegmentationMetric(val_dataset.num_class, args.distributed, num_gpu) 85 | 86 | # number of easy and hard images 87 | self.count_easy = 0 88 | self.count_hard = 0 89 | 90 | def set_batch_norm_attr(self, named_modules, attr, value): 91 | for m in named_modules: 92 | if isinstance(m[1], nn.BatchNorm2d) or isinstance(m[1], nn.SyncBatchNorm): 93 | setattr(m[1], attr, value) 94 | 95 | def eval(self): 96 | self.metric.reset() 97 | self.model.eval() 98 | if self.args.distributed: 99 | model = self.model.module 100 | else: 101 | model = self.model 102 | 103 | logging.info("Start validation, Total sample: {:d}".format(len(self.val_loader))) 104 | import time 105 | time_start = time.time() 106 | widgets = ['Inference: ', Percentage(), ' ', Bar('#'), ' ', Timer(), 107 | ' ', ETA(), ' ', FileTransferSpeed()] 108 | pbar = ProgressBar(widgets=widgets, maxval=10 * len(self.val_loader)).start() 109 | 110 | for i, (image, target, boundary, filename) in enumerate(self.val_loader): 111 | image = image.to(self.device) 112 | target = target.to(self.device) 113 | boundary = boundary.to(self.device) 114 | 115 | filename = filename[0] 116 | with torch.no_grad(): 117 | output, output_boundary = model.evaluate(image) 118 | 119 | if 'hard' in filename: 120 | self.metric_hard.update(output, target) 121 | self.count_hard += 1 122 | elif 'easy' in filename: 123 | self.metric_easy.update(output, target) 124 | self.count_easy += 1 125 | else: 126 | print(filename) 127 | continue 128 | 129 | self.metric.update(output, target) 130 | pbar.update(10 * i + 1) 131 | 132 | pbar.finish() 133 | synchronize() 134 | pixAcc, mIoU, category_iou, mae, mBer, category_Ber = self.metric.get(return_category_iou=True) 135 | pixAcc_e, mIoU_e, category_iou_e, mae_e, mBer_e, category_Ber_e = self.metric_easy.get(return_category_iou=True) 136 | pixAcc_h, mIoU_h, category_iou_h, mae_h, mBer_h, category_Ber_h = self.metric_hard.get(return_category_iou=True) 137 | 138 | logging.info('Eval use time: {:.3f} second'.format(time.time() - time_start)) 139 | logging.info('End validation pixAcc: {:.2f}, mIoU: {:.2f}, mae: {:.3f}, mBer: {:.2f}'.format( 140 | pixAcc * 100, mIoU * 100, mae, mBer)) 141 | logging.info('End validation easy pixAcc: {:.2f}, mIoU: {:.2f}, mae: {:.3f}, mBer: {:.2f}'.format( 142 | pixAcc_e * 100, mIoU_e * 100, mae_e, mBer_e)) 143 | logging.info('End validation hard pixAcc: {:.2f}, mIoU: {:.2f}, mae: {:.3f}, mBer: {:.2f}'.format( 144 | pixAcc_h * 100, mIoU_h * 100, mae_h, mBer_h)) 145 | 146 | headers = ['class id', 'class name', 'iou', 'iou_easy', 'iou_hard', 'ber', 'ber_easy', 'ber_hard'] 147 | table = [] 148 | for i, cls_name in enumerate(self.classes): 149 | table.append([ 150 | cls_name, category_iou[i], category_iou_e[i], category_iou_h[i], 151 | category_Ber[i], category_Ber_e[i], category_Ber_h[i] 152 | ]) 153 | logging.info('Category iou: \n {}'.format(tabulate(table, headers, tablefmt='grid', showindex="always", 154 | numalign='center', stralign='center'))) 155 | logging.info('easy images: {}, hard images: {}'.format(self.count_easy, self.count_hard)) 156 | 157 | 158 | if __name__ == '__main__': 159 | args = parse_args() 160 | cfg.update_from_file(args.config_file) 161 | cfg.update_from_list(args.opts) 162 | cfg.PHASE = 'test' 163 | cfg.ROOT_PATH = root_path 164 | cfg.check_and_freeze() 165 | 166 | default_setup(args) 167 | 168 | evaluator = Evaluator(args) 169 | evaluator.eval() 170 | -------------------------------------------------------------------------------- /tools/train.py: -------------------------------------------------------------------------------- 1 | import time 2 | import datetime 3 | import os 4 | import sys 5 | 6 | cur_path = os.path.abspath(os.path.dirname(__file__)) 7 | root_path = os.path.split(cur_path)[0] 8 | sys.path.append(root_path) 9 | 10 | import logging 11 | import torch 12 | import torch.nn as nn 13 | import torch.utils.data as data 14 | import torch.nn.functional as F 15 | 16 | from torchvision import transforms 17 | from segmentron.data.dataloader import get_segmentation_dataset 18 | from segmentron.models.model_zoo import get_segmentation_model 19 | from segmentron.solver.loss import get_segmentation_loss 20 | from segmentron.solver.optimizer import get_optimizer 21 | from segmentron.solver.lr_scheduler import get_scheduler 22 | from segmentron.utils.distributed import * 23 | from segmentron.utils.score import SegmentationMetric 24 | from segmentron.utils.filesystem import save_checkpoint 25 | from segmentron.utils.options import parse_args 26 | from segmentron.utils.default_setup import default_setup 27 | from segmentron.utils.visualize import show_flops_params 28 | from segmentron.config import cfg 29 | from IPython import embed 30 | 31 | class Trainer(object): 32 | def __init__(self, args): 33 | self.args = args 34 | self.device = torch.device(args.device) 35 | 36 | # image transform 37 | input_transform = transforms.Compose([ 38 | transforms.ToTensor(), 39 | transforms.Normalize(cfg.DATASET.MEAN, cfg.DATASET.STD), 40 | ]) 41 | # dataset and dataloader 42 | data_kwargs = {'transform': input_transform, 'base_size': cfg.TRAIN.BASE_SIZE, 43 | 'crop_size': cfg.TRAIN.CROP_SIZE} 44 | train_dataset = get_segmentation_dataset(cfg.DATASET.NAME, split='train', mode='train', **data_kwargs) 45 | # #debug code 46 | # import cv2 47 | # for i in range(10): 48 | # img, mask, _ = train_dataset[i] 49 | # print(img.shape, mask.shape) 50 | # mask = mask.data.cpu().numpy()*127 51 | # # mask = cv2.resize(mask, (500,500)) 52 | # # cv2.imwrite('./trash/{}.jpg'.format(i), mask) 53 | # embed(header='check loader') 54 | self.iters_per_epoch = len(train_dataset) // (args.num_gpus * cfg.TRAIN.BATCH_SIZE) 55 | self.max_iters = cfg.TRAIN.EPOCHS * self.iters_per_epoch 56 | 57 | train_sampler = make_data_sampler(train_dataset, shuffle=True, distributed=args.distributed) 58 | train_batch_sampler = make_batch_data_sampler(train_sampler, cfg.TRAIN.BATCH_SIZE, self.max_iters, drop_last=True) 59 | 60 | self.train_loader = data.DataLoader(dataset=train_dataset, 61 | batch_sampler=train_batch_sampler, 62 | num_workers=cfg.DATASET.WORKERS, 63 | pin_memory=True) 64 | 65 | 66 | # create network 67 | self.model = get_segmentation_model().to(self.device) 68 | # print params and flops 69 | if get_rank() == 0: 70 | try: 71 | show_flops_params(self.model, args.device) 72 | except Exception as e: 73 | logging.warning('get flops and params error: {}'.format(e)) 74 | 75 | if cfg.MODEL.BN_TYPE not in ['BN']: 76 | logging.info('Batch norm type is {}, convert_sync_batchnorm is not effective'.format(cfg.MODEL.BN_TYPE)) 77 | elif args.distributed and cfg.TRAIN.SYNC_BATCH_NORM: 78 | self.model = nn.SyncBatchNorm.convert_sync_batchnorm(self.model) 79 | logging.info('SyncBatchNorm is effective!') 80 | else: 81 | logging.info('Not use SyncBatchNorm!') 82 | 83 | # create criterion 84 | self.criterion = get_segmentation_loss(cfg.MODEL.MODEL_NAME, use_ohem=cfg.SOLVER.OHEM, 85 | aux=cfg.SOLVER.AUX, aux_weight=cfg.SOLVER.AUX_WEIGHT, 86 | ignore_index=cfg.DATASET.IGNORE_INDEX).to(self.device) 87 | 88 | cfg.SOLVER.LOSS_NAME = 'dice' 89 | self.criterion_b = get_segmentation_loss(cfg.MODEL.MODEL_NAME).to(self.device) 90 | 91 | # optimizer, for model just includes encoder, decoder(head and auxlayer). 92 | self.optimizer = get_optimizer(self.model) 93 | 94 | # lr scheduling 95 | self.lr_scheduler = get_scheduler(self.optimizer, max_iters=self.max_iters, 96 | iters_per_epoch=self.iters_per_epoch) 97 | 98 | # resume checkpoint if needed 99 | self.start_epoch = 0 100 | if args.resume and os.path.isfile(args.resume): 101 | name, ext = os.path.splitext(args.resume) 102 | assert ext == '.pkl' or '.pth', 'Sorry only .pth and .pkl files supported.' 103 | logging.info('Resuming training, loading {}...'.format(args.resume)) 104 | resume_sate = torch.load(args.resume) 105 | self.model.load_state_dict(resume_sate['state_dict']) 106 | self.start_epoch = resume_sate['epoch'] 107 | logging.info('resume train from epoch: {}'.format(self.start_epoch)) 108 | if resume_sate['optimizer'] is not None and resume_sate['lr_scheduler'] is not None: 109 | logging.info('resume optimizer and lr scheduler from resume state..') 110 | self.optimizer.load_state_dict(resume_sate['optimizer']) 111 | self.lr_scheduler.load_state_dict(resume_sate['lr_scheduler']) 112 | 113 | if args.distributed: 114 | self.model = nn.parallel.DistributedDataParallel(self.model, device_ids=[args.local_rank], 115 | output_device=args.local_rank, 116 | find_unused_parameters=True) 117 | 118 | # # evaluation metrics 119 | # self.metric = SegmentationMetric(train_dataset.num_class, args.distributed) 120 | # self.best_pred = 0.0 121 | 122 | 123 | def train(self): 124 | self.save_to_disk = get_rank() == 0 125 | epochs, max_iters, iters_per_epoch = cfg.TRAIN.EPOCHS, self.max_iters, self.iters_per_epoch 126 | log_per_iters, val_per_iters = self.args.log_iter, self.args.val_epoch * self.iters_per_epoch 127 | 128 | start_time = time.time() 129 | logging.info('Start training, Total Epochs: {:d} = Total Iterations {:d}'.format(epochs, max_iters)) 130 | 131 | self.model.train() 132 | iteration = self.start_epoch * iters_per_epoch if self.start_epoch > 0 else 0 133 | for (images, targets, boundary, _) in self.train_loader: 134 | epoch = iteration // iters_per_epoch + 1 135 | iteration += 1 136 | 137 | images = images.to(self.device) 138 | targets = targets.to(self.device) 139 | boundarys = boundary.to(self.device) 140 | 141 | outputs, outputs_boundary = self.model(images) 142 | 143 | loss_dict = self.criterion(outputs, targets) 144 | # embed(header='check loss') 145 | boundarys = boundarys.float() 146 | valid = torch.ones_like(boundarys) 147 | lossb_dict = self.criterion_b(outputs_boundary[0], boundarys, valid) 148 | 149 | weight_boundary = 5 150 | lossb_dict['loss'] = weight_boundary * lossb_dict['loss'] 151 | 152 | losses = sum(loss for loss in loss_dict.values()) + \ 153 | sum(loss for loss in lossb_dict.values()) 154 | 155 | # reduce losses over all GPUs for logging purposes 156 | loss_dict_reduced = reduce_loss_dict(loss_dict) 157 | losses_reduced = sum(loss for loss in loss_dict_reduced.values()) 158 | 159 | lossb_dict_reduced = reduce_loss_dict(lossb_dict) 160 | lossesb_reduced = sum(loss for loss in lossb_dict_reduced.values()) 161 | 162 | # embed(header='check loader') 163 | 164 | 165 | self.optimizer.zero_grad() 166 | losses.backward() 167 | self.optimizer.step() 168 | self.lr_scheduler.step() 169 | 170 | eta_seconds = ((time.time() - start_time) / iteration) * (max_iters - iteration) 171 | eta_string = str(datetime.timedelta(seconds=int(eta_seconds))) 172 | 173 | if iteration % log_per_iters == 0 and self.save_to_disk: 174 | logging.info( 175 | "Epoch: {:d}/{:d} || Iters: {:d}/{:d} || Lr: {:.6f} || " 176 | "Loss: {:.4f} || Loss_b: {:.4f} || Cost Time: {} || Estimated Time: {}".format( 177 | epoch, epochs, iteration % iters_per_epoch, iters_per_epoch, 178 | self.optimizer.param_groups[0]['lr'], losses_reduced.item(), lossesb_reduced.item(), 179 | str(datetime.timedelta(seconds=int(time.time() - start_time))), 180 | eta_string)) 181 | 182 | if iteration % self.iters_per_epoch == 0 and self.save_to_disk: 183 | save_checkpoint(self.model, epoch, self.optimizer, self.lr_scheduler, is_best=False) 184 | 185 | total_training_time = time.time() - start_time 186 | total_training_str = str(datetime.timedelta(seconds=total_training_time)) 187 | logging.info( 188 | "Total training time: {} ({:.4f}s / it)".format( 189 | total_training_str, total_training_time / max_iters)) 190 | 191 | 192 | 193 | if __name__ == '__main__': 194 | args = parse_args() 195 | # get config 196 | cfg.update_from_file(args.config_file) 197 | cfg.update_from_list(args.opts) 198 | cfg.PHASE = 'train' 199 | cfg.ROOT_PATH = root_path 200 | cfg.check_and_freeze() 201 | 202 | # setup python train environment, logger, seed.. 203 | default_setup(args) 204 | 205 | # create a trainer and start train 206 | trainer = Trainer(args) 207 | trainer.train() 208 | --------------------------------------------------------------------------------