├── .gitignore ├── Decoding ├── decode_autodeeplab.py └── decoding_formulas.py ├── LICENSE ├── README.md ├── dataloaders ├── __init__.py ├── custom_transforms.py ├── datasets │ ├── __init__.py │ ├── cityscapes.py │ └── pascal.py └── utils.py ├── eval.py ├── eval_edm.py ├── modeling ├── .DS_Store ├── ADD.py ├── __init__.py ├── aspp_train.py ├── autodeeplab.py ├── baseline_model.py ├── cell_level_search.py ├── decoder.py ├── genotypes.py ├── model_baseline_path_search.py ├── model_net_search.py ├── model_search.py ├── operations.py └── sync_batchnorm │ ├── __init__.py │ ├── batchnorm.py │ ├── comm.py │ ├── replicate.py │ └── unittest.py ├── mypath.py ├── scripts ├── eval.sh ├── search_cityscapes.sh ├── train_dist.sh └── train_edm.sh ├── search.py ├── search_layer.py ├── searched_arch ├── 40_5e_38_lr │ ├── genotype_1.npy │ └── genotype_2.npy ├── autodeeplab │ └── genotype.npy └── searched_baseline │ ├── genotype_1.npy │ ├── genotype_2.npy │ ├── network_path.npy │ └── network_path_space.npy ├── train.py ├── train_edm.py └── utils ├── calculate_weights.py ├── copy_state_dict.py ├── eval_utils.py ├── loss.py ├── lr_scheduler.py ├── metrics.py ├── multadds_count.py ├── saver.py └── summaries.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | *.egg-info/ 24 | .installed.cfg 25 | *.egg 26 | 27 | # PyInstaller 28 | # Usually these files are written by a python script from a template 29 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 30 | *.manifest 31 | *.spec 32 | 33 | # Installer logs 34 | pip-log.txt 35 | pip-delete-this-directory.txt 36 | 37 | # Unit test / coverage reports 38 | htmlcov/ 39 | .tox/ 40 | .coverage 41 | .coverage.* 42 | .cache 43 | nosetests.xml 44 | coverage.xml 45 | *.cover 46 | .hypothesis/ 47 | 48 | # Translations 49 | *.mo 50 | *.pot 51 | 52 | # Django stuff: 53 | *.log 54 | local_settings.py 55 | 56 | # Flask stuff: 57 | instance/ 58 | .webassets-cache 59 | 60 | # Scrapy stuff: 61 | .scrapy 62 | 63 | # Sphinx documentation 64 | docs/_build/ 65 | 66 | # PyBuilder 67 | target/ 68 | 69 | # Jupyter Notebook 70 | .ipynb_checkpoints 71 | 72 | # pyenv 73 | .python-version 74 | 75 | # celery beat schedule file 76 | celerybeat-schedule 77 | 78 | # SageMath parsed files 79 | *.sage.py 80 | 81 | # Environments 82 | .env 83 | .venv 84 | env/ 85 | venv/ 86 | ENV/ 87 | 88 | # Spyder project settings 89 | .spyderproject 90 | .spyproject 91 | 92 | # Rope project settings 93 | .ropeproject 94 | 95 | # mkdocs documentation 96 | /site 97 | 98 | # mypy 99 | .mypy_cache/ 100 | -------------------------------------------------------------------------------- /Decoding/decode_autodeeplab.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import numpy as np 4 | from tqdm import tqdm 5 | import sys 6 | import torch 7 | from collections import OrderedDict 8 | from mypath import Path 9 | from modeling.sync_batchnorm.replicate import patch_replication_callback 10 | from utils.saver import Saver 11 | from utils.summaries import TensorboardSummary 12 | from utils.metrics import Evaluator 13 | from modeling.model_search import AutoDeeplab 14 | from decoding.ecoding_formulas import Decoder 15 | 16 | 17 | class Loader(object): 18 | def __init__(self, args): 19 | self.args = args 20 | if self.args.dataset == 'cityscapes': 21 | self.nclass = 19 22 | 23 | if self.args.network == 'supernet': 24 | model = Model_search(num_classes=self.nclass, num_layers=12) 25 | elif self.args.network == 'layer_supernet': 26 | cell_path = os.path.join(args.saved_arch_path, 'autodeeplab', 'genotype.npy') 27 | cell_arch = np.load(cell_path) 28 | model = Model_layer_search(num_classes=self.nclass, num_layers=12) 29 | else: 30 | model = Model_search_baseline(num_classes=self.nclass, num_layers=12) 31 | # Using cuda 32 | if args.cuda: 33 | 34 | self.model = self.model.cuda() 35 | print('cuda finished') 36 | # Resuming checkpoint 37 | 38 | if args.resume is not None: 39 | if not os.path.isfile(args.resume): 40 | raise RuntimeError("=> no checkpoint found at '{}'".format(args.resume)) 41 | checkpoint = torch.load(args.resume) 42 | args.start_epoch = checkpoint['epoch'] 43 | 44 | # if the weights are wrapped in module object we have to clean it 45 | if args.clean_module: 46 | self.model.load_state_dict(checkpoint['state_dict']) 47 | state_dict = checkpoint['state_dict'] 48 | new_state_dict = OrderedDict() 49 | for k, v in state_dict.items(): 50 | name = k[7:] # remove 'module.' of dataparallel 51 | new_state_dict[name] = v 52 | self.model.load_state_dict(new_state_dict) 53 | else: 54 | if (torch.cuda.device_count() > 1): 55 | self.model.module.load_state_dict(checkpoint['state_dict']) 56 | else: 57 | self.model.load_state_dict(checkpoint['state_dict']) 58 | self.decoder = Decoder(self.model.alphas, 59 | self.model.betas, 60 | 5) 61 | print(self.model.betas) 62 | def retreive_alphas_betas(self): 63 | return self.model.alphas, self.model.bottom_betas, self.model.betas8, self.model.betas16, self.model.top_betas 64 | 65 | def decode_architecture(self): 66 | paths, paths_space = self.decoder.viterbi_decode() 67 | return paths, paths_space 68 | 69 | def decode_cell(self): 70 | genotype_d, genotype_c = self.decoder.genotype_decode() 71 | return genotype_d, genotype_c 72 | 73 | 74 | def get_new_network_cell() : 75 | parser = argparse.ArgumentParser(description="PyTorch DeeplabV3Plus Training") 76 | parser.add_argument('--backbone', type=str, default='resnet', 77 | choices=['resnet', 'xception', 'drn', 'mobilenet'], 78 | help='backbone name (default: resnet)') 79 | parser.add_argument('--dataset', type=str, default='cityscapes', 80 | choices=['pascal', 'coco', 'cityscapes', 'kd'], 81 | help='dataset name (default: pascal)') 82 | parser.add_argument('--autodeeplab', type=str, default='train', 83 | choices=['search', 'train']) 84 | parser.add_argument('--load-parallel', type=int, default=0) 85 | parser.add_argument('--clean-module', type=int, default=0) 86 | 87 | parser.add_argument('--test-batch-size', type=int, default=None, 88 | metavar='N', help='input batch size for \ 89 | testing (default: auto)') 90 | parser.add_argument('--no-cuda', action='store_true', default= 91 | False, help='disables CUDA training') 92 | parser.add_argument('--resume', type=str, default=None, 93 | help='put the path to resuming file if needed') 94 | 95 | 96 | args = parser.parse_args() 97 | args.cuda = not args.no_cuda and torch.cuda.is_available() 98 | 99 | load_model = Loader(args) 100 | result_paths, result_paths_space = load_model.decode_architecture() 101 | network_path = result_paths#.numpy() 102 | network_path_space = result_paths_space#.numpy() 103 | genotype_d, genotype_c = load_model.decode_cell() 104 | # print('arch space :', network_path_space) 105 | print ('architecture search results:',network_path) 106 | print ('new cell structure_device:', genotype_d) 107 | print ('new cell structure_cloud:', genotype_c) 108 | 109 | dir_name = os.path.dirname(args.resume) 110 | network_path_filename = os.path.join(dir_name,'network_path') 111 | network_path_space_filename = os.path.join(dir_name, 'network_path_space') 112 | genotype_filename_d = os.path.join(dir_name, 'genotype_device') 113 | genotype_filename_c = os.path.join(dir_name, 'genotype_cloud') 114 | 115 | np.save(network_path_filename, network_path) 116 | np.save(network_path_space_filename, network_path_space) 117 | np.save(genotype_filename_d, genotype_d) 118 | np.save(genotype_filename_c, genotype_c) 119 | 120 | 121 | print('saved to :', dir_name) 122 | 123 | if __name__ == '__main__' : 124 | get_new_network_cell() 125 | -------------------------------------------------------------------------------- /Decoding/decoding_formulas.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn.functional as F 4 | from modeling.genotypes import PRIMITIVES 5 | from modeling.genotypes import Genotype 6 | 7 | def network_layer_to_space(net_arch): 8 | for i, layer in enumerate(net_arch): 9 | if i == 0: 10 | space = np.zeros((1, 4, 3)) 11 | space[0][layer][0] = 1 12 | prev = layer 13 | else: 14 | if layer == prev + 1: 15 | sample = 0 16 | elif layer == prev: 17 | sample = 1 18 | elif layer == prev - 1: 19 | sample = 2 20 | space1 = np.zeros((1, 4, 3)) 21 | space1[0][layer][sample] = 1 22 | space = np.concatenate([space, space1], axis=0) 23 | prev = layer 24 | return space 25 | 26 | 27 | class Decoder(object): 28 | def __init__(self, alphas=None, betas=None, B=None): 29 | self._betas = betas 30 | self._alphas = alphas 31 | self._B = B 32 | self._num_layers = len(self._betas) 33 | self.network_space = torch.zeros(12, 4, 3) 34 | for layer in range(len(self._betas)): 35 | if layer == 0: 36 | self.network_space[layer][0][1:] = F.softmax(self._betas[layer][0][1:], dim=-1) * (2/3) 37 | elif layer == 1: 38 | self.network_space[layer][0][1:] = F.softmax(self._betas[layer][0][1:], dim=-1) * (2/3) 39 | self.network_space[layer][1] = F.softmax(self._betas[layer][1], dim=-1) 40 | 41 | elif layer == 2: 42 | self.network_space[layer][0][1:] = F.softmax(self._betas[layer][0][1:], dim=-1) * (2/3) 43 | self.network_space[layer][1] = F.softmax(self._betas[layer][1], dim=-1) 44 | self.network_space[layer][2] = F.softmax(self._betas[layer][2], dim=-1) 45 | else: 46 | self.network_space[layer][0][1:] = F.softmax(self._betas[layer][0][1:], dim=-1) * (2/3) 47 | self.network_space[layer][1] = F.softmax(self._betas[layer][1], dim=-1) 48 | self.network_space[layer][2] = F.softmax(self._betas[layer][2], dim=-1) 49 | self.network_space[layer][3][:2] = F.softmax(self._betas[layer][3][:2], dim=-1) * (2/3) 50 | 51 | 52 | def viterbi_decode(self): 53 | prob_space = np.zeros((self.network_space.shape[:2])) 54 | path_space = np.zeros((self.network_space.shape[:2])).astype('int8') 55 | 56 | for layer in range(self.network_space.shape[0]): 57 | if layer == 0: 58 | prob_space[layer][0] = self.network_space[layer][0][1] 59 | prob_space[layer][1] = self.network_space[layer][0][2] 60 | path_space[layer][0] = 0 61 | path_space[layer][1] = -1 62 | else: 63 | for sample in range(self.network_space.shape[1]): 64 | if layer - sample < - 1: 65 | continue 66 | local_prob = [] 67 | for rate in range(self.network_space.shape[2]): # k[0 : ➚, 1: ➙, 2 : ➘] 68 | if (sample == 0 and rate == 2) or (sample == 3 and rate == 0): 69 | continue 70 | else: 71 | local_prob.append(prob_space[layer - 1][sample + 1 - rate] * 72 | self.network_space[layer][sample + 1 - rate][rate]) 73 | prob_space[layer][sample] = np.max(local_prob, axis=0) 74 | rate = np.argmax(local_prob, axis=0) 75 | path = 1 - rate if sample != 3 else -rate 76 | path_space[layer][sample] = path # path[1 : ➚, 0: ➙, -1 : ➘] 77 | 78 | output_sample = prob_space[-1, :].argmax(axis=-1) 79 | actual_path = np.zeros(12).astype('uint8') 80 | actual_path[-1] = output_sample 81 | for i in range(1, self._num_layers): 82 | actual_path[-i - 1] = actual_path[-i] + path_space[self._num_layers - i, actual_path[-i]] 83 | return actual_path, network_layer_to_space(actual_path) 84 | 85 | 86 | def genotype_decode(self): 87 | 88 | def _parse(alphas, steps): 89 | gene = [] 90 | start = 0 91 | n = 2 92 | for i in range(steps): 93 | end = start + n 94 | edges = sorted(range(start, end), key=lambda x: -np.max(alphas[x, 1:])) # ignore none value 95 | top2edges = edges[:2] 96 | for j in top2edges: 97 | best_op_index = np.argmax(alphas[j]) # this can include none op 98 | gene.append([j, best_op_index]) 99 | start = end 100 | n += 1 101 | return np.array(gene) 102 | 103 | normalized_alphas = F.softmax(self._alphas, dim=-1).data.cpu().numpy() 104 | 105 | gene_cell = _parse(normalized_alphas, self._B) 106 | return gene_cell 107 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "{}" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright {yyyy} {name of copyright owner} 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # AutoDynamicDeepLab 2 | This repository is for the IROS 2021 paper [ADD: A Fine-grained Dynamic Inference Architecture for Semantic Image Segmentation](https://ieeexplore.ieee.org/abstract/document/9636650). 3 | 4 | Dynamic-Auto-DeepLab performs three-stage training by firstly searching for the architecture. Second, train the model with the searched network architecture. Third, train earlier-exit-decision. 5 | 6 | **Modifiy path to your Cityscapes in mypath.py** 7 | 8 | 9 | ## Neural Architecture Search 10 | 11 | **We search for the architecture on Cityscapes** 12 | 13 | ``` 14 | cd scripts 15 | bash search_cityscapes.sh 16 | ``` 17 | 18 | **The searched architecture and searching progress can be seen by:** 19 | ``` 20 | tensorboard --logdir path-to-your-exp 21 | ``` 22 | ## Train model: 23 | **One can choose network to train by modified .sh file. Note that we the batch size is #GPU/16 since we use torch.distributed** 24 | 25 | ``` 26 | bash train_dist.sh 27 | ``` 28 | 29 | ## Train earlier-decision-maker (EDM) with the feature processed by the model we just trained: 30 | ``` 31 | bash train_edm.sh 32 | ``` 33 | 34 | ## Evaluation on Cityscapes: 35 | ``` 36 | bash eval.sh 37 | ``` 38 | 39 | ## Requirements 40 | 41 | * Pytorch version 1.0+ 42 | 43 | * Python 3 44 | 45 | * tensorboardX 46 | 47 | * pycocotools 48 | 49 | * tqdm 50 | 51 | * apex 52 | 53 | ## Citation 54 | 55 | ## Acknowledgement 56 | [Auto-DeepLab](https://github.com/NoamRosenberg/AutoML) 57 | 58 | [pytorch-deeplab-xception](https://github.com/jfzhang95/pytorch-deeplab-xception) 59 | 60 | [DeepLabv3.pytorch](https://github.com/chenxi116/DeepLabv3.pytorch) 61 | 62 | [Synchronized-BatchNorm-PyTorch](https://github.com/vacancy/Synchronized-BatchNorm-PyTorch) 63 | 64 | -------------------------------------------------------------------------------- /dataloaders/__init__.py: -------------------------------------------------------------------------------- 1 | from dataloaders.datasets import cityscapes, pascal 2 | from torch.utils.data import DataLoader 3 | from torch.utils.data.distributed import DistributedSampler 4 | 5 | def make_data_loader(args, **kwargs): 6 | 7 | if args.dataset == 'pascal': 8 | train_set = pascal.VOCSegmentation('../../../Pascal/VOCdevkit', train=True) 9 | val_set = pascal.VOCSegmentation('../../../Pascal/VOCdevkit', train=False) 10 | 11 | num_class = train_set.NUM_CLASSES 12 | train_loader = DataLoader(train_set, batch_size=args.batch_size, shuffle=True, **kwargs) 13 | val_loader = DataLoader(val_set, batch_size=args.test_batch_size, shuffle=False, **kwargs) 14 | test_loader = None 15 | 16 | return train_loader, val_loader, test_loader, num_class 17 | 18 | elif 'cityscapes' in args.dataset: 19 | if args.dataset == 'cityscapes_edm': 20 | train_set = cityscapes.CityscapesSegmentation(args, split='train', full=True) 21 | num_class = train_set.NUM_CLASSES 22 | train_loader = DataLoader(train_set, batch_size=args.batch_size, shuffle=True, **kwargs) 23 | else: 24 | if 'supernet' in args.network: 25 | train_set1, train_set2 = cityscapes.twoTrainSeg(args) 26 | num_class = train_set1.NUM_CLASSES 27 | train_loader1 = DataLoader(train_set1, batch_size=args.batch_size, shuffle=True, **kwargs) 28 | train_loader2 = DataLoader(train_set2, batch_size=args.batch_size, shuffle=True, **kwargs) 29 | else: 30 | train_set = cityscapes.CityscapesSegmentation(args, split='train') 31 | num_class = train_set.NUM_CLASSES 32 | if args.dist: 33 | train_loader = DataLoader(train_set, batch_size=args.batch_size, shuffle=False, sampler=DistributedSampler(train_set), **kwargs) 34 | else: 35 | train_loader = DataLoader(train_set, batch_size=args.batch_size, shuffle=True, **kwargs) 36 | 37 | if args.network != None: 38 | val_set = cityscapes.CityscapesSegmentation(args, split='val') 39 | test_set = cityscapes.CityscapesSegmentation(args, split='test') 40 | else: 41 | raise Exception('autodeeplab param not set properly') 42 | 43 | val_loader = DataLoader(val_set, batch_size=args.test_batch_size, shuffle=False, **kwargs) 44 | test_loader = DataLoader(test_set, batch_size=args.test_batch_size, shuffle=False, **kwargs) 45 | 46 | 47 | if 'supernet' in args.network: 48 | return train_loader1, train_loader2, val_loader, test_loader, num_class 49 | else: 50 | return train_loader, val_loader, test_loader, num_class 51 | 52 | 53 | elif args.dataset == 'coco': 54 | train_set = coco.COCOSegmentation(args, split='train') 55 | val_set = coco.COCOSegmentation(args, split='val') 56 | num_class = train_set.NUM_CLASSES 57 | train_loader = DataLoader(train_set, batch_size=args.batch_size, shuffle=True, **kwargs) 58 | val_loader = DataLoader(val_set, batch_size=args.batch_size, shuffle=False, **kwargs) 59 | test_loader = None 60 | return train_loader, train_loader, val_loader, test_loader, num_class 61 | 62 | else: 63 | raise NotImplementedError 64 | 65 | -------------------------------------------------------------------------------- /dataloaders/custom_transforms.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import random 3 | import numpy as np 4 | import torchvision.transforms as transforms 5 | from PIL import Image, ImageOps, ImageFilter 6 | import math 7 | class Normalize(object): 8 | """Normalize a tensor image with mean and standard deviation. 9 | Args: 10 | mean (tuple): means for each channel. 11 | std (tuple): standard deviations for each channel. 12 | """ 13 | def __init__(self, mean=(0., 0., 0.), std=(1., 1., 1.)): 14 | self.mean = mean 15 | self.std = std 16 | 17 | def __call__(self, sample): 18 | img = sample['image'] 19 | mask = sample['label'] 20 | img = np.array(img).astype(np.float32) 21 | mask = np.array(mask).astype(np.float32) 22 | img /= 255.0 23 | img -= self.mean 24 | img /= self.std 25 | 26 | return {'image': img, 27 | 'label': mask} 28 | 29 | 30 | class ToTensor(object): 31 | """Convert ndarrays in sample to Tensors.""" 32 | 33 | def __call__(self, sample): 34 | # swap color axis because 35 | # numpy image: H x W x C 36 | # torch image: C X H X W 37 | img = sample['image'] 38 | mask = sample['label'] 39 | img = np.array(img).astype(np.float32).transpose((2, 0, 1)) 40 | mask = np.array(mask).astype(np.float32) 41 | 42 | img = torch.from_numpy(img).float() 43 | mask = torch.from_numpy(mask).float() 44 | 45 | return {'image': img, 46 | 'label': mask} 47 | 48 | 49 | class RandomHorizontalFlip(object): 50 | def __call__(self, sample): 51 | img = sample['image'] 52 | mask = sample['label'] 53 | if random.random() < 0.5: 54 | img = img.transpose(Image.FLIP_LEFT_RIGHT) 55 | mask = mask.transpose(Image.FLIP_LEFT_RIGHT) 56 | 57 | return {'image': img, 58 | 'label': mask} 59 | 60 | 61 | class RandomRotate(object): 62 | def __init__(self, degree): 63 | self.degree = degree 64 | 65 | def __call__(self, sample): 66 | img = sample['image'] 67 | mask = sample['label'] 68 | rotate_degree = random.uniform(-1*self.degree, self.degree) 69 | img = img.rotate(rotate_degree, Image.BILINEAR) 70 | mask = mask.rotate(rotate_degree, Image.NEAREST) 71 | 72 | return {'image': img, 73 | 'label': mask} 74 | 75 | 76 | class RandomGaussianBlur(object): 77 | def __call__(self, sample): 78 | img = sample['image'] 79 | mask = sample['label'] 80 | if random.random() < 0.5: 81 | img = img.filter(ImageFilter.GaussianBlur( 82 | radius=random.random())) 83 | 84 | return {'image': img, 85 | 'label': mask} 86 | 87 | 88 | class RandomScaleCrop(object): 89 | def __init__(self, base_size, crop_size, fill=0): 90 | self.base_size = base_size 91 | self.crop_size = crop_size 92 | self.fill = fill 93 | 94 | def __call__(self, sample): 95 | img = sample['image'] 96 | mask = sample['label'] 97 | # random scale (short edge) 98 | short_size = random.randint(int(self.base_size * 0.5), int(self.base_size * 2.0)) 99 | w, h = img.size 100 | if h > w: 101 | ow = short_size 102 | oh = int(1.0 * h * ow / w) 103 | else: 104 | oh = short_size 105 | ow = int(1.0 * w * oh / h) 106 | img = img.resize((ow, oh), Image.BILINEAR) 107 | mask = mask.resize((ow, oh), Image.NEAREST) 108 | # pad crop 109 | if short_size < self.crop_size: 110 | padh = self.crop_size - oh if oh < self.crop_size else 0 111 | padw = self.crop_size - ow if ow < self.crop_size else 0 112 | img = ImageOps.expand(img, border=(0, 0, padw, padh), fill=0) 113 | mask = ImageOps.expand(mask, border=(0, 0, padw, padh), fill=self.fill) 114 | # random crop crop_size 115 | w, h = img.size 116 | x1 = random.randint(0, w - self.crop_size) 117 | y1 = random.randint(0, h - self.crop_size) 118 | img = img.crop((x1, y1, x1 + self.crop_size, y1 + self.crop_size)) 119 | mask = mask.crop((x1, y1, x1 + self.crop_size, y1 + self.crop_size)) 120 | 121 | return {'image': img, 122 | 'label': mask} 123 | 124 | 125 | class FixScaleCrop(object): 126 | def __init__(self, crop_size): 127 | self.crop_size = crop_size 128 | 129 | def __call__(self, sample): 130 | img = sample['image'] 131 | mask = sample['label'] 132 | w, h = img.size 133 | if w > h: 134 | oh = self.crop_size 135 | ow = int(1.0 * w * oh / h) 136 | else: 137 | ow = self.crop_size 138 | oh = int(1.0 * h * ow / w) 139 | img = img.resize((ow, oh), Image.BILINEAR) 140 | mask = mask.resize((ow, oh), Image.NEAREST) 141 | # center crop 142 | w, h = img.size 143 | x1 = int(round((w - self.crop_size) / 2.)) 144 | y1 = int(round((h - self.crop_size) / 2.)) 145 | img = img.crop((x1, y1, x1 + self.crop_size, y1 + self.crop_size)) 146 | mask = mask.crop((x1, y1, x1 + self.crop_size, y1 + self.crop_size)) 147 | 148 | return {'image': img, 149 | 'label': mask} 150 | 151 | # resize to 512*1024 152 | class FixedResize(object): 153 | """change the short edge length to size""" 154 | def __init__(self, resize=512): 155 | self.resize = resize # size= 512 156 | def __call__(self, sample): 157 | img = sample['image'] 158 | mask = sample['label'] 159 | # print(img.size) 160 | 161 | # print(mask.size) 162 | assert img.size == mask.size 163 | 164 | w, h = img.size 165 | pad_tb = max(0, self.resize[0] - h) 166 | pad_lr = max(0, self.resize[1] - w) 167 | data_transforms = transforms.Compose([ 168 | transforms.ToTensor(), 169 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) 170 | ]) 171 | img = data_transforms(img) 172 | mask = torch.LongTensor(np.array(mask).astype(np.int64)) 173 | 174 | img = torch.nn.ZeroPad2d((0, pad_lr, 0, pad_tb))(img) 175 | mask = torch.nn.ConstantPad2d((0, pad_lr, 0, pad_tb), 255)(mask) 176 | 177 | h, w = img.shape[1], img.shape[2] 178 | i = random.randint(0, h - self.resize[0]) 179 | j = random.randint(0, w - self.resize[1]) 180 | img = img[:, i:i + self.resize[0], j:j + self.resize[1]] 181 | mask = mask[i:i + self.resize[0], j:j + self.resize[1]] 182 | 183 | return {'image': img, 184 | 'label': mask} 185 | 186 | # random crop 321*321 187 | class RandomCrop(object): 188 | def __init__(self, crop_size=769): 189 | self.crop_size = crop_size 190 | 191 | def __call__(self, sample): 192 | img = sample['image'] 193 | mask = sample['label'] 194 | w, h = img.size 195 | x1 = random.randint(0, w - self.crop_size) 196 | y1 = random.randint(0, h - self.crop_size) 197 | img = img.crop((x1, y1, x1 + self.crop_size, y1 + self.crop_size)) 198 | mask = mask.crop((x1, y1, x1 + self.crop_size, y1 + self.crop_size)) 199 | return {'image': img, 200 | 'label': mask} 201 | 202 | class FixedResize_Search(object): 203 | """change the short edge length to size""" 204 | 205 | def __init__(self, resize=512): 206 | self.size1 = resize # size= 512 207 | 208 | def __call__(self, sample): 209 | img = sample['image'] 210 | mask = sample['label'] 211 | assert img.size == mask.size 212 | 213 | w, h = img.size 214 | if w > h: 215 | oh = self.size1 216 | ow = int(1.0 * w * oh / h) 217 | else: 218 | ow = self.size1 219 | oh = int(1.0 * h * ow / w) 220 | img = img.resize((ow, oh), Image.BILINEAR) 221 | mask = mask.resize((ow, oh), Image.NEAREST) 222 | return {'image': img, 223 | 'label': mask} 224 | 225 | class Crop_for_eval(object): 226 | def __init__(self): 227 | self.fill=255 228 | 229 | def __call__(self, sample): 230 | img = sample['image'] 231 | mask = sample['label'] 232 | img = ImageOps.expand(img, border=(0, 0, 1, 1), fill=0) 233 | mask = ImageOps.expand(mask, border=(0, 0, 1, 1), fill=self.fill) 234 | 235 | return {'image': img, 236 | 'label': mask} 237 | 238 | class train_preprocess(object): 239 | def __init__(self, crop_size, mean, std, scale=0): 240 | self.crop_size = crop_size 241 | self.mean = mean 242 | self.std = std 243 | self.scale = scale 244 | 245 | def __call__(self, sample): 246 | image = sample['image'] 247 | mask = sample['label'] 248 | 249 | if random.random() < 0.5: 250 | image = image.transpose(Image.FLIP_LEFT_RIGHT) 251 | mask = mask.transpose(Image.FLIP_LEFT_RIGHT) 252 | if self.scale == 0: 253 | scale=(0.5, 2.0) 254 | w, h = image.size 255 | rand_log_scale = math.log(scale[0], 2) + random.random() * (math.log(scale[1], 2) - math.log(scale[0], 2)) 256 | random_scale = math.pow(2, rand_log_scale) 257 | new_size = (int(round(w * random_scale)), int(round(h * random_scale))) 258 | image = image.resize(new_size, Image.ANTIALIAS) 259 | mask = mask.resize(new_size, Image.NEAREST) 260 | else: 261 | w, h = image.size 262 | new_size = (int(round(w * self.scale)), int(round(h * self.scale))) 263 | image = image.resize(new_size, Image.ANTIALIAS) 264 | mask = mask.resize(new_size, Image.NEAREST) 265 | 266 | data_transforms = transforms.Compose([ 267 | transforms.ToTensor(), 268 | transforms.Normalize(self.mean, self.std) 269 | ]) 270 | image = data_transforms(image) 271 | mask = torch.LongTensor(np.array(mask).astype(np.int64)) 272 | 273 | h, w = image.shape[1], image.shape[2] 274 | pad_tb = max(0, self.crop_size[0] - h) 275 | pad_lr = max(0, self.crop_size[1] - w) 276 | image = torch.nn.ZeroPad2d((0, pad_lr, 0, pad_tb))(image) 277 | mask = torch.nn.ConstantPad2d((0, pad_lr, 0, pad_tb), 255)(mask) 278 | 279 | h, w = image.shape[1], image.shape[2] 280 | i = random.randint(0, h - self.crop_size[0]) 281 | j = random.randint(0, w - self.crop_size[1]) 282 | image = image[:, i:i + self.crop_size[0], j:j + self.crop_size[1]] 283 | mask = mask[i:i + self.crop_size[0], j:j + self.crop_size[1]] 284 | 285 | return {'image': image, 286 | 'label': mask} 287 | 288 | 289 | class eval_preprocess(object): 290 | def __init__(self, crop_size, mean, std): 291 | self.crop_size = crop_size 292 | self.mean = mean 293 | self.std = std 294 | 295 | def __call__(self, sample): 296 | image = sample['image'] 297 | mask = sample['label'] 298 | 299 | data_transforms = transforms.Compose([ 300 | transforms.ToTensor(), 301 | transforms.Normalize(self.mean, self.std) 302 | ]) 303 | 304 | image = data_transforms(image) 305 | mask = torch.LongTensor(np.array(mask).astype(np.int64)) 306 | 307 | h, w = image.shape[1], image.shape[2] 308 | pad_tb = max(0, self.crop_size[0] - h) 309 | pad_lr = max(0, self.crop_size[1] - w) 310 | image = torch.nn.ZeroPad2d((0, pad_lr, 0, pad_tb))(image) 311 | mask = torch.nn.ConstantPad2d((0, pad_lr, 0, pad_tb), 255)(mask) 312 | 313 | h, w = image.shape[1], image.shape[2] 314 | i = random.randint(0, h - self.crop_size[0]) 315 | j = random.randint(0, w - self.crop_size[1]) 316 | image = image[:, i:i + self.crop_size[0], j:j + self.crop_size[1]] 317 | mask = mask[i:i + self.crop_size[0], j:j + self.crop_size[1]] 318 | 319 | return {'image': image, 320 | 'label': mask} 321 | 322 | class full_image_eval_preprocess(object): 323 | def __init__(self, crop_size, mean, std): 324 | self.crop_size = crop_size 325 | self.mean = mean 326 | self.std = std 327 | 328 | def __call__(self, sample): 329 | image = sample['image'] 330 | mask = sample['label'] 331 | 332 | data_transforms = transforms.Compose([ 333 | transforms.ToTensor(), 334 | transforms.Normalize(self.mean, self.std) 335 | ]) 336 | 337 | image = data_transforms(image) 338 | mask = torch.LongTensor(np.array(mask).astype(np.int64)) 339 | 340 | h, w = image.shape[1], image.shape[2] 341 | pad_tb = max(0, self.crop_size[0] - h) 342 | pad_lr = max(0, self.crop_size[1] - w) 343 | image = torch.nn.ZeroPad2d((0, pad_lr, 0, pad_tb))(image) 344 | mask = torch.nn.ConstantPad2d((0, pad_lr, 0, pad_tb), 255)(mask) 345 | 346 | return {'image': image, 347 | 'label': mask} 348 | 349 | -------------------------------------------------------------------------------- /dataloaders/datasets/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HankKung/Auto-Dynamic-DeepLab/4150a19d632269f7ebcb63e92906a7f40e6a283b/dataloaders/datasets/__init__.py -------------------------------------------------------------------------------- /dataloaders/datasets/cityscapes.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import scipy.misc as m 4 | from PIL import Image 5 | from torch.utils import data 6 | from mypath import Path 7 | from torchvision import transforms 8 | from dataloaders import custom_transforms as tr 9 | 10 | 11 | 12 | def twoTrainSeg(args, root=Path.db_root_dir('cityscapes')): 13 | images_base = os.path.join(root, 'leftImg8bit', 'train') 14 | train_files = [os.path.join(looproot, filename) for looproot, _, filenames in os.walk(images_base) 15 | for filename in filenames if filename.endswith('.png')] 16 | number_images = len(train_files) 17 | permuted_indices_ls = np.random.permutation(number_images) 18 | indices_1 = permuted_indices_ls[: int(0.5 * number_images) + 1] 19 | indices_2 = permuted_indices_ls[int(0.5 * number_images):] 20 | if len(indices_1) % 2 != 0 or len(indices_2) % 2 != 0: 21 | raise Exception('indices lists need to be even numbers for batch norm') 22 | return CityscapesSegmentation(args, split='train', indices_for_split=indices_1, search=True), CityscapesSegmentation(args, split='train', indices_for_split=indices_2, search=True) 23 | 24 | 25 | class CityscapesSegmentation(data.Dataset): 26 | NUM_CLASSES = 19 27 | 28 | def __init__(self, args, root=Path.db_root_dir('cityscapes'), split="train",indices_for_split=None, search=False, full=False): 29 | 30 | self.root = root 31 | self.full = full 32 | self.split = split 33 | self.args = args 34 | self.files = {} 35 | self.search = search 36 | self.images_base = os.path.join(self.root, 'leftImg8bit', self.split) 37 | self.annotations_base = os.path.join(self.root, 'gtFine', self.split) 38 | 39 | self.files[split] = self.recursive_glob(rootdir=self.images_base, suffix='.png') 40 | 41 | if indices_for_split is not None: 42 | self.files[split] = np.array(self.files[split])[indices_for_split].tolist() 43 | 44 | self.void_classes = [0, 1, 2, 3, 4, 5, 6, 9, 10, 14, 15, 16, 18, 29, 30, -1] 45 | self.valid_classes = [7, 8, 11, 12, 13, 17, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 31, 32, 33] 46 | self.class_names = ['unlabelled', 'road', 'sidewalk', 'building', 'wall', 'fence', \ 47 | 'pole', 'traffic_light', 'traffic_sign', 'vegetation', 'terrain', \ 48 | 'sky', 'person', 'rider', 'car', 'truck', 'bus', 'train', \ 49 | 'motorcycle', 'bicycle'] 50 | 51 | self.ignore_index = 255 52 | self.class_map = dict(zip(self.valid_classes, range(self.NUM_CLASSES))) 53 | self.mean=(0.29866842, 0.30135223, 0.30561872) 54 | self.std=(0.23925215, 0.23859318, 0.2385942) 55 | 56 | if not self.files[split]: 57 | raise Exception("No files for split=[%s] found in %s" % (split, self.images_base)) 58 | 59 | print("Found %d %s images" % (len(self.files[split]), split)) 60 | 61 | def __len__(self): 62 | return len(self.files[self.split]) 63 | 64 | def __getitem__(self, index): 65 | 66 | img_path = self.files[self.split][index].rstrip() 67 | lbl_path = os.path.join(self.annotations_base, 68 | img_path.split(os.sep)[-2], 69 | os.path.basename(img_path)[:-15] + 'gtFine_labelIds.png') 70 | 71 | _img = Image.open(img_path).convert('RGB') 72 | _tmp = np.array(Image.open(lbl_path), dtype=np.uint8) 73 | _tmp = self.encode_segmap(_tmp) 74 | _target = Image.fromarray(_tmp) 75 | 76 | sample = {'image': _img, 'label': _target} 77 | 78 | if self.split == 'train' and self.full != True: 79 | return self.transform_tr(sample) 80 | elif self.split == 'val' or self.full == True: 81 | return self.transform_val(sample) 82 | elif self.split == 'test': 83 | return self.transform_ts(sample) 84 | 85 | def encode_segmap(self, mask): 86 | # Put all void classes to zero 87 | for _voidc in self.void_classes: 88 | mask[mask == _voidc] = self.ignore_index 89 | for _validc in self.valid_classes: 90 | mask[mask == _validc] = self.class_map[_validc] 91 | return mask 92 | 93 | def recursive_glob(self, rootdir='.', suffix=''): 94 | """Performs recursive glob with given suffix and rootdir 95 | :param rootdir is the root directory 96 | :param suffix is the suffix to be searched 97 | """ 98 | return [os.path.join(looproot, filename) 99 | for looproot, _, filenames in os.walk(rootdir) 100 | for filename in filenames if filename.endswith(suffix)] 101 | 102 | def transform_tr(self, sample): 103 | if self.search: 104 | transform = tr.train_preprocess((321,321), self.mean, self.std, scale=0.5) 105 | else: 106 | transform = tr.train_preprocess((769,769), self.mean, self.std) 107 | return transform(sample) 108 | 109 | def transform_val(self, sample): 110 | if self.search: 111 | transform = tr.full_image_eval_preprocess((1025,2049), self.mean, self.std) 112 | 113 | else: 114 | transform = tr.full_image_eval_preprocess((1025,2049), self.mean, self.std) 115 | return transform(sample) 116 | 117 | def transform_ts(self, sample): 118 | 119 | transform = tr.full_image_eval_preprocess((1025,2049), self.mean, self.std) 120 | return transform(sample) 121 | 122 | if __name__ == '__main__': 123 | from dataloaders.utils import decode_segmap 124 | from torch.utils.data import DataLoader 125 | import matplotlib.pyplot as plt 126 | import argparse 127 | 128 | parser = argparse.ArgumentParser() 129 | args = parser.parse_args() 130 | args.base_size = 513 131 | args.crop_size = 513 132 | 133 | cityscapes_train = CityscapesSegmentation(args, split='train') 134 | 135 | dataloader = DataLoader(cityscapes_train, batch_size=2, shuffle=True, num_workers=2) 136 | 137 | for ii, sample in enumerate(dataloader): 138 | for jj in range(sample["image"].size()[0]): 139 | img = sample['image'].numpy() 140 | gt = sample['label'].numpy() 141 | tmp = np.array(gt[jj]).astype(np.uint8) 142 | segmap = decode_segmap(tmp, dataset='cityscapes') 143 | img_tmp = np.transpose(img[jj], axes=[1, 2, 0]) 144 | img_tmp *= (0.229, 0.224, 0.225) 145 | img_tmp += (0.485, 0.456, 0.406) 146 | img_tmp *= 255.0 147 | img_tmp = img_tmp.astype(np.uint8) 148 | plt.figure() 149 | plt.title('display') 150 | plt.subplot(211) 151 | plt.imshow(img_tmp) 152 | plt.subplot(212) 153 | plt.imshow(segmap) 154 | 155 | if ii == 1: 156 | break 157 | 158 | plt.show(block=True) 159 | 160 | -------------------------------------------------------------------------------- /dataloaders/datasets/pascal.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function, division 2 | import os 3 | from PIL import Image 4 | import numpy as np 5 | from torch.utils.data import Dataset 6 | from mypath import Path 7 | from torchvision import transforms 8 | from dataloaders import custom_transforms as tr 9 | 10 | class VOCSegmentation(Dataset): 11 | """ 12 | PascalVoc dataset 13 | """ 14 | CLASSES = [ 15 | 'background', 'aeroplane', 'bicycle', 'bird', 'boat', 'bottle', 'bus', 16 | 'car', 'cat', 'chair', 'cow', 'diningtable', 'dog', 'horse', 17 | 'motorbike', 'person', 'potted-plant', 'sheep', 'sofa', 'train', 18 | 'tv/monitor' 19 | ] 20 | 21 | NUM_CLASSES = 21 22 | 23 | def __init__(self, 24 | root, 25 | train=True 26 | ): 27 | """ 28 | :param base_dir: path to VOC dataset directory 29 | :param split: train/val 30 | :param transform: transform to apply 31 | """ 32 | super().__init__() 33 | self.root = root 34 | self.train = train 35 | 36 | _voc_root = os.path.join(self.root, 'VOC2012') 37 | _list_dir = os.path.join(_voc_root, 'list') 38 | 39 | if self.train: 40 | _list_f = os.path.join(_list_dir, 'train_aug.txt') 41 | else: 42 | _list_f = os.path.join(_list_dir, 'val.txt') 43 | 44 | self.images = [] 45 | self.masks = [] 46 | with open(_list_f, 'r') as lines: 47 | for line in lines: 48 | _image = _voc_root + line.split()[0] 49 | _mask = _voc_root + line.split()[1] 50 | assert os.path.isfile(_image) 51 | assert os.path.isfile(_mask) 52 | self.images.append(_image) 53 | self.masks.append(_mask) 54 | 55 | self.mean = (0.485, 0.456, 0.406) 56 | self.std = (0.229, 0.224, 0.225) 57 | 58 | # Display stats 59 | print('Number of images : {:d}'.format(len(self.images))) 60 | 61 | def __len__(self): 62 | return len(self.images) 63 | 64 | 65 | def __getitem__(self, index): 66 | _img = Image.open(self.images[index]).convert('RGB') 67 | _target = Image.open(self.masks[index]) 68 | sample = {'image': _img, 'label': _target} 69 | 70 | if self.train: 71 | return self.transform_tr(sample) 72 | else: 73 | return self.transform_val(sample) 74 | 75 | def transform_tr(self, sample): 76 | transform = tr.tain_preprocess((513,513), self.mean, self.std) 77 | return transform(sample) 78 | 79 | def transform_val(self, sample): 80 | transform = tr.eval_preprocess((513,513), self.mean, self.std) 81 | return transform(sample) 82 | 83 | def __str__(self): 84 | return 'VOC2012(split=' + str(self.split) + ')' 85 | 86 | 87 | if __name__ == '__main__': 88 | from dataloaders.utils import decode_segmap 89 | from torch.utils.data import DataLoader 90 | import matplotlib.pyplot as plt 91 | import argparse 92 | 93 | parser = argparse.ArgumentParser() 94 | args = parser.parse_args() 95 | args.base_size = 513 96 | args.crop_size = 513 97 | 98 | voc_train = VOCSegmentation(args, split='train') 99 | 100 | dataloader = DataLoader(voc_train, batch_size=5, shuffle=True, num_workers=0) 101 | 102 | for ii, sample in enumerate(dataloader): 103 | for jj in range(sample["image"].size()[0]): 104 | img = sample['image'].numpy() 105 | gt = sample['label'].numpy() 106 | tmp = np.array(gt[jj]).astype(np.uint8) 107 | segmap = decode_segmap(tmp, dataset='pascal') 108 | img_tmp = np.transpose(img[jj], axes=[1, 2, 0]) 109 | img_tmp *= (0.229, 0.224, 0.225) 110 | img_tmp += (0.485, 0.456, 0.406) 111 | img_tmp *= 255.0 112 | img_tmp = img_tmp.astype(np.uint8) 113 | plt.figure() 114 | plt.title('display') 115 | plt.subplot(211) 116 | plt.imshow(img_tmp) 117 | plt.subplot(212) 118 | plt.imshow(segmap) 119 | 120 | if ii == 1: 121 | break 122 | 123 | plt.show(block=True) 124 | 125 | 126 | -------------------------------------------------------------------------------- /dataloaders/utils.py: -------------------------------------------------------------------------------- 1 | import matplotlib.pyplot as plt 2 | import numpy as np 3 | import torch 4 | 5 | def decode_seg_map_sequence(label_masks, dataset='pascal'): 6 | rgb_masks = [] 7 | for label_mask in label_masks: 8 | rgb_mask = decode_segmap(label_mask, dataset) 9 | rgb_masks.append(rgb_mask) 10 | rgb_masks = torch.from_numpy(np.array(rgb_masks).transpose([0, 3, 1, 2])) 11 | return rgb_masks 12 | 13 | 14 | def decode_segmap(label_mask, dataset, plot=False): 15 | """Decode segmentation class labels into a color image 16 | Args: 17 | label_mask (np.ndarray): an (M,N) array of integer values denoting 18 | the class label at each spatial location. 19 | plot (bool, optional): whether to show the resulting color image 20 | in a figure. 21 | Returns: 22 | (np.ndarray, optional): the resulting decoded color image. 23 | """ 24 | if dataset == 'pascal' or dataset == 'coco': 25 | n_classes = 21 26 | label_colours = get_pascal_labels() 27 | elif dataset == 'cityscapes': 28 | n_classes = 19 29 | label_colours = get_cityscapes_labels() 30 | elif dataset == 'kd': 31 | n_classes = 19 32 | label_colours = get_cityscapes_labels() 33 | else: 34 | raise NotImplementedError 35 | 36 | r = label_mask.copy() 37 | g = label_mask.copy() 38 | b = label_mask.copy() 39 | for ll in range(0, n_classes): 40 | r[label_mask == ll] = label_colours[ll, 0] 41 | g[label_mask == ll] = label_colours[ll, 1] 42 | b[label_mask == ll] = label_colours[ll, 2] 43 | rgb = np.zeros((label_mask.shape[0], label_mask.shape[1], 3)) 44 | rgb[:, :, 0] = r / 255.0 45 | rgb[:, :, 1] = g / 255.0 46 | rgb[:, :, 2] = b / 255.0 47 | if plot: 48 | plt.imshow(rgb) 49 | plt.show() 50 | else: 51 | return rgb 52 | 53 | 54 | def encode_segmap(mask): 55 | """Encode segmentation label images as pascal classes 56 | Args: 57 | mask (np.ndarray): raw segmentation label image of dimension 58 | (M, N, 3), in which the Pascal classes are encoded as colours. 59 | Returns: 60 | (np.ndarray): class map with dimensions (M,N), where the value at 61 | a given location is the integer denoting the class index. 62 | """ 63 | mask = mask.astype(int) 64 | label_mask = np.zeros((mask.shape[0], mask.shape[1]), dtype=np.int16) 65 | for ii, label in enumerate(get_pascal_labels()): 66 | label_mask[np.where(np.all(mask == label, axis=-1))[:2]] = ii 67 | label_mask = label_mask.astype(int) 68 | return label_mask 69 | 70 | 71 | def get_cityscapes_labels(): 72 | return np.array([ 73 | [128, 64, 128], 74 | [244, 35, 232], 75 | [70, 70, 70], 76 | [102, 102, 156], 77 | [190, 153, 153], 78 | [153, 153, 153], 79 | [250, 170, 30], 80 | [220, 220, 0], 81 | [107, 142, 35], 82 | [152, 251, 152], 83 | [0, 130, 180], 84 | [220, 20, 60], 85 | [255, 0, 0], 86 | [0, 0, 142], 87 | [0, 0, 70], 88 | [0, 60, 100], 89 | [0, 80, 100], 90 | [0, 0, 230], 91 | [119, 11, 32]]) 92 | 93 | 94 | def get_pascal_labels(): 95 | """Load the mapping that associates pascal classes with label colors 96 | Returns: 97 | np.ndarray with dimensions (21, 3) 98 | """ 99 | return np.asarray([[0, 0, 0], [128, 0, 0], [0, 128, 0], [128, 128, 0], 100 | [0, 0, 128], [128, 0, 128], [0, 128, 128], [128, 128, 128], 101 | [64, 0, 0], [192, 0, 0], [64, 128, 0], [192, 128, 0], 102 | [64, 0, 128], [192, 0, 128], [64, 128, 128], [192, 128, 128], 103 | [0, 64, 0], [128, 64, 0], [0, 192, 0], [128, 192, 0], 104 | [0, 64, 128]]) -------------------------------------------------------------------------------- /eval.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import time 4 | import numpy as np 5 | 6 | from mypath import Path 7 | from dataloaders import make_data_loader 8 | 9 | from utils.loss import SegmentationLosses 10 | from utils.calculate_weights import calculate_weigths_labels 11 | from utils.lr_scheduler import LR_Scheduler 12 | from utils.saver import Saver 13 | from utils.summaries import TensorboardSummary 14 | from utils.metrics import Evaluator 15 | from utils.eval_utils import AverageMeter 16 | 17 | from modeling.baseline_model import * 18 | from modeling.ADD import * 19 | from modeling.autodeeplab import * 20 | from modeling.operations import normalized_shannon_entropy 21 | from modeling.sync_batchnorm.batchnorm import SynchronizedBatchNorm2d 22 | from modeling.sync_batchnorm.replicate import patch_replication_callback 23 | 24 | from tqdm import tqdm 25 | from torchviz import make_dot, make_dot_from_trace 26 | from ptflops import get_model_complexity_info 27 | 28 | 29 | class Evaluation(object): 30 | def __init__(self, args): 31 | 32 | self.args = args 33 | self.saver = Saver(args) 34 | self.saver.save_experiment_config() 35 | self.summary = TensorboardSummary(self.saver.experiment_dir) 36 | self.writer = self.summary.create_summary() 37 | 38 | # Define Dataloader 39 | kwargs = {'num_workers': args.workers, 'pin_memory': True, 'drop_last': True} 40 | _, self.val_loader, self.test_loader, self.nclass = make_data_loader(args, **kwargs) 41 | 42 | if args.network == 'searched-dense': 43 | cell_path = os.path.join(args.saved_arch_path, 'autodeeplab', 'genotype.npy') 44 | cell_arch = np.load(cell_path) 45 | if self.args.C == 2: 46 | C_index = [5] 47 | #4_15_80e_40a_03-lr_5e-4wd_6e-4alr_1e-3awd 513x513 batch 4 48 | network_arch = [1, 2, 2, 2, 3, 2, 2, 1, 1, 1, 1, 2] 49 | low_level_layer = 0 50 | elif self.args.C == 3: 51 | C_index = [3, 7] 52 | network_arch = [1, 2, 3, 2, 2, 3, 2, 3, 2, 3, 2, 3] 53 | low_level_layer = 0 54 | elif self.args.C == 4: 55 | C_index = [2, 5, 8] 56 | network_arch = [1, 2, 3, 3, 2, 3, 3, 3, 3, 3, 2, 2] 57 | low_level_layer = 0 58 | 59 | model = ADD(network_arch, 60 | C_index, 61 | cell_arch, 62 | self.nclass, 63 | args, 64 | low_level_layer) 65 | 66 | elif args.network.startswith('autodeeplab'): 67 | network_arch = [0, 0, 0, 1, 2, 1, 2, 2, 3, 3, 2, 1] 68 | cell_path = os.path.join(args.saved_arch_path, 'autodeeplab', 'genotype.npy') 69 | cell_arch = np.load(cell_path) 70 | low_level_layer = 2 71 | if self.args.C == 2: 72 | C_index = [5] 73 | elif self.args.C == 3: 74 | C_index = [3, 7] 75 | elif self.args.C == 4: 76 | C_index = [2, 5, 8] 77 | 78 | if args.network == 'autodeeplab-dense': 79 | model = ADD(network_arch, 80 | C_index, 81 | cell_arch, 82 | self.nclass, 83 | args, 84 | low_level_layer) 85 | 86 | elif args.network == 'autodeeplab-baseline': 87 | model = Baselin_Model(network_arch, 88 | C_index, 89 | cell_arch, 90 | self.nclass, 91 | args, 92 | low_level_layer) 93 | 94 | if args.use_balanced_weights: 95 | classes_weights_path = os.path.join(Path.db_root_dir(args.dataset), args.dataset + '_classes_weights.npy') 96 | if os.path.isfile(classes_weights_path): 97 | weight = np.load(classes_weights_path) 98 | else: 99 | weight = calculate_weigths_labels(args.dataset, self.train_loader, self.nclass) 100 | weight = torch.from_numpy(weight.astype(np.float32)) 101 | else: 102 | weight = None 103 | 104 | self.criterion = nn.CrossEntropyLoss(weight=weight, ignore_index=255).cuda() 105 | self.model = model 106 | 107 | # Define Evaluator 108 | self.evaluator = [] 109 | for num in range(self.args.C): 110 | self.evaluator.append(Evaluator(self.nclass)) 111 | 112 | # Using cuda 113 | if args.cuda: 114 | self.model = self.model.cuda() 115 | if args.confidence == 'edm': 116 | self.edm = EDM() 117 | self.edm = self.edm.cuda() 118 | else: 119 | self.edm = False 120 | 121 | # Resuming checkpoint 122 | self.best_pred = 0.0 123 | if args.resume is not None: 124 | if not os.path.isfile(args.resume): 125 | raise RuntimeError("=> no checkpoint found at '{}'".format(args.resume)) 126 | checkpoint = torch.load(args.resume) 127 | args.start_epoch = checkpoint['epoch'] 128 | 129 | # if the weights are wrapped in module object we have to clean it 130 | if args.clean_module: 131 | self.model.load_state_dict(checkpoint['state_dict']) 132 | state_dict = checkpoint['state_dict'] 133 | new_state_dict = OrderedDict() 134 | for k, v in state_dict.items(): 135 | name = k[7:] # remove 'module.' of dataparallel 136 | new_state_dict[name] = v 137 | self.model.load_state_dict(new_state_dict) 138 | 139 | else: 140 | self.model.load_state_dict(checkpoint['state_dict']) 141 | 142 | 143 | self.best_pred = checkpoint['best_pred'] 144 | print("=> loaded checkpoint '{}' (epoch {})" 145 | .format(args.resume, checkpoint['epoch'])) 146 | if args.resume_edm is not None: 147 | if not os.path.isfile(args.resume_edm): 148 | raise RuntimeError("=> no checkpoint found at '{}'".format(args.resume_edm)) 149 | checkpoint = torch.load(args.resume_edm) 150 | 151 | # if the weights are wrapped in module object we have to clean it 152 | if args.clean_module: 153 | self.edm.load_state_dict(checkpoint['state_dict']) 154 | state_dict = checkpoint['state_dict'] 155 | new_state_dict = OrderedDict() 156 | for k, v in state_dict.items(): 157 | name = k[7:] # remove 'module.' of dataparallel 158 | new_state_dict[name] = v 159 | self.edm.load_state_dict(new_state_dict) 160 | 161 | else: 162 | self.edm.load_state_dict(checkpoint['state_dict']) 163 | 164 | 165 | def validation(self): 166 | self.model.eval() 167 | for e in self.evaluator: 168 | e.reset() 169 | tbar = tqdm(self.val_loader, desc='\r') 170 | test_loss = 0.0 171 | 172 | for i, sample in enumerate(tbar): 173 | image, target = sample['image'], sample['label'] 174 | if self.args.cuda: 175 | image, target = image.cuda(), target.cuda() 176 | 177 | with torch.no_grad(): 178 | outputs = self.model(image) 179 | 180 | prediction = [] 181 | """ Add batch sample into evaluator """ 182 | for classifier_i in range(self.args.C): 183 | pred = torch.argmax(outputs[classifier_i], axis=1) 184 | prediction.append(pred) 185 | self.evaluator[classifier_i].add_batch(target, prediction[classifier_i]) 186 | 187 | 188 | # Add batch sample into evaluator 189 | mIoU = [] 190 | for classifier_i, e in enumerate(self.evaluator): 191 | mIoU.append(e.Mean_Intersection_over_Union()) 192 | 193 | print("classifier_1_mIoU:{}, classifier_2_mIoU: {}".format(mIoU[0], mIoU[1])) 194 | 195 | def dynamic_inference(self, threshold, confidence): 196 | self.model.eval() 197 | self.evaluator[0].reset() 198 | if confidence == 'edm': 199 | self.edm.eval() 200 | time_meter = AverageMeter() 201 | 202 | tbar = tqdm(self.val_loader, desc='\r') 203 | test_loss = 0.0 204 | total_earlier_exit = 0 205 | confidence_value_avg = 0.0 206 | for i, sample in enumerate(tbar): 207 | image, target = sample['image'], sample['label'] 208 | if self.args.cuda: 209 | image, target = image.cuda(), target.cuda() 210 | 211 | with torch.no_grad(): 212 | output, earlier_exit, tic, confidence_value = self.model.dynamic_inference(image, threshold=threshold, confidence=confidence, edm=self.edm) 213 | total_earlier_exit += earlier_exit 214 | confidence_value_avg += confidence_value 215 | time_meter.update(tic) 216 | 217 | loss = self.criterion(output, target) 218 | pred = torch.argmax(output, axis=1) 219 | 220 | # Add batch sample into evaluator 221 | self.evaluator[0].add_batch(target, pred) 222 | 223 | mIoU = self.evaluator[0].Mean_Intersection_over_Union() 224 | 225 | print('Validation:') 226 | print("mIoU: {}".format(mIoU)) 227 | print("mean_inference_time: {}".format(time_meter.average())) 228 | print("fps: {}".format(1.0/time_meter.average())) 229 | print("num_earlier_exit: {}".format(total_earlier_exit/500*100)) 230 | print("avg_confidence: {}".format(confidence_value_avg/500)) 231 | 232 | 233 | def mac(self): 234 | self.model.eval() 235 | with torch.no_grad(): 236 | flops, params = get_model_complexity_info(self.model, (3, 1025, 2049), as_strings=True, print_per_layer_stat=False) 237 | print('{:<30} {:<8}'.format('Computational complexity: ', flops)) 238 | print('{:<30} {:<8}'.format('Number of parameters: ', params)) 239 | 240 | 241 | def main(): 242 | parser = argparse.ArgumentParser(description="Eval") 243 | 244 | """ model setting """ 245 | parser.add_argument('--network', type=str, default='searched-dense', \ 246 | choices=['searched-dense', 'autodeeplab-baseline', 'autodeeplab-dense', 'autodeeplab']) 247 | parser.add_argument('--F', type=int, default=20) 248 | parser.add_argument('--B', type=int, default=5) 249 | parser.add_argument('--C', type=int, default=2, help='num of classifiers') 250 | 251 | 252 | """ dynamic inference""" 253 | parser.add_argument('--dynamic', action='store_true', default=False) 254 | parser.add_argument('--threshold', type=float, default=None) 255 | parser.add_argument('--confidence', type=str, default='pool', choices=['edm', 'entropy', 'max']) 256 | 257 | 258 | """ dataset config""" 259 | parser.add_argument('--dataset', type=str, default='cityscapes') 260 | parser.add_argument('--workers', type=int, default=1, metavar='N') 261 | 262 | 263 | """ training config """ 264 | parser.add_argument('--sync-bn', type=bool, default=None) 265 | parser.add_argument('--freeze-bn', type=bool, default=False) 266 | 267 | parser.add_argument('--batch-size', type=int, default=1, metavar='N') 268 | parser.add_argument('--test-batch-size', type=int, default=1, metavar='N') 269 | parser.add_argument('--use-balanced-weights', action='store_true', default=False) 270 | parser.add_argument('--clean-module', type=int, default=0) 271 | parser.add_argument('--dist', action='store_true', default=False) 272 | 273 | """ cuda, seed and logging """ 274 | parser.add_argument('--no-cuda', action='store_true', default=False) 275 | parser.add_argument('--gpu-ids', type=str, default='0') 276 | parser.add_argument('--seed', type=int, default=1, metavar='S') 277 | 278 | 279 | """ checking point """ 280 | parser.add_argument('--resume', type=str, default=None) 281 | parser.add_argument('--resume_edm', type=str, default=None) 282 | parser.add_argument('--saved-arch-path', type=str, default='searched_arch/') 283 | parser.add_argument('--checkname', type=str, default='testing') 284 | 285 | 286 | args = parser.parse_args() 287 | args.cuda = not args.no_cuda and torch.cuda.is_available() 288 | if args.cuda: 289 | try: 290 | args.gpu_ids = [int(s) for s in args.gpu_ids.split(',')] 291 | except ValueError: 292 | raise ValueError('Argument --gpu_ids must be a comma-separated list of integers only') 293 | 294 | if args.cuda and len(args.gpu_ids) > 1: 295 | args.sync_bn = True 296 | else: 297 | args.sync_bn = False 298 | 299 | if args.checkname is None: 300 | args.checkname = 'evaluation' 301 | print(args) 302 | torch.manual_seed(args.seed) 303 | torch.cuda.manual_seed(args.seed) 304 | evaluation = Evaluation(args) 305 | evaluation.mac() 306 | if args.dynamic: 307 | evaluation.dynamic_inference(threshold=args.threshold, confidence=args.confidence) 308 | else: 309 | evaluation.validation() 310 | 311 | evaluation.writer.close() 312 | 313 | if __name__ == "__main__": 314 | main() 315 | 316 | -------------------------------------------------------------------------------- /eval_edm.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import time 4 | import numpy as np 5 | 6 | from mypath import Path 7 | from dataloaders import make_data_loader 8 | 9 | from utils.loss import SegmentationLosses 10 | from utils.calculate_weights import calculate_weigths_labels 11 | from utils.lr_scheduler import LR_Scheduler 12 | from utils.saver import Saver 13 | from utils.summaries import TensorboardSummary 14 | from utils.metrics import Evaluator 15 | from utils.eval_utils import AverageMeter 16 | 17 | from modeling.baseline_model import * 18 | from modeling.dense_model import * 19 | from modeling.autodeeplab import * 20 | from modeling.operations import normalized_shannon_entropy 21 | from modeling.sync_batchnorm.batchnorm import SynchronizedBatchNorm2d 22 | from modeling.sync_batchnorm.replicate import patch_replication_callback 23 | 24 | from tqdm import tqdm 25 | from torchviz import make_dot, make_dot_from_trace 26 | from apex import amp 27 | from ptflops import get_model_complexity_info 28 | 29 | 30 | class Evaluation(object): 31 | def __init__(self, args): 32 | 33 | self.args = args 34 | self.saver = Saver(args) 35 | self.saver.save_experiment_config() 36 | self.summary = TensorboardSummary(self.saver.experiment_dir) 37 | self.writer = self.summary.create_summary() 38 | 39 | # Define Dataloader 40 | kwargs = {'num_workers': args.workers, 'pin_memory': True, 'drop_last': True} 41 | _, self.val_loader, self.test_loader, self.nclass = make_data_loader(args, **kwargs) 42 | 43 | if args.network == 'searched-dense': 44 | """ 40_5e_lr_38_31.91 """ 45 | cell_path = os.path.join(args.saved_arch_path, 'autodeeplab', 'genotype.npy') 46 | cell_arch = np.load(cell_path) 47 | network_arch = [1, 2, 2, 2, 3, 2, 2, 1, 1, 1, 1, 2] 48 | low_level_layer = 0 49 | 50 | model = Model_2(network_arch, 51 | cell_arch, 52 | self.nclass, 53 | args, 54 | low_level_layer) 55 | 56 | elif args.network == 'searched-baseline': 57 | cell_path = os.path.join(args.saved_arch_path, 'searched_baseline', 'genotype.npy') 58 | cell_arch = np.load(cell_path) 59 | network_arch = [0, 1, 2, 2, 3, 2, 2, 1, 2, 1, 1, 2] 60 | low_level_layer = 1 61 | model = Model_2_baseline(network_arch, 62 | cell_arch, 63 | self.nclass, 64 | args, 65 | low_level_layer) 66 | 67 | elif args.network.startswith('autodeeplab'): 68 | network_arch = [0, 0, 0, 1, 2, 1, 2, 2, 3, 3, 2, 1] 69 | cell_path = os.path.join(args.saved_arch_path, 'autodeeplab', 'genotype.npy') 70 | cell_arch = np.load(cell_path) 71 | low_level_layer = 2 72 | 73 | if args.network == 'autodeeplab-dense': 74 | model = Model_2(network_arch, 75 | cell_arch, 76 | self.nclass, 77 | args, 78 | low_level_layer) 79 | 80 | elif args.network == 'autodeeplab-baseline': 81 | model = Model_2_baseline(network_arch, 82 | cell_arch, 83 | self.nclass, 84 | args, 85 | low_level_layer) 86 | elif args.network == 'autodeeplab': 87 | model = AutoDeepLab(network_arch, 88 | cell_arch, 89 | self.nclass, 90 | args, 91 | low_level_layer) 92 | 93 | if args.use_balanced_weights: 94 | classes_weights_path = os.path.join(Path.db_root_dir(args.dataset), args.dataset + '_classes_weights.npy') 95 | if os.path.isfile(classes_weights_path): 96 | weight = np.load(classes_weights_path) 97 | else: 98 | weight = calculate_weigths_labels(args.dataset, self.train_loader, self.nclass) 99 | weight = torch.from_numpy(weight.astype(np.float32)) 100 | else: 101 | weight = None 102 | 103 | self.criterion = nn.CrossEntropyLoss(weight=weight, ignore_index=255).cuda() 104 | self.model = model 105 | # Define Evaluator 106 | self.evaluator_1 = Evaluator(self.nclass) 107 | self.evaluator_2 = Evaluator(self.nclass) 108 | 109 | # Using cuda 110 | if args.cuda: 111 | self.model = self.model.cuda() 112 | if args.confidence == 'edm': 113 | self.edm = EDM() 114 | self.edm = self.edm.cuda() 115 | else: 116 | self.edm = False 117 | 118 | # Resuming checkpoint 119 | self.best_pred = 0.0 120 | if args.resume is not None: 121 | if not os.path.isfile(args.resume): 122 | raise RuntimeError("=> no checkpoint found at '{}'".format(args.resume)) 123 | checkpoint = torch.load(args.resume) 124 | args.start_epoch = checkpoint['epoch'] 125 | 126 | # if the weights are wrapped in module object we have to clean it 127 | if args.clean_module: 128 | self.model.load_state_dict(checkpoint['state_dict']) 129 | state_dict = checkpoint['state_dict'] 130 | new_state_dict = OrderedDict() 131 | for k, v in state_dict.items(): 132 | name = k[7:] # remove 'module.' of dataparallel 133 | new_state_dict[name] = v 134 | self.model.load_state_dict(new_state_dict) 135 | 136 | else: 137 | self.model.load_state_dict(checkpoint['state_dict']) 138 | 139 | 140 | self.best_pred = checkpoint['best_pred'] 141 | print("=> loaded checkpoint '{}' (epoch {})" 142 | .format(args.resume, checkpoint['epoch'])) 143 | if args.resume_edm is not None: 144 | if not os.path.isfile(args.resume_edm): 145 | raise RuntimeError("=> no checkpoint found at '{}'".format(args.resume_edm)) 146 | checkpoint = torch.load(args.resume_edm) 147 | 148 | # if the weights are wrapped in module object we have to clean it 149 | if args.clean_module: 150 | self.edm.load_state_dict(checkpoint['state_dict']) 151 | state_dict = checkpoint['state_dict'] 152 | new_state_dict = OrderedDict() 153 | for k, v in state_dict.items(): 154 | name = k[7:] # remove 'module.' of dataparallel 155 | new_state_dict[name] = v 156 | self.edm.load_state_dict(new_state_dict) 157 | 158 | else: 159 | self.edm.load_state_dict(checkpoint['state_dict']) 160 | 161 | 162 | def validation(self): 163 | self.model.eval() 164 | self.evaluator_1.reset() 165 | self.evaluator_2.reset() 166 | tbar = tqdm(self.val_loader, desc='\r') 167 | test_loss = 0.0 168 | 169 | for i, sample in enumerate(tbar): 170 | image, target = sample['image'], sample['label'] 171 | if self.args.cuda: 172 | image, target = image.cuda(), target.cuda() 173 | 174 | with torch.no_grad(): 175 | output_1, output_2 = self.model(image) 176 | 177 | loss_1 = self.criterion(output_1, target) 178 | loss_2 = self.criterion(output_2, target) 179 | 180 | 181 | pred_1 = torch.argmax(output_1, axis=1) 182 | pred_2 = torch.argmax(output_2, axis=1) 183 | 184 | 185 | # Add batch sample into evaluator 186 | self.evaluator_1.add_batch(target, pred_1) 187 | self.evaluator_2.add_batch(target, pred_2) 188 | 189 | mIoU_1 = self.evaluator_1.Mean_Intersection_over_Union() 190 | mIoU_2 = self.evaluator_2.Mean_Intersection_over_Union() 191 | 192 | print('Validation:') 193 | print("mIoU_1:{}, mIoU_2: {}".format(mIoU_1, mIoU_2)) 194 | 195 | 196 | def testing_entropy(self): 197 | self.model.eval() 198 | self.evaluator_1.reset() 199 | self.evaluator_2.reset() 200 | tbar = tqdm(self.val_loader, desc='\r') 201 | test_loss = 0.0 202 | pool_vec = np.zeros(500) 203 | entropy_vec = np.zeros(500) 204 | loss_vec = np.zeros(500) 205 | for i, sample in enumerate(tbar): 206 | image, target = sample['image'], sample['label'] 207 | if self.args.cuda: 208 | image, target = image.cuda(), target.cuda() 209 | 210 | with torch.no_grad(): 211 | output_1, output_2, pool = self.model.dynamic_inference(image, threshold=threshold, confidence=confidence) 212 | 213 | loss_1 = self.criterion(output_1, target) 214 | loss_2 = self.criterion(output_2, target) 215 | 216 | 217 | pred_1 = torch.argmax(output_1, axis=1) 218 | pred_2 = torch.argmax(output_2, axis=1) 219 | 220 | entropy = normalized_shannon_entropy(output_1) 221 | 222 | # Add batch sample into evaluator 223 | self.evaluator_1.add_batch(target, pred_1) 224 | self.evaluator_2.add_batch(target, pred_2) 225 | 226 | self.writer.add_scalar('pool/i', pool.item(), i) 227 | self.writer.add_scalar('entropy/i', entropy, i) 228 | self.writer.add_scalar('loss/i', loss_1.item(), i) 229 | 230 | pool_vec[i] = pool.item() 231 | entropy_vec[i] = entropy 232 | loss_vec[i] = loss_1.item() 233 | 234 | pool_vec = torch.from_numpy(pool_vec) 235 | entropy_vec = torch.from_numpy(entropy_vec) 236 | loss_vec = torch.from_numpy(loss_vec) 237 | 238 | mIoU_1 = self.evaluator_1.Mean_Intersection_over_Union() 239 | mIoU_2 = self.evaluator_2.Mean_Intersection_over_Union() 240 | 241 | cos = nn.CosineSimilarity(dim=-1) 242 | cos_sim = cos(pool_vec, entropy_vec) 243 | print("pool-entropy_cosine similarity: {}".format(cos_sim)) 244 | cos_sim = cos(pool_vec, loss_vec) 245 | print("pool-loss_cosine similarity: {}".format(cos_sim)) 246 | cos_sim = cos(entropy_vec, loss_vec) 247 | print("-entropy-loss_cosine similarity: {}".format(cos_sim)) 248 | 249 | print('Validation:') 250 | print("mIoU_1:{}, mIoU_2: {}".format(mIoU_1, mIoU_2)) 251 | 252 | def dynamic_inference(self, threshold, confidence): 253 | self.model.eval() 254 | self.evaluator_1.reset() 255 | time_meter = AverageMeter() 256 | if confidence == 'edm': 257 | self.edm.eval() 258 | tbar = tqdm(self.val_loader, desc='\r') 259 | test_loss = 0.0 260 | total_earlier_exit = 0 261 | confidence_value_avg = 0.0 262 | for i, sample in enumerate(tbar): 263 | image, target = sample['image'], sample['label'] 264 | if self.args.cuda: 265 | image, target = image.cuda(), target.cuda() 266 | 267 | with torch.no_grad(): 268 | output, earlier_exit, tic, confidence_value = \ 269 | self.model.dynamic_inference(image, threshold=threshold, confidence=confidence, edm=self.edm) 270 | total_earlier_exit += earlier_exit 271 | confidence_value_avg += confidence_value 272 | time_meter.update(tic) 273 | 274 | loss = self.criterion(output, target) 275 | pred = torch.argmax(output, axis=1) 276 | 277 | # Add batch sample into evaluator 278 | self.evaluator_1.add_batch(target, pred) 279 | tbar.set_description('earlier_exit_num: %.1f' % (total_earlier_exit)) 280 | mIoU = self.evaluator_1.Mean_Intersection_over_Union() 281 | 282 | print('Validation:') 283 | print("mIoU: {}".format(mIoU)) 284 | print("mean_inference_time: {}".format(time_meter.average())) 285 | print("fps: {}".format(1.0/time_meter.average())) 286 | print("num_earlier_exit: {}".format(total_earlier_exit/500*100)) 287 | print("avg_confidence: {}".format(confidence_value_avg/500)) 288 | 289 | def time_measure(self): 290 | time_meter_1 = AverageMeter() 291 | time_meter_2 = AverageMeter() 292 | self.model.eval() 293 | self.evaluator_1.reset() 294 | tbar = tqdm(self.val_loader, desc='\r') 295 | test_loss = 0.0 296 | 297 | for i, sample in enumerate(tbar): 298 | image, target = sample['image'], sample['label'] 299 | if self.args.cuda: 300 | image, target = image.cuda(), target.cuda() 301 | 302 | with torch.no_grad(): 303 | _, _, t1, t2 = self.model.time_measure(image) 304 | if t1 != None: 305 | time_meter_1.update(t1) 306 | time_meter_2.update(t2) 307 | if t1 != None: 308 | print(time_meter_1.average()) 309 | print(time_meter_2.average()) 310 | 311 | 312 | 313 | def mac(self): 314 | self.model.eval() 315 | with torch.no_grad(): 316 | flops, params = get_model_complexity_info(self.model, (3, 1025, 2049), as_strings=True, print_per_layer_stat=False) 317 | print('{:<30} {:<8}'.format('Computational complexity: ', flops)) 318 | print('{:<30} {:<8}'.format('Number of parameters: ', params)) 319 | 320 | 321 | def main(): 322 | parser = argparse.ArgumentParser(description="Eval") 323 | """ model setting """ 324 | parser.add_argument('--network', type=str, default='searched-dense', \ 325 | choices=['searched-dense', 'searched-baseline', 'autodeeplab-baseline', 'autodeeplab-dense', 'autodeeplab', 'supernet']) 326 | parser.add_argument('--num_model_1_layers', type=int, default=6) 327 | parser.add_argument('--F', type=int, default=20) 328 | parser.add_argument('--B', type=int, default=5) 329 | parser.add_argument('--use-map', type=bool, default=False) 330 | 331 | 332 | """ dynamic inference""" 333 | parser.add_argument('--threshold', type=float, default=None) 334 | parser.add_argument('--confidence', type=str, default='pool', choices=['edm', 'pool', 'entropy', 'max']) 335 | 336 | """ dataset config""" 337 | parser.add_argument('--dataset', type=str, default='cityscapes') 338 | parser.add_argument('--workers', type=int, default=1, metavar='N') 339 | 340 | 341 | """ training config """ 342 | parser.add_argument('--use-amp', type=bool, default=False) 343 | parser.add_argument('--dist', action='store_true', default=False) 344 | 345 | parser.add_argument('--sync-bn', type=bool, default=None) 346 | parser.add_argument('--freeze-bn', type=bool, default=False) 347 | 348 | parser.add_argument('--batch-size', type=int, default=1, metavar='N') 349 | parser.add_argument('--test-batch-size', type=int, default=1, metavar='N') 350 | parser.add_argument('--use-balanced-weights', action='store_true', default=False) 351 | parser.add_argument('--clean-module', type=int, default=0) 352 | 353 | 354 | """ cuda, seed and logging """ 355 | parser.add_argument('--no-cuda', action='store_true', default=False) 356 | parser.add_argument('--gpu-ids', type=str, default='0') 357 | parser.add_argument('--seed', type=int, default=1, metavar='S') 358 | 359 | 360 | """ checking point """ 361 | parser.add_argument('--resume', type=str, default=None) 362 | parser.add_argument('--resume_edm', type=str, default=None) 363 | parser.add_argument('--saved-arch-path', type=str, default='searched_arch/') 364 | parser.add_argument('--checkname', type=str, default='testing') 365 | 366 | 367 | args = parser.parse_args() 368 | args.cuda = not args.no_cuda and torch.cuda.is_available() 369 | if args.cuda: 370 | try: 371 | args.gpu_ids = [int(s) for s in args.gpu_ids.split(',')] 372 | except ValueError: 373 | raise ValueError('Argument --gpu_ids must be a comma-separated list of integers only') 374 | 375 | if args.cuda and len(args.gpu_ids) > 1: 376 | args.sync_bn = True 377 | else: 378 | args.sync_bn = False 379 | 380 | if args.checkname is None: 381 | args.checkname = 'evaluation' 382 | print(args) 383 | torch.manual_seed(args.seed) 384 | torch.cuda.manual_seed(args.seed) 385 | evaluation = Evaluation(args) 386 | # evaluation.mac() 387 | evaluation.dynamic_inference(threshold=args.threshold, confidence=args.confidence) 388 | #evaluation.validation() 389 | evaluation.writer.close() 390 | 391 | if __name__ == "__main__": 392 | main() -------------------------------------------------------------------------------- /modeling/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HankKung/Auto-Dynamic-DeepLab/4150a19d632269f7ebcb63e92906a7f40e6a283b/modeling/.DS_Store -------------------------------------------------------------------------------- /modeling/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HankKung/Auto-Dynamic-DeepLab/4150a19d632269f7ebcb63e92906a7f40e6a283b/modeling/__init__.py -------------------------------------------------------------------------------- /modeling/aspp_train.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from modeling.sync_batchnorm.batchnorm import SynchronizedBatchNorm2d 6 | from modeling.operations import ReLUConvBN 7 | 8 | class ASPP_train(nn.Module): 9 | def __init__(self, C, out, BatchNorm, depth=256, conv=nn.Conv2d, eps=1e-5, momentum=0.1, mult=1): 10 | super(ASPP_train, self).__init__() 11 | self._C = C 12 | self._depth = depth 13 | self.global_pooling = nn.AdaptiveAvgPool2d(1) 14 | self.relu = nn.ReLU(inplace=True) 15 | self.relu_non_inplace = nn.ReLU() 16 | self.aspp1 = conv(C, depth, kernel_size=1, stride=1, bias=False) 17 | self.aspp2 = conv(C, depth, kernel_size=3, stride=1, 18 | dilation=int(6*mult), padding=int(6*mult), bias=False) 19 | self.aspp3 = conv(C, depth, kernel_size=3, stride=1, 20 | dilation=int(12*mult), padding=int(12*mult), bias=False) 21 | self.aspp4 = conv(C, depth, kernel_size=3, stride=1, 22 | dilation=int(18*mult), padding=int(18*mult), bias=False) 23 | self.aspp5 = conv(C, depth, kernel_size=1, stride=1, bias=False) 24 | 25 | self.conv1 = conv(depth * 5, out, kernel_size=1, stride=1, bias=False) 26 | 27 | self.bn1 = BatchNorm(out, eps=eps, momentum=momentum) 28 | self.aspp1_bn = BatchNorm(depth, eps=eps, momentum=momentum) 29 | self.aspp2_bn = BatchNorm(depth, eps=eps, momentum=momentum) 30 | self.aspp3_bn = BatchNorm(depth, eps=eps, momentum=momentum) 31 | self.aspp4_bn = BatchNorm(depth, eps=eps, momentum=momentum) 32 | self.aspp5_bn = BatchNorm(depth, eps=eps, momentum=momentum) 33 | 34 | def forward(self, x, confidence_map=None, iter_rate=1.0): 35 | batch_size, h, w = x.size(0), x.size(2), x.size(3) 36 | 37 | x = self.relu_non_inplace(x) 38 | x1 = self.aspp1(x) 39 | x1 = self.aspp1_bn(x1) 40 | x1 = self.relu(x1) 41 | x2 = self.aspp2(x) 42 | x2 = self.aspp2_bn(x2) 43 | x2 = self.relu(x2) 44 | x3 = self.aspp3(x) 45 | x3 = self.aspp3_bn(x3) 46 | x3 = self.relu(x3) 47 | x4 = self.aspp4(x) 48 | x4 = self.aspp4_bn(x4) 49 | x4 = self.relu(x4) 50 | x5 = self.global_pooling(x) 51 | x5 = self.aspp5(x5) 52 | x5 = self.aspp5_bn(x5) 53 | x5 = self.relu(x5) 54 | x5 = nn.Upsample((x.shape[2], x.shape[3]), mode='bilinear', 55 | align_corners=True)(x5) 56 | 57 | x = torch.cat((x1, x2, x3, x4, x5), 1) 58 | 59 | x = self.conv1(x) 60 | x = self.bn1(x) 61 | return x 62 | 63 | 64 | class ASPP_Lite(nn.Module): 65 | def __init__(self, in_channels, low_level_channels, mid_channels, num_classes, BatchNorm): 66 | super().__init__() 67 | self._1x1_TL = ReLUConvBN(in_channels, mid_channels, 1, 1, 0, BatchNorm) 68 | self._1x1_BL = nn.Conv2d(in_channels, mid_channels, kernel_size=1) # TODO: bias=False? 69 | self._1x1_TR = nn.Conv2d(mid_channels, num_classes, kernel_size=1) 70 | self._1x1_BR = nn.Conv2d(low_level_channels, num_classes, kernel_size=1) 71 | self.avgpool = torch.nn.AvgPool2d(kernel_size=49, stride=[16, 20], count_include_pad=False) 72 | 73 | def forward(self, x, low_level_feature): 74 | t1 = self._1x1_TL(x) 75 | B, C, H, W = t1.shape 76 | t2 = self.avgpool(x) 77 | t2 = self._1x1_BL(t2) 78 | t2 = torch.sigmoid(t2) 79 | t2 = F.interpolate(t2, size=(H, W), mode='bilinear', align_corners=False) 80 | t3 = t1 * t2 81 | h , w = int((float(t3.shape[2]) - 1.0) * 2 + 1.0), int((float(t3.shape[3]) - 1.0) * 2 + 1.0) 82 | t3 = F.interpolate(t3, [h, w], mode='bilinear', align_corners=False) 83 | t3 = self._1x1_TR(t3) 84 | t4 = self._1x1_BR(low_level_feature) 85 | return t3 + t4 86 | -------------------------------------------------------------------------------- /modeling/autodeeplab.py: -------------------------------------------------------------------------------- 1 | 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | import numpy as np 6 | 7 | from modeling.genotypes import PRIMITIVES 8 | from modeling.aspp_train import ASPP_train 9 | from modeling.decoder import Decoder 10 | from modeling.sync_batchnorm.batchnorm import SynchronizedBatchNorm2d 11 | from modeling.operations import * 12 | 13 | import time 14 | 15 | class Cell_AutoDeepLab(nn.Module): 16 | 17 | def __init__(self, 18 | BatchNorm, 19 | B, 20 | prev_prev_C, 21 | prev_C, 22 | cell_arch, 23 | network_arch, 24 | C_out, 25 | downup_sample): 26 | 27 | super(Cell_AutoDeepLab, self).__init__() 28 | eps = 1e-5 29 | momentum = 0.1 30 | 31 | self.cell_arch = cell_arch 32 | self.downup_sample = downup_sample 33 | self.B = B 34 | 35 | self.pre_preprocess = ReLUConvBN( 36 | prev_prev_C, C_out, 1, 1, 0, BatchNorm, eps=eps, momentum=momentum, affine=True) 37 | self.preprocess = ReLUConvBN( 38 | prev_C, C_out, 1, 1, 0, BatchNorm, eps=eps, momentum=momentum, affine=True) 39 | 40 | self._ops = nn.ModuleList() 41 | if downup_sample == -1: 42 | self.preprocess = FactorizedReduce(prev_C, C_out, BatchNorm, eps=eps, momentum=momentum) 43 | elif downup_sample == 1: 44 | self.scale = 2 45 | 46 | for x in self.cell_arch: 47 | primitive = PRIMITIVES[x[1]] 48 | op = OPS[primitive](C_out, 1, BatchNorm, eps=eps, momentum=momentum, affine=True) 49 | self._ops.append(op) 50 | 51 | 52 | def scale_dimension(self, dim, scale): 53 | return int((float(dim) - 1.0) * scale + 1.0) 54 | 55 | 56 | def forward(self, prev_prev_input, prev_input): 57 | s1 = prev_input 58 | if self.downup_sample == 1: 59 | feature_size_h = self.scale_dimension( 60 | s1.shape[2], self.scale) 61 | feature_size_w = self.scale_dimension( 62 | s1.shape[3], self.scale) 63 | s1 = F.interpolate( 64 | s1, [feature_size_h, feature_size_w], mode='bilinear') 65 | s1 = self.preprocess(s1) 66 | 67 | s0 = prev_prev_input 68 | del prev_prev_input 69 | 70 | s0 = F.interpolate(s0, [s1.shape[2], s1.shape[3]], mode='bilinear') \ 71 | if s0.shape[2] != s1.shape[2] else s0 72 | s0 = self.pre_preprocess(s0) 73 | 74 | states = [s0, s1] 75 | offset = 0 76 | ops_index = 0 77 | for i in range(self.B): 78 | new_states = [] 79 | for j, h in enumerate(states): 80 | branch_index = offset + j 81 | if branch_index in self.cell_arch[:, 0]: 82 | new_state = self._ops[ops_index](h) 83 | new_states.append(new_state) 84 | ops_index += 1 85 | 86 | s = sum(new_states) 87 | offset += len(states) 88 | states.append(s) 89 | 90 | concat_feature = torch.cat(states[-self.B:], dim=1) 91 | return prev_input, concat_feature 92 | 93 | 94 | class AutoDeepLab (nn.Module): 95 | def __init__(self, network_arch, cell_arch, num_classes, args, low_level_layer): 96 | super(AutoDeepLab, self).__init__() 97 | BatchNorm = SynchronizedBatchNorm2d if args.sync_bn == True else nn.BatchNorm2d 98 | F = args.F 99 | B = args.B 100 | self.num_model_layers = len(network_arch) 101 | self.cells = nn.ModuleList() 102 | self.model_network = network_arch 103 | self.cell_arch = torch.from_numpy(cell_arch) 104 | self.low_level_layer = low_level_layer 105 | self._num_classes = num_classes 106 | 107 | FB = F * B 108 | fm = {0: 1, 1: 2, 2: 4, 3: 8} 109 | self.stem0 = nn.Sequential( 110 | nn.Conv2d(3, 64, 3, stride=2, padding=1, bias=False), 111 | BatchNorm(64), 112 | nn.ReLU(inplace=True) 113 | 114 | ) 115 | self.stem1 = nn.Sequential( 116 | nn.Conv2d(64, 64, 3, padding=1, bias=False), 117 | BatchNorm(64), 118 | ) 119 | 120 | self.stem2 = nn.Sequential( 121 | nn.ReLU(inplace=True), 122 | nn.Conv2d(64, 128, 3, stride=2, padding=1, bias=False), 123 | BatchNorm(128), 124 | ) 125 | 126 | for i in range(self.num_model_layers): 127 | level = self.model_network[i] 128 | prev_level = self.model_network[i-1] 129 | prev_prev_level = self.model_network[i-2] 130 | 131 | downup_sample = int(prev_level - level) 132 | if i == 0: 133 | downup_sample = int(0 - level) 134 | pre_downup_sample = int(-1 - level) 135 | _cell = Cell_AutoDeepLab(BatchNorm, 136 | B, 137 | 64, 138 | 128, 139 | self.cell_arch, 140 | self.model_network[i], 141 | F * fm[level], 142 | downup_sample) 143 | 144 | elif i == 1: 145 | pre_downup_sample = int(0 - level) 146 | _cell = Cell_AutoDeepLab(BatchNorm, 147 | B, 148 | 128, 149 | FB * fm[prev_level], 150 | self.cell_arch, 151 | self.model_network[i], 152 | F * fm[level], 153 | downup_sample) 154 | else: 155 | _cell = Cell_AutoDeepLab(BatchNorm, 156 | B, 157 | FB * fm[prev_prev_level], 158 | FB * fm[prev_level], 159 | self.cell_arch, 160 | self.model_network[i], 161 | F * fm[level], 162 | downup_sample) 163 | 164 | self.cells += [_cell] 165 | 166 | if self.model_network[-1] == 1: 167 | mult = 2 168 | elif self.model_network[-1] == 2: 169 | mult =1 170 | 171 | self.low_level_conv = nn.Sequential( 172 | nn.ReLU(), 173 | nn.Conv2d(F * B * 2**self.model_network[low_level_layer], 48, 1, bias=False), 174 | BatchNorm(48) 175 | ) 176 | 177 | self.aspp = ASPP_train(F * B * fm[self.model_network[-1]], 178 | 256, 179 | BatchNorm, 180 | mult=mult) 181 | 182 | self.decoder = Decoder(num_classes, BatchNorm) 183 | self._init_weight() 184 | 185 | 186 | def forward(self, x, iter_rate=1.0): 187 | size = (x.shape[2], x.shape[3]) 188 | stem = self.stem0(x) 189 | stem0 = self.stem1(stem) 190 | stem1 = self.stem2(stem0) 191 | two_last_inputs = (stem0, stem1) 192 | 193 | for i in range(self.num_model_layers): 194 | two_last_inputs = self.cells[i]( 195 | two_last_inputs[0], two_last_inputs[1]) 196 | 197 | if i == self.low_level_layer: 198 | low_level = two_last_inputs[1] 199 | low_level = self.low_level_conv(low_level) 200 | y = two_last_inputs[-1] 201 | y = self.aspp(y) 202 | y = self.decoder(y, low_level, size) 203 | 204 | return None, y 205 | 206 | def time_measure(self, x): 207 | size = (x.shape[2], x.shape[3]) 208 | torch.cuda.synchronize() 209 | tic = time.perf_counter() 210 | stem = self.stem0(x) 211 | stem0 = self.stem1(stem) 212 | stem1 = self.stem2(stem0) 213 | two_last_inputs = (stem0, stem1) 214 | 215 | for i in range(self.num_model_layers): 216 | two_last_inputs = self.cells[i]( 217 | two_last_inputs[0], two_last_inputs[1]) 218 | if i == self.low_level_layer: 219 | low_level = two_last_inputs[1] 220 | low_level = self.low_level_conv(low_level) 221 | 222 | y = two_last_inputs[-1] 223 | y = self.aspp(y) 224 | y = self.decoder(y, low_level, size) 225 | 226 | torch.cuda.synchronize() 227 | tic_1 = time.perf_counter() 228 | 229 | return None, None, None, tic_1 - tic 230 | 231 | def _init_weight(self): 232 | for m in self.modules(): 233 | if isinstance(m, nn.Conv2d): 234 | torch.nn.init.kaiming_normal_(m.weight) 235 | elif isinstance(m, SynchronizedBatchNorm2d): 236 | m.weight.data.fill_(1) 237 | m.bias.data.zero_() 238 | elif isinstance(m, nn.BatchNorm2d): 239 | m.weight.data.fill_(1) 240 | m.bias.data.zero_() -------------------------------------------------------------------------------- /modeling/baseline_model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import numpy as np 5 | 6 | from modeling.genotypes import PRIMITIVES 7 | from modeling.aspp_train import ASPP_train 8 | from modeling.decoder import Decoder 9 | from modeling.sync_batchnorm.batchnorm import SynchronizedBatchNorm2d 10 | from modeling.operations import * 11 | 12 | import time 13 | 14 | class Cell_baseline(nn.Module): 15 | 16 | def __init__(self, 17 | BatchNorm, 18 | B, 19 | prev_prev_C, 20 | prev_C, 21 | cell_arch, 22 | network_arch, 23 | C_out, 24 | downup_sample): 25 | 26 | super(Cell_baseline, self).__init__() 27 | eps = 1e-5 28 | momentum = 0.1 29 | 30 | self.cell_arch = cell_arch 31 | self.downup_sample = downup_sample 32 | self.B = B 33 | 34 | self.pre_preprocess = ReLUConvBN( 35 | prev_prev_C, C_out, 1, 1, 0, BatchNorm, eps=eps, momentum=momentum, affine=True) 36 | self.preprocess = ReLUConvBN( 37 | prev_C, C_out, 1, 1, 0, BatchNorm, eps=eps, momentum=momentum, affine=True) 38 | 39 | self._ops = nn.ModuleList() 40 | if downup_sample == -1: 41 | self.preprocess = FactorizedReduce(prev_C, C_out, BatchNorm, eps=eps, momentum=momentum) 42 | elif downup_sample == 1: 43 | self.scale = 2 44 | 45 | for x in self.cell_arch: 46 | primitive = PRIMITIVES[x[1]] 47 | op = OPS[primitive](C_out, 1, BatchNorm, eps=eps, momentum=momentum, affine=True) 48 | self._ops.append(op) 49 | 50 | 51 | def scale_dimension(self, dim, scale): 52 | return int((float(dim) - 1.0) * scale + 1.0) 53 | 54 | 55 | def forward(self, prev_prev_input, prev_input): 56 | s1 = prev_input 57 | if self.downup_sample == 1: 58 | feature_size_h = self.scale_dimension( 59 | s1.shape[2], self.scale) 60 | feature_size_w = self.scale_dimension( 61 | s1.shape[3], self.scale) 62 | s1 = F.interpolate( 63 | s1, [feature_size_h, feature_size_w], mode='bilinear') 64 | s1 = self.preprocess(s1) 65 | 66 | s0 = prev_prev_input 67 | del prev_prev_input 68 | 69 | s0 = F.interpolate(s0, [s1.shape[2], s1.shape[3]], mode='bilinear') \ 70 | if s0.shape[2] != s1.shape[2] else s0 71 | s0 = self.pre_preprocess(s0) 72 | 73 | states = [s0, s1] 74 | offset = 0 75 | ops_index = 0 76 | for i in range(self.B): 77 | new_states = [] 78 | for j, h in enumerate(states): 79 | branch_index = offset + j 80 | if branch_index in self.cell_arch[:, 0]: 81 | new_state = self._ops[ops_index](h) 82 | new_states.append(new_state) 83 | ops_index += 1 84 | 85 | s = sum(new_states) 86 | offset += len(states) 87 | states.append(s) 88 | 89 | concat_feature = torch.cat(states[-self.B:], dim=1) 90 | return prev_input, concat_feature 91 | 92 | 93 | class Baselin_Model (nn.Module): 94 | def __init__(self, 95 | network_arch, 96 | C_index, 97 | cell_arch, 98 | num_classes, 99 | args, 100 | low_level_layer): 101 | 102 | super(Baselin_Model, self).__init__() 103 | BatchNorm = SynchronizedBatchNorm2d if args.sync_bn == True else nn.BatchNorm2d 104 | F = args.F 105 | B = args.B 106 | 107 | eps = 1e-5 108 | momentum = 0.1 109 | 110 | self.args = args 111 | 112 | self.cells = nn.ModuleList() 113 | # self.model_2_network = network_arch[num_model_1_layers:] 114 | self.cell_arch = torch.from_numpy(cell_arch) 115 | self._num_classes = num_classes 116 | self.low_level_layer = low_level_layer 117 | # model_1_network = network_arch[:args.num_model_1_layers] 118 | self.decoder = Decoder(num_classes, BatchNorm) 119 | 120 | self.network_arch = network_arch 121 | self.num_net = len(network_arch) 122 | self.C_index = C_index 123 | 124 | 125 | FB = F * B 126 | fm = {0: 1, 1: 2, 2: 4, 3: 8} 127 | 128 | eps = 1e-5 129 | momentum = 0.1 130 | 131 | self.stem0 = nn.Sequential( 132 | nn.Conv2d(3, 64, 3, stride=2, padding=1, bias=False), 133 | BatchNorm(64, eps=eps, momentum=momentum), 134 | nn.ReLU(inplace=True) 135 | ) 136 | 137 | self.stem1 = nn.Sequential( 138 | nn.Conv2d(64, 64, 3, padding=1, bias=False), 139 | BatchNorm(64, eps=eps, momentum=momentum), 140 | ) 141 | 142 | self.stem2 = nn.Sequential( 143 | nn.ReLU(inplace=True), 144 | nn.Conv2d(64, 128, 3, stride=2, padding=1, bias=False), 145 | BatchNorm(128, eps=eps, momentum=momentum) 146 | ) 147 | 148 | for i in range(self.num_net): 149 | level = self.network_arch[i] 150 | prev_level = self.network_arch[i-1] 151 | prev_prev_level = self.network_arch[i-2] 152 | 153 | downup_sample = int(prev_level - level) 154 | if i == 0: 155 | downup_sample = int(0 - level) 156 | pre_downup_sample = int(-1 - level) 157 | _cell = Cell_baseline(BatchNorm, 158 | B, 159 | 64, 160 | 128, 161 | self.cell_arch, 162 | self.network_arch[i], 163 | F * fm[level], 164 | downup_sample) 165 | 166 | elif i == 1: 167 | pre_downup_sample = int(0 - level) 168 | _cell = Cell_baseline(BatchNorm, 169 | B, 170 | 128, 171 | FB * fm[prev_level], 172 | self.cell_arch, 173 | self.network_arch[i], 174 | F * fm[level], 175 | downup_sample) 176 | else: 177 | _cell = Cell_baseline(BatchNorm, 178 | B, 179 | FB * fm[prev_prev_level], 180 | FB * fm[prev_level], 181 | self.cell_arch, 182 | self.network_arch[i], 183 | F * fm[level], 184 | downup_sample) 185 | 186 | self.cells += [_cell] 187 | 188 | if self.network_arch[-1] == 1: 189 | mult = 2 190 | elif self.network_arch[-1] == 2: 191 | mult = 1 192 | elif self.network_arch[-1] == 3: 193 | mult = 0.5 194 | 195 | self._init_weight() 196 | self.pooling = nn.MaxPool2d(3, stride=2) 197 | self.gap = nn.AdaptiveAvgPool2d(1) 198 | self.relu = nn.ReLU() 199 | 200 | 201 | self.low_level_conv = nn.Sequential( 202 | nn.ReLU(), 203 | nn.Conv2d(F * B * 2**self.network_arch[low_level_layer], 48, 1, bias=False), 204 | BatchNorm(48, eps=eps, momentum=momentum), 205 | ) 206 | self.aspp = ASPP_train(F * B * fm[self.network_arch[-1]], 207 | 256, 208 | BatchNorm, 209 | mult=mult, 210 | ) 211 | 212 | self.conv_aspp = nn.ModuleList() 213 | for c in self.C_index: 214 | if self.network_arch[c] - self.network_arch[-1] == -1: 215 | self.conv_aspp.append(FactorizedReduce(FB*2**self.network_arch[c], FB*2**self.network_arch[-1], BatchNorm, eps=eps, momentum=momentum)) 216 | elif self.network_arch[c] - self.network_arch[-1] == -2: 217 | self.conv_aspp.append(DoubleFactorizedReduce(FB*2**self.network_arch[c], FB*2**self.network_arch[-1], BatchNorm, eps=eps, momentum=momentum)) 218 | elif self.network_arch[c] - self.network_arch[-1] > 0: 219 | self.conv_aspp.append(ReLUConvBN( 220 | FB*2**self.network_arch[c], FB*2**self.network_arch[-1], 1, 1, 0, BatchNorm, eps=eps, momentum=momentum, affine=True)) 221 | self._init_weight() 222 | 223 | 224 | def forward(self, x): 225 | size = (x.shape[2], x.shape[3]) 226 | aspp_size = (int((float(size[0]) - 1.0) * (2**(-1*(self.network_arch[-1]+2))) + 1.0), 227 | int((float(size[1]) - 1.0) * (2**(-1*(self.network_arch[-1]+2))) + 1.0)) 228 | conv_aspp_iter = 0 229 | 230 | stem = self.stem0(x) 231 | stem0 = self.stem1(stem) 232 | stem1 = self.stem2(stem0) 233 | two_last_inputs = [stem0, stem1] 234 | out = [] 235 | 236 | for i in range(self.num_net): 237 | 238 | two_last_inputs = self.cells[i]( 239 | two_last_inputs[0], two_last_inputs[1]) 240 | 241 | if i == self.low_level_layer: 242 | low_level = self.low_level_conv(two_last_inputs[1]) 243 | if i in self.C_index or i == self.num_net - 1: 244 | y = two_last_inputs[1] 245 | if y.shape[2] < aspp_size[0] or y.shape[3] < aspp_size[1]: 246 | y = F.interpolate(y, aspp_size, mode='bilinear') 247 | if self.network_arch[i] != self.network_arch[-1]: 248 | y = self.conv_aspp[conv_aspp_iter](y) 249 | conv_aspp_iter += 1 250 | y = self.aspp(y) 251 | y = self.decoder(y, low_level, size) 252 | out.append(y) 253 | 254 | return out 255 | 256 | def _init_weight(self): 257 | for m in self.modules(): 258 | if isinstance(m, nn.Conv2d): 259 | torch.nn.init.kaiming_normal_(m.weight) 260 | elif isinstance(m, SynchronizedBatchNorm2d): 261 | m.weight.data.fill_(1) 262 | m.bias.data.zero_() 263 | elif isinstance(m, nn.BatchNorm2d): 264 | m.weight.data.fill_(1) 265 | m.bias.data.zero_() 266 | -------------------------------------------------------------------------------- /modeling/cell_level_search.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import numpy as np 5 | from modeling.operations import * 6 | from modeling.genotypes import PRIMITIVES 7 | from modeling.genotypes import Genotype 8 | 9 | 10 | class MixedOp (nn.Module): 11 | 12 | def __init__(self, C, stride, BatchNorm): 13 | super(MixedOp, self).__init__() 14 | eps=1e-5 15 | momentum=0.1 16 | self._ops = nn.ModuleList() 17 | 18 | for i, primitive in enumerate(PRIMITIVES): 19 | op = OPS[primitive](C, stride, BatchNorm, eps, momentum, False) 20 | if 'pool' in primitive: 21 | op = nn.Sequential(op, BatchNorm(C, eps=eps, momentum=momentum, affine=False)) 22 | self._ops.append(op) 23 | 24 | def forward(self, x, weights, training=True): 25 | if training: 26 | return sum(w * op(x) for w, op in zip(weights, self._ops)) 27 | else: 28 | w = torch.argmax(weights) 29 | return self._ops[w](x) 30 | 31 | 32 | class Cell(nn.Module): 33 | 34 | def __init__(self, 35 | B, 36 | prev_prev_C, 37 | prev_C_down, 38 | prev_C_same, 39 | prev_C_up, 40 | C_out, 41 | BatchNorm=nn.BatchNorm2d, 42 | pre_preprocess_sample_rate=1): 43 | 44 | super(Cell, self).__init__() 45 | 46 | if prev_C_down is not None: 47 | self.preprocess_down = FactorizedReduce( 48 | prev_C_down, C_out, BatchNorm=BatchNorm, affine=False) 49 | if prev_C_same is not None: 50 | self.preprocess_same = ReLUConvBN( 51 | prev_C_same, C_out, 1, 1, 0, BatchNorm=BatchNorm, affine=False) 52 | if prev_C_up is not None: 53 | self.preprocess_up = ReLUConvBN( 54 | prev_C_up, C_out, 1, 1, 0, BatchNorm=BatchNorm, affine=False) 55 | 56 | if prev_prev_C != -1: 57 | if pre_preprocess_sample_rate >= 1: 58 | self.pre_preprocess = ReLUConvBN( 59 | prev_prev_C, C_out, 1, 1, 0, BatchNorm=BatchNorm, affine=False) 60 | elif pre_preprocess_sample_rate == 0.5: 61 | self.pre_preprocess = FactorizedReduce( 62 | prev_prev_C, C_out, BatchNorm=BatchNorm, affine=False) 63 | elif pre_preprocess_sample_rate == 0.25: 64 | self.pre_preprocess = DoubleFactorizedReduce( 65 | prev_prev_C, C_out, BatchNorm=BatchNorm, affine=False) 66 | 67 | self.B = B 68 | self._ops = nn.ModuleList() 69 | 70 | for i in range(self.B): 71 | for j in range(2+i): 72 | stride = 1 73 | if prev_prev_C == -1 and j == 0: 74 | op = None 75 | else: 76 | op = MixedOp(C_out, stride, BatchNorm) 77 | self._ops.append(op) 78 | 79 | 80 | def scale_dimension(self, dim, scale): 81 | assert isinstance(dim, int) 82 | return int((float(dim) - 1.0) * scale + 1.0) if dim % 2 else int(dim * scale) 83 | 84 | 85 | def prev_feature_resize(self, prev_feature, mode): 86 | if mode == 'down': 87 | feature_size_h = self.scale_dimension(prev_feature.shape[2], 0.5) 88 | feature_size_w = self.scale_dimension(prev_feature.shape[3], 0.5) 89 | elif mode == 'up': 90 | feature_size_h = self.scale_dimension(prev_feature.shape[2], 2) 91 | feature_size_w = self.scale_dimension(prev_feature.shape[3], 2) 92 | return F.interpolate(prev_feature, (feature_size_h, feature_size_w), mode='bilinear') 93 | 94 | 95 | def forward(self, s0, s1_down, s1_same, s1_up, n_alphas): 96 | if s1_down is not None: 97 | s1_down = self.preprocess_down(s1_down) 98 | size_h, size_w = s1_down.shape[2], s1_down.shape[3] 99 | if s1_same is not None: 100 | s1_same = self.preprocess_same(s1_same) 101 | size_h, size_w = s1_same.shape[2], s1_same.shape[3] 102 | if s1_up is not None: 103 | s1_up = self.prev_feature_resize(s1_up, 'up') 104 | s1_up = self.preprocess_up(s1_up) 105 | size_h, size_w = s1_up.shape[2], s1_up.shape[3] 106 | 107 | all_states = [] 108 | if s0 is not None: 109 | s0 = F.interpolate(s0, (size_h, size_w), mode='bilinear') if ( 110 | s0.shape[2] < size_h) or (s0.shape[3] < size_w) else s0 111 | s0 = self.pre_preprocess(s0) 112 | if s1_down is not None: 113 | states_down = [s0, s1_down] 114 | all_states.append(states_down) 115 | del s1_down 116 | if s1_same is not None: 117 | states_same = [s0, s1_same] 118 | all_states.append(states_same) 119 | del s1_same 120 | if s1_up is not None: 121 | states_up = [s0, s1_up] 122 | all_states.append(states_up) 123 | del s1_up 124 | else: 125 | if s1_down is not None: 126 | states_down = [0, s1_down] 127 | all_states.append(states_down) 128 | if s1_same is not None: 129 | states_same = [0, s1_same] 130 | all_states.append(states_same) 131 | if s1_up is not None: 132 | states_up = [0, s1_up] 133 | all_states.append(states_up) 134 | del s0 135 | final_concates = [] 136 | for states in all_states: 137 | offset = 0 138 | for i in range(self.B): 139 | new_states = [] 140 | for j, h in enumerate(states): 141 | branch_index = offset + j 142 | if self._ops[branch_index] is None: 143 | continue 144 | new_state = self._ops[branch_index]( 145 | h, n_alphas[branch_index]) 146 | new_states.append(new_state) 147 | 148 | s = sum(new_states) 149 | offset += len(states) 150 | states.append(s) 151 | 152 | concat_feature = torch.cat(states[-self.B:], dim=1) 153 | final_concates.append(concat_feature) 154 | 155 | return final_concates 156 | -------------------------------------------------------------------------------- /modeling/decoder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from modeling.sync_batchnorm.batchnorm import SynchronizedBatchNorm2d 5 | 6 | class Decoder(nn.Module): 7 | 8 | def __init__(self, n_class, BatchNorm): 9 | super(Decoder, self).__init__() 10 | eps = 1e-5 11 | momentum = 0.1 12 | self._conv = nn.Sequential( 13 | nn.ReLU(inplace=True), 14 | nn.Conv2d(304, 256, kernel_size=3, stride=1, padding=1, bias=False), 15 | BatchNorm(256, eps=eps, momentum=momentum), 16 | nn.ReLU(inplace=True), 17 | # 3x3 conv to refine the features 18 | nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1, bias=False), 19 | BatchNorm(256, eps=eps, momentum=momentum), 20 | nn.ReLU(inplace=True), 21 | nn.Conv2d(256, n_class, kernel_size=1, stride=1)) 22 | 23 | def forward(self, x, low_level, size): 24 | x = F.interpolate(x, [low_level.shape[2], low_level.shape[3]], mode='bilinear') \ 25 | if x.shape[2] != low_level.shape[2] else x 26 | x = torch.cat((x, low_level), 1) 27 | x = self._conv(x) 28 | x = F.interpolate(x, size, mode='bilinear') 29 | 30 | return x -------------------------------------------------------------------------------- /modeling/genotypes.py: -------------------------------------------------------------------------------- 1 | from collections import namedtuple 2 | 3 | Genotype = namedtuple('Genotype', 'cell cell_concat') 4 | 5 | PRIMITIVES = [ 6 | 'none', 7 | 'max_pool_3x3', 8 | 'avg_pool_3x3', 9 | 'skip_connect', 10 | 'sep_conv_3x3', 11 | 'sep_conv_5x5', 12 | 'dil_conv_3x3', 13 | 'dil_conv_5x5' 14 | ] 15 | 16 | 17 | -------------------------------------------------------------------------------- /modeling/operations.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import math 5 | from modeling.sync_batchnorm.batchnorm import SynchronizedBatchNorm2d 6 | 7 | OPS = { 8 | 'none' : lambda C, stride, BatchNorm, eps, momentum, affine: Zero(stride), 9 | 'avg_pool_3x3' : lambda C, stride, BatchNorm, eps, momentum, affine: nn.AvgPool2d(3, stride=stride, padding=1, count_include_pad=False), 10 | 'max_pool_3x3' : lambda C, stride, BatchNorm, eps, momentum, affine: nn.MaxPool2d(3, stride=stride, padding=1), 11 | 'skip_connect' : lambda C, stride, BatchNorm, eps, momentum, affine: Identity(), 12 | 'sep_conv_3x3' : lambda C, stride, BatchNorm, eps, momentum, affine: SepConv(C, C, 3, stride, 1, BatchNorm, eps=eps, momentum=momentum, affine=affine), 13 | 'sep_conv_5x5' : lambda C, stride, BatchNorm, eps, momentum, affine: SepConv(C, C, 5, stride, 2, BatchNorm, eps=eps, momentum=momentum, affine=affine), 14 | 'dil_conv_3x3' : lambda C, stride, BatchNorm, eps, momentum, affine: DilConv(C, C, 3, stride, 2, 2, BatchNorm, eps=eps, momentum=momentum, affine=affine), 15 | 'dil_conv_5x5' : lambda C, stride, BatchNorm, eps, momentum, affine: DilConv(C, C, 5, stride, 4, 2, BatchNorm, eps=eps, momentum=momentum, affine=affine), 16 | } 17 | 18 | class ReLUConvBN(nn.Module): 19 | 20 | def __init__(self, C_in, C_out, kernel_size, stride, padding, BatchNorm, eps=1e-5, momentum=0.1, affine=True): 21 | super(ReLUConvBN, self).__init__() 22 | self.op = nn.Sequential( 23 | nn.ReLU(inplace=False), 24 | nn.Conv2d(C_in, C_out, kernel_size, stride=stride, padding=padding, bias=False), 25 | BatchNorm(C_out, eps=eps, momentum=momentum, affine=affine) 26 | ) 27 | 28 | def forward(self, x): 29 | return self.op(x) 30 | 31 | 32 | class DilConv(nn.Module): 33 | 34 | def __init__(self, C_in, C_out, kernel_size, stride, padding, dilation, BatchNorm, eps=1e-5, momentum=0.1, affine=True): 35 | super(DilConv, self).__init__() 36 | self.op = nn.Sequential( 37 | nn.ReLU(inplace=False), 38 | nn.Conv2d(C_in, C_out, kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation, bias=False), 39 | BatchNorm(C_out, eps=eps, momentum=momentum, affine=affine) 40 | ) 41 | 42 | def forward(self, x): 43 | return self.op(x) 44 | 45 | 46 | class SepConv(nn.Module): 47 | 48 | def __init__(self, C_in, C_out, kernel_size, stride, padding, BatchNorm, eps=1e-5, momentum=0.1, affine=True): 49 | super(SepConv, self).__init__() 50 | self.op = nn.Sequential( 51 | nn.ReLU(inplace=False), 52 | nn.Conv2d(C_in, C_out, kernel_size=kernel_size, stride=stride, padding=padding, groups=C_in, bias=False), 53 | nn.Conv2d(C_out, C_out, kernel_size=1, padding=0, bias=False), 54 | BatchNorm(C_out, eps=eps, momentum=momentum, affine=affine), 55 | nn.ReLU(inplace=False), 56 | nn.Conv2d(C_out, C_out, kernel_size=kernel_size, stride=1, padding=padding, groups=C_in, bias=False), 57 | nn.Conv2d(C_out, C_out, kernel_size=1, padding=0, bias=False), 58 | BatchNorm(C_out, eps=eps, momentum=momentum, affine=affine), 59 | ) 60 | 61 | def forward(self, x): 62 | return self.op(x) 63 | 64 | 65 | class Identity(nn.Module): 66 | 67 | def __init__(self): 68 | super(Identity, self).__init__() 69 | 70 | def forward(self, x): 71 | return x 72 | 73 | 74 | class Zero(nn.Module): 75 | 76 | def __init__(self, stride): 77 | super(Zero, self).__init__() 78 | self.stride = stride 79 | 80 | def forward(self, x): 81 | if self.stride == 1: 82 | return x.mul(0.) 83 | return x[:,:,::self.stride,::self.stride].mul(0.) 84 | 85 | 86 | class FactorizedReduce(nn.Module): 87 | def __init__(self, C_in, C_out, BatchNorm, eps=1e-5, momentum=0.1, affine=True): 88 | super(FactorizedReduce, self).__init__() 89 | assert C_out % 2 == 0 90 | self.relu = nn.ReLU(inplace=False) 91 | self.conv_1 = nn.Conv2d(C_in, C_out // 2, 1, stride=2, padding=0, bias=False) 92 | self.conv_2 = nn.Conv2d(C_in, C_out // 2, 1, stride=2, padding=0, bias=False) 93 | self.bn = BatchNorm(C_out, eps=eps, momentum=momentum, affine=affine) 94 | self.pad = nn.ConstantPad2d((0, 1, 0, 1), 0) 95 | 96 | def forward(self, x): 97 | x = self.relu(x) 98 | y = self.pad(x) 99 | out = torch.cat([self.conv_1(x), self.conv_2(y[:,:,1:,1:])], dim=1) 100 | out = self.bn(out) 101 | return out 102 | 103 | 104 | class DoubleFactorizedReduce(nn.Module): 105 | def __init__(self, C_in, C_out, BatchNorm, eps=1e-5, momentum=0.1, affine=True): 106 | super(DoubleFactorizedReduce, self).__init__() 107 | assert C_out % 2 == 0 108 | self.relu = nn.ReLU(inplace=False) 109 | self.conv_1 = nn.Conv2d(C_in, C_out // 2, 1, stride=4, padding=0, bias=False) 110 | self.conv_2 = nn.Conv2d(C_in, C_out // 2, 1, stride=4, padding=0, bias=False) 111 | self.bn = BatchNorm(C_out, affine=affine) 112 | self.pad = nn.ConstantPad2d((0, 2, 0, 2), 0) 113 | 114 | def forward(self, x): 115 | x = self.relu(x) 116 | y = self.pad(x) 117 | out = torch.cat([self.conv_1(x), self.conv_2(y[:, :, 2:, 2:])], dim=1) 118 | out = self.bn(out) 119 | return out 120 | 121 | 122 | class ASPP(nn.Module): 123 | def __init__(self, in_channels, out_channels, paddings, dilations, BatchNorm=nn.BatchNorm2d, momentum=0.0003): 124 | 125 | super(ASPP, self).__init__() 126 | self.relu = nn.ReLU() 127 | self.conv11 = nn.Sequential(nn.Conv2d(in_channels, in_channels, 1, bias=False), 128 | BatchNorm(in_channels), 129 | nn.ReLU(inplace=True)) 130 | self.conv33 = nn.Sequential(nn.Conv2d(in_channels, in_channels, 3, 131 | padding=paddings, dilation=dilations, bias=False), 132 | BatchNorm(in_channels), 133 | nn.ReLU(inplace=True)) 134 | self.conv_p = nn.Sequential(nn.Conv2d(in_channels, in_channels, 1, bias=False), 135 | BatchNorm(in_channels), 136 | nn.ReLU(inplace=True)) 137 | 138 | self.concate_conv = nn.Sequential(nn.Conv2d(in_channels * 3, in_channels, 1, bias=False, stride=1, padding=0), 139 | BatchNorm(in_channels), 140 | nn.ReLU(inplace=True)) 141 | self.final_conv = nn.Conv2d(in_channels, out_channels, 1, bias=False, stride=1, padding=0) 142 | 143 | def forward(self, x): 144 | x = self.relu(x) 145 | conv11 = self.conv11(x) 146 | conv33 = self.conv33(x) 147 | 148 | # image pool and upsample 149 | image_pool = nn.AdaptiveAvgPool2d(1) 150 | upsample = nn.Upsample(size=x.size()[2:], mode='bilinear', align_corners=True) 151 | image_pool = image_pool(x) 152 | conv_image_pool = self.conv_p(image_pool) 153 | upsample = upsample(conv_image_pool) 154 | 155 | # concate 156 | concate = torch.cat([conv11, conv33, upsample], dim=1) 157 | concate = self.concate_conv(concate) 158 | return self.final_conv(concate) 159 | 160 | 161 | def normalized_shannon_entropy(x, num_class=19): 162 | size = (x.shape[2], x.shape[3]) 163 | x = F.softmax(x, dim=1).permute(0, 2, 3, 1) * F.log_softmax(x, dim=1).permute(0, 2, 3, 1) 164 | x = torch.sum(x, dim=3) 165 | x = x / math.log(num_class) 166 | x = -x 167 | 168 | x = x.sum() 169 | x = x / (size[0] * size[1]) 170 | return x.item() 171 | 172 | def confidence_max(x, thresold, num_class=19): 173 | x = F.softmax(x, dim=1) 174 | size = (x.shape[2], x.shape[3]) 175 | max_map = torch.max(x, 1) 176 | max_map = max_map[0] 177 | max_map = max_map[max_map > thresold] 178 | num_max = max_map.shape[0] 179 | num_max = num_max / (size[0] * size[1]) 180 | return num_max 181 | -------------------------------------------------------------------------------- /modeling/sync_batchnorm/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # File : __init__.py 3 | # Author : Jiayuan Mao 4 | # Email : maojiayuan@gmail.com 5 | # Date : 27/01/2018 6 | # 7 | # This file is part of Synchronized-BatchNorm-PyTorch. 8 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch 9 | # Distributed under MIT License. 10 | 11 | from .batchnorm import SynchronizedBatchNorm1d, SynchronizedBatchNorm2d, SynchronizedBatchNorm3d 12 | from .replicate import DataParallelWithCallback, patch_replication_callback -------------------------------------------------------------------------------- /modeling/sync_batchnorm/batchnorm.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # File : batchnorm.py 3 | # Author : Jiayuan Mao 4 | # Email : maojiayuan@gmail.com 5 | # Date : 27/01/2018 6 | # 7 | # This file is part of Synchronized-BatchNorm-PyTorch. 8 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch 9 | # Distributed under MIT License. 10 | 11 | import collections 12 | 13 | import torch 14 | import torch.nn.functional as F 15 | 16 | from torch.nn.modules.batchnorm import _BatchNorm 17 | from torch.nn.parallel._functions import ReduceAddCoalesced, Broadcast 18 | 19 | from .comm import SyncMaster 20 | 21 | __all__ = ['SynchronizedBatchNorm1d', 'SynchronizedBatchNorm2d', 'SynchronizedBatchNorm3d'] 22 | 23 | 24 | def _sum_ft(tensor): 25 | """sum over the first and last dimention""" 26 | return tensor.sum(dim=0).sum(dim=-1) 27 | 28 | 29 | def _unsqueeze_ft(tensor): 30 | """add new dementions at the front and the tail""" 31 | return tensor.unsqueeze(0).unsqueeze(-1) 32 | 33 | 34 | _ChildMessage = collections.namedtuple('_ChildMessage', ['sum', 'ssum', 'sum_size']) 35 | _MasterMessage = collections.namedtuple('_MasterMessage', ['sum', 'inv_std']) 36 | 37 | 38 | class _SynchronizedBatchNorm(_BatchNorm): 39 | def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True): 40 | super(_SynchronizedBatchNorm, self).__init__(num_features, eps=eps, momentum=momentum, affine=affine) 41 | 42 | self._sync_master = SyncMaster(self._data_parallel_master) 43 | 44 | self._is_parallel = False 45 | self._parallel_id = None 46 | self._slave_pipe = None 47 | 48 | def forward(self, input): 49 | # If it is not parallel computation or is in evaluation mode, use PyTorch's implementation. 50 | if not (self._is_parallel and self.training): 51 | return F.batch_norm( 52 | input, self.running_mean, self.running_var, self.weight, self.bias, 53 | self.training, self.momentum, self.eps) 54 | 55 | # Resize the input to (B, C, -1). 56 | input_shape = input.size() 57 | input = input.view(input.size(0), self.num_features, -1) 58 | 59 | # Compute the sum and square-sum. 60 | sum_size = input.size(0) * input.size(2) 61 | input_sum = _sum_ft(input) 62 | input_ssum = _sum_ft(input ** 2) 63 | 64 | # Reduce-and-broadcast the statistics. 65 | if self._parallel_id == 0: 66 | mean, inv_std = self._sync_master.run_master(_ChildMessage(input_sum, input_ssum, sum_size)) 67 | else: 68 | mean, inv_std = self._slave_pipe.run_slave(_ChildMessage(input_sum, input_ssum, sum_size)) 69 | 70 | # Compute the output. 71 | if self.affine: 72 | # MJY:: Fuse the multiplication for speed. 73 | output = (input - _unsqueeze_ft(mean)) * _unsqueeze_ft(inv_std * self.weight) + _unsqueeze_ft(self.bias) 74 | else: 75 | output = (input - _unsqueeze_ft(mean)) * _unsqueeze_ft(inv_std) 76 | 77 | # Reshape it. 78 | return output.view(input_shape) 79 | 80 | def __data_parallel_replicate__(self, ctx, copy_id): 81 | self._is_parallel = True 82 | self._parallel_id = copy_id 83 | 84 | # parallel_id == 0 means master device. 85 | if self._parallel_id == 0: 86 | ctx.sync_master = self._sync_master 87 | else: 88 | self._slave_pipe = ctx.sync_master.register_slave(copy_id) 89 | 90 | def _data_parallel_master(self, intermediates): 91 | """Reduce the sum and square-sum, compute the statistics, and broadcast it.""" 92 | 93 | # Always using same "device order" makes the ReduceAdd operation faster. 94 | # Thanks to:: Tete Xiao (http://tetexiao.com/) 95 | intermediates = sorted(intermediates, key=lambda i: i[1].sum.get_device()) 96 | 97 | to_reduce = [i[1][:2] for i in intermediates] 98 | to_reduce = [j for i in to_reduce for j in i] # flatten 99 | target_gpus = [i[1].sum.get_device() for i in intermediates] 100 | 101 | sum_size = sum([i[1].sum_size for i in intermediates]) 102 | sum_, ssum = ReduceAddCoalesced.apply(target_gpus[0], 2, *to_reduce) 103 | mean, inv_std = self._compute_mean_std(sum_, ssum, sum_size) 104 | 105 | broadcasted = Broadcast.apply(target_gpus, mean, inv_std) 106 | 107 | outputs = [] 108 | for i, rec in enumerate(intermediates): 109 | outputs.append((rec[0], _MasterMessage(*broadcasted[i * 2:i * 2 + 2]))) 110 | 111 | return outputs 112 | 113 | def _compute_mean_std(self, sum_, ssum, size): 114 | """Compute the mean and standard-deviation with sum and square-sum. This method 115 | also maintains the moving average on the master device.""" 116 | assert size > 1, 'BatchNorm computes unbiased standard-deviation, which requires size > 1.' 117 | mean = sum_ / size 118 | sumvar = ssum - sum_ * mean 119 | unbias_var = sumvar / (size - 1) 120 | bias_var = sumvar / size 121 | 122 | self.running_mean = (1 - self.momentum) * self.running_mean + self.momentum * mean.data 123 | self.running_var = (1 - self.momentum) * self.running_var + self.momentum * unbias_var.data 124 | 125 | return mean, bias_var.clamp(self.eps) ** -0.5 126 | 127 | 128 | class SynchronizedBatchNorm1d(_SynchronizedBatchNorm): 129 | r"""Applies Synchronized Batch Normalization over a 2d or 3d input that is seen as a 130 | mini-batch. 131 | .. math:: 132 | y = \frac{x - mean[x]}{ \sqrt{Var[x] + \epsilon}} * gamma + beta 133 | This module differs from the built-in PyTorch BatchNorm1d as the mean and 134 | standard-deviation are reduced across all devices during training. 135 | For example, when one uses `nn.DataParallel` to wrap the network during 136 | training, PyTorch's implementation normalize the tensor on each device using 137 | the statistics only on that device, which accelerated the computation and 138 | is also easy to implement, but the statistics might be inaccurate. 139 | Instead, in this synchronized version, the statistics will be computed 140 | over all training samples distributed on multiple devices. 141 | 142 | Note that, for one-GPU or CPU-only case, this module behaves exactly same 143 | as the built-in PyTorch implementation. 144 | The mean and standard-deviation are calculated per-dimension over 145 | the mini-batches and gamma and beta are learnable parameter vectors 146 | of size C (where C is the input size). 147 | During training, this layer keeps a running estimate of its computed mean 148 | and variance. The running sum is kept with a default momentum of 0.1. 149 | During evaluation, this running mean/variance is used for normalization. 150 | Because the BatchNorm is done over the `C` dimension, computing statistics 151 | on `(N, L)` slices, it's common terminology to call this Temporal BatchNorm 152 | Args: 153 | num_features: num_features from an expected input of size 154 | `batch_size x num_features [x width]` 155 | eps: a value added to the denominator for numerical stability. 156 | Default: 1e-5 157 | momentum: the value used for the running_mean and running_var 158 | computation. Default: 0.1 159 | affine: a boolean value that when set to ``True``, gives the layer learnable 160 | affine parameters. Default: ``True`` 161 | Shape: 162 | - Input: :math:`(N, C)` or :math:`(N, C, L)` 163 | - Output: :math:`(N, C)` or :math:`(N, C, L)` (same shape as input) 164 | Examples: 165 | >>> # With Learnable Parameters 166 | >>> m = SynchronizedBatchNorm1d(100) 167 | >>> # Without Learnable Parameters 168 | >>> m = SynchronizedBatchNorm1d(100, affine=False) 169 | >>> input = torch.autograd.Variable(torch.randn(20, 100)) 170 | >>> output = m(input) 171 | """ 172 | 173 | def _check_input_dim(self, input): 174 | if input.dim() != 2 and input.dim() != 3: 175 | raise ValueError('expected 2D or 3D input (got {}D input)' 176 | .format(input.dim())) 177 | super(SynchronizedBatchNorm1d, self)._check_input_dim(input) 178 | 179 | 180 | class SynchronizedBatchNorm2d(_SynchronizedBatchNorm): 181 | r"""Applies Batch Normalization over a 4d input that is seen as a mini-batch 182 | of 3d inputs 183 | .. math:: 184 | y = \frac{x - mean[x]}{ \sqrt{Var[x] + \epsilon}} * gamma + beta 185 | This module differs from the built-in PyTorch BatchNorm2d as the mean and 186 | standard-deviation are reduced across all devices during training. 187 | For example, when one uses `nn.DataParallel` to wrap the network during 188 | training, PyTorch's implementation normalize the tensor on each device using 189 | the statistics only on that device, which accelerated the computation and 190 | is also easy to implement, but the statistics might be inaccurate. 191 | Instead, in this synchronized version, the statistics will be computed 192 | over all training samples distributed on multiple devices. 193 | 194 | Note that, for one-GPU or CPU-only case, this module behaves exactly same 195 | as the built-in PyTorch implementation. 196 | The mean and standard-deviation are calculated per-dimension over 197 | the mini-batches and gamma and beta are learnable parameter vectors 198 | of size C (where C is the input size). 199 | During training, this layer keeps a running estimate of its computed mean 200 | and variance. The running sum is kept with a default momentum of 0.1. 201 | During evaluation, this running mean/variance is used for normalization. 202 | Because the BatchNorm is done over the `C` dimension, computing statistics 203 | on `(N, H, W)` slices, it's common terminology to call this Spatial BatchNorm 204 | Args: 205 | num_features: num_features from an expected input of 206 | size batch_size x num_features x height x width 207 | eps: a value added to the denominator for numerical stability. 208 | Default: 1e-5 209 | momentum: the value used for the running_mean and running_var 210 | computation. Default: 0.1 211 | affine: a boolean value that when set to ``True``, gives the layer learnable 212 | affine parameters. Default: ``True`` 213 | Shape: 214 | - Input: :math:`(N, C, H, W)` 215 | - Output: :math:`(N, C, H, W)` (same shape as input) 216 | Examples: 217 | >>> # With Learnable Parameters 218 | >>> m = SynchronizedBatchNorm2d(100) 219 | >>> # Without Learnable Parameters 220 | >>> m = SynchronizedBatchNorm2d(100, affine=False) 221 | >>> input = torch.autograd.Variable(torch.randn(20, 100, 35, 45)) 222 | >>> output = m(input) 223 | """ 224 | 225 | def _check_input_dim(self, input): 226 | if input.dim() != 4: 227 | raise ValueError('expected 4D input (got {}D input)' 228 | .format(input.dim())) 229 | super(SynchronizedBatchNorm2d, self)._check_input_dim(input) 230 | 231 | 232 | class SynchronizedBatchNorm3d(_SynchronizedBatchNorm): 233 | r"""Applies Batch Normalization over a 5d input that is seen as a mini-batch 234 | of 4d inputs 235 | .. math:: 236 | y = \frac{x - mean[x]}{ \sqrt{Var[x] + \epsilon}} * gamma + beta 237 | This module differs from the built-in PyTorch BatchNorm3d as the mean and 238 | standard-deviation are reduced across all devices during training. 239 | For example, when one uses `nn.DataParallel` to wrap the network during 240 | training, PyTorch's implementation normalize the tensor on each device using 241 | the statistics only on that device, which accelerated the computation and 242 | is also easy to implement, but the statistics might be inaccurate. 243 | Instead, in this synchronized version, the statistics will be computed 244 | over all training samples distributed on multiple devices. 245 | 246 | Note that, for one-GPU or CPU-only case, this module behaves exactly same 247 | as the built-in PyTorch implementation. 248 | The mean and standard-deviation are calculated per-dimension over 249 | the mini-batches and gamma and beta are learnable parameter vectors 250 | of size C (where C is the input size). 251 | During training, this layer keeps a running estimate of its computed mean 252 | and variance. The running sum is kept with a default momentum of 0.1. 253 | During evaluation, this running mean/variance is used for normalization. 254 | Because the BatchNorm is done over the `C` dimension, computing statistics 255 | on `(N, D, H, W)` slices, it's common terminology to call this Volumetric BatchNorm 256 | or Spatio-temporal BatchNorm 257 | Args: 258 | num_features: num_features from an expected input of 259 | size batch_size x num_features x depth x height x width 260 | eps: a value added to the denominator for numerical stability. 261 | Default: 1e-5 262 | momentum: the value used for the running_mean and running_var 263 | computation. Default: 0.1 264 | affine: a boolean value that when set to ``True``, gives the layer learnable 265 | affine parameters. Default: ``True`` 266 | Shape: 267 | - Input: :math:`(N, C, D, H, W)` 268 | - Output: :math:`(N, C, D, H, W)` (same shape as input) 269 | Examples: 270 | >>> # With Learnable Parameters 271 | >>> m = SynchronizedBatchNorm3d(100) 272 | >>> # Without Learnable Parameters 273 | >>> m = SynchronizedBatchNorm3d(100, affine=False) 274 | >>> input = torch.autograd.Variable(torch.randn(20, 100, 35, 45, 10)) 275 | >>> output = m(input) 276 | """ 277 | 278 | def _check_input_dim(self, input): 279 | if input.dim() != 5: 280 | raise ValueError('expected 5D input (got {}D input)' 281 | .format(input.dim())) 282 | super(SynchronizedBatchNorm3d, self)._check_input_dim(input) -------------------------------------------------------------------------------- /modeling/sync_batchnorm/comm.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # File : comm.py 3 | # Author : Jiayuan Mao 4 | # Email : maojiayuan@gmail.com 5 | # Date : 27/01/2018 6 | # 7 | # This file is part of Synchronized-BatchNorm-PyTorch. 8 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch 9 | # Distributed under MIT License. 10 | 11 | import queue 12 | import collections 13 | import threading 14 | 15 | __all__ = ['FutureResult', 'SlavePipe', 'SyncMaster'] 16 | 17 | 18 | class FutureResult(object): 19 | """A thread-safe future implementation. Used only as one-to-one pipe.""" 20 | 21 | def __init__(self): 22 | self._result = None 23 | self._lock = threading.Lock() 24 | self._cond = threading.Condition(self._lock) 25 | 26 | def put(self, result): 27 | with self._lock: 28 | assert self._result is None, 'Previous result has\'t been fetched.' 29 | self._result = result 30 | self._cond.notify() 31 | 32 | def get(self): 33 | with self._lock: 34 | if self._result is None: 35 | self._cond.wait() 36 | 37 | res = self._result 38 | self._result = None 39 | return res 40 | 41 | 42 | _MasterRegistry = collections.namedtuple('MasterRegistry', ['result']) 43 | _SlavePipeBase = collections.namedtuple('_SlavePipeBase', ['identifier', 'queue', 'result']) 44 | 45 | 46 | class SlavePipe(_SlavePipeBase): 47 | """Pipe for master-slave communication.""" 48 | 49 | def run_slave(self, msg): 50 | self.queue.put((self.identifier, msg)) 51 | ret = self.result.get() 52 | self.queue.put(True) 53 | return ret 54 | 55 | 56 | class SyncMaster(object): 57 | """An abstract `SyncMaster` object. 58 | - During the replication, as the data parallel will trigger an callback of each module, all slave devices should 59 | call `register(id)` and obtain an `SlavePipe` to communicate with the master. 60 | - During the forward pass, master device invokes `run_master`, all messages from slave devices will be collected, 61 | and passed to a registered callback. 62 | - After receiving the messages, the master device should gather the information and determine to message passed 63 | back to each slave devices. 64 | """ 65 | 66 | def __init__(self, master_callback): 67 | """ 68 | Args: 69 | master_callback: a callback to be invoked after having collected messages from slave devices. 70 | """ 71 | self._master_callback = master_callback 72 | self._queue = queue.Queue() 73 | self._registry = collections.OrderedDict() 74 | self._activated = False 75 | 76 | def __getstate__(self): 77 | return {'master_callback': self._master_callback} 78 | 79 | def __setstate__(self, state): 80 | self.__init__(state['master_callback']) 81 | 82 | def register_slave(self, identifier): 83 | """ 84 | Register an slave device. 85 | Args: 86 | identifier: an identifier, usually is the device id. 87 | Returns: a `SlavePipe` object which can be used to communicate with the master device. 88 | """ 89 | if self._activated: 90 | assert self._queue.empty(), 'Queue is not clean before next initialization.' 91 | self._activated = False 92 | self._registry.clear() 93 | future = FutureResult() 94 | self._registry[identifier] = _MasterRegistry(future) 95 | return SlavePipe(identifier, self._queue, future) 96 | 97 | def run_master(self, master_msg): 98 | """ 99 | Main entry for the master device in each forward pass. 100 | The messages were first collected from each devices (including the master device), and then 101 | an callback will be invoked to compute the message to be sent back to each devices 102 | (including the master device). 103 | Args: 104 | master_msg: the message that the master want to send to itself. This will be placed as the first 105 | message when calling `master_callback`. For detailed usage, see `_SynchronizedBatchNorm` for an example. 106 | Returns: the message to be sent back to the master device. 107 | """ 108 | self._activated = True 109 | 110 | intermediates = [(0, master_msg)] 111 | for i in range(self.nr_slaves): 112 | intermediates.append(self._queue.get()) 113 | 114 | results = self._master_callback(intermediates) 115 | assert results[0][0] == 0, 'The first result should belongs to the master.' 116 | 117 | for i, res in results: 118 | if i == 0: 119 | continue 120 | self._registry[i].result.put(res) 121 | 122 | for i in range(self.nr_slaves): 123 | assert self._queue.get() is True 124 | 125 | return results[0][1] 126 | 127 | @property 128 | def nr_slaves(self): 129 | return len(self._registry) 130 | -------------------------------------------------------------------------------- /modeling/sync_batchnorm/replicate.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # File : replicate.py 3 | # Author : Jiayuan Mao 4 | # Email : maojiayuan@gmail.com 5 | # Date : 27/01/2018 6 | # 7 | # This file is part of Synchronized-BatchNorm-PyTorch. 8 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch 9 | # Distributed under MIT License. 10 | 11 | import functools 12 | 13 | from torch.nn.parallel.data_parallel import DataParallel 14 | 15 | __all__ = [ 16 | 'CallbackContext', 17 | 'execute_replication_callbacks', 18 | 'DataParallelWithCallback', 19 | 'patch_replication_callback' 20 | ] 21 | 22 | 23 | class CallbackContext(object): 24 | pass 25 | 26 | 27 | def execute_replication_callbacks(modules): 28 | """ 29 | Execute an replication callback `__data_parallel_replicate__` on each module created by original replication. 30 | The callback will be invoked with arguments `__data_parallel_replicate__(ctx, copy_id)` 31 | Note that, as all modules are isomorphism, we assign each sub-module with a context 32 | (shared among multiple copies of this module on different devices). 33 | Through this context, different copies can share some information. 34 | We guarantee that the callback on the master copy (the first copy) will be called ahead of calling the callback 35 | of any slave copies. 36 | """ 37 | master_copy = modules[0] 38 | nr_modules = len(list(master_copy.modules())) 39 | ctxs = [CallbackContext() for _ in range(nr_modules)] 40 | 41 | for i, module in enumerate(modules): 42 | for j, m in enumerate(module.modules()): 43 | if hasattr(m, '__data_parallel_replicate__'): 44 | m.__data_parallel_replicate__(ctxs[j], i) 45 | 46 | 47 | class DataParallelWithCallback(DataParallel): 48 | """ 49 | Data Parallel with a replication callback. 50 | An replication callback `__data_parallel_replicate__` of each module will be invoked after being created by 51 | original `replicate` function. 52 | The callback will be invoked with arguments `__data_parallel_replicate__(ctx, copy_id)` 53 | Examples: 54 | > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False) 55 | > sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1]) 56 | # sync_bn.__data_parallel_replicate__ will be invoked. 57 | """ 58 | 59 | def replicate(self, module, device_ids): 60 | modules = super(DataParallelWithCallback, self).replicate(module, device_ids) 61 | execute_replication_callbacks(modules) 62 | return modules 63 | 64 | 65 | def patch_replication_callback(data_parallel): 66 | """ 67 | Monkey-patch an existing `DataParallel` object. Add the replication callback. 68 | Useful when you have customized `DataParallel` implementation. 69 | Examples: 70 | > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False) 71 | > sync_bn = DataParallel(sync_bn, device_ids=[0, 1]) 72 | > patch_replication_callback(sync_bn) 73 | # this is equivalent to 74 | > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False) 75 | > sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1]) 76 | """ 77 | 78 | assert isinstance(data_parallel, DataParallel) 79 | 80 | old_replicate = data_parallel.replicate 81 | 82 | @functools.wraps(old_replicate) 83 | def new_replicate(module, device_ids): 84 | modules = old_replicate(module, device_ids) 85 | execute_replication_callbacks(modules) 86 | return modules 87 | 88 | data_parallel.replicate = new_replicate -------------------------------------------------------------------------------- /modeling/sync_batchnorm/unittest.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # File : unittest.py 3 | # Author : Jiayuan Mao 4 | # Email : maojiayuan@gmail.com 5 | # Date : 27/01/2018 6 | # 7 | # This file is part of Synchronized-BatchNorm-PyTorch. 8 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch 9 | # Distributed under MIT License. 10 | 11 | import unittest 12 | 13 | import numpy as np 14 | from torch.autograd import Variable 15 | 16 | 17 | def as_numpy(v): 18 | if isinstance(v, Variable): 19 | v = v.data 20 | return v.cpu().numpy() 21 | 22 | 23 | class TorchTestCase(unittest.TestCase): 24 | def assertTensorClose(self, a, b, atol=1e-3, rtol=1e-3): 25 | npa, npb = as_numpy(a), as_numpy(b) 26 | self.assertTrue( 27 | np.allclose(npa, npb, atol=atol), 28 | 'Tensor close check failed\n{}\n{}\nadiff={}, rdiff={}'.format(a, b, np.abs(npa - npb).max(), np.abs((npa - npb) / np.fmax(npa, 1e-5)).max()) 29 | ) 30 | -------------------------------------------------------------------------------- /mypath.py: -------------------------------------------------------------------------------- 1 | class Path(object): 2 | @staticmethod 3 | def db_root_dir(dataset): 4 | if dataset == 'pascal': 5 | return 'your path to pascal dataset' # folder that contains VOCdevkit/. 6 | elif dataset == 'sbd': 7 | return 'your path to sbd dataset' # folder that contains dataset/. 8 | elif dataset == 'cityscapes': 9 | return 'your path to cityscapes dataset' # foler that contains leftImg8bit/ 10 | else: 11 | print('Dataset {} not available.'.format(dataset)) 12 | raise NotImplementedError 13 | -------------------------------------------------------------------------------- /scripts/eval.sh: -------------------------------------------------------------------------------- 1 | CUDA_VISIBLE_DEVICES=0 python ../eval.py \ 2 | --checkname evalution \ 3 | --network searched-dense \ 4 | --C 2 \ 5 | --F 20 \ 6 | --dataset cityscapes \ 7 | --workers 1 \ 8 | --gpu-ids 0 \ 9 | --dynamic \ 10 | --confidence edm \ 11 | --threshold 0.0 \ 12 | --saved-arch-path ../searched_arch \ 13 | --resume path_to_searched_dense_checkpoint \ 14 | --resume-edm path_to_edm_checkpoint -------------------------------------------------------------------------------- /scripts/search_cityscapes.sh: -------------------------------------------------------------------------------- 1 | CUDA_VISIBLE_DEVICES=0 python ../search.py \ 2 | --checkname Search \ 3 | --network net_supernet \ 4 | --F 20 \ 5 | --C 2 \ 6 | --batch-size 16 \ 7 | --workers 4 \ 8 | --dataset cityscapes \ 9 | --alpha-epoch 150 \ 10 | --epoch 300 \ 11 | --lr 0.05 \ 12 | --min-lr 0.003 \ 13 | --arch-lr 1e-3 \ 14 | --weight-decay 8e-4 \ 15 | --arch-weight-decay 1e-3 \ 16 | --opt-level O1 \ 17 | --seed 2 18 | -------------------------------------------------------------------------------- /scripts/train_dist.sh: -------------------------------------------------------------------------------- 1 | CUDA_VISIBLE_DEVICES=0,1,2,3 python -m torch.distributed.launch --nproc_per_node=4 ../train.py \ 2 | --checkname c2_search_dense \ 3 | --network searched-dense \ 4 | --C 2 \ 5 | --F 20 \ 6 | --dataset cityscapes \ 7 | --batch-size 4 \ 8 | --workers 8 \ 9 | --epoch 2689 \ 10 | --use-balanced-weights \ 11 | --dist \ 12 | --use-amp True \ 13 | --opt-level O1 \ 14 | --lr 0.05 \ 15 | --nesterov \ 16 | --gpu-ids 0,1,2,3 \ 17 | --saved-arch-path ../searched_arch -------------------------------------------------------------------------------- /scripts/train_edm.sh: -------------------------------------------------------------------------------- 1 | CUDA_VISIBLE_DEVICES=0 python ../train_edm.py \ 2 | --checkname edm \ 3 | --network searched-dense \ 4 | --F 20 \ 5 | --C 2 \ 6 | --batch-size 1 \ 7 | --train-batch 16 \ 8 | --workers 4 \ 9 | --dataset cityscapes_edm \ 10 | --epoch 20 \ 11 | --lr 0.001 \ 12 | --resume path_to_seached-dense-checkpoint -------------------------------------------------------------------------------- /search.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import numpy as np 4 | from tqdm import tqdm 5 | import sys 6 | import torch 7 | import torch.nn as nn 8 | from collections import OrderedDict 9 | from mypath import Path 10 | from dataloaders import make_data_loader 11 | 12 | from utils.loss import SegmentationLosses 13 | from utils.calculate_weights import calculate_weigths_labels 14 | from utils.lr_scheduler import LR_Scheduler 15 | from utils.saver import Saver 16 | from utils.summaries import TensorboardSummary 17 | from utils.metrics import Evaluator 18 | from utils.copy_state_dict import copy_state_dict 19 | from utils.eval_utils import * 20 | 21 | from modeling.sync_batchnorm.replicate import patch_replication_callback 22 | from modeling.model_search import Model_search 23 | from modeling.model_net_search import * 24 | from decoding.decoding_formulas import Decoder 25 | 26 | import apex 27 | 28 | 29 | try: 30 | from apex import amp 31 | APEX_AVAILABLE = True 32 | except ModuleNotFoundError: 33 | APEX_AVAILABLE = False 34 | 35 | 36 | print('working with pytorch version {}'.format(torch.__version__)) 37 | print('with cuda version {}'.format(torch.version.cuda)) 38 | print('cudnn enabled: {}'.format(torch.backends.cudnn.enabled)) 39 | print('cudnn version: {}'.format(torch.backends.cudnn.version())) 40 | 41 | torch.backends.cudnn.benchmark = True 42 | 43 | class Trainer(object): 44 | def __init__(self, args): 45 | self.args = args 46 | 47 | """ Define Saver """ 48 | self.saver = Saver(args) 49 | self.saver.save_experiment_config() 50 | 51 | """ Define Tensorboard Summary """ 52 | self.summary = TensorboardSummary(self.saver.experiment_dir) 53 | self.writer = self.summary.create_summary() 54 | self.use_amp = True 55 | self.opt_level = args.opt_level 56 | 57 | kwargs = {'num_workers': args.workers, 'pin_memory': True, 'drop_last':True, 'drop_last': True} 58 | self.train_loaderA, self.train_loaderB, self.val_loader, self.test_loader, self.nclass = make_data_loader(args, **kwargs) 59 | 60 | if args.use_balanced_weights: 61 | classes_weights_path = os.path.join(Path.db_root_dir(args.dataset), args.dataset+'_classes_weights.npy') 62 | if os.path.isfile(classes_weights_path): 63 | weight = np.load(classes_weights_path) 64 | else: 65 | """ if so, which trainloader to use? """ 66 | weight = calculate_weigths_labels(args.dataset, self.train_loader, self.nclass) 67 | weight = torch.from_numpy(weight.astype(np.float32)) 68 | else: 69 | weight = None 70 | 71 | self.criterion = nn.CrossEntropyLoss(weight=weight, ignore_index=255).cuda() 72 | 73 | """ Define network """ 74 | if self.args.network == 'supernet': 75 | model = Model_search(self.nclass, 12, self.args, exit_layer=5) 76 | 77 | elif self.args.network == 'net_supernet': 78 | cell_path = os.path.join(args.saved_arch_path, 'autodeeplab', 'genotype.npy') 79 | cell_arch = np.load(cell_path) 80 | 81 | if self.args.C == 2: 82 | C_index = [5] 83 | elif self.args.C == 3: 84 | C_index = [3,7] 85 | elif self.args.C == 4: 86 | C_index = [2,5,8] 87 | 88 | model = Model_net_search(self.nclass, 12, self.args, C_index=C_index, alphas=cell_arch) 89 | 90 | 91 | optimizer = torch.optim.SGD( 92 | model.weight_parameters(), 93 | args.lr, 94 | momentum=args.momentum, 95 | weight_decay=args.weight_decay 96 | ) 97 | 98 | self.model, self.optimizer = model, optimizer 99 | 100 | self.architect_optimizer = torch.optim.Adam(self.model.arch_parameters(), 101 | lr=args.arch_lr, betas=(0.9, 0.999), 102 | weight_decay=args.arch_weight_decay) 103 | 104 | """ Define Evaluator """ 105 | self.evaluator = [] 106 | for num in range(self.args.C): 107 | self.evaluator.append(Evaluator(self.nclass)) 108 | 109 | 110 | """ Define lr scheduler """ 111 | self.scheduler = LR_Scheduler(args.lr_scheduler, args.lr, 112 | args.epochs, len(self.train_loaderA), min_lr=args.min_lr) 113 | """ Using cuda """ 114 | if args.cuda: 115 | self.model = self.model.cuda() 116 | 117 | """ mixed precision """ 118 | if self.use_amp and args.cuda: 119 | keep_batchnorm_fp32 = True if (self.opt_level == 'O2' or self.opt_level == 'O3') else None 120 | 121 | """ fix for current pytorch version with opt_level 'O1' """ 122 | if self.opt_level == 'O1' and torch.__version__ < '1.3': 123 | for module in self.model.modules(): 124 | if isinstance(module, torch.nn.modules.batchnorm._BatchNorm): 125 | """ Hack to fix BN fprop without affine transformation """ 126 | if module.weight is None: 127 | module.weight = torch.nn.Parameter( 128 | torch.ones(module.running_var.shape, dtype=module.running_var.dtype, 129 | device=module.running_var.device), requires_grad=False) 130 | if module.bias is None: 131 | module.bias = torch.nn.Parameter( 132 | torch.zeros(module.running_var.shape, dtype=module.running_var.dtype, 133 | device=module.running_var.device), requires_grad=False) 134 | 135 | # print(keep_batchnorm_fp32) 136 | self.model, [self.optimizer, self.architect_optimizer] = amp.initialize( 137 | self.model, [self.optimizer, self.architect_optimizer], opt_level=self.opt_level, 138 | keep_batchnorm_fp32=keep_batchnorm_fp32, loss_scale="dynamic") 139 | 140 | print('cuda finished') 141 | 142 | 143 | """ Using data parallel""" 144 | if args.cuda and len(self.args.gpu_ids) >1: 145 | if self.opt_level == 'O2' or self.opt_level == 'O3': 146 | print('currently cannot run with nn.DataParallel and optimization level', self.opt_level) 147 | self.model = torch.nn.DataParallel(self.model, device_ids=self.args.gpu_ids) 148 | patch_replication_callback(self.model) 149 | print('training on multiple-GPUs') 150 | 151 | """ Resuming checkpoint """ 152 | self.best_pred = 0.0 153 | if args.resume is not None: 154 | if not os.path.isfile(args.resume): 155 | raise RuntimeError("=> no checkpoint found at '{}'" .format(args.resume)) 156 | checkpoint = torch.load(args.resume) 157 | args.start_epoch = checkpoint['epoch'] 158 | 159 | """ if the weights are wrapped in module object we have to clean it """ 160 | if args.clean_module: 161 | self.model.load_state_dict(checkpoint['state_dict']) 162 | state_dict = checkpoint['state_dict'] 163 | new_state_dict = OrderedDict() 164 | for k, v in state_dict.items(): 165 | name = k[7:] # remove 'module.' of dataparallel 166 | new_state_dict[name] = v 167 | copy_state_dict(self.model.state_dict(), new_state_dict) 168 | 169 | else: 170 | if (torch.cuda.device_count() > 1): 171 | copy_state_dict(self.model.module.state_dict(), checkpoint['state_dict']) 172 | else: 173 | copy_state_dict(self.model.state_dict(), checkpoint['state_dict']) 174 | 175 | 176 | def training(self, epoch): 177 | train_loss = 0.0 178 | search_loss = 0.0 179 | self.model.train() 180 | tbar = tqdm(self.train_loaderA) 181 | num_img_tr = len(self.train_loaderA) 182 | for i, sample in enumerate(tbar): 183 | image, target = sample['image'], sample['label'] 184 | if self.args.cuda: 185 | image, target = image.cuda(), target.cuda() 186 | self.scheduler(self.optimizer, i, epoch, self.best_pred) 187 | self.optimizer.zero_grad() 188 | outputs = self.model(image) 189 | 190 | loss = [] 191 | for classifier_i in range(self.args.C): 192 | loss.append(self.criterion(outputs[classifier_i], target)) 193 | 194 | loss = sum(loss)/(self.args.C) 195 | 196 | if self.use_amp: 197 | with amp.scale_loss(loss, self.optimizer) as scaled_loss: 198 | scaled_loss.backward() 199 | else: 200 | loss.backward() 201 | self.optimizer.step() 202 | 203 | if epoch >= self.args.alpha_epoch: 204 | search = iter(self.train_loaderB).next() 205 | image_search, target_search = search['image'], search['label'] 206 | if self.args.cuda: 207 | image_search, target_search = image_search.cuda(), target_search.cuda() 208 | 209 | self.architect_optimizer.zero_grad() 210 | outputs_search = self.model(image_search) 211 | 212 | arch_loss = [] 213 | for classifier_i in range(self.args.C): 214 | arch_loss.append(self.criterion(outputs_search[classifier_i], target_search)) 215 | 216 | arch_loss = sum(arch_loss) / self.args.C 217 | 218 | if self.use_amp: 219 | with amp.scale_loss(arch_loss, self.architect_optimizer) as arch_scaled_loss: 220 | arch_scaled_loss.backward() 221 | else: 222 | arch_loss.backward() 223 | 224 | self.architect_optimizer.step() 225 | search_loss += arch_loss.item() 226 | 227 | train_loss += loss.item() 228 | tbar.set_description('Train loss: %.3f --Search loss: %.3f' \ 229 | % (train_loss/(i+1), search_loss/(i+1))) 230 | 231 | self.writer.add_scalar('train/total_loss_epoch', train_loss, epoch) 232 | print('[Epoch: %d, numImages: %5d]' % (epoch, i * self.args.batch_size + image.data.shape[0])) 233 | print('Loss: %.3f' % train_loss) 234 | 235 | 236 | def validation(self, epoch): 237 | self.model.eval() 238 | for e in self.evaluator: 239 | e.reset() 240 | 241 | tbar = tqdm(self.val_loader, desc='\r') 242 | test_loss = 0.0 243 | for i, sample in enumerate(tbar): 244 | image, target = sample['image'], sample['label'] 245 | if self.args.cuda: 246 | image, target = image.cuda(), target.cuda() 247 | with torch.no_grad(): 248 | outputs = self.model(image) 249 | 250 | loss = [] 251 | for classifier_i in range(self.args.C): 252 | loss.append(self.criterion(outputs[classifier_i], target)) 253 | 254 | loss = sum(loss)/(self.args.C) 255 | test_loss += loss.item() 256 | 257 | tbar.set_description('Test loss: %.3f' % (test_loss / (i + 1))) 258 | 259 | for classifier_i in range(self.args.C): 260 | outputs[classifier_i] = torch.argmax(outputs[classifier_i], axis=1) 261 | self.evaluator[classifier_i].add_batch(target, outputs[classifier_i]) 262 | 263 | """ Add batch sample into evaluator""" 264 | 265 | mIoU = [] 266 | for classifier_i, e in enumerate(self.evaluator): 267 | mIoU.append(e.Mean_Intersection_over_Union()) 268 | self.writer.add_scalar('val/classifier_' + str(classifier_i) + '/mIoU', mIoU[classifier_i], epoch) 269 | 270 | """ FWIoU = self.evaluator.Frequency_Weighted_Intersection_over_Union() """ 271 | self.writer.add_scalar('val/total_loss_epoch', test_loss, epoch) 272 | 273 | 274 | print('Validation:') 275 | print('[Epoch: %d, numImages: %5d]' % (epoch, i * self.args.test_batch_size + image.data.shape[0])) 276 | print('Loss: %.3f' % test_loss) 277 | new_pred = sum(mIoU)/self.args.C 278 | if new_pred > self.best_pred: 279 | is_best = True 280 | self.best_pred = new_pred 281 | if torch.cuda.device_count() > 1: 282 | state_dict = self.model.module.state_dict() 283 | else: 284 | state_dict = self.model.state_dict() 285 | self.saver.save_checkpoint({ 286 | 'epoch': epoch + 1, 287 | 'state_dict': state_dict, 288 | 'optimizer': self.optimizer.state_dict(), 289 | 'best_pred': self.best_pred, 290 | }, is_best) 291 | 292 | """ decode the arch """ 293 | self.decoder_save(epoch, miou=new_pred, evaluation=True) 294 | 295 | 296 | def decoder_save(self, epoch, miou=None, evaluation=False): 297 | num = str(epoch) 298 | if evaluation: 299 | num = num + '_eval' 300 | try: 301 | dir_name = os.path.join(self.saver.experiment_dir, num) 302 | os.makedirs(dir_name) 303 | except: 304 | print('folder path error') 305 | 306 | decoder = Decoder(None, 307 | self.model.betas, 308 | self.args.B) 309 | 310 | result_paths, result_paths_space = decoder.viterbi_decode() 311 | 312 | betas = self.model.betas.data.cpu().numpy() 313 | 314 | network_path_filename = os.path.join(dir_name,'network_path') 315 | beta_filename = os.path.join(dir_name, 'betas') 316 | 317 | np.save(network_path_filename, result_paths) 318 | np.save(beta_filename, betas) 319 | 320 | if miou != None: 321 | with open(os.path.join(dir_name, 'miou.txt'), 'w') as f: 322 | f.write(str(miou)) 323 | if evaluation: 324 | self.writer.add_text('network_path', str(result_paths), epoch) 325 | self.writer.add_text('miou', str(miou), epoch) 326 | else: 327 | self.writer.add_text('network_path', str(result_paths), epoch) 328 | 329 | 330 | def main(): 331 | parser = argparse.ArgumentParser(description="The Search") 332 | 333 | """ Search Network """ 334 | parser.add_argument('--network', type=str, default='supernet', choices=['supernet', 'net_supernet']) 335 | parser.add_argument('--F', type=int, default=8) 336 | parser.add_argument('--B', type=int, default=5) 337 | parser.add_argument('--C', type=int, default=2, help='num of classifiers') 338 | 339 | 340 | 341 | """ Training Setting """ 342 | parser.add_argument('--start-epoch', type=int, default=0, metavar='N', help='start epochs (default:0)') 343 | parser.add_argument('--epochs', type=int, default=40, metavar='N', help='number of epochs to train (default: auto)') 344 | parser.add_argument('--alpha-epoch', type=int, default=20, metavar='N', help='epoch to start training alphas') 345 | parser.add_argument('--sync-bn', type=bool, default=None, help='whether to use sync bn (default: auto)') 346 | parser.add_argument('--clean-module', type=int, default=0) 347 | 348 | 349 | """ Dataset Setting """ 350 | parser.add_argument('--dataset', type=str, default='cityscapes', choices=['pascal', 'coco', 'cityscapes', 'kd']) 351 | parser.add_argument('--use-sbd', action='store_true', default=False, help='whether to use SBD dataset (default: True)') 352 | parser.add_argument('--load-parallel', type=int, default=0) 353 | parser.add_argument('--workers', type=int, default=2, metavar='N', help='dataloader threads') 354 | parser.add_argument('--batch-size', type=int, default=2, metavar='N') 355 | parser.add_argument('--test-batch-size', type=int, default=1, metavar='N') 356 | parser.add_argument('--use-balanced-weights', action='store_true', default=False, help='whether to use balanced weights (default: False)') 357 | 358 | 359 | """ optimizer params """ 360 | parser.add_argument('--lr', type=float, default=0.025, metavar='LR') 361 | parser.add_argument('--min-lr', type=float, default=0.001) 362 | parser.add_argument('--arch-lr', type=float, default=3e-3, metavar='LR', help='learning rate for alpha and beta in architect searching process') 363 | parser.add_argument('--lr-scheduler', type=str, default='cos',choices=['poly', 'step', 'cos']) 364 | parser.add_argument('--momentum', type=float, default=0.9, metavar='M', help='momentum (default: 0.9)') 365 | parser.add_argument('--weight-decay', type=float, default=5e-4, metavar='M', help='w-decay (default: 5e-4)') 366 | parser.add_argument('--arch-weight-decay', type=float, default=1e-3, metavar='M', help='w-decay (default: 5e-4)') 367 | parser.add_argument('--nesterov', action='store_true', default=False, help='whether use nesterov (default: False)') 368 | parser.add_argument('--use-amp', action='store_true', default=True) 369 | parser.add_argument('--opt-level', type=str, default='O0', choices=['O0', 'O1', 'O2', 'O3'], help='opt level for half percision training (default: O0)') 370 | 371 | 372 | """ cuda, seed and logging """ 373 | parser.add_argument('--no-cuda', action='store_true', default=False, help='disables CUDA training') 374 | parser.add_argument('--gpu-ids', type=str, default='0', help='use which gpu to train, must be a comma-separated list of integers only (default=0)') 375 | parser.add_argument('--seed', type=int, default=1, metavar='S', help='random seed (default: 1)') 376 | 377 | 378 | """ checking point """ 379 | parser.add_argument('--resume', type=str, default=None, help='put the path to resuming file if needed') 380 | parser.add_argument('--checkname', type=str, default=None, help='set the checkpoint name') 381 | parser.add_argument('--saved-arch-path', type=str, default='../searched_arch/') 382 | 383 | 384 | """ evaluation option """ 385 | parser.add_argument('--eval-interval', type=int, default=10, help='evaluuation interval (default: 1)') 386 | parser.add_argument('--no-val', action='store_true', default=False, help='skip validation during training') 387 | 388 | 389 | args = parser.parse_args() 390 | args.cuda = not args.no_cuda and torch.cuda.is_available() 391 | if args.cuda: 392 | try: 393 | args.gpu_ids = [int(s) for s in args.gpu_ids.split(',')] 394 | except ValueError: 395 | raise ValueError('Argument --gpu_ids must be a comma-separated list of integers only') 396 | 397 | if args.cuda and len(args.gpu_ids) > 1: 398 | args.sync_bn = True 399 | else: 400 | args.sync_bn = False 401 | 402 | if args.test_batch_size is None: 403 | args.test_batch_size = 1 404 | 405 | 406 | if args.checkname is None: 407 | args.checkname = 'deeplab-'+str(args.backbone) 408 | print(args) 409 | torch.manual_seed(args.seed) 410 | trainer = Trainer(args) 411 | print('Starting Epoch:', trainer.args.start_epoch) 412 | print('Total Epoches:', trainer.args.epochs) 413 | for epoch in range(trainer.args.start_epoch, trainer.args.epochs): 414 | trainer.training(epoch) 415 | if not trainer.args.no_val and epoch % args.eval_interval == (args.eval_interval - 1): 416 | trainer.validation(epoch) 417 | 418 | trainer.writer.close() 419 | 420 | if __name__ == "__main__": 421 | main() 422 | -------------------------------------------------------------------------------- /searched_arch/40_5e_38_lr/genotype_1.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HankKung/Auto-Dynamic-DeepLab/4150a19d632269f7ebcb63e92906a7f40e6a283b/searched_arch/40_5e_38_lr/genotype_1.npy -------------------------------------------------------------------------------- /searched_arch/40_5e_38_lr/genotype_2.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HankKung/Auto-Dynamic-DeepLab/4150a19d632269f7ebcb63e92906a7f40e6a283b/searched_arch/40_5e_38_lr/genotype_2.npy -------------------------------------------------------------------------------- /searched_arch/autodeeplab/genotype.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HankKung/Auto-Dynamic-DeepLab/4150a19d632269f7ebcb63e92906a7f40e6a283b/searched_arch/autodeeplab/genotype.npy -------------------------------------------------------------------------------- /searched_arch/searched_baseline/genotype_1.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HankKung/Auto-Dynamic-DeepLab/4150a19d632269f7ebcb63e92906a7f40e6a283b/searched_arch/searched_baseline/genotype_1.npy -------------------------------------------------------------------------------- /searched_arch/searched_baseline/genotype_2.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HankKung/Auto-Dynamic-DeepLab/4150a19d632269f7ebcb63e92906a7f40e6a283b/searched_arch/searched_baseline/genotype_2.npy -------------------------------------------------------------------------------- /searched_arch/searched_baseline/network_path.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HankKung/Auto-Dynamic-DeepLab/4150a19d632269f7ebcb63e92906a7f40e6a283b/searched_arch/searched_baseline/network_path.npy -------------------------------------------------------------------------------- /searched_arch/searched_baseline/network_path_space.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HankKung/Auto-Dynamic-DeepLab/4150a19d632269f7ebcb63e92906a7f40e6a283b/searched_arch/searched_baseline/network_path_space.npy -------------------------------------------------------------------------------- /train_edm.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import numpy as np 4 | from tqdm import tqdm 5 | import torch.nn as nn 6 | from torchviz import make_dot, make_dot_from_trace 7 | 8 | from mypath import Path 9 | from dataloaders import make_data_loader 10 | 11 | from utils.loss import SegmentationLosses 12 | from utils.calculate_weights import calculate_weigths_labels 13 | from utils.lr_scheduler import LR_Scheduler 14 | from utils.saver import Saver 15 | from utils.summaries import TensorboardSummary 16 | from utils.metrics import Evaluator 17 | from utils.copy_state_dict import copy_state_dict 18 | from utils.eval_utils import AverageMeter 19 | # from utils.encoding import * 20 | 21 | from modeling.baseline_model import * 22 | # from modeling.ADD import * 23 | from modeling.ADD import * 24 | from modeling.operations import normalized_shannon_entropy 25 | from modeling.sync_batchnorm.batchnorm import SynchronizedBatchNorm2d 26 | from modeling.sync_batchnorm.replicate import patch_replication_callback 27 | 28 | from apex import amp 29 | from ptflops import get_model_complexity_info 30 | from torch.utils.data import TensorDataset, DataLoader 31 | 32 | 33 | torch.backends.cudnn.benchmark = True 34 | 35 | 36 | class trainNew(object): 37 | def __init__(self, args): 38 | self.args = args 39 | 40 | """ Define Saver """ 41 | self.saver = Saver(args) 42 | self.saver.save_experiment_config() 43 | 44 | """ Define Tensorboard Summary """ 45 | self.summary = TensorboardSummary(self.saver.experiment_dir) 46 | self.writer = self.summary.create_summary() 47 | 48 | 49 | """ Define Dataloader """ 50 | kwargs = {'num_workers': args.workers, 'pin_memory': True, 'drop_last': True} 51 | self.train_loader, self.val_loader, _, self.nclass = make_data_loader(args, **kwargs) 52 | 53 | 54 | self.criterion = nn.L1Loss() 55 | if args.network == 'searched-dense': 56 | cell_path = os.path.join(args.saved_arch_path, 'autodeeplab', 'genotype.npy') 57 | cell_arch = np.load(cell_path) 58 | 59 | if self.args.C == 2: 60 | C_index = [5] 61 | network_arch = [1, 2, 2, 2, 3, 2, 2, 1, 1, 1, 1, 2] 62 | low_level_layer = 0 63 | elif self.args.C == 3: 64 | C_index = [3, 7] 65 | network_arch = [1, 2, 3, 2, 2, 3, 2, 3, 2, 3, 2, 3] 66 | low_level_layer = 0 67 | elif self.args.C == 4: 68 | C_index = [2, 5, 8] 69 | network_arch = [1, 2, 3, 3, 2, 3, 3, 3, 3, 3, 2, 2] 70 | low_level_layer = 0 71 | 72 | model = ADD(network_arch, 73 | C_index, 74 | cell_arch, 75 | self.nclass, 76 | args, 77 | low_level_layer) 78 | 79 | elif args.network.startswith('autodeeplab'): 80 | network_arch = [0, 0, 0, 1, 2, 1, 2, 2, 3, 3, 2, 1] 81 | cell_path = os.path.join(args.saved_arch_path, 'autodeeplab', 'genotype.npy') 82 | cell_arch = np.load(cell_path) 83 | low_level_layer = 2 84 | if self.args.C == 2: 85 | C_index = [5] 86 | elif self.args.C == 3: 87 | C_index = [3, 7] 88 | elif self.args.C == 4: 89 | C_index = [2, 5, 8] 90 | 91 | if args.network == 'autodeeplab-dense': 92 | model = ADD(network_arch, 93 | C_index, 94 | cell_arch, 95 | self.nclass, 96 | args, 97 | low_level_layer) 98 | 99 | elif args.network == 'autodeeplab-baseline': 100 | model = Baselin_Model(network_arch, 101 | C_index, 102 | cell_arch, 103 | self.nclass, 104 | args, 105 | low_level_layer) 106 | 107 | self.edm = EDM().cuda() 108 | optimizer = torch.optim.Adam(self.edm.parameters(), lr=args.lr) 109 | self.model, self.optimizer = model, optimizer 110 | 111 | if args.cuda: 112 | self.model = self.model.cuda() 113 | 114 | """ Resuming checkpoint """ 115 | if args.resume is not None: 116 | if not os.path.isfile(args.resume): 117 | raise RuntimeError("=> no checkpoint found at '{}'".format(args.resume)) 118 | checkpoint = torch.load(args.resume) 119 | args.start_epoch = checkpoint['epoch'] 120 | 121 | """ if the weights are wrapped in module object we have to clean it """ 122 | if args.clean_module: 123 | self.model.load_state_dict(checkpoint['state_dict']) 124 | state_dict = checkpoint['state_dict'] 125 | new_state_dict = OrderedDict() 126 | for k, v in state_dict.items(): 127 | name = k[7:] # remove 'module.' of dataparallel 128 | new_state_dict[name] = v 129 | copy_state_dict(self.model.state_dict(), new_state_dict) 130 | 131 | else: 132 | if (torch.cuda.device_count() > 1): 133 | copy_state_dict(self.model.module.state_dict(), checkpoint['state_dict']) 134 | else: 135 | copy_state_dict(self.model.state_dict(), checkpoint['state_dict']) 136 | 137 | if os.path.isfile('feature.npy'): 138 | train_feature = np.load('feature.npy') 139 | train_entropy = np.load('entropy.npy') 140 | train_set = TensorDataset(torch.tensor(train_feature), torch.tensor(train_entropy, dtype=torch.float)) 141 | train_set = DataLoader(train_set, batch_size=self.args.train_batch, shuffle=True, pin_memory=True) 142 | self.train_set = train_set 143 | else: 144 | self.make_data(self.args.train_batch) 145 | 146 | def make_data(self, batch_size): 147 | self.model.eval() 148 | tbar = tqdm(self.train_loader, desc='\r') 149 | train_feature = [] 150 | train_entropy = [] 151 | for i, sample in enumerate(tbar): 152 | image, target = sample['image'], sample['label'] 153 | if self.args.cuda: 154 | image, target = image.cuda(), target.cuda() 155 | 156 | with torch.no_grad(): 157 | output, feature = self.model.get_feature(image) 158 | train_entropy.append(normalized_shannon_entropy(output)) 159 | train_feature.append(feature.cpu()) 160 | 161 | train_feature = [t.numpy() for t in train_feature] 162 | np_entropy = np.array(train_entropy) 163 | np.save('feature', train_feature) 164 | np.save('entropy', train_entropy) 165 | train_set = TensorDataset(torch.tensor(train_feature, dtype=torch.float), torch.tensor(train_entropy, dtype=torch.float)) 166 | train_set = DataLoader(train_set, batch_size=batch_size, shuffle=True, pin_memory=True) 167 | self.train_set = train_set 168 | 169 | def training(self, epoch): 170 | train_loss = 0.0 171 | self.edm.train() 172 | tbar = tqdm(self.train_set) 173 | for i, (feature,entropy) in enumerate(tbar): 174 | if self.args.cuda: 175 | feature, entropy = feature.cuda(), entropy.cuda() 176 | output = self.edm(feature) 177 | loss = self.criterion(output, entropy) 178 | self.optimizer.zero_grad() 179 | loss.backward() 180 | self.optimizer.step() 181 | train_loss += loss.item() 182 | tbar.set_description('Train loss: %.3f' % (train_loss / (i + 1))) 183 | self.writer.add_scalar('train/total_loss_epoch', train_loss, epoch) 184 | print('[Epoch: %d' % (epoch)) 185 | print('Loss: %.3f' % train_loss) 186 | 187 | 188 | def main(): 189 | parser = argparse.ArgumentParser(description="Train EDM") 190 | 191 | """ model setting """ 192 | parser.add_argument('--network', type=str, default='searched-dense', \ 193 | choices=['searched-dense', 'autodeeplab-baseline', 'autodeeplab-dense']) 194 | parser.add_argument('--F', type=int, default=20) 195 | parser.add_argument('--B', type=int, default=5) 196 | parser.add_argument('--C', type=int, default=2, help='num of classifiers') 197 | 198 | 199 | """ dataset config""" 200 | parser.add_argument('--dataset', type=str, default='cityscapes', choices=['cityscapes', 'cityscapes_edm'], help='dataset name (default: pascal)') 201 | parser.add_argument('--workers', type=int, default=4, metavar='N', help='dataloader threads') 202 | 203 | 204 | """ training config """ 205 | parser.add_argument('--epochs', type=int, default=10, metavar='N') 206 | parser.add_argument('--start_epoch', type=int, default=0) 207 | parser.add_argument('--batch-size', type=int, default=1, metavar='N') 208 | parser.add_argument('--test-batch-size', type=int, default=1, metavar='N') 209 | parser.add_argument('--train-batch', type=int, default=16, metavar='N') 210 | parser.add_argument('--dist', action='store_true', default=False) 211 | 212 | 213 | """ optimizer params """ 214 | parser.add_argument('--lr', type=float, default=0.001, metavar='LR') 215 | parser.add_argument('--clean-module', type=int, default=0) 216 | parser.add_argument('--sync-bn', type=bool, default=False, help='whether to use sync bn (default: auto)') 217 | 218 | 219 | """ cuda, seed and logging """ 220 | parser.add_argument('--no-cuda', action='store_true', default=False) 221 | parser.add_argument('--gpu-ids', type=str, default='0', help='use which gpu to train, must be a comma-separated list of integers only (default=0)') 222 | parser.add_argument('--seed', type=int, default=1, metavar='S') 223 | 224 | 225 | """ checking point """ 226 | parser.add_argument('--resume', type=str, default=None, help='put the path to resuming file if needed') 227 | parser.add_argument('--saved-arch-path', type=str, default='searched_arch/') 228 | parser.add_argument('--checkname', type=str, default='edm') 229 | 230 | args = parser.parse_args() 231 | args.cuda = not args.no_cuda and torch.cuda.is_available() 232 | if args.cuda: 233 | try: 234 | args.gpu_ids = [int(s) for s in args.gpu_ids.split(',')] 235 | except ValueError: 236 | raise ValueError('Argument --gpu_ids must be a comma-separated list of integers only') 237 | 238 | 239 | if args.checkname is None: 240 | args.checkname = 'deeplab-'+str(args.network) 241 | 242 | print(args) 243 | torch.manual_seed(args.seed) 244 | torch.cuda.manual_seed(args.seed) 245 | 246 | new_trainer = trainNew(args) 247 | # new_trainer.mac() 248 | 249 | # new_trainer.make_data(args.train_batch) 250 | print('start training') 251 | for epoch in range(args.epochs): 252 | new_trainer.training(epoch) 253 | new_trainer.saver.save_checkpoint({ 254 | 'epoch':args.epochs, 255 | 'state_dict': new_trainer.edm.state_dict(), 256 | 'optimizer': new_trainer.optimizer.state_dict(), 257 | 'best_pred': 1}, 258 | True) 259 | new_trainer.writer.close() 260 | 261 | if __name__ == "__main__": 262 | main() 263 | -------------------------------------------------------------------------------- /utils/calculate_weights.py: -------------------------------------------------------------------------------- 1 | import os 2 | from tqdm import tqdm 3 | import numpy as np 4 | from mypath import Path 5 | 6 | def calculate_weigths_labels(dataset, dataloader, num_classes): 7 | # Create an instance from the data loader 8 | z = np.zeros((num_classes,)) 9 | # Initialize tqdm 10 | tqdm_batch = tqdm(dataloader) 11 | print('Calculating classes weights') 12 | for sample in tqdm_batch: 13 | y = sample['label'] 14 | y = y.detach().cpu().numpy() 15 | mask = (y >= 0) & (y < num_classes) 16 | labels = y[mask].astype(np.uint8) 17 | count_l = np.bincount(labels, minlength=num_classes) 18 | z += count_l 19 | tqdm_batch.close() 20 | total_frequency = np.sum(z) 21 | class_weights = [] 22 | for frequency in z: 23 | class_weight = 1 / (np.log(1.02 + (frequency / total_frequency))) 24 | class_weights.append(class_weight) 25 | ret = np.array(class_weights) 26 | classes_weights_path = os.path.join(Path.db_root_dir(dataset), dataset+'_classes_weights.npy') 27 | np.save(classes_weights_path, ret) 28 | 29 | return ret -------------------------------------------------------------------------------- /utils/copy_state_dict.py: -------------------------------------------------------------------------------- 1 | def copy_state_dict(cur_state_dict, pre_state_dict, prefix = ''): 2 | def _get_params(key): 3 | key = prefix + key 4 | if key in pre_state_dict: 5 | return pre_state_dict[key] 6 | return None 7 | 8 | for k in cur_state_dict.keys(): 9 | v = _get_params(k) 10 | try: 11 | if v is None: 12 | print('parameter {} not found'.format(k)) 13 | continue 14 | cur_state_dict[k].copy_(v) 15 | except: 16 | print('copy param {} failed'.format(k)) 17 | continue -------------------------------------------------------------------------------- /utils/eval_utils.py: -------------------------------------------------------------------------------- 1 | import math 2 | import random 3 | import numpy as np 4 | import torch 5 | import torch.nn.functional as F 6 | 7 | 8 | class AverageMeter(object): 9 | """Computes and stores the average and current value""" 10 | def __init__(self): 11 | self.initialized = False 12 | self.val = None 13 | self.avg = None 14 | self.sum = None 15 | self.count = None 16 | 17 | def initialize(self, val, weight): 18 | self.val = val 19 | self.avg = val 20 | self.sum = val * weight 21 | self.count = weight 22 | self.initialized = True 23 | 24 | def update(self, val, weight=1): 25 | if not self.initialized: 26 | self.initialize(val, weight) 27 | else: 28 | self.add(val, weight) 29 | 30 | def add(self, val, weight): 31 | self.val = val 32 | self.sum += val * weight 33 | self.count += weight 34 | self.avg = self.sum / self.count 35 | 36 | def value(self): 37 | return self.val 38 | 39 | def average(self): 40 | return self.avg 41 | 42 | 43 | def count_parameters_in_MB(model): 44 | return np.sum(np.prod(v.size()) for name, v in model.named_parameters() if "aux" not in name)/1e6 45 | -------------------------------------------------------------------------------- /utils/loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | class SegmentationLosses(object): 5 | def __init__(self, weight=None, ignore_index=255, cuda=False): 6 | self.ignore_index = ignore_index 7 | self.weight = weight 8 | self.cuda = cuda 9 | 10 | def build_loss(self, mode='ce'): 11 | """Choices: ['ce' or 'focal']""" 12 | if mode == 'ce': 13 | return self.CrossEntropyLoss 14 | else: 15 | raise NotImplementedError 16 | 17 | def CrossEntropyLoss(self, logit, target): 18 | n, c, h, w = logit.size() 19 | criterion = nn.CrossEntropyLoss(weight=self.weight, ignore_index=self.ignore_index) 20 | if self.cuda: 21 | criterion = criterion.cuda() 22 | 23 | loss = criterion(logit, target.long()) 24 | 25 | return loss 26 | 27 | if __name__ == "__main__": 28 | loss = SegmentationLosses(cuda=True) 29 | a = torch.rand(1, 3, 7, 7).cuda() 30 | b = torch.rand(1, 7, 7).cuda() 31 | print(loss.CrossEntropyLoss(a, b).item()) 32 | 33 | 34 | 35 | 36 | 37 | -------------------------------------------------------------------------------- /utils/lr_scheduler.py: -------------------------------------------------------------------------------- 1 | ##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 2 | ## Created by: Hang Zhang 3 | ## ECE Department, Rutgers University 4 | ## Email: zhang.hang@rutgers.edu 5 | ## Copyright (c) 2017 6 | ## 7 | ## This source code is licensed under the MIT-style license found in the 8 | ## LICENSE file in the root directory of this source tree 9 | ##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 10 | 11 | import math 12 | 13 | class LR_Scheduler(object): 14 | """Learning Rate Scheduler 15 | 16 | Step mode: ``lr = baselr * 0.1 ^ {floor(epoch-1 / lr_step)}`` 17 | 18 | Cosine mode: ``lr = baselr * 0.5 * (1 + cos(iter/maxiter))`` 19 | 20 | Poly mode: ``lr = baselr * (1 - iter/maxiter) ^ 0.9`` 21 | 22 | Args: 23 | args: 24 | :attr:`args.lr_scheduler` lr scheduler mode (`cos`, `poly`), 25 | :attr:`args.lr` base learning rate, :attr:`args.epochs` number of epochs, 26 | :attr:`args.lr_step` 27 | 28 | iters_per_epoch: number of iterations per epoch 29 | """ 30 | def __init__(self, mode, base_lr, num_epochs, iters_per_epoch=0, 31 | lr_step=0, warmup_epochs=0, min_lr = None): 32 | self.mode = mode 33 | print('Using {} LR Scheduler!'.format(self.mode)) 34 | self.lr = base_lr 35 | if mode == 'step': 36 | assert lr_step 37 | self.lr_step = lr_step 38 | self.iters_per_epoch = iters_per_epoch 39 | self.N = num_epochs * iters_per_epoch 40 | self.epoch = -1 41 | self.warmup_iters = warmup_epochs * iters_per_epoch 42 | self.min_lr = min_lr 43 | 44 | def __call__(self, optimizer, i, epoch, best_pred): 45 | T = epoch * self.iters_per_epoch + i 46 | if self.mode == 'cos': 47 | cos = 0.5 * (1 + math.cos(1.0 * T / self.N * math.pi)) 48 | decay = (1 - self.min_lr) * cos + self.min_lr 49 | lr = self.lr * decay 50 | elif self.mode == 'poly': 51 | lr = self.lr * pow((1 - 1.0 * T / self.N), 0.9) 52 | elif self.mode == 'step': 53 | lr = self.lr * (0.1 ** (epoch // self.lr_step)) 54 | else: 55 | raise NotImplemented 56 | # warm up lr schedule 57 | if self.min_lr is not None: 58 | if lr < self.min_lr: 59 | lr = self.min_lr 60 | if self.warmup_iters > 0 and T < self.warmup_iters: 61 | lr = lr * 1.0 * T / self.warmup_iters 62 | if epoch > self.epoch: 63 | print('\n=>Epoches %i, learning rate = %.4f, \ 64 | previous best = %.4f' % (epoch, lr, best_pred)) 65 | self.epoch = epoch 66 | assert lr >= 0 67 | self._adjust_learning_rate(optimizer, lr) 68 | 69 | def _adjust_learning_rate(self, optimizer, lr): 70 | if len(optimizer.param_groups) == 1: 71 | optimizer.param_groups[0]['lr'] = lr 72 | else: 73 | # enlarge the lr at the head 74 | optimizer.param_groups[0]['lr'] = lr 75 | for i in range(1, len(optimizer.param_groups)): 76 | optimizer.param_groups[i]['lr'] = lr * 10 77 | -------------------------------------------------------------------------------- /utils/metrics.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | 4 | class Evaluator(object): 5 | def __init__(self, num_class): 6 | self.num_class = num_class 7 | self.confusion_matrix = torch.zeros((self.num_class,)*2) 8 | 9 | def Pixel_Accuracy(self): 10 | Acc = torch.diag(self.confusion_matrix).sum() / self.confusion_matrix.sum() 11 | return Acc 12 | 13 | def Pixel_Accuracy_Class(self): 14 | Acc = torch.diag(self.confusion_matrix) / self.confusion_matrix.sum(axis=1) 15 | Acc = self.torch_nanmean(Acc) 16 | return Acc 17 | 18 | def Mean_Intersection_over_Union(self): 19 | MIoU = torch.diag(self.confusion_matrix) / ( 20 | torch.sum(self.confusion_matrix, axis=1) + torch.sum(self.confusion_matrix, axis=0) - 21 | torch.diag(self.confusion_matrix)) 22 | MIoU = self.torch_nanmean(MIoU) 23 | return MIoU.item() 24 | 25 | def Frequency_Weighted_Intersection_over_Union(self): 26 | freq = torch.sum(self.confusion_matrix, axis=1) / torch.sum(self.confusion_matrix) 27 | iu = torch.diag(self.confusion_matrix) / ( 28 | torch.sum(self.confusion_matrix, axis=1) + torch.sum(self.confusion_matrix, axis=0) - 29 | torch.diag(self.confusion_matrix)) 30 | 31 | FWIoU = (freq[freq > 0] * iu[freq > 0]).sum() 32 | return FWIoU 33 | 34 | def _generate_matrix(self, gt_image, pre_image): 35 | mask = (gt_image >= 0) & (gt_image < self.num_class) 36 | label = self.num_class * gt_image[mask].int() + pre_image[mask] 37 | count = torch.bincount(label, minlength=self.num_class**2) 38 | confusion_matrix = count.reshape(self.num_class, self.num_class) 39 | return confusion_matrix 40 | 41 | def add_batch(self, gt_image, pre_image): 42 | assert gt_image.shape == pre_image.shape 43 | self.confusion_matrix += self._generate_matrix(gt_image, pre_image) 44 | 45 | def reset(self): 46 | self.confusion_matrix = torch.zeros((self.num_class,) * 2).cuda() 47 | 48 | 49 | def torch_nanmean(self, x): 50 | num = torch.where(torch.isnan(x), torch.full_like(x, 0), torch.full_like(x, 1)).sum() 51 | value = torch.where(torch.isnan(x), torch.full_like(x, 0), x).sum() 52 | return value / num 53 | 54 | class Evaluator_cpu(object): 55 | def __init__(self, num_class): 56 | self.num_class = num_class 57 | self.confusion_matrix = np.zeros((self.num_class,)*2) 58 | 59 | def Pixel_Accuracy(self): 60 | Acc = np.diag(self.confusion_matrix).sum() / self.confusion_matrix.sum() 61 | return Acc 62 | 63 | def Pixel_Accuracy_Class(self): 64 | Acc = np.diag(self.confusion_matrix) / self.confusion_matrix.sum(axis=1) 65 | Acc = np.nanmean(Acc) 66 | return Acc 67 | 68 | def Mean_Intersection_over_Union(self): 69 | MIoU = np.diag(self.confusion_matrix) / ( 70 | np.sum(self.confusion_matrix, axis=1) + np.sum(self.confusion_matrix, axis=0) - 71 | np.diag(self.confusion_matrix)) 72 | MIoU = np.nanmean(MIoU) 73 | return MIoU 74 | 75 | def Frequency_Weighted_Intersection_over_Union(self): 76 | freq = np.sum(self.confusion_matrix, axis=1) / np.sum(self.confusion_matrix) 77 | iu = np.diag(self.confusion_matrix) / ( 78 | np.sum(self.confusion_matrix, axis=1) + np.sum(self.confusion_matrix, axis=0) - 79 | np.diag(self.confusion_matrix)) 80 | 81 | FWIoU = (freq[freq > 0] * iu[freq > 0]).sum() 82 | return FWIoU 83 | 84 | def _generate_matrix(self, gt_image, pre_image): 85 | mask = (gt_image >= 0) & (gt_image < self.num_class) 86 | label = self.num_class * gt_image[mask].astype('int') + pre_image[mask] 87 | count = np.bincount(label, minlength=self.num_class**2) 88 | confusion_matrix = count.reshape(self.num_class, self.num_class) 89 | return confusion_matrix 90 | 91 | def add_batch(self, gt_image, pre_image): 92 | assert gt_image.shape == pre_image.shape 93 | self.confusion_matrix += self._generate_matrix(gt_image, pre_image) 94 | 95 | def reset(self): 96 | self.confusion_matrix = np.zeros((self.num_class,) * 2) 97 | 98 | -------------------------------------------------------------------------------- /utils/multadds_count.py: -------------------------------------------------------------------------------- 1 | import torch 2 | # Original implementation: 3 | # https://github.com/warmspringwinds/pytorch-segmentation-detection/blob/master/pytorch_segmentation_detection/utils/flops_benchmark.py 4 | 5 | # ---- Public functions 6 | 7 | def comp_multadds(model, input_size=(3,224,224), half=False): 8 | input_size = (1,) + tuple(input_size) 9 | model = model.cuda() 10 | input_data = torch.randn(input_size).cuda() 11 | model = add_flops_counting_methods(model) 12 | if half: 13 | input_data = input_data.half() 14 | model.start_flops_count() 15 | with torch.no_grad(): 16 | _ = model(input_data) 17 | 18 | mult_adds = model.compute_average_flops_cost() / 1e6 19 | return mult_adds 20 | 21 | 22 | def comp_multadds_fw(model, input_data): 23 | model = add_flops_counting_methods(model) 24 | model = model.cuda() 25 | model.start_flops_count() 26 | with torch.no_grad(): 27 | output_data = model(input_data) 28 | 29 | mult_adds = model.compute_average_flops_cost() / 1e6 30 | return mult_adds, output_data 31 | 32 | 33 | def add_flops_counting_methods(net_main_module): 34 | """Adds flops counting functions to an existing model. After that 35 | the flops count should be activated and the model should be run on an input 36 | image. 37 | Example: 38 | fcn = add_flops_counting_methods(fcn) 39 | fcn = fcn.cuda().train() 40 | fcn.start_flops_count() 41 | _ = fcn(batch) 42 | fcn.compute_average_flops_cost() / 1e9 / 2 # Result in GFLOPs per image in batch 43 | Important: dividing by 2 only works for resnet models -- see below for the details 44 | of flops computation. 45 | Attention: we are counting multiply-add as two flops in this work, because in 46 | most resnet models convolutions are bias-free (BN layers act as bias there) 47 | and it makes sense to count muliply and add as separate flops therefore. 48 | This is why in the above example we divide by 2 in order to be consistent with 49 | most modern benchmarks. For example in "Spatially Adaptive Computatin Time for Residual 50 | Networks" by Figurnov et al multiply-add was counted as two flops. 51 | This module computes the average flops which is necessary for dynamic networks which 52 | have different number of executed layers. For static networks it is enough to run the network 53 | once and get statistics (above example). 54 | Implementation: 55 | The module works by adding batch_count to the main module which tracks the sum 56 | of all batch sizes that were run through the network. 57 | Also each convolutional layer of the network tracks the overall number of flops 58 | performed. 59 | The parameters are updated with the help of registered hook-functions which 60 | are being called each time the respective layer is executed. 61 | Parameters 62 | ---------- 63 | net_main_module : torch.nn.Module 64 | Main module containing network 65 | Returns 66 | ------- 67 | net_main_module : torch.nn.Module 68 | Updated main module with new methods/attributes that are used 69 | to compute flops. 70 | """ 71 | 72 | # adding additional methods to the existing module object, 73 | # this is done this way so that each function has access to self object 74 | net_main_module.start_flops_count = start_flops_count.__get__(net_main_module) 75 | net_main_module.stop_flops_count = stop_flops_count.__get__(net_main_module) 76 | net_main_module.reset_flops_count = reset_flops_count.__get__(net_main_module) 77 | net_main_module.compute_average_flops_cost = compute_average_flops_cost.__get__(net_main_module) 78 | 79 | net_main_module.reset_flops_count() 80 | 81 | # Adding varialbles necessary for masked flops computation 82 | net_main_module.apply(add_flops_mask_variable_or_reset) 83 | 84 | return net_main_module 85 | 86 | 87 | def compute_average_flops_cost(self): 88 | """ 89 | A method that will be available after add_flops_counting_methods() is called 90 | on a desired net object. 91 | Returns current mean flops consumption per image. 92 | """ 93 | 94 | batches_count = self.__batch_counter__ 95 | 96 | flops_sum = 0 97 | 98 | for module in self.modules(): 99 | 100 | if isinstance(module, torch.nn.Conv2d) or isinstance(module, torch.nn.Linear): 101 | flops_sum += module.__flops__ 102 | 103 | 104 | return flops_sum / batches_count 105 | 106 | 107 | def start_flops_count(self): 108 | """ 109 | A method that will be available after add_flops_counting_methods() is called 110 | on a desired net object. 111 | Activates the computation of mean flops consumption per image. 112 | Call it before you run the network. 113 | """ 114 | 115 | add_batch_counter_hook_function(self) 116 | 117 | self.apply(add_flops_counter_hook_function) 118 | 119 | 120 | def stop_flops_count(self): 121 | """ 122 | A method that will be available after add_flops_counting_methods() is called 123 | on a desired net object. 124 | Stops computing the mean flops consumption per image. 125 | Call whenever you want to pause the computation. 126 | """ 127 | 128 | remove_batch_counter_hook_function(self) 129 | 130 | self.apply(remove_flops_counter_hook_function) 131 | 132 | 133 | def reset_flops_count(self): 134 | """ 135 | A method that will be available after add_flops_counting_methods() is called 136 | on a desired net object. 137 | Resets statistics computed so far. 138 | """ 139 | 140 | add_batch_counter_variables_or_reset(self) 141 | 142 | self.apply(add_flops_counter_variable_or_reset) 143 | 144 | 145 | def add_flops_mask(module, mask): 146 | def add_flops_mask_func(module): 147 | if isinstance(module, torch.nn.Conv2d) or isinstance(module, torch.nn.Linear): 148 | module.__mask__ = mask 149 | 150 | module.apply(add_flops_mask_func) 151 | 152 | 153 | def remove_flops_mask(module): 154 | module.apply(add_flops_mask_variable_or_reset) 155 | 156 | 157 | # ---- Internal functions 158 | 159 | 160 | def conv_flops_counter_hook(conv_module, input, output): 161 | # Can have multiple inputs, getting the first one 162 | input = input[0] 163 | 164 | batch_size = input.shape[0] 165 | output_height, output_width = output.shape[2:] 166 | 167 | kernel_height, kernel_width = conv_module.kernel_size 168 | in_channels = conv_module.in_channels 169 | out_channels = conv_module.out_channels 170 | 171 | conv_per_position_flops = (kernel_height * kernel_width * in_channels * out_channels) / conv_module.groups 172 | 173 | active_elements_count = batch_size * output_height * output_width 174 | 175 | if conv_module.__mask__ is not None: 176 | # (b, 1, h, w) 177 | flops_mask = conv_module.__mask__.expand(batch_size, 1, output_height, output_width) 178 | active_elements_count = flops_mask.sum() 179 | 180 | overall_conv_flops = conv_per_position_flops * active_elements_count 181 | 182 | bias_flops = 0 183 | 184 | if conv_module.bias is not None: 185 | bias_flops = out_channels * active_elements_count 186 | 187 | overall_flops = overall_conv_flops + bias_flops 188 | 189 | conv_module.__flops__ += overall_flops 190 | 191 | 192 | def linear_flops_counter_hook(linear_module, input, output): 193 | 194 | input = input[0] 195 | batch_size = input.shape[0] 196 | overall_flops = linear_module.in_features * linear_module.out_features * batch_size 197 | 198 | # bias_flops = 0 199 | 200 | # if conv_module.bias is not None: 201 | # bias_flops = out_channels * active_elements_count 202 | 203 | # overall_flops = overall_conv_flops + bias_flops 204 | 205 | linear_module.__flops__ += overall_flops 206 | 207 | 208 | def batch_counter_hook(module, input, output): 209 | # Can have multiple inputs, getting the first one 210 | input = input[0] 211 | 212 | batch_size = input.shape[0] 213 | 214 | module.__batch_counter__ += batch_size 215 | 216 | 217 | def add_batch_counter_variables_or_reset(module): 218 | module.__batch_counter__ = 0 219 | 220 | 221 | def add_batch_counter_hook_function(module): 222 | if hasattr(module, '__batch_counter_handle__'): 223 | return 224 | 225 | handle = module.register_forward_hook(batch_counter_hook) 226 | module.__batch_counter_handle__ = handle 227 | 228 | 229 | def remove_batch_counter_hook_function(module): 230 | if hasattr(module, '__batch_counter_handle__'): 231 | module.__batch_counter_handle__.remove() 232 | 233 | del module.__batch_counter_handle__ 234 | 235 | 236 | def add_flops_counter_variable_or_reset(module): 237 | if isinstance(module, torch.nn.Conv2d) or isinstance(module, torch.nn.Linear): 238 | module.__flops__ = 0 239 | 240 | 241 | def add_flops_counter_hook_function(module): 242 | if isinstance(module, torch.nn.Conv2d): 243 | if hasattr(module, '__flops_handle__'): 244 | return 245 | 246 | handle = module.register_forward_hook(conv_flops_counter_hook) 247 | module.__flops_handle__ = handle 248 | elif isinstance(module, torch.nn.Linear): 249 | 250 | if hasattr(module, '__flops_handle__'): 251 | return 252 | 253 | handle = module.register_forward_hook(linear_flops_counter_hook) 254 | module.__flops_handle__ = handle 255 | 256 | 257 | def remove_flops_counter_hook_function(module): 258 | if isinstance(module, torch.nn.Conv2d) or isinstance(module, torch.nn.Linear): 259 | 260 | if hasattr(module, '__flops_handle__'): 261 | module.__flops_handle__.remove() 262 | 263 | del module.__flops_handle__ 264 | 265 | 266 | # --- Masked flops counting 267 | 268 | 269 | # Also being run in the initialization 270 | def add_flops_mask_variable_or_reset(module): 271 | if isinstance(module, torch.nn.Conv2d) or isinstance(module, torch.nn.Linear): 272 | module.__mask__ = None 273 | 274 | -------------------------------------------------------------------------------- /utils/saver.py: -------------------------------------------------------------------------------- 1 | import os 2 | import shutil 3 | import torch 4 | from collections import OrderedDict 5 | import glob 6 | 7 | class Saver(object): 8 | 9 | def __init__(self, args): 10 | self.args = args 11 | self.directory = os.path.join('run', args.dataset, args.checkname) 12 | self.runs = sorted(glob.glob(os.path.join(self.directory, 'experiment_*'))) 13 | run_id = max([int(x.split('_')[-1]) for x in self.runs]) + 1 if self.runs else 0 14 | 15 | self.experiment_dir = os.path.join(self.directory, 'experiment_{}'.format(str(run_id))) 16 | if not os.path.exists(self.experiment_dir): 17 | try: 18 | os.makedirs(self.experiment_dir) 19 | except: 20 | print('fold exists') 21 | 22 | def save_checkpoint(self, state, is_best, filename='checkpoint.pth.tar'): 23 | """Saves checkpoint to disk""" 24 | filename = os.path.join(self.experiment_dir, filename) 25 | torch.save(state, filename) 26 | if is_best: 27 | best_pred = state['best_pred'] 28 | with open(os.path.join(self.experiment_dir, 'best_pred.txt'), 'w') as f: 29 | f.write(str(best_pred)) 30 | if self.runs: 31 | previous_miou = [0.0] 32 | for run in self.runs: 33 | run_id = run.split('_')[-1] 34 | path = os.path.join(self.directory, 'experiment_{}'.format(str(run_id)), 'best_pred.txt') 35 | if os.path.exists(path): 36 | with open(path, 'r') as f: 37 | miou = float(f.readline()) 38 | previous_miou.append(miou) 39 | else: 40 | continue 41 | max_miou = max(previous_miou) 42 | if best_pred > max_miou: 43 | shutil.copyfile(filename, os.path.join(self.directory, 'model_best.pth.tar')) 44 | else: 45 | shutil.copyfile(filename, os.path.join(self.directory, 'model_best.pth.tar')) 46 | 47 | def save_experiment_config(self): 48 | logfile = os.path.join(self.experiment_dir, 'parameters.txt') 49 | log_file = open(logfile, 'w') 50 | p = OrderedDict() 51 | p['network'] = self.args.network 52 | p['datset'] = self.args.dataset 53 | # if self.args.use_amp: 54 | # p['lr'] = self.args.lr 55 | # p['lr_scheduler'] = self.args.lr_scheduler 56 | # p['epoch'] = self.args.epochs 57 | 58 | 59 | for key, val in p.items(): 60 | log_file.write(key + ':' + str(val) + '\n') 61 | log_file.close() -------------------------------------------------------------------------------- /utils/summaries.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | from torchvision.utils import make_grid 4 | from tensorboardX import SummaryWriter 5 | from dataloaders.utils import decode_seg_map_sequence 6 | 7 | class TensorboardSummary(object): 8 | def __init__(self, directory): 9 | self.directory = directory 10 | 11 | def create_summary(self): 12 | writer = SummaryWriter(logdir=os.path.join(self.directory)) 13 | return writer 14 | 15 | def visualize_image(self, writer, dataset, image, target, output, global_step): 16 | grid_image = make_grid(image[:3].clone().cpu().data, 3, normalize=True) 17 | writer.add_image('Image', grid_image, global_step) 18 | grid_image = make_grid(decode_seg_map_sequence(torch.max(output[:3], 1)[1].detach().cpu().numpy(), 19 | dataset=dataset), 3, normalize=False, range=(0, 255)) 20 | writer.add_image('Predicted label', grid_image, global_step) 21 | grid_image = make_grid(decode_seg_map_sequence(torch.squeeze(target[:3], 1).detach().cpu().numpy(), 22 | dataset=dataset), 3, normalize=False, range=(0, 255)) 23 | writer.add_image('Groundtruth label', grid_image, global_step) 24 | --------------------------------------------------------------------------------