├── .gitignore
├── LICENSE
├── README.md
├── configs
├── aognet-12m-an-imagenet-200e-lbsmth-mixup.yaml
├── aognet-12m-an-imagenet.yaml
├── aognet-12m-imagenet-200e-lbsmth-mixup.yaml
├── aognet-12m-imagenet.yaml
├── aognet-40m-an-imagenet-200e-lbsmth-mixup.yaml
├── aognet-40m-an-imagenet.yaml
├── aognet-40m-imagenet-200e-lbsmth-mixup.yaml
├── aognet-40m-imagenet.yaml
├── mobilenetv2-an-imagenet.yaml
├── resnet-101-an-imagenet-200e-lbsmth-mixup.yaml
├── resnet-101-an-imagenet.yaml
├── resnet-101-imagenet-200e-lbsmth-mixup.yaml
├── resnet-101-imagenet.yaml
├── resnet-50-an-imagenet-200e-lbsmth-mixup.yaml
├── resnet-50-an-imagenet.yaml
├── resnet-50-imagenet-200e-lbsmth-mixup.yaml
└── resnet-50-imagenet.yaml
├── images
├── an-comparison.png
├── teaser-imagenet-dissection.png
└── teaser.png
├── models
├── __init__.py
├── aognet
│ ├── AOG.py
│ ├── __init__.py
│ ├── aognet.py
│ ├── operator_basic.py
│ └── operator_singlescale.py
├── config.py
├── mobilenet.py
└── resnet.py
├── requirements.txt
├── scripts
├── test_fp16.sh
└── train_fp16.sh
└── tools
├── __init__.py
├── main_fp16.py
└── smoothing.py
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | env/
12 | build/
13 | develop-eggs/
14 | dist/
15 | downloads/
16 | eggs/
17 | .eggs/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | *.egg-info/
23 | .installed.cfg
24 | *.egg
25 |
26 | # PyInstaller
27 | # Usually these files are written by a python script from a template
28 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
29 | *.manifest
30 | *.spec
31 |
32 | # Installer logs
33 | pip-log.txt
34 | pip-delete-this-directory.txt
35 |
36 | # Unit test / coverage reports
37 | htmlcov/
38 | .tox/
39 | .coverage
40 | .coverage.*
41 | .cache
42 | .idea
43 | nosetests.xml
44 | coverage.xml
45 | *,cover
46 | .hypothesis/
47 |
48 | # Translations
49 | *.mo
50 | *.pot
51 |
52 | # Django stuff:
53 | #*.log
54 | temp*.log
55 | local_settings.py
56 |
57 | # Flask stuff:
58 | instance/
59 | .webassets-cache
60 |
61 | # Scrapy stuff:
62 | .scrapy
63 |
64 | # Sphinx documentation
65 | docs/_build/
66 |
67 | # PyBuilder
68 | target/
69 |
70 | # IPython Notebook
71 | .ipynb_checkpoints
72 |
73 | # pyenv
74 | .python-version
75 |
76 | # celery beat schedule file
77 | celerybeat-schedule
78 |
79 | # dotenv
80 | .env
81 |
82 | # virtualenv
83 | venv/
84 | ENV/
85 |
86 | # Spyder project settings
87 | .spyderproject
88 |
89 | # Rope project settings
90 | .ropeproject
91 |
92 | # vscode
93 | .vscode
94 |
95 | # this project
96 | data/
97 | datasets/
98 | results/
99 | nohup.out
100 | *.tar
101 | *.log
102 | .DS_Store
103 | pretrained_models/
104 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | RESEARCH ONLY LICENSE
2 | Copyright (c) 2018-2019 North Carolina State University.
3 | All rights reserved.
4 | Redistribution and use in source and binary forms, with or without modification, are permitted provided
5 | that the following conditions are met:
6 | 1. Redistributions and use are permitted for internal research purposes only, and commercial use
7 | is strictly prohibited under this license. Inquiries regarding commercial use should be directed to the
8 | Office of Research Commercialization at North Carolina State University, 919-215-7199,
9 | https://research.ncsu.edu/commercialization/contact/, commercialization@ncsu.edu .
10 | 2. Commercial use means the sale, lease, export, transfer, conveyance or other distribution to a
11 | third party for financial gain, income generation or other commercial purposes of any kind, whether
12 | direct or indirect. Commercial use also means providing a service to a third party for financial gain,
13 | income generation or other commercial purposes of any kind, whether direct or indirect.
14 | 3. Redistributions of source code must retain the above copyright notice, this list of conditions and
15 | the following disclaimer.
16 | 4. Redistributions in binary form must reproduce the above copyright notice, this list of conditions
17 | and the following disclaimer in the documentation and/or other materials provided with the
18 | distribution.
19 | 5. The names “North Carolina State University”, “NCSU” and any trade-name, personal name,
20 | trademark, trade device, service mark, symbol, image, icon, or any abbreviation, contraction or
21 | simulation thereof owned by North Carolina State University must not be used to endorse or promote
22 | products derived from this software without prior written permission. For written permission, please
23 | contact trademarks@ncsu.edu.
24 | Disclaimer: THIS SOFTWARE IS PROVIDED “AS IS” AND ANY EXPRESSED OR IMPLIED WARRANTIES,
25 | INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
26 | A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NORTH CAROLINA STATE UNIVERSITY BE
27 | LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
28 | (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
29 | DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
30 | LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR
31 | OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
32 | POSSIBILITY OF SUCH DAMAGE.
33 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 | # AOGNet-v2
3 |
4 | **Please check our refactorized code at [iVMCL-Release](https://github.com/iVMCL/iVMCL-Release).**
5 |
6 | This project provides source code for [AOGNets: Compositional Grammatical Architectures for Deep Learning
7 | ](https://arxiv.org/abs/1711.05847)(CVPR 2019) and [Attentive Normalization](https://arxiv.org/abs/1908.01259).
8 |
9 |
10 |
11 | ## Installation
12 |
13 | 1. Create a conda environment, and install all dependencies.
14 | ```
15 | conda create -n aognet-v2 python=3.7
16 | conda activate aognet-v2
17 | git clone https://github.com/iVMCL/AOGNet-v2
18 | cd AOGNet-v2
19 | pip install -r requirements.txt
20 | ```
21 | 2. Install PyTorch 1.0+, follows [https://pytorch.org/](https://pytorch.org/)
22 |
23 | 3. Install apex, follows [https://github.com/NVIDIA/apex](https://github.com/NVIDIA/apex)
24 |
25 |
26 | ## ImageNet dataset preparation
27 |
28 | - Download the [ImageNet dataset](http://image-net.org/download) to YOUR_IMAGENET_PATH and move validation images to labeled subfolders
29 | - The [script](https://raw.githubusercontent.com/soumith/imagenetloader.torch/master/valprep.sh) may be helpful.
30 |
31 | - Create a datasets subfolder under your cloned AOGNet-v2 and a symbolic link to the ImageNet dataset
32 |
33 | ```
34 | $ cd AOGNet-v2
35 | $ mkdir datasets
36 | $ ln -s PATH_TO_YOUR_IMAGENET ./datasets/
37 | ```
38 |
39 | ## Pretrained models for Evaluation
40 |
41 | ```
42 | $ cd AOGNet-v2
43 | mkdir pretrained_models
44 | ```
45 | Download pretrained models from the links provided in the tables below, unzip it and place into the pretrained_models directory.
46 | ```
47 | $ ./scripts/test_fp16.sh
48 | ```
49 |
50 | - Models trained with normal training setup: 120 epochs, cosine lr scheduling, SGD, ... (MobileNet-v2 trained with 150 epochs). All models are trained with 8 Nvidia V100 GPUs.
51 |
52 | | Method | #Params | FLOPS | top-1 error (%) | top-5 error (%)| Link |
53 | |---|---|---|---|---|---|
54 | | ResNet-50-BN | 25.56M | 4.09G | 23.01 | 6.68 | [Google Drive](https://drive.google.com/open?id=1NqGbN5XjEWufq-V3WXz0eXJ2HKDL2wGV) |
55 | | ResNet-50-GN* | 25.56M | 4.09G | 23.52 | 6.85 | - |
56 | | ResNet-50-SN* | 25.56M | - | 22.43 | 6.35 | - |
57 | | ResNet-50-SE* | 28.09M | - | 22.37 | 6.36 | - |
58 | | ResNet-50-AN (w/ BN) | 25.76M | 4.09G | 21.59 | 5.88 | [Google Drive](https://drive.google.com/open?id=17MNrFTBJS__1cW6PNcLS1lJ7590QyA6m) |
59 | | ResNet-101-BN | 44.57M | 8.12G | 21.33 | 5.85 | [Google Drive](https://drive.google.com/open?id=1txeVNkDDKd45dIrJAZyW6Si1UgGvqtXu)|
60 | | ResNet-101-AN (w/ BN) | 45.00M | 8.12G | 20.61 | 5.41 |[Google Drive](https://drive.google.com/open?id=1Cq-D2Gm2QeZW2WqfCrSeqou4AhYqH9yN) |
61 | | MobileNet-v2-BN | 3.50M | 0.34G | 28.69 | 9.33 | [Google Drive]()|
62 | | MobileNet-v2-AN (w/ BN) | 3.56M | 0.34G | 26.67 | 8.56 | [Google Drive](https://drive.google.com/open?id=1pD-fHdzyVW5ufC8FB4R7yPtjN4Z2St4t)|
63 | | AOGNet-12M-BN | 12.26M | 2.19G | 22.22 | 6.06 | [Google Drive](https://drive.google.com/open?id=1PsA2EvEw7wsCGhzp65Lfq3w1pvkU65o0) |
64 | | AOGNet-12M-AN (w/ BN) | 12.37M | 2.19G | 21.28 | 5.76 | [Google Drive](https://drive.google.com/open?id=1t4Oa0vZuakNfR-PWhMIiOWsVTCZ8Z-G-) |
65 | | AOGNet-40M-BN | 40.15M | 7.51G | 19.84 | 4.94 | [Google Drive](https://drive.google.com/open?id=1u-ToLniZVEkBlbSGQL49A3V72cXS5Dwd) |
66 | | AOGNet-40M-AN (w/ BN) | 40.39M | 7.51G | 19.33 | 4.72 | [Google Drive](https://drive.google.com/open?id=1LWvdhjxQ259_Gq-YMNT70fAraDGsIWKo) |
67 |
68 | \* From original paper
69 |
70 | - Models trained with advanced training setup: 200 epochs, 0.1 label smoothing, 0.2 mixup
71 |
72 | | Method | #Params | FLOPS | top-1 error (%) | top-5 error (%)| Link |
73 | |---|---|---|---|---|---|
74 | | ResNet-50-BN | 25.56M | 4.09G | 21.08 | 5.56 | [Google Drive](https://drive.google.com/open?id=1SoE0U9W5ghpEhmCCYqClq9Io1WkZts_C) |
75 | | ResNet-50-AN (w/ BN) | 25.76M | 4.09G | 19.92 | 5.04 | [Google Drive](https://drive.google.com/open?id=1qWSN-95Blq-MBCFCzh1DpmmB7kk9VHbU) |
76 | | ResNet-101-BN | 44.57M | 8.12G | 19.71 | 4.89 | [Google Drive](https://drive.google.com/open?id=1oqPQG7Oc0REvrLABAjsrmOpP4GdGgXEG)|
77 | | ResNet-101-AN (w/ BN) | 45.00M | 8.12G | 18.85 | 4.63 |[Google Drive](https://drive.google.com/open?id=1habUSoSotE8-fEq60IhnR2CjYM14wssv) |
78 | | AOGNet-12M-BN | 12.26M | 2.19G | 21.63 | 5.60 | [Google Drive](https://drive.google.com/open?id=14gZ18L4mqHWI79P3d-gY01jqqGVsA9hB) |
79 | | AOGNet-12M-AN (w/ BN) | 12.37M | 2.19G | 20.57 | 5.38 | [Google Drive](https://drive.google.com/open?id=1GY6TbanFrXSBcFD-kjmBixf4wG2Vp1sJ) |
80 | | AOGNet-40M-BN | 40.15M | 7.51G | 18.70 | 4.47 | [Google Drive](https://drive.google.com/open?id=1MugHc_9rn5wR7d1Hyx4oaGjc64l8v_aG) |
81 | | AOGNet-40M-AN (w/ BN) | 40.39M | 7.51G | 18.13 | 4.26 | [Google Drive](https://drive.google.com/open?id=1w5W12mDgni0DPCANuvPoIbmKKPND-Kdh) |
82 |
83 | - Remarks: The accuracy from the pretrained models might be slightly different than that reported during the training (used in our papers). The reason is still unclear.
84 |
85 |
86 | ## Perform training on ImageNet dataset
87 |
88 | ```
89 | $ cd AOGNet-v2
90 | $ ./scripts/train_fp16.sh
91 | ```
92 |
93 | e.g. For training AOGNet-12M with AN
94 | ```
95 | $ ./scripts/train_fp16.sh configs/aognet-12m-an-imagenet.yaml
96 | ```
97 |
98 | See more configuration files in [configs](https://github.com/iVMCL/AOGNet-v2/tree/master/configs). Change the GPU settings in scripts/train_fp16.sh
99 |
100 |
101 |
102 | ## Object Detection and Instance Segmentation on COCO
103 | We performed object detection and instance segmentation task on COCO with our models pretrained on ImageNet. We implement it based on the [mmdetection](https://github.com/open-mmlab/mmdetection) framework. The code is released in [https://github.com/iVMCL/AttentiveNorm_Detection/tree/master/configs/attn_norm](https://github.com/iVMCL/AttentiveNorm_Detection/tree/master/configs/attn_norm).
104 |
105 | ## Citations
106 | Please consider citing the AOGNets or Attentive Normalization papers in your publications if it helps your research.
107 | ```
108 | @inproceedings{li2019aognets,
109 | title={AOGNets: Compositional Grammatical Architectures for Deep Learning},
110 | author={Li, Xilai and Song, Xi and Wu, Tianfu},
111 | booktitle={Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition},
112 | pages={6220--6230},
113 | year={2019}
114 | }
115 |
116 | @article{li2019attentive,
117 | title={Attentive Normalization},
118 | author={Li, Xilai and Sun, Wei and Wu, Tianfu},
119 | journal={arXiv preprint arXiv:1908.01259},
120 | year={2019}
121 | }
122 | ```
123 |
124 | ## Contact
125 |
126 | Please feel free to report issues and any related problems to Xilai Li (xli47 at ncsu dot edu), and Tianfu Wu (twu19 at ncsu dot edu).
127 |
128 | ## License
129 |
130 | AOGNets related codes are under [RESEARCH ONLY LICENSE](./LICENSE).
131 |
--------------------------------------------------------------------------------
/configs/aognet-12m-an-imagenet-200e-lbsmth-mixup.yaml:
--------------------------------------------------------------------------------
1 | ---
2 | batch_size: 256
3 | num_epoch: 200
4 | dataset: 'imagenet'
5 | num_classes: 1000
6 | crop_size: 224
7 | crop_interpolation: 2 ### 2: BILINEAR, 3:CUBIC
8 | use_cosine_lr: True ###
9 | cosine_lr_min: 0.0
10 | warmup_epochs: 5
11 | lr: 0.1
12 | lr_scale_factor: 256
13 | lr_milestones: [30, 60, 90, 100]
14 | momentum: 0.9
15 | wd: 0.0001
16 | nesterov: False
17 | activation_mode: 0 ### 1: leakyReLU, 2: ReLU6, other: ReLU
18 | init_mode: 'kaiming'
19 | norm_name: 'MixtureBatchNorm2d'
20 | #norm_name: 'BatchNorm2d'
21 | norm_groups: 0
22 | norm_k: [10, 10, 20, 20] ### per stage
23 | norm_attention_mode: 2
24 | norm_zero_gamma_init: False
25 | dataaug:
26 | imagenet_extra_aug: False ### ColorJitter and PCA
27 | labelsmoothing_rate: 0.1
28 | mixup_rate: 0.2
29 | stem:
30 | imagenet_head7x7: False
31 | replace_maxpool_with_res_bottleneck: False
32 | aognet:
33 | max_split: [2, 2, 2, 2]
34 | extra_node_hierarchy: [4, 4, 4, 4] ### 0: none, 4: lateral connection
35 | remove_symmetric_children_of_or_node: [1, 2, 1, 2] ### if true, aog structure is much simplified and bigger filters and more units can be used
36 | terminal_node_no_slice: [0, 0, 0, 0]
37 | stride: [1, 2, 2, 2]
38 | drop_rate: [0.0, 0.0, 0.1, 0.1]
39 | bottleneck_ratio: 0.25
40 | handle_dbl_cnt: True
41 | handle_tnode_dbl_cnt: False
42 | handle_dbl_cnt_in_param_init: False
43 | use_group_conv: False
44 | width_per_group: 0
45 | when_downsample: 1
46 | replace_stride_with_avgpool: True ### in shortcut
47 | use_elem_max_for_ORNodes: False
48 | filter_list: [32, 128, 256, 512, 824] ### try to keep 1:4:2:2 ... except for the final stage which can be adusted for fitting the model size
49 | out_channels: [0, 0]
50 | blocks: [2, 2, 2, 1]
51 | dims: [2, 2, 4, 4]
52 |
53 |
54 |
55 |
--------------------------------------------------------------------------------
/configs/aognet-12m-an-imagenet.yaml:
--------------------------------------------------------------------------------
1 | ---
2 | batch_size: 256
3 | num_epoch: 120
4 | dataset: 'imagenet'
5 | num_classes: 1000
6 | crop_size: 224
7 | crop_interpolation: 2 ### 2: BILINEAR, 3:CUBIC
8 | use_cosine_lr: True ###
9 | cosine_lr_min: 0.0
10 | warmup_epochs: 5
11 | lr: 0.1
12 | lr_scale_factor: 256
13 | lr_milestones: [30, 60, 90, 100]
14 | momentum: 0.9
15 | wd: 0.0001
16 | nesterov: False
17 | activation_mode: 0 ### 1: leakyReLU, 2: ReLU6, other: ReLU
18 | init_mode: 'kaiming'
19 | norm_name: 'MixtureBatchNorm2d'
20 | #norm_name: 'BatchNorm2d'
21 | norm_groups: 0
22 | norm_k: [10, 10, 20, 20] ### per stage
23 | norm_attention_mode: 2
24 | norm_zero_gamma_init: False
25 | dataaug:
26 | imagenet_extra_aug: False ### ColorJitter and PCA
27 | labelsmoothing_rate: 0.0
28 | mixup_rate: 0.0
29 | stem:
30 | imagenet_head7x7: False
31 | replace_maxpool_with_res_bottleneck: False
32 | aognet:
33 | max_split: [2, 2, 2, 2]
34 | extra_node_hierarchy: [4, 4, 4, 4] ### 0: none, 4: lateral connection
35 | remove_symmetric_children_of_or_node: [1, 2, 1, 2] ### if true, aog structure is much simplified and bigger filters and more units can be used
36 | terminal_node_no_slice: [0, 0, 0, 0]
37 | stride: [1, 2, 2, 2]
38 | drop_rate: [0.0, 0.0, 0.1, 0.1]
39 | bottleneck_ratio: 0.25
40 | handle_dbl_cnt: True
41 | handle_tnode_dbl_cnt: False
42 | handle_dbl_cnt_in_param_init: False
43 | use_group_conv: False
44 | width_per_group: 0
45 | when_downsample: 1
46 | replace_stride_with_avgpool: True ### in shortcut
47 | use_elem_max_for_ORNodes: False
48 | filter_list: [32, 128, 256, 512, 824] ### try to keep 1:4:2:2 ... except for the final stage which can be adusted for fitting the model size
49 | out_channels: [0, 0]
50 | blocks: [2, 2, 2, 1]
51 | dims: [2, 2, 4, 4]
52 |
53 |
54 |
55 |
--------------------------------------------------------------------------------
/configs/aognet-12m-imagenet-200e-lbsmth-mixup.yaml:
--------------------------------------------------------------------------------
1 | ---
2 | batch_size: 256
3 | num_epoch: 200
4 | dataset: 'imagenet'
5 | num_classes: 1000
6 | crop_size: 224
7 | crop_interpolation: 2 ### 2: BILINEAR, 3:CUBIC
8 | use_cosine_lr: True ###
9 | cosine_lr_min: 0.0
10 | warmup_epochs: 5
11 | lr: 0.1
12 | lr_scale_factor: 256
13 | lr_milestones: [30, 60, 90, 100]
14 | momentum: 0.9
15 | wd: 0.0001
16 | nesterov: False
17 | activation_mode: 0 ### 1: leakyReLU, 2: ReLU6, other: ReLU
18 | init_mode: 'kaiming'
19 | #norm_name: 'MixtureBatchNorm2d'
20 | norm_name: 'BatchNorm2d'
21 | norm_groups: 0
22 | norm_k: [10, 10, 20, 20] ### per stage
23 | norm_attention_mode: 2
24 | norm_zero_gamma_init: False
25 | dataaug:
26 | imagenet_extra_aug: False ### ColorJitter and PCA
27 | labelsmoothing_rate: 0.1
28 | mixup_rate: 0.2
29 | stem:
30 | imagenet_head7x7: False
31 | replace_maxpool_with_res_bottleneck: False
32 | aognet:
33 | max_split: [2, 2, 2, 2]
34 | extra_node_hierarchy: [4, 4, 4, 4] ### 0: none, 4: lateral connection
35 | remove_symmetric_children_of_or_node: [1, 2, 1, 2] ### if true, aog structure is much simplified and bigger filters and more units can be used
36 | terminal_node_no_slice: [0, 0, 0, 0]
37 | stride: [1, 2, 2, 2]
38 | drop_rate: [0.0, 0.0, 0.1, 0.1]
39 | bottleneck_ratio: 0.25
40 | handle_dbl_cnt: True
41 | handle_tnode_dbl_cnt: False
42 | handle_dbl_cnt_in_param_init: False
43 | use_group_conv: False
44 | width_per_group: 0
45 | when_downsample: 1
46 | replace_stride_with_avgpool: True ### in shortcut
47 | use_elem_max_for_ORNodes: False
48 | filter_list: [32, 128, 256, 512, 824] ### try to keep 1:4:2:2 ... except for the final stage which can be adusted for fitting the model size
49 | out_channels: [0, 0]
50 | blocks: [2, 2, 2, 1]
51 | dims: [2, 2, 4, 4]
52 |
53 |
54 |
55 |
--------------------------------------------------------------------------------
/configs/aognet-12m-imagenet.yaml:
--------------------------------------------------------------------------------
1 | ---
2 | batch_size: 256
3 | num_epoch: 120
4 | dataset: 'imagenet'
5 | num_classes: 1000
6 | crop_size: 224
7 | crop_interpolation: 3 ### 2: BILINEAR, 3:CUBIC
8 | use_cosine_lr: True ###
9 | cosine_lr_min: 0.0
10 | warmup_epochs: 5
11 | lr: 0.1
12 | lr_scale_factor: 256
13 | lr_milestones: [30, 60, 90, 100]
14 | momentum: 0.9
15 | wd: 0.0001
16 | nesterov: False
17 | activation_mode: 0 ### 1: leakyReLU, 2: ReLU6, other: ReLU
18 | init_mode: 'kaiming'
19 | norm_name: 'BatchNorm2d'
20 | norm_groups: 0
21 | norm_k: [10, 10, 20, 20] ### per stage
22 | norm_attention_mode: 2
23 | norm_zero_gamma_init: False
24 | dataaug:
25 | imagenet_extra_aug: False ### ColorJitter and PCA
26 | labelsmoothing_rate: 0.0
27 | mixup_rate: 0.0
28 | stem:
29 | imagenet_head7x7: False
30 | replace_maxpool_with_res_bottleneck: False
31 | aognet:
32 | max_split: [2, 2, 2, 2]
33 | extra_node_hierarchy: [4, 4, 4, 4] ### 0: none, 4: lateral connection
34 | remove_symmetric_children_of_or_node: [1, 2, 1, 2] ### if true, aog structure is much simplified and bigger filters and more units can be used
35 | terminal_node_no_slice: [0, 0, 0, 0]
36 | stride: [1, 2, 2, 2]
37 | drop_rate: [0.0, 0.0, 0.1, 0.1]
38 | bottleneck_ratio: 0.25
39 | handle_dbl_cnt: True
40 | handle_tnode_dbl_cnt: False
41 | handle_dbl_cnt_in_param_init: False
42 | use_group_conv: False
43 | width_per_group: 0
44 | when_downsample: 1
45 | replace_stride_with_avgpool: True ### in shortcut
46 | use_elem_max_for_ORNodes: False
47 | filter_list: [32, 128, 256, 512, 824] ### try to keep 1:4:2:2 ... except for the final stage which can be adusted for fitting the model size
48 | out_channels: [0, 0]
49 | blocks: [2, 2, 2, 1]
50 | dims: [2, 2, 4, 4]
51 |
52 |
53 |
54 |
--------------------------------------------------------------------------------
/configs/aognet-40m-an-imagenet-200e-lbsmth-mixup.yaml:
--------------------------------------------------------------------------------
1 | ---
2 | batch_size: 128
3 | num_epoch: 200
4 | dataset: 'imagenet'
5 | num_classes: 1000
6 | crop_size: 224
7 | crop_interpolation: 2 ### 2: BILINEAR, 3:CUBIC
8 | use_cosine_lr: True ###
9 | cosine_lr_min: 0.0
10 | warmup_epochs: 5
11 | lr: 0.1
12 | lr_scale_factor: 256
13 | lr_milestones: [30, 60, 90, 100]
14 | momentum: 0.9
15 | wd: 0.0001
16 | nesterov: False
17 | activation_mode: 0 ### 1: leakyReLU, 2: ReLU6, other: ReLU
18 | init_mode: 'kaiming'
19 | norm_name: 'MixtureBatchNorm2d'
20 | #norm_name: 'BatchNorm2d'
21 | norm_groups: 0
22 | norm_k: [10, 10, 20, 20] ### per stage
23 | norm_attention_mode: 2
24 | norm_zero_gamma_init: False
25 | dataaug:
26 | imagenet_extra_aug: False ### ColorJitter and PCA
27 | labelsmoothing_rate: 0.1
28 | mixup_rate: 0.2
29 | stem:
30 | imagenet_head7x7: False
31 | replace_maxpool_with_res_bottleneck: False
32 | aognet:
33 | max_split: [2, 2, 2, 2]
34 | extra_node_hierarchy: [4, 4, 4, 4] ### 0: none, 4: lateral connection
35 | remove_symmetric_children_of_or_node: [1, 2, 1, 2] ### if true, aog structure is much simplified and bigger filters and more units can be used
36 | terminal_node_no_slice: [0, 0, 0, 0]
37 | stride: [1, 2, 2, 2]
38 | drop_rate: [0.0, 0.0, 0.1, 0.1]
39 | bottleneck_ratio: 0.25
40 | handle_dbl_cnt: True
41 | handle_tnode_dbl_cnt: False
42 | handle_dbl_cnt_in_param_init: False
43 | use_group_conv: False
44 | width_per_group: 0
45 | when_downsample: 1
46 | replace_stride_with_avgpool: True ### in shortcut
47 | use_elem_max_for_ORNodes: False
48 | filter_list: [56, 224, 448, 896, 1400] ### try to keep 1:4:2:2 ... except for the final stage which can be adusted for fitting the model size
49 | out_channels: [0, 0]
50 | blocks: [2, 2, 3, 1]
51 | dims: [2, 2, 4, 4]
52 |
53 |
54 |
55 |
--------------------------------------------------------------------------------
/configs/aognet-40m-an-imagenet.yaml:
--------------------------------------------------------------------------------
1 | ---
2 | batch_size: 128
3 | num_epoch: 120
4 | dataset: 'imagenet'
5 | num_classes: 1000
6 | crop_size: 224
7 | crop_interpolation: 2 ### 2: BILINEAR, 3:CUBIC
8 | use_cosine_lr: True ###
9 | cosine_lr_min: 0.0
10 | warmup_epochs: 5
11 | lr: 0.1
12 | lr_scale_factor: 256
13 | lr_milestones: [30, 60, 90, 100]
14 | momentum: 0.9
15 | wd: 0.0001
16 | nesterov: False
17 | activation_mode: 0 ### 1: leakyReLU, 2: ReLU6, other: ReLU
18 | init_mode: 'kaiming'
19 | norm_name: 'MixtureBatchNorm2d'
20 | #norm_name: 'BatchNorm2d'
21 | norm_groups: 0
22 | norm_k: [10, 10, 20, 20] ### per stage
23 | norm_attention_mode: 2
24 | norm_zero_gamma_init: False
25 | dataaug:
26 | imagenet_extra_aug: False ### ColorJitter and PCA
27 | labelsmoothing_rate: 0.0
28 | mixup_rate: 0.0
29 | stem:
30 | imagenet_head7x7: False
31 | replace_maxpool_with_res_bottleneck: False
32 | aognet:
33 | max_split: [2, 2, 2, 2]
34 | extra_node_hierarchy: [4, 4, 4, 4] ### 0: none, 4: lateral connection
35 | remove_symmetric_children_of_or_node: [1, 2, 1, 2] ### if true, aog structure is much simplified and bigger filters and more units can be used
36 | terminal_node_no_slice: [0, 0, 0, 0]
37 | stride: [1, 2, 2, 2]
38 | drop_rate: [0.0, 0.0, 0.1, 0.1]
39 | bottleneck_ratio: 0.25
40 | handle_dbl_cnt: True
41 | handle_tnode_dbl_cnt: False
42 | handle_dbl_cnt_in_param_init: False
43 | use_group_conv: False
44 | width_per_group: 0
45 | when_downsample: 1
46 | replace_stride_with_avgpool: True ### in shortcut
47 | use_elem_max_for_ORNodes: False
48 | filter_list: [56, 224, 448, 896, 1400] ### try to keep 1:4:2:2 ... except for the final stage which can be adusted for fitting the model size
49 | out_channels: [0, 0]
50 | blocks: [2, 2, 3, 1]
51 | dims: [2, 2, 4, 4]
52 |
53 |
54 |
55 |
--------------------------------------------------------------------------------
/configs/aognet-40m-imagenet-200e-lbsmth-mixup.yaml:
--------------------------------------------------------------------------------
1 | ---
2 | batch_size: 128
3 | num_epoch: 200
4 | dataset: 'imagenet'
5 | num_classes: 1000
6 | crop_size: 224
7 | crop_interpolation: 2 ### 2: BILINEAR, 3:CUBIC
8 | use_cosine_lr: True ###
9 | cosine_lr_min: 0.0
10 | warmup_epochs: 5
11 | lr: 0.1
12 | lr_scale_factor: 256
13 | lr_milestones: [30, 60, 90, 100]
14 | momentum: 0.9
15 | wd: 0.0001
16 | nesterov: False
17 | activation_mode: 0 ### 1: leakyReLU, 2: ReLU6, other: ReLU
18 | init_mode: 'kaiming'
19 | #norm_name: 'MixtureBatchNorm2d'
20 | norm_name: 'BatchNorm2d'
21 | norm_groups: 0
22 | norm_k: [10, 10, 20, 20] ### per stage
23 | norm_attention_mode: 2
24 | norm_zero_gamma_init: False
25 | dataaug:
26 | imagenet_extra_aug: False ### ColorJitter and PCA
27 | labelsmoothing_rate: 0.1
28 | mixup_rate: 0.2
29 | stem:
30 | imagenet_head7x7: False
31 | replace_maxpool_with_res_bottleneck: False
32 | aognet:
33 | max_split: [2, 2, 2, 2]
34 | extra_node_hierarchy: [4, 4, 4, 4] ### 0: none, 4: lateral connection
35 | remove_symmetric_children_of_or_node: [1, 2, 1, 2] ### if true, aog structure is much simplified and bigger filters and more units can be used
36 | terminal_node_no_slice: [0, 0, 0, 0]
37 | stride: [1, 2, 2, 2]
38 | drop_rate: [0.0, 0.0, 0.1, 0.1]
39 | bottleneck_ratio: 0.25
40 | handle_dbl_cnt: True
41 | handle_tnode_dbl_cnt: False
42 | handle_dbl_cnt_in_param_init: False
43 | use_group_conv: False
44 | width_per_group: 0
45 | when_downsample: 1
46 | replace_stride_with_avgpool: True ### in shortcut
47 | use_elem_max_for_ORNodes: False
48 | filter_list: [56, 224, 448, 896, 1400] ### try to keep 1:4:2:2 ... except for the final stage which can be adusted for fitting the model size
49 | out_channels: [0, 0]
50 | blocks: [2, 2, 3, 1]
51 | dims: [2, 2, 4, 4]
52 |
53 |
54 |
55 |
--------------------------------------------------------------------------------
/configs/aognet-40m-imagenet.yaml:
--------------------------------------------------------------------------------
1 | ---
2 | batch_size: 128
3 | num_epoch: 120
4 | dataset: 'imagenet'
5 | num_classes: 1000
6 | crop_size: 224
7 | crop_interpolation: 2 ### 2: BILINEAR, 3:CUBIC
8 | use_cosine_lr: True ###
9 | cosine_lr_min: 0.0
10 | warmup_epochs: 5
11 | lr: 0.1
12 | lr_scale_factor: 256
13 | lr_milestones: [30, 60, 90, 100]
14 | momentum: 0.9
15 | wd: 0.0001
16 | nesterov: True
17 | activation_mode: 0 ### 1: leakyReLU, 2: ReLU6, other: ReLU
18 | init_mode: 'kaiming'
19 | norm_name: 'BatchNorm2d'
20 | norm_groups: 0
21 | norm_k: [10, 10, 20, 20] ### per stage
22 | norm_attention_mode: 2
23 | norm_zero_gamma_init: False
24 | dataaug:
25 | imagenet_extra_aug: False ### ColorJitter and PCA
26 | labelsmoothing_rate: 0.0
27 | mixup_rate: 0.0
28 | stem:
29 | imagenet_head7x7: False
30 | replace_maxpool_with_res_bottleneck: False
31 | aognet:
32 | max_split: [2, 2, 2, 2]
33 | extra_node_hierarchy: [4, 4, 4, 4] ### 0: none, 4: lateral connection
34 | remove_symmetric_children_of_or_node: [1, 2, 1, 2] ### if true, aog structure is much simplified and bigger filters and more units can be used
35 | terminal_node_no_slice: [0, 0, 0, 0]
36 | stride: [1, 2, 2, 2]
37 | drop_rate: [0.0, 0.0, 0.1, 0.1]
38 | bottleneck_ratio: 0.25
39 | handle_dbl_cnt: True
40 | handle_tnode_dbl_cnt: False
41 | handle_dbl_cnt_in_param_init: False
42 | use_group_conv: False
43 | width_per_group: 0
44 | when_downsample: 1
45 | replace_stride_with_avgpool: False ### in shortcut
46 | use_elem_max_for_ORNodes: False
47 | filter_list: [56, 224, 448, 896, 1400] ### try to keep 1:4:2:2 ... except for the final stage which can be adusted for fitting the model size
48 | out_channels: [0, 0]
49 | blocks: [2, 2, 3, 1]
50 | dims: [2, 2, 4, 4]
51 |
52 |
53 |
54 |
--------------------------------------------------------------------------------
/configs/mobilenetv2-an-imagenet.yaml:
--------------------------------------------------------------------------------
1 | ---
2 | arch: 'mobilenet_v2'
3 | batch_size: 128
4 | num_epoch: 150
5 | dataset: 'imagenet'
6 | num_classes: 1000
7 | crop_size: 224
8 | crop_interpolation: 2 ### 2: BILINEAR, 3:CUBIC
9 | use_cosine_lr: True
10 | cosine_lr_min: 0.0
11 | warmup_epochs: 5
12 | lr: 0.05
13 | lr_scale_factor: 256
14 | lr_milestones: [30, 60, 90, 100]
15 | momentum: 0.9
16 | wd: 0.00004
17 | nesterov: False
18 | activation_mode: 0 ### 1: leakyReLU, 2: , other: ReLU
19 | #norm_name: 'BatchNorm2d'
20 | norm_name: 'MixtureBatchNorm2d'
21 | norm_groups: 32
22 | norm_k: [5, 5, 5, 5, 10, 10, 10]
23 | norm_attention_mode: 2
24 | dataaug:
25 | imagenet_extra_aug: False ###ColorJitter and PCA
26 | labelsmoothing_rate: 0.0
27 | mixup_rate: 0.0
28 | mobilenet:
29 | rm_se: False
30 | use_mn_in_se: False
31 |
--------------------------------------------------------------------------------
/configs/resnet-101-an-imagenet-200e-lbsmth-mixup.yaml:
--------------------------------------------------------------------------------
1 | ---
2 | arch: 'resnet101'
3 | batch_size: 128
4 | num_epoch: 200
5 | dataset: 'imagenet'
6 | num_classes: 1000
7 | crop_size: 224
8 | crop_interpolation: 2 ### 2: BILINEAR, 3:CUBIC
9 | use_cosine_lr: True
10 | cosine_lr_min: 0.0
11 | warmup_epochs: 5
12 | lr: 0.1
13 | lr_scale_factor: 256
14 | lr_milestones: [30, 60, 90, 100]
15 | momentum: 0.9
16 | wd: 0.0001
17 | nesterov: False
18 | activation_mode: 0 ### 1: leakyReLU, 2: , other: ReLU
19 | norm_name: 'MixtureBatchNorm2d'
20 | norm_groups: 32
21 | norm_k: [10, 10, 20, 20]
22 | norm_attention_mode: 2
23 | norm_zero_gamma_init: True
24 | norm_all_mix: False
25 | dataaug:
26 | imagenet_extra_aug: False ###ColorJitter and PCA
27 | labelsmoothing_rate: 0.1
28 | mixup_rate: 0.2
29 | stem:
30 | imagenet_head7x7: False
31 | stem_kernel_size: 7
32 | stem_stride: 2
33 | replace_maxpool_with_res_bottleneck: False
34 | resnet:
35 | base_inplanes: 64
36 | replace_stride_with_dilation: [False, False, False]
37 | replace_stride_with_avgpool: True
38 | extra_norm_ac: False
39 |
40 |
41 |
42 |
43 |
--------------------------------------------------------------------------------
/configs/resnet-101-an-imagenet.yaml:
--------------------------------------------------------------------------------
1 | ---
2 | arch: 'resnet101'
3 | batch_size: 128
4 | num_epoch: 120
5 | dataset: 'imagenet'
6 | num_classes: 1000
7 | crop_size: 224
8 | crop_interpolation: 2 ### 2: BILINEAR, 3:CUBIC
9 | use_cosine_lr: True
10 | cosine_lr_min: 0.0
11 | warmup_epochs: 5
12 | lr: 0.1
13 | lr_scale_factor: 256
14 | lr_milestones: [30, 60, 90, 100]
15 | momentum: 0.9
16 | wd: 0.0001
17 | nesterov: False
18 | activation_mode: 0 ### 1: leakyReLU, 2: , other: ReLU
19 | norm_name: 'MixtureBatchNorm2d'
20 | norm_groups: 32
21 | norm_k: [10, 10, 20, 20]
22 | norm_attention_mode: 2
23 | norm_zero_gamma_init: True
24 | norm_all_mix: False
25 | dataaug:
26 | imagenet_extra_aug: False ###ColorJitter and PCA
27 | labelsmoothing_rate: 0.0
28 | mixup_rate: 0.0
29 | stem:
30 | imagenet_head7x7: True
31 | stem_kernel_size: 7
32 | stem_stride: 2
33 | replace_maxpool_with_res_bottleneck: False
34 | resnet:
35 | base_inplanes: 64
36 | replace_stride_with_dilation: [False, False, False]
37 | replace_stride_with_avgpool: False
38 | extra_norm_ac: False
39 |
40 |
41 |
42 |
43 |
--------------------------------------------------------------------------------
/configs/resnet-101-imagenet-200e-lbsmth-mixup.yaml:
--------------------------------------------------------------------------------
1 | ---
2 | arch: 'resnet101'
3 | batch_size: 128
4 | num_epoch: 200
5 | dataset: 'imagenet'
6 | num_classes: 1000
7 | crop_size: 224
8 | crop_interpolation: 2 ### 2: BILINEAR, 3:CUBIC
9 | use_cosine_lr: True
10 | cosine_lr_min: 0.0
11 | warmup_epochs: 5
12 | lr: 0.1
13 | lr_scale_factor: 256
14 | lr_milestones: [30, 60, 90, 100]
15 | momentum: 0.9
16 | wd: 0.0001
17 | nesterov: False
18 | activation_mode: 0 ### 1: leakyReLU, 2: , other: ReLU
19 | norm_name: 'BatchNorm2d'
20 | norm_groups: 32
21 | norm_k: [10, 10, 20, 20]
22 | norm_attention_mode: 2
23 | norm_zero_gamma_init: True
24 | norm_all_mix: False
25 | dataaug:
26 | imagenet_extra_aug: False ###ColorJitter and PCA
27 | labelsmoothing_rate: 0.1
28 | mixup_rate: 0.2
29 | stem:
30 | imagenet_head7x7: False
31 | stem_kernel_size: 7
32 | stem_stride: 2
33 | replace_maxpool_with_res_bottleneck: False
34 | resnet:
35 | base_inplanes: 64
36 | replace_stride_with_dilation: [False, False, False]
37 | replace_stride_with_avgpool: True
38 | extra_norm_ac: False
39 |
40 |
41 |
42 |
43 |
--------------------------------------------------------------------------------
/configs/resnet-101-imagenet.yaml:
--------------------------------------------------------------------------------
1 | ---
2 | arch: 'resnet101'
3 | batch_size: 128
4 | num_epoch: 120
5 | dataset: 'imagenet'
6 | num_classes: 1000
7 | crop_size: 224
8 | crop_interpolation: 2 ### 2: BILINEAR, 3:CUBIC
9 | use_cosine_lr: True
10 | cosine_lr_min: 0.0
11 | warmup_epochs: 5
12 | lr: 0.1
13 | lr_scale_factor: 256
14 | lr_milestones: [30, 60, 90, 100]
15 | momentum: 0.9
16 | wd: 0.0001
17 | nesterov: False
18 | activation_mode: 0 ### 1: leakyReLU, 2: , other: ReLU
19 | norm_name: 'BatchNorm2d'
20 | norm_groups: 32
21 | norm_k: [10, 10, 20, 20]
22 | norm_attention_mode: 2
23 | norm_zero_gamma_init: True
24 | norm_all_mix: False
25 | dataaug:
26 | imagenet_extra_aug: False ###ColorJitter and PCA
27 | labelsmoothing_rate: 0.0
28 | mixup_rate: 0.0
29 | stem:
30 | imagenet_head7x7: True
31 | stem_kernel_size: 7
32 | stem_stride: 2
33 | replace_maxpool_with_res_bottleneck: False
34 | resnet:
35 | base_inplanes: 64
36 | replace_stride_with_dilation: [False, False, False]
37 | replace_stride_with_avgpool: False
38 | extra_norm_ac: False
39 |
40 |
41 |
42 |
43 |
--------------------------------------------------------------------------------
/configs/resnet-50-an-imagenet-200e-lbsmth-mixup.yaml:
--------------------------------------------------------------------------------
1 | ---
2 | arch: 'resnet50'
3 | batch_size: 128
4 | num_epoch: 200
5 | dataset: 'imagenet'
6 | num_classes: 1000
7 | crop_size: 224
8 | crop_interpolation: 2 ### 2: BILINEAR, 3:CUBIC
9 | use_cosine_lr: True
10 | cosine_lr_min: 0.0
11 | warmup_epochs: 5
12 | lr: 0.1
13 | lr_scale_factor: 256
14 | lr_milestones: [30, 60, 90, 100]
15 | momentum: 0.9
16 | wd: 0.0001
17 | nesterov: False
18 | activation_mode: 0 ### 1: leakyReLU, 2: , other: ReLU
19 | norm_name: 'MixtureBatchNorm2d'
20 | norm_groups: 32
21 | norm_k: [10, 10, 20, 20]
22 | norm_attention_mode: 2
23 | norm_zero_gamma_init: True
24 | norm_all_mix: False
25 | dataaug:
26 | imagenet_extra_aug: False ###ColorJitter and PCA
27 | labelsmoothing_rate: 0.1
28 | mixup_rate: 0.2
29 | stem:
30 | imagenet_head7x7: False
31 | stem_kernel_size: 7
32 | stem_stride: 2
33 | replace_maxpool_with_res_bottleneck: False
34 | resnet:
35 | base_inplanes: 64
36 | replace_stride_with_dilation: [False, False, False]
37 | replace_stride_with_avgpool: True
38 | extra_norm_ac: False
39 |
40 |
41 |
42 |
43 |
--------------------------------------------------------------------------------
/configs/resnet-50-an-imagenet.yaml:
--------------------------------------------------------------------------------
1 | ---
2 | arch: 'resnet50'
3 | batch_size: 128
4 | num_epoch: 120
5 | dataset: 'imagenet'
6 | num_classes: 1000
7 | crop_size: 224
8 | crop_interpolation: 2 ### 2: BILINEAR, 3:CUBIC
9 | use_cosine_lr: True
10 | cosine_lr_min: 0.0
11 | warmup_epochs: 5
12 | lr: 0.1
13 | lr_scale_factor: 256
14 | lr_milestones: [30, 60, 90, 100]
15 | momentum: 0.9
16 | wd: 0.0001
17 | nesterov: False
18 | activation_mode: 0 ### 1: leakyReLU, 2: , other: ReLU
19 | norm_name: 'MixtureBatchNorm2d'
20 | norm_groups: 32
21 | norm_k: [10, 10, 20, 20]
22 | norm_attention_mode: 2
23 | norm_zero_gamma_init: True
24 | norm_all_mix: False
25 | dataaug:
26 | imagenet_extra_aug: False ###ColorJitter and PCA
27 | labelsmoothing_rate: 0.0
28 | mixup_rate: 0.0
29 | stem:
30 | imagenet_head7x7: True
31 | stem_kernel_size: 7
32 | stem_stride: 2
33 | replace_maxpool_with_res_bottleneck: False
34 | resnet:
35 | base_inplanes: 64
36 | replace_stride_with_dilation: [False, False, False]
37 | replace_stride_with_avgpool: False
38 | extra_norm_ac: False
39 |
40 |
41 |
42 |
43 |
--------------------------------------------------------------------------------
/configs/resnet-50-imagenet-200e-lbsmth-mixup.yaml:
--------------------------------------------------------------------------------
1 | ---
2 | arch: 'resnet50'
3 | batch_size: 128
4 | num_epoch: 200
5 | dataset: 'imagenet'
6 | num_classes: 1000
7 | crop_size: 224
8 | crop_interpolation: 2 ### 2: BILINEAR, 3:CUBIC
9 | use_cosine_lr: True
10 | cosine_lr_min: 0.0
11 | warmup_epochs: 5
12 | lr: 0.1
13 | lr_scale_factor: 256
14 | lr_milestones: [30, 60, 90, 100]
15 | momentum: 0.9
16 | wd: 0.0001
17 | nesterov: False
18 | activation_mode: 0 ### 1: leakyReLU, 2: , other: ReLU
19 | norm_name: 'BatchNorm2d'
20 | norm_groups: 32
21 | norm_k: [10, 10, 20, 20]
22 | norm_attention_mode: 2
23 | norm_zero_gamma_init: True
24 | norm_all_mix: False
25 | dataaug:
26 | imagenet_extra_aug: False ###ColorJitter and PCA
27 | labelsmoothing_rate: 0.1
28 | mixup_rate: 0.2
29 | stem:
30 | imagenet_head7x7: False
31 | stem_kernel_size: 7
32 | stem_stride: 2
33 | replace_maxpool_with_res_bottleneck: False
34 | resnet:
35 | base_inplanes: 64
36 | replace_stride_with_dilation: [False, False, False]
37 | replace_stride_with_avgpool: True
38 | extra_norm_ac: False
39 |
40 |
41 |
42 |
43 |
--------------------------------------------------------------------------------
/configs/resnet-50-imagenet.yaml:
--------------------------------------------------------------------------------
1 | ---
2 | arch: 'resnet50'
3 | batch_size: 128
4 | num_epoch: 120
5 | dataset: 'imagenet'
6 | num_classes: 1000
7 | crop_size: 224
8 | crop_interpolation: 2 ### 2: BILINEAR, 3:CUBIC
9 | use_cosine_lr: True
10 | cosine_lr_min: 0.0
11 | warmup_epochs: 5
12 | lr: 0.1
13 | lr_scale_factor: 256
14 | lr_milestones: [30, 60, 90, 100]
15 | momentum: 0.9
16 | wd: 0.0001
17 | nesterov: False
18 | activation_mode: 0 ### 1: leakyReLU, 2: , other: ReLU
19 | norm_name: 'BatchNorm2d'
20 | norm_groups: 32
21 | norm_k: [10, 10, 20, 20]
22 | norm_attention_mode: 2
23 | norm_zero_gamma_init: True
24 | norm_all_mix: False
25 | dataaug:
26 | imagenet_extra_aug: False ###ColorJitter and PCA
27 | labelsmoothing_rate: 0.0
28 | mixup_rate: 0.0
29 | stem:
30 | imagenet_head7x7: True
31 | stem_kernel_size: 7
32 | stem_stride: 2
33 | replace_maxpool_with_res_bottleneck: False
34 | resnet:
35 | base_inplanes: 64
36 | replace_stride_with_dilation: [False, False, False]
37 | replace_stride_with_avgpool: False
38 | extra_norm_ac: False
39 |
40 |
41 |
42 |
43 |
--------------------------------------------------------------------------------
/images/an-comparison.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/iVMCL/AOGNet-v2/a95a8696c131331607e81bb31eeae3405a76b969/images/an-comparison.png
--------------------------------------------------------------------------------
/images/teaser-imagenet-dissection.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/iVMCL/AOGNet-v2/a95a8696c131331607e81bb31eeae3405a76b969/images/teaser-imagenet-dissection.png
--------------------------------------------------------------------------------
/images/teaser.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/iVMCL/AOGNet-v2/a95a8696c131331607e81bb31eeae3405a76b969/images/teaser.png
--------------------------------------------------------------------------------
/models/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/iVMCL/AOGNet-v2/a95a8696c131331607e81bb31eeae3405a76b969/models/__init__.py
--------------------------------------------------------------------------------
/models/aognet/AOG.py:
--------------------------------------------------------------------------------
1 | """ RESEARCH ONLY LICENSE
2 | Copyright (c) 2018-2019 North Carolina State University.
3 | All rights reserved.
4 | Redistribution and use in source and binary forms, with or without modification, are permitted provided
5 | that the following conditions are met:
6 | 1. Redistributions and use are permitted for internal research purposes only, and commercial use
7 | is strictly prohibited under this license. Inquiries regarding commercial use should be directed to the
8 | Office of Research Commercialization at North Carolina State University, 919-215-7199,
9 | https://research.ncsu.edu/commercialization/contact/, commercialization@ncsu.edu .
10 | 2. Commercial use means the sale, lease, export, transfer, conveyance or other distribution to a
11 | third party for financial gain, income generation or other commercial purposes of any kind, whether
12 | direct or indirect. Commercial use also means providing a service to a third party for financial gain,
13 | income generation or other commercial purposes of any kind, whether direct or indirect.
14 | 3. Redistributions of source code must retain the above copyright notice, this list of conditions and
15 | the following disclaimer.
16 | 4. Redistributions in binary form must reproduce the above copyright notice, this list of conditions
17 | and the following disclaimer in the documentation and/or other materials provided with the
18 | distribution.
19 | 5. The names “North Carolina State University”, “NCSU” and any trade-name, personal name,
20 | trademark, trade device, service mark, symbol, image, icon, or any abbreviation, contraction or
21 | simulation thereof owned by North Carolina State University must not be used to endorse or promote
22 | products derived from this software without prior written permission. For written permission, please
23 | contact trademarks@ncsu.edu.
24 | Disclaimer: THIS SOFTWARE IS PROVIDED “AS IS” AND ANY EXPRESSED OR IMPLIED WARRANTIES,
25 | INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
26 | A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NORTH CAROLINA STATE UNIVERSITY BE
27 | LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
28 | (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
29 | DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
30 | LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR
31 | OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
32 | POSSIBILITY OF SUCH DAMAGE.
33 | """
34 | # The system is protected via patent (pending)
35 | # Written by Tianfu Wu and Xi Song
36 | # Contact: tianfu_wu@ncsu.edu, xsong.lhi@gmail.com
37 |
38 | # -*- coding: utf-8 -*-
39 | from __future__ import absolute_import
40 | from __future__ import division
41 | from __future__ import print_function # force to use print as function print(args)
42 | from __future__ import unicode_literals
43 |
44 | from math import ceil, floor
45 | from collections import deque
46 | import numpy as np
47 | import os
48 | import random
49 | import math
50 | import copy
51 |
52 |
53 | def get_aog(grid_ht, grid_wd, min_size=1, max_split=2,
54 | not_use_large_TerminalNode=False, turn_off_size_ratio_TerminalNode=1./4.,
55 | not_use_intermediate_TerminalNode=False,
56 | use_root_TerminalNode=True, use_tnode_as_alpha_channel=0,
57 | use_tnode_topdown_connection=False,
58 | use_tnode_bottomup_connection=False,
59 | use_tnode_bottomup_connection_layerwise=False,
60 | use_tnode_bottomup_connection_sequential=False,
61 | use_node_lateral_connection=False, # not include T-nodes
62 | use_node_lateral_connection_1=False, # include T-nodes
63 | use_super_OrNode=False,
64 | remove_single_child_or_node=False,
65 | remove_symmetric_children_of_or_node=0,
66 | mark_symmetric_syntatic_subgraph = False,
67 | max_children_kept_for_or=1000):
68 | aog_param = Param(grid_ht=grid_ht, grid_wd=grid_wd, min_size=min_size, max_split=max_split,
69 | not_use_large_TerminalNode=not_use_large_TerminalNode,
70 | turn_off_size_ratio_TerminalNode=turn_off_size_ratio_TerminalNode,
71 | not_use_intermediate_TerminalNode=not_use_intermediate_TerminalNode,
72 | use_root_TerminalNode=use_root_TerminalNode,
73 | use_tnode_as_alpha_channel=use_tnode_as_alpha_channel,
74 | use_tnode_topdown_connection=use_tnode_topdown_connection,
75 | use_tnode_bottomup_connection=use_tnode_bottomup_connection,
76 | use_tnode_bottomup_connection_layerwise=use_tnode_bottomup_connection_layerwise,
77 | use_tnode_bottomup_connection_sequential=use_tnode_bottomup_connection_sequential,
78 | use_node_lateral_connection=use_node_lateral_connection,
79 | use_node_lateral_connection_1=use_node_lateral_connection_1,
80 | use_super_OrNode=use_super_OrNode,
81 | remove_single_child_or_node=remove_single_child_or_node,
82 | mark_symmetric_syntatic_subgraph = mark_symmetric_syntatic_subgraph,
83 | remove_symmetric_children_of_or_node=remove_symmetric_children_of_or_node,
84 | max_children_kept_for_or=max_children_kept_for_or)
85 | aog = AOGrid(aog_param)
86 | aog.Create()
87 | return aog
88 |
89 |
90 | class NodeType(object):
91 | OrNode = "OrNode"
92 | AndNode = "AndNode"
93 | TerminalNode = "TerminalNode"
94 | Unknow = "Unknown"
95 |
96 |
97 | class SplitType(object):
98 | HorSplit = "Hor"
99 | VerSplit = "Ver"
100 | Unknown = "Unknown"
101 |
102 |
103 | class Param(object):
104 | """Input parameters for creating an AOG
105 | """
106 |
107 | def __init__(self, grid_ht=3, grid_wd=3, max_split=2, min_size=1, control_side_length=False,
108 | overlap_ratio=0., use_root_TerminalNode=False,
109 | not_use_large_TerminalNode=False, turn_off_size_ratio_TerminalNode=0.5,
110 | not_use_intermediate_TerminalNode= False,
111 | use_tnode_as_alpha_channel=0,
112 | use_tnode_topdown_connection=False,
113 | use_tnode_bottomup_connection=False,
114 | use_tnode_bottomup_connection_layerwise=False,
115 | use_tnode_bottomup_connection_sequential=False,
116 | use_node_lateral_connection=False,
117 | use_node_lateral_connection_1=False,
118 | use_super_OrNode=False,
119 | remove_single_child_or_node=False,
120 | remove_symmetric_children_of_or_node=0,
121 | mark_symmetric_syntatic_subgraph=False,
122 | max_children_kept_for_or=100):
123 | self.grid_ht = grid_ht
124 | self.grid_wd = grid_wd
125 | self.max_split = max_split # maximum number of child nodes when splitting an AND-node
126 | self.min_size = min_size # minimum side length or minimum area allowed
127 | self.control_side_length = control_side_length
128 | self.overlap_ratio = overlap_ratio
129 | self.use_root_terminal_node = use_root_TerminalNode
130 | self.not_use_large_terminal_node = not_use_large_TerminalNode
131 | self.turn_off_size_ratio_terminal_node = turn_off_size_ratio_TerminalNode
132 | self.not_use_intermediate_TerminalNode = not_use_intermediate_TerminalNode
133 | self.use_tnode_as_alpha_channel = use_tnode_as_alpha_channel
134 | self.use_tnode_topdown_connection = use_tnode_topdown_connection
135 | self.use_tnode_bottomup_connection = use_tnode_bottomup_connection
136 | self.use_tnode_bottomup_connection_layerwise = use_tnode_bottomup_connection_layerwise
137 | self.use_node_lateral_connection = use_node_lateral_connection
138 | self.use_node_lateral_connection_1 = use_node_lateral_connection_1
139 | self.use_tnode_bottomup_connection_sequential = use_tnode_bottomup_connection_sequential
140 | assert 1 >= self.use_node_lateral_connection_1 + self.use_node_lateral_connection + \
141 | self.use_tnode_topdown_connection + self.use_tnode_bottomup_connection + \
142 | self.use_tnode_bottomup_connection_layerwise + self.use_tnode_bottomup_connection_sequential, \
143 | 'only one type of node hierarchy can be used'
144 | self.use_super_OrNode = use_super_OrNode
145 | self.remove_single_child_or_node = remove_single_child_or_node
146 | self.remove_symmetric_children_of_or_node = remove_symmetric_children_of_or_node #0: not, 1: keep left, 2: keep right
147 | self.mark_symmetric_syntatic_subgraph = mark_symmetric_syntatic_subgraph # true, only mark the nodes which will be removed based on remove_symmetric_children_of_or_node
148 | self.max_children_kept_for_or = max_children_kept_for_or # how many child nodes kept for an OR-node
149 |
150 | self.get_tag()
151 |
152 | def get_tag(self):
153 | # identifier useful for naming a particular aog
154 | self.tag = '{}-{}-{}-{}-{}-{}-{}-{}-{}-{}-{}-{}-{}-{}-{}-{}-{}-{}-{}-{}-{}-{}'.format(
155 | self.grid_wd, self.grid_ht, self.max_split,
156 | self.min_size, self.control_side_length,
157 | self.overlap_ratio,
158 | self.use_root_terminal_node,
159 | self.not_use_large_terminal_node,
160 | self.turn_off_size_ratio_terminal_node,
161 | self.not_use_intermediate_TerminalNode,
162 | self.use_tnode_as_alpha_channel,
163 | self.use_tnode_topdown_connection,
164 | self.use_tnode_bottomup_connection,
165 | self.use_tnode_bottomup_connection_layerwise,
166 | self.use_tnode_bottomup_connection_sequential,
167 | self.use_node_lateral_connection,
168 | self.use_node_lateral_connection_1,
169 | self.use_super_OrNode,
170 | self.remove_single_child_or_node,
171 | self.remove_symmetric_children_of_or_node,
172 | self.mark_symmetric_syntatic_subgraph,
173 | self.max_children_kept_for_or)
174 |
175 |
176 | class Rect(object):
177 | """A simple rectangle
178 | """
179 |
180 | def __init__(self, x1=0, y1=0, x2=0, y2=0):
181 | self.x1 = x1
182 | self.y1 = y1
183 | self.x2 = x2
184 | self.y2 = y2
185 |
186 | def __eq__(self, other):
187 | """Override the default Equals behavior"""
188 | if isinstance(other, self.__class__):
189 | return self.__dict__ == other.__dict__
190 | return NotImplemented
191 |
192 | def __ne__(self, other):
193 | """Define a non-equality test"""
194 | if isinstance(other, self.__class__):
195 | return not self.__eq__(other)
196 | return NotImplemented
197 |
198 | def __hash__(self):
199 | """Override the default hash behavior (that returns id or the object)"""
200 | return hash(tuple(sorted(self.__dict__.items())))
201 |
202 | def Width(self):
203 | return self.x2 - self.x1 + 1
204 |
205 | def Height(self):
206 | return self.y2 - self.y1 + 1
207 |
208 | def Area(self):
209 | return self.Width() * self.Height()
210 |
211 | def MinLength(self):
212 | return min(self.Width(), self.Height())
213 |
214 | def IsOverlap(self, other):
215 | assert isinstance(other, self.__class__)
216 |
217 | x1 = max(self.x1, other.x1)
218 | x2 = min(self.x2, other.x2)
219 | if x1 > x2:
220 | return False
221 |
222 | y1 = max(self.y1, other.y1)
223 | y2 = min(self.y2, other.y2)
224 | if y1 > y2:
225 | return False
226 |
227 | return True
228 |
229 | def IsSame(self, other):
230 | assert isinstance(other, self.__class__)
231 |
232 | return self.Width() == other.Width() and self.Height() == other.Height()
233 |
234 |
235 | class Node(object):
236 | """Types of nodes in an AOG
237 | AND-node (structural decomposition),
238 | OR-node (alternative decompositions),
239 | TERMINAL-node (link to data).
240 | """
241 |
242 | def __init__(self, node_id=-1, node_type=NodeType.Unknow, rect_idx=-1,
243 | child_ids=None, parent_ids=None,
244 | split_type=SplitType.Unknown, split_step1=0, split_step2=0, is_symm=False,
245 | ancestors_ids=None):
246 | self.id = node_id
247 | self.node_type = node_type
248 | self.rect_idx = rect_idx
249 | self.child_ids = child_ids if child_ids is not None else []
250 | self.parent_ids = parent_ids if parent_ids is not None else []
251 | self.ancestors_ids = ancestors_ids if ancestors_ids is not None else [] # root or-node exlusive
252 | self.split_type = split_type
253 | self.split_step1 = split_step1
254 | self.split_step2 = split_step2
255 |
256 | # some utility variables used in object detection models
257 | self.on_off = True
258 | self.out_edge_visited_count = []
259 | self.which_classes_visited = {} # key=class_name, val=frequency
260 | self.npaths = 0.0
261 | self.is_symmetric = False
262 | self.has_dbl_counting = False
263 |
264 | def __eq__(self, other):
265 | if isinstance(other, self.__class__):
266 | res = ((self.node_type == other.node_type) and (self.rect_idx == other.rect_idx))
267 | if res:
268 | if self.node_type != NodeType.AndNode:
269 | return True
270 | else:
271 | if self.split_type != SplitType.Unknown:
272 | return (self.split_type == other.split_type) and (self.split_step1 == other.split_step1) and \
273 | (self.split_step2 == other.split_step2)
274 | else:
275 | return (set(self.child_ids) == set(other.child_ids))
276 |
277 | return False
278 |
279 | return NotImplemented
280 |
281 | def __ne__(self, other):
282 | if isinstance(other, self.__class__):
283 | return not self.__eq__(other)
284 | return NotImplemented
285 |
286 | def __hash__(self):
287 | """Override the default hash behavior (that returns id or the object)"""
288 | return hash(tuple(sorted(self.__dict__.items())))
289 |
290 |
291 | class AOGrid(object):
292 | """The AOGrid defines a Directed Acyclic And-Or Graph
293 | which is used to explore/unfold the space of latent structures
294 | of a grid (e.g., a 7 * 7 grid for a 100 * 200 lattice)
295 | """
296 |
297 | def __init__(self, param_):
298 | assert isinstance(param_, Param)
299 | self.param = param_
300 | assert self.param.max_split > 1
301 | self.primitive_set = []
302 | self.node_set = []
303 | self.num_TNodes = 0
304 | self.num_AndNodes = 0
305 | self.num_OrNodes = 0
306 | self.DFS = []
307 | self.BFS = []
308 | self.node_DFS = {}
309 | self.node_BFS = {}
310 | self.OrNodeIdxInBFS = {}
311 | self.TNodeIdxInBFS = {}
312 |
313 | # for color consistency in viz
314 | self.TNodeColors = {}
315 |
316 | def _AddPrimitve(self, rect):
317 | assert isinstance(rect, Rect)
318 |
319 | if rect in self.primitive_set:
320 | return self.primitive_set.index(rect)
321 |
322 | self.primitive_set.append(rect)
323 |
324 | return len(self.primitive_set) - 1
325 |
326 | def _AddNode(self, node, not_create_if_existed=True):
327 | assert isinstance(node, Node)
328 |
329 | if node in self.node_set and not_create_if_existed:
330 | node = self.node_set[self.node_set.index(node)]
331 | return False, node
332 |
333 | node.id = len(self.node_set)
334 | if node.node_type == NodeType.AndNode:
335 | self.num_AndNodes += 1
336 | elif node.node_type == NodeType.OrNode:
337 | self.num_OrNodes += 1
338 | elif node.node_type == NodeType.TerminalNode:
339 | self.num_TNodes += 1
340 | else:
341 | raise NotImplementedError
342 |
343 | self.node_set.append(node)
344 |
345 | return True, node
346 |
347 | def _DoSplit(self, rect):
348 | assert isinstance(rect, Rect)
349 |
350 | if self.param.control_side_length:
351 | return rect.Width() >= self.param.min_size and rect.Height() >= self.param.min_size
352 |
353 | return rect.Area() > self.param.min_size
354 |
355 | def _SplitStep(self, sz):
356 | if self.param.control_side_length:
357 | return self.param.min_size
358 |
359 | if sz >= self.param.min_size:
360 | return 1
361 | else:
362 | return int(ceil(self.param.min_size / sz))
363 |
364 | def _DFS(self, id, q, visited):
365 | if visited[id] == 1:
366 | raise RuntimeError
367 |
368 | visited[id] = 1
369 | for i in self.node_set[id].child_ids:
370 | if visited[i] < 2:
371 | q, visited = self._DFS(i, q, visited)
372 |
373 | if self.node_set[id].on_off:
374 | q.append(id)
375 |
376 | visited[id] = 2
377 |
378 | return q, visited
379 |
380 | def _BFS(self, id, q, visited):
381 | q = [id]
382 | visited[id] = 0 # count indegree
383 | i = 0
384 | while i < len(q):
385 | node = self.node_set[q[i]]
386 | for j in node.child_ids:
387 | visited[j] += 1
388 | if j not in q:
389 | q.append(j)
390 |
391 | i += 1
392 |
393 | q = [id]
394 | i = 0
395 | while i < len(q):
396 | node = self.node_set[q[i]]
397 | for j in node.child_ids:
398 | visited[j] -= 1
399 | if visited[j] == 0:
400 | q.append(j)
401 | i += 1
402 |
403 | return q, visited
404 |
405 | def _countPaths(self, s, t, npaths):
406 | if s.id == t.id:
407 | return 1.0
408 | else:
409 | if not npaths[s.id]:
410 | rect = self.primitive_set[s.rect_idx]
411 | #ids1 = set(s.ancestors_ids)
412 | ids1 = set(s.parent_ids)
413 | num_shared = 0
414 | for c in s.child_ids:
415 | ch = self.node_set[c]
416 | ch_rect = self.primitive_set[ch.rect_idx]
417 | #ids2 = ch.ancestors_ids
418 | ids2 = ch.parent_ids
419 |
420 | if s.node_type == NodeType.AndNode and ch.node_type == NodeType.AndNode and \
421 | rect.Width() == ch_rect.Width() and rect.Height() == ch_rect.Height():
422 | continue
423 | if s.node_type == NodeType.OrNode and \
424 | ((ch.node_type == NodeType.OrNode) or \
425 | (ch.node_type == NodeType.TerminalNode and (rect.Area() < ch_rect.Area()) )):
426 | continue
427 |
428 | npaths[s.id] += self._countPaths(ch, t, npaths)
429 | return npaths[s.id]
430 |
431 | def _AssignParentIds(self):
432 | for i in range(len(self.node_set)):
433 | self.node_set[i].parent_ids = []
434 |
435 | for node in self.node_set:
436 | for i in node.child_ids:
437 | self.node_set[i].parent_ids.append(node.id)
438 |
439 | for i in range(len(self.node_set)):
440 | self.node_set[i].parent_ids = list(set(self.node_set[i].parent_ids))
441 |
442 | def _AssignAncestorsIds(self):
443 | self._AssignParentIds()
444 |
445 | assert len(self.BFS) == len(self.node_set)
446 | self.node_set[self.BFS[0]].ancestors_ids = []
447 |
448 | for nid in self.BFS[1:]:
449 | node = self.node_set[nid]
450 | rect = self.primitive_set[node.rect_idx]
451 | ancestors = []
452 | for pid in node.parent_ids:
453 | p = self.node_set[pid]
454 | p_rect = self.primitive_set[p.rect_idx]
455 | equal_size = rect.Width() == p_rect.Width() and \
456 | rect.Height() == p_rect.Height()
457 | # AND-to-AND lateral path
458 | if node.node_type == NodeType.AndNode and p.node_type == NodeType.AndNode and \
459 | equal_size:
460 | continue
461 | # OR-to-OR/T lateral path
462 | if node.node_type == NodeType.OrNode and \
463 | ((p.node_type == NodeType.OrNode and equal_size) or \
464 | (p.node_type == NodeType.TerminalNode and (rect.Area() < p_rect.Area()) )):
465 | continue
466 | for ppid in p.ancestors_ids:
467 | if ppid != self.BFS[0] and ppid not in ancestors:
468 | ancestors.append(ppid)
469 | ancestors.append(pid)
470 | self.node_set[nid].ancestors_ids = list(set(ancestors))
471 |
472 | def _Postprocessing(self, root_id):
473 | self.DFS = []
474 | self.BFS = []
475 | visited = np.zeros(len(self.node_set))
476 | self.DFS, _ = self._DFS(root_id, self.DFS, visited)
477 | visited = np.zeros(len(self.node_set))
478 | self.BFS, _ = self._BFS(root_id, self.BFS, visited)
479 | self._AssignAncestorsIds()
480 |
481 | def _FindNodeIdWithGivenRect(self, rect, node_type):
482 | for node in self.node_set:
483 | if node.node_type != node_type:
484 | continue
485 | if rect == self.primitive_set[node.rect_idx]:
486 | return node.id
487 |
488 | return -1
489 |
490 | def _add_tnode_topdown_connection(self):
491 |
492 | assert self.param.use_root_terminal_node
493 |
494 | prim_type = [self.param.grid_ht, self.param.grid_wd]
495 | tnode_queue = self.find_node_ids_with_given_prim_type(prim_type)
496 | assert len(tnode_queue) == 1
497 |
498 | i = 0
499 | while i < len(tnode_queue):
500 | id_ = tnode_queue[i]
501 | node = self.node_set[id_]
502 | i += 1
503 |
504 | rect = self.primitive_set[node.rect_idx]
505 |
506 | ids = []
507 | for y in range(0, rect.Height()):
508 | for x in range(0, rect.Width()):
509 | if y == 0 and x == 0:
510 | continue
511 | prim_type = [rect.Height()-y, rect.Width()-x]
512 | ids += self.find_node_ids_with_given_prim_type(prim_type, rect)
513 |
514 | ids = list(set(ids))
515 | tnode_queue += ids
516 |
517 | for pid in ids:
518 | if id_ not in self.node_set[pid].child_ids:
519 | self.node_set[pid].child_ids.append(id_)
520 |
521 | def _add_onode_topdown_connection(self):
522 | assert self.param.use_root_terminal_node
523 |
524 | prim_type = [self.param.grid_ht, self.param.grid_wd]
525 | tnode_queue = self.find_node_ids_with_given_prim_type(prim_type)
526 | assert len(tnode_queue) == 1
527 |
528 | i = 0
529 | while i < len(tnode_queue):
530 | id_ = tnode_queue[i]
531 | node = self.node_set[id_]
532 | i += 1
533 |
534 | rect = self.primitive_set[node.rect_idx]
535 |
536 | ids = []
537 | ids_t = []
538 | for y in range(0, rect.Height()):
539 | for x in range(0, rect.Width()):
540 | if y == 0 and x == 0:
541 | continue
542 | prim_type = [rect.Height()-y, rect.Width()-x]
543 | ids += self.find_node_ids_with_given_prim_type(prim_type, rect, NodeType.OrNode)
544 | ids_t += self.find_node_ids_with_given_prim_type(prim_type, rect)
545 |
546 | ids = list(set(ids))
547 | ids_t = list(set(ids_t))
548 |
549 | for pid in ids:
550 | if id_ not in self.node_set[pid].child_ids:
551 | self.node_set[pid].child_ids.append(id_)
552 |
553 | def _add_tnode_bottomup_connection(self):
554 | assert self.param.use_root_terminal_node
555 |
556 | # primitive tnodes
557 | prim_type = [1, 1]
558 | t_ids = self.find_node_ids_with_given_prim_type(prim_type)
559 | assert len(t_ids) == self.param.grid_wd * self.param.grid_ht
560 |
561 | # other tnodes will be converted to and-nodes
562 | for h in range(1, self.param.grid_ht+1):
563 | for w in range(1, self.param.grid_wd+1):
564 | if h == 1 and w == 1:
565 | continue
566 | prim_type = [h, w]
567 | ids = self.find_node_ids_with_given_prim_type(prim_type)
568 | for id_ in ids:
569 | self.node_set[id_].node_type = NodeType.AndNode
570 | node = self.node_set[id_]
571 | rect = self.primitive_set[node.rect_idx]
572 | prim_type = [1, 1]
573 | for y in range(rect.y1, rect.y2+1):
574 | for x in range(rect.x1, rect.x2+1):
575 | parent_rect = Rect(x, y, x, y)
576 | ch_ids = self.find_node_ids_with_given_prim_type(prim_type, parent_rect)
577 | assert len(ch_ids) == 1
578 | if ch_ids[0] not in self.node_set[id_].child_ids:
579 | self.node_set[id_].child_ids.append(ch_ids[0])
580 |
581 | def _add_lateral_connection(self):
582 | self._add_node_bottomup_connection_layerwise(node_type=NodeType.AndNode, direction=1)
583 | self._add_node_bottomup_connection_layerwise(node_type=NodeType.OrNode, direction=0)
584 |
585 | if not self.param.use_node_lateral_connection_1:
586 | return self.BFS[0]
587 |
588 | # or for all or nodes
589 | for node in self.node_set:
590 | if node.node_type != NodeType.OrNode:
591 | continue
592 |
593 | ch_ids = node.child_ids
594 | numCh = len(ch_ids)
595 |
596 | hasLateral = False
597 | for id_ in ch_ids:
598 | if self.node_set[id_].node_type == NodeType.OrNode:
599 | hasLateral = True
600 | numCh -= 1
601 |
602 | minNumCh = 3 if hasLateral else 2
603 | if len(ch_ids) < minNumCh:
604 | continue
605 |
606 | # find t-node child
607 | ch0 = -1
608 | for id_ in ch_ids:
609 | if self.node_set[id_].node_type == NodeType.TerminalNode:
610 | ch0 = id_
611 | break
612 | assert ch0 != -1
613 |
614 | added = False
615 | for id_ in ch_ids:
616 | if id_ == ch0 or self.node_set[id_].node_type == NodeType.OrNode:
617 | continue
618 |
619 | if len(self.node_set[id_].child_ids) == 2 or numCh == 2:
620 | assert ch0 not in self.node_set[id_].child_ids
621 | self.node_set[id_].child_ids.append(ch0)
622 | added = True
623 |
624 | if not added:
625 | for id_ in ch_ids:
626 | if id_ == ch0 or self.node_set[id_].node_type == NodeType.OrNode:
627 | continue
628 |
629 | found = True
630 | for id__ in ch_ids:
631 | if id_ in self.node_set[id__].child_ids:
632 | found = False
633 | if found:
634 | assert ch0 not in self.node_set[id_].child_ids
635 | self.node_set[id_].child_ids.append(ch0)
636 |
637 | return self.BFS[0]
638 |
639 | def _add_node_bottomup_connection_layerwise(self, node_type=NodeType.TerminalNode, direction=0):
640 |
641 | prim_types = []
642 | for node in self.node_set:
643 | if node.node_type == node_type:
644 | rect = self.primitive_set[node.rect_idx]
645 | p = [rect.Height(), rect.Width()]
646 | if p not in prim_types:
647 | prim_types.append(p)
648 |
649 | change_direction = False
650 |
651 | prim_types.sort()
652 |
653 | for p in prim_types:
654 | ids = self.find_node_ids_with_given_prim_type(p, node_type=node_type)
655 | if len(ids) < 2:
656 | change_direction = True
657 | continue
658 |
659 | if change_direction:
660 | direction = 1 - direction
661 |
662 | yx = np.empty((0, 4 if node_type==NodeType.AndNode else 2), dtype=np.float32)
663 | for id_ in ids:
664 | node = self.node_set[id_]
665 | rect = self.primitive_set[node.rect_idx]
666 |
667 | if node_type == NodeType.AndNode:
668 | ch_node = self.node_set[node.child_ids[0]]
669 | ch_rect = self.primitive_set[ch_node.rect_idx]
670 | if ch_rect.x1 != rect.x1 or ch_rect.y1 != rect.y1:
671 | ch_node = self.node_set[node.child_ids[1]]
672 | ch_rect = self.primitive_set[ch_node.rect_idx]
673 | pos = (rect.y1, rect.x1, ch_rect.y2, ch_rect.x2)
674 | else:
675 | pos = (rect.y1, rect.x1)
676 | yx = np.vstack((yx, np.array(pos)))
677 |
678 | if node_type == NodeType.AndNode:
679 | ind = np.lexsort((yx[:, 1], yx[:, 0], yx[:, 3], yx[:, 2]))
680 | else:
681 | ind = np.lexsort((yx[:, 1], yx[:, 0]))
682 |
683 | istart = len(ind) - 1 if direction == 0 else 0
684 | iend = 0 if direction == 0 else len(ind) - 1
685 | step = -1 if direction == 0 else 1
686 | for i in range(istart, iend, step):
687 | id_ = ids[ind[i]]
688 | chid = ids[ind[i - 1]] if direction==0 else ids[ind[i+1]]
689 | if chid not in self.node_set[id_].child_ids:
690 | self.node_set[id_].child_ids.append(chid)
691 |
692 | if change_direction:
693 | direction = 1 - direction
694 | change_direction = False
695 |
696 | def _add_tnode_bottomup_connection_sequential(self):
697 |
698 | assert self.param.grid_wd > 1 and self.param.grid_ht == 1
699 |
700 | self._add_node_bottomup_connection_layerwise()
701 |
702 | for node in self.node_set:
703 | if node.node_type != NodeType.TerminalNode:
704 | continue
705 | rect = self.primitive_set[node.rect_idx]
706 | if rect.Width() == 1:
707 | continue
708 |
709 | rect1 = copy.deepcopy(rect)
710 | rect1.x1 += 1
711 | chid = self._FindNodeIdWithGivenRect(rect1, NodeType.TerminalNode)
712 | if chid != -1:
713 | self.node_set[node.id].child_ids.append(chid)
714 |
715 | def _mark_symmetric_subgraph(self):
716 |
717 | for i in self.BFS:
718 | node = self.node_set[i]
719 |
720 | if node.is_symmetric or node.node_type == NodeType.TerminalNode:
721 | continue
722 |
723 | if i != self.BFS[0]:
724 | is_symmetric = True
725 | for j in node.parent_ids:
726 | p = self.node_set[j]
727 | if not p.is_symmetric:
728 | is_symmetric = False
729 | break
730 | if is_symmetric:
731 | self.node_set[i].is_symmetric = True
732 | continue
733 |
734 | rect = self.primitive_set[node.rect_idx]
735 | Wd = rect.Width()
736 | Ht = rect.Height()
737 |
738 | if node.node_type == NodeType.OrNode:
739 | # mark symmetric children
740 | useSplitWds = []
741 | useSplitHts = []
742 | if self.param.remove_symmetric_children_of_or_node == 2:
743 | child_ids = node.child_ids[::-1]
744 | else:
745 | child_ids = node.child_ids
746 |
747 | for j in child_ids:
748 | ch = self.node_set[j]
749 | if ch.node_type == NodeType.TerminalNode:
750 | continue
751 |
752 | if ch.split_type == SplitType.VerSplit:
753 | if (Wd-ch.split_step2, ch.split_step1) not in useSplitWds:
754 | useSplitWds.append((ch.split_step1, Wd-ch.split_step2))
755 | else:
756 | self.node_set[j].is_symmetric = True
757 |
758 | elif ch.split_type == SplitType.HorSplit:
759 | if (Ht-ch.split_step2, ch.split_step1) not in useSplitHts:
760 | useSplitHts.append((ch.split_step1, Ht-ch.split_step2))
761 | else:
762 | self.node_set[j].is_symmetric = True
763 |
764 | def _find_dbl_counting_or_nodes(self):
765 | for node in self.node_set:
766 | if node.node_type != NodeType.OrNode or len(node.child_ids) < 2:
767 | continue
768 | for i in self.node_BFS[node.id][1:]:
769 | npaths = { x : 0 for x in self.node_BFS[node.id] }
770 | n = self._countPaths(node, self.node_set[i], npaths)
771 | if n > 1:
772 | self.node_set[node.id].has_dbl_counting = True
773 | break
774 |
775 | def find_node_ids_with_given_prim_type(self, prim_type, parent_rect=None, node_type=NodeType.TerminalNode):
776 | ids = []
777 | for node in self.node_set:
778 | if node.node_type != node_type:
779 | continue
780 | rect = self.primitive_set[node.rect_idx]
781 | if [rect.Height(), rect.Width()] == prim_type:
782 | if parent_rect is not None:
783 | if rect.x1 >= parent_rect.x1 and rect.y1 >= parent_rect.y1 and \
784 | rect.x2 <= parent_rect.x2 and rect.y2 <= parent_rect.y2:
785 | ids.append(node.id)
786 | else:
787 | ids.append(node.id)
788 | return ids
789 |
790 | def Create(self):
791 | # print("======= creating AOGrid {}, could take a while".format(self.param.tag))
792 | # FIXME: when remove_symmetric_children_of_or_node is true, top-down hierarchy is not correctly created.
793 |
794 | # the root OrNode
795 | rect = Rect(0, 0, self.param.grid_wd - 1, self.param.grid_ht - 1)
796 | self.primitive_set.append(rect)
797 | node = Node(node_type=NodeType.OrNode, rect_idx=0)
798 | self._AddNode(node)
799 |
800 | BFS = deque()
801 | BFS.append(0)
802 | keepLeft = True
803 | keepTop = True
804 | while len(BFS) > 0:
805 | curId = BFS.popleft()
806 | curNode = self.node_set[curId]
807 | curRect = self.primitive_set[curNode.rect_idx]
808 | curWd = curRect.Width()
809 | curHt = curRect.Height()
810 |
811 | childIds = []
812 |
813 | if curNode.node_type == NodeType.OrNode:
814 | num_child_node_kept = 0
815 | # add a terminal node for a non-root OrNode
816 | allowTerminate = not ((self.param.not_use_large_terminal_node and
817 | float(curWd * curHt) / float(self.param.grid_ht * self.param.grid_wd) >
818 | self.param.turn_off_size_ratio_terminal_node) or
819 | (self.param.not_use_intermediate_TerminalNode and (curWd > self.param.min_size or curHt > self.param.min_size)))
820 |
821 | if (curId > 0 and allowTerminate) or (curId==0 and self.param.use_root_terminal_node):
822 | node = Node(node_type=NodeType.TerminalNode, rect_idx=curNode.rect_idx)
823 | suc, node = self._AddNode(node)
824 | childIds.append(node.id)
825 | num_child_node_kept += 1
826 |
827 | # add all AndNodes for horizontal and vertical binary splits
828 | if not self._DoSplit(curRect):
829 | childIds = list(set(childIds))
830 | self.node_set[curId].child_ids = childIds
831 | continue
832 |
833 | num_child_node_to_add = self.param.max_children_kept_for_or - num_child_node_kept
834 | stepH = self._SplitStep(curWd)
835 | stepV = self._SplitStep(curHt)
836 | num_stepH = curHt - stepH + 1 - stepH
837 | num_stepV = curWd - stepV + 1 - stepV
838 | if num_stepH == 0 and num_stepV == 0:
839 | childIds = list(set(childIds))
840 | self.node_set[curId].child_ids = childIds
841 | continue
842 |
843 | num_child_node_to_add_H = num_stepH / float(num_stepH + num_stepV) * num_child_node_to_add
844 | num_child_node_to_add_V = num_child_node_to_add - num_child_node_to_add_H
845 |
846 | stepH_step = int(
847 | max(1, floor(float(num_stepH) / num_child_node_to_add_H) if num_child_node_to_add_H > 0 else 0))
848 | stepV_step = int(
849 | max(1, floor(float(num_stepV) / num_child_node_to_add_V) if num_child_node_to_add_V > 0 else 0))
850 |
851 | # horizontal splits
852 | step = stepH
853 | num_child_node_added_H = 0
854 |
855 | splitHts = []
856 | for topHt in range(step, curHt - step + 1, stepH_step):
857 | if num_child_node_added_H >= num_child_node_to_add_H:
858 | break
859 |
860 | bottomHt = curHt - topHt
861 | if self.param.overlap_ratio > 0:
862 | numSplit = int(1 + floor(topHt * self.param.overlap_ratio))
863 | else:
864 | numSplit = 1
865 | for b in range(0, numSplit):
866 | splitHts.append((topHt, bottomHt))
867 | bottomHt += 1
868 | num_child_node_added_H += 1
869 |
870 | if self.param.remove_symmetric_children_of_or_node == 1 and self.param.mark_symmetric_syntatic_subgraph == False:
871 | useSplitHts = []
872 | for (topHt, bottomHt) in splitHts:
873 | if (bottomHt, topHt) not in useSplitHts:
874 | useSplitHts.append((topHt, bottomHt))
875 | elif self.param.remove_symmetric_children_of_or_node == 2 and self.param.mark_symmetric_syntatic_subgraph == False:
876 | useSplitHts = []
877 | for (topHt, bottomHt) in reversed(splitHts):
878 | if (bottomHt, topHt) not in useSplitHts:
879 | useSplitHts.append((topHt, bottomHt))
880 | else:
881 | useSplitHts = splitHts
882 |
883 | for (topHt, bottomHt) in useSplitHts:
884 | node = Node(node_type=NodeType.AndNode, rect_idx=curNode.rect_idx,
885 | split_type=SplitType.HorSplit,
886 | split_step1=topHt, split_step2=curHt - bottomHt)
887 | suc, node = self._AddNode(node)
888 | if suc:
889 | BFS.append(node.id)
890 | childIds.append(node.id)
891 |
892 | # vertical splits
893 | step = stepV
894 | num_child_node_added_V = 0
895 |
896 | splitWds = []
897 | for leftWd in range(step, curWd - step + 1, stepV_step):
898 | if num_child_node_added_V >= num_child_node_to_add_V:
899 | break
900 |
901 | rightWd = curWd - leftWd
902 | if self.param.overlap_ratio > 0:
903 | numSplit = int(1 + floor(leftWd * self.param.overlap_ratio))
904 | else:
905 | numSplit = 1
906 | for r in range(0, numSplit):
907 | splitWds.append((leftWd, rightWd))
908 | rightWd += 1
909 | num_child_node_added_V += 1
910 |
911 | if self.param.remove_symmetric_children_of_or_node == 1 and self.param.mark_symmetric_syntatic_subgraph == False:
912 | useSplitWds = []
913 | for (leftWd, rightWd) in splitWds:
914 | if (rightWd, leftWd) not in useSplitWds:
915 | useSplitWds.append((leftWd, rightWd))
916 | elif self.param.remove_symmetric_children_of_or_node == 2 and self.param.mark_symmetric_syntatic_subgraph == False:
917 | useSplitWds = []
918 | for (leftWd, rightWd) in reversed(splitWds):
919 | if (rightWd, leftWd) not in useSplitWds:
920 | useSplitWds.append((leftWd, rightWd))
921 | else:
922 | useSplitWds = splitWds
923 |
924 | for (leftWd, rightWd) in useSplitWds:
925 | node = Node(node_type=NodeType.AndNode, rect_idx=curNode.rect_idx,
926 | split_type=SplitType.VerSplit,
927 | split_step1=leftWd, split_step2=curWd - rightWd)
928 | suc, node = self._AddNode(node)
929 | if suc:
930 | BFS.append(node.id)
931 | childIds.append(node.id)
932 |
933 | elif curNode.node_type == NodeType.AndNode:
934 | # add two child OrNodes
935 | if curNode.split_type == SplitType.HorSplit:
936 | top = Rect(x1=curRect.x1, y1=curRect.y1,
937 | x2=curRect.x2, y2=curRect.y1 + curNode.split_step1 - 1)
938 | node = Node(node_type=NodeType.OrNode, rect_idx=self._AddPrimitve(top))
939 | suc, node = self._AddNode(node)
940 | if suc:
941 | BFS.append(node.id)
942 | childIds.append(node.id)
943 |
944 | bottom = Rect(x1=curRect.x1, y1=curRect.y1 + curNode.split_step2,
945 | x2=curRect.x2, y2=curRect.y2)
946 | node = Node(node_type=NodeType.OrNode, rect_idx=self._AddPrimitve(bottom))
947 | suc, node = self._AddNode(node)
948 | if suc:
949 | BFS.append(node.id)
950 | childIds.append(node.id)
951 | elif curNode.split_type == SplitType.VerSplit:
952 | left = Rect(curRect.x1, curRect.y1,
953 | curRect.x1 + curNode.split_step1 - 1, curRect.y2)
954 | node = Node(node_type=NodeType.OrNode, rect_idx=self._AddPrimitve(left))
955 | suc, node = self._AddNode(node)
956 | if suc:
957 | BFS.append(node.id)
958 | childIds.append(node.id)
959 |
960 | right = Rect(curRect.x1 + curNode.split_step2, curRect.y1,
961 | curRect.x2, curRect.y2)
962 | node = Node(node_type=NodeType.OrNode, rect_idx=self._AddPrimitve(right))
963 | suc, node = self._AddNode(node)
964 | if suc:
965 | BFS.append(node.id)
966 | childIds.append(node.id)
967 |
968 | childIds = list(set(childIds))
969 | self.node_set[curId].child_ids = childIds
970 |
971 | # add root terminal node if allowed
972 | root_id = 0
973 |
974 | # create And-nodes with more than 2 children
975 | # TODO: handle remove_symmetric_child_node
976 | if self.param.max_split > 2:
977 | for branch in range(3, self.param.max_split + 1):
978 | for node in self.node_set:
979 | if node.node_type != NodeType.OrNode:
980 | continue
981 |
982 | new_and_ids = []
983 |
984 | for cur_id in node.child_ids:
985 | cur_and = self.node_set[cur_id]
986 | if len(cur_and.child_ids) != branch - 1:
987 | continue
988 | assert cur_and.node_type == NodeType.AndNode
989 |
990 | for ch_id in cur_and.child_ids:
991 | ch = self.node_set[ch_id]
992 | curRect = self.primitive_set[ch.rect_idx]
993 | curWd = curRect.Width()
994 | curHt = curRect.Height()
995 |
996 | # split ch into two to create new And-nodes
997 |
998 | # add all AndNodes for horizontal and vertical binary splits
999 | if not self._DoSplit(curRect):
1000 | continue
1001 |
1002 | # horizontal splits
1003 | step = self._SplitStep(curWd)
1004 | for topHt in range(step, curHt - step + 1):
1005 | bottomHt = curHt - topHt
1006 | if self.param.overlap_ratio > 0:
1007 | numSplit = int(1 + floor(topHt * self.param.overlap_ratio))
1008 | else:
1009 | numSplit = 1
1010 | for b in range(0, numSplit):
1011 | split_step1 = topHt
1012 | split_step2 = curHt - bottomHt
1013 |
1014 | top = Rect(x1=curRect.x1, y1=curRect.y1,
1015 | x2=curRect.x2, y2=curRect.y1 + split_step1 - 1)
1016 | top_id = self._FindNodeIdWithGivenRect(top, NodeType.OrNode)
1017 | if top_id == -1:
1018 | continue
1019 | # assert top_id != -1
1020 |
1021 | bottom = Rect(x1=curRect.x1, y1=curRect.y1 + split_step2,
1022 | x2=curRect.x2, y2=curRect.y2)
1023 | bottom_id = self._FindNodeIdWithGivenRect(bottom, NodeType.OrNode)
1024 | if bottom_id == -1:
1025 | continue
1026 | # assert bottom_id != -1
1027 |
1028 | # add a new And-node
1029 | new_and = Node(node_type=NodeType.AndNode, rect_idx=cur_and.rect_idx)
1030 | new_and.child_ids = list(set(cur_and.child_ids) - set([ch_id])) + [top_id,
1031 | bottom_id]
1032 |
1033 | suc, new_and = self._AddNode(new_and)
1034 | new_and_ids.append(new_and.id)
1035 |
1036 | bottomHt += 1
1037 |
1038 | # vertical splits
1039 | step = self._SplitStep(curHt)
1040 | for leftWd in range(step, curWd - step + 1):
1041 | rightWd = curWd - leftWd
1042 |
1043 | if self.param.overlap_ratio > 0:
1044 | numSplit = int(1 + floor(leftWd * self.param.overlap_ratio))
1045 | else:
1046 | numSplit = 1
1047 | for r in range(0, numSplit):
1048 | split_step1 = leftWd
1049 | split_step2 = curWd - rightWd
1050 |
1051 | left = Rect(curRect.x1, curRect.y1,
1052 | curRect.x1 + split_step1 - 1, curRect.y2)
1053 | left_id = self._FindNodeIdWithGivenRect(left, NodeType.OrNode)
1054 | if left_id == -1:
1055 | continue
1056 | # assert left_id != -1
1057 |
1058 | right = Rect(curRect.x1 + split_step2, curRect.y1,
1059 | curRect.x2, curRect.y2)
1060 | right_id = self._FindNodeIdWithGivenRect(right, NodeType.OrNode)
1061 | if right_id == -1:
1062 | continue
1063 | # assert right_id != -1
1064 |
1065 | # add a new And-node
1066 | new_and = Node(node_type=NodeType.AndNode, rect_idx=cur_and.rect_idx)
1067 | new_and.child_ids = list(set(cur_and.child_ids) - set([ch_id])) + [left_id,
1068 | right_id]
1069 |
1070 | suc, new_and = self._AddNode(new_and)
1071 | new_and_ids.append(new_and.id)
1072 |
1073 | rightWd += 1
1074 |
1075 | self.node_set[node.id].child_ids = list(set(self.node_set[node.id].child_ids + new_and_ids))
1076 |
1077 | self._Postprocessing(root_id)
1078 |
1079 | # change tnodes to child nodes of and-nodes / or-nodes
1080 | if self.param.use_tnode_as_alpha_channel > 0:
1081 | node_type = NodeType.OrNode if self.param.use_tnode_as_alpha_channel==1 else NodeType.AndNode
1082 | not_create_if_existed = not self.param.use_tnode_as_alpha_channel==1
1083 | for id_ in self.BFS:
1084 | node = self.node_set[id_]
1085 | if node.node_type == NodeType.OrNode and len(node.child_ids) > 1:
1086 | for ch in node.child_ids:
1087 | ch_node = self.node_set[ch]
1088 | if ch_node.node_type == NodeType.TerminalNode:
1089 | new_parent_node = Node(node_type=node_type, rect_idx=ch_node.rect_idx)
1090 | _, new_parent_node = self._AddNode(new_parent_node, not_create_if_existed)
1091 | new_parent_node.child_ids = [ch_node.id, node.id]
1092 |
1093 | for pr in node.parent_ids:
1094 | pr_node = self.node_set[pr]
1095 | for i, pr_ch in enumerate(pr_node.child_ids):
1096 | if pr_ch == node.id:
1097 | pr_node.child_ids[i] = new_parent_node.id
1098 | break
1099 |
1100 | self.node_set[id_].child_ids.remove(ch)
1101 | if id_ == self.BFS[0]:
1102 | root_id = new_parent_node.id
1103 | break
1104 |
1105 | self._Postprocessing(root_id)
1106 |
1107 | # add super-or node
1108 | if self.param.use_super_OrNode:
1109 | super_or_node = Node(node_type=NodeType.OrNode, rect_idx=-1)
1110 | _, super_or_node = self._AddNode(super_or_node)
1111 | super_or_node.child_ids = []
1112 | for node in self.node_set:
1113 | if node.node_type == NodeType.OrNode and node.rect_idx != -1:
1114 | rect = self.primitive_set[node.rect_idx]
1115 | r = float(rect.Area()) / float(self.param.grid_ht * self.param.grid_wd)
1116 | if r > 0.5:
1117 | super_or_node.child_ids.append(node.id)
1118 |
1119 | root_id = super_or_node.id
1120 |
1121 | self._Postprocessing(root_id)
1122 |
1123 | # remove or-nodes with single child node
1124 | if self.param.remove_single_child_or_node:
1125 | remove_ids = []
1126 | for node in self.node_set:
1127 | if node.node_type == NodeType.OrNode and len(node.child_ids) == 1:
1128 | for pr in node.parent_ids:
1129 | pr_node = self.node_set[pr]
1130 | for i, pr_ch in enumerate(pr_node.child_ids):
1131 | if pr_ch == node.id:
1132 | pr_node.child_ids[i] = node.child_ids[0]
1133 | break
1134 |
1135 | remove_ids.append(node.id)
1136 | node.child_ids = []
1137 |
1138 | remove_ids.sort()
1139 | remove_ids.reverse()
1140 |
1141 | for id_ in remove_ids:
1142 | for node in self.node_set:
1143 | if node.id > id_:
1144 | node.id -= 1
1145 | for i, ch in enumerate(node.child_ids):
1146 | if ch > id_:
1147 | node.child_ids[i] -= 1
1148 |
1149 | if root_id > id_:
1150 | root_id -= 1
1151 |
1152 | for id_ in remove_ids:
1153 | del self.node_set[id_]
1154 |
1155 | self._Postprocessing(root_id)
1156 |
1157 | # mark symmetric nodes
1158 | if self.param.mark_symmetric_syntatic_subgraph:
1159 | self._mark_symmetric_subgraph()
1160 |
1161 | # add tnode hierarchy
1162 | if self.param.use_tnode_topdown_connection:
1163 | self._add_tnode_topdown_connection()
1164 | self._Postprocessing(root_id)
1165 | elif self.param.use_tnode_bottomup_connection:
1166 | self._add_tnode_bottomup_connection()
1167 | self._Postprocessing(root_id)
1168 | elif self.param.use_tnode_bottomup_connection_layerwise:
1169 | self._add_node_bottomup_connection_layerwise()
1170 | self._Postprocessing(root_id)
1171 | elif self.param.use_tnode_bottomup_connection_sequential:
1172 | self._add_tnode_bottomup_connection_sequential()
1173 | self._Postprocessing(root_id)
1174 | elif self.param.use_node_lateral_connection or self.param.use_node_lateral_connection_1:
1175 | root_id = self._add_lateral_connection()
1176 | self._Postprocessing(root_id)
1177 |
1178 | # index of Or-nodes in BFS
1179 | self.OrNodeIdxInBFS = {}
1180 | self.TNodeIdxInBFS = {}
1181 | idx_or = 0
1182 | idx_t = 0
1183 | for id_ in self.BFS:
1184 | node = self.node_set[id_]
1185 | if node.node_type == NodeType.OrNode:
1186 | self.OrNodeIdxInBFS[node.id] = idx_or
1187 | idx_or += 1
1188 | elif node.node_type == NodeType.TerminalNode:
1189 | self.TNodeIdxInBFS[node.id] = idx_t
1190 | idx_t += 1
1191 |
1192 | # get DFS and BFS rooted at each node
1193 | for node in self.node_set:
1194 | if node.node_type == NodeType.TerminalNode:
1195 | continue
1196 | visited = np.zeros(len(self.node_set))
1197 | self.node_DFS[node.id] = []
1198 | self.node_DFS[node.id], _ = self._DFS(node.id, self.node_DFS[node.id], visited)
1199 |
1200 | visited = np.zeros(len(self.node_set))
1201 | self.node_BFS[node.id] = []
1202 | self.node_BFS[node.id], _ = self._BFS(node.id, self.node_BFS[node.id], visited)
1203 |
1204 | # count paths between nodes and root node
1205 | for n in self.node_set:
1206 | npaths = { x.id : 0 for x in self.node_set }
1207 | self.node_set[n.id].npaths = self._countPaths(self.node_set[self.BFS[0]], n, npaths)
1208 |
1209 | # find ornode with double-counting children
1210 | self._find_dbl_counting_or_nodes()
1211 |
1212 | # generate colors for terminal nodes for consistency in visualization
1213 | self.TNodeColors = {}
1214 | for node in self.node_set:
1215 | if node.node_type == NodeType.TerminalNode:
1216 | self.TNodeColors[node.id] = (
1217 | random.random(), random.random(), random.random()) # generate a random color
1218 |
1219 |
1220 | def TurnOnOffNodes(self, on_off):
1221 | for i in range(len(self.node_set)):
1222 | self.node_set[i].on_off = on_off
1223 |
1224 | def UpdateOnOffNodes(self, pg, offset_using_part_type, class_name=''):
1225 | BFS = [self.BFS[0]]
1226 | pg_used = np.ones((1, len(pg)), dtype=np.int) * -1
1227 | configuration = []
1228 | tnode_offset_indx = []
1229 | while len(BFS):
1230 | id = BFS.pop()
1231 | node = self.node_set[id]
1232 | self.node_set[id].on_off = True
1233 | if len(class_name):
1234 | if class_name in node.which_classes_visited.keys():
1235 | self.node_set[id].which_classes_visited[class_name] += 1.0
1236 | else:
1237 | self.node_set[id].which_classes_visited[class_name] = 0
1238 |
1239 | if node.node_type == NodeType.OrNode:
1240 | idx = self.OrNodeIdxInBFS[node.id]
1241 | BFS.append(node.child_ids[int(pg[idx])])
1242 | pg_used[0, idx] = int(pg[idx])
1243 | if len(self.node_set[id].out_edge_visited_count):
1244 | self.node_set[id].out_edge_visited_count[int(pg[idx])] += 1.0
1245 | else:
1246 | self.node_set[id].out_edge_visited_count = np.zeros((len(node.child_ids),), dtype=np.float32)
1247 | elif node.node_type == NodeType.AndNode:
1248 | BFS += node.child_ids
1249 |
1250 | else:
1251 | configuration.append(node.id)
1252 |
1253 | offset_ind = 0
1254 | if not offset_using_part_type:
1255 | for node1 in self.node_set:
1256 | if node1.node_type == NodeType.TerminalNode: # change to BFS after _part_instance is changed to BFS
1257 | if node1.id == node.id:
1258 | break
1259 | offset_ind += 1
1260 | else:
1261 | rect = self.primitive_set[node.rect_idx]
1262 | offset_ind = self.part_type.index([rect.Height(), rect.Width()])
1263 |
1264 | tnode_offset_indx.append(offset_ind)
1265 |
1266 | configuration.sort()
1267 | cfg = np.ones((1, self.num_TNodes), dtype=np.int) * -1
1268 | cfg[0, :len(configuration)] = configuration
1269 | return pg_used, cfg, tnode_offset_indx
1270 |
1271 | def ResetOutEdgeVisitedCountNodes(self):
1272 | for i in range(len(self.node_set)):
1273 | self.node_set[i].out_edge_visited_count = []
1274 |
1275 | def NormalizeOutEdgeVisitedCountNodes(self, count=0):
1276 | if count == 0:
1277 | for i in range(len(self.node_set)):
1278 | if len(self.node_set[i].out_edge_visited_count):
1279 | count = max(count, max(self.node_set[i].out_edge_visited_count))
1280 |
1281 | if count == 0:
1282 | return
1283 |
1284 | for i in range(len(self.node_set)):
1285 | if len(self.node_set[i].out_edge_visited_count):
1286 | self.node_set[i].out_edge_visited_count /= count
1287 |
1288 | def ResetWhichClassesVisitedNodes(self):
1289 | for i in range(len(self.node_set)):
1290 | self.node_set[i].which_classes_visited = {}
1291 |
1292 | def NormalizeWhichClassesVisitedNodes(self, class_name, count):
1293 | assert count > 0
1294 | for i in range(len(self.node_set)):
1295 | if class_name in self.node_set[i].which_classes_visited.keys():
1296 | self.node_set[i].which_classes_visited[class_name] /= count
1297 |
--------------------------------------------------------------------------------
/models/aognet/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/iVMCL/AOGNet-v2/a95a8696c131331607e81bb31eeae3405a76b969/models/aognet/__init__.py
--------------------------------------------------------------------------------
/models/aognet/aognet.py:
--------------------------------------------------------------------------------
1 | """ RESEARCH ONLY LICENSE
2 | Copyright (c) 2018-2019 North Carolina State University.
3 | All rights reserved.
4 | Redistribution and use in source and binary forms, with or without modification, are permitted provided
5 | that the following conditions are met:
6 | 1. Redistributions and use are permitted for internal research purposes only, and commercial use
7 | is strictly prohibited under this license. Inquiries regarding commercial use should be directed to the
8 | Office of Research Commercialization at North Carolina State University, 919-215-7199,
9 | https://research.ncsu.edu/commercialization/contact/, commercialization@ncsu.edu .
10 | 2. Commercial use means the sale, lease, export, transfer, conveyance or other distribution to a
11 | third party for financial gain, income generation or other commercial purposes of any kind, whether
12 | direct or indirect. Commercial use also means providing a service to a third party for financial gain,
13 | income generation or other commercial purposes of any kind, whether direct or indirect.
14 | 3. Redistributions of source code must retain the above copyright notice, this list of conditions and
15 | the following disclaimer.
16 | 4. Redistributions in binary form must reproduce the above copyright notice, this list of conditions
17 | and the following disclaimer in the documentation and/or other materials provided with the
18 | distribution.
19 | 5. The names “North Carolina State University”, “NCSU” and any trade-name, personal name,
20 | trademark, trade device, service mark, symbol, image, icon, or any abbreviation, contraction or
21 | simulation thereof owned by North Carolina State University must not be used to endorse or promote
22 | products derived from this software without prior written permission. For written permission, please
23 | contact trademarks@ncsu.edu.
24 | Disclaimer: THIS SOFTWARE IS PROVIDED “AS IS” AND ANY EXPRESSED OR IMPLIED WARRANTIES,
25 | INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
26 | A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NORTH CAROLINA STATE UNIVERSITY BE
27 | LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
28 | (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
29 | DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
30 | LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR
31 | OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
32 | POSSIBILITY OF SUCH DAMAGE.
33 | """
34 | # The system is protected via patent (pending)
35 | # Written by Tianfu Wu and Xilai Li
36 | # Contact: {xli47, tianfu_wu}@ncsu.edu
37 | from __future__ import absolute_import
38 | from __future__ import division
39 | from __future__ import print_function # force to use print as function print(args)
40 | from __future__ import unicode_literals
41 |
42 | import torch
43 | import torch.nn as nn
44 | import torch.nn.functional as F
45 | from torch.autograd import Variable
46 |
47 | import scipy.stats as stats
48 |
49 | from models.config import cfg
50 | from .AOG import *
51 | from .operator_basic import *
52 | from .operator_singlescale import *
53 |
54 | ### AOG building block
55 | class AOGBlock(nn.Module):
56 | def __init__(self, stage, block, aog, in_channels, out_channels, drop_rate, stride):
57 | super(AOGBlock, self).__init__()
58 | self.stage = stage
59 | self.block = block
60 | self.aog = aog
61 | self.in_channels = in_channels
62 | self.out_channels = out_channels
63 | self.drop_rate = drop_rate
64 | self.stride = stride
65 |
66 | self.dim = aog.param.grid_wd
67 | self.in_slices = self._calculate_slices(self.dim, in_channels)
68 | self.out_slices = self._calculate_slices(self.dim, out_channels)
69 |
70 | self.node_set = aog.node_set
71 | self.primitive_set = aog.primitive_set
72 | self.BFS = aog.BFS
73 | self.DFS = aog.DFS
74 |
75 | self.hasLateral = {}
76 | self.hasDblCnt = {}
77 |
78 | self.primitiveDblCnt = None
79 | self._set_primitive_dbl_cnt()
80 |
81 | if "BatchNorm2d" in cfg.norm_name:
82 | self.norm_name_base = "BatchNorm2d"
83 | elif "GroupNorm" in cfg.norm_name:
84 | self.norm_name_base = "GroupNorm"
85 | else:
86 | raise ValueError("Unknown norm layer")
87 |
88 | self._set_weights_attr()
89 |
90 | self.extra_norm_ac = self._extra_norm_ac()
91 |
92 | def _calculate_slices(self, dim, channels):
93 | slices = [0] * dim
94 | for i in range(channels):
95 | slices[i % dim] += 1
96 | for d in range(1, dim):
97 | slices[d] += slices[d - 1]
98 | slices = [0] + slices
99 | return slices
100 |
101 | def _set_primitive_dbl_cnt(self):
102 | self.primitiveDblCnt = [0.0 for i in range(self.dim)]
103 | for id_ in self.DFS:
104 | node = self.node_set[id_]
105 | arr = self.primitive_set[node.rect_idx]
106 | if node.node_type == NodeType.TerminalNode:
107 | for i in range(arr.x1, arr.x2+1):
108 | self.primitiveDblCnt[i] += node.npaths
109 | for i in range(self.dim):
110 | assert self.primitiveDblCnt[i] >= 1.0
111 |
112 | def _create_op(self, node_id, cin, cout, stride, groups=1,
113 | keep_norm_base=False, norm_k=0):
114 | replace_stride = cfg.aognet.replace_stride_with_avgpool
115 | setattr(self, 'stage_{}_block_{}_node_{}_op'.format(self.stage, self.block, node_id),
116 | NodeOpSingleScale(cin, cout, stride,
117 | groups=groups, drop_rate=self.drop_rate,
118 | ac_mode=cfg.activation_mode,
119 | bn_ratio=cfg.aognet.bottleneck_ratio,
120 | norm_name=self.norm_name_base if keep_norm_base else cfg.norm_name,
121 | norm_groups=cfg.norm_groups,
122 | norm_k = norm_k,
123 | norm_attention_mode=cfg.norm_attention_mode,
124 | replace_stride_with_avgpool=replace_stride))
125 |
126 | def _set_weights_attr(self):
127 | for id_ in self.DFS:
128 | node = self.node_set[id_]
129 | arr = self.primitive_set[node.rect_idx]
130 | bn_ratio = cfg.aognet.bottleneck_ratio
131 | width_per_group = cfg.aognet.width_per_group
132 | keep_norm_base = arr.Width() ch_arr.Width():
162 | if node.npaths / self.node_set[chid].npaths != 1.0:
163 | self.hasDblCnt[node.id] = True
164 | break
165 | self._create_op(node.id, plane, plane, stride, groups=groups,
166 | keep_norm_base=keep_norm_base, norm_k=norm_k)
167 |
168 | elif node.node_type == NodeType.OrNode:
169 | assert self.node_set[node.child_ids[0]].node_type != NodeType.OrNode
170 | plane = self.out_slices[arr.x2 + 1] - self.out_slices[arr.x1]
171 | stride = 1
172 | groups = max(1, to_int(plane * bn_ratio / width_per_group)) \
173 | if cfg.aognet.use_group_conv else 1
174 | self.hasLateral[node.id] = False
175 | self.hasDblCnt[node.id] = False
176 | for chid in node.child_ids:
177 | ch_arr = self.primitive_set[self.node_set[chid].rect_idx]
178 | if self.node_set[chid].node_type == NodeType.OrNode or arr.Width() < ch_arr.Width():
179 | self.hasLateral[node.id] = True
180 | break
181 | if cfg.aognet.handle_dbl_cnt:
182 | for chid in node.child_ids:
183 | ch_arr = self.primitive_set[self.node_set[chid].rect_idx]
184 | if not (self.node_set[chid].node_type == NodeType.OrNode or arr.Width() < ch_arr.Width()):
185 | if node.npaths / self.node_set[chid].npaths != 1.0:
186 | self.hasDblCnt[node.id] = True
187 | break
188 | self._create_op(node.id, plane, plane, stride, groups=groups,
189 | keep_norm_base=keep_norm_base, norm_k=norm_k)
190 |
191 | def _extra_norm_ac(self):
192 | return nn.Sequential(FeatureNorm(self.norm_name_base, self.out_channels,
193 | cfg.norm_groups, cfg.norm_k[self.stage],
194 | cfg.norm_attention_mode),
195 | AC(cfg.activation_mode))
196 |
197 | def forward(self, x):
198 | NodeIdTensorDict = {}
199 |
200 | # handle input x
201 | tnode_dblcnt = False
202 | if cfg.aognet.handle_tnode_dbl_cnt and self.in_channels==self.out_channels:
203 | x_scaled = []
204 | for i in range(self.dim):
205 | left, right = self.in_slices[i], self.in_slices[i+1]
206 | x_scaled.append(x[:, left:right, :, :].div(self.primitiveDblCnt[i]))
207 | xx = torch.cat(x_scaled, 1)
208 | tnode_dblcnt = True
209 |
210 | # T-nodes, (hope they will be computed in parallel by pytorch)
211 | for id_ in self.DFS:
212 | node = self.node_set[id_]
213 | op_name = 'stage_{}_block_{}_node_{}_op'.format(self.stage, self.block, node.id)
214 | if node.node_type == NodeType.TerminalNode:
215 | arr = self.primitive_set[node.rect_idx]
216 | right, left = self.in_slices[arr.x2 + 1], self.in_slices[arr.x1]
217 | tnode_tensor_op = x if cfg.aognet.terminal_node_no_slice[self.stage] else x[:, left:right, :, :].contiguous()
218 | # assert tnode_tensor.requires_grad, 'slice needs to retain grad'
219 | if tnode_dblcnt:
220 | tnode_tensor_res = xx[:, left:right, :, :].mul(node.npaths)
221 | tnode_output = getattr(self, op_name)(tnode_tensor_op, tnode_tensor_res)
222 | else:
223 | tnode_output = getattr(self, op_name)(tnode_tensor_op)
224 | NodeIdTensorDict[node.id] = tnode_output
225 |
226 | # AND- and OR-nodes
227 | for id_ in self.DFS:
228 | node = self.node_set[id_]
229 | arr = self.primitive_set[node.rect_idx]
230 | op_name = 'stage_{}_block_{}_node_{}_op'.format(self.stage, self.block, node.id)
231 | if node.node_type == NodeType.AndNode:
232 | if self.hasDblCnt[node.id]:
233 | child_tensor_res = []
234 | child_tensor_op = []
235 | for chid in node.child_ids:
236 | ch_arr = self.primitive_set[self.node_set[chid].rect_idx]
237 | if arr.Width() > ch_arr.Width():
238 | factor = node.npaths / self.node_set[chid].npaths
239 | if factor == 1.0:
240 | child_tensor_res.append(NodeIdTensorDict[chid])
241 | else:
242 | child_tensor_res.append(NodeIdTensorDict[chid].mul(factor))
243 | child_tensor_op.append(NodeIdTensorDict[chid])
244 |
245 | anode_tensor_res = torch.cat(child_tensor_res, 1)
246 | anode_tensor_op = torch.cat(child_tensor_op, 1)
247 |
248 | if self.hasLateral[node.id]:
249 | ids1 = set(node.parent_ids)
250 | num_shared = 0
251 | for chid in node.child_ids:
252 | ch_arr = self.primitive_set[self.node_set[chid].rect_idx]
253 | ids2 = self.node_set[chid].parent_ids
254 | if arr.Width() == ch_arr.Width():
255 | anode_tensor_op = anode_tensor_op + NodeIdTensorDict[chid]
256 | if len(ids1.intersection(ids2)) == num_shared:
257 | anode_tensor_res = anode_tensor_res + NodeIdTensorDict[chid]
258 |
259 | anode_output = getattr(self, op_name)(anode_tensor_op, anode_tensor_res)
260 | else:
261 | child_tensor = []
262 | for chid in node.child_ids:
263 | ch_arr = self.primitive_set[self.node_set[chid].rect_idx]
264 | if arr.Width() > ch_arr.Width():
265 | child_tensor.append(NodeIdTensorDict[chid])
266 |
267 | anode_tensor_op = torch.cat(child_tensor, 1)
268 |
269 | if self.hasLateral[node.id]:
270 | ids1 = set(node.parent_ids)
271 | num_shared = 0
272 | for chid in node.child_ids:
273 | ch_arr = self.primitive_set[self.node_set[chid].rect_idx]
274 | ids2 = self.node_set[chid].parent_ids
275 | if arr.Width() == ch_arr.Width() and len(ids1.intersection(ids2)) == num_shared:
276 | anode_tensor_op = anode_tensor_op + NodeIdTensorDict[chid]
277 |
278 | anode_tensor_res = anode_tensor_op
279 |
280 | for chid in node.child_ids:
281 | ch_arr = self.primitive_set[self.node_set[chid].rect_idx]
282 | ids2 = self.node_set[chid].parent_ids
283 | if arr.Width() == ch_arr.Width() and len(ids1.intersection(ids2)) > num_shared:
284 | anode_tensor_op = anode_tensor_op + NodeIdTensorDict[chid]
285 |
286 | anode_output = getattr(self, op_name)(anode_tensor_op, anode_tensor_res)
287 | else:
288 | anode_output = getattr(self, op_name)(anode_tensor_op)
289 |
290 | NodeIdTensorDict[node.id] = anode_output
291 |
292 | elif node.node_type == NodeType.OrNode:
293 | if self.hasDblCnt[node.id]:
294 | factor = node.npaths / self.node_set[node.child_ids[0]].npaths
295 | if factor == 1.0:
296 | onode_tensor_res = NodeIdTensorDict[node.child_ids[0]]
297 | else:
298 | onode_tensor_res = NodeIdTensorDict[node.child_ids[0]].mul(factor)
299 | onode_tensor_op = NodeIdTensorDict[node.child_ids[0]]
300 | for chid in node.child_ids[1:]:
301 | if self.node_set[chid].node_type != NodeType.OrNode:
302 | ch_arr = self.primitive_set[self.node_set[chid].rect_idx]
303 | if arr.Width() == ch_arr.Width():
304 | factor = node.npaths / self.node_set[chid].npaths
305 | if factor == 1.0:
306 | onode_tensor_res = onode_tensor_res + NodeIdTensorDict[chid]
307 | else:
308 | onode_tensor_res = onode_tensor_res + NodeIdTensorDict[chid].mul(factor)
309 | if cfg.aognet.use_elem_max_for_ORNodes:
310 | onode_tensor_op = torch.max(onode_tensor_op, NodeIdTensorDict[chid])
311 | else:
312 | onode_tensor_op = onode_tensor_op + NodeIdTensorDict[chid]
313 |
314 | if self.hasLateral[node.id]:
315 | ids1 = set(node.parent_ids)
316 | num_shared = 0
317 | for chid in node.child_ids[1:]:
318 | ids2 = self.node_set[chid].parent_ids
319 | if self.node_set[chid].node_type == NodeType.OrNode and \
320 | len(ids1.intersection(ids2)) == num_shared:
321 | onode_tensor_res = onode_tensor_res + NodeIdTensorDict[chid]
322 | if cfg.aognet.use_elem_max_for_ORNodes:
323 | onode_tensor_op = torch.max(onode_tensor_op, NodeIdTensorDict[chid])
324 | else:
325 | onode_tensor_op = onode_tensor_op + NodeIdTensorDict[chid]
326 |
327 | for chid in node.child_ids[1:]:
328 | ch_arr = self.primitive_set[self.node_set[chid].rect_idx]
329 | ids2 = self.node_set[chid].parent_ids
330 | if self.node_set[chid].node_type == NodeType.OrNode and \
331 | len(ids1.intersection(ids2)) > num_shared:
332 | if cfg.aognet.use_elem_max_for_ORNodes:
333 | onode_tensor_op = torch.max(onode_tensor_op, NodeIdTensorDict[chid])
334 | else:
335 | onode_tensor_op = onode_tensor_op + NodeIdTensorDict[chid]
336 | elif self.node_set[chid].node_type == NodeType.TerminalNode and \
337 | arr.Width() < ch_arr.Width():
338 | ch_left = self.out_slices[arr.x1] - self.out_slices[ch_arr.x1]
339 | ch_right = self.out_slices[arr.x2 + 1] - self.out_slices[ch_arr.x1]
340 | if cfg.aognet.use_elem_max_for_ORNodes:
341 | onode_tensor_op = torch.max(onode_tensor_op, NodeIdTensorDict[chid][:, ch_left:ch_right, :, :])
342 | else:
343 | onode_tensor_op = onode_tensor_op + NodeIdTensorDict[chid][:, ch_left:ch_right, :, :]#.contiguous()
344 |
345 | onode_output = getattr(self, op_name)(onode_tensor_op, onode_tensor_res)
346 | else:
347 | if cfg.aognet.use_elem_max_for_ORNodes:
348 | onode_tensor_op = NodeIdTensorDict[node.child_ids[0]]
349 | onode_tensor_res = NodeIdTensorDict[node.child_ids[0]]
350 | for chid in node.child_ids[1:]:
351 | if self.node_set[chid].node_type != NodeType.OrNode:
352 | ch_arr = self.primitive_set[self.node_set[chid].rect_idx]
353 | if arr.Width() == ch_arr.Width():
354 | onode_tensor_op = torch.max(onode_tensor_op, NodeIdTensorDict[chid])
355 | onode_tensor_res = onode_tensor_res + NodeIdTensorDict[chid]
356 |
357 | if self.hasLateral[node.id]:
358 | ids1 = set(node.parent_ids)
359 | num_shared = 0
360 | for chid in node.child_ids[1:]:
361 | ids2 = self.node_set[chid].parent_ids
362 | if self.node_set[chid].node_type == NodeType.OrNode and \
363 | len(ids1.intersection(ids2)) == num_shared:
364 | onode_tensor_op = torch.max(onode_tensor_op, NodeIdTensorDict[chid])
365 | onode_tensor_res = onode_tensor_res + NodeIdTensorDict[chid]
366 |
367 | for chid in node.child_ids[1:]:
368 | ch_arr = self.primitive_set[self.node_set[chid].rect_idx]
369 | ids2 = self.node_set[chid].parent_ids
370 | if self.node_set[chid].node_type == NodeType.OrNode and \
371 | len(ids1.intersection(ids2)) > num_shared:
372 | onode_tensor_op = torch.max(onode_tensor_op, NodeIdTensorDict[chid])
373 | elif self.node_set[chid].node_type == NodeType.TerminalNode and \
374 | arr.Width() < ch_arr.Width():
375 | ch_left = self.out_slices[arr.x1] - self.out_slices[ch_arr.x1]
376 | ch_right = self.out_slices[arr.x2 + 1] - self.out_slices[ch_arr.x1]
377 | onode_tensor_op = torch.max(onode_tensor_op, NodeIdTensorDict[chid][:, ch_left:ch_right, :, :])
378 |
379 | onode_output = getattr(self, op_name)(onode_tensor_op, onode_tensor_res)
380 | else:
381 | onode_output = getattr(self, op_name)(onode_tensor_op)
382 | else:
383 | onode_tensor_op = NodeIdTensorDict[node.child_ids[0]]
384 | for chid in node.child_ids[1:]:
385 | if self.node_set[chid].node_type != NodeType.OrNode:
386 | ch_arr = self.primitive_set[self.node_set[chid].rect_idx]
387 | if arr.Width() == ch_arr.Width():
388 | onode_tensor_op = onode_tensor_op + NodeIdTensorDict[chid]
389 |
390 | if self.hasLateral[node.id]:
391 | ids1 = set(node.parent_ids)
392 | num_shared = 0
393 | for chid in node.child_ids[1:]:
394 | ids2 = self.node_set[chid].parent_ids
395 | if self.node_set[chid].node_type == NodeType.OrNode and \
396 | len(ids1.intersection(ids2)) == num_shared:
397 | onode_tensor_op = onode_tensor_op + NodeIdTensorDict[chid]
398 |
399 | onode_tensor_res = onode_tensor_op
400 |
401 | for chid in node.child_ids[1:]:
402 | ch_arr = self.primitive_set[self.node_set[chid].rect_idx]
403 | ids2 = self.node_set[chid].parent_ids
404 | if self.node_set[chid].node_type == NodeType.OrNode and \
405 | len(ids1.intersection(ids2)) > num_shared:
406 | onode_tensor_op = onode_tensor_op + NodeIdTensorDict[chid]
407 | elif self.node_set[chid].node_type == NodeType.TerminalNode and \
408 | arr.Width() < ch_arr.Width():
409 | ch_left = self.out_slices[arr.x1] - self.out_slices[ch_arr.x1]
410 | ch_right = self.out_slices[arr.x2 + 1] - self.out_slices[ch_arr.x1]
411 | onode_tensor_op = onode_tensor_op + NodeIdTensorDict[chid][:, ch_left:ch_right, :, :].contiguous()
412 |
413 | onode_output = getattr(self, op_name)(onode_tensor_op, onode_tensor_res)
414 | else:
415 | onode_output = getattr(self, op_name)(onode_tensor_op)
416 |
417 | NodeIdTensorDict[node.id] = onode_output
418 |
419 | out = NodeIdTensorDict[self.aog.BFS[0]]
420 | out = self.extra_norm_ac(out) #TODO: Why this? Analyze it in depth
421 | return out
422 |
423 | ### AOGNet
424 | class AOGNet(nn.Module):
425 | def __init__(self, block=AOGBlock):
426 | super(AOGNet, self).__init__()
427 | filter_list = cfg.aognet.filter_list
428 | self.aogs = self._create_aogs()
429 | self.block = block
430 | if "BatchNorm2d" in cfg.norm_name:
431 | self.norm_name_base = "BatchNorm2d"
432 | elif "GroupNorm" in cfg.norm_name:
433 | self.norm_name_base = "GroupNorm"
434 | else:
435 | raise ValueError("Unknown norm layer")
436 |
437 | if "Mixture" in cfg.norm_name:
438 | assert len(cfg.norm_k) == len(filter_list)-1 and any(cfg.norm_k), \
439 | "Wrong mixture component specification (cfg.norm_k)"
440 | else:
441 | cfg.norm_k = [0 for i in range(len(filter_list)-1)]
442 |
443 | self.stem = self._stem(filter_list[0])
444 |
445 | self.stage0 = self._make_stage(stage=0, in_channels=filter_list[0], out_channels=filter_list[1])
446 | self.stage1 = self._make_stage(stage=1, in_channels=filter_list[1], out_channels=filter_list[2])
447 | self.stage2 = self._make_stage(stage=2, in_channels=filter_list[2], out_channels=filter_list[3])
448 | self.stage3 = None
449 | outchannels = filter_list[3]
450 | if cfg.dataset == 'imagenet':
451 | self.stage3 = self._make_stage(stage=3, in_channels=filter_list[3], out_channels=filter_list[4])
452 | outchannels = filter_list[4]
453 |
454 | self.conv_head = None
455 | if any(cfg.aognet.out_channels):
456 | assert len(cfg.aognet.out_channels) == 2
457 | self.conv_head = nn.Sequential(Conv_Norm_AC(outchannels, cfg.aognet.out_channels[0], 1, 1, 0,
458 | ac_mode=cfg.activation_mode,
459 | norm_name=self.norm_name_base,
460 | norm_groups=cfg.norm_groups,
461 | norm_k=cfg.norm_k[-1],
462 | norm_attention_mode=cfg.norm_attention_mode),
463 | nn.AdaptiveAvgPool2d((1, 1)),
464 | Conv_Norm_AC(cfg.aognet.out_channels[0], cfg.aognet.out_channels[1], 1, 1, 0,
465 | ac_mode=cfg.activation_mode,
466 | norm_name=self.norm_name_base,
467 | norm_groups=cfg.norm_groups,
468 | norm_k=cfg.norm_k[-1],
469 | norm_attention_mode=cfg.norm_attention_mode)
470 | )
471 | outchannels = cfg.aognet.out_channels[1]
472 | else:
473 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
474 | self.fc = nn.Linear(outchannels, cfg.num_classes)
475 |
476 | ## initialize
477 | self._init_params()
478 |
479 | def _stem(self, cout):
480 | layers = []
481 | if cfg.dataset == 'imagenet':
482 | if cfg.stem.imagenet_head7x7:
483 | layers.append( Conv_Norm_AC(3, cout, 7, 2, 3,
484 | ac_mode=cfg.activation_mode,
485 | norm_name=self.norm_name_base,
486 | norm_groups=cfg.norm_groups,
487 | norm_k=cfg.norm_k[0],
488 | norm_attention_mode=cfg.norm_attention_mode) )
489 | else:
490 | plane = cout // 2
491 | layers.append( Conv_Norm_AC(3, plane, 3, 2, 1,
492 | ac_mode=cfg.activation_mode,
493 | norm_name=self.norm_name_base,
494 | norm_groups=cfg.norm_groups,
495 | norm_k=cfg.norm_k[0],
496 | norm_attention_mode=cfg.norm_attention_mode) )
497 | layers.append( Conv_Norm_AC(plane, plane, 3, 1, 1,
498 | ac_mode=cfg.activation_mode,
499 | norm_name=self.norm_name_base,
500 | norm_groups=cfg.norm_groups,
501 | norm_k=cfg.norm_k[0],
502 | norm_attention_mode=cfg.norm_attention_mode) )
503 | layers.append( Conv_Norm_AC(plane, cout, 3, 1, 1,
504 | ac_mode=cfg.activation_mode,
505 | norm_name=self.norm_name_base,
506 | norm_groups=cfg.norm_groups,
507 | norm_k=cfg.norm_k[0],
508 | norm_attention_mode=cfg.norm_attention_mode) )
509 | if cfg.stem.replace_maxpool_with_res_bottleneck:
510 | layers.append( NodeOpSingleScale(cout, cout, 2,
511 | ac_mode=cfg.activation_mode,
512 | bn_ratio=cfg.aognet.bottleneck_ratio,
513 | norm_name=self.norm_name_base,
514 | norm_groups=cfg.norm_groups,
515 | norm_k = cfg.norm_k[0],
516 | norm_attention_mode=cfg.norm_attention_mode,
517 | replace_stride_with_avgpool=True) ) # used in OctConv
518 | else:
519 | layers.append( nn.MaxPool2d(2, 2) )
520 | elif cfg.dataset == 'cifar10' or cfg.dataset == 'cifar100':
521 | layers.append( Conv_Norm_AC(3, cout, 3, 1, 1,
522 | ac_mode=cfg.activation_mode,
523 | norm_name=self.norm_name_base,
524 | norm_groups=cfg.norm_groups,
525 | norm_k=cfg.norm_k[0],
526 | norm_attention_mode=cfg.norm_attention_mode) )
527 | else:
528 | raise NotImplementedError
529 |
530 | return nn.Sequential(*layers)
531 |
532 | def _init_params(self):
533 | for m in self.modules():
534 | if isinstance(m, nn.Conv2d):
535 | if cfg.init_mode == 'xavier':
536 | nn.init.xavier_normal_(m.weight)
537 | elif cfg.init_mode == 'avg':
538 | n = m.kernel_size[0] * m.kernel_size[1] * (m.in_channels + m.out_channels) / 2
539 | m.weight.data.normal_(0, math.sqrt(2. / n))
540 | else: # cfg.init_mode == 'kaiming': as default
541 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
542 |
543 | for name, _ in m.named_parameters():
544 | if name in ['bias']:
545 | nn.init.constant_(m.bias, 0.0)
546 | elif isinstance(m, (MixtureBatchNorm2d, MixtureGroupNorm)): # before BatchNorm2d
547 | nn.init.normal_(m.weight_, 1, 0.1)
548 | nn.init.normal_(m.bias_, 0, 0.1)
549 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
550 | nn.init.constant_(m.weight, 1.0)
551 | nn.init.constant_(m.bias, 0.0)
552 |
553 | # handle dbl cnt in init
554 | if cfg.aognet.handle_dbl_cnt_in_param_init:
555 | import re
556 | for name_, m in self.named_modules():
557 | if 'node' in name_:
558 | idx = re.findall(r'\d+', name_)
559 | sid = int(idx[0])
560 | nid = int(idx[2])
561 | npaths = self.aogs[sid].node_set[nid].npaths
562 | if npaths > 1:
563 | scale = 1.0 / npaths
564 | with torch.no_grad():
565 | for ch in m.modules():
566 | if isinstance(ch, nn.Conv2d):
567 | ch.weight.mul_(scale)
568 |
569 | # TODO: handle zero-gamma in the last norm layer of bottleneck op
570 |
571 | def _create_aogs(self):
572 | aogs = []
573 | num_stages = len(cfg.aognet.filter_list) - 1
574 | for i in range(num_stages):
575 | grid_ht = 1
576 | grid_wd = int(cfg.aognet.dims[i])
577 | aogs.append(get_aog(grid_ht=grid_ht, grid_wd=grid_wd, max_split=cfg.aognet.max_split[i],
578 | use_tnode_topdown_connection= cfg.aognet.extra_node_hierarchy[i] == 1,
579 | use_tnode_bottomup_connection_layerwise= cfg.aognet.extra_node_hierarchy[i] == 2,
580 | use_tnode_bottomup_connection_sequential= cfg.aognet.extra_node_hierarchy[i] == 3,
581 | use_node_lateral_connection= cfg.aognet.extra_node_hierarchy[i] == 4,
582 | use_tnode_bottomup_connection= cfg.aognet.extra_node_hierarchy[i] == 5,
583 | use_node_lateral_connection_1= cfg.aognet.extra_node_hierarchy[i] == 6,
584 | remove_symmetric_children_of_or_node=cfg.aognet.remove_symmetric_children_of_or_node[i]
585 | ))
586 |
587 | return aogs
588 |
589 | def _make_stage(self, stage, in_channels, out_channels):
590 | blocks = nn.Sequential()
591 | dim = cfg.aognet.dims[stage]
592 | assert in_channels % dim == 0 and out_channels % dim == 0
593 | step_channels = (out_channels - in_channels) // cfg.aognet.blocks[stage]
594 | if step_channels % dim != 0:
595 | low = (step_channels // dim) * dim
596 | high = (step_channels // dim + 1) * dim
597 | if (step_channels-low) <= (high-step_channels):
598 | step_channels = low
599 | else:
600 | step_channels = high
601 |
602 | aog = self.aogs[stage]
603 | for j in range(cfg.aognet.blocks[stage]):
604 | name_ = 'stage_{}_block_{}'.format(stage, j)
605 | drop_rate = cfg.aognet.drop_rate[stage]
606 | stride = cfg.aognet.stride[stage] if j==0 else 1
607 | outchannels = (in_channels + step_channels) if j < cfg.aognet.blocks[stage]-1 else out_channels
608 | if stride > 1 and cfg.aognet.when_downsample == 1:
609 | blocks.add_module(name_ + '_transition',
610 | nn.Sequential( Conv_Norm_AC(in_channels, in_channels, 1, 1, 0,
611 | ac_mode=cfg.activation_mode,
612 | norm_name=self.norm_name_base,
613 | norm_groups=cfg.norm_groups,
614 | norm_k=cfg.norm_k[stage],
615 | norm_attention_mode=cfg.norm_attention_mode,
616 | replace_stride_with_avgpool=False),
617 | nn.AvgPool2d(kernel_size=(stride, stride), stride=stride)
618 | )
619 | )
620 | stride = 1
621 | elif (stride > 1 or in_channels != outchannels) and cfg.aognet.when_downsample == 2:
622 | trans_op = [Conv_Norm_AC(in_channels, outchannels, 1, 1, 0,
623 | ac_mode=cfg.activation_mode,
624 | norm_name=self.norm_name_base,
625 | norm_groups=cfg.norm_groups,
626 | norm_k=cfg.norm_k[stage],
627 | norm_attention_mode=cfg.norm_attention_mode,
628 | replace_stride_with_avgpool=False)]
629 | if stride > 1:
630 | trans_op.append(nn.AvgPool2d(kernel_size=(stride, stride), stride=stride))
631 | blocks.add_module(name_ + '_transition', nn.Sequential(*trans_op))
632 | stride = 1
633 | in_channels = outchannels
634 |
635 | blocks.add_module(name_, self.block(stage, j, aog, in_channels, outchannels, drop_rate, stride))
636 | in_channels = outchannels
637 |
638 | return blocks
639 |
640 | def forward(self, x):
641 | y = self.stem(x)
642 |
643 | y = self.stage0(y)
644 | y = self.stage1(y)
645 | y = self.stage2(y)
646 | if self.stage3 is not None:
647 | y = self.stage3(y)
648 | if self.conv_head is not None:
649 | y = self.conv_head(y)
650 | else:
651 | y = self.avgpool(y)
652 | y = y.view(y.size(0), -1)
653 | y = self.fc(y)
654 |
655 | return y
656 |
657 | def aognet(**kwargs):
658 | '''
659 | Construct a single scale AOGNet model
660 | '''
661 | return AOGNet(**kwargs)
662 |
--------------------------------------------------------------------------------
/models/aognet/operator_basic.py:
--------------------------------------------------------------------------------
1 | """ RESEARCH ONLY LICENSE
2 | Copyright (c) 2018-2019 North Carolina State University.
3 | All rights reserved.
4 | Redistribution and use in source and binary forms, with or without modification, are permitted provided
5 | that the following conditions are met:
6 | 1. Redistributions and use are permitted for internal research purposes only, and commercial use
7 | is strictly prohibited under this license. Inquiries regarding commercial use should be directed to the
8 | Office of Research Commercialization at North Carolina State University, 919-215-7199,
9 | https://research.ncsu.edu/commercialization/contact/, commercialization@ncsu.edu .
10 | 2. Commercial use means the sale, lease, export, transfer, conveyance or other distribution to a
11 | third party for financial gain, income generation or other commercial purposes of any kind, whether
12 | direct or indirect. Commercial use also means providing a service to a third party for financial gain,
13 | income generation or other commercial purposes of any kind, whether direct or indirect.
14 | 3. Redistributions of source code must retain the above copyright notice, this list of conditions and
15 | the following disclaimer.
16 | 4. Redistributions in binary form must reproduce the above copyright notice, this list of conditions
17 | and the following disclaimer in the documentation and/or other materials provided with the
18 | distribution.
19 | 5. The names “North Carolina State University”, “NCSU” and any trade-name, personal name,
20 | trademark, trade device, service mark, symbol, image, icon, or any abbreviation, contraction or
21 | simulation thereof owned by North Carolina State University must not be used to endorse or promote
22 | products derived from this software without prior written permission. For written permission, please
23 | contact trademarks@ncsu.edu.
24 | Disclaimer: THIS SOFTWARE IS PROVIDED “AS IS” AND ANY EXPRESSED OR IMPLIED WARRANTIES,
25 | INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
26 | A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NORTH CAROLINA STATE UNIVERSITY BE
27 | LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
28 | (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
29 | DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
30 | LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR
31 | OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
32 | POSSIBILITY OF SUCH DAMAGE.
33 | """
34 | # The system is protected via patent (pending)
35 | # Written by Tianfu Wu and Xilai Li
36 | # Contact: {tianfu_wu, xli47}@ncsu.edu
37 | from __future__ import absolute_import
38 | from __future__ import division
39 | from __future__ import print_function # force to use print as function print(args)
40 | from __future__ import unicode_literals
41 |
42 | import torch
43 | import torch.nn as nn
44 | import torch.nn.functional as F
45 |
46 | _inplace = True
47 | _norm_eps = 1e-5
48 |
49 | def to_int(x):
50 | if x - int(x) < 0.5:
51 | return int(x)
52 | else:
53 | return int(x) + 1
54 |
55 | ### Activation
56 | class AC(nn.Module):
57 | def __init__(self, mode):
58 | super(AC, self).__init__()
59 | if mode == 1:
60 | self.ac = nn.LeakyReLU(inplace=_inplace)
61 | elif mode == 2:
62 | self.ac = nn.ReLU6(inplace=_inplace)
63 | else:
64 | self.ac = nn.ReLU(inplace=_inplace)
65 |
66 | def forward(self, x):
67 | x = self.ac(x)
68 | return x
69 |
70 | ###
71 | class hsigmoid(nn.Module):
72 | def forward(self, x):
73 | out = F.relu6(x + 3, inplace=True) / 6
74 | return out
75 |
76 | ### Feature Norm
77 | def FeatureNorm(norm_name, num_channels, num_groups, num_k, attention_mode):
78 | if norm_name == "BatchNorm2d":
79 | return nn.BatchNorm2d(num_channels, eps=_norm_eps)
80 | elif norm_name == "GroupNorm":
81 | assert num_groups > 1
82 | if num_channels % num_groups != 0:
83 | raise ValueError("channels {} not dividable by groups {}".format(num_channels, num_groups))
84 | return nn.GroupNorm(num_channels, num_groups, eps=_norm_eps)
85 | elif norm_name == "MixtureBatchNorm2d":
86 | assert num_k > 1
87 | return MixtureBatchNorm2d(num_channels, num_k, attention_mode)
88 | elif norm_name == "MixtureGroupNorm":
89 | assert num_groups > 1 and num_k > 1
90 | if num_channels % num_groups != 0:
91 | raise ValueError("channels {} not dividable by groups {}".format(num_channels, num_groups))
92 | return MixtureGroupNorm(num_channels, num_groups, num_k, attention_mode)
93 | else:
94 | raise NotImplementedError("Unknown feature norm name")
95 |
96 | ### Attention weights for mixture norm
97 | class AttentionWeights(nn.Module):
98 | expansion = 2
99 | def __init__(self, attention_mode, num_channels, k,
100 | norm_name=None, norm_groups=0):
101 | super(AttentionWeights, self).__init__()
102 | #num_channels *= 2
103 | self.k = k
104 | self.avgpool = nn.AdaptiveAvgPool2d(1)
105 | layers = []
106 | if attention_mode == 0:
107 | layers = [ nn.Conv2d(num_channels, k, 1),
108 | nn.Sigmoid() ]
109 | elif attention_mode == 4:
110 | layers = [ nn.Conv2d(num_channels, k, 1),
111 | hsigmoid() ]
112 | elif attention_mode == 1:
113 | layers = [ nn.Conv2d(num_channels, k*self.expansion, 1),
114 | nn.ReLU(inplace=True),
115 | nn.Conv2d(k*self.expansion, k, 1),
116 | nn.Sigmoid() ]
117 | elif attention_mode == 2:
118 | assert norm_name is not None
119 | layers = [ nn.Conv2d(num_channels, k, 1, bias=False),
120 | FeatureNorm(norm_name, k, norm_groups, 0, 0),
121 | hsigmoid() ]
122 | elif attention_mode == 5:
123 | assert norm_name is not None
124 | layers = [ nn.Conv2d(num_channels, k, 1, bias=False),
125 | FeatureNorm(norm_name, k, norm_groups, 0, 0),
126 | nn.Sigmoid() ]
127 | elif attention_mode == 6:
128 | assert norm_name is not None
129 | layers = [ nn.Conv2d(num_channels, k, 1, bias=False),
130 | FeatureNorm(norm_name, k, norm_groups, 0, 0),
131 | nn.Softmax(dim=1) ]
132 | elif attention_mode == 3:
133 | assert norm_name is not None
134 | layers = [ nn.Conv2d(num_channels, k*self.expansion, 1, bias=False),
135 | FeatureNorm(norm_name, k*self.expansion, norm_groups, 0, 0),
136 | nn.ReLU(inplace=True),
137 | nn.Conv2d(k*self.expansion, k, 1, bias=False),
138 | FeatureNorm(norm_name, k, norm_groups, 0, 0),
139 | hsigmoid() ]
140 | else:
141 | raise NotImplementedError("Unknow attention weight type")
142 | self.attention = nn.Sequential(*layers)
143 |
144 | def forward(self, x):
145 | b, c, _, _ = x.size()
146 | y = self.avgpool(x)#.view(b, c)
147 | var = torch.var(x, dim=(2, 3)).view(b, c, 1, 1)
148 | y *= (var + 1e-3).rsqrt()
149 | #y = torch.cat((y, var), dim=1)
150 | return self.attention(y).view(b, self.k)
151 |
152 |
153 | ### Mixture Norm
154 | # TODO: keep it to use FP32 always, need to figure out how to set it using apex ?
155 | class MixtureBatchNorm2d(nn.BatchNorm2d):
156 | def __init__(self, num_channels, k, attention_mode, eps=_norm_eps, momentum=0.1,
157 | track_running_stats=True):
158 | super(MixtureBatchNorm2d, self).__init__(num_channels, eps=eps,
159 | momentum=momentum, affine=False, track_running_stats=track_running_stats)
160 | self.k = k
161 | self.weight_ = nn.Parameter(torch.Tensor(k, num_channels))
162 | self.bias_ = nn.Parameter(torch.Tensor(k, num_channels))
163 |
164 | self.attention_weights = AttentionWeights(attention_mode, num_channels, k,
165 | norm_name='BatchNorm2d')
166 |
167 | self._init_params()
168 |
169 | def _init_params(self):
170 | nn.init.normal_(self.weight_, 1, 0.1)
171 | nn.init.normal_(self.bias_, 0, 0.1)
172 |
173 | def forward(self, x):
174 | output = super(MixtureBatchNorm2d, self).forward(x)
175 | size = output.size()
176 | y = self.attention_weights(x) # bxk # or use output as attention input
177 |
178 | weight = y @ self.weight_ # bxc
179 | bias = y @ self.bias_ # bxc
180 | weight = weight.unsqueeze(-1).unsqueeze(-1).expand(size)
181 | bias = bias.unsqueeze(-1).unsqueeze(-1).expand(size)
182 |
183 | return weight * output + bias
184 |
185 |
186 | # Modified on top of nn.GroupNorm
187 | # TODO: keep it to use FP32 always, need to figure out how to set it using apex ?
188 | class MixtureGroupNorm(nn.Module):
189 | __constants__ = ['num_groups', 'num_channels', 'k', 'eps', 'weight',
190 | 'bias']
191 |
192 | def __init__(self, num_channels, num_groups, k, attention_mode, eps=_norm_eps):
193 | super(MixtureGroupNorm, self).__init__()
194 | self.num_groups = num_groups
195 | self.num_channels = num_channels
196 | self.k = k
197 | self.eps = eps
198 | self.affine = True
199 | self.weight_ = nn.Parameter(torch.Tensor(k, num_channels))
200 | self.bias_ = nn.Parameter(torch.Tensor(k, num_channels))
201 | self.register_parameter('weight', None)
202 | self.register_parameter('bias', None)
203 |
204 | self.attention_weights = AttentionWeights(attention_mode, num_channels, k,
205 | norm_name='GroupNorm', norm_groups=1)
206 |
207 | self.reset_parameters()
208 |
209 | def reset_parameters(self):
210 | nn.init.normal_(self.weight_, 1, 0.1)
211 | nn.init.normal_(self.bias_, 0, 0.1)
212 |
213 | def forward(self, x):
214 | output = F.group_norm(
215 | x, self.num_groups, self.weight, self.bias, self.eps)
216 | size = output.size()
217 |
218 | y = self.attention_weights(x) # TODO: use output as attention input
219 |
220 | weight = y @ self.weight_
221 | bias = y @ self.bias_
222 |
223 | weight = weight.unsqueeze(-1).unsqueeze(-1).expand(size)
224 | bias = bias.unsqueeze(-1).unsqueeze(-1).expand(size)
225 |
226 | return weight * output + bias
227 |
228 | def extra_repr(self):
229 | return '{num_groups}, {num_channels}, eps={eps}, ' \
230 | 'affine={affine}'.format(**self.__dict__)
231 |
232 |
233 |
234 |
235 |
--------------------------------------------------------------------------------
/models/aognet/operator_singlescale.py:
--------------------------------------------------------------------------------
1 | """ RESEARCH ONLY LICENSE
2 | Copyright (c) 2018-2019 North Carolina State University.
3 | All rights reserved.
4 | Redistribution and use in source and binary forms, with or without modification, are permitted provided
5 | that the following conditions are met:
6 | 1. Redistributions and use are permitted for internal research purposes only, and commercial use
7 | is strictly prohibited under this license. Inquiries regarding commercial use should be directed to the
8 | Office of Research Commercialization at North Carolina State University, 919-215-7199,
9 | https://research.ncsu.edu/commercialization/contact/, commercialization@ncsu.edu .
10 | 2. Commercial use means the sale, lease, export, transfer, conveyance or other distribution to a
11 | third party for financial gain, income generation or other commercial purposes of any kind, whether
12 | direct or indirect. Commercial use also means providing a service to a third party for financial gain,
13 | income generation or other commercial purposes of any kind, whether direct or indirect.
14 | 3. Redistributions of source code must retain the above copyright notice, this list of conditions and
15 | the following disclaimer.
16 | 4. Redistributions in binary form must reproduce the above copyright notice, this list of conditions
17 | and the following disclaimer in the documentation and/or other materials provided with the
18 | distribution.
19 | 5. The names “North Carolina State University”, “NCSU” and any trade-name, personal name,
20 | trademark, trade device, service mark, symbol, image, icon, or any abbreviation, contraction or
21 | simulation thereof owned by North Carolina State University must not be used to endorse or promote
22 | products derived from this software without prior written permission. For written permission, please
23 | contact trademarks@ncsu.edu.
24 | Disclaimer: THIS SOFTWARE IS PROVIDED “AS IS” AND ANY EXPRESSED OR IMPLIED WARRANTIES,
25 | INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
26 | A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NORTH CAROLINA STATE UNIVERSITY BE
27 | LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
28 | (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
29 | DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
30 | LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR
31 | OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
32 | POSSIBILITY OF SUCH DAMAGE.
33 | """
34 | # The system is protected via patent (pending)
35 | # Written by Tianfu Wu and Xilai Li
36 | # Contact: {tianfu_wu, xli47}@ncsu.edu
37 | from __future__ import absolute_import
38 | from __future__ import division
39 | from __future__ import print_function # force to use print as function print(args)
40 | from __future__ import unicode_literals
41 |
42 | import torch
43 | import torch.nn as nn
44 | import torch.nn.functional as F
45 | from torch.autograd import Variable
46 |
47 | from .operator_basic import *
48 |
49 | _bias = False
50 | _inplace = True
51 |
52 | ### Conv_Norm
53 | class Conv_Norm(nn.Module):
54 | def __init__(self, in_channels, out_channels, kernel_size, stride, padding,
55 | groups=1, drop_rate=0.0,
56 | norm_name='BatchNorm2d', norm_groups=0, norm_k=0, norm_attention_mode=0,
57 | replace_stride_with_avgpool=False):
58 | super(Conv_Norm, self).__init__()
59 |
60 | layers = []
61 | if stride > 1 and replace_stride_with_avgpool:
62 | layers.append(nn.AvgPool2d(kernel_size=(stride, stride), stride=stride))
63 | stride = 1
64 | layers.append(nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size,
65 | stride=stride, padding=padding,
66 | groups=groups, bias=_bias))
67 | layers.append(FeatureNorm(norm_name, out_channels, norm_groups, norm_k, norm_attention_mode))
68 | if drop_rate > 0.0:
69 | layers.append(nn.Dropout2d(p=drop_rate, inplace=_inplace))
70 | self.conv_norm = nn.Sequential(*layers)
71 |
72 | def forward(self, x):
73 | y = self.conv_norm(x)
74 | return y
75 |
76 | ### Conv_Norm_AC
77 | class Conv_Norm_AC(nn.Module):
78 | def __init__(self, in_channels, out_channels, kernel_size, stride, padding,
79 | groups=1, drop_rate=0., ac_mode=0,
80 | norm_name='BatchNorm2d', norm_groups=0, norm_k=0, norm_attention_mode=0,
81 | replace_stride_with_avgpool=False):
82 | super(Conv_Norm_AC, self).__init__()
83 |
84 | self.conv_norm = Conv_Norm(in_channels, out_channels, kernel_size, stride, padding,
85 | groups=groups, drop_rate=drop_rate,
86 | norm_name=norm_name, norm_groups=norm_groups, norm_k=norm_k, norm_attention_mode=norm_attention_mode,
87 | replace_stride_with_avgpool=replace_stride_with_avgpool)
88 | self.ac = AC(ac_mode)
89 |
90 | def forward(self, x):
91 | y = self.conv_norm(x)
92 | y = self.ac(y)
93 | return y
94 |
95 | ### NodeOpSingleScale
96 | class NodeOpSingleScale(nn.Module):
97 | def __init__(self, in_channels, out_channels, stride,
98 | groups=1, drop_rate=0., ac_mode=0, bn_ratio=0.25,
99 | norm_name='BatchNorm2d', norm_groups=0, norm_k=0, norm_attention_mode=0,
100 | replace_stride_with_avgpool=True):
101 | super(NodeOpSingleScale, self).__init__()
102 | if "BatchNorm2d" in norm_name:
103 | norm_name_base = "BatchNorm2d"
104 | elif "GroupNorm" in norm_name:
105 | norm_name_base = "GroupNorm"
106 | else:
107 | raise ValueError("Unknown norm layer")
108 |
109 | mid_channels = max(4, to_int(out_channels * bn_ratio / groups) * groups)
110 | self.conv_norm_ac_1 = Conv_Norm_AC(in_channels, mid_channels, 1, 1, 0,
111 | ac_mode=ac_mode,
112 | norm_name=norm_name_base, norm_groups=norm_groups, norm_k=norm_k, norm_attention_mode=norm_attention_mode)
113 | self.conv_norm_ac_2 = Conv_Norm_AC(mid_channels, mid_channels, 3, stride, 1,
114 | groups=groups, ac_mode=ac_mode,
115 | norm_name=norm_name, norm_groups=norm_groups, norm_k=norm_k, norm_attention_mode=norm_attention_mode,
116 | replace_stride_with_avgpool=False)
117 | self.conv_norm_3 = Conv_Norm(mid_channels, out_channels, 1, 1, 0,
118 | drop_rate=drop_rate,
119 | norm_name=norm_name_base, norm_groups=norm_groups, norm_k=norm_k, norm_attention_mode=norm_attention_mode)
120 |
121 | self.shortcut = None
122 | if in_channels != out_channels or stride > 1:
123 | self.shortcut = Conv_Norm(in_channels, out_channels, 1, stride, 0,
124 | norm_name=norm_name_base, norm_groups=norm_groups, norm_k=norm_k, norm_attention_mode=norm_attention_mode,
125 | replace_stride_with_avgpool=replace_stride_with_avgpool)
126 |
127 | self.ac = AC(ac_mode)
128 |
129 | def forward(self, x, res=None):
130 | residual = x if res is None else res
131 | y = self.conv_norm_ac_1(x)
132 | y = self.conv_norm_ac_2(y)
133 | y = self.conv_norm_3(y)
134 |
135 | if self.shortcut is not None:
136 | residual = self.shortcut(residual)
137 |
138 | y += residual
139 | y = self.ac(y)
140 | return y
141 |
142 | ### TODO: write a unit test for NodeOpSingleScale in a standalone way
143 |
144 |
145 |
--------------------------------------------------------------------------------
/models/config.py:
--------------------------------------------------------------------------------
1 | from yacs.config import CfgNode as CN
2 |
3 | _C = CN()
4 | _C.arch = 'aognet'
5 | _C.batch_size = 128
6 | _C.num_epoch = 300
7 | _C.dataset = 'cifar10'
8 | _C.num_classes = 10
9 | _C.crop_size = 224 # imagenet
10 | _C.crop_interpolation = 2 # 2=BILINEAR, default; 3=BICUBIC
11 | _C.optimizer = 'SGD'
12 | _C.gamma = 0.1 # decay_rate
13 | _C.use_cosine_lr = False
14 | _C.cosine_lr_min = 0.0
15 | _C.warmup_epochs = 5
16 | _C.lr = 0.1
17 | _C.lr_scale_factor = 256 # per nvidia apex
18 | _C.lr_milestones = [150, 225]
19 | _C.momentum = 0.9
20 | _C.wd = 5e-4
21 | _C.nesterov = False
22 | _C.activation_mode = 0 # 1: leakyReLU, 2: ReLU6 , other: ReLU
23 | _C.init_mode = 'kaiming'
24 | _C.norm_name = 'BatchNorm2d'
25 | _C.norm_groups = 0
26 | _C.norm_k = [0]
27 | _C.norm_attention_mode = 0
28 | _C.norm_zero_gamma_init = False
29 | _C.norm_all_mix = False
30 |
31 | # data augmentation
32 | _C.dataaug = CN()
33 | _C.dataaug.imagenet_extra_aug = False
34 | _C.dataaug.labelsmoothing_rate = 0. # 0.1
35 | _C.dataaug.mixup_rate = 0. # 0.2
36 |
37 | # stem
38 | _C.stem = CN()
39 | _C.stem.imagenet_head7x7 = False
40 | _C.stem.replace_maxpool_with_res_bottleneck = False
41 | _C.stem.stem_kernel_size = 7
42 | _C.stem.stem_stride = 2
43 |
44 | # resnet
45 | _C.resnet = CN()
46 | _C.resnet.base_inplanes = 16
47 | _C.resnet.replace_stride_with_dilation = [False, False, False]
48 | _C.resnet.replace_stride_with_avgpool = False
49 | _C.resnet.extra_norm_ac = False
50 |
51 | # mobilenet
52 | _C.mobilenet = CN()
53 | _C.mobilenet.rm_se = False # for mobilenetv3 with mixture norm
54 | _C.mobilenet.use_mn_in_se = False
55 |
56 | # aognet
57 | _C.aognet = CN()
58 | _C.aognet.filter_list = [16, 64, 128, 256]
59 | _C.aognet.out_channels = [0,0]
60 | _C.aognet.blocks = [1, 1, 1]
61 | _C.aognet.dims = [4, 4, 4]
62 | _C.aognet.max_split = [2, 2, 2] # must >= 2
63 | _C.aognet.extra_node_hierarchy = [0, 0, 0] # 0: none, 1: tnode topdown, 2: tnode bottomup layerwise, 3: tnode bottomup sequential, 4: non-term node lateral, 5: tnode bottomup
64 | _C.aognet.remove_symmetric_children_of_or_node = [0, 0, 0]
65 | _C.aognet.terminal_node_no_slice = [0, 0, 0]
66 | _C.aognet.stride = [1, 2, 2]
67 | _C.aognet.drop_rate = [0.0, 0.0, 0.0]
68 | _C.aognet.bottleneck_ratio = 0.25
69 | _C.aognet.handle_dbl_cnt = True
70 | _C.aognet.handle_tnode_dbl_cnt = False
71 | _C.aognet.handle_dbl_cnt_in_param_init = False
72 | _C.aognet.use_group_conv = False
73 | _C.aognet.width_per_group = 0
74 | _C.aognet.when_downsample = 0 # 0: at T-nodes, 1: before a aogblock, by conv_norm_ac + avgpool
75 | _C.aognet.replace_stride_with_avgpool = True # for downsample in node op.
76 | _C.aognet.use_elem_max_for_ORNodes = False
77 |
78 | cfg = _C
79 |
--------------------------------------------------------------------------------
/models/mobilenet.py:
--------------------------------------------------------------------------------
1 | # from https://raw.githubusercontent.com/pytorch/vision/master/torchvision/models/mobilenet.py
2 | # Modified by Tianfu Wu
3 | # Contact: tianfu_wu@ncsu.edu
4 |
5 | from torch import nn
6 | from .aognet.operator_basic import FeatureNorm, MixtureBatchNorm2d, MixtureGroupNorm
7 | from .config import cfg
8 |
9 |
10 | __all__ = ['MobileNetV2', 'mobilenet_v2']
11 |
12 |
13 | class ConvBNReLU(nn.Sequential):
14 | def __init__(self, in_planes, out_planes, kernel_size=3, stride=1, groups=1,
15 | norm_name=None, norm_groups=0, norm_k=0, norm_attention_mode=0):
16 | if norm_name is None:
17 | norm_name = "BatchNorm2d"
18 | padding = (kernel_size - 1) // 2
19 | super(ConvBNReLU, self).__init__(
20 | nn.Conv2d(in_planes, out_planes, kernel_size, stride, padding, groups=groups, bias=False),
21 | #nn.BatchNorm2d(out_planes),
22 | FeatureNorm(norm_name, out_planes,
23 | num_groups=norm_groups, num_k=norm_k,
24 | attention_mode=norm_attention_mode),
25 | nn.ReLU6(inplace=True)
26 | )
27 |
28 |
29 | class InvertedResidual(nn.Module):
30 | def __init__(self, inp, oup, stride, expand_ratio,
31 | norm_name=None, norm_groups=0, norm_k=0, norm_attention_mode=0):
32 | super(InvertedResidual, self).__init__()
33 | self.stride = stride
34 | assert stride in [1, 2]
35 | if norm_name is None:
36 | norm_name = "BatchNorm2d"
37 | if "BatchNorm2d" in norm_name:
38 | norm_name_base = "BatchNorm2d"
39 | elif "GroupNorm" in norm_name:
40 | norm_name_base = "GroupNorm"
41 | else:
42 | raise ValueError("Unknown norm.")
43 |
44 | hidden_dim = int(round(inp * expand_ratio))
45 | self.use_res_connect = self.stride == 1 and inp == oup
46 |
47 | layers = []
48 | if expand_ratio != 1:
49 | # pw
50 | layers.append(ConvBNReLU(inp, hidden_dim, kernel_size=1,
51 | norm_name=norm_name_base, norm_groups=norm_groups,
52 | norm_k=norm_k, norm_attention_mode=norm_attention_mode))
53 | layers.extend([
54 | # dw
55 | ConvBNReLU(hidden_dim, hidden_dim, stride=stride, groups=hidden_dim,
56 | norm_name=norm_name, norm_groups=norm_groups,
57 | norm_k=norm_k, norm_attention_mode=norm_attention_mode),
58 | # pw-linear
59 | nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False),
60 | #nn.BatchNorm2d(oup),
61 | FeatureNorm(norm_name_base, oup,
62 | num_groups=norm_groups, num_k=norm_k,
63 | attention_mode=norm_attention_mode),
64 | ])
65 | self.conv = nn.Sequential(*layers)
66 |
67 | def forward(self, x):
68 | if self.use_res_connect:
69 | return x + self.conv(x)
70 | else:
71 | return self.conv(x)
72 |
73 |
74 | class MobileNetV2(nn.Module):
75 | def __init__(self, num_classes=1000, width_mult=1.0):
76 | super(MobileNetV2, self).__init__()
77 | block = InvertedResidual
78 | input_channel = 32
79 | last_channel = 1280
80 | inverted_residual_setting = [
81 | # t, c, n, s
82 | [1, 16, 1, 1],
83 | [6, 24, 2, 2],
84 | [6, 32, 3, 2],
85 | [6, 64, 4, 2],
86 | [6, 96, 3, 1],
87 | [6, 160, 3, 2],
88 | [6, 320, 1, 1],
89 | ]
90 | norm_name = cfg.norm_name
91 | norm_groups = cfg.norm_groups
92 | norm_k = cfg.norm_k
93 | norm_attention_mode = cfg.norm_attention_mode
94 | if norm_name is None:
95 | norm_name = "BatchNorm2d"
96 | if "BatchNorm2d" in norm_name:
97 | norm_name_base = "BatchNorm2d"
98 | elif "GroupNorm" in norm_name:
99 | norm_name_base = "GroupNorm"
100 | else:
101 | raise ValueError("Unknown norm.")
102 |
103 | # building first layer
104 | input_channel = int(input_channel * width_mult)
105 | self.last_channel = int(last_channel * max(1.0, width_mult))
106 | features = [ConvBNReLU(3, input_channel, stride=2,
107 | norm_name=norm_name_base, norm_groups=norm_groups,
108 | norm_k=-1, norm_attention_mode=norm_attention_mode)]
109 | # building inverted residual blocks
110 | for j, (t, c, n, s) in enumerate(inverted_residual_setting):
111 | output_channel = int(c * width_mult)
112 | for i in range(n):
113 | stride = s if i == 0 else 1
114 | features.append(block(input_channel, output_channel, stride, expand_ratio=t,
115 | norm_name=norm_name, norm_groups=norm_groups,
116 | norm_k=norm_k[j], norm_attention_mode=norm_attention_mode))
117 | input_channel = output_channel
118 | # building last several layers
119 | features.append(ConvBNReLU(input_channel, self.last_channel, kernel_size=1,
120 | norm_name=norm_name_base, norm_groups=norm_groups,
121 | norm_k=-1, norm_attention_mode=norm_attention_mode))
122 | # make it nn.Sequential
123 | self.features = nn.Sequential(*features)
124 |
125 | # building classifier
126 | self.classifier = nn.Sequential(
127 | nn.Dropout(0.2),
128 | nn.Linear(self.last_channel, num_classes),
129 | )
130 |
131 | # weight initialization
132 | for m in self.modules():
133 | if isinstance(m, nn.Conv2d):
134 | nn.init.kaiming_normal_(m.weight, mode='fan_out')
135 | if m.bias is not None:
136 | nn.init.zeros_(m.bias)
137 | elif isinstance(m, (MixtureBatchNorm2d, MixtureGroupNorm)):
138 | nn.init.normal_(m.weight_, 1, 0.1)
139 | nn.init.normal_(m.bias_, 0, 0.1)
140 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
141 | nn.init.ones_(m.weight)
142 | nn.init.zeros_(m.bias)
143 | elif isinstance(m, nn.Linear):
144 | nn.init.normal_(m.weight, 0, 0.01)
145 | nn.init.zeros_(m.bias)
146 |
147 | def forward(self, x):
148 | x = self.features(x)
149 | x = x.mean([2, 3])
150 | x = self.classifier(x)
151 | return x
152 |
153 |
154 | def mobilenet_v2(pretrained=False, progress=True, **kwargs):
155 | """
156 | Constructs a MobileNetV2 architecture from
157 | `"MobileNetV2: Inverted Residuals and Linear Bottlenecks" `_.
158 |
159 | Args:
160 | pretrained (bool): If True, returns a model pre-trained on ImageNet
161 | progress (bool): If True, displays a progress bar of the download to stderr
162 | """
163 | model = MobileNetV2(**kwargs)
164 | if pretrained:
165 | state_dict = load_state_dict_from_url(model_urls['mobilenet_v2'],
166 | progress=progress)
167 | model.load_state_dict(state_dict)
168 | return model
169 |
--------------------------------------------------------------------------------
/models/resnet.py:
--------------------------------------------------------------------------------
1 | # Modefied From: https://github.com/pytorch/vision/blob/master/torchvision/models/resnet.py
2 | # By Tianfu Wu
3 | # Contact: tianfu_wu@ncsu.edu
4 | import torch.nn as nn
5 | from .aognet.operator_basic import FeatureNorm, MixtureBatchNorm2d, MixtureGroupNorm
6 | from .config import cfg
7 |
8 |
9 | __all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101',
10 | 'resnet152', 'resnext50_32x4d', 'resnext101_32x8d', 'resnext101_64x4d']
11 |
12 |
13 | def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1):
14 | """3x3 convolution with padding"""
15 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
16 | padding=dilation, groups=groups, bias=False, dilation=dilation)
17 |
18 |
19 | def conv1x1(in_planes, out_planes, stride=1):
20 | """1x1 convolution"""
21 | return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
22 |
23 |
24 | class BasicBlock(nn.Module):
25 | expansion = 1
26 |
27 | def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
28 | base_width=64, dilation=1,
29 | norm_name=None, norm_groups=0, norm_k=0, norm_attention_mode=0):
30 | super(BasicBlock, self).__init__()
31 | if norm_name is None:
32 | norm_name = "BatchNorm2d"
33 | if groups != 1 or base_width != 64:
34 | raise ValueError(
35 | 'BasicBlock only supports groups=1 and base_width=64')
36 | if dilation > 1:
37 | raise NotImplementedError(
38 | "Dilation > 1 not supported in BasicBlock")
39 | # Both self.conv1 and self.downsample layers downsample the input when stride != 1
40 | self.conv1 = conv3x3(inplanes, planes, stride)
41 | self.bn1 = FeatureNorm(norm_name, planes,
42 | num_groups=norm_groups, num_k=norm_k,
43 | attention_mode=norm_attention_mode)
44 | self.relu = nn.ReLU(inplace=True)
45 | self.conv2 = conv3x3(planes, planes)
46 | self.bn2 = FeatureNorm(norm_name, planes,
47 | num_groups=norm_groups, num_k=norm_k,
48 | attention_mode=norm_attention_mode)
49 | self.downsample = downsample
50 | self.stride = stride
51 |
52 | def forward(self, x):
53 | identity = x
54 |
55 | out = self.conv1(x)
56 | out = self.bn1(out)
57 | out = self.relu(out)
58 |
59 | out = self.conv2(out)
60 | out = self.bn2(out)
61 |
62 | if self.downsample is not None:
63 | identity = self.downsample(x)
64 |
65 | out += identity
66 | out = self.relu(out)
67 |
68 | return out
69 |
70 |
71 | class Bottleneck(nn.Module):
72 | expansion = 4
73 |
74 | def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
75 | base_width=64, dilation=1,
76 | norm_name=None, norm_groups=0, norm_k=0, norm_attention_mode=0,
77 | norm_all_mix=False):
78 | super(Bottleneck, self).__init__()
79 | if norm_name is None:
80 | norm_name = "BatchNorm2d"
81 | if norm_all_mix:
82 | norm_name_base = norm_name
83 | else:
84 | if "BatchNorm2d" in norm_name:
85 | norm_name_base = "BatchNorm2d"
86 | elif "GroupNorm" in norm_name:
87 | norm_name_base = "GroupNorm"
88 | else:
89 | raise ValueError("Unknown norm.")
90 | width = int(planes * (base_width / 64.)) * groups
91 | # Both self.conv2 and self.downsample layers downsample the input when stride != 1
92 | self.conv1 = conv1x1(inplanes, width)
93 | self.bn1 = FeatureNorm(norm_name_base, width,
94 | num_groups=norm_groups, num_k=norm_k,
95 | attention_mode=norm_attention_mode)
96 | self.conv2 = conv3x3(width, width, stride, groups, dilation)
97 | self.bn2 = FeatureNorm(norm_name, width,
98 | num_groups=norm_groups, num_k=norm_k,
99 | attention_mode=norm_attention_mode)
100 | self.conv3 = conv1x1(width, planes * self.expansion)
101 | self.bn3 = FeatureNorm(norm_name_base, planes * self.expansion,
102 | num_groups=norm_groups, num_k=norm_k,
103 | attention_mode=norm_attention_mode)
104 | self.relu = nn.ReLU(inplace=True)
105 | self.downsample = downsample
106 | self.stride = stride
107 |
108 | def forward(self, x):
109 | identity = x
110 |
111 | out = self.conv1(x)
112 | out = self.bn1(out)
113 | out = self.relu(out)
114 |
115 | out = self.conv2(out)
116 | out = self.bn2(out)
117 | out = self.relu(out)
118 |
119 | out = self.conv3(out)
120 | out = self.bn3(out)
121 |
122 | if self.downsample is not None:
123 | identity = self.downsample(x)
124 |
125 | out += identity
126 | out = self.relu(out)
127 |
128 | return out
129 |
130 |
131 | class ResNet(nn.Module):
132 | def __init__(self, block, layers, groups=1, width_per_group=64):
133 | super(ResNet, self).__init__()
134 | replace_stride_with_dilation = cfg.resnet.replace_stride_with_dilation
135 | base_inplanes = cfg.resnet.base_inplanes
136 | norm_name = cfg.norm_name
137 | norm_groups = cfg.norm_groups
138 | norm_attention_mode = cfg.norm_attention_mode
139 | norm_all_mix = cfg.norm_all_mix
140 | replace_stride_with_avgpool = cfg.resnet.replace_stride_with_avgpool
141 | self.norm_name = norm_name
142 | self.norm_groups = norm_groups
143 | self.norm_ks = cfg.norm_k
144 | self.norm_attention_mode = norm_attention_mode
145 | self.norm_all_mix = norm_all_mix
146 | if norm_all_mix:
147 | self.norm_name_base = norm_name
148 | else:
149 | if "BatchNorm2d" in norm_name:
150 | self.norm_name_base = "BatchNorm2d"
151 | elif "GroupNorm" in norm_name:
152 | self.norm_name_base = "GroupNorm"
153 | else:
154 | raise ValueError("Unknown norm layer")
155 | if "Mixture" in norm_name:
156 | assert len(self.norm_ks) == len(layers) and any(self.norm_ks), \
157 | "Wrong mixture component specification (cfg.norm_k)"
158 | else:
159 | self.norm_ks = [0 for i in range(len(layers))]
160 |
161 | self.inplanes = base_inplanes
162 | self.extra_norm_ac = cfg.resnet.extra_norm_ac
163 | self.dilation = 1
164 | self.norm_k = self.norm_ks[0]
165 | if replace_stride_with_dilation is None:
166 | # each element in the tuple indicates if we should replace
167 | # the 2x2 stride with a dilated convolution instead
168 | replace_stride_with_dilation = [False, False, False]
169 | if len(replace_stride_with_dilation) != 3:
170 | raise ValueError("replace_stride_with_dilation should be None "
171 | "or a 3-element tuple, got {}".format(replace_stride_with_dilation))
172 | self.groups = groups
173 | self.base_width = width_per_group
174 | if cfg.stem.imagenet_head7x7:
175 | self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=cfg.stem.stem_kernel_size,
176 | stride=cfg.stem.stem_stride, padding=(
177 | cfg.stem.stem_kernel_size-1)//2,
178 | bias=False)
179 | self.bn1 = FeatureNorm(self.norm_name_base, self.inplanes,
180 | num_groups=norm_groups, num_k=self.norm_k,
181 | attention_mode=norm_attention_mode)
182 | self.relu = nn.ReLU(inplace=True)
183 | self.maxpool = nn.MaxPool2d(
184 | kernel_size=3, stride=2, padding=1) if cfg.dataset == 'imagenet' else None
185 | else:
186 | plane = self.inplanes // 2
187 | self.conv1 = nn.Sequential(
188 | nn.Conv2d(3, plane, kernel_size=3,
189 | stride=2, padding=1, bias=False),
190 | FeatureNorm(self.norm_name_base, plane,
191 | num_groups=norm_groups, num_k=self.norm_k,
192 | attention_mode=norm_attention_mode),
193 | nn.ReLU(inplace=True),
194 | nn.Conv2d(plane, plane, kernel_size=3,
195 | stride=1, padding=1, bias=False),
196 | FeatureNorm(self.norm_name_base, plane,
197 | num_groups=norm_groups, num_k=self.norm_k,
198 | attention_mode=norm_attention_mode),
199 | nn.ReLU(inplace=True),
200 | nn.Conv2d(plane, self.inplanes, kernel_size=3,
201 | stride=1, padding=1, bias=False)
202 | )
203 | self.bn1 = FeatureNorm(self.norm_name_base, self.inplanes,
204 | num_groups=norm_groups, num_k=self.norm_k,
205 | attention_mode=norm_attention_mode)
206 | self.relu = nn.ReLU(inplace=True)
207 | self.maxpool = nn.MaxPool2d(kernel_size=2, stride=2)
208 |
209 | self.layer1 = self._make_layer(block, base_inplanes, layers[0],
210 | replace_stride_with_avgpool=replace_stride_with_avgpool)
211 | self.norm_k = self.norm_ks[1]
212 | self.layer2 = self._make_layer(block, base_inplanes*2, layers[1], stride=2,
213 | dilate=replace_stride_with_dilation[0],
214 | replace_stride_with_avgpool=replace_stride_with_avgpool)
215 | self.norm_k = self.norm_ks[2]
216 | self.layer3 = self._make_layer(block, base_inplanes*4, layers[2], stride=2,
217 | dilate=replace_stride_with_dilation[1],
218 | replace_stride_with_avgpool=replace_stride_with_avgpool)
219 | self.layer4 = None
220 | outplanes = base_inplanes*4*block.expansion
221 | if len(layers) > 3:
222 | self.norm_k = self.norm_ks[3]
223 | self.layer4 = self._make_layer(block, base_inplanes*8, layers[3], stride=2,
224 | dilate=replace_stride_with_dilation[2],
225 | replace_stride_with_avgpool=replace_stride_with_avgpool)
226 | outplanes = base_inplanes*8*block.expansion
227 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
228 | self.fc = nn.Linear(outplanes, cfg.num_classes)
229 |
230 | for m in self.modules():
231 | if isinstance(m, nn.Conv2d):
232 | nn.init.kaiming_normal_(
233 | m.weight, mode='fan_out', nonlinearity='relu')
234 | for name, _ in m.named_parameters():
235 | if name in ['bias']:
236 | nn.init.constant_(m.bias, 0)
237 | elif isinstance(m, (MixtureBatchNorm2d, MixtureGroupNorm)):
238 | nn.init.normal_(m.weight_, 1, 0.1)
239 | nn.init.normal_(m.bias_, 0, 0.1)
240 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
241 | nn.init.constant_(m.weight, 1)
242 | nn.init.constant_(m.bias, 0)
243 |
244 | # Zero-initialize the last BN in each residual branch,
245 | # so that the residual branch starts with zeros, and each residual block behaves like an identity.
246 | # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677
247 | if cfg.norm_zero_gamma_init:
248 | for m in self.modules():
249 | if isinstance(m, Bottleneck):
250 | if isinstance(m.bn3, (MixtureBatchNorm2d, MixtureGroupNorm)):
251 | nn.init.constant_(m.bn3.weight_, 0)
252 | nn.init.constant_(m.bn3.bias_, 0)
253 | else:
254 | nn.init.constant_(m.bn3.weight, 0)
255 | elif isinstance(m, BasicBlock):
256 | # TODO: handle mixture norm
257 | nn.init.constant_(m.bn2.weight, 0)
258 |
259 | def _extra_norm_ac(self, out_channels, norm_k):
260 | return nn.Sequential(FeatureNorm(self.norm_name_base, out_channels,
261 | self.norm_groups, norm_k,
262 | self.norm_attention_mode),
263 | nn.ReLU(inplace=True))
264 |
265 | def _make_layer(self, block, planes, blocks, stride=1, dilate=False,
266 | replace_stride_with_avgpool=False):
267 | norm_name = self.norm_name
268 | norm_name_base = self.norm_name_base
269 | norm_groups = self.norm_groups
270 | norm_k = self.norm_k
271 | norm_attention_mode = self.norm_attention_mode
272 | norm_all_mix = self.norm_all_mix
273 | downsample = None
274 | previous_dilation = self.dilation
275 | if dilate:
276 | self.dilation *= stride
277 | stride = 1
278 |
279 | downsample_op = []
280 | if stride != 1 or self.inplanes != planes * block.expansion:
281 | downsample_stride = stride
282 | if replace_stride_with_avgpool and stride > 1:
283 | downsample_op.append(nn.AvgPool2d((stride, stride), stride))
284 | downsample_stride = 1
285 |
286 | downsample_op.append(
287 | conv1x1(self.inplanes, planes * block.expansion, downsample_stride))
288 | downsample_op.append(FeatureNorm(norm_name_base, planes * block.expansion,
289 | num_groups=norm_groups, num_k=norm_k,
290 | attention_mode=norm_attention_mode))
291 |
292 | if len(downsample_op) > 0:
293 | downsample = nn.Sequential(*downsample_op)
294 |
295 | layers = []
296 | layers.append(block(self.inplanes, planes, stride, downsample, self.groups,
297 | base_width=self.base_width, dilation=previous_dilation,
298 | norm_name=norm_name, norm_groups=norm_groups, norm_k=norm_k,
299 | norm_attention_mode=norm_attention_mode,
300 | norm_all_mix=norm_all_mix))
301 | self.inplanes = planes * block.expansion
302 | for _ in range(1, blocks):
303 | layers.append(block(self.inplanes, planes, groups=self.groups,
304 | base_width=self.base_width, dilation=self.dilation,
305 | norm_name=norm_name, norm_groups=norm_groups, norm_k=norm_k,
306 | norm_attention_mode=norm_attention_mode,
307 | norm_all_mix=norm_all_mix))
308 |
309 | if self.extra_norm_ac:
310 | layers.append(self._extra_norm_ac(self.inplanes, norm_k))
311 |
312 | return nn.Sequential(*layers)
313 |
314 | def forward(self, x):
315 | x = self.conv1(x)
316 | x = self.bn1(x)
317 | x = self.relu(x)
318 | if self.maxpool is not None:
319 | x = self.maxpool(x)
320 |
321 | x = self.layer1(x)
322 | x = self.layer2(x)
323 | x = self.layer3(x)
324 | if self.layer4 is not None:
325 | x = self.layer4(x)
326 |
327 | x = self.avgpool(x)
328 | x = x.reshape(x.size(0), -1)
329 | x = self.fc(x)
330 |
331 | return x
332 |
333 |
334 | def _resnet(arch, inplanes, planes, pretrained, progress, **kwargs):
335 | model = ResNet(inplanes, planes, **kwargs)
336 | if pretrained:
337 | state_dict = load_state_dict_from_url(model_urls[arch],
338 | progress=progress)
339 | model.load_state_dict(state_dict)
340 | return model
341 |
342 |
343 | def resnet18(pretrained=False, progress=True, **kwargs):
344 | """Constructs a ResNet-18 model.
345 | Args:
346 | pretrained (bool): If True, returns a model pre-trained on ImageNet
347 | progress (bool): If True, displays a progress bar of the download to stderr
348 | """
349 | return _resnet('resnet18', BasicBlock, [2, 2, 2, 2], pretrained, progress,
350 | **kwargs)
351 |
352 |
353 | def resnet34(pretrained=False, progress=True, **kwargs):
354 | """Constructs a ResNet-34 model.
355 | Args:
356 | pretrained (bool): If True, returns a model pre-trained on ImageNet
357 | progress (bool): If True, displays a progress bar of the download to stderr
358 | """
359 | return _resnet('resnet34', BasicBlock, [3, 4, 6, 3], pretrained, progress,
360 | **kwargs)
361 |
362 |
363 | def resnet50(pretrained=False, progress=True, **kwargs):
364 | """Constructs a ResNet-50 model.
365 | Args:
366 | pretrained (bool): If True, returns a model pre-trained on ImageNet
367 | progress (bool): If True, displays a progress bar of the download to stderr
368 | """
369 | return _resnet('resnet50', Bottleneck, [3, 4, 6, 3], pretrained, progress,
370 | **kwargs)
371 |
372 |
373 | def resnet101(pretrained=False, progress=True, **kwargs):
374 | """Constructs a ResNet-101 model.
375 | Args:
376 | pretrained (bool): If True, returns a model pre-trained on ImageNet
377 | progress (bool): If True, displays a progress bar of the download to stderr
378 | """
379 | return _resnet('resnet101', Bottleneck, [3, 4, 23, 3], pretrained, progress,
380 | **kwargs)
381 |
382 |
383 | def resnet152(pretrained=False, progress=True, **kwargs):
384 | """Constructs a ResNet-152 model.
385 | Args:
386 | pretrained (bool): If True, returns a model pre-trained on ImageNet
387 | progress (bool): If True, displays a progress bar of the download to stderr
388 | """
389 | return _resnet('resnet152', Bottleneck, [3, 8, 36, 3], pretrained, progress,
390 | **kwargs)
391 |
392 |
393 | def resnext50_32x4d(**kwargs):
394 | kwargs['groups'] = 32
395 | kwargs['width_per_group'] = 4
396 | return _resnet('resnext50_32x4d', Bottleneck, [3, 4, 6, 3],
397 | pretrained=False, progress=True, **kwargs)
398 |
399 |
400 | def resnext101_32x8d(**kwargs):
401 | kwargs['groups'] = 32
402 | kwargs['width_per_group'] = 8
403 | return _resnet('resnext101_32x8d', Bottleneck, [3, 4, 23, 3],
404 | pretrained=False, progress=True, **kwargs)
405 |
406 |
407 | def resnext101_64x4d(**kwargs):
408 | kwargs['groups'] = 64
409 | kwargs['width_per_group'] = 4
410 | return _resnet('resnext101_64x4d', Bottleneck, [3, 4, 23, 3],
411 | pretrained=False, progress=True, **kwargs)
412 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | opencv-python
2 | thop
3 | yacs
4 | scipy
5 |
--------------------------------------------------------------------------------
/scripts/test_fp16.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 |
3 | # Usage: test_fp16.sh pretrained_model_folder
4 |
5 | DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" >/dev/null 2>&1 && pwd )"
6 |
7 | if [ "$#" -ne 1 ]; then
8 | echo "Usage: test_fp16.sh pretrained_model_folder"
9 | exit
10 | fi
11 |
12 | PRETRAINED_MODEL_PAHT=$1
13 | CONFIG_FILE=$PRETRAINED_MODEL_PAHT/config.yaml
14 | PRETRAINED_MODEL_FILE=$PRETRAINED_MODEL_PAHT/model_best.pth.tar
15 |
16 | ### Change accordingly
17 | GPUS=0,1,2,3,4,5,6,7
18 | NUM_GPUS=8
19 | NUM_WORKERS=8
20 | MASTER_PORT=1245
21 |
22 | # ImageNet
23 | DATA=$DIR/../datasets/ILSVRC2015/Data/CLS-LOC/
24 |
25 | # test
26 | CUDA_VISIBLE_DEVICES=$GPUS python -W ignore -m torch.distributed.launch --nproc_per_node=$NUM_GPUS --master_port $MASTER_PORT \
27 | $DIR/../tools/main_fp16.py --cfg $CONFIG_FILE --workers $NUM_WORKERS \
28 | --fp16 \
29 | -p 100 --save-dir $PRETRAINED_MODEL_PAHT --pretrained --evaluate $DATA
30 |
31 |
32 |
33 |
--------------------------------------------------------------------------------
/scripts/train_fp16.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 |
3 | # Usage: train_fp16.sh config_filename
4 |
5 | DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" >/dev/null 2>&1 && pwd )"
6 |
7 | if [ "$#" -ne 1 ]; then
8 | echo "Usage: train_fp16.sh relative_config_filename"
9 | exit
10 | fi
11 |
12 | CONFIG_FILE=$DIR/../$1
13 |
14 | ### Change accordingly
15 | GPUS=0,1,2,3,4,5,6,7
16 | NUM_GPUS=8
17 | NUM_WORKERS=8
18 | MASTER_PORT=1234
19 |
20 | CONFIG_FILENAME="$(cut -d'/' -f2 <<<$1)"
21 | CONFIG_BASE="${CONFIG_FILENAME%.*}"
22 | SAVE_DIR=$DIR/../results/$CONFIG_BASE
23 | mkdir -p $SAVE_DIR
24 |
25 | # backup for reproducing results
26 | cp $CONFIG_FILE $SAVE_DIR/config.yaml
27 | cp -r $DIR/../models $SAVE_DIR
28 | cp $DIR/../tools/main_fp16.py $SAVE_DIR
29 |
30 | # ImageNet
31 | DATA=$DIR/../datasets/ILSVRC2015/Data/CLS-LOC/
32 |
33 | # train
34 | CUDA_VISIBLE_DEVICES=$GPUS python -W ignore -m torch.distributed.launch --nproc_per_node=$NUM_GPUS --master_port $MASTER_PORT \
35 | $DIR/../tools/main_fp16.py --cfg $CONFIG_FILE --workers $NUM_WORKERS \
36 | --fp16 --static-loss-scale 128 \
37 | -p 100 --save-dir $SAVE_DIR $DATA \
38 | 2>&1 | tee $SAVE_DIR/log.txt
39 |
40 |
--------------------------------------------------------------------------------
/tools/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/iVMCL/AOGNet-v2/a95a8696c131331607e81bb31eeae3405a76b969/tools/__init__.py
--------------------------------------------------------------------------------
/tools/main_fp16.py:
--------------------------------------------------------------------------------
1 | # From: https://github.com/NVIDIA/apex
2 |
3 | ### some tweaks
4 | # USE pillow-simd to speed up pytorch image loader
5 | # pip uninstall pillow
6 | # conda uninstall --force jpeg libtiff -y
7 | # conda install -c conda-forge libjpeg-turbo
8 | # CC="cc -mavx2" pip install --no-cache-dir -U --force-reinstall --no-binary :all: --compile pillow-simd
9 |
10 | # Install NCCL https://docs.nvidia.com/deeplearning/sdk/nccl-install-guide/index.html
11 |
12 | import argparse
13 | import os
14 | import sys
15 | import shutil
16 | import time
17 | import copy
18 |
19 | import torch
20 | import torch.nn as nn
21 | import torch.nn.parallel
22 | import torch.backends.cudnn as cudnn
23 | import torch.distributed as dist
24 | import torch.optim
25 | import torch.utils.data
26 | import torch.utils.data.distributed
27 | import torchvision.transforms as transforms
28 | import torchvision.datasets as datasets
29 |
30 | import numpy as np
31 |
32 | try:
33 | from apex.parallel import DistributedDataParallel as DDP
34 | from apex.fp16_utils import *
35 | except ImportError:
36 | raise ImportError("Please install apex from https://www.github.com/nvidia/apex to run this example.")
37 |
38 |
39 | import math
40 | import sys
41 | import re
42 | sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..'))
43 | from models.aognet.operator_basic import MixtureBatchNorm2d, MixtureGroupNorm
44 | from models.aognet.aognet import aognet
45 | from models.config import cfg
46 | import models.resnet as resnets
47 | import models.mobilenet as mobilenets
48 | from smoothing import LabelSmoothing
49 |
50 | parser = argparse.ArgumentParser(description='PyTorch Image Classification Training')
51 | parser.add_argument('data', metavar='DIR',
52 | help='path to dataset')
53 | parser.add_argument('-j', '--workers', default=4, type=int, metavar='N',
54 | help='number of data loading workers (default: 4)')
55 | parser.add_argument('--epochs', default=90, type=int, metavar='N',
56 | help='number of total epochs to run')
57 | parser.add_argument('--start-epoch', default=0, type=int, metavar='N',
58 | help='manual epoch number (useful on restarts)')
59 | parser.add_argument('-b', '--batch-size', default=256, type=int,
60 | metavar='N', help='mini-batch size per process (default: 256)')
61 | parser.add_argument('--lr', '--learning-rate', default=0.1, type=float,
62 | metavar='LR', help='Initial learning rate. \
63 | Will be scaled by /256: args.lr = args.lr*float(args.batch_size*args.world_size)/256. \
64 | A warmup schedule will also be applied over the first 5 epochs.')
65 | parser.add_argument('--momentum', default=0.9, type=float, metavar='M',
66 | help='momentum')
67 | parser.add_argument('--weight-decay', '--wd', default=1e-4, type=float,
68 | metavar='W', help='weight decay (default: 1e-4)')
69 | parser.add_argument('--print-freq', '-p', default=10, type=int,
70 | metavar='N', help='print frequency (default: 10)')
71 | parser.add_argument('--resume', default='', type=str, metavar='PATH',
72 | help='path to latest checkpoint (default: none)')
73 | parser.add_argument('-e', '--evaluate', dest='evaluate', action='store_true',
74 | help='evaluate model on validation set')
75 | parser.add_argument('--pretrained', dest='pretrained', action='store_true',
76 | help='use pre-trained model')
77 |
78 | parser.add_argument('--fp16', action='store_true',
79 | help='Run model fp16 mode.')
80 | parser.add_argument('--static-loss-scale', type=float, default=1,
81 | help='Static loss scale, positive power of 2 values can improve fp16 convergence.')
82 | parser.add_argument('--dynamic-loss-scale', action='store_true',
83 | help='Use dynamic loss scaling. If supplied, this argument supersedes ' +
84 | '--static-loss-scale.')
85 | parser.add_argument('--prof', dest='prof', action='store_true',
86 | help='Only run 10 iterations for profiling.')
87 | parser.add_argument('--deterministic', action='store_true')
88 |
89 | parser.add_argument("--local_rank", default=0, type=int)
90 | parser.add_argument('--sync_bn', action='store_true',
91 | help='enabling apex sync BN.')
92 |
93 | parser.add_argument('--cfg', help='experiment configure file name', required=True, type=str)
94 | parser.add_argument('--save-dir', type=str, default='/tmp/models')
95 | parser.add_argument('--nesterov', type=str, default=None)
96 | parser.add_argument('--remove-norm-weight-decay', type=str, default=None)
97 |
98 | cudnn.benchmark = True
99 |
100 | def fast_collate(batch):
101 | imgs = [img[0] for img in batch]
102 | targets = torch.tensor([target[1] for target in batch], dtype=torch.int64)
103 | w = imgs[0].size[0]
104 | h = imgs[0].size[1]
105 | tensor = torch.zeros( (len(imgs), 3, h, w), dtype=torch.uint8 )
106 | for i, img in enumerate(imgs):
107 | nump_array = np.asarray(img, dtype=np.uint8)
108 | if(nump_array.ndim < 3):
109 | nump_array = np.expand_dims(nump_array, axis=-1)
110 | nump_array = np.rollaxis(nump_array, 2)
111 |
112 | tensor[i] += torch.from_numpy(nump_array)
113 |
114 | return tensor, targets
115 |
116 | best_prec1 = 0
117 | best_prec1_val = 0
118 | prec5_val = 0
119 | best_prec5_val = 0
120 |
121 | args = parser.parse_args()
122 |
123 | if args.local_rank == 0:
124 | print("PyTorch VERSION: {}".format(torch.__version__)) # PyTorch version
125 | print("CUDA VERSION: {}".format(torch.version.cuda)) # Corresponding CUDA version
126 | print("CUDNN VERSION: {}".format(torch.backends.cudnn.version())) # Corresponding cuDNN version
127 | print("GPU TYPE: {}".format(torch.cuda.get_device_name(0))) # GPU type
128 |
129 | if args.deterministic:
130 | cudnn.benchmark = False
131 | cudnn.deterministic = True
132 | torch.manual_seed(args.local_rank)
133 | torch.set_printoptions(precision=10)
134 |
135 | def main():
136 | global best_prec1, args
137 |
138 | args.distributed = False
139 | if 'WORLD_SIZE' in os.environ:
140 | args.distributed = int(os.environ['WORLD_SIZE']) > 1
141 |
142 | args.gpu = 0
143 | args.world_size = 1
144 |
145 | if args.distributed:
146 | args.gpu = args.local_rank
147 | torch.cuda.set_device(args.gpu)
148 | torch.distributed.init_process_group(backend='nccl',
149 | init_method='env://')
150 | args.world_size = torch.distributed.get_world_size()
151 |
152 | if args.fp16:
153 | assert torch.backends.cudnn.enabled, "fp16 requires cudnn backend to be enabled."
154 | if args.static_loss_scale != 1.0:
155 | if not args.fp16:
156 | print("Warning: if --fp16 is not used, static_loss_scale will be ignored.")
157 |
158 | # create model
159 | if args.pretrained:
160 | cfg.merge_from_file(os.path.join(args.save_dir, 'config.yaml'))
161 | else:
162 | cfg.merge_from_file(args.cfg)
163 | args.arch = cfg.arch
164 | # update args
165 | args.batch_size = cfg.batch_size
166 | args.lr = cfg.lr
167 | args.momentum = cfg.momentum
168 | args.weight_decay = cfg.wd
169 | args.nesterov = cfg.nesterov
170 | args.epochs = cfg.num_epoch
171 |
172 | if args.local_rank == 0:
173 | print("=> creating {}".format(args.arch))
174 | if args.arch.startswith('aognet'):
175 | model = aognet()
176 | elif args.arch.startswith('resnet') or args.arch.startswith('resnext'):
177 | model = resnets.__dict__[args.arch]()
178 | elif args.arch.startswith('mobilenet'):
179 | model = mobilenets.__dict__[args.arch]()
180 | else:
181 | raise NotImplementedError("Unkown network arch.")
182 |
183 | if args.pretrained:
184 | if args.local_rank == 0:
185 | print("=> loading pre-trained model '{}'".format(args.arch))
186 | checkpoint = torch.load(os.path.join(args.save_dir, 'model_best.pth.tar'), map_location='cpu')
187 | st_dict = {k[15:]: v for k, v in checkpoint['state_dict'].items()}
188 | model.load_state_dict(st_dict)
189 |
190 | if args.local_rank == 0:
191 | print('=> Params (double-check): %.6fM' % (sum(p.numel() for p in model.parameters()) / 1e6))
192 |
193 | #sys.exit()
194 |
195 | if args.sync_bn:
196 | import apex
197 | if args.local_rank == 0:
198 | print("using apex synced BN")
199 | model = apex.parallel.convert_syncbn_model(model)
200 |
201 | model = model.cuda()
202 | if args.fp16:
203 | model = FP16Model(model)
204 | if args.distributed:
205 | # By default, apex.parallel.DistributedDataParallel overlaps communication with
206 | # computation in the backward pass.
207 | # model = DDP(model)
208 | # delay_allreduce delays all communication to the end of the backward pass.
209 | model = DDP(model, delay_allreduce=True)
210 |
211 | # Scale learning rate based on global batch size
212 | args.lr = args.lr*float(args.batch_size*args.world_size)/cfg.lr_scale_factor #TODO: control the maximum?
213 |
214 | if args.remove_norm_weight_decay:
215 | if args.local_rank == 0:
216 | print("=> ! Weight decay NOT applied to FeatNorm parameters ")
217 | norm_params=set() #TODO: need to check this via experiments
218 | rest_params=set()
219 | for m in model.modules():
220 | if isinstance(m, (nn.BatchNorm2d, nn.GroupNorm, MixtureBatchNorm2d, MixtureGroupNorm)):
221 | for param in m.parameters(False):
222 | norm_params.add(param)
223 | else:
224 | for param in m.parameters(False):
225 | rest_params.add(param)
226 |
227 | optimizer = torch.optim.SGD([{'params': list(norm_params), 'weight_decay' : 0.0},
228 | {'params': list(rest_params)}],
229 | args.lr,
230 | momentum=args.momentum,
231 | weight_decay=args.weight_decay,
232 | nesterov=args.nesterov)
233 | else:
234 | if args.local_rank == 0:
235 | print("=> ! Weight decay applied to FeatNorm parameters ")
236 | optimizer = torch.optim.SGD(model.parameters(), args.lr,
237 | momentum=args.momentum,
238 | weight_decay=args.weight_decay,
239 | nesterov=args.nesterov)
240 |
241 | if args.fp16:
242 | optimizer = FP16_Optimizer(optimizer,
243 | static_loss_scale=args.static_loss_scale,
244 | dynamic_loss_scale=args.dynamic_loss_scale)
245 |
246 | # define loss function (criterion) and optimizer
247 | criterion_train = nn.CrossEntropyLoss().cuda() if cfg.dataaug.labelsmoothing_rate == 0.0 \
248 | else LabelSmoothing(cfg.dataaug.labelsmoothing_rate).cuda()
249 | criterion_val = nn.CrossEntropyLoss().cuda()
250 |
251 | # Optionally resume from a checkpoint
252 | if args.resume:
253 | # Use a local scope to avoid dangling references
254 | def resume():
255 | if os.path.isfile(args.resume):
256 | if args.local_rank == 0:
257 | print("=> loading checkpoint '{}'".format(args.resume))
258 | checkpoint = torch.load(args.resume, map_location = lambda storage, loc: storage.cuda(args.gpu))
259 | args.start_epoch = checkpoint['epoch']
260 | best_prec1 = checkpoint['best_prec1']
261 | model.load_state_dict(checkpoint['state_dict'])
262 | optimizer.load_state_dict(checkpoint['optimizer'])
263 | if args.local_rank == 0:
264 | print("=> loaded checkpoint '{}' (epoch {})"
265 | .format(args.resume, checkpoint['epoch']))
266 | else:
267 | if args.local_rank == 0:
268 | print("=> no checkpoint found at '{}'".format(args.resume))
269 | resume()
270 |
271 | # Data loading code
272 | lr_milestones = None
273 | if cfg.dataset == "cifar10":
274 | train_transform = transforms.Compose([
275 | transforms.RandomCrop(32, padding=4),
276 | transforms.RandomHorizontalFlip()
277 | ])
278 | train_dataset = datasets.CIFAR10('./datasets', train=True, download=False, transform=train_transform)
279 | val_dataset = datasets.CIFAR10('./datasets', train=False, download=False)
280 | lr_milestones = cfg.lr_milestones
281 | elif cfg.dataset == "cifar100":
282 | train_transform = transforms.Compose([
283 | transforms.RandomCrop(32, padding=4),
284 | transforms.RandomHorizontalFlip()
285 | ])
286 | train_dataset = datasets.CIFAR100('./datasets', train=True, download=False, transform=train_transform)
287 | val_dataset = datasets.CIFAR100('./datasets', train=False, download=False)
288 | lr_milestones = cfg.lr_milestones
289 | elif cfg.dataset == "imagenet":
290 | traindir = os.path.join(args.data, 'train')
291 | valdir = os.path.join(args.data, 'val')
292 |
293 | crop_size = cfg.crop_size # 224
294 | val_size = cfg.crop_size + 32 # 256
295 |
296 | train_dataset = datasets.ImageFolder(
297 | traindir,
298 | transforms.Compose([
299 | transforms.RandomResizedCrop(crop_size, interpolation=cfg.crop_interpolation),
300 | transforms.RandomHorizontalFlip(),
301 | # transforms.ToTensor(), Too slow
302 | # normalize,
303 | ]))
304 | val_dataset = datasets.ImageFolder(valdir, transforms.Compose([
305 | transforms.Resize(val_size, interpolation=cfg.crop_interpolation),
306 | transforms.CenterCrop(crop_size),
307 | ]))
308 |
309 | train_sampler = None
310 | val_sampler = None
311 | if args.distributed:
312 | train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset)
313 | val_sampler = torch.utils.data.distributed.DistributedSampler(val_dataset)
314 |
315 | train_loader = torch.utils.data.DataLoader(
316 | train_dataset, batch_size=args.batch_size, shuffle=(train_sampler is None),
317 | num_workers=args.workers, pin_memory=True, sampler=train_sampler, collate_fn=fast_collate)
318 |
319 | val_loader = torch.utils.data.DataLoader(
320 | val_dataset,
321 | batch_size=args.batch_size, shuffle=False,
322 | num_workers=args.workers, pin_memory=True,
323 | sampler=val_sampler,
324 | collate_fn=fast_collate)
325 |
326 | if args.evaluate:
327 | validate(val_loader, model, criterion_val)
328 | return
329 |
330 | scheduler = CosineAnnealingLR(optimizer.optimizer if args.fp16 else optimizer,
331 | args.epochs, len(train_loader),
332 | eta_min=cfg.cosine_lr_min, warmup=cfg.warmup_epochs) if cfg.use_cosine_lr else None
333 |
334 | for epoch in range(args.start_epoch, args.epochs):
335 | if args.distributed:
336 | train_sampler.set_epoch(epoch)
337 |
338 | # train for one epoch
339 | train(train_loader, model, criterion_train, optimizer, epoch, scheduler, lr_milestones, cfg.warmup_epochs,
340 | cfg.dataaug.mixup_rate, cfg.dataaug.labelsmoothing_rate)
341 | if args.prof:
342 | break
343 | # evaluate on validation set
344 | prec1 = validate(val_loader, model, criterion_val)
345 |
346 | # remember best prec@1 and save checkpoint
347 | if args.local_rank == 0:
348 | is_best = prec1 > best_prec1
349 | best_prec1 = max(prec1, best_prec1)
350 | save_checkpoint({
351 | 'epoch': epoch + 1,
352 | 'arch': args.arch,
353 | 'state_dict': model.state_dict(),
354 | 'best_prec1': best_prec1,
355 | 'optimizer' : optimizer.state_dict(),
356 | }, is_best, args.save_dir)
357 |
358 | class data_prefetcher():
359 | def __init__(self, loader):
360 | self.loader = iter(loader)
361 | self.stream = torch.cuda.Stream()
362 | if cfg.dataset == 'cifar10':
363 | self.mean = torch.tensor([0.49139968 * 255, 0.48215827 * 255, 0.44653124 * 255]).cuda().view(1,3,1,1)
364 | self.std = torch.tensor([0.24703233 * 255, 0.24348505 * 255, 0.26158768 * 255]).cuda().view(1,3,1,1)
365 | elif cfg.dataset == 'cifar100':
366 | self.mean = torch.tensor([0.5071 * 255, 0.4867 * 255, 0.4408 * 255]).cuda().view(1,3,1,1)
367 | self.std = torch.tensor([0.2675 * 255, 0.2565 * 255, 0.2761 * 255]).cuda().view(1,3,1,1)
368 | elif cfg.dataset == 'imagenet':
369 | self.mean = torch.tensor([0.485 * 255, 0.456 * 255, 0.406 * 255]).cuda().view(1,3,1,1)
370 | self.std = torch.tensor([0.229 * 255, 0.224 * 255, 0.225 * 255]).cuda().view(1,3,1,1)
371 | else:
372 | raise NotImplementedError
373 | if args.fp16:
374 | self.mean = self.mean.half()
375 | self.std = self.std.half()
376 | self.preload()
377 |
378 | def preload(self):
379 | try:
380 | self.next_input, self.next_target = next(self.loader)
381 | except StopIteration:
382 | self.next_input = None
383 | self.next_target = None
384 | return
385 | with torch.cuda.stream(self.stream):
386 | self.next_input = self.next_input.cuda(non_blocking=True)
387 | self.next_target = self.next_target.cuda(non_blocking=True)
388 | if args.fp16:
389 | self.next_input = self.next_input.half()
390 | else:
391 | self.next_input = self.next_input.float()
392 | self.next_input = self.next_input.sub_(self.mean).div_(self.std)
393 |
394 | def next(self):
395 | torch.cuda.current_stream().wait_stream(self.stream)
396 | input = self.next_input
397 | target = self.next_target
398 | self.preload()
399 | return input, target
400 |
401 | # from NVIDIA DL Examples
402 | def prefetched_loader(loader):
403 | if cfg.dataset == 'cifar10':
404 | self.mean = torch.tensor([0.49139968 * 255, 0.48215827 * 255, 0.44653124 * 255]).cuda().view(1,3,1,1)
405 | self.std = torch.tensor([0.24703233 * 255, 0.24348505 * 255, 0.26158768 * 255]).cuda().view(1,3,1,1)
406 | elif cfg.dataset == 'cifar100':
407 | self.mean = torch.tensor([0.5071 * 255, 0.4867 * 255, 0.4408 * 255]).cuda().view(1,3,1,1)
408 | self.std = torch.tensor([0.2675 * 255, 0.2565 * 255, 0.2761 * 255]).cuda().view(1,3,1,1)
409 | elif cfg.dataset == 'imagenet':
410 | self.mean = torch.tensor([0.485 * 255, 0.456 * 255, 0.406 * 255]).cuda().view(1,3,1,1)
411 | self.std = torch.tensor([0.229 * 255, 0.224 * 255, 0.225 * 255]).cuda().view(1,3,1,1)
412 | else:
413 | raise NotImplementedError
414 |
415 | stream = torch.cuda.Stream()
416 | first = True
417 |
418 | for next_input, next_target in loader:
419 | with torch.cuda.stream(stream):
420 | next_input = next_input.cuda(non_blocking=True)
421 | next_target = next_target.cuda(non_blocking=True)
422 | next_input = next_input.float()
423 | next_input = next_input.sub_(mean).div_(std)
424 |
425 | if not first:
426 | yield input, target
427 | else:
428 | first = False
429 |
430 | torch.cuda.current_stream().wait_stream(stream)
431 | input = next_input
432 | target = next_target
433 |
434 | yield input, target
435 |
436 |
437 | def train(train_loader, model, criterion, optimizer, epoch, scheduler=None, lr_milestones=None, warmup_epoch=0,
438 | mixup_rate=0.0, labelsmoothing_rate=0.0):
439 | batch_time = AverageMeter()
440 | data_time = AverageMeter()
441 | losses = AverageMeter()
442 | top1 = AverageMeter()
443 | top5 = AverageMeter()
444 |
445 | # switch to train mode
446 | model.train()
447 | end = time.time()
448 |
449 | prefetcher = data_prefetcher(train_loader)
450 | input, target = prefetcher.next()
451 | i = -1
452 | beta_distribution = torch.distributions.beta.Beta(mixup_rate, mixup_rate)
453 | while input is not None:
454 | i += 1
455 |
456 | if scheduler is None:
457 | lr = adjust_learning_rate(optimizer, epoch, i, len(train_loader), lr_milestones, warmup_epoch)
458 | else:
459 | lr = scheduler.update(epoch, i)
460 |
461 | if args.prof:
462 | if i > 10:
463 | break
464 | # measure data loading time
465 | data_time.update(time.time() - end)
466 |
467 | # Mixup input
468 | if mixup_rate > 0.0:
469 | lambda_ = beta_distribution.sample([]).item()
470 | index = torch.randperm(input.size(0)).cuda()
471 | input = lambda_ * input + (1 - lambda_) * input[index, :]
472 |
473 | # compute output
474 | if args.prof: torch.cuda.nvtx.range_push("forward")
475 | output = model(input)
476 | if args.prof: torch.cuda.nvtx.range_pop()
477 |
478 | # Mixup loss
479 | if mixup_rate > 0.0:
480 | # Mixup loss
481 | loss = (lambda_ * criterion(output, target)
482 | + (1 - lambda_) * criterion(output, target[index]))
483 |
484 | # Mixup target
485 | if labelsmoothing_rate > 0.0:
486 | N = output.size(0)
487 | C = output.size(1)
488 | off_prob = labelsmoothing_rate / C
489 | target_1 = torch.full(size=(N, C), fill_value=off_prob ).cuda()
490 | target_2 = torch.full(size=(N, C), fill_value=off_prob ).cuda()
491 | target_1.scatter_(dim=1, index=torch.unsqueeze(target, dim=1), value=1.0-labelsmoothing_rate+off_prob)
492 | target_2.scatter_(dim=1, index=torch.unsqueeze(target[index], dim=1), value=1.0-labelsmoothing_rate+off_prob)
493 | target = lambda_ * target_1 + (1 - lambda_) * target_2
494 | else:
495 | target = lambda_ * target + (1 - lambda_) * target[index]
496 | else:
497 | loss = criterion(output, target)
498 |
499 | # compute gradient and do SGD step
500 | optimizer.zero_grad()
501 |
502 | if args.prof: torch.cuda.nvtx.range_push("backward")
503 | if args.fp16:
504 | optimizer.backward(loss)
505 | else:
506 | loss.backward()
507 | if args.prof: torch.cuda.nvtx.range_pop()
508 |
509 | # debug
510 | # if args.local_rank == 0:
511 | # for name_, param in model.named_parameters():
512 | # print(name_, param.data.double().sum().item(), param.grad.data.double().sum().item())
513 |
514 | if args.prof: torch.cuda.nvtx.range_push("step")
515 | optimizer.step()
516 | if args.prof: torch.cuda.nvtx.range_pop()
517 |
518 | # Measure accuracy
519 | if mixup_rate > 0.0:
520 | prec1 = rmse(output.data, target)
521 | prec5 = prec1
522 | else:
523 | prec1, prec5 = accuracy(output.data, target, topk=(1, 5))
524 |
525 | # Average loss and accuracy across processes for logging
526 | if args.distributed:
527 | reduced_loss = reduce_tensor(loss.data)
528 | prec1 = reduce_tensor(prec1)
529 | prec5 = reduce_tensor(prec5)
530 | else:
531 | reduced_loss = loss.data
532 |
533 | # to_python_float incurs a host<->device sync
534 | losses.update(to_python_float(reduced_loss), input.size(0))
535 | top1.update(to_python_float(prec1), input.size(0))
536 | top5.update(to_python_float(prec5), input.size(0))
537 |
538 | # torch.cuda.synchronize() # no this in torchvision ex. and cause nan loss problems in deep models with fp16
539 |
540 | batch_time.update(time.time() - end)
541 | end = time.time()
542 | input, target = prefetcher.next()
543 |
544 | if i%args.print_freq == 0 and args.local_rank == 0:
545 | # Every print_freq iterations, check the loss, accuracy, and speed.
546 | print('Epoch: [{0}][{1}/{2}]\t'
547 | 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
548 | 'Speed {3:.3f} ({4:.3f})\t'
549 | 'Data {data_time.val:.3f} ({data_time.avg:.3f})\t'
550 | 'Loss {loss.val:.10f} ({loss.avg:.4f})\t'
551 | 'Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t'
552 | 'Prec@5 {top5.val:.3f} ({top5.avg:.3f})\t'
553 | 'lr {lr:.6f}\t'.format(
554 | epoch, i, len(train_loader),
555 | args.world_size*args.batch_size/batch_time.val,
556 | args.world_size*args.batch_size/batch_time.avg,
557 | batch_time=batch_time,
558 | data_time=data_time, loss=losses, top1=top1, top5=top5, lr=lr[0]))
559 |
560 | def validate(val_loader, model, criterion):
561 | global best_prec1_val, prec5_val, best_prec5_val
562 | batch_time = AverageMeter()
563 | losses = AverageMeter()
564 | top1 = AverageMeter()
565 | top5 = AverageMeter()
566 |
567 | # switch to evaluate mode
568 | model.eval()
569 |
570 | end = time.time()
571 |
572 | prefetcher = data_prefetcher(val_loader)
573 | input, target = prefetcher.next()
574 | i = -1
575 | while input is not None:
576 | i += 1
577 |
578 | # compute output
579 | with torch.no_grad():
580 | output = model(input)
581 | loss = criterion(output, target)
582 |
583 | # measure accuracy and record loss
584 | prec1, prec5 = accuracy(output.data, target, topk=(1, 5))
585 |
586 | if args.distributed:
587 | reduced_loss = reduce_tensor(loss.data)
588 | prec1 = reduce_tensor(prec1)
589 | prec5 = reduce_tensor(prec5)
590 | else:
591 | reduced_loss = loss.data
592 |
593 | losses.update(to_python_float(reduced_loss), input.size(0))
594 | top1.update(to_python_float(prec1), input.size(0))
595 | top5.update(to_python_float(prec5), input.size(0))
596 |
597 | # measure elapsed time
598 | batch_time.update(time.time() - end)
599 | end = time.time()
600 |
601 | # TODO: Change timings to mirror train().
602 | if args.local_rank == 0 and i > 0 and i % args.print_freq == 0:
603 | print('Test: [{0}/{1}]\t'
604 | 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
605 | 'Speed {2:.3f} ({3:.3f})\t'
606 | 'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
607 | 'Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t'
608 | 'Prec@5 {top5.val:.3f} ({top5.avg:.3f})'.format(
609 | i, len(val_loader),
610 | args.world_size * args.batch_size / batch_time.val,
611 | args.world_size * args.batch_size / batch_time.avg,
612 | batch_time=batch_time, loss=losses,
613 | top1=top1, top5=top5))
614 |
615 | input, target = prefetcher.next()
616 |
617 | if args.local_rank == 0:
618 | if top1.avg >= best_prec1_val:
619 | best_prec1_val = top1.avg
620 | prec5_val = top5.avg
621 | best_prec5_val = max(best_prec5_val, top5.avg)
622 | print('Test: Prec@1 {top1.avg:.3f} Prec@5 {top5.avg:.3f}\t Best_Prec@1 {best:.3f}\t Prec@5 {prec5_val:.3f}\t Best_Prec@5 {bestprec5_val:.3f}'
623 | .format(top1=top1, top5=top5, best=best_prec1_val, prec5_val=prec5_val, bestprec5_val=best_prec5_val))
624 |
625 | return top1.avg
626 |
627 |
628 | def save_checkpoint(state, is_best, save_dir='./'):
629 | filename = os.path.join(save_dir, 'checkpoint.pth.tar')
630 | best_file = os.path.join(save_dir, 'model_best.pth.tar')
631 | torch.save(state, filename)
632 | if is_best:
633 | shutil.copyfile(filename, best_file)
634 |
635 | class AverageMeter(object):
636 | """Computes and stores the average and current value"""
637 | def __init__(self):
638 | self.reset()
639 |
640 | def reset(self):
641 | self.val = 0
642 | self.avg = 0
643 | self.sum = 0
644 | self.count = 0
645 |
646 | def update(self, val, n=1):
647 | self.val = val
648 | self.sum += val * n
649 | self.count += n
650 | self.avg = self.sum / self.count
651 |
652 |
653 | def adjust_learning_rate(optimizer, epoch, step, len_epoch, lr_milestones=None, warmup_epoch=0):
654 | """LR schedule that should yield 76% converged accuracy with batch size 256"""
655 | # if not isinstance(optimizer, torch.optim.Optimizer):
656 | # raise TypeError('{} is not an Optimizer'.format(
657 | # type(optimizer).__name__))
658 | if lr_milestones is None:
659 | factor = epoch // 30
660 |
661 | if epoch >= 80:
662 | factor = factor + 1
663 |
664 | lr = args.lr*(0.1**factor)
665 |
666 | """Warmup"""
667 | if epoch < 5:
668 | lr = lr*float(1 + step + epoch*len_epoch)/(5.*len_epoch)
669 |
670 | else:
671 | factor = 0
672 | for m in lr_milestones:
673 | if epoch >= m:
674 | factor += 1
675 |
676 | lr = args.lr*(0.1**factor)
677 |
678 | """Warmup"""
679 | if epoch < warmup_epoch:
680 | lr = lr*float(1 + step + epoch*len_epoch)/(warmup_epoch*len_epoch)
681 |
682 |
683 | # if(args.local_rank == 0):
684 | # print("epoch = {}, step = {}, lr = {}".format(epoch, step, lr))
685 |
686 | for param_group in optimizer.param_groups:
687 | param_group['lr'] = lr
688 |
689 | return [lr]
690 |
691 |
692 | class CosineAnnealingLR(object):
693 | def __init__(self, optimizer, T_max, N_batch, eta_min=0, last_epoch=-1, warmup=0):
694 | if not isinstance(optimizer, torch.optim.Optimizer):
695 | raise TypeError('{} is not an Optimizer'.format(
696 | type(optimizer).__name__))
697 | self.optimizer = optimizer
698 | self.T_max = T_max
699 | self.N_batch = N_batch
700 | self.eta_min = eta_min
701 | self.warmup = warmup
702 |
703 | if last_epoch == -1:
704 | for group in optimizer.param_groups:
705 | group.setdefault('initial_lr', group['lr'])
706 | else:
707 | for i, group in enumerate(optimizer.param_groups):
708 | if 'initial_lr' not in group:
709 | raise KeyError("param 'initial_lr' is not specified "
710 | "in param_groups[{}] when resuming an optimizer".format(i))
711 | self.base_lrs = list(map(lambda group: group['initial_lr'], optimizer.param_groups))
712 | self.update(last_epoch+1)
713 | self.last_epoch = last_epoch
714 | self.iter = 0
715 |
716 | def state_dict(self):
717 | return {key: value for key, value in self.__dict__.items() if key != 'optimizer'}
718 |
719 | def load_state_dict(self, state_dict):
720 | self.__dict__.update(state_dict)
721 |
722 | def get_lr(self):
723 | if self.last_epoch < self.warmup:
724 | lrs = [base_lr * (self.last_epoch + self.iter / self.N_batch) / self.warmup for base_lr in self.base_lrs]
725 | else:
726 | lrs = [self.eta_min + (base_lr - self.eta_min) *
727 | (1 + math.cos(math.pi * (self.last_epoch - self.warmup + self.iter / self.N_batch) / (self.T_max - self.warmup))) / 2
728 | for base_lr in self.base_lrs]
729 | return lrs
730 |
731 | def update(self, epoch, batch=0):
732 | self.last_epoch = epoch
733 | self.iter = batch + 1
734 | lrs = self.get_lr()
735 | for param_group, lr in zip(self.optimizer.param_groups, lrs):
736 | param_group['lr'] = lr
737 |
738 | return lrs
739 |
740 |
741 | def accuracy(output, target, topk=(1,)):
742 | """Computes the precision@k for the specified values of k"""
743 | maxk = max(topk)
744 | batch_size = target.size(0)
745 |
746 | _, pred = output.topk(maxk, 1, True, True)
747 | pred = pred.t()
748 | correct = pred.eq(target.view(1, -1).expand_as(pred))
749 |
750 | res = []
751 | for k in topk:
752 | correct_k = correct[:k].view(-1).float().sum(0, keepdim=True)
753 | res.append(correct_k.mul_(100.0 / batch_size))
754 | return res
755 |
756 | def rmse(yhat,y):
757 | if args.fp16:
758 | res = torch.sqrt(torch.mean((yhat.float()-y.float())**2))
759 | else:
760 | res = torch.sqrt(torch.mean((yhat-y)**2))
761 | return res
762 |
763 | def reduce_tensor(tensor):
764 | rt = tensor.clone()
765 | dist.all_reduce(rt, op=dist.ReduceOp.SUM)
766 | rt /= args.world_size
767 | return rt
768 |
769 | if __name__ == '__main__':
770 | # to suppress annoying warnings
771 | import warnings
772 | warnings.filterwarnings("ignore", "(Possibly )?corrupt EXIF data", UserWarning)
773 |
774 | main()
775 |
--------------------------------------------------------------------------------
/tools/smoothing.py:
--------------------------------------------------------------------------------
1 | # From: https://github.com/NVIDIA/DeepLearningExamples/tree/master/PyTorch/Classification/RN50v1.5
2 | # commit a1aff31
3 | # Date: 05/01/2019
4 | # Note: check the updates in NVIDIA DeepLearningExamples regulary
5 | import torch
6 | import torch.nn as nn
7 |
8 | class LabelSmoothing(nn.Module):
9 | """
10 | NLL loss with label smoothing.
11 | """
12 | def __init__(self, smoothing=0.0):
13 | """
14 | Constructor for the LabelSmoothing module.
15 |
16 | :param smoothing: label smoothing factor
17 | """
18 | super(LabelSmoothing, self).__init__()
19 | self.confidence = 1.0 - smoothing
20 | self.smoothing = smoothing
21 |
22 | def forward(self, x, target):
23 | logprobs = torch.nn.functional.log_softmax(x, dim=-1)
24 |
25 | nll_loss = -logprobs.gather(dim=-1, index=target.unsqueeze(1))
26 | nll_loss = nll_loss.squeeze(1)
27 | smooth_loss = -logprobs.mean(dim=-1)
28 | loss = self.confidence * nll_loss + self.smoothing * smooth_loss
29 | return loss.mean()
30 |
31 |
--------------------------------------------------------------------------------