├── LICENSE ├── NOTICE ├── README.md ├── evolution_search ├── cand_evaluator.py ├── config.py ├── datasets │ ├── DownsampledImageNet.py │ ├── SearchDatasetWrap.py │ ├── __init__.py │ ├── config_utils │ │ └── __init__.py │ ├── configs │ │ ├── cifar-split.txt │ │ ├── cifar100-split.txt │ │ ├── cifar100-test-split.txt │ │ └── imagenet-16-120-test-split.txt │ ├── get_dataset_with_transform.py │ └── test_utils.py ├── genotypes.py ├── metrics │ ├── tester_acc.py │ └── tester_wlm.py ├── operations.py ├── search.py ├── super_model.py └── utils.py ├── repo_figures ├── ABS_FBS_architecture_normal.png ├── ABS_FBS_architecture_reduce.png ├── ABS_architecture_normal.png ├── ABS_architecture_reduce.png ├── FBS_architecture_normal.png ├── FBS_architecture_reduce.png └── motivation.png ├── requirements.txt ├── retrain_architecture ├── config.py ├── genotypes.py ├── model.py ├── operations.py ├── retrain.py ├── thop │ ├── __init__.py │ ├── count_hooks.py │ ├── profile.py │ └── utils.py ├── utils.py └── visualize.py └── train_supernet ├── config.py ├── datasets ├── DownsampledImageNet.py ├── SearchDatasetWrap.py ├── __init__.py ├── config_utils │ └── __init__.py ├── configs │ ├── cifar-split.txt │ ├── cifar100-split.txt │ ├── cifar100-test-split.txt │ └── imagenet-16-120-test-split.txt ├── get_dataset_with_transform.py └── test_utils.py ├── genotypes.py ├── operations.py ├── super_model.py ├── train.py └── utils.py /LICENSE: -------------------------------------------------------------------------------- 1 | GeNAS 2 | Copyright (c) 2023-present NAVER Cloud Corp. 3 | 4 | Permission is hereby granted, free of charge, to any person obtaining a copy 5 | of this software and associated documentation files (the "Software"), to deal 6 | in the Software without restriction, including without limitation the rights 7 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 8 | copies of the Software, and to permit persons to whom the Software is 9 | furnished to do so, subject to the following conditions: 10 | 11 | The above copyright notice and this permission notice shall be included in 12 | all copies or substantial portions of the Software. 13 | 14 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 15 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 16 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 17 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 18 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 19 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN 20 | THE SOFTWARE. 21 | -------------------------------------------------------------------------------- /NOTICE: -------------------------------------------------------------------------------- 1 | GeNAS 2 | Copyright (c) 2023-present NAVER Cloud Corp. 3 | 4 | Permission is hereby granted, free of charge, to any person obtaining a copy 5 | of this software and associated documentation files (the "Software"), to deal 6 | in the Software without restriction, including without limitation the rights 7 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 8 | copies of the Software, and to permit persons to whom the Software is 9 | furnished to do so, subject to the following conditions: 10 | 11 | The above copyright notice and this permission notice shall be included in 12 | all copies or substantial portions of the Software. 13 | 14 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 15 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 16 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 17 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 18 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 19 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN 20 | THE SOFTWARE. 21 | 22 | -------------------------------------------------------------------------------------- 23 | 24 | This project contains subcomponents with separate copyright notices and license terms. 25 | Your use of the source code for these subcomponents is subject to the terms and conditions of the following licenses. 26 | 27 | ===== 28 | 29 | D-X-Y/AutoDL-Projects 30 | https://github.com/D-X-Y/AutoDL-Projects 31 | 32 | 33 | MIT License 34 | 35 | Copyright (c) since 2019.01.01, author: Xuanyi Dong (GitHub: https://github.com/D-X-Y) 36 | 37 | Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: 38 | 39 | The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. 40 | 41 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. 42 | 43 | ===== 44 | 45 | megvii-model/RLNAS 46 | https://github.com/megvii-model/RLNAS 47 | 48 | 49 | MIT License 50 | 51 | Copyright (c) 2021 megvii-model 52 | 53 | Permission is hereby granted, free of charge, to any person obtaining a copy 54 | of this software and associated documentation files (the "Software"), to deal 55 | in the Software without restriction, including without limitation the rights 56 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 57 | copies of the Software, and to permit persons to whom the Software is 58 | furnished to do so, subject to the following conditions: 59 | 60 | The above copyright notice and this permission notice shall be included in all 61 | copies or substantial portions of the Software. 62 | 63 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 64 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 65 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 66 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 67 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 68 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 69 | SOFTWARE. 70 | 71 | ===== 72 | 73 | ECP-CANDLE/Benchmarks 74 | https://github.com/ECP-CANDLE/Benchmarks 75 | 76 | 77 | MIT License 78 | 79 | Copyright (c) 2016 - 2017 ECP-CANDLE 80 | 81 | Permission is hereby granted, free of charge, to any person obtaining a copy 82 | of this software and associated documentation files (the "Software"), to deal 83 | in the Software without restriction, including without limitation the rights 84 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 85 | copies of the Software, and to permit persons to whom the Software is 86 | furnished to do so, subject to the following conditions: 87 | 88 | The above copyright notice and this permission notice shall be included in all 89 | copies or substantial portions of the Software. 90 | 91 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 92 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 93 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 94 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 95 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 96 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 97 | SOFTWARE. 98 | 99 | ===== 100 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # GeNAS (IJCAI 2023) 2 | 3 | **GeNAS: Neural Architecture Search with Better Generalization**
4 | 5 | [Joonhyun Jeong](https://bestdeveloper691.github.io/)1,2, [Joonsang Yu](https://scholar.google.co.kr/citations?user=IC6M7_IAAAAJ&hl=ko)1,3, [Geondo Park](https://scholar.google.com/citations?user=Z8SGJ60AAAAJ&hl=ko)2, [Dongyoon Han](https://scholar.google.com/citations?user=jcP7m1QAAAAJ&hl=en)3, [YoungJoon Yoo](https://yjyoo3312.github.io/)1
6 | 7 | 1 NAVER Cloud, ImageVision
8 | 2 KAIST
9 | 3 NAVER AI Lab
10 | 11 | [![](https://img.shields.io/badge/IJCAI-2023-blue)](https://ijcai-23.org/) 12 | [![Paper](https://img.shields.io/badge/Paper-arxiv.2305.08611-red)](https://arxiv.org/abs/2305.08611) 13 | 14 | ## Introduction 15 | 16 | Neural Architecture Search (NAS) aims to automatically excavate the optimal network architecture with superior test performance. Recent neural architecture search (NAS) approaches rely on validation loss or accuracy to find the superior network for the target data. In this paper, we investigate a new neural architecture search measure for 17 | excavating architectures with better generalization. We demonstrate that the flatness of the loss surface can be a promising proxy for predicting the generalization capability of neural network architectures. We evaluate our proposed method on various search spaces, showing similar or even better performance compared to the state-of-the-art NAS methods. Notably, the resultant architecture found by flatness measure generalizes robustly to various shifts in data distribution (e.g. ImageNet-V2,-A,-O), as well as various tasks such as object detection and semantic segmentation. 18 | 19 | 20 | 21 | ## Updates 22 | **_2023-08-09_** We release the official implementation of GeNAS. 23 | 24 | ## Requirements 25 | 26 | * Pytorch 1.7.1 27 | 28 | Please see [requirements](./requirements.txt) for detailed specs. 29 | 30 | ## Quick Start 31 | 32 | 1. Train SuperNet, following [SPOS](https://github.com/megvii-model/SinglePathOneShot). 33 | 34 | ```bash 35 | cd train_supernet 36 | python3 train.py \ 37 | --seed 1 \ 38 | --data [CIFAR_DATASET_DIRECTORY] \ 39 | --epochs 250 \ 40 | --save [OUTPUT_DIRECTORY] \ 41 | --random_label 0 \ 42 | --split_data 1 43 | ``` 44 | 45 | 2. Evolutionary Searching 46 | 47 | - You can skip step 1 and use [the pretrained SuperNet checkpoints](https://drive.google.com/drive/folders/19TAHE5C66n1PCLaAjcemGfmkIJQNkVKj?usp=sharing). 48 | 49 | ### Searching with Flatness 50 | 51 | ```bash 52 | cd evolutionary_search 53 | python3 search.py \ 54 | --split_data 1 \ 55 | --seed 3 \ 56 | --init_model_path [SUPERNET_WEIGHT@INITIAL_EPOCH] \ 57 | --model_path [SUPERNET_WEIGHT@FINAL_EPOCH] \ 58 | --data [CIFAR_DATASET_DIRECTORY] \ 59 | --metric wlm \ 60 | --stds 0.001,0.003,0.006 \ 61 | --max_train_img_size 850 \ 62 | --max_val_img_size 25000 \ 63 | --wlm_weight 0 \ 64 | --acc_weight 0 65 | ``` 66 | 67 | ### Searching with Angle + Flatness 68 | 69 | ```bash 70 | python3 search.py \ 71 | --split_data 1 \ 72 | --seed 3 \ 73 | --init_model_path [SUPERNET_WEIGHT@INITIAL_EPOCH] \ 74 | --model_path [SUPERNET_WEIGHT@FINAL_EPOCH] \ 75 | --data [CIFAR_DATASET_DIRECTORY] \ 76 | --metric angle+wlm \ 77 | --stds 0.001,0.003,0.006 \ 78 | --max_train_img_size 850 \ 79 | --max_val_img_size 25000 \ 80 | --wlm_weight 16 \ 81 | --acc_weight 0 82 | ``` 83 | 84 | 3. Re-training on ImageNet 85 | 86 | - We used V100 X 8 gpus for re-training on ImageNet. 87 | 88 | ### searched on CIFAR-100 with flatness 89 | ```bash 90 | python3 retrain.py \ 91 | --data_root [IMAGENET_DATA_DIRECTORY] \ 92 | --auxiliary \ 93 | --arch=GENAS_FLATNESS_CIFAR100 \ 94 | --init_channels 46 95 | ``` 96 | 97 | ### searched on CIFAR-100 with angle + flatness 98 | ```bash 99 | python3 retrain.py \ 100 | --data_root [IMAGENET_DATA_DIRECTORY] \ 101 | --auxiliary \ 102 | --arch=GENAS_ANGLE_FLATNESS_CIFAR100 \ 103 | --init_channels 48 104 | ``` 105 | 106 | ### searched on CIFAR-10 with flatness 107 | ```bash 108 | python3 retrain.py \ 109 | --data_root [IMAGENET_DATA_DIRECTORY] \ 110 | --auxiliary \ 111 | --arch=GENAS_FLATNESS_CIFAR10 \ 112 | --init_channels 52 113 | ``` 114 | 115 | ### searched on CIFAR-10 with angle + flatness 116 | ```bash 117 | python3 retrain.py \ 118 | --data_root [IMAGENET_DATA_DIRECTORY] \ 119 | --auxiliary \ 120 | --arch=GENAS_ANGLE_FLATNESS_CIFAR10 \ 121 | --init_channels 44 122 | ``` 123 | 124 | ## Model Zoo 125 | 126 | | Search Dataset | Search Metric | Params (M) | FLOPs (G) | ImageNet Top-1 Acc (%) | Weight | 127 | | :--------: | :----------------: | :-----------------: | :--------------: | :------: | :------: | 128 | CIFAR-10 | Angle | 5.3 | 0.6 | 75.7 | [ckpt](https://drive.google.com/file/d/1J_xyxU3ZbuDDr1ASEjdUIkjnrf5rNqB_/view?usp=sharing) 129 | CIFAR-10 | Accuracy | 5.4 | 0.6 | 75.3 | [ckpt](https://drive.google.com/file/d/1jo76ZhbqJt11qls3q2rMVsUcfzkQWp1Q/view?usp=sharing) 130 | CIFAR-10 | Flatness | 5.6 | 0.6 | 76.0 | [ckpt](https://drive.google.com/file/d/1VamhvAUSi2XZVE0Vn4Lxxp1S_dqODTil/view?usp=sharing) 131 | CIFAR-10 | Angle + Flatness | 5.3 | 0.6 | 76.1 | [ckpt](https://drive.google.com/file/d/1p2PSkt5ZyFY2NLGgU45Ilr5NIaXNizW9/view?usp=sharing) 132 | CIFAR-10 | Accuracy + Flatness | 5.6 | 0.6 | 75.7 | [ckpt](https://drive.google.com/file/d/1QBEyY-vFYpGOlwRSsTFxMM8GBtY3F8k7/view?usp=sharing) 133 | | | | | | 134 | CIFAR-100 | Angle | 5.4 | 0.6 | 75.0 | [ckpt](https://drive.google.com/file/d/1CmpkPsWNWVdbDbcmyuVfp38A7lB2MWoC/view?usp=sharing) 135 | CIFAR-100 | Accuracy | 5.4 | 0.6 | 75.4 | [ckpt](https://drive.google.com/file/d/1TWzs-upwnAgOvF0HjKDSwC3TeDRstl4C/view?usp=sharing) 136 | CIFAR-100 | Flatness | 5.2 | 0.6 | 76.1 | [ckpt](https://drive.google.com/file/d/1YLcZNpTytP9XTYDYoQ_nv8gQHnRAFe67/view?usp=sharing) 137 | CIFAR-100 | Angle + Flatness | 5.4 | 0.6 | 75.7 | [ckpt](https://drive.google.com/file/d/1reRbr4cFeoL8fwOTPQjQTg7w_QZAAIm4/view?usp=sharing) 138 | CIFAR-100 | Accuracy + Flatness | 5.4 | 0.6 | 75.9 | [ckpt](https://drive.google.com/file/d/1-GVpP7yUWc7W6Qf8dM_FN3fTo0acO0AI/view?usp=sharing) 139 | 140 | ### Architecture Visualization 141 | 142 | #### angle-based searching 143 | 144 | - normal cell 145 | 146 | 147 | 148 | - reduce cell 149 | 150 | 151 | 152 | 153 | #### angle+flatness based searching 154 | 155 | - normal cell 156 | 157 | 158 | 159 | - reduce cell 160 | 161 | 162 | 163 | #### flatness-based searching 164 | 165 | - normal cell 166 | 167 | 168 | 169 | - reduce cell 170 | 171 | 172 | 173 | ## Citation 174 | If you find that this project helps your research, please consider citing as below: 175 | 176 | ``` 177 | @article{jeong2023genas, 178 | title={GeNAS: Neural Architecture Search with Better Generalization}, 179 | author={Jeong, Joonhyun and Yu, Joonsang and Park, Geondo and Han, Dongyoon and Yoo, Youngjoon}, 180 | journal={arXiv preprint arXiv:2305.08611}, 181 | year={2023} 182 | } 183 | ``` 184 | 185 | ## License 186 | ``` 187 | GeNAS 188 | Copyright (c) 2023-present NAVER Cloud Corp. 189 | 190 | Permission is hereby granted, free of charge, to any person obtaining a copy 191 | of this software and associated documentation files (the "Software"), to deal 192 | in the Software without restriction, including without limitation the rights 193 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 194 | copies of the Software, and to permit persons to whom the Software is 195 | furnished to do so, subject to the following conditions: 196 | 197 | The above copyright notice and this permission notice shall be included in 198 | all copies or substantial portions of the Software. 199 | 200 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 201 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 202 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 203 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 204 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 205 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN 206 | THE SOFTWARE. 207 | ``` 208 | -------------------------------------------------------------------------------- /evolution_search/config.py: -------------------------------------------------------------------------------- 1 | """ 2 | Code Adapted from https://github.com/megvii-model/RLNAS/blob/main/darts_search_space/cifar10/rlnas/evolution_search/config.py 3 | """ 4 | 5 | import os 6 | class config: 7 | host = '127.0.0.1' 8 | 9 | username = 'test' 10 | port = 5672 11 | 12 | exp_name = os.path.dirname(os.path.abspath(__file__)) 13 | exp_name = '-'.join(i for i in exp_name.split(os.path.sep) if i); 14 | 15 | test_send_pipe = exp_name + '-test-send_pipe' 16 | test_recv_pipe = exp_name + '-test-recv_pipe' 17 | 18 | net_cache = 'model_and_data/checkpoint_epoch_250.pth.tar' 19 | initial_net_cache = 'model_and_data/checkpoint_epoch_0.pth.tar' 20 | 21 | 22 | layers = 8 23 | edges = 14 24 | model_input_size_imagenet = (1, 3, 224, 224) 25 | 26 | # Candidate operators 27 | blocks_keys = [ 28 | 'none', 29 | 'max_pool_3x3', 30 | 'avg_pool_3x3', 31 | 'skip_connect', 32 | 'sep_conv_3x3', 33 | 'sep_conv_5x5', 34 | 'dil_conv_3x3', 35 | 'dil_conv_5x5' 36 | ] 37 | op_num = len(blocks_keys) 38 | 39 | # Operators encoding 40 | NONE = 0 41 | MAX_POOLING_3x3 = 1 42 | AVG_POOL_3x3 = 2 43 | SKIP_CONNECT = 3 44 | SEP_CONV_3x3 = 4 45 | SEP_CONV_5x5 = 5 46 | DIL_CONV_3x3 = 6 47 | DIL_CONV_5x5 = 7 48 | 49 | time_limit=None 50 | #time_limit=0.050 51 | speed_input_shape=[32,3,224,224] 52 | 53 | flops_limit=None 54 | 55 | max_epochs=20 56 | select_num = 10 57 | population_num = 50 58 | mutation_num = 25 59 | m_prob = 0.1 60 | crossover_num = 25 61 | 62 | 63 | momentum = 0.7 64 | eps = 1e-5 65 | 66 | # Enumerate all paths of a single cell 67 | # paths = [[0, 2, 3, 4, 5], [0, 2, 3, 5], [0, 2, 4, 5], [0, 2, 5], [0, 3, 4, 5], [0, 3, 5], [0, 4, 5], [0, 5], 68 | # [1, 2, 3, 4, 5], [1, 2, 3, 5], [1, 2, 4, 5], [1, 2, 5], [1, 3, 4, 5], [1, 3, 5], [1, 4, 5], [1, 5]] 69 | # Enumerate all paths of a single cell 70 | paths = [[0, 2, 3, 4, 5], [0, 2, 3, 5], [0, 2, 4, 5], [0, 2, 5], [0, 3, 4, 5], [0, 3, 5], [0, 4, 5], [0, 5], 71 | [1, 2, 3, 4, 5], [1, 2, 3, 5], [1, 2, 4, 5], [1, 2, 5], [1, 3, 4, 5], [1, 3, 5], [1, 4, 5], [1, 5], 72 | [0, 2, 3, 4], [0, 2, 4], [0, 3, 4], [0, 4], 73 | [1, 2, 3, 4], [1, 2, 4], [1, 3, 4], [1, 4], 74 | [0, 2, 3], [0, 3], 75 | [1, 2, 3], [1, 3], 76 | [0, 2], 77 | [1, 2]] 78 | 79 | for i in ['exp_name']: 80 | print('{}: {}'.format(i,eval('config.{}'.format(i)))) 81 | -------------------------------------------------------------------------------- /evolution_search/datasets/DownsampledImageNet.py: -------------------------------------------------------------------------------- 1 | ################################################## 2 | # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 # 3 | ################################################## 4 | import os, sys, hashlib, torch 5 | import numpy as np 6 | from PIL import Image 7 | import torch.utils.data as data 8 | if sys.version_info[0] == 2: 9 | import cPickle as pickle 10 | else: 11 | import pickle 12 | import pdb 13 | 14 | def calculate_md5(fpath, chunk_size=1024 * 1024): 15 | md5 = hashlib.md5() 16 | with open(fpath, 'rb') as f: 17 | for chunk in iter(lambda: f.read(chunk_size), b''): 18 | md5.update(chunk) 19 | return md5.hexdigest() 20 | 21 | 22 | def check_md5(fpath, md5, **kwargs): 23 | return md5 == calculate_md5(fpath, **kwargs) 24 | 25 | 26 | def check_integrity(fpath, md5=None): 27 | if not os.path.isfile(fpath): return False 28 | if md5 is None: return True 29 | else : return check_md5(fpath, md5) 30 | 31 | 32 | class ImageNet16(data.Dataset): 33 | # http://image-net.org/download-images 34 | # A Downsampled Variant of ImageNet as an Alternative to the CIFAR datasets 35 | # https://arxiv.org/pdf/1707.08819.pdf 36 | 37 | train_list = [ 38 | ['train_data_batch_1', '27846dcaa50de8e21a7d1a35f30f0e91'], 39 | ['train_data_batch_2', 'c7254a054e0e795c69120a5727050e3f'], 40 | ['train_data_batch_3', '4333d3df2e5ffb114b05d2ffc19b1e87'], 41 | ['train_data_batch_4', '1620cdf193304f4a92677b695d70d10f'], 42 | ['train_data_batch_5', '348b3c2fdbb3940c4e9e834affd3b18d'], 43 | ['train_data_batch_6', '6e765307c242a1b3d7d5ef9139b48945'], 44 | ['train_data_batch_7', '564926d8cbf8fc4818ba23d2faac7564'], 45 | ['train_data_batch_8', 'f4755871f718ccb653440b9dd0ebac66'], 46 | ['train_data_batch_9', 'bb6dd660c38c58552125b1a92f86b5d4'], 47 | ['train_data_batch_10','8f03f34ac4b42271a294f91bf480f29b'], 48 | ] 49 | valid_list = [ 50 | ['val_data', '3410e3017fdaefba8d5073aaa65e4bd6'], 51 | ] 52 | 53 | def __init__(self, root, train, transform, use_num_of_class_only=None): 54 | self.root = root 55 | self.transform = transform 56 | self.train = train # training set or valid set 57 | if not self._check_integrity(): raise RuntimeError('Dataset not found or corrupted.') 58 | 59 | if self.train: downloaded_list = self.train_list 60 | else : downloaded_list = self.valid_list 61 | self.data = [] 62 | self.targets = [] 63 | 64 | # now load the picked numpy arrays 65 | for i, (file_name, checksum) in enumerate(downloaded_list): 66 | file_path = os.path.join(self.root, file_name) 67 | #print ('Load {:}/{:02d}-th : {:}'.format(i, len(downloaded_list), file_path)) 68 | with open(file_path, 'rb') as f: 69 | if sys.version_info[0] == 2: 70 | entry = pickle.load(f) 71 | else: 72 | entry = pickle.load(f, encoding='latin1') 73 | self.data.append(entry['data']) 74 | self.targets.extend(entry['labels']) 75 | self.data = np.vstack(self.data).reshape(-1, 3, 16, 16) 76 | self.data = self.data.transpose((0, 2, 3, 1)) # convert to HWC 77 | if use_num_of_class_only is not None: 78 | assert isinstance(use_num_of_class_only, int) and use_num_of_class_only > 0 and use_num_of_class_only < 1000, 'invalid use_num_of_class_only : {:}'.format(use_num_of_class_only) 79 | new_data, new_targets = [], [] 80 | for I, L in zip(self.data, self.targets): 81 | if 1 <= L <= use_num_of_class_only: 82 | new_data.append( I ) 83 | new_targets.append( L ) 84 | self.data = new_data 85 | self.targets = new_targets 86 | # self.mean.append(entry['mean']) 87 | #self.mean = np.vstack(self.mean).reshape(-1, 3, 16, 16) 88 | #self.mean = np.mean(np.mean(np.mean(self.mean, axis=0), axis=1), axis=1) 89 | #print ('Mean : {:}'.format(self.mean)) 90 | #temp = self.data - np.reshape(self.mean, (1, 1, 1, 3)) 91 | #std_data = np.std(temp, axis=0) 92 | #std_data = np.mean(np.mean(std_data, axis=0), axis=0) 93 | #print ('Std : {:}'.format(std_data)) 94 | 95 | def __getitem__(self, index): 96 | img, target = self.data[index], self.targets[index] - 1 97 | 98 | img = Image.fromarray(img) 99 | 100 | if self.transform is not None: 101 | img = self.transform(img) 102 | 103 | return img, target 104 | 105 | def __len__(self): 106 | return len(self.data) 107 | 108 | def _check_integrity(self): 109 | root = self.root 110 | for fentry in (self.train_list + self.valid_list): 111 | filename, md5 = fentry[0], fentry[1] 112 | fpath = os.path.join(root, filename) 113 | if not check_integrity(fpath, md5): 114 | return False 115 | return True 116 | 117 | # 118 | if __name__ == '__main__': 119 | train = ImageNet16('/data02/dongxuanyi/.torch/cifar.python/ImageNet16', True , None) 120 | valid = ImageNet16('/data02/dongxuanyi/.torch/cifar.python/ImageNet16', False, None) 121 | 122 | print ( len(train) ) 123 | print ( len(valid) ) 124 | image, label = train[111] 125 | trainX = ImageNet16('/data02/dongxuanyi/.torch/cifar.python/ImageNet16', True , None, 200) 126 | validX = ImageNet16('/data02/dongxuanyi/.torch/cifar.python/ImageNet16', False , None, 200) 127 | print ( len(trainX) ) 128 | print ( len(validX) ) 129 | #import pdb; pdb.set_trace() 130 | -------------------------------------------------------------------------------- /evolution_search/datasets/SearchDatasetWrap.py: -------------------------------------------------------------------------------- 1 | ################################################## 2 | # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 # 3 | ################################################## 4 | import torch, copy, random 5 | import torch.utils.data as data 6 | 7 | 8 | class SearchDataset(data.Dataset): 9 | 10 | def __init__(self, name, data, train_split, valid_split, check=True): 11 | self.datasetname = name 12 | if isinstance(data, (list, tuple)): # new type of SearchDataset 13 | assert len(data) == 2, 'invalid length: {:}'.format( len(data) ) 14 | self.train_data = data[0] 15 | self.valid_data = data[1] 16 | self.train_split = train_split.copy() 17 | self.valid_split = valid_split.copy() 18 | self.mode_str = 'V2' # new mode 19 | else: 20 | self.mode_str = 'V1' # old mode 21 | self.data = data 22 | self.train_split = train_split.copy() 23 | self.valid_split = valid_split.copy() 24 | if check: 25 | intersection = set(train_split).intersection(set(valid_split)) 26 | assert len(intersection) == 0, 'the splitted train and validation sets should have no intersection' 27 | self.length = len(self.train_split) 28 | 29 | def __repr__(self): 30 | return ('{name}(name={datasetname}, train={tr_L}, valid={val_L}, version={ver})'.format(name=self.__class__.__name__, datasetname=self.datasetname, tr_L=len(self.train_split), val_L=len(self.valid_split), ver=self.mode_str)) 31 | 32 | def __len__(self): 33 | return self.length 34 | 35 | def __getitem__(self, index): 36 | assert index >= 0 and index < self.length, 'invalid index = {:}'.format(index) 37 | train_index = self.train_split[index] 38 | valid_index = random.choice( self.valid_split ) 39 | if self.mode_str == 'V1': 40 | train_image, train_label = self.data[train_index] 41 | valid_image, valid_label = self.data[valid_index] 42 | elif self.mode_str == 'V2': 43 | train_image, train_label = self.train_data[train_index] 44 | valid_image, valid_label = self.valid_data[valid_index] 45 | else: raise ValueError('invalid mode : {:}'.format(self.mode_str)) 46 | return train_image, train_label, valid_image, valid_label 47 | -------------------------------------------------------------------------------- /evolution_search/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | ################################################## 2 | # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 # 3 | ################################################## 4 | from .get_dataset_with_transform import get_datasets, get_nas_search_loaders 5 | from .SearchDatasetWrap import SearchDataset 6 | -------------------------------------------------------------------------------- /evolution_search/datasets/config_utils/__init__.py: -------------------------------------------------------------------------------- 1 | import sys, os 2 | import json 3 | from collections import namedtuple 4 | 5 | support_types = ('str', 'int', 'bool', 'float', 'none') 6 | 7 | def convert_param(original_lists): 8 | assert isinstance(original_lists, list), 'The type is not right : {:}'.format(original_lists) 9 | ctype, value = original_lists[0], original_lists[1] 10 | assert ctype in support_types, 'Ctype={:}, support={:}'.format(ctype, support_types) 11 | is_list = isinstance(value, list) 12 | if not is_list: value = [value] 13 | outs = [] 14 | for x in value: 15 | if ctype == 'int': 16 | x = int(x) 17 | elif ctype == 'str': 18 | x = str(x) 19 | elif ctype == 'bool': 20 | x = bool(int(x)) 21 | elif ctype == 'float': 22 | x = float(x) 23 | elif ctype == 'none': 24 | assert x == 'None', 'for none type, the value must be None instead of {:}'.format(x) 25 | x = None 26 | else: 27 | raise TypeError('Does not know this type : {:}'.format(ctype)) 28 | outs.append(x) 29 | if not is_list: outs = outs[0] 30 | return outs 31 | 32 | def load_config(path, extra, logger): 33 | path = str(path) 34 | if hasattr(logger, 'log'): logger.log(path) 35 | assert os.path.exists(path), 'Can not find {:}'.format(path) 36 | # Reading data back 37 | with open(path, 'r') as f: 38 | data = json.load(f) 39 | content = { k: convert_param(v) for k,v in data.items()} 40 | assert extra is None or isinstance(extra, dict), 'invalid type of extra : {:}'.format(extra) 41 | if isinstance(extra, dict): content = {**content, **extra} 42 | Arguments = namedtuple('Configure', ' '.join(content.keys())) 43 | content = Arguments(**content) 44 | if hasattr(logger, 'log'): logger.log('{:}'.format(content)) 45 | return content -------------------------------------------------------------------------------- /evolution_search/datasets/get_dataset_with_transform.py: -------------------------------------------------------------------------------- 1 | ################################################## 2 | # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 # 3 | ################################################## 4 | import os, sys, torch 5 | import os.path as osp 6 | import numpy as np 7 | import torchvision.datasets as dset 8 | import torchvision.transforms as transforms 9 | from copy import deepcopy 10 | from PIL import Image 11 | 12 | from .DownsampledImageNet import ImageNet16 13 | from .SearchDatasetWrap import SearchDataset 14 | from .config_utils import load_config as load_dataset_config 15 | from torchvision.transforms import transforms 16 | from PIL import ImageFilter, ImageOps 17 | import random 18 | import torchvision.datasets as datasets 19 | 20 | Dataset2Class = {'cifar10' : 10, 21 | 'cifar100': 100, 22 | 'imagenet-1k-s':1000, 23 | 'imagenet-1k' : 1000, 24 | 'ImageNet16' : 1000, 25 | 'ImageNet16-150': 150, 26 | 'ImageNet16-120': 120, 27 | 'ImageNet16-200': 200} 28 | 29 | class CUTOUT(object): 30 | 31 | def __init__(self, length): 32 | self.length = length 33 | 34 | def __repr__(self): 35 | return ('{name}(length={length})'.format(name=self.__class__.__name__, **self.__dict__)) 36 | 37 | def __call__(self, img): 38 | h, w = img.size(1), img.size(2) 39 | mask = np.ones((h, w), np.float32) 40 | y = np.random.randint(h) 41 | x = np.random.randint(w) 42 | 43 | y1 = np.clip(y - self.length // 2, 0, h) 44 | y2 = np.clip(y + self.length // 2, 0, h) 45 | x1 = np.clip(x - self.length // 2, 0, w) 46 | x2 = np.clip(x + self.length // 2, 0, w) 47 | 48 | mask[y1: y2, x1: x2] = 0. 49 | mask = torch.from_numpy(mask) 50 | mask = mask.expand_as(img) 51 | img *= mask 52 | return img 53 | 54 | 55 | imagenet_pca = { 56 | 'eigval': np.asarray([0.2175, 0.0188, 0.0045]), 57 | 'eigvec': np.asarray([ 58 | [-0.5675, 0.7192, 0.4009], 59 | [-0.5808, -0.0045, -0.8140], 60 | [-0.5836, -0.6948, 0.4203], 61 | ]) 62 | } 63 | 64 | 65 | class Lighting(object): 66 | def __init__(self, alphastd, 67 | eigval=imagenet_pca['eigval'], 68 | eigvec=imagenet_pca['eigvec']): 69 | self.alphastd = alphastd 70 | assert eigval.shape == (3,) 71 | assert eigvec.shape == (3, 3) 72 | self.eigval = eigval 73 | self.eigvec = eigvec 74 | 75 | def __call__(self, img): 76 | if self.alphastd == 0.: 77 | return img 78 | rnd = np.random.randn(3) * self.alphastd 79 | rnd = rnd.astype('float32') 80 | v = rnd 81 | old_dtype = np.asarray(img).dtype 82 | v = v * self.eigval 83 | v = v.reshape((3, 1)) 84 | inc = np.dot(self.eigvec, v).reshape((3,)) 85 | img = np.add(img, inc) 86 | if old_dtype == np.uint8: 87 | img = np.clip(img, 0, 255) 88 | img = Image.fromarray(img.astype(old_dtype), 'RGB') 89 | return img 90 | 91 | def __repr__(self): 92 | return self.__class__.__name__ + '()' 93 | 94 | 95 | class Cifar10RandomLabels(datasets.CIFAR10): 96 | """CIFAR10 dataset, with support for randomly corrupt labels. 97 | Params 98 | ------ 99 | rand_seed: int 100 | Default 0. numpy random seed. 101 | num_classes: int 102 | Default 10. The number of classes in the dataset. 103 | """ 104 | def __init__(self, rand_seed=0, num_classes=10, **kwargs): 105 | super(Cifar10RandomLabels, self).__init__(**kwargs) 106 | self.n_classes = num_classes 107 | self.rand_seed = rand_seed 108 | self.random_labels() 109 | 110 | def random_labels(self): 111 | labels = np.array(self.targets) 112 | print('num_classes:{}, random labels num:{}, random seed:{}'.format(self.n_classes, len(labels), self.rand_seed)) 113 | np.random.seed(self.rand_seed) 114 | rnd_labels = np.random.randint(0, self.n_classes, len(labels)) 115 | # we need to explicitly cast the labels from npy.int64 to 116 | # builtin int type, otherwise pytorch will fail... 117 | labels = [int(x) for x in rnd_labels] 118 | 119 | self.targets = labels 120 | 121 | class Cifar100RandomLabels(datasets.CIFAR100): 122 | """CIFAR10 dataset, with support for randomly corrupt labels. 123 | Params 124 | ------ 125 | rand_seed: int 126 | Default 0. numpy random seed. 127 | num_classes: int 128 | Default 100. The number of classes in the dataset. 129 | """ 130 | def __init__(self, rand_seed=0, num_classes=100, **kwargs): 131 | super(Cifar100RandomLabels, self).__init__(**kwargs) 132 | self.n_classes = num_classes 133 | self.rand_seed = rand_seed 134 | self.random_labels() 135 | 136 | def random_labels(self): 137 | labels = np.array(self.targets) 138 | print('num_classes:{}, random labels num:{}, random seed:{}'.format(self.n_classes, len(labels), self.rand_seed)) 139 | np.random.seed(self.rand_seed) 140 | rnd_labels = np.random.randint(0, self.n_classes, len(labels)) 141 | # we need to explicitly cast the labels from npy.int64 to 142 | # builtin int type, otherwise pytorch will fail... 143 | labels = [int(x) for x in rnd_labels] 144 | 145 | self.targets = labels 146 | 147 | class ImageNet16RandomLabels(ImageNet16): 148 | """CIFAR10 dataset, with support for randomly corrupt labels. 149 | Params 150 | ------ 151 | rand_seed: int 152 | Default 0. numpy random seed. 153 | num_classes: int 154 | Default 120. The number of classes in the dataset. 155 | """ 156 | def __init__(self, rand_seed=0, num_classes=120, **kwargs): 157 | super(ImageNet16RandomLabels, self).__init__(**kwargs) 158 | self.n_classes = num_classes 159 | self.rand_seed = rand_seed 160 | self.random_labels() 161 | # print('min_label:{}, max_label:{}'.format(min(self.targets), max(self.targets))) 162 | 163 | def random_labels(self): 164 | labels = np.array(self.targets) 165 | print('num_classes:{}, random labels num:{}, random seed:{}'.format(self.n_classes, len(labels), self.rand_seed)) 166 | np.random.seed(self.rand_seed) 167 | rnd_labels = np.random.randint(1, self.n_classes+1, len(labels)) 168 | # we need to explicitly cast the labels from npy.int64 to 169 | # builtin int type, otherwise pytorch will fail... 170 | labels = [int(x) for x in rnd_labels] 171 | 172 | self.targets = labels 173 | 174 | def get_datasets(name, root, cutout, rand_seed, byol_aug_type=None, random_label=True): 175 | 176 | if name == 'cifar10': 177 | mean = [x / 255 for x in [125.3, 123.0, 113.9]] 178 | std = [x / 255 for x in [63.0, 62.1, 66.7]] 179 | elif name == 'cifar100': 180 | mean = [x / 255 for x in [129.3, 124.1, 112.4]] 181 | std = [x / 255 for x in [68.2, 65.4, 70.4]] 182 | elif name.startswith('imagenet-1k'): 183 | mean, std = [0.485, 0.456, 0.406], [0.229, 0.224, 0.225] 184 | elif name.startswith('ImageNet16'): 185 | mean = [x / 255 for x in [122.68, 116.66, 104.01]] 186 | std = [x / 255 for x in [63.22, 61.26 , 65.09]] 187 | else: 188 | raise TypeError("Unknow dataset : {:}".format(name)) 189 | 190 | # Data Argumentation 191 | if name == 'cifar10' or name == 'cifar100': 192 | lists = [transforms.RandomHorizontalFlip(), transforms.RandomCrop(32, padding=4), transforms.ToTensor(), transforms.Normalize(mean, std)] 193 | if cutout > 0 : lists += [CUTOUT(cutout)] 194 | if byol_aug_type is None: 195 | train_transform = transforms.Compose(lists) 196 | elif byol_aug_type=='byol': 197 | online_aug = get_train_transform('BYOL_Tau', 32, mean, std) 198 | target_aug = get_train_transform('BYOL_Tau_Hat', 32, mean, std) 199 | train_transform = TwoImageAugmentations(online_aug, target_aug) 200 | else: 201 | raise NotImplementedError 202 | test_transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize(mean, std)]) 203 | xshape = (1, 3, 32, 32) 204 | elif name.startswith('ImageNet16'): 205 | lists = [transforms.RandomHorizontalFlip(), transforms.RandomCrop(16, padding=2), transforms.ToTensor(), transforms.Normalize(mean, std)] 206 | if cutout > 0 : lists += [CUTOUT(cutout)] 207 | if byol_aug_type is None: 208 | train_transform = transforms.Compose(lists) 209 | elif byol_aug_type=='byol': 210 | online_aug = get_train_transform('BYOL_Tau', 16, mean, std) 211 | target_aug = get_train_transform('BYOL_Tau_Hat', 16, mean, std) 212 | train_transform = TwoImageAugmentations(online_aug, target_aug) 213 | else: 214 | raise NotImplementedError 215 | test_transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize(mean, std)]) 216 | xshape = (1, 3, 16, 16) 217 | elif name == 'tiered': 218 | lists = [transforms.RandomHorizontalFlip(), transforms.RandomCrop(80, padding=4), transforms.ToTensor(), transforms.Normalize(mean, std)] 219 | if cutout > 0 : lists += [CUTOUT(cutout)] 220 | train_transform = transforms.Compose(lists) 221 | test_transform = transforms.Compose([transforms.CenterCrop(80), transforms.ToTensor(), transforms.Normalize(mean, std)]) 222 | xshape = (1, 3, 32, 32) 223 | elif name.startswith('imagenet-1k'): 224 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) 225 | if name == 'imagenet-1k': 226 | xlists = [transforms.RandomResizedCrop(224)] 227 | xlists.append( 228 | transforms.ColorJitter( 229 | brightness=0.4, 230 | contrast=0.4, 231 | saturation=0.4, 232 | hue=0.2)) 233 | xlists.append( Lighting(0.1)) 234 | elif name == 'imagenet-1k-s': 235 | xlists = [transforms.RandomResizedCrop(224, scale=(0.2, 1.0))] 236 | else: raise ValueError('invalid name : {:}'.format(name)) 237 | xlists.append( transforms.RandomHorizontalFlip(p=0.5) ) 238 | xlists.append( transforms.ToTensor() ) 239 | xlists.append( normalize ) 240 | train_transform = transforms.Compose(xlists) 241 | test_transform = transforms.Compose([transforms.Resize(256), transforms.CenterCrop(224), transforms.ToTensor(), normalize]) 242 | xshape = (1, 3, 224, 224) 243 | else: 244 | raise TypeError("Unknow dataset : {:}".format(name)) 245 | 246 | if name == 'cifar10': 247 | if random_label: 248 | train_data = Cifar10RandomLabels(root=root, train=True , transform=train_transform, download=True, rand_seed=rand_seed) 249 | else: 250 | train_data = datasets.CIFAR10(root=root, train=True, transform=train_transform, download=True) 251 | test_data = datasets.CIFAR10(root=root, train=True , transform=test_transform, download=True) 252 | # test_data = datasets.CIFAR10(root=root, train=False, transform=test_transform , download=True) 253 | assert len(train_data) == 50000 and len(test_data) == 50000 254 | elif name == 'cifar100': 255 | if random_label: 256 | train_data = Cifar100RandomLabels(root=root, train=True , transform=train_transform, download=True, rand_seed=rand_seed) 257 | else: 258 | train_data = datasets.CIFAR100(root=root, train=True , transform=train_transform, download=True) 259 | test_data = datasets.CIFAR100(root=root, train=True, transform=test_transform , download=True) 260 | assert len(train_data) == 50000 and len(test_data) == 50000 261 | elif name.startswith('imagenet-1k'): 262 | train_data = dset.ImageFolder(osp.join(root, 'train'), train_transform) 263 | test_data = dset.ImageFolder(osp.join(root, 'val'), test_transform) 264 | assert len(train_data) == 1281167 and len(test_data) == 50000, 'invalid number of images : {:} & {:} vs {:} & {:}'.format(len(train_data), len(test_data), 1281167, 50000) 265 | elif name == 'ImageNet16': 266 | if random_label: 267 | train_data = ImageNet16RandomLabels(root=root, train=True ,transform=train_transform, rand_seed=rand_seed) 268 | else: 269 | train_data = ImageNet16(root=root, train=True, transform=train_transform) 270 | test_data = ImageNet16(root=root, train=False, transform=test_transform) 271 | assert len(train_data) == 1281167 and len(test_data) == 50000 272 | elif name == 'ImageNet16-120': 273 | if random_label: 274 | train_data = ImageNet16RandomLabels(root=root, train=True , transform=train_transform, num_classes=120, use_num_of_class_only=120, rand_seed=rand_seed) 275 | else: 276 | train_data = ImageNet16(root=root, train=True , transform=train_transform, use_num_of_class_only=120) 277 | test_data = ImageNet16(root=root, train=False, transform=test_transform , use_num_of_class_only=120) 278 | assert len(train_data) == 151700 and len(test_data) == 6000 279 | elif name == 'ImageNet16-150': 280 | if random_label: 281 | train_data = ImageNet16RandomLabels(root=root, train=True , transform=train_transform, num_classes=150, use_num_of_class_only=150, rand_seed=rand_seed) 282 | else: 283 | train_data = ImageNet16(root=root, train=True , transform=train_transform, use_num_of_class_only=150) 284 | test_data = ImageNet16(root=root, train=False, transform=test_transform , use_num_of_class_only=150) 285 | assert len(train_data) == 190272 and len(test_data) == 7500 286 | elif name == 'ImageNet16-200': 287 | if random_label: 288 | train_data = ImageNet16RandomLabels(root=root, train=True ,transform=train_transform, num_classes=200, use_num_of_class_only=200, rand_seed=rand_seed) 289 | else: 290 | train_data = ImageNet16(root=root, train=True , transform=train_transform, use_num_of_class_only=200) 291 | test_data = ImageNet16(root=root, train=False, transform=test_transform , use_num_of_class_only=200) 292 | assert len(train_data) == 254775 and len(test_data) == 10000 293 | else: raise TypeError("Unknow dataset : {:}".format(name)) 294 | 295 | class_num = Dataset2Class[name] 296 | return train_data, test_data, xshape, class_num 297 | 298 | 299 | def get_nas_search_loaders(train_data, valid_data, dataset, config_root, batch_size, workers, use_valid_no_shuffle=False): 300 | # NOTE: detailed dataset configuration is given in NAS-BENCH-201 paper, https://arxiv.org/pdf/2001.00326.pdf. 301 | if isinstance(batch_size, (list,tuple)): 302 | batch, test_batch = batch_size 303 | else: 304 | batch, test_batch = batch_size, batch_size 305 | if dataset == 'cifar10' or dataset == 'cifar100': 306 | #split_Fpath = 'configs/nas-benchmark/cifar-split.txt' 307 | if dataset == 'cifar10': 308 | cifar_split = load_dataset_config('{:}/cifar-split.txt'.format(config_root), None, None) 309 | elif dataset == 'cifar100': 310 | cifar_split = load_dataset_config('{:}/cifar100-split.txt'.format(config_root), None, None) 311 | train_split, valid_split = cifar_split.train, cifar_split.valid # search over the proposed training and validation set 312 | #logger.log('Load split file from {:}'.format(split_Fpath)) # they are two disjoint groups in the original CIFAR-10 training set 313 | # To split data 314 | xvalid_data = valid_data 315 | search_data = SearchDataset(dataset, train_data, train_split, valid_split) 316 | # data loader 317 | search_loader = torch.utils.data.DataLoader(search_data, batch_size=batch, shuffle=True , num_workers=workers, pin_memory=True) 318 | train_loader = torch.utils.data.DataLoader(train_data , batch_size=batch, sampler=torch.utils.data.sampler.SubsetRandomSampler(train_split), num_workers=workers, pin_memory=True) 319 | valid_loader = torch.utils.data.DataLoader(xvalid_data, batch_size=test_batch, sampler=torch.utils.data.sampler.SubsetRandomSampler(valid_split), num_workers=workers, pin_memory=True) 320 | if use_valid_no_shuffle: 321 | # NOTE: using validation dataset 322 | valid_loader_no_shuffle = torch.utils.data.DataLoader(torch.utils.data.Subset(xvalid_data, valid_split), batch_size=test_batch, shuffle=False, num_workers=workers, pin_memory=True) 323 | # NOTE: using search training dataset 324 | # valid_loader_no_shuffle = torch.utils.data.DataLoader(search_data, batch_size=test_batch, shuffle=False , num_workers=workers, pin_memory=True) 325 | elif dataset == 'ImageNet16-120': 326 | imagenet_test_split = load_dataset_config('{:}/imagenet-16-120-test-split.txt'.format(config_root), None, None) 327 | search_train_data = train_data 328 | search_valid_data = deepcopy(valid_data) ; search_valid_data.transform = train_data.transform 329 | search_data = SearchDataset(dataset, [search_train_data,search_valid_data], list(range(len(search_train_data))), imagenet_test_split.xvalid) 330 | search_loader = torch.utils.data.DataLoader(search_data, batch_size=batch, shuffle=True , num_workers=workers, pin_memory=True) 331 | train_loader = torch.utils.data.DataLoader(train_data , batch_size=batch, shuffle=True , num_workers=workers, pin_memory=True) 332 | valid_loader = torch.utils.data.DataLoader(valid_data , batch_size=test_batch, sampler=torch.utils.data.sampler.SubsetRandomSampler(imagenet_test_split.xvalid), num_workers=workers, pin_memory=True) 333 | if use_valid_no_shuffle: 334 | # NOTE: using validation dataset 335 | valid_loader_no_shuffle = torch.utils.data.DataLoader(torch.utils.data.Subset(valid_data, imagenet_test_split.xvalid), batch_size=test_batch, shuffle=False, num_workers=workers, pin_memory=True) 336 | # NOTE: using search training dataset 337 | # valid_loader_no_shuffle = torch.utils.data.DataLoader(search_data, batch_size=test_batch, shuffle=False , num_workers=workers, pin_memory=True) 338 | else: 339 | raise ValueError('invalid dataset : {:}'.format(dataset)) 340 | 341 | if use_valid_no_shuffle: 342 | return search_loader, train_loader, valid_loader, valid_loader_no_shuffle 343 | else: 344 | return search_loader, train_loader, valid_loader 345 | 346 | def get_train_transform(aug, image_size, mean, std): 347 | 348 | if aug == 'BYOL_Tau': 349 | transform = transforms.Compose([ 350 | transforms.RandomResizedCrop(image_size), 351 | transforms.RandomHorizontalFlip(), 352 | transforms.RandomApply([transforms.ColorJitter(0.4, 0.4, 0.2, 0.1)], p=0.8), 353 | transforms.RandomGrayscale(p=0.2), 354 | transforms.RandomApply([GaussianBlur([0.1, 2.0])], p=1.0), 355 | transforms.RandomApply([Solarization(128)], p=0.0), 356 | transforms.ToTensor(), 357 | transforms.Normalize(mean, std), 358 | 359 | ]) 360 | elif aug == 'BYOL_Tau_Hat': 361 | transform = transforms.Compose([ 362 | transforms.RandomResizedCrop(image_size), 363 | transforms.RandomHorizontalFlip(), 364 | transforms.RandomApply([transforms.ColorJitter(0.4, 0.4, 0.2, 0.1)], p=0.8), 365 | transforms.RandomGrayscale(p=0.2), 366 | transforms.RandomApply([GaussianBlur([0.1, 2.0])], p=0.1), 367 | transforms.RandomApply([Solarization(128)], p=0.2), 368 | transforms.ToTensor(), 369 | transforms.Normalize(mean, std), 370 | ]) 371 | else: 372 | raise NotImplementedError 373 | 374 | return transform 375 | 376 | 377 | class TwoImageAugmentations: 378 | def __init__(self, online_aug, target_aug): 379 | self.online_aug = online_aug 380 | self.target_aug = target_aug 381 | 382 | def __call__(self, x): 383 | online_image = self.online_aug(x) 384 | target_image = self.target_aug(x) 385 | return [online_image, target_image] 386 | 387 | class GaussianBlur(object): 388 | """Gaussian blur augmentation in SimCLR https://arxiv.org/abs/2002.05709""" 389 | 390 | def __init__(self, sigma=[.1, 2.]): 391 | self.sigma = sigma 392 | 393 | def __call__(self, x): 394 | sigma = random.uniform(self.sigma[0], self.sigma[1]) 395 | x = x.filter(ImageFilter.GaussianBlur(radius=sigma)) 396 | return x 397 | 398 | 399 | class Solarization(object): 400 | def __init__(self, magnitude=128): 401 | self.magnitude = magnitude 402 | 403 | def __call__(self, x): 404 | x = ImageOps.solarize(x, self.magnitude) 405 | return x 406 | 407 | if __name__ == '__main__': 408 | byol = True 409 | train_data, test_data, xshape, class_num = get_datasets('cifar10', '/home/zhangxuanyang/dataset/cifar.python/', -1, byol) 410 | search_loader, _, valid_loader = get_nas_search_loaders(train_data, test_data, 'cifar10', 'configs/nas-benchmark/', \ 411 | (3, 3), 4) 412 | for step, (base_inputs, base_targets, arch_inputs, arch_targets) in enumerate(search_loader): 413 | print(base_inputs) 414 | break 415 | 416 | # import pdb; pdb.set_trace() 417 | -------------------------------------------------------------------------------- /evolution_search/datasets/test_utils.py: -------------------------------------------------------------------------------- 1 | ################################################## 2 | # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 # 3 | ################################################## 4 | def test_imagenet_data(imagenet): 5 | total_length = len(imagenet) 6 | assert total_length == 1281166 or total_length == 50000, 'The length of ImageNet is wrong : {}'.format(total_length) 7 | map_id = {} 8 | for index in range(total_length): 9 | path, target = imagenet.imgs[index] 10 | folder, image_name = os.path.split(path) 11 | _, folder = os.path.split(folder) 12 | if folder not in map_id: 13 | map_id[folder] = target 14 | else: 15 | assert map_id[folder] == target, 'Class : {} is not {}'.format(folder, target) 16 | assert image_name.find(folder) == 0, '{} is wrong.'.format(path) 17 | print ('Check ImageNet Dataset OK') 18 | -------------------------------------------------------------------------------- /evolution_search/genotypes.py: -------------------------------------------------------------------------------- 1 | """ 2 | Code Adapted from https://github.com/megvii-model/RLNAS/blob/main/darts_search_space/cifar10/rlnas/evolution_search/genotypes.py 3 | """ 4 | 5 | from collections import namedtuple 6 | 7 | Genotype = namedtuple('Genotype', 'normal normal_concat reduce reduce_concat') 8 | 9 | PRIMITIVES = [ 10 | 'none', 11 | 'max_pool_3x3', 12 | 'avg_pool_3x3', 13 | 'skip_connect', 14 | 'sep_conv_3x3', 15 | 'sep_conv_5x5', 16 | 'dil_conv_3x3', 17 | 'dil_conv_5x5' 18 | ] 19 | 20 | NASNet = Genotype( 21 | normal = [ 22 | ('sep_conv_5x5', 1), 23 | ('sep_conv_3x3', 0), 24 | ('sep_conv_5x5', 0), 25 | ('sep_conv_3x3', 0), 26 | ('avg_pool_3x3', 1), 27 | ('skip_connect', 0), 28 | ('avg_pool_3x3', 0), 29 | ('avg_pool_3x3', 0), 30 | ('sep_conv_3x3', 1), 31 | ('skip_connect', 1), 32 | ], 33 | normal_concat = [2, 3, 4, 5, 6], 34 | reduce = [ 35 | ('sep_conv_5x5', 1), 36 | ('sep_conv_7x7', 0), 37 | ('max_pool_3x3', 1), 38 | ('sep_conv_7x7', 0), 39 | ('avg_pool_3x3', 1), 40 | ('sep_conv_5x5', 0), 41 | ('skip_connect', 3), 42 | ('avg_pool_3x3', 2), 43 | ('sep_conv_3x3', 2), 44 | ('max_pool_3x3', 1), 45 | ], 46 | reduce_concat = [4, 5, 6], 47 | ) 48 | 49 | AmoebaNet = Genotype( 50 | normal = [ 51 | ('avg_pool_3x3', 0), 52 | ('max_pool_3x3', 1), 53 | ('sep_conv_3x3', 0), 54 | ('sep_conv_5x5', 2), 55 | ('sep_conv_3x3', 0), 56 | ('avg_pool_3x3', 3), 57 | ('sep_conv_3x3', 1), 58 | ('skip_connect', 1), 59 | ('skip_connect', 0), 60 | ('avg_pool_3x3', 1), 61 | ], 62 | normal_concat = [4, 5, 6], 63 | reduce = [ 64 | ('avg_pool_3x3', 0), 65 | ('sep_conv_3x3', 1), 66 | ('max_pool_3x3', 0), 67 | ('sep_conv_7x7', 2), 68 | ('sep_conv_7x7', 0), 69 | ('avg_pool_3x3', 1), 70 | ('max_pool_3x3', 0), 71 | ('max_pool_3x3', 1), 72 | ('conv_7x1_1x7', 0), 73 | ('sep_conv_3x3', 5), 74 | ], 75 | reduce_concat = [3, 4, 6] 76 | ) 77 | 78 | DARTS_V1 = Genotype(normal=[('sep_conv_3x3', 1), ('sep_conv_3x3', 0), ('skip_connect', 0), ('sep_conv_3x3', 1), ('skip_connect', 0), ('sep_conv_3x3', 1), ('sep_conv_3x3', 0), ('skip_connect', 2)], normal_concat=[2, 3, 4, 5], reduce=[('max_pool_3x3', 0), ('max_pool_3x3', 1), ('skip_connect', 2), ('max_pool_3x3', 0), ('max_pool_3x3', 0), ('skip_connect', 2), ('skip_connect', 2), ('avg_pool_3x3', 0)], reduce_concat=[2, 3, 4, 5]) 79 | DARTS_V2 = Genotype(normal=[('sep_conv_3x3', 0), ('sep_conv_3x3', 1), ('sep_conv_3x3', 0), ('sep_conv_3x3', 1), ('sep_conv_3x3', 1), ('skip_connect', 0), ('skip_connect', 0), ('dil_conv_3x3', 2)], normal_concat=[2, 3, 4, 5], reduce=[('max_pool_3x3', 0), ('max_pool_3x3', 1), ('skip_connect', 2), ('max_pool_3x3', 1), ('max_pool_3x3', 0), ('skip_connect', 2), ('skip_connect', 2), ('max_pool_3x3', 1)], reduce_concat=[2, 3, 4, 5]) 80 | 81 | def parse_searched_cell(normal_reduce_cell): 82 | ''' 83 | normal_reduce_cell: list of normal + reduce cell. 84 | e.g) [ 85 | 14 elements for normal cell edges where each element denote operation for each edge + 86 | 14 elements for reduce cell edges where each element denote operation for each edge 87 | ] 88 | ''' 89 | assert len(normal_reduce_cell) == 28, "cell should contain normal + reduce edges (14 + 14 = 28)" 90 | normal_cell = normal_reduce_cell[:14] 91 | reduce_cell = normal_reduce_cell[14:] 92 | 93 | normal_cell_decoded = [] 94 | reduce_cell_decoded = [] 95 | # normal cell decode 96 | for i in range(len(normal_cell)): 97 | # NOTE: for generating intermediate node 0 98 | if i in [0, 1]: 99 | normal_cell_decoded.append((PRIMITIVES[normal_cell[i]], i)) 100 | # NOTE: for generating intermediate node 1 101 | elif i in [2, 3, 4]: 102 | if normal_cell[i] != 0: 103 | normal_cell_decoded.append((PRIMITIVES[normal_cell[i]], i - 2)) 104 | # NOTE: for generating intermediate node 2 105 | elif i in [5, 6, 7, 8]: 106 | if normal_cell[i] != 0: 107 | normal_cell_decoded.append((PRIMITIVES[normal_cell[i]], i - 5)) 108 | # NOTE: for generating intermediate node 3 109 | elif i in [9, 10, 11, 12, 13]: 110 | if normal_cell[i] != 0: 111 | normal_cell_decoded.append((PRIMITIVES[normal_cell[i]], i - 9)) 112 | 113 | # reduce cell decode 114 | for i in range(len(reduce_cell)): 115 | # NOTE: for generating intermediate node 0 116 | if i in [0, 1]: 117 | reduce_cell_decoded.append((PRIMITIVES[reduce_cell[i]], i)) 118 | # NOTE: for generating intermediate node 1 119 | elif i in [2, 3, 4]: 120 | if reduce_cell[i] != 0: 121 | reduce_cell_decoded.append((PRIMITIVES[reduce_cell[i]], i - 2)) 122 | # NOTE: for generating intermediate node 2 123 | elif i in [5, 6, 7, 8]: 124 | if reduce_cell[i] != 0: 125 | reduce_cell_decoded.append((PRIMITIVES[reduce_cell[i]], i - 5)) 126 | # NOTE: for generating intermediate node 3 127 | elif i in [9, 10, 11, 12, 13]: 128 | if reduce_cell[i] != 0: 129 | reduce_cell_decoded.append((PRIMITIVES[reduce_cell[i]], i - 9)) 130 | 131 | return Genotype(normal=normal_cell_decoded, normal_concat=[2, 3, 4, 5], reduce=reduce_cell_decoded, reduce_concat=[2, 3, 4, 5]) 132 | 133 | RLDARTS = parse_searched_cell((5, 4, 5, 0, 5, 0, 0, 5, 5, 0, 0, 0, 7, 4, 5, 4, 2, 0, 5, 0, 0, 4, 4, 0, 4, 4, 0, 0)) 134 | RLDARTS_GT = parse_searched_cell((5, 5, 4, 5, 0, 0, 4, 0, 4, 0, 0, 0, 4, 4, 1, 3, 3, 3, 0, 3, 2, 0, 0, 0, 4, 0, 0, 6)) 135 | 136 | DARTS = DARTS_V2 137 | 138 | -------------------------------------------------------------------------------- /evolution_search/metrics/tester_acc.py: -------------------------------------------------------------------------------- 1 | """ 2 | GeNAS 3 | Copyright (c) 2023-present NAVER Cloud Corp. 4 | MIT license 5 | """ 6 | import torch 7 | import math 8 | from config import config 9 | assert torch.cuda.is_available() 10 | 11 | def accuracy(output, target, topk=(1,)): 12 | maxk = max(topk) 13 | batch_size = target.size(0) 14 | 15 | _, pred = output.topk(maxk, 1, True, True) 16 | pred = pred.t() 17 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 18 | 19 | res = [] 20 | for k in topk: 21 | correct_k = correct[:k].reshape(-1).float().sum(0) 22 | res.append(correct_k.mul_(100.0 / batch_size)) 23 | return res 24 | 25 | 26 | def no_grad_wrapper(func): 27 | def new_func(*args, **kwargs): 28 | with torch.no_grad(): 29 | return func(*args, **kwargs) 30 | return new_func 31 | 32 | @no_grad_wrapper 33 | def get_cand_acc(model, genotype, train_dataloader, val_dataloader, max_train_img_size, max_val_img_size=25000): 34 | ''' 35 | genotype: normal (14 edges) + reduce cell (14 edges) with operation indices. e.g. 6, 6, 3, 0, 4, 7, 0, 0, 1, 0, 6, 0, 6, 0, 1, 4, 3, 7, 0, 7, 0, 0, 4, 0, 0, 0, 3, 2] 36 | train_dataloader: half (25K) of original training set (50K). 37 | val_dataloader: another half (25K) of original training set (50K). 38 | ''' 39 | # separate genotype 40 | normal_genotype = tuple(genotype[:config.edges]) 41 | reduce_genotype = tuple(genotype[config.edges:]) 42 | 43 | train_dataloader_iter = iter(train_dataloader) 44 | val_dataloader_iter = iter(val_dataloader) 45 | 46 | if torch.cuda.is_available(): 47 | device = torch.device('cuda') 48 | else: 49 | device = torch.device('cpu') 50 | 51 | # NOTE: # iterations of BN statistics re-tracking for search loader 52 | max_train_iters = math.ceil(max_train_img_size / train_dataloader.batch_size) 53 | 54 | # NOTE: # iterations of measure validation accuracy for all validation images 55 | max_test_iters = math.ceil(max_val_img_size / val_dataloader.batch_size) 56 | 57 | if max_train_iters > 0: 58 | # NOTE: [from SPOS paper] "Before the inference of an architecture, the statistics of all the Batch Normalization (BN) [9] operations are recalculated on a random subset of training data" 59 | for m in model.modules(): 60 | if isinstance(m, torch.nn.BatchNorm2d): 61 | m.running_mean = torch.zeros_like(m.running_mean) 62 | m.running_var = torch.ones_like(m.running_var) 63 | 64 | model.train() 65 | 66 | for step in range(max_train_iters): 67 | batch = train_dataloader_iter.next() 68 | if len(batch) == 4: 69 | data, target, _, _ = batch 70 | elif len(batch) == 2: 71 | data, target = batch 72 | 73 | target = target.type(torch.LongTensor) 74 | 75 | data, target = data.to(device), target.to(device) 76 | 77 | output = model(data, normal_genotype, reduce_genotype) 78 | 79 | del data, target, output 80 | 81 | top1 = 0 82 | top5 = 0 83 | total = 0 84 | 85 | print('starting test....') 86 | model.eval() 87 | 88 | for step in range(max_test_iters): 89 | data, target = val_dataloader_iter.next() 90 | batchsize = data.shape[0] 91 | target = target.type(torch.LongTensor) 92 | data, target = data.to(device), target.to(device) 93 | logits = model(data, normal_genotype, reduce_genotype) 94 | prec1, prec5 = accuracy(logits, target, topk=(1, 5)) 95 | top1 += prec1.item() * batchsize 96 | top5 += prec5.item() * batchsize 97 | total += batchsize 98 | 99 | del data, target, logits, prec1, prec5 100 | 101 | top1, top5 = top1 / total, top5 / total 102 | top1, top5 = top1 / 100, top5 / 100 103 | return top1, top5 104 | 105 | def main(): 106 | pass -------------------------------------------------------------------------------- /evolution_search/metrics/tester_wlm.py: -------------------------------------------------------------------------------- 1 | """ 2 | GeNAS 3 | Copyright (c) 2023-present NAVER Cloud Corp. 4 | MIT license 5 | """ 6 | import torch 7 | import torch.nn as nn 8 | import numpy as np 9 | from copy import deepcopy 10 | import math 11 | from config import config 12 | assert torch.cuda.is_available() 13 | 14 | def check_strictly_increasing(L): 15 | return all(x 0: 50 | # NOTE: [from SPOS paper] "Before the inference of an architecture, the statistics of all the Batch Normalization (BN) [9] operations are recalculated on a random subset of training data" 51 | for m in model.modules(): 52 | if isinstance(m, torch.nn.BatchNorm2d): 53 | m.running_mean = torch.zeros_like(m.running_mean) 54 | m.running_var = torch.ones_like(m.running_var) 55 | 56 | model.train() 57 | 58 | for step in range(max_train_iters): 59 | batch = train_dataloader_iter.next() 60 | if len(batch) == 4: 61 | data, target, _, _ = batch 62 | elif len(batch) == 2: 63 | data, target = batch 64 | 65 | target = target.type(torch.LongTensor) 66 | data, target = data.to(device), target.to(device) 67 | output = model(data, normal_genotype, reduce_genotype) 68 | del data, target, output 69 | 70 | losses_per_stds = [] 71 | 72 | model.eval() 73 | 74 | model_ = deepcopy(model) 75 | 76 | for std_idx, cur_std in enumerate(stds): 77 | val_dataloader_iter = iter(val_dataloader) 78 | 79 | # NOTE: adding gaussian noise parameterized by residual of std (\simga_t+1 - \sigma_t)in a cumulative way or direct adding way (\sigma_t) 80 | # NOTE: former bypasses deep copy of models each time, thus memory efficient. 81 | # NOTE: while former and latter could give different results, we take former for efficient memory usage. 82 | if std_idx == 0: 83 | std = cur_std # initial std 84 | else: 85 | std = cur_std - stds[std_idx - 1] # cumulate 86 | 87 | for name, param in model_.named_parameters(): 88 | # NOTE: add gaussian noise for all parameters 89 | param.data.add_(torch.normal(0, std, size=param.size()).type(param.dtype).to(param.device)) 90 | model_.eval() 91 | 92 | losses_over_batch = 0 93 | 94 | for step in range(max_test_iters): 95 | # NOTE: using validation dataset 96 | data, target = val_dataloader_iter.next() 97 | 98 | # NOTE: using search training dataset 99 | batchsize = data.shape[0] 100 | target = target.type(torch.LongTensor) 101 | data, target = data.to(device), target.to(device) 102 | 103 | logits = model_(data, normal_genotype, reduce_genotype) 104 | loss = criterion(logits, target) 105 | losses_over_batch += loss.item() 106 | 107 | del data, target, logits 108 | 109 | losses_mean = losses_over_batch / max_test_iters 110 | losses_per_stds.append(losses_mean) 111 | 112 | del model_ 113 | 114 | # calculate wide & flat measure 115 | poor_minima = 0 116 | # NOTE: summation of gradients for loss. (regard non-perturbed loss value as 0) 117 | # TODO: add initial loss value (std:0, non-perturbed loss value) 118 | poor_minima += abs(losses_per_stds[0] / stds[0]) 119 | for i in range(len(losses_per_stds) - 1): 120 | poor_minima += abs((losses_per_stds[i+1] - losses_per_stds[i]) / (stds[i+1] - stds[i])) 121 | 122 | wlm = 1 / (poor_minima + 1e-5) # wide & flat minima measure 123 | return wlm, losses_per_stds -------------------------------------------------------------------------------- /evolution_search/operations.py: -------------------------------------------------------------------------------- 1 | """ 2 | Code Adapted from https://github.com/megvii-model/RLNAS/blob/main/darts_search_space/cifar10/rlnas/evolution_search/operations.py 3 | """ 4 | 5 | import torch 6 | import torch.nn as nn 7 | 8 | OPS = { 9 | 'none' : lambda C, stride, affine: Zero(stride), 10 | 'avg_pool_3x3' : lambda C, stride, affine: nn.AvgPool2d(3, stride=stride, padding=1, count_include_pad=False), 11 | 'max_pool_3x3' : lambda C, stride, affine: nn.MaxPool2d(3, stride=stride, padding=1), 12 | 'skip_connect' : lambda C, stride, affine: Identity() if stride == 1 else FactorizedReduce(C, C, affine=affine), 13 | 'sep_conv_3x3' : lambda C, stride, affine: SepConv(C, C, 3, stride, 1, affine=affine), 14 | 'sep_conv_5x5' : lambda C, stride, affine: SepConv(C, C, 5, stride, 2, affine=affine), 15 | 'sep_conv_7x7' : lambda C, stride, affine: SepConv(C, C, 7, stride, 3, affine=affine), 16 | 'dil_conv_3x3' : lambda C, stride, affine: DilConv(C, C, 3, stride, 2, 2, affine=affine), 17 | 'dil_conv_5x5' : lambda C, stride, affine: DilConv(C, C, 5, stride, 4, 2, affine=affine), 18 | 'conv_7x1_1x7' : lambda C, stride, affine: nn.Sequential( 19 | nn.ReLU(inplace=False), 20 | nn.Conv2d(C, C, (1,7), stride=(1, stride), padding=(0, 3), bias=False), 21 | nn.Conv2d(C, C, (7,1), stride=(stride, 1), padding=(3, 0), bias=False), 22 | nn.BatchNorm2d(C, affine=affine) 23 | ), 24 | } 25 | 26 | class ReLUConvBN(nn.Module): 27 | 28 | def __init__(self, C_in, C_out, kernel_size, stride, padding, affine=True): 29 | super(ReLUConvBN, self).__init__() 30 | self.op = nn.Sequential( 31 | nn.ReLU(inplace=False), 32 | nn.Conv2d(C_in, C_out, kernel_size, stride=stride, padding=padding, bias=False), 33 | nn.BatchNorm2d(C_out, affine=affine) 34 | ) 35 | 36 | def forward(self, x, rngs=None): 37 | return self.op(x) 38 | 39 | class DilConv(nn.Module): 40 | 41 | def __init__(self, C_in, C_out, kernel_size, stride, padding, dilation, affine=True): 42 | super(DilConv, self).__init__() 43 | self.op = nn.Sequential( 44 | nn.ReLU(inplace=False), 45 | nn.Conv2d(C_in, C_in, kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation, groups=C_in, bias=False), 46 | nn.Conv2d(C_in, C_out, kernel_size=1, padding=0, bias=False), 47 | nn.BatchNorm2d(C_out, affine=affine), 48 | ) 49 | 50 | def forward(self, x, rngs=None): 51 | return self.op(x) 52 | 53 | class SepConv(nn.Module): 54 | 55 | def __init__(self, C_in, C_out, kernel_size, stride, padding, affine=True): 56 | super(SepConv, self).__init__() 57 | self.op = nn.Sequential( 58 | nn.ReLU(inplace=False), 59 | nn.Conv2d(C_in, C_in, kernel_size=kernel_size, stride=stride, padding=padding, groups=C_in, bias=False), 60 | nn.Conv2d(C_in, C_in, kernel_size=1, padding=0, bias=False), 61 | nn.BatchNorm2d(C_in, affine=affine), 62 | nn.ReLU(inplace=False), 63 | nn.Conv2d(C_in, C_in, kernel_size=kernel_size, stride=1, padding=padding, groups=C_in, bias=False), 64 | nn.Conv2d(C_in, C_out, kernel_size=1, padding=0, bias=False), 65 | nn.BatchNorm2d(C_out, affine=affine), 66 | ) 67 | 68 | def forward(self, x, rngs=None): 69 | return self.op(x) 70 | 71 | 72 | class Identity(nn.Module): 73 | 74 | def __init__(self): 75 | super(Identity, self).__init__() 76 | 77 | def forward(self, x, rngs=None): 78 | return x 79 | 80 | class Zero(nn.Module): 81 | 82 | def __init__(self, stride): 83 | super(Zero, self).__init__() 84 | self.stride = stride 85 | def forward(self, x, rngs=None): 86 | n, c, h, w = x.size() 87 | h //= self.stride 88 | w //= self.stride 89 | if x.is_cuda: 90 | with torch.cuda.device(x.get_device()): 91 | padding = torch.cuda.FloatTensor(n, c, h, w).fill_(0) 92 | else: 93 | padding = torch.FloatTensor(n, c, h, w).fill_(0) 94 | return padding 95 | 96 | class FactorizedReduce(nn.Module): 97 | 98 | def __init__(self, C_in, C_out, affine=True): 99 | super(FactorizedReduce, self).__init__() 100 | assert C_out % 2 == 0 101 | self.relu = nn.ReLU(inplace=False) 102 | self.conv_1 = nn.Conv2d(C_in, C_out // 2, 1, stride=2, padding=0, bias=False) 103 | self.conv_2 = nn.Conv2d(C_in, C_out // 2, 1, stride=2, padding=0, bias=False) 104 | self.bn = nn.BatchNorm2d(C_out, affine=affine) 105 | 106 | def forward(self, x, rngs=None): 107 | x = self.relu(x) 108 | out = torch.cat([self.conv_1(x), self.conv_2(x[:,:,1:,1:])], dim=1) 109 | out = self.bn(out) 110 | return out 111 | 112 | 113 | -------------------------------------------------------------------------------- /evolution_search/search.py: -------------------------------------------------------------------------------- 1 | """ 2 | Code Adapted from https://github.com/megvii-model/RLNAS/blob/main/darts_search_space/cifar10/rlnas/evolution_search/search.py 3 | """ 4 | 5 | import os 6 | import sys 7 | import time 8 | import glob 9 | import random 10 | import numpy as np 11 | import pickle 12 | import torch 13 | import logging 14 | import argparse 15 | import torch.nn as nn 16 | import torch.utils 17 | import torch.nn.functional as F 18 | import torchvision.datasets as dset 19 | import torch.backends.cudnn as cudnn 20 | from cand_evaluator import CandEvaluator 21 | from genotypes import parse_searched_cell 22 | from datasets import get_datasets, get_nas_search_loaders 23 | from config import config 24 | import collections 25 | import sys 26 | 27 | sys.setrecursionlimit(10000) 28 | import argparse 29 | import utils 30 | import functools 31 | 32 | print = functools.partial(print, flush=True) 33 | 34 | choice = ( 35 | lambda x: x[np.random.randint(len(x))] if isinstance(x, tuple) else choice(tuple(x)) 36 | ) 37 | 38 | 39 | class EvolutionTrainer(object): 40 | def __init__( 41 | self, 42 | log_dir, 43 | final_model_path, 44 | initial_model_path, 45 | metric="angle", 46 | train_loader=None, 47 | valid_loader=None, 48 | perturb_stds=None, 49 | max_train_img_size=5000, 50 | max_val_img_size=10000, 51 | wlm_weight=0, 52 | acc_weight=0, 53 | refresh=False, 54 | ): 55 | self.log_dir = log_dir 56 | self.checkpoint_name = os.path.join(self.log_dir, "checkpoint.brainpkl") 57 | self.refresh = refresh 58 | self.cand_evaluator = CandEvaluator( 59 | logging, 60 | final_model_path, 61 | initial_model_path, 62 | metric, 63 | train_loader, 64 | valid_loader, 65 | perturb_stds, 66 | max_train_img_size, 67 | max_val_img_size, 68 | wlm_weight, 69 | acc_weight, 70 | ) 71 | 72 | self.memory = [] 73 | self.candidates = [] 74 | self.vis_dict = {} 75 | self.keep_top_k = {config.select_num: [], 50: []} 76 | self.epoch = 0 77 | self.cand_idx = 0 # for generating candidate idx 78 | self.operations = [list(range(config.op_num)) for _ in range(config.edges)] 79 | 80 | self.metric = metric 81 | 82 | def save_checkpoint(self): 83 | if not os.path.exists(self.log_dir): 84 | os.mkdir(self.log_dir) 85 | info = {} 86 | info["memory"] = self.memory 87 | info["candidates"] = self.candidates 88 | info["vis_dict"] = self.vis_dict 89 | info["keep_top_k"] = self.keep_top_k 90 | info["epoch"] = self.epoch 91 | info["cand_idx"] = self.cand_idx 92 | torch.save(info, self.checkpoint_name) 93 | logging.info("save checkpoint to {}".format(self.checkpoint_name)) 94 | 95 | def load_checkpoint(self): 96 | if not os.path.exists(self.checkpoint_name): 97 | return False 98 | info = torch.load(self.checkpoint_name) 99 | self.memory = info["memory"] 100 | self.candidates = info["candidates"] 101 | self.vis_dict = info["vis_dict"] 102 | self.keep_top_k = info["keep_top_k"] 103 | self.epoch = info["epoch"] 104 | self.cand_idx = info["cand_idx"] 105 | 106 | if self.refresh: 107 | for i, j in self.vis_dict.items(): 108 | for k in ["test_key"]: 109 | if k in j: 110 | j.pop(k) 111 | self.refresh = False 112 | 113 | logging.info("load checkpoint from {}".format(self.checkpoint_name)) 114 | return True 115 | 116 | def legal(self, cand): 117 | assert isinstance(cand, tuple) and len(cand) == (2 * config.edges) 118 | if cand not in self.vis_dict: 119 | self.vis_dict[cand] = {} 120 | info = self.vis_dict[cand] 121 | if "visited" in info: 122 | return False 123 | 124 | if config.flops_limit is not None: 125 | pass 126 | 127 | self.vis_dict[cand] = info 128 | info["visited"] = True 129 | 130 | return True 131 | 132 | def update_top_k(self, candidates, *, k, key, reverse=False): 133 | assert k in self.keep_top_k 134 | logging.info("select ......") 135 | t = self.keep_top_k[k] 136 | t += candidates 137 | t.sort(key=key, reverse=reverse) 138 | self.keep_top_k[k] = t[:k] 139 | 140 | def gen_key(self, cand): 141 | # NOTE: generate unique id for candidate 142 | self.cand_idx += 1 143 | key = "{}-{}".format(self.cand_idx, time.time()) 144 | return key 145 | 146 | def eval_cand(self, cand, cand_key): 147 | # NOTE: evaluate candidate 148 | try: 149 | result = self.cand_evaluator.eval(cand) 150 | return result 151 | except: 152 | import traceback 153 | 154 | traceback.print_exc() 155 | return {"status": "uncatched error"} 156 | 157 | def sync_candidates(self): 158 | while True: 159 | ok = True 160 | for cand in self.candidates: 161 | info = self.vis_dict[cand] 162 | if self.metric in info: 163 | continue 164 | ok = False 165 | if "test_key" not in info: 166 | info["test_key"] = self.gen_key(cand) 167 | 168 | self.save_checkpoint() 169 | 170 | for cand in self.candidates: 171 | info = self.vis_dict[cand] 172 | if self.metric in info: 173 | continue 174 | key = info.pop("test_key") 175 | 176 | try: 177 | logging.info("try to get {}".format(key)) 178 | res = self.eval_cand( 179 | cand, key 180 | ) # NOTE: currently, key and cand has implicit connection 181 | logging.info(res) 182 | info[self.metric] = res[self.metric] 183 | self.save_checkpoint() 184 | except: 185 | import traceback 186 | 187 | traceback.print_exc() 188 | time.sleep(1) 189 | 190 | time.sleep(5) 191 | if ok: 192 | break 193 | 194 | def stack_random_cand(self, random_func, *, batchsize=10): 195 | while True: 196 | cands = [random_func() for _ in range(batchsize)] 197 | for cand in cands: 198 | if cand not in self.vis_dict: 199 | self.vis_dict[cand] = {} 200 | else: 201 | continue 202 | info = self.vis_dict[cand] 203 | # for cand in cands: 204 | yield cand 205 | 206 | def stack_random_cand_crossover(self, random_func, max_iters, *, batchsize=10): 207 | cand_count = 0 208 | while True: 209 | if cand_count > max_iters: 210 | break 211 | cands = [random_func() for _ in range(batchsize)] 212 | cand_count += 1 213 | for cand in cands: 214 | if cand not in self.vis_dict: 215 | self.vis_dict[cand] = {} 216 | else: 217 | continue 218 | info = self.vis_dict[cand] 219 | # for cand in cands: 220 | yield cand 221 | 222 | def random_can(self, num): 223 | logging.info("random select ........") 224 | candidates = [] 225 | cand_iter = self.stack_random_cand( 226 | lambda: tuple( 227 | np.random.randint(config.op_num) for _ in range(2 * config.edges) 228 | ) 229 | ) 230 | while len(candidates) < num: 231 | cand = next(cand_iter) 232 | normal_cand = cand[: config.edges] 233 | reduction_cand = cand[config.edges :] 234 | normal_cand = utils.check_cand(normal_cand, self.operations) 235 | reduction_cand = utils.check_cand(reduction_cand, self.operations) 236 | cand = normal_cand + reduction_cand 237 | cand = tuple(cand) 238 | if not self.legal(cand): 239 | continue 240 | candidates.append(cand) 241 | logging.info("random {}/{}".format(len(candidates), num)) 242 | logging.info("random_num = {}".format(len(candidates))) 243 | return candidates 244 | 245 | def get_mutation(self, k, mutation_num, m_prob): 246 | assert k in self.keep_top_k 247 | logging.info("mutation ......") 248 | res = [] 249 | iter = 0 250 | max_iters = mutation_num * 10 251 | 252 | def random_func(): 253 | cand = list(choice(self.keep_top_k[k])) 254 | for i in range(config.edges): 255 | if np.random.random_sample() < m_prob: 256 | cand[i] = np.random.randint(0, config.op_num) 257 | return tuple(cand) 258 | 259 | cand_iter = self.stack_random_cand(random_func) 260 | while len(res) < mutation_num and max_iters > 0: 261 | cand = next(cand_iter) 262 | normal_cand = cand[: config.edges] 263 | reduction_cand = cand[config.edges :] 264 | normal_cand = utils.check_cand(normal_cand, self.operations) 265 | reduction_cand = utils.check_cand(reduction_cand, self.operations) 266 | cand = normal_cand + reduction_cand 267 | cand = tuple(cand) 268 | if not self.legal(cand): 269 | continue 270 | res.append(cand) 271 | logging.info("mutation {}/{}".format(len(res), mutation_num)) 272 | max_iters -= 1 273 | 274 | logging.info("mutation_num = {}".format(len(res))) 275 | return res 276 | 277 | def get_crossover(self, k, crossover_num): 278 | assert k in self.keep_top_k 279 | logging.info("crossover ......") 280 | res = [] 281 | iter = 0 282 | max_iters = 10 * crossover_num 283 | 284 | def random_func(): 285 | p1 = choice(self.keep_top_k[k]) 286 | p2 = choice(self.keep_top_k[k]) 287 | return tuple(choice([i, j]) for i, j in zip(p1, p2)) 288 | 289 | cand_iter = self.stack_random_cand_crossover(random_func, crossover_num) 290 | while len(res) < crossover_num: 291 | try: 292 | cand = next(cand_iter) 293 | normal_cand = cand[: config.edges] 294 | reduction_cand = cand[config.edges :] 295 | normal_cand = utils.check_cand(normal_cand, self.operations) 296 | reduction_cand = utils.check_cand(reduction_cand, self.operations) 297 | cand = normal_cand + reduction_cand 298 | cand = tuple(cand) 299 | except Exception as e: 300 | logging.info(e) 301 | break 302 | if not self.legal(cand): 303 | continue 304 | res.append(cand) 305 | logging.info("crossover {}/{}".format(len(res), crossover_num)) 306 | 307 | logging.info("crossover_num = {}".format(len(res))) 308 | return res 309 | 310 | def train(self): 311 | logging.info( 312 | "population_num = {} select_num = {} mutation_num = {} crossover_num = {} random_num = {} max_epochs = {}".format( 313 | config.population_num, 314 | config.select_num, 315 | config.mutation_num, 316 | config.crossover_num, 317 | config.population_num - config.mutation_num - config.crossover_num, 318 | config.max_epochs, 319 | ) 320 | ) 321 | 322 | if not self.load_checkpoint(): 323 | self.candidates = self.random_can(config.population_num) 324 | self.save_checkpoint() 325 | 326 | while self.epoch < config.max_epochs: 327 | logging.info("epoch = {}".format(self.epoch)) 328 | 329 | self.sync_candidates() # NOTE: evaluate candidates 330 | 331 | logging.info("sync finish") 332 | 333 | self.memory.append([]) 334 | for cand in self.candidates: 335 | self.memory[-1].append(cand) 336 | self.vis_dict[cand]["visited"] = True 337 | 338 | self.update_top_k( 339 | self.candidates, 340 | k=config.select_num, 341 | key=lambda x: self.vis_dict[x][self.metric], 342 | reverse=True, 343 | ) 344 | self.update_top_k( 345 | self.candidates, 346 | k=50, 347 | key=lambda x: self.vis_dict[x][self.metric], 348 | reverse=True, 349 | ) 350 | 351 | logging.info( 352 | "epoch = {} : top {} result".format( 353 | self.epoch, len(self.keep_top_k[50]) 354 | ) 355 | ) 356 | for i, cand in enumerate(self.keep_top_k[50]): 357 | logging.info( 358 | "No.{} {} {} = {}".format( 359 | i + 1, cand, self.metric, self.vis_dict[cand][self.metric] 360 | ) 361 | ) 362 | # ops = [config.blocks_keys[i] for i in cand] 363 | ops = [config.blocks_keys[i] for i in cand] 364 | logging.info(ops) 365 | 366 | mutation = self.get_mutation( 367 | config.select_num, config.mutation_num, config.m_prob 368 | ) 369 | crossover = self.get_crossover(config.select_num, config.crossover_num) 370 | rand = self.random_can( 371 | config.population_num - len(mutation) - len(crossover) 372 | ) 373 | self.candidates = mutation + crossover + rand 374 | 375 | self.epoch += 1 376 | self.save_checkpoint() 377 | 378 | logging.info(self.keep_top_k[config.select_num]) 379 | logging.info("finish!") 380 | logging.info( 381 | "Top-1 Searched Cell Architecture : {}".format( 382 | parse_searched_cell(self.keep_top_k[config.select_num][0]) 383 | ) 384 | ) 385 | 386 | 387 | def prepare_seed(rand_seed): 388 | random.seed(rand_seed) 389 | np.random.seed(rand_seed) 390 | torch.manual_seed(rand_seed) 391 | torch.cuda.manual_seed(rand_seed) 392 | torch.cuda.manual_seed_all(rand_seed) 393 | 394 | 395 | class SplitArgs(argparse.Action): 396 | def __call__(self, parser, namespace, values, option_string=None): 397 | setattr(namespace, self.dest, [float(val) for val in values.split(",")]) 398 | 399 | 400 | def main(): 401 | parser = argparse.ArgumentParser() 402 | parser.add_argument("-r", "--refresh", action="store_true") 403 | parser.add_argument("--save", type=str, default="log", help="experiment name") 404 | parser.add_argument("--seed", type=int, default=1, help="experiment name") 405 | parser.add_argument( 406 | "--init_model_path", 407 | type=str, 408 | default=config.initial_net_cache, 409 | help="initial model ckpt path", 410 | ) 411 | parser.add_argument( 412 | "--model_path", type=str, default=config.net_cache, help="final model ckpt path" 413 | ) 414 | parser.add_argument( 415 | "--metric", type=str, default="angle", help="metric to evaulate candidate with." 416 | ) 417 | 418 | """ below are required if args.metric is not "angle" """ 419 | parser.add_argument( 420 | "--data", 421 | type=str, 422 | default="", 423 | help='data root path. required if --metric is not "angle"', 424 | ) 425 | parser.add_argument( 426 | "--split_data", 427 | type=int, 428 | choices=[0, 1], 429 | default=1, 430 | help="Whether use split data for training & validation. (default: True)", 431 | ) 432 | parser.add_argument("--batch_size", type=int, default=64, help="train batch_size") 433 | parser.add_argument( 434 | "--test_batch_size", type=int, default=512, help="test batch_size" 435 | ) 436 | parser.add_argument( 437 | "--cutout", action="store_true", default=False, help="use cutout" 438 | ) 439 | parser.add_argument("--cutout_length", type=int, default=16, help="cutout length") 440 | # GeNAS hyperparameters 441 | parser.add_argument( 442 | "--stds", 443 | default=None, 444 | action=SplitArgs, 445 | help="std values for weight perturbation", 446 | ) 447 | parser.add_argument( 448 | "--max_train_img_size", 449 | type=int, 450 | default=5000, 451 | help="maximum number of training imgs for batch norm statistics recalculation.", 452 | ) 453 | parser.add_argument( 454 | "--max_val_img_size", 455 | type=int, 456 | default=10000, 457 | help="maximum number of validation imgs for evaluating architecture candidates. (required only for wlm)", 458 | ) 459 | # Combined metric 460 | parser.add_argument("--wlm_weight", type=float, default=0, help="wlm weight") 461 | parser.add_argument("--acc_weight", type=float, default=0, help="acc weight") 462 | 463 | args = parser.parse_args() 464 | 465 | args.split_data = bool(args.split_data) 466 | 467 | utils.create_exp_dir(args.save, scripts_to_save=glob.glob("*.py")) 468 | 469 | log_format = "%(asctime)s %(message)s" 470 | logging.basicConfig( 471 | stream=sys.stdout, 472 | level=logging.INFO, 473 | format=log_format, 474 | datefmt="%m/%d %I:%M:%S %p", 475 | ) 476 | fh = logging.FileHandler(os.path.join(args.save, "search_log.txt")) 477 | fh.setFormatter(logging.Formatter(log_format)) 478 | logging.getLogger().addHandler(fh) 479 | 480 | if ( 481 | args.split_data 482 | ): # NOTE: split train data in half to be new train, val set. new train is used for supernet training, new val set is used for evaluation 483 | train_data, valid_data, xshape, class_num = get_datasets( 484 | "cifar100", args.data, -1, args.seed, random_label=False 485 | ) # NOTE: using GT label 486 | train_queue, _, _, valid_queue = get_nas_search_loaders( 487 | train_data, 488 | valid_data, 489 | "cifar100", 490 | "datasets/configs/", 491 | (args.batch_size, args.batch_size), 492 | 4, 493 | use_valid_no_shuffle=True, 494 | ) 495 | else: 496 | assert ValueError("only --split_data 1 is supported") 497 | 498 | refresh = args.refresh 499 | # np.random.seed(args.seed) 500 | prepare_seed(args.seed) 501 | 502 | t = time.time() 503 | 504 | trainer = EvolutionTrainer( 505 | args.save, 506 | args.model_path, 507 | args.init_model_path, 508 | metric=args.metric, 509 | train_loader=train_queue, 510 | valid_loader=valid_queue, 511 | perturb_stds=args.stds, 512 | max_train_img_size=args.max_train_img_size, 513 | max_val_img_size=args.max_val_img_size, 514 | wlm_weight=args.wlm_weight, 515 | acc_weight=args.acc_weight, 516 | refresh=refresh, 517 | ) 518 | 519 | trainer.train() 520 | logging.info("total searching time = {:.2f} hours".format((time.time() - t) / 3600)) 521 | 522 | 523 | if __name__ == "__main__": 524 | try: 525 | main() 526 | os._exit(0) 527 | except: 528 | import traceback 529 | 530 | traceback.print_exc() 531 | time.sleep(1) 532 | os._exit(1) 533 | -------------------------------------------------------------------------------- /evolution_search/super_model.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Copied from https://github.com/megvii-model/RLNAS/blob/main/darts_search_space/cifar10/rlnas/evolution_search/super_model.py 3 | ''' 4 | 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | from operations import * 9 | from torch.autograd import Variable 10 | from genotypes import PRIMITIVES 11 | from genotypes import Genotype 12 | import math 13 | import numpy as np 14 | from config import config 15 | import copy 16 | from utils import check_cand 17 | 18 | class MixedOp(nn.Module): 19 | 20 | def __init__(self, C, stride): 21 | super(MixedOp, self).__init__() 22 | self._ops = nn.ModuleList() 23 | for idx, primitive in enumerate(PRIMITIVES): 24 | op = OPS[primitive](C, stride, True) 25 | op.idx = idx 26 | if 'pool' in primitive: 27 | op = nn.Sequential(op, nn.BatchNorm2d(C, affine=True)) 28 | self._ops.append(op) 29 | 30 | def forward(self, x, rng): 31 | return self._ops[rng](x) 32 | 33 | 34 | class Cell(nn.Module): 35 | 36 | def __init__(self, steps, multiplier, C_prev_prev, C_prev, C, reduction, reduction_prev): 37 | super(Cell, self).__init__() 38 | if reduction_prev: 39 | # NOTE: if K-1 cell output was from stride-2 op, K-2 cell output should shrink its spatial size by stride-2. 40 | self.preprocess0 = FactorizedReduce(C_prev_prev, C, affine=True) 41 | else: 42 | self.preprocess0 = ReLUConvBN(C_prev_prev, C, 1, 1, 0, affine=True) 43 | self.preprocess1 = ReLUConvBN(C_prev, C, 1, 1, 0, affine=True) 44 | self._steps = steps 45 | self._multiplier = multiplier 46 | self._C = C 47 | self.out_C = self._multiplier * C 48 | self.reduction = reduction 49 | 50 | self._ops = nn.ModuleList() 51 | self._bns = nn.ModuleList() 52 | self.time_stamp = 1 53 | 54 | for i in range(self._steps): 55 | for j in range(2+i): 56 | stride = 2 if reduction and j < 2 else 1 57 | op = MixedOp(C, stride) 58 | self._ops.append(op) 59 | 60 | def forward(self, s0, s1, rngs): 61 | s0 = self.preprocess0(s0) 62 | s1 = self.preprocess1(s1) 63 | states = [s0, s1] 64 | offset = 0 65 | for i in range(self._steps): 66 | # NOTE: only two edges (operations) from two previous nodes are summed. 67 | s = sum(self._ops[offset+j](h, rngs[offset+j]) for j, h in enumerate(states)) 68 | offset += len(states) 69 | states.append(s) 70 | return torch.cat(states[-self._multiplier:], dim=1) # NOTE: final 4 intermediate nodes are concatenated. (k-1, k-2 node ouptut제외) 71 | 72 | class Network(nn.Module): 73 | def __init__(self, C=16, num_classes=100, layers=8, steps=4, multiplier=4, stem_multiplier=3): 74 | super(Network, self).__init__() 75 | self._C = C 76 | self._num_classes = num_classes 77 | self._layers = layers 78 | self._steps = steps 79 | self._multiplier = multiplier 80 | 81 | C_curr = stem_multiplier * C 82 | 83 | self.stem = nn.Sequential( 84 | nn.Conv2d(3, C_curr, 3, padding=1, bias=False), 85 | nn.BatchNorm2d(C_curr) 86 | ) 87 | 88 | C_prev_prev, C_prev, C_curr = C_curr, C_curr, C 89 | 90 | self.cells = nn.ModuleList() 91 | reduction_prev = False 92 | 93 | for i in range(layers): 94 | if i in [layers // 3, 2 * layers // 3]: 95 | C_curr *= 2 96 | reduction = True 97 | else: 98 | reduction = False 99 | cell = Cell(steps, multiplier, C_prev_prev, C_prev, C_curr, reduction, reduction_prev) 100 | reduction_prev = reduction 101 | self.cells += [cell] 102 | C_prev_prev, C_prev = C_prev, multiplier * C_curr 103 | 104 | self.global_pooling = nn.AdaptiveAvgPool2d(1) 105 | self.classifier = nn.Linear(C_prev, num_classes) 106 | 107 | def forward_normal_only(self, input, rng): 108 | ''' forward function for only normal cells ''' 109 | s0 = s1 = self.stem(input) 110 | for i, cell in enumerate(self.cells): 111 | s0, s1 = s1, cell(s0, s1, rng) 112 | out = self.global_pooling(s1) 113 | logits = self.classifier(out.view(out.size(0),-1)) 114 | return logits 115 | 116 | def forward(self, input, normal_rng, reduction_rng): 117 | ''' forward function for normal + reduction cells ''' 118 | s0 = s1 = self.stem(input) 119 | for i, cell in enumerate(self.cells): 120 | if i in [self._layers // 3, 2 * self._layers // 3]: 121 | s0, s1 = s1, cell(s0, s1, reduction_rng) 122 | else: 123 | s0, s1 = s1, cell(s0, s1, normal_rng) 124 | out = self.global_pooling(s1) 125 | logits = self.classifier(out.view(out.size(0),-1)) 126 | return logits 127 | 128 | if __name__ == '__main__': 129 | from copy import deepcopy 130 | model = Network() 131 | operations = [] 132 | for _ in range(config.edges): 133 | operations.append(list(range(config.op_num))) 134 | normal_rng = [np.random.randint(len(config.blocks_keys)) for i in range(config.edges)] 135 | reduction_rng = [np.random.randint(len(config.blocks_keys)) for i in range(config.edges)] 136 | normal_rng = check_cand(normal_rng, operations) # NOTE: modify genotype to accept only two edges (operetions) from previous nodes 137 | reduction_rng = check_cand(reduction_rng, operations) # NOTE: modify genotype to accept only two edges (operetions) from previous nodes 138 | x = torch.rand(4,3,32,32) 139 | logit = model(x, normal_rng, reduction_rng) 140 | print('logit:{0}'.format(logit)) 141 | -------------------------------------------------------------------------------- /evolution_search/utils.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Copied from https://github.com/megvii-model/RLNAS/blob/main/darts_search_space/cifar10/rlnas/evolution_search/utils.py 3 | ''' 4 | 5 | import os 6 | import numpy as np 7 | import torch 8 | import shutil 9 | import torchvision.transforms as transforms 10 | from torch.autograd import Variable 11 | from collections import defaultdict 12 | from config import config 13 | 14 | class AvgrageMeter(object): 15 | 16 | def __init__(self): 17 | self.reset() 18 | 19 | def reset(self): 20 | self.avg = 0 21 | self.sum = 0 22 | self.cnt = 0 23 | 24 | def update(self, val, n=1): 25 | self.sum += val * n 26 | self.cnt += n 27 | self.avg = self.sum / self.cnt 28 | 29 | 30 | def accuracy(output, target, topk=(1,)): 31 | maxk = max(topk) 32 | batch_size = target.size(0) 33 | 34 | _, pred = output.topk(maxk, 1, True, True) 35 | pred = pred.t() 36 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 37 | 38 | res = [] 39 | for k in topk: 40 | correct_k = correct[:k].view(-1).float().sum(0) 41 | res.append(correct_k.mul_(100.0/batch_size)) 42 | return res 43 | 44 | 45 | class Cutout(object): 46 | def __init__(self, length): 47 | self.length = length 48 | 49 | def __call__(self, img): 50 | h, w = img.size(1), img.size(2) 51 | mask = np.ones((h, w), np.float32) 52 | y = np.random.randint(h) 53 | x = np.random.randint(w) 54 | 55 | y1 = np.clip(y - self.length // 2, 0, h) 56 | y2 = np.clip(y + self.length // 2, 0, h) 57 | x1 = np.clip(x - self.length // 2, 0, w) 58 | x2 = np.clip(x + self.length // 2, 0, w) 59 | 60 | mask[y1: y2, x1: x2] = 0. 61 | mask = torch.from_numpy(mask) 62 | mask = mask.expand_as(img) 63 | img *= mask 64 | return img 65 | 66 | 67 | def _data_transforms_cifar10(args): 68 | CIFAR_MEAN = [0.49139968, 0.48215827, 0.44653124] 69 | CIFAR_STD = [0.24703233, 0.24348505, 0.26158768] 70 | 71 | train_transform = transforms.Compose([ 72 | transforms.RandomCrop(32, padding=4), 73 | transforms.RandomHorizontalFlip(), 74 | transforms.ToTensor(), 75 | transforms.Normalize(CIFAR_MEAN, CIFAR_STD), 76 | ]) 77 | if args.cutout: 78 | train_transform.transforms.append(Cutout(args.cutout_length)) 79 | 80 | valid_transform = transforms.Compose([ 81 | transforms.ToTensor(), 82 | transforms.Normalize(CIFAR_MEAN, CIFAR_STD), 83 | ]) 84 | return train_transform, valid_transform 85 | 86 | 87 | def count_parameters_in_MB(model): 88 | return np.sum(np.prod(v.size()) for name, v in model.named_parameters() if "auxiliary" not in name)/1e6 89 | 90 | 91 | def save_checkpoint(state, is_best, save): 92 | filename = os.path.join(save, 'checkpoint.pth.tar') 93 | torch.save(state, filename) 94 | if is_best: 95 | best_filename = os.path.join(save, 'model_best.pth.tar') 96 | shutil.copyfile(filename, best_filename) 97 | 98 | 99 | def save(model, model_path): 100 | torch.save(model.state_dict(), model_path) 101 | 102 | 103 | def load(model, model_path): 104 | model.load_state_dict(torch.load(model_path)) 105 | 106 | 107 | def drop_path(x, drop_prob): 108 | if drop_prob > 0.: 109 | keep_prob = 1.-drop_prob 110 | mask = Variable(torch.cuda.FloatTensor(x.size(0), 1, 1, 1).bernoulli_(keep_prob)) 111 | x.div_(keep_prob) 112 | x.mul_(mask) 113 | return x 114 | 115 | 116 | def create_exp_dir(path, scripts_to_save=None): 117 | if not os.path.exists(path): 118 | os.makedirs(path, exist_ok=True) 119 | print('Experiment dir : {}'.format(path)) 120 | 121 | if scripts_to_save is not None: 122 | os.makedirs(os.path.join(path, 'scripts'), exist_ok=True) 123 | for script in scripts_to_save: 124 | dst_file = os.path.join(path, 'scripts', os.path.basename(script)) 125 | shutil.copyfile(script, dst_file) 126 | 127 | def get_location(s, key): 128 | d = defaultdict(list) 129 | for k,va in [(v,i) for i,v in enumerate(s)]: 130 | d[k].append(va) 131 | return d[key] 132 | 133 | def list_substract(list1, list2): 134 | list1 = [item for item in list1 if item not in set(list2)] 135 | return list1 136 | 137 | def check_cand(cand, operations): 138 | cand = np.reshape(cand, [-1, config.edges]) 139 | offset, cell_cand = 0, cand[0] 140 | for j in range(4): 141 | edges = cell_cand[offset:offset+j+2] 142 | edges_ops = operations[offset:offset+j+2] 143 | none_idxs = get_location(edges, 0) 144 | if len(none_idxs) < j: 145 | general_idxs = list_substract(range(j+2), none_idxs) 146 | num = min(j-len(none_idxs), len(general_idxs)) 147 | general_idxs = np.random.choice(general_idxs, size=num, replace=False, p=None) 148 | for k in general_idxs: 149 | edges[k] = 0 150 | elif len(none_idxs) > j: 151 | none_idxs = np.random.choice(none_idxs, size=len(none_idxs)-j, replace=False, p=None) 152 | for k in none_idxs: 153 | if len(edges_ops[k]) > 1: 154 | l = np.random.randint(len(edges_ops[k])-1) 155 | edges[k] = edges_ops[k][l+1] 156 | offset += len(edges) 157 | 158 | return tuple(cell_cand) 159 | -------------------------------------------------------------------------------- /repo_figures/ABS_FBS_architecture_normal.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/clovaai/GeNAS/fd3cba87b55c5638458a059c3e15e7f74469a7bb/repo_figures/ABS_FBS_architecture_normal.png -------------------------------------------------------------------------------- /repo_figures/ABS_FBS_architecture_reduce.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/clovaai/GeNAS/fd3cba87b55c5638458a059c3e15e7f74469a7bb/repo_figures/ABS_FBS_architecture_reduce.png -------------------------------------------------------------------------------- /repo_figures/ABS_architecture_normal.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/clovaai/GeNAS/fd3cba87b55c5638458a059c3e15e7f74469a7bb/repo_figures/ABS_architecture_normal.png -------------------------------------------------------------------------------- /repo_figures/ABS_architecture_reduce.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/clovaai/GeNAS/fd3cba87b55c5638458a059c3e15e7f74469a7bb/repo_figures/ABS_architecture_reduce.png -------------------------------------------------------------------------------- /repo_figures/FBS_architecture_normal.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/clovaai/GeNAS/fd3cba87b55c5638458a059c3e15e7f74469a7bb/repo_figures/FBS_architecture_normal.png -------------------------------------------------------------------------------- /repo_figures/FBS_architecture_reduce.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/clovaai/GeNAS/fd3cba87b55c5638458a059c3e15e7f74469a7bb/repo_figures/FBS_architecture_reduce.png -------------------------------------------------------------------------------- /repo_figures/motivation.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/clovaai/GeNAS/fd3cba87b55c5638458a059c3e15e7f74469a7bb/repo_figures/motivation.png -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch==1.7.1+cu110 2 | torchvision==0.8.2+cu110 3 | opencv-contrib-python==4.6.0 4 | matplotlib == 2.2.2 5 | numpy==1.20.0 6 | Tqdm == 4.64.1 7 | wget -------------------------------------------------------------------------------- /retrain_architecture/config.py: -------------------------------------------------------------------------------- 1 | import os 2 | """ 3 | MASTER_HOST: master node ip 4 | MASTER_PORT: master node port 5 | NODE_NUM: # nodes 6 | MY_RANK: current node idx 7 | GPU_NUM: # gpus per node 8 | """ 9 | MASTER_HOST = os.environ["HOST_RANK0"] 10 | MASTER_PORT = 13322 11 | NODE_NUM = int(os.environ["WORLD_SIZE"]) 12 | MY_RANK = int(os.environ["RANK"]) 13 | GPU_NUM = int(os.environ["GPU_COUNT"]) -------------------------------------------------------------------------------- /retrain_architecture/genotypes.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Code adapted from https://github.com/megvii-model/RLNAS/blob/main/darts_search_space/cifar10/rlnas/retrain_architecture/genotypes.py 3 | ''' 4 | 5 | from ast import parse 6 | from collections import namedtuple 7 | 8 | Genotype = namedtuple("Genotype", "normal normal_concat reduce reduce_concat") 9 | 10 | PRIMITIVES = [ 11 | "none", 12 | "max_pool_3x3", 13 | "avg_pool_3x3", 14 | "skip_connect", 15 | "sep_conv_3x3", 16 | "sep_conv_5x5", 17 | "dil_conv_3x3", 18 | "dil_conv_5x5", 19 | ] 20 | 21 | NASNet = Genotype( 22 | normal=[ 23 | ("sep_conv_5x5", 1), 24 | ("sep_conv_3x3", 0), 25 | ("sep_conv_5x5", 0), 26 | ("sep_conv_3x3", 0), 27 | ("avg_pool_3x3", 1), 28 | ("skip_connect", 0), 29 | ("avg_pool_3x3", 0), 30 | ("avg_pool_3x3", 0), 31 | ("sep_conv_3x3", 1), 32 | ("skip_connect", 1), 33 | ], 34 | normal_concat=[2, 3, 4, 5, 6], 35 | reduce=[ 36 | ("sep_conv_5x5", 1), 37 | ("sep_conv_7x7", 0), 38 | ("max_pool_3x3", 1), 39 | ("sep_conv_7x7", 0), 40 | ("avg_pool_3x3", 1), 41 | ("sep_conv_5x5", 0), 42 | ("skip_connect", 3), 43 | ("avg_pool_3x3", 2), 44 | ("sep_conv_3x3", 2), 45 | ("max_pool_3x3", 1), 46 | ], 47 | reduce_concat=[4, 5, 6], 48 | ) 49 | 50 | AmoebaNet = Genotype( 51 | normal=[ 52 | ("avg_pool_3x3", 0), 53 | ("max_pool_3x3", 1), 54 | ("sep_conv_3x3", 0), 55 | ("sep_conv_5x5", 2), 56 | ("sep_conv_3x3", 0), 57 | ("avg_pool_3x3", 3), 58 | ("sep_conv_3x3", 1), 59 | ("skip_connect", 1), 60 | ("skip_connect", 0), 61 | ("avg_pool_3x3", 1), 62 | ], 63 | normal_concat=[4, 5, 6], 64 | reduce=[ 65 | ("avg_pool_3x3", 0), 66 | ("sep_conv_3x3", 1), 67 | ("max_pool_3x3", 0), 68 | ("sep_conv_7x7", 2), 69 | ("sep_conv_7x7", 0), 70 | ("avg_pool_3x3", 1), 71 | ("max_pool_3x3", 0), 72 | ("max_pool_3x3", 1), 73 | ("conv_7x1_1x7", 0), 74 | ("sep_conv_3x3", 5), 75 | ], 76 | reduce_concat=[3, 4, 6], 77 | ) 78 | 79 | DARTS_V1_CIFAR10 = Genotype( 80 | normal=[ 81 | ("sep_conv_3x3", 1), 82 | ("sep_conv_3x3", 0), 83 | ("skip_connect", 0), 84 | ("sep_conv_3x3", 1), 85 | ("skip_connect", 0), 86 | ("sep_conv_3x3", 1), 87 | ("sep_conv_3x3", 0), 88 | ("skip_connect", 2), 89 | ], 90 | normal_concat=[2, 3, 4, 5], 91 | reduce=[ 92 | ("max_pool_3x3", 0), 93 | ("max_pool_3x3", 1), 94 | ("skip_connect", 2), 95 | ("max_pool_3x3", 0), 96 | ("max_pool_3x3", 0), 97 | ("skip_connect", 2), 98 | ("skip_connect", 2), 99 | ("avg_pool_3x3", 0), 100 | ], 101 | reduce_concat=[2, 3, 4, 5], 102 | ) 103 | 104 | DARTS_V2_CIFAR10 = Genotype( 105 | normal=[ 106 | ("sep_conv_3x3", 0), 107 | ("sep_conv_3x3", 1), 108 | ("sep_conv_3x3", 0), 109 | ("sep_conv_3x3", 1), 110 | ("sep_conv_3x3", 1), 111 | ("skip_connect", 0), 112 | ("skip_connect", 0), 113 | ("dil_conv_3x3", 2), 114 | ], 115 | normal_concat=[2, 3, 4, 5], 116 | reduce=[ 117 | ("max_pool_3x3", 0), 118 | ("max_pool_3x3", 1), 119 | ("skip_connect", 2), 120 | ("max_pool_3x3", 1), 121 | ("max_pool_3x3", 0), 122 | ("skip_connect", 2), 123 | ("skip_connect", 2), 124 | ("max_pool_3x3", 1), 125 | ], 126 | reduce_concat=[2, 3, 4, 5], 127 | ) 128 | 129 | 130 | def parse_searched_cell(normal_reduce_cell): 131 | """ 132 | normal_reduce_cell: list of normal + reduce cell. 133 | e.g) [ 134 | 14 elements for normal cell edges where each element denote operation for each edge + 135 | 14 elements for reduce cell edges where each element denote operation for each edge 136 | ] 137 | """ 138 | assert ( 139 | len(normal_reduce_cell) == 28 140 | ), "cell should contain normal + reduce edges (14 + 14 = 28)" 141 | normal_cell = normal_reduce_cell[:14] 142 | reduce_cell = normal_reduce_cell[14:] 143 | 144 | normal_cell_decoded = [] 145 | reduce_cell_decoded = [] 146 | # normal cell decode 147 | for i in range(len(normal_cell)): 148 | # NOTE: for generating intermediate node 0 149 | if i in [0, 1]: 150 | normal_cell_decoded.append((PRIMITIVES[normal_cell[i]], i)) 151 | # NOTE: for generating intermediate node 1 152 | elif i in [2, 3, 4]: 153 | if normal_cell[i] != 0: 154 | normal_cell_decoded.append((PRIMITIVES[normal_cell[i]], i - 2)) 155 | # NOTE: for generating intermediate node 2 156 | elif i in [5, 6, 7, 8]: 157 | if normal_cell[i] != 0: 158 | normal_cell_decoded.append((PRIMITIVES[normal_cell[i]], i - 5)) 159 | # NOTE: for generating intermediate node 3 160 | elif i in [9, 10, 11, 12, 13]: 161 | if normal_cell[i] != 0: 162 | normal_cell_decoded.append((PRIMITIVES[normal_cell[i]], i - 9)) 163 | 164 | # reduce cell decode 165 | for i in range(len(reduce_cell)): 166 | # NOTE: for generating intermediate node 0 167 | if i in [0, 1]: 168 | reduce_cell_decoded.append((PRIMITIVES[reduce_cell[i]], i)) 169 | # NOTE: for generating intermediate node 1 170 | elif i in [2, 3, 4]: 171 | if reduce_cell[i] != 0: 172 | reduce_cell_decoded.append((PRIMITIVES[reduce_cell[i]], i - 2)) 173 | # NOTE: for generating intermediate node 2 174 | elif i in [5, 6, 7, 8]: 175 | if reduce_cell[i] != 0: 176 | reduce_cell_decoded.append((PRIMITIVES[reduce_cell[i]], i - 5)) 177 | # NOTE: for generating intermediate node 3 178 | elif i in [9, 10, 11, 12, 13]: 179 | if reduce_cell[i] != 0: 180 | reduce_cell_decoded.append((PRIMITIVES[reduce_cell[i]], i - 9)) 181 | 182 | return Genotype( 183 | normal=normal_cell_decoded, 184 | normal_concat=[2, 3, 4, 5], 185 | reduce=reduce_cell_decoded, 186 | reduce_concat=[2, 3, 4, 5], 187 | ) 188 | 189 | 190 | # print(parse_searched_cell((5, 4, 6, 0, 5, 4, 0, 0, 7, 0, 7, 6, 0, 0, 5, 6, 4, 0, 5, 0, 6, 0, 4, 0, 1, 5, 0, 0))) 191 | # cand=[[5, 4, 6, 0, 5, 4, 0, 0, 7, 0, 7, 6, 0, 0], [5, 6, 4, 0, 5, 0, 6, 0, 4, 0, 1, 5, 0, 0]] 192 | # NOTE: [5, 4, 6, 0, 5, 4, 0, 0, 7, 0, 7, 6, 0, 0] means 193 | # NOTE: [5, 4]: for generating intermediate node 0, operation 5:sep_conv_5x5(k-2 node(prev prev cell output)) + operation 4: sep_conv_3x3(k-1 node (prev cell output)) 194 | # NOTE: [6, 0, 5]: for generating intermediate node 1, operation 6:dil_conv_3x3(itm node 0) + operation 5:sep_conv_5x5(itm node 2) 195 | # NOTE: [4, 0, 0, 7]: for generating intermediate node 2, operation 4:sep_conv_3x3(itm node 0) + operation 7:dil_conv_5x5(itm node 3) 196 | # NOTE: [0, 7, 6, 0, 0]: for generating intermediate node 3, operation 7:dil_conv_5x5(itm node 1) + operation 6:dil_conv_3x3(itm node 2) 197 | # NOTE: all intermediate node outputs (0, 1, 2, 3) are concatenated to be the output of current cell. 198 | # RLDARTS = Genotype( 199 | # normal=[('sep_conv_5x5', 0), ('sep_conv_3x3', 1), ('dil_conv_3x3', 0), ('sep_conv_5x5', 2), ('sep_conv_3x3', 0), 200 | # ('dil_conv_5x5', 3), ('dil_conv_5x5', 1), ('dil_conv_3x3', 2)], normal_concat=[2, 3, 4, 5], 201 | # reduce=[('sep_conv_5x5', 0), ('dil_conv_3x3', 1), ('sep_conv_3x3', 0), ('sep_conv_5x5', 2), ('dil_conv_3x3', 1), 202 | # ('sep_conv_3x3', 3), ('max_pool_3x3', 1), ('sep_conv_5x5', 2)], reduce_concat=[2, 3, 4, 5]) 203 | RLDARTS_OURS_GT = parse_searched_cell( 204 | (5, 5, 2, 5, 0, 4, 4, 0, 0, 0, 4, 0, 0, 4, 3, 3, 0, 3, 2, 3, 7, 0, 0, 3, 0, 0, 0, 4) 205 | ) 206 | PCDARTS_OURS_SEARCHEPOCH40 = Genotype( 207 | normal=[ 208 | ("sep_conv_3x3", 0), 209 | ("sep_conv_3x3", 1), 210 | ("sep_conv_3x3", 1), 211 | ("sep_conv_5x5", 0), 212 | ("sep_conv_5x5", 1), 213 | ("sep_conv_3x3", 3), 214 | ("sep_conv_5x5", 4), 215 | ("sep_conv_3x3", 0), 216 | ], 217 | normal_concat=range(2, 6), 218 | reduce=[ 219 | ("max_pool_3x3", 0), 220 | ("sep_conv_3x3", 1), 221 | ("max_pool_3x3", 0), 222 | ("skip_connect", 1), 223 | ("sep_conv_5x5", 2), 224 | ("sep_conv_3x3", 0), 225 | ("skip_connect", 0), 226 | ("dil_conv_3x3", 4), 227 | ], 228 | reduce_concat=range(2, 6), 229 | ) 230 | 231 | PCDARTS_OURS = Genotype( 232 | normal=[ 233 | ("dil_conv_5x5", 1), 234 | ("dil_conv_3x3", 0), 235 | ("dil_conv_5x5", 1), 236 | ("max_pool_3x3", 0), 237 | ("dil_conv_3x3", 0), 238 | ("sep_conv_3x3", 3), 239 | ("sep_conv_3x3", 0), 240 | ("sep_conv_3x3", 2), 241 | ], 242 | normal_concat=range(2, 6), 243 | reduce=[ 244 | ("max_pool_3x3", 0), 245 | ("max_pool_3x3", 1), 246 | ("max_pool_3x3", 0), 247 | ("skip_connect", 2), 248 | ("sep_conv_5x5", 2), 249 | ("skip_connect", 1), 250 | ("sep_conv_3x3", 0), 251 | ("sep_conv_3x3", 2), 252 | ], 253 | reduce_concat=range(2, 6), 254 | ) 255 | 256 | # PDARTS searched on CIFAR-10 257 | PDARTS_CIFAR10 = Genotype( 258 | normal=[ 259 | ("skip_connect", 0), 260 | ("dil_conv_3x3", 1), 261 | ("skip_connect", 0), 262 | ("sep_conv_3x3", 1), 263 | ("sep_conv_3x3", 1), 264 | ("sep_conv_3x3", 3), 265 | ("sep_conv_3x3", 0), 266 | ("dil_conv_5x5", 4), 267 | ], 268 | normal_concat=range(2, 6), 269 | reduce=[ 270 | ("avg_pool_3x3", 0), 271 | ("sep_conv_5x5", 1), 272 | ("sep_conv_3x3", 0), 273 | ("dil_conv_5x5", 2), 274 | ("max_pool_3x3", 0), 275 | ("dil_conv_3x3", 1), 276 | ("dil_conv_3x3", 1), 277 | ("dil_conv_5x5", 3), 278 | ], 279 | reduce_concat=range(2, 6), 280 | ) 281 | 282 | 283 | # DARTS-v1 searched on CIFAR-100 284 | DARTS_V1_CIFAR100 = Genotype( 285 | normal=[ 286 | ("skip_connect", 0), 287 | ("sep_conv_3x3", 1), 288 | ("skip_connect", 0), 289 | ("sep_conv_3x3", 1), 290 | ("skip_connect", 0), 291 | ("skip_connect", 1), 292 | ("skip_connect", 0), 293 | ("skip_connect", 1), 294 | ], 295 | normal_concat=range(2, 6), 296 | reduce=[ 297 | ("avg_pool_3x3", 0), 298 | ("avg_pool_3x3", 1), 299 | ("avg_pool_3x3", 0), 300 | ("skip_connect", 2), 301 | ("skip_connect", 2), 302 | ("avg_pool_3x3", 0), 303 | ("skip_connect", 2), 304 | ("avg_pool_3x3", 0), 305 | ], 306 | reduce_concat=range(2, 6), 307 | ) 308 | 309 | SDARTS_RS_CIFAR10 = Genotype( 310 | normal=[ 311 | ("sep_conv_3x3", 1), 312 | ("sep_conv_3x3", 0), 313 | ("sep_conv_5x5", 1), 314 | ("skip_connect", 0), 315 | ("sep_conv_3x3", 3), 316 | ("skip_connect", 1), 317 | ("sep_conv_3x3", 1), 318 | ("dil_conv_3x3", 2), 319 | ], 320 | normal_concat=range(2, 6), 321 | reduce=[ 322 | ("max_pool_3x3", 0), 323 | ("sep_conv_3x3", 1), 324 | ("skip_connect", 2), 325 | ("max_pool_3x3", 0), 326 | ("dil_conv_5x5", 3), 327 | ("max_pool_3x3", 0), 328 | ("sep_conv_3x3", 2), 329 | ("sep_conv_5x5", 3), 330 | ], 331 | reduce_concat=range(2, 6), 332 | ) 333 | 334 | SDARTS_ADV_CIFAR10 = Genotype( 335 | normal=[ 336 | ("sep_conv_3x3", 0), 337 | ("sep_conv_3x3", 1), 338 | ("sep_conv_3x3", 1), 339 | ("skip_connect", 0), 340 | ("sep_conv_5x5", 0), 341 | ("dil_conv_3x3", 3), 342 | ("dil_conv_3x3", 4), 343 | ("skip_connect", 0), 344 | ], 345 | normal_concat=range(2, 6), 346 | reduce=[ 347 | ("max_pool_3x3", 0), 348 | ("sep_conv_5x5", 1), 349 | ("skip_connect", 2), 350 | ("max_pool_3x3", 0), 351 | ("skip_connect", 3), 352 | ("skip_connect", 2), 353 | ("skip_connect", 2), 354 | ("sep_conv_5x5", 4), 355 | ], 356 | reduce_concat=range(2, 6), 357 | ) 358 | 359 | DROPNAS = Genotype( 360 | normal=[ 361 | ("skip_connect", 0), 362 | ("sep_conv_3x3", 1), 363 | ("sep_conv_3x3", 1), 364 | ("max_pool_3x3", 2), 365 | ("sep_conv_3x3", 1), 366 | ("sep_conv_5x5", 2), 367 | ("sep_conv_5x5", 0), 368 | ("sep_conv_5x5", 1), 369 | ], 370 | normal_concat=[2, 3, 4, 5], 371 | reduce=[ 372 | ("max_pool_3x3", 0), 373 | ("sep_conv_5x5", 1), 374 | ("dil_conv_5x5", 2), 375 | ("sep_conv_5x5", 1), 376 | ("dil_conv_5x5", 2), 377 | ("dil_conv_5x5", 3), 378 | ("dil_conv_5x5", 2), 379 | ("dil_conv_5x5", 4), 380 | ], 381 | reduce_concat=[2, 3, 4, 5], 382 | ) 383 | 384 | GENAS_FLATNESS_CIFAR10 = parse_searched_cell( 385 | (3, 6, 0, 4, 4, 0, 6, 6, 0, 0, 0, 0, 4, 3, 5, 2, 0, 6, 4, 0, 3, 0, 6, 0, 0, 5, 0, 7) 386 | ) 387 | 388 | GENAS_ANGLE_FLATNESS_CIFAR10 = parse_searched_cell( 389 | (5, 4, 4, 4, 0, 4, 5, 0, 0, 4, 0, 0, 0, 4, 7, 7, 2, 3, 0, 4, 1, 0, 0, 0, 0, 0, 5, 6) 390 | ) 391 | 392 | GENAS_FLATNESS_CIFAR100 = parse_searched_cell( 393 | (3, 5, 0, 4, 4, 0, 0, 4, 5, 0, 4, 0, 0, 7, 6, 3, 2, 0, 5, 0, 5, 0, 7, 0, 1, 1, 0, 0) 394 | ) 395 | 396 | GENAS_ANGLE_FLATNESS_CIFAR100 = parse_searched_cell( 397 | (3, 5, 1, 4, 0, 0, 4, 0, 4, 0, 4, 0, 0, 5, 6, 4, 0, 7, 2, 1, 3, 0, 0, 0, 4, 0, 0, 7) 398 | ) -------------------------------------------------------------------------------- /retrain_architecture/model.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Copied from https://github.com/megvii-model/RLNAS/blob/main/darts_search_space/cifar10/rlnas/retrain_architecture/model.py 3 | ''' 4 | 5 | import torch 6 | import torch.nn as nn 7 | from operations import * 8 | from torch.autograd import Variable 9 | from utils import drop_path 10 | 11 | 12 | class Cell(nn.Module): 13 | 14 | def __init__(self, genotype, C_prev_prev, C_prev, C, reduction, reduction_prev): 15 | super(Cell, self).__init__() 16 | print(C_prev_prev, C_prev, C) 17 | 18 | if reduction_prev: 19 | self.preprocess0 = FactorizedReduce(C_prev_prev, C) 20 | else: 21 | self.preprocess0 = ReLUConvBN(C_prev_prev, C, 1, 1, 0) 22 | self.preprocess1 = ReLUConvBN(C_prev, C, 1, 1, 0) 23 | 24 | if reduction: 25 | op_names, indices = zip(*genotype.reduce) 26 | concat = genotype.reduce_concat 27 | else: 28 | op_names, indices = zip(*genotype.normal) 29 | concat = genotype.normal_concat 30 | self._compile(C, op_names, indices, concat, reduction) 31 | 32 | def _compile(self, C, op_names, indices, concat, reduction): 33 | assert len(op_names) == len(indices) 34 | self._steps = len(op_names) // 2 35 | self._concat = concat 36 | self.multiplier = len(concat) 37 | 38 | self._ops = nn.ModuleList() 39 | for name, index in zip(op_names, indices): 40 | stride = 2 if reduction and index < 2 else 1 41 | op = OPS[name](C, stride, True) 42 | self._ops += [op] 43 | self._indices = indices 44 | 45 | def forward(self, s0, s1, drop_prob): 46 | s0 = self.preprocess0(s0) 47 | s1 = self.preprocess1(s1) 48 | 49 | states = [s0, s1] 50 | for i in range(self._steps): 51 | h1 = states[self._indices[2 * i]] 52 | h2 = states[self._indices[2 * i + 1]] 53 | op1 = self._ops[2 * i] 54 | op2 = self._ops[2 * i + 1] 55 | h1 = op1(h1) 56 | h2 = op2(h2) 57 | if self.training and drop_prob > 0.: 58 | if not isinstance(op1, Identity): 59 | h1 = drop_path(h1, drop_prob) 60 | if not isinstance(op2, Identity): 61 | h2 = drop_path(h2, drop_prob) 62 | s = h1 + h2 63 | states += [s] 64 | return torch.cat([states[i] for i in self._concat], dim=1) 65 | 66 | 67 | class AuxiliaryHeadCIFAR(nn.Module): 68 | 69 | def __init__(self, C, num_classes): 70 | """assuming input size 8x8""" 71 | super(AuxiliaryHeadCIFAR, self).__init__() 72 | self.features = nn.Sequential( 73 | nn.ReLU(inplace=True), 74 | nn.AvgPool2d(5, stride=3, padding=0, count_include_pad=False), # image size = 2 x 2 75 | nn.Conv2d(C, 128, 1, bias=False), 76 | nn.BatchNorm2d(128), 77 | nn.ReLU(inplace=True), 78 | nn.Conv2d(128, 768, 2, bias=False), 79 | nn.BatchNorm2d(768), 80 | nn.ReLU(inplace=True) 81 | ) 82 | self.classifier = nn.Linear(768, num_classes) 83 | 84 | def forward(self, x): 85 | x = self.features(x) 86 | x = self.classifier(x.view(x.size(0), -1)) 87 | return x 88 | 89 | 90 | class AuxiliaryHeadImageNet(nn.Module): 91 | 92 | def __init__(self, C, num_classes): 93 | """assuming input size 14x14""" 94 | super(AuxiliaryHeadImageNet, self).__init__() 95 | self.features = nn.Sequential( 96 | nn.ReLU(inplace=True), 97 | nn.AvgPool2d(5, stride=2, padding=0, count_include_pad=False), 98 | nn.Conv2d(C, 128, 1, bias=False), 99 | nn.BatchNorm2d(128), 100 | nn.ReLU(inplace=True), 101 | nn.Conv2d(128, 768, 2, bias=False), 102 | # NOTE: This batchnorm was omitted in my earlier implementation due to a typo. 103 | # Commenting it out for consistency with the experiments in the paper. 104 | # nn.BatchNorm2d(768), 105 | nn.ReLU(inplace=True) 106 | ) 107 | self.classifier = nn.Linear(768, num_classes) 108 | 109 | def forward(self, x): 110 | x = self.features(x) 111 | x = self.classifier(x.view(x.size(0), -1)) 112 | return x 113 | 114 | 115 | class NetworkCIFAR(nn.Module): 116 | 117 | def __init__(self, C, num_classes, layers, auxiliary, genotype): 118 | super(NetworkCIFAR, self).__init__() 119 | self._layers = layers 120 | self._auxiliary = auxiliary 121 | 122 | stem_multiplier = 3 123 | C_curr = stem_multiplier * C 124 | self.stem = nn.Sequential( 125 | nn.Conv2d(3, C_curr, 3, padding=1, bias=False), 126 | nn.BatchNorm2d(C_curr) 127 | ) 128 | 129 | C_prev_prev, C_prev, C_curr = C_curr, C_curr, C 130 | self.cells = nn.ModuleList() 131 | reduction_prev = False 132 | for i in range(layers): 133 | if i in [layers // 3, 2 * layers // 3]: 134 | C_curr *= 2 135 | reduction = True 136 | else: 137 | reduction = False 138 | cell = Cell(genotype, C_prev_prev, C_prev, C_curr, reduction, reduction_prev) 139 | reduction_prev = reduction 140 | self.cells += [cell] 141 | C_prev_prev, C_prev = C_prev, cell.multiplier * C_curr 142 | if i == 2 * layers // 3: 143 | C_to_auxiliary = C_prev 144 | 145 | if auxiliary: 146 | self.auxiliary_head = AuxiliaryHeadCIFAR(C_to_auxiliary, num_classes) 147 | self.global_pooling = nn.AdaptiveAvgPool2d(1) 148 | self.classifier = nn.Linear(C_prev, num_classes) 149 | 150 | def forward(self, input): 151 | logits_aux = None 152 | s0 = s1 = self.stem(input) 153 | for i, cell in enumerate(self.cells): 154 | s0, s1 = s1, cell(s0, s1, self.drop_path_prob) 155 | if i == 2 * self._layers // 3: 156 | if self._auxiliary and self.training: 157 | logits_aux = self.auxiliary_head(s1) 158 | out = self.global_pooling(s1) 159 | logits = self.classifier(out.view(out.size(0), -1)) 160 | return logits, logits_aux 161 | 162 | 163 | class NetworkImageNet(nn.Module): 164 | 165 | def __init__(self, C, num_classes, layers, auxiliary, genotype): 166 | super(NetworkImageNet, self).__init__() 167 | self._layers = layers 168 | self._auxiliary = auxiliary 169 | 170 | self.stem0 = nn.Sequential( 171 | nn.Conv2d(3, C // 2, kernel_size=3, stride=2, padding=1, bias=False), 172 | nn.BatchNorm2d(C // 2), 173 | nn.ReLU(inplace=True), 174 | nn.Conv2d(C // 2, C, 3, stride=2, padding=1, bias=False), 175 | nn.BatchNorm2d(C), 176 | ) 177 | 178 | self.stem1 = nn.Sequential( 179 | nn.ReLU(inplace=True), 180 | nn.Conv2d(C, C, 3, stride=2, padding=1, bias=False), 181 | nn.BatchNorm2d(C), 182 | ) 183 | 184 | C_prev_prev, C_prev, C_curr = C, C, C 185 | 186 | self.cells = nn.ModuleList() 187 | reduction_prev = True 188 | for i in range(layers): 189 | if i in [layers // 3, 2 * layers // 3]: 190 | C_curr *= 2 191 | reduction = True 192 | else: 193 | reduction = False 194 | cell = Cell(genotype, C_prev_prev, C_prev, C_curr, reduction, reduction_prev) 195 | reduction_prev = reduction 196 | self.cells += [cell] 197 | C_prev_prev, C_prev = C_prev, cell.multiplier * C_curr 198 | if i == 2 * layers // 3: 199 | C_to_auxiliary = C_prev 200 | 201 | if auxiliary: 202 | self.auxiliary_head = AuxiliaryHeadImageNet(C_to_auxiliary, num_classes) 203 | self.global_pooling = nn.AvgPool2d(7) 204 | self.classifier = nn.Linear(C_prev, num_classes) 205 | 206 | def forward(self, input): 207 | logits_aux = None 208 | s0 = self.stem0(input) 209 | s1 = self.stem1(s0) 210 | for i, cell in enumerate(self.cells): 211 | s0, s1 = s1, cell(s0, s1, self.drop_path_prob) 212 | if i == 2 * self._layers // 3: 213 | if self._auxiliary and self.training: 214 | logits_aux = self.auxiliary_head(s1) 215 | out = self.global_pooling(s1) 216 | logits = self.classifier(out.view(out.size(0), -1)) 217 | return logits, logits_aux 218 | -------------------------------------------------------------------------------- /retrain_architecture/operations.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Copied from https://github.com/megvii-model/RLNAS/blob/main/darts_search_space/cifar10/rlnas/retrain_architecture/operations.py 3 | ''' 4 | 5 | import torch 6 | import torch.nn as nn 7 | 8 | OPS = { 9 | 'none' : lambda C, stride, affine: Zero(stride), 10 | 'avg_pool_3x3' : lambda C, stride, affine: nn.AvgPool2d(3, stride=stride, padding=1, count_include_pad=False), 11 | 'max_pool_3x3' : lambda C, stride, affine: nn.MaxPool2d(3, stride=stride, padding=1), 12 | 'skip_connect' : lambda C, stride, affine: Identity() if stride == 1 else FactorizedReduce(C, C, affine=affine), 13 | 'sep_conv_3x3' : lambda C, stride, affine: SepConv(C, C, 3, stride, 1, affine=affine), 14 | 'sep_conv_5x5' : lambda C, stride, affine: SepConv(C, C, 5, stride, 2, affine=affine), 15 | 'sep_conv_7x7' : lambda C, stride, affine: SepConv(C, C, 7, stride, 3, affine=affine), 16 | 'dil_conv_3x3' : lambda C, stride, affine: DilConv(C, C, 3, stride, 2, 2, affine=affine), 17 | 'dil_conv_5x5' : lambda C, stride, affine: DilConv(C, C, 5, stride, 4, 2, affine=affine), 18 | 'conv_7x1_1x7' : lambda C, stride, affine: nn.Sequential( 19 | nn.ReLU(inplace=False), 20 | nn.Conv2d(C, C, (1,7), stride=(1, stride), padding=(0, 3), bias=False), 21 | nn.Conv2d(C, C, (7,1), stride=(stride, 1), padding=(3, 0), bias=False), 22 | nn.BatchNorm2d(C, affine=affine) 23 | ), 24 | } 25 | 26 | class ReLUConvBN(nn.Module): 27 | 28 | def __init__(self, C_in, C_out, kernel_size, stride, padding, affine=True): 29 | super(ReLUConvBN, self).__init__() 30 | self.op = nn.Sequential( 31 | nn.ReLU(inplace=False), 32 | nn.Conv2d(C_in, C_out, kernel_size, stride=stride, padding=padding, bias=False), 33 | nn.BatchNorm2d(C_out, affine=affine) 34 | ) 35 | 36 | def forward(self, x, rngs=None): 37 | return self.op(x) 38 | 39 | class DilConv(nn.Module): 40 | 41 | def __init__(self, C_in, C_out, kernel_size, stride, padding, dilation, affine=True): 42 | super(DilConv, self).__init__() 43 | self.op = nn.Sequential( 44 | nn.ReLU(inplace=False), 45 | nn.Conv2d(C_in, C_in, kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation, groups=C_in, bias=False), 46 | nn.Conv2d(C_in, C_out, kernel_size=1, padding=0, bias=False), 47 | nn.BatchNorm2d(C_out, affine=affine), 48 | ) 49 | 50 | def forward(self, x, rngs=None): 51 | return self.op(x) 52 | 53 | class SepConv(nn.Module): 54 | 55 | def __init__(self, C_in, C_out, kernel_size, stride, padding, affine=True): 56 | super(SepConv, self).__init__() 57 | self.op = nn.Sequential( 58 | nn.ReLU(inplace=False), 59 | nn.Conv2d(C_in, C_in, kernel_size=kernel_size, stride=stride, padding=padding, groups=C_in, bias=False), 60 | nn.Conv2d(C_in, C_in, kernel_size=1, padding=0, bias=False), 61 | nn.BatchNorm2d(C_in, affine=affine), 62 | nn.ReLU(inplace=False), 63 | nn.Conv2d(C_in, C_in, kernel_size=kernel_size, stride=1, padding=padding, groups=C_in, bias=False), 64 | nn.Conv2d(C_in, C_out, kernel_size=1, padding=0, bias=False), 65 | nn.BatchNorm2d(C_out, affine=affine), 66 | ) 67 | 68 | def forward(self, x, rngs=None): 69 | return self.op(x) 70 | 71 | 72 | class Identity(nn.Module): 73 | 74 | def __init__(self): 75 | super(Identity, self).__init__() 76 | 77 | def forward(self, x, rngs=None): 78 | return x 79 | 80 | class Zero(nn.Module): 81 | 82 | def __init__(self, stride): 83 | super(Zero, self).__init__() 84 | self.stride = stride 85 | def forward(self, x, rngs=None): 86 | n, c, h, w = x.size() 87 | h //= self.stride 88 | w //= self.stride 89 | if x.is_cuda: 90 | with torch.cuda.device(x.get_device()): 91 | padding = torch.cuda.FloatTensor(n, c, h, w).fill_(0) 92 | else: 93 | padding = torch.FloatTensor(n, c, h, w).fill_(0) 94 | return padding 95 | 96 | class FactorizedReduce(nn.Module): 97 | 98 | def __init__(self, C_in, C_out, affine=True): 99 | super(FactorizedReduce, self).__init__() 100 | assert C_out % 2 == 0 101 | self.relu = nn.ReLU(inplace=False) 102 | self.conv_1 = nn.Conv2d(C_in, C_out // 2, 1, stride=2, padding=0, bias=False) 103 | self.conv_2 = nn.Conv2d(C_in, C_out // 2, 1, stride=2, padding=0, bias=False) 104 | self.bn = nn.BatchNorm2d(C_out, affine=affine) 105 | 106 | def forward(self, x, rngs=None): 107 | x = self.relu(x) 108 | out = torch.cat([self.conv_1(x), self.conv_2(x[:,:,1:,1:])], dim=1) 109 | out = self.bn(out) 110 | return out 111 | 112 | 113 | -------------------------------------------------------------------------------- /retrain_architecture/retrain.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import numpy as np 4 | import time 5 | import torch 6 | import utils 7 | import glob 8 | import random 9 | import logging 10 | import argparse 11 | import torch.nn as nn 12 | import genotypes 13 | import torch.utils 14 | import torchvision.transforms as transforms 15 | import torch.backends.cudnn as cudnn 16 | import time 17 | import torch.multiprocessing as mp 18 | import torch.distributed as dist 19 | from torch.autograd import Variable 20 | from model import NetworkImageNet as Network 21 | from tensorboardX import SummaryWriter 22 | from thop import profile 23 | import torchvision.datasets as datasets 24 | from config import ( 25 | MASTER_HOST, 26 | MASTER_PORT, 27 | NODE_NUM, 28 | MY_RANK, 29 | GPU_NUM, 30 | ) 31 | 32 | parser = argparse.ArgumentParser("training imagenet") 33 | parser.add_argument( 34 | "--data_root", type=str, required=True, help="imagenet dataset root directory" 35 | ) 36 | parser.add_argument( 37 | "--workers", type=int, default=32, help="number of workers to load dataset" 38 | ) 39 | parser.add_argument("--batch_size", type=int, default=1024, help="batch size") 40 | parser.add_argument( 41 | "--learning_rate", type=float, default=0.5, help="init learning rate" 42 | ) 43 | parser.add_argument("--momentum", type=float, default=0.9, help="momentum") 44 | parser.add_argument("--weight_decay", type=float, default=3e-5, help="weight decay") 45 | parser.add_argument("--report_freq", type=float, default=100, help="report frequency") 46 | parser.add_argument("--epochs", type=int, default=250, help="num of training epochs") 47 | parser.add_argument( 48 | "--init_channels", type=int, default=48, help="num of init channels" 49 | ) 50 | parser.add_argument("--layers", type=int, default=14, help="total number of layers") 51 | parser.add_argument( 52 | "--auxiliary", action="store_true", default=False, help="use auxiliary tower" 53 | ) 54 | parser.add_argument( 55 | "--auxiliary_weight", type=float, default=0.4, help="weight for auxiliary loss" 56 | ) 57 | parser.add_argument( 58 | "--drop_path_prob", type=float, default=0, help="drop path probability" 59 | ) 60 | parser.add_argument("--save", type=str, default="test", help="experiment name") 61 | parser.add_argument("--seed", type=int, default=0, help="random seed") 62 | parser.add_argument( 63 | "--arch", type=str, default="PDARTS", help="which architecture to use" 64 | ) 65 | parser.add_argument("--grad_clip", type=float, default=5.0, help="gradient clipping") 66 | parser.add_argument("--label_smooth", type=float, default=0.1, help="label smoothing") 67 | parser.add_argument( 68 | "--lr_scheduler", type=str, default="linear", help="lr scheduler, linear or cosine" 69 | ) 70 | parser.add_argument("--note", type=str, default="try", help="note for this run") 71 | 72 | args, unparsed = parser.parse_known_args() 73 | 74 | # args.save = "eval-{}".format(args.save) 75 | 76 | if not os.path.exists(args.save): 77 | utils.create_exp_dir(args.save, scripts_to_save=glob.glob("*.py")) 78 | 79 | time.sleep(1) 80 | log_format = "%(asctime)s %(message)s" 81 | logging.basicConfig( 82 | stream=sys.stdout, 83 | level=logging.INFO, 84 | format=log_format, 85 | datefmt="%m/%d %I:%M:%S %p", 86 | ) 87 | fh = logging.FileHandler(os.path.join(args.save, "log.txt")) 88 | fh.setFormatter(logging.Formatter(log_format)) 89 | logging.getLogger().addHandler(fh) 90 | writer = SummaryWriter(logdir=args.save) 91 | 92 | IMAGENET_TRAINING_SET_SIZE = 1281167 93 | IMAGENET_TEST_SET_SIZE = 50000 94 | CLASSES = 1000 95 | train_iters = ( 96 | IMAGENET_TRAINING_SET_SIZE // args.batch_size 97 | ) # NOTE: for each training iteration, all gpus on multiple nodes take args.batch_size (1024) // (# gpu per node (4)* # node (2)) = 128 imgs, which are gathered to be args.batch_size=1024 in DistributedDataParallel. 98 | val_iters = ( 99 | IMAGENET_TEST_SET_SIZE // args.batch_size 100 | ) # NOTE: Without DistributedDataParallel. Thus, single GPU (gpu id = 0 per node) takes args.batch_size = 1024 imgs. 101 | 102 | # Average loss across processes for logging. 103 | def reduce_tensor(tensor, device=0, world_size=1): 104 | tensor = tensor.clone() 105 | dist.reduce(tensor, device) 106 | tensor.div_(world_size) 107 | return tensor 108 | 109 | 110 | class CrossEntropyLabelSmooth(nn.Module): 111 | def __init__(self, num_classes, epsilon): 112 | super(CrossEntropyLabelSmooth, self).__init__() 113 | self.num_classes = num_classes 114 | self.epsilon = epsilon 115 | self.logsoftmax = nn.LogSoftmax(dim=1) 116 | 117 | def forward(self, inputs, targets): 118 | log_probs = self.logsoftmax(inputs) 119 | targets = torch.zeros_like(log_probs).scatter_(1, targets.unsqueeze(1), 1) 120 | targets = (1 - self.epsilon) * targets + self.epsilon / self.num_classes 121 | loss = (-targets * log_probs).mean(0).sum() 122 | return loss 123 | 124 | 125 | def main(local_rank, *args): 126 | # NOTE: local_rank is reserved from mp.spawn, which denotes local gpu id inside current node. 127 | args = args[0] # NOTE: take arguments 128 | if not torch.cuda.is_available(): 129 | logging.info("No GPU device available") 130 | sys.exit(1) 131 | 132 | num_gpus_per_node = GPU_NUM # NOTE: num gpus per node. 133 | 134 | np.random.seed(args.seed) 135 | cudnn.benchmark = True 136 | cudnn.deterministic = True 137 | 138 | torch.manual_seed(args.seed) 139 | cudnn.enabled = True 140 | torch.cuda.manual_seed(args.seed) 141 | logging.info("args = %s", args) 142 | logging.info("unparsed_args = %s", unparsed) 143 | 144 | n_nodes = NODE_NUM 145 | args.world_size = n_nodes * num_gpus_per_node 146 | assert ( 147 | args.world_size == 8 148 | ), "world_size is not 8." # for reproducibility 149 | args.dist_url = "tcp://{}:{}".format(MASTER_HOST, MASTER_PORT) 150 | args.distributed = args.world_size > 1 # NOTE: whether using distributed or not 151 | os.environ["NCCL_DEBUG"] = "info" 152 | # os.environ["NCCL_SOCKET_IFNAME"] = "bond0" 153 | print("master addr: {} with {} node(s)".format(args.dist_url, n_nodes)) 154 | 155 | global_rank = ( 156 | num_gpus_per_node * MY_RANK + local_rank 157 | ) # global gpu id over all gpus over all nodes 158 | # NOTE: init DDP connection 159 | torch.distributed.init_process_group( 160 | backend="nccl", 161 | init_method=args.dist_url, 162 | world_size=args.world_size, 163 | rank=global_rank, 164 | ) 165 | print("init process group finished...") 166 | 167 | # reset batch size accordingly with number of total processes over all nodes 168 | args.batch_size = ( 169 | args.batch_size // args.world_size 170 | ) # 1024 (original batch_size) // 8 (# total processes over all nodes) = 128 171 | 172 | # Data loading 173 | traindir = os.path.join(args.data_root, "train") 174 | valdir = os.path.join(args.data_root, "val") 175 | train_transform = utils.get_train_transform() 176 | eval_transform = utils.get_eval_transform() 177 | print("train dataset preparing...") 178 | train_dataset = datasets.ImageFolder(root=traindir, transform=train_transform) 179 | print("train dataset prepared...") 180 | val_dataset = datasets.ImageFolder(root=valdir, transform=eval_transform) 181 | print("val dataset prepared...") 182 | 183 | if args.distributed: 184 | # NOTE: train_sampler assigned to each process over all process on multiple nodes 185 | train_sampler = torch.utils.data.distributed.DistributedSampler( 186 | train_dataset, num_replicas=args.world_size, rank=global_rank 187 | ) 188 | else: 189 | train_sampler = None 190 | 191 | # NOTE: for each training iteration, each gpu on multiple nodes take args.batch_size (1024) // (# gpu per node (4)* # node (2)) = 128 imgs, which are gathered to be args.batch_size=1024 in DistributedDataParallel. 192 | train_loader = torch.utils.data.DataLoader( 193 | train_dataset, 194 | batch_size=args.batch_size, 195 | shuffle=(train_sampler is None), 196 | num_workers=args.workers // args.world_size, 197 | pin_memory=True, 198 | sampler=train_sampler, 199 | ) 200 | 201 | # NOTE: Without DistributedDataParallel. Thus, single GPU (gpu id = 0 per node) takes args.batch_size = 1024 imgs. 202 | val_loader = torch.utils.data.DataLoader( 203 | val_dataset, 204 | batch_size=args.batch_size, 205 | shuffle=False, 206 | num_workers=args.workers // args.world_size, 207 | pin_memory=True, 208 | ) 209 | 210 | genotype = eval("genotypes.%s" % args.arch) 211 | logging.info("---------Genotype---------") 212 | logging.info(genotype) 213 | logging.info("--------------------------") 214 | torch.cuda.set_device( 215 | local_rank 216 | ) # NOTE: enforce using current assigned gpu id given by mp.spawn 217 | model = Network(args.init_channels, CLASSES, args.layers, args.auxiliary, genotype) 218 | model = model.cuda( 219 | local_rank 220 | ) # NOTE: enforce using current assigned gpu id given by mp.spawn 221 | print( 222 | "local rank: ", 223 | local_rank, 224 | "model deployed on : ", 225 | next(model.parameters()).device, 226 | ) 227 | # model = apex.parallel.DistributedDataParallel(model, delay_allreduce=True) 228 | # model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.local_rank], 229 | # output_device=args.local_rank, broadcast_buffers=False) 230 | model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[local_rank]) 231 | model_profile = Network( 232 | args.init_channels, CLASSES, args.layers, args.auxiliary, genotype 233 | ) 234 | model_profile = model_profile.cuda( 235 | local_rank 236 | ) # NOTE: enforce using current assigned gpu id given by mp.spawn 237 | model_input_size_imagenet = (1, 3, 224, 224) 238 | model_profile.drop_path_prob = 0 239 | flops, _ = profile(model_profile, model_input_size_imagenet) 240 | logging.info( 241 | "flops = %fM, param size = %fM", 242 | flops / 1e6, 243 | utils.count_parameters_in_MB(model), 244 | ) 245 | 246 | criterion = nn.CrossEntropyLoss() 247 | criterion = criterion.cuda( 248 | local_rank 249 | ) # NOTE: enforce using current assigned gpu id given by mp.spawn 250 | criterion_smooth = CrossEntropyLabelSmooth(CLASSES, args.label_smooth) 251 | criterion_smooth = criterion_smooth.cuda( 252 | local_rank 253 | ) # NOTE: enforce using current assigned gpu id given by mp.spawn 254 | 255 | optimizer = torch.optim.SGD( 256 | model.parameters(), 257 | args.learning_rate, 258 | momentum=args.momentum, 259 | weight_decay=args.weight_decay, 260 | ) 261 | 262 | scheduler = torch.optim.lr_scheduler.CosineAnnealingLR( 263 | optimizer, float(args.epochs) 264 | ) 265 | 266 | start_epoch = 0 267 | best_acc_top1 = 0 268 | best_acc_top5 = 0 269 | checkpoint_tar = os.path.join(args.save, "checkpoint.pth.tar") 270 | if os.path.exists(checkpoint_tar): 271 | logging.info("loading checkpoint {} ..........".format(checkpoint_tar)) 272 | checkpoint = torch.load( 273 | checkpoint_tar, map_location={"cuda:0": "cuda:{}".format(local_rank)} 274 | ) 275 | start_epoch = checkpoint["epoch"] + 1 276 | model.load_state_dict(checkpoint["state_dict"]) 277 | logging.info( 278 | "loaded checkpoint {} epoch = {}".format( 279 | checkpoint_tar, checkpoint["epoch"] 280 | ) 281 | ) 282 | 283 | for epoch in range(start_epoch, args.epochs): 284 | if args.distributed: 285 | train_sampler.set_epoch(epoch) 286 | if args.lr_scheduler == "cosine": 287 | scheduler.step() 288 | current_lr = scheduler.get_lr()[0] 289 | elif args.lr_scheduler == "linear": 290 | current_lr = adjust_lr(optimizer, epoch) 291 | else: 292 | logging.info("Wrong lr type, exit") 293 | sys.exit(1) 294 | 295 | logging.info("Epoch: %d lr %e", epoch, current_lr) 296 | if epoch < 5: 297 | for param_group in optimizer.param_groups: 298 | param_group["lr"] = current_lr * (epoch + 1) / 5.0 299 | logging.info( 300 | "Warming-up Epoch: %d, LR: %e", epoch, current_lr * (epoch + 1) / 5.0 301 | ) 302 | 303 | model.module.drop_path_prob = args.drop_path_prob * epoch / args.epochs 304 | epoch_start = time.time() 305 | train_acc, train_obj = train( 306 | train_loader, 307 | model, 308 | criterion_smooth, 309 | optimizer, 310 | epoch, 311 | local_rank, 312 | args.world_size, 313 | ) 314 | 315 | writer.add_scalar("Train/Loss", train_obj, epoch) 316 | writer.add_scalar("Train/LR", current_lr, epoch) 317 | 318 | # NOTE: if gpu id == 0 in current node, execute infer function. 319 | # NOTE: while other processes in current node are waiting for gpu process id 0 to finish infer function. 320 | # NOTE: if gpu id == 0 done infer function, next epoch train functoin is executed over all gpus on all distributed nodes. 321 | if local_rank == 0: 322 | valid_acc_top1, valid_acc_top5, valid_obj = infer( 323 | val_loader, model.module, criterion, epoch, local_rank, args.world_size 324 | ) 325 | is_best = False 326 | # if valid_acc_top5 > best_acc_top5: 327 | # best_acc_top5 = valid_acc_top5 328 | if valid_acc_top1 > best_acc_top1: 329 | best_acc_top1 = valid_acc_top1 330 | best_acc_top5 = valid_acc_top5 331 | is_best = True 332 | 333 | logging.info("Valid_acc_top1: %f", valid_acc_top1) 334 | logging.info("Valid_acc_top5: %f", valid_acc_top5) 335 | logging.info("best_acc_top1: %f", best_acc_top1) 336 | logging.info("best_acc_top5: %f", best_acc_top5) 337 | epoch_duration = time.time() - epoch_start 338 | logging.info("Epoch time: %ds.", epoch_duration) 339 | 340 | utils.save_checkpoint( 341 | { 342 | "epoch": epoch, 343 | "state_dict": model.state_dict(), 344 | "best_acc_top1": best_acc_top1, 345 | "optimizer": optimizer.state_dict(), 346 | }, 347 | is_best, 348 | args.save, 349 | ) 350 | 351 | 352 | def adjust_lr(optimizer, epoch): 353 | # Smaller slope for the last 5 epochs because lr * 1/250 is relatively large 354 | if args.epochs - epoch > 5: 355 | lr = args.learning_rate * (args.epochs - 5 - epoch) / (args.epochs - 5) 356 | else: 357 | lr = args.learning_rate * (args.epochs - epoch) / ((args.epochs - 5) * 5) 358 | for param_group in optimizer.param_groups: 359 | param_group["lr"] = lr 360 | return lr 361 | 362 | 363 | def train(train_loader, model, criterion, optimizer, epoch, local_rank, world_size): 364 | objs = utils.AvgrageMeter() 365 | top1 = utils.AvgrageMeter() 366 | top5 = utils.AvgrageMeter() 367 | batch_time = utils.AvgrageMeter() 368 | model.train() 369 | 370 | for i, (image, target) in enumerate(train_loader): 371 | # image: [128 (1024 // 8 (num total gpus))] 372 | t0 = time.time() 373 | image = image.cuda( 374 | local_rank, non_blocking=True 375 | ) # NOTE: enforce using current assigned gpu id given by mp.spawn 376 | target = target.cuda( 377 | local_rank, non_blocking=True 378 | ) # NOTE: enforce using current assigned gpu id given by mp.spawn 379 | datatime = time.time() - t0 380 | 381 | b_start = time.time() 382 | logits, logits_aux = model(image) 383 | optimizer.zero_grad() 384 | loss = criterion(logits, target) 385 | if args.auxiliary: 386 | loss_aux = criterion(logits_aux, target) 387 | loss += args.auxiliary_weight * loss_aux 388 | loss_reduce = reduce_tensor(loss, 0, world_size) 389 | loss.backward() 390 | nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip) 391 | optimizer.step() 392 | batch_time.update(time.time() - b_start) 393 | 394 | prec1, prec5 = utils.accuracy(logits, target, topk=(1, 5)) 395 | n = image.size(0) 396 | objs.update(loss_reduce.data.item(), n) 397 | top1.update(prec1.data.item(), n) 398 | top5.update(prec5.data.item(), n) 399 | 400 | if i % args.report_freq == 0 and local_rank == 0: 401 | logging.info( 402 | "TRAIN Step: %03d/%03d Objs: %e R1: %f R5: %f BTime: %.3fs Datatime: %.3f", 403 | i, 404 | train_iters, 405 | objs.avg, 406 | top1.avg, 407 | top5.avg, 408 | batch_time.avg, 409 | float(datatime), 410 | ) 411 | 412 | return top1.avg, objs.avg 413 | 414 | 415 | def infer(val_loader, model, criterion, epoch, local_rank, world_size): 416 | objs = utils.AvgrageMeter() 417 | top1 = utils.AvgrageMeter() 418 | top5 = utils.AvgrageMeter() 419 | model.eval() 420 | 421 | for i, (image, target) in enumerate(val_loader): 422 | t0 = time.time() 423 | image = image.cuda( 424 | local_rank, non_blocking=True 425 | ) # NOTE: enforce using current assigned gpu id given by mp.spawn 426 | target = target.cuda( 427 | local_rank, non_blocking=True 428 | ) # NOTE: enforce using current assigned gpu id given by mp.spawn 429 | datatime = time.time() - t0 430 | 431 | with torch.no_grad(): 432 | logits, _ = model(image) 433 | loss = criterion(logits, target) 434 | 435 | prec1, prec5 = utils.accuracy(logits, target, topk=(1, 5)) 436 | n = image.size(0) 437 | objs.update(loss.data.item(), n) 438 | top1.update(prec1.data.item(), n) 439 | top5.update(prec5.data.item(), n) 440 | 441 | if i % args.report_freq == 0: 442 | logging.info( 443 | "[%03d] VALID Step: %03d/%03d Objs: %e R1: %f R5: %f Datatime: %.3f", 444 | epoch, 445 | i, 446 | val_iters * world_size, 447 | objs.avg, 448 | top1.avg, 449 | top5.avg, 450 | float(datatime), 451 | ) 452 | 453 | return top1.avg, top5.avg, objs.avg 454 | 455 | 456 | if __name__ == "__main__": 457 | mp.spawn(main, (args,), nprocs=int(GPU_NUM), join=True) # GPU_NUM: # gpus per node. 458 | -------------------------------------------------------------------------------- /retrain_architecture/thop/__init__.py: -------------------------------------------------------------------------------- 1 | from .profile import profile -------------------------------------------------------------------------------- /retrain_architecture/thop/count_hooks.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Copied from https://github.com/megvii-model/RLNAS/blob/main/darts_search_space/cifar10/rlnas/retrain_architecture/thop/count_hooks.py 3 | ''' 4 | import argparse 5 | 6 | import torch 7 | import torch.nn as nn 8 | import numpy as np 9 | multiply_adds = 1 10 | num = 0 11 | 12 | def count_ABN(m, x, y): 13 | x = x[0] 14 | 15 | # bn 16 | nelements = x.numel() 17 | # subtract, divide, gamma, beta + relu 18 | total_ops = 4 * nelements 19 | m.total_ops = torch.Tensor([int(total_ops)]) 20 | for p in m.parameters(): 21 | m.total_params += torch.Tensor([p.numel()]) 22 | 23 | def count_convNd(m, x, y): 24 | x = x[0] 25 | cin = m.in_channels 26 | # batch_size = x.size(0) 27 | 28 | kernel_ops = m.weight.size()[2:].numel() 29 | bias_ops = 1 if m.bias is not None else 0 30 | ops_per_element = kernel_ops + bias_ops 31 | output_elements = y.nelement() 32 | 33 | # cout x oW x oH 34 | total_ops = cin * output_elements * ops_per_element // m.groups 35 | m.total_ops = torch.Tensor([int(total_ops)]) 36 | for p in m.parameters(): 37 | m.total_params += torch.Tensor([p.numel()]) 38 | 39 | def count_conv2d(m, x, y): 40 | x = x[0] 41 | 42 | cin = m.in_channels 43 | cout = m.out_channels 44 | kh, kw = m.kernel_size 45 | batch_size = x.size()[0] 46 | 47 | out_h = y.size(2) 48 | out_w = y.size(3) 49 | 50 | # ops per output element 51 | # kernel_mul = kh * kw * cin 52 | # kernel_add = kh * kw * cin - 1 53 | kernel_ops = multiply_adds * kh * kw 54 | bias_ops = 1 if m.bias is not None else 0 55 | ops_per_element = kernel_ops + bias_ops 56 | 57 | # total ops 58 | # num_out_elements = y.numel() 59 | output_elements = batch_size * out_w * out_h * cout 60 | total_ops = output_elements * ops_per_element * cin // m.groups 61 | m.total_ops = torch.Tensor([int(total_ops)]) 62 | for p in m.parameters(): 63 | m.total_params += torch.Tensor([p.numel()]) 64 | 65 | def count_convtranspose2d(m, x, y): 66 | x = x[0] 67 | 68 | cin = m.in_channels 69 | cout = m.out_channels 70 | kh, kw = m.kernel_size 71 | # batch_size = x.size()[0] 72 | 73 | out_h = y.size(2) 74 | out_w = y.size(3) 75 | 76 | # ops per output element 77 | # kernel_mul = kh * kw * cin 78 | # kernel_add = kh * kw * cin - 1 79 | kernel_ops = multiply_adds * kh * kw * cin // m.groups 80 | bias_ops = 1 if m.bias is not None else 0 81 | ops_per_element = kernel_ops + bias_ops 82 | 83 | # total ops 84 | # num_out_elements = y.numel() 85 | # output_elements = batch_size * out_w * out_h * cout 86 | ops_per_element = m.weight.nelement() 87 | output_elements = y.nelement() 88 | total_ops = output_elements * ops_per_element 89 | 90 | m.total_ops = torch.Tensor([int(total_ops)]) 91 | for p in m.parameters(): 92 | m.total_params += torch.Tensor([p.numel()]) 93 | 94 | def count_bn(m, x, y): 95 | x = x[0] 96 | 97 | nelements = x.numel() 98 | # subtract, divide, gamma, beta 99 | total_ops = 4 * nelements 100 | 101 | m.total_ops = torch.Tensor([int(total_ops)]) 102 | for p in m.parameters(): 103 | m.total_params += torch.Tensor([p.numel()]) 104 | 105 | def count_relu(m, x, y): 106 | x = x[0] 107 | 108 | nelements = x.numel() 109 | total_ops = nelements 110 | 111 | m.total_ops = torch.Tensor([int(total_ops)]) 112 | for p in m.parameters(): 113 | m.total_params += torch.Tensor([p.numel()]) 114 | 115 | def count_softmax(m, x, y): 116 | x = x[0] 117 | 118 | batch_size, nfeatures = x.size() 119 | 120 | total_exp = nfeatures 121 | total_add = nfeatures - 1 122 | total_div = nfeatures 123 | total_ops = batch_size * (total_exp + total_add + total_div) 124 | 125 | m.total_ops = torch.Tensor([int(total_ops)]) 126 | for p in m.parameters(): 127 | m.total_params += torch.Tensor([p.numel()]) 128 | 129 | def count_maxpool(m, x, y): 130 | kernel_ops = torch.prod(torch.Tensor([m.kernel_size])) 131 | num_elements = y.numel() 132 | total_ops = kernel_ops * num_elements 133 | 134 | m.total_ops = torch.Tensor([int(total_ops)]) 135 | for p in m.parameters(): 136 | m.total_params += torch.Tensor([p.numel()]) 137 | 138 | def count_adap_maxpool(m, x, y): 139 | kernel = torch.Tensor([*(x[0].shape[2:])]) // torch.Tensor(list((m.output_size,))).squeeze() 140 | kernel_ops = torch.prod(kernel) 141 | num_elements = y.numel() 142 | total_ops = kernel_ops * num_elements 143 | 144 | m.total_ops = torch.Tensor([int(total_ops)]) 145 | for p in m.parameters(): 146 | m.total_params += torch.Tensor([p.numel()]) 147 | 148 | def count_avgpool(m, x, y): 149 | total_add = torch.prod(torch.Tensor([m.kernel_size])) 150 | total_div = 1 151 | kernel_ops = total_add + total_div 152 | num_elements = y.numel() 153 | total_ops = kernel_ops * num_elements 154 | 155 | m.total_ops = torch.Tensor([int(total_ops)]) 156 | for p in m.parameters(): 157 | m.total_params += torch.Tensor([p.numel()]) 158 | 159 | def count_adap_avgpool(m, x, y): 160 | kernel = torch.Tensor([*(x[0].shape[2:])]) // torch.Tensor(list((m.output_size,))).squeeze() 161 | total_add = torch.prod(kernel) 162 | total_div = 1 163 | kernel_ops = total_add + total_div 164 | num_elements = y.numel() 165 | total_ops = kernel_ops * num_elements 166 | 167 | m.total_ops = torch.Tensor([int(total_ops)]) 168 | for p in m.parameters(): 169 | m.total_params += torch.Tensor([p.numel()]) 170 | 171 | def count_linear(m, x, y): 172 | # per output element 173 | total_mul = m.in_features 174 | total_add = m.in_features - 1 175 | num_elements = y.numel() 176 | total_ops = (total_mul + total_add) * num_elements 177 | 178 | m.total_ops = torch.Tensor([int(total_ops)]) 179 | for p in m.parameters(): 180 | m.total_params += torch.Tensor([p.numel()]) -------------------------------------------------------------------------------- /retrain_architecture/thop/profile.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Copied from https://github.com/megvii-model/RLNAS/blob/main/darts_search_space/cifar10/rlnas/retrain_architecture/thop/profile.py 3 | ''' 4 | import logging 5 | 6 | import torch 7 | import torch.nn as nn 8 | from torch.nn.modules.conv import _ConvNd 9 | from .count_hooks import * 10 | 11 | register_hooks = { 12 | nn.Conv1d: count_convNd, 13 | nn.Conv2d: count_convNd, 14 | nn.Conv3d: count_convNd, 15 | nn.ConvTranspose2d: count_convtranspose2d, 16 | 17 | # nn.BatchNorm1d: count_bn, 18 | # nn.BatchNorm2d: count_bn, 19 | # nn.BatchNorm3d: count_bn, 20 | 21 | # # nn.ReLU: count_relu, 22 | # # nn.ReLU6: count_relu, 23 | # # nn.LeakyReLU: count_relu, 24 | 25 | # nn.MaxPool1d: count_maxpool, 26 | # nn.MaxPool2d: count_maxpool, 27 | # nn.MaxPool3d: count_maxpool, 28 | # nn.AdaptiveMaxPool1d: count_adap_maxpool, 29 | # nn.AdaptiveMaxPool2d: count_adap_maxpool, 30 | # nn.AdaptiveMaxPool3d: count_adap_maxpool, 31 | 32 | # nn.AvgPool1d: count_avgpool, 33 | # nn.AvgPool2d: count_avgpool, 34 | # nn.AvgPool3d: count_avgpool, 35 | 36 | # nn.AdaptiveAvgPool1d: count_adap_avgpool, 37 | # nn.AdaptiveAvgPool2d: count_adap_avgpool, 38 | # nn.AdaptiveAvgPool3d: count_adap_avgpool, 39 | nn.Linear: count_linear, 40 | nn.Dropout: None, 41 | } 42 | 43 | 44 | def profile(model, input_size, custom_ops={}, device="cpu"): 45 | handler_collection = [] 46 | 47 | def add_hooks(m): 48 | if len(list(m.children())) > 0: 49 | return 50 | 51 | m.register_buffer('total_ops', torch.zeros(1)) 52 | m.register_buffer('total_params', torch.zeros(1)) 53 | 54 | # for p in m.parameters(): 55 | # m.total_params += torch.Tensor([p.numel()]) 56 | m_type = type(m) 57 | fn = None 58 | 59 | if m_type in custom_ops: 60 | fn = custom_ops[m_type] 61 | elif m_type in register_hooks: 62 | fn = register_hooks[m_type] 63 | else: 64 | #print("Not implemented for ", m) 65 | pass 66 | 67 | if fn is not None: 68 | #print("Register FLOP counter for module %s" % str(m)) 69 | handler = m.register_forward_hook(fn) 70 | handler_collection.append(handler) 71 | 72 | # original_device = model.parameters().__next__().device 73 | training = model.training 74 | 75 | model.eval().to(device) 76 | model.apply(add_hooks) 77 | x = torch.zeros(input_size).to(device) 78 | with torch.no_grad(): 79 | model(x) 80 | 81 | total_ops = 0 82 | total_params = 0 83 | for m in model.modules(): 84 | if len(list(m.children())) > 0: # skip for non-leaf module 85 | continue 86 | total_ops += m.total_ops 87 | total_params += m.total_params 88 | total_ops = total_ops.item() 89 | total_params = total_params.item() 90 | 91 | # model.train(training).to(original_device) 92 | for handler in handler_collection: 93 | handler.remove() 94 | 95 | return total_ops, total_params 96 | -------------------------------------------------------------------------------- /retrain_architecture/thop/utils.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Copied from https://github.com/megvii-model/RLNAS/blob/main/darts_search_space/cifar10/rlnas/retrain_architecture/thop/utils.py 3 | ''' 4 | 5 | def clever_format(num, format="%.2f"): 6 | if num > 1e12: 7 | return format % (num / 1e12) + "T" 8 | if num > 1e9: 9 | return format % (num / 1e9) + "G" 10 | if num > 1e6: 11 | return format % (num / 1e6) + "M" 12 | if num > 1e3: 13 | return format % (num / 1e3) + "K" -------------------------------------------------------------------------------- /retrain_architecture/utils.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Copied from https://github.com/megvii-model/RLNAS/blob/main/darts_search_space/cifar10/rlnas/retrain_architecture/utils.py 3 | ''' 4 | 5 | import os 6 | import numpy as np 7 | import torch 8 | import shutil 9 | import torchvision.transforms as transforms 10 | from torch.autograd import Variable 11 | import cv2 12 | import random 13 | import PIL 14 | from PIL import Image 15 | import math 16 | 17 | class AvgrageMeter(object): 18 | 19 | def __init__(self): 20 | self.reset() 21 | 22 | def reset(self): 23 | self.avg = 0 24 | self.sum = 0 25 | self.cnt = 0 26 | 27 | def update(self, val, n=1): 28 | self.sum += val * n 29 | self.cnt += n 30 | self.avg = self.sum / self.cnt 31 | 32 | 33 | def accuracy(output, target, topk=(1,)): 34 | maxk = max(topk) 35 | batch_size = target.size(0) 36 | 37 | _, pred = output.topk(maxk, 1, True, True) 38 | pred = pred.t() 39 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 40 | 41 | res = [] 42 | for k in topk: 43 | correct_k = correct[:k].reshape(-1).float().sum(0) 44 | res.append(correct_k.mul_(100.0/batch_size)) 45 | return res 46 | 47 | 48 | class Cutout(object): 49 | def __init__(self, length): 50 | self.length = length 51 | 52 | def __call__(self, img): 53 | h, w = img.size(1), img.size(2) 54 | mask = np.ones((h, w), np.float32) 55 | y = np.random.randint(h) 56 | x = np.random.randint(w) 57 | 58 | y1 = np.clip(y - self.length // 2, 0, h) 59 | y2 = np.clip(y + self.length // 2, 0, h) 60 | x1 = np.clip(x - self.length // 2, 0, w) 61 | x2 = np.clip(x + self.length // 2, 0, w) 62 | 63 | mask[y1: y2, x1: x2] = 0. 64 | mask = torch.from_numpy(mask) 65 | mask = mask.expand_as(img) 66 | img *= mask 67 | return img 68 | 69 | 70 | def _data_transforms_cifar10(args): 71 | CIFAR_MEAN = [0.49139968, 0.48215827, 0.44653124] 72 | CIFAR_STD = [0.24703233, 0.24348505, 0.26158768] 73 | 74 | train_transform = transforms.Compose([ 75 | transforms.RandomCrop(32, padding=4), 76 | transforms.RandomHorizontalFlip(), 77 | transforms.ToTensor(), 78 | transforms.Normalize(CIFAR_MEAN, CIFAR_STD), 79 | ]) 80 | if args.cutout: 81 | train_transform.transforms.append(Cutout(args.cutout_length)) 82 | 83 | valid_transform = transforms.Compose([ 84 | transforms.ToTensor(), 85 | transforms.Normalize(CIFAR_MEAN, CIFAR_STD), 86 | ]) 87 | return train_transform, valid_transform 88 | 89 | 90 | def count_parameters_in_MB(model): 91 | return np.sum(np.prod(v.size()) for name, v in model.named_parameters() if "auxiliary" not in name)/1e6 92 | 93 | 94 | def save_checkpoint(state, is_best, save): 95 | filename = os.path.join(save, 'checkpoint.pth.tar') 96 | torch.save(state, filename) 97 | if is_best: 98 | best_filename = os.path.join(save, 'model_best.pth.tar') 99 | shutil.copyfile(filename, best_filename) 100 | 101 | 102 | def save(model, model_path): 103 | torch.save(model.state_dict(), model_path) 104 | 105 | 106 | def load(model, model_path): 107 | model.load_state_dict(torch.load(model_path)) 108 | 109 | 110 | def drop_path(x, drop_prob): 111 | if drop_prob > 0.: 112 | keep_prob = 1.-drop_prob 113 | mask = Variable(torch.cuda.FloatTensor(x.size(0), 1, 1, 1).bernoulli_(keep_prob)) 114 | x.div_(keep_prob) 115 | x.mul_(mask) 116 | return x 117 | 118 | 119 | def create_exp_dir(path, scripts_to_save=None): 120 | if not os.path.exists(path): 121 | os.mkdir(path) 122 | print('Experiment dir : {}'.format(path)) 123 | 124 | if scripts_to_save is not None: 125 | os.mkdir(os.path.join(path, 'scripts')) 126 | for script in scripts_to_save: 127 | dst_file = os.path.join(path, 'scripts', os.path.basename(script)) 128 | shutil.copyfile(script, dst_file) 129 | 130 | 131 | class OpencvResize(object): 132 | 133 | def __init__(self, size=256): 134 | self.size = size 135 | 136 | def __call__(self, img): 137 | assert isinstance(img, PIL.Image.Image) 138 | img = np.asarray(img) # (H,W,3) RGB 139 | img = img[:, :, ::-1] # 2 BGR 140 | img = np.ascontiguousarray(img) 141 | H, W, _ = img.shape 142 | target_size = (int(self.size / H * W + 0.5), self.size) if H < W else (self.size, int(self.size / W * H + 0.5)) 143 | img = cv2.resize(img, target_size, interpolation=cv2.INTER_LINEAR) 144 | img = img[:, :, ::-1] # 2 RGB 145 | img = np.ascontiguousarray(img) 146 | img = Image.fromarray(img) 147 | return img 148 | 149 | class RandomResizedCrop(object): 150 | 151 | def __init__(self, scale=(0.08, 1.0), target_size: int = 224, max_attempts: int = 10): 152 | assert scale[0] <= scale[1] 153 | self.scale = scale 154 | assert target_size > 0 155 | self.target_size = target_size 156 | assert max_attempts > 0 157 | self.max_attempts = max_attempts 158 | 159 | def __call__(self, img): 160 | assert isinstance(img, PIL.Image.Image) 161 | img = np.asarray(img, dtype=np.uint8) 162 | H, W, C = img.shape 163 | 164 | well_cropped = False 165 | for _ in range(self.max_attempts): 166 | crop_area = (H * W) * random.uniform(self.scale[0], self.scale[1]) 167 | crop_edge = round(math.sqrt(crop_area)) 168 | dH = H - crop_edge 169 | dW = W - crop_edge 170 | crop_left = random.randint(min(dW, 0), max(dW, 0)) 171 | crop_top = random.randint(min(dH, 0), max(dH, 0)) 172 | if dH >= 0 and dW >= 0: 173 | well_cropped = True 174 | break 175 | 176 | crop_bottom = crop_top + crop_edge 177 | crop_right = crop_left + crop_edge 178 | if well_cropped: 179 | crop_image = img[crop_top:crop_bottom, :, :][:, crop_left:crop_right, :] 180 | 181 | else: 182 | roi_top = max(crop_top, 0) 183 | padding_top = roi_top - crop_top 184 | roi_bottom = min(crop_bottom, H) 185 | padding_bottom = crop_bottom - roi_bottom 186 | roi_left = max(crop_left, 0) 187 | padding_left = roi_left - crop_left 188 | roi_right = min(crop_right, W) 189 | padding_right = crop_right - roi_right 190 | 191 | roi_image = img[roi_top:roi_bottom, :, :][:, roi_left:roi_right, :] 192 | crop_image = cv2.copyMakeBorder(roi_image, padding_top, padding_bottom, padding_left, padding_right, 193 | borderType=cv2.BORDER_CONSTANT, value=0) 194 | 195 | random.choice([1]) 196 | target_image = cv2.resize(crop_image, (self.target_size, self.target_size), interpolation=cv2.INTER_LINEAR) 197 | target_image = PIL.Image.fromarray(target_image.astype('uint8')) 198 | return target_image 199 | 200 | 201 | class LighteningJitter(object): 202 | 203 | def __init__(self, eigen_vecs, eigen_values, max_eigen_jitter=0.1): 204 | self.eigen_vecs = np.array(eigen_vecs, dtype=np.float32) 205 | self.eigen_values = np.array(eigen_values, dtype=np.float32) 206 | self.max_eigen_jitter = max_eigen_jitter 207 | 208 | def __call__(self, img): 209 | assert isinstance(img, PIL.Image.Image) 210 | img = np.asarray(img, dtype=np.float32) 211 | img = np.ascontiguousarray(img / 255) 212 | 213 | cur_eigen_jitter = np.random.normal(scale=self.max_eigen_jitter, size=self.eigen_values.shape) 214 | color_purb = (self.eigen_vecs @ (self.eigen_values * cur_eigen_jitter)).reshape([1, 1, -1]) 215 | img += color_purb 216 | img = np.ascontiguousarray(img * 255) 217 | img.clip(0, 255, out=img) 218 | img = PIL.Image.fromarray(np.uint8(img)) 219 | return img 220 | 221 | def get_train_transform(): 222 | eigvec = np.array([ 223 | [-0.5836, -0.6948, 0.4203], 224 | [-0.5808, -0.0045, -0.8140], 225 | [-0.5675, 0.7192, 0.4009] 226 | ]) 227 | 228 | eigval = np.array([0.2175, 0.0188, 0.0045]) 229 | 230 | transform = transforms.Compose([ 231 | RandomResizedCrop(target_size=224, scale=(0.08, 1.0)), 232 | LighteningJitter(eigen_vecs=eigvec[::-1, :], eigen_values=eigval, 233 | max_eigen_jitter=0.1), 234 | transforms.RandomHorizontalFlip(0.5), 235 | transforms.ToTensor(), 236 | ]) 237 | return transform 238 | 239 | def get_eval_transform(): 240 | transform = transforms.Compose([ 241 | OpencvResize(256), 242 | transforms.CenterCrop(224), 243 | transforms.ToTensor(), 244 | ]) 245 | 246 | return transform -------------------------------------------------------------------------------- /retrain_architecture/visualize.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Copied from https://github.com/ECP-CANDLE/Benchmarks/blob/master/common/darts/visualize.py 3 | ''' 4 | import sys 5 | import genotypes 6 | from graphviz import Digraph 7 | 8 | 9 | def plot(genotype, filename): 10 | g = Digraph( 11 | format='pdf', 12 | edge_attr=dict(fontsize='20', fontname="times"), 13 | node_attr=dict(style='filled', shape='rect', align='center', fontsize='20', height='0.5', width='0.5', penwidth='2', fontname="times"), 14 | engine='dot') 15 | g.body.extend(['rankdir=LR']) 16 | 17 | g.node("c_{k-2}", fillcolor='darkseagreen2') 18 | g.node("c_{k-1}", fillcolor='darkseagreen2') 19 | assert len(genotype) % 2 == 0 20 | steps = len(genotype) // 2 21 | 22 | for i in range(steps): 23 | g.node(str(i), fillcolor='lightblue') 24 | 25 | for i in range(steps): 26 | for k in [2*i, 2*i + 1]: 27 | op, j = genotype[k] 28 | if j == 0: 29 | u = "c_{k-2}" 30 | elif j == 1: 31 | u = "c_{k-1}" 32 | else: 33 | u = str(j-2) 34 | v = str(i) 35 | g.edge(u, v, label=op, fillcolor="gray") 36 | 37 | g.node("c_{k}", fillcolor='palegoldenrod') 38 | for i in range(steps): 39 | g.edge(str(i), "c_{k}", fillcolor="gray") 40 | 41 | g.render(filename, view=True) 42 | 43 | 44 | if __name__ == '__main__': 45 | if len(sys.argv) != 2: 46 | print("usage:\n python {} ARCH_NAME".format(sys.argv[0])) 47 | sys.exit(1) 48 | 49 | genotype_name = sys.argv[1] 50 | try: 51 | genotype = eval('genotypes.{}'.format(genotype_name)) 52 | except AttributeError: 53 | print("{} is not specified in genotypes.py".format(genotype_name)) 54 | sys.exit(1) 55 | 56 | plot(genotype.normal, "normal") 57 | plot(genotype.reduce, "reduction") 58 | 59 | -------------------------------------------------------------------------------- /train_supernet/config.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Copied from https://github.com/megvii-model/RLNAS/blob/main/darts_search_space/cifar10/rlnas/train_supernet/config.py 3 | ''' 4 | import os 5 | class config: 6 | # Basic configration 7 | layers = 8 8 | edges = 14 9 | model_input_size_imagenet = (1, 3, 224, 224) 10 | 11 | # Candidate operators 12 | blocks_keys = [ 13 | 'none', 14 | 'max_pool_3x3', 15 | 'avg_pool_3x3', 16 | 'skip_connect', 17 | 'sep_conv_3x3', 18 | 'sep_conv_5x5', 19 | 'dil_conv_3x3', 20 | 'dil_conv_5x5' 21 | ] 22 | op_num=len(blocks_keys) 23 | 24 | # Operators encoding 25 | NONE = 0 26 | MAX_POOLING_3x3 = 1 27 | AVG_POOL_3x3 = 2 28 | SKIP_CONNECT = 3 29 | SEP_CONV_3x3 = 4 30 | SEP_CONV_5x5 = 5 31 | DIL_CONV_3x3 = 6 32 | DIL_CONV_5x5 = 7 33 | 34 | 35 | # Shrinking configuration 36 | exp_name = './' 37 | net_cache = os.path.join(exp_name, 'weight.pt') 38 | base_net_cache = os.path.join(exp_name, 'base_weight.pt') 39 | modify_base_net_cache = os.path.join(exp_name, 'weight_0.pt') 40 | shrinking_finish_threshold = 1000000 41 | sample_num = 1000 42 | per_stage_drop_num = 14 43 | epsilon = 1e-12 44 | 45 | # Enumerate all paths of a single cell 46 | paths = [[0, 2, 3, 4, 5], [0, 2, 3, 5], [0, 2, 4, 5], [0, 2, 5], [0, 3, 4, 5], [0, 3, 5],[0, 4, 5],[0, 5], 47 | [1, 2, 3, 4, 5], [1, 2, 3, 5], [1, 2, 4, 5], [1, 2, 5], [1, 3, 4, 5], [1, 3, 5],[1, 4, 5],[1, 5]] 48 | 49 | 50 | 51 | 52 | 53 | -------------------------------------------------------------------------------- /train_supernet/datasets/DownsampledImageNet.py: -------------------------------------------------------------------------------- 1 | ################################################## 2 | # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 # 3 | ################################################## 4 | import os, sys, hashlib, torch 5 | import numpy as np 6 | from PIL import Image 7 | import torch.utils.data as data 8 | if sys.version_info[0] == 2: 9 | import cPickle as pickle 10 | else: 11 | import pickle 12 | import pdb 13 | 14 | def calculate_md5(fpath, chunk_size=1024 * 1024): 15 | md5 = hashlib.md5() 16 | with open(fpath, 'rb') as f: 17 | for chunk in iter(lambda: f.read(chunk_size), b''): 18 | md5.update(chunk) 19 | return md5.hexdigest() 20 | 21 | 22 | def check_md5(fpath, md5, **kwargs): 23 | return md5 == calculate_md5(fpath, **kwargs) 24 | 25 | 26 | def check_integrity(fpath, md5=None): 27 | if not os.path.isfile(fpath): return False 28 | if md5 is None: return True 29 | else : return check_md5(fpath, md5) 30 | 31 | 32 | class ImageNet16(data.Dataset): 33 | # http://image-net.org/download-images 34 | # A Downsampled Variant of ImageNet as an Alternative to the CIFAR datasets 35 | # https://arxiv.org/pdf/1707.08819.pdf 36 | 37 | train_list = [ 38 | ['train_data_batch_1', '27846dcaa50de8e21a7d1a35f30f0e91'], 39 | ['train_data_batch_2', 'c7254a054e0e795c69120a5727050e3f'], 40 | ['train_data_batch_3', '4333d3df2e5ffb114b05d2ffc19b1e87'], 41 | ['train_data_batch_4', '1620cdf193304f4a92677b695d70d10f'], 42 | ['train_data_batch_5', '348b3c2fdbb3940c4e9e834affd3b18d'], 43 | ['train_data_batch_6', '6e765307c242a1b3d7d5ef9139b48945'], 44 | ['train_data_batch_7', '564926d8cbf8fc4818ba23d2faac7564'], 45 | ['train_data_batch_8', 'f4755871f718ccb653440b9dd0ebac66'], 46 | ['train_data_batch_9', 'bb6dd660c38c58552125b1a92f86b5d4'], 47 | ['train_data_batch_10','8f03f34ac4b42271a294f91bf480f29b'], 48 | ] 49 | valid_list = [ 50 | ['val_data', '3410e3017fdaefba8d5073aaa65e4bd6'], 51 | ] 52 | 53 | def __init__(self, root, train, transform, use_num_of_class_only=None): 54 | self.root = root 55 | self.transform = transform 56 | self.train = train # training set or valid set 57 | if not self._check_integrity(): raise RuntimeError('Dataset not found or corrupted.') 58 | 59 | if self.train: downloaded_list = self.train_list 60 | else : downloaded_list = self.valid_list 61 | self.data = [] 62 | self.targets = [] 63 | 64 | # now load the picked numpy arrays 65 | for i, (file_name, checksum) in enumerate(downloaded_list): 66 | file_path = os.path.join(self.root, file_name) 67 | #print ('Load {:}/{:02d}-th : {:}'.format(i, len(downloaded_list), file_path)) 68 | with open(file_path, 'rb') as f: 69 | if sys.version_info[0] == 2: 70 | entry = pickle.load(f) 71 | else: 72 | entry = pickle.load(f, encoding='latin1') 73 | self.data.append(entry['data']) 74 | self.targets.extend(entry['labels']) 75 | self.data = np.vstack(self.data).reshape(-1, 3, 16, 16) 76 | self.data = self.data.transpose((0, 2, 3, 1)) # convert to HWC 77 | if use_num_of_class_only is not None: 78 | assert isinstance(use_num_of_class_only, int) and use_num_of_class_only > 0 and use_num_of_class_only < 1000, 'invalid use_num_of_class_only : {:}'.format(use_num_of_class_only) 79 | new_data, new_targets = [], [] 80 | for I, L in zip(self.data, self.targets): 81 | if 1 <= L <= use_num_of_class_only: 82 | new_data.append( I ) 83 | new_targets.append( L ) 84 | self.data = new_data 85 | self.targets = new_targets 86 | # self.mean.append(entry['mean']) 87 | #self.mean = np.vstack(self.mean).reshape(-1, 3, 16, 16) 88 | #self.mean = np.mean(np.mean(np.mean(self.mean, axis=0), axis=1), axis=1) 89 | #print ('Mean : {:}'.format(self.mean)) 90 | #temp = self.data - np.reshape(self.mean, (1, 1, 1, 3)) 91 | #std_data = np.std(temp, axis=0) 92 | #std_data = np.mean(np.mean(std_data, axis=0), axis=0) 93 | #print ('Std : {:}'.format(std_data)) 94 | 95 | def __getitem__(self, index): 96 | img, target = self.data[index], self.targets[index] - 1 97 | 98 | img = Image.fromarray(img) 99 | 100 | if self.transform is not None: 101 | img = self.transform(img) 102 | 103 | return img, target 104 | 105 | def __len__(self): 106 | return len(self.data) 107 | 108 | def _check_integrity(self): 109 | root = self.root 110 | for fentry in (self.train_list + self.valid_list): 111 | filename, md5 = fentry[0], fentry[1] 112 | fpath = os.path.join(root, filename) 113 | if not check_integrity(fpath, md5): 114 | return False 115 | return True 116 | 117 | # 118 | if __name__ == '__main__': 119 | train = ImageNet16('/data02/dongxuanyi/.torch/cifar.python/ImageNet16', True , None) 120 | valid = ImageNet16('/data02/dongxuanyi/.torch/cifar.python/ImageNet16', False, None) 121 | 122 | print ( len(train) ) 123 | print ( len(valid) ) 124 | image, label = train[111] 125 | trainX = ImageNet16('/data02/dongxuanyi/.torch/cifar.python/ImageNet16', True , None, 200) 126 | validX = ImageNet16('/data02/dongxuanyi/.torch/cifar.python/ImageNet16', False , None, 200) 127 | print ( len(trainX) ) 128 | print ( len(validX) ) 129 | #import pdb; pdb.set_trace() 130 | -------------------------------------------------------------------------------- /train_supernet/datasets/SearchDatasetWrap.py: -------------------------------------------------------------------------------- 1 | ################################################## 2 | # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 # 3 | ################################################## 4 | import torch, copy, random 5 | import torch.utils.data as data 6 | 7 | 8 | class SearchDataset(data.Dataset): 9 | 10 | def __init__(self, name, data, train_split, valid_split, check=True): 11 | self.datasetname = name 12 | if isinstance(data, (list, tuple)): # new type of SearchDataset 13 | assert len(data) == 2, 'invalid length: {:}'.format( len(data) ) 14 | self.train_data = data[0] 15 | self.valid_data = data[1] 16 | self.train_split = train_split.copy() 17 | self.valid_split = valid_split.copy() 18 | self.mode_str = 'V2' # new mode 19 | else: 20 | self.mode_str = 'V1' # old mode 21 | self.data = data 22 | self.train_split = train_split.copy() 23 | self.valid_split = valid_split.copy() 24 | if check: 25 | intersection = set(train_split).intersection(set(valid_split)) 26 | assert len(intersection) == 0, 'the splitted train and validation sets should have no intersection' 27 | self.length = len(self.train_split) 28 | 29 | def __repr__(self): 30 | return ('{name}(name={datasetname}, train={tr_L}, valid={val_L}, version={ver})'.format(name=self.__class__.__name__, datasetname=self.datasetname, tr_L=len(self.train_split), val_L=len(self.valid_split), ver=self.mode_str)) 31 | 32 | def __len__(self): 33 | return self.length 34 | 35 | def __getitem__(self, index): 36 | assert index >= 0 and index < self.length, 'invalid index = {:}'.format(index) 37 | train_index = self.train_split[index] 38 | valid_index = random.choice( self.valid_split ) 39 | if self.mode_str == 'V1': 40 | train_image, train_label = self.data[train_index] 41 | valid_image, valid_label = self.data[valid_index] 42 | elif self.mode_str == 'V2': 43 | train_image, train_label = self.train_data[train_index] 44 | valid_image, valid_label = self.valid_data[valid_index] 45 | else: raise ValueError('invalid mode : {:}'.format(self.mode_str)) 46 | return train_image, train_label, valid_image, valid_label 47 | -------------------------------------------------------------------------------- /train_supernet/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | ################################################## 2 | # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 # 3 | ################################################## 4 | from .get_dataset_with_transform import get_datasets, get_nas_search_loaders 5 | from .SearchDatasetWrap import SearchDataset 6 | -------------------------------------------------------------------------------- /train_supernet/datasets/config_utils/__init__.py: -------------------------------------------------------------------------------- 1 | import sys, os 2 | import json 3 | from collections import namedtuple 4 | 5 | support_types = ('str', 'int', 'bool', 'float', 'none') 6 | 7 | def convert_param(original_lists): 8 | assert isinstance(original_lists, list), 'The type is not right : {:}'.format(original_lists) 9 | ctype, value = original_lists[0], original_lists[1] 10 | assert ctype in support_types, 'Ctype={:}, support={:}'.format(ctype, support_types) 11 | is_list = isinstance(value, list) 12 | if not is_list: value = [value] 13 | outs = [] 14 | for x in value: 15 | if ctype == 'int': 16 | x = int(x) 17 | elif ctype == 'str': 18 | x = str(x) 19 | elif ctype == 'bool': 20 | x = bool(int(x)) 21 | elif ctype == 'float': 22 | x = float(x) 23 | elif ctype == 'none': 24 | assert x == 'None', 'for none type, the value must be None instead of {:}'.format(x) 25 | x = None 26 | else: 27 | raise TypeError('Does not know this type : {:}'.format(ctype)) 28 | outs.append(x) 29 | if not is_list: outs = outs[0] 30 | return outs 31 | 32 | def load_config(path, extra, logger): 33 | path = str(path) 34 | if hasattr(logger, 'log'): logger.log(path) 35 | assert os.path.exists(path), 'Can not find {:}'.format(path) 36 | # Reading data back 37 | with open(path, 'r') as f: 38 | data = json.load(f) 39 | content = { k: convert_param(v) for k,v in data.items()} 40 | assert extra is None or isinstance(extra, dict), 'invalid type of extra : {:}'.format(extra) 41 | if isinstance(extra, dict): content = {**content, **extra} 42 | Arguments = namedtuple('Configure', ' '.join(content.keys())) 43 | content = Arguments(**content) 44 | if hasattr(logger, 'log'): logger.log('{:}'.format(content)) 45 | return content -------------------------------------------------------------------------------- /train_supernet/datasets/get_dataset_with_transform.py: -------------------------------------------------------------------------------- 1 | ################################################## 2 | # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 # 3 | ################################################## 4 | import os, sys, torch 5 | import os.path as osp 6 | import numpy as np 7 | import torchvision.datasets as dset 8 | import torchvision.transforms as transforms 9 | from copy import deepcopy 10 | from PIL import Image 11 | 12 | from .DownsampledImageNet import ImageNet16 13 | from .SearchDatasetWrap import SearchDataset 14 | from .config_utils import load_config as load_dataset_config 15 | from torchvision.transforms import transforms 16 | from PIL import ImageFilter, ImageOps 17 | import random 18 | import torchvision.datasets as datasets 19 | 20 | Dataset2Class = {'cifar10' : 10, 21 | 'cifar100': 100, 22 | 'imagenet-1k-s':1000, 23 | 'imagenet-1k' : 1000, 24 | 'ImageNet16' : 1000, 25 | 'ImageNet16-150': 150, 26 | 'ImageNet16-120': 120, 27 | 'ImageNet16-200': 200} 28 | 29 | class CUTOUT(object): 30 | 31 | def __init__(self, length): 32 | self.length = length 33 | 34 | def __repr__(self): 35 | return ('{name}(length={length})'.format(name=self.__class__.__name__, **self.__dict__)) 36 | 37 | def __call__(self, img): 38 | h, w = img.size(1), img.size(2) 39 | mask = np.ones((h, w), np.float32) 40 | y = np.random.randint(h) 41 | x = np.random.randint(w) 42 | 43 | y1 = np.clip(y - self.length // 2, 0, h) 44 | y2 = np.clip(y + self.length // 2, 0, h) 45 | x1 = np.clip(x - self.length // 2, 0, w) 46 | x2 = np.clip(x + self.length // 2, 0, w) 47 | 48 | mask[y1: y2, x1: x2] = 0. 49 | mask = torch.from_numpy(mask) 50 | mask = mask.expand_as(img) 51 | img *= mask 52 | return img 53 | 54 | 55 | imagenet_pca = { 56 | 'eigval': np.asarray([0.2175, 0.0188, 0.0045]), 57 | 'eigvec': np.asarray([ 58 | [-0.5675, 0.7192, 0.4009], 59 | [-0.5808, -0.0045, -0.8140], 60 | [-0.5836, -0.6948, 0.4203], 61 | ]) 62 | } 63 | 64 | 65 | class Lighting(object): 66 | def __init__(self, alphastd, 67 | eigval=imagenet_pca['eigval'], 68 | eigvec=imagenet_pca['eigvec']): 69 | self.alphastd = alphastd 70 | assert eigval.shape == (3,) 71 | assert eigvec.shape == (3, 3) 72 | self.eigval = eigval 73 | self.eigvec = eigvec 74 | 75 | def __call__(self, img): 76 | if self.alphastd == 0.: 77 | return img 78 | rnd = np.random.randn(3) * self.alphastd 79 | rnd = rnd.astype('float32') 80 | v = rnd 81 | old_dtype = np.asarray(img).dtype 82 | v = v * self.eigval 83 | v = v.reshape((3, 1)) 84 | inc = np.dot(self.eigvec, v).reshape((3,)) 85 | img = np.add(img, inc) 86 | if old_dtype == np.uint8: 87 | img = np.clip(img, 0, 255) 88 | img = Image.fromarray(img.astype(old_dtype), 'RGB') 89 | return img 90 | 91 | def __repr__(self): 92 | return self.__class__.__name__ + '()' 93 | 94 | 95 | class Cifar10RandomLabels(datasets.CIFAR10): 96 | """CIFAR10 dataset, with support for randomly corrupt labels. 97 | Params 98 | ------ 99 | rand_seed: int 100 | Default 0. numpy random seed. 101 | num_classes: int 102 | Default 10. The number of classes in the dataset. 103 | """ 104 | def __init__(self, rand_seed=0, num_classes=10, **kwargs): 105 | super(Cifar10RandomLabels, self).__init__(**kwargs) 106 | self.n_classes = num_classes 107 | self.rand_seed = rand_seed 108 | self.random_labels() 109 | 110 | def random_labels(self): 111 | labels = np.array(self.targets) 112 | print('num_classes:{}, random labels num:{}, random seed:{}'.format(self.n_classes, len(labels), self.rand_seed)) 113 | np.random.seed(self.rand_seed) 114 | rnd_labels = np.random.randint(0, self.n_classes, len(labels)) 115 | # we need to explicitly cast the labels from npy.int64 to 116 | # builtin int type, otherwise pytorch will fail... 117 | labels = [int(x) for x in rnd_labels] 118 | 119 | self.targets = labels 120 | 121 | class Cifar100RandomLabels(datasets.CIFAR100): 122 | """CIFAR10 dataset, with support for randomly corrupt labels. 123 | Params 124 | ------ 125 | rand_seed: int 126 | Default 0. numpy random seed. 127 | num_classes: int 128 | Default 100. The number of classes in the dataset. 129 | """ 130 | def __init__(self, rand_seed=0, num_classes=100, **kwargs): 131 | super(Cifar100RandomLabels, self).__init__(**kwargs) 132 | self.n_classes = num_classes 133 | self.rand_seed = rand_seed 134 | self.random_labels() 135 | 136 | def random_labels(self): 137 | labels = np.array(self.targets) 138 | print('num_classes:{}, random labels num:{}, random seed:{}'.format(self.n_classes, len(labels), self.rand_seed)) 139 | np.random.seed(self.rand_seed) 140 | rnd_labels = np.random.randint(0, self.n_classes, len(labels)) 141 | # we need to explicitly cast the labels from npy.int64 to 142 | # builtin int type, otherwise pytorch will fail... 143 | labels = [int(x) for x in rnd_labels] 144 | 145 | self.targets = labels 146 | 147 | class ImageNet16RandomLabels(ImageNet16): 148 | """CIFAR10 dataset, with support for randomly corrupt labels. 149 | Params 150 | ------ 151 | rand_seed: int 152 | Default 0. numpy random seed. 153 | num_classes: int 154 | Default 120. The number of classes in the dataset. 155 | """ 156 | def __init__(self, rand_seed=0, num_classes=120, **kwargs): 157 | super(ImageNet16RandomLabels, self).__init__(**kwargs) 158 | self.n_classes = num_classes 159 | self.rand_seed = rand_seed 160 | self.random_labels() 161 | # print('min_label:{}, max_label:{}'.format(min(self.targets), max(self.targets))) 162 | 163 | def random_labels(self): 164 | labels = np.array(self.targets) 165 | print('num_classes:{}, random labels num:{}, random seed:{}'.format(self.n_classes, len(labels), self.rand_seed)) 166 | np.random.seed(self.rand_seed) 167 | rnd_labels = np.random.randint(1, self.n_classes+1, len(labels)) 168 | # we need to explicitly cast the labels from npy.int64 to 169 | # builtin int type, otherwise pytorch will fail... 170 | labels = [int(x) for x in rnd_labels] 171 | 172 | self.targets = labels 173 | 174 | def get_datasets(name, root, cutout, rand_seed, byol_aug_type=None, random_label=True): 175 | 176 | if name == 'cifar10': 177 | mean = [x / 255 for x in [125.3, 123.0, 113.9]] 178 | std = [x / 255 for x in [63.0, 62.1, 66.7]] 179 | elif name == 'cifar100': 180 | mean = [x / 255 for x in [129.3, 124.1, 112.4]] 181 | std = [x / 255 for x in [68.2, 65.4, 70.4]] 182 | elif name.startswith('imagenet-1k'): 183 | mean, std = [0.485, 0.456, 0.406], [0.229, 0.224, 0.225] 184 | elif name.startswith('ImageNet16'): 185 | mean = [x / 255 for x in [122.68, 116.66, 104.01]] 186 | std = [x / 255 for x in [63.22, 61.26 , 65.09]] 187 | else: 188 | raise TypeError("Unknow dataset : {:}".format(name)) 189 | 190 | # Data Argumentation 191 | if name == 'cifar10' or name == 'cifar100': 192 | lists = [transforms.RandomHorizontalFlip(), transforms.RandomCrop(32, padding=4), transforms.ToTensor(), transforms.Normalize(mean, std)] 193 | if cutout > 0 : lists += [CUTOUT(cutout)] 194 | if byol_aug_type is None: 195 | train_transform = transforms.Compose(lists) 196 | elif byol_aug_type=='byol': 197 | online_aug = get_train_transform('BYOL_Tau', 32, mean, std) 198 | target_aug = get_train_transform('BYOL_Tau_Hat', 32, mean, std) 199 | train_transform = TwoImageAugmentations(online_aug, target_aug) 200 | else: 201 | raise NotImplementedError 202 | test_transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize(mean, std)]) 203 | xshape = (1, 3, 32, 32) 204 | elif name.startswith('ImageNet16'): 205 | lists = [transforms.RandomHorizontalFlip(), transforms.RandomCrop(16, padding=2), transforms.ToTensor(), transforms.Normalize(mean, std)] 206 | if cutout > 0 : lists += [CUTOUT(cutout)] 207 | if byol_aug_type is None: 208 | train_transform = transforms.Compose(lists) 209 | elif byol_aug_type=='byol': 210 | online_aug = get_train_transform('BYOL_Tau', 16, mean, std) 211 | target_aug = get_train_transform('BYOL_Tau_Hat', 16, mean, std) 212 | train_transform = TwoImageAugmentations(online_aug, target_aug) 213 | else: 214 | raise NotImplementedError 215 | test_transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize(mean, std)]) 216 | xshape = (1, 3, 16, 16) 217 | elif name == 'tiered': 218 | lists = [transforms.RandomHorizontalFlip(), transforms.RandomCrop(80, padding=4), transforms.ToTensor(), transforms.Normalize(mean, std)] 219 | if cutout > 0 : lists += [CUTOUT(cutout)] 220 | train_transform = transforms.Compose(lists) 221 | test_transform = transforms.Compose([transforms.CenterCrop(80), transforms.ToTensor(), transforms.Normalize(mean, std)]) 222 | xshape = (1, 3, 32, 32) 223 | elif name.startswith('imagenet-1k'): 224 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) 225 | if name == 'imagenet-1k': 226 | xlists = [transforms.RandomResizedCrop(224)] 227 | xlists.append( 228 | transforms.ColorJitter( 229 | brightness=0.4, 230 | contrast=0.4, 231 | saturation=0.4, 232 | hue=0.2)) 233 | xlists.append( Lighting(0.1)) 234 | elif name == 'imagenet-1k-s': 235 | xlists = [transforms.RandomResizedCrop(224, scale=(0.2, 1.0))] 236 | else: raise ValueError('invalid name : {:}'.format(name)) 237 | xlists.append( transforms.RandomHorizontalFlip(p=0.5) ) 238 | xlists.append( transforms.ToTensor() ) 239 | xlists.append( normalize ) 240 | train_transform = transforms.Compose(xlists) 241 | test_transform = transforms.Compose([transforms.Resize(256), transforms.CenterCrop(224), transforms.ToTensor(), normalize]) 242 | xshape = (1, 3, 224, 224) 243 | else: 244 | raise TypeError("Unknow dataset : {:}".format(name)) 245 | 246 | if name == 'cifar10': 247 | if random_label: 248 | train_data = Cifar10RandomLabels(root=root, train=True , transform=train_transform, download=True, rand_seed=rand_seed) 249 | else: 250 | train_data = datasets.CIFAR10(root=root, train=True, transform=train_transform, download=True) 251 | test_data = datasets.CIFAR10(root=root, train=True , transform=test_transform, download=True) 252 | # test_data = datasets.CIFAR10(root=root, train=False, transform=test_transform , download=True) 253 | assert len(train_data) == 50000 and len(test_data) == 50000 254 | elif name == 'cifar100': 255 | if random_label: 256 | train_data = Cifar100RandomLabels(root=root, train=True , transform=train_transform, download=True, rand_seed=rand_seed) 257 | else: 258 | train_data = datasets.CIFAR100(root=root, train=True , transform=train_transform, download=True) 259 | test_data = datasets.CIFAR100(root=root, train=True, transform=test_transform , download=True) 260 | assert len(train_data) == 50000 and len(test_data) == 50000 261 | elif name.startswith('imagenet-1k'): 262 | train_data = dset.ImageFolder(osp.join(root, 'train'), train_transform) 263 | test_data = dset.ImageFolder(osp.join(root, 'val'), test_transform) 264 | assert len(train_data) == 1281167 and len(test_data) == 50000, 'invalid number of images : {:} & {:} vs {:} & {:}'.format(len(train_data), len(test_data), 1281167, 50000) 265 | elif name == 'ImageNet16': 266 | if random_label: 267 | train_data = ImageNet16RandomLabels(root=root, train=True ,transform=train_transform, rand_seed=rand_seed) 268 | else: 269 | train_data = ImageNet16(root=root, train=True, transform=train_transform) 270 | test_data = ImageNet16(root=root, train=False, transform=test_transform) 271 | assert len(train_data) == 1281167 and len(test_data) == 50000 272 | elif name == 'ImageNet16-120': 273 | if random_label: 274 | train_data = ImageNet16RandomLabels(root=root, train=True , transform=train_transform, num_classes=120, use_num_of_class_only=120, rand_seed=rand_seed) 275 | else: 276 | train_data = ImageNet16(root=root, train=True , transform=train_transform, use_num_of_class_only=120) 277 | test_data = ImageNet16(root=root, train=False, transform=test_transform , use_num_of_class_only=120) 278 | assert len(train_data) == 151700 and len(test_data) == 6000 279 | elif name == 'ImageNet16-150': 280 | if random_label: 281 | train_data = ImageNet16RandomLabels(root=root, train=True , transform=train_transform, num_classes=150, use_num_of_class_only=150, rand_seed=rand_seed) 282 | else: 283 | train_data = ImageNet16(root=root, train=True , transform=train_transform, use_num_of_class_only=150) 284 | test_data = ImageNet16(root=root, train=False, transform=test_transform , use_num_of_class_only=150) 285 | assert len(train_data) == 190272 and len(test_data) == 7500 286 | elif name == 'ImageNet16-200': 287 | if random_label: 288 | train_data = ImageNet16RandomLabels(root=root, train=True ,transform=train_transform, num_classes=200, use_num_of_class_only=200, rand_seed=rand_seed) 289 | else: 290 | train_data = ImageNet16(root=root, train=True , transform=train_transform, use_num_of_class_only=200) 291 | test_data = ImageNet16(root=root, train=False, transform=test_transform , use_num_of_class_only=200) 292 | assert len(train_data) == 254775 and len(test_data) == 10000 293 | else: raise TypeError("Unknow dataset : {:}".format(name)) 294 | 295 | class_num = Dataset2Class[name] 296 | return train_data, test_data, xshape, class_num 297 | 298 | 299 | def get_nas_search_loaders(train_data, valid_data, dataset, config_root, batch_size, workers, use_valid_no_shuffle=False): 300 | # NOTE: detailed dataset configuration is given in NAS-BENCH-201 paper, https://arxiv.org/pdf/2001.00326.pdf. 301 | if isinstance(batch_size, (list,tuple)): 302 | batch, test_batch = batch_size 303 | else: 304 | batch, test_batch = batch_size, batch_size 305 | if dataset == 'cifar10' or dataset == 'cifar100': 306 | #split_Fpath = 'configs/nas-benchmark/cifar-split.txt' 307 | if dataset == 'cifar10': 308 | cifar_split = load_dataset_config('{:}/cifar-split.txt'.format(config_root), None, None) 309 | elif dataset == 'cifar100': 310 | cifar_split = load_dataset_config('{:}/cifar100-split.txt'.format(config_root), None, None) 311 | train_split, valid_split = cifar_split.train, cifar_split.valid # search over the proposed training and validation set 312 | #logger.log('Load split file from {:}'.format(split_Fpath)) # they are two disjoint groups in the original CIFAR-10 training set 313 | # To split data 314 | xvalid_data = valid_data 315 | search_data = SearchDataset(dataset, train_data, train_split, valid_split) 316 | # data loader 317 | search_loader = torch.utils.data.DataLoader(search_data, batch_size=batch, shuffle=True , num_workers=workers, pin_memory=True) 318 | train_loader = torch.utils.data.DataLoader(train_data , batch_size=batch, sampler=torch.utils.data.sampler.SubsetRandomSampler(train_split), num_workers=workers, pin_memory=True) 319 | valid_loader = torch.utils.data.DataLoader(xvalid_data, batch_size=test_batch, sampler=torch.utils.data.sampler.SubsetRandomSampler(valid_split), num_workers=workers, pin_memory=True) 320 | if use_valid_no_shuffle: 321 | # NOTE: using validation dataset 322 | valid_loader_no_shuffle = torch.utils.data.DataLoader(torch.utils.data.Subset(xvalid_data, valid_split), batch_size=test_batch, shuffle=False, num_workers=workers, pin_memory=True) 323 | # NOTE: using search training dataset 324 | # valid_loader_no_shuffle = torch.utils.data.DataLoader(search_data, batch_size=test_batch, shuffle=False , num_workers=workers, pin_memory=True) 325 | elif dataset == 'ImageNet16-120': 326 | imagenet_test_split = load_dataset_config('{:}/imagenet-16-120-test-split.txt'.format(config_root), None, None) 327 | search_train_data = train_data 328 | search_valid_data = deepcopy(valid_data) ; search_valid_data.transform = train_data.transform 329 | search_data = SearchDataset(dataset, [search_train_data,search_valid_data], list(range(len(search_train_data))), imagenet_test_split.xvalid) 330 | search_loader = torch.utils.data.DataLoader(search_data, batch_size=batch, shuffle=True , num_workers=workers, pin_memory=True) 331 | train_loader = torch.utils.data.DataLoader(train_data , batch_size=batch, shuffle=True , num_workers=workers, pin_memory=True) 332 | valid_loader = torch.utils.data.DataLoader(valid_data , batch_size=test_batch, sampler=torch.utils.data.sampler.SubsetRandomSampler(imagenet_test_split.xvalid), num_workers=workers, pin_memory=True) 333 | if use_valid_no_shuffle: 334 | # NOTE: using validation dataset 335 | valid_loader_no_shuffle = torch.utils.data.DataLoader(torch.utils.data.Subset(valid_data, imagenet_test_split.xvalid), batch_size=test_batch, shuffle=False, num_workers=workers, pin_memory=True) 336 | # NOTE: using search training dataset 337 | # valid_loader_no_shuffle = torch.utils.data.DataLoader(search_data, batch_size=test_batch, shuffle=False , num_workers=workers, pin_memory=True) 338 | else: 339 | raise ValueError('invalid dataset : {:}'.format(dataset)) 340 | 341 | if use_valid_no_shuffle: 342 | return search_loader, train_loader, valid_loader, valid_loader_no_shuffle 343 | else: 344 | return search_loader, train_loader, valid_loader 345 | 346 | def get_train_transform(aug, image_size, mean, std): 347 | 348 | if aug == 'BYOL_Tau': 349 | transform = transforms.Compose([ 350 | transforms.RandomResizedCrop(image_size), 351 | transforms.RandomHorizontalFlip(), 352 | transforms.RandomApply([transforms.ColorJitter(0.4, 0.4, 0.2, 0.1)], p=0.8), 353 | transforms.RandomGrayscale(p=0.2), 354 | transforms.RandomApply([GaussianBlur([0.1, 2.0])], p=1.0), 355 | transforms.RandomApply([Solarization(128)], p=0.0), 356 | transforms.ToTensor(), 357 | transforms.Normalize(mean, std), 358 | 359 | ]) 360 | elif aug == 'BYOL_Tau_Hat': 361 | transform = transforms.Compose([ 362 | transforms.RandomResizedCrop(image_size), 363 | transforms.RandomHorizontalFlip(), 364 | transforms.RandomApply([transforms.ColorJitter(0.4, 0.4, 0.2, 0.1)], p=0.8), 365 | transforms.RandomGrayscale(p=0.2), 366 | transforms.RandomApply([GaussianBlur([0.1, 2.0])], p=0.1), 367 | transforms.RandomApply([Solarization(128)], p=0.2), 368 | transforms.ToTensor(), 369 | transforms.Normalize(mean, std), 370 | ]) 371 | else: 372 | raise NotImplementedError 373 | 374 | return transform 375 | 376 | 377 | class TwoImageAugmentations: 378 | def __init__(self, online_aug, target_aug): 379 | self.online_aug = online_aug 380 | self.target_aug = target_aug 381 | 382 | def __call__(self, x): 383 | online_image = self.online_aug(x) 384 | target_image = self.target_aug(x) 385 | return [online_image, target_image] 386 | 387 | class GaussianBlur(object): 388 | """Gaussian blur augmentation in SimCLR https://arxiv.org/abs/2002.05709""" 389 | 390 | def __init__(self, sigma=[.1, 2.]): 391 | self.sigma = sigma 392 | 393 | def __call__(self, x): 394 | sigma = random.uniform(self.sigma[0], self.sigma[1]) 395 | x = x.filter(ImageFilter.GaussianBlur(radius=sigma)) 396 | return x 397 | 398 | 399 | class Solarization(object): 400 | def __init__(self, magnitude=128): 401 | self.magnitude = magnitude 402 | 403 | def __call__(self, x): 404 | x = ImageOps.solarize(x, self.magnitude) 405 | return x 406 | 407 | if __name__ == '__main__': 408 | byol = True 409 | train_data, test_data, xshape, class_num = get_datasets('cifar10', '/home/zhangxuanyang/dataset/cifar.python/', -1, byol) 410 | search_loader, _, valid_loader = get_nas_search_loaders(train_data, test_data, 'cifar10', 'configs/nas-benchmark/', \ 411 | (3, 3), 4) 412 | for step, (base_inputs, base_targets, arch_inputs, arch_targets) in enumerate(search_loader): 413 | print(base_inputs) 414 | break 415 | 416 | # import pdb; pdb.set_trace() 417 | -------------------------------------------------------------------------------- /train_supernet/datasets/test_utils.py: -------------------------------------------------------------------------------- 1 | ################################################## 2 | # Copyright (c) Xuanyi Dong [GitHub D-X-Y], 2019 # 3 | ################################################## 4 | def test_imagenet_data(imagenet): 5 | total_length = len(imagenet) 6 | assert total_length == 1281166 or total_length == 50000, 'The length of ImageNet is wrong : {}'.format(total_length) 7 | map_id = {} 8 | for index in range(total_length): 9 | path, target = imagenet.imgs[index] 10 | folder, image_name = os.path.split(path) 11 | _, folder = os.path.split(folder) 12 | if folder not in map_id: 13 | map_id[folder] = target 14 | else: 15 | assert map_id[folder] == target, 'Class : {} is not {}'.format(folder, target) 16 | assert image_name.find(folder) == 0, '{} is wrong.'.format(path) 17 | print ('Check ImageNet Dataset OK') 18 | -------------------------------------------------------------------------------- /train_supernet/genotypes.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Copied from https://github.com/megvii-model/RLNAS/blob/main/darts_search_space/cifar10/rlnas/train_supernet/genotypes.py 3 | ''' 4 | from collections import namedtuple 5 | 6 | Genotype = namedtuple('Genotype', 'normal normal_concat reduce reduce_concat') 7 | 8 | PRIMITIVES = [ 9 | 'none', 10 | 'max_pool_3x3', 11 | 'avg_pool_3x3', 12 | 'skip_connect', 13 | 'sep_conv_3x3', 14 | 'sep_conv_5x5', 15 | 'dil_conv_3x3', 16 | 'dil_conv_5x5' 17 | ] 18 | 19 | NASNet = Genotype( 20 | normal = [ 21 | ('sep_conv_5x5', 1), 22 | ('sep_conv_3x3', 0), 23 | ('sep_conv_5x5', 0), 24 | ('sep_conv_3x3', 0), 25 | ('avg_pool_3x3', 1), 26 | ('skip_connect', 0), 27 | ('avg_pool_3x3', 0), 28 | ('avg_pool_3x3', 0), 29 | ('sep_conv_3x3', 1), 30 | ('skip_connect', 1), 31 | ], 32 | normal_concat = [2, 3, 4, 5, 6], 33 | reduce = [ 34 | ('sep_conv_5x5', 1), 35 | ('sep_conv_7x7', 0), 36 | ('max_pool_3x3', 1), 37 | ('sep_conv_7x7', 0), 38 | ('avg_pool_3x3', 1), 39 | ('sep_conv_5x5', 0), 40 | ('skip_connect', 3), 41 | ('avg_pool_3x3', 2), 42 | ('sep_conv_3x3', 2), 43 | ('max_pool_3x3', 1), 44 | ], 45 | reduce_concat = [4, 5, 6], 46 | ) 47 | 48 | AmoebaNet = Genotype( 49 | normal = [ 50 | ('avg_pool_3x3', 0), 51 | ('max_pool_3x3', 1), 52 | ('sep_conv_3x3', 0), 53 | ('sep_conv_5x5', 2), 54 | ('sep_conv_3x3', 0), 55 | ('avg_pool_3x3', 3), 56 | ('sep_conv_3x3', 1), 57 | ('skip_connect', 1), 58 | ('skip_connect', 0), 59 | ('avg_pool_3x3', 1), 60 | ], 61 | normal_concat = [4, 5, 6], 62 | reduce = [ 63 | ('avg_pool_3x3', 0), 64 | ('sep_conv_3x3', 1), 65 | ('max_pool_3x3', 0), 66 | ('sep_conv_7x7', 2), 67 | ('sep_conv_7x7', 0), 68 | ('avg_pool_3x3', 1), 69 | ('max_pool_3x3', 0), 70 | ('max_pool_3x3', 1), 71 | ('conv_7x1_1x7', 0), 72 | ('sep_conv_3x3', 5), 73 | ], 74 | reduce_concat = [3, 4, 6] 75 | ) 76 | 77 | DARTS_V1 = Genotype(normal=[('sep_conv_3x3', 1), ('sep_conv_3x3', 0), ('skip_connect', 0), ('sep_conv_3x3', 1), ('skip_connect', 0), ('sep_conv_3x3', 1), ('sep_conv_3x3', 0), ('skip_connect', 2)], normal_concat=[2, 3, 4, 5], reduce=[('max_pool_3x3', 0), ('max_pool_3x3', 1), ('skip_connect', 2), ('max_pool_3x3', 0), ('max_pool_3x3', 0), ('skip_connect', 2), ('skip_connect', 2), ('avg_pool_3x3', 0)], reduce_concat=[2, 3, 4, 5]) 78 | DARTS_V2 = Genotype(normal=[('sep_conv_3x3', 0), ('sep_conv_3x3', 1), ('sep_conv_3x3', 0), ('sep_conv_3x3', 1), ('sep_conv_3x3', 1), ('skip_connect', 0), ('skip_connect', 0), ('dil_conv_3x3', 2)], normal_concat=[2, 3, 4, 5], reduce=[('max_pool_3x3', 0), ('max_pool_3x3', 1), ('skip_connect', 2), ('max_pool_3x3', 1), ('max_pool_3x3', 0), ('skip_connect', 2), ('skip_connect', 2), ('max_pool_3x3', 1)], reduce_concat=[2, 3, 4, 5]) 79 | 80 | DARTS = DARTS_V2 81 | 82 | -------------------------------------------------------------------------------- /train_supernet/operations.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Copied from https://github.com/megvii-model/RLNAS/blob/main/darts_search_space/cifar10/rlnas/train_supernet/operations.py 3 | ''' 4 | import torch 5 | import torch.nn as nn 6 | 7 | OPS = { 8 | 'none' : lambda C, stride, affine: Zero(stride), 9 | 'avg_pool_3x3' : lambda C, stride, affine: nn.AvgPool2d(3, stride=stride, padding=1, count_include_pad=False), 10 | 'max_pool_3x3' : lambda C, stride, affine: nn.MaxPool2d(3, stride=stride, padding=1), 11 | 'skip_connect' : lambda C, stride, affine: Identity() if stride == 1 else FactorizedReduce(C, C, affine=affine), 12 | 'sep_conv_3x3' : lambda C, stride, affine: SepConv(C, C, 3, stride, 1, affine=affine), 13 | 'sep_conv_5x5' : lambda C, stride, affine: SepConv(C, C, 5, stride, 2, affine=affine), 14 | 'sep_conv_7x7' : lambda C, stride, affine: SepConv(C, C, 7, stride, 3, affine=affine), 15 | 'dil_conv_3x3' : lambda C, stride, affine: DilConv(C, C, 3, stride, 2, 2, affine=affine), 16 | 'dil_conv_5x5' : lambda C, stride, affine: DilConv(C, C, 5, stride, 4, 2, affine=affine), 17 | 'conv_7x1_1x7' : lambda C, stride, affine: nn.Sequential( 18 | nn.ReLU(inplace=False), 19 | nn.Conv2d(C, C, (1,7), stride=(1, stride), padding=(0, 3), bias=False), 20 | nn.Conv2d(C, C, (7,1), stride=(stride, 1), padding=(3, 0), bias=False), 21 | nn.BatchNorm2d(C, affine=affine) 22 | ), 23 | } 24 | 25 | class ReLUConvBN(nn.Module): 26 | 27 | def __init__(self, C_in, C_out, kernel_size, stride, padding, affine=True): 28 | super(ReLUConvBN, self).__init__() 29 | self.op = nn.Sequential( 30 | nn.ReLU(inplace=False), 31 | nn.Conv2d(C_in, C_out, kernel_size, stride=stride, padding=padding, bias=False), 32 | nn.BatchNorm2d(C_out, affine=affine) 33 | ) 34 | 35 | def forward(self, x, rngs=None): 36 | return self.op(x) 37 | 38 | class DilConv(nn.Module): 39 | 40 | def __init__(self, C_in, C_out, kernel_size, stride, padding, dilation, affine=True): 41 | super(DilConv, self).__init__() 42 | self.op = nn.Sequential( 43 | nn.ReLU(inplace=False), 44 | nn.Conv2d(C_in, C_in, kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation, groups=C_in, bias=False), 45 | nn.Conv2d(C_in, C_out, kernel_size=1, padding=0, bias=False), 46 | nn.BatchNorm2d(C_out, affine=affine), 47 | ) 48 | 49 | def forward(self, x, rngs=None): 50 | return self.op(x) 51 | 52 | class SepConv(nn.Module): 53 | 54 | def __init__(self, C_in, C_out, kernel_size, stride, padding, affine=True): 55 | super(SepConv, self).__init__() 56 | self.op = nn.Sequential( 57 | nn.ReLU(inplace=False), 58 | nn.Conv2d(C_in, C_in, kernel_size=kernel_size, stride=stride, padding=padding, groups=C_in, bias=False), 59 | nn.Conv2d(C_in, C_in, kernel_size=1, padding=0, bias=False), 60 | nn.BatchNorm2d(C_in, affine=affine), 61 | nn.ReLU(inplace=False), 62 | nn.Conv2d(C_in, C_in, kernel_size=kernel_size, stride=1, padding=padding, groups=C_in, bias=False), 63 | nn.Conv2d(C_in, C_out, kernel_size=1, padding=0, bias=False), 64 | nn.BatchNorm2d(C_out, affine=affine), 65 | ) 66 | 67 | def forward(self, x, rngs=None): 68 | return self.op(x) 69 | 70 | 71 | class Identity(nn.Module): 72 | 73 | def __init__(self): 74 | super(Identity, self).__init__() 75 | 76 | def forward(self, x, rngs=None): 77 | return x 78 | 79 | class Zero(nn.Module): 80 | 81 | def __init__(self, stride): 82 | super(Zero, self).__init__() 83 | self.stride = stride 84 | def forward(self, x, rngs=None): 85 | n, c, h, w = x.size() 86 | h //= self.stride 87 | w //= self.stride 88 | if x.is_cuda: 89 | with torch.cuda.device(x.get_device()): 90 | padding = torch.cuda.FloatTensor(n, c, h, w).fill_(0) 91 | else: 92 | padding = torch.FloatTensor(n, c, h, w).fill_(0) 93 | return padding 94 | 95 | class FactorizedReduce(nn.Module): 96 | 97 | def __init__(self, C_in, C_out, affine=True): 98 | super(FactorizedReduce, self).__init__() 99 | assert C_out % 2 == 0 100 | self.relu = nn.ReLU(inplace=False) 101 | self.conv_1 = nn.Conv2d(C_in, C_out // 2, 1, stride=2, padding=0, bias=False) 102 | self.conv_2 = nn.Conv2d(C_in, C_out // 2, 1, stride=2, padding=0, bias=False) 103 | self.bn = nn.BatchNorm2d(C_out, affine=affine) 104 | 105 | def forward(self, x, rngs=None): 106 | x = self.relu(x) 107 | out = torch.cat([self.conv_1(x), self.conv_2(x[:,:,1:,1:])], dim=1) 108 | out = self.bn(out) 109 | return out 110 | 111 | 112 | -------------------------------------------------------------------------------- /train_supernet/super_model.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Copied from https://github.com/megvii-model/RLNAS/blob/main/darts_search_space/cifar10/rlnas/train_supernet/super_model.py 3 | ''' 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | from operations import * 8 | from torch.autograd import Variable 9 | from genotypes import PRIMITIVES 10 | from genotypes import Genotype 11 | import math 12 | import numpy as np 13 | from config import config 14 | import copy 15 | from utils import check_cand 16 | 17 | class MixedOp(nn.Module): 18 | 19 | def __init__(self, C, stride): 20 | super(MixedOp, self).__init__() 21 | self._ops = nn.ModuleList() 22 | for idx, primitive in enumerate(PRIMITIVES): 23 | op = OPS[primitive](C, stride, True) 24 | op.idx = idx 25 | if 'pool' in primitive: 26 | op = nn.Sequential(op, nn.BatchNorm2d(C, affine=True)) 27 | self._ops.append(op) 28 | 29 | def forward(self, x, rng): 30 | return self._ops[rng](x) 31 | 32 | 33 | class Cell(nn.Module): 34 | 35 | def __init__(self, steps, multiplier, C_prev_prev, C_prev, C, reduction, reduction_prev): 36 | super(Cell, self).__init__() 37 | if reduction_prev: 38 | # NOTE: if K-1 cell output was from stride-2 op, K-2 cell output should shrink its spatial size by stride-2. 39 | self.preprocess0 = FactorizedReduce(C_prev_prev, C, affine=True) 40 | else: 41 | self.preprocess0 = ReLUConvBN(C_prev_prev, C, 1, 1, 0, affine=True) 42 | self.preprocess1 = ReLUConvBN(C_prev, C, 1, 1, 0, affine=True) 43 | self._steps = steps 44 | self._multiplier = multiplier 45 | self._C = C 46 | self.out_C = self._multiplier * C 47 | self.reduction = reduction 48 | 49 | self._ops = nn.ModuleList() 50 | self._bns = nn.ModuleList() 51 | self.time_stamp = 1 52 | 53 | for i in range(self._steps): 54 | for j in range(2+i): 55 | stride = 2 if reduction and j < 2 else 1 56 | op = MixedOp(C, stride) 57 | self._ops.append(op) 58 | 59 | def forward(self, s0, s1, rngs): 60 | s0 = self.preprocess0(s0) 61 | s1 = self.preprocess1(s1) 62 | states = [s0, s1] 63 | offset = 0 64 | for i in range(self._steps): 65 | s = sum(self._ops[offset+j](h, rngs[offset+j]) for j, h in enumerate(states)) 66 | offset += len(states) 67 | states.append(s) 68 | return torch.cat(states[-self._multiplier:], dim=1) 69 | 70 | class Network(nn.Module): 71 | def __init__(self, C=16, num_classes=10, layers=8, steps=4, multiplier=4, stem_multiplier=3): 72 | super(Network, self).__init__() 73 | self._C = C 74 | self._num_classes = num_classes 75 | self._layers = layers 76 | self._steps = steps 77 | self._multiplier = multiplier 78 | 79 | C_curr = stem_multiplier * C 80 | 81 | self.stem = nn.Sequential( 82 | nn.Conv2d(3, C_curr, 3, padding=1, bias=False), 83 | nn.BatchNorm2d(C_curr) 84 | ) 85 | 86 | C_prev_prev, C_prev, C_curr = C_curr, C_curr, C 87 | 88 | self.cells = nn.ModuleList() 89 | reduction_prev = False 90 | 91 | for i in range(layers): 92 | if i in [layers // 3, 2 * layers // 3]: 93 | C_curr *= 2 94 | reduction = True 95 | else: 96 | reduction = False 97 | cell = Cell(steps, multiplier, C_prev_prev, C_prev, C_curr, reduction, reduction_prev) 98 | reduction_prev = reduction 99 | self.cells += [cell] 100 | C_prev_prev, C_prev = C_prev, multiplier * C_curr 101 | 102 | self.global_pooling = nn.AdaptiveAvgPool2d(1) 103 | self.classifier = nn.Linear(C_prev, num_classes) 104 | 105 | def forward(self, input, normal_rng, reduction_rng): 106 | s0 = s1 = self.stem(input) 107 | for i, cell in enumerate(self.cells): 108 | if i in [self._layers // 3, 2 * self._layers // 3]: 109 | s0, s1 = s1, cell(s0, s1, reduction_rng) 110 | else: 111 | s0, s1 = s1, cell(s0, s1, normal_rng) 112 | out = self.global_pooling(s1) 113 | logits = self.classifier(out.view(out.size(0),-1)) 114 | return logits 115 | 116 | if __name__ == '__main__': 117 | from copy import deepcopy 118 | np.random.seed(0) 119 | model = Network() 120 | print(model) 121 | exit(0) 122 | operations = [] 123 | for _ in range(config.edges): 124 | operations.append(list(range(config.op_num))) 125 | norm_rng = [np.random.randint(len(config.blocks_keys)) for i in range(config.edges)] 126 | reduction_rng = [np.random.randint(len(config.blocks_keys)) for i in range(config.edges)] 127 | print("operations: ", operations) 128 | print("norm rng: ", norm_rng) 129 | norm_rng = check_cand(norm_rng, operations) 130 | print("after check_cand norm_rng: ", norm_rng) 131 | reduction_rng = check_cand(reduction_rng, operations) 132 | x = torch.rand(4,3,32,32) 133 | logit = model(x, norm_rng, reduction_rng) 134 | print('logit:{0}'.format(logit)) -------------------------------------------------------------------------------- /train_supernet/train.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Code adapted from https://github.com/megvii-model/RLNAS/blob/main/darts_search_space/cifar10/rlnas/train_supernet/train.py 3 | ''' 4 | import os 5 | import sys 6 | import time 7 | import glob 8 | import numpy as np 9 | import torch 10 | import utils 11 | import logging 12 | import argparse 13 | import torch.nn as nn 14 | import torch.utils 15 | import torch.nn.functional as F 16 | import torchvision.datasets as dset 17 | import torch.backends.cudnn as cudnn 18 | 19 | from torch.autograd import Variable 20 | from super_model import Network 21 | from copy import deepcopy 22 | from config import config 23 | from datasets import get_datasets, get_nas_search_loaders 24 | 25 | parser = argparse.ArgumentParser("cifar") 26 | parser.add_argument('--data', type=str, default='../data', help='location of the data corpus') 27 | parser.add_argument('--batch_size', type=int, default=64, help='batch size') 28 | parser.add_argument('--learning_rate', type=float, default=0.025, help='init learning rate') 29 | parser.add_argument('--learning_rate_min', type=float, default=0.001, help='min learning rate') 30 | parser.add_argument('--momentum', type=float, default=0.9, help='momentum') 31 | parser.add_argument('--weight_decay', type=float, default=3e-4, help='weight decay') 32 | parser.add_argument('--report_freq', type=float, default=50, help='report frequency') 33 | parser.add_argument('--gpu', type=int, default=0, help='gpu device id') 34 | parser.add_argument('--epochs', type=int, default=50, help='num of training epochs') 35 | parser.add_argument('--init_channels', type=int, default=16, help='num of init channels') 36 | parser.add_argument('--layers', type=int, default=8, help='total number of layers') 37 | parser.add_argument('--model_path', type=str, default='saved_models', help='path to save the model') 38 | parser.add_argument('--cutout', action='store_true', default=False, help='use cutout') 39 | parser.add_argument('--cutout_length', type=int, default=16, help='cutout length') 40 | parser.add_argument('--drop_path_prob', type=float, default=0.3, help='drop path probability') 41 | parser.add_argument('--save', type=str, default='models', help='experiment name') 42 | parser.add_argument('--seed', type=int, default=1, help='random seed') 43 | parser.add_argument('--grad_clip', type=float, default=5, help='gradient clipping') 44 | parser.add_argument('--train_portion', type=float, default=0.5, help='portion of training data') 45 | parser.add_argument('--unrolled', action='store_true', default=False, help='use one-step unrolled validation loss') 46 | parser.add_argument('--arch_learning_rate', type=float, default=3e-4, help='learning rate for arch encoding') 47 | parser.add_argument('--arch_weight_decay', type=float, default=1e-3, help='weight decay for arch encoding') 48 | parser.add_argument('--random_label', type=int, choices=[0, 1], default=1, help='Whether use random label for dataset or not. (default: True)') 49 | parser.add_argument('--split_data', type=int, choices=[0, 1], default=1, help='Whether use split data for training & validation. (default: True)') 50 | args = parser.parse_args() 51 | 52 | args.split_data = bool(args.split_data) 53 | 54 | utils.create_exp_dir(args.save, scripts_to_save=glob.glob('*.py')) 55 | 56 | log_format = '%(asctime)s %(message)s' 57 | logging.basicConfig(stream=sys.stdout, level=logging.INFO, 58 | format=log_format, datefmt='%m/%d %I:%M:%S %p') 59 | fh = logging.FileHandler(os.path.join(args.save, 'log.txt')) 60 | fh.setFormatter(logging.Formatter(log_format)) 61 | logging.getLogger().addHandler(fh) 62 | 63 | CIFAR_CLASSES = 100 64 | 65 | def main(): 66 | if not torch.cuda.is_available(): 67 | logging.info('no gpu device available') 68 | sys.exit(1) 69 | 70 | np.random.seed(args.seed) 71 | torch.cuda.set_device(args.gpu) 72 | cudnn.benchmark = True 73 | torch.manual_seed(args.seed) 74 | cudnn.enabled=True 75 | torch.cuda.manual_seed(args.seed) 76 | seed = args.seed 77 | logging.info('gpu device = %d' % args.gpu) 78 | logging.info("args = %s", args) 79 | 80 | criterion = nn.CrossEntropyLoss() 81 | criterion = criterion.cuda() 82 | # NOTE: layers: number of cells in network 83 | model = Network(args.init_channels, CIFAR_CLASSES, args.layers) 84 | model = model.cuda() 85 | logging.info("param size = %fMB", utils.count_parameters_in_MB(model)) 86 | 87 | optimizer = torch.optim.SGD( 88 | model.parameters(), 89 | args.learning_rate, 90 | momentum=args.momentum, 91 | weight_decay=args.weight_decay) 92 | 93 | if args.split_data: # NOTE: split train data in half to be new train, val set. new train is used for supernet training, new val set is used for evaluation 94 | train_data, valid_data, xshape, class_num = get_datasets('cifar100', args.data, -1, args.seed, random_label=bool(args.random_label)) 95 | train_queue, _, _, valid_queue = get_nas_search_loaders(train_data, valid_data, 'cifar100', 96 | 'datasets/configs/', \ 97 | (args.batch_size, args.batch_size), 4, use_valid_no_shuffle=True) 98 | else: 99 | assert ValueError("only --split_data 1 is supported") 100 | 101 | scheduler = torch.optim.lr_scheduler.CosineAnnealingLR( 102 | optimizer, float(args.epochs), eta_min=args.learning_rate_min) 103 | 104 | operations = [] 105 | for _ in range(config.edges): 106 | operations.append(list(range(config.op_num))) 107 | print('operations={}'.format(operations)) 108 | 109 | utils.save_checkpoint({'epoch': -1, 110 | 'state_dict': model.state_dict(), 111 | 'optimizer': optimizer.state_dict()}, args.save) 112 | 113 | for epoch in range(args.epochs): 114 | scheduler.step() 115 | lr = scheduler.get_lr()[0] 116 | logging.info('epoch %d lr %e', epoch, lr) 117 | 118 | # training 119 | seed, train_acc, train_obj = train(train_queue, model, criterion, optimizer, operations, seed, epoch) 120 | logging.info('train_acc %f', train_acc) 121 | 122 | # validation 123 | valid_acc, valid_obj = infer(valid_queue, model, criterion, seed, operations) 124 | logging.info('valid_acc %f', valid_acc) 125 | 126 | if (epoch+1)%5 == 0: 127 | utils.save_checkpoint({'epoch':epoch, 128 | 'state_dict':model.state_dict(), 129 | 'optimizer':optimizer.state_dict()}, args.save) 130 | 131 | def get_random_cand(seed, operations): 132 | # Uniform Sampling 133 | rng = [] 134 | for op in operations: 135 | np.random.seed(seed) 136 | k = np.random.randint(len(op)) 137 | select_op = op[k] 138 | rng.append(select_op) 139 | seed += 1 140 | 141 | return rng, seed 142 | 143 | def train(train_queue, model, criterion, optimizer, operations, seed, epoch): 144 | objs = utils.AvgrageMeter() 145 | top1 = utils.AvgrageMeter() 146 | top5 = utils.AvgrageMeter() 147 | 148 | model.train() 149 | 150 | for step, batch in enumerate(train_queue): 151 | if len(batch) == 4: 152 | input, target, _, _ = batch 153 | elif len(batch) == 2: 154 | input, target = batch 155 | n = input.size(0) 156 | 157 | input = input.cuda(non_blocking=True) 158 | target = target.cuda(non_blocking=True) 159 | 160 | optimizer.zero_grad() 161 | 162 | # NOTE: per training iteration, operation per edge is randomly sampled.. (as in SPOS!. no architecture parameters.) 163 | normal_rng, seed = get_random_cand(seed, operations) 164 | reduction_rng, seed = get_random_cand(seed, operations) 165 | 166 | normal_rng = utils.check_cand(normal_rng, operations) 167 | reduction_rng = utils.check_cand(reduction_rng, operations) 168 | 169 | logits = model(input, normal_rng, reduction_rng) 170 | loss = criterion(logits, target) 171 | 172 | loss.backward() 173 | nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip) 174 | optimizer.step() 175 | 176 | prec1, prec5 = utils.accuracy(logits, target, topk=(1, 5)) 177 | objs.update(loss.item(), n) 178 | top1.update(prec1.item(), n) 179 | top5.update(prec5.item(), n) 180 | 181 | if step % args.report_freq == 0: 182 | logging.info('train %03d %e %f %f', step, objs.avg, top1.avg, top5.avg) 183 | 184 | return seed, top1.avg, objs.avg 185 | 186 | 187 | def infer(valid_queue, model, criterion, seed, operations): 188 | objs = utils.AvgrageMeter() 189 | top1 = utils.AvgrageMeter() 190 | top5 = utils.AvgrageMeter() 191 | model.eval() 192 | 193 | normal_rng, seed = get_random_cand(seed, operations) 194 | reduction_rng, seed = get_random_cand(seed, operations) 195 | 196 | normal_rng = utils.check_cand(normal_rng, operations) 197 | reduction_rng = utils.check_cand(reduction_rng, operations) 198 | 199 | # NOTE: no optimize for architecture parameters (abscence of architecture parameters) 200 | # NOTE: instead, randomly select operation for each edge and evaluate. 201 | for step, (input, target) in enumerate(valid_queue): 202 | input = input.cuda(non_blocking=True) 203 | target = target.cuda(non_blocking=True) 204 | 205 | with torch.no_grad(): 206 | logits = model(input, normal_rng, reduction_rng) 207 | loss = criterion(logits, target) 208 | 209 | prec1, prec5 = utils.accuracy(logits, target, topk=(1, 5)) 210 | n = input.size(0) 211 | objs.update(loss.item(), n) 212 | top1.update(prec1.item(), n) 213 | top5.update(prec5.item(), n) 214 | 215 | if step % args.report_freq == 0: 216 | logging.info('valid %03d %e %f %f', step, objs.avg, top1.avg, top5.avg) 217 | 218 | return top1.avg, objs.avg 219 | 220 | 221 | if __name__ == '__main__': 222 | main() 223 | -------------------------------------------------------------------------------- /train_supernet/utils.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Copied from https://github.com/megvii-model/RLNAS/blob/main/darts_search_space/cifar10/rlnas/train_supernet/utils.py 3 | ''' 4 | import os 5 | import sys 6 | import shutil 7 | import numpy as np 8 | import time, datetime 9 | import torch 10 | import glob 11 | import random 12 | import logging 13 | import argparse 14 | import torch.nn as nn 15 | import torch.utils 16 | import torchvision.datasets as dset 17 | import torchvision.transforms as transforms 18 | import torch.backends.cudnn as cudnn 19 | from torch.autograd import Variable 20 | import joblib 21 | import pdb 22 | import pickle 23 | from collections import defaultdict 24 | from config import config 25 | import copy 26 | 27 | def broadcast(args, obj, src, group=torch.distributed.group.WORLD, async_op=False): 28 | print('local_rank:{}, obj:{}'.format(args.local_rank, obj)) 29 | obj_tensor = torch.from_numpy(np.array(obj)).cuda() 30 | torch.distributed.broadcast(obj_tensor, src, group, async_op) 31 | obj = obj_tensor.cpu().numpy() 32 | print('local_rank:{}, tensor:{}'.format(args.local_rank, obj)) 33 | return obj 34 | 35 | class CrossEntropyLabelSmooth(nn.Module): 36 | 37 | def __init__(self, num_classes, epsilon): 38 | super(CrossEntropyLabelSmooth, self).__init__() 39 | self.num_classes = num_classes 40 | self.epsilon = epsilon 41 | self.logsoftmax = nn.LogSoftmax(dim=1) 42 | 43 | def forward(self, inputs, targets): 44 | log_probs = self.logsoftmax(inputs) 45 | targets = torch.zeros_like(log_probs).scatter_(1, targets.unsqueeze(1), 1) 46 | targets = (1 - self.epsilon) * targets + self.epsilon / self.num_classes 47 | loss = (-targets * log_probs).mean(0).sum() 48 | return loss 49 | 50 | def get_optimizer_schedule(model, args, total_iters): 51 | all_parameters = model.parameters() 52 | weight_parameters = [] 53 | for pname, p in model.named_parameters(): 54 | if p.ndimension() == 4 or 'classifier.0.weight' in pname or 'classifier.0.bias' in pname: 55 | weight_parameters.append(p) 56 | weight_parameters_id = list(map(id, weight_parameters)) 57 | other_parameters = list(filter(lambda p: id(p) not in weight_parameters_id, all_parameters)) 58 | optimizer = torch.optim.SGD( 59 | [{'params' : other_parameters}, 60 | {'params' : weight_parameters, 'weight_decay' : args.weight_decay}], 61 | args.learning_rate, 62 | momentum=args.momentum, 63 | ) 64 | 65 | delta_iters = total_iters / (1.-args.min_lr / args.learning_rate) 66 | print('delta_iters={}'.format(delta_iters)) 67 | scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lambda step : (1.0-step/delta_iters), last_epoch=-1) 68 | return optimizer, scheduler 69 | 70 | def get_location(s, key): 71 | d = defaultdict(list) 72 | for k,va in [(v,i) for i,v in enumerate(s)]: 73 | d[k].append(va) 74 | return d[key] 75 | 76 | def list_substract(list1, list2): 77 | list1 = [item for item in list1 if item not in set(list2)] 78 | return list1 79 | 80 | def check_cand(cand, operations): 81 | cand = np.reshape(cand, [-1, config.edges]) 82 | offset, cell_cand = 0, cand[0] 83 | for j in range(4): 84 | edges = cell_cand[offset:offset+j+2] 85 | edges_ops = operations[offset:offset+j+2] 86 | none_idxs = get_location(edges, 0) 87 | if len(none_idxs) < j: 88 | general_idxs = list_substract(range(j+2), none_idxs) 89 | num = min(j-len(none_idxs), len(general_idxs)) 90 | general_idxs = np.random.choice(general_idxs, size=num, replace=False, p=None) 91 | for k in general_idxs: 92 | edges[k] = 0 93 | elif len(none_idxs) > j: 94 | none_idxs = np.random.choice(none_idxs, size=len(none_idxs)-j, replace=False, p=None) 95 | for k in none_idxs: 96 | if len(edges_ops[k]) > 1: 97 | l = np.random.randint(len(edges_ops[k])-1) 98 | edges[k] = edges_ops[k][l+1] 99 | offset += len(edges) 100 | 101 | return cell_cand.tolist() 102 | 103 | class AvgrageMeter(object): 104 | 105 | def __init__(self): 106 | self.reset() 107 | 108 | def reset(self): 109 | self.avg = 0 110 | self.sum = 0 111 | self.cnt = 0 112 | 113 | def update(self, val, n=1): 114 | self.sum += val * n 115 | self.cnt += n 116 | self.avg = self.sum / self.cnt 117 | 118 | def accuracy(output, target, topk=(1,)): 119 | maxk = max(topk) 120 | batch_size = target.size(0) 121 | 122 | _, pred = output.topk(maxk, 1, True, True) 123 | pred = pred.t() 124 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 125 | 126 | res = [] 127 | for k in topk: 128 | # correct_k = correct[:k].view(-1).float().sum(0) 129 | # for pytorch >= 1.7.0 130 | correct_k = correct[:k].reshape(-1).float().sum(0) 131 | res.append(correct_k.mul_(100.0/batch_size)) 132 | return res 133 | 134 | def save_checkpoint(state, save): 135 | if not os.path.exists(save): 136 | os.makedirs(save) 137 | filename = os.path.join(save, 'checkpoint_epoch_{}.pth.tar'.format(state['epoch']+1)) 138 | torch.save(state, filename) 139 | print('Save CheckPoint....') 140 | 141 | 142 | def save(model, save, suffix): 143 | torch.save(model.module.state_dict(), save) 144 | shutil.copyfile(save, 'weight_{}.pt'.format(suffix)) 145 | 146 | def create_exp_dir(path, scripts_to_save=None): 147 | if not os.path.exists(path): 148 | os.mkdir(path) 149 | print('Experiment dir : {}'.format(path)) 150 | 151 | script_path = os.path.join(path, 'scripts') 152 | if scripts_to_save is not None and not os.path.exists(script_path): 153 | os.mkdir(script_path) 154 | for script in scripts_to_save: 155 | dst_file = os.path.join(path, 'scripts', os.path.basename(script)) 156 | shutil.copyfile(script, dst_file) 157 | 158 | def merge_ops(rngs): 159 | cand = [] 160 | for rng in rngs: 161 | for r in rng: 162 | cand.append(r) 163 | cand += [-1] 164 | cand = cand[:-1] 165 | return cand 166 | 167 | def split_ops(cand): 168 | cell, layer = 0, 0 169 | cand_ = [[]] 170 | for c in cand: 171 | if c == -1: 172 | cand_.append([]) 173 | layer += 1 174 | else: 175 | cand_[layer].append(c) 176 | return cand_ 177 | 178 | def get_search_space_size(operations): 179 | comb_num = 1 180 | for j in range(len(operations)): 181 | comb_num *= len(operations[j]) 182 | return comb_num 183 | 184 | def count_parameters_in_MB(model): 185 | return np.sum(np.prod(v.size()) for name, v in model.named_parameters() if "auxiliary" not in name)/1e6 186 | 187 | class Cutout(object): 188 | def __init__(self, length): 189 | self.length = length 190 | 191 | def __call__(self, img): 192 | h, w = img.size(1), img.size(2) 193 | mask = np.ones((h, w), np.float32) 194 | y = np.random.randint(h) 195 | x = np.random.randint(w) 196 | 197 | y1 = np.clip(y - self.length // 2, 0, h) 198 | y2 = np.clip(y + self.length // 2, 0, h) 199 | x1 = np.clip(x - self.length // 2, 0, w) 200 | x2 = np.clip(x + self.length // 2, 0, w) 201 | 202 | mask[y1: y2, x1: x2] = 0. 203 | mask = torch.from_numpy(mask) 204 | mask = mask.expand_as(img) 205 | img *= mask 206 | return img 207 | 208 | def _data_transforms_cifar10(args): 209 | CIFAR_MEAN = [0.49139968, 0.48215827, 0.44653124] 210 | CIFAR_STD = [0.24703233, 0.24348505, 0.26158768] 211 | 212 | train_transform = transforms.Compose([ 213 | transforms.RandomCrop(32, padding=4), 214 | transforms.RandomHorizontalFlip(), 215 | transforms.ToTensor(), 216 | transforms.Normalize(CIFAR_MEAN, CIFAR_STD), 217 | ]) 218 | if args.cutout: 219 | train_transform.transforms.append(Cutout(args.cutout_length)) 220 | 221 | valid_transform = transforms.Compose([ 222 | transforms.ToTensor(), 223 | transforms.Normalize(CIFAR_MEAN, CIFAR_STD), 224 | ]) 225 | return train_transform, valid_transform 226 | 227 | class Cifar10RandomLabels(dset.CIFAR10): 228 | """CIFAR10 dataset, with support for randomly corrupt labels. 229 | Params 230 | ------ 231 | rand_seed: int 232 | Default 0. numpy random seed. 233 | num_classes: int 234 | Default 10. The number of classes in the dataset. 235 | """ 236 | def __init__(self, rand_seed=0, num_classes=10, **kwargs): 237 | super(Cifar10RandomLabels, self).__init__(**kwargs) 238 | self.n_classes = num_classes 239 | self.rand_seed = rand_seed 240 | self.random_labels() 241 | 242 | def random_labels(self): 243 | labels = np.array(self.targets) 244 | print('num_classes:{}, random labels num:{}, random seed:{}'.format(self.n_classes, len(labels), self.rand_seed)) 245 | np.random.seed(self.rand_seed) 246 | rnd_labels = np.random.randint(0, self.n_classes, len(labels)) 247 | # we need to explicitly cast the labels from npy.int64 to 248 | # builtin int type, otherwise pytorch will fail... 249 | labels = [int(x) for x in rnd_labels] 250 | 251 | self.targets = labels --------------------------------------------------------------------------------