├── .gitignore ├── LICENSE ├── README.md ├── backbones ├── __init__.py └── aognet │ ├── AOG.py │ ├── __init__.py │ ├── aognet_singlescale.py │ ├── config.py │ ├── operator_basic.py │ └── operator_singlescale.py ├── configs ├── aognet_cifar100_1M.yaml ├── aognet_imagenet_12M.yaml └── aognet_imagenet_40M.yaml ├── examples ├── kill_all_python.sh ├── test_fp16.sh └── train_fp16.sh ├── images ├── teaser-imagenet-dissection.png └── teaser.png ├── requirements.txt └── tools ├── __init__.py ├── get_cifar.py ├── main_fp16.py └── smoothing.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | env/ 12 | build/ 13 | develop-eggs/ 14 | dist/ 15 | downloads/ 16 | eggs/ 17 | .eggs/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | *.egg-info/ 23 | .installed.cfg 24 | *.egg 25 | 26 | # PyInstaller 27 | # Usually these files are written by a python script from a template 28 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 29 | *.manifest 30 | *.spec 31 | 32 | # Installer logs 33 | pip-log.txt 34 | pip-delete-this-directory.txt 35 | 36 | # Unit test / coverage reports 37 | htmlcov/ 38 | .tox/ 39 | .coverage 40 | .coverage.* 41 | .cache 42 | .idea 43 | nosetests.xml 44 | coverage.xml 45 | *,cover 46 | .hypothesis/ 47 | 48 | # Translations 49 | *.mo 50 | *.pot 51 | 52 | # Django stuff: 53 | #*.log 54 | temp*.log 55 | local_settings.py 56 | 57 | # Flask stuff: 58 | instance/ 59 | .webassets-cache 60 | 61 | # Scrapy stuff: 62 | .scrapy 63 | 64 | # Sphinx documentation 65 | docs/_build/ 66 | 67 | # PyBuilder 68 | target/ 69 | 70 | # IPython Notebook 71 | .ipynb_checkpoints 72 | 73 | # pyenv 74 | .python-version 75 | 76 | # celery beat schedule file 77 | celerybeat-schedule 78 | 79 | # dotenv 80 | .env 81 | 82 | # virtualenv 83 | venv/ 84 | ENV/ 85 | 86 | # Spyder project settings 87 | .spyderproject 88 | 89 | # Rope project settings 90 | .ropeproject 91 | 92 | # vscode 93 | .vscode 94 | 95 | # this project 96 | data/ 97 | datasets/ 98 | results/ 99 | nohup.out 100 | *.tar 101 | *.log 102 | .DS_Store 103 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | RESEARCH ONLY LICENSE 2 | Copyright (c) 2018-2019 North Carolina State University. 3 | All rights reserved. 4 | Redistribution and use in source and binary forms, with or without modification, are permitted provided 5 | that the following conditions are met: 6 | 1. Redistributions and use are permitted for internal research purposes only, and commercial use 7 | is strictly prohibited under this license. Inquiries regarding commercial use should be directed to the 8 | Office of Research Commercialization at North Carolina State University, 919-215-7199, 9 | https://research.ncsu.edu/commercialization/contact/, commercialization@ncsu.edu . 10 | 2. Commercial use means the sale, lease, export, transfer, conveyance or other distribution to a 11 | third party for financial gain, income generation or other commercial purposes of any kind, whether 12 | direct or indirect. Commercial use also means providing a service to a third party for financial gain, 13 | income generation or other commercial purposes of any kind, whether direct or indirect. 14 | 3. Redistributions of source code must retain the above copyright notice, this list of conditions and 15 | the following disclaimer. 16 | 4. Redistributions in binary form must reproduce the above copyright notice, this list of conditions 17 | and the following disclaimer in the documentation and/or other materials provided with the 18 | distribution. 19 | 5. The names “North Carolina State University”, “NCSU” and any trade-name, personal name, 20 | trademark, trade device, service mark, symbol, image, icon, or any abbreviation, contraction or 21 | simulation thereof owned by North Carolina State University must not be used to endorse or promote 22 | products derived from this software without prior written permission. For written permission, please 23 | contact trademarks@ncsu.edu. 24 | Disclaimer: THIS SOFTWARE IS PROVIDED “AS IS” AND ANY EXPRESSED OR IMPLIED WARRANTIES, 25 | INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR 26 | A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NORTH CAROLINA STATE UNIVERSITY BE 27 | LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES 28 | (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, 29 | DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF 30 | LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR 31 | OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE 32 | POSSIBILITY OF SUCH DAMAGE. 33 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | 2 | # AOGNets: Compositional Grammatical Architectures for Deep Learning 3 | 4 | This project provides source code for our CVPR19 paper [AOGNets](https://arxiv.org/abs/1711.05847). 5 | The code is still under refactoring. Please stay tuned. 6 | 7 | ![alt text](images/teaser.png "AOG building block and ImageNet performance") 8 | 9 | 10 | ## Installation 11 | 12 | ### Requirements 13 | 14 | Ubuntu 16.04 LTS (although not tested yet, other OS should work too if PyTorch can be installed successfully) 15 | 16 | Python 3 ([Anaconda](https://www.anaconda.com/) is recommended) 17 | 18 | CUDA 9 or newer 19 | 20 | PyTorch 0.4 or newer 21 | 22 | NVIDIA [APEX](https://github.com/NVIDIA/apex) 23 | 24 | NVIDIA [NCCL](https://docs.nvidia.com/deeplearning/sdk/nccl-install-guide/index.html). 25 | 26 | ### Clone the repo 27 | ``` 28 | $ git clone https://github.com/iVMCL/AOGNets.git 29 | $ cd AOGNets 30 | $ pip install requirements.txt 31 | ``` 32 | 33 | ### Some tweaks 34 | 35 | USE pillow-simd to speed up pytorch image loader (assume Anaconda is used) 36 | 37 | ``` 38 | $ pip uninstall pillow 39 | $ conda uninstall --force jpeg libtiff -y 40 | $ conda install -c conda-forge libjpeg-turbo 41 | $ CC="cc -mavx2" pip install --no-cache-dir -U --force-reinstall --no-binary :all: --compile pillow-simd 42 | ``` 43 | 44 | ## ImageNet dataset preparation 45 | 46 | - Download the ImageNet dataset to YOUR_IMAGENET_PATH and move validation images to labeled subfolders 47 | - The [script](https://raw.githubusercontent.com/soumith/imagenetloader.torch/master/valprep.sh) may be helpful. 48 | 49 | - Create a datasets subfolder under your cloned AOGNets and a symlink to the ImageNet dataset 50 | 51 | ``` 52 | $ cd AOGNets 53 | $ mkdir datasets 54 | $ ln -s PATH_TO_YOUR_IMAGENET ./datasets/ 55 | ``` 56 | 57 | ## Perform training on ImageNet dataset 58 | 59 | NVIDIA Apex is used for FP16 training. 60 | 61 | E.g., 62 | 63 | ``` 64 | $ cd AOGNets 65 | $ ./examples/train_fp16.sh aognet_s configs/aognet_imagenet_12M.yaml first_try 66 | ``` 67 | 68 | See more configuration files in AOGNets/configs. Change the GPU settings in train_fp16.sh 69 | 70 | 71 | ## Perform testing with pretrained models 72 | 73 | - Pretrained [AOGNet_12M in ImageNet](https://drive.google.com/open?id=1MTPFR8C9tCXFNeYgn9NqZ3wOJt8ZeMm7) on Google Drive 74 | 75 | - Pretrained [AOGNet_40M in ImageNet](https://drive.google.com/open?id=1t7gGiNcP8L6TSzLDHg8qcb8G_x-nlIfV) on Google Drive 76 | 77 | - More coming soon 78 | 79 | - Remarks: The provided pretrained models are obtained using the latest refactored code and the performance are slightly different from the results in the paper. 80 | 81 | E.g., 82 | 83 | ``` 84 | $ cd AOGNets 85 | $ ./examples/test_fp16.sh aognet_s AOGNet_12M_PATH 86 | ``` 87 | 88 | ## Citations 89 | Please consider citing the AOGNets paper in your publications if it helps your research. 90 | ``` 91 | @inproceedings{AOGNets, 92 | author = {Xilai Li and Xi Song and Tianfu Wu}, 93 | title = {AOGNets: Compositional Grammatical Architectures for Deep Learning}, 94 | booktitle = {{IEEE} Conference on Computer Vision and Pattern Recognition, {CVPR}}, 95 | year = {2019}, 96 | url = {https://arxiv.org/pdf/1711.05847.pdf} 97 | } 98 | ``` 99 | 100 | ## Contact 101 | 102 | Please feel free to report issues and any related problems to Xilai Li (xli47 at ncsu dot edu), Xi Song (xsong.lhi at gmail.com) and Tianfu Wu (twu19 at ncsu dot edu). 103 | 104 | ## License 105 | 106 | AOGNets related codes are under [RESEARCH ONLY LICENSE](./LICENSE). 107 | -------------------------------------------------------------------------------- /backbones/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/iVMCL/AOGNets/cddba00e97d6f74d2e7bfce50fd7aea9630ff996/backbones/__init__.py -------------------------------------------------------------------------------- /backbones/aognet/AOG.py: -------------------------------------------------------------------------------- 1 | """ RESEARCH ONLY LICENSE 2 | Copyright (c) 2018-2019 North Carolina State University. 3 | All rights reserved. 4 | Redistribution and use in source and binary forms, with or without modification, are permitted provided 5 | that the following conditions are met: 6 | 1. Redistributions and use are permitted for internal research purposes only, and commercial use 7 | is strictly prohibited under this license. Inquiries regarding commercial use should be directed to the 8 | Office of Research Commercialization at North Carolina State University, 919-215-7199, 9 | https://research.ncsu.edu/commercialization/contact/, commercialization@ncsu.edu . 10 | 2. Commercial use means the sale, lease, export, transfer, conveyance or other distribution to a 11 | third party for financial gain, income generation or other commercial purposes of any kind, whether 12 | direct or indirect. Commercial use also means providing a service to a third party for financial gain, 13 | income generation or other commercial purposes of any kind, whether direct or indirect. 14 | 3. Redistributions of source code must retain the above copyright notice, this list of conditions and 15 | the following disclaimer. 16 | 4. Redistributions in binary form must reproduce the above copyright notice, this list of conditions 17 | and the following disclaimer in the documentation and/or other materials provided with the 18 | distribution. 19 | 5. The names “North Carolina State University”, “NCSU” and any trade-name, personal name, 20 | trademark, trade device, service mark, symbol, image, icon, or any abbreviation, contraction or 21 | simulation thereof owned by North Carolina State University must not be used to endorse or promote 22 | products derived from this software without prior written permission. For written permission, please 23 | contact trademarks@ncsu.edu. 24 | Disclaimer: THIS SOFTWARE IS PROVIDED “AS IS” AND ANY EXPRESSED OR IMPLIED WARRANTIES, 25 | INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR 26 | A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NORTH CAROLINA STATE UNIVERSITY BE 27 | LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES 28 | (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, 29 | DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF 30 | LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR 31 | OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE 32 | POSSIBILITY OF SUCH DAMAGE. 33 | """ 34 | 35 | # -*- coding: utf-8 -*- 36 | from __future__ import absolute_import 37 | from __future__ import division 38 | from __future__ import print_function # force to use print as function print(args) 39 | from __future__ import unicode_literals 40 | 41 | from math import ceil, floor 42 | from collections import deque 43 | import numpy as np 44 | import os 45 | import random 46 | import math 47 | import copy 48 | 49 | 50 | def get_aog(grid_ht, grid_wd, min_size=1, max_split=2, 51 | not_use_large_TerminalNode=False, turn_off_size_ratio_TerminalNode=1./4., 52 | not_use_intermediate_TerminalNode=False, 53 | use_root_TerminalNode=True, use_tnode_as_alpha_channel=0, 54 | use_tnode_topdown_connection=False, 55 | use_tnode_bottomup_connection=False, 56 | use_tnode_bottomup_connection_layerwise=False, 57 | use_tnode_bottomup_connection_sequential=False, 58 | use_node_lateral_connection=False, # not include T-nodes 59 | use_node_lateral_connection_1=False, # include T-nodes 60 | use_super_OrNode=False, 61 | remove_single_child_or_node=False, 62 | remove_symmetric_children_of_or_node=0, 63 | mark_symmetric_syntatic_subgraph = False, 64 | max_children_kept_for_or=1000): 65 | aog_param = Param(grid_ht=grid_ht, grid_wd=grid_wd, min_size=min_size, max_split=max_split, 66 | not_use_large_TerminalNode=not_use_large_TerminalNode, 67 | turn_off_size_ratio_TerminalNode=turn_off_size_ratio_TerminalNode, 68 | not_use_intermediate_TerminalNode=not_use_intermediate_TerminalNode, 69 | use_root_TerminalNode=use_root_TerminalNode, 70 | use_tnode_as_alpha_channel=use_tnode_as_alpha_channel, 71 | use_tnode_topdown_connection=use_tnode_topdown_connection, 72 | use_tnode_bottomup_connection=use_tnode_bottomup_connection, 73 | use_tnode_bottomup_connection_layerwise=use_tnode_bottomup_connection_layerwise, 74 | use_tnode_bottomup_connection_sequential=use_tnode_bottomup_connection_sequential, 75 | use_node_lateral_connection=use_node_lateral_connection, 76 | use_node_lateral_connection_1=use_node_lateral_connection_1, 77 | use_super_OrNode=use_super_OrNode, 78 | remove_single_child_or_node=remove_single_child_or_node, 79 | mark_symmetric_syntatic_subgraph = mark_symmetric_syntatic_subgraph, 80 | remove_symmetric_children_of_or_node=remove_symmetric_children_of_or_node, 81 | max_children_kept_for_or=max_children_kept_for_or) 82 | aog = AOGrid(aog_param) 83 | aog.Create() 84 | return aog 85 | 86 | 87 | class NodeType(object): 88 | OrNode = "OrNode" 89 | AndNode = "AndNode" 90 | TerminalNode = "TerminalNode" 91 | Unknow = "Unknown" 92 | 93 | 94 | class SplitType(object): 95 | HorSplit = "Hor" 96 | VerSplit = "Ver" 97 | Unknown = "Unknown" 98 | 99 | 100 | class Param(object): 101 | """Input parameters for creating an AOG 102 | """ 103 | 104 | def __init__(self, grid_ht=3, grid_wd=3, max_split=2, min_size=1, control_side_length=False, 105 | overlap_ratio=0., use_root_TerminalNode=False, 106 | not_use_large_TerminalNode=False, turn_off_size_ratio_TerminalNode=0.5, 107 | not_use_intermediate_TerminalNode= False, 108 | use_tnode_as_alpha_channel=0, 109 | use_tnode_topdown_connection=False, 110 | use_tnode_bottomup_connection=False, 111 | use_tnode_bottomup_connection_layerwise=False, 112 | use_tnode_bottomup_connection_sequential=False, 113 | use_node_lateral_connection=False, 114 | use_node_lateral_connection_1=False, 115 | use_super_OrNode=False, 116 | remove_single_child_or_node=False, 117 | remove_symmetric_children_of_or_node=0, 118 | mark_symmetric_syntatic_subgraph=False, 119 | max_children_kept_for_or=100): 120 | self.grid_ht = grid_ht 121 | self.grid_wd = grid_wd 122 | self.max_split = max_split # maximum number of child nodes when splitting an AND-node 123 | self.min_size = min_size # minimum side length or minimum area allowed 124 | self.control_side_length = control_side_length 125 | self.overlap_ratio = overlap_ratio 126 | self.use_root_terminal_node = use_root_TerminalNode 127 | self.not_use_large_terminal_node = not_use_large_TerminalNode 128 | self.turn_off_size_ratio_terminal_node = turn_off_size_ratio_TerminalNode 129 | self.not_use_intermediate_TerminalNode = not_use_intermediate_TerminalNode 130 | self.use_tnode_as_alpha_channel = use_tnode_as_alpha_channel 131 | self.use_tnode_topdown_connection = use_tnode_topdown_connection 132 | self.use_tnode_bottomup_connection = use_tnode_bottomup_connection 133 | self.use_tnode_bottomup_connection_layerwise = use_tnode_bottomup_connection_layerwise 134 | self.use_node_lateral_connection = use_node_lateral_connection 135 | self.use_node_lateral_connection_1 = use_node_lateral_connection_1 136 | self.use_tnode_bottomup_connection_sequential = use_tnode_bottomup_connection_sequential 137 | assert 1 >= self.use_node_lateral_connection_1 + self.use_node_lateral_connection + \ 138 | self.use_tnode_topdown_connection + self.use_tnode_bottomup_connection + \ 139 | self.use_tnode_bottomup_connection_layerwise + self.use_tnode_bottomup_connection_sequential, \ 140 | 'only one type of node hierarchy can be used' 141 | self.use_super_OrNode = use_super_OrNode 142 | self.remove_single_child_or_node = remove_single_child_or_node 143 | self.remove_symmetric_children_of_or_node = remove_symmetric_children_of_or_node #0: not, 1: keep left, 2: keep right 144 | self.mark_symmetric_syntatic_subgraph = mark_symmetric_syntatic_subgraph # true, only mark the nodes which will be removed based on remove_symmetric_children_of_or_node 145 | self.max_children_kept_for_or = max_children_kept_for_or # how many child nodes kept for an OR-node 146 | 147 | self.get_tag() 148 | 149 | def get_tag(self): 150 | # identifier useful for naming a particular aog 151 | self.tag = '{}-{}-{}-{}-{}-{}-{}-{}-{}-{}-{}-{}-{}-{}-{}-{}-{}-{}-{}-{}-{}-{}'.format( 152 | self.grid_wd, self.grid_ht, self.max_split, 153 | self.min_size, self.control_side_length, 154 | self.overlap_ratio, 155 | self.use_root_terminal_node, 156 | self.not_use_large_terminal_node, 157 | self.turn_off_size_ratio_terminal_node, 158 | self.not_use_intermediate_TerminalNode, 159 | self.use_tnode_as_alpha_channel, 160 | self.use_tnode_topdown_connection, 161 | self.use_tnode_bottomup_connection, 162 | self.use_tnode_bottomup_connection_layerwise, 163 | self.use_tnode_bottomup_connection_sequential, 164 | self.use_node_lateral_connection, 165 | self.use_node_lateral_connection_1, 166 | self.use_super_OrNode, 167 | self.remove_single_child_or_node, 168 | self.remove_symmetric_children_of_or_node, 169 | self.mark_symmetric_syntatic_subgraph, 170 | self.max_children_kept_for_or) 171 | 172 | 173 | class Rect(object): 174 | """A simple rectangle 175 | """ 176 | 177 | def __init__(self, x1=0, y1=0, x2=0, y2=0): 178 | self.x1 = x1 179 | self.y1 = y1 180 | self.x2 = x2 181 | self.y2 = y2 182 | 183 | def __eq__(self, other): 184 | """Override the default Equals behavior""" 185 | if isinstance(other, self.__class__): 186 | return self.__dict__ == other.__dict__ 187 | return NotImplemented 188 | 189 | def __ne__(self, other): 190 | """Define a non-equality test""" 191 | if isinstance(other, self.__class__): 192 | return not self.__eq__(other) 193 | return NotImplemented 194 | 195 | def __hash__(self): 196 | """Override the default hash behavior (that returns id or the object)""" 197 | return hash(tuple(sorted(self.__dict__.items()))) 198 | 199 | def Width(self): 200 | return self.x2 - self.x1 + 1 201 | 202 | def Height(self): 203 | return self.y2 - self.y1 + 1 204 | 205 | def Area(self): 206 | return self.Width() * self.Height() 207 | 208 | def MinLength(self): 209 | return min(self.Width(), self.Height()) 210 | 211 | def IsOverlap(self, other): 212 | assert isinstance(other, self.__class__) 213 | 214 | x1 = max(self.x1, other.x1) 215 | x2 = min(self.x2, other.x2) 216 | if x1 > x2: 217 | return False 218 | 219 | y1 = max(self.y1, other.y1) 220 | y2 = min(self.y2, other.y2) 221 | if y1 > y2: 222 | return False 223 | 224 | return True 225 | 226 | def IsSame(self, other): 227 | assert isinstance(other, self.__class__) 228 | 229 | return self.Width() == other.Width() and self.Height() == other.Height() 230 | 231 | 232 | class Node(object): 233 | """Types of nodes in an AOG 234 | AND-node (structural decomposition), 235 | OR-node (alternative decompositions), 236 | TERMINAL-node (link to data). 237 | """ 238 | 239 | def __init__(self, node_id=-1, node_type=NodeType.Unknow, rect_idx=-1, 240 | child_ids=None, parent_ids=None, 241 | split_type=SplitType.Unknown, split_step1=0, split_step2=0, is_symm=False, 242 | ancestors_ids=None): 243 | self.id = node_id 244 | self.node_type = node_type 245 | self.rect_idx = rect_idx 246 | self.child_ids = child_ids if child_ids is not None else [] 247 | self.parent_ids = parent_ids if parent_ids is not None else [] 248 | self.ancestors_ids = ancestors_ids if ancestors_ids is not None else [] # root or-node exlusive 249 | self.split_type = split_type 250 | self.split_step1 = split_step1 251 | self.split_step2 = split_step2 252 | 253 | # some utility variables used in object detection models 254 | self.on_off = True 255 | self.out_edge_visited_count = [] 256 | self.which_classes_visited = {} # key=class_name, val=frequency 257 | self.npaths = 0.0 258 | self.is_symmetric = False 259 | self.has_dbl_counting = False 260 | 261 | def __eq__(self, other): 262 | if isinstance(other, self.__class__): 263 | res = ((self.node_type == other.node_type) and (self.rect_idx == other.rect_idx)) 264 | if res: 265 | if self.node_type != NodeType.AndNode: 266 | return True 267 | else: 268 | if self.split_type != SplitType.Unknown: 269 | return (self.split_type == other.split_type) and (self.split_step1 == other.split_step1) and \ 270 | (self.split_step2 == other.split_step2) 271 | else: 272 | return (set(self.child_ids) == set(other.child_ids)) 273 | 274 | return False 275 | 276 | return NotImplemented 277 | 278 | def __ne__(self, other): 279 | if isinstance(other, self.__class__): 280 | return not self.__eq__(other) 281 | return NotImplemented 282 | 283 | def __hash__(self): 284 | """Override the default hash behavior (that returns id or the object)""" 285 | return hash(tuple(sorted(self.__dict__.items()))) 286 | 287 | 288 | class AOGrid(object): 289 | """The AOGrid defines a Directed Acyclic And-Or Graph 290 | which is used to explore/unfold the space of latent structures 291 | of a grid (e.g., a 7 * 7 grid for a 100 * 200 lattice) 292 | """ 293 | 294 | def __init__(self, param_): 295 | assert isinstance(param_, Param) 296 | self.param = param_ 297 | assert self.param.max_split > 1 298 | self.primitive_set = [] 299 | self.node_set = [] 300 | self.num_TNodes = 0 301 | self.num_AndNodes = 0 302 | self.num_OrNodes = 0 303 | self.DFS = [] 304 | self.BFS = [] 305 | self.node_DFS = {} 306 | self.node_BFS = {} 307 | self.OrNodeIdxInBFS = {} 308 | self.TNodeIdxInBFS = {} 309 | 310 | # for color consistency in viz 311 | self.TNodeColors = {} 312 | 313 | def _AddPrimitve(self, rect): 314 | assert isinstance(rect, Rect) 315 | 316 | if rect in self.primitive_set: 317 | return self.primitive_set.index(rect) 318 | 319 | self.primitive_set.append(rect) 320 | 321 | return len(self.primitive_set) - 1 322 | 323 | def _AddNode(self, node, not_create_if_existed=True): 324 | assert isinstance(node, Node) 325 | 326 | if node in self.node_set and not_create_if_existed: 327 | node = self.node_set[self.node_set.index(node)] 328 | return False, node 329 | 330 | node.id = len(self.node_set) 331 | if node.node_type == NodeType.AndNode: 332 | self.num_AndNodes += 1 333 | elif node.node_type == NodeType.OrNode: 334 | self.num_OrNodes += 1 335 | elif node.node_type == NodeType.TerminalNode: 336 | self.num_TNodes += 1 337 | else: 338 | raise NotImplementedError 339 | 340 | self.node_set.append(node) 341 | 342 | return True, node 343 | 344 | def _DoSplit(self, rect): 345 | assert isinstance(rect, Rect) 346 | 347 | if self.param.control_side_length: 348 | return rect.Width() >= self.param.min_size and rect.Height() >= self.param.min_size 349 | 350 | return rect.Area() > self.param.min_size 351 | 352 | def _SplitStep(self, sz): 353 | if self.param.control_side_length: 354 | return self.param.min_size 355 | 356 | if sz >= self.param.min_size: 357 | return 1 358 | else: 359 | return int(ceil(self.param.min_size / sz)) 360 | 361 | def _DFS(self, id, q, visited): 362 | if visited[id] == 1: 363 | raise RuntimeError 364 | 365 | visited[id] = 1 366 | for i in self.node_set[id].child_ids: 367 | if visited[i] < 2: 368 | q, visited = self._DFS(i, q, visited) 369 | 370 | if self.node_set[id].on_off: 371 | q.append(id) 372 | 373 | visited[id] = 2 374 | 375 | return q, visited 376 | 377 | def _BFS(self, id, q, visited): 378 | q = [id] 379 | visited[id] = 0 # count indegree 380 | i = 0 381 | while i < len(q): 382 | node = self.node_set[q[i]] 383 | for j in node.child_ids: 384 | visited[j] += 1 385 | if j not in q: 386 | q.append(j) 387 | 388 | i += 1 389 | 390 | q = [id] 391 | i = 0 392 | while i < len(q): 393 | node = self.node_set[q[i]] 394 | for j in node.child_ids: 395 | visited[j] -= 1 396 | if visited[j] == 0: 397 | q.append(j) 398 | i += 1 399 | 400 | return q, visited 401 | 402 | def _countPaths(self, s, t, npaths): 403 | if s.id == t.id: 404 | return 1.0 405 | else: 406 | if not npaths[s.id]: 407 | rect = self.primitive_set[s.rect_idx] 408 | #ids1 = set(s.ancestors_ids) 409 | ids1 = set(s.parent_ids) 410 | num_shared = 0 411 | for c in s.child_ids: 412 | ch = self.node_set[c] 413 | ch_rect = self.primitive_set[ch.rect_idx] 414 | #ids2 = ch.ancestors_ids 415 | ids2 = ch.parent_ids 416 | 417 | if s.node_type == NodeType.AndNode and ch.node_type == NodeType.AndNode and \ 418 | rect.Width() == ch_rect.Width() and rect.Height() == ch_rect.Height(): 419 | continue 420 | if s.node_type == NodeType.OrNode and \ 421 | ((ch.node_type == NodeType.OrNode) or \ 422 | (ch.node_type == NodeType.TerminalNode and (rect.Area() < ch_rect.Area()) )): 423 | continue 424 | 425 | npaths[s.id] += self._countPaths(ch, t, npaths) 426 | return npaths[s.id] 427 | 428 | def _AssignParentIds(self): 429 | for i in range(len(self.node_set)): 430 | self.node_set[i].parent_ids = [] 431 | 432 | for node in self.node_set: 433 | for i in node.child_ids: 434 | self.node_set[i].parent_ids.append(node.id) 435 | 436 | for i in range(len(self.node_set)): 437 | self.node_set[i].parent_ids = list(set(self.node_set[i].parent_ids)) 438 | 439 | def _AssignAncestorsIds(self): 440 | self._AssignParentIds() 441 | 442 | assert len(self.BFS) == len(self.node_set) 443 | self.node_set[self.BFS[0]].ancestors_ids = [] 444 | 445 | for nid in self.BFS[1:]: 446 | node = self.node_set[nid] 447 | rect = self.primitive_set[node.rect_idx] 448 | ancestors = [] 449 | for pid in node.parent_ids: 450 | p = self.node_set[pid] 451 | p_rect = self.primitive_set[p.rect_idx] 452 | equal_size = rect.Width() == p_rect.Width() and \ 453 | rect.Height() == p_rect.Height() 454 | # AND-to-AND lateral path 455 | if node.node_type == NodeType.AndNode and p.node_type == NodeType.AndNode and \ 456 | equal_size: 457 | continue 458 | # OR-to-OR/T lateral path 459 | if node.node_type == NodeType.OrNode and \ 460 | ((p.node_type == NodeType.OrNode and equal_size) or \ 461 | (p.node_type == NodeType.TerminalNode and (rect.Area() < p_rect.Area()) )): 462 | continue 463 | for ppid in p.ancestors_ids: 464 | if ppid != self.BFS[0] and ppid not in ancestors: 465 | ancestors.append(ppid) 466 | ancestors.append(pid) 467 | self.node_set[nid].ancestors_ids = list(set(ancestors)) 468 | 469 | def _Postprocessing(self, root_id): 470 | self.DFS = [] 471 | self.BFS = [] 472 | visited = np.zeros(len(self.node_set)) 473 | self.DFS, _ = self._DFS(root_id, self.DFS, visited) 474 | visited = np.zeros(len(self.node_set)) 475 | self.BFS, _ = self._BFS(root_id, self.BFS, visited) 476 | self._AssignAncestorsIds() 477 | 478 | def _FindNodeIdWithGivenRect(self, rect, node_type): 479 | for node in self.node_set: 480 | if node.node_type != node_type: 481 | continue 482 | if rect == self.primitive_set[node.rect_idx]: 483 | return node.id 484 | 485 | return -1 486 | 487 | def _add_tnode_topdown_connection(self): 488 | 489 | assert self.param.use_root_terminal_node 490 | 491 | prim_type = [self.param.grid_ht, self.param.grid_wd] 492 | tnode_queue = self.find_node_ids_with_given_prim_type(prim_type) 493 | assert len(tnode_queue) == 1 494 | 495 | i = 0 496 | while i < len(tnode_queue): 497 | id_ = tnode_queue[i] 498 | node = self.node_set[id_] 499 | i += 1 500 | 501 | rect = self.primitive_set[node.rect_idx] 502 | 503 | ids = [] 504 | for y in range(0, rect.Height()): 505 | for x in range(0, rect.Width()): 506 | if y == 0 and x == 0: 507 | continue 508 | prim_type = [rect.Height()-y, rect.Width()-x] 509 | ids += self.find_node_ids_with_given_prim_type(prim_type, rect) 510 | 511 | ids = list(set(ids)) 512 | tnode_queue += ids 513 | 514 | for pid in ids: 515 | if id_ not in self.node_set[pid].child_ids: 516 | self.node_set[pid].child_ids.append(id_) 517 | 518 | def _add_onode_topdown_connection(self): 519 | assert self.param.use_root_terminal_node 520 | 521 | prim_type = [self.param.grid_ht, self.param.grid_wd] 522 | tnode_queue = self.find_node_ids_with_given_prim_type(prim_type) 523 | assert len(tnode_queue) == 1 524 | 525 | i = 0 526 | while i < len(tnode_queue): 527 | id_ = tnode_queue[i] 528 | node = self.node_set[id_] 529 | i += 1 530 | 531 | rect = self.primitive_set[node.rect_idx] 532 | 533 | ids = [] 534 | ids_t = [] 535 | for y in range(0, rect.Height()): 536 | for x in range(0, rect.Width()): 537 | if y == 0 and x == 0: 538 | continue 539 | prim_type = [rect.Height()-y, rect.Width()-x] 540 | ids += self.find_node_ids_with_given_prim_type(prim_type, rect, NodeType.OrNode) 541 | ids_t += self.find_node_ids_with_given_prim_type(prim_type, rect) 542 | 543 | ids = list(set(ids)) 544 | ids_t = list(set(ids_t)) 545 | 546 | for pid in ids: 547 | if id_ not in self.node_set[pid].child_ids: 548 | self.node_set[pid].child_ids.append(id_) 549 | 550 | def _add_tnode_bottomup_connection(self): 551 | assert self.param.use_root_terminal_node 552 | 553 | # primitive tnodes 554 | prim_type = [1, 1] 555 | t_ids = self.find_node_ids_with_given_prim_type(prim_type) 556 | assert len(t_ids) == self.param.grid_wd * self.param.grid_ht 557 | 558 | # other tnodes will be converted to and-nodes 559 | for h in range(1, self.param.grid_ht+1): 560 | for w in range(1, self.param.grid_wd+1): 561 | if h == 1 and w == 1: 562 | continue 563 | prim_type = [h, w] 564 | ids = self.find_node_ids_with_given_prim_type(prim_type) 565 | for id_ in ids: 566 | self.node_set[id_].node_type = NodeType.AndNode 567 | node = self.node_set[id_] 568 | rect = self.primitive_set[node.rect_idx] 569 | prim_type = [1, 1] 570 | for y in range(rect.y1, rect.y2+1): 571 | for x in range(rect.x1, rect.x2+1): 572 | parent_rect = Rect(x, y, x, y) 573 | ch_ids = self.find_node_ids_with_given_prim_type(prim_type, parent_rect) 574 | assert len(ch_ids) == 1 575 | if ch_ids[0] not in self.node_set[id_].child_ids: 576 | self.node_set[id_].child_ids.append(ch_ids[0]) 577 | 578 | def _add_lateral_connection(self): 579 | self._add_node_bottomup_connection_layerwise(node_type=NodeType.AndNode, direction=1) 580 | self._add_node_bottomup_connection_layerwise(node_type=NodeType.OrNode, direction=0) 581 | 582 | if not self.param.use_node_lateral_connection_1: 583 | return self.BFS[0] 584 | 585 | # or for all or nodes 586 | for node in self.node_set: 587 | if node.node_type != NodeType.OrNode: 588 | continue 589 | 590 | ch_ids = node.child_ids 591 | numCh = len(ch_ids) 592 | 593 | hasLateral = False 594 | for id_ in ch_ids: 595 | if self.node_set[id_].node_type == NodeType.OrNode: 596 | hasLateral = True 597 | numCh -= 1 598 | 599 | minNumCh = 3 if hasLateral else 2 600 | if len(ch_ids) < minNumCh: 601 | continue 602 | 603 | # find t-node child 604 | ch0 = -1 605 | for id_ in ch_ids: 606 | if self.node_set[id_].node_type == NodeType.TerminalNode: 607 | ch0 = id_ 608 | break 609 | assert ch0 != -1 610 | 611 | added = False 612 | for id_ in ch_ids: 613 | if id_ == ch0 or self.node_set[id_].node_type == NodeType.OrNode: 614 | continue 615 | 616 | if len(self.node_set[id_].child_ids) == 2 or numCh == 2: 617 | assert ch0 not in self.node_set[id_].child_ids 618 | self.node_set[id_].child_ids.append(ch0) 619 | added = True 620 | 621 | if not added: 622 | for id_ in ch_ids: 623 | if id_ == ch0 or self.node_set[id_].node_type == NodeType.OrNode: 624 | continue 625 | 626 | found = True 627 | for id__ in ch_ids: 628 | if id_ in self.node_set[id__].child_ids: 629 | found = False 630 | if found: 631 | assert ch0 not in self.node_set[id_].child_ids 632 | self.node_set[id_].child_ids.append(ch0) 633 | 634 | return self.BFS[0] 635 | 636 | def _add_node_bottomup_connection_layerwise(self, node_type=NodeType.TerminalNode, direction=0): 637 | 638 | prim_types = [] 639 | for node in self.node_set: 640 | if node.node_type == node_type: 641 | rect = self.primitive_set[node.rect_idx] 642 | p = [rect.Height(), rect.Width()] 643 | if p not in prim_types: 644 | prim_types.append(p) 645 | 646 | change_direction = False 647 | 648 | prim_types.sort() 649 | 650 | for p in prim_types: 651 | ids = self.find_node_ids_with_given_prim_type(p, node_type=node_type) 652 | if len(ids) < 2: 653 | change_direction = True 654 | continue 655 | 656 | if change_direction: 657 | direction = 1 - direction 658 | 659 | yx = np.empty((0, 4 if node_type==NodeType.AndNode else 2), dtype=np.float32) 660 | for id_ in ids: 661 | node = self.node_set[id_] 662 | rect = self.primitive_set[node.rect_idx] 663 | 664 | if node_type == NodeType.AndNode: 665 | ch_node = self.node_set[node.child_ids[0]] 666 | ch_rect = self.primitive_set[ch_node.rect_idx] 667 | if ch_rect.x1 != rect.x1 or ch_rect.y1 != rect.y1: 668 | ch_node = self.node_set[node.child_ids[1]] 669 | ch_rect = self.primitive_set[ch_node.rect_idx] 670 | pos = (rect.y1, rect.x1, ch_rect.y2, ch_rect.x2) 671 | else: 672 | pos = (rect.y1, rect.x1) 673 | yx = np.vstack((yx, np.array(pos))) 674 | 675 | if node_type == NodeType.AndNode: 676 | ind = np.lexsort((yx[:, 1], yx[:, 0], yx[:, 3], yx[:, 2])) 677 | else: 678 | ind = np.lexsort((yx[:, 1], yx[:, 0])) 679 | 680 | istart = len(ind) - 1 if direction == 0 else 0 681 | iend = 0 if direction == 0 else len(ind) - 1 682 | step = -1 if direction == 0 else 1 683 | for i in range(istart, iend, step): 684 | id_ = ids[ind[i]] 685 | chid = ids[ind[i - 1]] if direction==0 else ids[ind[i+1]] 686 | if chid not in self.node_set[id_].child_ids: 687 | self.node_set[id_].child_ids.append(chid) 688 | 689 | if change_direction: 690 | direction = 1 - direction 691 | change_direction = False 692 | 693 | def _add_tnode_bottomup_connection_sequential(self): 694 | 695 | assert self.param.grid_wd > 1 and self.param.grid_ht == 1 696 | 697 | self._add_node_bottomup_connection_layerwise() 698 | 699 | for node in self.node_set: 700 | if node.node_type != NodeType.TerminalNode: 701 | continue 702 | rect = self.primitive_set[node.rect_idx] 703 | if rect.Width() == 1: 704 | continue 705 | 706 | rect1 = copy.deepcopy(rect) 707 | rect1.x1 += 1 708 | chid = self._FindNodeIdWithGivenRect(rect1, NodeType.TerminalNode) 709 | if chid != -1: 710 | self.node_set[node.id].child_ids.append(chid) 711 | 712 | def _mark_symmetric_subgraph(self): 713 | 714 | for i in self.BFS: 715 | node = self.node_set[i] 716 | 717 | if node.is_symmetric or node.node_type == NodeType.TerminalNode: 718 | continue 719 | 720 | if i != self.BFS[0]: 721 | is_symmetric = True 722 | for j in node.parent_ids: 723 | p = self.node_set[j] 724 | if not p.is_symmetric: 725 | is_symmetric = False 726 | break 727 | if is_symmetric: 728 | self.node_set[i].is_symmetric = True 729 | continue 730 | 731 | rect = self.primitive_set[node.rect_idx] 732 | Wd = rect.Width() 733 | Ht = rect.Height() 734 | 735 | if node.node_type == NodeType.OrNode: 736 | # mark symmetric children 737 | useSplitWds = [] 738 | useSplitHts = [] 739 | if self.param.remove_symmetric_children_of_or_node == 2: 740 | child_ids = node.child_ids[::-1] 741 | else: 742 | child_ids = node.child_ids 743 | 744 | for j in child_ids: 745 | ch = self.node_set[j] 746 | if ch.node_type == NodeType.TerminalNode: 747 | continue 748 | 749 | if ch.split_type == SplitType.VerSplit: 750 | if (Wd-ch.split_step2, ch.split_step1) not in useSplitWds: 751 | useSplitWds.append((ch.split_step1, Wd-ch.split_step2)) 752 | else: 753 | self.node_set[j].is_symmetric = True 754 | 755 | elif ch.split_type == SplitType.HorSplit: 756 | if (Ht-ch.split_step2, ch.split_step1) not in useSplitHts: 757 | useSplitHts.append((ch.split_step1, Ht-ch.split_step2)) 758 | else: 759 | self.node_set[j].is_symmetric = True 760 | 761 | def _find_dbl_counting_or_nodes(self): 762 | for node in self.node_set: 763 | if node.node_type != NodeType.OrNode or len(node.child_ids) < 2: 764 | continue 765 | for i in self.node_BFS[node.id][1:]: 766 | npaths = { x : 0 for x in self.node_BFS[node.id] } 767 | n = self._countPaths(node, self.node_set[i], npaths) 768 | if n > 1: 769 | self.node_set[node.id].has_dbl_counting = True 770 | break 771 | 772 | def find_node_ids_with_given_prim_type(self, prim_type, parent_rect=None, node_type=NodeType.TerminalNode): 773 | ids = [] 774 | for node in self.node_set: 775 | if node.node_type != node_type: 776 | continue 777 | rect = self.primitive_set[node.rect_idx] 778 | if [rect.Height(), rect.Width()] == prim_type: 779 | if parent_rect is not None: 780 | if rect.x1 >= parent_rect.x1 and rect.y1 >= parent_rect.y1 and \ 781 | rect.x2 <= parent_rect.x2 and rect.y2 <= parent_rect.y2: 782 | ids.append(node.id) 783 | else: 784 | ids.append(node.id) 785 | return ids 786 | 787 | def Create(self): 788 | # print("======= creating AOGrid {}, could take a while".format(self.param.tag)) 789 | # FIXME: when remove_symmetric_children_of_or_node is true, top-down hierarchy is not correctly created. 790 | 791 | # the root OrNode 792 | rect = Rect(0, 0, self.param.grid_wd - 1, self.param.grid_ht - 1) 793 | self.primitive_set.append(rect) 794 | node = Node(node_type=NodeType.OrNode, rect_idx=0) 795 | self._AddNode(node) 796 | 797 | BFS = deque() 798 | BFS.append(0) 799 | keepLeft = True 800 | keepTop = True 801 | while len(BFS) > 0: 802 | curId = BFS.popleft() 803 | curNode = self.node_set[curId] 804 | curRect = self.primitive_set[curNode.rect_idx] 805 | curWd = curRect.Width() 806 | curHt = curRect.Height() 807 | 808 | childIds = [] 809 | 810 | if curNode.node_type == NodeType.OrNode: 811 | num_child_node_kept = 0 812 | # add a terminal node for a non-root OrNode 813 | allowTerminate = not ((self.param.not_use_large_terminal_node and 814 | float(curWd * curHt) / float(self.param.grid_ht * self.param.grid_wd) > 815 | self.param.turn_off_size_ratio_terminal_node) or 816 | (self.param.not_use_intermediate_TerminalNode and (curWd > self.param.min_size or curHt > self.param.min_size))) 817 | 818 | if (curId > 0 and allowTerminate) or (curId==0 and self.param.use_root_terminal_node): 819 | node = Node(node_type=NodeType.TerminalNode, rect_idx=curNode.rect_idx) 820 | suc, node = self._AddNode(node) 821 | childIds.append(node.id) 822 | num_child_node_kept += 1 823 | 824 | # add all AndNodes for horizontal and vertical binary splits 825 | if not self._DoSplit(curRect): 826 | childIds = list(set(childIds)) 827 | self.node_set[curId].child_ids = childIds 828 | continue 829 | 830 | num_child_node_to_add = self.param.max_children_kept_for_or - num_child_node_kept 831 | stepH = self._SplitStep(curWd) 832 | stepV = self._SplitStep(curHt) 833 | num_stepH = curHt - stepH + 1 - stepH 834 | num_stepV = curWd - stepV + 1 - stepV 835 | if num_stepH == 0 and num_stepV == 0: 836 | childIds = list(set(childIds)) 837 | self.node_set[curId].child_ids = childIds 838 | continue 839 | 840 | num_child_node_to_add_H = num_stepH / float(num_stepH + num_stepV) * num_child_node_to_add 841 | num_child_node_to_add_V = num_child_node_to_add - num_child_node_to_add_H 842 | 843 | stepH_step = int( 844 | max(1, floor(float(num_stepH) / num_child_node_to_add_H) if num_child_node_to_add_H > 0 else 0)) 845 | stepV_step = int( 846 | max(1, floor(float(num_stepV) / num_child_node_to_add_V) if num_child_node_to_add_V > 0 else 0)) 847 | 848 | # horizontal splits 849 | step = stepH 850 | num_child_node_added_H = 0 851 | 852 | splitHts = [] 853 | for topHt in range(step, curHt - step + 1, stepH_step): 854 | if num_child_node_added_H >= num_child_node_to_add_H: 855 | break 856 | 857 | bottomHt = curHt - topHt 858 | if self.param.overlap_ratio > 0: 859 | numSplit = int(1 + floor(topHt * self.param.overlap_ratio)) 860 | else: 861 | numSplit = 1 862 | for b in range(0, numSplit): 863 | splitHts.append((topHt, bottomHt)) 864 | bottomHt += 1 865 | num_child_node_added_H += 1 866 | 867 | if self.param.remove_symmetric_children_of_or_node == 1 and self.param.mark_symmetric_syntatic_subgraph == False: 868 | useSplitHts = [] 869 | for (topHt, bottomHt) in splitHts: 870 | if (bottomHt, topHt) not in useSplitHts: 871 | useSplitHts.append((topHt, bottomHt)) 872 | elif self.param.remove_symmetric_children_of_or_node == 2 and self.param.mark_symmetric_syntatic_subgraph == False: 873 | useSplitHts = [] 874 | for (topHt, bottomHt) in reversed(splitHts): 875 | if (bottomHt, topHt) not in useSplitHts: 876 | useSplitHts.append((topHt, bottomHt)) 877 | else: 878 | useSplitHts = splitHts 879 | 880 | for (topHt, bottomHt) in useSplitHts: 881 | node = Node(node_type=NodeType.AndNode, rect_idx=curNode.rect_idx, 882 | split_type=SplitType.HorSplit, 883 | split_step1=topHt, split_step2=curHt - bottomHt) 884 | suc, node = self._AddNode(node) 885 | if suc: 886 | BFS.append(node.id) 887 | childIds.append(node.id) 888 | 889 | # vertical splits 890 | step = stepV 891 | num_child_node_added_V = 0 892 | 893 | splitWds = [] 894 | for leftWd in range(step, curWd - step + 1, stepV_step): 895 | if num_child_node_added_V >= num_child_node_to_add_V: 896 | break 897 | 898 | rightWd = curWd - leftWd 899 | if self.param.overlap_ratio > 0: 900 | numSplit = int(1 + floor(leftWd * self.param.overlap_ratio)) 901 | else: 902 | numSplit = 1 903 | for r in range(0, numSplit): 904 | splitWds.append((leftWd, rightWd)) 905 | rightWd += 1 906 | num_child_node_added_V += 1 907 | 908 | if self.param.remove_symmetric_children_of_or_node == 1 and self.param.mark_symmetric_syntatic_subgraph == False: 909 | useSplitWds = [] 910 | for (leftWd, rightWd) in splitWds: 911 | if (rightWd, leftWd) not in useSplitWds: 912 | useSplitWds.append((leftWd, rightWd)) 913 | elif self.param.remove_symmetric_children_of_or_node == 2 and self.param.mark_symmetric_syntatic_subgraph == False: 914 | useSplitWds = [] 915 | for (leftWd, rightWd) in reversed(splitWds): 916 | if (rightWd, leftWd) not in useSplitWds: 917 | useSplitWds.append((leftWd, rightWd)) 918 | else: 919 | useSplitWds = splitWds 920 | 921 | for (leftWd, rightWd) in useSplitWds: 922 | node = Node(node_type=NodeType.AndNode, rect_idx=curNode.rect_idx, 923 | split_type=SplitType.VerSplit, 924 | split_step1=leftWd, split_step2=curWd - rightWd) 925 | suc, node = self._AddNode(node) 926 | if suc: 927 | BFS.append(node.id) 928 | childIds.append(node.id) 929 | 930 | elif curNode.node_type == NodeType.AndNode: 931 | # add two child OrNodes 932 | if curNode.split_type == SplitType.HorSplit: 933 | top = Rect(x1=curRect.x1, y1=curRect.y1, 934 | x2=curRect.x2, y2=curRect.y1 + curNode.split_step1 - 1) 935 | node = Node(node_type=NodeType.OrNode, rect_idx=self._AddPrimitve(top)) 936 | suc, node = self._AddNode(node) 937 | if suc: 938 | BFS.append(node.id) 939 | childIds.append(node.id) 940 | 941 | bottom = Rect(x1=curRect.x1, y1=curRect.y1 + curNode.split_step2, 942 | x2=curRect.x2, y2=curRect.y2) 943 | node = Node(node_type=NodeType.OrNode, rect_idx=self._AddPrimitve(bottom)) 944 | suc, node = self._AddNode(node) 945 | if suc: 946 | BFS.append(node.id) 947 | childIds.append(node.id) 948 | elif curNode.split_type == SplitType.VerSplit: 949 | left = Rect(curRect.x1, curRect.y1, 950 | curRect.x1 + curNode.split_step1 - 1, curRect.y2) 951 | node = Node(node_type=NodeType.OrNode, rect_idx=self._AddPrimitve(left)) 952 | suc, node = self._AddNode(node) 953 | if suc: 954 | BFS.append(node.id) 955 | childIds.append(node.id) 956 | 957 | right = Rect(curRect.x1 + curNode.split_step2, curRect.y1, 958 | curRect.x2, curRect.y2) 959 | node = Node(node_type=NodeType.OrNode, rect_idx=self._AddPrimitve(right)) 960 | suc, node = self._AddNode(node) 961 | if suc: 962 | BFS.append(node.id) 963 | childIds.append(node.id) 964 | 965 | childIds = list(set(childIds)) 966 | self.node_set[curId].child_ids = childIds 967 | 968 | # add root terminal node if allowed 969 | root_id = 0 970 | 971 | # create And-nodes with more than 2 children 972 | # TODO: handle remove_symmetric_child_node 973 | if self.param.max_split > 2: 974 | for branch in range(3, self.param.max_split + 1): 975 | for node in self.node_set: 976 | if node.node_type != NodeType.OrNode: 977 | continue 978 | 979 | new_and_ids = [] 980 | 981 | for cur_id in node.child_ids: 982 | cur_and = self.node_set[cur_id] 983 | if len(cur_and.child_ids) != branch - 1: 984 | continue 985 | assert cur_and.node_type == NodeType.AndNode 986 | 987 | for ch_id in cur_and.child_ids: 988 | ch = self.node_set[ch_id] 989 | curRect = self.primitive_set[ch.rect_idx] 990 | curWd = curRect.Width() 991 | curHt = curRect.Height() 992 | 993 | # split ch into two to create new And-nodes 994 | 995 | # add all AndNodes for horizontal and vertical binary splits 996 | if not self._DoSplit(curRect): 997 | continue 998 | 999 | # horizontal splits 1000 | step = self._SplitStep(curWd) 1001 | for topHt in range(step, curHt - step + 1): 1002 | bottomHt = curHt - topHt 1003 | if self.param.overlap_ratio > 0: 1004 | numSplit = int(1 + floor(topHt * self.param.overlap_ratio)) 1005 | else: 1006 | numSplit = 1 1007 | for b in range(0, numSplit): 1008 | split_step1 = topHt 1009 | split_step2 = curHt - bottomHt 1010 | 1011 | top = Rect(x1=curRect.x1, y1=curRect.y1, 1012 | x2=curRect.x2, y2=curRect.y1 + split_step1 - 1) 1013 | top_id = self._FindNodeIdWithGivenRect(top, NodeType.OrNode) 1014 | if top_id == -1: 1015 | continue 1016 | # assert top_id != -1 1017 | 1018 | bottom = Rect(x1=curRect.x1, y1=curRect.y1 + split_step2, 1019 | x2=curRect.x2, y2=curRect.y2) 1020 | bottom_id = self._FindNodeIdWithGivenRect(bottom, NodeType.OrNode) 1021 | if bottom_id == -1: 1022 | continue 1023 | # assert bottom_id != -1 1024 | 1025 | # add a new And-node 1026 | new_and = Node(node_type=NodeType.AndNode, rect_idx=cur_and.rect_idx) 1027 | new_and.child_ids = list(set(cur_and.child_ids) - set([ch_id])) + [top_id, 1028 | bottom_id] 1029 | 1030 | suc, new_and = self._AddNode(new_and) 1031 | new_and_ids.append(new_and.id) 1032 | 1033 | bottomHt += 1 1034 | 1035 | # vertical splits 1036 | step = self._SplitStep(curHt) 1037 | for leftWd in range(step, curWd - step + 1): 1038 | rightWd = curWd - leftWd 1039 | 1040 | if self.param.overlap_ratio > 0: 1041 | numSplit = int(1 + floor(leftWd * self.param.overlap_ratio)) 1042 | else: 1043 | numSplit = 1 1044 | for r in range(0, numSplit): 1045 | split_step1 = leftWd 1046 | split_step2 = curWd - rightWd 1047 | 1048 | left = Rect(curRect.x1, curRect.y1, 1049 | curRect.x1 + split_step1 - 1, curRect.y2) 1050 | left_id = self._FindNodeIdWithGivenRect(left, NodeType.OrNode) 1051 | if left_id == -1: 1052 | continue 1053 | # assert left_id != -1 1054 | 1055 | right = Rect(curRect.x1 + split_step2, curRect.y1, 1056 | curRect.x2, curRect.y2) 1057 | right_id = self._FindNodeIdWithGivenRect(right, NodeType.OrNode) 1058 | if right_id == -1: 1059 | continue 1060 | # assert right_id != -1 1061 | 1062 | # add a new And-node 1063 | new_and = Node(node_type=NodeType.AndNode, rect_idx=cur_and.rect_idx) 1064 | new_and.child_ids = list(set(cur_and.child_ids) - set([ch_id])) + [left_id, 1065 | right_id] 1066 | 1067 | suc, new_and = self._AddNode(new_and) 1068 | new_and_ids.append(new_and.id) 1069 | 1070 | rightWd += 1 1071 | 1072 | self.node_set[node.id].child_ids = list(set(self.node_set[node.id].child_ids + new_and_ids)) 1073 | 1074 | self._Postprocessing(root_id) 1075 | 1076 | # change tnodes to child nodes of and-nodes / or-nodes 1077 | if self.param.use_tnode_as_alpha_channel > 0: 1078 | node_type = NodeType.OrNode if self.param.use_tnode_as_alpha_channel==1 else NodeType.AndNode 1079 | not_create_if_existed = not self.param.use_tnode_as_alpha_channel==1 1080 | for id_ in self.BFS: 1081 | node = self.node_set[id_] 1082 | if node.node_type == NodeType.OrNode and len(node.child_ids) > 1: 1083 | for ch in node.child_ids: 1084 | ch_node = self.node_set[ch] 1085 | if ch_node.node_type == NodeType.TerminalNode: 1086 | new_parent_node = Node(node_type=node_type, rect_idx=ch_node.rect_idx) 1087 | _, new_parent_node = self._AddNode(new_parent_node, not_create_if_existed) 1088 | new_parent_node.child_ids = [ch_node.id, node.id] 1089 | 1090 | for pr in node.parent_ids: 1091 | pr_node = self.node_set[pr] 1092 | for i, pr_ch in enumerate(pr_node.child_ids): 1093 | if pr_ch == node.id: 1094 | pr_node.child_ids[i] = new_parent_node.id 1095 | break 1096 | 1097 | self.node_set[id_].child_ids.remove(ch) 1098 | if id_ == self.BFS[0]: 1099 | root_id = new_parent_node.id 1100 | break 1101 | 1102 | self._Postprocessing(root_id) 1103 | 1104 | # add super-or node 1105 | if self.param.use_super_OrNode: 1106 | super_or_node = Node(node_type=NodeType.OrNode, rect_idx=-1) 1107 | _, super_or_node = self._AddNode(super_or_node) 1108 | super_or_node.child_ids = [] 1109 | for node in self.node_set: 1110 | if node.node_type == NodeType.OrNode and node.rect_idx != -1: 1111 | rect = self.primitive_set[node.rect_idx] 1112 | r = float(rect.Area()) / float(self.param.grid_ht * self.param.grid_wd) 1113 | if r > 0.5: 1114 | super_or_node.child_ids.append(node.id) 1115 | 1116 | root_id = super_or_node.id 1117 | 1118 | self._Postprocessing(root_id) 1119 | 1120 | # remove or-nodes with single child node 1121 | if self.param.remove_single_child_or_node: 1122 | remove_ids = [] 1123 | for node in self.node_set: 1124 | if node.node_type == NodeType.OrNode and len(node.child_ids) == 1: 1125 | for pr in node.parent_ids: 1126 | pr_node = self.node_set[pr] 1127 | for i, pr_ch in enumerate(pr_node.child_ids): 1128 | if pr_ch == node.id: 1129 | pr_node.child_ids[i] = node.child_ids[0] 1130 | break 1131 | 1132 | remove_ids.append(node.id) 1133 | node.child_ids = [] 1134 | 1135 | remove_ids.sort() 1136 | remove_ids.reverse() 1137 | 1138 | for id_ in remove_ids: 1139 | for node in self.node_set: 1140 | if node.id > id_: 1141 | node.id -= 1 1142 | for i, ch in enumerate(node.child_ids): 1143 | if ch > id_: 1144 | node.child_ids[i] -= 1 1145 | 1146 | if root_id > id_: 1147 | root_id -= 1 1148 | 1149 | for id_ in remove_ids: 1150 | del self.node_set[id_] 1151 | 1152 | self._Postprocessing(root_id) 1153 | 1154 | # mark symmetric nodes 1155 | if self.param.mark_symmetric_syntatic_subgraph: 1156 | self._mark_symmetric_subgraph() 1157 | 1158 | # add tnode hierarchy 1159 | if self.param.use_tnode_topdown_connection: 1160 | self._add_tnode_topdown_connection() 1161 | self._Postprocessing(root_id) 1162 | elif self.param.use_tnode_bottomup_connection: 1163 | self._add_tnode_bottomup_connection() 1164 | self._Postprocessing(root_id) 1165 | elif self.param.use_tnode_bottomup_connection_layerwise: 1166 | self._add_node_bottomup_connection_layerwise() 1167 | self._Postprocessing(root_id) 1168 | elif self.param.use_tnode_bottomup_connection_sequential: 1169 | self._add_tnode_bottomup_connection_sequential() 1170 | self._Postprocessing(root_id) 1171 | elif self.param.use_node_lateral_connection or self.param.use_node_lateral_connection_1: 1172 | root_id = self._add_lateral_connection() 1173 | self._Postprocessing(root_id) 1174 | 1175 | # index of Or-nodes in BFS 1176 | self.OrNodeIdxInBFS = {} 1177 | self.TNodeIdxInBFS = {} 1178 | idx_or = 0 1179 | idx_t = 0 1180 | for id_ in self.BFS: 1181 | node = self.node_set[id_] 1182 | if node.node_type == NodeType.OrNode: 1183 | self.OrNodeIdxInBFS[node.id] = idx_or 1184 | idx_or += 1 1185 | elif node.node_type == NodeType.TerminalNode: 1186 | self.TNodeIdxInBFS[node.id] = idx_t 1187 | idx_t += 1 1188 | 1189 | # get DFS and BFS rooted at each node 1190 | for node in self.node_set: 1191 | if node.node_type == NodeType.TerminalNode: 1192 | continue 1193 | visited = np.zeros(len(self.node_set)) 1194 | self.node_DFS[node.id] = [] 1195 | self.node_DFS[node.id], _ = self._DFS(node.id, self.node_DFS[node.id], visited) 1196 | 1197 | visited = np.zeros(len(self.node_set)) 1198 | self.node_BFS[node.id] = [] 1199 | self.node_BFS[node.id], _ = self._BFS(node.id, self.node_BFS[node.id], visited) 1200 | 1201 | # count paths between nodes and root node 1202 | for n in self.node_set: 1203 | npaths = { x.id : 0 for x in self.node_set } 1204 | self.node_set[n.id].npaths = self._countPaths(self.node_set[self.BFS[0]], n, npaths) 1205 | 1206 | # find ornode with double-counting children 1207 | self._find_dbl_counting_or_nodes() 1208 | 1209 | # generate colors for terminal nodes for consistency in visualization 1210 | self.TNodeColors = {} 1211 | for node in self.node_set: 1212 | if node.node_type == NodeType.TerminalNode: 1213 | self.TNodeColors[node.id] = ( 1214 | random.random(), random.random(), random.random()) # generate a random color 1215 | 1216 | 1217 | def TurnOnOffNodes(self, on_off): 1218 | for i in range(len(self.node_set)): 1219 | self.node_set[i].on_off = on_off 1220 | 1221 | def UpdateOnOffNodes(self, pg, offset_using_part_type, class_name=''): 1222 | BFS = [self.BFS[0]] 1223 | pg_used = np.ones((1, len(pg)), dtype=np.int) * -1 1224 | configuration = [] 1225 | tnode_offset_indx = [] 1226 | while len(BFS): 1227 | id = BFS.pop() 1228 | node = self.node_set[id] 1229 | self.node_set[id].on_off = True 1230 | if len(class_name): 1231 | if class_name in node.which_classes_visited.keys(): 1232 | self.node_set[id].which_classes_visited[class_name] += 1.0 1233 | else: 1234 | self.node_set[id].which_classes_visited[class_name] = 0 1235 | 1236 | if node.node_type == NodeType.OrNode: 1237 | idx = self.OrNodeIdxInBFS[node.id] 1238 | BFS.append(node.child_ids[int(pg[idx])]) 1239 | pg_used[0, idx] = int(pg[idx]) 1240 | if len(self.node_set[id].out_edge_visited_count): 1241 | self.node_set[id].out_edge_visited_count[int(pg[idx])] += 1.0 1242 | else: 1243 | self.node_set[id].out_edge_visited_count = np.zeros((len(node.child_ids),), dtype=np.float32) 1244 | elif node.node_type == NodeType.AndNode: 1245 | BFS += node.child_ids 1246 | 1247 | else: 1248 | configuration.append(node.id) 1249 | 1250 | offset_ind = 0 1251 | if not offset_using_part_type: 1252 | for node1 in self.node_set: 1253 | if node1.node_type == NodeType.TerminalNode: # change to BFS after _part_instance is changed to BFS 1254 | if node1.id == node.id: 1255 | break 1256 | offset_ind += 1 1257 | else: 1258 | rect = self.primitive_set[node.rect_idx] 1259 | offset_ind = self.part_type.index([rect.Height(), rect.Width()]) 1260 | 1261 | tnode_offset_indx.append(offset_ind) 1262 | 1263 | configuration.sort() 1264 | cfg = np.ones((1, self.num_TNodes), dtype=np.int) * -1 1265 | cfg[0, :len(configuration)] = configuration 1266 | return pg_used, cfg, tnode_offset_indx 1267 | 1268 | def ResetOutEdgeVisitedCountNodes(self): 1269 | for i in range(len(self.node_set)): 1270 | self.node_set[i].out_edge_visited_count = [] 1271 | 1272 | def NormalizeOutEdgeVisitedCountNodes(self, count=0): 1273 | if count == 0: 1274 | for i in range(len(self.node_set)): 1275 | if len(self.node_set[i].out_edge_visited_count): 1276 | count = max(count, max(self.node_set[i].out_edge_visited_count)) 1277 | 1278 | if count == 0: 1279 | return 1280 | 1281 | for i in range(len(self.node_set)): 1282 | if len(self.node_set[i].out_edge_visited_count): 1283 | self.node_set[i].out_edge_visited_count /= count 1284 | 1285 | def ResetWhichClassesVisitedNodes(self): 1286 | for i in range(len(self.node_set)): 1287 | self.node_set[i].which_classes_visited = {} 1288 | 1289 | def NormalizeWhichClassesVisitedNodes(self, class_name, count): 1290 | assert count > 0 1291 | for i in range(len(self.node_set)): 1292 | if class_name in self.node_set[i].which_classes_visited.keys(): 1293 | self.node_set[i].which_classes_visited[class_name] /= count 1294 | -------------------------------------------------------------------------------- /backbones/aognet/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/iVMCL/AOGNets/cddba00e97d6f74d2e7bfce50fd7aea9630ff996/backbones/aognet/__init__.py -------------------------------------------------------------------------------- /backbones/aognet/aognet_singlescale.py: -------------------------------------------------------------------------------- 1 | """ RESEARCH ONLY LICENSE 2 | Copyright (c) 2018-2019 North Carolina State University. 3 | All rights reserved. 4 | Redistribution and use in source and binary forms, with or without modification, are permitted provided 5 | that the following conditions are met: 6 | 1. Redistributions and use are permitted for internal research purposes only, and commercial use 7 | is strictly prohibited under this license. Inquiries regarding commercial use should be directed to the 8 | Office of Research Commercialization at North Carolina State University, 919-215-7199, 9 | https://research.ncsu.edu/commercialization/contact/, commercialization@ncsu.edu . 10 | 2. Commercial use means the sale, lease, export, transfer, conveyance or other distribution to a 11 | third party for financial gain, income generation or other commercial purposes of any kind, whether 12 | direct or indirect. Commercial use also means providing a service to a third party for financial gain, 13 | income generation or other commercial purposes of any kind, whether direct or indirect. 14 | 3. Redistributions of source code must retain the above copyright notice, this list of conditions and 15 | the following disclaimer. 16 | 4. Redistributions in binary form must reproduce the above copyright notice, this list of conditions 17 | and the following disclaimer in the documentation and/or other materials provided with the 18 | distribution. 19 | 5. The names “North Carolina State University”, “NCSU” and any trade-name, personal name, 20 | trademark, trade device, service mark, symbol, image, icon, or any abbreviation, contraction or 21 | simulation thereof owned by North Carolina State University must not be used to endorse or promote 22 | products derived from this software without prior written permission. For written permission, please 23 | contact trademarks@ncsu.edu. 24 | Disclaimer: THIS SOFTWARE IS PROVIDED “AS IS” AND ANY EXPRESSED OR IMPLIED WARRANTIES, 25 | INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR 26 | A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NORTH CAROLINA STATE UNIVERSITY BE 27 | LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES 28 | (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, 29 | DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF 30 | LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR 31 | OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE 32 | POSSIBILITY OF SUCH DAMAGE. 33 | """ 34 | 35 | from __future__ import absolute_import 36 | from __future__ import division 37 | from __future__ import print_function # force to use print as function print(args) 38 | from __future__ import unicode_literals 39 | 40 | import torch 41 | import torch.nn as nn 42 | import torch.nn.functional as F 43 | from torch.autograd import Variable 44 | 45 | import scipy.stats as stats 46 | 47 | from .config import cfg 48 | from .AOG import * 49 | from .operator_basic import * 50 | from .operator_singlescale import * 51 | 52 | ### AOG building block 53 | class AOGBlock(nn.Module): 54 | def __init__(self, stage, block, aog, in_channels, out_channels, drop_rate, stride): 55 | super(AOGBlock, self).__init__() 56 | self.stage = stage 57 | self.block = block 58 | self.aog = aog 59 | self.in_channels = in_channels 60 | self.out_channels = out_channels 61 | self.drop_rate = drop_rate 62 | self.stride = stride 63 | 64 | self.dim = aog.param.grid_wd 65 | self.in_slices = self._calculate_slices(self.dim, in_channels) 66 | self.out_slices = self._calculate_slices(self.dim, out_channels) 67 | 68 | self.node_set = aog.node_set 69 | self.primitive_set = aog.primitive_set 70 | self.BFS = aog.BFS 71 | self.DFS = aog.DFS 72 | 73 | self.hasLateral = {} 74 | self.hasDblCnt = {} 75 | 76 | self.primitiveDblCnt = None 77 | self._set_primitive_dbl_cnt() 78 | 79 | if "BatchNorm2d" in cfg.norm_name: 80 | self.norm_name_base = "BatchNorm2d" 81 | elif "GroupNorm" in cfg.norm_name: 82 | self.norm_name_base = "GroupNorm" 83 | else: 84 | raise ValueError("Unknown norm layer") 85 | 86 | self._set_weights_attr() 87 | 88 | self.extra_norm_ac = self._extra_norm_ac() 89 | 90 | def _calculate_slices(self, dim, channels): 91 | slices = [0] * dim 92 | for i in range(channels): 93 | slices[i % dim] += 1 94 | for d in range(1, dim): 95 | slices[d] += slices[d - 1] 96 | slices = [0] + slices 97 | return slices 98 | 99 | def _set_primitive_dbl_cnt(self): 100 | self.primitiveDblCnt = [0.0 for i in range(self.dim)] 101 | for id_ in self.DFS: 102 | node = self.node_set[id_] 103 | arr = self.primitive_set[node.rect_idx] 104 | if node.node_type == NodeType.TerminalNode: 105 | for i in range(arr.x1, arr.x2+1): 106 | self.primitiveDblCnt[i] += node.npaths 107 | for i in range(self.dim): 108 | assert self.primitiveDblCnt[i] >= 1.0 109 | 110 | def _create_op(self, node_id, cin, cout, stride, groups=1, 111 | keep_norm_base=False, norm_k=0): 112 | replace_stride = cfg.aognet.replace_stride_with_avgpool 113 | setattr(self, 'stage_{}_block_{}_node_{}_op'.format(self.stage, self.block, node_id), 114 | NodeOpSingleScale(cin, cout, stride, 115 | groups=groups, drop_rate=self.drop_rate, 116 | ac_mode=cfg.activation_mode, 117 | bn_ratio=cfg.aognet.bottleneck_ratio, 118 | norm_name=self.norm_name_base if keep_norm_base else cfg.norm_name, 119 | norm_groups=cfg.norm_groups, 120 | norm_k = norm_k, 121 | norm_attention_mode=cfg.norm_attention_mode, 122 | replace_stride_with_avgpool=replace_stride)) 123 | 124 | def _set_weights_attr(self): 125 | for id_ in self.DFS: 126 | node = self.node_set[id_] 127 | arr = self.primitive_set[node.rect_idx] 128 | bn_ratio = cfg.aognet.bottleneck_ratio 129 | width_per_group = cfg.aognet.width_per_group 130 | keep_norm_base = arr.Width() ch_arr.Width(): 160 | if node.npaths / self.node_set[chid].npaths != 1.0: 161 | self.hasDblCnt[node.id] = True 162 | break 163 | self._create_op(node.id, plane, plane, stride, groups=groups, 164 | keep_norm_base=keep_norm_base, norm_k=norm_k) 165 | 166 | elif node.node_type == NodeType.OrNode: 167 | assert self.node_set[node.child_ids[0]].node_type != NodeType.OrNode 168 | plane = self.out_slices[arr.x2 + 1] - self.out_slices[arr.x1] 169 | stride = 1 170 | groups = max(1, to_int(plane * bn_ratio / width_per_group)) \ 171 | if cfg.aognet.use_group_conv else 1 172 | self.hasLateral[node.id] = False 173 | self.hasDblCnt[node.id] = False 174 | for chid in node.child_ids: 175 | ch_arr = self.primitive_set[self.node_set[chid].rect_idx] 176 | if self.node_set[chid].node_type == NodeType.OrNode or arr.Width() < ch_arr.Width(): 177 | self.hasLateral[node.id] = True 178 | break 179 | if cfg.aognet.handle_dbl_cnt: 180 | for chid in node.child_ids: 181 | ch_arr = self.primitive_set[self.node_set[chid].rect_idx] 182 | if not (self.node_set[chid].node_type == NodeType.OrNode or arr.Width() < ch_arr.Width()): 183 | if node.npaths / self.node_set[chid].npaths != 1.0: 184 | self.hasDblCnt[node.id] = True 185 | break 186 | self._create_op(node.id, plane, plane, stride, groups=groups, 187 | keep_norm_base=keep_norm_base, norm_k=norm_k) 188 | 189 | def _extra_norm_ac(self): 190 | return nn.Sequential(FeatureNorm(self.norm_name_base, self.out_channels, 191 | cfg.norm_groups, cfg.norm_k[self.stage], 192 | cfg.norm_attention_mode), 193 | AC(cfg.activation_mode)) 194 | 195 | def forward(self, x): 196 | NodeIdTensorDict = {} 197 | 198 | # handle input x 199 | tnode_dblcnt = False 200 | if cfg.aognet.handle_tnode_dbl_cnt and self.in_channels==self.out_channels: 201 | x_scaled = [] 202 | for i in range(self.dim): 203 | left, right = self.in_slices[i], self.in_slices[i+1] 204 | x_scaled.append(x[:, left:right, :, :].div(self.primitiveDblCnt[i])) 205 | xx = torch.cat(x_scaled, 1) 206 | tnode_dblcnt = True 207 | 208 | # T-nodes, (hope they will be computed in parallel by pytorch) 209 | for id_ in self.DFS: 210 | node = self.node_set[id_] 211 | op_name = 'stage_{}_block_{}_node_{}_op'.format(self.stage, self.block, node.id) 212 | if node.node_type == NodeType.TerminalNode: 213 | arr = self.primitive_set[node.rect_idx] 214 | right, left = self.in_slices[arr.x2 + 1], self.in_slices[arr.x1] 215 | tnode_tensor_op = x if cfg.aognet.terminal_node_no_slice[self.stage] else x[:, left:right, :, :] #.contiguous() 216 | # assert tnode_tensor.requires_grad, 'slice needs to retain grad' 217 | if tnode_dblcnt: 218 | tnode_tensor_res = xx[:, left:right, :, :].mul(node.npaths) 219 | tnode_output = getattr(self, op_name)(tnode_tensor_op, tnode_tensor_res) 220 | else: 221 | tnode_output = getattr(self, op_name)(tnode_tensor_op) 222 | NodeIdTensorDict[node.id] = tnode_output 223 | 224 | # AND- and OR-nodes 225 | for id_ in self.DFS: 226 | node = self.node_set[id_] 227 | arr = self.primitive_set[node.rect_idx] 228 | op_name = 'stage_{}_block_{}_node_{}_op'.format(self.stage, self.block, node.id) 229 | if node.node_type == NodeType.AndNode: 230 | if self.hasDblCnt[node.id]: 231 | child_tensor_res = [] 232 | child_tensor_op = [] 233 | for chid in node.child_ids: 234 | ch_arr = self.primitive_set[self.node_set[chid].rect_idx] 235 | if arr.Width() > ch_arr.Width(): 236 | factor = node.npaths / self.node_set[chid].npaths 237 | if factor == 1.0: 238 | child_tensor_res.append(NodeIdTensorDict[chid]) 239 | else: 240 | child_tensor_res.append(NodeIdTensorDict[chid].mul(factor)) 241 | child_tensor_op.append(NodeIdTensorDict[chid]) 242 | 243 | anode_tensor_res = torch.cat(child_tensor_res, 1) 244 | anode_tensor_op = torch.cat(child_tensor_op, 1) 245 | 246 | if self.hasLateral[node.id]: 247 | ids1 = set(node.parent_ids) 248 | num_shared = 0 249 | for chid in node.child_ids: 250 | ch_arr = self.primitive_set[self.node_set[chid].rect_idx] 251 | ids2 = self.node_set[chid].parent_ids 252 | if arr.Width() == ch_arr.Width(): 253 | anode_tensor_op = anode_tensor_op + NodeIdTensorDict[chid] 254 | if len(ids1.intersection(ids2)) == num_shared: 255 | anode_tensor_res = anode_tensor_res + NodeIdTensorDict[chid] 256 | 257 | anode_output = getattr(self, op_name)(anode_tensor_op, anode_tensor_res) 258 | else: 259 | child_tensor = [] 260 | for chid in node.child_ids: 261 | ch_arr = self.primitive_set[self.node_set[chid].rect_idx] 262 | if arr.Width() > ch_arr.Width(): 263 | child_tensor.append(NodeIdTensorDict[chid]) 264 | 265 | anode_tensor_op = torch.cat(child_tensor, 1) 266 | 267 | if self.hasLateral[node.id]: 268 | ids1 = set(node.parent_ids) 269 | num_shared = 0 270 | for chid in node.child_ids: 271 | ch_arr = self.primitive_set[self.node_set[chid].rect_idx] 272 | ids2 = self.node_set[chid].parent_ids 273 | if arr.Width() == ch_arr.Width() and len(ids1.intersection(ids2)) == num_shared: 274 | anode_tensor_op = anode_tensor_op + NodeIdTensorDict[chid] 275 | 276 | anode_tensor_res = anode_tensor_op 277 | 278 | for chid in node.child_ids: 279 | ch_arr = self.primitive_set[self.node_set[chid].rect_idx] 280 | ids2 = self.node_set[chid].parent_ids 281 | if arr.Width() == ch_arr.Width() and len(ids1.intersection(ids2)) > num_shared: 282 | anode_tensor_op = anode_tensor_op + NodeIdTensorDict[chid] 283 | 284 | anode_output = getattr(self, op_name)(anode_tensor_op, anode_tensor_res) 285 | else: 286 | anode_output = getattr(self, op_name)(anode_tensor_op) 287 | 288 | NodeIdTensorDict[node.id] = anode_output 289 | 290 | elif node.node_type == NodeType.OrNode: 291 | if self.hasDblCnt[node.id]: 292 | factor = node.npaths / self.node_set[node.child_ids[0]].npaths 293 | if factor == 1.0: 294 | onode_tensor_res = NodeIdTensorDict[node.child_ids[0]] 295 | else: 296 | onode_tensor_res = NodeIdTensorDict[node.child_ids[0]].mul(factor) 297 | onode_tensor_op = NodeIdTensorDict[node.child_ids[0]] 298 | for chid in node.child_ids[1:]: 299 | if self.node_set[chid].node_type != NodeType.OrNode: 300 | ch_arr = self.primitive_set[self.node_set[chid].rect_idx] 301 | if arr.Width() == ch_arr.Width(): 302 | factor = node.npaths / self.node_set[chid].npaths 303 | if factor == 1.0: 304 | onode_tensor_res = onode_tensor_res + NodeIdTensorDict[chid] 305 | else: 306 | onode_tensor_res = onode_tensor_res + NodeIdTensorDict[chid].mul(factor) 307 | if cfg.aognet.use_elem_max_for_ORNodes: 308 | onode_tensor_op = torch.max(onode_tensor_op, NodeIdTensorDict[chid]) 309 | else: 310 | onode_tensor_op = onode_tensor_op + NodeIdTensorDict[chid] 311 | 312 | if self.hasLateral[node.id]: 313 | ids1 = set(node.parent_ids) 314 | num_shared = 0 315 | for chid in node.child_ids[1:]: 316 | ids2 = self.node_set[chid].parent_ids 317 | if self.node_set[chid].node_type == NodeType.OrNode and \ 318 | len(ids1.intersection(ids2)) == num_shared: 319 | onode_tensor_res = onode_tensor_res + NodeIdTensorDict[chid] 320 | if cfg.aognet.use_elem_max_for_ORNodes: 321 | onode_tensor_op = torch.max(onode_tensor_op, NodeIdTensorDict[chid]) 322 | else: 323 | onode_tensor_op = onode_tensor_op + NodeIdTensorDict[chid] 324 | 325 | for chid in node.child_ids[1:]: 326 | ch_arr = self.primitive_set[self.node_set[chid].rect_idx] 327 | ids2 = self.node_set[chid].parent_ids 328 | if self.node_set[chid].node_type == NodeType.OrNode and \ 329 | len(ids1.intersection(ids2)) > num_shared: 330 | if cfg.aognet.use_elem_max_for_ORNodes: 331 | onode_tensor_op = torch.max(onode_tensor_op, NodeIdTensorDict[chid]) 332 | else: 333 | onode_tensor_op = onode_tensor_op + NodeIdTensorDict[chid] 334 | elif self.node_set[chid].node_type == NodeType.TerminalNode and \ 335 | arr.Width() < ch_arr.Width(): 336 | ch_left = self.out_slices[arr.x1] - self.out_slices[ch_arr.x1] 337 | ch_right = self.out_slices[arr.x2 + 1] - self.out_slices[ch_arr.x1] 338 | if cfg.aognet.use_elem_max_for_ORNodes: 339 | onode_tensor_op = torch.max(onode_tensor_op, NodeIdTensorDict[chid][:, ch_left:ch_right, :, :]) 340 | else: 341 | onode_tensor_op = onode_tensor_op + NodeIdTensorDict[chid][:, ch_left:ch_right, :, :]#.contiguous() 342 | 343 | onode_output = getattr(self, op_name)(onode_tensor_op, onode_tensor_res) 344 | else: 345 | if cfg.aognet.use_elem_max_for_ORNodes: 346 | onode_tensor_op = NodeIdTensorDict[node.child_ids[0]] 347 | onode_tensor_res = NodeIdTensorDict[node.child_ids[0]] 348 | for chid in node.child_ids[1:]: 349 | if self.node_set[chid].node_type != NodeType.OrNode: 350 | ch_arr = self.primitive_set[self.node_set[chid].rect_idx] 351 | if arr.Width() == ch_arr.Width(): 352 | onode_tensor_op = torch.max(onode_tensor_op, NodeIdTensorDict[chid]) 353 | onode_tensor_res = onode_tensor_res + NodeIdTensorDict[chid] 354 | 355 | if self.hasLateral[node.id]: 356 | ids1 = set(node.parent_ids) 357 | num_shared = 0 358 | for chid in node.child_ids[1:]: 359 | ids2 = self.node_set[chid].parent_ids 360 | if self.node_set[chid].node_type == NodeType.OrNode and \ 361 | len(ids1.intersection(ids2)) == num_shared: 362 | onode_tensor_op = torch.max(onode_tensor_op, NodeIdTensorDict[chid]) 363 | onode_tensor_res = onode_tensor_res + NodeIdTensorDict[chid] 364 | 365 | for chid in node.child_ids[1:]: 366 | ch_arr = self.primitive_set[self.node_set[chid].rect_idx] 367 | ids2 = self.node_set[chid].parent_ids 368 | if self.node_set[chid].node_type == NodeType.OrNode and \ 369 | len(ids1.intersection(ids2)) > num_shared: 370 | onode_tensor_op = torch.max(onode_tensor_op, NodeIdTensorDict[chid]) 371 | elif self.node_set[chid].node_type == NodeType.TerminalNode and \ 372 | arr.Width() < ch_arr.Width(): 373 | ch_left = self.out_slices[arr.x1] - self.out_slices[ch_arr.x1] 374 | ch_right = self.out_slices[arr.x2 + 1] - self.out_slices[ch_arr.x1] 375 | onode_tensor_op = torch.max(onode_tensor_op, NodeIdTensorDict[chid][:, ch_left:ch_right, :, :]) 376 | 377 | onode_output = getattr(self, op_name)(onode_tensor_op, onode_tensor_res) 378 | else: 379 | onode_output = getattr(self, op_name)(onode_tensor_op) 380 | else: 381 | onode_tensor_op = NodeIdTensorDict[node.child_ids[0]] 382 | for chid in node.child_ids[1:]: 383 | if self.node_set[chid].node_type != NodeType.OrNode: 384 | ch_arr = self.primitive_set[self.node_set[chid].rect_idx] 385 | if arr.Width() == ch_arr.Width(): 386 | onode_tensor_op = onode_tensor_op + NodeIdTensorDict[chid] 387 | 388 | if self.hasLateral[node.id]: 389 | ids1 = set(node.parent_ids) 390 | num_shared = 0 391 | for chid in node.child_ids[1:]: 392 | ids2 = self.node_set[chid].parent_ids 393 | if self.node_set[chid].node_type == NodeType.OrNode and \ 394 | len(ids1.intersection(ids2)) == num_shared: 395 | onode_tensor_op = onode_tensor_op + NodeIdTensorDict[chid] 396 | 397 | onode_tensor_res = onode_tensor_op 398 | 399 | for chid in node.child_ids[1:]: 400 | ch_arr = self.primitive_set[self.node_set[chid].rect_idx] 401 | ids2 = self.node_set[chid].parent_ids 402 | if self.node_set[chid].node_type == NodeType.OrNode and \ 403 | len(ids1.intersection(ids2)) > num_shared: 404 | onode_tensor_op = onode_tensor_op + NodeIdTensorDict[chid] 405 | elif self.node_set[chid].node_type == NodeType.TerminalNode and \ 406 | arr.Width() < ch_arr.Width(): 407 | ch_left = self.out_slices[arr.x1] - self.out_slices[ch_arr.x1] 408 | ch_right = self.out_slices[arr.x2 + 1] - self.out_slices[ch_arr.x1] 409 | onode_tensor_op = onode_tensor_op + NodeIdTensorDict[chid][:, ch_left:ch_right, :, :]#.contiguous() 410 | 411 | onode_output = getattr(self, op_name)(onode_tensor_op, onode_tensor_res) 412 | else: 413 | onode_output = getattr(self, op_name)(onode_tensor_op) 414 | 415 | NodeIdTensorDict[node.id] = onode_output 416 | 417 | out = NodeIdTensorDict[self.aog.BFS[0]] 418 | out = self.extra_norm_ac(out) #TODO: Why this? Analyze it in depth 419 | return out 420 | 421 | ### AOGNet 422 | class AOGNet(nn.Module): 423 | def __init__(self, block=AOGBlock): 424 | super(AOGNet, self).__init__() 425 | filter_list = cfg.aognet.filter_list 426 | self.aogs = self._create_aogs() 427 | self.block = block 428 | if "BatchNorm2d" in cfg.norm_name: 429 | self.norm_name_base = "BatchNorm2d" 430 | elif "GroupNorm" in cfg.norm_name: 431 | self.norm_name_base = "GroupNorm" 432 | else: 433 | raise ValueError("Unknown norm layer") 434 | 435 | if "Mixture" in cfg.norm_name: 436 | assert len(cfg.norm_k) == len(filter_list)-1 and any(cfg.norm_k), \ 437 | "Wrong mixture component specification (cfg.norm_k)" 438 | else: 439 | cfg.norm_k = [0 for i in range(len(filter_list)-1)] 440 | 441 | self.stem = self._stem(filter_list[0]) 442 | 443 | self.stage0 = self._make_stage(stage=0, in_channels=filter_list[0], out_channels=filter_list[1]) 444 | self.stage1 = self._make_stage(stage=1, in_channels=filter_list[1], out_channels=filter_list[2]) 445 | self.stage2 = self._make_stage(stage=2, in_channels=filter_list[2], out_channels=filter_list[3]) 446 | self.stage3 = None 447 | outchannels = filter_list[3] 448 | if cfg.dataset == 'imagenet': 449 | self.stage3 = self._make_stage(stage=3, in_channels=filter_list[3], out_channels=filter_list[4]) 450 | outchannels = filter_list[4] 451 | 452 | self.conv_head = None 453 | if any(cfg.aognet.out_channels): 454 | assert len(cfg.aognet.out_channels) == 2 455 | self.conv_head = nn.Sequential(Conv_Norm_AC(outchannels, cfg.aognet.out_channels[0], 1, 1, 0, 456 | ac_mode=cfg.activation_mode, 457 | norm_name=self.norm_name_base, 458 | norm_groups=cfg.norm_groups, 459 | norm_k=cfg.norm_k[-1], 460 | norm_attention_mode=cfg.norm_attention_mode), 461 | nn.AdaptiveAvgPool2d((1, 1)), 462 | Conv_Norm_AC(cfg.aognet.out_channels[0], cfg.aognet.out_channels[1], 1, 1, 0, 463 | ac_mode=cfg.activation_mode, 464 | norm_name=self.norm_name_base, 465 | norm_groups=cfg.norm_groups, 466 | norm_k=cfg.norm_k[-1], 467 | norm_attention_mode=cfg.norm_attention_mode) 468 | ) 469 | outchannels = cfg.aognet.out_channels[1] 470 | else: 471 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) 472 | self.fc = nn.Linear(outchannels, cfg.num_classes) 473 | 474 | ## initialize 475 | self._init_params() 476 | 477 | def _stem(self, cout): 478 | layers = [] 479 | if cfg.dataset == 'imagenet': 480 | if cfg.stem.imagenet_head7x7: 481 | layers.append( Conv_Norm_AC(3, cout, 7, 2, 3, 482 | ac_mode=cfg.activation_mode, 483 | norm_name=self.norm_name_base, 484 | norm_groups=cfg.norm_groups, 485 | norm_k=cfg.norm_k[0], 486 | norm_attention_mode=cfg.norm_attention_mode) ) 487 | else: 488 | plane = cout // 2 489 | layers.append( Conv_Norm_AC(3, plane, 3, 2, 1, 490 | ac_mode=cfg.activation_mode, 491 | norm_name=self.norm_name_base, 492 | norm_groups=cfg.norm_groups, 493 | norm_k=cfg.norm_k[0], 494 | norm_attention_mode=cfg.norm_attention_mode) ) 495 | layers.append( Conv_Norm_AC(plane, plane, 3, 1, 1, 496 | ac_mode=cfg.activation_mode, 497 | norm_name=self.norm_name_base, 498 | norm_groups=cfg.norm_groups, 499 | norm_k=cfg.norm_k[0], 500 | norm_attention_mode=cfg.norm_attention_mode) ) 501 | layers.append( Conv_Norm_AC(plane, cout, 3, 1, 1, 502 | ac_mode=cfg.activation_mode, 503 | norm_name=self.norm_name_base, 504 | norm_groups=cfg.norm_groups, 505 | norm_k=cfg.norm_k[0], 506 | norm_attention_mode=cfg.norm_attention_mode) ) 507 | if cfg.stem.replace_maxpool_with_res_bottleneck: 508 | layers.append( NodeOpSingleScale(cout, cout, 2, 509 | ac_mode=cfg.activation_mode, 510 | bn_ratio=cfg.aognet.bottleneck_ratio, 511 | norm_name=self.norm_name_base, 512 | norm_groups=cfg.norm_groups, 513 | norm_k = cfg.norm_k[0], 514 | norm_attention_mode=cfg.norm_attention_mode, 515 | replace_stride_with_avgpool=True) ) # used in OctConv 516 | else: 517 | layers.append( nn.MaxPool2d(2, 2) ) 518 | elif cfg.dataset == 'cifar10' or cfg.dataset == 'cifar100': 519 | layers.append( Conv_Norm_AC(3, cout, 3, 1, 1, 520 | ac_mode=cfg.activation_mode, 521 | norm_name=self.norm_name_base, 522 | norm_groups=cfg.norm_groups, 523 | norm_k=cfg.norm_k[0], 524 | norm_attention_mode=cfg.norm_attention_mode) ) 525 | else: 526 | raise NotImplementedError 527 | 528 | return nn.Sequential(*layers) 529 | 530 | def _init_params(self): 531 | for m in self.modules(): 532 | if isinstance(m, nn.Conv2d): 533 | if cfg.init_mode == 'xavier': 534 | nn.init.xavier_normal_(m.weight) 535 | elif cfg.init_mode == 'avg': 536 | n = m.kernel_size[0] * m.kernel_size[1] * (m.in_channels + m.out_channels) / 2 537 | m.weight.data.normal_(0, math.sqrt(2. / n)) 538 | else: # cfg.init_mode == 'kaiming': as default 539 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 540 | 541 | for name, _ in m.named_parameters(): 542 | if name in ['bias']: 543 | nn.init.constant_(m.bias, 0.0) 544 | elif isinstance(m, (MixtureBatchNorm2d, MixtureGroupNorm)): # before BatchNorm2d 545 | nn.init.normal_(m.weight_, 1, 0.1) 546 | nn.init.normal_(m.bias_, 0, 0.1) 547 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): 548 | nn.init.constant_(m.weight, 1.0) 549 | nn.init.constant_(m.bias, 0.0) 550 | 551 | # handle dbl cnt in init 552 | if cfg.aognet.handle_dbl_cnt_in_param_init: 553 | import re 554 | for name_, m in self.named_modules(): 555 | if 'node' in name_: 556 | idx = re.findall(r'\d+', name_) 557 | sid = int(idx[0]) 558 | nid = int(idx[2]) 559 | npaths = self.aogs[sid].node_set[nid].npaths 560 | if npaths > 1: 561 | scale = 1.0 / npaths 562 | with torch.no_grad(): 563 | for ch in m.modules(): 564 | if isinstance(ch, nn.Conv2d): 565 | ch.weight.mul_(scale) 566 | 567 | # TODO: handle zero-gamma in the last norm layer of bottleneck op 568 | 569 | def _create_aogs(self): 570 | aogs = [] 571 | num_stages = len(cfg.aognet.filter_list) - 1 572 | for i in range(num_stages): 573 | grid_ht = 1 574 | grid_wd = int(cfg.aognet.dims[i]) 575 | aogs.append(get_aog(grid_ht=grid_ht, grid_wd=grid_wd, max_split=cfg.aognet.max_split[i], 576 | use_tnode_topdown_connection= cfg.aognet.extra_node_hierarchy[i] == 1, 577 | use_tnode_bottomup_connection_layerwise= cfg.aognet.extra_node_hierarchy[i] == 2, 578 | use_tnode_bottomup_connection_sequential= cfg.aognet.extra_node_hierarchy[i] == 3, 579 | use_node_lateral_connection= cfg.aognet.extra_node_hierarchy[i] == 4, 580 | use_tnode_bottomup_connection= cfg.aognet.extra_node_hierarchy[i] == 5, 581 | use_node_lateral_connection_1= cfg.aognet.extra_node_hierarchy[i] == 6, 582 | remove_symmetric_children_of_or_node=cfg.aognet.remove_symmetric_children_of_or_node[i] 583 | )) 584 | 585 | return aogs 586 | 587 | def _make_stage(self, stage, in_channels, out_channels): 588 | blocks = nn.Sequential() 589 | dim = cfg.aognet.dims[stage] 590 | assert in_channels % dim == 0 and out_channels % dim == 0 591 | step_channels = (out_channels - in_channels) // cfg.aognet.blocks[stage] 592 | if step_channels % dim != 0: 593 | low = (step_channels // dim) * dim 594 | high = (step_channels // dim + 1) * dim 595 | if (step_channels-low) <= (high-step_channels): 596 | step_channels = low 597 | else: 598 | step_channels = high 599 | 600 | aog = self.aogs[stage] 601 | for j in range(cfg.aognet.blocks[stage]): 602 | name_ = 'stage_{}_block_{}'.format(stage, j) 603 | drop_rate = cfg.aognet.drop_rate[stage] 604 | stride = cfg.aognet.stride[stage] if j==0 else 1 605 | outchannels = (in_channels + step_channels) if j < cfg.aognet.blocks[stage]-1 else out_channels 606 | if stride > 1 and cfg.aognet.when_downsample == 1: 607 | blocks.add_module(name_ + '_transition', 608 | nn.Sequential( Conv_Norm_AC(in_channels, in_channels, 1, 1, 0, 609 | ac_mode=cfg.activation_mode, 610 | norm_name=self.norm_name_base, 611 | norm_groups=cfg.norm_groups, 612 | norm_k=cfg.norm_k[stage], 613 | norm_attention_mode=cfg.norm_attention_mode, 614 | replace_stride_with_avgpool=False), 615 | nn.AvgPool2d(kernel_size=(stride, stride), stride=stride) 616 | ) 617 | ) 618 | stride = 1 619 | elif (stride > 1 or in_channels != outchannels) and cfg.aognet.when_downsample == 2: 620 | trans_op = [Conv_Norm_AC(in_channels, outchannels, 1, 1, 0, 621 | ac_mode=cfg.activation_mode, 622 | norm_name=self.norm_name_base, 623 | norm_groups=cfg.norm_groups, 624 | norm_k=cfg.norm_k[stage], 625 | norm_attention_mode=cfg.norm_attention_mode, 626 | replace_stride_with_avgpool=False)] 627 | if stride > 1: 628 | trans_op.append(nn.AvgPool2d(kernel_size=(stride, stride), stride=stride)) 629 | blocks.add_module(name_ + '_transition', nn.Sequential(*trans_op)) 630 | stride = 1 631 | in_channels = outchannels 632 | 633 | blocks.add_module(name_, self.block(stage, j, aog, in_channels, outchannels, drop_rate, stride)) 634 | in_channels = outchannels 635 | 636 | return blocks 637 | 638 | def forward(self, x): 639 | y = self.stem(x) 640 | 641 | y = self.stage0(y) 642 | y = self.stage1(y) 643 | y = self.stage2(y) 644 | if self.stage3 is not None: 645 | y = self.stage3(y) 646 | if self.conv_head is not None: 647 | y = self.conv_head(y) 648 | else: 649 | y = self.avgpool(y) 650 | y = y.view(y.size(0), -1) 651 | y = self.fc(y) 652 | 653 | return y 654 | 655 | def aognet_singlescale(**kwargs): 656 | ''' 657 | Construct a single scale AOGNet model 658 | ''' 659 | return AOGNet(**kwargs) 660 | -------------------------------------------------------------------------------- /backbones/aognet/config.py: -------------------------------------------------------------------------------- 1 | from yacs.config import CfgNode as CN 2 | 3 | _C = CN() 4 | _C.batch_size = 128 5 | _C.num_epoch = 300 6 | _C.dataset = 'cifar10' 7 | _C.num_classes = 10 8 | _C.crop_size = 224 # imagenet 9 | _C.crop_interpolation = 2 # 2=BILINEAR, default; 3=BICUBIC 10 | _C.optimizer = 'SGD' 11 | _C.gamma = 0.1 # decay_rate 12 | _C.use_cosine_lr = False 13 | _C.cosine_lr_min = 0.0 14 | _C.warmup_epochs = 5 15 | _C.lr = 0.1 16 | _C.lr_scale_factor = 256 # per nvidia apex 17 | _C.lr_milestones = [150, 225] 18 | _C.momentum = 0.9 19 | _C.wd = 5e-4 20 | _C.nesterov = False 21 | _C.activation_mode = 0 # 1: leakyReLU, 2: ReLU6 , other: ReLU 22 | _C.init_mode = 'kaiming' 23 | _C.norm_name = 'BatchNorm2d' 24 | _C.norm_groups = 0 25 | _C.norm_k = [0] 26 | _C.norm_attention_mode = 0 27 | _C.norm_zero_gamma_init = False 28 | _C.norm_all_mix = False 29 | 30 | # data augmentation 31 | _C.dataaug = CN() 32 | _C.dataaug.imagenet_extra_aug = False 33 | _C.dataaug.labelsmoothing_rate = 0. # 0.1 34 | _C.dataaug.mixup_rate = 0. # 0.2 35 | 36 | # stem 37 | _C.stem = CN() 38 | _C.stem.imagenet_head7x7 = False 39 | _C.stem.replace_maxpool_with_res_bottleneck = False 40 | _C.stem.stem_kernel_size = 7 41 | _C.stem.stem_stride = 2 42 | 43 | 44 | # aognet 45 | _C.aognet = CN() 46 | _C.aognet.filter_list = [16, 64, 128, 256] 47 | _C.aognet.out_channels = [0,0] 48 | _C.aognet.blocks = [1, 1, 1] 49 | _C.aognet.dims = [4, 4, 4] 50 | _C.aognet.max_split = [2, 2, 2] # must >= 2 51 | _C.aognet.extra_node_hierarchy = [0, 0, 0] # 0: none, 1: tnode topdown, 2: tnode bottomup layerwise, 3: tnode bottomup sequential, 4: non-term node lateral, 5: tnode bottomup 52 | _C.aognet.remove_symmetric_children_of_or_node = [0, 0, 0] 53 | _C.aognet.terminal_node_no_slice = [0, 0, 0] 54 | _C.aognet.stride = [1, 2, 2] 55 | _C.aognet.drop_rate = [0.0, 0.0, 0.0] 56 | _C.aognet.bottleneck_ratio = 0.25 57 | _C.aognet.handle_dbl_cnt = True 58 | _C.aognet.handle_tnode_dbl_cnt = False 59 | _C.aognet.handle_dbl_cnt_in_param_init = False 60 | _C.aognet.use_group_conv = False 61 | _C.aognet.width_per_group = 0 62 | _C.aognet.when_downsample = 0 # 0: at T-nodes, 1: before a aogblock, by conv_norm_ac + avgpool 63 | _C.aognet.replace_stride_with_avgpool = True # for downsample in node op. 64 | _C.aognet.use_elem_max_for_ORNodes = False 65 | 66 | cfg = _C 67 | -------------------------------------------------------------------------------- /backbones/aognet/operator_basic.py: -------------------------------------------------------------------------------- 1 | """ RESEARCH ONLY LICENSE 2 | Copyright (c) 2018-2019 North Carolina State University. 3 | All rights reserved. 4 | Redistribution and use in source and binary forms, with or without modification, are permitted provided 5 | that the following conditions are met: 6 | 1. Redistributions and use are permitted for internal research purposes only, and commercial use 7 | is strictly prohibited under this license. Inquiries regarding commercial use should be directed to the 8 | Office of Research Commercialization at North Carolina State University, 919-215-7199, 9 | https://research.ncsu.edu/commercialization/contact/, commercialization@ncsu.edu . 10 | 2. Commercial use means the sale, lease, export, transfer, conveyance or other distribution to a 11 | third party for financial gain, income generation or other commercial purposes of any kind, whether 12 | direct or indirect. Commercial use also means providing a service to a third party for financial gain, 13 | income generation or other commercial purposes of any kind, whether direct or indirect. 14 | 3. Redistributions of source code must retain the above copyright notice, this list of conditions and 15 | the following disclaimer. 16 | 4. Redistributions in binary form must reproduce the above copyright notice, this list of conditions 17 | and the following disclaimer in the documentation and/or other materials provided with the 18 | distribution. 19 | 5. The names “North Carolina State University”, “NCSU” and any trade-name, personal name, 20 | trademark, trade device, service mark, symbol, image, icon, or any abbreviation, contraction or 21 | simulation thereof owned by North Carolina State University must not be used to endorse or promote 22 | products derived from this software without prior written permission. For written permission, please 23 | contact trademarks@ncsu.edu. 24 | Disclaimer: THIS SOFTWARE IS PROVIDED “AS IS” AND ANY EXPRESSED OR IMPLIED WARRANTIES, 25 | INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR 26 | A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NORTH CAROLINA STATE UNIVERSITY BE 27 | LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES 28 | (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, 29 | DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF 30 | LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR 31 | OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE 32 | POSSIBILITY OF SUCH DAMAGE. 33 | """ 34 | 35 | from __future__ import absolute_import 36 | from __future__ import division 37 | from __future__ import print_function # force to use print as function print(args) 38 | from __future__ import unicode_literals 39 | 40 | import torch 41 | import torch.nn as nn 42 | import torch.nn.functional as F 43 | 44 | _inplace = True 45 | _norm_eps = 1e-5 46 | 47 | def to_int(x): 48 | if x - int(x) < 0.5: 49 | return int(x) 50 | else: 51 | return int(x) + 1 52 | 53 | ### Activation 54 | class AC(nn.Module): 55 | def __init__(self, mode): 56 | super(AC, self).__init__() 57 | if mode == 1: 58 | self.ac = nn.LeakyReLU(inplace=_inplace) 59 | elif mode == 2: 60 | self.ac = nn.ReLU6(inplace=_inplace) 61 | else: 62 | self.ac = nn.ReLU(inplace=_inplace) 63 | 64 | def forward(self, x): 65 | x = self.ac(x) 66 | return x 67 | 68 | ### From mobilenet_v3 69 | class hsigmoid(nn.Module): 70 | def forward(self, x): 71 | out = F.relu6(x + 3, inplace=True) / 6 72 | return out 73 | 74 | ### Feature Norm 75 | def FeatureNorm(norm_name, num_channels, num_groups, num_k, attention_mode): 76 | if norm_name == "BatchNorm2d": 77 | return nn.BatchNorm2d(num_channels, eps=_norm_eps) 78 | elif norm_name == "GroupNorm": 79 | assert num_groups > 1 80 | if num_channels % num_groups != 0: 81 | raise ValueError("channels {} not dividable by groups {}".format(num_channels, num_groups)) 82 | return nn.GroupNorm(num_channels, num_groups, eps=_norm_eps) 83 | elif norm_name == "MixtureBatchNorm2d": 84 | assert num_k > 1 85 | return MixtureBatchNorm2d(num_channels, num_k, attention_mode) 86 | elif norm_name == "MixtureGroupNorm": 87 | assert num_groups > 1 and num_k > 1 88 | if num_channels % num_groups != 0: 89 | raise ValueError("channels {} not dividable by groups {}".format(num_channels, num_groups)) 90 | return MixtureGroupNorm(num_channels, num_groups, num_k, attention_mode) 91 | else: 92 | raise NotImplementedError("Unknown feature norm name") 93 | 94 | ### Attention weights for mixture norm 95 | class AttentionWeights(nn.Module): 96 | expansion = 2 97 | def __init__(self, attention_mode, num_channels, k, 98 | norm_name=None, norm_groups=0): 99 | super(AttentionWeights, self).__init__() 100 | self.k = k 101 | self.avgpool = nn.AdaptiveAvgPool2d(1) 102 | layers = [] 103 | if attention_mode == 0: 104 | layers = [ nn.Conv2d(num_channels, k, 1), 105 | nn.Sigmoid() ] 106 | elif attention_mode == 1: 107 | layers = [ nn.Conv2d(num_channels, k*self.expansion, 1), 108 | nn.ReLU(inplace=True), 109 | nn.Conv2d(k*self.expansion, k, 1), 110 | nn.Sigmoid() ] 111 | elif attention_mode == 2: 112 | assert norm_name is not None 113 | layers = [ nn.Conv2d(num_channels, k, 1, bias=False), 114 | FeatureNorm(norm_name, k, norm_groups, 0, 0), 115 | hsigmoid() ] 116 | elif attention_mode == 3: 117 | assert norm_name is not None 118 | layers = [ nn.Conv2d(num_channels, k*self.expansion, 1, bias=False), 119 | FeatureNorm(norm_name, k*self.expansion, norm_groups, 0, 0), 120 | nn.ReLU(inplace=True), 121 | nn.Conv2d(k*self.expansion, k, 1, bias=False), 122 | FeatureNorm(norm_name, k, norm_groups, 0, 0), 123 | hsigmoid() ] 124 | else: 125 | raise NotImplementedError("Unknow attention weight type") 126 | self.attention = nn.Sequential(*layers) 127 | 128 | def forward(self, x): 129 | b, c, _, _ = x.size() 130 | y = self.avgpool(x)#.view(b, c) 131 | return self.attention(y).view(b, self.k) 132 | 133 | 134 | ### Mixture Norm 135 | # TODO: keep it to use FP32 always, need to figure out how to set it using apex ? 136 | class MixtureBatchNorm2d(nn.BatchNorm2d): 137 | def __init__(self, num_channels, k, attention_mode, eps=_norm_eps, momentum=0.1, 138 | track_running_stats=True): 139 | super(MixtureBatchNorm2d, self).__init__(num_channels, eps=eps, 140 | momentum=momentum, affine=False, track_running_stats=track_running_stats) 141 | self.k = k 142 | self.weight_ = nn.Parameter(torch.Tensor(k, num_channels)) 143 | self.bias_ = nn.Parameter(torch.Tensor(k, num_channels)) 144 | 145 | self.attention_weights = AttentionWeights(attention_mode, num_channels, k, 146 | norm_name='BatchNorm2d') 147 | 148 | self._init_params() 149 | 150 | def _init_params(self): 151 | nn.init.normal_(self.weight_, 1, 0.1) 152 | nn.init.normal_(self.bias_, 0, 0.1) 153 | 154 | def forward(self, x): 155 | output = super(MixtureBatchNorm2d, self).forward(x) 156 | size = output.size() 157 | y = self.attention_weights(x) # bxk # or use output as attention input 158 | 159 | weight = y @ self.weight_ # bxc 160 | bias = y @ self.bias_ # bxc 161 | weight = weight.unsqueeze(-1).unsqueeze(-1).expand(size) 162 | bias = bias.unsqueeze(-1).unsqueeze(-1).expand(size) 163 | 164 | return weight * output + bias 165 | 166 | 167 | # Modified on top of nn.GroupNorm 168 | # TODO: keep it to use FP32 always, need to figure out how to set it using apex ? 169 | class MixtureGroupNorm(nn.Module): 170 | __constants__ = ['num_groups', 'num_channels', 'k', 'eps', 'weight', 171 | 'bias'] 172 | 173 | def __init__(self, num_channels, num_groups, k, attention_mode, eps=_norm_eps): 174 | super(MixtureGroupNorm, self).__init__() 175 | self.num_groups = num_groups 176 | self.num_channels = num_channels 177 | self.k = k 178 | self.eps = eps 179 | self.affine = True 180 | self.weight_ = nn.Parameter(torch.Tensor(k, num_channels)) 181 | self.bias_ = nn.Parameter(torch.Tensor(k, num_channels)) 182 | self.register_parameter('weight', None) 183 | self.register_parameter('bias', None) 184 | 185 | self.attention_weights = AttentionWeights(attention_mode, num_channels, k, 186 | norm_name='GroupNorm', norm_groups=1) 187 | 188 | self.reset_parameters() 189 | 190 | def reset_parameters(self): 191 | nn.init.normal_(self.weight_, 1, 0.1) 192 | nn.init.normal_(self.bias_, 0, 0.1) 193 | 194 | def forward(self, x): 195 | output = F.group_norm( 196 | x, self.num_groups, self.weight, self.bias, self.eps) 197 | size = output.size() 198 | 199 | y = self.attention_weights(x) # TODO: use output as attention input 200 | 201 | weight = y @ self.weight_ 202 | bias = y @ self.bias_ 203 | 204 | weight = weight.unsqueeze(-1).unsqueeze(-1).expand(size) 205 | bias = bias.unsqueeze(-1).unsqueeze(-1).expand(size) 206 | 207 | return weight * output + bias 208 | 209 | def extra_repr(self): 210 | return '{num_groups}, {num_channels}, eps={eps}, ' \ 211 | 'affine={affine}'.format(**self.__dict__) 212 | 213 | 214 | 215 | 216 | -------------------------------------------------------------------------------- /backbones/aognet/operator_singlescale.py: -------------------------------------------------------------------------------- 1 | """ RESEARCH ONLY LICENSE 2 | Copyright (c) 2018-2019 North Carolina State University. 3 | All rights reserved. 4 | Redistribution and use in source and binary forms, with or without modification, are permitted provided 5 | that the following conditions are met: 6 | 1. Redistributions and use are permitted for internal research purposes only, and commercial use 7 | is strictly prohibited under this license. Inquiries regarding commercial use should be directed to the 8 | Office of Research Commercialization at North Carolina State University, 919-215-7199, 9 | https://research.ncsu.edu/commercialization/contact/, commercialization@ncsu.edu . 10 | 2. Commercial use means the sale, lease, export, transfer, conveyance or other distribution to a 11 | third party for financial gain, income generation or other commercial purposes of any kind, whether 12 | direct or indirect. Commercial use also means providing a service to a third party for financial gain, 13 | income generation or other commercial purposes of any kind, whether direct or indirect. 14 | 3. Redistributions of source code must retain the above copyright notice, this list of conditions and 15 | the following disclaimer. 16 | 4. Redistributions in binary form must reproduce the above copyright notice, this list of conditions 17 | and the following disclaimer in the documentation and/or other materials provided with the 18 | distribution. 19 | 5. The names “North Carolina State University”, “NCSU” and any trade-name, personal name, 20 | trademark, trade device, service mark, symbol, image, icon, or any abbreviation, contraction or 21 | simulation thereof owned by North Carolina State University must not be used to endorse or promote 22 | products derived from this software without prior written permission. For written permission, please 23 | contact trademarks@ncsu.edu. 24 | Disclaimer: THIS SOFTWARE IS PROVIDED “AS IS” AND ANY EXPRESSED OR IMPLIED WARRANTIES, 25 | INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR 26 | A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NORTH CAROLINA STATE UNIVERSITY BE 27 | LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES 28 | (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, 29 | DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF 30 | LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR 31 | OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE 32 | POSSIBILITY OF SUCH DAMAGE. 33 | """ 34 | 35 | from __future__ import absolute_import 36 | from __future__ import division 37 | from __future__ import print_function # force to use print as function print(args) 38 | from __future__ import unicode_literals 39 | 40 | import torch 41 | import torch.nn as nn 42 | import torch.nn.functional as F 43 | from torch.autograd import Variable 44 | 45 | from .operator_basic import * 46 | 47 | _bias = False 48 | _inplace = True 49 | 50 | ### Conv_Norm 51 | class Conv_Norm(nn.Module): 52 | def __init__(self, in_channels, out_channels, kernel_size, stride, padding, 53 | groups=1, drop_rate=0.0, 54 | norm_name='BatchNorm2d', norm_groups=0, norm_k=0, norm_attention_mode=0, 55 | replace_stride_with_avgpool=False): 56 | super(Conv_Norm, self).__init__() 57 | 58 | layers = [] 59 | if stride > 1 and replace_stride_with_avgpool: 60 | layers.append(nn.AvgPool2d(kernel_size=(stride, stride), stride=stride)) 61 | stride = 1 62 | layers.append(nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, 63 | stride=stride, padding=padding, 64 | groups=groups, bias=_bias)) 65 | layers.append(FeatureNorm(norm_name, out_channels, norm_groups, norm_k, norm_attention_mode)) 66 | if drop_rate > 0.0: 67 | layers.append(nn.Dropout2d(p=drop_rate, inplace=_inplace)) 68 | self.conv_norm = nn.Sequential(*layers) 69 | 70 | def forward(self, x): 71 | y = self.conv_norm(x) 72 | return y 73 | 74 | ### Conv_Norm_AC 75 | class Conv_Norm_AC(nn.Module): 76 | def __init__(self, in_channels, out_channels, kernel_size, stride, padding, 77 | groups=1, drop_rate=0., ac_mode=0, 78 | norm_name='BatchNorm2d', norm_groups=0, norm_k=0, norm_attention_mode=0, 79 | replace_stride_with_avgpool=False): 80 | super(Conv_Norm_AC, self).__init__() 81 | 82 | self.conv_norm = Conv_Norm(in_channels, out_channels, kernel_size, stride, padding, 83 | groups=groups, drop_rate=drop_rate, 84 | norm_name=norm_name, norm_groups=norm_groups, norm_k=norm_k, norm_attention_mode=norm_attention_mode, 85 | replace_stride_with_avgpool=replace_stride_with_avgpool) 86 | self.ac = AC(ac_mode) 87 | 88 | def forward(self, x): 89 | y = self.conv_norm(x) 90 | y = self.ac(y) 91 | return y 92 | 93 | ### NodeOpSingleScale 94 | class NodeOpSingleScale(nn.Module): 95 | def __init__(self, in_channels, out_channels, stride, 96 | groups=1, drop_rate=0., ac_mode=0, bn_ratio=0.25, 97 | norm_name='BatchNorm2d', norm_groups=0, norm_k=0, norm_attention_mode=0, 98 | replace_stride_with_avgpool=True): 99 | super(NodeOpSingleScale, self).__init__() 100 | if "BatchNorm2d" in norm_name: 101 | norm_name_base = "BatchNorm2d" 102 | elif "GroupNorm" in norm_name: 103 | norm_name_base = "GroupNorm" 104 | else: 105 | raise ValueError("Unknown norm layer") 106 | 107 | mid_channels = max(4, to_int(out_channels * bn_ratio / groups) * groups) 108 | self.conv_norm_ac_1 = Conv_Norm_AC(in_channels, mid_channels, 1, 1, 0, 109 | ac_mode=ac_mode, 110 | norm_name=norm_name_base, norm_groups=norm_groups, norm_k=norm_k, norm_attention_mode=norm_attention_mode) 111 | self.conv_norm_ac_2 = Conv_Norm_AC(mid_channels, mid_channels, 3, stride, 1, 112 | groups=groups, ac_mode=ac_mode, 113 | norm_name=norm_name, norm_groups=norm_groups, norm_k=norm_k, norm_attention_mode=norm_attention_mode, 114 | replace_stride_with_avgpool=False) 115 | self.conv_norm_3 = Conv_Norm(mid_channels, out_channels, 1, 1, 0, 116 | drop_rate=drop_rate, 117 | norm_name=norm_name_base, norm_groups=norm_groups, norm_k=norm_k, norm_attention_mode=norm_attention_mode) 118 | 119 | self.shortcut = None 120 | if in_channels != out_channels or stride > 1: 121 | self.shortcut = Conv_Norm(in_channels, out_channels, 1, stride, 0, 122 | norm_name=norm_name_base, norm_groups=norm_groups, norm_k=norm_k, norm_attention_mode=norm_attention_mode, 123 | replace_stride_with_avgpool=replace_stride_with_avgpool) 124 | 125 | self.ac = AC(ac_mode) 126 | 127 | def forward(self, x, res=None): 128 | residual = x if res is None else res 129 | y = self.conv_norm_ac_1(x) 130 | y = self.conv_norm_ac_2(y) 131 | y = self.conv_norm_3(y) 132 | 133 | if self.shortcut is not None: 134 | residual = self.shortcut(residual) 135 | 136 | y += residual 137 | y = self.ac(y) 138 | return y 139 | 140 | ### TODO: write a unit test for NodeOpSingleScale in a standalone way 141 | 142 | 143 | -------------------------------------------------------------------------------- /configs/aognet_cifar100_1M.yaml: -------------------------------------------------------------------------------- 1 | --- 2 | batch_size: 128 3 | num_epoch: 300 4 | dataset: 'cifar100' 5 | num_classes: 100 6 | use_cosine_lr: True 7 | cosine_lr_min: 0.0 8 | warmup_epochs: 5 9 | lr: 0.1 10 | lr_scale_factor: 256 11 | lr_milestones: [150, 225] 12 | momentum: 0.9 13 | wd: 0.0005 14 | nesterov: False 15 | activation_mode: 0 ### 1: leakyReLU, 2: , other: ReLU 16 | init_mode: 'kaiming' 17 | norm_name: 'MixtureBatchNorm2d' 18 | norm_groups: 0 19 | norm_k: [10, 10, 20] 20 | norm_attention_mode: 0 21 | norm_zero_gamma_init: False 22 | dataaug: 23 | labelsmoothing_rate: 0.0 24 | mixup_rate: 0.0 25 | stem: 26 | imagenet_head7x7: False 27 | aognet: 28 | filter_list: [32, 64, 128, 248] ### try to keep 1:2:2:2 ... except for the final stage which can be adusted for fitting the model size 29 | out_channels: [0, 0] 30 | blocks: [1, 1, 1, 1] 31 | dims: [4, 4, 4, 4] 32 | max_split: [2, 2, 2, 2] 33 | extra_node_hierarchy: [4, 4, 4, 4] # 0: none, 1: tnode topdown, 2: tnode bottomup layerwise, 3: tnode bottomup sequential, 4: or-node lateral connection, 5: tnode bottomup 34 | remove_symmetric_children_of_or_node: [1, 2, 1, 2] # if true, aog structure is much simplified and bigger filters and more units can be used 35 | terminal_node_no_slice: [0, 0, 0, 0] 36 | stride: [1, 2, 2, 2] 37 | drop_rate: [0.0, 0.0, 0.1, 0.1] 38 | bottleneck_ratio: 0.25 39 | handle_dbl_cnt: True 40 | handle_tnode_dbl_cnt: False 41 | handle_dbl_cnt_in_param_init: False 42 | use_group_conv: False 43 | width_per_group: 0 44 | when_downsample: 1 45 | replace_stride_with_avgpool: True 46 | use_elem_max_for_ORNodes: False 47 | 48 | 49 | 50 | -------------------------------------------------------------------------------- /configs/aognet_imagenet_12M.yaml: -------------------------------------------------------------------------------- 1 | --- 2 | batch_size: 256 3 | num_epoch: 120 4 | dataset: 'imagenet' 5 | num_classes: 1000 6 | crop_size: 224 7 | crop_interpolation: 2 ### 2: BILINEAR, 3:CUBIC 8 | use_cosine_lr: True ### 9 | cosine_lr_min: 0.0 10 | warmup_epochs: 5 11 | lr: 0.1 12 | lr_scale_factor: 256 13 | lr_milestones: [30, 60, 90, 100] 14 | momentum: 0.9 15 | wd: 0.0001 16 | nesterov: False 17 | activation_mode: 0 ### 1: leakyReLU, 2: ReLU6, other: ReLU 18 | init_mode: 'kaiming' 19 | norm_name: 'BatchNorm2d' 20 | norm_groups: 0 21 | norm_k: [10, 10, 20, 20] ### per stage 22 | norm_attention_mode: 2 23 | norm_zero_gamma_init: False 24 | dataaug: 25 | imagenet_extra_aug: False ### ColorJitter and PCA 26 | labelsmoothing_rate: 0.0 27 | mixup_rate: 0.0 28 | stem: 29 | imagenet_head7x7: False 30 | replace_maxpool_with_res_bottleneck: False 31 | aognet: 32 | max_split: [2, 2, 2, 2] 33 | extra_node_hierarchy: [4, 4, 4, 4] ### 0: none, 4: lateral connection 34 | remove_symmetric_children_of_or_node: [1, 2, 1, 2] ### if true, aog structure is much simplified and bigger filters and more units can be used 35 | terminal_node_no_slice: [0, 0, 0, 0] 36 | stride: [1, 2, 2, 2] 37 | drop_rate: [0.0, 0.0, 0.1, 0.1] 38 | bottleneck_ratio: 0.25 39 | handle_dbl_cnt: True 40 | handle_tnode_dbl_cnt: False 41 | handle_dbl_cnt_in_param_init: False 42 | use_group_conv: False 43 | width_per_group: 0 44 | when_downsample: 1 45 | replace_stride_with_avgpool: True ### in shortcut 46 | use_elem_max_for_ORNodes: False 47 | filter_list: [32, 128, 256, 512, 824] ### try to keep 1:4:2:2 ... except for the final stage which can be adusted for fitting the model size 48 | out_channels: [0, 0] 49 | blocks: [2, 2, 2, 1] 50 | dims: [2, 2, 4, 4] 51 | 52 | 53 | 54 | -------------------------------------------------------------------------------- /configs/aognet_imagenet_40M.yaml: -------------------------------------------------------------------------------- 1 | --- 2 | batch_size: 128 3 | num_epoch: 120 4 | dataset: 'imagenet' 5 | num_classes: 1000 6 | crop_size: 224 7 | crop_interpolation: 2 ### 2: BILINEAR, 3:CUBIC 8 | use_cosine_lr: True ### 9 | cosine_lr_min: 0.0 10 | warmup_epochs: 5 11 | lr: 0.1 12 | lr_scale_factor: 256 13 | lr_milestones: [30, 60, 90, 100] 14 | momentum: 0.9 15 | wd: 0.0001 16 | nesterov: False 17 | activation_mode: 0 ### 1: leakyReLU, 2: ReLU6, other: ReLU 18 | init_mode: 'kaiming' 19 | norm_name: 'BatchNorm2d' 20 | norm_groups: 0 21 | norm_k: [10, 10, 20, 20] ### per stage 22 | norm_attention_mode: 2 23 | norm_zero_gamma_init: False 24 | dataaug: 25 | imagenet_extra_aug: False ### ColorJitter and PCA 26 | labelsmoothing_rate: 0.0 27 | mixup_rate: 0.0 28 | stem: 29 | imagenet_head7x7: False 30 | replace_maxpool_with_res_bottleneck: False 31 | aognet: 32 | max_split: [2, 2, 2, 2] 33 | extra_node_hierarchy: [4, 4, 4, 4] ### 0: none, 4: lateral connection 34 | remove_symmetric_children_of_or_node: [1, 2, 1, 2] ### if true, aog structure is much simplified and bigger filters and more units can be used 35 | terminal_node_no_slice: [0, 0, 0, 0] 36 | stride: [1, 2, 2, 2] 37 | drop_rate: [0.0, 0.0, 0.1, 0.1] 38 | bottleneck_ratio: 0.25 39 | handle_dbl_cnt: True 40 | handle_tnode_dbl_cnt: False 41 | handle_dbl_cnt_in_param_init: False 42 | use_group_conv: False 43 | width_per_group: 0 44 | when_downsample: 1 45 | replace_stride_with_avgpool: True ### in shortcut 46 | use_elem_max_for_ORNodes: False 47 | filter_list: [56, 224, 448, 896, 1400] ### try to keep 1:4:2:2 ... except for the final stage which can be adusted for fitting the model size 48 | out_channels: [0, 0] 49 | blocks: [2, 2, 3, 1] 50 | dims: [2, 2, 4, 4] 51 | 52 | 53 | 54 | -------------------------------------------------------------------------------- /examples/kill_all_python.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | ps x | grep python | awk '{print $1}' | xargs kill -------------------------------------------------------------------------------- /examples/test_fp16.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | # Usage: test_fp16.sh arch_name pretrained_model_folder 4 | 5 | DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" >/dev/null 2>&1 && pwd )" 6 | 7 | if [ "$#" -ne 2 ]; then 8 | echo "Usage: test_fp16.sh arch_name pretrained_model_folder" 9 | exit 10 | fi 11 | 12 | ARCH_NAME=$1 13 | PRETRAINED_MODEL_PAHT=$2 14 | CONFIG_FILE=$PRETRAINED_MODEL_PAHT/config.yaml 15 | PRETRAINED_MODEL_FILE=$PRETRAINED_MODEL_PAHT/model.pth.tar 16 | 17 | ### Change accordingly 18 | GPUS=2,3,4,5,6,7 19 | NUM_GPUS=6 20 | NUM_WORKERS=8 21 | 22 | # ImageNet 23 | DATA=$DIR/../datasets/ILSVRC2015/Data/CLS-LOC/ 24 | 25 | # test 26 | CUDA_VISIBLE_DEVICES=$GPUS python -W ignore -m torch.distributed.launch --nproc_per_node=$NUM_GPUS \ 27 | $DIR/../tools/main_fp16.py -a $ARCH_NAME --cfg $CONFIG_FILE --workers $NUM_WORKERS \ 28 | --fp16 \ 29 | -p 100 --save-dir $PRETRAINED_MODEL_PAHT --pretrained --evaluate $DATA \ 30 | 2>&1 | tee $PRETRAINED_MODEL_PAHT/log_test.txt 31 | 32 | 33 | 34 | -------------------------------------------------------------------------------- /examples/train_fp16.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | # Usage: train_fp16.sh arch_name config_filename name_tag 4 | 5 | DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" >/dev/null 2>&1 && pwd )" 6 | 7 | if [ "$#" -ne 3 ]; then 8 | echo "Usage: train_fp16.sh arch_name relative_config_filename name_tag" 9 | exit 10 | fi 11 | 12 | ARCH_NAME=$1 13 | CONFIG_FILE=$DIR/../$2 14 | 15 | ### Change accordingly 16 | GPUS=0,1,2,3,4,5,6,7 17 | NUM_GPUS=8 18 | NUM_WORKERS=8 19 | 20 | CONFIG_FILENAME="$(cut -d'/' -f2 <<<$2)" 21 | CONFIG_BASE="${CONFIG_FILENAME%.*}" 22 | NAME_TAG=$3 23 | SAVE_DIR=$DIR/../results/$ARCH_NAME-$CONFIG_BASE-$NAME_TAG 24 | if [ -d $SAVE_DIR ]; then 25 | echo "$SAVE_DIR --- Already exists, try a different name tag or delete it first" 26 | exit 27 | else 28 | mkdir -p $SAVE_DIR 29 | fi 30 | 31 | # backup for reproducing results 32 | cp $CONFIG_FILE $SAVE_DIR/config.yaml 33 | cp -r $DIR/../backbones $SAVE_DIR 34 | cp $DIR/../tools/main_fp16.py $SAVE_DIR 35 | 36 | # ImageNet 37 | DATA=$DIR/../datasets/ILSVRC2015/Data/CLS-LOC/ 38 | 39 | # train 40 | CUDA_VISIBLE_DEVICES=$GPUS python -W ignore -m torch.distributed.launch --nproc_per_node=$NUM_GPUS \ 41 | $DIR/../tools/main_fp16.py -a $ARCH_NAME --cfg $CONFIG_FILE --workers $NUM_WORKERS \ 42 | --fp16 --static-loss-scale 128 \ 43 | -p 100 --save-dir $SAVE_DIR $DATA \ 44 | 2>&1 | tee $SAVE_DIR/log.txt 45 | 46 | -------------------------------------------------------------------------------- /images/teaser-imagenet-dissection.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/iVMCL/AOGNets/cddba00e97d6f74d2e7bfce50fd7aea9630ff996/images/teaser-imagenet-dissection.png -------------------------------------------------------------------------------- /images/teaser.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/iVMCL/AOGNets/cddba00e97d6f74d2e7bfce50fd7aea9630ff996/images/teaser.png -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | opencv-python 2 | thop 3 | yacs 4 | scipy 5 | -------------------------------------------------------------------------------- /tools/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/iVMCL/AOGNets/cddba00e97d6f74d2e7bfce50fd7aea9630ff996/tools/__init__.py -------------------------------------------------------------------------------- /tools/get_cifar.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function # force to use print as function print(args) 4 | from __future__ import unicode_literals 5 | 6 | import os 7 | import pickle 8 | import torchvision.datasets as datasets 9 | 10 | if __name__ == '__main__': 11 | dataset_folder = os.path.join(os.path.dirname(__file__), '../datasets') 12 | 13 | datasets.CIFAR10(dataset_folder, train=True, download=True) 14 | datasets.CIFAR10(dataset_folder, train=False, download=True) 15 | datasets.CIFAR100(dataset_folder, train=True, download=True) 16 | datasets.CIFAR100(dataset_folder, train=False, download=True) 17 | 18 | # get labels 19 | # file = os.path.join(os.path.dirname(__file__), '../datasets/cifar-100-python/meta') 20 | # with open(file, 'rb') as fo: 21 | # dict = pickle.load(fo, encoding='bytes') 22 | 23 | # print(dict) 24 | -------------------------------------------------------------------------------- /tools/main_fp16.py: -------------------------------------------------------------------------------- 1 | # From: https://github.com/NVIDIA/apex 2 | 3 | ### some tweaks 4 | # USE pillow-simd to speed up pytorch image loader 5 | # pip uninstall pillow 6 | # conda uninstall --force jpeg libtiff -y 7 | # conda install -c conda-forge libjpeg-turbo 8 | # CC="cc -mavx2" pip install --no-cache-dir -U --force-reinstall --no-binary :all: --compile pillow-simd 9 | 10 | # Install NCCL https://docs.nvidia.com/deeplearning/sdk/nccl-install-guide/index.html 11 | 12 | import argparse 13 | import os 14 | import shutil 15 | import time 16 | import copy 17 | 18 | import torch 19 | import torch.nn as nn 20 | import torch.nn.parallel 21 | import torch.backends.cudnn as cudnn 22 | import torch.distributed as dist 23 | import torch.optim 24 | import torch.utils.data 25 | import torch.utils.data.distributed 26 | import torchvision.transforms as transforms 27 | import torchvision.datasets as datasets 28 | 29 | import numpy as np 30 | 31 | try: 32 | from apex.parallel import DistributedDataParallel as DDP 33 | from apex.fp16_utils import * 34 | except ImportError: 35 | raise ImportError("Please install apex from https://www.github.com/nvidia/apex to run this example.") 36 | 37 | 38 | try: 39 | from thop import profile as thop_profile # compute params and flops 40 | except ImportError: 41 | raise ImportError("Please install https://github.com/Lyken17/pytorch-OpCounter") 42 | import math 43 | import sys 44 | import re 45 | sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..')) 46 | from backbones.aognet.operator_basic import MixtureBatchNorm2d, MixtureGroupNorm 47 | from backbones.aognet.aognet_singlescale import aognet_singlescale as aognet_s 48 | from backbones.aognet.config import cfg 49 | from smoothing import LabelSmoothing 50 | 51 | parser = argparse.ArgumentParser(description='PyTorch Image Classification Training') 52 | parser.add_argument('data', metavar='DIR', 53 | help='path to dataset') 54 | parser.add_argument('--arch', '-a', metavar='ARCH', default='aognet_s', 55 | help='arch') 56 | parser.add_argument('-j', '--workers', default=4, type=int, metavar='N', 57 | help='number of data loading workers (default: 4)') 58 | parser.add_argument('--epochs', default=90, type=int, metavar='N', 59 | help='number of total epochs to run') 60 | parser.add_argument('--start-epoch', default=0, type=int, metavar='N', 61 | help='manual epoch number (useful on restarts)') 62 | parser.add_argument('-b', '--batch-size', default=256, type=int, 63 | metavar='N', help='mini-batch size per process (default: 256)') 64 | parser.add_argument('--lr', '--learning-rate', default=0.1, type=float, 65 | metavar='LR', help='Initial learning rate. \ 66 | Will be scaled by /256: args.lr = args.lr*float(args.batch_size*args.world_size)/256. \ 67 | A warmup schedule will also be applied over the first 5 epochs.') 68 | parser.add_argument('--momentum', default=0.9, type=float, metavar='M', 69 | help='momentum') 70 | parser.add_argument('--weight-decay', '--wd', default=1e-4, type=float, 71 | metavar='W', help='weight decay (default: 1e-4)') 72 | parser.add_argument('--print-freq', '-p', default=10, type=int, 73 | metavar='N', help='print frequency (default: 10)') 74 | parser.add_argument('--resume', default='', type=str, metavar='PATH', 75 | help='path to latest checkpoint (default: none)') 76 | parser.add_argument('-e', '--evaluate', dest='evaluate', action='store_true', 77 | help='evaluate model on validation set') 78 | parser.add_argument('--pretrained', dest='pretrained', action='store_true', 79 | help='use pre-trained model') 80 | 81 | parser.add_argument('--fp16', action='store_true', 82 | help='Run model fp16 mode.') 83 | parser.add_argument('--static-loss-scale', type=float, default=1, 84 | help='Static loss scale, positive power of 2 values can improve fp16 convergence.') 85 | parser.add_argument('--dynamic-loss-scale', action='store_true', 86 | help='Use dynamic loss scaling. If supplied, this argument supersedes ' + 87 | '--static-loss-scale.') 88 | parser.add_argument('--prof', dest='prof', action='store_true', 89 | help='Only run 10 iterations for profiling.') 90 | parser.add_argument('--deterministic', action='store_true') 91 | 92 | parser.add_argument("--local_rank", default=0, type=int) 93 | parser.add_argument('--sync_bn', action='store_true', 94 | help='enabling apex sync BN.') 95 | 96 | parser.add_argument('--cfg', help='experiment configure file name', required=True, type=str) 97 | parser.add_argument('--save-dir', type=str, default='/tmp/models') 98 | parser.add_argument('--nesterov', type=str, default=None) 99 | parser.add_argument('--remove-norm-weight-decay', type=str, default=None) 100 | 101 | cudnn.benchmark = True 102 | 103 | def fast_collate(batch): 104 | imgs = [img[0] for img in batch] 105 | targets = torch.tensor([target[1] for target in batch], dtype=torch.int64) 106 | w = imgs[0].size[0] 107 | h = imgs[0].size[1] 108 | tensor = torch.zeros( (len(imgs), 3, h, w), dtype=torch.uint8 ) 109 | for i, img in enumerate(imgs): 110 | nump_array = np.asarray(img, dtype=np.uint8) 111 | if(nump_array.ndim < 3): 112 | nump_array = np.expand_dims(nump_array, axis=-1) 113 | nump_array = np.rollaxis(nump_array, 2) 114 | 115 | tensor[i] += torch.from_numpy(nump_array) 116 | 117 | return tensor, targets 118 | 119 | best_prec1 = 0 120 | best_prec1_val = 0 121 | prec5_val = 0 122 | best_prec5_val = 0 123 | 124 | args = parser.parse_args() 125 | 126 | if args.local_rank == 0: 127 | print("PyTorch VERSION: {}".format(torch.__version__)) # PyTorch version 128 | print("CUDA VERSION: {}".format(torch.version.cuda)) # Corresponding CUDA version 129 | print("CUDNN VERSION: {}".format(torch.backends.cudnn.version())) # Corresponding cuDNN version 130 | print("GPU TYPE: {}".format(torch.cuda.get_device_name(0))) # GPU type 131 | 132 | if args.deterministic: 133 | cudnn.benchmark = False 134 | cudnn.deterministic = True 135 | torch.manual_seed(args.local_rank) 136 | torch.set_printoptions(precision=10) 137 | 138 | def main(): 139 | global best_prec1, args 140 | 141 | args.distributed = False 142 | if 'WORLD_SIZE' in os.environ: 143 | args.distributed = int(os.environ['WORLD_SIZE']) > 1 144 | 145 | args.gpu = 0 146 | args.world_size = 1 147 | 148 | if args.distributed: 149 | args.gpu = args.local_rank 150 | torch.cuda.set_device(args.gpu) 151 | torch.distributed.init_process_group(backend='nccl', 152 | init_method='env://') 153 | args.world_size = torch.distributed.get_world_size() 154 | 155 | if args.fp16: 156 | assert torch.backends.cudnn.enabled, "fp16 requires cudnn backend to be enabled." 157 | if args.static_loss_scale != 1.0: 158 | if not args.fp16: 159 | print("Warning: if --fp16 is not used, static_loss_scale will be ignored.") 160 | 161 | # create model 162 | if args.pretrained: 163 | if args.local_rank == 0: 164 | print("=> using pre-trained model '{}'".format(args.arch)) 165 | if args.arch.startswith('aognet'): 166 | cfg.merge_from_file(os.path.join(args.save_dir, 'config.yaml')) 167 | 168 | model = aognet_m() if args.arch == 'aognet_m' else aognet_s() 169 | checkpoint = torch.load(os.path.join(args.save_dir, 'model_best.pth.tar')) 170 | # model.load_state_dict(checkpoint['state_dict']) 171 | elif args.arch.startswith('resnet'): 172 | model = resnets.__dict__[args.arch](pretrained=True) 173 | elif args.arch.startswith('mobilenet'): 174 | model = mobilenets.__dict__[args.arch](pretrained=True) 175 | else: 176 | raise NotImplementedError("Unkown network arch.") 177 | else: 178 | if args.local_rank == 0: 179 | print("=> creating {}".format(args.arch)) 180 | # update args 181 | cfg.merge_from_file(args.cfg) 182 | args.batch_size = cfg.batch_size 183 | args.lr = cfg.lr 184 | args.momentum = cfg.momentum 185 | args.weight_decay = cfg.wd 186 | args.nesterov = cfg.nesterov 187 | args.epochs = cfg.num_epoch 188 | if args.arch.startswith('aognet'): 189 | model = aognet_m() if args.arch == 'aognet_m' else aognet_s() 190 | elif args.arch.startswith('resnet'): 191 | model = resnets.__dict__[args.arch](zero_init_residual=cfg.norm_zero_gamma_init, num_classes=cfg.num_classes, 192 | replace_stride_with_dilation=cfg.resnet.replace_stride_with_dilation, 193 | dataset=cfg.dataset, base_inplanes=cfg.resnet.base_inplanes, 194 | imagenet_head7x7=cfg.stem.imagenet_head7x7, 195 | stem_kernel_size=cfg.stem.stem_kernel_size, stem_stride=cfg.stem.stem_stride, 196 | norm_name=cfg.norm_name, norm_groups=cfg.norm_groups, 197 | norm_k=cfg.norm_k, norm_attention_mode=cfg.norm_attention_mode, 198 | norm_all_mix=cfg.norm_all_mix, 199 | extra_norm_ac=cfg.resnet.extra_norm_ac, 200 | replace_stride_with_avgpool=cfg.resnet.replace_stride_with_avgpool) 201 | elif args.arch.startswith('MobileNetV3'): 202 | model = mobilenetsv3.__dict__[args.arch](norm_name=cfg.norm_name, 203 | norm_groups=cfg.norm_groups, 204 | norm_k=cfg.norm_k, 205 | norm_attention_mode=cfg.norm_attention_mode, 206 | rm_se=cfg.mobilenet.rm_se, 207 | use_mn_in_se=cfg.mobilenet.use_mn_in_se) 208 | elif args.arch.startswith('mobilenet'): 209 | model = mobilenets.__dict__[args.arch](norm_name=cfg.norm_name, 210 | norm_groups=cfg.norm_groups, 211 | norm_k=cfg.norm_k, 212 | norm_attention_mode=cfg.norm_attention_mode) 213 | elif args.arch.startswith('densenet'): 214 | model = densenets.__dict__[args.arch](num_classes=cfg.num_classes, 215 | imagenet_head7x7=cfg.stem.imagenet_head7x7, 216 | norm_name=cfg.norm_name, 217 | norm_groups=cfg.norm_groups, 218 | norm_k=cfg.norm_k, 219 | norm_attention_mode=cfg.norm_attention_mode) 220 | else: 221 | raise NotImplementedError("Unkown network arch.") 222 | 223 | if args.local_rank == 0: 224 | if cfg.dataset.startswith('cifar'): 225 | H, W = 32, 32 226 | elif cfg.dataset.startswith('imagenet'): 227 | H, W = 224, 224 228 | else: 229 | raise NotImplementedError("Unknown dataset") 230 | flops, params = thop_profile(copy.deepcopy(model), input_size=(1, 3, H, W)) 231 | print('=> FLOPs: {:.6f}G, Params: {:.6f}M'.format(flops/1e9, params/1e6)) 232 | print('=> Params (double-check): %.6fM' % (sum(p.numel() for p in model.parameters()) / 1e6)) 233 | 234 | if args.sync_bn: 235 | import apex 236 | if args.local_rank == 0: 237 | print("using apex synced BN") 238 | model = apex.parallel.convert_syncbn_model(model) 239 | 240 | model = model.cuda() 241 | if args.fp16: 242 | model = FP16Model(model) 243 | if args.distributed: 244 | # By default, apex.parallel.DistributedDataParallel overlaps communication with 245 | # computation in the backward pass. 246 | # model = DDP(model) 247 | # delay_allreduce delays all communication to the end of the backward pass. 248 | model = DDP(model, delay_allreduce=True) 249 | 250 | if args.pretrained: 251 | model.load_state_dict(checkpoint['state_dict']) 252 | 253 | # Scale learning rate based on global batch size 254 | args.lr = args.lr*float(args.batch_size*args.world_size)/cfg.lr_scale_factor #TODO: control the maximum? 255 | 256 | if args.remove_norm_weight_decay: 257 | if args.local_rank == 0: 258 | print("=> ! Weight decay NOT applied to FeatNorm parameters ") 259 | norm_params=set() #TODO: need to check this via experiments 260 | rest_params=set() 261 | for m in model.modules(): 262 | if isinstance(m, (nn.BatchNorm2d, nn.GroupNorm, MixtureBatchNorm2d, MixtureGroupNorm)): 263 | for param in m.parameters(False): 264 | norm_params.add(param) 265 | else: 266 | for param in m.parameters(False): 267 | rest_params.add(param) 268 | 269 | optimizer = torch.optim.SGD([{'params': list(norm_params), 'weight_decay' : 0.0}, 270 | {'params': list(rest_params)}], 271 | args.lr, 272 | momentum=args.momentum, 273 | weight_decay=args.weight_decay, 274 | nesterov=args.nesterov) 275 | else: 276 | if args.local_rank == 0: 277 | print("=> ! Weight decay applied to FeatNorm parameters ") 278 | optimizer = torch.optim.SGD(model.parameters(), args.lr, 279 | momentum=args.momentum, 280 | weight_decay=args.weight_decay, 281 | nesterov=args.nesterov) 282 | 283 | if args.fp16: 284 | optimizer = FP16_Optimizer(optimizer, 285 | static_loss_scale=args.static_loss_scale, 286 | dynamic_loss_scale=args.dynamic_loss_scale) 287 | 288 | # define loss function (criterion) and optimizer 289 | criterion_train = nn.CrossEntropyLoss().cuda() if cfg.dataaug.labelsmoothing_rate == 0.0 \ 290 | else LabelSmoothing(cfg.dataaug.labelsmoothing_rate).cuda() 291 | criterion_val = nn.CrossEntropyLoss().cuda() 292 | 293 | # Optionally resume from a checkpoint 294 | if args.resume: 295 | # Use a local scope to avoid dangling references 296 | def resume(): 297 | if os.path.isfile(args.resume): 298 | if args.local_rank == 0: 299 | print("=> loading checkpoint '{}'".format(args.resume)) 300 | checkpoint = torch.load(args.resume, map_location = lambda storage, loc: storage.cuda(args.gpu)) 301 | args.start_epoch = checkpoint['epoch'] 302 | best_prec1 = checkpoint['best_prec1'] 303 | model.load_state_dict(checkpoint['state_dict']) 304 | optimizer.load_state_dict(checkpoint['optimizer']) 305 | if args.local_rank == 0: 306 | print("=> loaded checkpoint '{}' (epoch {})" 307 | .format(args.resume, checkpoint['epoch'])) 308 | else: 309 | if args.local_rank == 0: 310 | print("=> no checkpoint found at '{}'".format(args.resume)) 311 | resume() 312 | 313 | # Data loading code 314 | lr_milestones = None 315 | if cfg.dataset == "cifar10": 316 | train_transform = transforms.Compose([ 317 | transforms.RandomCrop(32, padding=4), 318 | transforms.RandomHorizontalFlip() 319 | ]) 320 | train_dataset = datasets.CIFAR10('./datasets', train=True, download=False, transform=train_transform) 321 | val_dataset = datasets.CIFAR10('./datasets', train=False, download=False) 322 | lr_milestones = cfg.lr_milestones 323 | elif cfg.dataset == "cifar100": 324 | train_transform = transforms.Compose([ 325 | transforms.RandomCrop(32, padding=4), 326 | transforms.RandomHorizontalFlip() 327 | ]) 328 | train_dataset = datasets.CIFAR100('./datasets', train=True, download=False, transform=train_transform) 329 | val_dataset = datasets.CIFAR100('./datasets', train=False, download=False) 330 | lr_milestones = cfg.lr_milestones 331 | elif cfg.dataset == "imagenet": 332 | traindir = os.path.join(args.data, 'train') 333 | valdir = os.path.join(args.data, 'val') 334 | 335 | crop_size = cfg.crop_size # 224 336 | val_size = cfg.crop_size + 32 # 256 337 | 338 | train_dataset = datasets.ImageFolder( 339 | traindir, 340 | transforms.Compose([ 341 | transforms.RandomResizedCrop(crop_size, interpolation=cfg.crop_interpolation), 342 | transforms.RandomHorizontalFlip(), 343 | # transforms.ToTensor(), Too slow 344 | # normalize, 345 | ])) 346 | val_dataset = datasets.ImageFolder(valdir, transforms.Compose([ 347 | transforms.Resize(val_size, interpolation=cfg.crop_interpolation), 348 | transforms.CenterCrop(crop_size), 349 | ])) 350 | 351 | train_sampler = None 352 | val_sampler = None 353 | if args.distributed: 354 | train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset) 355 | val_sampler = torch.utils.data.distributed.DistributedSampler(val_dataset) 356 | 357 | train_loader = torch.utils.data.DataLoader( 358 | train_dataset, batch_size=args.batch_size, shuffle=(train_sampler is None), 359 | num_workers=args.workers, pin_memory=True, sampler=train_sampler, collate_fn=fast_collate) 360 | 361 | val_loader = torch.utils.data.DataLoader( 362 | val_dataset, 363 | batch_size=args.batch_size, shuffle=False, 364 | num_workers=args.workers, pin_memory=True, 365 | sampler=val_sampler, 366 | collate_fn=fast_collate) 367 | 368 | if args.evaluate: 369 | validate(val_loader, model, criterion_val) 370 | return 371 | 372 | scheduler = CosineAnnealingLR(optimizer.optimizer if args.fp16 else optimizer, 373 | args.epochs, len(train_loader), 374 | eta_min=cfg.cosine_lr_min, warmup=cfg.warmup_epochs) if cfg.use_cosine_lr else None 375 | 376 | for epoch in range(args.start_epoch, args.epochs): 377 | if args.distributed: 378 | train_sampler.set_epoch(epoch) 379 | 380 | # train for one epoch 381 | train(train_loader, model, criterion_train, optimizer, epoch, scheduler, lr_milestones, cfg.warmup_epochs, 382 | cfg.dataaug.mixup_rate, cfg.dataaug.labelsmoothing_rate) 383 | if args.prof: 384 | break 385 | # evaluate on validation set 386 | prec1 = validate(val_loader, model, criterion_val) 387 | 388 | # remember best prec@1 and save checkpoint 389 | if args.local_rank == 0: 390 | is_best = prec1 > best_prec1 391 | best_prec1 = max(prec1, best_prec1) 392 | save_checkpoint({ 393 | 'epoch': epoch + 1, 394 | 'arch': args.arch, 395 | 'state_dict': model.state_dict(), 396 | 'best_prec1': best_prec1, 397 | 'optimizer' : optimizer.state_dict(), 398 | }, is_best, args.save_dir) 399 | 400 | class data_prefetcher(): 401 | def __init__(self, loader): 402 | self.loader = iter(loader) 403 | self.stream = torch.cuda.Stream() 404 | if cfg.dataset == 'cifar10': 405 | self.mean = torch.tensor([0.49139968 * 255, 0.48215827 * 255, 0.44653124 * 255]).cuda().view(1,3,1,1) 406 | self.std = torch.tensor([0.24703233 * 255, 0.24348505 * 255, 0.26158768 * 255]).cuda().view(1,3,1,1) 407 | elif cfg.dataset == 'cifar100': 408 | self.mean = torch.tensor([0.5071 * 255, 0.4867 * 255, 0.4408 * 255]).cuda().view(1,3,1,1) 409 | self.std = torch.tensor([0.2675 * 255, 0.2565 * 255, 0.2761 * 255]).cuda().view(1,3,1,1) 410 | elif cfg.dataset == 'imagenet': 411 | self.mean = torch.tensor([0.485 * 255, 0.456 * 255, 0.406 * 255]).cuda().view(1,3,1,1) 412 | self.std = torch.tensor([0.229 * 255, 0.224 * 255, 0.225 * 255]).cuda().view(1,3,1,1) 413 | else: 414 | raise NotImplementedError 415 | if args.fp16: 416 | self.mean = self.mean.half() 417 | self.std = self.std.half() 418 | self.preload() 419 | 420 | def preload(self): 421 | try: 422 | self.next_input, self.next_target = next(self.loader) 423 | except StopIteration: 424 | self.next_input = None 425 | self.next_target = None 426 | return 427 | with torch.cuda.stream(self.stream): 428 | self.next_input = self.next_input.cuda(non_blocking=True) 429 | self.next_target = self.next_target.cuda(non_blocking=True) 430 | if args.fp16: 431 | self.next_input = self.next_input.half() 432 | else: 433 | self.next_input = self.next_input.float() 434 | self.next_input = self.next_input.sub_(self.mean).div_(self.std) 435 | 436 | def next(self): 437 | torch.cuda.current_stream().wait_stream(self.stream) 438 | input = self.next_input 439 | target = self.next_target 440 | self.preload() 441 | return input, target 442 | 443 | # from NVIDIA DL Examples 444 | def prefetched_loader(loader): 445 | if cfg.dataset == 'cifar10': 446 | self.mean = torch.tensor([0.49139968 * 255, 0.48215827 * 255, 0.44653124 * 255]).cuda().view(1,3,1,1) 447 | self.std = torch.tensor([0.24703233 * 255, 0.24348505 * 255, 0.26158768 * 255]).cuda().view(1,3,1,1) 448 | elif cfg.dataset == 'cifar100': 449 | self.mean = torch.tensor([0.5071 * 255, 0.4867 * 255, 0.4408 * 255]).cuda().view(1,3,1,1) 450 | self.std = torch.tensor([0.2675 * 255, 0.2565 * 255, 0.2761 * 255]).cuda().view(1,3,1,1) 451 | elif cfg.dataset == 'imagenet': 452 | self.mean = torch.tensor([0.485 * 255, 0.456 * 255, 0.406 * 255]).cuda().view(1,3,1,1) 453 | self.std = torch.tensor([0.229 * 255, 0.224 * 255, 0.225 * 255]).cuda().view(1,3,1,1) 454 | else: 455 | raise NotImplementedError 456 | 457 | stream = torch.cuda.Stream() 458 | first = True 459 | 460 | for next_input, next_target in loader: 461 | with torch.cuda.stream(stream): 462 | next_input = next_input.cuda(non_blocking=True) 463 | next_target = next_target.cuda(non_blocking=True) 464 | next_input = next_input.float() 465 | next_input = next_input.sub_(mean).div_(std) 466 | 467 | if not first: 468 | yield input, target 469 | else: 470 | first = False 471 | 472 | torch.cuda.current_stream().wait_stream(stream) 473 | input = next_input 474 | target = next_target 475 | 476 | yield input, target 477 | 478 | 479 | def train(train_loader, model, criterion, optimizer, epoch, scheduler=None, lr_milestones=None, warmup_epoch=0, 480 | mixup_rate=0.0, labelsmoothing_rate=0.0): 481 | batch_time = AverageMeter() 482 | data_time = AverageMeter() 483 | losses = AverageMeter() 484 | top1 = AverageMeter() 485 | top5 = AverageMeter() 486 | 487 | # switch to train mode 488 | model.train() 489 | end = time.time() 490 | 491 | prefetcher = data_prefetcher(train_loader) 492 | input, target = prefetcher.next() 493 | i = -1 494 | beta_distribution = torch.distributions.beta.Beta(mixup_rate, mixup_rate) 495 | while input is not None: 496 | i += 1 497 | 498 | if scheduler is None: 499 | lr = adjust_learning_rate(optimizer, epoch, i, len(train_loader), lr_milestones, warmup_epoch) 500 | else: 501 | lr = scheduler.update(epoch, i) 502 | 503 | if args.prof: 504 | if i > 10: 505 | break 506 | # measure data loading time 507 | data_time.update(time.time() - end) 508 | 509 | # Mixup input 510 | if mixup_rate > 0.0: 511 | lambda_ = beta_distribution.sample([]).item() 512 | index = torch.randperm(input.size(0)).cuda() 513 | input = lambda_ * input + (1 - lambda_) * input[index, :] 514 | 515 | # compute output 516 | if args.prof: torch.cuda.nvtx.range_push("forward") 517 | output = model(input) 518 | if args.prof: torch.cuda.nvtx.range_pop() 519 | 520 | # Mixup loss 521 | if mixup_rate > 0.0: 522 | # Mixup loss 523 | loss = (lambda_ * criterion(output, target) 524 | + (1 - lambda_) * criterion(output, target[index])) 525 | 526 | # Mixup target 527 | if labelsmoothing_rate > 0.0: 528 | N = output.size(0) 529 | C = output.size(1) 530 | off_prob = labelsmoothing_rate / C 531 | target_1 = torch.full(size=(N, C), fill_value=off_prob ).cuda() 532 | target_2 = torch.full(size=(N, C), fill_value=off_prob ).cuda() 533 | target_1.scatter_(dim=1, index=torch.unsqueeze(target, dim=1), value=1.0-labelsmoothing_rate+off_prob) 534 | target_2.scatter_(dim=1, index=torch.unsqueeze(target[index], dim=1), value=1.0-labelsmoothing_rate+off_prob) 535 | target = lambda_ * target_1 + (1 - lambda_) * target_2 536 | else: 537 | target = lambda_ * target + (1 - lambda_) * target[index] 538 | else: 539 | loss = criterion(output, target) 540 | 541 | # compute gradient and do SGD step 542 | optimizer.zero_grad() 543 | 544 | if args.prof: torch.cuda.nvtx.range_push("backward") 545 | if args.fp16: 546 | optimizer.backward(loss) 547 | else: 548 | loss.backward() 549 | if args.prof: torch.cuda.nvtx.range_pop() 550 | 551 | # debug 552 | # if args.local_rank == 0: 553 | # for name_, param in model.named_parameters(): 554 | # print(name_, param.data.double().sum().item(), param.grad.data.double().sum().item()) 555 | 556 | if args.prof: torch.cuda.nvtx.range_push("step") 557 | optimizer.step() 558 | if args.prof: torch.cuda.nvtx.range_pop() 559 | 560 | # Measure accuracy 561 | if mixup_rate > 0.0: 562 | prec1 = rmse(output.data, target) 563 | prec5 = prec1 564 | else: 565 | prec1, prec5 = accuracy(output.data, target, topk=(1, 5)) 566 | 567 | # Average loss and accuracy across processes for logging 568 | if args.distributed: 569 | reduced_loss = reduce_tensor(loss.data) 570 | prec1 = reduce_tensor(prec1) 571 | prec5 = reduce_tensor(prec5) 572 | else: 573 | reduced_loss = loss.data 574 | 575 | # to_python_float incurs a host<->device sync 576 | losses.update(to_python_float(reduced_loss), input.size(0)) 577 | top1.update(to_python_float(prec1), input.size(0)) 578 | top5.update(to_python_float(prec5), input.size(0)) 579 | 580 | # torch.cuda.synchronize() # no this in torchvision ex. and cause nan loss problems in deep models with fp16 581 | 582 | batch_time.update(time.time() - end) 583 | end = time.time() 584 | input, target = prefetcher.next() 585 | 586 | if i%args.print_freq == 0 and args.local_rank == 0: 587 | # Every print_freq iterations, check the loss, accuracy, and speed. 588 | print('Epoch: [{0}][{1}/{2}]\t' 589 | 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' 590 | 'Speed {3:.3f} ({4:.3f})\t' 591 | 'Data {data_time.val:.3f} ({data_time.avg:.3f})\t' 592 | 'Loss {loss.val:.10f} ({loss.avg:.4f})\t' 593 | 'Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t' 594 | 'Prec@5 {top5.val:.3f} ({top5.avg:.3f})\t' 595 | 'lr {lr:.6f}\t'.format( 596 | epoch, i, len(train_loader), 597 | args.world_size*args.batch_size/batch_time.val, 598 | args.world_size*args.batch_size/batch_time.avg, 599 | batch_time=batch_time, 600 | data_time=data_time, loss=losses, top1=top1, top5=top5, lr=lr[0])) 601 | 602 | def validate(val_loader, model, criterion): 603 | global best_prec1_val, prec5_val, best_prec5_val 604 | batch_time = AverageMeter() 605 | losses = AverageMeter() 606 | top1 = AverageMeter() 607 | top5 = AverageMeter() 608 | 609 | # switch to evaluate mode 610 | model.eval() 611 | 612 | end = time.time() 613 | 614 | prefetcher = data_prefetcher(val_loader) 615 | input, target = prefetcher.next() 616 | i = -1 617 | while input is not None: 618 | i += 1 619 | 620 | # compute output 621 | with torch.no_grad(): 622 | output = model(input) 623 | loss = criterion(output, target) 624 | 625 | # measure accuracy and record loss 626 | prec1, prec5 = accuracy(output.data, target, topk=(1, 5)) 627 | 628 | if args.distributed: 629 | reduced_loss = reduce_tensor(loss.data) 630 | prec1 = reduce_tensor(prec1) 631 | prec5 = reduce_tensor(prec5) 632 | else: 633 | reduced_loss = loss.data 634 | 635 | losses.update(to_python_float(reduced_loss), input.size(0)) 636 | top1.update(to_python_float(prec1), input.size(0)) 637 | top5.update(to_python_float(prec5), input.size(0)) 638 | 639 | # measure elapsed time 640 | batch_time.update(time.time() - end) 641 | end = time.time() 642 | 643 | # TODO: Change timings to mirror train(). 644 | if args.local_rank == 0 and i > 0 and i % args.print_freq == 0: 645 | print('Test: [{0}/{1}]\t' 646 | 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' 647 | 'Speed {2:.3f} ({3:.3f})\t' 648 | 'Loss {loss.val:.4f} ({loss.avg:.4f})\t' 649 | 'Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t' 650 | 'Prec@5 {top5.val:.3f} ({top5.avg:.3f})'.format( 651 | i, len(val_loader), 652 | args.world_size * args.batch_size / batch_time.val, 653 | args.world_size * args.batch_size / batch_time.avg, 654 | batch_time=batch_time, loss=losses, 655 | top1=top1, top5=top5)) 656 | 657 | input, target = prefetcher.next() 658 | 659 | if args.local_rank == 0: 660 | if top1.avg >= best_prec1_val: 661 | best_prec1_val = top1.avg 662 | prec5_val = top5.avg 663 | best_prec5_val = max(best_prec5_val, top5.avg) 664 | print('Test: Prec@1 {top1.avg:.3f} Prec@5 {top5.avg:.3f}\t Best_Prec@1 {best:.3f}\t Prec@5 {prec5_val:.3f}\t Best_Prec@5 {bestprec5_val:.3f}' 665 | .format(top1=top1, top5=top5, best=best_prec1_val, prec5_val=prec5_val, bestprec5_val=best_prec5_val)) 666 | 667 | return top1.avg 668 | 669 | 670 | def save_checkpoint(state, is_best, save_dir='./'): 671 | filename = os.path.join(save_dir, 'checkpoint.pth.tar') 672 | best_file = os.path.join(save_dir, 'model_best.pth.tar') 673 | torch.save(state, filename) 674 | if is_best: 675 | shutil.copyfile(filename, best_file) 676 | 677 | class AverageMeter(object): 678 | """Computes and stores the average and current value""" 679 | def __init__(self): 680 | self.reset() 681 | 682 | def reset(self): 683 | self.val = 0 684 | self.avg = 0 685 | self.sum = 0 686 | self.count = 0 687 | 688 | def update(self, val, n=1): 689 | self.val = val 690 | self.sum += val * n 691 | self.count += n 692 | self.avg = self.sum / self.count 693 | 694 | 695 | def adjust_learning_rate(optimizer, epoch, step, len_epoch, lr_milestones=None, warmup_epoch=0): 696 | """LR schedule that should yield 76% converged accuracy with batch size 256""" 697 | # if not isinstance(optimizer, torch.optim.Optimizer): 698 | # raise TypeError('{} is not an Optimizer'.format( 699 | # type(optimizer).__name__)) 700 | if lr_milestones is None: 701 | factor = epoch // 30 702 | 703 | if epoch >= 80: 704 | factor = factor + 1 705 | 706 | lr = args.lr*(0.1**factor) 707 | 708 | """Warmup""" 709 | if epoch < 5: 710 | lr = lr*float(1 + step + epoch*len_epoch)/(5.*len_epoch) 711 | 712 | else: 713 | factor = 0 714 | for m in lr_milestones: 715 | if epoch >= m: 716 | factor += 1 717 | 718 | lr = args.lr*(0.1**factor) 719 | 720 | """Warmup""" 721 | if epoch < warmup_epoch: 722 | lr = lr*float(1 + step + epoch*len_epoch)/(warmup_epoch*len_epoch) 723 | 724 | 725 | # if(args.local_rank == 0): 726 | # print("epoch = {}, step = {}, lr = {}".format(epoch, step, lr)) 727 | 728 | for param_group in optimizer.param_groups: 729 | param_group['lr'] = lr 730 | 731 | return [lr] 732 | 733 | 734 | class CosineAnnealingLR(object): 735 | def __init__(self, optimizer, T_max, N_batch, eta_min=0, last_epoch=-1, warmup=0): 736 | if not isinstance(optimizer, torch.optim.Optimizer): 737 | raise TypeError('{} is not an Optimizer'.format( 738 | type(optimizer).__name__)) 739 | self.optimizer = optimizer 740 | self.T_max = T_max 741 | self.N_batch = N_batch 742 | self.eta_min = eta_min 743 | self.warmup = warmup 744 | 745 | if last_epoch == -1: 746 | for group in optimizer.param_groups: 747 | group.setdefault('initial_lr', group['lr']) 748 | else: 749 | for i, group in enumerate(optimizer.param_groups): 750 | if 'initial_lr' not in group: 751 | raise KeyError("param 'initial_lr' is not specified " 752 | "in param_groups[{}] when resuming an optimizer".format(i)) 753 | self.base_lrs = list(map(lambda group: group['initial_lr'], optimizer.param_groups)) 754 | self.update(last_epoch+1) 755 | self.last_epoch = last_epoch 756 | self.iter = 0 757 | 758 | def state_dict(self): 759 | return {key: value for key, value in self.__dict__.items() if key != 'optimizer'} 760 | 761 | def load_state_dict(self, state_dict): 762 | self.__dict__.update(state_dict) 763 | 764 | def get_lr(self): 765 | if self.last_epoch < self.warmup: 766 | lrs = [base_lr * (self.last_epoch + self.iter / self.N_batch) / self.warmup for base_lr in self.base_lrs] 767 | else: 768 | lrs = [self.eta_min + (base_lr - self.eta_min) * 769 | (1 + math.cos(math.pi * (self.last_epoch - self.warmup + self.iter / self.N_batch) / (self.T_max - self.warmup))) / 2 770 | for base_lr in self.base_lrs] 771 | return lrs 772 | 773 | def update(self, epoch, batch=0): 774 | self.last_epoch = epoch 775 | self.iter = batch + 1 776 | lrs = self.get_lr() 777 | for param_group, lr in zip(self.optimizer.param_groups, lrs): 778 | param_group['lr'] = lr 779 | 780 | return lrs 781 | 782 | 783 | def accuracy(output, target, topk=(1,)): 784 | """Computes the precision@k for the specified values of k""" 785 | maxk = max(topk) 786 | batch_size = target.size(0) 787 | 788 | _, pred = output.topk(maxk, 1, True, True) 789 | pred = pred.t() 790 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 791 | 792 | res = [] 793 | for k in topk: 794 | correct_k = correct[:k].view(-1).float().sum(0, keepdim=True) 795 | res.append(correct_k.mul_(100.0 / batch_size)) 796 | return res 797 | 798 | def rmse(yhat,y): 799 | if args.fp16: 800 | res = torch.sqrt(torch.mean((yhat.float()-y.float())**2)) 801 | else: 802 | res = torch.sqrt(torch.mean((yhat-y)**2)) 803 | return res 804 | 805 | def reduce_tensor(tensor): 806 | rt = tensor.clone() 807 | dist.all_reduce(rt, op=dist.ReduceOp.SUM) 808 | rt /= args.world_size 809 | return rt 810 | 811 | if __name__ == '__main__': 812 | # to suppress annoying warnings 813 | import warnings 814 | warnings.filterwarnings("ignore", "(Possibly )?corrupt EXIF data", UserWarning) 815 | 816 | main() 817 | -------------------------------------------------------------------------------- /tools/smoothing.py: -------------------------------------------------------------------------------- 1 | # From: https://github.com/NVIDIA/DeepLearningExamples/tree/master/PyTorch/Classification/RN50v1.5 2 | # commit a1aff31 3 | # Date: 05/01/2019 4 | # Note: check the updates in NVIDIA DeepLearningExamples regulary 5 | import torch 6 | import torch.nn as nn 7 | 8 | class LabelSmoothing(nn.Module): 9 | """ 10 | NLL loss with label smoothing. 11 | """ 12 | def __init__(self, smoothing=0.0): 13 | """ 14 | Constructor for the LabelSmoothing module. 15 | 16 | :param smoothing: label smoothing factor 17 | """ 18 | super(LabelSmoothing, self).__init__() 19 | self.confidence = 1.0 - smoothing 20 | self.smoothing = smoothing 21 | 22 | def forward(self, x, target): 23 | logprobs = torch.nn.functional.log_softmax(x, dim=-1) 24 | 25 | nll_loss = -logprobs.gather(dim=-1, index=target.unsqueeze(1)) 26 | nll_loss = nll_loss.squeeze(1) 27 | smooth_loss = -logprobs.mean(dim=-1) 28 | loss = self.confidence * nll_loss + self.smoothing * smooth_loss 29 | return loss.mean() 30 | 31 | --------------------------------------------------------------------------------