├── LICENSE
├── NOTICE
├── README.md
├── evolution_search
├── cand_evaluator.py
├── config.py
├── datasets
│ ├── DownsampledImageNet.py
│ ├── SearchDatasetWrap.py
│ ├── __init__.py
│ ├── config_utils
│ │ └── __init__.py
│ ├── configs
│ │ ├── cifar-split.txt
│ │ ├── cifar100-split.txt
│ │ ├── cifar100-test-split.txt
│ │ └── imagenet-16-120-test-split.txt
│ ├── get_dataset_with_transform.py
│ └── test_utils.py
├── genotypes.py
├── metrics
│ ├── tester_acc.py
│ └── tester_wlm.py
├── operations.py
├── search.py
├── super_model.py
└── utils.py
├── repo_figures
├── ABS_FBS_architecture_normal.png
├── ABS_FBS_architecture_reduce.png
├── ABS_architecture_normal.png
├── ABS_architecture_reduce.png
├── FBS_architecture_normal.png
├── FBS_architecture_reduce.png
└── motivation.png
├── requirements.txt
├── retrain_architecture
├── config.py
├── genotypes.py
├── model.py
├── operations.py
├── retrain.py
├── thop
│ ├── __init__.py
│ ├── count_hooks.py
│ ├── profile.py
│ └── utils.py
├── utils.py
└── visualize.py
└── train_supernet
├── config.py
├── datasets
├── DownsampledImageNet.py
├── SearchDatasetWrap.py
├── __init__.py
├── config_utils
│ └── __init__.py
├── configs
│ ├── cifar-split.txt
│ ├── cifar100-split.txt
│ ├── cifar100-test-split.txt
│ └── imagenet-16-120-test-split.txt
├── get_dataset_with_transform.py
└── test_utils.py
├── genotypes.py
├── operations.py
├── super_model.py
├── train.py
└── utils.py
/LICENSE:
--------------------------------------------------------------------------------
1 | GeNAS
2 | Copyright (c) 2023-present NAVER Cloud Corp.
3 |
4 | Permission is hereby granted, free of charge, to any person obtaining a copy
5 | of this software and associated documentation files (the "Software"), to deal
6 | in the Software without restriction, including without limitation the rights
7 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
8 | copies of the Software, and to permit persons to whom the Software is
9 | furnished to do so, subject to the following conditions:
10 |
11 | The above copyright notice and this permission notice shall be included in
12 | all copies or substantial portions of the Software.
13 |
14 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
15 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
16 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
17 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
18 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
19 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
20 | THE SOFTWARE.
21 |
--------------------------------------------------------------------------------
/NOTICE:
--------------------------------------------------------------------------------
1 | GeNAS
2 | Copyright (c) 2023-present NAVER Cloud Corp.
3 |
4 | Permission is hereby granted, free of charge, to any person obtaining a copy
5 | of this software and associated documentation files (the "Software"), to deal
6 | in the Software without restriction, including without limitation the rights
7 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
8 | copies of the Software, and to permit persons to whom the Software is
9 | furnished to do so, subject to the following conditions:
10 |
11 | The above copyright notice and this permission notice shall be included in
12 | all copies or substantial portions of the Software.
13 |
14 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
15 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
16 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
17 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
18 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
19 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
20 | THE SOFTWARE.
21 |
22 | --------------------------------------------------------------------------------------
23 |
24 | This project contains subcomponents with separate copyright notices and license terms.
25 | Your use of the source code for these subcomponents is subject to the terms and conditions of the following licenses.
26 |
27 | =====
28 |
29 | D-X-Y/AutoDL-Projects
30 | https://github.com/D-X-Y/AutoDL-Projects
31 |
32 |
33 | MIT License
34 |
35 | Copyright (c) since 2019.01.01, author: Xuanyi Dong (GitHub: https://github.com/D-X-Y)
36 |
37 | Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions:
38 |
39 | The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software.
40 |
41 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
42 |
43 | =====
44 |
45 | megvii-model/RLNAS
46 | https://github.com/megvii-model/RLNAS
47 |
48 |
49 | MIT License
50 |
51 | Copyright (c) 2021 megvii-model
52 |
53 | Permission is hereby granted, free of charge, to any person obtaining a copy
54 | of this software and associated documentation files (the "Software"), to deal
55 | in the Software without restriction, including without limitation the rights
56 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
57 | copies of the Software, and to permit persons to whom the Software is
58 | furnished to do so, subject to the following conditions:
59 |
60 | The above copyright notice and this permission notice shall be included in all
61 | copies or substantial portions of the Software.
62 |
63 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
64 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
65 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
66 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
67 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
68 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
69 | SOFTWARE.
70 |
71 | =====
72 |
73 | ECP-CANDLE/Benchmarks
74 | https://github.com/ECP-CANDLE/Benchmarks
75 |
76 |
77 | MIT License
78 |
79 | Copyright (c) 2016 - 2017 ECP-CANDLE
80 |
81 | Permission is hereby granted, free of charge, to any person obtaining a copy
82 | of this software and associated documentation files (the "Software"), to deal
83 | in the Software without restriction, including without limitation the rights
84 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
85 | copies of the Software, and to permit persons to whom the Software is
86 | furnished to do so, subject to the following conditions:
87 |
88 | The above copyright notice and this permission notice shall be included in all
89 | copies or substantial portions of the Software.
90 |
91 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
92 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
93 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
94 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
95 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
96 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
97 | SOFTWARE.
98 |
99 | =====
100 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # GeNAS (IJCAI 2023)
2 |
3 | **GeNAS: Neural Architecture Search with Better Generalization**
4 |
5 | [Joonhyun Jeong](https://bestdeveloper691.github.io/)1,2, [Joonsang Yu](https://scholar.google.co.kr/citations?user=IC6M7_IAAAAJ&hl=ko)1,3, [Geondo Park](https://scholar.google.com/citations?user=Z8SGJ60AAAAJ&hl=ko)2, [Dongyoon Han](https://scholar.google.com/citations?user=jcP7m1QAAAAJ&hl=en)3, [YoungJoon Yoo](https://yjyoo3312.github.io/)1
6 |
7 | 1 NAVER Cloud, ImageVision
8 | 2 KAIST
9 | 3 NAVER AI Lab
10 |
11 | [](https://ijcai-23.org/)
12 | [](https://arxiv.org/abs/2305.08611)
13 |
14 | ## Introduction
15 |
16 | Neural Architecture Search (NAS) aims to automatically excavate the optimal network architecture with superior test performance. Recent neural architecture search (NAS) approaches rely on validation loss or accuracy to find the superior network for the target data. In this paper, we investigate a new neural architecture search measure for
17 | excavating architectures with better generalization. We demonstrate that the flatness of the loss surface can be a promising proxy for predicting the generalization capability of neural network architectures. We evaluate our proposed method on various search spaces, showing similar or even better performance compared to the state-of-the-art NAS methods. Notably, the resultant architecture found by flatness measure generalizes robustly to various shifts in data distribution (e.g. ImageNet-V2,-A,-O), as well as various tasks such as object detection and semantic segmentation.
18 |
19 |
20 |
21 | ## Updates
22 | **_2023-08-09_** We release the official implementation of GeNAS.
23 |
24 | ## Requirements
25 |
26 | * Pytorch 1.7.1
27 |
28 | Please see [requirements](./requirements.txt) for detailed specs.
29 |
30 | ## Quick Start
31 |
32 | 1. Train SuperNet, following [SPOS](https://github.com/megvii-model/SinglePathOneShot).
33 |
34 | ```bash
35 | cd train_supernet
36 | python3 train.py \
37 | --seed 1 \
38 | --data [CIFAR_DATASET_DIRECTORY] \
39 | --epochs 250 \
40 | --save [OUTPUT_DIRECTORY] \
41 | --random_label 0 \
42 | --split_data 1
43 | ```
44 |
45 | 2. Evolutionary Searching
46 |
47 | - You can skip step 1 and use [the pretrained SuperNet checkpoints](https://drive.google.com/drive/folders/19TAHE5C66n1PCLaAjcemGfmkIJQNkVKj?usp=sharing).
48 |
49 | ### Searching with Flatness
50 |
51 | ```bash
52 | cd evolutionary_search
53 | python3 search.py \
54 | --split_data 1 \
55 | --seed 3 \
56 | --init_model_path [SUPERNET_WEIGHT@INITIAL_EPOCH] \
57 | --model_path [SUPERNET_WEIGHT@FINAL_EPOCH] \
58 | --data [CIFAR_DATASET_DIRECTORY] \
59 | --metric wlm \
60 | --stds 0.001,0.003,0.006 \
61 | --max_train_img_size 850 \
62 | --max_val_img_size 25000 \
63 | --wlm_weight 0 \
64 | --acc_weight 0
65 | ```
66 |
67 | ### Searching with Angle + Flatness
68 |
69 | ```bash
70 | python3 search.py \
71 | --split_data 1 \
72 | --seed 3 \
73 | --init_model_path [SUPERNET_WEIGHT@INITIAL_EPOCH] \
74 | --model_path [SUPERNET_WEIGHT@FINAL_EPOCH] \
75 | --data [CIFAR_DATASET_DIRECTORY] \
76 | --metric angle+wlm \
77 | --stds 0.001,0.003,0.006 \
78 | --max_train_img_size 850 \
79 | --max_val_img_size 25000 \
80 | --wlm_weight 16 \
81 | --acc_weight 0
82 | ```
83 |
84 | 3. Re-training on ImageNet
85 |
86 | - We used V100 X 8 gpus for re-training on ImageNet.
87 |
88 | ### searched on CIFAR-100 with flatness
89 | ```bash
90 | python3 retrain.py \
91 | --data_root [IMAGENET_DATA_DIRECTORY] \
92 | --auxiliary \
93 | --arch=GENAS_FLATNESS_CIFAR100 \
94 | --init_channels 46
95 | ```
96 |
97 | ### searched on CIFAR-100 with angle + flatness
98 | ```bash
99 | python3 retrain.py \
100 | --data_root [IMAGENET_DATA_DIRECTORY] \
101 | --auxiliary \
102 | --arch=GENAS_ANGLE_FLATNESS_CIFAR100 \
103 | --init_channels 48
104 | ```
105 |
106 | ### searched on CIFAR-10 with flatness
107 | ```bash
108 | python3 retrain.py \
109 | --data_root [IMAGENET_DATA_DIRECTORY] \
110 | --auxiliary \
111 | --arch=GENAS_FLATNESS_CIFAR10 \
112 | --init_channels 52
113 | ```
114 |
115 | ### searched on CIFAR-10 with angle + flatness
116 | ```bash
117 | python3 retrain.py \
118 | --data_root [IMAGENET_DATA_DIRECTORY] \
119 | --auxiliary \
120 | --arch=GENAS_ANGLE_FLATNESS_CIFAR10 \
121 | --init_channels 44
122 | ```
123 |
124 | ## Model Zoo
125 |
126 | | Search Dataset | Search Metric | Params (M) | FLOPs (G) | ImageNet Top-1 Acc (%) | Weight |
127 | | :--------: | :----------------: | :-----------------: | :--------------: | :------: | :------: |
128 | CIFAR-10 | Angle | 5.3 | 0.6 | 75.7 | [ckpt](https://drive.google.com/file/d/1J_xyxU3ZbuDDr1ASEjdUIkjnrf5rNqB_/view?usp=sharing)
129 | CIFAR-10 | Accuracy | 5.4 | 0.6 | 75.3 | [ckpt](https://drive.google.com/file/d/1jo76ZhbqJt11qls3q2rMVsUcfzkQWp1Q/view?usp=sharing)
130 | CIFAR-10 | Flatness | 5.6 | 0.6 | 76.0 | [ckpt](https://drive.google.com/file/d/1VamhvAUSi2XZVE0Vn4Lxxp1S_dqODTil/view?usp=sharing)
131 | CIFAR-10 | Angle + Flatness | 5.3 | 0.6 | 76.1 | [ckpt](https://drive.google.com/file/d/1p2PSkt5ZyFY2NLGgU45Ilr5NIaXNizW9/view?usp=sharing)
132 | CIFAR-10 | Accuracy + Flatness | 5.6 | 0.6 | 75.7 | [ckpt](https://drive.google.com/file/d/1QBEyY-vFYpGOlwRSsTFxMM8GBtY3F8k7/view?usp=sharing)
133 | | | | | |
134 | CIFAR-100 | Angle | 5.4 | 0.6 | 75.0 | [ckpt](https://drive.google.com/file/d/1CmpkPsWNWVdbDbcmyuVfp38A7lB2MWoC/view?usp=sharing)
135 | CIFAR-100 | Accuracy | 5.4 | 0.6 | 75.4 | [ckpt](https://drive.google.com/file/d/1TWzs-upwnAgOvF0HjKDSwC3TeDRstl4C/view?usp=sharing)
136 | CIFAR-100 | Flatness | 5.2 | 0.6 | 76.1 | [ckpt](https://drive.google.com/file/d/1YLcZNpTytP9XTYDYoQ_nv8gQHnRAFe67/view?usp=sharing)
137 | CIFAR-100 | Angle + Flatness | 5.4 | 0.6 | 75.7 | [ckpt](https://drive.google.com/file/d/1reRbr4cFeoL8fwOTPQjQTg7w_QZAAIm4/view?usp=sharing)
138 | CIFAR-100 | Accuracy + Flatness | 5.4 | 0.6 | 75.9 | [ckpt](https://drive.google.com/file/d/1-GVpP7yUWc7W6Qf8dM_FN3fTo0acO0AI/view?usp=sharing)
139 |
140 | ### Architecture Visualization
141 |
142 | #### angle-based searching
143 |
144 | - normal cell
145 |
146 |
147 |
148 | - reduce cell
149 |
150 |
151 |
152 |
153 | #### angle+flatness based searching
154 |
155 | - normal cell
156 |
157 |
158 |
159 | - reduce cell
160 |
161 |
162 |
163 | #### flatness-based searching
164 |
165 | - normal cell
166 |
167 |
168 |
169 | - reduce cell
170 |
171 |
172 |
173 | ## Citation
174 | If you find that this project helps your research, please consider citing as below:
175 |
176 | ```
177 | @article{jeong2023genas,
178 | title={GeNAS: Neural Architecture Search with Better Generalization},
179 | author={Jeong, Joonhyun and Yu, Joonsang and Park, Geondo and Han, Dongyoon and Yoo, Youngjoon},
180 | journal={arXiv preprint arXiv:2305.08611},
181 | year={2023}
182 | }
183 | ```
184 |
185 | ## License
186 | ```
187 | GeNAS
188 | Copyright (c) 2023-present NAVER Cloud Corp.
189 |
190 | Permission is hereby granted, free of charge, to any person obtaining a copy
191 | of this software and associated documentation files (the "Software"), to deal
192 | in the Software without restriction, including without limitation the rights
193 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
194 | copies of the Software, and to permit persons to whom the Software is
195 | furnished to do so, subject to the following conditions:
196 |
197 | The above copyright notice and this permission notice shall be included in
198 | all copies or substantial portions of the Software.
199 |
200 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
201 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
202 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
203 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
204 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
205 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
206 | THE SOFTWARE.
207 | ```
208 |
--------------------------------------------------------------------------------
/evolution_search/config.py:
--------------------------------------------------------------------------------
1 | """
2 | Code Adapted from https://github.com/megvii-model/RLNAS/blob/main/darts_search_space/cifar10/rlnas/evolution_search/config.py
3 | """
4 |
5 | import os
6 | class config:
7 | host = '127.0.0.1'
8 |
9 | username = 'test'
10 | port = 5672
11 |
12 | exp_name = os.path.dirname(os.path.abspath(__file__))
13 | exp_name = '-'.join(i for i in exp_name.split(os.path.sep) if i);
14 |
15 | test_send_pipe = exp_name + '-test-send_pipe'
16 | test_recv_pipe = exp_name + '-test-recv_pipe'
17 |
18 | net_cache = 'model_and_data/checkpoint_epoch_250.pth.tar'
19 | initial_net_cache = 'model_and_data/checkpoint_epoch_0.pth.tar'
20 |
21 |
22 | layers = 8
23 | edges = 14
24 | model_input_size_imagenet = (1, 3, 224, 224)
25 |
26 | # Candidate operators
27 | blocks_keys = [
28 | 'none',
29 | 'max_pool_3x3',
30 | 'avg_pool_3x3',
31 | 'skip_connect',
32 | 'sep_conv_3x3',
33 | 'sep_conv_5x5',
34 | 'dil_conv_3x3',
35 | 'dil_conv_5x5'
36 | ]
37 | op_num = len(blocks_keys)
38 |
39 | # Operators encoding
40 | NONE = 0
41 | MAX_POOLING_3x3 = 1
42 | AVG_POOL_3x3 = 2
43 | SKIP_CONNECT = 3
44 | SEP_CONV_3x3 = 4
45 | SEP_CONV_5x5 = 5
46 | DIL_CONV_3x3 = 6
47 | DIL_CONV_5x5 = 7
48 |
49 | time_limit=None
50 | #time_limit=0.050
51 | speed_input_shape=[32,3,224,224]
52 |
53 | flops_limit=None
54 |
55 | max_epochs=20
56 | select_num = 10
57 | population_num = 50
58 | mutation_num = 25
59 | m_prob = 0.1
60 | crossover_num = 25
61 |
62 |
63 | momentum = 0.7
64 | eps = 1e-5
65 |
66 | # Enumerate all paths of a single cell
67 | # paths = [[0, 2, 3, 4, 5], [0, 2, 3, 5], [0, 2, 4, 5], [0, 2, 5], [0, 3, 4, 5], [0, 3, 5], [0, 4, 5], [0, 5],
68 | # [1, 2, 3, 4, 5], [1, 2, 3, 5], [1, 2, 4, 5], [1, 2, 5], [1, 3, 4, 5], [1, 3, 5], [1, 4, 5], [1, 5]]
69 | # Enumerate all paths of a single cell
70 | paths = [[0, 2, 3, 4, 5], [0, 2, 3, 5], [0, 2, 4, 5], [0, 2, 5], [0, 3, 4, 5], [0, 3, 5], [0, 4, 5], [0, 5],
71 | [1, 2, 3, 4, 5], [1, 2, 3, 5], [1, 2, 4, 5], [1, 2, 5], [1, 3, 4, 5], [1, 3, 5], [1, 4, 5], [1, 5],
72 | [0, 2, 3, 4], [0, 2, 4], [0, 3, 4], [0, 4],
73 | [1, 2, 3, 4], [1, 2, 4], [1, 3, 4], [1, 4],
74 | [0, 2, 3], [0, 3],
75 | [1, 2, 3], [1, 3],
76 | [0, 2],
77 | [1, 2]]
78 |
79 | for i in ['exp_name']:
80 | print('{}: {}'.format(i,eval('config.{}'.format(i))))
81 |
--------------------------------------------------------------------------------
/evolution_search/datasets/DownsampledImageNet.py:
--------------------------------------------------------------------------------
1 | ##################################################
2 | # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 #
3 | ##################################################
4 | import os, sys, hashlib, torch
5 | import numpy as np
6 | from PIL import Image
7 | import torch.utils.data as data
8 | if sys.version_info[0] == 2:
9 | import cPickle as pickle
10 | else:
11 | import pickle
12 | import pdb
13 |
14 | def calculate_md5(fpath, chunk_size=1024 * 1024):
15 | md5 = hashlib.md5()
16 | with open(fpath, 'rb') as f:
17 | for chunk in iter(lambda: f.read(chunk_size), b''):
18 | md5.update(chunk)
19 | return md5.hexdigest()
20 |
21 |
22 | def check_md5(fpath, md5, **kwargs):
23 | return md5 == calculate_md5(fpath, **kwargs)
24 |
25 |
26 | def check_integrity(fpath, md5=None):
27 | if not os.path.isfile(fpath): return False
28 | if md5 is None: return True
29 | else : return check_md5(fpath, md5)
30 |
31 |
32 | class ImageNet16(data.Dataset):
33 | # http://image-net.org/download-images
34 | # A Downsampled Variant of ImageNet as an Alternative to the CIFAR datasets
35 | # https://arxiv.org/pdf/1707.08819.pdf
36 |
37 | train_list = [
38 | ['train_data_batch_1', '27846dcaa50de8e21a7d1a35f30f0e91'],
39 | ['train_data_batch_2', 'c7254a054e0e795c69120a5727050e3f'],
40 | ['train_data_batch_3', '4333d3df2e5ffb114b05d2ffc19b1e87'],
41 | ['train_data_batch_4', '1620cdf193304f4a92677b695d70d10f'],
42 | ['train_data_batch_5', '348b3c2fdbb3940c4e9e834affd3b18d'],
43 | ['train_data_batch_6', '6e765307c242a1b3d7d5ef9139b48945'],
44 | ['train_data_batch_7', '564926d8cbf8fc4818ba23d2faac7564'],
45 | ['train_data_batch_8', 'f4755871f718ccb653440b9dd0ebac66'],
46 | ['train_data_batch_9', 'bb6dd660c38c58552125b1a92f86b5d4'],
47 | ['train_data_batch_10','8f03f34ac4b42271a294f91bf480f29b'],
48 | ]
49 | valid_list = [
50 | ['val_data', '3410e3017fdaefba8d5073aaa65e4bd6'],
51 | ]
52 |
53 | def __init__(self, root, train, transform, use_num_of_class_only=None):
54 | self.root = root
55 | self.transform = transform
56 | self.train = train # training set or valid set
57 | if not self._check_integrity(): raise RuntimeError('Dataset not found or corrupted.')
58 |
59 | if self.train: downloaded_list = self.train_list
60 | else : downloaded_list = self.valid_list
61 | self.data = []
62 | self.targets = []
63 |
64 | # now load the picked numpy arrays
65 | for i, (file_name, checksum) in enumerate(downloaded_list):
66 | file_path = os.path.join(self.root, file_name)
67 | #print ('Load {:}/{:02d}-th : {:}'.format(i, len(downloaded_list), file_path))
68 | with open(file_path, 'rb') as f:
69 | if sys.version_info[0] == 2:
70 | entry = pickle.load(f)
71 | else:
72 | entry = pickle.load(f, encoding='latin1')
73 | self.data.append(entry['data'])
74 | self.targets.extend(entry['labels'])
75 | self.data = np.vstack(self.data).reshape(-1, 3, 16, 16)
76 | self.data = self.data.transpose((0, 2, 3, 1)) # convert to HWC
77 | if use_num_of_class_only is not None:
78 | assert isinstance(use_num_of_class_only, int) and use_num_of_class_only > 0 and use_num_of_class_only < 1000, 'invalid use_num_of_class_only : {:}'.format(use_num_of_class_only)
79 | new_data, new_targets = [], []
80 | for I, L in zip(self.data, self.targets):
81 | if 1 <= L <= use_num_of_class_only:
82 | new_data.append( I )
83 | new_targets.append( L )
84 | self.data = new_data
85 | self.targets = new_targets
86 | # self.mean.append(entry['mean'])
87 | #self.mean = np.vstack(self.mean).reshape(-1, 3, 16, 16)
88 | #self.mean = np.mean(np.mean(np.mean(self.mean, axis=0), axis=1), axis=1)
89 | #print ('Mean : {:}'.format(self.mean))
90 | #temp = self.data - np.reshape(self.mean, (1, 1, 1, 3))
91 | #std_data = np.std(temp, axis=0)
92 | #std_data = np.mean(np.mean(std_data, axis=0), axis=0)
93 | #print ('Std : {:}'.format(std_data))
94 |
95 | def __getitem__(self, index):
96 | img, target = self.data[index], self.targets[index] - 1
97 |
98 | img = Image.fromarray(img)
99 |
100 | if self.transform is not None:
101 | img = self.transform(img)
102 |
103 | return img, target
104 |
105 | def __len__(self):
106 | return len(self.data)
107 |
108 | def _check_integrity(self):
109 | root = self.root
110 | for fentry in (self.train_list + self.valid_list):
111 | filename, md5 = fentry[0], fentry[1]
112 | fpath = os.path.join(root, filename)
113 | if not check_integrity(fpath, md5):
114 | return False
115 | return True
116 |
117 | #
118 | if __name__ == '__main__':
119 | train = ImageNet16('/data02/dongxuanyi/.torch/cifar.python/ImageNet16', True , None)
120 | valid = ImageNet16('/data02/dongxuanyi/.torch/cifar.python/ImageNet16', False, None)
121 |
122 | print ( len(train) )
123 | print ( len(valid) )
124 | image, label = train[111]
125 | trainX = ImageNet16('/data02/dongxuanyi/.torch/cifar.python/ImageNet16', True , None, 200)
126 | validX = ImageNet16('/data02/dongxuanyi/.torch/cifar.python/ImageNet16', False , None, 200)
127 | print ( len(trainX) )
128 | print ( len(validX) )
129 | #import pdb; pdb.set_trace()
130 |
--------------------------------------------------------------------------------
/evolution_search/datasets/SearchDatasetWrap.py:
--------------------------------------------------------------------------------
1 | ##################################################
2 | # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 #
3 | ##################################################
4 | import torch, copy, random
5 | import torch.utils.data as data
6 |
7 |
8 | class SearchDataset(data.Dataset):
9 |
10 | def __init__(self, name, data, train_split, valid_split, check=True):
11 | self.datasetname = name
12 | if isinstance(data, (list, tuple)): # new type of SearchDataset
13 | assert len(data) == 2, 'invalid length: {:}'.format( len(data) )
14 | self.train_data = data[0]
15 | self.valid_data = data[1]
16 | self.train_split = train_split.copy()
17 | self.valid_split = valid_split.copy()
18 | self.mode_str = 'V2' # new mode
19 | else:
20 | self.mode_str = 'V1' # old mode
21 | self.data = data
22 | self.train_split = train_split.copy()
23 | self.valid_split = valid_split.copy()
24 | if check:
25 | intersection = set(train_split).intersection(set(valid_split))
26 | assert len(intersection) == 0, 'the splitted train and validation sets should have no intersection'
27 | self.length = len(self.train_split)
28 |
29 | def __repr__(self):
30 | return ('{name}(name={datasetname}, train={tr_L}, valid={val_L}, version={ver})'.format(name=self.__class__.__name__, datasetname=self.datasetname, tr_L=len(self.train_split), val_L=len(self.valid_split), ver=self.mode_str))
31 |
32 | def __len__(self):
33 | return self.length
34 |
35 | def __getitem__(self, index):
36 | assert index >= 0 and index < self.length, 'invalid index = {:}'.format(index)
37 | train_index = self.train_split[index]
38 | valid_index = random.choice( self.valid_split )
39 | if self.mode_str == 'V1':
40 | train_image, train_label = self.data[train_index]
41 | valid_image, valid_label = self.data[valid_index]
42 | elif self.mode_str == 'V2':
43 | train_image, train_label = self.train_data[train_index]
44 | valid_image, valid_label = self.valid_data[valid_index]
45 | else: raise ValueError('invalid mode : {:}'.format(self.mode_str))
46 | return train_image, train_label, valid_image, valid_label
47 |
--------------------------------------------------------------------------------
/evolution_search/datasets/__init__.py:
--------------------------------------------------------------------------------
1 | ##################################################
2 | # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 #
3 | ##################################################
4 | from .get_dataset_with_transform import get_datasets, get_nas_search_loaders
5 | from .SearchDatasetWrap import SearchDataset
6 |
--------------------------------------------------------------------------------
/evolution_search/datasets/config_utils/__init__.py:
--------------------------------------------------------------------------------
1 | import sys, os
2 | import json
3 | from collections import namedtuple
4 |
5 | support_types = ('str', 'int', 'bool', 'float', 'none')
6 |
7 | def convert_param(original_lists):
8 | assert isinstance(original_lists, list), 'The type is not right : {:}'.format(original_lists)
9 | ctype, value = original_lists[0], original_lists[1]
10 | assert ctype in support_types, 'Ctype={:}, support={:}'.format(ctype, support_types)
11 | is_list = isinstance(value, list)
12 | if not is_list: value = [value]
13 | outs = []
14 | for x in value:
15 | if ctype == 'int':
16 | x = int(x)
17 | elif ctype == 'str':
18 | x = str(x)
19 | elif ctype == 'bool':
20 | x = bool(int(x))
21 | elif ctype == 'float':
22 | x = float(x)
23 | elif ctype == 'none':
24 | assert x == 'None', 'for none type, the value must be None instead of {:}'.format(x)
25 | x = None
26 | else:
27 | raise TypeError('Does not know this type : {:}'.format(ctype))
28 | outs.append(x)
29 | if not is_list: outs = outs[0]
30 | return outs
31 |
32 | def load_config(path, extra, logger):
33 | path = str(path)
34 | if hasattr(logger, 'log'): logger.log(path)
35 | assert os.path.exists(path), 'Can not find {:}'.format(path)
36 | # Reading data back
37 | with open(path, 'r') as f:
38 | data = json.load(f)
39 | content = { k: convert_param(v) for k,v in data.items()}
40 | assert extra is None or isinstance(extra, dict), 'invalid type of extra : {:}'.format(extra)
41 | if isinstance(extra, dict): content = {**content, **extra}
42 | Arguments = namedtuple('Configure', ' '.join(content.keys()))
43 | content = Arguments(**content)
44 | if hasattr(logger, 'log'): logger.log('{:}'.format(content))
45 | return content
--------------------------------------------------------------------------------
/evolution_search/datasets/get_dataset_with_transform.py:
--------------------------------------------------------------------------------
1 | ##################################################
2 | # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 #
3 | ##################################################
4 | import os, sys, torch
5 | import os.path as osp
6 | import numpy as np
7 | import torchvision.datasets as dset
8 | import torchvision.transforms as transforms
9 | from copy import deepcopy
10 | from PIL import Image
11 |
12 | from .DownsampledImageNet import ImageNet16
13 | from .SearchDatasetWrap import SearchDataset
14 | from .config_utils import load_config as load_dataset_config
15 | from torchvision.transforms import transforms
16 | from PIL import ImageFilter, ImageOps
17 | import random
18 | import torchvision.datasets as datasets
19 |
20 | Dataset2Class = {'cifar10' : 10,
21 | 'cifar100': 100,
22 | 'imagenet-1k-s':1000,
23 | 'imagenet-1k' : 1000,
24 | 'ImageNet16' : 1000,
25 | 'ImageNet16-150': 150,
26 | 'ImageNet16-120': 120,
27 | 'ImageNet16-200': 200}
28 |
29 | class CUTOUT(object):
30 |
31 | def __init__(self, length):
32 | self.length = length
33 |
34 | def __repr__(self):
35 | return ('{name}(length={length})'.format(name=self.__class__.__name__, **self.__dict__))
36 |
37 | def __call__(self, img):
38 | h, w = img.size(1), img.size(2)
39 | mask = np.ones((h, w), np.float32)
40 | y = np.random.randint(h)
41 | x = np.random.randint(w)
42 |
43 | y1 = np.clip(y - self.length // 2, 0, h)
44 | y2 = np.clip(y + self.length // 2, 0, h)
45 | x1 = np.clip(x - self.length // 2, 0, w)
46 | x2 = np.clip(x + self.length // 2, 0, w)
47 |
48 | mask[y1: y2, x1: x2] = 0.
49 | mask = torch.from_numpy(mask)
50 | mask = mask.expand_as(img)
51 | img *= mask
52 | return img
53 |
54 |
55 | imagenet_pca = {
56 | 'eigval': np.asarray([0.2175, 0.0188, 0.0045]),
57 | 'eigvec': np.asarray([
58 | [-0.5675, 0.7192, 0.4009],
59 | [-0.5808, -0.0045, -0.8140],
60 | [-0.5836, -0.6948, 0.4203],
61 | ])
62 | }
63 |
64 |
65 | class Lighting(object):
66 | def __init__(self, alphastd,
67 | eigval=imagenet_pca['eigval'],
68 | eigvec=imagenet_pca['eigvec']):
69 | self.alphastd = alphastd
70 | assert eigval.shape == (3,)
71 | assert eigvec.shape == (3, 3)
72 | self.eigval = eigval
73 | self.eigvec = eigvec
74 |
75 | def __call__(self, img):
76 | if self.alphastd == 0.:
77 | return img
78 | rnd = np.random.randn(3) * self.alphastd
79 | rnd = rnd.astype('float32')
80 | v = rnd
81 | old_dtype = np.asarray(img).dtype
82 | v = v * self.eigval
83 | v = v.reshape((3, 1))
84 | inc = np.dot(self.eigvec, v).reshape((3,))
85 | img = np.add(img, inc)
86 | if old_dtype == np.uint8:
87 | img = np.clip(img, 0, 255)
88 | img = Image.fromarray(img.astype(old_dtype), 'RGB')
89 | return img
90 |
91 | def __repr__(self):
92 | return self.__class__.__name__ + '()'
93 |
94 |
95 | class Cifar10RandomLabels(datasets.CIFAR10):
96 | """CIFAR10 dataset, with support for randomly corrupt labels.
97 | Params
98 | ------
99 | rand_seed: int
100 | Default 0. numpy random seed.
101 | num_classes: int
102 | Default 10. The number of classes in the dataset.
103 | """
104 | def __init__(self, rand_seed=0, num_classes=10, **kwargs):
105 | super(Cifar10RandomLabels, self).__init__(**kwargs)
106 | self.n_classes = num_classes
107 | self.rand_seed = rand_seed
108 | self.random_labels()
109 |
110 | def random_labels(self):
111 | labels = np.array(self.targets)
112 | print('num_classes:{}, random labels num:{}, random seed:{}'.format(self.n_classes, len(labels), self.rand_seed))
113 | np.random.seed(self.rand_seed)
114 | rnd_labels = np.random.randint(0, self.n_classes, len(labels))
115 | # we need to explicitly cast the labels from npy.int64 to
116 | # builtin int type, otherwise pytorch will fail...
117 | labels = [int(x) for x in rnd_labels]
118 |
119 | self.targets = labels
120 |
121 | class Cifar100RandomLabels(datasets.CIFAR100):
122 | """CIFAR10 dataset, with support for randomly corrupt labels.
123 | Params
124 | ------
125 | rand_seed: int
126 | Default 0. numpy random seed.
127 | num_classes: int
128 | Default 100. The number of classes in the dataset.
129 | """
130 | def __init__(self, rand_seed=0, num_classes=100, **kwargs):
131 | super(Cifar100RandomLabels, self).__init__(**kwargs)
132 | self.n_classes = num_classes
133 | self.rand_seed = rand_seed
134 | self.random_labels()
135 |
136 | def random_labels(self):
137 | labels = np.array(self.targets)
138 | print('num_classes:{}, random labels num:{}, random seed:{}'.format(self.n_classes, len(labels), self.rand_seed))
139 | np.random.seed(self.rand_seed)
140 | rnd_labels = np.random.randint(0, self.n_classes, len(labels))
141 | # we need to explicitly cast the labels from npy.int64 to
142 | # builtin int type, otherwise pytorch will fail...
143 | labels = [int(x) for x in rnd_labels]
144 |
145 | self.targets = labels
146 |
147 | class ImageNet16RandomLabels(ImageNet16):
148 | """CIFAR10 dataset, with support for randomly corrupt labels.
149 | Params
150 | ------
151 | rand_seed: int
152 | Default 0. numpy random seed.
153 | num_classes: int
154 | Default 120. The number of classes in the dataset.
155 | """
156 | def __init__(self, rand_seed=0, num_classes=120, **kwargs):
157 | super(ImageNet16RandomLabels, self).__init__(**kwargs)
158 | self.n_classes = num_classes
159 | self.rand_seed = rand_seed
160 | self.random_labels()
161 | # print('min_label:{}, max_label:{}'.format(min(self.targets), max(self.targets)))
162 |
163 | def random_labels(self):
164 | labels = np.array(self.targets)
165 | print('num_classes:{}, random labels num:{}, random seed:{}'.format(self.n_classes, len(labels), self.rand_seed))
166 | np.random.seed(self.rand_seed)
167 | rnd_labels = np.random.randint(1, self.n_classes+1, len(labels))
168 | # we need to explicitly cast the labels from npy.int64 to
169 | # builtin int type, otherwise pytorch will fail...
170 | labels = [int(x) for x in rnd_labels]
171 |
172 | self.targets = labels
173 |
174 | def get_datasets(name, root, cutout, rand_seed, byol_aug_type=None, random_label=True):
175 |
176 | if name == 'cifar10':
177 | mean = [x / 255 for x in [125.3, 123.0, 113.9]]
178 | std = [x / 255 for x in [63.0, 62.1, 66.7]]
179 | elif name == 'cifar100':
180 | mean = [x / 255 for x in [129.3, 124.1, 112.4]]
181 | std = [x / 255 for x in [68.2, 65.4, 70.4]]
182 | elif name.startswith('imagenet-1k'):
183 | mean, std = [0.485, 0.456, 0.406], [0.229, 0.224, 0.225]
184 | elif name.startswith('ImageNet16'):
185 | mean = [x / 255 for x in [122.68, 116.66, 104.01]]
186 | std = [x / 255 for x in [63.22, 61.26 , 65.09]]
187 | else:
188 | raise TypeError("Unknow dataset : {:}".format(name))
189 |
190 | # Data Argumentation
191 | if name == 'cifar10' or name == 'cifar100':
192 | lists = [transforms.RandomHorizontalFlip(), transforms.RandomCrop(32, padding=4), transforms.ToTensor(), transforms.Normalize(mean, std)]
193 | if cutout > 0 : lists += [CUTOUT(cutout)]
194 | if byol_aug_type is None:
195 | train_transform = transforms.Compose(lists)
196 | elif byol_aug_type=='byol':
197 | online_aug = get_train_transform('BYOL_Tau', 32, mean, std)
198 | target_aug = get_train_transform('BYOL_Tau_Hat', 32, mean, std)
199 | train_transform = TwoImageAugmentations(online_aug, target_aug)
200 | else:
201 | raise NotImplementedError
202 | test_transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize(mean, std)])
203 | xshape = (1, 3, 32, 32)
204 | elif name.startswith('ImageNet16'):
205 | lists = [transforms.RandomHorizontalFlip(), transforms.RandomCrop(16, padding=2), transforms.ToTensor(), transforms.Normalize(mean, std)]
206 | if cutout > 0 : lists += [CUTOUT(cutout)]
207 | if byol_aug_type is None:
208 | train_transform = transforms.Compose(lists)
209 | elif byol_aug_type=='byol':
210 | online_aug = get_train_transform('BYOL_Tau', 16, mean, std)
211 | target_aug = get_train_transform('BYOL_Tau_Hat', 16, mean, std)
212 | train_transform = TwoImageAugmentations(online_aug, target_aug)
213 | else:
214 | raise NotImplementedError
215 | test_transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize(mean, std)])
216 | xshape = (1, 3, 16, 16)
217 | elif name == 'tiered':
218 | lists = [transforms.RandomHorizontalFlip(), transforms.RandomCrop(80, padding=4), transforms.ToTensor(), transforms.Normalize(mean, std)]
219 | if cutout > 0 : lists += [CUTOUT(cutout)]
220 | train_transform = transforms.Compose(lists)
221 | test_transform = transforms.Compose([transforms.CenterCrop(80), transforms.ToTensor(), transforms.Normalize(mean, std)])
222 | xshape = (1, 3, 32, 32)
223 | elif name.startswith('imagenet-1k'):
224 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
225 | if name == 'imagenet-1k':
226 | xlists = [transforms.RandomResizedCrop(224)]
227 | xlists.append(
228 | transforms.ColorJitter(
229 | brightness=0.4,
230 | contrast=0.4,
231 | saturation=0.4,
232 | hue=0.2))
233 | xlists.append( Lighting(0.1))
234 | elif name == 'imagenet-1k-s':
235 | xlists = [transforms.RandomResizedCrop(224, scale=(0.2, 1.0))]
236 | else: raise ValueError('invalid name : {:}'.format(name))
237 | xlists.append( transforms.RandomHorizontalFlip(p=0.5) )
238 | xlists.append( transforms.ToTensor() )
239 | xlists.append( normalize )
240 | train_transform = transforms.Compose(xlists)
241 | test_transform = transforms.Compose([transforms.Resize(256), transforms.CenterCrop(224), transforms.ToTensor(), normalize])
242 | xshape = (1, 3, 224, 224)
243 | else:
244 | raise TypeError("Unknow dataset : {:}".format(name))
245 |
246 | if name == 'cifar10':
247 | if random_label:
248 | train_data = Cifar10RandomLabels(root=root, train=True , transform=train_transform, download=True, rand_seed=rand_seed)
249 | else:
250 | train_data = datasets.CIFAR10(root=root, train=True, transform=train_transform, download=True)
251 | test_data = datasets.CIFAR10(root=root, train=True , transform=test_transform, download=True)
252 | # test_data = datasets.CIFAR10(root=root, train=False, transform=test_transform , download=True)
253 | assert len(train_data) == 50000 and len(test_data) == 50000
254 | elif name == 'cifar100':
255 | if random_label:
256 | train_data = Cifar100RandomLabels(root=root, train=True , transform=train_transform, download=True, rand_seed=rand_seed)
257 | else:
258 | train_data = datasets.CIFAR100(root=root, train=True , transform=train_transform, download=True)
259 | test_data = datasets.CIFAR100(root=root, train=True, transform=test_transform , download=True)
260 | assert len(train_data) == 50000 and len(test_data) == 50000
261 | elif name.startswith('imagenet-1k'):
262 | train_data = dset.ImageFolder(osp.join(root, 'train'), train_transform)
263 | test_data = dset.ImageFolder(osp.join(root, 'val'), test_transform)
264 | assert len(train_data) == 1281167 and len(test_data) == 50000, 'invalid number of images : {:} & {:} vs {:} & {:}'.format(len(train_data), len(test_data), 1281167, 50000)
265 | elif name == 'ImageNet16':
266 | if random_label:
267 | train_data = ImageNet16RandomLabels(root=root, train=True ,transform=train_transform, rand_seed=rand_seed)
268 | else:
269 | train_data = ImageNet16(root=root, train=True, transform=train_transform)
270 | test_data = ImageNet16(root=root, train=False, transform=test_transform)
271 | assert len(train_data) == 1281167 and len(test_data) == 50000
272 | elif name == 'ImageNet16-120':
273 | if random_label:
274 | train_data = ImageNet16RandomLabels(root=root, train=True , transform=train_transform, num_classes=120, use_num_of_class_only=120, rand_seed=rand_seed)
275 | else:
276 | train_data = ImageNet16(root=root, train=True , transform=train_transform, use_num_of_class_only=120)
277 | test_data = ImageNet16(root=root, train=False, transform=test_transform , use_num_of_class_only=120)
278 | assert len(train_data) == 151700 and len(test_data) == 6000
279 | elif name == 'ImageNet16-150':
280 | if random_label:
281 | train_data = ImageNet16RandomLabels(root=root, train=True , transform=train_transform, num_classes=150, use_num_of_class_only=150, rand_seed=rand_seed)
282 | else:
283 | train_data = ImageNet16(root=root, train=True , transform=train_transform, use_num_of_class_only=150)
284 | test_data = ImageNet16(root=root, train=False, transform=test_transform , use_num_of_class_only=150)
285 | assert len(train_data) == 190272 and len(test_data) == 7500
286 | elif name == 'ImageNet16-200':
287 | if random_label:
288 | train_data = ImageNet16RandomLabels(root=root, train=True ,transform=train_transform, num_classes=200, use_num_of_class_only=200, rand_seed=rand_seed)
289 | else:
290 | train_data = ImageNet16(root=root, train=True , transform=train_transform, use_num_of_class_only=200)
291 | test_data = ImageNet16(root=root, train=False, transform=test_transform , use_num_of_class_only=200)
292 | assert len(train_data) == 254775 and len(test_data) == 10000
293 | else: raise TypeError("Unknow dataset : {:}".format(name))
294 |
295 | class_num = Dataset2Class[name]
296 | return train_data, test_data, xshape, class_num
297 |
298 |
299 | def get_nas_search_loaders(train_data, valid_data, dataset, config_root, batch_size, workers, use_valid_no_shuffle=False):
300 | # NOTE: detailed dataset configuration is given in NAS-BENCH-201 paper, https://arxiv.org/pdf/2001.00326.pdf.
301 | if isinstance(batch_size, (list,tuple)):
302 | batch, test_batch = batch_size
303 | else:
304 | batch, test_batch = batch_size, batch_size
305 | if dataset == 'cifar10' or dataset == 'cifar100':
306 | #split_Fpath = 'configs/nas-benchmark/cifar-split.txt'
307 | if dataset == 'cifar10':
308 | cifar_split = load_dataset_config('{:}/cifar-split.txt'.format(config_root), None, None)
309 | elif dataset == 'cifar100':
310 | cifar_split = load_dataset_config('{:}/cifar100-split.txt'.format(config_root), None, None)
311 | train_split, valid_split = cifar_split.train, cifar_split.valid # search over the proposed training and validation set
312 | #logger.log('Load split file from {:}'.format(split_Fpath)) # they are two disjoint groups in the original CIFAR-10 training set
313 | # To split data
314 | xvalid_data = valid_data
315 | search_data = SearchDataset(dataset, train_data, train_split, valid_split)
316 | # data loader
317 | search_loader = torch.utils.data.DataLoader(search_data, batch_size=batch, shuffle=True , num_workers=workers, pin_memory=True)
318 | train_loader = torch.utils.data.DataLoader(train_data , batch_size=batch, sampler=torch.utils.data.sampler.SubsetRandomSampler(train_split), num_workers=workers, pin_memory=True)
319 | valid_loader = torch.utils.data.DataLoader(xvalid_data, batch_size=test_batch, sampler=torch.utils.data.sampler.SubsetRandomSampler(valid_split), num_workers=workers, pin_memory=True)
320 | if use_valid_no_shuffle:
321 | # NOTE: using validation dataset
322 | valid_loader_no_shuffle = torch.utils.data.DataLoader(torch.utils.data.Subset(xvalid_data, valid_split), batch_size=test_batch, shuffle=False, num_workers=workers, pin_memory=True)
323 | # NOTE: using search training dataset
324 | # valid_loader_no_shuffle = torch.utils.data.DataLoader(search_data, batch_size=test_batch, shuffle=False , num_workers=workers, pin_memory=True)
325 | elif dataset == 'ImageNet16-120':
326 | imagenet_test_split = load_dataset_config('{:}/imagenet-16-120-test-split.txt'.format(config_root), None, None)
327 | search_train_data = train_data
328 | search_valid_data = deepcopy(valid_data) ; search_valid_data.transform = train_data.transform
329 | search_data = SearchDataset(dataset, [search_train_data,search_valid_data], list(range(len(search_train_data))), imagenet_test_split.xvalid)
330 | search_loader = torch.utils.data.DataLoader(search_data, batch_size=batch, shuffle=True , num_workers=workers, pin_memory=True)
331 | train_loader = torch.utils.data.DataLoader(train_data , batch_size=batch, shuffle=True , num_workers=workers, pin_memory=True)
332 | valid_loader = torch.utils.data.DataLoader(valid_data , batch_size=test_batch, sampler=torch.utils.data.sampler.SubsetRandomSampler(imagenet_test_split.xvalid), num_workers=workers, pin_memory=True)
333 | if use_valid_no_shuffle:
334 | # NOTE: using validation dataset
335 | valid_loader_no_shuffle = torch.utils.data.DataLoader(torch.utils.data.Subset(valid_data, imagenet_test_split.xvalid), batch_size=test_batch, shuffle=False, num_workers=workers, pin_memory=True)
336 | # NOTE: using search training dataset
337 | # valid_loader_no_shuffle = torch.utils.data.DataLoader(search_data, batch_size=test_batch, shuffle=False , num_workers=workers, pin_memory=True)
338 | else:
339 | raise ValueError('invalid dataset : {:}'.format(dataset))
340 |
341 | if use_valid_no_shuffle:
342 | return search_loader, train_loader, valid_loader, valid_loader_no_shuffle
343 | else:
344 | return search_loader, train_loader, valid_loader
345 |
346 | def get_train_transform(aug, image_size, mean, std):
347 |
348 | if aug == 'BYOL_Tau':
349 | transform = transforms.Compose([
350 | transforms.RandomResizedCrop(image_size),
351 | transforms.RandomHorizontalFlip(),
352 | transforms.RandomApply([transforms.ColorJitter(0.4, 0.4, 0.2, 0.1)], p=0.8),
353 | transforms.RandomGrayscale(p=0.2),
354 | transforms.RandomApply([GaussianBlur([0.1, 2.0])], p=1.0),
355 | transforms.RandomApply([Solarization(128)], p=0.0),
356 | transforms.ToTensor(),
357 | transforms.Normalize(mean, std),
358 |
359 | ])
360 | elif aug == 'BYOL_Tau_Hat':
361 | transform = transforms.Compose([
362 | transforms.RandomResizedCrop(image_size),
363 | transforms.RandomHorizontalFlip(),
364 | transforms.RandomApply([transforms.ColorJitter(0.4, 0.4, 0.2, 0.1)], p=0.8),
365 | transforms.RandomGrayscale(p=0.2),
366 | transforms.RandomApply([GaussianBlur([0.1, 2.0])], p=0.1),
367 | transforms.RandomApply([Solarization(128)], p=0.2),
368 | transforms.ToTensor(),
369 | transforms.Normalize(mean, std),
370 | ])
371 | else:
372 | raise NotImplementedError
373 |
374 | return transform
375 |
376 |
377 | class TwoImageAugmentations:
378 | def __init__(self, online_aug, target_aug):
379 | self.online_aug = online_aug
380 | self.target_aug = target_aug
381 |
382 | def __call__(self, x):
383 | online_image = self.online_aug(x)
384 | target_image = self.target_aug(x)
385 | return [online_image, target_image]
386 |
387 | class GaussianBlur(object):
388 | """Gaussian blur augmentation in SimCLR https://arxiv.org/abs/2002.05709"""
389 |
390 | def __init__(self, sigma=[.1, 2.]):
391 | self.sigma = sigma
392 |
393 | def __call__(self, x):
394 | sigma = random.uniform(self.sigma[0], self.sigma[1])
395 | x = x.filter(ImageFilter.GaussianBlur(radius=sigma))
396 | return x
397 |
398 |
399 | class Solarization(object):
400 | def __init__(self, magnitude=128):
401 | self.magnitude = magnitude
402 |
403 | def __call__(self, x):
404 | x = ImageOps.solarize(x, self.magnitude)
405 | return x
406 |
407 | if __name__ == '__main__':
408 | byol = True
409 | train_data, test_data, xshape, class_num = get_datasets('cifar10', '/home/zhangxuanyang/dataset/cifar.python/', -1, byol)
410 | search_loader, _, valid_loader = get_nas_search_loaders(train_data, test_data, 'cifar10', 'configs/nas-benchmark/', \
411 | (3, 3), 4)
412 | for step, (base_inputs, base_targets, arch_inputs, arch_targets) in enumerate(search_loader):
413 | print(base_inputs)
414 | break
415 |
416 | # import pdb; pdb.set_trace()
417 |
--------------------------------------------------------------------------------
/evolution_search/datasets/test_utils.py:
--------------------------------------------------------------------------------
1 | ##################################################
2 | # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 #
3 | ##################################################
4 | def test_imagenet_data(imagenet):
5 | total_length = len(imagenet)
6 | assert total_length == 1281166 or total_length == 50000, 'The length of ImageNet is wrong : {}'.format(total_length)
7 | map_id = {}
8 | for index in range(total_length):
9 | path, target = imagenet.imgs[index]
10 | folder, image_name = os.path.split(path)
11 | _, folder = os.path.split(folder)
12 | if folder not in map_id:
13 | map_id[folder] = target
14 | else:
15 | assert map_id[folder] == target, 'Class : {} is not {}'.format(folder, target)
16 | assert image_name.find(folder) == 0, '{} is wrong.'.format(path)
17 | print ('Check ImageNet Dataset OK')
18 |
--------------------------------------------------------------------------------
/evolution_search/genotypes.py:
--------------------------------------------------------------------------------
1 | """
2 | Code Adapted from https://github.com/megvii-model/RLNAS/blob/main/darts_search_space/cifar10/rlnas/evolution_search/genotypes.py
3 | """
4 |
5 | from collections import namedtuple
6 |
7 | Genotype = namedtuple('Genotype', 'normal normal_concat reduce reduce_concat')
8 |
9 | PRIMITIVES = [
10 | 'none',
11 | 'max_pool_3x3',
12 | 'avg_pool_3x3',
13 | 'skip_connect',
14 | 'sep_conv_3x3',
15 | 'sep_conv_5x5',
16 | 'dil_conv_3x3',
17 | 'dil_conv_5x5'
18 | ]
19 |
20 | NASNet = Genotype(
21 | normal = [
22 | ('sep_conv_5x5', 1),
23 | ('sep_conv_3x3', 0),
24 | ('sep_conv_5x5', 0),
25 | ('sep_conv_3x3', 0),
26 | ('avg_pool_3x3', 1),
27 | ('skip_connect', 0),
28 | ('avg_pool_3x3', 0),
29 | ('avg_pool_3x3', 0),
30 | ('sep_conv_3x3', 1),
31 | ('skip_connect', 1),
32 | ],
33 | normal_concat = [2, 3, 4, 5, 6],
34 | reduce = [
35 | ('sep_conv_5x5', 1),
36 | ('sep_conv_7x7', 0),
37 | ('max_pool_3x3', 1),
38 | ('sep_conv_7x7', 0),
39 | ('avg_pool_3x3', 1),
40 | ('sep_conv_5x5', 0),
41 | ('skip_connect', 3),
42 | ('avg_pool_3x3', 2),
43 | ('sep_conv_3x3', 2),
44 | ('max_pool_3x3', 1),
45 | ],
46 | reduce_concat = [4, 5, 6],
47 | )
48 |
49 | AmoebaNet = Genotype(
50 | normal = [
51 | ('avg_pool_3x3', 0),
52 | ('max_pool_3x3', 1),
53 | ('sep_conv_3x3', 0),
54 | ('sep_conv_5x5', 2),
55 | ('sep_conv_3x3', 0),
56 | ('avg_pool_3x3', 3),
57 | ('sep_conv_3x3', 1),
58 | ('skip_connect', 1),
59 | ('skip_connect', 0),
60 | ('avg_pool_3x3', 1),
61 | ],
62 | normal_concat = [4, 5, 6],
63 | reduce = [
64 | ('avg_pool_3x3', 0),
65 | ('sep_conv_3x3', 1),
66 | ('max_pool_3x3', 0),
67 | ('sep_conv_7x7', 2),
68 | ('sep_conv_7x7', 0),
69 | ('avg_pool_3x3', 1),
70 | ('max_pool_3x3', 0),
71 | ('max_pool_3x3', 1),
72 | ('conv_7x1_1x7', 0),
73 | ('sep_conv_3x3', 5),
74 | ],
75 | reduce_concat = [3, 4, 6]
76 | )
77 |
78 | DARTS_V1 = Genotype(normal=[('sep_conv_3x3', 1), ('sep_conv_3x3', 0), ('skip_connect', 0), ('sep_conv_3x3', 1), ('skip_connect', 0), ('sep_conv_3x3', 1), ('sep_conv_3x3', 0), ('skip_connect', 2)], normal_concat=[2, 3, 4, 5], reduce=[('max_pool_3x3', 0), ('max_pool_3x3', 1), ('skip_connect', 2), ('max_pool_3x3', 0), ('max_pool_3x3', 0), ('skip_connect', 2), ('skip_connect', 2), ('avg_pool_3x3', 0)], reduce_concat=[2, 3, 4, 5])
79 | DARTS_V2 = Genotype(normal=[('sep_conv_3x3', 0), ('sep_conv_3x3', 1), ('sep_conv_3x3', 0), ('sep_conv_3x3', 1), ('sep_conv_3x3', 1), ('skip_connect', 0), ('skip_connect', 0), ('dil_conv_3x3', 2)], normal_concat=[2, 3, 4, 5], reduce=[('max_pool_3x3', 0), ('max_pool_3x3', 1), ('skip_connect', 2), ('max_pool_3x3', 1), ('max_pool_3x3', 0), ('skip_connect', 2), ('skip_connect', 2), ('max_pool_3x3', 1)], reduce_concat=[2, 3, 4, 5])
80 |
81 | def parse_searched_cell(normal_reduce_cell):
82 | '''
83 | normal_reduce_cell: list of normal + reduce cell.
84 | e.g) [
85 | 14 elements for normal cell edges where each element denote operation for each edge +
86 | 14 elements for reduce cell edges where each element denote operation for each edge
87 | ]
88 | '''
89 | assert len(normal_reduce_cell) == 28, "cell should contain normal + reduce edges (14 + 14 = 28)"
90 | normal_cell = normal_reduce_cell[:14]
91 | reduce_cell = normal_reduce_cell[14:]
92 |
93 | normal_cell_decoded = []
94 | reduce_cell_decoded = []
95 | # normal cell decode
96 | for i in range(len(normal_cell)):
97 | # NOTE: for generating intermediate node 0
98 | if i in [0, 1]:
99 | normal_cell_decoded.append((PRIMITIVES[normal_cell[i]], i))
100 | # NOTE: for generating intermediate node 1
101 | elif i in [2, 3, 4]:
102 | if normal_cell[i] != 0:
103 | normal_cell_decoded.append((PRIMITIVES[normal_cell[i]], i - 2))
104 | # NOTE: for generating intermediate node 2
105 | elif i in [5, 6, 7, 8]:
106 | if normal_cell[i] != 0:
107 | normal_cell_decoded.append((PRIMITIVES[normal_cell[i]], i - 5))
108 | # NOTE: for generating intermediate node 3
109 | elif i in [9, 10, 11, 12, 13]:
110 | if normal_cell[i] != 0:
111 | normal_cell_decoded.append((PRIMITIVES[normal_cell[i]], i - 9))
112 |
113 | # reduce cell decode
114 | for i in range(len(reduce_cell)):
115 | # NOTE: for generating intermediate node 0
116 | if i in [0, 1]:
117 | reduce_cell_decoded.append((PRIMITIVES[reduce_cell[i]], i))
118 | # NOTE: for generating intermediate node 1
119 | elif i in [2, 3, 4]:
120 | if reduce_cell[i] != 0:
121 | reduce_cell_decoded.append((PRIMITIVES[reduce_cell[i]], i - 2))
122 | # NOTE: for generating intermediate node 2
123 | elif i in [5, 6, 7, 8]:
124 | if reduce_cell[i] != 0:
125 | reduce_cell_decoded.append((PRIMITIVES[reduce_cell[i]], i - 5))
126 | # NOTE: for generating intermediate node 3
127 | elif i in [9, 10, 11, 12, 13]:
128 | if reduce_cell[i] != 0:
129 | reduce_cell_decoded.append((PRIMITIVES[reduce_cell[i]], i - 9))
130 |
131 | return Genotype(normal=normal_cell_decoded, normal_concat=[2, 3, 4, 5], reduce=reduce_cell_decoded, reduce_concat=[2, 3, 4, 5])
132 |
133 | RLDARTS = parse_searched_cell((5, 4, 5, 0, 5, 0, 0, 5, 5, 0, 0, 0, 7, 4, 5, 4, 2, 0, 5, 0, 0, 4, 4, 0, 4, 4, 0, 0))
134 | RLDARTS_GT = parse_searched_cell((5, 5, 4, 5, 0, 0, 4, 0, 4, 0, 0, 0, 4, 4, 1, 3, 3, 3, 0, 3, 2, 0, 0, 0, 4, 0, 0, 6))
135 |
136 | DARTS = DARTS_V2
137 |
138 |
--------------------------------------------------------------------------------
/evolution_search/metrics/tester_acc.py:
--------------------------------------------------------------------------------
1 | """
2 | GeNAS
3 | Copyright (c) 2023-present NAVER Cloud Corp.
4 | MIT license
5 | """
6 | import torch
7 | import math
8 | from config import config
9 | assert torch.cuda.is_available()
10 |
11 | def accuracy(output, target, topk=(1,)):
12 | maxk = max(topk)
13 | batch_size = target.size(0)
14 |
15 | _, pred = output.topk(maxk, 1, True, True)
16 | pred = pred.t()
17 | correct = pred.eq(target.view(1, -1).expand_as(pred))
18 |
19 | res = []
20 | for k in topk:
21 | correct_k = correct[:k].reshape(-1).float().sum(0)
22 | res.append(correct_k.mul_(100.0 / batch_size))
23 | return res
24 |
25 |
26 | def no_grad_wrapper(func):
27 | def new_func(*args, **kwargs):
28 | with torch.no_grad():
29 | return func(*args, **kwargs)
30 | return new_func
31 |
32 | @no_grad_wrapper
33 | def get_cand_acc(model, genotype, train_dataloader, val_dataloader, max_train_img_size, max_val_img_size=25000):
34 | '''
35 | genotype: normal (14 edges) + reduce cell (14 edges) with operation indices. e.g. 6, 6, 3, 0, 4, 7, 0, 0, 1, 0, 6, 0, 6, 0, 1, 4, 3, 7, 0, 7, 0, 0, 4, 0, 0, 0, 3, 2]
36 | train_dataloader: half (25K) of original training set (50K).
37 | val_dataloader: another half (25K) of original training set (50K).
38 | '''
39 | # separate genotype
40 | normal_genotype = tuple(genotype[:config.edges])
41 | reduce_genotype = tuple(genotype[config.edges:])
42 |
43 | train_dataloader_iter = iter(train_dataloader)
44 | val_dataloader_iter = iter(val_dataloader)
45 |
46 | if torch.cuda.is_available():
47 | device = torch.device('cuda')
48 | else:
49 | device = torch.device('cpu')
50 |
51 | # NOTE: # iterations of BN statistics re-tracking for search loader
52 | max_train_iters = math.ceil(max_train_img_size / train_dataloader.batch_size)
53 |
54 | # NOTE: # iterations of measure validation accuracy for all validation images
55 | max_test_iters = math.ceil(max_val_img_size / val_dataloader.batch_size)
56 |
57 | if max_train_iters > 0:
58 | # NOTE: [from SPOS paper] "Before the inference of an architecture, the statistics of all the Batch Normalization (BN) [9] operations are recalculated on a random subset of training data"
59 | for m in model.modules():
60 | if isinstance(m, torch.nn.BatchNorm2d):
61 | m.running_mean = torch.zeros_like(m.running_mean)
62 | m.running_var = torch.ones_like(m.running_var)
63 |
64 | model.train()
65 |
66 | for step in range(max_train_iters):
67 | batch = train_dataloader_iter.next()
68 | if len(batch) == 4:
69 | data, target, _, _ = batch
70 | elif len(batch) == 2:
71 | data, target = batch
72 |
73 | target = target.type(torch.LongTensor)
74 |
75 | data, target = data.to(device), target.to(device)
76 |
77 | output = model(data, normal_genotype, reduce_genotype)
78 |
79 | del data, target, output
80 |
81 | top1 = 0
82 | top5 = 0
83 | total = 0
84 |
85 | print('starting test....')
86 | model.eval()
87 |
88 | for step in range(max_test_iters):
89 | data, target = val_dataloader_iter.next()
90 | batchsize = data.shape[0]
91 | target = target.type(torch.LongTensor)
92 | data, target = data.to(device), target.to(device)
93 | logits = model(data, normal_genotype, reduce_genotype)
94 | prec1, prec5 = accuracy(logits, target, topk=(1, 5))
95 | top1 += prec1.item() * batchsize
96 | top5 += prec5.item() * batchsize
97 | total += batchsize
98 |
99 | del data, target, logits, prec1, prec5
100 |
101 | top1, top5 = top1 / total, top5 / total
102 | top1, top5 = top1 / 100, top5 / 100
103 | return top1, top5
104 |
105 | def main():
106 | pass
--------------------------------------------------------------------------------
/evolution_search/metrics/tester_wlm.py:
--------------------------------------------------------------------------------
1 | """
2 | GeNAS
3 | Copyright (c) 2023-present NAVER Cloud Corp.
4 | MIT license
5 | """
6 | import torch
7 | import torch.nn as nn
8 | import numpy as np
9 | from copy import deepcopy
10 | import math
11 | from config import config
12 | assert torch.cuda.is_available()
13 |
14 | def check_strictly_increasing(L):
15 | return all(x 0:
50 | # NOTE: [from SPOS paper] "Before the inference of an architecture, the statistics of all the Batch Normalization (BN) [9] operations are recalculated on a random subset of training data"
51 | for m in model.modules():
52 | if isinstance(m, torch.nn.BatchNorm2d):
53 | m.running_mean = torch.zeros_like(m.running_mean)
54 | m.running_var = torch.ones_like(m.running_var)
55 |
56 | model.train()
57 |
58 | for step in range(max_train_iters):
59 | batch = train_dataloader_iter.next()
60 | if len(batch) == 4:
61 | data, target, _, _ = batch
62 | elif len(batch) == 2:
63 | data, target = batch
64 |
65 | target = target.type(torch.LongTensor)
66 | data, target = data.to(device), target.to(device)
67 | output = model(data, normal_genotype, reduce_genotype)
68 | del data, target, output
69 |
70 | losses_per_stds = []
71 |
72 | model.eval()
73 |
74 | model_ = deepcopy(model)
75 |
76 | for std_idx, cur_std in enumerate(stds):
77 | val_dataloader_iter = iter(val_dataloader)
78 |
79 | # NOTE: adding gaussian noise parameterized by residual of std (\simga_t+1 - \sigma_t)in a cumulative way or direct adding way (\sigma_t)
80 | # NOTE: former bypasses deep copy of models each time, thus memory efficient.
81 | # NOTE: while former and latter could give different results, we take former for efficient memory usage.
82 | if std_idx == 0:
83 | std = cur_std # initial std
84 | else:
85 | std = cur_std - stds[std_idx - 1] # cumulate
86 |
87 | for name, param in model_.named_parameters():
88 | # NOTE: add gaussian noise for all parameters
89 | param.data.add_(torch.normal(0, std, size=param.size()).type(param.dtype).to(param.device))
90 | model_.eval()
91 |
92 | losses_over_batch = 0
93 |
94 | for step in range(max_test_iters):
95 | # NOTE: using validation dataset
96 | data, target = val_dataloader_iter.next()
97 |
98 | # NOTE: using search training dataset
99 | batchsize = data.shape[0]
100 | target = target.type(torch.LongTensor)
101 | data, target = data.to(device), target.to(device)
102 |
103 | logits = model_(data, normal_genotype, reduce_genotype)
104 | loss = criterion(logits, target)
105 | losses_over_batch += loss.item()
106 |
107 | del data, target, logits
108 |
109 | losses_mean = losses_over_batch / max_test_iters
110 | losses_per_stds.append(losses_mean)
111 |
112 | del model_
113 |
114 | # calculate wide & flat measure
115 | poor_minima = 0
116 | # NOTE: summation of gradients for loss. (regard non-perturbed loss value as 0)
117 | # TODO: add initial loss value (std:0, non-perturbed loss value)
118 | poor_minima += abs(losses_per_stds[0] / stds[0])
119 | for i in range(len(losses_per_stds) - 1):
120 | poor_minima += abs((losses_per_stds[i+1] - losses_per_stds[i]) / (stds[i+1] - stds[i]))
121 |
122 | wlm = 1 / (poor_minima + 1e-5) # wide & flat minima measure
123 | return wlm, losses_per_stds
--------------------------------------------------------------------------------
/evolution_search/operations.py:
--------------------------------------------------------------------------------
1 | """
2 | Code Adapted from https://github.com/megvii-model/RLNAS/blob/main/darts_search_space/cifar10/rlnas/evolution_search/operations.py
3 | """
4 |
5 | import torch
6 | import torch.nn as nn
7 |
8 | OPS = {
9 | 'none' : lambda C, stride, affine: Zero(stride),
10 | 'avg_pool_3x3' : lambda C, stride, affine: nn.AvgPool2d(3, stride=stride, padding=1, count_include_pad=False),
11 | 'max_pool_3x3' : lambda C, stride, affine: nn.MaxPool2d(3, stride=stride, padding=1),
12 | 'skip_connect' : lambda C, stride, affine: Identity() if stride == 1 else FactorizedReduce(C, C, affine=affine),
13 | 'sep_conv_3x3' : lambda C, stride, affine: SepConv(C, C, 3, stride, 1, affine=affine),
14 | 'sep_conv_5x5' : lambda C, stride, affine: SepConv(C, C, 5, stride, 2, affine=affine),
15 | 'sep_conv_7x7' : lambda C, stride, affine: SepConv(C, C, 7, stride, 3, affine=affine),
16 | 'dil_conv_3x3' : lambda C, stride, affine: DilConv(C, C, 3, stride, 2, 2, affine=affine),
17 | 'dil_conv_5x5' : lambda C, stride, affine: DilConv(C, C, 5, stride, 4, 2, affine=affine),
18 | 'conv_7x1_1x7' : lambda C, stride, affine: nn.Sequential(
19 | nn.ReLU(inplace=False),
20 | nn.Conv2d(C, C, (1,7), stride=(1, stride), padding=(0, 3), bias=False),
21 | nn.Conv2d(C, C, (7,1), stride=(stride, 1), padding=(3, 0), bias=False),
22 | nn.BatchNorm2d(C, affine=affine)
23 | ),
24 | }
25 |
26 | class ReLUConvBN(nn.Module):
27 |
28 | def __init__(self, C_in, C_out, kernel_size, stride, padding, affine=True):
29 | super(ReLUConvBN, self).__init__()
30 | self.op = nn.Sequential(
31 | nn.ReLU(inplace=False),
32 | nn.Conv2d(C_in, C_out, kernel_size, stride=stride, padding=padding, bias=False),
33 | nn.BatchNorm2d(C_out, affine=affine)
34 | )
35 |
36 | def forward(self, x, rngs=None):
37 | return self.op(x)
38 |
39 | class DilConv(nn.Module):
40 |
41 | def __init__(self, C_in, C_out, kernel_size, stride, padding, dilation, affine=True):
42 | super(DilConv, self).__init__()
43 | self.op = nn.Sequential(
44 | nn.ReLU(inplace=False),
45 | nn.Conv2d(C_in, C_in, kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation, groups=C_in, bias=False),
46 | nn.Conv2d(C_in, C_out, kernel_size=1, padding=0, bias=False),
47 | nn.BatchNorm2d(C_out, affine=affine),
48 | )
49 |
50 | def forward(self, x, rngs=None):
51 | return self.op(x)
52 |
53 | class SepConv(nn.Module):
54 |
55 | def __init__(self, C_in, C_out, kernel_size, stride, padding, affine=True):
56 | super(SepConv, self).__init__()
57 | self.op = nn.Sequential(
58 | nn.ReLU(inplace=False),
59 | nn.Conv2d(C_in, C_in, kernel_size=kernel_size, stride=stride, padding=padding, groups=C_in, bias=False),
60 | nn.Conv2d(C_in, C_in, kernel_size=1, padding=0, bias=False),
61 | nn.BatchNorm2d(C_in, affine=affine),
62 | nn.ReLU(inplace=False),
63 | nn.Conv2d(C_in, C_in, kernel_size=kernel_size, stride=1, padding=padding, groups=C_in, bias=False),
64 | nn.Conv2d(C_in, C_out, kernel_size=1, padding=0, bias=False),
65 | nn.BatchNorm2d(C_out, affine=affine),
66 | )
67 |
68 | def forward(self, x, rngs=None):
69 | return self.op(x)
70 |
71 |
72 | class Identity(nn.Module):
73 |
74 | def __init__(self):
75 | super(Identity, self).__init__()
76 |
77 | def forward(self, x, rngs=None):
78 | return x
79 |
80 | class Zero(nn.Module):
81 |
82 | def __init__(self, stride):
83 | super(Zero, self).__init__()
84 | self.stride = stride
85 | def forward(self, x, rngs=None):
86 | n, c, h, w = x.size()
87 | h //= self.stride
88 | w //= self.stride
89 | if x.is_cuda:
90 | with torch.cuda.device(x.get_device()):
91 | padding = torch.cuda.FloatTensor(n, c, h, w).fill_(0)
92 | else:
93 | padding = torch.FloatTensor(n, c, h, w).fill_(0)
94 | return padding
95 |
96 | class FactorizedReduce(nn.Module):
97 |
98 | def __init__(self, C_in, C_out, affine=True):
99 | super(FactorizedReduce, self).__init__()
100 | assert C_out % 2 == 0
101 | self.relu = nn.ReLU(inplace=False)
102 | self.conv_1 = nn.Conv2d(C_in, C_out // 2, 1, stride=2, padding=0, bias=False)
103 | self.conv_2 = nn.Conv2d(C_in, C_out // 2, 1, stride=2, padding=0, bias=False)
104 | self.bn = nn.BatchNorm2d(C_out, affine=affine)
105 |
106 | def forward(self, x, rngs=None):
107 | x = self.relu(x)
108 | out = torch.cat([self.conv_1(x), self.conv_2(x[:,:,1:,1:])], dim=1)
109 | out = self.bn(out)
110 | return out
111 |
112 |
113 |
--------------------------------------------------------------------------------
/evolution_search/search.py:
--------------------------------------------------------------------------------
1 | """
2 | Code Adapted from https://github.com/megvii-model/RLNAS/blob/main/darts_search_space/cifar10/rlnas/evolution_search/search.py
3 | """
4 |
5 | import os
6 | import sys
7 | import time
8 | import glob
9 | import random
10 | import numpy as np
11 | import pickle
12 | import torch
13 | import logging
14 | import argparse
15 | import torch.nn as nn
16 | import torch.utils
17 | import torch.nn.functional as F
18 | import torchvision.datasets as dset
19 | import torch.backends.cudnn as cudnn
20 | from cand_evaluator import CandEvaluator
21 | from genotypes import parse_searched_cell
22 | from datasets import get_datasets, get_nas_search_loaders
23 | from config import config
24 | import collections
25 | import sys
26 |
27 | sys.setrecursionlimit(10000)
28 | import argparse
29 | import utils
30 | import functools
31 |
32 | print = functools.partial(print, flush=True)
33 |
34 | choice = (
35 | lambda x: x[np.random.randint(len(x))] if isinstance(x, tuple) else choice(tuple(x))
36 | )
37 |
38 |
39 | class EvolutionTrainer(object):
40 | def __init__(
41 | self,
42 | log_dir,
43 | final_model_path,
44 | initial_model_path,
45 | metric="angle",
46 | train_loader=None,
47 | valid_loader=None,
48 | perturb_stds=None,
49 | max_train_img_size=5000,
50 | max_val_img_size=10000,
51 | wlm_weight=0,
52 | acc_weight=0,
53 | refresh=False,
54 | ):
55 | self.log_dir = log_dir
56 | self.checkpoint_name = os.path.join(self.log_dir, "checkpoint.brainpkl")
57 | self.refresh = refresh
58 | self.cand_evaluator = CandEvaluator(
59 | logging,
60 | final_model_path,
61 | initial_model_path,
62 | metric,
63 | train_loader,
64 | valid_loader,
65 | perturb_stds,
66 | max_train_img_size,
67 | max_val_img_size,
68 | wlm_weight,
69 | acc_weight,
70 | )
71 |
72 | self.memory = []
73 | self.candidates = []
74 | self.vis_dict = {}
75 | self.keep_top_k = {config.select_num: [], 50: []}
76 | self.epoch = 0
77 | self.cand_idx = 0 # for generating candidate idx
78 | self.operations = [list(range(config.op_num)) for _ in range(config.edges)]
79 |
80 | self.metric = metric
81 |
82 | def save_checkpoint(self):
83 | if not os.path.exists(self.log_dir):
84 | os.mkdir(self.log_dir)
85 | info = {}
86 | info["memory"] = self.memory
87 | info["candidates"] = self.candidates
88 | info["vis_dict"] = self.vis_dict
89 | info["keep_top_k"] = self.keep_top_k
90 | info["epoch"] = self.epoch
91 | info["cand_idx"] = self.cand_idx
92 | torch.save(info, self.checkpoint_name)
93 | logging.info("save checkpoint to {}".format(self.checkpoint_name))
94 |
95 | def load_checkpoint(self):
96 | if not os.path.exists(self.checkpoint_name):
97 | return False
98 | info = torch.load(self.checkpoint_name)
99 | self.memory = info["memory"]
100 | self.candidates = info["candidates"]
101 | self.vis_dict = info["vis_dict"]
102 | self.keep_top_k = info["keep_top_k"]
103 | self.epoch = info["epoch"]
104 | self.cand_idx = info["cand_idx"]
105 |
106 | if self.refresh:
107 | for i, j in self.vis_dict.items():
108 | for k in ["test_key"]:
109 | if k in j:
110 | j.pop(k)
111 | self.refresh = False
112 |
113 | logging.info("load checkpoint from {}".format(self.checkpoint_name))
114 | return True
115 |
116 | def legal(self, cand):
117 | assert isinstance(cand, tuple) and len(cand) == (2 * config.edges)
118 | if cand not in self.vis_dict:
119 | self.vis_dict[cand] = {}
120 | info = self.vis_dict[cand]
121 | if "visited" in info:
122 | return False
123 |
124 | if config.flops_limit is not None:
125 | pass
126 |
127 | self.vis_dict[cand] = info
128 | info["visited"] = True
129 |
130 | return True
131 |
132 | def update_top_k(self, candidates, *, k, key, reverse=False):
133 | assert k in self.keep_top_k
134 | logging.info("select ......")
135 | t = self.keep_top_k[k]
136 | t += candidates
137 | t.sort(key=key, reverse=reverse)
138 | self.keep_top_k[k] = t[:k]
139 |
140 | def gen_key(self, cand):
141 | # NOTE: generate unique id for candidate
142 | self.cand_idx += 1
143 | key = "{}-{}".format(self.cand_idx, time.time())
144 | return key
145 |
146 | def eval_cand(self, cand, cand_key):
147 | # NOTE: evaluate candidate
148 | try:
149 | result = self.cand_evaluator.eval(cand)
150 | return result
151 | except:
152 | import traceback
153 |
154 | traceback.print_exc()
155 | return {"status": "uncatched error"}
156 |
157 | def sync_candidates(self):
158 | while True:
159 | ok = True
160 | for cand in self.candidates:
161 | info = self.vis_dict[cand]
162 | if self.metric in info:
163 | continue
164 | ok = False
165 | if "test_key" not in info:
166 | info["test_key"] = self.gen_key(cand)
167 |
168 | self.save_checkpoint()
169 |
170 | for cand in self.candidates:
171 | info = self.vis_dict[cand]
172 | if self.metric in info:
173 | continue
174 | key = info.pop("test_key")
175 |
176 | try:
177 | logging.info("try to get {}".format(key))
178 | res = self.eval_cand(
179 | cand, key
180 | ) # NOTE: currently, key and cand has implicit connection
181 | logging.info(res)
182 | info[self.metric] = res[self.metric]
183 | self.save_checkpoint()
184 | except:
185 | import traceback
186 |
187 | traceback.print_exc()
188 | time.sleep(1)
189 |
190 | time.sleep(5)
191 | if ok:
192 | break
193 |
194 | def stack_random_cand(self, random_func, *, batchsize=10):
195 | while True:
196 | cands = [random_func() for _ in range(batchsize)]
197 | for cand in cands:
198 | if cand not in self.vis_dict:
199 | self.vis_dict[cand] = {}
200 | else:
201 | continue
202 | info = self.vis_dict[cand]
203 | # for cand in cands:
204 | yield cand
205 |
206 | def stack_random_cand_crossover(self, random_func, max_iters, *, batchsize=10):
207 | cand_count = 0
208 | while True:
209 | if cand_count > max_iters:
210 | break
211 | cands = [random_func() for _ in range(batchsize)]
212 | cand_count += 1
213 | for cand in cands:
214 | if cand not in self.vis_dict:
215 | self.vis_dict[cand] = {}
216 | else:
217 | continue
218 | info = self.vis_dict[cand]
219 | # for cand in cands:
220 | yield cand
221 |
222 | def random_can(self, num):
223 | logging.info("random select ........")
224 | candidates = []
225 | cand_iter = self.stack_random_cand(
226 | lambda: tuple(
227 | np.random.randint(config.op_num) for _ in range(2 * config.edges)
228 | )
229 | )
230 | while len(candidates) < num:
231 | cand = next(cand_iter)
232 | normal_cand = cand[: config.edges]
233 | reduction_cand = cand[config.edges :]
234 | normal_cand = utils.check_cand(normal_cand, self.operations)
235 | reduction_cand = utils.check_cand(reduction_cand, self.operations)
236 | cand = normal_cand + reduction_cand
237 | cand = tuple(cand)
238 | if not self.legal(cand):
239 | continue
240 | candidates.append(cand)
241 | logging.info("random {}/{}".format(len(candidates), num))
242 | logging.info("random_num = {}".format(len(candidates)))
243 | return candidates
244 |
245 | def get_mutation(self, k, mutation_num, m_prob):
246 | assert k in self.keep_top_k
247 | logging.info("mutation ......")
248 | res = []
249 | iter = 0
250 | max_iters = mutation_num * 10
251 |
252 | def random_func():
253 | cand = list(choice(self.keep_top_k[k]))
254 | for i in range(config.edges):
255 | if np.random.random_sample() < m_prob:
256 | cand[i] = np.random.randint(0, config.op_num)
257 | return tuple(cand)
258 |
259 | cand_iter = self.stack_random_cand(random_func)
260 | while len(res) < mutation_num and max_iters > 0:
261 | cand = next(cand_iter)
262 | normal_cand = cand[: config.edges]
263 | reduction_cand = cand[config.edges :]
264 | normal_cand = utils.check_cand(normal_cand, self.operations)
265 | reduction_cand = utils.check_cand(reduction_cand, self.operations)
266 | cand = normal_cand + reduction_cand
267 | cand = tuple(cand)
268 | if not self.legal(cand):
269 | continue
270 | res.append(cand)
271 | logging.info("mutation {}/{}".format(len(res), mutation_num))
272 | max_iters -= 1
273 |
274 | logging.info("mutation_num = {}".format(len(res)))
275 | return res
276 |
277 | def get_crossover(self, k, crossover_num):
278 | assert k in self.keep_top_k
279 | logging.info("crossover ......")
280 | res = []
281 | iter = 0
282 | max_iters = 10 * crossover_num
283 |
284 | def random_func():
285 | p1 = choice(self.keep_top_k[k])
286 | p2 = choice(self.keep_top_k[k])
287 | return tuple(choice([i, j]) for i, j in zip(p1, p2))
288 |
289 | cand_iter = self.stack_random_cand_crossover(random_func, crossover_num)
290 | while len(res) < crossover_num:
291 | try:
292 | cand = next(cand_iter)
293 | normal_cand = cand[: config.edges]
294 | reduction_cand = cand[config.edges :]
295 | normal_cand = utils.check_cand(normal_cand, self.operations)
296 | reduction_cand = utils.check_cand(reduction_cand, self.operations)
297 | cand = normal_cand + reduction_cand
298 | cand = tuple(cand)
299 | except Exception as e:
300 | logging.info(e)
301 | break
302 | if not self.legal(cand):
303 | continue
304 | res.append(cand)
305 | logging.info("crossover {}/{}".format(len(res), crossover_num))
306 |
307 | logging.info("crossover_num = {}".format(len(res)))
308 | return res
309 |
310 | def train(self):
311 | logging.info(
312 | "population_num = {} select_num = {} mutation_num = {} crossover_num = {} random_num = {} max_epochs = {}".format(
313 | config.population_num,
314 | config.select_num,
315 | config.mutation_num,
316 | config.crossover_num,
317 | config.population_num - config.mutation_num - config.crossover_num,
318 | config.max_epochs,
319 | )
320 | )
321 |
322 | if not self.load_checkpoint():
323 | self.candidates = self.random_can(config.population_num)
324 | self.save_checkpoint()
325 |
326 | while self.epoch < config.max_epochs:
327 | logging.info("epoch = {}".format(self.epoch))
328 |
329 | self.sync_candidates() # NOTE: evaluate candidates
330 |
331 | logging.info("sync finish")
332 |
333 | self.memory.append([])
334 | for cand in self.candidates:
335 | self.memory[-1].append(cand)
336 | self.vis_dict[cand]["visited"] = True
337 |
338 | self.update_top_k(
339 | self.candidates,
340 | k=config.select_num,
341 | key=lambda x: self.vis_dict[x][self.metric],
342 | reverse=True,
343 | )
344 | self.update_top_k(
345 | self.candidates,
346 | k=50,
347 | key=lambda x: self.vis_dict[x][self.metric],
348 | reverse=True,
349 | )
350 |
351 | logging.info(
352 | "epoch = {} : top {} result".format(
353 | self.epoch, len(self.keep_top_k[50])
354 | )
355 | )
356 | for i, cand in enumerate(self.keep_top_k[50]):
357 | logging.info(
358 | "No.{} {} {} = {}".format(
359 | i + 1, cand, self.metric, self.vis_dict[cand][self.metric]
360 | )
361 | )
362 | # ops = [config.blocks_keys[i] for i in cand]
363 | ops = [config.blocks_keys[i] for i in cand]
364 | logging.info(ops)
365 |
366 | mutation = self.get_mutation(
367 | config.select_num, config.mutation_num, config.m_prob
368 | )
369 | crossover = self.get_crossover(config.select_num, config.crossover_num)
370 | rand = self.random_can(
371 | config.population_num - len(mutation) - len(crossover)
372 | )
373 | self.candidates = mutation + crossover + rand
374 |
375 | self.epoch += 1
376 | self.save_checkpoint()
377 |
378 | logging.info(self.keep_top_k[config.select_num])
379 | logging.info("finish!")
380 | logging.info(
381 | "Top-1 Searched Cell Architecture : {}".format(
382 | parse_searched_cell(self.keep_top_k[config.select_num][0])
383 | )
384 | )
385 |
386 |
387 | def prepare_seed(rand_seed):
388 | random.seed(rand_seed)
389 | np.random.seed(rand_seed)
390 | torch.manual_seed(rand_seed)
391 | torch.cuda.manual_seed(rand_seed)
392 | torch.cuda.manual_seed_all(rand_seed)
393 |
394 |
395 | class SplitArgs(argparse.Action):
396 | def __call__(self, parser, namespace, values, option_string=None):
397 | setattr(namespace, self.dest, [float(val) for val in values.split(",")])
398 |
399 |
400 | def main():
401 | parser = argparse.ArgumentParser()
402 | parser.add_argument("-r", "--refresh", action="store_true")
403 | parser.add_argument("--save", type=str, default="log", help="experiment name")
404 | parser.add_argument("--seed", type=int, default=1, help="experiment name")
405 | parser.add_argument(
406 | "--init_model_path",
407 | type=str,
408 | default=config.initial_net_cache,
409 | help="initial model ckpt path",
410 | )
411 | parser.add_argument(
412 | "--model_path", type=str, default=config.net_cache, help="final model ckpt path"
413 | )
414 | parser.add_argument(
415 | "--metric", type=str, default="angle", help="metric to evaulate candidate with."
416 | )
417 |
418 | """ below are required if args.metric is not "angle" """
419 | parser.add_argument(
420 | "--data",
421 | type=str,
422 | default="",
423 | help='data root path. required if --metric is not "angle"',
424 | )
425 | parser.add_argument(
426 | "--split_data",
427 | type=int,
428 | choices=[0, 1],
429 | default=1,
430 | help="Whether use split data for training & validation. (default: True)",
431 | )
432 | parser.add_argument("--batch_size", type=int, default=64, help="train batch_size")
433 | parser.add_argument(
434 | "--test_batch_size", type=int, default=512, help="test batch_size"
435 | )
436 | parser.add_argument(
437 | "--cutout", action="store_true", default=False, help="use cutout"
438 | )
439 | parser.add_argument("--cutout_length", type=int, default=16, help="cutout length")
440 | # GeNAS hyperparameters
441 | parser.add_argument(
442 | "--stds",
443 | default=None,
444 | action=SplitArgs,
445 | help="std values for weight perturbation",
446 | )
447 | parser.add_argument(
448 | "--max_train_img_size",
449 | type=int,
450 | default=5000,
451 | help="maximum number of training imgs for batch norm statistics recalculation.",
452 | )
453 | parser.add_argument(
454 | "--max_val_img_size",
455 | type=int,
456 | default=10000,
457 | help="maximum number of validation imgs for evaluating architecture candidates. (required only for wlm)",
458 | )
459 | # Combined metric
460 | parser.add_argument("--wlm_weight", type=float, default=0, help="wlm weight")
461 | parser.add_argument("--acc_weight", type=float, default=0, help="acc weight")
462 |
463 | args = parser.parse_args()
464 |
465 | args.split_data = bool(args.split_data)
466 |
467 | utils.create_exp_dir(args.save, scripts_to_save=glob.glob("*.py"))
468 |
469 | log_format = "%(asctime)s %(message)s"
470 | logging.basicConfig(
471 | stream=sys.stdout,
472 | level=logging.INFO,
473 | format=log_format,
474 | datefmt="%m/%d %I:%M:%S %p",
475 | )
476 | fh = logging.FileHandler(os.path.join(args.save, "search_log.txt"))
477 | fh.setFormatter(logging.Formatter(log_format))
478 | logging.getLogger().addHandler(fh)
479 |
480 | if (
481 | args.split_data
482 | ): # NOTE: split train data in half to be new train, val set. new train is used for supernet training, new val set is used for evaluation
483 | train_data, valid_data, xshape, class_num = get_datasets(
484 | "cifar100", args.data, -1, args.seed, random_label=False
485 | ) # NOTE: using GT label
486 | train_queue, _, _, valid_queue = get_nas_search_loaders(
487 | train_data,
488 | valid_data,
489 | "cifar100",
490 | "datasets/configs/",
491 | (args.batch_size, args.batch_size),
492 | 4,
493 | use_valid_no_shuffle=True,
494 | )
495 | else:
496 | assert ValueError("only --split_data 1 is supported")
497 |
498 | refresh = args.refresh
499 | # np.random.seed(args.seed)
500 | prepare_seed(args.seed)
501 |
502 | t = time.time()
503 |
504 | trainer = EvolutionTrainer(
505 | args.save,
506 | args.model_path,
507 | args.init_model_path,
508 | metric=args.metric,
509 | train_loader=train_queue,
510 | valid_loader=valid_queue,
511 | perturb_stds=args.stds,
512 | max_train_img_size=args.max_train_img_size,
513 | max_val_img_size=args.max_val_img_size,
514 | wlm_weight=args.wlm_weight,
515 | acc_weight=args.acc_weight,
516 | refresh=refresh,
517 | )
518 |
519 | trainer.train()
520 | logging.info("total searching time = {:.2f} hours".format((time.time() - t) / 3600))
521 |
522 |
523 | if __name__ == "__main__":
524 | try:
525 | main()
526 | os._exit(0)
527 | except:
528 | import traceback
529 |
530 | traceback.print_exc()
531 | time.sleep(1)
532 | os._exit(1)
533 |
--------------------------------------------------------------------------------
/evolution_search/super_model.py:
--------------------------------------------------------------------------------
1 | '''
2 | Copied from https://github.com/megvii-model/RLNAS/blob/main/darts_search_space/cifar10/rlnas/evolution_search/super_model.py
3 | '''
4 |
5 | import torch
6 | import torch.nn as nn
7 | import torch.nn.functional as F
8 | from operations import *
9 | from torch.autograd import Variable
10 | from genotypes import PRIMITIVES
11 | from genotypes import Genotype
12 | import math
13 | import numpy as np
14 | from config import config
15 | import copy
16 | from utils import check_cand
17 |
18 | class MixedOp(nn.Module):
19 |
20 | def __init__(self, C, stride):
21 | super(MixedOp, self).__init__()
22 | self._ops = nn.ModuleList()
23 | for idx, primitive in enumerate(PRIMITIVES):
24 | op = OPS[primitive](C, stride, True)
25 | op.idx = idx
26 | if 'pool' in primitive:
27 | op = nn.Sequential(op, nn.BatchNorm2d(C, affine=True))
28 | self._ops.append(op)
29 |
30 | def forward(self, x, rng):
31 | return self._ops[rng](x)
32 |
33 |
34 | class Cell(nn.Module):
35 |
36 | def __init__(self, steps, multiplier, C_prev_prev, C_prev, C, reduction, reduction_prev):
37 | super(Cell, self).__init__()
38 | if reduction_prev:
39 | # NOTE: if K-1 cell output was from stride-2 op, K-2 cell output should shrink its spatial size by stride-2.
40 | self.preprocess0 = FactorizedReduce(C_prev_prev, C, affine=True)
41 | else:
42 | self.preprocess0 = ReLUConvBN(C_prev_prev, C, 1, 1, 0, affine=True)
43 | self.preprocess1 = ReLUConvBN(C_prev, C, 1, 1, 0, affine=True)
44 | self._steps = steps
45 | self._multiplier = multiplier
46 | self._C = C
47 | self.out_C = self._multiplier * C
48 | self.reduction = reduction
49 |
50 | self._ops = nn.ModuleList()
51 | self._bns = nn.ModuleList()
52 | self.time_stamp = 1
53 |
54 | for i in range(self._steps):
55 | for j in range(2+i):
56 | stride = 2 if reduction and j < 2 else 1
57 | op = MixedOp(C, stride)
58 | self._ops.append(op)
59 |
60 | def forward(self, s0, s1, rngs):
61 | s0 = self.preprocess0(s0)
62 | s1 = self.preprocess1(s1)
63 | states = [s0, s1]
64 | offset = 0
65 | for i in range(self._steps):
66 | # NOTE: only two edges (operations) from two previous nodes are summed.
67 | s = sum(self._ops[offset+j](h, rngs[offset+j]) for j, h in enumerate(states))
68 | offset += len(states)
69 | states.append(s)
70 | return torch.cat(states[-self._multiplier:], dim=1) # NOTE: final 4 intermediate nodes are concatenated. (k-1, k-2 node ouptut제외)
71 |
72 | class Network(nn.Module):
73 | def __init__(self, C=16, num_classes=100, layers=8, steps=4, multiplier=4, stem_multiplier=3):
74 | super(Network, self).__init__()
75 | self._C = C
76 | self._num_classes = num_classes
77 | self._layers = layers
78 | self._steps = steps
79 | self._multiplier = multiplier
80 |
81 | C_curr = stem_multiplier * C
82 |
83 | self.stem = nn.Sequential(
84 | nn.Conv2d(3, C_curr, 3, padding=1, bias=False),
85 | nn.BatchNorm2d(C_curr)
86 | )
87 |
88 | C_prev_prev, C_prev, C_curr = C_curr, C_curr, C
89 |
90 | self.cells = nn.ModuleList()
91 | reduction_prev = False
92 |
93 | for i in range(layers):
94 | if i in [layers // 3, 2 * layers // 3]:
95 | C_curr *= 2
96 | reduction = True
97 | else:
98 | reduction = False
99 | cell = Cell(steps, multiplier, C_prev_prev, C_prev, C_curr, reduction, reduction_prev)
100 | reduction_prev = reduction
101 | self.cells += [cell]
102 | C_prev_prev, C_prev = C_prev, multiplier * C_curr
103 |
104 | self.global_pooling = nn.AdaptiveAvgPool2d(1)
105 | self.classifier = nn.Linear(C_prev, num_classes)
106 |
107 | def forward_normal_only(self, input, rng):
108 | ''' forward function for only normal cells '''
109 | s0 = s1 = self.stem(input)
110 | for i, cell in enumerate(self.cells):
111 | s0, s1 = s1, cell(s0, s1, rng)
112 | out = self.global_pooling(s1)
113 | logits = self.classifier(out.view(out.size(0),-1))
114 | return logits
115 |
116 | def forward(self, input, normal_rng, reduction_rng):
117 | ''' forward function for normal + reduction cells '''
118 | s0 = s1 = self.stem(input)
119 | for i, cell in enumerate(self.cells):
120 | if i in [self._layers // 3, 2 * self._layers // 3]:
121 | s0, s1 = s1, cell(s0, s1, reduction_rng)
122 | else:
123 | s0, s1 = s1, cell(s0, s1, normal_rng)
124 | out = self.global_pooling(s1)
125 | logits = self.classifier(out.view(out.size(0),-1))
126 | return logits
127 |
128 | if __name__ == '__main__':
129 | from copy import deepcopy
130 | model = Network()
131 | operations = []
132 | for _ in range(config.edges):
133 | operations.append(list(range(config.op_num)))
134 | normal_rng = [np.random.randint(len(config.blocks_keys)) for i in range(config.edges)]
135 | reduction_rng = [np.random.randint(len(config.blocks_keys)) for i in range(config.edges)]
136 | normal_rng = check_cand(normal_rng, operations) # NOTE: modify genotype to accept only two edges (operetions) from previous nodes
137 | reduction_rng = check_cand(reduction_rng, operations) # NOTE: modify genotype to accept only two edges (operetions) from previous nodes
138 | x = torch.rand(4,3,32,32)
139 | logit = model(x, normal_rng, reduction_rng)
140 | print('logit:{0}'.format(logit))
141 |
--------------------------------------------------------------------------------
/evolution_search/utils.py:
--------------------------------------------------------------------------------
1 | '''
2 | Copied from https://github.com/megvii-model/RLNAS/blob/main/darts_search_space/cifar10/rlnas/evolution_search/utils.py
3 | '''
4 |
5 | import os
6 | import numpy as np
7 | import torch
8 | import shutil
9 | import torchvision.transforms as transforms
10 | from torch.autograd import Variable
11 | from collections import defaultdict
12 | from config import config
13 |
14 | class AvgrageMeter(object):
15 |
16 | def __init__(self):
17 | self.reset()
18 |
19 | def reset(self):
20 | self.avg = 0
21 | self.sum = 0
22 | self.cnt = 0
23 |
24 | def update(self, val, n=1):
25 | self.sum += val * n
26 | self.cnt += n
27 | self.avg = self.sum / self.cnt
28 |
29 |
30 | def accuracy(output, target, topk=(1,)):
31 | maxk = max(topk)
32 | batch_size = target.size(0)
33 |
34 | _, pred = output.topk(maxk, 1, True, True)
35 | pred = pred.t()
36 | correct = pred.eq(target.view(1, -1).expand_as(pred))
37 |
38 | res = []
39 | for k in topk:
40 | correct_k = correct[:k].view(-1).float().sum(0)
41 | res.append(correct_k.mul_(100.0/batch_size))
42 | return res
43 |
44 |
45 | class Cutout(object):
46 | def __init__(self, length):
47 | self.length = length
48 |
49 | def __call__(self, img):
50 | h, w = img.size(1), img.size(2)
51 | mask = np.ones((h, w), np.float32)
52 | y = np.random.randint(h)
53 | x = np.random.randint(w)
54 |
55 | y1 = np.clip(y - self.length // 2, 0, h)
56 | y2 = np.clip(y + self.length // 2, 0, h)
57 | x1 = np.clip(x - self.length // 2, 0, w)
58 | x2 = np.clip(x + self.length // 2, 0, w)
59 |
60 | mask[y1: y2, x1: x2] = 0.
61 | mask = torch.from_numpy(mask)
62 | mask = mask.expand_as(img)
63 | img *= mask
64 | return img
65 |
66 |
67 | def _data_transforms_cifar10(args):
68 | CIFAR_MEAN = [0.49139968, 0.48215827, 0.44653124]
69 | CIFAR_STD = [0.24703233, 0.24348505, 0.26158768]
70 |
71 | train_transform = transforms.Compose([
72 | transforms.RandomCrop(32, padding=4),
73 | transforms.RandomHorizontalFlip(),
74 | transforms.ToTensor(),
75 | transforms.Normalize(CIFAR_MEAN, CIFAR_STD),
76 | ])
77 | if args.cutout:
78 | train_transform.transforms.append(Cutout(args.cutout_length))
79 |
80 | valid_transform = transforms.Compose([
81 | transforms.ToTensor(),
82 | transforms.Normalize(CIFAR_MEAN, CIFAR_STD),
83 | ])
84 | return train_transform, valid_transform
85 |
86 |
87 | def count_parameters_in_MB(model):
88 | return np.sum(np.prod(v.size()) for name, v in model.named_parameters() if "auxiliary" not in name)/1e6
89 |
90 |
91 | def save_checkpoint(state, is_best, save):
92 | filename = os.path.join(save, 'checkpoint.pth.tar')
93 | torch.save(state, filename)
94 | if is_best:
95 | best_filename = os.path.join(save, 'model_best.pth.tar')
96 | shutil.copyfile(filename, best_filename)
97 |
98 |
99 | def save(model, model_path):
100 | torch.save(model.state_dict(), model_path)
101 |
102 |
103 | def load(model, model_path):
104 | model.load_state_dict(torch.load(model_path))
105 |
106 |
107 | def drop_path(x, drop_prob):
108 | if drop_prob > 0.:
109 | keep_prob = 1.-drop_prob
110 | mask = Variable(torch.cuda.FloatTensor(x.size(0), 1, 1, 1).bernoulli_(keep_prob))
111 | x.div_(keep_prob)
112 | x.mul_(mask)
113 | return x
114 |
115 |
116 | def create_exp_dir(path, scripts_to_save=None):
117 | if not os.path.exists(path):
118 | os.makedirs(path, exist_ok=True)
119 | print('Experiment dir : {}'.format(path))
120 |
121 | if scripts_to_save is not None:
122 | os.makedirs(os.path.join(path, 'scripts'), exist_ok=True)
123 | for script in scripts_to_save:
124 | dst_file = os.path.join(path, 'scripts', os.path.basename(script))
125 | shutil.copyfile(script, dst_file)
126 |
127 | def get_location(s, key):
128 | d = defaultdict(list)
129 | for k,va in [(v,i) for i,v in enumerate(s)]:
130 | d[k].append(va)
131 | return d[key]
132 |
133 | def list_substract(list1, list2):
134 | list1 = [item for item in list1 if item not in set(list2)]
135 | return list1
136 |
137 | def check_cand(cand, operations):
138 | cand = np.reshape(cand, [-1, config.edges])
139 | offset, cell_cand = 0, cand[0]
140 | for j in range(4):
141 | edges = cell_cand[offset:offset+j+2]
142 | edges_ops = operations[offset:offset+j+2]
143 | none_idxs = get_location(edges, 0)
144 | if len(none_idxs) < j:
145 | general_idxs = list_substract(range(j+2), none_idxs)
146 | num = min(j-len(none_idxs), len(general_idxs))
147 | general_idxs = np.random.choice(general_idxs, size=num, replace=False, p=None)
148 | for k in general_idxs:
149 | edges[k] = 0
150 | elif len(none_idxs) > j:
151 | none_idxs = np.random.choice(none_idxs, size=len(none_idxs)-j, replace=False, p=None)
152 | for k in none_idxs:
153 | if len(edges_ops[k]) > 1:
154 | l = np.random.randint(len(edges_ops[k])-1)
155 | edges[k] = edges_ops[k][l+1]
156 | offset += len(edges)
157 |
158 | return tuple(cell_cand)
159 |
--------------------------------------------------------------------------------
/repo_figures/ABS_FBS_architecture_normal.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/clovaai/GeNAS/fd3cba87b55c5638458a059c3e15e7f74469a7bb/repo_figures/ABS_FBS_architecture_normal.png
--------------------------------------------------------------------------------
/repo_figures/ABS_FBS_architecture_reduce.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/clovaai/GeNAS/fd3cba87b55c5638458a059c3e15e7f74469a7bb/repo_figures/ABS_FBS_architecture_reduce.png
--------------------------------------------------------------------------------
/repo_figures/ABS_architecture_normal.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/clovaai/GeNAS/fd3cba87b55c5638458a059c3e15e7f74469a7bb/repo_figures/ABS_architecture_normal.png
--------------------------------------------------------------------------------
/repo_figures/ABS_architecture_reduce.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/clovaai/GeNAS/fd3cba87b55c5638458a059c3e15e7f74469a7bb/repo_figures/ABS_architecture_reduce.png
--------------------------------------------------------------------------------
/repo_figures/FBS_architecture_normal.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/clovaai/GeNAS/fd3cba87b55c5638458a059c3e15e7f74469a7bb/repo_figures/FBS_architecture_normal.png
--------------------------------------------------------------------------------
/repo_figures/FBS_architecture_reduce.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/clovaai/GeNAS/fd3cba87b55c5638458a059c3e15e7f74469a7bb/repo_figures/FBS_architecture_reduce.png
--------------------------------------------------------------------------------
/repo_figures/motivation.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/clovaai/GeNAS/fd3cba87b55c5638458a059c3e15e7f74469a7bb/repo_figures/motivation.png
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | torch==1.7.1+cu110
2 | torchvision==0.8.2+cu110
3 | opencv-contrib-python==4.6.0
4 | matplotlib == 2.2.2
5 | numpy==1.20.0
6 | Tqdm == 4.64.1
7 | wget
--------------------------------------------------------------------------------
/retrain_architecture/config.py:
--------------------------------------------------------------------------------
1 | import os
2 | """
3 | MASTER_HOST: master node ip
4 | MASTER_PORT: master node port
5 | NODE_NUM: # nodes
6 | MY_RANK: current node idx
7 | GPU_NUM: # gpus per node
8 | """
9 | MASTER_HOST = os.environ["HOST_RANK0"]
10 | MASTER_PORT = 13322
11 | NODE_NUM = int(os.environ["WORLD_SIZE"])
12 | MY_RANK = int(os.environ["RANK"])
13 | GPU_NUM = int(os.environ["GPU_COUNT"])
--------------------------------------------------------------------------------
/retrain_architecture/genotypes.py:
--------------------------------------------------------------------------------
1 | '''
2 | Code adapted from https://github.com/megvii-model/RLNAS/blob/main/darts_search_space/cifar10/rlnas/retrain_architecture/genotypes.py
3 | '''
4 |
5 | from ast import parse
6 | from collections import namedtuple
7 |
8 | Genotype = namedtuple("Genotype", "normal normal_concat reduce reduce_concat")
9 |
10 | PRIMITIVES = [
11 | "none",
12 | "max_pool_3x3",
13 | "avg_pool_3x3",
14 | "skip_connect",
15 | "sep_conv_3x3",
16 | "sep_conv_5x5",
17 | "dil_conv_3x3",
18 | "dil_conv_5x5",
19 | ]
20 |
21 | NASNet = Genotype(
22 | normal=[
23 | ("sep_conv_5x5", 1),
24 | ("sep_conv_3x3", 0),
25 | ("sep_conv_5x5", 0),
26 | ("sep_conv_3x3", 0),
27 | ("avg_pool_3x3", 1),
28 | ("skip_connect", 0),
29 | ("avg_pool_3x3", 0),
30 | ("avg_pool_3x3", 0),
31 | ("sep_conv_3x3", 1),
32 | ("skip_connect", 1),
33 | ],
34 | normal_concat=[2, 3, 4, 5, 6],
35 | reduce=[
36 | ("sep_conv_5x5", 1),
37 | ("sep_conv_7x7", 0),
38 | ("max_pool_3x3", 1),
39 | ("sep_conv_7x7", 0),
40 | ("avg_pool_3x3", 1),
41 | ("sep_conv_5x5", 0),
42 | ("skip_connect", 3),
43 | ("avg_pool_3x3", 2),
44 | ("sep_conv_3x3", 2),
45 | ("max_pool_3x3", 1),
46 | ],
47 | reduce_concat=[4, 5, 6],
48 | )
49 |
50 | AmoebaNet = Genotype(
51 | normal=[
52 | ("avg_pool_3x3", 0),
53 | ("max_pool_3x3", 1),
54 | ("sep_conv_3x3", 0),
55 | ("sep_conv_5x5", 2),
56 | ("sep_conv_3x3", 0),
57 | ("avg_pool_3x3", 3),
58 | ("sep_conv_3x3", 1),
59 | ("skip_connect", 1),
60 | ("skip_connect", 0),
61 | ("avg_pool_3x3", 1),
62 | ],
63 | normal_concat=[4, 5, 6],
64 | reduce=[
65 | ("avg_pool_3x3", 0),
66 | ("sep_conv_3x3", 1),
67 | ("max_pool_3x3", 0),
68 | ("sep_conv_7x7", 2),
69 | ("sep_conv_7x7", 0),
70 | ("avg_pool_3x3", 1),
71 | ("max_pool_3x3", 0),
72 | ("max_pool_3x3", 1),
73 | ("conv_7x1_1x7", 0),
74 | ("sep_conv_3x3", 5),
75 | ],
76 | reduce_concat=[3, 4, 6],
77 | )
78 |
79 | DARTS_V1_CIFAR10 = Genotype(
80 | normal=[
81 | ("sep_conv_3x3", 1),
82 | ("sep_conv_3x3", 0),
83 | ("skip_connect", 0),
84 | ("sep_conv_3x3", 1),
85 | ("skip_connect", 0),
86 | ("sep_conv_3x3", 1),
87 | ("sep_conv_3x3", 0),
88 | ("skip_connect", 2),
89 | ],
90 | normal_concat=[2, 3, 4, 5],
91 | reduce=[
92 | ("max_pool_3x3", 0),
93 | ("max_pool_3x3", 1),
94 | ("skip_connect", 2),
95 | ("max_pool_3x3", 0),
96 | ("max_pool_3x3", 0),
97 | ("skip_connect", 2),
98 | ("skip_connect", 2),
99 | ("avg_pool_3x3", 0),
100 | ],
101 | reduce_concat=[2, 3, 4, 5],
102 | )
103 |
104 | DARTS_V2_CIFAR10 = Genotype(
105 | normal=[
106 | ("sep_conv_3x3", 0),
107 | ("sep_conv_3x3", 1),
108 | ("sep_conv_3x3", 0),
109 | ("sep_conv_3x3", 1),
110 | ("sep_conv_3x3", 1),
111 | ("skip_connect", 0),
112 | ("skip_connect", 0),
113 | ("dil_conv_3x3", 2),
114 | ],
115 | normal_concat=[2, 3, 4, 5],
116 | reduce=[
117 | ("max_pool_3x3", 0),
118 | ("max_pool_3x3", 1),
119 | ("skip_connect", 2),
120 | ("max_pool_3x3", 1),
121 | ("max_pool_3x3", 0),
122 | ("skip_connect", 2),
123 | ("skip_connect", 2),
124 | ("max_pool_3x3", 1),
125 | ],
126 | reduce_concat=[2, 3, 4, 5],
127 | )
128 |
129 |
130 | def parse_searched_cell(normal_reduce_cell):
131 | """
132 | normal_reduce_cell: list of normal + reduce cell.
133 | e.g) [
134 | 14 elements for normal cell edges where each element denote operation for each edge +
135 | 14 elements for reduce cell edges where each element denote operation for each edge
136 | ]
137 | """
138 | assert (
139 | len(normal_reduce_cell) == 28
140 | ), "cell should contain normal + reduce edges (14 + 14 = 28)"
141 | normal_cell = normal_reduce_cell[:14]
142 | reduce_cell = normal_reduce_cell[14:]
143 |
144 | normal_cell_decoded = []
145 | reduce_cell_decoded = []
146 | # normal cell decode
147 | for i in range(len(normal_cell)):
148 | # NOTE: for generating intermediate node 0
149 | if i in [0, 1]:
150 | normal_cell_decoded.append((PRIMITIVES[normal_cell[i]], i))
151 | # NOTE: for generating intermediate node 1
152 | elif i in [2, 3, 4]:
153 | if normal_cell[i] != 0:
154 | normal_cell_decoded.append((PRIMITIVES[normal_cell[i]], i - 2))
155 | # NOTE: for generating intermediate node 2
156 | elif i in [5, 6, 7, 8]:
157 | if normal_cell[i] != 0:
158 | normal_cell_decoded.append((PRIMITIVES[normal_cell[i]], i - 5))
159 | # NOTE: for generating intermediate node 3
160 | elif i in [9, 10, 11, 12, 13]:
161 | if normal_cell[i] != 0:
162 | normal_cell_decoded.append((PRIMITIVES[normal_cell[i]], i - 9))
163 |
164 | # reduce cell decode
165 | for i in range(len(reduce_cell)):
166 | # NOTE: for generating intermediate node 0
167 | if i in [0, 1]:
168 | reduce_cell_decoded.append((PRIMITIVES[reduce_cell[i]], i))
169 | # NOTE: for generating intermediate node 1
170 | elif i in [2, 3, 4]:
171 | if reduce_cell[i] != 0:
172 | reduce_cell_decoded.append((PRIMITIVES[reduce_cell[i]], i - 2))
173 | # NOTE: for generating intermediate node 2
174 | elif i in [5, 6, 7, 8]:
175 | if reduce_cell[i] != 0:
176 | reduce_cell_decoded.append((PRIMITIVES[reduce_cell[i]], i - 5))
177 | # NOTE: for generating intermediate node 3
178 | elif i in [9, 10, 11, 12, 13]:
179 | if reduce_cell[i] != 0:
180 | reduce_cell_decoded.append((PRIMITIVES[reduce_cell[i]], i - 9))
181 |
182 | return Genotype(
183 | normal=normal_cell_decoded,
184 | normal_concat=[2, 3, 4, 5],
185 | reduce=reduce_cell_decoded,
186 | reduce_concat=[2, 3, 4, 5],
187 | )
188 |
189 |
190 | # print(parse_searched_cell((5, 4, 6, 0, 5, 4, 0, 0, 7, 0, 7, 6, 0, 0, 5, 6, 4, 0, 5, 0, 6, 0, 4, 0, 1, 5, 0, 0)))
191 | # cand=[[5, 4, 6, 0, 5, 4, 0, 0, 7, 0, 7, 6, 0, 0], [5, 6, 4, 0, 5, 0, 6, 0, 4, 0, 1, 5, 0, 0]]
192 | # NOTE: [5, 4, 6, 0, 5, 4, 0, 0, 7, 0, 7, 6, 0, 0] means
193 | # NOTE: [5, 4]: for generating intermediate node 0, operation 5:sep_conv_5x5(k-2 node(prev prev cell output)) + operation 4: sep_conv_3x3(k-1 node (prev cell output))
194 | # NOTE: [6, 0, 5]: for generating intermediate node 1, operation 6:dil_conv_3x3(itm node 0) + operation 5:sep_conv_5x5(itm node 2)
195 | # NOTE: [4, 0, 0, 7]: for generating intermediate node 2, operation 4:sep_conv_3x3(itm node 0) + operation 7:dil_conv_5x5(itm node 3)
196 | # NOTE: [0, 7, 6, 0, 0]: for generating intermediate node 3, operation 7:dil_conv_5x5(itm node 1) + operation 6:dil_conv_3x3(itm node 2)
197 | # NOTE: all intermediate node outputs (0, 1, 2, 3) are concatenated to be the output of current cell.
198 | # RLDARTS = Genotype(
199 | # normal=[('sep_conv_5x5', 0), ('sep_conv_3x3', 1), ('dil_conv_3x3', 0), ('sep_conv_5x5', 2), ('sep_conv_3x3', 0),
200 | # ('dil_conv_5x5', 3), ('dil_conv_5x5', 1), ('dil_conv_3x3', 2)], normal_concat=[2, 3, 4, 5],
201 | # reduce=[('sep_conv_5x5', 0), ('dil_conv_3x3', 1), ('sep_conv_3x3', 0), ('sep_conv_5x5', 2), ('dil_conv_3x3', 1),
202 | # ('sep_conv_3x3', 3), ('max_pool_3x3', 1), ('sep_conv_5x5', 2)], reduce_concat=[2, 3, 4, 5])
203 | RLDARTS_OURS_GT = parse_searched_cell(
204 | (5, 5, 2, 5, 0, 4, 4, 0, 0, 0, 4, 0, 0, 4, 3, 3, 0, 3, 2, 3, 7, 0, 0, 3, 0, 0, 0, 4)
205 | )
206 | PCDARTS_OURS_SEARCHEPOCH40 = Genotype(
207 | normal=[
208 | ("sep_conv_3x3", 0),
209 | ("sep_conv_3x3", 1),
210 | ("sep_conv_3x3", 1),
211 | ("sep_conv_5x5", 0),
212 | ("sep_conv_5x5", 1),
213 | ("sep_conv_3x3", 3),
214 | ("sep_conv_5x5", 4),
215 | ("sep_conv_3x3", 0),
216 | ],
217 | normal_concat=range(2, 6),
218 | reduce=[
219 | ("max_pool_3x3", 0),
220 | ("sep_conv_3x3", 1),
221 | ("max_pool_3x3", 0),
222 | ("skip_connect", 1),
223 | ("sep_conv_5x5", 2),
224 | ("sep_conv_3x3", 0),
225 | ("skip_connect", 0),
226 | ("dil_conv_3x3", 4),
227 | ],
228 | reduce_concat=range(2, 6),
229 | )
230 |
231 | PCDARTS_OURS = Genotype(
232 | normal=[
233 | ("dil_conv_5x5", 1),
234 | ("dil_conv_3x3", 0),
235 | ("dil_conv_5x5", 1),
236 | ("max_pool_3x3", 0),
237 | ("dil_conv_3x3", 0),
238 | ("sep_conv_3x3", 3),
239 | ("sep_conv_3x3", 0),
240 | ("sep_conv_3x3", 2),
241 | ],
242 | normal_concat=range(2, 6),
243 | reduce=[
244 | ("max_pool_3x3", 0),
245 | ("max_pool_3x3", 1),
246 | ("max_pool_3x3", 0),
247 | ("skip_connect", 2),
248 | ("sep_conv_5x5", 2),
249 | ("skip_connect", 1),
250 | ("sep_conv_3x3", 0),
251 | ("sep_conv_3x3", 2),
252 | ],
253 | reduce_concat=range(2, 6),
254 | )
255 |
256 | # PDARTS searched on CIFAR-10
257 | PDARTS_CIFAR10 = Genotype(
258 | normal=[
259 | ("skip_connect", 0),
260 | ("dil_conv_3x3", 1),
261 | ("skip_connect", 0),
262 | ("sep_conv_3x3", 1),
263 | ("sep_conv_3x3", 1),
264 | ("sep_conv_3x3", 3),
265 | ("sep_conv_3x3", 0),
266 | ("dil_conv_5x5", 4),
267 | ],
268 | normal_concat=range(2, 6),
269 | reduce=[
270 | ("avg_pool_3x3", 0),
271 | ("sep_conv_5x5", 1),
272 | ("sep_conv_3x3", 0),
273 | ("dil_conv_5x5", 2),
274 | ("max_pool_3x3", 0),
275 | ("dil_conv_3x3", 1),
276 | ("dil_conv_3x3", 1),
277 | ("dil_conv_5x5", 3),
278 | ],
279 | reduce_concat=range(2, 6),
280 | )
281 |
282 |
283 | # DARTS-v1 searched on CIFAR-100
284 | DARTS_V1_CIFAR100 = Genotype(
285 | normal=[
286 | ("skip_connect", 0),
287 | ("sep_conv_3x3", 1),
288 | ("skip_connect", 0),
289 | ("sep_conv_3x3", 1),
290 | ("skip_connect", 0),
291 | ("skip_connect", 1),
292 | ("skip_connect", 0),
293 | ("skip_connect", 1),
294 | ],
295 | normal_concat=range(2, 6),
296 | reduce=[
297 | ("avg_pool_3x3", 0),
298 | ("avg_pool_3x3", 1),
299 | ("avg_pool_3x3", 0),
300 | ("skip_connect", 2),
301 | ("skip_connect", 2),
302 | ("avg_pool_3x3", 0),
303 | ("skip_connect", 2),
304 | ("avg_pool_3x3", 0),
305 | ],
306 | reduce_concat=range(2, 6),
307 | )
308 |
309 | SDARTS_RS_CIFAR10 = Genotype(
310 | normal=[
311 | ("sep_conv_3x3", 1),
312 | ("sep_conv_3x3", 0),
313 | ("sep_conv_5x5", 1),
314 | ("skip_connect", 0),
315 | ("sep_conv_3x3", 3),
316 | ("skip_connect", 1),
317 | ("sep_conv_3x3", 1),
318 | ("dil_conv_3x3", 2),
319 | ],
320 | normal_concat=range(2, 6),
321 | reduce=[
322 | ("max_pool_3x3", 0),
323 | ("sep_conv_3x3", 1),
324 | ("skip_connect", 2),
325 | ("max_pool_3x3", 0),
326 | ("dil_conv_5x5", 3),
327 | ("max_pool_3x3", 0),
328 | ("sep_conv_3x3", 2),
329 | ("sep_conv_5x5", 3),
330 | ],
331 | reduce_concat=range(2, 6),
332 | )
333 |
334 | SDARTS_ADV_CIFAR10 = Genotype(
335 | normal=[
336 | ("sep_conv_3x3", 0),
337 | ("sep_conv_3x3", 1),
338 | ("sep_conv_3x3", 1),
339 | ("skip_connect", 0),
340 | ("sep_conv_5x5", 0),
341 | ("dil_conv_3x3", 3),
342 | ("dil_conv_3x3", 4),
343 | ("skip_connect", 0),
344 | ],
345 | normal_concat=range(2, 6),
346 | reduce=[
347 | ("max_pool_3x3", 0),
348 | ("sep_conv_5x5", 1),
349 | ("skip_connect", 2),
350 | ("max_pool_3x3", 0),
351 | ("skip_connect", 3),
352 | ("skip_connect", 2),
353 | ("skip_connect", 2),
354 | ("sep_conv_5x5", 4),
355 | ],
356 | reduce_concat=range(2, 6),
357 | )
358 |
359 | DROPNAS = Genotype(
360 | normal=[
361 | ("skip_connect", 0),
362 | ("sep_conv_3x3", 1),
363 | ("sep_conv_3x3", 1),
364 | ("max_pool_3x3", 2),
365 | ("sep_conv_3x3", 1),
366 | ("sep_conv_5x5", 2),
367 | ("sep_conv_5x5", 0),
368 | ("sep_conv_5x5", 1),
369 | ],
370 | normal_concat=[2, 3, 4, 5],
371 | reduce=[
372 | ("max_pool_3x3", 0),
373 | ("sep_conv_5x5", 1),
374 | ("dil_conv_5x5", 2),
375 | ("sep_conv_5x5", 1),
376 | ("dil_conv_5x5", 2),
377 | ("dil_conv_5x5", 3),
378 | ("dil_conv_5x5", 2),
379 | ("dil_conv_5x5", 4),
380 | ],
381 | reduce_concat=[2, 3, 4, 5],
382 | )
383 |
384 | GENAS_FLATNESS_CIFAR10 = parse_searched_cell(
385 | (3, 6, 0, 4, 4, 0, 6, 6, 0, 0, 0, 0, 4, 3, 5, 2, 0, 6, 4, 0, 3, 0, 6, 0, 0, 5, 0, 7)
386 | )
387 |
388 | GENAS_ANGLE_FLATNESS_CIFAR10 = parse_searched_cell(
389 | (5, 4, 4, 4, 0, 4, 5, 0, 0, 4, 0, 0, 0, 4, 7, 7, 2, 3, 0, 4, 1, 0, 0, 0, 0, 0, 5, 6)
390 | )
391 |
392 | GENAS_FLATNESS_CIFAR100 = parse_searched_cell(
393 | (3, 5, 0, 4, 4, 0, 0, 4, 5, 0, 4, 0, 0, 7, 6, 3, 2, 0, 5, 0, 5, 0, 7, 0, 1, 1, 0, 0)
394 | )
395 |
396 | GENAS_ANGLE_FLATNESS_CIFAR100 = parse_searched_cell(
397 | (3, 5, 1, 4, 0, 0, 4, 0, 4, 0, 4, 0, 0, 5, 6, 4, 0, 7, 2, 1, 3, 0, 0, 0, 4, 0, 0, 7)
398 | )
--------------------------------------------------------------------------------
/retrain_architecture/model.py:
--------------------------------------------------------------------------------
1 | '''
2 | Copied from https://github.com/megvii-model/RLNAS/blob/main/darts_search_space/cifar10/rlnas/retrain_architecture/model.py
3 | '''
4 |
5 | import torch
6 | import torch.nn as nn
7 | from operations import *
8 | from torch.autograd import Variable
9 | from utils import drop_path
10 |
11 |
12 | class Cell(nn.Module):
13 |
14 | def __init__(self, genotype, C_prev_prev, C_prev, C, reduction, reduction_prev):
15 | super(Cell, self).__init__()
16 | print(C_prev_prev, C_prev, C)
17 |
18 | if reduction_prev:
19 | self.preprocess0 = FactorizedReduce(C_prev_prev, C)
20 | else:
21 | self.preprocess0 = ReLUConvBN(C_prev_prev, C, 1, 1, 0)
22 | self.preprocess1 = ReLUConvBN(C_prev, C, 1, 1, 0)
23 |
24 | if reduction:
25 | op_names, indices = zip(*genotype.reduce)
26 | concat = genotype.reduce_concat
27 | else:
28 | op_names, indices = zip(*genotype.normal)
29 | concat = genotype.normal_concat
30 | self._compile(C, op_names, indices, concat, reduction)
31 |
32 | def _compile(self, C, op_names, indices, concat, reduction):
33 | assert len(op_names) == len(indices)
34 | self._steps = len(op_names) // 2
35 | self._concat = concat
36 | self.multiplier = len(concat)
37 |
38 | self._ops = nn.ModuleList()
39 | for name, index in zip(op_names, indices):
40 | stride = 2 if reduction and index < 2 else 1
41 | op = OPS[name](C, stride, True)
42 | self._ops += [op]
43 | self._indices = indices
44 |
45 | def forward(self, s0, s1, drop_prob):
46 | s0 = self.preprocess0(s0)
47 | s1 = self.preprocess1(s1)
48 |
49 | states = [s0, s1]
50 | for i in range(self._steps):
51 | h1 = states[self._indices[2 * i]]
52 | h2 = states[self._indices[2 * i + 1]]
53 | op1 = self._ops[2 * i]
54 | op2 = self._ops[2 * i + 1]
55 | h1 = op1(h1)
56 | h2 = op2(h2)
57 | if self.training and drop_prob > 0.:
58 | if not isinstance(op1, Identity):
59 | h1 = drop_path(h1, drop_prob)
60 | if not isinstance(op2, Identity):
61 | h2 = drop_path(h2, drop_prob)
62 | s = h1 + h2
63 | states += [s]
64 | return torch.cat([states[i] for i in self._concat], dim=1)
65 |
66 |
67 | class AuxiliaryHeadCIFAR(nn.Module):
68 |
69 | def __init__(self, C, num_classes):
70 | """assuming input size 8x8"""
71 | super(AuxiliaryHeadCIFAR, self).__init__()
72 | self.features = nn.Sequential(
73 | nn.ReLU(inplace=True),
74 | nn.AvgPool2d(5, stride=3, padding=0, count_include_pad=False), # image size = 2 x 2
75 | nn.Conv2d(C, 128, 1, bias=False),
76 | nn.BatchNorm2d(128),
77 | nn.ReLU(inplace=True),
78 | nn.Conv2d(128, 768, 2, bias=False),
79 | nn.BatchNorm2d(768),
80 | nn.ReLU(inplace=True)
81 | )
82 | self.classifier = nn.Linear(768, num_classes)
83 |
84 | def forward(self, x):
85 | x = self.features(x)
86 | x = self.classifier(x.view(x.size(0), -1))
87 | return x
88 |
89 |
90 | class AuxiliaryHeadImageNet(nn.Module):
91 |
92 | def __init__(self, C, num_classes):
93 | """assuming input size 14x14"""
94 | super(AuxiliaryHeadImageNet, self).__init__()
95 | self.features = nn.Sequential(
96 | nn.ReLU(inplace=True),
97 | nn.AvgPool2d(5, stride=2, padding=0, count_include_pad=False),
98 | nn.Conv2d(C, 128, 1, bias=False),
99 | nn.BatchNorm2d(128),
100 | nn.ReLU(inplace=True),
101 | nn.Conv2d(128, 768, 2, bias=False),
102 | # NOTE: This batchnorm was omitted in my earlier implementation due to a typo.
103 | # Commenting it out for consistency with the experiments in the paper.
104 | # nn.BatchNorm2d(768),
105 | nn.ReLU(inplace=True)
106 | )
107 | self.classifier = nn.Linear(768, num_classes)
108 |
109 | def forward(self, x):
110 | x = self.features(x)
111 | x = self.classifier(x.view(x.size(0), -1))
112 | return x
113 |
114 |
115 | class NetworkCIFAR(nn.Module):
116 |
117 | def __init__(self, C, num_classes, layers, auxiliary, genotype):
118 | super(NetworkCIFAR, self).__init__()
119 | self._layers = layers
120 | self._auxiliary = auxiliary
121 |
122 | stem_multiplier = 3
123 | C_curr = stem_multiplier * C
124 | self.stem = nn.Sequential(
125 | nn.Conv2d(3, C_curr, 3, padding=1, bias=False),
126 | nn.BatchNorm2d(C_curr)
127 | )
128 |
129 | C_prev_prev, C_prev, C_curr = C_curr, C_curr, C
130 | self.cells = nn.ModuleList()
131 | reduction_prev = False
132 | for i in range(layers):
133 | if i in [layers // 3, 2 * layers // 3]:
134 | C_curr *= 2
135 | reduction = True
136 | else:
137 | reduction = False
138 | cell = Cell(genotype, C_prev_prev, C_prev, C_curr, reduction, reduction_prev)
139 | reduction_prev = reduction
140 | self.cells += [cell]
141 | C_prev_prev, C_prev = C_prev, cell.multiplier * C_curr
142 | if i == 2 * layers // 3:
143 | C_to_auxiliary = C_prev
144 |
145 | if auxiliary:
146 | self.auxiliary_head = AuxiliaryHeadCIFAR(C_to_auxiliary, num_classes)
147 | self.global_pooling = nn.AdaptiveAvgPool2d(1)
148 | self.classifier = nn.Linear(C_prev, num_classes)
149 |
150 | def forward(self, input):
151 | logits_aux = None
152 | s0 = s1 = self.stem(input)
153 | for i, cell in enumerate(self.cells):
154 | s0, s1 = s1, cell(s0, s1, self.drop_path_prob)
155 | if i == 2 * self._layers // 3:
156 | if self._auxiliary and self.training:
157 | logits_aux = self.auxiliary_head(s1)
158 | out = self.global_pooling(s1)
159 | logits = self.classifier(out.view(out.size(0), -1))
160 | return logits, logits_aux
161 |
162 |
163 | class NetworkImageNet(nn.Module):
164 |
165 | def __init__(self, C, num_classes, layers, auxiliary, genotype):
166 | super(NetworkImageNet, self).__init__()
167 | self._layers = layers
168 | self._auxiliary = auxiliary
169 |
170 | self.stem0 = nn.Sequential(
171 | nn.Conv2d(3, C // 2, kernel_size=3, stride=2, padding=1, bias=False),
172 | nn.BatchNorm2d(C // 2),
173 | nn.ReLU(inplace=True),
174 | nn.Conv2d(C // 2, C, 3, stride=2, padding=1, bias=False),
175 | nn.BatchNorm2d(C),
176 | )
177 |
178 | self.stem1 = nn.Sequential(
179 | nn.ReLU(inplace=True),
180 | nn.Conv2d(C, C, 3, stride=2, padding=1, bias=False),
181 | nn.BatchNorm2d(C),
182 | )
183 |
184 | C_prev_prev, C_prev, C_curr = C, C, C
185 |
186 | self.cells = nn.ModuleList()
187 | reduction_prev = True
188 | for i in range(layers):
189 | if i in [layers // 3, 2 * layers // 3]:
190 | C_curr *= 2
191 | reduction = True
192 | else:
193 | reduction = False
194 | cell = Cell(genotype, C_prev_prev, C_prev, C_curr, reduction, reduction_prev)
195 | reduction_prev = reduction
196 | self.cells += [cell]
197 | C_prev_prev, C_prev = C_prev, cell.multiplier * C_curr
198 | if i == 2 * layers // 3:
199 | C_to_auxiliary = C_prev
200 |
201 | if auxiliary:
202 | self.auxiliary_head = AuxiliaryHeadImageNet(C_to_auxiliary, num_classes)
203 | self.global_pooling = nn.AvgPool2d(7)
204 | self.classifier = nn.Linear(C_prev, num_classes)
205 |
206 | def forward(self, input):
207 | logits_aux = None
208 | s0 = self.stem0(input)
209 | s1 = self.stem1(s0)
210 | for i, cell in enumerate(self.cells):
211 | s0, s1 = s1, cell(s0, s1, self.drop_path_prob)
212 | if i == 2 * self._layers // 3:
213 | if self._auxiliary and self.training:
214 | logits_aux = self.auxiliary_head(s1)
215 | out = self.global_pooling(s1)
216 | logits = self.classifier(out.view(out.size(0), -1))
217 | return logits, logits_aux
218 |
--------------------------------------------------------------------------------
/retrain_architecture/operations.py:
--------------------------------------------------------------------------------
1 | '''
2 | Copied from https://github.com/megvii-model/RLNAS/blob/main/darts_search_space/cifar10/rlnas/retrain_architecture/operations.py
3 | '''
4 |
5 | import torch
6 | import torch.nn as nn
7 |
8 | OPS = {
9 | 'none' : lambda C, stride, affine: Zero(stride),
10 | 'avg_pool_3x3' : lambda C, stride, affine: nn.AvgPool2d(3, stride=stride, padding=1, count_include_pad=False),
11 | 'max_pool_3x3' : lambda C, stride, affine: nn.MaxPool2d(3, stride=stride, padding=1),
12 | 'skip_connect' : lambda C, stride, affine: Identity() if stride == 1 else FactorizedReduce(C, C, affine=affine),
13 | 'sep_conv_3x3' : lambda C, stride, affine: SepConv(C, C, 3, stride, 1, affine=affine),
14 | 'sep_conv_5x5' : lambda C, stride, affine: SepConv(C, C, 5, stride, 2, affine=affine),
15 | 'sep_conv_7x7' : lambda C, stride, affine: SepConv(C, C, 7, stride, 3, affine=affine),
16 | 'dil_conv_3x3' : lambda C, stride, affine: DilConv(C, C, 3, stride, 2, 2, affine=affine),
17 | 'dil_conv_5x5' : lambda C, stride, affine: DilConv(C, C, 5, stride, 4, 2, affine=affine),
18 | 'conv_7x1_1x7' : lambda C, stride, affine: nn.Sequential(
19 | nn.ReLU(inplace=False),
20 | nn.Conv2d(C, C, (1,7), stride=(1, stride), padding=(0, 3), bias=False),
21 | nn.Conv2d(C, C, (7,1), stride=(stride, 1), padding=(3, 0), bias=False),
22 | nn.BatchNorm2d(C, affine=affine)
23 | ),
24 | }
25 |
26 | class ReLUConvBN(nn.Module):
27 |
28 | def __init__(self, C_in, C_out, kernel_size, stride, padding, affine=True):
29 | super(ReLUConvBN, self).__init__()
30 | self.op = nn.Sequential(
31 | nn.ReLU(inplace=False),
32 | nn.Conv2d(C_in, C_out, kernel_size, stride=stride, padding=padding, bias=False),
33 | nn.BatchNorm2d(C_out, affine=affine)
34 | )
35 |
36 | def forward(self, x, rngs=None):
37 | return self.op(x)
38 |
39 | class DilConv(nn.Module):
40 |
41 | def __init__(self, C_in, C_out, kernel_size, stride, padding, dilation, affine=True):
42 | super(DilConv, self).__init__()
43 | self.op = nn.Sequential(
44 | nn.ReLU(inplace=False),
45 | nn.Conv2d(C_in, C_in, kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation, groups=C_in, bias=False),
46 | nn.Conv2d(C_in, C_out, kernel_size=1, padding=0, bias=False),
47 | nn.BatchNorm2d(C_out, affine=affine),
48 | )
49 |
50 | def forward(self, x, rngs=None):
51 | return self.op(x)
52 |
53 | class SepConv(nn.Module):
54 |
55 | def __init__(self, C_in, C_out, kernel_size, stride, padding, affine=True):
56 | super(SepConv, self).__init__()
57 | self.op = nn.Sequential(
58 | nn.ReLU(inplace=False),
59 | nn.Conv2d(C_in, C_in, kernel_size=kernel_size, stride=stride, padding=padding, groups=C_in, bias=False),
60 | nn.Conv2d(C_in, C_in, kernel_size=1, padding=0, bias=False),
61 | nn.BatchNorm2d(C_in, affine=affine),
62 | nn.ReLU(inplace=False),
63 | nn.Conv2d(C_in, C_in, kernel_size=kernel_size, stride=1, padding=padding, groups=C_in, bias=False),
64 | nn.Conv2d(C_in, C_out, kernel_size=1, padding=0, bias=False),
65 | nn.BatchNorm2d(C_out, affine=affine),
66 | )
67 |
68 | def forward(self, x, rngs=None):
69 | return self.op(x)
70 |
71 |
72 | class Identity(nn.Module):
73 |
74 | def __init__(self):
75 | super(Identity, self).__init__()
76 |
77 | def forward(self, x, rngs=None):
78 | return x
79 |
80 | class Zero(nn.Module):
81 |
82 | def __init__(self, stride):
83 | super(Zero, self).__init__()
84 | self.stride = stride
85 | def forward(self, x, rngs=None):
86 | n, c, h, w = x.size()
87 | h //= self.stride
88 | w //= self.stride
89 | if x.is_cuda:
90 | with torch.cuda.device(x.get_device()):
91 | padding = torch.cuda.FloatTensor(n, c, h, w).fill_(0)
92 | else:
93 | padding = torch.FloatTensor(n, c, h, w).fill_(0)
94 | return padding
95 |
96 | class FactorizedReduce(nn.Module):
97 |
98 | def __init__(self, C_in, C_out, affine=True):
99 | super(FactorizedReduce, self).__init__()
100 | assert C_out % 2 == 0
101 | self.relu = nn.ReLU(inplace=False)
102 | self.conv_1 = nn.Conv2d(C_in, C_out // 2, 1, stride=2, padding=0, bias=False)
103 | self.conv_2 = nn.Conv2d(C_in, C_out // 2, 1, stride=2, padding=0, bias=False)
104 | self.bn = nn.BatchNorm2d(C_out, affine=affine)
105 |
106 | def forward(self, x, rngs=None):
107 | x = self.relu(x)
108 | out = torch.cat([self.conv_1(x), self.conv_2(x[:,:,1:,1:])], dim=1)
109 | out = self.bn(out)
110 | return out
111 |
112 |
113 |
--------------------------------------------------------------------------------
/retrain_architecture/retrain.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 | import numpy as np
4 | import time
5 | import torch
6 | import utils
7 | import glob
8 | import random
9 | import logging
10 | import argparse
11 | import torch.nn as nn
12 | import genotypes
13 | import torch.utils
14 | import torchvision.transforms as transforms
15 | import torch.backends.cudnn as cudnn
16 | import time
17 | import torch.multiprocessing as mp
18 | import torch.distributed as dist
19 | from torch.autograd import Variable
20 | from model import NetworkImageNet as Network
21 | from tensorboardX import SummaryWriter
22 | from thop import profile
23 | import torchvision.datasets as datasets
24 | from config import (
25 | MASTER_HOST,
26 | MASTER_PORT,
27 | NODE_NUM,
28 | MY_RANK,
29 | GPU_NUM,
30 | )
31 |
32 | parser = argparse.ArgumentParser("training imagenet")
33 | parser.add_argument(
34 | "--data_root", type=str, required=True, help="imagenet dataset root directory"
35 | )
36 | parser.add_argument(
37 | "--workers", type=int, default=32, help="number of workers to load dataset"
38 | )
39 | parser.add_argument("--batch_size", type=int, default=1024, help="batch size")
40 | parser.add_argument(
41 | "--learning_rate", type=float, default=0.5, help="init learning rate"
42 | )
43 | parser.add_argument("--momentum", type=float, default=0.9, help="momentum")
44 | parser.add_argument("--weight_decay", type=float, default=3e-5, help="weight decay")
45 | parser.add_argument("--report_freq", type=float, default=100, help="report frequency")
46 | parser.add_argument("--epochs", type=int, default=250, help="num of training epochs")
47 | parser.add_argument(
48 | "--init_channels", type=int, default=48, help="num of init channels"
49 | )
50 | parser.add_argument("--layers", type=int, default=14, help="total number of layers")
51 | parser.add_argument(
52 | "--auxiliary", action="store_true", default=False, help="use auxiliary tower"
53 | )
54 | parser.add_argument(
55 | "--auxiliary_weight", type=float, default=0.4, help="weight for auxiliary loss"
56 | )
57 | parser.add_argument(
58 | "--drop_path_prob", type=float, default=0, help="drop path probability"
59 | )
60 | parser.add_argument("--save", type=str, default="test", help="experiment name")
61 | parser.add_argument("--seed", type=int, default=0, help="random seed")
62 | parser.add_argument(
63 | "--arch", type=str, default="PDARTS", help="which architecture to use"
64 | )
65 | parser.add_argument("--grad_clip", type=float, default=5.0, help="gradient clipping")
66 | parser.add_argument("--label_smooth", type=float, default=0.1, help="label smoothing")
67 | parser.add_argument(
68 | "--lr_scheduler", type=str, default="linear", help="lr scheduler, linear or cosine"
69 | )
70 | parser.add_argument("--note", type=str, default="try", help="note for this run")
71 |
72 | args, unparsed = parser.parse_known_args()
73 |
74 | # args.save = "eval-{}".format(args.save)
75 |
76 | if not os.path.exists(args.save):
77 | utils.create_exp_dir(args.save, scripts_to_save=glob.glob("*.py"))
78 |
79 | time.sleep(1)
80 | log_format = "%(asctime)s %(message)s"
81 | logging.basicConfig(
82 | stream=sys.stdout,
83 | level=logging.INFO,
84 | format=log_format,
85 | datefmt="%m/%d %I:%M:%S %p",
86 | )
87 | fh = logging.FileHandler(os.path.join(args.save, "log.txt"))
88 | fh.setFormatter(logging.Formatter(log_format))
89 | logging.getLogger().addHandler(fh)
90 | writer = SummaryWriter(logdir=args.save)
91 |
92 | IMAGENET_TRAINING_SET_SIZE = 1281167
93 | IMAGENET_TEST_SET_SIZE = 50000
94 | CLASSES = 1000
95 | train_iters = (
96 | IMAGENET_TRAINING_SET_SIZE // args.batch_size
97 | ) # NOTE: for each training iteration, all gpus on multiple nodes take args.batch_size (1024) // (# gpu per node (4)* # node (2)) = 128 imgs, which are gathered to be args.batch_size=1024 in DistributedDataParallel.
98 | val_iters = (
99 | IMAGENET_TEST_SET_SIZE // args.batch_size
100 | ) # NOTE: Without DistributedDataParallel. Thus, single GPU (gpu id = 0 per node) takes args.batch_size = 1024 imgs.
101 |
102 | # Average loss across processes for logging.
103 | def reduce_tensor(tensor, device=0, world_size=1):
104 | tensor = tensor.clone()
105 | dist.reduce(tensor, device)
106 | tensor.div_(world_size)
107 | return tensor
108 |
109 |
110 | class CrossEntropyLabelSmooth(nn.Module):
111 | def __init__(self, num_classes, epsilon):
112 | super(CrossEntropyLabelSmooth, self).__init__()
113 | self.num_classes = num_classes
114 | self.epsilon = epsilon
115 | self.logsoftmax = nn.LogSoftmax(dim=1)
116 |
117 | def forward(self, inputs, targets):
118 | log_probs = self.logsoftmax(inputs)
119 | targets = torch.zeros_like(log_probs).scatter_(1, targets.unsqueeze(1), 1)
120 | targets = (1 - self.epsilon) * targets + self.epsilon / self.num_classes
121 | loss = (-targets * log_probs).mean(0).sum()
122 | return loss
123 |
124 |
125 | def main(local_rank, *args):
126 | # NOTE: local_rank is reserved from mp.spawn, which denotes local gpu id inside current node.
127 | args = args[0] # NOTE: take arguments
128 | if not torch.cuda.is_available():
129 | logging.info("No GPU device available")
130 | sys.exit(1)
131 |
132 | num_gpus_per_node = GPU_NUM # NOTE: num gpus per node.
133 |
134 | np.random.seed(args.seed)
135 | cudnn.benchmark = True
136 | cudnn.deterministic = True
137 |
138 | torch.manual_seed(args.seed)
139 | cudnn.enabled = True
140 | torch.cuda.manual_seed(args.seed)
141 | logging.info("args = %s", args)
142 | logging.info("unparsed_args = %s", unparsed)
143 |
144 | n_nodes = NODE_NUM
145 | args.world_size = n_nodes * num_gpus_per_node
146 | assert (
147 | args.world_size == 8
148 | ), "world_size is not 8." # for reproducibility
149 | args.dist_url = "tcp://{}:{}".format(MASTER_HOST, MASTER_PORT)
150 | args.distributed = args.world_size > 1 # NOTE: whether using distributed or not
151 | os.environ["NCCL_DEBUG"] = "info"
152 | # os.environ["NCCL_SOCKET_IFNAME"] = "bond0"
153 | print("master addr: {} with {} node(s)".format(args.dist_url, n_nodes))
154 |
155 | global_rank = (
156 | num_gpus_per_node * MY_RANK + local_rank
157 | ) # global gpu id over all gpus over all nodes
158 | # NOTE: init DDP connection
159 | torch.distributed.init_process_group(
160 | backend="nccl",
161 | init_method=args.dist_url,
162 | world_size=args.world_size,
163 | rank=global_rank,
164 | )
165 | print("init process group finished...")
166 |
167 | # reset batch size accordingly with number of total processes over all nodes
168 | args.batch_size = (
169 | args.batch_size // args.world_size
170 | ) # 1024 (original batch_size) // 8 (# total processes over all nodes) = 128
171 |
172 | # Data loading
173 | traindir = os.path.join(args.data_root, "train")
174 | valdir = os.path.join(args.data_root, "val")
175 | train_transform = utils.get_train_transform()
176 | eval_transform = utils.get_eval_transform()
177 | print("train dataset preparing...")
178 | train_dataset = datasets.ImageFolder(root=traindir, transform=train_transform)
179 | print("train dataset prepared...")
180 | val_dataset = datasets.ImageFolder(root=valdir, transform=eval_transform)
181 | print("val dataset prepared...")
182 |
183 | if args.distributed:
184 | # NOTE: train_sampler assigned to each process over all process on multiple nodes
185 | train_sampler = torch.utils.data.distributed.DistributedSampler(
186 | train_dataset, num_replicas=args.world_size, rank=global_rank
187 | )
188 | else:
189 | train_sampler = None
190 |
191 | # NOTE: for each training iteration, each gpu on multiple nodes take args.batch_size (1024) // (# gpu per node (4)* # node (2)) = 128 imgs, which are gathered to be args.batch_size=1024 in DistributedDataParallel.
192 | train_loader = torch.utils.data.DataLoader(
193 | train_dataset,
194 | batch_size=args.batch_size,
195 | shuffle=(train_sampler is None),
196 | num_workers=args.workers // args.world_size,
197 | pin_memory=True,
198 | sampler=train_sampler,
199 | )
200 |
201 | # NOTE: Without DistributedDataParallel. Thus, single GPU (gpu id = 0 per node) takes args.batch_size = 1024 imgs.
202 | val_loader = torch.utils.data.DataLoader(
203 | val_dataset,
204 | batch_size=args.batch_size,
205 | shuffle=False,
206 | num_workers=args.workers // args.world_size,
207 | pin_memory=True,
208 | )
209 |
210 | genotype = eval("genotypes.%s" % args.arch)
211 | logging.info("---------Genotype---------")
212 | logging.info(genotype)
213 | logging.info("--------------------------")
214 | torch.cuda.set_device(
215 | local_rank
216 | ) # NOTE: enforce using current assigned gpu id given by mp.spawn
217 | model = Network(args.init_channels, CLASSES, args.layers, args.auxiliary, genotype)
218 | model = model.cuda(
219 | local_rank
220 | ) # NOTE: enforce using current assigned gpu id given by mp.spawn
221 | print(
222 | "local rank: ",
223 | local_rank,
224 | "model deployed on : ",
225 | next(model.parameters()).device,
226 | )
227 | # model = apex.parallel.DistributedDataParallel(model, delay_allreduce=True)
228 | # model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.local_rank],
229 | # output_device=args.local_rank, broadcast_buffers=False)
230 | model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[local_rank])
231 | model_profile = Network(
232 | args.init_channels, CLASSES, args.layers, args.auxiliary, genotype
233 | )
234 | model_profile = model_profile.cuda(
235 | local_rank
236 | ) # NOTE: enforce using current assigned gpu id given by mp.spawn
237 | model_input_size_imagenet = (1, 3, 224, 224)
238 | model_profile.drop_path_prob = 0
239 | flops, _ = profile(model_profile, model_input_size_imagenet)
240 | logging.info(
241 | "flops = %fM, param size = %fM",
242 | flops / 1e6,
243 | utils.count_parameters_in_MB(model),
244 | )
245 |
246 | criterion = nn.CrossEntropyLoss()
247 | criterion = criterion.cuda(
248 | local_rank
249 | ) # NOTE: enforce using current assigned gpu id given by mp.spawn
250 | criterion_smooth = CrossEntropyLabelSmooth(CLASSES, args.label_smooth)
251 | criterion_smooth = criterion_smooth.cuda(
252 | local_rank
253 | ) # NOTE: enforce using current assigned gpu id given by mp.spawn
254 |
255 | optimizer = torch.optim.SGD(
256 | model.parameters(),
257 | args.learning_rate,
258 | momentum=args.momentum,
259 | weight_decay=args.weight_decay,
260 | )
261 |
262 | scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
263 | optimizer, float(args.epochs)
264 | )
265 |
266 | start_epoch = 0
267 | best_acc_top1 = 0
268 | best_acc_top5 = 0
269 | checkpoint_tar = os.path.join(args.save, "checkpoint.pth.tar")
270 | if os.path.exists(checkpoint_tar):
271 | logging.info("loading checkpoint {} ..........".format(checkpoint_tar))
272 | checkpoint = torch.load(
273 | checkpoint_tar, map_location={"cuda:0": "cuda:{}".format(local_rank)}
274 | )
275 | start_epoch = checkpoint["epoch"] + 1
276 | model.load_state_dict(checkpoint["state_dict"])
277 | logging.info(
278 | "loaded checkpoint {} epoch = {}".format(
279 | checkpoint_tar, checkpoint["epoch"]
280 | )
281 | )
282 |
283 | for epoch in range(start_epoch, args.epochs):
284 | if args.distributed:
285 | train_sampler.set_epoch(epoch)
286 | if args.lr_scheduler == "cosine":
287 | scheduler.step()
288 | current_lr = scheduler.get_lr()[0]
289 | elif args.lr_scheduler == "linear":
290 | current_lr = adjust_lr(optimizer, epoch)
291 | else:
292 | logging.info("Wrong lr type, exit")
293 | sys.exit(1)
294 |
295 | logging.info("Epoch: %d lr %e", epoch, current_lr)
296 | if epoch < 5:
297 | for param_group in optimizer.param_groups:
298 | param_group["lr"] = current_lr * (epoch + 1) / 5.0
299 | logging.info(
300 | "Warming-up Epoch: %d, LR: %e", epoch, current_lr * (epoch + 1) / 5.0
301 | )
302 |
303 | model.module.drop_path_prob = args.drop_path_prob * epoch / args.epochs
304 | epoch_start = time.time()
305 | train_acc, train_obj = train(
306 | train_loader,
307 | model,
308 | criterion_smooth,
309 | optimizer,
310 | epoch,
311 | local_rank,
312 | args.world_size,
313 | )
314 |
315 | writer.add_scalar("Train/Loss", train_obj, epoch)
316 | writer.add_scalar("Train/LR", current_lr, epoch)
317 |
318 | # NOTE: if gpu id == 0 in current node, execute infer function.
319 | # NOTE: while other processes in current node are waiting for gpu process id 0 to finish infer function.
320 | # NOTE: if gpu id == 0 done infer function, next epoch train functoin is executed over all gpus on all distributed nodes.
321 | if local_rank == 0:
322 | valid_acc_top1, valid_acc_top5, valid_obj = infer(
323 | val_loader, model.module, criterion, epoch, local_rank, args.world_size
324 | )
325 | is_best = False
326 | # if valid_acc_top5 > best_acc_top5:
327 | # best_acc_top5 = valid_acc_top5
328 | if valid_acc_top1 > best_acc_top1:
329 | best_acc_top1 = valid_acc_top1
330 | best_acc_top5 = valid_acc_top5
331 | is_best = True
332 |
333 | logging.info("Valid_acc_top1: %f", valid_acc_top1)
334 | logging.info("Valid_acc_top5: %f", valid_acc_top5)
335 | logging.info("best_acc_top1: %f", best_acc_top1)
336 | logging.info("best_acc_top5: %f", best_acc_top5)
337 | epoch_duration = time.time() - epoch_start
338 | logging.info("Epoch time: %ds.", epoch_duration)
339 |
340 | utils.save_checkpoint(
341 | {
342 | "epoch": epoch,
343 | "state_dict": model.state_dict(),
344 | "best_acc_top1": best_acc_top1,
345 | "optimizer": optimizer.state_dict(),
346 | },
347 | is_best,
348 | args.save,
349 | )
350 |
351 |
352 | def adjust_lr(optimizer, epoch):
353 | # Smaller slope for the last 5 epochs because lr * 1/250 is relatively large
354 | if args.epochs - epoch > 5:
355 | lr = args.learning_rate * (args.epochs - 5 - epoch) / (args.epochs - 5)
356 | else:
357 | lr = args.learning_rate * (args.epochs - epoch) / ((args.epochs - 5) * 5)
358 | for param_group in optimizer.param_groups:
359 | param_group["lr"] = lr
360 | return lr
361 |
362 |
363 | def train(train_loader, model, criterion, optimizer, epoch, local_rank, world_size):
364 | objs = utils.AvgrageMeter()
365 | top1 = utils.AvgrageMeter()
366 | top5 = utils.AvgrageMeter()
367 | batch_time = utils.AvgrageMeter()
368 | model.train()
369 |
370 | for i, (image, target) in enumerate(train_loader):
371 | # image: [128 (1024 // 8 (num total gpus))]
372 | t0 = time.time()
373 | image = image.cuda(
374 | local_rank, non_blocking=True
375 | ) # NOTE: enforce using current assigned gpu id given by mp.spawn
376 | target = target.cuda(
377 | local_rank, non_blocking=True
378 | ) # NOTE: enforce using current assigned gpu id given by mp.spawn
379 | datatime = time.time() - t0
380 |
381 | b_start = time.time()
382 | logits, logits_aux = model(image)
383 | optimizer.zero_grad()
384 | loss = criterion(logits, target)
385 | if args.auxiliary:
386 | loss_aux = criterion(logits_aux, target)
387 | loss += args.auxiliary_weight * loss_aux
388 | loss_reduce = reduce_tensor(loss, 0, world_size)
389 | loss.backward()
390 | nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip)
391 | optimizer.step()
392 | batch_time.update(time.time() - b_start)
393 |
394 | prec1, prec5 = utils.accuracy(logits, target, topk=(1, 5))
395 | n = image.size(0)
396 | objs.update(loss_reduce.data.item(), n)
397 | top1.update(prec1.data.item(), n)
398 | top5.update(prec5.data.item(), n)
399 |
400 | if i % args.report_freq == 0 and local_rank == 0:
401 | logging.info(
402 | "TRAIN Step: %03d/%03d Objs: %e R1: %f R5: %f BTime: %.3fs Datatime: %.3f",
403 | i,
404 | train_iters,
405 | objs.avg,
406 | top1.avg,
407 | top5.avg,
408 | batch_time.avg,
409 | float(datatime),
410 | )
411 |
412 | return top1.avg, objs.avg
413 |
414 |
415 | def infer(val_loader, model, criterion, epoch, local_rank, world_size):
416 | objs = utils.AvgrageMeter()
417 | top1 = utils.AvgrageMeter()
418 | top5 = utils.AvgrageMeter()
419 | model.eval()
420 |
421 | for i, (image, target) in enumerate(val_loader):
422 | t0 = time.time()
423 | image = image.cuda(
424 | local_rank, non_blocking=True
425 | ) # NOTE: enforce using current assigned gpu id given by mp.spawn
426 | target = target.cuda(
427 | local_rank, non_blocking=True
428 | ) # NOTE: enforce using current assigned gpu id given by mp.spawn
429 | datatime = time.time() - t0
430 |
431 | with torch.no_grad():
432 | logits, _ = model(image)
433 | loss = criterion(logits, target)
434 |
435 | prec1, prec5 = utils.accuracy(logits, target, topk=(1, 5))
436 | n = image.size(0)
437 | objs.update(loss.data.item(), n)
438 | top1.update(prec1.data.item(), n)
439 | top5.update(prec5.data.item(), n)
440 |
441 | if i % args.report_freq == 0:
442 | logging.info(
443 | "[%03d] VALID Step: %03d/%03d Objs: %e R1: %f R5: %f Datatime: %.3f",
444 | epoch,
445 | i,
446 | val_iters * world_size,
447 | objs.avg,
448 | top1.avg,
449 | top5.avg,
450 | float(datatime),
451 | )
452 |
453 | return top1.avg, top5.avg, objs.avg
454 |
455 |
456 | if __name__ == "__main__":
457 | mp.spawn(main, (args,), nprocs=int(GPU_NUM), join=True) # GPU_NUM: # gpus per node.
458 |
--------------------------------------------------------------------------------
/retrain_architecture/thop/__init__.py:
--------------------------------------------------------------------------------
1 | from .profile import profile
--------------------------------------------------------------------------------
/retrain_architecture/thop/count_hooks.py:
--------------------------------------------------------------------------------
1 | '''
2 | Copied from https://github.com/megvii-model/RLNAS/blob/main/darts_search_space/cifar10/rlnas/retrain_architecture/thop/count_hooks.py
3 | '''
4 | import argparse
5 |
6 | import torch
7 | import torch.nn as nn
8 | import numpy as np
9 | multiply_adds = 1
10 | num = 0
11 |
12 | def count_ABN(m, x, y):
13 | x = x[0]
14 |
15 | # bn
16 | nelements = x.numel()
17 | # subtract, divide, gamma, beta + relu
18 | total_ops = 4 * nelements
19 | m.total_ops = torch.Tensor([int(total_ops)])
20 | for p in m.parameters():
21 | m.total_params += torch.Tensor([p.numel()])
22 |
23 | def count_convNd(m, x, y):
24 | x = x[0]
25 | cin = m.in_channels
26 | # batch_size = x.size(0)
27 |
28 | kernel_ops = m.weight.size()[2:].numel()
29 | bias_ops = 1 if m.bias is not None else 0
30 | ops_per_element = kernel_ops + bias_ops
31 | output_elements = y.nelement()
32 |
33 | # cout x oW x oH
34 | total_ops = cin * output_elements * ops_per_element // m.groups
35 | m.total_ops = torch.Tensor([int(total_ops)])
36 | for p in m.parameters():
37 | m.total_params += torch.Tensor([p.numel()])
38 |
39 | def count_conv2d(m, x, y):
40 | x = x[0]
41 |
42 | cin = m.in_channels
43 | cout = m.out_channels
44 | kh, kw = m.kernel_size
45 | batch_size = x.size()[0]
46 |
47 | out_h = y.size(2)
48 | out_w = y.size(3)
49 |
50 | # ops per output element
51 | # kernel_mul = kh * kw * cin
52 | # kernel_add = kh * kw * cin - 1
53 | kernel_ops = multiply_adds * kh * kw
54 | bias_ops = 1 if m.bias is not None else 0
55 | ops_per_element = kernel_ops + bias_ops
56 |
57 | # total ops
58 | # num_out_elements = y.numel()
59 | output_elements = batch_size * out_w * out_h * cout
60 | total_ops = output_elements * ops_per_element * cin // m.groups
61 | m.total_ops = torch.Tensor([int(total_ops)])
62 | for p in m.parameters():
63 | m.total_params += torch.Tensor([p.numel()])
64 |
65 | def count_convtranspose2d(m, x, y):
66 | x = x[0]
67 |
68 | cin = m.in_channels
69 | cout = m.out_channels
70 | kh, kw = m.kernel_size
71 | # batch_size = x.size()[0]
72 |
73 | out_h = y.size(2)
74 | out_w = y.size(3)
75 |
76 | # ops per output element
77 | # kernel_mul = kh * kw * cin
78 | # kernel_add = kh * kw * cin - 1
79 | kernel_ops = multiply_adds * kh * kw * cin // m.groups
80 | bias_ops = 1 if m.bias is not None else 0
81 | ops_per_element = kernel_ops + bias_ops
82 |
83 | # total ops
84 | # num_out_elements = y.numel()
85 | # output_elements = batch_size * out_w * out_h * cout
86 | ops_per_element = m.weight.nelement()
87 | output_elements = y.nelement()
88 | total_ops = output_elements * ops_per_element
89 |
90 | m.total_ops = torch.Tensor([int(total_ops)])
91 | for p in m.parameters():
92 | m.total_params += torch.Tensor([p.numel()])
93 |
94 | def count_bn(m, x, y):
95 | x = x[0]
96 |
97 | nelements = x.numel()
98 | # subtract, divide, gamma, beta
99 | total_ops = 4 * nelements
100 |
101 | m.total_ops = torch.Tensor([int(total_ops)])
102 | for p in m.parameters():
103 | m.total_params += torch.Tensor([p.numel()])
104 |
105 | def count_relu(m, x, y):
106 | x = x[0]
107 |
108 | nelements = x.numel()
109 | total_ops = nelements
110 |
111 | m.total_ops = torch.Tensor([int(total_ops)])
112 | for p in m.parameters():
113 | m.total_params += torch.Tensor([p.numel()])
114 |
115 | def count_softmax(m, x, y):
116 | x = x[0]
117 |
118 | batch_size, nfeatures = x.size()
119 |
120 | total_exp = nfeatures
121 | total_add = nfeatures - 1
122 | total_div = nfeatures
123 | total_ops = batch_size * (total_exp + total_add + total_div)
124 |
125 | m.total_ops = torch.Tensor([int(total_ops)])
126 | for p in m.parameters():
127 | m.total_params += torch.Tensor([p.numel()])
128 |
129 | def count_maxpool(m, x, y):
130 | kernel_ops = torch.prod(torch.Tensor([m.kernel_size]))
131 | num_elements = y.numel()
132 | total_ops = kernel_ops * num_elements
133 |
134 | m.total_ops = torch.Tensor([int(total_ops)])
135 | for p in m.parameters():
136 | m.total_params += torch.Tensor([p.numel()])
137 |
138 | def count_adap_maxpool(m, x, y):
139 | kernel = torch.Tensor([*(x[0].shape[2:])]) // torch.Tensor(list((m.output_size,))).squeeze()
140 | kernel_ops = torch.prod(kernel)
141 | num_elements = y.numel()
142 | total_ops = kernel_ops * num_elements
143 |
144 | m.total_ops = torch.Tensor([int(total_ops)])
145 | for p in m.parameters():
146 | m.total_params += torch.Tensor([p.numel()])
147 |
148 | def count_avgpool(m, x, y):
149 | total_add = torch.prod(torch.Tensor([m.kernel_size]))
150 | total_div = 1
151 | kernel_ops = total_add + total_div
152 | num_elements = y.numel()
153 | total_ops = kernel_ops * num_elements
154 |
155 | m.total_ops = torch.Tensor([int(total_ops)])
156 | for p in m.parameters():
157 | m.total_params += torch.Tensor([p.numel()])
158 |
159 | def count_adap_avgpool(m, x, y):
160 | kernel = torch.Tensor([*(x[0].shape[2:])]) // torch.Tensor(list((m.output_size,))).squeeze()
161 | total_add = torch.prod(kernel)
162 | total_div = 1
163 | kernel_ops = total_add + total_div
164 | num_elements = y.numel()
165 | total_ops = kernel_ops * num_elements
166 |
167 | m.total_ops = torch.Tensor([int(total_ops)])
168 | for p in m.parameters():
169 | m.total_params += torch.Tensor([p.numel()])
170 |
171 | def count_linear(m, x, y):
172 | # per output element
173 | total_mul = m.in_features
174 | total_add = m.in_features - 1
175 | num_elements = y.numel()
176 | total_ops = (total_mul + total_add) * num_elements
177 |
178 | m.total_ops = torch.Tensor([int(total_ops)])
179 | for p in m.parameters():
180 | m.total_params += torch.Tensor([p.numel()])
--------------------------------------------------------------------------------
/retrain_architecture/thop/profile.py:
--------------------------------------------------------------------------------
1 | '''
2 | Copied from https://github.com/megvii-model/RLNAS/blob/main/darts_search_space/cifar10/rlnas/retrain_architecture/thop/profile.py
3 | '''
4 | import logging
5 |
6 | import torch
7 | import torch.nn as nn
8 | from torch.nn.modules.conv import _ConvNd
9 | from .count_hooks import *
10 |
11 | register_hooks = {
12 | nn.Conv1d: count_convNd,
13 | nn.Conv2d: count_convNd,
14 | nn.Conv3d: count_convNd,
15 | nn.ConvTranspose2d: count_convtranspose2d,
16 |
17 | # nn.BatchNorm1d: count_bn,
18 | # nn.BatchNorm2d: count_bn,
19 | # nn.BatchNorm3d: count_bn,
20 |
21 | # # nn.ReLU: count_relu,
22 | # # nn.ReLU6: count_relu,
23 | # # nn.LeakyReLU: count_relu,
24 |
25 | # nn.MaxPool1d: count_maxpool,
26 | # nn.MaxPool2d: count_maxpool,
27 | # nn.MaxPool3d: count_maxpool,
28 | # nn.AdaptiveMaxPool1d: count_adap_maxpool,
29 | # nn.AdaptiveMaxPool2d: count_adap_maxpool,
30 | # nn.AdaptiveMaxPool3d: count_adap_maxpool,
31 |
32 | # nn.AvgPool1d: count_avgpool,
33 | # nn.AvgPool2d: count_avgpool,
34 | # nn.AvgPool3d: count_avgpool,
35 |
36 | # nn.AdaptiveAvgPool1d: count_adap_avgpool,
37 | # nn.AdaptiveAvgPool2d: count_adap_avgpool,
38 | # nn.AdaptiveAvgPool3d: count_adap_avgpool,
39 | nn.Linear: count_linear,
40 | nn.Dropout: None,
41 | }
42 |
43 |
44 | def profile(model, input_size, custom_ops={}, device="cpu"):
45 | handler_collection = []
46 |
47 | def add_hooks(m):
48 | if len(list(m.children())) > 0:
49 | return
50 |
51 | m.register_buffer('total_ops', torch.zeros(1))
52 | m.register_buffer('total_params', torch.zeros(1))
53 |
54 | # for p in m.parameters():
55 | # m.total_params += torch.Tensor([p.numel()])
56 | m_type = type(m)
57 | fn = None
58 |
59 | if m_type in custom_ops:
60 | fn = custom_ops[m_type]
61 | elif m_type in register_hooks:
62 | fn = register_hooks[m_type]
63 | else:
64 | #print("Not implemented for ", m)
65 | pass
66 |
67 | if fn is not None:
68 | #print("Register FLOP counter for module %s" % str(m))
69 | handler = m.register_forward_hook(fn)
70 | handler_collection.append(handler)
71 |
72 | # original_device = model.parameters().__next__().device
73 | training = model.training
74 |
75 | model.eval().to(device)
76 | model.apply(add_hooks)
77 | x = torch.zeros(input_size).to(device)
78 | with torch.no_grad():
79 | model(x)
80 |
81 | total_ops = 0
82 | total_params = 0
83 | for m in model.modules():
84 | if len(list(m.children())) > 0: # skip for non-leaf module
85 | continue
86 | total_ops += m.total_ops
87 | total_params += m.total_params
88 | total_ops = total_ops.item()
89 | total_params = total_params.item()
90 |
91 | # model.train(training).to(original_device)
92 | for handler in handler_collection:
93 | handler.remove()
94 |
95 | return total_ops, total_params
96 |
--------------------------------------------------------------------------------
/retrain_architecture/thop/utils.py:
--------------------------------------------------------------------------------
1 | '''
2 | Copied from https://github.com/megvii-model/RLNAS/blob/main/darts_search_space/cifar10/rlnas/retrain_architecture/thop/utils.py
3 | '''
4 |
5 | def clever_format(num, format="%.2f"):
6 | if num > 1e12:
7 | return format % (num / 1e12) + "T"
8 | if num > 1e9:
9 | return format % (num / 1e9) + "G"
10 | if num > 1e6:
11 | return format % (num / 1e6) + "M"
12 | if num > 1e3:
13 | return format % (num / 1e3) + "K"
--------------------------------------------------------------------------------
/retrain_architecture/utils.py:
--------------------------------------------------------------------------------
1 | '''
2 | Copied from https://github.com/megvii-model/RLNAS/blob/main/darts_search_space/cifar10/rlnas/retrain_architecture/utils.py
3 | '''
4 |
5 | import os
6 | import numpy as np
7 | import torch
8 | import shutil
9 | import torchvision.transforms as transforms
10 | from torch.autograd import Variable
11 | import cv2
12 | import random
13 | import PIL
14 | from PIL import Image
15 | import math
16 |
17 | class AvgrageMeter(object):
18 |
19 | def __init__(self):
20 | self.reset()
21 |
22 | def reset(self):
23 | self.avg = 0
24 | self.sum = 0
25 | self.cnt = 0
26 |
27 | def update(self, val, n=1):
28 | self.sum += val * n
29 | self.cnt += n
30 | self.avg = self.sum / self.cnt
31 |
32 |
33 | def accuracy(output, target, topk=(1,)):
34 | maxk = max(topk)
35 | batch_size = target.size(0)
36 |
37 | _, pred = output.topk(maxk, 1, True, True)
38 | pred = pred.t()
39 | correct = pred.eq(target.view(1, -1).expand_as(pred))
40 |
41 | res = []
42 | for k in topk:
43 | correct_k = correct[:k].reshape(-1).float().sum(0)
44 | res.append(correct_k.mul_(100.0/batch_size))
45 | return res
46 |
47 |
48 | class Cutout(object):
49 | def __init__(self, length):
50 | self.length = length
51 |
52 | def __call__(self, img):
53 | h, w = img.size(1), img.size(2)
54 | mask = np.ones((h, w), np.float32)
55 | y = np.random.randint(h)
56 | x = np.random.randint(w)
57 |
58 | y1 = np.clip(y - self.length // 2, 0, h)
59 | y2 = np.clip(y + self.length // 2, 0, h)
60 | x1 = np.clip(x - self.length // 2, 0, w)
61 | x2 = np.clip(x + self.length // 2, 0, w)
62 |
63 | mask[y1: y2, x1: x2] = 0.
64 | mask = torch.from_numpy(mask)
65 | mask = mask.expand_as(img)
66 | img *= mask
67 | return img
68 |
69 |
70 | def _data_transforms_cifar10(args):
71 | CIFAR_MEAN = [0.49139968, 0.48215827, 0.44653124]
72 | CIFAR_STD = [0.24703233, 0.24348505, 0.26158768]
73 |
74 | train_transform = transforms.Compose([
75 | transforms.RandomCrop(32, padding=4),
76 | transforms.RandomHorizontalFlip(),
77 | transforms.ToTensor(),
78 | transforms.Normalize(CIFAR_MEAN, CIFAR_STD),
79 | ])
80 | if args.cutout:
81 | train_transform.transforms.append(Cutout(args.cutout_length))
82 |
83 | valid_transform = transforms.Compose([
84 | transforms.ToTensor(),
85 | transforms.Normalize(CIFAR_MEAN, CIFAR_STD),
86 | ])
87 | return train_transform, valid_transform
88 |
89 |
90 | def count_parameters_in_MB(model):
91 | return np.sum(np.prod(v.size()) for name, v in model.named_parameters() if "auxiliary" not in name)/1e6
92 |
93 |
94 | def save_checkpoint(state, is_best, save):
95 | filename = os.path.join(save, 'checkpoint.pth.tar')
96 | torch.save(state, filename)
97 | if is_best:
98 | best_filename = os.path.join(save, 'model_best.pth.tar')
99 | shutil.copyfile(filename, best_filename)
100 |
101 |
102 | def save(model, model_path):
103 | torch.save(model.state_dict(), model_path)
104 |
105 |
106 | def load(model, model_path):
107 | model.load_state_dict(torch.load(model_path))
108 |
109 |
110 | def drop_path(x, drop_prob):
111 | if drop_prob > 0.:
112 | keep_prob = 1.-drop_prob
113 | mask = Variable(torch.cuda.FloatTensor(x.size(0), 1, 1, 1).bernoulli_(keep_prob))
114 | x.div_(keep_prob)
115 | x.mul_(mask)
116 | return x
117 |
118 |
119 | def create_exp_dir(path, scripts_to_save=None):
120 | if not os.path.exists(path):
121 | os.mkdir(path)
122 | print('Experiment dir : {}'.format(path))
123 |
124 | if scripts_to_save is not None:
125 | os.mkdir(os.path.join(path, 'scripts'))
126 | for script in scripts_to_save:
127 | dst_file = os.path.join(path, 'scripts', os.path.basename(script))
128 | shutil.copyfile(script, dst_file)
129 |
130 |
131 | class OpencvResize(object):
132 |
133 | def __init__(self, size=256):
134 | self.size = size
135 |
136 | def __call__(self, img):
137 | assert isinstance(img, PIL.Image.Image)
138 | img = np.asarray(img) # (H,W,3) RGB
139 | img = img[:, :, ::-1] # 2 BGR
140 | img = np.ascontiguousarray(img)
141 | H, W, _ = img.shape
142 | target_size = (int(self.size / H * W + 0.5), self.size) if H < W else (self.size, int(self.size / W * H + 0.5))
143 | img = cv2.resize(img, target_size, interpolation=cv2.INTER_LINEAR)
144 | img = img[:, :, ::-1] # 2 RGB
145 | img = np.ascontiguousarray(img)
146 | img = Image.fromarray(img)
147 | return img
148 |
149 | class RandomResizedCrop(object):
150 |
151 | def __init__(self, scale=(0.08, 1.0), target_size: int = 224, max_attempts: int = 10):
152 | assert scale[0] <= scale[1]
153 | self.scale = scale
154 | assert target_size > 0
155 | self.target_size = target_size
156 | assert max_attempts > 0
157 | self.max_attempts = max_attempts
158 |
159 | def __call__(self, img):
160 | assert isinstance(img, PIL.Image.Image)
161 | img = np.asarray(img, dtype=np.uint8)
162 | H, W, C = img.shape
163 |
164 | well_cropped = False
165 | for _ in range(self.max_attempts):
166 | crop_area = (H * W) * random.uniform(self.scale[0], self.scale[1])
167 | crop_edge = round(math.sqrt(crop_area))
168 | dH = H - crop_edge
169 | dW = W - crop_edge
170 | crop_left = random.randint(min(dW, 0), max(dW, 0))
171 | crop_top = random.randint(min(dH, 0), max(dH, 0))
172 | if dH >= 0 and dW >= 0:
173 | well_cropped = True
174 | break
175 |
176 | crop_bottom = crop_top + crop_edge
177 | crop_right = crop_left + crop_edge
178 | if well_cropped:
179 | crop_image = img[crop_top:crop_bottom, :, :][:, crop_left:crop_right, :]
180 |
181 | else:
182 | roi_top = max(crop_top, 0)
183 | padding_top = roi_top - crop_top
184 | roi_bottom = min(crop_bottom, H)
185 | padding_bottom = crop_bottom - roi_bottom
186 | roi_left = max(crop_left, 0)
187 | padding_left = roi_left - crop_left
188 | roi_right = min(crop_right, W)
189 | padding_right = crop_right - roi_right
190 |
191 | roi_image = img[roi_top:roi_bottom, :, :][:, roi_left:roi_right, :]
192 | crop_image = cv2.copyMakeBorder(roi_image, padding_top, padding_bottom, padding_left, padding_right,
193 | borderType=cv2.BORDER_CONSTANT, value=0)
194 |
195 | random.choice([1])
196 | target_image = cv2.resize(crop_image, (self.target_size, self.target_size), interpolation=cv2.INTER_LINEAR)
197 | target_image = PIL.Image.fromarray(target_image.astype('uint8'))
198 | return target_image
199 |
200 |
201 | class LighteningJitter(object):
202 |
203 | def __init__(self, eigen_vecs, eigen_values, max_eigen_jitter=0.1):
204 | self.eigen_vecs = np.array(eigen_vecs, dtype=np.float32)
205 | self.eigen_values = np.array(eigen_values, dtype=np.float32)
206 | self.max_eigen_jitter = max_eigen_jitter
207 |
208 | def __call__(self, img):
209 | assert isinstance(img, PIL.Image.Image)
210 | img = np.asarray(img, dtype=np.float32)
211 | img = np.ascontiguousarray(img / 255)
212 |
213 | cur_eigen_jitter = np.random.normal(scale=self.max_eigen_jitter, size=self.eigen_values.shape)
214 | color_purb = (self.eigen_vecs @ (self.eigen_values * cur_eigen_jitter)).reshape([1, 1, -1])
215 | img += color_purb
216 | img = np.ascontiguousarray(img * 255)
217 | img.clip(0, 255, out=img)
218 | img = PIL.Image.fromarray(np.uint8(img))
219 | return img
220 |
221 | def get_train_transform():
222 | eigvec = np.array([
223 | [-0.5836, -0.6948, 0.4203],
224 | [-0.5808, -0.0045, -0.8140],
225 | [-0.5675, 0.7192, 0.4009]
226 | ])
227 |
228 | eigval = np.array([0.2175, 0.0188, 0.0045])
229 |
230 | transform = transforms.Compose([
231 | RandomResizedCrop(target_size=224, scale=(0.08, 1.0)),
232 | LighteningJitter(eigen_vecs=eigvec[::-1, :], eigen_values=eigval,
233 | max_eigen_jitter=0.1),
234 | transforms.RandomHorizontalFlip(0.5),
235 | transforms.ToTensor(),
236 | ])
237 | return transform
238 |
239 | def get_eval_transform():
240 | transform = transforms.Compose([
241 | OpencvResize(256),
242 | transforms.CenterCrop(224),
243 | transforms.ToTensor(),
244 | ])
245 |
246 | return transform
--------------------------------------------------------------------------------
/retrain_architecture/visualize.py:
--------------------------------------------------------------------------------
1 | '''
2 | Copied from https://github.com/ECP-CANDLE/Benchmarks/blob/master/common/darts/visualize.py
3 | '''
4 | import sys
5 | import genotypes
6 | from graphviz import Digraph
7 |
8 |
9 | def plot(genotype, filename):
10 | g = Digraph(
11 | format='pdf',
12 | edge_attr=dict(fontsize='20', fontname="times"),
13 | node_attr=dict(style='filled', shape='rect', align='center', fontsize='20', height='0.5', width='0.5', penwidth='2', fontname="times"),
14 | engine='dot')
15 | g.body.extend(['rankdir=LR'])
16 |
17 | g.node("c_{k-2}", fillcolor='darkseagreen2')
18 | g.node("c_{k-1}", fillcolor='darkseagreen2')
19 | assert len(genotype) % 2 == 0
20 | steps = len(genotype) // 2
21 |
22 | for i in range(steps):
23 | g.node(str(i), fillcolor='lightblue')
24 |
25 | for i in range(steps):
26 | for k in [2*i, 2*i + 1]:
27 | op, j = genotype[k]
28 | if j == 0:
29 | u = "c_{k-2}"
30 | elif j == 1:
31 | u = "c_{k-1}"
32 | else:
33 | u = str(j-2)
34 | v = str(i)
35 | g.edge(u, v, label=op, fillcolor="gray")
36 |
37 | g.node("c_{k}", fillcolor='palegoldenrod')
38 | for i in range(steps):
39 | g.edge(str(i), "c_{k}", fillcolor="gray")
40 |
41 | g.render(filename, view=True)
42 |
43 |
44 | if __name__ == '__main__':
45 | if len(sys.argv) != 2:
46 | print("usage:\n python {} ARCH_NAME".format(sys.argv[0]))
47 | sys.exit(1)
48 |
49 | genotype_name = sys.argv[1]
50 | try:
51 | genotype = eval('genotypes.{}'.format(genotype_name))
52 | except AttributeError:
53 | print("{} is not specified in genotypes.py".format(genotype_name))
54 | sys.exit(1)
55 |
56 | plot(genotype.normal, "normal")
57 | plot(genotype.reduce, "reduction")
58 |
59 |
--------------------------------------------------------------------------------
/train_supernet/config.py:
--------------------------------------------------------------------------------
1 | '''
2 | Copied from https://github.com/megvii-model/RLNAS/blob/main/darts_search_space/cifar10/rlnas/train_supernet/config.py
3 | '''
4 | import os
5 | class config:
6 | # Basic configration
7 | layers = 8
8 | edges = 14
9 | model_input_size_imagenet = (1, 3, 224, 224)
10 |
11 | # Candidate operators
12 | blocks_keys = [
13 | 'none',
14 | 'max_pool_3x3',
15 | 'avg_pool_3x3',
16 | 'skip_connect',
17 | 'sep_conv_3x3',
18 | 'sep_conv_5x5',
19 | 'dil_conv_3x3',
20 | 'dil_conv_5x5'
21 | ]
22 | op_num=len(blocks_keys)
23 |
24 | # Operators encoding
25 | NONE = 0
26 | MAX_POOLING_3x3 = 1
27 | AVG_POOL_3x3 = 2
28 | SKIP_CONNECT = 3
29 | SEP_CONV_3x3 = 4
30 | SEP_CONV_5x5 = 5
31 | DIL_CONV_3x3 = 6
32 | DIL_CONV_5x5 = 7
33 |
34 |
35 | # Shrinking configuration
36 | exp_name = './'
37 | net_cache = os.path.join(exp_name, 'weight.pt')
38 | base_net_cache = os.path.join(exp_name, 'base_weight.pt')
39 | modify_base_net_cache = os.path.join(exp_name, 'weight_0.pt')
40 | shrinking_finish_threshold = 1000000
41 | sample_num = 1000
42 | per_stage_drop_num = 14
43 | epsilon = 1e-12
44 |
45 | # Enumerate all paths of a single cell
46 | paths = [[0, 2, 3, 4, 5], [0, 2, 3, 5], [0, 2, 4, 5], [0, 2, 5], [0, 3, 4, 5], [0, 3, 5],[0, 4, 5],[0, 5],
47 | [1, 2, 3, 4, 5], [1, 2, 3, 5], [1, 2, 4, 5], [1, 2, 5], [1, 3, 4, 5], [1, 3, 5],[1, 4, 5],[1, 5]]
48 |
49 |
50 |
51 |
52 |
53 |
--------------------------------------------------------------------------------
/train_supernet/datasets/DownsampledImageNet.py:
--------------------------------------------------------------------------------
1 | ##################################################
2 | # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 #
3 | ##################################################
4 | import os, sys, hashlib, torch
5 | import numpy as np
6 | from PIL import Image
7 | import torch.utils.data as data
8 | if sys.version_info[0] == 2:
9 | import cPickle as pickle
10 | else:
11 | import pickle
12 | import pdb
13 |
14 | def calculate_md5(fpath, chunk_size=1024 * 1024):
15 | md5 = hashlib.md5()
16 | with open(fpath, 'rb') as f:
17 | for chunk in iter(lambda: f.read(chunk_size), b''):
18 | md5.update(chunk)
19 | return md5.hexdigest()
20 |
21 |
22 | def check_md5(fpath, md5, **kwargs):
23 | return md5 == calculate_md5(fpath, **kwargs)
24 |
25 |
26 | def check_integrity(fpath, md5=None):
27 | if not os.path.isfile(fpath): return False
28 | if md5 is None: return True
29 | else : return check_md5(fpath, md5)
30 |
31 |
32 | class ImageNet16(data.Dataset):
33 | # http://image-net.org/download-images
34 | # A Downsampled Variant of ImageNet as an Alternative to the CIFAR datasets
35 | # https://arxiv.org/pdf/1707.08819.pdf
36 |
37 | train_list = [
38 | ['train_data_batch_1', '27846dcaa50de8e21a7d1a35f30f0e91'],
39 | ['train_data_batch_2', 'c7254a054e0e795c69120a5727050e3f'],
40 | ['train_data_batch_3', '4333d3df2e5ffb114b05d2ffc19b1e87'],
41 | ['train_data_batch_4', '1620cdf193304f4a92677b695d70d10f'],
42 | ['train_data_batch_5', '348b3c2fdbb3940c4e9e834affd3b18d'],
43 | ['train_data_batch_6', '6e765307c242a1b3d7d5ef9139b48945'],
44 | ['train_data_batch_7', '564926d8cbf8fc4818ba23d2faac7564'],
45 | ['train_data_batch_8', 'f4755871f718ccb653440b9dd0ebac66'],
46 | ['train_data_batch_9', 'bb6dd660c38c58552125b1a92f86b5d4'],
47 | ['train_data_batch_10','8f03f34ac4b42271a294f91bf480f29b'],
48 | ]
49 | valid_list = [
50 | ['val_data', '3410e3017fdaefba8d5073aaa65e4bd6'],
51 | ]
52 |
53 | def __init__(self, root, train, transform, use_num_of_class_only=None):
54 | self.root = root
55 | self.transform = transform
56 | self.train = train # training set or valid set
57 | if not self._check_integrity(): raise RuntimeError('Dataset not found or corrupted.')
58 |
59 | if self.train: downloaded_list = self.train_list
60 | else : downloaded_list = self.valid_list
61 | self.data = []
62 | self.targets = []
63 |
64 | # now load the picked numpy arrays
65 | for i, (file_name, checksum) in enumerate(downloaded_list):
66 | file_path = os.path.join(self.root, file_name)
67 | #print ('Load {:}/{:02d}-th : {:}'.format(i, len(downloaded_list), file_path))
68 | with open(file_path, 'rb') as f:
69 | if sys.version_info[0] == 2:
70 | entry = pickle.load(f)
71 | else:
72 | entry = pickle.load(f, encoding='latin1')
73 | self.data.append(entry['data'])
74 | self.targets.extend(entry['labels'])
75 | self.data = np.vstack(self.data).reshape(-1, 3, 16, 16)
76 | self.data = self.data.transpose((0, 2, 3, 1)) # convert to HWC
77 | if use_num_of_class_only is not None:
78 | assert isinstance(use_num_of_class_only, int) and use_num_of_class_only > 0 and use_num_of_class_only < 1000, 'invalid use_num_of_class_only : {:}'.format(use_num_of_class_only)
79 | new_data, new_targets = [], []
80 | for I, L in zip(self.data, self.targets):
81 | if 1 <= L <= use_num_of_class_only:
82 | new_data.append( I )
83 | new_targets.append( L )
84 | self.data = new_data
85 | self.targets = new_targets
86 | # self.mean.append(entry['mean'])
87 | #self.mean = np.vstack(self.mean).reshape(-1, 3, 16, 16)
88 | #self.mean = np.mean(np.mean(np.mean(self.mean, axis=0), axis=1), axis=1)
89 | #print ('Mean : {:}'.format(self.mean))
90 | #temp = self.data - np.reshape(self.mean, (1, 1, 1, 3))
91 | #std_data = np.std(temp, axis=0)
92 | #std_data = np.mean(np.mean(std_data, axis=0), axis=0)
93 | #print ('Std : {:}'.format(std_data))
94 |
95 | def __getitem__(self, index):
96 | img, target = self.data[index], self.targets[index] - 1
97 |
98 | img = Image.fromarray(img)
99 |
100 | if self.transform is not None:
101 | img = self.transform(img)
102 |
103 | return img, target
104 |
105 | def __len__(self):
106 | return len(self.data)
107 |
108 | def _check_integrity(self):
109 | root = self.root
110 | for fentry in (self.train_list + self.valid_list):
111 | filename, md5 = fentry[0], fentry[1]
112 | fpath = os.path.join(root, filename)
113 | if not check_integrity(fpath, md5):
114 | return False
115 | return True
116 |
117 | #
118 | if __name__ == '__main__':
119 | train = ImageNet16('/data02/dongxuanyi/.torch/cifar.python/ImageNet16', True , None)
120 | valid = ImageNet16('/data02/dongxuanyi/.torch/cifar.python/ImageNet16', False, None)
121 |
122 | print ( len(train) )
123 | print ( len(valid) )
124 | image, label = train[111]
125 | trainX = ImageNet16('/data02/dongxuanyi/.torch/cifar.python/ImageNet16', True , None, 200)
126 | validX = ImageNet16('/data02/dongxuanyi/.torch/cifar.python/ImageNet16', False , None, 200)
127 | print ( len(trainX) )
128 | print ( len(validX) )
129 | #import pdb; pdb.set_trace()
130 |
--------------------------------------------------------------------------------
/train_supernet/datasets/SearchDatasetWrap.py:
--------------------------------------------------------------------------------
1 | ##################################################
2 | # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 #
3 | ##################################################
4 | import torch, copy, random
5 | import torch.utils.data as data
6 |
7 |
8 | class SearchDataset(data.Dataset):
9 |
10 | def __init__(self, name, data, train_split, valid_split, check=True):
11 | self.datasetname = name
12 | if isinstance(data, (list, tuple)): # new type of SearchDataset
13 | assert len(data) == 2, 'invalid length: {:}'.format( len(data) )
14 | self.train_data = data[0]
15 | self.valid_data = data[1]
16 | self.train_split = train_split.copy()
17 | self.valid_split = valid_split.copy()
18 | self.mode_str = 'V2' # new mode
19 | else:
20 | self.mode_str = 'V1' # old mode
21 | self.data = data
22 | self.train_split = train_split.copy()
23 | self.valid_split = valid_split.copy()
24 | if check:
25 | intersection = set(train_split).intersection(set(valid_split))
26 | assert len(intersection) == 0, 'the splitted train and validation sets should have no intersection'
27 | self.length = len(self.train_split)
28 |
29 | def __repr__(self):
30 | return ('{name}(name={datasetname}, train={tr_L}, valid={val_L}, version={ver})'.format(name=self.__class__.__name__, datasetname=self.datasetname, tr_L=len(self.train_split), val_L=len(self.valid_split), ver=self.mode_str))
31 |
32 | def __len__(self):
33 | return self.length
34 |
35 | def __getitem__(self, index):
36 | assert index >= 0 and index < self.length, 'invalid index = {:}'.format(index)
37 | train_index = self.train_split[index]
38 | valid_index = random.choice( self.valid_split )
39 | if self.mode_str == 'V1':
40 | train_image, train_label = self.data[train_index]
41 | valid_image, valid_label = self.data[valid_index]
42 | elif self.mode_str == 'V2':
43 | train_image, train_label = self.train_data[train_index]
44 | valid_image, valid_label = self.valid_data[valid_index]
45 | else: raise ValueError('invalid mode : {:}'.format(self.mode_str))
46 | return train_image, train_label, valid_image, valid_label
47 |
--------------------------------------------------------------------------------
/train_supernet/datasets/__init__.py:
--------------------------------------------------------------------------------
1 | ##################################################
2 | # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 #
3 | ##################################################
4 | from .get_dataset_with_transform import get_datasets, get_nas_search_loaders
5 | from .SearchDatasetWrap import SearchDataset
6 |
--------------------------------------------------------------------------------
/train_supernet/datasets/config_utils/__init__.py:
--------------------------------------------------------------------------------
1 | import sys, os
2 | import json
3 | from collections import namedtuple
4 |
5 | support_types = ('str', 'int', 'bool', 'float', 'none')
6 |
7 | def convert_param(original_lists):
8 | assert isinstance(original_lists, list), 'The type is not right : {:}'.format(original_lists)
9 | ctype, value = original_lists[0], original_lists[1]
10 | assert ctype in support_types, 'Ctype={:}, support={:}'.format(ctype, support_types)
11 | is_list = isinstance(value, list)
12 | if not is_list: value = [value]
13 | outs = []
14 | for x in value:
15 | if ctype == 'int':
16 | x = int(x)
17 | elif ctype == 'str':
18 | x = str(x)
19 | elif ctype == 'bool':
20 | x = bool(int(x))
21 | elif ctype == 'float':
22 | x = float(x)
23 | elif ctype == 'none':
24 | assert x == 'None', 'for none type, the value must be None instead of {:}'.format(x)
25 | x = None
26 | else:
27 | raise TypeError('Does not know this type : {:}'.format(ctype))
28 | outs.append(x)
29 | if not is_list: outs = outs[0]
30 | return outs
31 |
32 | def load_config(path, extra, logger):
33 | path = str(path)
34 | if hasattr(logger, 'log'): logger.log(path)
35 | assert os.path.exists(path), 'Can not find {:}'.format(path)
36 | # Reading data back
37 | with open(path, 'r') as f:
38 | data = json.load(f)
39 | content = { k: convert_param(v) for k,v in data.items()}
40 | assert extra is None or isinstance(extra, dict), 'invalid type of extra : {:}'.format(extra)
41 | if isinstance(extra, dict): content = {**content, **extra}
42 | Arguments = namedtuple('Configure', ' '.join(content.keys()))
43 | content = Arguments(**content)
44 | if hasattr(logger, 'log'): logger.log('{:}'.format(content))
45 | return content
--------------------------------------------------------------------------------
/train_supernet/datasets/get_dataset_with_transform.py:
--------------------------------------------------------------------------------
1 | ##################################################
2 | # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 #
3 | ##################################################
4 | import os, sys, torch
5 | import os.path as osp
6 | import numpy as np
7 | import torchvision.datasets as dset
8 | import torchvision.transforms as transforms
9 | from copy import deepcopy
10 | from PIL import Image
11 |
12 | from .DownsampledImageNet import ImageNet16
13 | from .SearchDatasetWrap import SearchDataset
14 | from .config_utils import load_config as load_dataset_config
15 | from torchvision.transforms import transforms
16 | from PIL import ImageFilter, ImageOps
17 | import random
18 | import torchvision.datasets as datasets
19 |
20 | Dataset2Class = {'cifar10' : 10,
21 | 'cifar100': 100,
22 | 'imagenet-1k-s':1000,
23 | 'imagenet-1k' : 1000,
24 | 'ImageNet16' : 1000,
25 | 'ImageNet16-150': 150,
26 | 'ImageNet16-120': 120,
27 | 'ImageNet16-200': 200}
28 |
29 | class CUTOUT(object):
30 |
31 | def __init__(self, length):
32 | self.length = length
33 |
34 | def __repr__(self):
35 | return ('{name}(length={length})'.format(name=self.__class__.__name__, **self.__dict__))
36 |
37 | def __call__(self, img):
38 | h, w = img.size(1), img.size(2)
39 | mask = np.ones((h, w), np.float32)
40 | y = np.random.randint(h)
41 | x = np.random.randint(w)
42 |
43 | y1 = np.clip(y - self.length // 2, 0, h)
44 | y2 = np.clip(y + self.length // 2, 0, h)
45 | x1 = np.clip(x - self.length // 2, 0, w)
46 | x2 = np.clip(x + self.length // 2, 0, w)
47 |
48 | mask[y1: y2, x1: x2] = 0.
49 | mask = torch.from_numpy(mask)
50 | mask = mask.expand_as(img)
51 | img *= mask
52 | return img
53 |
54 |
55 | imagenet_pca = {
56 | 'eigval': np.asarray([0.2175, 0.0188, 0.0045]),
57 | 'eigvec': np.asarray([
58 | [-0.5675, 0.7192, 0.4009],
59 | [-0.5808, -0.0045, -0.8140],
60 | [-0.5836, -0.6948, 0.4203],
61 | ])
62 | }
63 |
64 |
65 | class Lighting(object):
66 | def __init__(self, alphastd,
67 | eigval=imagenet_pca['eigval'],
68 | eigvec=imagenet_pca['eigvec']):
69 | self.alphastd = alphastd
70 | assert eigval.shape == (3,)
71 | assert eigvec.shape == (3, 3)
72 | self.eigval = eigval
73 | self.eigvec = eigvec
74 |
75 | def __call__(self, img):
76 | if self.alphastd == 0.:
77 | return img
78 | rnd = np.random.randn(3) * self.alphastd
79 | rnd = rnd.astype('float32')
80 | v = rnd
81 | old_dtype = np.asarray(img).dtype
82 | v = v * self.eigval
83 | v = v.reshape((3, 1))
84 | inc = np.dot(self.eigvec, v).reshape((3,))
85 | img = np.add(img, inc)
86 | if old_dtype == np.uint8:
87 | img = np.clip(img, 0, 255)
88 | img = Image.fromarray(img.astype(old_dtype), 'RGB')
89 | return img
90 |
91 | def __repr__(self):
92 | return self.__class__.__name__ + '()'
93 |
94 |
95 | class Cifar10RandomLabels(datasets.CIFAR10):
96 | """CIFAR10 dataset, with support for randomly corrupt labels.
97 | Params
98 | ------
99 | rand_seed: int
100 | Default 0. numpy random seed.
101 | num_classes: int
102 | Default 10. The number of classes in the dataset.
103 | """
104 | def __init__(self, rand_seed=0, num_classes=10, **kwargs):
105 | super(Cifar10RandomLabels, self).__init__(**kwargs)
106 | self.n_classes = num_classes
107 | self.rand_seed = rand_seed
108 | self.random_labels()
109 |
110 | def random_labels(self):
111 | labels = np.array(self.targets)
112 | print('num_classes:{}, random labels num:{}, random seed:{}'.format(self.n_classes, len(labels), self.rand_seed))
113 | np.random.seed(self.rand_seed)
114 | rnd_labels = np.random.randint(0, self.n_classes, len(labels))
115 | # we need to explicitly cast the labels from npy.int64 to
116 | # builtin int type, otherwise pytorch will fail...
117 | labels = [int(x) for x in rnd_labels]
118 |
119 | self.targets = labels
120 |
121 | class Cifar100RandomLabels(datasets.CIFAR100):
122 | """CIFAR10 dataset, with support for randomly corrupt labels.
123 | Params
124 | ------
125 | rand_seed: int
126 | Default 0. numpy random seed.
127 | num_classes: int
128 | Default 100. The number of classes in the dataset.
129 | """
130 | def __init__(self, rand_seed=0, num_classes=100, **kwargs):
131 | super(Cifar100RandomLabels, self).__init__(**kwargs)
132 | self.n_classes = num_classes
133 | self.rand_seed = rand_seed
134 | self.random_labels()
135 |
136 | def random_labels(self):
137 | labels = np.array(self.targets)
138 | print('num_classes:{}, random labels num:{}, random seed:{}'.format(self.n_classes, len(labels), self.rand_seed))
139 | np.random.seed(self.rand_seed)
140 | rnd_labels = np.random.randint(0, self.n_classes, len(labels))
141 | # we need to explicitly cast the labels from npy.int64 to
142 | # builtin int type, otherwise pytorch will fail...
143 | labels = [int(x) for x in rnd_labels]
144 |
145 | self.targets = labels
146 |
147 | class ImageNet16RandomLabels(ImageNet16):
148 | """CIFAR10 dataset, with support for randomly corrupt labels.
149 | Params
150 | ------
151 | rand_seed: int
152 | Default 0. numpy random seed.
153 | num_classes: int
154 | Default 120. The number of classes in the dataset.
155 | """
156 | def __init__(self, rand_seed=0, num_classes=120, **kwargs):
157 | super(ImageNet16RandomLabels, self).__init__(**kwargs)
158 | self.n_classes = num_classes
159 | self.rand_seed = rand_seed
160 | self.random_labels()
161 | # print('min_label:{}, max_label:{}'.format(min(self.targets), max(self.targets)))
162 |
163 | def random_labels(self):
164 | labels = np.array(self.targets)
165 | print('num_classes:{}, random labels num:{}, random seed:{}'.format(self.n_classes, len(labels), self.rand_seed))
166 | np.random.seed(self.rand_seed)
167 | rnd_labels = np.random.randint(1, self.n_classes+1, len(labels))
168 | # we need to explicitly cast the labels from npy.int64 to
169 | # builtin int type, otherwise pytorch will fail...
170 | labels = [int(x) for x in rnd_labels]
171 |
172 | self.targets = labels
173 |
174 | def get_datasets(name, root, cutout, rand_seed, byol_aug_type=None, random_label=True):
175 |
176 | if name == 'cifar10':
177 | mean = [x / 255 for x in [125.3, 123.0, 113.9]]
178 | std = [x / 255 for x in [63.0, 62.1, 66.7]]
179 | elif name == 'cifar100':
180 | mean = [x / 255 for x in [129.3, 124.1, 112.4]]
181 | std = [x / 255 for x in [68.2, 65.4, 70.4]]
182 | elif name.startswith('imagenet-1k'):
183 | mean, std = [0.485, 0.456, 0.406], [0.229, 0.224, 0.225]
184 | elif name.startswith('ImageNet16'):
185 | mean = [x / 255 for x in [122.68, 116.66, 104.01]]
186 | std = [x / 255 for x in [63.22, 61.26 , 65.09]]
187 | else:
188 | raise TypeError("Unknow dataset : {:}".format(name))
189 |
190 | # Data Argumentation
191 | if name == 'cifar10' or name == 'cifar100':
192 | lists = [transforms.RandomHorizontalFlip(), transforms.RandomCrop(32, padding=4), transforms.ToTensor(), transforms.Normalize(mean, std)]
193 | if cutout > 0 : lists += [CUTOUT(cutout)]
194 | if byol_aug_type is None:
195 | train_transform = transforms.Compose(lists)
196 | elif byol_aug_type=='byol':
197 | online_aug = get_train_transform('BYOL_Tau', 32, mean, std)
198 | target_aug = get_train_transform('BYOL_Tau_Hat', 32, mean, std)
199 | train_transform = TwoImageAugmentations(online_aug, target_aug)
200 | else:
201 | raise NotImplementedError
202 | test_transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize(mean, std)])
203 | xshape = (1, 3, 32, 32)
204 | elif name.startswith('ImageNet16'):
205 | lists = [transforms.RandomHorizontalFlip(), transforms.RandomCrop(16, padding=2), transforms.ToTensor(), transforms.Normalize(mean, std)]
206 | if cutout > 0 : lists += [CUTOUT(cutout)]
207 | if byol_aug_type is None:
208 | train_transform = transforms.Compose(lists)
209 | elif byol_aug_type=='byol':
210 | online_aug = get_train_transform('BYOL_Tau', 16, mean, std)
211 | target_aug = get_train_transform('BYOL_Tau_Hat', 16, mean, std)
212 | train_transform = TwoImageAugmentations(online_aug, target_aug)
213 | else:
214 | raise NotImplementedError
215 | test_transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize(mean, std)])
216 | xshape = (1, 3, 16, 16)
217 | elif name == 'tiered':
218 | lists = [transforms.RandomHorizontalFlip(), transforms.RandomCrop(80, padding=4), transforms.ToTensor(), transforms.Normalize(mean, std)]
219 | if cutout > 0 : lists += [CUTOUT(cutout)]
220 | train_transform = transforms.Compose(lists)
221 | test_transform = transforms.Compose([transforms.CenterCrop(80), transforms.ToTensor(), transforms.Normalize(mean, std)])
222 | xshape = (1, 3, 32, 32)
223 | elif name.startswith('imagenet-1k'):
224 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
225 | if name == 'imagenet-1k':
226 | xlists = [transforms.RandomResizedCrop(224)]
227 | xlists.append(
228 | transforms.ColorJitter(
229 | brightness=0.4,
230 | contrast=0.4,
231 | saturation=0.4,
232 | hue=0.2))
233 | xlists.append( Lighting(0.1))
234 | elif name == 'imagenet-1k-s':
235 | xlists = [transforms.RandomResizedCrop(224, scale=(0.2, 1.0))]
236 | else: raise ValueError('invalid name : {:}'.format(name))
237 | xlists.append( transforms.RandomHorizontalFlip(p=0.5) )
238 | xlists.append( transforms.ToTensor() )
239 | xlists.append( normalize )
240 | train_transform = transforms.Compose(xlists)
241 | test_transform = transforms.Compose([transforms.Resize(256), transforms.CenterCrop(224), transforms.ToTensor(), normalize])
242 | xshape = (1, 3, 224, 224)
243 | else:
244 | raise TypeError("Unknow dataset : {:}".format(name))
245 |
246 | if name == 'cifar10':
247 | if random_label:
248 | train_data = Cifar10RandomLabels(root=root, train=True , transform=train_transform, download=True, rand_seed=rand_seed)
249 | else:
250 | train_data = datasets.CIFAR10(root=root, train=True, transform=train_transform, download=True)
251 | test_data = datasets.CIFAR10(root=root, train=True , transform=test_transform, download=True)
252 | # test_data = datasets.CIFAR10(root=root, train=False, transform=test_transform , download=True)
253 | assert len(train_data) == 50000 and len(test_data) == 50000
254 | elif name == 'cifar100':
255 | if random_label:
256 | train_data = Cifar100RandomLabels(root=root, train=True , transform=train_transform, download=True, rand_seed=rand_seed)
257 | else:
258 | train_data = datasets.CIFAR100(root=root, train=True , transform=train_transform, download=True)
259 | test_data = datasets.CIFAR100(root=root, train=True, transform=test_transform , download=True)
260 | assert len(train_data) == 50000 and len(test_data) == 50000
261 | elif name.startswith('imagenet-1k'):
262 | train_data = dset.ImageFolder(osp.join(root, 'train'), train_transform)
263 | test_data = dset.ImageFolder(osp.join(root, 'val'), test_transform)
264 | assert len(train_data) == 1281167 and len(test_data) == 50000, 'invalid number of images : {:} & {:} vs {:} & {:}'.format(len(train_data), len(test_data), 1281167, 50000)
265 | elif name == 'ImageNet16':
266 | if random_label:
267 | train_data = ImageNet16RandomLabels(root=root, train=True ,transform=train_transform, rand_seed=rand_seed)
268 | else:
269 | train_data = ImageNet16(root=root, train=True, transform=train_transform)
270 | test_data = ImageNet16(root=root, train=False, transform=test_transform)
271 | assert len(train_data) == 1281167 and len(test_data) == 50000
272 | elif name == 'ImageNet16-120':
273 | if random_label:
274 | train_data = ImageNet16RandomLabels(root=root, train=True , transform=train_transform, num_classes=120, use_num_of_class_only=120, rand_seed=rand_seed)
275 | else:
276 | train_data = ImageNet16(root=root, train=True , transform=train_transform, use_num_of_class_only=120)
277 | test_data = ImageNet16(root=root, train=False, transform=test_transform , use_num_of_class_only=120)
278 | assert len(train_data) == 151700 and len(test_data) == 6000
279 | elif name == 'ImageNet16-150':
280 | if random_label:
281 | train_data = ImageNet16RandomLabels(root=root, train=True , transform=train_transform, num_classes=150, use_num_of_class_only=150, rand_seed=rand_seed)
282 | else:
283 | train_data = ImageNet16(root=root, train=True , transform=train_transform, use_num_of_class_only=150)
284 | test_data = ImageNet16(root=root, train=False, transform=test_transform , use_num_of_class_only=150)
285 | assert len(train_data) == 190272 and len(test_data) == 7500
286 | elif name == 'ImageNet16-200':
287 | if random_label:
288 | train_data = ImageNet16RandomLabels(root=root, train=True ,transform=train_transform, num_classes=200, use_num_of_class_only=200, rand_seed=rand_seed)
289 | else:
290 | train_data = ImageNet16(root=root, train=True , transform=train_transform, use_num_of_class_only=200)
291 | test_data = ImageNet16(root=root, train=False, transform=test_transform , use_num_of_class_only=200)
292 | assert len(train_data) == 254775 and len(test_data) == 10000
293 | else: raise TypeError("Unknow dataset : {:}".format(name))
294 |
295 | class_num = Dataset2Class[name]
296 | return train_data, test_data, xshape, class_num
297 |
298 |
299 | def get_nas_search_loaders(train_data, valid_data, dataset, config_root, batch_size, workers, use_valid_no_shuffle=False):
300 | # NOTE: detailed dataset configuration is given in NAS-BENCH-201 paper, https://arxiv.org/pdf/2001.00326.pdf.
301 | if isinstance(batch_size, (list,tuple)):
302 | batch, test_batch = batch_size
303 | else:
304 | batch, test_batch = batch_size, batch_size
305 | if dataset == 'cifar10' or dataset == 'cifar100':
306 | #split_Fpath = 'configs/nas-benchmark/cifar-split.txt'
307 | if dataset == 'cifar10':
308 | cifar_split = load_dataset_config('{:}/cifar-split.txt'.format(config_root), None, None)
309 | elif dataset == 'cifar100':
310 | cifar_split = load_dataset_config('{:}/cifar100-split.txt'.format(config_root), None, None)
311 | train_split, valid_split = cifar_split.train, cifar_split.valid # search over the proposed training and validation set
312 | #logger.log('Load split file from {:}'.format(split_Fpath)) # they are two disjoint groups in the original CIFAR-10 training set
313 | # To split data
314 | xvalid_data = valid_data
315 | search_data = SearchDataset(dataset, train_data, train_split, valid_split)
316 | # data loader
317 | search_loader = torch.utils.data.DataLoader(search_data, batch_size=batch, shuffle=True , num_workers=workers, pin_memory=True)
318 | train_loader = torch.utils.data.DataLoader(train_data , batch_size=batch, sampler=torch.utils.data.sampler.SubsetRandomSampler(train_split), num_workers=workers, pin_memory=True)
319 | valid_loader = torch.utils.data.DataLoader(xvalid_data, batch_size=test_batch, sampler=torch.utils.data.sampler.SubsetRandomSampler(valid_split), num_workers=workers, pin_memory=True)
320 | if use_valid_no_shuffle:
321 | # NOTE: using validation dataset
322 | valid_loader_no_shuffle = torch.utils.data.DataLoader(torch.utils.data.Subset(xvalid_data, valid_split), batch_size=test_batch, shuffle=False, num_workers=workers, pin_memory=True)
323 | # NOTE: using search training dataset
324 | # valid_loader_no_shuffle = torch.utils.data.DataLoader(search_data, batch_size=test_batch, shuffle=False , num_workers=workers, pin_memory=True)
325 | elif dataset == 'ImageNet16-120':
326 | imagenet_test_split = load_dataset_config('{:}/imagenet-16-120-test-split.txt'.format(config_root), None, None)
327 | search_train_data = train_data
328 | search_valid_data = deepcopy(valid_data) ; search_valid_data.transform = train_data.transform
329 | search_data = SearchDataset(dataset, [search_train_data,search_valid_data], list(range(len(search_train_data))), imagenet_test_split.xvalid)
330 | search_loader = torch.utils.data.DataLoader(search_data, batch_size=batch, shuffle=True , num_workers=workers, pin_memory=True)
331 | train_loader = torch.utils.data.DataLoader(train_data , batch_size=batch, shuffle=True , num_workers=workers, pin_memory=True)
332 | valid_loader = torch.utils.data.DataLoader(valid_data , batch_size=test_batch, sampler=torch.utils.data.sampler.SubsetRandomSampler(imagenet_test_split.xvalid), num_workers=workers, pin_memory=True)
333 | if use_valid_no_shuffle:
334 | # NOTE: using validation dataset
335 | valid_loader_no_shuffle = torch.utils.data.DataLoader(torch.utils.data.Subset(valid_data, imagenet_test_split.xvalid), batch_size=test_batch, shuffle=False, num_workers=workers, pin_memory=True)
336 | # NOTE: using search training dataset
337 | # valid_loader_no_shuffle = torch.utils.data.DataLoader(search_data, batch_size=test_batch, shuffle=False , num_workers=workers, pin_memory=True)
338 | else:
339 | raise ValueError('invalid dataset : {:}'.format(dataset))
340 |
341 | if use_valid_no_shuffle:
342 | return search_loader, train_loader, valid_loader, valid_loader_no_shuffle
343 | else:
344 | return search_loader, train_loader, valid_loader
345 |
346 | def get_train_transform(aug, image_size, mean, std):
347 |
348 | if aug == 'BYOL_Tau':
349 | transform = transforms.Compose([
350 | transforms.RandomResizedCrop(image_size),
351 | transforms.RandomHorizontalFlip(),
352 | transforms.RandomApply([transforms.ColorJitter(0.4, 0.4, 0.2, 0.1)], p=0.8),
353 | transforms.RandomGrayscale(p=0.2),
354 | transforms.RandomApply([GaussianBlur([0.1, 2.0])], p=1.0),
355 | transforms.RandomApply([Solarization(128)], p=0.0),
356 | transforms.ToTensor(),
357 | transforms.Normalize(mean, std),
358 |
359 | ])
360 | elif aug == 'BYOL_Tau_Hat':
361 | transform = transforms.Compose([
362 | transforms.RandomResizedCrop(image_size),
363 | transforms.RandomHorizontalFlip(),
364 | transforms.RandomApply([transforms.ColorJitter(0.4, 0.4, 0.2, 0.1)], p=0.8),
365 | transforms.RandomGrayscale(p=0.2),
366 | transforms.RandomApply([GaussianBlur([0.1, 2.0])], p=0.1),
367 | transforms.RandomApply([Solarization(128)], p=0.2),
368 | transforms.ToTensor(),
369 | transforms.Normalize(mean, std),
370 | ])
371 | else:
372 | raise NotImplementedError
373 |
374 | return transform
375 |
376 |
377 | class TwoImageAugmentations:
378 | def __init__(self, online_aug, target_aug):
379 | self.online_aug = online_aug
380 | self.target_aug = target_aug
381 |
382 | def __call__(self, x):
383 | online_image = self.online_aug(x)
384 | target_image = self.target_aug(x)
385 | return [online_image, target_image]
386 |
387 | class GaussianBlur(object):
388 | """Gaussian blur augmentation in SimCLR https://arxiv.org/abs/2002.05709"""
389 |
390 | def __init__(self, sigma=[.1, 2.]):
391 | self.sigma = sigma
392 |
393 | def __call__(self, x):
394 | sigma = random.uniform(self.sigma[0], self.sigma[1])
395 | x = x.filter(ImageFilter.GaussianBlur(radius=sigma))
396 | return x
397 |
398 |
399 | class Solarization(object):
400 | def __init__(self, magnitude=128):
401 | self.magnitude = magnitude
402 |
403 | def __call__(self, x):
404 | x = ImageOps.solarize(x, self.magnitude)
405 | return x
406 |
407 | if __name__ == '__main__':
408 | byol = True
409 | train_data, test_data, xshape, class_num = get_datasets('cifar10', '/home/zhangxuanyang/dataset/cifar.python/', -1, byol)
410 | search_loader, _, valid_loader = get_nas_search_loaders(train_data, test_data, 'cifar10', 'configs/nas-benchmark/', \
411 | (3, 3), 4)
412 | for step, (base_inputs, base_targets, arch_inputs, arch_targets) in enumerate(search_loader):
413 | print(base_inputs)
414 | break
415 |
416 | # import pdb; pdb.set_trace()
417 |
--------------------------------------------------------------------------------
/train_supernet/datasets/test_utils.py:
--------------------------------------------------------------------------------
1 | ##################################################
2 | # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 #
3 | ##################################################
4 | def test_imagenet_data(imagenet):
5 | total_length = len(imagenet)
6 | assert total_length == 1281166 or total_length == 50000, 'The length of ImageNet is wrong : {}'.format(total_length)
7 | map_id = {}
8 | for index in range(total_length):
9 | path, target = imagenet.imgs[index]
10 | folder, image_name = os.path.split(path)
11 | _, folder = os.path.split(folder)
12 | if folder not in map_id:
13 | map_id[folder] = target
14 | else:
15 | assert map_id[folder] == target, 'Class : {} is not {}'.format(folder, target)
16 | assert image_name.find(folder) == 0, '{} is wrong.'.format(path)
17 | print ('Check ImageNet Dataset OK')
18 |
--------------------------------------------------------------------------------
/train_supernet/genotypes.py:
--------------------------------------------------------------------------------
1 | '''
2 | Copied from https://github.com/megvii-model/RLNAS/blob/main/darts_search_space/cifar10/rlnas/train_supernet/genotypes.py
3 | '''
4 | from collections import namedtuple
5 |
6 | Genotype = namedtuple('Genotype', 'normal normal_concat reduce reduce_concat')
7 |
8 | PRIMITIVES = [
9 | 'none',
10 | 'max_pool_3x3',
11 | 'avg_pool_3x3',
12 | 'skip_connect',
13 | 'sep_conv_3x3',
14 | 'sep_conv_5x5',
15 | 'dil_conv_3x3',
16 | 'dil_conv_5x5'
17 | ]
18 |
19 | NASNet = Genotype(
20 | normal = [
21 | ('sep_conv_5x5', 1),
22 | ('sep_conv_3x3', 0),
23 | ('sep_conv_5x5', 0),
24 | ('sep_conv_3x3', 0),
25 | ('avg_pool_3x3', 1),
26 | ('skip_connect', 0),
27 | ('avg_pool_3x3', 0),
28 | ('avg_pool_3x3', 0),
29 | ('sep_conv_3x3', 1),
30 | ('skip_connect', 1),
31 | ],
32 | normal_concat = [2, 3, 4, 5, 6],
33 | reduce = [
34 | ('sep_conv_5x5', 1),
35 | ('sep_conv_7x7', 0),
36 | ('max_pool_3x3', 1),
37 | ('sep_conv_7x7', 0),
38 | ('avg_pool_3x3', 1),
39 | ('sep_conv_5x5', 0),
40 | ('skip_connect', 3),
41 | ('avg_pool_3x3', 2),
42 | ('sep_conv_3x3', 2),
43 | ('max_pool_3x3', 1),
44 | ],
45 | reduce_concat = [4, 5, 6],
46 | )
47 |
48 | AmoebaNet = Genotype(
49 | normal = [
50 | ('avg_pool_3x3', 0),
51 | ('max_pool_3x3', 1),
52 | ('sep_conv_3x3', 0),
53 | ('sep_conv_5x5', 2),
54 | ('sep_conv_3x3', 0),
55 | ('avg_pool_3x3', 3),
56 | ('sep_conv_3x3', 1),
57 | ('skip_connect', 1),
58 | ('skip_connect', 0),
59 | ('avg_pool_3x3', 1),
60 | ],
61 | normal_concat = [4, 5, 6],
62 | reduce = [
63 | ('avg_pool_3x3', 0),
64 | ('sep_conv_3x3', 1),
65 | ('max_pool_3x3', 0),
66 | ('sep_conv_7x7', 2),
67 | ('sep_conv_7x7', 0),
68 | ('avg_pool_3x3', 1),
69 | ('max_pool_3x3', 0),
70 | ('max_pool_3x3', 1),
71 | ('conv_7x1_1x7', 0),
72 | ('sep_conv_3x3', 5),
73 | ],
74 | reduce_concat = [3, 4, 6]
75 | )
76 |
77 | DARTS_V1 = Genotype(normal=[('sep_conv_3x3', 1), ('sep_conv_3x3', 0), ('skip_connect', 0), ('sep_conv_3x3', 1), ('skip_connect', 0), ('sep_conv_3x3', 1), ('sep_conv_3x3', 0), ('skip_connect', 2)], normal_concat=[2, 3, 4, 5], reduce=[('max_pool_3x3', 0), ('max_pool_3x3', 1), ('skip_connect', 2), ('max_pool_3x3', 0), ('max_pool_3x3', 0), ('skip_connect', 2), ('skip_connect', 2), ('avg_pool_3x3', 0)], reduce_concat=[2, 3, 4, 5])
78 | DARTS_V2 = Genotype(normal=[('sep_conv_3x3', 0), ('sep_conv_3x3', 1), ('sep_conv_3x3', 0), ('sep_conv_3x3', 1), ('sep_conv_3x3', 1), ('skip_connect', 0), ('skip_connect', 0), ('dil_conv_3x3', 2)], normal_concat=[2, 3, 4, 5], reduce=[('max_pool_3x3', 0), ('max_pool_3x3', 1), ('skip_connect', 2), ('max_pool_3x3', 1), ('max_pool_3x3', 0), ('skip_connect', 2), ('skip_connect', 2), ('max_pool_3x3', 1)], reduce_concat=[2, 3, 4, 5])
79 |
80 | DARTS = DARTS_V2
81 |
82 |
--------------------------------------------------------------------------------
/train_supernet/operations.py:
--------------------------------------------------------------------------------
1 | '''
2 | Copied from https://github.com/megvii-model/RLNAS/blob/main/darts_search_space/cifar10/rlnas/train_supernet/operations.py
3 | '''
4 | import torch
5 | import torch.nn as nn
6 |
7 | OPS = {
8 | 'none' : lambda C, stride, affine: Zero(stride),
9 | 'avg_pool_3x3' : lambda C, stride, affine: nn.AvgPool2d(3, stride=stride, padding=1, count_include_pad=False),
10 | 'max_pool_3x3' : lambda C, stride, affine: nn.MaxPool2d(3, stride=stride, padding=1),
11 | 'skip_connect' : lambda C, stride, affine: Identity() if stride == 1 else FactorizedReduce(C, C, affine=affine),
12 | 'sep_conv_3x3' : lambda C, stride, affine: SepConv(C, C, 3, stride, 1, affine=affine),
13 | 'sep_conv_5x5' : lambda C, stride, affine: SepConv(C, C, 5, stride, 2, affine=affine),
14 | 'sep_conv_7x7' : lambda C, stride, affine: SepConv(C, C, 7, stride, 3, affine=affine),
15 | 'dil_conv_3x3' : lambda C, stride, affine: DilConv(C, C, 3, stride, 2, 2, affine=affine),
16 | 'dil_conv_5x5' : lambda C, stride, affine: DilConv(C, C, 5, stride, 4, 2, affine=affine),
17 | 'conv_7x1_1x7' : lambda C, stride, affine: nn.Sequential(
18 | nn.ReLU(inplace=False),
19 | nn.Conv2d(C, C, (1,7), stride=(1, stride), padding=(0, 3), bias=False),
20 | nn.Conv2d(C, C, (7,1), stride=(stride, 1), padding=(3, 0), bias=False),
21 | nn.BatchNorm2d(C, affine=affine)
22 | ),
23 | }
24 |
25 | class ReLUConvBN(nn.Module):
26 |
27 | def __init__(self, C_in, C_out, kernel_size, stride, padding, affine=True):
28 | super(ReLUConvBN, self).__init__()
29 | self.op = nn.Sequential(
30 | nn.ReLU(inplace=False),
31 | nn.Conv2d(C_in, C_out, kernel_size, stride=stride, padding=padding, bias=False),
32 | nn.BatchNorm2d(C_out, affine=affine)
33 | )
34 |
35 | def forward(self, x, rngs=None):
36 | return self.op(x)
37 |
38 | class DilConv(nn.Module):
39 |
40 | def __init__(self, C_in, C_out, kernel_size, stride, padding, dilation, affine=True):
41 | super(DilConv, self).__init__()
42 | self.op = nn.Sequential(
43 | nn.ReLU(inplace=False),
44 | nn.Conv2d(C_in, C_in, kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation, groups=C_in, bias=False),
45 | nn.Conv2d(C_in, C_out, kernel_size=1, padding=0, bias=False),
46 | nn.BatchNorm2d(C_out, affine=affine),
47 | )
48 |
49 | def forward(self, x, rngs=None):
50 | return self.op(x)
51 |
52 | class SepConv(nn.Module):
53 |
54 | def __init__(self, C_in, C_out, kernel_size, stride, padding, affine=True):
55 | super(SepConv, self).__init__()
56 | self.op = nn.Sequential(
57 | nn.ReLU(inplace=False),
58 | nn.Conv2d(C_in, C_in, kernel_size=kernel_size, stride=stride, padding=padding, groups=C_in, bias=False),
59 | nn.Conv2d(C_in, C_in, kernel_size=1, padding=0, bias=False),
60 | nn.BatchNorm2d(C_in, affine=affine),
61 | nn.ReLU(inplace=False),
62 | nn.Conv2d(C_in, C_in, kernel_size=kernel_size, stride=1, padding=padding, groups=C_in, bias=False),
63 | nn.Conv2d(C_in, C_out, kernel_size=1, padding=0, bias=False),
64 | nn.BatchNorm2d(C_out, affine=affine),
65 | )
66 |
67 | def forward(self, x, rngs=None):
68 | return self.op(x)
69 |
70 |
71 | class Identity(nn.Module):
72 |
73 | def __init__(self):
74 | super(Identity, self).__init__()
75 |
76 | def forward(self, x, rngs=None):
77 | return x
78 |
79 | class Zero(nn.Module):
80 |
81 | def __init__(self, stride):
82 | super(Zero, self).__init__()
83 | self.stride = stride
84 | def forward(self, x, rngs=None):
85 | n, c, h, w = x.size()
86 | h //= self.stride
87 | w //= self.stride
88 | if x.is_cuda:
89 | with torch.cuda.device(x.get_device()):
90 | padding = torch.cuda.FloatTensor(n, c, h, w).fill_(0)
91 | else:
92 | padding = torch.FloatTensor(n, c, h, w).fill_(0)
93 | return padding
94 |
95 | class FactorizedReduce(nn.Module):
96 |
97 | def __init__(self, C_in, C_out, affine=True):
98 | super(FactorizedReduce, self).__init__()
99 | assert C_out % 2 == 0
100 | self.relu = nn.ReLU(inplace=False)
101 | self.conv_1 = nn.Conv2d(C_in, C_out // 2, 1, stride=2, padding=0, bias=False)
102 | self.conv_2 = nn.Conv2d(C_in, C_out // 2, 1, stride=2, padding=0, bias=False)
103 | self.bn = nn.BatchNorm2d(C_out, affine=affine)
104 |
105 | def forward(self, x, rngs=None):
106 | x = self.relu(x)
107 | out = torch.cat([self.conv_1(x), self.conv_2(x[:,:,1:,1:])], dim=1)
108 | out = self.bn(out)
109 | return out
110 |
111 |
112 |
--------------------------------------------------------------------------------
/train_supernet/super_model.py:
--------------------------------------------------------------------------------
1 | '''
2 | Copied from https://github.com/megvii-model/RLNAS/blob/main/darts_search_space/cifar10/rlnas/train_supernet/super_model.py
3 | '''
4 | import torch
5 | import torch.nn as nn
6 | import torch.nn.functional as F
7 | from operations import *
8 | from torch.autograd import Variable
9 | from genotypes import PRIMITIVES
10 | from genotypes import Genotype
11 | import math
12 | import numpy as np
13 | from config import config
14 | import copy
15 | from utils import check_cand
16 |
17 | class MixedOp(nn.Module):
18 |
19 | def __init__(self, C, stride):
20 | super(MixedOp, self).__init__()
21 | self._ops = nn.ModuleList()
22 | for idx, primitive in enumerate(PRIMITIVES):
23 | op = OPS[primitive](C, stride, True)
24 | op.idx = idx
25 | if 'pool' in primitive:
26 | op = nn.Sequential(op, nn.BatchNorm2d(C, affine=True))
27 | self._ops.append(op)
28 |
29 | def forward(self, x, rng):
30 | return self._ops[rng](x)
31 |
32 |
33 | class Cell(nn.Module):
34 |
35 | def __init__(self, steps, multiplier, C_prev_prev, C_prev, C, reduction, reduction_prev):
36 | super(Cell, self).__init__()
37 | if reduction_prev:
38 | # NOTE: if K-1 cell output was from stride-2 op, K-2 cell output should shrink its spatial size by stride-2.
39 | self.preprocess0 = FactorizedReduce(C_prev_prev, C, affine=True)
40 | else:
41 | self.preprocess0 = ReLUConvBN(C_prev_prev, C, 1, 1, 0, affine=True)
42 | self.preprocess1 = ReLUConvBN(C_prev, C, 1, 1, 0, affine=True)
43 | self._steps = steps
44 | self._multiplier = multiplier
45 | self._C = C
46 | self.out_C = self._multiplier * C
47 | self.reduction = reduction
48 |
49 | self._ops = nn.ModuleList()
50 | self._bns = nn.ModuleList()
51 | self.time_stamp = 1
52 |
53 | for i in range(self._steps):
54 | for j in range(2+i):
55 | stride = 2 if reduction and j < 2 else 1
56 | op = MixedOp(C, stride)
57 | self._ops.append(op)
58 |
59 | def forward(self, s0, s1, rngs):
60 | s0 = self.preprocess0(s0)
61 | s1 = self.preprocess1(s1)
62 | states = [s0, s1]
63 | offset = 0
64 | for i in range(self._steps):
65 | s = sum(self._ops[offset+j](h, rngs[offset+j]) for j, h in enumerate(states))
66 | offset += len(states)
67 | states.append(s)
68 | return torch.cat(states[-self._multiplier:], dim=1)
69 |
70 | class Network(nn.Module):
71 | def __init__(self, C=16, num_classes=10, layers=8, steps=4, multiplier=4, stem_multiplier=3):
72 | super(Network, self).__init__()
73 | self._C = C
74 | self._num_classes = num_classes
75 | self._layers = layers
76 | self._steps = steps
77 | self._multiplier = multiplier
78 |
79 | C_curr = stem_multiplier * C
80 |
81 | self.stem = nn.Sequential(
82 | nn.Conv2d(3, C_curr, 3, padding=1, bias=False),
83 | nn.BatchNorm2d(C_curr)
84 | )
85 |
86 | C_prev_prev, C_prev, C_curr = C_curr, C_curr, C
87 |
88 | self.cells = nn.ModuleList()
89 | reduction_prev = False
90 |
91 | for i in range(layers):
92 | if i in [layers // 3, 2 * layers // 3]:
93 | C_curr *= 2
94 | reduction = True
95 | else:
96 | reduction = False
97 | cell = Cell(steps, multiplier, C_prev_prev, C_prev, C_curr, reduction, reduction_prev)
98 | reduction_prev = reduction
99 | self.cells += [cell]
100 | C_prev_prev, C_prev = C_prev, multiplier * C_curr
101 |
102 | self.global_pooling = nn.AdaptiveAvgPool2d(1)
103 | self.classifier = nn.Linear(C_prev, num_classes)
104 |
105 | def forward(self, input, normal_rng, reduction_rng):
106 | s0 = s1 = self.stem(input)
107 | for i, cell in enumerate(self.cells):
108 | if i in [self._layers // 3, 2 * self._layers // 3]:
109 | s0, s1 = s1, cell(s0, s1, reduction_rng)
110 | else:
111 | s0, s1 = s1, cell(s0, s1, normal_rng)
112 | out = self.global_pooling(s1)
113 | logits = self.classifier(out.view(out.size(0),-1))
114 | return logits
115 |
116 | if __name__ == '__main__':
117 | from copy import deepcopy
118 | np.random.seed(0)
119 | model = Network()
120 | print(model)
121 | exit(0)
122 | operations = []
123 | for _ in range(config.edges):
124 | operations.append(list(range(config.op_num)))
125 | norm_rng = [np.random.randint(len(config.blocks_keys)) for i in range(config.edges)]
126 | reduction_rng = [np.random.randint(len(config.blocks_keys)) for i in range(config.edges)]
127 | print("operations: ", operations)
128 | print("norm rng: ", norm_rng)
129 | norm_rng = check_cand(norm_rng, operations)
130 | print("after check_cand norm_rng: ", norm_rng)
131 | reduction_rng = check_cand(reduction_rng, operations)
132 | x = torch.rand(4,3,32,32)
133 | logit = model(x, norm_rng, reduction_rng)
134 | print('logit:{0}'.format(logit))
--------------------------------------------------------------------------------
/train_supernet/train.py:
--------------------------------------------------------------------------------
1 | '''
2 | Code adapted from https://github.com/megvii-model/RLNAS/blob/main/darts_search_space/cifar10/rlnas/train_supernet/train.py
3 | '''
4 | import os
5 | import sys
6 | import time
7 | import glob
8 | import numpy as np
9 | import torch
10 | import utils
11 | import logging
12 | import argparse
13 | import torch.nn as nn
14 | import torch.utils
15 | import torch.nn.functional as F
16 | import torchvision.datasets as dset
17 | import torch.backends.cudnn as cudnn
18 |
19 | from torch.autograd import Variable
20 | from super_model import Network
21 | from copy import deepcopy
22 | from config import config
23 | from datasets import get_datasets, get_nas_search_loaders
24 |
25 | parser = argparse.ArgumentParser("cifar")
26 | parser.add_argument('--data', type=str, default='../data', help='location of the data corpus')
27 | parser.add_argument('--batch_size', type=int, default=64, help='batch size')
28 | parser.add_argument('--learning_rate', type=float, default=0.025, help='init learning rate')
29 | parser.add_argument('--learning_rate_min', type=float, default=0.001, help='min learning rate')
30 | parser.add_argument('--momentum', type=float, default=0.9, help='momentum')
31 | parser.add_argument('--weight_decay', type=float, default=3e-4, help='weight decay')
32 | parser.add_argument('--report_freq', type=float, default=50, help='report frequency')
33 | parser.add_argument('--gpu', type=int, default=0, help='gpu device id')
34 | parser.add_argument('--epochs', type=int, default=50, help='num of training epochs')
35 | parser.add_argument('--init_channels', type=int, default=16, help='num of init channels')
36 | parser.add_argument('--layers', type=int, default=8, help='total number of layers')
37 | parser.add_argument('--model_path', type=str, default='saved_models', help='path to save the model')
38 | parser.add_argument('--cutout', action='store_true', default=False, help='use cutout')
39 | parser.add_argument('--cutout_length', type=int, default=16, help='cutout length')
40 | parser.add_argument('--drop_path_prob', type=float, default=0.3, help='drop path probability')
41 | parser.add_argument('--save', type=str, default='models', help='experiment name')
42 | parser.add_argument('--seed', type=int, default=1, help='random seed')
43 | parser.add_argument('--grad_clip', type=float, default=5, help='gradient clipping')
44 | parser.add_argument('--train_portion', type=float, default=0.5, help='portion of training data')
45 | parser.add_argument('--unrolled', action='store_true', default=False, help='use one-step unrolled validation loss')
46 | parser.add_argument('--arch_learning_rate', type=float, default=3e-4, help='learning rate for arch encoding')
47 | parser.add_argument('--arch_weight_decay', type=float, default=1e-3, help='weight decay for arch encoding')
48 | parser.add_argument('--random_label', type=int, choices=[0, 1], default=1, help='Whether use random label for dataset or not. (default: True)')
49 | parser.add_argument('--split_data', type=int, choices=[0, 1], default=1, help='Whether use split data for training & validation. (default: True)')
50 | args = parser.parse_args()
51 |
52 | args.split_data = bool(args.split_data)
53 |
54 | utils.create_exp_dir(args.save, scripts_to_save=glob.glob('*.py'))
55 |
56 | log_format = '%(asctime)s %(message)s'
57 | logging.basicConfig(stream=sys.stdout, level=logging.INFO,
58 | format=log_format, datefmt='%m/%d %I:%M:%S %p')
59 | fh = logging.FileHandler(os.path.join(args.save, 'log.txt'))
60 | fh.setFormatter(logging.Formatter(log_format))
61 | logging.getLogger().addHandler(fh)
62 |
63 | CIFAR_CLASSES = 100
64 |
65 | def main():
66 | if not torch.cuda.is_available():
67 | logging.info('no gpu device available')
68 | sys.exit(1)
69 |
70 | np.random.seed(args.seed)
71 | torch.cuda.set_device(args.gpu)
72 | cudnn.benchmark = True
73 | torch.manual_seed(args.seed)
74 | cudnn.enabled=True
75 | torch.cuda.manual_seed(args.seed)
76 | seed = args.seed
77 | logging.info('gpu device = %d' % args.gpu)
78 | logging.info("args = %s", args)
79 |
80 | criterion = nn.CrossEntropyLoss()
81 | criterion = criterion.cuda()
82 | # NOTE: layers: number of cells in network
83 | model = Network(args.init_channels, CIFAR_CLASSES, args.layers)
84 | model = model.cuda()
85 | logging.info("param size = %fMB", utils.count_parameters_in_MB(model))
86 |
87 | optimizer = torch.optim.SGD(
88 | model.parameters(),
89 | args.learning_rate,
90 | momentum=args.momentum,
91 | weight_decay=args.weight_decay)
92 |
93 | if args.split_data: # NOTE: split train data in half to be new train, val set. new train is used for supernet training, new val set is used for evaluation
94 | train_data, valid_data, xshape, class_num = get_datasets('cifar100', args.data, -1, args.seed, random_label=bool(args.random_label))
95 | train_queue, _, _, valid_queue = get_nas_search_loaders(train_data, valid_data, 'cifar100',
96 | 'datasets/configs/', \
97 | (args.batch_size, args.batch_size), 4, use_valid_no_shuffle=True)
98 | else:
99 | assert ValueError("only --split_data 1 is supported")
100 |
101 | scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
102 | optimizer, float(args.epochs), eta_min=args.learning_rate_min)
103 |
104 | operations = []
105 | for _ in range(config.edges):
106 | operations.append(list(range(config.op_num)))
107 | print('operations={}'.format(operations))
108 |
109 | utils.save_checkpoint({'epoch': -1,
110 | 'state_dict': model.state_dict(),
111 | 'optimizer': optimizer.state_dict()}, args.save)
112 |
113 | for epoch in range(args.epochs):
114 | scheduler.step()
115 | lr = scheduler.get_lr()[0]
116 | logging.info('epoch %d lr %e', epoch, lr)
117 |
118 | # training
119 | seed, train_acc, train_obj = train(train_queue, model, criterion, optimizer, operations, seed, epoch)
120 | logging.info('train_acc %f', train_acc)
121 |
122 | # validation
123 | valid_acc, valid_obj = infer(valid_queue, model, criterion, seed, operations)
124 | logging.info('valid_acc %f', valid_acc)
125 |
126 | if (epoch+1)%5 == 0:
127 | utils.save_checkpoint({'epoch':epoch,
128 | 'state_dict':model.state_dict(),
129 | 'optimizer':optimizer.state_dict()}, args.save)
130 |
131 | def get_random_cand(seed, operations):
132 | # Uniform Sampling
133 | rng = []
134 | for op in operations:
135 | np.random.seed(seed)
136 | k = np.random.randint(len(op))
137 | select_op = op[k]
138 | rng.append(select_op)
139 | seed += 1
140 |
141 | return rng, seed
142 |
143 | def train(train_queue, model, criterion, optimizer, operations, seed, epoch):
144 | objs = utils.AvgrageMeter()
145 | top1 = utils.AvgrageMeter()
146 | top5 = utils.AvgrageMeter()
147 |
148 | model.train()
149 |
150 | for step, batch in enumerate(train_queue):
151 | if len(batch) == 4:
152 | input, target, _, _ = batch
153 | elif len(batch) == 2:
154 | input, target = batch
155 | n = input.size(0)
156 |
157 | input = input.cuda(non_blocking=True)
158 | target = target.cuda(non_blocking=True)
159 |
160 | optimizer.zero_grad()
161 |
162 | # NOTE: per training iteration, operation per edge is randomly sampled.. (as in SPOS!. no architecture parameters.)
163 | normal_rng, seed = get_random_cand(seed, operations)
164 | reduction_rng, seed = get_random_cand(seed, operations)
165 |
166 | normal_rng = utils.check_cand(normal_rng, operations)
167 | reduction_rng = utils.check_cand(reduction_rng, operations)
168 |
169 | logits = model(input, normal_rng, reduction_rng)
170 | loss = criterion(logits, target)
171 |
172 | loss.backward()
173 | nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip)
174 | optimizer.step()
175 |
176 | prec1, prec5 = utils.accuracy(logits, target, topk=(1, 5))
177 | objs.update(loss.item(), n)
178 | top1.update(prec1.item(), n)
179 | top5.update(prec5.item(), n)
180 |
181 | if step % args.report_freq == 0:
182 | logging.info('train %03d %e %f %f', step, objs.avg, top1.avg, top5.avg)
183 |
184 | return seed, top1.avg, objs.avg
185 |
186 |
187 | def infer(valid_queue, model, criterion, seed, operations):
188 | objs = utils.AvgrageMeter()
189 | top1 = utils.AvgrageMeter()
190 | top5 = utils.AvgrageMeter()
191 | model.eval()
192 |
193 | normal_rng, seed = get_random_cand(seed, operations)
194 | reduction_rng, seed = get_random_cand(seed, operations)
195 |
196 | normal_rng = utils.check_cand(normal_rng, operations)
197 | reduction_rng = utils.check_cand(reduction_rng, operations)
198 |
199 | # NOTE: no optimize for architecture parameters (abscence of architecture parameters)
200 | # NOTE: instead, randomly select operation for each edge and evaluate.
201 | for step, (input, target) in enumerate(valid_queue):
202 | input = input.cuda(non_blocking=True)
203 | target = target.cuda(non_blocking=True)
204 |
205 | with torch.no_grad():
206 | logits = model(input, normal_rng, reduction_rng)
207 | loss = criterion(logits, target)
208 |
209 | prec1, prec5 = utils.accuracy(logits, target, topk=(1, 5))
210 | n = input.size(0)
211 | objs.update(loss.item(), n)
212 | top1.update(prec1.item(), n)
213 | top5.update(prec5.item(), n)
214 |
215 | if step % args.report_freq == 0:
216 | logging.info('valid %03d %e %f %f', step, objs.avg, top1.avg, top5.avg)
217 |
218 | return top1.avg, objs.avg
219 |
220 |
221 | if __name__ == '__main__':
222 | main()
223 |
--------------------------------------------------------------------------------
/train_supernet/utils.py:
--------------------------------------------------------------------------------
1 | '''
2 | Copied from https://github.com/megvii-model/RLNAS/blob/main/darts_search_space/cifar10/rlnas/train_supernet/utils.py
3 | '''
4 | import os
5 | import sys
6 | import shutil
7 | import numpy as np
8 | import time, datetime
9 | import torch
10 | import glob
11 | import random
12 | import logging
13 | import argparse
14 | import torch.nn as nn
15 | import torch.utils
16 | import torchvision.datasets as dset
17 | import torchvision.transforms as transforms
18 | import torch.backends.cudnn as cudnn
19 | from torch.autograd import Variable
20 | import joblib
21 | import pdb
22 | import pickle
23 | from collections import defaultdict
24 | from config import config
25 | import copy
26 |
27 | def broadcast(args, obj, src, group=torch.distributed.group.WORLD, async_op=False):
28 | print('local_rank:{}, obj:{}'.format(args.local_rank, obj))
29 | obj_tensor = torch.from_numpy(np.array(obj)).cuda()
30 | torch.distributed.broadcast(obj_tensor, src, group, async_op)
31 | obj = obj_tensor.cpu().numpy()
32 | print('local_rank:{}, tensor:{}'.format(args.local_rank, obj))
33 | return obj
34 |
35 | class CrossEntropyLabelSmooth(nn.Module):
36 |
37 | def __init__(self, num_classes, epsilon):
38 | super(CrossEntropyLabelSmooth, self).__init__()
39 | self.num_classes = num_classes
40 | self.epsilon = epsilon
41 | self.logsoftmax = nn.LogSoftmax(dim=1)
42 |
43 | def forward(self, inputs, targets):
44 | log_probs = self.logsoftmax(inputs)
45 | targets = torch.zeros_like(log_probs).scatter_(1, targets.unsqueeze(1), 1)
46 | targets = (1 - self.epsilon) * targets + self.epsilon / self.num_classes
47 | loss = (-targets * log_probs).mean(0).sum()
48 | return loss
49 |
50 | def get_optimizer_schedule(model, args, total_iters):
51 | all_parameters = model.parameters()
52 | weight_parameters = []
53 | for pname, p in model.named_parameters():
54 | if p.ndimension() == 4 or 'classifier.0.weight' in pname or 'classifier.0.bias' in pname:
55 | weight_parameters.append(p)
56 | weight_parameters_id = list(map(id, weight_parameters))
57 | other_parameters = list(filter(lambda p: id(p) not in weight_parameters_id, all_parameters))
58 | optimizer = torch.optim.SGD(
59 | [{'params' : other_parameters},
60 | {'params' : weight_parameters, 'weight_decay' : args.weight_decay}],
61 | args.learning_rate,
62 | momentum=args.momentum,
63 | )
64 |
65 | delta_iters = total_iters / (1.-args.min_lr / args.learning_rate)
66 | print('delta_iters={}'.format(delta_iters))
67 | scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lambda step : (1.0-step/delta_iters), last_epoch=-1)
68 | return optimizer, scheduler
69 |
70 | def get_location(s, key):
71 | d = defaultdict(list)
72 | for k,va in [(v,i) for i,v in enumerate(s)]:
73 | d[k].append(va)
74 | return d[key]
75 |
76 | def list_substract(list1, list2):
77 | list1 = [item for item in list1 if item not in set(list2)]
78 | return list1
79 |
80 | def check_cand(cand, operations):
81 | cand = np.reshape(cand, [-1, config.edges])
82 | offset, cell_cand = 0, cand[0]
83 | for j in range(4):
84 | edges = cell_cand[offset:offset+j+2]
85 | edges_ops = operations[offset:offset+j+2]
86 | none_idxs = get_location(edges, 0)
87 | if len(none_idxs) < j:
88 | general_idxs = list_substract(range(j+2), none_idxs)
89 | num = min(j-len(none_idxs), len(general_idxs))
90 | general_idxs = np.random.choice(general_idxs, size=num, replace=False, p=None)
91 | for k in general_idxs:
92 | edges[k] = 0
93 | elif len(none_idxs) > j:
94 | none_idxs = np.random.choice(none_idxs, size=len(none_idxs)-j, replace=False, p=None)
95 | for k in none_idxs:
96 | if len(edges_ops[k]) > 1:
97 | l = np.random.randint(len(edges_ops[k])-1)
98 | edges[k] = edges_ops[k][l+1]
99 | offset += len(edges)
100 |
101 | return cell_cand.tolist()
102 |
103 | class AvgrageMeter(object):
104 |
105 | def __init__(self):
106 | self.reset()
107 |
108 | def reset(self):
109 | self.avg = 0
110 | self.sum = 0
111 | self.cnt = 0
112 |
113 | def update(self, val, n=1):
114 | self.sum += val * n
115 | self.cnt += n
116 | self.avg = self.sum / self.cnt
117 |
118 | def accuracy(output, target, topk=(1,)):
119 | maxk = max(topk)
120 | batch_size = target.size(0)
121 |
122 | _, pred = output.topk(maxk, 1, True, True)
123 | pred = pred.t()
124 | correct = pred.eq(target.view(1, -1).expand_as(pred))
125 |
126 | res = []
127 | for k in topk:
128 | # correct_k = correct[:k].view(-1).float().sum(0)
129 | # for pytorch >= 1.7.0
130 | correct_k = correct[:k].reshape(-1).float().sum(0)
131 | res.append(correct_k.mul_(100.0/batch_size))
132 | return res
133 |
134 | def save_checkpoint(state, save):
135 | if not os.path.exists(save):
136 | os.makedirs(save)
137 | filename = os.path.join(save, 'checkpoint_epoch_{}.pth.tar'.format(state['epoch']+1))
138 | torch.save(state, filename)
139 | print('Save CheckPoint....')
140 |
141 |
142 | def save(model, save, suffix):
143 | torch.save(model.module.state_dict(), save)
144 | shutil.copyfile(save, 'weight_{}.pt'.format(suffix))
145 |
146 | def create_exp_dir(path, scripts_to_save=None):
147 | if not os.path.exists(path):
148 | os.mkdir(path)
149 | print('Experiment dir : {}'.format(path))
150 |
151 | script_path = os.path.join(path, 'scripts')
152 | if scripts_to_save is not None and not os.path.exists(script_path):
153 | os.mkdir(script_path)
154 | for script in scripts_to_save:
155 | dst_file = os.path.join(path, 'scripts', os.path.basename(script))
156 | shutil.copyfile(script, dst_file)
157 |
158 | def merge_ops(rngs):
159 | cand = []
160 | for rng in rngs:
161 | for r in rng:
162 | cand.append(r)
163 | cand += [-1]
164 | cand = cand[:-1]
165 | return cand
166 |
167 | def split_ops(cand):
168 | cell, layer = 0, 0
169 | cand_ = [[]]
170 | for c in cand:
171 | if c == -1:
172 | cand_.append([])
173 | layer += 1
174 | else:
175 | cand_[layer].append(c)
176 | return cand_
177 |
178 | def get_search_space_size(operations):
179 | comb_num = 1
180 | for j in range(len(operations)):
181 | comb_num *= len(operations[j])
182 | return comb_num
183 |
184 | def count_parameters_in_MB(model):
185 | return np.sum(np.prod(v.size()) for name, v in model.named_parameters() if "auxiliary" not in name)/1e6
186 |
187 | class Cutout(object):
188 | def __init__(self, length):
189 | self.length = length
190 |
191 | def __call__(self, img):
192 | h, w = img.size(1), img.size(2)
193 | mask = np.ones((h, w), np.float32)
194 | y = np.random.randint(h)
195 | x = np.random.randint(w)
196 |
197 | y1 = np.clip(y - self.length // 2, 0, h)
198 | y2 = np.clip(y + self.length // 2, 0, h)
199 | x1 = np.clip(x - self.length // 2, 0, w)
200 | x2 = np.clip(x + self.length // 2, 0, w)
201 |
202 | mask[y1: y2, x1: x2] = 0.
203 | mask = torch.from_numpy(mask)
204 | mask = mask.expand_as(img)
205 | img *= mask
206 | return img
207 |
208 | def _data_transforms_cifar10(args):
209 | CIFAR_MEAN = [0.49139968, 0.48215827, 0.44653124]
210 | CIFAR_STD = [0.24703233, 0.24348505, 0.26158768]
211 |
212 | train_transform = transforms.Compose([
213 | transforms.RandomCrop(32, padding=4),
214 | transforms.RandomHorizontalFlip(),
215 | transforms.ToTensor(),
216 | transforms.Normalize(CIFAR_MEAN, CIFAR_STD),
217 | ])
218 | if args.cutout:
219 | train_transform.transforms.append(Cutout(args.cutout_length))
220 |
221 | valid_transform = transforms.Compose([
222 | transforms.ToTensor(),
223 | transforms.Normalize(CIFAR_MEAN, CIFAR_STD),
224 | ])
225 | return train_transform, valid_transform
226 |
227 | class Cifar10RandomLabels(dset.CIFAR10):
228 | """CIFAR10 dataset, with support for randomly corrupt labels.
229 | Params
230 | ------
231 | rand_seed: int
232 | Default 0. numpy random seed.
233 | num_classes: int
234 | Default 10. The number of classes in the dataset.
235 | """
236 | def __init__(self, rand_seed=0, num_classes=10, **kwargs):
237 | super(Cifar10RandomLabels, self).__init__(**kwargs)
238 | self.n_classes = num_classes
239 | self.rand_seed = rand_seed
240 | self.random_labels()
241 |
242 | def random_labels(self):
243 | labels = np.array(self.targets)
244 | print('num_classes:{}, random labels num:{}, random seed:{}'.format(self.n_classes, len(labels), self.rand_seed))
245 | np.random.seed(self.rand_seed)
246 | rnd_labels = np.random.randint(0, self.n_classes, len(labels))
247 | # we need to explicitly cast the labels from npy.int64 to
248 | # builtin int type, otherwise pytorch will fail...
249 | labels = [int(x) for x in rnd_labels]
250 |
251 | self.targets = labels
--------------------------------------------------------------------------------