├── .gitignore ├── LICENSE ├── README.md ├── configs ├── aognet-12m-an-imagenet-200e-lbsmth-mixup.yaml ├── aognet-12m-an-imagenet.yaml ├── aognet-12m-imagenet-200e-lbsmth-mixup.yaml ├── aognet-12m-imagenet.yaml ├── aognet-40m-an-imagenet-200e-lbsmth-mixup.yaml ├── aognet-40m-an-imagenet.yaml ├── aognet-40m-imagenet-200e-lbsmth-mixup.yaml ├── aognet-40m-imagenet.yaml ├── mobilenetv2-an-imagenet.yaml ├── resnet-101-an-imagenet-200e-lbsmth-mixup.yaml ├── resnet-101-an-imagenet.yaml ├── resnet-101-imagenet-200e-lbsmth-mixup.yaml ├── resnet-101-imagenet.yaml ├── resnet-50-an-imagenet-200e-lbsmth-mixup.yaml ├── resnet-50-an-imagenet.yaml ├── resnet-50-imagenet-200e-lbsmth-mixup.yaml └── resnet-50-imagenet.yaml ├── images ├── an-comparison.png ├── teaser-imagenet-dissection.png └── teaser.png ├── models ├── __init__.py ├── aognet │ ├── AOG.py │ ├── __init__.py │ ├── aognet.py │ ├── operator_basic.py │ └── operator_singlescale.py ├── config.py ├── mobilenet.py └── resnet.py ├── requirements.txt ├── scripts ├── test_fp16.sh └── train_fp16.sh └── tools ├── __init__.py ├── main_fp16.py └── smoothing.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | env/ 12 | build/ 13 | develop-eggs/ 14 | dist/ 15 | downloads/ 16 | eggs/ 17 | .eggs/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | *.egg-info/ 23 | .installed.cfg 24 | *.egg 25 | 26 | # PyInstaller 27 | # Usually these files are written by a python script from a template 28 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 29 | *.manifest 30 | *.spec 31 | 32 | # Installer logs 33 | pip-log.txt 34 | pip-delete-this-directory.txt 35 | 36 | # Unit test / coverage reports 37 | htmlcov/ 38 | .tox/ 39 | .coverage 40 | .coverage.* 41 | .cache 42 | .idea 43 | nosetests.xml 44 | coverage.xml 45 | *,cover 46 | .hypothesis/ 47 | 48 | # Translations 49 | *.mo 50 | *.pot 51 | 52 | # Django stuff: 53 | #*.log 54 | temp*.log 55 | local_settings.py 56 | 57 | # Flask stuff: 58 | instance/ 59 | .webassets-cache 60 | 61 | # Scrapy stuff: 62 | .scrapy 63 | 64 | # Sphinx documentation 65 | docs/_build/ 66 | 67 | # PyBuilder 68 | target/ 69 | 70 | # IPython Notebook 71 | .ipynb_checkpoints 72 | 73 | # pyenv 74 | .python-version 75 | 76 | # celery beat schedule file 77 | celerybeat-schedule 78 | 79 | # dotenv 80 | .env 81 | 82 | # virtualenv 83 | venv/ 84 | ENV/ 85 | 86 | # Spyder project settings 87 | .spyderproject 88 | 89 | # Rope project settings 90 | .ropeproject 91 | 92 | # vscode 93 | .vscode 94 | 95 | # this project 96 | data/ 97 | datasets/ 98 | results/ 99 | nohup.out 100 | *.tar 101 | *.log 102 | .DS_Store 103 | pretrained_models/ 104 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | RESEARCH ONLY LICENSE 2 | Copyright (c) 2018-2019 North Carolina State University. 3 | All rights reserved. 4 | Redistribution and use in source and binary forms, with or without modification, are permitted provided 5 | that the following conditions are met: 6 | 1. Redistributions and use are permitted for internal research purposes only, and commercial use 7 | is strictly prohibited under this license. Inquiries regarding commercial use should be directed to the 8 | Office of Research Commercialization at North Carolina State University, 919-215-7199, 9 | https://research.ncsu.edu/commercialization/contact/, commercialization@ncsu.edu . 10 | 2. Commercial use means the sale, lease, export, transfer, conveyance or other distribution to a 11 | third party for financial gain, income generation or other commercial purposes of any kind, whether 12 | direct or indirect. Commercial use also means providing a service to a third party for financial gain, 13 | income generation or other commercial purposes of any kind, whether direct or indirect. 14 | 3. Redistributions of source code must retain the above copyright notice, this list of conditions and 15 | the following disclaimer. 16 | 4. Redistributions in binary form must reproduce the above copyright notice, this list of conditions 17 | and the following disclaimer in the documentation and/or other materials provided with the 18 | distribution. 19 | 5. The names “North Carolina State University”, “NCSU” and any trade-name, personal name, 20 | trademark, trade device, service mark, symbol, image, icon, or any abbreviation, contraction or 21 | simulation thereof owned by North Carolina State University must not be used to endorse or promote 22 | products derived from this software without prior written permission. For written permission, please 23 | contact trademarks@ncsu.edu. 24 | Disclaimer: THIS SOFTWARE IS PROVIDED “AS IS” AND ANY EXPRESSED OR IMPLIED WARRANTIES, 25 | INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR 26 | A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NORTH CAROLINA STATE UNIVERSITY BE 27 | LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES 28 | (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, 29 | DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF 30 | LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR 31 | OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE 32 | POSSIBILITY OF SUCH DAMAGE. 33 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | 2 | # AOGNet-v2 3 | 4 | **Please check our refactorized code at [iVMCL-Release](https://github.com/iVMCL/iVMCL-Release).** 5 | 6 | This project provides source code for [AOGNets: Compositional Grammatical Architectures for Deep Learning 7 | ](https://arxiv.org/abs/1711.05847)(CVPR 2019) and [Attentive Normalization](https://arxiv.org/abs/1908.01259). 8 | 9 | drawing 10 | 11 | ## Installation 12 | 13 | 1. Create a conda environment, and install all dependencies. 14 | ``` 15 | conda create -n aognet-v2 python=3.7 16 | conda activate aognet-v2 17 | git clone https://github.com/iVMCL/AOGNet-v2 18 | cd AOGNet-v2 19 | pip install -r requirements.txt 20 | ``` 21 | 2. Install PyTorch 1.0+, follows [https://pytorch.org/](https://pytorch.org/) 22 | 23 | 3. Install apex, follows [https://github.com/NVIDIA/apex](https://github.com/NVIDIA/apex) 24 | 25 | 26 | ## ImageNet dataset preparation 27 | 28 | - Download the [ImageNet dataset](http://image-net.org/download) to YOUR_IMAGENET_PATH and move validation images to labeled subfolders 29 | - The [script](https://raw.githubusercontent.com/soumith/imagenetloader.torch/master/valprep.sh) may be helpful. 30 | 31 | - Create a datasets subfolder under your cloned AOGNet-v2 and a symbolic link to the ImageNet dataset 32 | 33 | ``` 34 | $ cd AOGNet-v2 35 | $ mkdir datasets 36 | $ ln -s PATH_TO_YOUR_IMAGENET ./datasets/ 37 | ``` 38 | 39 | ## Pretrained models for Evaluation 40 | 41 | ``` 42 | $ cd AOGNet-v2 43 | mkdir pretrained_models 44 | ``` 45 | Download pretrained models from the links provided in the tables below, unzip it and place into the pretrained_models directory. 46 | ``` 47 | $ ./scripts/test_fp16.sh 48 | ``` 49 | 50 | - Models trained with normal training setup: 120 epochs, cosine lr scheduling, SGD, ... (MobileNet-v2 trained with 150 epochs). All models are trained with 8 Nvidia V100 GPUs. 51 | 52 | | Method | #Params | FLOPS | top-1 error (%) | top-5 error (%)| Link | 53 | |---|---|---|---|---|---| 54 | | ResNet-50-BN | 25.56M | 4.09G | 23.01 | 6.68 | [Google Drive](https://drive.google.com/open?id=1NqGbN5XjEWufq-V3WXz0eXJ2HKDL2wGV) | 55 | | ResNet-50-GN* | 25.56M | 4.09G | 23.52 | 6.85 | - | 56 | | ResNet-50-SN* | 25.56M | - | 22.43 | 6.35 | - | 57 | | ResNet-50-SE* | 28.09M | - | 22.37 | 6.36 | - | 58 | | ResNet-50-AN (w/ BN) | 25.76M | 4.09G | 21.59 | 5.88 | [Google Drive](https://drive.google.com/open?id=17MNrFTBJS__1cW6PNcLS1lJ7590QyA6m) | 59 | | ResNet-101-BN | 44.57M | 8.12G | 21.33 | 5.85 | [Google Drive](https://drive.google.com/open?id=1txeVNkDDKd45dIrJAZyW6Si1UgGvqtXu)| 60 | | ResNet-101-AN (w/ BN) | 45.00M | 8.12G | 20.61 | 5.41 |[Google Drive](https://drive.google.com/open?id=1Cq-D2Gm2QeZW2WqfCrSeqou4AhYqH9yN) | 61 | | MobileNet-v2-BN | 3.50M | 0.34G | 28.69 | 9.33 | [Google Drive]()| 62 | | MobileNet-v2-AN (w/ BN) | 3.56M | 0.34G | 26.67 | 8.56 | [Google Drive](https://drive.google.com/open?id=1pD-fHdzyVW5ufC8FB4R7yPtjN4Z2St4t)| 63 | | AOGNet-12M-BN | 12.26M | 2.19G | 22.22 | 6.06 | [Google Drive](https://drive.google.com/open?id=1PsA2EvEw7wsCGhzp65Lfq3w1pvkU65o0) | 64 | | AOGNet-12M-AN (w/ BN) | 12.37M | 2.19G | 21.28 | 5.76 | [Google Drive](https://drive.google.com/open?id=1t4Oa0vZuakNfR-PWhMIiOWsVTCZ8Z-G-) | 65 | | AOGNet-40M-BN | 40.15M | 7.51G | 19.84 | 4.94 | [Google Drive](https://drive.google.com/open?id=1u-ToLniZVEkBlbSGQL49A3V72cXS5Dwd) | 66 | | AOGNet-40M-AN (w/ BN) | 40.39M | 7.51G | 19.33 | 4.72 | [Google Drive](https://drive.google.com/open?id=1LWvdhjxQ259_Gq-YMNT70fAraDGsIWKo) | 67 | 68 | \* From original paper 69 | 70 | - Models trained with advanced training setup: 200 epochs, 0.1 label smoothing, 0.2 mixup 71 | 72 | | Method | #Params | FLOPS | top-1 error (%) | top-5 error (%)| Link | 73 | |---|---|---|---|---|---| 74 | | ResNet-50-BN | 25.56M | 4.09G | 21.08 | 5.56 | [Google Drive](https://drive.google.com/open?id=1SoE0U9W5ghpEhmCCYqClq9Io1WkZts_C) | 75 | | ResNet-50-AN (w/ BN) | 25.76M | 4.09G | 19.92 | 5.04 | [Google Drive](https://drive.google.com/open?id=1qWSN-95Blq-MBCFCzh1DpmmB7kk9VHbU) | 76 | | ResNet-101-BN | 44.57M | 8.12G | 19.71 | 4.89 | [Google Drive](https://drive.google.com/open?id=1oqPQG7Oc0REvrLABAjsrmOpP4GdGgXEG)| 77 | | ResNet-101-AN (w/ BN) | 45.00M | 8.12G | 18.85 | 4.63 |[Google Drive](https://drive.google.com/open?id=1habUSoSotE8-fEq60IhnR2CjYM14wssv) | 78 | | AOGNet-12M-BN | 12.26M | 2.19G | 21.63 | 5.60 | [Google Drive](https://drive.google.com/open?id=14gZ18L4mqHWI79P3d-gY01jqqGVsA9hB) | 79 | | AOGNet-12M-AN (w/ BN) | 12.37M | 2.19G | 20.57 | 5.38 | [Google Drive](https://drive.google.com/open?id=1GY6TbanFrXSBcFD-kjmBixf4wG2Vp1sJ) | 80 | | AOGNet-40M-BN | 40.15M | 7.51G | 18.70 | 4.47 | [Google Drive](https://drive.google.com/open?id=1MugHc_9rn5wR7d1Hyx4oaGjc64l8v_aG) | 81 | | AOGNet-40M-AN (w/ BN) | 40.39M | 7.51G | 18.13 | 4.26 | [Google Drive](https://drive.google.com/open?id=1w5W12mDgni0DPCANuvPoIbmKKPND-Kdh) | 82 | 83 | - Remarks: The accuracy from the pretrained models might be slightly different than that reported during the training (used in our papers). The reason is still unclear. 84 | 85 | 86 | ## Perform training on ImageNet dataset 87 | 88 | ``` 89 | $ cd AOGNet-v2 90 | $ ./scripts/train_fp16.sh 91 | ``` 92 | 93 | e.g. For training AOGNet-12M with AN 94 | ``` 95 | $ ./scripts/train_fp16.sh configs/aognet-12m-an-imagenet.yaml 96 | ``` 97 | 98 | See more configuration files in [configs](https://github.com/iVMCL/AOGNet-v2/tree/master/configs). Change the GPU settings in scripts/train_fp16.sh 99 | 100 | 101 | 102 | ## Object Detection and Instance Segmentation on COCO 103 | We performed object detection and instance segmentation task on COCO with our models pretrained on ImageNet. We implement it based on the [mmdetection](https://github.com/open-mmlab/mmdetection) framework. The code is released in [https://github.com/iVMCL/AttentiveNorm_Detection/tree/master/configs/attn_norm](https://github.com/iVMCL/AttentiveNorm_Detection/tree/master/configs/attn_norm). 104 | 105 | ## Citations 106 | Please consider citing the AOGNets or Attentive Normalization papers in your publications if it helps your research. 107 | ``` 108 | @inproceedings{li2019aognets, 109 | title={AOGNets: Compositional Grammatical Architectures for Deep Learning}, 110 | author={Li, Xilai and Song, Xi and Wu, Tianfu}, 111 | booktitle={Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition}, 112 | pages={6220--6230}, 113 | year={2019} 114 | } 115 | 116 | @article{li2019attentive, 117 | title={Attentive Normalization}, 118 | author={Li, Xilai and Sun, Wei and Wu, Tianfu}, 119 | journal={arXiv preprint arXiv:1908.01259}, 120 | year={2019} 121 | } 122 | ``` 123 | 124 | ## Contact 125 | 126 | Please feel free to report issues and any related problems to Xilai Li (xli47 at ncsu dot edu), and Tianfu Wu (twu19 at ncsu dot edu). 127 | 128 | ## License 129 | 130 | AOGNets related codes are under [RESEARCH ONLY LICENSE](./LICENSE). 131 | -------------------------------------------------------------------------------- /configs/aognet-12m-an-imagenet-200e-lbsmth-mixup.yaml: -------------------------------------------------------------------------------- 1 | --- 2 | batch_size: 256 3 | num_epoch: 200 4 | dataset: 'imagenet' 5 | num_classes: 1000 6 | crop_size: 224 7 | crop_interpolation: 2 ### 2: BILINEAR, 3:CUBIC 8 | use_cosine_lr: True ### 9 | cosine_lr_min: 0.0 10 | warmup_epochs: 5 11 | lr: 0.1 12 | lr_scale_factor: 256 13 | lr_milestones: [30, 60, 90, 100] 14 | momentum: 0.9 15 | wd: 0.0001 16 | nesterov: False 17 | activation_mode: 0 ### 1: leakyReLU, 2: ReLU6, other: ReLU 18 | init_mode: 'kaiming' 19 | norm_name: 'MixtureBatchNorm2d' 20 | #norm_name: 'BatchNorm2d' 21 | norm_groups: 0 22 | norm_k: [10, 10, 20, 20] ### per stage 23 | norm_attention_mode: 2 24 | norm_zero_gamma_init: False 25 | dataaug: 26 | imagenet_extra_aug: False ### ColorJitter and PCA 27 | labelsmoothing_rate: 0.1 28 | mixup_rate: 0.2 29 | stem: 30 | imagenet_head7x7: False 31 | replace_maxpool_with_res_bottleneck: False 32 | aognet: 33 | max_split: [2, 2, 2, 2] 34 | extra_node_hierarchy: [4, 4, 4, 4] ### 0: none, 4: lateral connection 35 | remove_symmetric_children_of_or_node: [1, 2, 1, 2] ### if true, aog structure is much simplified and bigger filters and more units can be used 36 | terminal_node_no_slice: [0, 0, 0, 0] 37 | stride: [1, 2, 2, 2] 38 | drop_rate: [0.0, 0.0, 0.1, 0.1] 39 | bottleneck_ratio: 0.25 40 | handle_dbl_cnt: True 41 | handle_tnode_dbl_cnt: False 42 | handle_dbl_cnt_in_param_init: False 43 | use_group_conv: False 44 | width_per_group: 0 45 | when_downsample: 1 46 | replace_stride_with_avgpool: True ### in shortcut 47 | use_elem_max_for_ORNodes: False 48 | filter_list: [32, 128, 256, 512, 824] ### try to keep 1:4:2:2 ... except for the final stage which can be adusted for fitting the model size 49 | out_channels: [0, 0] 50 | blocks: [2, 2, 2, 1] 51 | dims: [2, 2, 4, 4] 52 | 53 | 54 | 55 | -------------------------------------------------------------------------------- /configs/aognet-12m-an-imagenet.yaml: -------------------------------------------------------------------------------- 1 | --- 2 | batch_size: 256 3 | num_epoch: 120 4 | dataset: 'imagenet' 5 | num_classes: 1000 6 | crop_size: 224 7 | crop_interpolation: 2 ### 2: BILINEAR, 3:CUBIC 8 | use_cosine_lr: True ### 9 | cosine_lr_min: 0.0 10 | warmup_epochs: 5 11 | lr: 0.1 12 | lr_scale_factor: 256 13 | lr_milestones: [30, 60, 90, 100] 14 | momentum: 0.9 15 | wd: 0.0001 16 | nesterov: False 17 | activation_mode: 0 ### 1: leakyReLU, 2: ReLU6, other: ReLU 18 | init_mode: 'kaiming' 19 | norm_name: 'MixtureBatchNorm2d' 20 | #norm_name: 'BatchNorm2d' 21 | norm_groups: 0 22 | norm_k: [10, 10, 20, 20] ### per stage 23 | norm_attention_mode: 2 24 | norm_zero_gamma_init: False 25 | dataaug: 26 | imagenet_extra_aug: False ### ColorJitter and PCA 27 | labelsmoothing_rate: 0.0 28 | mixup_rate: 0.0 29 | stem: 30 | imagenet_head7x7: False 31 | replace_maxpool_with_res_bottleneck: False 32 | aognet: 33 | max_split: [2, 2, 2, 2] 34 | extra_node_hierarchy: [4, 4, 4, 4] ### 0: none, 4: lateral connection 35 | remove_symmetric_children_of_or_node: [1, 2, 1, 2] ### if true, aog structure is much simplified and bigger filters and more units can be used 36 | terminal_node_no_slice: [0, 0, 0, 0] 37 | stride: [1, 2, 2, 2] 38 | drop_rate: [0.0, 0.0, 0.1, 0.1] 39 | bottleneck_ratio: 0.25 40 | handle_dbl_cnt: True 41 | handle_tnode_dbl_cnt: False 42 | handle_dbl_cnt_in_param_init: False 43 | use_group_conv: False 44 | width_per_group: 0 45 | when_downsample: 1 46 | replace_stride_with_avgpool: True ### in shortcut 47 | use_elem_max_for_ORNodes: False 48 | filter_list: [32, 128, 256, 512, 824] ### try to keep 1:4:2:2 ... except for the final stage which can be adusted for fitting the model size 49 | out_channels: [0, 0] 50 | blocks: [2, 2, 2, 1] 51 | dims: [2, 2, 4, 4] 52 | 53 | 54 | 55 | -------------------------------------------------------------------------------- /configs/aognet-12m-imagenet-200e-lbsmth-mixup.yaml: -------------------------------------------------------------------------------- 1 | --- 2 | batch_size: 256 3 | num_epoch: 200 4 | dataset: 'imagenet' 5 | num_classes: 1000 6 | crop_size: 224 7 | crop_interpolation: 2 ### 2: BILINEAR, 3:CUBIC 8 | use_cosine_lr: True ### 9 | cosine_lr_min: 0.0 10 | warmup_epochs: 5 11 | lr: 0.1 12 | lr_scale_factor: 256 13 | lr_milestones: [30, 60, 90, 100] 14 | momentum: 0.9 15 | wd: 0.0001 16 | nesterov: False 17 | activation_mode: 0 ### 1: leakyReLU, 2: ReLU6, other: ReLU 18 | init_mode: 'kaiming' 19 | #norm_name: 'MixtureBatchNorm2d' 20 | norm_name: 'BatchNorm2d' 21 | norm_groups: 0 22 | norm_k: [10, 10, 20, 20] ### per stage 23 | norm_attention_mode: 2 24 | norm_zero_gamma_init: False 25 | dataaug: 26 | imagenet_extra_aug: False ### ColorJitter and PCA 27 | labelsmoothing_rate: 0.1 28 | mixup_rate: 0.2 29 | stem: 30 | imagenet_head7x7: False 31 | replace_maxpool_with_res_bottleneck: False 32 | aognet: 33 | max_split: [2, 2, 2, 2] 34 | extra_node_hierarchy: [4, 4, 4, 4] ### 0: none, 4: lateral connection 35 | remove_symmetric_children_of_or_node: [1, 2, 1, 2] ### if true, aog structure is much simplified and bigger filters and more units can be used 36 | terminal_node_no_slice: [0, 0, 0, 0] 37 | stride: [1, 2, 2, 2] 38 | drop_rate: [0.0, 0.0, 0.1, 0.1] 39 | bottleneck_ratio: 0.25 40 | handle_dbl_cnt: True 41 | handle_tnode_dbl_cnt: False 42 | handle_dbl_cnt_in_param_init: False 43 | use_group_conv: False 44 | width_per_group: 0 45 | when_downsample: 1 46 | replace_stride_with_avgpool: True ### in shortcut 47 | use_elem_max_for_ORNodes: False 48 | filter_list: [32, 128, 256, 512, 824] ### try to keep 1:4:2:2 ... except for the final stage which can be adusted for fitting the model size 49 | out_channels: [0, 0] 50 | blocks: [2, 2, 2, 1] 51 | dims: [2, 2, 4, 4] 52 | 53 | 54 | 55 | -------------------------------------------------------------------------------- /configs/aognet-12m-imagenet.yaml: -------------------------------------------------------------------------------- 1 | --- 2 | batch_size: 256 3 | num_epoch: 120 4 | dataset: 'imagenet' 5 | num_classes: 1000 6 | crop_size: 224 7 | crop_interpolation: 3 ### 2: BILINEAR, 3:CUBIC 8 | use_cosine_lr: True ### 9 | cosine_lr_min: 0.0 10 | warmup_epochs: 5 11 | lr: 0.1 12 | lr_scale_factor: 256 13 | lr_milestones: [30, 60, 90, 100] 14 | momentum: 0.9 15 | wd: 0.0001 16 | nesterov: False 17 | activation_mode: 0 ### 1: leakyReLU, 2: ReLU6, other: ReLU 18 | init_mode: 'kaiming' 19 | norm_name: 'BatchNorm2d' 20 | norm_groups: 0 21 | norm_k: [10, 10, 20, 20] ### per stage 22 | norm_attention_mode: 2 23 | norm_zero_gamma_init: False 24 | dataaug: 25 | imagenet_extra_aug: False ### ColorJitter and PCA 26 | labelsmoothing_rate: 0.0 27 | mixup_rate: 0.0 28 | stem: 29 | imagenet_head7x7: False 30 | replace_maxpool_with_res_bottleneck: False 31 | aognet: 32 | max_split: [2, 2, 2, 2] 33 | extra_node_hierarchy: [4, 4, 4, 4] ### 0: none, 4: lateral connection 34 | remove_symmetric_children_of_or_node: [1, 2, 1, 2] ### if true, aog structure is much simplified and bigger filters and more units can be used 35 | terminal_node_no_slice: [0, 0, 0, 0] 36 | stride: [1, 2, 2, 2] 37 | drop_rate: [0.0, 0.0, 0.1, 0.1] 38 | bottleneck_ratio: 0.25 39 | handle_dbl_cnt: True 40 | handle_tnode_dbl_cnt: False 41 | handle_dbl_cnt_in_param_init: False 42 | use_group_conv: False 43 | width_per_group: 0 44 | when_downsample: 1 45 | replace_stride_with_avgpool: True ### in shortcut 46 | use_elem_max_for_ORNodes: False 47 | filter_list: [32, 128, 256, 512, 824] ### try to keep 1:4:2:2 ... except for the final stage which can be adusted for fitting the model size 48 | out_channels: [0, 0] 49 | blocks: [2, 2, 2, 1] 50 | dims: [2, 2, 4, 4] 51 | 52 | 53 | 54 | -------------------------------------------------------------------------------- /configs/aognet-40m-an-imagenet-200e-lbsmth-mixup.yaml: -------------------------------------------------------------------------------- 1 | --- 2 | batch_size: 128 3 | num_epoch: 200 4 | dataset: 'imagenet' 5 | num_classes: 1000 6 | crop_size: 224 7 | crop_interpolation: 2 ### 2: BILINEAR, 3:CUBIC 8 | use_cosine_lr: True ### 9 | cosine_lr_min: 0.0 10 | warmup_epochs: 5 11 | lr: 0.1 12 | lr_scale_factor: 256 13 | lr_milestones: [30, 60, 90, 100] 14 | momentum: 0.9 15 | wd: 0.0001 16 | nesterov: False 17 | activation_mode: 0 ### 1: leakyReLU, 2: ReLU6, other: ReLU 18 | init_mode: 'kaiming' 19 | norm_name: 'MixtureBatchNorm2d' 20 | #norm_name: 'BatchNorm2d' 21 | norm_groups: 0 22 | norm_k: [10, 10, 20, 20] ### per stage 23 | norm_attention_mode: 2 24 | norm_zero_gamma_init: False 25 | dataaug: 26 | imagenet_extra_aug: False ### ColorJitter and PCA 27 | labelsmoothing_rate: 0.1 28 | mixup_rate: 0.2 29 | stem: 30 | imagenet_head7x7: False 31 | replace_maxpool_with_res_bottleneck: False 32 | aognet: 33 | max_split: [2, 2, 2, 2] 34 | extra_node_hierarchy: [4, 4, 4, 4] ### 0: none, 4: lateral connection 35 | remove_symmetric_children_of_or_node: [1, 2, 1, 2] ### if true, aog structure is much simplified and bigger filters and more units can be used 36 | terminal_node_no_slice: [0, 0, 0, 0] 37 | stride: [1, 2, 2, 2] 38 | drop_rate: [0.0, 0.0, 0.1, 0.1] 39 | bottleneck_ratio: 0.25 40 | handle_dbl_cnt: True 41 | handle_tnode_dbl_cnt: False 42 | handle_dbl_cnt_in_param_init: False 43 | use_group_conv: False 44 | width_per_group: 0 45 | when_downsample: 1 46 | replace_stride_with_avgpool: True ### in shortcut 47 | use_elem_max_for_ORNodes: False 48 | filter_list: [56, 224, 448, 896, 1400] ### try to keep 1:4:2:2 ... except for the final stage which can be adusted for fitting the model size 49 | out_channels: [0, 0] 50 | blocks: [2, 2, 3, 1] 51 | dims: [2, 2, 4, 4] 52 | 53 | 54 | 55 | -------------------------------------------------------------------------------- /configs/aognet-40m-an-imagenet.yaml: -------------------------------------------------------------------------------- 1 | --- 2 | batch_size: 128 3 | num_epoch: 120 4 | dataset: 'imagenet' 5 | num_classes: 1000 6 | crop_size: 224 7 | crop_interpolation: 2 ### 2: BILINEAR, 3:CUBIC 8 | use_cosine_lr: True ### 9 | cosine_lr_min: 0.0 10 | warmup_epochs: 5 11 | lr: 0.1 12 | lr_scale_factor: 256 13 | lr_milestones: [30, 60, 90, 100] 14 | momentum: 0.9 15 | wd: 0.0001 16 | nesterov: False 17 | activation_mode: 0 ### 1: leakyReLU, 2: ReLU6, other: ReLU 18 | init_mode: 'kaiming' 19 | norm_name: 'MixtureBatchNorm2d' 20 | #norm_name: 'BatchNorm2d' 21 | norm_groups: 0 22 | norm_k: [10, 10, 20, 20] ### per stage 23 | norm_attention_mode: 2 24 | norm_zero_gamma_init: False 25 | dataaug: 26 | imagenet_extra_aug: False ### ColorJitter and PCA 27 | labelsmoothing_rate: 0.0 28 | mixup_rate: 0.0 29 | stem: 30 | imagenet_head7x7: False 31 | replace_maxpool_with_res_bottleneck: False 32 | aognet: 33 | max_split: [2, 2, 2, 2] 34 | extra_node_hierarchy: [4, 4, 4, 4] ### 0: none, 4: lateral connection 35 | remove_symmetric_children_of_or_node: [1, 2, 1, 2] ### if true, aog structure is much simplified and bigger filters and more units can be used 36 | terminal_node_no_slice: [0, 0, 0, 0] 37 | stride: [1, 2, 2, 2] 38 | drop_rate: [0.0, 0.0, 0.1, 0.1] 39 | bottleneck_ratio: 0.25 40 | handle_dbl_cnt: True 41 | handle_tnode_dbl_cnt: False 42 | handle_dbl_cnt_in_param_init: False 43 | use_group_conv: False 44 | width_per_group: 0 45 | when_downsample: 1 46 | replace_stride_with_avgpool: True ### in shortcut 47 | use_elem_max_for_ORNodes: False 48 | filter_list: [56, 224, 448, 896, 1400] ### try to keep 1:4:2:2 ... except for the final stage which can be adusted for fitting the model size 49 | out_channels: [0, 0] 50 | blocks: [2, 2, 3, 1] 51 | dims: [2, 2, 4, 4] 52 | 53 | 54 | 55 | -------------------------------------------------------------------------------- /configs/aognet-40m-imagenet-200e-lbsmth-mixup.yaml: -------------------------------------------------------------------------------- 1 | --- 2 | batch_size: 128 3 | num_epoch: 200 4 | dataset: 'imagenet' 5 | num_classes: 1000 6 | crop_size: 224 7 | crop_interpolation: 2 ### 2: BILINEAR, 3:CUBIC 8 | use_cosine_lr: True ### 9 | cosine_lr_min: 0.0 10 | warmup_epochs: 5 11 | lr: 0.1 12 | lr_scale_factor: 256 13 | lr_milestones: [30, 60, 90, 100] 14 | momentum: 0.9 15 | wd: 0.0001 16 | nesterov: False 17 | activation_mode: 0 ### 1: leakyReLU, 2: ReLU6, other: ReLU 18 | init_mode: 'kaiming' 19 | #norm_name: 'MixtureBatchNorm2d' 20 | norm_name: 'BatchNorm2d' 21 | norm_groups: 0 22 | norm_k: [10, 10, 20, 20] ### per stage 23 | norm_attention_mode: 2 24 | norm_zero_gamma_init: False 25 | dataaug: 26 | imagenet_extra_aug: False ### ColorJitter and PCA 27 | labelsmoothing_rate: 0.1 28 | mixup_rate: 0.2 29 | stem: 30 | imagenet_head7x7: False 31 | replace_maxpool_with_res_bottleneck: False 32 | aognet: 33 | max_split: [2, 2, 2, 2] 34 | extra_node_hierarchy: [4, 4, 4, 4] ### 0: none, 4: lateral connection 35 | remove_symmetric_children_of_or_node: [1, 2, 1, 2] ### if true, aog structure is much simplified and bigger filters and more units can be used 36 | terminal_node_no_slice: [0, 0, 0, 0] 37 | stride: [1, 2, 2, 2] 38 | drop_rate: [0.0, 0.0, 0.1, 0.1] 39 | bottleneck_ratio: 0.25 40 | handle_dbl_cnt: True 41 | handle_tnode_dbl_cnt: False 42 | handle_dbl_cnt_in_param_init: False 43 | use_group_conv: False 44 | width_per_group: 0 45 | when_downsample: 1 46 | replace_stride_with_avgpool: True ### in shortcut 47 | use_elem_max_for_ORNodes: False 48 | filter_list: [56, 224, 448, 896, 1400] ### try to keep 1:4:2:2 ... except for the final stage which can be adusted for fitting the model size 49 | out_channels: [0, 0] 50 | blocks: [2, 2, 3, 1] 51 | dims: [2, 2, 4, 4] 52 | 53 | 54 | 55 | -------------------------------------------------------------------------------- /configs/aognet-40m-imagenet.yaml: -------------------------------------------------------------------------------- 1 | --- 2 | batch_size: 128 3 | num_epoch: 120 4 | dataset: 'imagenet' 5 | num_classes: 1000 6 | crop_size: 224 7 | crop_interpolation: 2 ### 2: BILINEAR, 3:CUBIC 8 | use_cosine_lr: True ### 9 | cosine_lr_min: 0.0 10 | warmup_epochs: 5 11 | lr: 0.1 12 | lr_scale_factor: 256 13 | lr_milestones: [30, 60, 90, 100] 14 | momentum: 0.9 15 | wd: 0.0001 16 | nesterov: True 17 | activation_mode: 0 ### 1: leakyReLU, 2: ReLU6, other: ReLU 18 | init_mode: 'kaiming' 19 | norm_name: 'BatchNorm2d' 20 | norm_groups: 0 21 | norm_k: [10, 10, 20, 20] ### per stage 22 | norm_attention_mode: 2 23 | norm_zero_gamma_init: False 24 | dataaug: 25 | imagenet_extra_aug: False ### ColorJitter and PCA 26 | labelsmoothing_rate: 0.0 27 | mixup_rate: 0.0 28 | stem: 29 | imagenet_head7x7: False 30 | replace_maxpool_with_res_bottleneck: False 31 | aognet: 32 | max_split: [2, 2, 2, 2] 33 | extra_node_hierarchy: [4, 4, 4, 4] ### 0: none, 4: lateral connection 34 | remove_symmetric_children_of_or_node: [1, 2, 1, 2] ### if true, aog structure is much simplified and bigger filters and more units can be used 35 | terminal_node_no_slice: [0, 0, 0, 0] 36 | stride: [1, 2, 2, 2] 37 | drop_rate: [0.0, 0.0, 0.1, 0.1] 38 | bottleneck_ratio: 0.25 39 | handle_dbl_cnt: True 40 | handle_tnode_dbl_cnt: False 41 | handle_dbl_cnt_in_param_init: False 42 | use_group_conv: False 43 | width_per_group: 0 44 | when_downsample: 1 45 | replace_stride_with_avgpool: False ### in shortcut 46 | use_elem_max_for_ORNodes: False 47 | filter_list: [56, 224, 448, 896, 1400] ### try to keep 1:4:2:2 ... except for the final stage which can be adusted for fitting the model size 48 | out_channels: [0, 0] 49 | blocks: [2, 2, 3, 1] 50 | dims: [2, 2, 4, 4] 51 | 52 | 53 | 54 | -------------------------------------------------------------------------------- /configs/mobilenetv2-an-imagenet.yaml: -------------------------------------------------------------------------------- 1 | --- 2 | arch: 'mobilenet_v2' 3 | batch_size: 128 4 | num_epoch: 150 5 | dataset: 'imagenet' 6 | num_classes: 1000 7 | crop_size: 224 8 | crop_interpolation: 2 ### 2: BILINEAR, 3:CUBIC 9 | use_cosine_lr: True 10 | cosine_lr_min: 0.0 11 | warmup_epochs: 5 12 | lr: 0.05 13 | lr_scale_factor: 256 14 | lr_milestones: [30, 60, 90, 100] 15 | momentum: 0.9 16 | wd: 0.00004 17 | nesterov: False 18 | activation_mode: 0 ### 1: leakyReLU, 2: , other: ReLU 19 | #norm_name: 'BatchNorm2d' 20 | norm_name: 'MixtureBatchNorm2d' 21 | norm_groups: 32 22 | norm_k: [5, 5, 5, 5, 10, 10, 10] 23 | norm_attention_mode: 2 24 | dataaug: 25 | imagenet_extra_aug: False ###ColorJitter and PCA 26 | labelsmoothing_rate: 0.0 27 | mixup_rate: 0.0 28 | mobilenet: 29 | rm_se: False 30 | use_mn_in_se: False 31 | -------------------------------------------------------------------------------- /configs/resnet-101-an-imagenet-200e-lbsmth-mixup.yaml: -------------------------------------------------------------------------------- 1 | --- 2 | arch: 'resnet101' 3 | batch_size: 128 4 | num_epoch: 200 5 | dataset: 'imagenet' 6 | num_classes: 1000 7 | crop_size: 224 8 | crop_interpolation: 2 ### 2: BILINEAR, 3:CUBIC 9 | use_cosine_lr: True 10 | cosine_lr_min: 0.0 11 | warmup_epochs: 5 12 | lr: 0.1 13 | lr_scale_factor: 256 14 | lr_milestones: [30, 60, 90, 100] 15 | momentum: 0.9 16 | wd: 0.0001 17 | nesterov: False 18 | activation_mode: 0 ### 1: leakyReLU, 2: , other: ReLU 19 | norm_name: 'MixtureBatchNorm2d' 20 | norm_groups: 32 21 | norm_k: [10, 10, 20, 20] 22 | norm_attention_mode: 2 23 | norm_zero_gamma_init: True 24 | norm_all_mix: False 25 | dataaug: 26 | imagenet_extra_aug: False ###ColorJitter and PCA 27 | labelsmoothing_rate: 0.1 28 | mixup_rate: 0.2 29 | stem: 30 | imagenet_head7x7: False 31 | stem_kernel_size: 7 32 | stem_stride: 2 33 | replace_maxpool_with_res_bottleneck: False 34 | resnet: 35 | base_inplanes: 64 36 | replace_stride_with_dilation: [False, False, False] 37 | replace_stride_with_avgpool: True 38 | extra_norm_ac: False 39 | 40 | 41 | 42 | 43 | -------------------------------------------------------------------------------- /configs/resnet-101-an-imagenet.yaml: -------------------------------------------------------------------------------- 1 | --- 2 | arch: 'resnet101' 3 | batch_size: 128 4 | num_epoch: 120 5 | dataset: 'imagenet' 6 | num_classes: 1000 7 | crop_size: 224 8 | crop_interpolation: 2 ### 2: BILINEAR, 3:CUBIC 9 | use_cosine_lr: True 10 | cosine_lr_min: 0.0 11 | warmup_epochs: 5 12 | lr: 0.1 13 | lr_scale_factor: 256 14 | lr_milestones: [30, 60, 90, 100] 15 | momentum: 0.9 16 | wd: 0.0001 17 | nesterov: False 18 | activation_mode: 0 ### 1: leakyReLU, 2: , other: ReLU 19 | norm_name: 'MixtureBatchNorm2d' 20 | norm_groups: 32 21 | norm_k: [10, 10, 20, 20] 22 | norm_attention_mode: 2 23 | norm_zero_gamma_init: True 24 | norm_all_mix: False 25 | dataaug: 26 | imagenet_extra_aug: False ###ColorJitter and PCA 27 | labelsmoothing_rate: 0.0 28 | mixup_rate: 0.0 29 | stem: 30 | imagenet_head7x7: True 31 | stem_kernel_size: 7 32 | stem_stride: 2 33 | replace_maxpool_with_res_bottleneck: False 34 | resnet: 35 | base_inplanes: 64 36 | replace_stride_with_dilation: [False, False, False] 37 | replace_stride_with_avgpool: False 38 | extra_norm_ac: False 39 | 40 | 41 | 42 | 43 | -------------------------------------------------------------------------------- /configs/resnet-101-imagenet-200e-lbsmth-mixup.yaml: -------------------------------------------------------------------------------- 1 | --- 2 | arch: 'resnet101' 3 | batch_size: 128 4 | num_epoch: 200 5 | dataset: 'imagenet' 6 | num_classes: 1000 7 | crop_size: 224 8 | crop_interpolation: 2 ### 2: BILINEAR, 3:CUBIC 9 | use_cosine_lr: True 10 | cosine_lr_min: 0.0 11 | warmup_epochs: 5 12 | lr: 0.1 13 | lr_scale_factor: 256 14 | lr_milestones: [30, 60, 90, 100] 15 | momentum: 0.9 16 | wd: 0.0001 17 | nesterov: False 18 | activation_mode: 0 ### 1: leakyReLU, 2: , other: ReLU 19 | norm_name: 'BatchNorm2d' 20 | norm_groups: 32 21 | norm_k: [10, 10, 20, 20] 22 | norm_attention_mode: 2 23 | norm_zero_gamma_init: True 24 | norm_all_mix: False 25 | dataaug: 26 | imagenet_extra_aug: False ###ColorJitter and PCA 27 | labelsmoothing_rate: 0.1 28 | mixup_rate: 0.2 29 | stem: 30 | imagenet_head7x7: False 31 | stem_kernel_size: 7 32 | stem_stride: 2 33 | replace_maxpool_with_res_bottleneck: False 34 | resnet: 35 | base_inplanes: 64 36 | replace_stride_with_dilation: [False, False, False] 37 | replace_stride_with_avgpool: True 38 | extra_norm_ac: False 39 | 40 | 41 | 42 | 43 | -------------------------------------------------------------------------------- /configs/resnet-101-imagenet.yaml: -------------------------------------------------------------------------------- 1 | --- 2 | arch: 'resnet101' 3 | batch_size: 128 4 | num_epoch: 120 5 | dataset: 'imagenet' 6 | num_classes: 1000 7 | crop_size: 224 8 | crop_interpolation: 2 ### 2: BILINEAR, 3:CUBIC 9 | use_cosine_lr: True 10 | cosine_lr_min: 0.0 11 | warmup_epochs: 5 12 | lr: 0.1 13 | lr_scale_factor: 256 14 | lr_milestones: [30, 60, 90, 100] 15 | momentum: 0.9 16 | wd: 0.0001 17 | nesterov: False 18 | activation_mode: 0 ### 1: leakyReLU, 2: , other: ReLU 19 | norm_name: 'BatchNorm2d' 20 | norm_groups: 32 21 | norm_k: [10, 10, 20, 20] 22 | norm_attention_mode: 2 23 | norm_zero_gamma_init: True 24 | norm_all_mix: False 25 | dataaug: 26 | imagenet_extra_aug: False ###ColorJitter and PCA 27 | labelsmoothing_rate: 0.0 28 | mixup_rate: 0.0 29 | stem: 30 | imagenet_head7x7: True 31 | stem_kernel_size: 7 32 | stem_stride: 2 33 | replace_maxpool_with_res_bottleneck: False 34 | resnet: 35 | base_inplanes: 64 36 | replace_stride_with_dilation: [False, False, False] 37 | replace_stride_with_avgpool: False 38 | extra_norm_ac: False 39 | 40 | 41 | 42 | 43 | -------------------------------------------------------------------------------- /configs/resnet-50-an-imagenet-200e-lbsmth-mixup.yaml: -------------------------------------------------------------------------------- 1 | --- 2 | arch: 'resnet50' 3 | batch_size: 128 4 | num_epoch: 200 5 | dataset: 'imagenet' 6 | num_classes: 1000 7 | crop_size: 224 8 | crop_interpolation: 2 ### 2: BILINEAR, 3:CUBIC 9 | use_cosine_lr: True 10 | cosine_lr_min: 0.0 11 | warmup_epochs: 5 12 | lr: 0.1 13 | lr_scale_factor: 256 14 | lr_milestones: [30, 60, 90, 100] 15 | momentum: 0.9 16 | wd: 0.0001 17 | nesterov: False 18 | activation_mode: 0 ### 1: leakyReLU, 2: , other: ReLU 19 | norm_name: 'MixtureBatchNorm2d' 20 | norm_groups: 32 21 | norm_k: [10, 10, 20, 20] 22 | norm_attention_mode: 2 23 | norm_zero_gamma_init: True 24 | norm_all_mix: False 25 | dataaug: 26 | imagenet_extra_aug: False ###ColorJitter and PCA 27 | labelsmoothing_rate: 0.1 28 | mixup_rate: 0.2 29 | stem: 30 | imagenet_head7x7: False 31 | stem_kernel_size: 7 32 | stem_stride: 2 33 | replace_maxpool_with_res_bottleneck: False 34 | resnet: 35 | base_inplanes: 64 36 | replace_stride_with_dilation: [False, False, False] 37 | replace_stride_with_avgpool: True 38 | extra_norm_ac: False 39 | 40 | 41 | 42 | 43 | -------------------------------------------------------------------------------- /configs/resnet-50-an-imagenet.yaml: -------------------------------------------------------------------------------- 1 | --- 2 | arch: 'resnet50' 3 | batch_size: 128 4 | num_epoch: 120 5 | dataset: 'imagenet' 6 | num_classes: 1000 7 | crop_size: 224 8 | crop_interpolation: 2 ### 2: BILINEAR, 3:CUBIC 9 | use_cosine_lr: True 10 | cosine_lr_min: 0.0 11 | warmup_epochs: 5 12 | lr: 0.1 13 | lr_scale_factor: 256 14 | lr_milestones: [30, 60, 90, 100] 15 | momentum: 0.9 16 | wd: 0.0001 17 | nesterov: False 18 | activation_mode: 0 ### 1: leakyReLU, 2: , other: ReLU 19 | norm_name: 'MixtureBatchNorm2d' 20 | norm_groups: 32 21 | norm_k: [10, 10, 20, 20] 22 | norm_attention_mode: 2 23 | norm_zero_gamma_init: True 24 | norm_all_mix: False 25 | dataaug: 26 | imagenet_extra_aug: False ###ColorJitter and PCA 27 | labelsmoothing_rate: 0.0 28 | mixup_rate: 0.0 29 | stem: 30 | imagenet_head7x7: True 31 | stem_kernel_size: 7 32 | stem_stride: 2 33 | replace_maxpool_with_res_bottleneck: False 34 | resnet: 35 | base_inplanes: 64 36 | replace_stride_with_dilation: [False, False, False] 37 | replace_stride_with_avgpool: False 38 | extra_norm_ac: False 39 | 40 | 41 | 42 | 43 | -------------------------------------------------------------------------------- /configs/resnet-50-imagenet-200e-lbsmth-mixup.yaml: -------------------------------------------------------------------------------- 1 | --- 2 | arch: 'resnet50' 3 | batch_size: 128 4 | num_epoch: 200 5 | dataset: 'imagenet' 6 | num_classes: 1000 7 | crop_size: 224 8 | crop_interpolation: 2 ### 2: BILINEAR, 3:CUBIC 9 | use_cosine_lr: True 10 | cosine_lr_min: 0.0 11 | warmup_epochs: 5 12 | lr: 0.1 13 | lr_scale_factor: 256 14 | lr_milestones: [30, 60, 90, 100] 15 | momentum: 0.9 16 | wd: 0.0001 17 | nesterov: False 18 | activation_mode: 0 ### 1: leakyReLU, 2: , other: ReLU 19 | norm_name: 'BatchNorm2d' 20 | norm_groups: 32 21 | norm_k: [10, 10, 20, 20] 22 | norm_attention_mode: 2 23 | norm_zero_gamma_init: True 24 | norm_all_mix: False 25 | dataaug: 26 | imagenet_extra_aug: False ###ColorJitter and PCA 27 | labelsmoothing_rate: 0.1 28 | mixup_rate: 0.2 29 | stem: 30 | imagenet_head7x7: False 31 | stem_kernel_size: 7 32 | stem_stride: 2 33 | replace_maxpool_with_res_bottleneck: False 34 | resnet: 35 | base_inplanes: 64 36 | replace_stride_with_dilation: [False, False, False] 37 | replace_stride_with_avgpool: True 38 | extra_norm_ac: False 39 | 40 | 41 | 42 | 43 | -------------------------------------------------------------------------------- /configs/resnet-50-imagenet.yaml: -------------------------------------------------------------------------------- 1 | --- 2 | arch: 'resnet50' 3 | batch_size: 128 4 | num_epoch: 120 5 | dataset: 'imagenet' 6 | num_classes: 1000 7 | crop_size: 224 8 | crop_interpolation: 2 ### 2: BILINEAR, 3:CUBIC 9 | use_cosine_lr: True 10 | cosine_lr_min: 0.0 11 | warmup_epochs: 5 12 | lr: 0.1 13 | lr_scale_factor: 256 14 | lr_milestones: [30, 60, 90, 100] 15 | momentum: 0.9 16 | wd: 0.0001 17 | nesterov: False 18 | activation_mode: 0 ### 1: leakyReLU, 2: , other: ReLU 19 | norm_name: 'BatchNorm2d' 20 | norm_groups: 32 21 | norm_k: [10, 10, 20, 20] 22 | norm_attention_mode: 2 23 | norm_zero_gamma_init: True 24 | norm_all_mix: False 25 | dataaug: 26 | imagenet_extra_aug: False ###ColorJitter and PCA 27 | labelsmoothing_rate: 0.0 28 | mixup_rate: 0.0 29 | stem: 30 | imagenet_head7x7: True 31 | stem_kernel_size: 7 32 | stem_stride: 2 33 | replace_maxpool_with_res_bottleneck: False 34 | resnet: 35 | base_inplanes: 64 36 | replace_stride_with_dilation: [False, False, False] 37 | replace_stride_with_avgpool: False 38 | extra_norm_ac: False 39 | 40 | 41 | 42 | 43 | -------------------------------------------------------------------------------- /images/an-comparison.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/iVMCL/AOGNet-v2/a95a8696c131331607e81bb31eeae3405a76b969/images/an-comparison.png -------------------------------------------------------------------------------- /images/teaser-imagenet-dissection.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/iVMCL/AOGNet-v2/a95a8696c131331607e81bb31eeae3405a76b969/images/teaser-imagenet-dissection.png -------------------------------------------------------------------------------- /images/teaser.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/iVMCL/AOGNet-v2/a95a8696c131331607e81bb31eeae3405a76b969/images/teaser.png -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/iVMCL/AOGNet-v2/a95a8696c131331607e81bb31eeae3405a76b969/models/__init__.py -------------------------------------------------------------------------------- /models/aognet/AOG.py: -------------------------------------------------------------------------------- 1 | """ RESEARCH ONLY LICENSE 2 | Copyright (c) 2018-2019 North Carolina State University. 3 | All rights reserved. 4 | Redistribution and use in source and binary forms, with or without modification, are permitted provided 5 | that the following conditions are met: 6 | 1. Redistributions and use are permitted for internal research purposes only, and commercial use 7 | is strictly prohibited under this license. Inquiries regarding commercial use should be directed to the 8 | Office of Research Commercialization at North Carolina State University, 919-215-7199, 9 | https://research.ncsu.edu/commercialization/contact/, commercialization@ncsu.edu . 10 | 2. Commercial use means the sale, lease, export, transfer, conveyance or other distribution to a 11 | third party for financial gain, income generation or other commercial purposes of any kind, whether 12 | direct or indirect. Commercial use also means providing a service to a third party for financial gain, 13 | income generation or other commercial purposes of any kind, whether direct or indirect. 14 | 3. Redistributions of source code must retain the above copyright notice, this list of conditions and 15 | the following disclaimer. 16 | 4. Redistributions in binary form must reproduce the above copyright notice, this list of conditions 17 | and the following disclaimer in the documentation and/or other materials provided with the 18 | distribution. 19 | 5. The names “North Carolina State University”, “NCSU” and any trade-name, personal name, 20 | trademark, trade device, service mark, symbol, image, icon, or any abbreviation, contraction or 21 | simulation thereof owned by North Carolina State University must not be used to endorse or promote 22 | products derived from this software without prior written permission. For written permission, please 23 | contact trademarks@ncsu.edu. 24 | Disclaimer: THIS SOFTWARE IS PROVIDED “AS IS” AND ANY EXPRESSED OR IMPLIED WARRANTIES, 25 | INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR 26 | A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NORTH CAROLINA STATE UNIVERSITY BE 27 | LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES 28 | (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, 29 | DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF 30 | LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR 31 | OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE 32 | POSSIBILITY OF SUCH DAMAGE. 33 | """ 34 | # The system is protected via patent (pending) 35 | # Written by Tianfu Wu and Xi Song 36 | # Contact: tianfu_wu@ncsu.edu, xsong.lhi@gmail.com 37 | 38 | # -*- coding: utf-8 -*- 39 | from __future__ import absolute_import 40 | from __future__ import division 41 | from __future__ import print_function # force to use print as function print(args) 42 | from __future__ import unicode_literals 43 | 44 | from math import ceil, floor 45 | from collections import deque 46 | import numpy as np 47 | import os 48 | import random 49 | import math 50 | import copy 51 | 52 | 53 | def get_aog(grid_ht, grid_wd, min_size=1, max_split=2, 54 | not_use_large_TerminalNode=False, turn_off_size_ratio_TerminalNode=1./4., 55 | not_use_intermediate_TerminalNode=False, 56 | use_root_TerminalNode=True, use_tnode_as_alpha_channel=0, 57 | use_tnode_topdown_connection=False, 58 | use_tnode_bottomup_connection=False, 59 | use_tnode_bottomup_connection_layerwise=False, 60 | use_tnode_bottomup_connection_sequential=False, 61 | use_node_lateral_connection=False, # not include T-nodes 62 | use_node_lateral_connection_1=False, # include T-nodes 63 | use_super_OrNode=False, 64 | remove_single_child_or_node=False, 65 | remove_symmetric_children_of_or_node=0, 66 | mark_symmetric_syntatic_subgraph = False, 67 | max_children_kept_for_or=1000): 68 | aog_param = Param(grid_ht=grid_ht, grid_wd=grid_wd, min_size=min_size, max_split=max_split, 69 | not_use_large_TerminalNode=not_use_large_TerminalNode, 70 | turn_off_size_ratio_TerminalNode=turn_off_size_ratio_TerminalNode, 71 | not_use_intermediate_TerminalNode=not_use_intermediate_TerminalNode, 72 | use_root_TerminalNode=use_root_TerminalNode, 73 | use_tnode_as_alpha_channel=use_tnode_as_alpha_channel, 74 | use_tnode_topdown_connection=use_tnode_topdown_connection, 75 | use_tnode_bottomup_connection=use_tnode_bottomup_connection, 76 | use_tnode_bottomup_connection_layerwise=use_tnode_bottomup_connection_layerwise, 77 | use_tnode_bottomup_connection_sequential=use_tnode_bottomup_connection_sequential, 78 | use_node_lateral_connection=use_node_lateral_connection, 79 | use_node_lateral_connection_1=use_node_lateral_connection_1, 80 | use_super_OrNode=use_super_OrNode, 81 | remove_single_child_or_node=remove_single_child_or_node, 82 | mark_symmetric_syntatic_subgraph = mark_symmetric_syntatic_subgraph, 83 | remove_symmetric_children_of_or_node=remove_symmetric_children_of_or_node, 84 | max_children_kept_for_or=max_children_kept_for_or) 85 | aog = AOGrid(aog_param) 86 | aog.Create() 87 | return aog 88 | 89 | 90 | class NodeType(object): 91 | OrNode = "OrNode" 92 | AndNode = "AndNode" 93 | TerminalNode = "TerminalNode" 94 | Unknow = "Unknown" 95 | 96 | 97 | class SplitType(object): 98 | HorSplit = "Hor" 99 | VerSplit = "Ver" 100 | Unknown = "Unknown" 101 | 102 | 103 | class Param(object): 104 | """Input parameters for creating an AOG 105 | """ 106 | 107 | def __init__(self, grid_ht=3, grid_wd=3, max_split=2, min_size=1, control_side_length=False, 108 | overlap_ratio=0., use_root_TerminalNode=False, 109 | not_use_large_TerminalNode=False, turn_off_size_ratio_TerminalNode=0.5, 110 | not_use_intermediate_TerminalNode= False, 111 | use_tnode_as_alpha_channel=0, 112 | use_tnode_topdown_connection=False, 113 | use_tnode_bottomup_connection=False, 114 | use_tnode_bottomup_connection_layerwise=False, 115 | use_tnode_bottomup_connection_sequential=False, 116 | use_node_lateral_connection=False, 117 | use_node_lateral_connection_1=False, 118 | use_super_OrNode=False, 119 | remove_single_child_or_node=False, 120 | remove_symmetric_children_of_or_node=0, 121 | mark_symmetric_syntatic_subgraph=False, 122 | max_children_kept_for_or=100): 123 | self.grid_ht = grid_ht 124 | self.grid_wd = grid_wd 125 | self.max_split = max_split # maximum number of child nodes when splitting an AND-node 126 | self.min_size = min_size # minimum side length or minimum area allowed 127 | self.control_side_length = control_side_length 128 | self.overlap_ratio = overlap_ratio 129 | self.use_root_terminal_node = use_root_TerminalNode 130 | self.not_use_large_terminal_node = not_use_large_TerminalNode 131 | self.turn_off_size_ratio_terminal_node = turn_off_size_ratio_TerminalNode 132 | self.not_use_intermediate_TerminalNode = not_use_intermediate_TerminalNode 133 | self.use_tnode_as_alpha_channel = use_tnode_as_alpha_channel 134 | self.use_tnode_topdown_connection = use_tnode_topdown_connection 135 | self.use_tnode_bottomup_connection = use_tnode_bottomup_connection 136 | self.use_tnode_bottomup_connection_layerwise = use_tnode_bottomup_connection_layerwise 137 | self.use_node_lateral_connection = use_node_lateral_connection 138 | self.use_node_lateral_connection_1 = use_node_lateral_connection_1 139 | self.use_tnode_bottomup_connection_sequential = use_tnode_bottomup_connection_sequential 140 | assert 1 >= self.use_node_lateral_connection_1 + self.use_node_lateral_connection + \ 141 | self.use_tnode_topdown_connection + self.use_tnode_bottomup_connection + \ 142 | self.use_tnode_bottomup_connection_layerwise + self.use_tnode_bottomup_connection_sequential, \ 143 | 'only one type of node hierarchy can be used' 144 | self.use_super_OrNode = use_super_OrNode 145 | self.remove_single_child_or_node = remove_single_child_or_node 146 | self.remove_symmetric_children_of_or_node = remove_symmetric_children_of_or_node #0: not, 1: keep left, 2: keep right 147 | self.mark_symmetric_syntatic_subgraph = mark_symmetric_syntatic_subgraph # true, only mark the nodes which will be removed based on remove_symmetric_children_of_or_node 148 | self.max_children_kept_for_or = max_children_kept_for_or # how many child nodes kept for an OR-node 149 | 150 | self.get_tag() 151 | 152 | def get_tag(self): 153 | # identifier useful for naming a particular aog 154 | self.tag = '{}-{}-{}-{}-{}-{}-{}-{}-{}-{}-{}-{}-{}-{}-{}-{}-{}-{}-{}-{}-{}-{}'.format( 155 | self.grid_wd, self.grid_ht, self.max_split, 156 | self.min_size, self.control_side_length, 157 | self.overlap_ratio, 158 | self.use_root_terminal_node, 159 | self.not_use_large_terminal_node, 160 | self.turn_off_size_ratio_terminal_node, 161 | self.not_use_intermediate_TerminalNode, 162 | self.use_tnode_as_alpha_channel, 163 | self.use_tnode_topdown_connection, 164 | self.use_tnode_bottomup_connection, 165 | self.use_tnode_bottomup_connection_layerwise, 166 | self.use_tnode_bottomup_connection_sequential, 167 | self.use_node_lateral_connection, 168 | self.use_node_lateral_connection_1, 169 | self.use_super_OrNode, 170 | self.remove_single_child_or_node, 171 | self.remove_symmetric_children_of_or_node, 172 | self.mark_symmetric_syntatic_subgraph, 173 | self.max_children_kept_for_or) 174 | 175 | 176 | class Rect(object): 177 | """A simple rectangle 178 | """ 179 | 180 | def __init__(self, x1=0, y1=0, x2=0, y2=0): 181 | self.x1 = x1 182 | self.y1 = y1 183 | self.x2 = x2 184 | self.y2 = y2 185 | 186 | def __eq__(self, other): 187 | """Override the default Equals behavior""" 188 | if isinstance(other, self.__class__): 189 | return self.__dict__ == other.__dict__ 190 | return NotImplemented 191 | 192 | def __ne__(self, other): 193 | """Define a non-equality test""" 194 | if isinstance(other, self.__class__): 195 | return not self.__eq__(other) 196 | return NotImplemented 197 | 198 | def __hash__(self): 199 | """Override the default hash behavior (that returns id or the object)""" 200 | return hash(tuple(sorted(self.__dict__.items()))) 201 | 202 | def Width(self): 203 | return self.x2 - self.x1 + 1 204 | 205 | def Height(self): 206 | return self.y2 - self.y1 + 1 207 | 208 | def Area(self): 209 | return self.Width() * self.Height() 210 | 211 | def MinLength(self): 212 | return min(self.Width(), self.Height()) 213 | 214 | def IsOverlap(self, other): 215 | assert isinstance(other, self.__class__) 216 | 217 | x1 = max(self.x1, other.x1) 218 | x2 = min(self.x2, other.x2) 219 | if x1 > x2: 220 | return False 221 | 222 | y1 = max(self.y1, other.y1) 223 | y2 = min(self.y2, other.y2) 224 | if y1 > y2: 225 | return False 226 | 227 | return True 228 | 229 | def IsSame(self, other): 230 | assert isinstance(other, self.__class__) 231 | 232 | return self.Width() == other.Width() and self.Height() == other.Height() 233 | 234 | 235 | class Node(object): 236 | """Types of nodes in an AOG 237 | AND-node (structural decomposition), 238 | OR-node (alternative decompositions), 239 | TERMINAL-node (link to data). 240 | """ 241 | 242 | def __init__(self, node_id=-1, node_type=NodeType.Unknow, rect_idx=-1, 243 | child_ids=None, parent_ids=None, 244 | split_type=SplitType.Unknown, split_step1=0, split_step2=0, is_symm=False, 245 | ancestors_ids=None): 246 | self.id = node_id 247 | self.node_type = node_type 248 | self.rect_idx = rect_idx 249 | self.child_ids = child_ids if child_ids is not None else [] 250 | self.parent_ids = parent_ids if parent_ids is not None else [] 251 | self.ancestors_ids = ancestors_ids if ancestors_ids is not None else [] # root or-node exlusive 252 | self.split_type = split_type 253 | self.split_step1 = split_step1 254 | self.split_step2 = split_step2 255 | 256 | # some utility variables used in object detection models 257 | self.on_off = True 258 | self.out_edge_visited_count = [] 259 | self.which_classes_visited = {} # key=class_name, val=frequency 260 | self.npaths = 0.0 261 | self.is_symmetric = False 262 | self.has_dbl_counting = False 263 | 264 | def __eq__(self, other): 265 | if isinstance(other, self.__class__): 266 | res = ((self.node_type == other.node_type) and (self.rect_idx == other.rect_idx)) 267 | if res: 268 | if self.node_type != NodeType.AndNode: 269 | return True 270 | else: 271 | if self.split_type != SplitType.Unknown: 272 | return (self.split_type == other.split_type) and (self.split_step1 == other.split_step1) and \ 273 | (self.split_step2 == other.split_step2) 274 | else: 275 | return (set(self.child_ids) == set(other.child_ids)) 276 | 277 | return False 278 | 279 | return NotImplemented 280 | 281 | def __ne__(self, other): 282 | if isinstance(other, self.__class__): 283 | return not self.__eq__(other) 284 | return NotImplemented 285 | 286 | def __hash__(self): 287 | """Override the default hash behavior (that returns id or the object)""" 288 | return hash(tuple(sorted(self.__dict__.items()))) 289 | 290 | 291 | class AOGrid(object): 292 | """The AOGrid defines a Directed Acyclic And-Or Graph 293 | which is used to explore/unfold the space of latent structures 294 | of a grid (e.g., a 7 * 7 grid for a 100 * 200 lattice) 295 | """ 296 | 297 | def __init__(self, param_): 298 | assert isinstance(param_, Param) 299 | self.param = param_ 300 | assert self.param.max_split > 1 301 | self.primitive_set = [] 302 | self.node_set = [] 303 | self.num_TNodes = 0 304 | self.num_AndNodes = 0 305 | self.num_OrNodes = 0 306 | self.DFS = [] 307 | self.BFS = [] 308 | self.node_DFS = {} 309 | self.node_BFS = {} 310 | self.OrNodeIdxInBFS = {} 311 | self.TNodeIdxInBFS = {} 312 | 313 | # for color consistency in viz 314 | self.TNodeColors = {} 315 | 316 | def _AddPrimitve(self, rect): 317 | assert isinstance(rect, Rect) 318 | 319 | if rect in self.primitive_set: 320 | return self.primitive_set.index(rect) 321 | 322 | self.primitive_set.append(rect) 323 | 324 | return len(self.primitive_set) - 1 325 | 326 | def _AddNode(self, node, not_create_if_existed=True): 327 | assert isinstance(node, Node) 328 | 329 | if node in self.node_set and not_create_if_existed: 330 | node = self.node_set[self.node_set.index(node)] 331 | return False, node 332 | 333 | node.id = len(self.node_set) 334 | if node.node_type == NodeType.AndNode: 335 | self.num_AndNodes += 1 336 | elif node.node_type == NodeType.OrNode: 337 | self.num_OrNodes += 1 338 | elif node.node_type == NodeType.TerminalNode: 339 | self.num_TNodes += 1 340 | else: 341 | raise NotImplementedError 342 | 343 | self.node_set.append(node) 344 | 345 | return True, node 346 | 347 | def _DoSplit(self, rect): 348 | assert isinstance(rect, Rect) 349 | 350 | if self.param.control_side_length: 351 | return rect.Width() >= self.param.min_size and rect.Height() >= self.param.min_size 352 | 353 | return rect.Area() > self.param.min_size 354 | 355 | def _SplitStep(self, sz): 356 | if self.param.control_side_length: 357 | return self.param.min_size 358 | 359 | if sz >= self.param.min_size: 360 | return 1 361 | else: 362 | return int(ceil(self.param.min_size / sz)) 363 | 364 | def _DFS(self, id, q, visited): 365 | if visited[id] == 1: 366 | raise RuntimeError 367 | 368 | visited[id] = 1 369 | for i in self.node_set[id].child_ids: 370 | if visited[i] < 2: 371 | q, visited = self._DFS(i, q, visited) 372 | 373 | if self.node_set[id].on_off: 374 | q.append(id) 375 | 376 | visited[id] = 2 377 | 378 | return q, visited 379 | 380 | def _BFS(self, id, q, visited): 381 | q = [id] 382 | visited[id] = 0 # count indegree 383 | i = 0 384 | while i < len(q): 385 | node = self.node_set[q[i]] 386 | for j in node.child_ids: 387 | visited[j] += 1 388 | if j not in q: 389 | q.append(j) 390 | 391 | i += 1 392 | 393 | q = [id] 394 | i = 0 395 | while i < len(q): 396 | node = self.node_set[q[i]] 397 | for j in node.child_ids: 398 | visited[j] -= 1 399 | if visited[j] == 0: 400 | q.append(j) 401 | i += 1 402 | 403 | return q, visited 404 | 405 | def _countPaths(self, s, t, npaths): 406 | if s.id == t.id: 407 | return 1.0 408 | else: 409 | if not npaths[s.id]: 410 | rect = self.primitive_set[s.rect_idx] 411 | #ids1 = set(s.ancestors_ids) 412 | ids1 = set(s.parent_ids) 413 | num_shared = 0 414 | for c in s.child_ids: 415 | ch = self.node_set[c] 416 | ch_rect = self.primitive_set[ch.rect_idx] 417 | #ids2 = ch.ancestors_ids 418 | ids2 = ch.parent_ids 419 | 420 | if s.node_type == NodeType.AndNode and ch.node_type == NodeType.AndNode and \ 421 | rect.Width() == ch_rect.Width() and rect.Height() == ch_rect.Height(): 422 | continue 423 | if s.node_type == NodeType.OrNode and \ 424 | ((ch.node_type == NodeType.OrNode) or \ 425 | (ch.node_type == NodeType.TerminalNode and (rect.Area() < ch_rect.Area()) )): 426 | continue 427 | 428 | npaths[s.id] += self._countPaths(ch, t, npaths) 429 | return npaths[s.id] 430 | 431 | def _AssignParentIds(self): 432 | for i in range(len(self.node_set)): 433 | self.node_set[i].parent_ids = [] 434 | 435 | for node in self.node_set: 436 | for i in node.child_ids: 437 | self.node_set[i].parent_ids.append(node.id) 438 | 439 | for i in range(len(self.node_set)): 440 | self.node_set[i].parent_ids = list(set(self.node_set[i].parent_ids)) 441 | 442 | def _AssignAncestorsIds(self): 443 | self._AssignParentIds() 444 | 445 | assert len(self.BFS) == len(self.node_set) 446 | self.node_set[self.BFS[0]].ancestors_ids = [] 447 | 448 | for nid in self.BFS[1:]: 449 | node = self.node_set[nid] 450 | rect = self.primitive_set[node.rect_idx] 451 | ancestors = [] 452 | for pid in node.parent_ids: 453 | p = self.node_set[pid] 454 | p_rect = self.primitive_set[p.rect_idx] 455 | equal_size = rect.Width() == p_rect.Width() and \ 456 | rect.Height() == p_rect.Height() 457 | # AND-to-AND lateral path 458 | if node.node_type == NodeType.AndNode and p.node_type == NodeType.AndNode and \ 459 | equal_size: 460 | continue 461 | # OR-to-OR/T lateral path 462 | if node.node_type == NodeType.OrNode and \ 463 | ((p.node_type == NodeType.OrNode and equal_size) or \ 464 | (p.node_type == NodeType.TerminalNode and (rect.Area() < p_rect.Area()) )): 465 | continue 466 | for ppid in p.ancestors_ids: 467 | if ppid != self.BFS[0] and ppid not in ancestors: 468 | ancestors.append(ppid) 469 | ancestors.append(pid) 470 | self.node_set[nid].ancestors_ids = list(set(ancestors)) 471 | 472 | def _Postprocessing(self, root_id): 473 | self.DFS = [] 474 | self.BFS = [] 475 | visited = np.zeros(len(self.node_set)) 476 | self.DFS, _ = self._DFS(root_id, self.DFS, visited) 477 | visited = np.zeros(len(self.node_set)) 478 | self.BFS, _ = self._BFS(root_id, self.BFS, visited) 479 | self._AssignAncestorsIds() 480 | 481 | def _FindNodeIdWithGivenRect(self, rect, node_type): 482 | for node in self.node_set: 483 | if node.node_type != node_type: 484 | continue 485 | if rect == self.primitive_set[node.rect_idx]: 486 | return node.id 487 | 488 | return -1 489 | 490 | def _add_tnode_topdown_connection(self): 491 | 492 | assert self.param.use_root_terminal_node 493 | 494 | prim_type = [self.param.grid_ht, self.param.grid_wd] 495 | tnode_queue = self.find_node_ids_with_given_prim_type(prim_type) 496 | assert len(tnode_queue) == 1 497 | 498 | i = 0 499 | while i < len(tnode_queue): 500 | id_ = tnode_queue[i] 501 | node = self.node_set[id_] 502 | i += 1 503 | 504 | rect = self.primitive_set[node.rect_idx] 505 | 506 | ids = [] 507 | for y in range(0, rect.Height()): 508 | for x in range(0, rect.Width()): 509 | if y == 0 and x == 0: 510 | continue 511 | prim_type = [rect.Height()-y, rect.Width()-x] 512 | ids += self.find_node_ids_with_given_prim_type(prim_type, rect) 513 | 514 | ids = list(set(ids)) 515 | tnode_queue += ids 516 | 517 | for pid in ids: 518 | if id_ not in self.node_set[pid].child_ids: 519 | self.node_set[pid].child_ids.append(id_) 520 | 521 | def _add_onode_topdown_connection(self): 522 | assert self.param.use_root_terminal_node 523 | 524 | prim_type = [self.param.grid_ht, self.param.grid_wd] 525 | tnode_queue = self.find_node_ids_with_given_prim_type(prim_type) 526 | assert len(tnode_queue) == 1 527 | 528 | i = 0 529 | while i < len(tnode_queue): 530 | id_ = tnode_queue[i] 531 | node = self.node_set[id_] 532 | i += 1 533 | 534 | rect = self.primitive_set[node.rect_idx] 535 | 536 | ids = [] 537 | ids_t = [] 538 | for y in range(0, rect.Height()): 539 | for x in range(0, rect.Width()): 540 | if y == 0 and x == 0: 541 | continue 542 | prim_type = [rect.Height()-y, rect.Width()-x] 543 | ids += self.find_node_ids_with_given_prim_type(prim_type, rect, NodeType.OrNode) 544 | ids_t += self.find_node_ids_with_given_prim_type(prim_type, rect) 545 | 546 | ids = list(set(ids)) 547 | ids_t = list(set(ids_t)) 548 | 549 | for pid in ids: 550 | if id_ not in self.node_set[pid].child_ids: 551 | self.node_set[pid].child_ids.append(id_) 552 | 553 | def _add_tnode_bottomup_connection(self): 554 | assert self.param.use_root_terminal_node 555 | 556 | # primitive tnodes 557 | prim_type = [1, 1] 558 | t_ids = self.find_node_ids_with_given_prim_type(prim_type) 559 | assert len(t_ids) == self.param.grid_wd * self.param.grid_ht 560 | 561 | # other tnodes will be converted to and-nodes 562 | for h in range(1, self.param.grid_ht+1): 563 | for w in range(1, self.param.grid_wd+1): 564 | if h == 1 and w == 1: 565 | continue 566 | prim_type = [h, w] 567 | ids = self.find_node_ids_with_given_prim_type(prim_type) 568 | for id_ in ids: 569 | self.node_set[id_].node_type = NodeType.AndNode 570 | node = self.node_set[id_] 571 | rect = self.primitive_set[node.rect_idx] 572 | prim_type = [1, 1] 573 | for y in range(rect.y1, rect.y2+1): 574 | for x in range(rect.x1, rect.x2+1): 575 | parent_rect = Rect(x, y, x, y) 576 | ch_ids = self.find_node_ids_with_given_prim_type(prim_type, parent_rect) 577 | assert len(ch_ids) == 1 578 | if ch_ids[0] not in self.node_set[id_].child_ids: 579 | self.node_set[id_].child_ids.append(ch_ids[0]) 580 | 581 | def _add_lateral_connection(self): 582 | self._add_node_bottomup_connection_layerwise(node_type=NodeType.AndNode, direction=1) 583 | self._add_node_bottomup_connection_layerwise(node_type=NodeType.OrNode, direction=0) 584 | 585 | if not self.param.use_node_lateral_connection_1: 586 | return self.BFS[0] 587 | 588 | # or for all or nodes 589 | for node in self.node_set: 590 | if node.node_type != NodeType.OrNode: 591 | continue 592 | 593 | ch_ids = node.child_ids 594 | numCh = len(ch_ids) 595 | 596 | hasLateral = False 597 | for id_ in ch_ids: 598 | if self.node_set[id_].node_type == NodeType.OrNode: 599 | hasLateral = True 600 | numCh -= 1 601 | 602 | minNumCh = 3 if hasLateral else 2 603 | if len(ch_ids) < minNumCh: 604 | continue 605 | 606 | # find t-node child 607 | ch0 = -1 608 | for id_ in ch_ids: 609 | if self.node_set[id_].node_type == NodeType.TerminalNode: 610 | ch0 = id_ 611 | break 612 | assert ch0 != -1 613 | 614 | added = False 615 | for id_ in ch_ids: 616 | if id_ == ch0 or self.node_set[id_].node_type == NodeType.OrNode: 617 | continue 618 | 619 | if len(self.node_set[id_].child_ids) == 2 or numCh == 2: 620 | assert ch0 not in self.node_set[id_].child_ids 621 | self.node_set[id_].child_ids.append(ch0) 622 | added = True 623 | 624 | if not added: 625 | for id_ in ch_ids: 626 | if id_ == ch0 or self.node_set[id_].node_type == NodeType.OrNode: 627 | continue 628 | 629 | found = True 630 | for id__ in ch_ids: 631 | if id_ in self.node_set[id__].child_ids: 632 | found = False 633 | if found: 634 | assert ch0 not in self.node_set[id_].child_ids 635 | self.node_set[id_].child_ids.append(ch0) 636 | 637 | return self.BFS[0] 638 | 639 | def _add_node_bottomup_connection_layerwise(self, node_type=NodeType.TerminalNode, direction=0): 640 | 641 | prim_types = [] 642 | for node in self.node_set: 643 | if node.node_type == node_type: 644 | rect = self.primitive_set[node.rect_idx] 645 | p = [rect.Height(), rect.Width()] 646 | if p not in prim_types: 647 | prim_types.append(p) 648 | 649 | change_direction = False 650 | 651 | prim_types.sort() 652 | 653 | for p in prim_types: 654 | ids = self.find_node_ids_with_given_prim_type(p, node_type=node_type) 655 | if len(ids) < 2: 656 | change_direction = True 657 | continue 658 | 659 | if change_direction: 660 | direction = 1 - direction 661 | 662 | yx = np.empty((0, 4 if node_type==NodeType.AndNode else 2), dtype=np.float32) 663 | for id_ in ids: 664 | node = self.node_set[id_] 665 | rect = self.primitive_set[node.rect_idx] 666 | 667 | if node_type == NodeType.AndNode: 668 | ch_node = self.node_set[node.child_ids[0]] 669 | ch_rect = self.primitive_set[ch_node.rect_idx] 670 | if ch_rect.x1 != rect.x1 or ch_rect.y1 != rect.y1: 671 | ch_node = self.node_set[node.child_ids[1]] 672 | ch_rect = self.primitive_set[ch_node.rect_idx] 673 | pos = (rect.y1, rect.x1, ch_rect.y2, ch_rect.x2) 674 | else: 675 | pos = (rect.y1, rect.x1) 676 | yx = np.vstack((yx, np.array(pos))) 677 | 678 | if node_type == NodeType.AndNode: 679 | ind = np.lexsort((yx[:, 1], yx[:, 0], yx[:, 3], yx[:, 2])) 680 | else: 681 | ind = np.lexsort((yx[:, 1], yx[:, 0])) 682 | 683 | istart = len(ind) - 1 if direction == 0 else 0 684 | iend = 0 if direction == 0 else len(ind) - 1 685 | step = -1 if direction == 0 else 1 686 | for i in range(istart, iend, step): 687 | id_ = ids[ind[i]] 688 | chid = ids[ind[i - 1]] if direction==0 else ids[ind[i+1]] 689 | if chid not in self.node_set[id_].child_ids: 690 | self.node_set[id_].child_ids.append(chid) 691 | 692 | if change_direction: 693 | direction = 1 - direction 694 | change_direction = False 695 | 696 | def _add_tnode_bottomup_connection_sequential(self): 697 | 698 | assert self.param.grid_wd > 1 and self.param.grid_ht == 1 699 | 700 | self._add_node_bottomup_connection_layerwise() 701 | 702 | for node in self.node_set: 703 | if node.node_type != NodeType.TerminalNode: 704 | continue 705 | rect = self.primitive_set[node.rect_idx] 706 | if rect.Width() == 1: 707 | continue 708 | 709 | rect1 = copy.deepcopy(rect) 710 | rect1.x1 += 1 711 | chid = self._FindNodeIdWithGivenRect(rect1, NodeType.TerminalNode) 712 | if chid != -1: 713 | self.node_set[node.id].child_ids.append(chid) 714 | 715 | def _mark_symmetric_subgraph(self): 716 | 717 | for i in self.BFS: 718 | node = self.node_set[i] 719 | 720 | if node.is_symmetric or node.node_type == NodeType.TerminalNode: 721 | continue 722 | 723 | if i != self.BFS[0]: 724 | is_symmetric = True 725 | for j in node.parent_ids: 726 | p = self.node_set[j] 727 | if not p.is_symmetric: 728 | is_symmetric = False 729 | break 730 | if is_symmetric: 731 | self.node_set[i].is_symmetric = True 732 | continue 733 | 734 | rect = self.primitive_set[node.rect_idx] 735 | Wd = rect.Width() 736 | Ht = rect.Height() 737 | 738 | if node.node_type == NodeType.OrNode: 739 | # mark symmetric children 740 | useSplitWds = [] 741 | useSplitHts = [] 742 | if self.param.remove_symmetric_children_of_or_node == 2: 743 | child_ids = node.child_ids[::-1] 744 | else: 745 | child_ids = node.child_ids 746 | 747 | for j in child_ids: 748 | ch = self.node_set[j] 749 | if ch.node_type == NodeType.TerminalNode: 750 | continue 751 | 752 | if ch.split_type == SplitType.VerSplit: 753 | if (Wd-ch.split_step2, ch.split_step1) not in useSplitWds: 754 | useSplitWds.append((ch.split_step1, Wd-ch.split_step2)) 755 | else: 756 | self.node_set[j].is_symmetric = True 757 | 758 | elif ch.split_type == SplitType.HorSplit: 759 | if (Ht-ch.split_step2, ch.split_step1) not in useSplitHts: 760 | useSplitHts.append((ch.split_step1, Ht-ch.split_step2)) 761 | else: 762 | self.node_set[j].is_symmetric = True 763 | 764 | def _find_dbl_counting_or_nodes(self): 765 | for node in self.node_set: 766 | if node.node_type != NodeType.OrNode or len(node.child_ids) < 2: 767 | continue 768 | for i in self.node_BFS[node.id][1:]: 769 | npaths = { x : 0 for x in self.node_BFS[node.id] } 770 | n = self._countPaths(node, self.node_set[i], npaths) 771 | if n > 1: 772 | self.node_set[node.id].has_dbl_counting = True 773 | break 774 | 775 | def find_node_ids_with_given_prim_type(self, prim_type, parent_rect=None, node_type=NodeType.TerminalNode): 776 | ids = [] 777 | for node in self.node_set: 778 | if node.node_type != node_type: 779 | continue 780 | rect = self.primitive_set[node.rect_idx] 781 | if [rect.Height(), rect.Width()] == prim_type: 782 | if parent_rect is not None: 783 | if rect.x1 >= parent_rect.x1 and rect.y1 >= parent_rect.y1 and \ 784 | rect.x2 <= parent_rect.x2 and rect.y2 <= parent_rect.y2: 785 | ids.append(node.id) 786 | else: 787 | ids.append(node.id) 788 | return ids 789 | 790 | def Create(self): 791 | # print("======= creating AOGrid {}, could take a while".format(self.param.tag)) 792 | # FIXME: when remove_symmetric_children_of_or_node is true, top-down hierarchy is not correctly created. 793 | 794 | # the root OrNode 795 | rect = Rect(0, 0, self.param.grid_wd - 1, self.param.grid_ht - 1) 796 | self.primitive_set.append(rect) 797 | node = Node(node_type=NodeType.OrNode, rect_idx=0) 798 | self._AddNode(node) 799 | 800 | BFS = deque() 801 | BFS.append(0) 802 | keepLeft = True 803 | keepTop = True 804 | while len(BFS) > 0: 805 | curId = BFS.popleft() 806 | curNode = self.node_set[curId] 807 | curRect = self.primitive_set[curNode.rect_idx] 808 | curWd = curRect.Width() 809 | curHt = curRect.Height() 810 | 811 | childIds = [] 812 | 813 | if curNode.node_type == NodeType.OrNode: 814 | num_child_node_kept = 0 815 | # add a terminal node for a non-root OrNode 816 | allowTerminate = not ((self.param.not_use_large_terminal_node and 817 | float(curWd * curHt) / float(self.param.grid_ht * self.param.grid_wd) > 818 | self.param.turn_off_size_ratio_terminal_node) or 819 | (self.param.not_use_intermediate_TerminalNode and (curWd > self.param.min_size or curHt > self.param.min_size))) 820 | 821 | if (curId > 0 and allowTerminate) or (curId==0 and self.param.use_root_terminal_node): 822 | node = Node(node_type=NodeType.TerminalNode, rect_idx=curNode.rect_idx) 823 | suc, node = self._AddNode(node) 824 | childIds.append(node.id) 825 | num_child_node_kept += 1 826 | 827 | # add all AndNodes for horizontal and vertical binary splits 828 | if not self._DoSplit(curRect): 829 | childIds = list(set(childIds)) 830 | self.node_set[curId].child_ids = childIds 831 | continue 832 | 833 | num_child_node_to_add = self.param.max_children_kept_for_or - num_child_node_kept 834 | stepH = self._SplitStep(curWd) 835 | stepV = self._SplitStep(curHt) 836 | num_stepH = curHt - stepH + 1 - stepH 837 | num_stepV = curWd - stepV + 1 - stepV 838 | if num_stepH == 0 and num_stepV == 0: 839 | childIds = list(set(childIds)) 840 | self.node_set[curId].child_ids = childIds 841 | continue 842 | 843 | num_child_node_to_add_H = num_stepH / float(num_stepH + num_stepV) * num_child_node_to_add 844 | num_child_node_to_add_V = num_child_node_to_add - num_child_node_to_add_H 845 | 846 | stepH_step = int( 847 | max(1, floor(float(num_stepH) / num_child_node_to_add_H) if num_child_node_to_add_H > 0 else 0)) 848 | stepV_step = int( 849 | max(1, floor(float(num_stepV) / num_child_node_to_add_V) if num_child_node_to_add_V > 0 else 0)) 850 | 851 | # horizontal splits 852 | step = stepH 853 | num_child_node_added_H = 0 854 | 855 | splitHts = [] 856 | for topHt in range(step, curHt - step + 1, stepH_step): 857 | if num_child_node_added_H >= num_child_node_to_add_H: 858 | break 859 | 860 | bottomHt = curHt - topHt 861 | if self.param.overlap_ratio > 0: 862 | numSplit = int(1 + floor(topHt * self.param.overlap_ratio)) 863 | else: 864 | numSplit = 1 865 | for b in range(0, numSplit): 866 | splitHts.append((topHt, bottomHt)) 867 | bottomHt += 1 868 | num_child_node_added_H += 1 869 | 870 | if self.param.remove_symmetric_children_of_or_node == 1 and self.param.mark_symmetric_syntatic_subgraph == False: 871 | useSplitHts = [] 872 | for (topHt, bottomHt) in splitHts: 873 | if (bottomHt, topHt) not in useSplitHts: 874 | useSplitHts.append((topHt, bottomHt)) 875 | elif self.param.remove_symmetric_children_of_or_node == 2 and self.param.mark_symmetric_syntatic_subgraph == False: 876 | useSplitHts = [] 877 | for (topHt, bottomHt) in reversed(splitHts): 878 | if (bottomHt, topHt) not in useSplitHts: 879 | useSplitHts.append((topHt, bottomHt)) 880 | else: 881 | useSplitHts = splitHts 882 | 883 | for (topHt, bottomHt) in useSplitHts: 884 | node = Node(node_type=NodeType.AndNode, rect_idx=curNode.rect_idx, 885 | split_type=SplitType.HorSplit, 886 | split_step1=topHt, split_step2=curHt - bottomHt) 887 | suc, node = self._AddNode(node) 888 | if suc: 889 | BFS.append(node.id) 890 | childIds.append(node.id) 891 | 892 | # vertical splits 893 | step = stepV 894 | num_child_node_added_V = 0 895 | 896 | splitWds = [] 897 | for leftWd in range(step, curWd - step + 1, stepV_step): 898 | if num_child_node_added_V >= num_child_node_to_add_V: 899 | break 900 | 901 | rightWd = curWd - leftWd 902 | if self.param.overlap_ratio > 0: 903 | numSplit = int(1 + floor(leftWd * self.param.overlap_ratio)) 904 | else: 905 | numSplit = 1 906 | for r in range(0, numSplit): 907 | splitWds.append((leftWd, rightWd)) 908 | rightWd += 1 909 | num_child_node_added_V += 1 910 | 911 | if self.param.remove_symmetric_children_of_or_node == 1 and self.param.mark_symmetric_syntatic_subgraph == False: 912 | useSplitWds = [] 913 | for (leftWd, rightWd) in splitWds: 914 | if (rightWd, leftWd) not in useSplitWds: 915 | useSplitWds.append((leftWd, rightWd)) 916 | elif self.param.remove_symmetric_children_of_or_node == 2 and self.param.mark_symmetric_syntatic_subgraph == False: 917 | useSplitWds = [] 918 | for (leftWd, rightWd) in reversed(splitWds): 919 | if (rightWd, leftWd) not in useSplitWds: 920 | useSplitWds.append((leftWd, rightWd)) 921 | else: 922 | useSplitWds = splitWds 923 | 924 | for (leftWd, rightWd) in useSplitWds: 925 | node = Node(node_type=NodeType.AndNode, rect_idx=curNode.rect_idx, 926 | split_type=SplitType.VerSplit, 927 | split_step1=leftWd, split_step2=curWd - rightWd) 928 | suc, node = self._AddNode(node) 929 | if suc: 930 | BFS.append(node.id) 931 | childIds.append(node.id) 932 | 933 | elif curNode.node_type == NodeType.AndNode: 934 | # add two child OrNodes 935 | if curNode.split_type == SplitType.HorSplit: 936 | top = Rect(x1=curRect.x1, y1=curRect.y1, 937 | x2=curRect.x2, y2=curRect.y1 + curNode.split_step1 - 1) 938 | node = Node(node_type=NodeType.OrNode, rect_idx=self._AddPrimitve(top)) 939 | suc, node = self._AddNode(node) 940 | if suc: 941 | BFS.append(node.id) 942 | childIds.append(node.id) 943 | 944 | bottom = Rect(x1=curRect.x1, y1=curRect.y1 + curNode.split_step2, 945 | x2=curRect.x2, y2=curRect.y2) 946 | node = Node(node_type=NodeType.OrNode, rect_idx=self._AddPrimitve(bottom)) 947 | suc, node = self._AddNode(node) 948 | if suc: 949 | BFS.append(node.id) 950 | childIds.append(node.id) 951 | elif curNode.split_type == SplitType.VerSplit: 952 | left = Rect(curRect.x1, curRect.y1, 953 | curRect.x1 + curNode.split_step1 - 1, curRect.y2) 954 | node = Node(node_type=NodeType.OrNode, rect_idx=self._AddPrimitve(left)) 955 | suc, node = self._AddNode(node) 956 | if suc: 957 | BFS.append(node.id) 958 | childIds.append(node.id) 959 | 960 | right = Rect(curRect.x1 + curNode.split_step2, curRect.y1, 961 | curRect.x2, curRect.y2) 962 | node = Node(node_type=NodeType.OrNode, rect_idx=self._AddPrimitve(right)) 963 | suc, node = self._AddNode(node) 964 | if suc: 965 | BFS.append(node.id) 966 | childIds.append(node.id) 967 | 968 | childIds = list(set(childIds)) 969 | self.node_set[curId].child_ids = childIds 970 | 971 | # add root terminal node if allowed 972 | root_id = 0 973 | 974 | # create And-nodes with more than 2 children 975 | # TODO: handle remove_symmetric_child_node 976 | if self.param.max_split > 2: 977 | for branch in range(3, self.param.max_split + 1): 978 | for node in self.node_set: 979 | if node.node_type != NodeType.OrNode: 980 | continue 981 | 982 | new_and_ids = [] 983 | 984 | for cur_id in node.child_ids: 985 | cur_and = self.node_set[cur_id] 986 | if len(cur_and.child_ids) != branch - 1: 987 | continue 988 | assert cur_and.node_type == NodeType.AndNode 989 | 990 | for ch_id in cur_and.child_ids: 991 | ch = self.node_set[ch_id] 992 | curRect = self.primitive_set[ch.rect_idx] 993 | curWd = curRect.Width() 994 | curHt = curRect.Height() 995 | 996 | # split ch into two to create new And-nodes 997 | 998 | # add all AndNodes for horizontal and vertical binary splits 999 | if not self._DoSplit(curRect): 1000 | continue 1001 | 1002 | # horizontal splits 1003 | step = self._SplitStep(curWd) 1004 | for topHt in range(step, curHt - step + 1): 1005 | bottomHt = curHt - topHt 1006 | if self.param.overlap_ratio > 0: 1007 | numSplit = int(1 + floor(topHt * self.param.overlap_ratio)) 1008 | else: 1009 | numSplit = 1 1010 | for b in range(0, numSplit): 1011 | split_step1 = topHt 1012 | split_step2 = curHt - bottomHt 1013 | 1014 | top = Rect(x1=curRect.x1, y1=curRect.y1, 1015 | x2=curRect.x2, y2=curRect.y1 + split_step1 - 1) 1016 | top_id = self._FindNodeIdWithGivenRect(top, NodeType.OrNode) 1017 | if top_id == -1: 1018 | continue 1019 | # assert top_id != -1 1020 | 1021 | bottom = Rect(x1=curRect.x1, y1=curRect.y1 + split_step2, 1022 | x2=curRect.x2, y2=curRect.y2) 1023 | bottom_id = self._FindNodeIdWithGivenRect(bottom, NodeType.OrNode) 1024 | if bottom_id == -1: 1025 | continue 1026 | # assert bottom_id != -1 1027 | 1028 | # add a new And-node 1029 | new_and = Node(node_type=NodeType.AndNode, rect_idx=cur_and.rect_idx) 1030 | new_and.child_ids = list(set(cur_and.child_ids) - set([ch_id])) + [top_id, 1031 | bottom_id] 1032 | 1033 | suc, new_and = self._AddNode(new_and) 1034 | new_and_ids.append(new_and.id) 1035 | 1036 | bottomHt += 1 1037 | 1038 | # vertical splits 1039 | step = self._SplitStep(curHt) 1040 | for leftWd in range(step, curWd - step + 1): 1041 | rightWd = curWd - leftWd 1042 | 1043 | if self.param.overlap_ratio > 0: 1044 | numSplit = int(1 + floor(leftWd * self.param.overlap_ratio)) 1045 | else: 1046 | numSplit = 1 1047 | for r in range(0, numSplit): 1048 | split_step1 = leftWd 1049 | split_step2 = curWd - rightWd 1050 | 1051 | left = Rect(curRect.x1, curRect.y1, 1052 | curRect.x1 + split_step1 - 1, curRect.y2) 1053 | left_id = self._FindNodeIdWithGivenRect(left, NodeType.OrNode) 1054 | if left_id == -1: 1055 | continue 1056 | # assert left_id != -1 1057 | 1058 | right = Rect(curRect.x1 + split_step2, curRect.y1, 1059 | curRect.x2, curRect.y2) 1060 | right_id = self._FindNodeIdWithGivenRect(right, NodeType.OrNode) 1061 | if right_id == -1: 1062 | continue 1063 | # assert right_id != -1 1064 | 1065 | # add a new And-node 1066 | new_and = Node(node_type=NodeType.AndNode, rect_idx=cur_and.rect_idx) 1067 | new_and.child_ids = list(set(cur_and.child_ids) - set([ch_id])) + [left_id, 1068 | right_id] 1069 | 1070 | suc, new_and = self._AddNode(new_and) 1071 | new_and_ids.append(new_and.id) 1072 | 1073 | rightWd += 1 1074 | 1075 | self.node_set[node.id].child_ids = list(set(self.node_set[node.id].child_ids + new_and_ids)) 1076 | 1077 | self._Postprocessing(root_id) 1078 | 1079 | # change tnodes to child nodes of and-nodes / or-nodes 1080 | if self.param.use_tnode_as_alpha_channel > 0: 1081 | node_type = NodeType.OrNode if self.param.use_tnode_as_alpha_channel==1 else NodeType.AndNode 1082 | not_create_if_existed = not self.param.use_tnode_as_alpha_channel==1 1083 | for id_ in self.BFS: 1084 | node = self.node_set[id_] 1085 | if node.node_type == NodeType.OrNode and len(node.child_ids) > 1: 1086 | for ch in node.child_ids: 1087 | ch_node = self.node_set[ch] 1088 | if ch_node.node_type == NodeType.TerminalNode: 1089 | new_parent_node = Node(node_type=node_type, rect_idx=ch_node.rect_idx) 1090 | _, new_parent_node = self._AddNode(new_parent_node, not_create_if_existed) 1091 | new_parent_node.child_ids = [ch_node.id, node.id] 1092 | 1093 | for pr in node.parent_ids: 1094 | pr_node = self.node_set[pr] 1095 | for i, pr_ch in enumerate(pr_node.child_ids): 1096 | if pr_ch == node.id: 1097 | pr_node.child_ids[i] = new_parent_node.id 1098 | break 1099 | 1100 | self.node_set[id_].child_ids.remove(ch) 1101 | if id_ == self.BFS[0]: 1102 | root_id = new_parent_node.id 1103 | break 1104 | 1105 | self._Postprocessing(root_id) 1106 | 1107 | # add super-or node 1108 | if self.param.use_super_OrNode: 1109 | super_or_node = Node(node_type=NodeType.OrNode, rect_idx=-1) 1110 | _, super_or_node = self._AddNode(super_or_node) 1111 | super_or_node.child_ids = [] 1112 | for node in self.node_set: 1113 | if node.node_type == NodeType.OrNode and node.rect_idx != -1: 1114 | rect = self.primitive_set[node.rect_idx] 1115 | r = float(rect.Area()) / float(self.param.grid_ht * self.param.grid_wd) 1116 | if r > 0.5: 1117 | super_or_node.child_ids.append(node.id) 1118 | 1119 | root_id = super_or_node.id 1120 | 1121 | self._Postprocessing(root_id) 1122 | 1123 | # remove or-nodes with single child node 1124 | if self.param.remove_single_child_or_node: 1125 | remove_ids = [] 1126 | for node in self.node_set: 1127 | if node.node_type == NodeType.OrNode and len(node.child_ids) == 1: 1128 | for pr in node.parent_ids: 1129 | pr_node = self.node_set[pr] 1130 | for i, pr_ch in enumerate(pr_node.child_ids): 1131 | if pr_ch == node.id: 1132 | pr_node.child_ids[i] = node.child_ids[0] 1133 | break 1134 | 1135 | remove_ids.append(node.id) 1136 | node.child_ids = [] 1137 | 1138 | remove_ids.sort() 1139 | remove_ids.reverse() 1140 | 1141 | for id_ in remove_ids: 1142 | for node in self.node_set: 1143 | if node.id > id_: 1144 | node.id -= 1 1145 | for i, ch in enumerate(node.child_ids): 1146 | if ch > id_: 1147 | node.child_ids[i] -= 1 1148 | 1149 | if root_id > id_: 1150 | root_id -= 1 1151 | 1152 | for id_ in remove_ids: 1153 | del self.node_set[id_] 1154 | 1155 | self._Postprocessing(root_id) 1156 | 1157 | # mark symmetric nodes 1158 | if self.param.mark_symmetric_syntatic_subgraph: 1159 | self._mark_symmetric_subgraph() 1160 | 1161 | # add tnode hierarchy 1162 | if self.param.use_tnode_topdown_connection: 1163 | self._add_tnode_topdown_connection() 1164 | self._Postprocessing(root_id) 1165 | elif self.param.use_tnode_bottomup_connection: 1166 | self._add_tnode_bottomup_connection() 1167 | self._Postprocessing(root_id) 1168 | elif self.param.use_tnode_bottomup_connection_layerwise: 1169 | self._add_node_bottomup_connection_layerwise() 1170 | self._Postprocessing(root_id) 1171 | elif self.param.use_tnode_bottomup_connection_sequential: 1172 | self._add_tnode_bottomup_connection_sequential() 1173 | self._Postprocessing(root_id) 1174 | elif self.param.use_node_lateral_connection or self.param.use_node_lateral_connection_1: 1175 | root_id = self._add_lateral_connection() 1176 | self._Postprocessing(root_id) 1177 | 1178 | # index of Or-nodes in BFS 1179 | self.OrNodeIdxInBFS = {} 1180 | self.TNodeIdxInBFS = {} 1181 | idx_or = 0 1182 | idx_t = 0 1183 | for id_ in self.BFS: 1184 | node = self.node_set[id_] 1185 | if node.node_type == NodeType.OrNode: 1186 | self.OrNodeIdxInBFS[node.id] = idx_or 1187 | idx_or += 1 1188 | elif node.node_type == NodeType.TerminalNode: 1189 | self.TNodeIdxInBFS[node.id] = idx_t 1190 | idx_t += 1 1191 | 1192 | # get DFS and BFS rooted at each node 1193 | for node in self.node_set: 1194 | if node.node_type == NodeType.TerminalNode: 1195 | continue 1196 | visited = np.zeros(len(self.node_set)) 1197 | self.node_DFS[node.id] = [] 1198 | self.node_DFS[node.id], _ = self._DFS(node.id, self.node_DFS[node.id], visited) 1199 | 1200 | visited = np.zeros(len(self.node_set)) 1201 | self.node_BFS[node.id] = [] 1202 | self.node_BFS[node.id], _ = self._BFS(node.id, self.node_BFS[node.id], visited) 1203 | 1204 | # count paths between nodes and root node 1205 | for n in self.node_set: 1206 | npaths = { x.id : 0 for x in self.node_set } 1207 | self.node_set[n.id].npaths = self._countPaths(self.node_set[self.BFS[0]], n, npaths) 1208 | 1209 | # find ornode with double-counting children 1210 | self._find_dbl_counting_or_nodes() 1211 | 1212 | # generate colors for terminal nodes for consistency in visualization 1213 | self.TNodeColors = {} 1214 | for node in self.node_set: 1215 | if node.node_type == NodeType.TerminalNode: 1216 | self.TNodeColors[node.id] = ( 1217 | random.random(), random.random(), random.random()) # generate a random color 1218 | 1219 | 1220 | def TurnOnOffNodes(self, on_off): 1221 | for i in range(len(self.node_set)): 1222 | self.node_set[i].on_off = on_off 1223 | 1224 | def UpdateOnOffNodes(self, pg, offset_using_part_type, class_name=''): 1225 | BFS = [self.BFS[0]] 1226 | pg_used = np.ones((1, len(pg)), dtype=np.int) * -1 1227 | configuration = [] 1228 | tnode_offset_indx = [] 1229 | while len(BFS): 1230 | id = BFS.pop() 1231 | node = self.node_set[id] 1232 | self.node_set[id].on_off = True 1233 | if len(class_name): 1234 | if class_name in node.which_classes_visited.keys(): 1235 | self.node_set[id].which_classes_visited[class_name] += 1.0 1236 | else: 1237 | self.node_set[id].which_classes_visited[class_name] = 0 1238 | 1239 | if node.node_type == NodeType.OrNode: 1240 | idx = self.OrNodeIdxInBFS[node.id] 1241 | BFS.append(node.child_ids[int(pg[idx])]) 1242 | pg_used[0, idx] = int(pg[idx]) 1243 | if len(self.node_set[id].out_edge_visited_count): 1244 | self.node_set[id].out_edge_visited_count[int(pg[idx])] += 1.0 1245 | else: 1246 | self.node_set[id].out_edge_visited_count = np.zeros((len(node.child_ids),), dtype=np.float32) 1247 | elif node.node_type == NodeType.AndNode: 1248 | BFS += node.child_ids 1249 | 1250 | else: 1251 | configuration.append(node.id) 1252 | 1253 | offset_ind = 0 1254 | if not offset_using_part_type: 1255 | for node1 in self.node_set: 1256 | if node1.node_type == NodeType.TerminalNode: # change to BFS after _part_instance is changed to BFS 1257 | if node1.id == node.id: 1258 | break 1259 | offset_ind += 1 1260 | else: 1261 | rect = self.primitive_set[node.rect_idx] 1262 | offset_ind = self.part_type.index([rect.Height(), rect.Width()]) 1263 | 1264 | tnode_offset_indx.append(offset_ind) 1265 | 1266 | configuration.sort() 1267 | cfg = np.ones((1, self.num_TNodes), dtype=np.int) * -1 1268 | cfg[0, :len(configuration)] = configuration 1269 | return pg_used, cfg, tnode_offset_indx 1270 | 1271 | def ResetOutEdgeVisitedCountNodes(self): 1272 | for i in range(len(self.node_set)): 1273 | self.node_set[i].out_edge_visited_count = [] 1274 | 1275 | def NormalizeOutEdgeVisitedCountNodes(self, count=0): 1276 | if count == 0: 1277 | for i in range(len(self.node_set)): 1278 | if len(self.node_set[i].out_edge_visited_count): 1279 | count = max(count, max(self.node_set[i].out_edge_visited_count)) 1280 | 1281 | if count == 0: 1282 | return 1283 | 1284 | for i in range(len(self.node_set)): 1285 | if len(self.node_set[i].out_edge_visited_count): 1286 | self.node_set[i].out_edge_visited_count /= count 1287 | 1288 | def ResetWhichClassesVisitedNodes(self): 1289 | for i in range(len(self.node_set)): 1290 | self.node_set[i].which_classes_visited = {} 1291 | 1292 | def NormalizeWhichClassesVisitedNodes(self, class_name, count): 1293 | assert count > 0 1294 | for i in range(len(self.node_set)): 1295 | if class_name in self.node_set[i].which_classes_visited.keys(): 1296 | self.node_set[i].which_classes_visited[class_name] /= count 1297 | -------------------------------------------------------------------------------- /models/aognet/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/iVMCL/AOGNet-v2/a95a8696c131331607e81bb31eeae3405a76b969/models/aognet/__init__.py -------------------------------------------------------------------------------- /models/aognet/aognet.py: -------------------------------------------------------------------------------- 1 | """ RESEARCH ONLY LICENSE 2 | Copyright (c) 2018-2019 North Carolina State University. 3 | All rights reserved. 4 | Redistribution and use in source and binary forms, with or without modification, are permitted provided 5 | that the following conditions are met: 6 | 1. Redistributions and use are permitted for internal research purposes only, and commercial use 7 | is strictly prohibited under this license. Inquiries regarding commercial use should be directed to the 8 | Office of Research Commercialization at North Carolina State University, 919-215-7199, 9 | https://research.ncsu.edu/commercialization/contact/, commercialization@ncsu.edu . 10 | 2. Commercial use means the sale, lease, export, transfer, conveyance or other distribution to a 11 | third party for financial gain, income generation or other commercial purposes of any kind, whether 12 | direct or indirect. Commercial use also means providing a service to a third party for financial gain, 13 | income generation or other commercial purposes of any kind, whether direct or indirect. 14 | 3. Redistributions of source code must retain the above copyright notice, this list of conditions and 15 | the following disclaimer. 16 | 4. Redistributions in binary form must reproduce the above copyright notice, this list of conditions 17 | and the following disclaimer in the documentation and/or other materials provided with the 18 | distribution. 19 | 5. The names “North Carolina State University”, “NCSU” and any trade-name, personal name, 20 | trademark, trade device, service mark, symbol, image, icon, or any abbreviation, contraction or 21 | simulation thereof owned by North Carolina State University must not be used to endorse or promote 22 | products derived from this software without prior written permission. For written permission, please 23 | contact trademarks@ncsu.edu. 24 | Disclaimer: THIS SOFTWARE IS PROVIDED “AS IS” AND ANY EXPRESSED OR IMPLIED WARRANTIES, 25 | INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR 26 | A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NORTH CAROLINA STATE UNIVERSITY BE 27 | LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES 28 | (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, 29 | DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF 30 | LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR 31 | OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE 32 | POSSIBILITY OF SUCH DAMAGE. 33 | """ 34 | # The system is protected via patent (pending) 35 | # Written by Tianfu Wu and Xilai Li 36 | # Contact: {xli47, tianfu_wu}@ncsu.edu 37 | from __future__ import absolute_import 38 | from __future__ import division 39 | from __future__ import print_function # force to use print as function print(args) 40 | from __future__ import unicode_literals 41 | 42 | import torch 43 | import torch.nn as nn 44 | import torch.nn.functional as F 45 | from torch.autograd import Variable 46 | 47 | import scipy.stats as stats 48 | 49 | from models.config import cfg 50 | from .AOG import * 51 | from .operator_basic import * 52 | from .operator_singlescale import * 53 | 54 | ### AOG building block 55 | class AOGBlock(nn.Module): 56 | def __init__(self, stage, block, aog, in_channels, out_channels, drop_rate, stride): 57 | super(AOGBlock, self).__init__() 58 | self.stage = stage 59 | self.block = block 60 | self.aog = aog 61 | self.in_channels = in_channels 62 | self.out_channels = out_channels 63 | self.drop_rate = drop_rate 64 | self.stride = stride 65 | 66 | self.dim = aog.param.grid_wd 67 | self.in_slices = self._calculate_slices(self.dim, in_channels) 68 | self.out_slices = self._calculate_slices(self.dim, out_channels) 69 | 70 | self.node_set = aog.node_set 71 | self.primitive_set = aog.primitive_set 72 | self.BFS = aog.BFS 73 | self.DFS = aog.DFS 74 | 75 | self.hasLateral = {} 76 | self.hasDblCnt = {} 77 | 78 | self.primitiveDblCnt = None 79 | self._set_primitive_dbl_cnt() 80 | 81 | if "BatchNorm2d" in cfg.norm_name: 82 | self.norm_name_base = "BatchNorm2d" 83 | elif "GroupNorm" in cfg.norm_name: 84 | self.norm_name_base = "GroupNorm" 85 | else: 86 | raise ValueError("Unknown norm layer") 87 | 88 | self._set_weights_attr() 89 | 90 | self.extra_norm_ac = self._extra_norm_ac() 91 | 92 | def _calculate_slices(self, dim, channels): 93 | slices = [0] * dim 94 | for i in range(channels): 95 | slices[i % dim] += 1 96 | for d in range(1, dim): 97 | slices[d] += slices[d - 1] 98 | slices = [0] + slices 99 | return slices 100 | 101 | def _set_primitive_dbl_cnt(self): 102 | self.primitiveDblCnt = [0.0 for i in range(self.dim)] 103 | for id_ in self.DFS: 104 | node = self.node_set[id_] 105 | arr = self.primitive_set[node.rect_idx] 106 | if node.node_type == NodeType.TerminalNode: 107 | for i in range(arr.x1, arr.x2+1): 108 | self.primitiveDblCnt[i] += node.npaths 109 | for i in range(self.dim): 110 | assert self.primitiveDblCnt[i] >= 1.0 111 | 112 | def _create_op(self, node_id, cin, cout, stride, groups=1, 113 | keep_norm_base=False, norm_k=0): 114 | replace_stride = cfg.aognet.replace_stride_with_avgpool 115 | setattr(self, 'stage_{}_block_{}_node_{}_op'.format(self.stage, self.block, node_id), 116 | NodeOpSingleScale(cin, cout, stride, 117 | groups=groups, drop_rate=self.drop_rate, 118 | ac_mode=cfg.activation_mode, 119 | bn_ratio=cfg.aognet.bottleneck_ratio, 120 | norm_name=self.norm_name_base if keep_norm_base else cfg.norm_name, 121 | norm_groups=cfg.norm_groups, 122 | norm_k = norm_k, 123 | norm_attention_mode=cfg.norm_attention_mode, 124 | replace_stride_with_avgpool=replace_stride)) 125 | 126 | def _set_weights_attr(self): 127 | for id_ in self.DFS: 128 | node = self.node_set[id_] 129 | arr = self.primitive_set[node.rect_idx] 130 | bn_ratio = cfg.aognet.bottleneck_ratio 131 | width_per_group = cfg.aognet.width_per_group 132 | keep_norm_base = arr.Width() ch_arr.Width(): 162 | if node.npaths / self.node_set[chid].npaths != 1.0: 163 | self.hasDblCnt[node.id] = True 164 | break 165 | self._create_op(node.id, plane, plane, stride, groups=groups, 166 | keep_norm_base=keep_norm_base, norm_k=norm_k) 167 | 168 | elif node.node_type == NodeType.OrNode: 169 | assert self.node_set[node.child_ids[0]].node_type != NodeType.OrNode 170 | plane = self.out_slices[arr.x2 + 1] - self.out_slices[arr.x1] 171 | stride = 1 172 | groups = max(1, to_int(plane * bn_ratio / width_per_group)) \ 173 | if cfg.aognet.use_group_conv else 1 174 | self.hasLateral[node.id] = False 175 | self.hasDblCnt[node.id] = False 176 | for chid in node.child_ids: 177 | ch_arr = self.primitive_set[self.node_set[chid].rect_idx] 178 | if self.node_set[chid].node_type == NodeType.OrNode or arr.Width() < ch_arr.Width(): 179 | self.hasLateral[node.id] = True 180 | break 181 | if cfg.aognet.handle_dbl_cnt: 182 | for chid in node.child_ids: 183 | ch_arr = self.primitive_set[self.node_set[chid].rect_idx] 184 | if not (self.node_set[chid].node_type == NodeType.OrNode or arr.Width() < ch_arr.Width()): 185 | if node.npaths / self.node_set[chid].npaths != 1.0: 186 | self.hasDblCnt[node.id] = True 187 | break 188 | self._create_op(node.id, plane, plane, stride, groups=groups, 189 | keep_norm_base=keep_norm_base, norm_k=norm_k) 190 | 191 | def _extra_norm_ac(self): 192 | return nn.Sequential(FeatureNorm(self.norm_name_base, self.out_channels, 193 | cfg.norm_groups, cfg.norm_k[self.stage], 194 | cfg.norm_attention_mode), 195 | AC(cfg.activation_mode)) 196 | 197 | def forward(self, x): 198 | NodeIdTensorDict = {} 199 | 200 | # handle input x 201 | tnode_dblcnt = False 202 | if cfg.aognet.handle_tnode_dbl_cnt and self.in_channels==self.out_channels: 203 | x_scaled = [] 204 | for i in range(self.dim): 205 | left, right = self.in_slices[i], self.in_slices[i+1] 206 | x_scaled.append(x[:, left:right, :, :].div(self.primitiveDblCnt[i])) 207 | xx = torch.cat(x_scaled, 1) 208 | tnode_dblcnt = True 209 | 210 | # T-nodes, (hope they will be computed in parallel by pytorch) 211 | for id_ in self.DFS: 212 | node = self.node_set[id_] 213 | op_name = 'stage_{}_block_{}_node_{}_op'.format(self.stage, self.block, node.id) 214 | if node.node_type == NodeType.TerminalNode: 215 | arr = self.primitive_set[node.rect_idx] 216 | right, left = self.in_slices[arr.x2 + 1], self.in_slices[arr.x1] 217 | tnode_tensor_op = x if cfg.aognet.terminal_node_no_slice[self.stage] else x[:, left:right, :, :].contiguous() 218 | # assert tnode_tensor.requires_grad, 'slice needs to retain grad' 219 | if tnode_dblcnt: 220 | tnode_tensor_res = xx[:, left:right, :, :].mul(node.npaths) 221 | tnode_output = getattr(self, op_name)(tnode_tensor_op, tnode_tensor_res) 222 | else: 223 | tnode_output = getattr(self, op_name)(tnode_tensor_op) 224 | NodeIdTensorDict[node.id] = tnode_output 225 | 226 | # AND- and OR-nodes 227 | for id_ in self.DFS: 228 | node = self.node_set[id_] 229 | arr = self.primitive_set[node.rect_idx] 230 | op_name = 'stage_{}_block_{}_node_{}_op'.format(self.stage, self.block, node.id) 231 | if node.node_type == NodeType.AndNode: 232 | if self.hasDblCnt[node.id]: 233 | child_tensor_res = [] 234 | child_tensor_op = [] 235 | for chid in node.child_ids: 236 | ch_arr = self.primitive_set[self.node_set[chid].rect_idx] 237 | if arr.Width() > ch_arr.Width(): 238 | factor = node.npaths / self.node_set[chid].npaths 239 | if factor == 1.0: 240 | child_tensor_res.append(NodeIdTensorDict[chid]) 241 | else: 242 | child_tensor_res.append(NodeIdTensorDict[chid].mul(factor)) 243 | child_tensor_op.append(NodeIdTensorDict[chid]) 244 | 245 | anode_tensor_res = torch.cat(child_tensor_res, 1) 246 | anode_tensor_op = torch.cat(child_tensor_op, 1) 247 | 248 | if self.hasLateral[node.id]: 249 | ids1 = set(node.parent_ids) 250 | num_shared = 0 251 | for chid in node.child_ids: 252 | ch_arr = self.primitive_set[self.node_set[chid].rect_idx] 253 | ids2 = self.node_set[chid].parent_ids 254 | if arr.Width() == ch_arr.Width(): 255 | anode_tensor_op = anode_tensor_op + NodeIdTensorDict[chid] 256 | if len(ids1.intersection(ids2)) == num_shared: 257 | anode_tensor_res = anode_tensor_res + NodeIdTensorDict[chid] 258 | 259 | anode_output = getattr(self, op_name)(anode_tensor_op, anode_tensor_res) 260 | else: 261 | child_tensor = [] 262 | for chid in node.child_ids: 263 | ch_arr = self.primitive_set[self.node_set[chid].rect_idx] 264 | if arr.Width() > ch_arr.Width(): 265 | child_tensor.append(NodeIdTensorDict[chid]) 266 | 267 | anode_tensor_op = torch.cat(child_tensor, 1) 268 | 269 | if self.hasLateral[node.id]: 270 | ids1 = set(node.parent_ids) 271 | num_shared = 0 272 | for chid in node.child_ids: 273 | ch_arr = self.primitive_set[self.node_set[chid].rect_idx] 274 | ids2 = self.node_set[chid].parent_ids 275 | if arr.Width() == ch_arr.Width() and len(ids1.intersection(ids2)) == num_shared: 276 | anode_tensor_op = anode_tensor_op + NodeIdTensorDict[chid] 277 | 278 | anode_tensor_res = anode_tensor_op 279 | 280 | for chid in node.child_ids: 281 | ch_arr = self.primitive_set[self.node_set[chid].rect_idx] 282 | ids2 = self.node_set[chid].parent_ids 283 | if arr.Width() == ch_arr.Width() and len(ids1.intersection(ids2)) > num_shared: 284 | anode_tensor_op = anode_tensor_op + NodeIdTensorDict[chid] 285 | 286 | anode_output = getattr(self, op_name)(anode_tensor_op, anode_tensor_res) 287 | else: 288 | anode_output = getattr(self, op_name)(anode_tensor_op) 289 | 290 | NodeIdTensorDict[node.id] = anode_output 291 | 292 | elif node.node_type == NodeType.OrNode: 293 | if self.hasDblCnt[node.id]: 294 | factor = node.npaths / self.node_set[node.child_ids[0]].npaths 295 | if factor == 1.0: 296 | onode_tensor_res = NodeIdTensorDict[node.child_ids[0]] 297 | else: 298 | onode_tensor_res = NodeIdTensorDict[node.child_ids[0]].mul(factor) 299 | onode_tensor_op = NodeIdTensorDict[node.child_ids[0]] 300 | for chid in node.child_ids[1:]: 301 | if self.node_set[chid].node_type != NodeType.OrNode: 302 | ch_arr = self.primitive_set[self.node_set[chid].rect_idx] 303 | if arr.Width() == ch_arr.Width(): 304 | factor = node.npaths / self.node_set[chid].npaths 305 | if factor == 1.0: 306 | onode_tensor_res = onode_tensor_res + NodeIdTensorDict[chid] 307 | else: 308 | onode_tensor_res = onode_tensor_res + NodeIdTensorDict[chid].mul(factor) 309 | if cfg.aognet.use_elem_max_for_ORNodes: 310 | onode_tensor_op = torch.max(onode_tensor_op, NodeIdTensorDict[chid]) 311 | else: 312 | onode_tensor_op = onode_tensor_op + NodeIdTensorDict[chid] 313 | 314 | if self.hasLateral[node.id]: 315 | ids1 = set(node.parent_ids) 316 | num_shared = 0 317 | for chid in node.child_ids[1:]: 318 | ids2 = self.node_set[chid].parent_ids 319 | if self.node_set[chid].node_type == NodeType.OrNode and \ 320 | len(ids1.intersection(ids2)) == num_shared: 321 | onode_tensor_res = onode_tensor_res + NodeIdTensorDict[chid] 322 | if cfg.aognet.use_elem_max_for_ORNodes: 323 | onode_tensor_op = torch.max(onode_tensor_op, NodeIdTensorDict[chid]) 324 | else: 325 | onode_tensor_op = onode_tensor_op + NodeIdTensorDict[chid] 326 | 327 | for chid in node.child_ids[1:]: 328 | ch_arr = self.primitive_set[self.node_set[chid].rect_idx] 329 | ids2 = self.node_set[chid].parent_ids 330 | if self.node_set[chid].node_type == NodeType.OrNode and \ 331 | len(ids1.intersection(ids2)) > num_shared: 332 | if cfg.aognet.use_elem_max_for_ORNodes: 333 | onode_tensor_op = torch.max(onode_tensor_op, NodeIdTensorDict[chid]) 334 | else: 335 | onode_tensor_op = onode_tensor_op + NodeIdTensorDict[chid] 336 | elif self.node_set[chid].node_type == NodeType.TerminalNode and \ 337 | arr.Width() < ch_arr.Width(): 338 | ch_left = self.out_slices[arr.x1] - self.out_slices[ch_arr.x1] 339 | ch_right = self.out_slices[arr.x2 + 1] - self.out_slices[ch_arr.x1] 340 | if cfg.aognet.use_elem_max_for_ORNodes: 341 | onode_tensor_op = torch.max(onode_tensor_op, NodeIdTensorDict[chid][:, ch_left:ch_right, :, :]) 342 | else: 343 | onode_tensor_op = onode_tensor_op + NodeIdTensorDict[chid][:, ch_left:ch_right, :, :]#.contiguous() 344 | 345 | onode_output = getattr(self, op_name)(onode_tensor_op, onode_tensor_res) 346 | else: 347 | if cfg.aognet.use_elem_max_for_ORNodes: 348 | onode_tensor_op = NodeIdTensorDict[node.child_ids[0]] 349 | onode_tensor_res = NodeIdTensorDict[node.child_ids[0]] 350 | for chid in node.child_ids[1:]: 351 | if self.node_set[chid].node_type != NodeType.OrNode: 352 | ch_arr = self.primitive_set[self.node_set[chid].rect_idx] 353 | if arr.Width() == ch_arr.Width(): 354 | onode_tensor_op = torch.max(onode_tensor_op, NodeIdTensorDict[chid]) 355 | onode_tensor_res = onode_tensor_res + NodeIdTensorDict[chid] 356 | 357 | if self.hasLateral[node.id]: 358 | ids1 = set(node.parent_ids) 359 | num_shared = 0 360 | for chid in node.child_ids[1:]: 361 | ids2 = self.node_set[chid].parent_ids 362 | if self.node_set[chid].node_type == NodeType.OrNode and \ 363 | len(ids1.intersection(ids2)) == num_shared: 364 | onode_tensor_op = torch.max(onode_tensor_op, NodeIdTensorDict[chid]) 365 | onode_tensor_res = onode_tensor_res + NodeIdTensorDict[chid] 366 | 367 | for chid in node.child_ids[1:]: 368 | ch_arr = self.primitive_set[self.node_set[chid].rect_idx] 369 | ids2 = self.node_set[chid].parent_ids 370 | if self.node_set[chid].node_type == NodeType.OrNode and \ 371 | len(ids1.intersection(ids2)) > num_shared: 372 | onode_tensor_op = torch.max(onode_tensor_op, NodeIdTensorDict[chid]) 373 | elif self.node_set[chid].node_type == NodeType.TerminalNode and \ 374 | arr.Width() < ch_arr.Width(): 375 | ch_left = self.out_slices[arr.x1] - self.out_slices[ch_arr.x1] 376 | ch_right = self.out_slices[arr.x2 + 1] - self.out_slices[ch_arr.x1] 377 | onode_tensor_op = torch.max(onode_tensor_op, NodeIdTensorDict[chid][:, ch_left:ch_right, :, :]) 378 | 379 | onode_output = getattr(self, op_name)(onode_tensor_op, onode_tensor_res) 380 | else: 381 | onode_output = getattr(self, op_name)(onode_tensor_op) 382 | else: 383 | onode_tensor_op = NodeIdTensorDict[node.child_ids[0]] 384 | for chid in node.child_ids[1:]: 385 | if self.node_set[chid].node_type != NodeType.OrNode: 386 | ch_arr = self.primitive_set[self.node_set[chid].rect_idx] 387 | if arr.Width() == ch_arr.Width(): 388 | onode_tensor_op = onode_tensor_op + NodeIdTensorDict[chid] 389 | 390 | if self.hasLateral[node.id]: 391 | ids1 = set(node.parent_ids) 392 | num_shared = 0 393 | for chid in node.child_ids[1:]: 394 | ids2 = self.node_set[chid].parent_ids 395 | if self.node_set[chid].node_type == NodeType.OrNode and \ 396 | len(ids1.intersection(ids2)) == num_shared: 397 | onode_tensor_op = onode_tensor_op + NodeIdTensorDict[chid] 398 | 399 | onode_tensor_res = onode_tensor_op 400 | 401 | for chid in node.child_ids[1:]: 402 | ch_arr = self.primitive_set[self.node_set[chid].rect_idx] 403 | ids2 = self.node_set[chid].parent_ids 404 | if self.node_set[chid].node_type == NodeType.OrNode and \ 405 | len(ids1.intersection(ids2)) > num_shared: 406 | onode_tensor_op = onode_tensor_op + NodeIdTensorDict[chid] 407 | elif self.node_set[chid].node_type == NodeType.TerminalNode and \ 408 | arr.Width() < ch_arr.Width(): 409 | ch_left = self.out_slices[arr.x1] - self.out_slices[ch_arr.x1] 410 | ch_right = self.out_slices[arr.x2 + 1] - self.out_slices[ch_arr.x1] 411 | onode_tensor_op = onode_tensor_op + NodeIdTensorDict[chid][:, ch_left:ch_right, :, :].contiguous() 412 | 413 | onode_output = getattr(self, op_name)(onode_tensor_op, onode_tensor_res) 414 | else: 415 | onode_output = getattr(self, op_name)(onode_tensor_op) 416 | 417 | NodeIdTensorDict[node.id] = onode_output 418 | 419 | out = NodeIdTensorDict[self.aog.BFS[0]] 420 | out = self.extra_norm_ac(out) #TODO: Why this? Analyze it in depth 421 | return out 422 | 423 | ### AOGNet 424 | class AOGNet(nn.Module): 425 | def __init__(self, block=AOGBlock): 426 | super(AOGNet, self).__init__() 427 | filter_list = cfg.aognet.filter_list 428 | self.aogs = self._create_aogs() 429 | self.block = block 430 | if "BatchNorm2d" in cfg.norm_name: 431 | self.norm_name_base = "BatchNorm2d" 432 | elif "GroupNorm" in cfg.norm_name: 433 | self.norm_name_base = "GroupNorm" 434 | else: 435 | raise ValueError("Unknown norm layer") 436 | 437 | if "Mixture" in cfg.norm_name: 438 | assert len(cfg.norm_k) == len(filter_list)-1 and any(cfg.norm_k), \ 439 | "Wrong mixture component specification (cfg.norm_k)" 440 | else: 441 | cfg.norm_k = [0 for i in range(len(filter_list)-1)] 442 | 443 | self.stem = self._stem(filter_list[0]) 444 | 445 | self.stage0 = self._make_stage(stage=0, in_channels=filter_list[0], out_channels=filter_list[1]) 446 | self.stage1 = self._make_stage(stage=1, in_channels=filter_list[1], out_channels=filter_list[2]) 447 | self.stage2 = self._make_stage(stage=2, in_channels=filter_list[2], out_channels=filter_list[3]) 448 | self.stage3 = None 449 | outchannels = filter_list[3] 450 | if cfg.dataset == 'imagenet': 451 | self.stage3 = self._make_stage(stage=3, in_channels=filter_list[3], out_channels=filter_list[4]) 452 | outchannels = filter_list[4] 453 | 454 | self.conv_head = None 455 | if any(cfg.aognet.out_channels): 456 | assert len(cfg.aognet.out_channels) == 2 457 | self.conv_head = nn.Sequential(Conv_Norm_AC(outchannels, cfg.aognet.out_channels[0], 1, 1, 0, 458 | ac_mode=cfg.activation_mode, 459 | norm_name=self.norm_name_base, 460 | norm_groups=cfg.norm_groups, 461 | norm_k=cfg.norm_k[-1], 462 | norm_attention_mode=cfg.norm_attention_mode), 463 | nn.AdaptiveAvgPool2d((1, 1)), 464 | Conv_Norm_AC(cfg.aognet.out_channels[0], cfg.aognet.out_channels[1], 1, 1, 0, 465 | ac_mode=cfg.activation_mode, 466 | norm_name=self.norm_name_base, 467 | norm_groups=cfg.norm_groups, 468 | norm_k=cfg.norm_k[-1], 469 | norm_attention_mode=cfg.norm_attention_mode) 470 | ) 471 | outchannels = cfg.aognet.out_channels[1] 472 | else: 473 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) 474 | self.fc = nn.Linear(outchannels, cfg.num_classes) 475 | 476 | ## initialize 477 | self._init_params() 478 | 479 | def _stem(self, cout): 480 | layers = [] 481 | if cfg.dataset == 'imagenet': 482 | if cfg.stem.imagenet_head7x7: 483 | layers.append( Conv_Norm_AC(3, cout, 7, 2, 3, 484 | ac_mode=cfg.activation_mode, 485 | norm_name=self.norm_name_base, 486 | norm_groups=cfg.norm_groups, 487 | norm_k=cfg.norm_k[0], 488 | norm_attention_mode=cfg.norm_attention_mode) ) 489 | else: 490 | plane = cout // 2 491 | layers.append( Conv_Norm_AC(3, plane, 3, 2, 1, 492 | ac_mode=cfg.activation_mode, 493 | norm_name=self.norm_name_base, 494 | norm_groups=cfg.norm_groups, 495 | norm_k=cfg.norm_k[0], 496 | norm_attention_mode=cfg.norm_attention_mode) ) 497 | layers.append( Conv_Norm_AC(plane, plane, 3, 1, 1, 498 | ac_mode=cfg.activation_mode, 499 | norm_name=self.norm_name_base, 500 | norm_groups=cfg.norm_groups, 501 | norm_k=cfg.norm_k[0], 502 | norm_attention_mode=cfg.norm_attention_mode) ) 503 | layers.append( Conv_Norm_AC(plane, cout, 3, 1, 1, 504 | ac_mode=cfg.activation_mode, 505 | norm_name=self.norm_name_base, 506 | norm_groups=cfg.norm_groups, 507 | norm_k=cfg.norm_k[0], 508 | norm_attention_mode=cfg.norm_attention_mode) ) 509 | if cfg.stem.replace_maxpool_with_res_bottleneck: 510 | layers.append( NodeOpSingleScale(cout, cout, 2, 511 | ac_mode=cfg.activation_mode, 512 | bn_ratio=cfg.aognet.bottleneck_ratio, 513 | norm_name=self.norm_name_base, 514 | norm_groups=cfg.norm_groups, 515 | norm_k = cfg.norm_k[0], 516 | norm_attention_mode=cfg.norm_attention_mode, 517 | replace_stride_with_avgpool=True) ) # used in OctConv 518 | else: 519 | layers.append( nn.MaxPool2d(2, 2) ) 520 | elif cfg.dataset == 'cifar10' or cfg.dataset == 'cifar100': 521 | layers.append( Conv_Norm_AC(3, cout, 3, 1, 1, 522 | ac_mode=cfg.activation_mode, 523 | norm_name=self.norm_name_base, 524 | norm_groups=cfg.norm_groups, 525 | norm_k=cfg.norm_k[0], 526 | norm_attention_mode=cfg.norm_attention_mode) ) 527 | else: 528 | raise NotImplementedError 529 | 530 | return nn.Sequential(*layers) 531 | 532 | def _init_params(self): 533 | for m in self.modules(): 534 | if isinstance(m, nn.Conv2d): 535 | if cfg.init_mode == 'xavier': 536 | nn.init.xavier_normal_(m.weight) 537 | elif cfg.init_mode == 'avg': 538 | n = m.kernel_size[0] * m.kernel_size[1] * (m.in_channels + m.out_channels) / 2 539 | m.weight.data.normal_(0, math.sqrt(2. / n)) 540 | else: # cfg.init_mode == 'kaiming': as default 541 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 542 | 543 | for name, _ in m.named_parameters(): 544 | if name in ['bias']: 545 | nn.init.constant_(m.bias, 0.0) 546 | elif isinstance(m, (MixtureBatchNorm2d, MixtureGroupNorm)): # before BatchNorm2d 547 | nn.init.normal_(m.weight_, 1, 0.1) 548 | nn.init.normal_(m.bias_, 0, 0.1) 549 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): 550 | nn.init.constant_(m.weight, 1.0) 551 | nn.init.constant_(m.bias, 0.0) 552 | 553 | # handle dbl cnt in init 554 | if cfg.aognet.handle_dbl_cnt_in_param_init: 555 | import re 556 | for name_, m in self.named_modules(): 557 | if 'node' in name_: 558 | idx = re.findall(r'\d+', name_) 559 | sid = int(idx[0]) 560 | nid = int(idx[2]) 561 | npaths = self.aogs[sid].node_set[nid].npaths 562 | if npaths > 1: 563 | scale = 1.0 / npaths 564 | with torch.no_grad(): 565 | for ch in m.modules(): 566 | if isinstance(ch, nn.Conv2d): 567 | ch.weight.mul_(scale) 568 | 569 | # TODO: handle zero-gamma in the last norm layer of bottleneck op 570 | 571 | def _create_aogs(self): 572 | aogs = [] 573 | num_stages = len(cfg.aognet.filter_list) - 1 574 | for i in range(num_stages): 575 | grid_ht = 1 576 | grid_wd = int(cfg.aognet.dims[i]) 577 | aogs.append(get_aog(grid_ht=grid_ht, grid_wd=grid_wd, max_split=cfg.aognet.max_split[i], 578 | use_tnode_topdown_connection= cfg.aognet.extra_node_hierarchy[i] == 1, 579 | use_tnode_bottomup_connection_layerwise= cfg.aognet.extra_node_hierarchy[i] == 2, 580 | use_tnode_bottomup_connection_sequential= cfg.aognet.extra_node_hierarchy[i] == 3, 581 | use_node_lateral_connection= cfg.aognet.extra_node_hierarchy[i] == 4, 582 | use_tnode_bottomup_connection= cfg.aognet.extra_node_hierarchy[i] == 5, 583 | use_node_lateral_connection_1= cfg.aognet.extra_node_hierarchy[i] == 6, 584 | remove_symmetric_children_of_or_node=cfg.aognet.remove_symmetric_children_of_or_node[i] 585 | )) 586 | 587 | return aogs 588 | 589 | def _make_stage(self, stage, in_channels, out_channels): 590 | blocks = nn.Sequential() 591 | dim = cfg.aognet.dims[stage] 592 | assert in_channels % dim == 0 and out_channels % dim == 0 593 | step_channels = (out_channels - in_channels) // cfg.aognet.blocks[stage] 594 | if step_channels % dim != 0: 595 | low = (step_channels // dim) * dim 596 | high = (step_channels // dim + 1) * dim 597 | if (step_channels-low) <= (high-step_channels): 598 | step_channels = low 599 | else: 600 | step_channels = high 601 | 602 | aog = self.aogs[stage] 603 | for j in range(cfg.aognet.blocks[stage]): 604 | name_ = 'stage_{}_block_{}'.format(stage, j) 605 | drop_rate = cfg.aognet.drop_rate[stage] 606 | stride = cfg.aognet.stride[stage] if j==0 else 1 607 | outchannels = (in_channels + step_channels) if j < cfg.aognet.blocks[stage]-1 else out_channels 608 | if stride > 1 and cfg.aognet.when_downsample == 1: 609 | blocks.add_module(name_ + '_transition', 610 | nn.Sequential( Conv_Norm_AC(in_channels, in_channels, 1, 1, 0, 611 | ac_mode=cfg.activation_mode, 612 | norm_name=self.norm_name_base, 613 | norm_groups=cfg.norm_groups, 614 | norm_k=cfg.norm_k[stage], 615 | norm_attention_mode=cfg.norm_attention_mode, 616 | replace_stride_with_avgpool=False), 617 | nn.AvgPool2d(kernel_size=(stride, stride), stride=stride) 618 | ) 619 | ) 620 | stride = 1 621 | elif (stride > 1 or in_channels != outchannels) and cfg.aognet.when_downsample == 2: 622 | trans_op = [Conv_Norm_AC(in_channels, outchannels, 1, 1, 0, 623 | ac_mode=cfg.activation_mode, 624 | norm_name=self.norm_name_base, 625 | norm_groups=cfg.norm_groups, 626 | norm_k=cfg.norm_k[stage], 627 | norm_attention_mode=cfg.norm_attention_mode, 628 | replace_stride_with_avgpool=False)] 629 | if stride > 1: 630 | trans_op.append(nn.AvgPool2d(kernel_size=(stride, stride), stride=stride)) 631 | blocks.add_module(name_ + '_transition', nn.Sequential(*trans_op)) 632 | stride = 1 633 | in_channels = outchannels 634 | 635 | blocks.add_module(name_, self.block(stage, j, aog, in_channels, outchannels, drop_rate, stride)) 636 | in_channels = outchannels 637 | 638 | return blocks 639 | 640 | def forward(self, x): 641 | y = self.stem(x) 642 | 643 | y = self.stage0(y) 644 | y = self.stage1(y) 645 | y = self.stage2(y) 646 | if self.stage3 is not None: 647 | y = self.stage3(y) 648 | if self.conv_head is not None: 649 | y = self.conv_head(y) 650 | else: 651 | y = self.avgpool(y) 652 | y = y.view(y.size(0), -1) 653 | y = self.fc(y) 654 | 655 | return y 656 | 657 | def aognet(**kwargs): 658 | ''' 659 | Construct a single scale AOGNet model 660 | ''' 661 | return AOGNet(**kwargs) 662 | -------------------------------------------------------------------------------- /models/aognet/operator_basic.py: -------------------------------------------------------------------------------- 1 | """ RESEARCH ONLY LICENSE 2 | Copyright (c) 2018-2019 North Carolina State University. 3 | All rights reserved. 4 | Redistribution and use in source and binary forms, with or without modification, are permitted provided 5 | that the following conditions are met: 6 | 1. Redistributions and use are permitted for internal research purposes only, and commercial use 7 | is strictly prohibited under this license. Inquiries regarding commercial use should be directed to the 8 | Office of Research Commercialization at North Carolina State University, 919-215-7199, 9 | https://research.ncsu.edu/commercialization/contact/, commercialization@ncsu.edu . 10 | 2. Commercial use means the sale, lease, export, transfer, conveyance or other distribution to a 11 | third party for financial gain, income generation or other commercial purposes of any kind, whether 12 | direct or indirect. Commercial use also means providing a service to a third party for financial gain, 13 | income generation or other commercial purposes of any kind, whether direct or indirect. 14 | 3. Redistributions of source code must retain the above copyright notice, this list of conditions and 15 | the following disclaimer. 16 | 4. Redistributions in binary form must reproduce the above copyright notice, this list of conditions 17 | and the following disclaimer in the documentation and/or other materials provided with the 18 | distribution. 19 | 5. The names “North Carolina State University”, “NCSU” and any trade-name, personal name, 20 | trademark, trade device, service mark, symbol, image, icon, or any abbreviation, contraction or 21 | simulation thereof owned by North Carolina State University must not be used to endorse or promote 22 | products derived from this software without prior written permission. For written permission, please 23 | contact trademarks@ncsu.edu. 24 | Disclaimer: THIS SOFTWARE IS PROVIDED “AS IS” AND ANY EXPRESSED OR IMPLIED WARRANTIES, 25 | INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR 26 | A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NORTH CAROLINA STATE UNIVERSITY BE 27 | LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES 28 | (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, 29 | DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF 30 | LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR 31 | OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE 32 | POSSIBILITY OF SUCH DAMAGE. 33 | """ 34 | # The system is protected via patent (pending) 35 | # Written by Tianfu Wu and Xilai Li 36 | # Contact: {tianfu_wu, xli47}@ncsu.edu 37 | from __future__ import absolute_import 38 | from __future__ import division 39 | from __future__ import print_function # force to use print as function print(args) 40 | from __future__ import unicode_literals 41 | 42 | import torch 43 | import torch.nn as nn 44 | import torch.nn.functional as F 45 | 46 | _inplace = True 47 | _norm_eps = 1e-5 48 | 49 | def to_int(x): 50 | if x - int(x) < 0.5: 51 | return int(x) 52 | else: 53 | return int(x) + 1 54 | 55 | ### Activation 56 | class AC(nn.Module): 57 | def __init__(self, mode): 58 | super(AC, self).__init__() 59 | if mode == 1: 60 | self.ac = nn.LeakyReLU(inplace=_inplace) 61 | elif mode == 2: 62 | self.ac = nn.ReLU6(inplace=_inplace) 63 | else: 64 | self.ac = nn.ReLU(inplace=_inplace) 65 | 66 | def forward(self, x): 67 | x = self.ac(x) 68 | return x 69 | 70 | ### 71 | class hsigmoid(nn.Module): 72 | def forward(self, x): 73 | out = F.relu6(x + 3, inplace=True) / 6 74 | return out 75 | 76 | ### Feature Norm 77 | def FeatureNorm(norm_name, num_channels, num_groups, num_k, attention_mode): 78 | if norm_name == "BatchNorm2d": 79 | return nn.BatchNorm2d(num_channels, eps=_norm_eps) 80 | elif norm_name == "GroupNorm": 81 | assert num_groups > 1 82 | if num_channels % num_groups != 0: 83 | raise ValueError("channels {} not dividable by groups {}".format(num_channels, num_groups)) 84 | return nn.GroupNorm(num_channels, num_groups, eps=_norm_eps) 85 | elif norm_name == "MixtureBatchNorm2d": 86 | assert num_k > 1 87 | return MixtureBatchNorm2d(num_channels, num_k, attention_mode) 88 | elif norm_name == "MixtureGroupNorm": 89 | assert num_groups > 1 and num_k > 1 90 | if num_channels % num_groups != 0: 91 | raise ValueError("channels {} not dividable by groups {}".format(num_channels, num_groups)) 92 | return MixtureGroupNorm(num_channels, num_groups, num_k, attention_mode) 93 | else: 94 | raise NotImplementedError("Unknown feature norm name") 95 | 96 | ### Attention weights for mixture norm 97 | class AttentionWeights(nn.Module): 98 | expansion = 2 99 | def __init__(self, attention_mode, num_channels, k, 100 | norm_name=None, norm_groups=0): 101 | super(AttentionWeights, self).__init__() 102 | #num_channels *= 2 103 | self.k = k 104 | self.avgpool = nn.AdaptiveAvgPool2d(1) 105 | layers = [] 106 | if attention_mode == 0: 107 | layers = [ nn.Conv2d(num_channels, k, 1), 108 | nn.Sigmoid() ] 109 | elif attention_mode == 4: 110 | layers = [ nn.Conv2d(num_channels, k, 1), 111 | hsigmoid() ] 112 | elif attention_mode == 1: 113 | layers = [ nn.Conv2d(num_channels, k*self.expansion, 1), 114 | nn.ReLU(inplace=True), 115 | nn.Conv2d(k*self.expansion, k, 1), 116 | nn.Sigmoid() ] 117 | elif attention_mode == 2: 118 | assert norm_name is not None 119 | layers = [ nn.Conv2d(num_channels, k, 1, bias=False), 120 | FeatureNorm(norm_name, k, norm_groups, 0, 0), 121 | hsigmoid() ] 122 | elif attention_mode == 5: 123 | assert norm_name is not None 124 | layers = [ nn.Conv2d(num_channels, k, 1, bias=False), 125 | FeatureNorm(norm_name, k, norm_groups, 0, 0), 126 | nn.Sigmoid() ] 127 | elif attention_mode == 6: 128 | assert norm_name is not None 129 | layers = [ nn.Conv2d(num_channels, k, 1, bias=False), 130 | FeatureNorm(norm_name, k, norm_groups, 0, 0), 131 | nn.Softmax(dim=1) ] 132 | elif attention_mode == 3: 133 | assert norm_name is not None 134 | layers = [ nn.Conv2d(num_channels, k*self.expansion, 1, bias=False), 135 | FeatureNorm(norm_name, k*self.expansion, norm_groups, 0, 0), 136 | nn.ReLU(inplace=True), 137 | nn.Conv2d(k*self.expansion, k, 1, bias=False), 138 | FeatureNorm(norm_name, k, norm_groups, 0, 0), 139 | hsigmoid() ] 140 | else: 141 | raise NotImplementedError("Unknow attention weight type") 142 | self.attention = nn.Sequential(*layers) 143 | 144 | def forward(self, x): 145 | b, c, _, _ = x.size() 146 | y = self.avgpool(x)#.view(b, c) 147 | var = torch.var(x, dim=(2, 3)).view(b, c, 1, 1) 148 | y *= (var + 1e-3).rsqrt() 149 | #y = torch.cat((y, var), dim=1) 150 | return self.attention(y).view(b, self.k) 151 | 152 | 153 | ### Mixture Norm 154 | # TODO: keep it to use FP32 always, need to figure out how to set it using apex ? 155 | class MixtureBatchNorm2d(nn.BatchNorm2d): 156 | def __init__(self, num_channels, k, attention_mode, eps=_norm_eps, momentum=0.1, 157 | track_running_stats=True): 158 | super(MixtureBatchNorm2d, self).__init__(num_channels, eps=eps, 159 | momentum=momentum, affine=False, track_running_stats=track_running_stats) 160 | self.k = k 161 | self.weight_ = nn.Parameter(torch.Tensor(k, num_channels)) 162 | self.bias_ = nn.Parameter(torch.Tensor(k, num_channels)) 163 | 164 | self.attention_weights = AttentionWeights(attention_mode, num_channels, k, 165 | norm_name='BatchNorm2d') 166 | 167 | self._init_params() 168 | 169 | def _init_params(self): 170 | nn.init.normal_(self.weight_, 1, 0.1) 171 | nn.init.normal_(self.bias_, 0, 0.1) 172 | 173 | def forward(self, x): 174 | output = super(MixtureBatchNorm2d, self).forward(x) 175 | size = output.size() 176 | y = self.attention_weights(x) # bxk # or use output as attention input 177 | 178 | weight = y @ self.weight_ # bxc 179 | bias = y @ self.bias_ # bxc 180 | weight = weight.unsqueeze(-1).unsqueeze(-1).expand(size) 181 | bias = bias.unsqueeze(-1).unsqueeze(-1).expand(size) 182 | 183 | return weight * output + bias 184 | 185 | 186 | # Modified on top of nn.GroupNorm 187 | # TODO: keep it to use FP32 always, need to figure out how to set it using apex ? 188 | class MixtureGroupNorm(nn.Module): 189 | __constants__ = ['num_groups', 'num_channels', 'k', 'eps', 'weight', 190 | 'bias'] 191 | 192 | def __init__(self, num_channels, num_groups, k, attention_mode, eps=_norm_eps): 193 | super(MixtureGroupNorm, self).__init__() 194 | self.num_groups = num_groups 195 | self.num_channels = num_channels 196 | self.k = k 197 | self.eps = eps 198 | self.affine = True 199 | self.weight_ = nn.Parameter(torch.Tensor(k, num_channels)) 200 | self.bias_ = nn.Parameter(torch.Tensor(k, num_channels)) 201 | self.register_parameter('weight', None) 202 | self.register_parameter('bias', None) 203 | 204 | self.attention_weights = AttentionWeights(attention_mode, num_channels, k, 205 | norm_name='GroupNorm', norm_groups=1) 206 | 207 | self.reset_parameters() 208 | 209 | def reset_parameters(self): 210 | nn.init.normal_(self.weight_, 1, 0.1) 211 | nn.init.normal_(self.bias_, 0, 0.1) 212 | 213 | def forward(self, x): 214 | output = F.group_norm( 215 | x, self.num_groups, self.weight, self.bias, self.eps) 216 | size = output.size() 217 | 218 | y = self.attention_weights(x) # TODO: use output as attention input 219 | 220 | weight = y @ self.weight_ 221 | bias = y @ self.bias_ 222 | 223 | weight = weight.unsqueeze(-1).unsqueeze(-1).expand(size) 224 | bias = bias.unsqueeze(-1).unsqueeze(-1).expand(size) 225 | 226 | return weight * output + bias 227 | 228 | def extra_repr(self): 229 | return '{num_groups}, {num_channels}, eps={eps}, ' \ 230 | 'affine={affine}'.format(**self.__dict__) 231 | 232 | 233 | 234 | 235 | -------------------------------------------------------------------------------- /models/aognet/operator_singlescale.py: -------------------------------------------------------------------------------- 1 | """ RESEARCH ONLY LICENSE 2 | Copyright (c) 2018-2019 North Carolina State University. 3 | All rights reserved. 4 | Redistribution and use in source and binary forms, with or without modification, are permitted provided 5 | that the following conditions are met: 6 | 1. Redistributions and use are permitted for internal research purposes only, and commercial use 7 | is strictly prohibited under this license. Inquiries regarding commercial use should be directed to the 8 | Office of Research Commercialization at North Carolina State University, 919-215-7199, 9 | https://research.ncsu.edu/commercialization/contact/, commercialization@ncsu.edu . 10 | 2. Commercial use means the sale, lease, export, transfer, conveyance or other distribution to a 11 | third party for financial gain, income generation or other commercial purposes of any kind, whether 12 | direct or indirect. Commercial use also means providing a service to a third party for financial gain, 13 | income generation or other commercial purposes of any kind, whether direct or indirect. 14 | 3. Redistributions of source code must retain the above copyright notice, this list of conditions and 15 | the following disclaimer. 16 | 4. Redistributions in binary form must reproduce the above copyright notice, this list of conditions 17 | and the following disclaimer in the documentation and/or other materials provided with the 18 | distribution. 19 | 5. The names “North Carolina State University”, “NCSU” and any trade-name, personal name, 20 | trademark, trade device, service mark, symbol, image, icon, or any abbreviation, contraction or 21 | simulation thereof owned by North Carolina State University must not be used to endorse or promote 22 | products derived from this software without prior written permission. For written permission, please 23 | contact trademarks@ncsu.edu. 24 | Disclaimer: THIS SOFTWARE IS PROVIDED “AS IS” AND ANY EXPRESSED OR IMPLIED WARRANTIES, 25 | INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR 26 | A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NORTH CAROLINA STATE UNIVERSITY BE 27 | LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES 28 | (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, 29 | DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF 30 | LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR 31 | OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE 32 | POSSIBILITY OF SUCH DAMAGE. 33 | """ 34 | # The system is protected via patent (pending) 35 | # Written by Tianfu Wu and Xilai Li 36 | # Contact: {tianfu_wu, xli47}@ncsu.edu 37 | from __future__ import absolute_import 38 | from __future__ import division 39 | from __future__ import print_function # force to use print as function print(args) 40 | from __future__ import unicode_literals 41 | 42 | import torch 43 | import torch.nn as nn 44 | import torch.nn.functional as F 45 | from torch.autograd import Variable 46 | 47 | from .operator_basic import * 48 | 49 | _bias = False 50 | _inplace = True 51 | 52 | ### Conv_Norm 53 | class Conv_Norm(nn.Module): 54 | def __init__(self, in_channels, out_channels, kernel_size, stride, padding, 55 | groups=1, drop_rate=0.0, 56 | norm_name='BatchNorm2d', norm_groups=0, norm_k=0, norm_attention_mode=0, 57 | replace_stride_with_avgpool=False): 58 | super(Conv_Norm, self).__init__() 59 | 60 | layers = [] 61 | if stride > 1 and replace_stride_with_avgpool: 62 | layers.append(nn.AvgPool2d(kernel_size=(stride, stride), stride=stride)) 63 | stride = 1 64 | layers.append(nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, 65 | stride=stride, padding=padding, 66 | groups=groups, bias=_bias)) 67 | layers.append(FeatureNorm(norm_name, out_channels, norm_groups, norm_k, norm_attention_mode)) 68 | if drop_rate > 0.0: 69 | layers.append(nn.Dropout2d(p=drop_rate, inplace=_inplace)) 70 | self.conv_norm = nn.Sequential(*layers) 71 | 72 | def forward(self, x): 73 | y = self.conv_norm(x) 74 | return y 75 | 76 | ### Conv_Norm_AC 77 | class Conv_Norm_AC(nn.Module): 78 | def __init__(self, in_channels, out_channels, kernel_size, stride, padding, 79 | groups=1, drop_rate=0., ac_mode=0, 80 | norm_name='BatchNorm2d', norm_groups=0, norm_k=0, norm_attention_mode=0, 81 | replace_stride_with_avgpool=False): 82 | super(Conv_Norm_AC, self).__init__() 83 | 84 | self.conv_norm = Conv_Norm(in_channels, out_channels, kernel_size, stride, padding, 85 | groups=groups, drop_rate=drop_rate, 86 | norm_name=norm_name, norm_groups=norm_groups, norm_k=norm_k, norm_attention_mode=norm_attention_mode, 87 | replace_stride_with_avgpool=replace_stride_with_avgpool) 88 | self.ac = AC(ac_mode) 89 | 90 | def forward(self, x): 91 | y = self.conv_norm(x) 92 | y = self.ac(y) 93 | return y 94 | 95 | ### NodeOpSingleScale 96 | class NodeOpSingleScale(nn.Module): 97 | def __init__(self, in_channels, out_channels, stride, 98 | groups=1, drop_rate=0., ac_mode=0, bn_ratio=0.25, 99 | norm_name='BatchNorm2d', norm_groups=0, norm_k=0, norm_attention_mode=0, 100 | replace_stride_with_avgpool=True): 101 | super(NodeOpSingleScale, self).__init__() 102 | if "BatchNorm2d" in norm_name: 103 | norm_name_base = "BatchNorm2d" 104 | elif "GroupNorm" in norm_name: 105 | norm_name_base = "GroupNorm" 106 | else: 107 | raise ValueError("Unknown norm layer") 108 | 109 | mid_channels = max(4, to_int(out_channels * bn_ratio / groups) * groups) 110 | self.conv_norm_ac_1 = Conv_Norm_AC(in_channels, mid_channels, 1, 1, 0, 111 | ac_mode=ac_mode, 112 | norm_name=norm_name_base, norm_groups=norm_groups, norm_k=norm_k, norm_attention_mode=norm_attention_mode) 113 | self.conv_norm_ac_2 = Conv_Norm_AC(mid_channels, mid_channels, 3, stride, 1, 114 | groups=groups, ac_mode=ac_mode, 115 | norm_name=norm_name, norm_groups=norm_groups, norm_k=norm_k, norm_attention_mode=norm_attention_mode, 116 | replace_stride_with_avgpool=False) 117 | self.conv_norm_3 = Conv_Norm(mid_channels, out_channels, 1, 1, 0, 118 | drop_rate=drop_rate, 119 | norm_name=norm_name_base, norm_groups=norm_groups, norm_k=norm_k, norm_attention_mode=norm_attention_mode) 120 | 121 | self.shortcut = None 122 | if in_channels != out_channels or stride > 1: 123 | self.shortcut = Conv_Norm(in_channels, out_channels, 1, stride, 0, 124 | norm_name=norm_name_base, norm_groups=norm_groups, norm_k=norm_k, norm_attention_mode=norm_attention_mode, 125 | replace_stride_with_avgpool=replace_stride_with_avgpool) 126 | 127 | self.ac = AC(ac_mode) 128 | 129 | def forward(self, x, res=None): 130 | residual = x if res is None else res 131 | y = self.conv_norm_ac_1(x) 132 | y = self.conv_norm_ac_2(y) 133 | y = self.conv_norm_3(y) 134 | 135 | if self.shortcut is not None: 136 | residual = self.shortcut(residual) 137 | 138 | y += residual 139 | y = self.ac(y) 140 | return y 141 | 142 | ### TODO: write a unit test for NodeOpSingleScale in a standalone way 143 | 144 | 145 | -------------------------------------------------------------------------------- /models/config.py: -------------------------------------------------------------------------------- 1 | from yacs.config import CfgNode as CN 2 | 3 | _C = CN() 4 | _C.arch = 'aognet' 5 | _C.batch_size = 128 6 | _C.num_epoch = 300 7 | _C.dataset = 'cifar10' 8 | _C.num_classes = 10 9 | _C.crop_size = 224 # imagenet 10 | _C.crop_interpolation = 2 # 2=BILINEAR, default; 3=BICUBIC 11 | _C.optimizer = 'SGD' 12 | _C.gamma = 0.1 # decay_rate 13 | _C.use_cosine_lr = False 14 | _C.cosine_lr_min = 0.0 15 | _C.warmup_epochs = 5 16 | _C.lr = 0.1 17 | _C.lr_scale_factor = 256 # per nvidia apex 18 | _C.lr_milestones = [150, 225] 19 | _C.momentum = 0.9 20 | _C.wd = 5e-4 21 | _C.nesterov = False 22 | _C.activation_mode = 0 # 1: leakyReLU, 2: ReLU6 , other: ReLU 23 | _C.init_mode = 'kaiming' 24 | _C.norm_name = 'BatchNorm2d' 25 | _C.norm_groups = 0 26 | _C.norm_k = [0] 27 | _C.norm_attention_mode = 0 28 | _C.norm_zero_gamma_init = False 29 | _C.norm_all_mix = False 30 | 31 | # data augmentation 32 | _C.dataaug = CN() 33 | _C.dataaug.imagenet_extra_aug = False 34 | _C.dataaug.labelsmoothing_rate = 0. # 0.1 35 | _C.dataaug.mixup_rate = 0. # 0.2 36 | 37 | # stem 38 | _C.stem = CN() 39 | _C.stem.imagenet_head7x7 = False 40 | _C.stem.replace_maxpool_with_res_bottleneck = False 41 | _C.stem.stem_kernel_size = 7 42 | _C.stem.stem_stride = 2 43 | 44 | # resnet 45 | _C.resnet = CN() 46 | _C.resnet.base_inplanes = 16 47 | _C.resnet.replace_stride_with_dilation = [False, False, False] 48 | _C.resnet.replace_stride_with_avgpool = False 49 | _C.resnet.extra_norm_ac = False 50 | 51 | # mobilenet 52 | _C.mobilenet = CN() 53 | _C.mobilenet.rm_se = False # for mobilenetv3 with mixture norm 54 | _C.mobilenet.use_mn_in_se = False 55 | 56 | # aognet 57 | _C.aognet = CN() 58 | _C.aognet.filter_list = [16, 64, 128, 256] 59 | _C.aognet.out_channels = [0,0] 60 | _C.aognet.blocks = [1, 1, 1] 61 | _C.aognet.dims = [4, 4, 4] 62 | _C.aognet.max_split = [2, 2, 2] # must >= 2 63 | _C.aognet.extra_node_hierarchy = [0, 0, 0] # 0: none, 1: tnode topdown, 2: tnode bottomup layerwise, 3: tnode bottomup sequential, 4: non-term node lateral, 5: tnode bottomup 64 | _C.aognet.remove_symmetric_children_of_or_node = [0, 0, 0] 65 | _C.aognet.terminal_node_no_slice = [0, 0, 0] 66 | _C.aognet.stride = [1, 2, 2] 67 | _C.aognet.drop_rate = [0.0, 0.0, 0.0] 68 | _C.aognet.bottleneck_ratio = 0.25 69 | _C.aognet.handle_dbl_cnt = True 70 | _C.aognet.handle_tnode_dbl_cnt = False 71 | _C.aognet.handle_dbl_cnt_in_param_init = False 72 | _C.aognet.use_group_conv = False 73 | _C.aognet.width_per_group = 0 74 | _C.aognet.when_downsample = 0 # 0: at T-nodes, 1: before a aogblock, by conv_norm_ac + avgpool 75 | _C.aognet.replace_stride_with_avgpool = True # for downsample in node op. 76 | _C.aognet.use_elem_max_for_ORNodes = False 77 | 78 | cfg = _C 79 | -------------------------------------------------------------------------------- /models/mobilenet.py: -------------------------------------------------------------------------------- 1 | # from https://raw.githubusercontent.com/pytorch/vision/master/torchvision/models/mobilenet.py 2 | # Modified by Tianfu Wu 3 | # Contact: tianfu_wu@ncsu.edu 4 | 5 | from torch import nn 6 | from .aognet.operator_basic import FeatureNorm, MixtureBatchNorm2d, MixtureGroupNorm 7 | from .config import cfg 8 | 9 | 10 | __all__ = ['MobileNetV2', 'mobilenet_v2'] 11 | 12 | 13 | class ConvBNReLU(nn.Sequential): 14 | def __init__(self, in_planes, out_planes, kernel_size=3, stride=1, groups=1, 15 | norm_name=None, norm_groups=0, norm_k=0, norm_attention_mode=0): 16 | if norm_name is None: 17 | norm_name = "BatchNorm2d" 18 | padding = (kernel_size - 1) // 2 19 | super(ConvBNReLU, self).__init__( 20 | nn.Conv2d(in_planes, out_planes, kernel_size, stride, padding, groups=groups, bias=False), 21 | #nn.BatchNorm2d(out_planes), 22 | FeatureNorm(norm_name, out_planes, 23 | num_groups=norm_groups, num_k=norm_k, 24 | attention_mode=norm_attention_mode), 25 | nn.ReLU6(inplace=True) 26 | ) 27 | 28 | 29 | class InvertedResidual(nn.Module): 30 | def __init__(self, inp, oup, stride, expand_ratio, 31 | norm_name=None, norm_groups=0, norm_k=0, norm_attention_mode=0): 32 | super(InvertedResidual, self).__init__() 33 | self.stride = stride 34 | assert stride in [1, 2] 35 | if norm_name is None: 36 | norm_name = "BatchNorm2d" 37 | if "BatchNorm2d" in norm_name: 38 | norm_name_base = "BatchNorm2d" 39 | elif "GroupNorm" in norm_name: 40 | norm_name_base = "GroupNorm" 41 | else: 42 | raise ValueError("Unknown norm.") 43 | 44 | hidden_dim = int(round(inp * expand_ratio)) 45 | self.use_res_connect = self.stride == 1 and inp == oup 46 | 47 | layers = [] 48 | if expand_ratio != 1: 49 | # pw 50 | layers.append(ConvBNReLU(inp, hidden_dim, kernel_size=1, 51 | norm_name=norm_name_base, norm_groups=norm_groups, 52 | norm_k=norm_k, norm_attention_mode=norm_attention_mode)) 53 | layers.extend([ 54 | # dw 55 | ConvBNReLU(hidden_dim, hidden_dim, stride=stride, groups=hidden_dim, 56 | norm_name=norm_name, norm_groups=norm_groups, 57 | norm_k=norm_k, norm_attention_mode=norm_attention_mode), 58 | # pw-linear 59 | nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False), 60 | #nn.BatchNorm2d(oup), 61 | FeatureNorm(norm_name_base, oup, 62 | num_groups=norm_groups, num_k=norm_k, 63 | attention_mode=norm_attention_mode), 64 | ]) 65 | self.conv = nn.Sequential(*layers) 66 | 67 | def forward(self, x): 68 | if self.use_res_connect: 69 | return x + self.conv(x) 70 | else: 71 | return self.conv(x) 72 | 73 | 74 | class MobileNetV2(nn.Module): 75 | def __init__(self, num_classes=1000, width_mult=1.0): 76 | super(MobileNetV2, self).__init__() 77 | block = InvertedResidual 78 | input_channel = 32 79 | last_channel = 1280 80 | inverted_residual_setting = [ 81 | # t, c, n, s 82 | [1, 16, 1, 1], 83 | [6, 24, 2, 2], 84 | [6, 32, 3, 2], 85 | [6, 64, 4, 2], 86 | [6, 96, 3, 1], 87 | [6, 160, 3, 2], 88 | [6, 320, 1, 1], 89 | ] 90 | norm_name = cfg.norm_name 91 | norm_groups = cfg.norm_groups 92 | norm_k = cfg.norm_k 93 | norm_attention_mode = cfg.norm_attention_mode 94 | if norm_name is None: 95 | norm_name = "BatchNorm2d" 96 | if "BatchNorm2d" in norm_name: 97 | norm_name_base = "BatchNorm2d" 98 | elif "GroupNorm" in norm_name: 99 | norm_name_base = "GroupNorm" 100 | else: 101 | raise ValueError("Unknown norm.") 102 | 103 | # building first layer 104 | input_channel = int(input_channel * width_mult) 105 | self.last_channel = int(last_channel * max(1.0, width_mult)) 106 | features = [ConvBNReLU(3, input_channel, stride=2, 107 | norm_name=norm_name_base, norm_groups=norm_groups, 108 | norm_k=-1, norm_attention_mode=norm_attention_mode)] 109 | # building inverted residual blocks 110 | for j, (t, c, n, s) in enumerate(inverted_residual_setting): 111 | output_channel = int(c * width_mult) 112 | for i in range(n): 113 | stride = s if i == 0 else 1 114 | features.append(block(input_channel, output_channel, stride, expand_ratio=t, 115 | norm_name=norm_name, norm_groups=norm_groups, 116 | norm_k=norm_k[j], norm_attention_mode=norm_attention_mode)) 117 | input_channel = output_channel 118 | # building last several layers 119 | features.append(ConvBNReLU(input_channel, self.last_channel, kernel_size=1, 120 | norm_name=norm_name_base, norm_groups=norm_groups, 121 | norm_k=-1, norm_attention_mode=norm_attention_mode)) 122 | # make it nn.Sequential 123 | self.features = nn.Sequential(*features) 124 | 125 | # building classifier 126 | self.classifier = nn.Sequential( 127 | nn.Dropout(0.2), 128 | nn.Linear(self.last_channel, num_classes), 129 | ) 130 | 131 | # weight initialization 132 | for m in self.modules(): 133 | if isinstance(m, nn.Conv2d): 134 | nn.init.kaiming_normal_(m.weight, mode='fan_out') 135 | if m.bias is not None: 136 | nn.init.zeros_(m.bias) 137 | elif isinstance(m, (MixtureBatchNorm2d, MixtureGroupNorm)): 138 | nn.init.normal_(m.weight_, 1, 0.1) 139 | nn.init.normal_(m.bias_, 0, 0.1) 140 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): 141 | nn.init.ones_(m.weight) 142 | nn.init.zeros_(m.bias) 143 | elif isinstance(m, nn.Linear): 144 | nn.init.normal_(m.weight, 0, 0.01) 145 | nn.init.zeros_(m.bias) 146 | 147 | def forward(self, x): 148 | x = self.features(x) 149 | x = x.mean([2, 3]) 150 | x = self.classifier(x) 151 | return x 152 | 153 | 154 | def mobilenet_v2(pretrained=False, progress=True, **kwargs): 155 | """ 156 | Constructs a MobileNetV2 architecture from 157 | `"MobileNetV2: Inverted Residuals and Linear Bottlenecks" `_. 158 | 159 | Args: 160 | pretrained (bool): If True, returns a model pre-trained on ImageNet 161 | progress (bool): If True, displays a progress bar of the download to stderr 162 | """ 163 | model = MobileNetV2(**kwargs) 164 | if pretrained: 165 | state_dict = load_state_dict_from_url(model_urls['mobilenet_v2'], 166 | progress=progress) 167 | model.load_state_dict(state_dict) 168 | return model 169 | -------------------------------------------------------------------------------- /models/resnet.py: -------------------------------------------------------------------------------- 1 | # Modefied From: https://github.com/pytorch/vision/blob/master/torchvision/models/resnet.py 2 | # By Tianfu Wu 3 | # Contact: tianfu_wu@ncsu.edu 4 | import torch.nn as nn 5 | from .aognet.operator_basic import FeatureNorm, MixtureBatchNorm2d, MixtureGroupNorm 6 | from .config import cfg 7 | 8 | 9 | __all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101', 10 | 'resnet152', 'resnext50_32x4d', 'resnext101_32x8d', 'resnext101_64x4d'] 11 | 12 | 13 | def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1): 14 | """3x3 convolution with padding""" 15 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 16 | padding=dilation, groups=groups, bias=False, dilation=dilation) 17 | 18 | 19 | def conv1x1(in_planes, out_planes, stride=1): 20 | """1x1 convolution""" 21 | return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) 22 | 23 | 24 | class BasicBlock(nn.Module): 25 | expansion = 1 26 | 27 | def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1, 28 | base_width=64, dilation=1, 29 | norm_name=None, norm_groups=0, norm_k=0, norm_attention_mode=0): 30 | super(BasicBlock, self).__init__() 31 | if norm_name is None: 32 | norm_name = "BatchNorm2d" 33 | if groups != 1 or base_width != 64: 34 | raise ValueError( 35 | 'BasicBlock only supports groups=1 and base_width=64') 36 | if dilation > 1: 37 | raise NotImplementedError( 38 | "Dilation > 1 not supported in BasicBlock") 39 | # Both self.conv1 and self.downsample layers downsample the input when stride != 1 40 | self.conv1 = conv3x3(inplanes, planes, stride) 41 | self.bn1 = FeatureNorm(norm_name, planes, 42 | num_groups=norm_groups, num_k=norm_k, 43 | attention_mode=norm_attention_mode) 44 | self.relu = nn.ReLU(inplace=True) 45 | self.conv2 = conv3x3(planes, planes) 46 | self.bn2 = FeatureNorm(norm_name, planes, 47 | num_groups=norm_groups, num_k=norm_k, 48 | attention_mode=norm_attention_mode) 49 | self.downsample = downsample 50 | self.stride = stride 51 | 52 | def forward(self, x): 53 | identity = x 54 | 55 | out = self.conv1(x) 56 | out = self.bn1(out) 57 | out = self.relu(out) 58 | 59 | out = self.conv2(out) 60 | out = self.bn2(out) 61 | 62 | if self.downsample is not None: 63 | identity = self.downsample(x) 64 | 65 | out += identity 66 | out = self.relu(out) 67 | 68 | return out 69 | 70 | 71 | class Bottleneck(nn.Module): 72 | expansion = 4 73 | 74 | def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1, 75 | base_width=64, dilation=1, 76 | norm_name=None, norm_groups=0, norm_k=0, norm_attention_mode=0, 77 | norm_all_mix=False): 78 | super(Bottleneck, self).__init__() 79 | if norm_name is None: 80 | norm_name = "BatchNorm2d" 81 | if norm_all_mix: 82 | norm_name_base = norm_name 83 | else: 84 | if "BatchNorm2d" in norm_name: 85 | norm_name_base = "BatchNorm2d" 86 | elif "GroupNorm" in norm_name: 87 | norm_name_base = "GroupNorm" 88 | else: 89 | raise ValueError("Unknown norm.") 90 | width = int(planes * (base_width / 64.)) * groups 91 | # Both self.conv2 and self.downsample layers downsample the input when stride != 1 92 | self.conv1 = conv1x1(inplanes, width) 93 | self.bn1 = FeatureNorm(norm_name_base, width, 94 | num_groups=norm_groups, num_k=norm_k, 95 | attention_mode=norm_attention_mode) 96 | self.conv2 = conv3x3(width, width, stride, groups, dilation) 97 | self.bn2 = FeatureNorm(norm_name, width, 98 | num_groups=norm_groups, num_k=norm_k, 99 | attention_mode=norm_attention_mode) 100 | self.conv3 = conv1x1(width, planes * self.expansion) 101 | self.bn3 = FeatureNorm(norm_name_base, planes * self.expansion, 102 | num_groups=norm_groups, num_k=norm_k, 103 | attention_mode=norm_attention_mode) 104 | self.relu = nn.ReLU(inplace=True) 105 | self.downsample = downsample 106 | self.stride = stride 107 | 108 | def forward(self, x): 109 | identity = x 110 | 111 | out = self.conv1(x) 112 | out = self.bn1(out) 113 | out = self.relu(out) 114 | 115 | out = self.conv2(out) 116 | out = self.bn2(out) 117 | out = self.relu(out) 118 | 119 | out = self.conv3(out) 120 | out = self.bn3(out) 121 | 122 | if self.downsample is not None: 123 | identity = self.downsample(x) 124 | 125 | out += identity 126 | out = self.relu(out) 127 | 128 | return out 129 | 130 | 131 | class ResNet(nn.Module): 132 | def __init__(self, block, layers, groups=1, width_per_group=64): 133 | super(ResNet, self).__init__() 134 | replace_stride_with_dilation = cfg.resnet.replace_stride_with_dilation 135 | base_inplanes = cfg.resnet.base_inplanes 136 | norm_name = cfg.norm_name 137 | norm_groups = cfg.norm_groups 138 | norm_attention_mode = cfg.norm_attention_mode 139 | norm_all_mix = cfg.norm_all_mix 140 | replace_stride_with_avgpool = cfg.resnet.replace_stride_with_avgpool 141 | self.norm_name = norm_name 142 | self.norm_groups = norm_groups 143 | self.norm_ks = cfg.norm_k 144 | self.norm_attention_mode = norm_attention_mode 145 | self.norm_all_mix = norm_all_mix 146 | if norm_all_mix: 147 | self.norm_name_base = norm_name 148 | else: 149 | if "BatchNorm2d" in norm_name: 150 | self.norm_name_base = "BatchNorm2d" 151 | elif "GroupNorm" in norm_name: 152 | self.norm_name_base = "GroupNorm" 153 | else: 154 | raise ValueError("Unknown norm layer") 155 | if "Mixture" in norm_name: 156 | assert len(self.norm_ks) == len(layers) and any(self.norm_ks), \ 157 | "Wrong mixture component specification (cfg.norm_k)" 158 | else: 159 | self.norm_ks = [0 for i in range(len(layers))] 160 | 161 | self.inplanes = base_inplanes 162 | self.extra_norm_ac = cfg.resnet.extra_norm_ac 163 | self.dilation = 1 164 | self.norm_k = self.norm_ks[0] 165 | if replace_stride_with_dilation is None: 166 | # each element in the tuple indicates if we should replace 167 | # the 2x2 stride with a dilated convolution instead 168 | replace_stride_with_dilation = [False, False, False] 169 | if len(replace_stride_with_dilation) != 3: 170 | raise ValueError("replace_stride_with_dilation should be None " 171 | "or a 3-element tuple, got {}".format(replace_stride_with_dilation)) 172 | self.groups = groups 173 | self.base_width = width_per_group 174 | if cfg.stem.imagenet_head7x7: 175 | self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=cfg.stem.stem_kernel_size, 176 | stride=cfg.stem.stem_stride, padding=( 177 | cfg.stem.stem_kernel_size-1)//2, 178 | bias=False) 179 | self.bn1 = FeatureNorm(self.norm_name_base, self.inplanes, 180 | num_groups=norm_groups, num_k=self.norm_k, 181 | attention_mode=norm_attention_mode) 182 | self.relu = nn.ReLU(inplace=True) 183 | self.maxpool = nn.MaxPool2d( 184 | kernel_size=3, stride=2, padding=1) if cfg.dataset == 'imagenet' else None 185 | else: 186 | plane = self.inplanes // 2 187 | self.conv1 = nn.Sequential( 188 | nn.Conv2d(3, plane, kernel_size=3, 189 | stride=2, padding=1, bias=False), 190 | FeatureNorm(self.norm_name_base, plane, 191 | num_groups=norm_groups, num_k=self.norm_k, 192 | attention_mode=norm_attention_mode), 193 | nn.ReLU(inplace=True), 194 | nn.Conv2d(plane, plane, kernel_size=3, 195 | stride=1, padding=1, bias=False), 196 | FeatureNorm(self.norm_name_base, plane, 197 | num_groups=norm_groups, num_k=self.norm_k, 198 | attention_mode=norm_attention_mode), 199 | nn.ReLU(inplace=True), 200 | nn.Conv2d(plane, self.inplanes, kernel_size=3, 201 | stride=1, padding=1, bias=False) 202 | ) 203 | self.bn1 = FeatureNorm(self.norm_name_base, self.inplanes, 204 | num_groups=norm_groups, num_k=self.norm_k, 205 | attention_mode=norm_attention_mode) 206 | self.relu = nn.ReLU(inplace=True) 207 | self.maxpool = nn.MaxPool2d(kernel_size=2, stride=2) 208 | 209 | self.layer1 = self._make_layer(block, base_inplanes, layers[0], 210 | replace_stride_with_avgpool=replace_stride_with_avgpool) 211 | self.norm_k = self.norm_ks[1] 212 | self.layer2 = self._make_layer(block, base_inplanes*2, layers[1], stride=2, 213 | dilate=replace_stride_with_dilation[0], 214 | replace_stride_with_avgpool=replace_stride_with_avgpool) 215 | self.norm_k = self.norm_ks[2] 216 | self.layer3 = self._make_layer(block, base_inplanes*4, layers[2], stride=2, 217 | dilate=replace_stride_with_dilation[1], 218 | replace_stride_with_avgpool=replace_stride_with_avgpool) 219 | self.layer4 = None 220 | outplanes = base_inplanes*4*block.expansion 221 | if len(layers) > 3: 222 | self.norm_k = self.norm_ks[3] 223 | self.layer4 = self._make_layer(block, base_inplanes*8, layers[3], stride=2, 224 | dilate=replace_stride_with_dilation[2], 225 | replace_stride_with_avgpool=replace_stride_with_avgpool) 226 | outplanes = base_inplanes*8*block.expansion 227 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) 228 | self.fc = nn.Linear(outplanes, cfg.num_classes) 229 | 230 | for m in self.modules(): 231 | if isinstance(m, nn.Conv2d): 232 | nn.init.kaiming_normal_( 233 | m.weight, mode='fan_out', nonlinearity='relu') 234 | for name, _ in m.named_parameters(): 235 | if name in ['bias']: 236 | nn.init.constant_(m.bias, 0) 237 | elif isinstance(m, (MixtureBatchNorm2d, MixtureGroupNorm)): 238 | nn.init.normal_(m.weight_, 1, 0.1) 239 | nn.init.normal_(m.bias_, 0, 0.1) 240 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): 241 | nn.init.constant_(m.weight, 1) 242 | nn.init.constant_(m.bias, 0) 243 | 244 | # Zero-initialize the last BN in each residual branch, 245 | # so that the residual branch starts with zeros, and each residual block behaves like an identity. 246 | # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677 247 | if cfg.norm_zero_gamma_init: 248 | for m in self.modules(): 249 | if isinstance(m, Bottleneck): 250 | if isinstance(m.bn3, (MixtureBatchNorm2d, MixtureGroupNorm)): 251 | nn.init.constant_(m.bn3.weight_, 0) 252 | nn.init.constant_(m.bn3.bias_, 0) 253 | else: 254 | nn.init.constant_(m.bn3.weight, 0) 255 | elif isinstance(m, BasicBlock): 256 | # TODO: handle mixture norm 257 | nn.init.constant_(m.bn2.weight, 0) 258 | 259 | def _extra_norm_ac(self, out_channels, norm_k): 260 | return nn.Sequential(FeatureNorm(self.norm_name_base, out_channels, 261 | self.norm_groups, norm_k, 262 | self.norm_attention_mode), 263 | nn.ReLU(inplace=True)) 264 | 265 | def _make_layer(self, block, planes, blocks, stride=1, dilate=False, 266 | replace_stride_with_avgpool=False): 267 | norm_name = self.norm_name 268 | norm_name_base = self.norm_name_base 269 | norm_groups = self.norm_groups 270 | norm_k = self.norm_k 271 | norm_attention_mode = self.norm_attention_mode 272 | norm_all_mix = self.norm_all_mix 273 | downsample = None 274 | previous_dilation = self.dilation 275 | if dilate: 276 | self.dilation *= stride 277 | stride = 1 278 | 279 | downsample_op = [] 280 | if stride != 1 or self.inplanes != planes * block.expansion: 281 | downsample_stride = stride 282 | if replace_stride_with_avgpool and stride > 1: 283 | downsample_op.append(nn.AvgPool2d((stride, stride), stride)) 284 | downsample_stride = 1 285 | 286 | downsample_op.append( 287 | conv1x1(self.inplanes, planes * block.expansion, downsample_stride)) 288 | downsample_op.append(FeatureNorm(norm_name_base, planes * block.expansion, 289 | num_groups=norm_groups, num_k=norm_k, 290 | attention_mode=norm_attention_mode)) 291 | 292 | if len(downsample_op) > 0: 293 | downsample = nn.Sequential(*downsample_op) 294 | 295 | layers = [] 296 | layers.append(block(self.inplanes, planes, stride, downsample, self.groups, 297 | base_width=self.base_width, dilation=previous_dilation, 298 | norm_name=norm_name, norm_groups=norm_groups, norm_k=norm_k, 299 | norm_attention_mode=norm_attention_mode, 300 | norm_all_mix=norm_all_mix)) 301 | self.inplanes = planes * block.expansion 302 | for _ in range(1, blocks): 303 | layers.append(block(self.inplanes, planes, groups=self.groups, 304 | base_width=self.base_width, dilation=self.dilation, 305 | norm_name=norm_name, norm_groups=norm_groups, norm_k=norm_k, 306 | norm_attention_mode=norm_attention_mode, 307 | norm_all_mix=norm_all_mix)) 308 | 309 | if self.extra_norm_ac: 310 | layers.append(self._extra_norm_ac(self.inplanes, norm_k)) 311 | 312 | return nn.Sequential(*layers) 313 | 314 | def forward(self, x): 315 | x = self.conv1(x) 316 | x = self.bn1(x) 317 | x = self.relu(x) 318 | if self.maxpool is not None: 319 | x = self.maxpool(x) 320 | 321 | x = self.layer1(x) 322 | x = self.layer2(x) 323 | x = self.layer3(x) 324 | if self.layer4 is not None: 325 | x = self.layer4(x) 326 | 327 | x = self.avgpool(x) 328 | x = x.reshape(x.size(0), -1) 329 | x = self.fc(x) 330 | 331 | return x 332 | 333 | 334 | def _resnet(arch, inplanes, planes, pretrained, progress, **kwargs): 335 | model = ResNet(inplanes, planes, **kwargs) 336 | if pretrained: 337 | state_dict = load_state_dict_from_url(model_urls[arch], 338 | progress=progress) 339 | model.load_state_dict(state_dict) 340 | return model 341 | 342 | 343 | def resnet18(pretrained=False, progress=True, **kwargs): 344 | """Constructs a ResNet-18 model. 345 | Args: 346 | pretrained (bool): If True, returns a model pre-trained on ImageNet 347 | progress (bool): If True, displays a progress bar of the download to stderr 348 | """ 349 | return _resnet('resnet18', BasicBlock, [2, 2, 2, 2], pretrained, progress, 350 | **kwargs) 351 | 352 | 353 | def resnet34(pretrained=False, progress=True, **kwargs): 354 | """Constructs a ResNet-34 model. 355 | Args: 356 | pretrained (bool): If True, returns a model pre-trained on ImageNet 357 | progress (bool): If True, displays a progress bar of the download to stderr 358 | """ 359 | return _resnet('resnet34', BasicBlock, [3, 4, 6, 3], pretrained, progress, 360 | **kwargs) 361 | 362 | 363 | def resnet50(pretrained=False, progress=True, **kwargs): 364 | """Constructs a ResNet-50 model. 365 | Args: 366 | pretrained (bool): If True, returns a model pre-trained on ImageNet 367 | progress (bool): If True, displays a progress bar of the download to stderr 368 | """ 369 | return _resnet('resnet50', Bottleneck, [3, 4, 6, 3], pretrained, progress, 370 | **kwargs) 371 | 372 | 373 | def resnet101(pretrained=False, progress=True, **kwargs): 374 | """Constructs a ResNet-101 model. 375 | Args: 376 | pretrained (bool): If True, returns a model pre-trained on ImageNet 377 | progress (bool): If True, displays a progress bar of the download to stderr 378 | """ 379 | return _resnet('resnet101', Bottleneck, [3, 4, 23, 3], pretrained, progress, 380 | **kwargs) 381 | 382 | 383 | def resnet152(pretrained=False, progress=True, **kwargs): 384 | """Constructs a ResNet-152 model. 385 | Args: 386 | pretrained (bool): If True, returns a model pre-trained on ImageNet 387 | progress (bool): If True, displays a progress bar of the download to stderr 388 | """ 389 | return _resnet('resnet152', Bottleneck, [3, 8, 36, 3], pretrained, progress, 390 | **kwargs) 391 | 392 | 393 | def resnext50_32x4d(**kwargs): 394 | kwargs['groups'] = 32 395 | kwargs['width_per_group'] = 4 396 | return _resnet('resnext50_32x4d', Bottleneck, [3, 4, 6, 3], 397 | pretrained=False, progress=True, **kwargs) 398 | 399 | 400 | def resnext101_32x8d(**kwargs): 401 | kwargs['groups'] = 32 402 | kwargs['width_per_group'] = 8 403 | return _resnet('resnext101_32x8d', Bottleneck, [3, 4, 23, 3], 404 | pretrained=False, progress=True, **kwargs) 405 | 406 | 407 | def resnext101_64x4d(**kwargs): 408 | kwargs['groups'] = 64 409 | kwargs['width_per_group'] = 4 410 | return _resnet('resnext101_64x4d', Bottleneck, [3, 4, 23, 3], 411 | pretrained=False, progress=True, **kwargs) 412 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | opencv-python 2 | thop 3 | yacs 4 | scipy 5 | -------------------------------------------------------------------------------- /scripts/test_fp16.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | # Usage: test_fp16.sh pretrained_model_folder 4 | 5 | DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" >/dev/null 2>&1 && pwd )" 6 | 7 | if [ "$#" -ne 1 ]; then 8 | echo "Usage: test_fp16.sh pretrained_model_folder" 9 | exit 10 | fi 11 | 12 | PRETRAINED_MODEL_PAHT=$1 13 | CONFIG_FILE=$PRETRAINED_MODEL_PAHT/config.yaml 14 | PRETRAINED_MODEL_FILE=$PRETRAINED_MODEL_PAHT/model_best.pth.tar 15 | 16 | ### Change accordingly 17 | GPUS=0,1,2,3,4,5,6,7 18 | NUM_GPUS=8 19 | NUM_WORKERS=8 20 | MASTER_PORT=1245 21 | 22 | # ImageNet 23 | DATA=$DIR/../datasets/ILSVRC2015/Data/CLS-LOC/ 24 | 25 | # test 26 | CUDA_VISIBLE_DEVICES=$GPUS python -W ignore -m torch.distributed.launch --nproc_per_node=$NUM_GPUS --master_port $MASTER_PORT \ 27 | $DIR/../tools/main_fp16.py --cfg $CONFIG_FILE --workers $NUM_WORKERS \ 28 | --fp16 \ 29 | -p 100 --save-dir $PRETRAINED_MODEL_PAHT --pretrained --evaluate $DATA 30 | 31 | 32 | 33 | -------------------------------------------------------------------------------- /scripts/train_fp16.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | # Usage: train_fp16.sh config_filename 4 | 5 | DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" >/dev/null 2>&1 && pwd )" 6 | 7 | if [ "$#" -ne 1 ]; then 8 | echo "Usage: train_fp16.sh relative_config_filename" 9 | exit 10 | fi 11 | 12 | CONFIG_FILE=$DIR/../$1 13 | 14 | ### Change accordingly 15 | GPUS=0,1,2,3,4,5,6,7 16 | NUM_GPUS=8 17 | NUM_WORKERS=8 18 | MASTER_PORT=1234 19 | 20 | CONFIG_FILENAME="$(cut -d'/' -f2 <<<$1)" 21 | CONFIG_BASE="${CONFIG_FILENAME%.*}" 22 | SAVE_DIR=$DIR/../results/$CONFIG_BASE 23 | mkdir -p $SAVE_DIR 24 | 25 | # backup for reproducing results 26 | cp $CONFIG_FILE $SAVE_DIR/config.yaml 27 | cp -r $DIR/../models $SAVE_DIR 28 | cp $DIR/../tools/main_fp16.py $SAVE_DIR 29 | 30 | # ImageNet 31 | DATA=$DIR/../datasets/ILSVRC2015/Data/CLS-LOC/ 32 | 33 | # train 34 | CUDA_VISIBLE_DEVICES=$GPUS python -W ignore -m torch.distributed.launch --nproc_per_node=$NUM_GPUS --master_port $MASTER_PORT \ 35 | $DIR/../tools/main_fp16.py --cfg $CONFIG_FILE --workers $NUM_WORKERS \ 36 | --fp16 --static-loss-scale 128 \ 37 | -p 100 --save-dir $SAVE_DIR $DATA \ 38 | 2>&1 | tee $SAVE_DIR/log.txt 39 | 40 | -------------------------------------------------------------------------------- /tools/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/iVMCL/AOGNet-v2/a95a8696c131331607e81bb31eeae3405a76b969/tools/__init__.py -------------------------------------------------------------------------------- /tools/main_fp16.py: -------------------------------------------------------------------------------- 1 | # From: https://github.com/NVIDIA/apex 2 | 3 | ### some tweaks 4 | # USE pillow-simd to speed up pytorch image loader 5 | # pip uninstall pillow 6 | # conda uninstall --force jpeg libtiff -y 7 | # conda install -c conda-forge libjpeg-turbo 8 | # CC="cc -mavx2" pip install --no-cache-dir -U --force-reinstall --no-binary :all: --compile pillow-simd 9 | 10 | # Install NCCL https://docs.nvidia.com/deeplearning/sdk/nccl-install-guide/index.html 11 | 12 | import argparse 13 | import os 14 | import sys 15 | import shutil 16 | import time 17 | import copy 18 | 19 | import torch 20 | import torch.nn as nn 21 | import torch.nn.parallel 22 | import torch.backends.cudnn as cudnn 23 | import torch.distributed as dist 24 | import torch.optim 25 | import torch.utils.data 26 | import torch.utils.data.distributed 27 | import torchvision.transforms as transforms 28 | import torchvision.datasets as datasets 29 | 30 | import numpy as np 31 | 32 | try: 33 | from apex.parallel import DistributedDataParallel as DDP 34 | from apex.fp16_utils import * 35 | except ImportError: 36 | raise ImportError("Please install apex from https://www.github.com/nvidia/apex to run this example.") 37 | 38 | 39 | import math 40 | import sys 41 | import re 42 | sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..')) 43 | from models.aognet.operator_basic import MixtureBatchNorm2d, MixtureGroupNorm 44 | from models.aognet.aognet import aognet 45 | from models.config import cfg 46 | import models.resnet as resnets 47 | import models.mobilenet as mobilenets 48 | from smoothing import LabelSmoothing 49 | 50 | parser = argparse.ArgumentParser(description='PyTorch Image Classification Training') 51 | parser.add_argument('data', metavar='DIR', 52 | help='path to dataset') 53 | parser.add_argument('-j', '--workers', default=4, type=int, metavar='N', 54 | help='number of data loading workers (default: 4)') 55 | parser.add_argument('--epochs', default=90, type=int, metavar='N', 56 | help='number of total epochs to run') 57 | parser.add_argument('--start-epoch', default=0, type=int, metavar='N', 58 | help='manual epoch number (useful on restarts)') 59 | parser.add_argument('-b', '--batch-size', default=256, type=int, 60 | metavar='N', help='mini-batch size per process (default: 256)') 61 | parser.add_argument('--lr', '--learning-rate', default=0.1, type=float, 62 | metavar='LR', help='Initial learning rate. \ 63 | Will be scaled by /256: args.lr = args.lr*float(args.batch_size*args.world_size)/256. \ 64 | A warmup schedule will also be applied over the first 5 epochs.') 65 | parser.add_argument('--momentum', default=0.9, type=float, metavar='M', 66 | help='momentum') 67 | parser.add_argument('--weight-decay', '--wd', default=1e-4, type=float, 68 | metavar='W', help='weight decay (default: 1e-4)') 69 | parser.add_argument('--print-freq', '-p', default=10, type=int, 70 | metavar='N', help='print frequency (default: 10)') 71 | parser.add_argument('--resume', default='', type=str, metavar='PATH', 72 | help='path to latest checkpoint (default: none)') 73 | parser.add_argument('-e', '--evaluate', dest='evaluate', action='store_true', 74 | help='evaluate model on validation set') 75 | parser.add_argument('--pretrained', dest='pretrained', action='store_true', 76 | help='use pre-trained model') 77 | 78 | parser.add_argument('--fp16', action='store_true', 79 | help='Run model fp16 mode.') 80 | parser.add_argument('--static-loss-scale', type=float, default=1, 81 | help='Static loss scale, positive power of 2 values can improve fp16 convergence.') 82 | parser.add_argument('--dynamic-loss-scale', action='store_true', 83 | help='Use dynamic loss scaling. If supplied, this argument supersedes ' + 84 | '--static-loss-scale.') 85 | parser.add_argument('--prof', dest='prof', action='store_true', 86 | help='Only run 10 iterations for profiling.') 87 | parser.add_argument('--deterministic', action='store_true') 88 | 89 | parser.add_argument("--local_rank", default=0, type=int) 90 | parser.add_argument('--sync_bn', action='store_true', 91 | help='enabling apex sync BN.') 92 | 93 | parser.add_argument('--cfg', help='experiment configure file name', required=True, type=str) 94 | parser.add_argument('--save-dir', type=str, default='/tmp/models') 95 | parser.add_argument('--nesterov', type=str, default=None) 96 | parser.add_argument('--remove-norm-weight-decay', type=str, default=None) 97 | 98 | cudnn.benchmark = True 99 | 100 | def fast_collate(batch): 101 | imgs = [img[0] for img in batch] 102 | targets = torch.tensor([target[1] for target in batch], dtype=torch.int64) 103 | w = imgs[0].size[0] 104 | h = imgs[0].size[1] 105 | tensor = torch.zeros( (len(imgs), 3, h, w), dtype=torch.uint8 ) 106 | for i, img in enumerate(imgs): 107 | nump_array = np.asarray(img, dtype=np.uint8) 108 | if(nump_array.ndim < 3): 109 | nump_array = np.expand_dims(nump_array, axis=-1) 110 | nump_array = np.rollaxis(nump_array, 2) 111 | 112 | tensor[i] += torch.from_numpy(nump_array) 113 | 114 | return tensor, targets 115 | 116 | best_prec1 = 0 117 | best_prec1_val = 0 118 | prec5_val = 0 119 | best_prec5_val = 0 120 | 121 | args = parser.parse_args() 122 | 123 | if args.local_rank == 0: 124 | print("PyTorch VERSION: {}".format(torch.__version__)) # PyTorch version 125 | print("CUDA VERSION: {}".format(torch.version.cuda)) # Corresponding CUDA version 126 | print("CUDNN VERSION: {}".format(torch.backends.cudnn.version())) # Corresponding cuDNN version 127 | print("GPU TYPE: {}".format(torch.cuda.get_device_name(0))) # GPU type 128 | 129 | if args.deterministic: 130 | cudnn.benchmark = False 131 | cudnn.deterministic = True 132 | torch.manual_seed(args.local_rank) 133 | torch.set_printoptions(precision=10) 134 | 135 | def main(): 136 | global best_prec1, args 137 | 138 | args.distributed = False 139 | if 'WORLD_SIZE' in os.environ: 140 | args.distributed = int(os.environ['WORLD_SIZE']) > 1 141 | 142 | args.gpu = 0 143 | args.world_size = 1 144 | 145 | if args.distributed: 146 | args.gpu = args.local_rank 147 | torch.cuda.set_device(args.gpu) 148 | torch.distributed.init_process_group(backend='nccl', 149 | init_method='env://') 150 | args.world_size = torch.distributed.get_world_size() 151 | 152 | if args.fp16: 153 | assert torch.backends.cudnn.enabled, "fp16 requires cudnn backend to be enabled." 154 | if args.static_loss_scale != 1.0: 155 | if not args.fp16: 156 | print("Warning: if --fp16 is not used, static_loss_scale will be ignored.") 157 | 158 | # create model 159 | if args.pretrained: 160 | cfg.merge_from_file(os.path.join(args.save_dir, 'config.yaml')) 161 | else: 162 | cfg.merge_from_file(args.cfg) 163 | args.arch = cfg.arch 164 | # update args 165 | args.batch_size = cfg.batch_size 166 | args.lr = cfg.lr 167 | args.momentum = cfg.momentum 168 | args.weight_decay = cfg.wd 169 | args.nesterov = cfg.nesterov 170 | args.epochs = cfg.num_epoch 171 | 172 | if args.local_rank == 0: 173 | print("=> creating {}".format(args.arch)) 174 | if args.arch.startswith('aognet'): 175 | model = aognet() 176 | elif args.arch.startswith('resnet') or args.arch.startswith('resnext'): 177 | model = resnets.__dict__[args.arch]() 178 | elif args.arch.startswith('mobilenet'): 179 | model = mobilenets.__dict__[args.arch]() 180 | else: 181 | raise NotImplementedError("Unkown network arch.") 182 | 183 | if args.pretrained: 184 | if args.local_rank == 0: 185 | print("=> loading pre-trained model '{}'".format(args.arch)) 186 | checkpoint = torch.load(os.path.join(args.save_dir, 'model_best.pth.tar'), map_location='cpu') 187 | st_dict = {k[15:]: v for k, v in checkpoint['state_dict'].items()} 188 | model.load_state_dict(st_dict) 189 | 190 | if args.local_rank == 0: 191 | print('=> Params (double-check): %.6fM' % (sum(p.numel() for p in model.parameters()) / 1e6)) 192 | 193 | #sys.exit() 194 | 195 | if args.sync_bn: 196 | import apex 197 | if args.local_rank == 0: 198 | print("using apex synced BN") 199 | model = apex.parallel.convert_syncbn_model(model) 200 | 201 | model = model.cuda() 202 | if args.fp16: 203 | model = FP16Model(model) 204 | if args.distributed: 205 | # By default, apex.parallel.DistributedDataParallel overlaps communication with 206 | # computation in the backward pass. 207 | # model = DDP(model) 208 | # delay_allreduce delays all communication to the end of the backward pass. 209 | model = DDP(model, delay_allreduce=True) 210 | 211 | # Scale learning rate based on global batch size 212 | args.lr = args.lr*float(args.batch_size*args.world_size)/cfg.lr_scale_factor #TODO: control the maximum? 213 | 214 | if args.remove_norm_weight_decay: 215 | if args.local_rank == 0: 216 | print("=> ! Weight decay NOT applied to FeatNorm parameters ") 217 | norm_params=set() #TODO: need to check this via experiments 218 | rest_params=set() 219 | for m in model.modules(): 220 | if isinstance(m, (nn.BatchNorm2d, nn.GroupNorm, MixtureBatchNorm2d, MixtureGroupNorm)): 221 | for param in m.parameters(False): 222 | norm_params.add(param) 223 | else: 224 | for param in m.parameters(False): 225 | rest_params.add(param) 226 | 227 | optimizer = torch.optim.SGD([{'params': list(norm_params), 'weight_decay' : 0.0}, 228 | {'params': list(rest_params)}], 229 | args.lr, 230 | momentum=args.momentum, 231 | weight_decay=args.weight_decay, 232 | nesterov=args.nesterov) 233 | else: 234 | if args.local_rank == 0: 235 | print("=> ! Weight decay applied to FeatNorm parameters ") 236 | optimizer = torch.optim.SGD(model.parameters(), args.lr, 237 | momentum=args.momentum, 238 | weight_decay=args.weight_decay, 239 | nesterov=args.nesterov) 240 | 241 | if args.fp16: 242 | optimizer = FP16_Optimizer(optimizer, 243 | static_loss_scale=args.static_loss_scale, 244 | dynamic_loss_scale=args.dynamic_loss_scale) 245 | 246 | # define loss function (criterion) and optimizer 247 | criterion_train = nn.CrossEntropyLoss().cuda() if cfg.dataaug.labelsmoothing_rate == 0.0 \ 248 | else LabelSmoothing(cfg.dataaug.labelsmoothing_rate).cuda() 249 | criterion_val = nn.CrossEntropyLoss().cuda() 250 | 251 | # Optionally resume from a checkpoint 252 | if args.resume: 253 | # Use a local scope to avoid dangling references 254 | def resume(): 255 | if os.path.isfile(args.resume): 256 | if args.local_rank == 0: 257 | print("=> loading checkpoint '{}'".format(args.resume)) 258 | checkpoint = torch.load(args.resume, map_location = lambda storage, loc: storage.cuda(args.gpu)) 259 | args.start_epoch = checkpoint['epoch'] 260 | best_prec1 = checkpoint['best_prec1'] 261 | model.load_state_dict(checkpoint['state_dict']) 262 | optimizer.load_state_dict(checkpoint['optimizer']) 263 | if args.local_rank == 0: 264 | print("=> loaded checkpoint '{}' (epoch {})" 265 | .format(args.resume, checkpoint['epoch'])) 266 | else: 267 | if args.local_rank == 0: 268 | print("=> no checkpoint found at '{}'".format(args.resume)) 269 | resume() 270 | 271 | # Data loading code 272 | lr_milestones = None 273 | if cfg.dataset == "cifar10": 274 | train_transform = transforms.Compose([ 275 | transforms.RandomCrop(32, padding=4), 276 | transforms.RandomHorizontalFlip() 277 | ]) 278 | train_dataset = datasets.CIFAR10('./datasets', train=True, download=False, transform=train_transform) 279 | val_dataset = datasets.CIFAR10('./datasets', train=False, download=False) 280 | lr_milestones = cfg.lr_milestones 281 | elif cfg.dataset == "cifar100": 282 | train_transform = transforms.Compose([ 283 | transforms.RandomCrop(32, padding=4), 284 | transforms.RandomHorizontalFlip() 285 | ]) 286 | train_dataset = datasets.CIFAR100('./datasets', train=True, download=False, transform=train_transform) 287 | val_dataset = datasets.CIFAR100('./datasets', train=False, download=False) 288 | lr_milestones = cfg.lr_milestones 289 | elif cfg.dataset == "imagenet": 290 | traindir = os.path.join(args.data, 'train') 291 | valdir = os.path.join(args.data, 'val') 292 | 293 | crop_size = cfg.crop_size # 224 294 | val_size = cfg.crop_size + 32 # 256 295 | 296 | train_dataset = datasets.ImageFolder( 297 | traindir, 298 | transforms.Compose([ 299 | transforms.RandomResizedCrop(crop_size, interpolation=cfg.crop_interpolation), 300 | transforms.RandomHorizontalFlip(), 301 | # transforms.ToTensor(), Too slow 302 | # normalize, 303 | ])) 304 | val_dataset = datasets.ImageFolder(valdir, transforms.Compose([ 305 | transforms.Resize(val_size, interpolation=cfg.crop_interpolation), 306 | transforms.CenterCrop(crop_size), 307 | ])) 308 | 309 | train_sampler = None 310 | val_sampler = None 311 | if args.distributed: 312 | train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset) 313 | val_sampler = torch.utils.data.distributed.DistributedSampler(val_dataset) 314 | 315 | train_loader = torch.utils.data.DataLoader( 316 | train_dataset, batch_size=args.batch_size, shuffle=(train_sampler is None), 317 | num_workers=args.workers, pin_memory=True, sampler=train_sampler, collate_fn=fast_collate) 318 | 319 | val_loader = torch.utils.data.DataLoader( 320 | val_dataset, 321 | batch_size=args.batch_size, shuffle=False, 322 | num_workers=args.workers, pin_memory=True, 323 | sampler=val_sampler, 324 | collate_fn=fast_collate) 325 | 326 | if args.evaluate: 327 | validate(val_loader, model, criterion_val) 328 | return 329 | 330 | scheduler = CosineAnnealingLR(optimizer.optimizer if args.fp16 else optimizer, 331 | args.epochs, len(train_loader), 332 | eta_min=cfg.cosine_lr_min, warmup=cfg.warmup_epochs) if cfg.use_cosine_lr else None 333 | 334 | for epoch in range(args.start_epoch, args.epochs): 335 | if args.distributed: 336 | train_sampler.set_epoch(epoch) 337 | 338 | # train for one epoch 339 | train(train_loader, model, criterion_train, optimizer, epoch, scheduler, lr_milestones, cfg.warmup_epochs, 340 | cfg.dataaug.mixup_rate, cfg.dataaug.labelsmoothing_rate) 341 | if args.prof: 342 | break 343 | # evaluate on validation set 344 | prec1 = validate(val_loader, model, criterion_val) 345 | 346 | # remember best prec@1 and save checkpoint 347 | if args.local_rank == 0: 348 | is_best = prec1 > best_prec1 349 | best_prec1 = max(prec1, best_prec1) 350 | save_checkpoint({ 351 | 'epoch': epoch + 1, 352 | 'arch': args.arch, 353 | 'state_dict': model.state_dict(), 354 | 'best_prec1': best_prec1, 355 | 'optimizer' : optimizer.state_dict(), 356 | }, is_best, args.save_dir) 357 | 358 | class data_prefetcher(): 359 | def __init__(self, loader): 360 | self.loader = iter(loader) 361 | self.stream = torch.cuda.Stream() 362 | if cfg.dataset == 'cifar10': 363 | self.mean = torch.tensor([0.49139968 * 255, 0.48215827 * 255, 0.44653124 * 255]).cuda().view(1,3,1,1) 364 | self.std = torch.tensor([0.24703233 * 255, 0.24348505 * 255, 0.26158768 * 255]).cuda().view(1,3,1,1) 365 | elif cfg.dataset == 'cifar100': 366 | self.mean = torch.tensor([0.5071 * 255, 0.4867 * 255, 0.4408 * 255]).cuda().view(1,3,1,1) 367 | self.std = torch.tensor([0.2675 * 255, 0.2565 * 255, 0.2761 * 255]).cuda().view(1,3,1,1) 368 | elif cfg.dataset == 'imagenet': 369 | self.mean = torch.tensor([0.485 * 255, 0.456 * 255, 0.406 * 255]).cuda().view(1,3,1,1) 370 | self.std = torch.tensor([0.229 * 255, 0.224 * 255, 0.225 * 255]).cuda().view(1,3,1,1) 371 | else: 372 | raise NotImplementedError 373 | if args.fp16: 374 | self.mean = self.mean.half() 375 | self.std = self.std.half() 376 | self.preload() 377 | 378 | def preload(self): 379 | try: 380 | self.next_input, self.next_target = next(self.loader) 381 | except StopIteration: 382 | self.next_input = None 383 | self.next_target = None 384 | return 385 | with torch.cuda.stream(self.stream): 386 | self.next_input = self.next_input.cuda(non_blocking=True) 387 | self.next_target = self.next_target.cuda(non_blocking=True) 388 | if args.fp16: 389 | self.next_input = self.next_input.half() 390 | else: 391 | self.next_input = self.next_input.float() 392 | self.next_input = self.next_input.sub_(self.mean).div_(self.std) 393 | 394 | def next(self): 395 | torch.cuda.current_stream().wait_stream(self.stream) 396 | input = self.next_input 397 | target = self.next_target 398 | self.preload() 399 | return input, target 400 | 401 | # from NVIDIA DL Examples 402 | def prefetched_loader(loader): 403 | if cfg.dataset == 'cifar10': 404 | self.mean = torch.tensor([0.49139968 * 255, 0.48215827 * 255, 0.44653124 * 255]).cuda().view(1,3,1,1) 405 | self.std = torch.tensor([0.24703233 * 255, 0.24348505 * 255, 0.26158768 * 255]).cuda().view(1,3,1,1) 406 | elif cfg.dataset == 'cifar100': 407 | self.mean = torch.tensor([0.5071 * 255, 0.4867 * 255, 0.4408 * 255]).cuda().view(1,3,1,1) 408 | self.std = torch.tensor([0.2675 * 255, 0.2565 * 255, 0.2761 * 255]).cuda().view(1,3,1,1) 409 | elif cfg.dataset == 'imagenet': 410 | self.mean = torch.tensor([0.485 * 255, 0.456 * 255, 0.406 * 255]).cuda().view(1,3,1,1) 411 | self.std = torch.tensor([0.229 * 255, 0.224 * 255, 0.225 * 255]).cuda().view(1,3,1,1) 412 | else: 413 | raise NotImplementedError 414 | 415 | stream = torch.cuda.Stream() 416 | first = True 417 | 418 | for next_input, next_target in loader: 419 | with torch.cuda.stream(stream): 420 | next_input = next_input.cuda(non_blocking=True) 421 | next_target = next_target.cuda(non_blocking=True) 422 | next_input = next_input.float() 423 | next_input = next_input.sub_(mean).div_(std) 424 | 425 | if not first: 426 | yield input, target 427 | else: 428 | first = False 429 | 430 | torch.cuda.current_stream().wait_stream(stream) 431 | input = next_input 432 | target = next_target 433 | 434 | yield input, target 435 | 436 | 437 | def train(train_loader, model, criterion, optimizer, epoch, scheduler=None, lr_milestones=None, warmup_epoch=0, 438 | mixup_rate=0.0, labelsmoothing_rate=0.0): 439 | batch_time = AverageMeter() 440 | data_time = AverageMeter() 441 | losses = AverageMeter() 442 | top1 = AverageMeter() 443 | top5 = AverageMeter() 444 | 445 | # switch to train mode 446 | model.train() 447 | end = time.time() 448 | 449 | prefetcher = data_prefetcher(train_loader) 450 | input, target = prefetcher.next() 451 | i = -1 452 | beta_distribution = torch.distributions.beta.Beta(mixup_rate, mixup_rate) 453 | while input is not None: 454 | i += 1 455 | 456 | if scheduler is None: 457 | lr = adjust_learning_rate(optimizer, epoch, i, len(train_loader), lr_milestones, warmup_epoch) 458 | else: 459 | lr = scheduler.update(epoch, i) 460 | 461 | if args.prof: 462 | if i > 10: 463 | break 464 | # measure data loading time 465 | data_time.update(time.time() - end) 466 | 467 | # Mixup input 468 | if mixup_rate > 0.0: 469 | lambda_ = beta_distribution.sample([]).item() 470 | index = torch.randperm(input.size(0)).cuda() 471 | input = lambda_ * input + (1 - lambda_) * input[index, :] 472 | 473 | # compute output 474 | if args.prof: torch.cuda.nvtx.range_push("forward") 475 | output = model(input) 476 | if args.prof: torch.cuda.nvtx.range_pop() 477 | 478 | # Mixup loss 479 | if mixup_rate > 0.0: 480 | # Mixup loss 481 | loss = (lambda_ * criterion(output, target) 482 | + (1 - lambda_) * criterion(output, target[index])) 483 | 484 | # Mixup target 485 | if labelsmoothing_rate > 0.0: 486 | N = output.size(0) 487 | C = output.size(1) 488 | off_prob = labelsmoothing_rate / C 489 | target_1 = torch.full(size=(N, C), fill_value=off_prob ).cuda() 490 | target_2 = torch.full(size=(N, C), fill_value=off_prob ).cuda() 491 | target_1.scatter_(dim=1, index=torch.unsqueeze(target, dim=1), value=1.0-labelsmoothing_rate+off_prob) 492 | target_2.scatter_(dim=1, index=torch.unsqueeze(target[index], dim=1), value=1.0-labelsmoothing_rate+off_prob) 493 | target = lambda_ * target_1 + (1 - lambda_) * target_2 494 | else: 495 | target = lambda_ * target + (1 - lambda_) * target[index] 496 | else: 497 | loss = criterion(output, target) 498 | 499 | # compute gradient and do SGD step 500 | optimizer.zero_grad() 501 | 502 | if args.prof: torch.cuda.nvtx.range_push("backward") 503 | if args.fp16: 504 | optimizer.backward(loss) 505 | else: 506 | loss.backward() 507 | if args.prof: torch.cuda.nvtx.range_pop() 508 | 509 | # debug 510 | # if args.local_rank == 0: 511 | # for name_, param in model.named_parameters(): 512 | # print(name_, param.data.double().sum().item(), param.grad.data.double().sum().item()) 513 | 514 | if args.prof: torch.cuda.nvtx.range_push("step") 515 | optimizer.step() 516 | if args.prof: torch.cuda.nvtx.range_pop() 517 | 518 | # Measure accuracy 519 | if mixup_rate > 0.0: 520 | prec1 = rmse(output.data, target) 521 | prec5 = prec1 522 | else: 523 | prec1, prec5 = accuracy(output.data, target, topk=(1, 5)) 524 | 525 | # Average loss and accuracy across processes for logging 526 | if args.distributed: 527 | reduced_loss = reduce_tensor(loss.data) 528 | prec1 = reduce_tensor(prec1) 529 | prec5 = reduce_tensor(prec5) 530 | else: 531 | reduced_loss = loss.data 532 | 533 | # to_python_float incurs a host<->device sync 534 | losses.update(to_python_float(reduced_loss), input.size(0)) 535 | top1.update(to_python_float(prec1), input.size(0)) 536 | top5.update(to_python_float(prec5), input.size(0)) 537 | 538 | # torch.cuda.synchronize() # no this in torchvision ex. and cause nan loss problems in deep models with fp16 539 | 540 | batch_time.update(time.time() - end) 541 | end = time.time() 542 | input, target = prefetcher.next() 543 | 544 | if i%args.print_freq == 0 and args.local_rank == 0: 545 | # Every print_freq iterations, check the loss, accuracy, and speed. 546 | print('Epoch: [{0}][{1}/{2}]\t' 547 | 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' 548 | 'Speed {3:.3f} ({4:.3f})\t' 549 | 'Data {data_time.val:.3f} ({data_time.avg:.3f})\t' 550 | 'Loss {loss.val:.10f} ({loss.avg:.4f})\t' 551 | 'Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t' 552 | 'Prec@5 {top5.val:.3f} ({top5.avg:.3f})\t' 553 | 'lr {lr:.6f}\t'.format( 554 | epoch, i, len(train_loader), 555 | args.world_size*args.batch_size/batch_time.val, 556 | args.world_size*args.batch_size/batch_time.avg, 557 | batch_time=batch_time, 558 | data_time=data_time, loss=losses, top1=top1, top5=top5, lr=lr[0])) 559 | 560 | def validate(val_loader, model, criterion): 561 | global best_prec1_val, prec5_val, best_prec5_val 562 | batch_time = AverageMeter() 563 | losses = AverageMeter() 564 | top1 = AverageMeter() 565 | top5 = AverageMeter() 566 | 567 | # switch to evaluate mode 568 | model.eval() 569 | 570 | end = time.time() 571 | 572 | prefetcher = data_prefetcher(val_loader) 573 | input, target = prefetcher.next() 574 | i = -1 575 | while input is not None: 576 | i += 1 577 | 578 | # compute output 579 | with torch.no_grad(): 580 | output = model(input) 581 | loss = criterion(output, target) 582 | 583 | # measure accuracy and record loss 584 | prec1, prec5 = accuracy(output.data, target, topk=(1, 5)) 585 | 586 | if args.distributed: 587 | reduced_loss = reduce_tensor(loss.data) 588 | prec1 = reduce_tensor(prec1) 589 | prec5 = reduce_tensor(prec5) 590 | else: 591 | reduced_loss = loss.data 592 | 593 | losses.update(to_python_float(reduced_loss), input.size(0)) 594 | top1.update(to_python_float(prec1), input.size(0)) 595 | top5.update(to_python_float(prec5), input.size(0)) 596 | 597 | # measure elapsed time 598 | batch_time.update(time.time() - end) 599 | end = time.time() 600 | 601 | # TODO: Change timings to mirror train(). 602 | if args.local_rank == 0 and i > 0 and i % args.print_freq == 0: 603 | print('Test: [{0}/{1}]\t' 604 | 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' 605 | 'Speed {2:.3f} ({3:.3f})\t' 606 | 'Loss {loss.val:.4f} ({loss.avg:.4f})\t' 607 | 'Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t' 608 | 'Prec@5 {top5.val:.3f} ({top5.avg:.3f})'.format( 609 | i, len(val_loader), 610 | args.world_size * args.batch_size / batch_time.val, 611 | args.world_size * args.batch_size / batch_time.avg, 612 | batch_time=batch_time, loss=losses, 613 | top1=top1, top5=top5)) 614 | 615 | input, target = prefetcher.next() 616 | 617 | if args.local_rank == 0: 618 | if top1.avg >= best_prec1_val: 619 | best_prec1_val = top1.avg 620 | prec5_val = top5.avg 621 | best_prec5_val = max(best_prec5_val, top5.avg) 622 | print('Test: Prec@1 {top1.avg:.3f} Prec@5 {top5.avg:.3f}\t Best_Prec@1 {best:.3f}\t Prec@5 {prec5_val:.3f}\t Best_Prec@5 {bestprec5_val:.3f}' 623 | .format(top1=top1, top5=top5, best=best_prec1_val, prec5_val=prec5_val, bestprec5_val=best_prec5_val)) 624 | 625 | return top1.avg 626 | 627 | 628 | def save_checkpoint(state, is_best, save_dir='./'): 629 | filename = os.path.join(save_dir, 'checkpoint.pth.tar') 630 | best_file = os.path.join(save_dir, 'model_best.pth.tar') 631 | torch.save(state, filename) 632 | if is_best: 633 | shutil.copyfile(filename, best_file) 634 | 635 | class AverageMeter(object): 636 | """Computes and stores the average and current value""" 637 | def __init__(self): 638 | self.reset() 639 | 640 | def reset(self): 641 | self.val = 0 642 | self.avg = 0 643 | self.sum = 0 644 | self.count = 0 645 | 646 | def update(self, val, n=1): 647 | self.val = val 648 | self.sum += val * n 649 | self.count += n 650 | self.avg = self.sum / self.count 651 | 652 | 653 | def adjust_learning_rate(optimizer, epoch, step, len_epoch, lr_milestones=None, warmup_epoch=0): 654 | """LR schedule that should yield 76% converged accuracy with batch size 256""" 655 | # if not isinstance(optimizer, torch.optim.Optimizer): 656 | # raise TypeError('{} is not an Optimizer'.format( 657 | # type(optimizer).__name__)) 658 | if lr_milestones is None: 659 | factor = epoch // 30 660 | 661 | if epoch >= 80: 662 | factor = factor + 1 663 | 664 | lr = args.lr*(0.1**factor) 665 | 666 | """Warmup""" 667 | if epoch < 5: 668 | lr = lr*float(1 + step + epoch*len_epoch)/(5.*len_epoch) 669 | 670 | else: 671 | factor = 0 672 | for m in lr_milestones: 673 | if epoch >= m: 674 | factor += 1 675 | 676 | lr = args.lr*(0.1**factor) 677 | 678 | """Warmup""" 679 | if epoch < warmup_epoch: 680 | lr = lr*float(1 + step + epoch*len_epoch)/(warmup_epoch*len_epoch) 681 | 682 | 683 | # if(args.local_rank == 0): 684 | # print("epoch = {}, step = {}, lr = {}".format(epoch, step, lr)) 685 | 686 | for param_group in optimizer.param_groups: 687 | param_group['lr'] = lr 688 | 689 | return [lr] 690 | 691 | 692 | class CosineAnnealingLR(object): 693 | def __init__(self, optimizer, T_max, N_batch, eta_min=0, last_epoch=-1, warmup=0): 694 | if not isinstance(optimizer, torch.optim.Optimizer): 695 | raise TypeError('{} is not an Optimizer'.format( 696 | type(optimizer).__name__)) 697 | self.optimizer = optimizer 698 | self.T_max = T_max 699 | self.N_batch = N_batch 700 | self.eta_min = eta_min 701 | self.warmup = warmup 702 | 703 | if last_epoch == -1: 704 | for group in optimizer.param_groups: 705 | group.setdefault('initial_lr', group['lr']) 706 | else: 707 | for i, group in enumerate(optimizer.param_groups): 708 | if 'initial_lr' not in group: 709 | raise KeyError("param 'initial_lr' is not specified " 710 | "in param_groups[{}] when resuming an optimizer".format(i)) 711 | self.base_lrs = list(map(lambda group: group['initial_lr'], optimizer.param_groups)) 712 | self.update(last_epoch+1) 713 | self.last_epoch = last_epoch 714 | self.iter = 0 715 | 716 | def state_dict(self): 717 | return {key: value for key, value in self.__dict__.items() if key != 'optimizer'} 718 | 719 | def load_state_dict(self, state_dict): 720 | self.__dict__.update(state_dict) 721 | 722 | def get_lr(self): 723 | if self.last_epoch < self.warmup: 724 | lrs = [base_lr * (self.last_epoch + self.iter / self.N_batch) / self.warmup for base_lr in self.base_lrs] 725 | else: 726 | lrs = [self.eta_min + (base_lr - self.eta_min) * 727 | (1 + math.cos(math.pi * (self.last_epoch - self.warmup + self.iter / self.N_batch) / (self.T_max - self.warmup))) / 2 728 | for base_lr in self.base_lrs] 729 | return lrs 730 | 731 | def update(self, epoch, batch=0): 732 | self.last_epoch = epoch 733 | self.iter = batch + 1 734 | lrs = self.get_lr() 735 | for param_group, lr in zip(self.optimizer.param_groups, lrs): 736 | param_group['lr'] = lr 737 | 738 | return lrs 739 | 740 | 741 | def accuracy(output, target, topk=(1,)): 742 | """Computes the precision@k for the specified values of k""" 743 | maxk = max(topk) 744 | batch_size = target.size(0) 745 | 746 | _, pred = output.topk(maxk, 1, True, True) 747 | pred = pred.t() 748 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 749 | 750 | res = [] 751 | for k in topk: 752 | correct_k = correct[:k].view(-1).float().sum(0, keepdim=True) 753 | res.append(correct_k.mul_(100.0 / batch_size)) 754 | return res 755 | 756 | def rmse(yhat,y): 757 | if args.fp16: 758 | res = torch.sqrt(torch.mean((yhat.float()-y.float())**2)) 759 | else: 760 | res = torch.sqrt(torch.mean((yhat-y)**2)) 761 | return res 762 | 763 | def reduce_tensor(tensor): 764 | rt = tensor.clone() 765 | dist.all_reduce(rt, op=dist.ReduceOp.SUM) 766 | rt /= args.world_size 767 | return rt 768 | 769 | if __name__ == '__main__': 770 | # to suppress annoying warnings 771 | import warnings 772 | warnings.filterwarnings("ignore", "(Possibly )?corrupt EXIF data", UserWarning) 773 | 774 | main() 775 | -------------------------------------------------------------------------------- /tools/smoothing.py: -------------------------------------------------------------------------------- 1 | # From: https://github.com/NVIDIA/DeepLearningExamples/tree/master/PyTorch/Classification/RN50v1.5 2 | # commit a1aff31 3 | # Date: 05/01/2019 4 | # Note: check the updates in NVIDIA DeepLearningExamples regulary 5 | import torch 6 | import torch.nn as nn 7 | 8 | class LabelSmoothing(nn.Module): 9 | """ 10 | NLL loss with label smoothing. 11 | """ 12 | def __init__(self, smoothing=0.0): 13 | """ 14 | Constructor for the LabelSmoothing module. 15 | 16 | :param smoothing: label smoothing factor 17 | """ 18 | super(LabelSmoothing, self).__init__() 19 | self.confidence = 1.0 - smoothing 20 | self.smoothing = smoothing 21 | 22 | def forward(self, x, target): 23 | logprobs = torch.nn.functional.log_softmax(x, dim=-1) 24 | 25 | nll_loss = -logprobs.gather(dim=-1, index=target.unsqueeze(1)) 26 | nll_loss = nll_loss.squeeze(1) 27 | smooth_loss = -logprobs.mean(dim=-1) 28 | loss = self.confidence * nll_loss + self.smoothing * smooth_loss 29 | return loss.mean() 30 | 31 | --------------------------------------------------------------------------------