├── LICENSE ├── README.md ├── darts └── cnn │ ├── architect.py │ ├── eval-EXP │ └── log.txt │ ├── genotypes.py │ ├── model.py │ ├── operations.py │ ├── test.py │ ├── test_imagenet.py │ ├── train.py │ ├── train_imagenet.py │ ├── train_search.py │ ├── utils.py │ └── visualize.py ├── docs └── arch.png ├── gin └── models │ ├── graphcnn.py │ └── mlp.py ├── models ├── configs.py ├── layers.py ├── model.py ├── pretraining_darts.py ├── pretraining_darts.sh ├── pretraining_nasbench101.py ├── pretraining_nasbench101.sh ├── pretraining_nasbench201.py └── pretraining_nasbench201.sh ├── plot_scripts ├── distance_comparison_fig3.py ├── draw_darts.py ├── drawfig4.sh ├── drawfig5-darts.sh ├── drawfig5-nas101.sh ├── drawfig5-nas201.sh ├── nas201.jpg ├── pearson_plot_fig2.py ├── plot_cdf.py ├── plot_dngo_search_arch2vec.py ├── plot_nasbench101_comparison.py ├── plot_reinforce_search_arch2vec.py ├── summarize_nasbench201.py ├── try_networkx.py ├── visdensity.py └── visgraph.py ├── preprocessing ├── api.py ├── gen_isomorphism_graphs.py ├── gen_json.py └── nasbench201_json.py ├── pybnn ├── __init__.py ├── base_model.py ├── bayesian_linear_regression.py ├── dngo.py ├── dngo_supervised.py └── util │ ├── __init__.py │ └── normalization.py ├── requirements.txt ├── results ├── BO-arch2vec-model-nasbench-101.json ├── BO-supervised-nasbench-101.json ├── BOHB-Search-Encoding-A.json ├── RL-arch2vec-model-nasbench-101.json ├── RL-supervised-nasbench-101.json ├── Random-Search-Encoding-A.json ├── Regularized-Evolution-Encoding-A.json └── Reinforce-Search-Encoding-A.json ├── run_scripts ├── extract_arch2vec.sh ├── extract_arch2vec_darts.sh ├── extract_arch2vec_nasbench201.sh ├── run_bo_arch2vec_darts.sh ├── run_bo_arch2vec_nasbench201_ImageNet.sh ├── run_bo_arch2vec_nasbench201_cifar100.sh ├── run_bo_arch2vec_nasbench201_cifar10_valid.sh ├── run_dngo_arch2vec.sh ├── run_dngo_supervised.sh ├── run_reinforce_arch2vec.sh ├── run_reinforce_arch2vec_darts.sh ├── run_reinforce_arch2vec_nasbench201_ImageNet.sh ├── run_reinforce_arch2vec_nasbench201_cifar100.sh ├── run_reinforce_arch2vec_nasbench201_cifar10_valid.sh └── run_reinforce_supervised.sh ├── search_methods ├── dngo.py ├── dngo_darts.py ├── dngo_search_NB201_8x8.py ├── reinforce.py ├── reinforce_darts.py ├── reinforce_search_NB201_8x8.py ├── supervised_dngo.py └── supervised_reinforce.py └── utils └── utils.py /README.md: -------------------------------------------------------------------------------- 1 | # Does Unsupervised Architecture Representation Learning Help Neural Architecture Search? 2 | Code for paper: 3 | > [Does Unsupervised Architecture Representation Learning Help Neural Architecture Search?](https://arxiv.org/abs/2006.06936)\ 4 | > Shen Yan, Yu Zheng, Wei Ao, Xiao Zeng, Mi Zhang.\ 5 | > _NeurIPS 2020_. 6 | 7 |

8 | arch2vec
9 | Top: The supervision signal for representation learning comes from the accuracies of architectures selected by the search strategies. Bottom (ours): Disentangling architecture representation learning and architecture search through unsupervised pre-training. 10 |

11 | 12 | The repository is built upon [pytorch_geometric](https://github.com/rusty1s/pytorch_geometric), [pybnn](https://github.com/automl/pybnn), [nas_benchmarks](https://github.com/automl/nas_benchmarks), [bananas](https://github.com/naszilla/bananas). 13 | 14 | ## 1. Requirements 15 | - NVIDIA GPU, Linux, Python3 16 | ```bash 17 | pip install -r requirements.txt 18 | ``` 19 | 20 | ## 2. Experiments on NAS-Bench-101 21 | ### Dataset preparation on NAS-Bench-101 22 | 23 | Install [nasbench](https://github.com/google-research/nasbench) and download [nasbench_only108.tfrecord](https://storage.googleapis.com/nasbench/nasbench_only108.tfrecord) under `./data` folder. 24 | 25 | ```bash 26 | python preprocessing/gen_json.py 27 | ``` 28 | 29 | Data will be saved in `./data/data.json`. 30 | 31 | ### Pretraining 32 | ```bash 33 | bash models/pretraining_nasbench101.sh 34 | ``` 35 | 36 | The pretrained model will be saved in `./pretrained/dim-16/`. 37 | 38 | ### arch2vec extraction 39 | ```bash 40 | bash run_scripts/extract_arch2vec.sh 41 | ``` 42 | 43 | The extracted arch2vec will be saved in `./pretrained/dim-16/`. 44 | 45 | Alternatively, you can download the pretrained [arch2vec](https://drive.google.com/file/d/16GnqqrN46PJWl8QnES83WY3W58NUhgCr/view?usp=sharing) on NAS-Bench-101. 46 | 47 | 48 | ### Run experiments of RL search on NAS-Bench-101 49 | ```bash 50 | bash run_scripts/run_reinforce_supervised.sh 51 | bash run_scripts/run_reinforce_arch2vec.sh 52 | ``` 53 | 54 | Search results will be saved in `./saved_logs/rl/dim16` 55 | 56 | Generate json file: 57 | ```bash 58 | python plot_scripts/plot_reinforce_search_arch2vec.py 59 | ``` 60 | 61 | 62 | ### Run experiments of BO search on NAS-Bench-101 63 | ```bash 64 | bash run_scripts/run_dngo_supervised.sh 65 | bash run_scripts/run_dngo_arch2vec.sh 66 | ``` 67 | 68 | Search results will be saved in `./saved_logs/bo/dim16`. 69 | 70 | Generate json file: 71 | ```bash 72 | python plot_scripts/plot_dngo_search_arch2vec.py 73 | ``` 74 | 75 | ### Plot NAS comparison curve on NAS-Bench-101: 76 | ```bash 77 | python plot_scipts/plot_nasbench101_comparison.py 78 | ``` 79 | 80 | ### Plot CDF comparison curve on NAS-Bench-101: 81 | Download the search results from [search_logs](https://drive.google.com/drive/u/1/folders/1FKZghhBX0-gVNcQpzYjMShOH7mdkfwC1). 82 | ```bash 83 | python plot_scripts/plot_cdf.py 84 | ``` 85 | 86 | 87 | ## 3. Experiments on NAS-Bench-201 88 | 89 | ### Dataset preparation 90 | Download the [NAS-Bench-201-v1_0-e61699.pth](https://drive.google.com/file/d/1SKW0Cu0u8-gb18zDpaAGi0f74UdXeGKs/view) under `./data` folder. 91 | ```bash 92 | python preprocessing/nasbench201_json.py 93 | ``` 94 | Data corresponding to the three datasets in NAS-Bench-201 will be saved in folder `./data/` as `cifar10_valid_converged.json`, `cifar100.json`, `ImageNet16_120.json`. 95 | 96 | ### Pretraining 97 | ```bash 98 | bash models/pretraining_nasbench201.sh 99 | ``` 100 | The pretrained model will be saved in `./pretrained/dim-16/`. 101 | 102 | Note that the pretrained model is shared across the 3 datasets in NAS-Bench-201. 103 | 104 | ### arch2vec extraction 105 | ```bash 106 | bash run_scripts/extract_arch2vec_nasbench201.sh 107 | ``` 108 | The extracted arch2vec will be saved in `./pretrained/dim-16/` as `cifar10_valid_converged-arch2vec.pt`, `cifar100-arch2vec.pt` and `ImageNet16_120-arch2vec.pt`. 109 | 110 | Alternatively, you can download the pretrained [arch2vec](https://drive.google.com/drive/u/1/folders/16AIs4GfGNgeaHriTAICLCxIBdYE223id) on NAS-Bench-201. 111 | 112 | ### Run experiments of RL search on NAS-Bench-201 113 | ```bash 114 | CIFAR-10: ./run_scripts/run_reinforce_arch2vec_nasbench201_cifar10_valid.sh 115 | CIFAR-100: ./run_scripts/run_reinforce_arch2vec_nasbench201_cifar100.sh 116 | ImageNet-16-120: ./run_scripts/run_reinforce_arch2vec_nasbench201_ImageNet.sh 117 | ``` 118 | 119 | 120 | ### Run experiments of BO search on NAS-Bench-201 121 | ```bash 122 | CIFAR-10: ./run_scripts/run_bo_arch2vec_nasbench201_cifar10_valid.sh 123 | CIFAR-100: ./run_scripts/run_bo_arch2vec_nasbench201_cifar100.sh 124 | ImageNet-16-120: ./run_scripts/run_bo_arch2vec_nasbench201_ImageNet.sh 125 | ``` 126 | 127 | 128 | ### Summarize search result on NAS-Bench-201 129 | ```bash 130 | python ./plot_scripts/summarize_nasbench201.py 131 | ``` 132 | The corresponding table will be printed to the console. 133 | 134 | 135 | ## 4. Experiments on DARTS Search Space 136 | CIFAR-10 can be automatically downloaded by torchvision, ImageNet needs to be manually downloaded (preferably to a SSD) from http://image-net.org/download. 137 | 138 | ### Random sampling 600,000 isomorphic graphs in DARTS space 139 | ```bash 140 | python preprocessing/gen_isomorphism_graphs.py 141 | ``` 142 | Data will be saved in `./data/data_darts_counter600000.json`. 143 | 144 | Alternatively, you can download the extracted [data_darts_counter600000.json](https://drive.google.com/file/d/1xboQV_NtsSDyOPM4H7RxtDNL-2WXo3Wr/view?usp=sharing). 145 | 146 | ### Pretraining 147 | ```bash 148 | bash models/pretraining_darts.sh 149 | ``` 150 | The pretrained model is saved in `./pretrained/dim-16/`. 151 | 152 | ### arch2vec extraction 153 | ```bash 154 | bash run_scripts/extract_arch2vec_darts.sh 155 | ``` 156 | The extracted arch2vec will be saved in `./pretrained/dim-16/arch2vec-darts.pt`. 157 | 158 | Alternatively, you can download the pretrained [arch2vec](https://drive.google.com/file/d/1bDZCD-XDzded6SRjDUpRV6xTINpwTNcm/view?usp=sharing) on DARTS search space. 159 | 160 | ### Run experiments of RL search on DARTS search space 161 | ```bash 162 | bash run_scripts/run_reinforce_arch2vec_darts.sh 163 | ``` 164 | logs will be saved in `./darts-rl/`. 165 | 166 | Final search result will be saved in `./saved_logs/rl/dim16`. 167 | 168 | ### Run experiments of BO search on DARTS search space 169 | ```bash 170 | bash run_scripts/run_bo_arch2vec_darts.sh 171 | ``` 172 | logs will be saved in `./darts-bo/` . 173 | 174 | Final search result will be saved in `./saved_logs/bo/dim16`. 175 | 176 | ### Evaluate the learned cell on DARTS Search Space on CIFAR-10 177 | ```bash 178 | python darts/cnn/train.py --auxiliary --cutout --arch arch2vec_rl --seed 1 179 | python darts/cnn/train.py --auxiliary --cutout --arch arch2vec_bo --seed 1 180 | ``` 181 | - Expected results (RL): 2.60\% test error with 3.3M model params. 182 | - Expected results (BO): 2.48\% test error with 3.6M model params. 183 | 184 | 185 | ### Transfer learning on ImageNet 186 | ```bash 187 | python darts/cnn/train_imagenet.py --arch arch2vec_rl --seed 1 188 | python darts/cnn/train_imagenet.py --arch arch2vec_bo --seed 1 189 | ``` 190 | - Expected results (RL): 25.8\% test error with 4.8M model params and 533M mult-adds. 191 | - Expected results (RL): 25.5\% test error with 5.2M model params and 580M mult-adds. 192 | 193 | 194 | ### Visualize the learned cell 195 | ```bash 196 | python darts/cnn/visualize.py arch2vec_rl 197 | python darts/cnn/visualize.py arch2vec_bo 198 | ``` 199 | 200 | ## 5. Analyzing the results 201 | ### Visualize a sequence of decoded cells from the latent space 202 | Download pretrained supervised embeddings of [nasbench101](https://drive.google.com/file/d/19-1gpMdXftXoH7G5929peoOnS1xKf5wN/view?usp=sharing) and [nasbench201](https://drive.google.com/file/d/1_Pw8MDp6ZrlI6EJ0kS3MVEz3HOSJMnIV/view?usp=sharing). 203 | ```bash 204 | bash plot_scripts/drawfig5-nas101.sh # visualization on nasbench-101 205 | bash plot_scripts/drawfig5-nas201.sh # visualization on nasbench-201 206 | bash plot_scripts/drawfig5-darts.sh # visualization on darts 207 | ``` 208 | The plots will be saved in `./graphvisualization`. 209 | 210 | ### Plot distribution of L2 distance by edit distance 211 | Install [nas_benchmarks](https://github.com/automl/nas_benchmarks) and download [nasbench_full.tfrecord](https://storage.googleapis.com/nasbench/nasbench_full.tfrecord) under the same directory. 212 | ```bash 213 | python plot_scripts/distance_comparison_fig3.py 214 | ``` 215 | 216 | ### Latent space 2D visualization 217 | ```bash 218 | bash plot_scripts/drawfig4.sh 219 | ``` 220 | the plots will be saved in `./density`. 221 | 222 | ### Predictive performance comparison 223 | Download [predicted_accuracy](https://drive.google.com/drive/u/1/folders/1mNlg5s3FQ8PEcgTDSnAuM6qa8ECDTzhh) under `saved_logs/`. 224 | ```bash 225 | python plot_scripts/pearson_plot_fig2.py 226 | ``` 227 | 228 | 229 | 230 | 231 | 232 | # Citation 233 | If you find this useful for your work, please consider citing: 234 | ``` 235 | @InProceedings{yan2020arch, 236 | title = {Does Unsupervised Architecture Representation Learning Help Neural Architecture Search?}, 237 | author = {Yan, Shen and Zheng, Yu and Ao, Wei and Zeng, Xiao and Zhang, Mi}, 238 | booktitle = {NeurIPS}, 239 | year = {2020} 240 | } 241 | ``` 242 | 243 | 244 | 245 | 246 | -------------------------------------------------------------------------------- /darts/cnn/architect.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import torch.nn as nn 4 | from torch.autograd import Variable 5 | 6 | 7 | def _concat(xs): 8 | return torch.cat([x.view(-1) for x in xs]) 9 | 10 | 11 | class Architect(object): 12 | 13 | def __init__(self, model, args): 14 | self.network_momentum = args.momentum 15 | self.network_weight_decay = args.weight_decay 16 | self.model = model 17 | self.optimizer = torch.optim.Adam(self.model.arch_parameters(), 18 | lr=args.arch_learning_rate, betas=(0.5, 0.999), weight_decay=args.arch_weight_decay) 19 | 20 | def _compute_unrolled_model(self, input, target, eta, network_optimizer): 21 | loss = self.model._loss(input, target) 22 | theta = _concat(self.model.parameters()).data 23 | try: 24 | moment = _concat(network_optimizer.state[v]['momentum_buffer'] for v in self.model.parameters()).mul_(self.network_momentum) 25 | except: 26 | moment = torch.zeros_like(theta) 27 | dtheta = _concat(torch.autograd.grad(loss, self.model.parameters())).data + self.network_weight_decay*theta 28 | unrolled_model = self._construct_model_from_theta(theta.sub(eta, moment+dtheta)) 29 | return unrolled_model 30 | 31 | def step(self, input_train, target_train, input_valid, target_valid, eta, network_optimizer, unrolled): 32 | self.optimizer.zero_grad() 33 | if unrolled: 34 | self._backward_step_unrolled(input_train, target_train, input_valid, target_valid, eta, network_optimizer) 35 | else: 36 | self._backward_step(input_valid, target_valid) 37 | self.optimizer.step() 38 | 39 | def _backward_step(self, input_valid, target_valid): 40 | loss = self.model._loss(input_valid, target_valid) 41 | loss.backward() 42 | 43 | def _backward_step_unrolled(self, input_train, target_train, input_valid, target_valid, eta, network_optimizer): 44 | unrolled_model = self._compute_unrolled_model(input_train, target_train, eta, network_optimizer) 45 | unrolled_loss = unrolled_model._loss(input_valid, target_valid) 46 | 47 | unrolled_loss.backward() 48 | dalpha = [v.grad for v in unrolled_model.arch_parameters()] 49 | vector = [v.grad.data for v in unrolled_model.parameters()] 50 | implicit_grads = self._hessian_vector_product(vector, input_train, target_train) 51 | 52 | for g, ig in zip(dalpha, implicit_grads): 53 | g.data.sub_(eta, ig.data) 54 | 55 | for v, g in zip(self.model.arch_parameters(), dalpha): 56 | if v.grad is None: 57 | v.grad = Variable(g.data) 58 | else: 59 | v.grad.data.copy_(g.data) 60 | 61 | def _construct_model_from_theta(self, theta): 62 | model_new = self.model.new() 63 | model_dict = self.model.state_dict() 64 | 65 | params, offset = {}, 0 66 | for k, v in self.model.named_parameters(): 67 | v_length = np.prod(v.size()) 68 | params[k] = theta[offset: offset+v_length].view(v.size()) 69 | offset += v_length 70 | 71 | assert offset == len(theta) 72 | model_dict.update(params) 73 | model_new.load_state_dict(model_dict) 74 | return model_new.cuda() 75 | 76 | def _hessian_vector_product(self, vector, input, target, r=1e-2): 77 | R = r / _concat(vector).norm() 78 | for p, v in zip(self.model.parameters(), vector): 79 | p.data.add_(R, v) 80 | loss = self.model._loss(input, target) 81 | grads_p = torch.autograd.grad(loss, self.model.arch_parameters()) 82 | 83 | for p, v in zip(self.model.parameters(), vector): 84 | p.data.sub_(2*R, v) 85 | loss = self.model._loss(input, target) 86 | grads_n = torch.autograd.grad(loss, self.model.arch_parameters()) 87 | 88 | for p, v in zip(self.model.parameters(), vector): 89 | p.data.add_(R, v) 90 | 91 | return [(x-y).div_(2*R) for x, y in zip(grads_p, grads_n)] 92 | 93 | -------------------------------------------------------------------------------- /darts/cnn/genotypes.py: -------------------------------------------------------------------------------- 1 | from collections import namedtuple 2 | 3 | Genotype = namedtuple('Genotype', 'normal normal_concat reduce reduce_concat') 4 | 5 | PRIMITIVES = [ 6 | 'none', 7 | 'max_pool_3x3', 8 | 'avg_pool_3x3', 9 | 'skip_connect', 10 | 'sep_conv_3x3', 11 | 'sep_conv_5x5', 12 | 'dil_conv_3x3', 13 | 'dil_conv_5x5' 14 | ] 15 | 16 | NASNet = Genotype( 17 | normal = [ 18 | ('sep_conv_5x5', 1), 19 | ('sep_conv_3x3', 0), 20 | ('sep_conv_5x5', 0), 21 | ('sep_conv_3x3', 0), 22 | ('avg_pool_3x3', 1), 23 | ('skip_connect', 0), 24 | ('avg_pool_3x3', 0), 25 | ('avg_pool_3x3', 0), 26 | ('sep_conv_3x3', 1), 27 | ('skip_connect', 1), 28 | ], 29 | normal_concat = [2, 3, 4, 5, 6], 30 | reduce = [ 31 | ('sep_conv_5x5', 1), 32 | ('sep_conv_7x7', 0), 33 | ('max_pool_3x3', 1), 34 | ('sep_conv_7x7', 0), 35 | ('avg_pool_3x3', 1), 36 | ('sep_conv_5x5', 0), 37 | ('skip_connect', 3), 38 | ('avg_pool_3x3', 2), 39 | ('sep_conv_3x3', 2), 40 | ('max_pool_3x3', 1), 41 | ], 42 | reduce_concat = [4, 5, 6], 43 | ) 44 | 45 | AmoebaNet = Genotype( 46 | normal = [ 47 | ('avg_pool_3x3', 0), 48 | ('max_pool_3x3', 1), 49 | ('sep_conv_3x3', 0), 50 | ('sep_conv_5x5', 2), 51 | ('sep_conv_3x3', 0), 52 | ('avg_pool_3x3', 3), 53 | ('sep_conv_3x3', 1), 54 | ('skip_connect', 1), 55 | ('skip_connect', 0), 56 | ('avg_pool_3x3', 1), 57 | ], 58 | normal_concat = [4, 5, 6], 59 | reduce = [ 60 | ('avg_pool_3x3', 0), 61 | ('sep_conv_3x3', 1), 62 | ('max_pool_3x3', 0), 63 | ('sep_conv_7x7', 2), 64 | ('sep_conv_7x7', 0), 65 | ('avg_pool_3x3', 1), 66 | ('max_pool_3x3', 0), 67 | ('max_pool_3x3', 1), 68 | ('conv_7x1_1x7', 0), 69 | ('sep_conv_3x3', 5), 70 | ], 71 | reduce_concat = [3, 4, 6] 72 | ) 73 | 74 | 75 | DARTS = Genotype(normal=[('sep_conv_3x3', 0), ('sep_conv_3x3', 1), ('sep_conv_3x3', 0), ('sep_conv_3x3', 1), ('sep_conv_3x3', 1), ('skip_connect', 0), ('skip_connect', 0), ('dil_conv_3x3', 2)], normal_concat=[2, 3, 4, 5], reduce=[('max_pool_3x3', 0), ('max_pool_3x3', 1), ('skip_connect', 2), ('max_pool_3x3', 1), ('max_pool_3x3', 0), ('skip_connect', 2), ('skip_connect', 2), ('max_pool_3x3', 1)], reduce_concat=[2, 3, 4, 5]) 76 | 77 | BANANAS = Genotype(normal=[('sep_conv_3x3', 1), ('skip_connect', 0), ('sep_conv_5x5', 2), ('sep_conv_5x5', 0), ('skip_connect', 0), ('sep_conv_5x5', 2), ('sep_conv_3x3', 0), ('skip_connect', 2)], normal_concat=[2, 3, 4, 5], reduce=[('sep_conv_3x3', 1), ('max_pool_3x3', 0), ('max_pool_3x3', 0), ('none', 1), ('dil_conv_3x3', 2), ('sep_conv_5x5', 3), ('sep_conv_5x5', 4), ('sep_conv_3x3', 1)], reduce_concat=[2, 3, 4, 5]) 78 | 79 | arch2vec_bo = Genotype(normal=[('sep_conv_5x5', 1), ('max_pool_3x3', 0), ('skip_connect', 0), ('dil_conv_3x3', 1), ('sep_conv_5x5', 1), ('sep_conv_3x3', 0), ('dil_conv_5x5', 2), ('sep_conv_3x3', 0)], normal_concat=[2, 3, 4, 5], reduce=[('sep_conv_5x5', 1), ('max_pool_3x3', 0), ('skip_connect', 0), ('dil_conv_3x3', 1), ('sep_conv_5x5', 1), ('sep_conv_3x3', 0), ('dil_conv_5x5', 2), ('sep_conv_3x3', 0)], reduce_concat=[2, 3, 4, 5]) 80 | 81 | arch2vec_rl = Genotype(normal=[('sep_conv_3x3', 0), ('dil_conv_3x3', 1), ('max_pool_3x3', 0), ('dil_conv_3x3', 1), ('skip_connect', 0), ('sep_conv_3x3', 1), ('dil_conv_5x5', 1), ('sep_conv_3x3', 0)], normal_concat=[2, 3, 4, 5], reduce=[('sep_conv_3x3', 0), ('dil_conv_3x3', 1), ('max_pool_3x3', 0), ('dil_conv_3x3', 1), ('skip_connect', 0), ('sep_conv_3x3', 1), ('dil_conv_5x5', 1), ('sep_conv_3x3', 0)], reduce_concat=[2, 3, 4, 5]) 82 | -------------------------------------------------------------------------------- /darts/cnn/model.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | sys.path.insert(0, os.getcwd()) 4 | from darts.cnn.operations import * 5 | from darts.cnn.utils import drop_path 6 | 7 | 8 | class Cell(nn.Module): 9 | 10 | def __init__(self, genotype, C_prev_prev, C_prev, C, reduction, reduction_prev): 11 | super(Cell, self).__init__() 12 | #print(C_prev_prev, C_prev, C) 13 | 14 | if reduction_prev: 15 | self.preprocess0 = FactorizedReduce(C_prev_prev, C) 16 | else: 17 | self.preprocess0 = ReLUConvBN(C_prev_prev, C, 1, 1, 0) 18 | self.preprocess1 = ReLUConvBN(C_prev, C, 1, 1, 0) 19 | 20 | if reduction: 21 | op_names, indices = zip(*genotype.reduce) 22 | concat = genotype.reduce_concat 23 | else: 24 | op_names, indices = zip(*genotype.normal) 25 | concat = genotype.normal_concat 26 | self._compile(C, op_names, indices, concat, reduction) 27 | 28 | def _compile(self, C, op_names, indices, concat, reduction): 29 | assert len(op_names) == len(indices) 30 | self._steps = len(op_names) // 2 31 | self._concat = concat 32 | self.multiplier = len(concat) 33 | 34 | self._ops = nn.ModuleList() 35 | for name, index in zip(op_names, indices): 36 | stride = 2 if reduction and index < 2 else 1 37 | op = OPS[name](C, stride, True) 38 | self._ops += [op] 39 | self._indices = indices 40 | 41 | def forward(self, s0, s1, drop_prob): 42 | s0 = self.preprocess0(s0) 43 | s1 = self.preprocess1(s1) 44 | 45 | states = [s0, s1] 46 | for i in range(self._steps): 47 | h1 = states[self._indices[2*i]] 48 | h2 = states[self._indices[2*i+1]] 49 | op1 = self._ops[2*i] 50 | op2 = self._ops[2*i+1] 51 | h1 = op1(h1) 52 | h2 = op2(h2) 53 | if self.training and drop_prob > 0.: 54 | if not isinstance(op1, Identity): 55 | h1 = drop_path(h1, drop_prob) 56 | if not isinstance(op2, Identity): 57 | h2 = drop_path(h2, drop_prob) 58 | s = h1 + h2 59 | states += [s] 60 | return torch.cat([states[i] for i in self._concat], dim=1) 61 | 62 | 63 | class AuxiliaryHeadCIFAR(nn.Module): 64 | 65 | def __init__(self, C, num_classes): 66 | """assuming input size 8x8""" 67 | super(AuxiliaryHeadCIFAR, self).__init__() 68 | self.features = nn.Sequential( 69 | nn.ReLU(inplace=True), 70 | nn.AvgPool2d(5, stride=3, padding=0, count_include_pad=False), # image size = 2 x 2 71 | nn.Conv2d(C, 128, 1, bias=False), 72 | nn.BatchNorm2d(128), 73 | nn.ReLU(inplace=True), 74 | nn.Conv2d(128, 768, 2, bias=False), 75 | nn.BatchNorm2d(768), 76 | nn.ReLU(inplace=True) 77 | ) 78 | self.classifier = nn.Linear(768, num_classes) 79 | 80 | def forward(self, x): 81 | x = self.features(x) 82 | x = self.classifier(x.view(x.size(0),-1)) 83 | return x 84 | 85 | 86 | class AuxiliaryHeadImageNet(nn.Module): 87 | 88 | def __init__(self, C, num_classes): 89 | """assuming input size 14x14""" 90 | super(AuxiliaryHeadImageNet, self).__init__() 91 | self.features = nn.Sequential( 92 | nn.ReLU(inplace=True), 93 | nn.AvgPool2d(5, stride=2, padding=0, count_include_pad=False), 94 | nn.Conv2d(C, 128, 1, bias=False), 95 | nn.BatchNorm2d(128), 96 | nn.ReLU(inplace=True), 97 | nn.Conv2d(128, 768, 2, bias=False), 98 | # NOTE: This batchnorm was omitted in my earlier implementation due to a typo. 99 | # Commenting it out for consistency with the experiments in the paper. 100 | # nn.BatchNorm2d(768), 101 | nn.ReLU(inplace=True) 102 | ) 103 | self.classifier = nn.Linear(768, num_classes) 104 | 105 | def forward(self, x): 106 | x = self.features(x) 107 | x = self.classifier(x.view(x.size(0),-1)) 108 | return x 109 | 110 | 111 | class NetworkCIFAR(nn.Module): 112 | 113 | def __init__(self, C, num_classes, layers, auxiliary, genotype): 114 | super(NetworkCIFAR, self).__init__() 115 | self._layers = layers 116 | self._auxiliary = auxiliary 117 | 118 | stem_multiplier = 3 119 | C_curr = stem_multiplier*C 120 | self.stem = nn.Sequential( 121 | nn.Conv2d(3, C_curr, 3, padding=1, bias=False), 122 | nn.BatchNorm2d(C_curr) 123 | ) 124 | 125 | C_prev_prev, C_prev, C_curr = C_curr, C_curr, C 126 | self.cells = nn.ModuleList() 127 | reduction_prev = False 128 | for i in range(layers): 129 | if i in [layers//3, 2*layers//3]: 130 | C_curr *= 2 131 | reduction = True 132 | else: 133 | reduction = False 134 | cell = Cell(genotype, C_prev_prev, C_prev, C_curr, reduction, reduction_prev) 135 | reduction_prev = reduction 136 | self.cells += [cell] 137 | C_prev_prev, C_prev = C_prev, cell.multiplier*C_curr 138 | if i == 2*layers//3: 139 | C_to_auxiliary = C_prev 140 | 141 | if auxiliary: 142 | self.auxiliary_head = AuxiliaryHeadCIFAR(C_to_auxiliary, num_classes) 143 | self.global_pooling = nn.AdaptiveAvgPool2d(1) 144 | self.classifier = nn.Linear(C_prev, num_classes) 145 | 146 | def forward(self, input): 147 | logits_aux = None 148 | s0 = s1 = self.stem(input) 149 | for i, cell in enumerate(self.cells): 150 | s0, s1 = s1, cell(s0, s1, self.drop_path_prob) 151 | if i == 2*self._layers//3: 152 | if self._auxiliary and self.training: 153 | logits_aux = self.auxiliary_head(s1) 154 | out = self.global_pooling(s1) 155 | logits = self.classifier(out.view(out.size(0),-1)) 156 | return logits, logits_aux 157 | 158 | 159 | class NetworkImageNet(nn.Module): 160 | 161 | def __init__(self, C, num_classes, layers, auxiliary, genotype): 162 | super(NetworkImageNet, self).__init__() 163 | self._layers = layers 164 | self._auxiliary = auxiliary 165 | 166 | self.stem0 = nn.Sequential( 167 | nn.Conv2d(3, C // 2, kernel_size=3, stride=2, padding=1, bias=False), 168 | nn.BatchNorm2d(C // 2), 169 | nn.ReLU(inplace=True), 170 | nn.Conv2d(C // 2, C, 3, stride=2, padding=1, bias=False), 171 | nn.BatchNorm2d(C), 172 | ) 173 | 174 | self.stem1 = nn.Sequential( 175 | nn.ReLU(inplace=True), 176 | nn.Conv2d(C, C, 3, stride=2, padding=1, bias=False), 177 | nn.BatchNorm2d(C), 178 | ) 179 | 180 | C_prev_prev, C_prev, C_curr = C, C, C 181 | 182 | self.cells = nn.ModuleList() 183 | reduction_prev = True 184 | for i in range(layers): 185 | if i in [layers // 3, 2 * layers // 3]: 186 | C_curr *= 2 187 | reduction = True 188 | else: 189 | reduction = False 190 | cell = Cell(genotype, C_prev_prev, C_prev, C_curr, reduction, reduction_prev) 191 | reduction_prev = reduction 192 | self.cells += [cell] 193 | C_prev_prev, C_prev = C_prev, cell.multiplier * C_curr 194 | if i == 2 * layers // 3: 195 | C_to_auxiliary = C_prev 196 | 197 | if auxiliary: 198 | self.auxiliary_head = AuxiliaryHeadImageNet(C_to_auxiliary, num_classes) 199 | self.global_pooling = nn.AvgPool2d(7) 200 | self.classifier = nn.Linear(C_prev, num_classes) 201 | self.drop_path_prob = 0 202 | 203 | def forward(self, input): 204 | logits_aux = None 205 | s0 = self.stem0(input) 206 | s1 = self.stem1(s0) 207 | for i, cell in enumerate(self.cells): 208 | s0, s1 = s1, cell(s0, s1, self.drop_path_prob) 209 | if i == 2 * self._layers // 3: 210 | if self._auxiliary and self.training: 211 | logits_aux = self.auxiliary_head(s1) 212 | out = self.global_pooling(s1) 213 | logits = self.classifier(out.view(out.size(0), -1)) 214 | return logits, logits_aux 215 | 216 | -------------------------------------------------------------------------------- /darts/cnn/operations.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | OPS = { 5 | 'none' : lambda C, stride, affine: Zero(stride), 6 | 'avg_pool_3x3' : lambda C, stride, affine: nn.AvgPool2d(3, stride=stride, padding=1, count_include_pad=False), 7 | 'max_pool_3x3' : lambda C, stride, affine: nn.MaxPool2d(3, stride=stride, padding=1), 8 | 'skip_connect' : lambda C, stride, affine: Identity() if stride == 1 else FactorizedReduce(C, C, affine=affine), 9 | 'sep_conv_3x3' : lambda C, stride, affine: SepConv(C, C, 3, stride, 1, affine=affine), 10 | 'sep_conv_5x5' : lambda C, stride, affine: SepConv(C, C, 5, stride, 2, affine=affine), 11 | 'sep_conv_7x7' : lambda C, stride, affine: SepConv(C, C, 7, stride, 3, affine=affine), 12 | 'dil_conv_3x3' : lambda C, stride, affine: DilConv(C, C, 3, stride, 2, 2, affine=affine), 13 | 'dil_conv_5x5' : lambda C, stride, affine: DilConv(C, C, 5, stride, 4, 2, affine=affine), 14 | 'conv_7x1_1x7' : lambda C, stride, affine: nn.Sequential( 15 | nn.ReLU(inplace=False), 16 | nn.Conv2d(C, C, (1,7), stride=(1, stride), padding=(0, 3), bias=False), 17 | nn.Conv2d(C, C, (7,1), stride=(stride, 1), padding=(3, 0), bias=False), 18 | nn.BatchNorm2d(C, affine=affine) 19 | ), 20 | } 21 | 22 | class ReLUConvBN(nn.Module): 23 | 24 | def __init__(self, C_in, C_out, kernel_size, stride, padding, affine=True): 25 | super(ReLUConvBN, self).__init__() 26 | self.op = nn.Sequential( 27 | nn.ReLU(inplace=False), 28 | nn.Conv2d(C_in, C_out, kernel_size, stride=stride, padding=padding, bias=False), 29 | nn.BatchNorm2d(C_out, affine=affine) 30 | ) 31 | 32 | def forward(self, x): 33 | return self.op(x) 34 | 35 | class DilConv(nn.Module): 36 | 37 | def __init__(self, C_in, C_out, kernel_size, stride, padding, dilation, affine=True): 38 | super(DilConv, self).__init__() 39 | self.op = nn.Sequential( 40 | nn.ReLU(inplace=False), 41 | nn.Conv2d(C_in, C_in, kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation, groups=C_in, bias=False), 42 | nn.Conv2d(C_in, C_out, kernel_size=1, padding=0, bias=False), 43 | nn.BatchNorm2d(C_out, affine=affine), 44 | ) 45 | 46 | def forward(self, x): 47 | return self.op(x) 48 | 49 | 50 | class SepConv(nn.Module): 51 | 52 | def __init__(self, C_in, C_out, kernel_size, stride, padding, affine=True): 53 | super(SepConv, self).__init__() 54 | self.op = nn.Sequential( 55 | nn.ReLU(inplace=False), 56 | nn.Conv2d(C_in, C_in, kernel_size=kernel_size, stride=stride, padding=padding, groups=C_in, bias=False), 57 | nn.Conv2d(C_in, C_in, kernel_size=1, padding=0, bias=False), 58 | nn.BatchNorm2d(C_in, affine=affine), 59 | nn.ReLU(inplace=False), 60 | nn.Conv2d(C_in, C_in, kernel_size=kernel_size, stride=1, padding=padding, groups=C_in, bias=False), 61 | nn.Conv2d(C_in, C_out, kernel_size=1, padding=0, bias=False), 62 | nn.BatchNorm2d(C_out, affine=affine), 63 | ) 64 | 65 | def forward(self, x): 66 | return self.op(x) 67 | 68 | 69 | class Identity(nn.Module): 70 | 71 | def __init__(self): 72 | super(Identity, self).__init__() 73 | 74 | def forward(self, x): 75 | return x 76 | 77 | 78 | class Zero(nn.Module): 79 | 80 | def __init__(self, stride): 81 | super(Zero, self).__init__() 82 | self.stride = stride 83 | 84 | def forward(self, x): 85 | if self.stride == 1: 86 | return x.mul(0.) 87 | return x[:,:,::self.stride,::self.stride].mul(0.) 88 | 89 | 90 | class FactorizedReduce(nn.Module): 91 | 92 | def __init__(self, C_in, C_out, affine=True): 93 | super(FactorizedReduce, self).__init__() 94 | assert C_out % 2 == 0 95 | self.relu = nn.ReLU(inplace=False) 96 | self.conv_1 = nn.Conv2d(C_in, C_out // 2, 1, stride=2, padding=0, bias=False) 97 | self.conv_2 = nn.Conv2d(C_in, C_out // 2, 1, stride=2, padding=0, bias=False) 98 | self.bn = nn.BatchNorm2d(C_out, affine=affine) 99 | 100 | def forward(self, x): 101 | x = self.relu(x) 102 | out = torch.cat([self.conv_1(x), self.conv_2(x[:,:,1:,1:])], dim=1) 103 | out = self.bn(out) 104 | return out 105 | 106 | -------------------------------------------------------------------------------- /darts/cnn/test.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | sys.path.insert(0, os.getcwd()) 4 | import glob 5 | import numpy as np 6 | import torch 7 | import utils 8 | import logging 9 | import argparse 10 | import torch.nn as nn 11 | import darts.cnn.genotypes as genotypes 12 | import torch.utils 13 | import torchvision.datasets as dset 14 | import torch.backends.cudnn as cudnn 15 | 16 | from darts.cnn.model import NetworkCIFAR as Network 17 | 18 | 19 | parser = argparse.ArgumentParser("cifar") 20 | parser.add_argument('--data', type=str, default='../data', help='location of the data corpus') 21 | parser.add_argument('--batch_size', type=int, default=32, help='batch size') 22 | parser.add_argument('--report_freq', type=float, default=50, help='report frequency') 23 | parser.add_argument('--gpu', type=int, default=0, help='gpu device id') 24 | parser.add_argument('--init_channels', type=int, default=36, help='num of init channels') 25 | parser.add_argument('--layers', type=int, default=20, help='total number of layers') 26 | parser.add_argument('--model_path', type=str, default='EXP/model.pt', help='path of pretrained model') 27 | parser.add_argument('--auxiliary', action='store_true', default=False, help='use auxiliary tower') 28 | parser.add_argument('--cutout', action='store_true', default=False, help='use cutout') 29 | parser.add_argument('--cutout_length', type=int, default=16, help='cutout length') 30 | parser.add_argument('--drop_path_prob', type=float, default=0.2, help='drop path probability') 31 | parser.add_argument('--seed', type=int, default=0, help='random seed') 32 | parser.add_argument('--arch', type=str, default='BANANAS', help='which architecture to use') 33 | args = parser.parse_args() 34 | 35 | log_format = '%(asctime)s %(message)s' 36 | logging.basicConfig(stream=sys.stdout, level=logging.INFO, 37 | format=log_format, datefmt='%m/%d %I:%M:%S %p') 38 | 39 | CIFAR_CLASSES = 10 40 | 41 | 42 | def main(): 43 | 44 | np.random.seed(args.seed) 45 | 46 | if torch.cuda.is_available(): 47 | device = torch.device('cuda:{}'.format(args.gpu)) 48 | cudnn.benchmark = True 49 | torch.manual_seed(args.seed) 50 | cudnn.enabled = True 51 | cudnn.deterministic = True 52 | torch.cuda.manual_seed(args.seed) 53 | logging.info('gpu device = %d' % args.gpu) 54 | else: 55 | device = torch.device('cpu') 56 | logging.info('No gpu device available') 57 | torch.manual_seed(args.seed) 58 | 59 | logging.info("args = %s", args) 60 | genotype = eval("genotypes.%s" % args.arch) 61 | model = Network(args.init_channels, CIFAR_CLASSES, args.layers, args.auxiliary, genotype) 62 | model = model.to(device) 63 | utils.load(model, args.model_path, args.gpu) 64 | 65 | logging.info("param size = %fMB", utils.count_parameters_in_MB(model)) 66 | 67 | criterion = nn.CrossEntropyLoss() 68 | criterion = criterion.cuda() 69 | 70 | _, test_transform = utils._data_transforms_cifar10(args.cutout, args.cutout_length) 71 | test_data = dset.CIFAR10(root=args.data, train=False, download=True, transform=test_transform) 72 | 73 | test_queue = torch.utils.data.DataLoader( 74 | test_data, batch_size=args.batch_size, shuffle=False, pin_memory=True, num_workers=2) 75 | 76 | model.drop_path_prob = args.drop_path_prob 77 | test_acc, test_obj = infer(test_queue, model, criterion, args.gpu) 78 | logging.info('test_acc %f', test_acc) 79 | 80 | 81 | def infer(test_queue, model, criterion, gpu_id=0): 82 | objs = utils.AvgrageMeter() 83 | top1 = utils.AvgrageMeter() 84 | top5 = utils.AvgrageMeter() 85 | model.eval() 86 | 87 | device = torch.device('cuda:{}'.format(gpu_id) if torch.cuda.is_available() \ 88 | else 'cpu') 89 | 90 | for step, (input, target) in enumerate(test_queue): 91 | input = input.to(device) 92 | target = target.to(device) 93 | 94 | logits, _ = model(input) 95 | loss = criterion(logits, target) 96 | 97 | prec1, prec5 = utils.accuracy(logits, target, topk=(1, 5)) 98 | n = input.size(0) 99 | objs.update(loss.item(), n) 100 | top1.update(prec1.item(), n) 101 | top5.update(prec5.item(), n) 102 | 103 | if step % args.report_freq == 0: 104 | logging.info('test %03d %e %f %f', step, objs.avg, top1.avg, top5.avg) 105 | 106 | return top1.avg, objs.avg 107 | 108 | 109 | if __name__ == '__main__': 110 | main() 111 | 112 | -------------------------------------------------------------------------------- /darts/cnn/test_imagenet.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import numpy as np 4 | import torch 5 | import utils 6 | import glob 7 | import random 8 | import logging 9 | import argparse 10 | import torch.nn as nn 11 | import genotypes 12 | import torch.utils 13 | import torchvision.datasets as dset 14 | import torchvision.transforms as transforms 15 | import torch.backends.cudnn as cudnn 16 | 17 | from model import NetworkImageNet as Network 18 | 19 | 20 | parser = argparse.ArgumentParser("imagenet") 21 | parser.add_argument('--data', type=str, default='../data/imagenet/', help='location of the data corpus') 22 | parser.add_argument('--batch_size', type=int, default=128, help='batch size') 23 | parser.add_argument('--report_freq', type=float, default=100, help='report frequency') 24 | parser.add_argument('--gpu', type=int, default=0, help='gpu device id') 25 | parser.add_argument('--init_channels', type=int, default=48, help='num of init channels') 26 | parser.add_argument('--layers', type=int, default=14, help='total number of layers') 27 | parser.add_argument('--model_path', type=str, default='EXP/model.pt', help='path of pretrained model') 28 | parser.add_argument('--auxiliary', action='store_true', default=False, help='use auxiliary tower') 29 | parser.add_argument('--drop_path_prob', type=float, default=0, help='drop path probability') 30 | parser.add_argument('--seed', type=int, default=0, help='random seed') 31 | parser.add_argument('--arch', type=str, default='DARTS', help='which architecture to use') 32 | args = parser.parse_args() 33 | 34 | log_format = '%(asctime)s %(message)s' 35 | logging.basicConfig(stream=sys.stdout, level=logging.INFO, 36 | format=log_format, datefmt='%m/%d %I:%M:%S %p') 37 | 38 | CLASSES = 1000 39 | 40 | 41 | def main(): 42 | if not torch.cuda.is_available(): 43 | logging.info('no gpu device available') 44 | sys.exit(1) 45 | 46 | np.random.seed(args.seed) 47 | torch.cuda.set_device(args.gpu) 48 | cudnn.benchmark = True 49 | torch.manual_seed(args.seed) 50 | cudnn.enabled=True 51 | torch.cuda.manual_seed(args.seed) 52 | logging.info('gpu device = %d' % args.gpu) 53 | logging.info("args = %s", args) 54 | 55 | genotype = eval("genotypes.%s" % args.arch) 56 | model = Network(args.init_channels, CLASSES, args.layers, args.auxiliary, genotype) 57 | model = model.cuda() 58 | model.load_state_dict(torch.load(args.model_path)['state_dict']) 59 | 60 | logging.info("param size = %fMB", utils.count_parameters_in_MB(model)) 61 | 62 | criterion = nn.CrossEntropyLoss() 63 | criterion = criterion.cuda() 64 | 65 | validdir = os.path.join(args.data, 'val') 66 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) 67 | valid_data = dset.ImageFolder( 68 | validdir, 69 | transforms.Compose([ 70 | transforms.Resize(256), 71 | transforms.CenterCrop(224), 72 | transforms.ToTensor(), 73 | normalize, 74 | ])) 75 | 76 | valid_queue = torch.utils.data.DataLoader( 77 | valid_data, batch_size=args.batch_size, shuffle=False, pin_memory=True, num_workers=4) 78 | 79 | model.drop_path_prob = args.drop_path_prob 80 | valid_acc_top1, valid_acc_top5, valid_obj = infer(valid_queue, model, criterion) 81 | logging.info('valid_acc_top1 %f', valid_acc_top1) 82 | logging.info('valid_acc_top5 %f', valid_acc_top5) 83 | 84 | 85 | def infer(valid_queue, model, criterion): 86 | objs = utils.AvgrageMeter() 87 | top1 = utils.AvgrageMeter() 88 | top5 = utils.AvgrageMeter() 89 | model.eval() 90 | 91 | for step, (input, target) in enumerate(valid_queue): 92 | input = input.cuda() 93 | target = target.cuda() 94 | 95 | logits, _ = model(input) 96 | loss = criterion(logits, target) 97 | 98 | prec1, prec5 = utils.accuracy(logits, target, topk=(1, 5)) 99 | n = input.size(0) 100 | objs.update(loss.data[0], n) 101 | top1.update(prec1.data[0], n) 102 | top5.update(prec5.data[0], n) 103 | 104 | if step % args.report_freq == 0: 105 | logging.info('valid %03d %e %f %f', step, objs.avg, top1.avg, top5.avg) 106 | 107 | return top1.avg, top5.avg, objs.avg 108 | 109 | 110 | if __name__ == '__main__': 111 | main() 112 | -------------------------------------------------------------------------------- /darts/cnn/train.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | sys.path.insert(0, os.getcwd()) 4 | import time 5 | import glob 6 | import numpy as np 7 | import random 8 | import torch 9 | import darts.cnn.utils as utils 10 | import logging 11 | import argparse 12 | import torch.nn as nn 13 | import darts.cnn.genotypes as genotypes 14 | import torch.utils 15 | import torchvision.datasets as dset 16 | import torch.backends.cudnn as cudnn 17 | 18 | from darts.cnn.model import NetworkCIFAR as Network 19 | 20 | 21 | parser = argparse.ArgumentParser("cifar") 22 | parser.add_argument('--data', type=str, default='./data', help='location of the data corpus') 23 | parser.add_argument('--batch_size', type=int, default=96, help='batch size') 24 | parser.add_argument('--learning_rate', type=float, default=0.025, help='init learning rate') 25 | parser.add_argument('--momentum', type=float, default=0.9, help='momentum') 26 | parser.add_argument('--weight_decay', type=float, default=3e-4, help='weight decay') 27 | parser.add_argument('--report_freq', type=float, default=500, help='report frequency') 28 | parser.add_argument('--gpu', type=int, default=0, help='gpu device id') 29 | parser.add_argument('--epochs', type=int, default=600, help='num of training epochs') 30 | parser.add_argument('--init_channels', type=int, default=36, help='num of init channels') 31 | parser.add_argument('--layers', type=int, default=20, help='total number of layers') 32 | parser.add_argument('--model_path', type=str, default='saved_models', help='path to save the model') 33 | parser.add_argument('--auxiliary', action='store_true', default=False, help='use auxiliary tower') 34 | parser.add_argument('--auxiliary_weight', type=float, default=0.4, help='weight for auxiliary loss') 35 | parser.add_argument('--cutout', action='store_true', default=False, help='use cutout') 36 | parser.add_argument('--cutout_length', type=int, default=16, help='cutout length') 37 | parser.add_argument('--drop_path_prob', type=float, default=0.2, help='drop path probability') 38 | parser.add_argument('--save', type=str, default='EXP', help='experiment name') 39 | parser.add_argument('--seed', type=int, default=3, help='random seed') 40 | parser.add_argument('--arch', type=str, default='DARTS', help='which architecture to use') 41 | parser.add_argument('--grad_clip', type=float, default=5, help='gradient clipping') 42 | args = parser.parse_args() 43 | 44 | args.save = 'eval-{}-{}'.format(args.save, time.strftime("%Y%m%d-%H%M%S")) 45 | utils.create_exp_dir(args.save, scripts_to_save=glob.glob('*.py')) 46 | 47 | log_format = '%(asctime)s %(message)s' 48 | logging.basicConfig(stream=sys.stdout, level=logging.INFO, 49 | format=log_format, datefmt='%m/%d %I:%M:%S %p') 50 | fh = logging.FileHandler(os.path.join(args.save, 'log.txt')) 51 | fh.setFormatter(logging.Formatter(log_format)) 52 | logging.getLogger().addHandler(fh) 53 | 54 | CIFAR_CLASSES = 10 55 | 56 | 57 | def main(): 58 | 59 | np.random.seed(args.seed) 60 | random.seed(args.seed) 61 | 62 | if torch.cuda.is_available(): 63 | device = torch.device('cuda:{}'.format(args.gpu)) 64 | cudnn.benchmark = False 65 | torch.manual_seed(args.seed) 66 | cudnn.enabled = True 67 | cudnn.deterministic = True 68 | torch.cuda.manual_seed(args.seed) 69 | logging.info('gpu device = %d' % args.gpu) 70 | else: 71 | device = torch.device('cpu') 72 | logging.info('No gpu device available') 73 | torch.manual_seed(args.seed) 74 | 75 | logging.info("args = %s", args) 76 | genotype = eval("genotypes.%s" % args.arch) 77 | model = Network(args.init_channels, CIFAR_CLASSES, args.layers, args.auxiliary, genotype) 78 | model = model.to(device) 79 | 80 | logging.info("param size = %fMB", utils.count_parameters_in_MB(model)) 81 | total_params = sum(x.data.nelement() for x in model.parameters()) 82 | logging.info('Model total parameters: {}'.format(total_params)) 83 | 84 | criterion = nn.CrossEntropyLoss() 85 | criterion = criterion.cuda() 86 | optimizer = torch.optim.SGD( 87 | model.parameters(), 88 | args.learning_rate, 89 | momentum=args.momentum, 90 | weight_decay=args.weight_decay 91 | ) 92 | 93 | train_transform, valid_transform = utils._data_transforms_cifar10(args.cutout, args.cutout_length) 94 | train_data = dset.CIFAR10(root=args.data, train=True, download=True, transform=train_transform) 95 | valid_data = dset.CIFAR10(root=args.data, train=False, download=True, transform=valid_transform) 96 | 97 | train_queue = torch.utils.data.DataLoader( 98 | train_data, batch_size=args.batch_size, shuffle=True, pin_memory=True, num_workers=4) 99 | 100 | valid_queue = torch.utils.data.DataLoader( 101 | valid_data, batch_size=args.batch_size*4, shuffle=False, pin_memory=True, num_workers=4) 102 | 103 | scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, float(args.epochs)) 104 | 105 | for epoch in range(args.epochs): 106 | 107 | logging.info('epoch %d lr %e', epoch, scheduler.get_lr()[0]) 108 | model.drop_path_prob = args.drop_path_prob * epoch / args.epochs 109 | 110 | train_acc, train_obj = train(train_queue, model, criterion, optimizer, args.gpu) 111 | logging.info('train_acc %f', train_acc) 112 | 113 | valid_acc, valid_obj = infer(valid_queue, model, criterion, args.gpu) 114 | logging.info('valid_acc %f', valid_acc) 115 | 116 | scheduler.step() 117 | 118 | utils.save(model, os.path.join(args.save, 'weights.pt')) 119 | 120 | 121 | def train(train_queue, model, criterion, optimizer, gpu_id=0): 122 | objs = utils.AvgrageMeter() 123 | top1 = utils.AvgrageMeter() 124 | top5 = utils.AvgrageMeter() 125 | model.train() 126 | 127 | device = torch.device('cuda:{}'.format(gpu_id) if torch.cuda.is_available() \ 128 | else 'cpu') 129 | 130 | 131 | for step, (input, target) in enumerate(train_queue): 132 | input = input.to(device) 133 | target = target.to(device) 134 | 135 | optimizer.zero_grad() 136 | logits, logits_aux = model(input) 137 | loss = criterion(logits, target) 138 | if args.auxiliary: 139 | loss_aux = criterion(logits_aux, target) 140 | loss += args.auxiliary_weight*loss_aux 141 | loss.backward() 142 | nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip) 143 | optimizer.step() 144 | 145 | prec1, prec5 = utils.accuracy(logits, target, topk=(1, 5)) 146 | n = input.size(0) 147 | objs.update(loss.item(), n) 148 | top1.update(prec1.item(), n) 149 | top5.update(prec5.item(), n) 150 | 151 | if step % args.report_freq == 0: 152 | logging.info('train %03d %e %f %f', step, objs.avg, top1.avg, top5.avg) 153 | 154 | return top1.avg, objs.avg 155 | 156 | 157 | def infer(valid_queue, model, criterion, gpu_id=0): 158 | with torch.no_grad(): 159 | objs = utils.AvgrageMeter() 160 | top1 = utils.AvgrageMeter() 161 | top5 = utils.AvgrageMeter() 162 | model.eval() 163 | 164 | device = torch.device('cuda:{}'.format(gpu_id) if torch.cuda.is_available() \ 165 | else 'cpu') 166 | 167 | for step, (input, target) in enumerate(valid_queue): 168 | input = input.to(device) 169 | target = target.to(device) 170 | 171 | logits, _ = model(input) 172 | loss = criterion(logits, target) 173 | 174 | prec1, prec5 = utils.accuracy(logits, target, topk=(1, 5)) 175 | n = input.size(0) 176 | objs.update(loss.item(), n) 177 | top1.update(prec1.item(), n) 178 | top5.update(prec5.item(), n) 179 | 180 | if step % args.report_freq == 0: 181 | logging.info('valid %03d %e %f %f', step, objs.avg, top1.avg, top5.avg) 182 | 183 | return top1.avg, objs.avg 184 | 185 | 186 | if __name__ == '__main__': 187 | main() 188 | 189 | -------------------------------------------------------------------------------- /darts/cnn/train_imagenet.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | sys.path.insert(0, os.getcwd()) 4 | import numpy as np 5 | import time 6 | import torch 7 | import darts.cnn.utils as utils 8 | import glob 9 | import random 10 | import logging 11 | import argparse 12 | import torch.nn as nn 13 | import darts.cnn.genotypes as genotypes 14 | import torch.utils 15 | import torchvision.datasets as dset 16 | import torchvision.transforms as transforms 17 | import torch.backends.cudnn as cudnn 18 | from darts.cnn.model import NetworkImageNet as Network 19 | from thop import profile 20 | 21 | 22 | parser = argparse.ArgumentParser("imagenet") 23 | parser.add_argument('--data', type=str, default='data/imagenet/', help='location of the data corpus') 24 | parser.add_argument('--batch_size', type=int, default=128, help='batch size') 25 | parser.add_argument('--learning_rate', type=float, default=0.1, help='init learning rate') 26 | parser.add_argument('--momentum', type=float, default=0.9, help='momentum') 27 | parser.add_argument('--weight_decay', type=float, default=3e-5, help='weight decay') 28 | parser.add_argument('--report_freq', type=float, default=100, help='report frequency') 29 | parser.add_argument('--gpu', type=int, default=0, help='gpu device id') 30 | parser.add_argument('--epochs', type=int, default=250, help='num of training epochs') 31 | parser.add_argument('--init_channels', type=int, default=48, help='num of init channels') 32 | parser.add_argument('--layers', type=int, default=14, help='total number of layers') 33 | parser.add_argument('--auxiliary', action='store_true', default=False, help='use auxiliary tower') 34 | parser.add_argument('--auxiliary_weight', type=float, default=0.4, help='weight for auxiliary loss') 35 | parser.add_argument('--drop_path_prob', type=float, default=0, help='drop path probability') 36 | parser.add_argument('--save', type=str, default='EXP', help='experiment name') 37 | parser.add_argument('--seed', type=int, default=0, help='random seed') 38 | parser.add_argument('--arch', type=str, default='DARTS', help='which architecture to use') 39 | parser.add_argument('--grad_clip', type=float, default=5., help='gradient clipping') 40 | parser.add_argument('--label_smooth', type=float, default=0.1, help='label smoothing') 41 | parser.add_argument('--gamma', type=float, default=0.97, help='learning rate decay') 42 | parser.add_argument('--decay_period', type=int, default=1, help='epochs between two learning rate decays') 43 | parser.add_argument('--parallel', action='store_true', default=False, help='data parallelism') 44 | args = parser.parse_args() 45 | 46 | args.save = 'eval-{}-{}'.format(args.save, time.strftime("%Y%m%d-%H%M%S")) 47 | utils.create_exp_dir(args.save, scripts_to_save=glob.glob('*.py')) 48 | 49 | log_format = '%(asctime)s %(message)s' 50 | logging.basicConfig(stream=sys.stdout, level=logging.INFO, 51 | format=log_format, datefmt='%m/%d %I:%M:%S %p') 52 | fh = logging.FileHandler(os.path.join(args.save, 'log.txt')) 53 | fh.setFormatter(logging.Formatter(log_format)) 54 | logging.getLogger().addHandler(fh) 55 | 56 | CLASSES = 1000 57 | 58 | 59 | class CrossEntropyLabelSmooth(nn.Module): 60 | 61 | def __init__(self, num_classes, epsilon): 62 | super(CrossEntropyLabelSmooth, self).__init__() 63 | self.num_classes = num_classes 64 | self.epsilon = epsilon 65 | self.logsoftmax = nn.LogSoftmax(dim=1) 66 | 67 | def forward(self, inputs, targets): 68 | log_probs = self.logsoftmax(inputs) 69 | targets = torch.zeros_like(log_probs).scatter_(1, targets.unsqueeze(1), 1) 70 | targets = (1 - self.epsilon) * targets + self.epsilon / self.num_classes 71 | loss = (-targets * log_probs).mean(0).sum() 72 | return loss 73 | 74 | 75 | def main(): 76 | if not torch.cuda.is_available(): 77 | logging.info('no gpu device available') 78 | sys.exit(1) 79 | 80 | np.random.seed(args.seed) 81 | torch.cuda.set_device(args.gpu) 82 | cudnn.benchmark = True 83 | torch.manual_seed(args.seed) 84 | cudnn.enabled=True 85 | torch.cuda.manual_seed(args.seed) 86 | logging.info('gpu device = %d' % args.gpu) 87 | logging.info("args = %s", args) 88 | 89 | genotype = eval("genotypes.%s" % args.arch) 90 | model = Network(args.init_channels, CLASSES, args.layers, args.auxiliary, genotype) 91 | 92 | if args.parallel: 93 | model = nn.DataParallel(model).cuda() 94 | else: 95 | model = model.cuda() 96 | 97 | #input = torch.randn(1,3,224,224).cuda() 98 | #macs, params = profile(model, inputs=(input,)) 99 | #print('flops: {}, params: {}'.format(macs, params)) #arch2vec_bo: 580M, 5.18M; arch2vec_rl: 533M, 4.82M 100 | #print("param size = %fMB", utils.count_parameters_in_MB(model)) 101 | #exit() 102 | 103 | logging.info("param size = %fMB", utils.count_parameters_in_MB(model)) 104 | 105 | criterion = nn.CrossEntropyLoss() 106 | criterion = criterion.cuda() 107 | criterion_smooth = CrossEntropyLabelSmooth(CLASSES, args.label_smooth) 108 | criterion_smooth = criterion_smooth.cuda() 109 | 110 | optimizer = torch.optim.SGD( 111 | model.parameters(), 112 | args.learning_rate, 113 | momentum=args.momentum, 114 | weight_decay=args.weight_decay 115 | ) 116 | 117 | traindir = os.path.join(args.data, 'train') 118 | validdir = os.path.join(args.data, 'val') 119 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) 120 | train_data = dset.ImageFolder( 121 | traindir, 122 | transforms.Compose([ 123 | transforms.RandomResizedCrop(224), 124 | transforms.RandomHorizontalFlip(), 125 | transforms.ColorJitter( 126 | brightness=0.4, 127 | contrast=0.4, 128 | saturation=0.4, 129 | hue=0.2), 130 | transforms.ToTensor(), 131 | normalize, 132 | ])) 133 | valid_data = dset.ImageFolder( 134 | validdir, 135 | transforms.Compose([ 136 | transforms.Resize(256), 137 | transforms.CenterCrop(224), 138 | transforms.ToTensor(), 139 | normalize, 140 | ])) 141 | 142 | train_queue = torch.utils.data.DataLoader( 143 | train_data, batch_size=args.batch_size, shuffle=True, pin_memory=True, num_workers=4) 144 | 145 | valid_queue = torch.utils.data.DataLoader( 146 | valid_data, batch_size=args.batch_size, shuffle=False, pin_memory=True, num_workers=4) 147 | 148 | scheduler = torch.optim.lr_scheduler.StepLR(optimizer, args.decay_period, gamma=args.gamma) 149 | 150 | best_acc_top1 = 0 151 | for epoch in range(args.epochs): 152 | logging.info('epoch %d lr %e', epoch, scheduler.get_lr()[0]) 153 | model.drop_path_prob = args.drop_path_prob * epoch / args.epochs 154 | 155 | train_acc, train_obj = train(train_queue, model, criterion_smooth, optimizer) 156 | logging.info('train_acc %f', train_acc) 157 | 158 | valid_acc_top1, valid_acc_top5, valid_obj = infer(valid_queue, model, criterion) 159 | logging.info('valid_acc_top1 %f', valid_acc_top1) 160 | logging.info('valid_acc_top5 %f', valid_acc_top5) 161 | 162 | is_best = False 163 | if valid_acc_top1 > best_acc_top1: 164 | best_acc_top1 = valid_acc_top1 165 | is_best = True 166 | 167 | utils.save_checkpoint({ 168 | 'epoch': epoch + 1, 169 | 'state_dict': model.state_dict(), 170 | 'best_acc_top1': best_acc_top1, 171 | 'optimizer' : optimizer.state_dict(), 172 | }, is_best, args.save) 173 | 174 | scheduler.step() 175 | 176 | 177 | def train(train_queue, model, criterion, optimizer): 178 | objs = utils.AvgrageMeter() 179 | top1 = utils.AvgrageMeter() 180 | top5 = utils.AvgrageMeter() 181 | model.train() 182 | 183 | for step, (input, target) in enumerate(train_queue): 184 | target = target.cuda() 185 | input = input.cuda() 186 | 187 | optimizer.zero_grad() 188 | logits, logits_aux = model(input) 189 | loss = criterion(logits, target) 190 | if args.auxiliary: 191 | loss_aux = criterion(logits_aux, target) 192 | loss += args.auxiliary_weight*loss_aux 193 | 194 | loss.backward() 195 | nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip) 196 | optimizer.step() 197 | 198 | prec1, prec5 = utils.accuracy(logits, target, topk=(1, 5)) 199 | n = input.size(0) 200 | objs.update(loss.item(), n) 201 | top1.update(prec1.item(), n) 202 | top5.update(prec5.item(), n) 203 | 204 | if step % args.report_freq == 0: 205 | logging.info('train %03d %e %f %f', step, objs.avg, top1.avg, top5.avg) 206 | 207 | return top1.avg, objs.avg 208 | 209 | 210 | def infer(valid_queue, model, criterion): 211 | objs = utils.AvgrageMeter() 212 | top1 = utils.AvgrageMeter() 213 | top5 = utils.AvgrageMeter() 214 | model.eval() 215 | 216 | for step, (input, target) in enumerate(valid_queue): 217 | with torch.no_grad(): 218 | input = input.cuda() 219 | target = target.cuda() 220 | 221 | logits, _ = model(input) 222 | loss = criterion(logits, target) 223 | 224 | prec1, prec5 = utils.accuracy(logits, target, topk=(1, 5)) 225 | n = input.size(0) 226 | objs.update(loss.item(), n) 227 | top1.update(prec1.item(), n) 228 | top5.update(prec5.item(), n) 229 | 230 | if step % args.report_freq == 0: 231 | logging.info('valid %03d %e %f %f', step, objs.avg, top1.avg, top5.avg) 232 | 233 | return top1.avg, top5.avg, objs.avg 234 | 235 | 236 | if __name__ == '__main__': 237 | main() 238 | -------------------------------------------------------------------------------- /darts/cnn/train_search.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | sys.path.insert(0, os.getcwd()) 4 | import time 5 | import glob 6 | import numpy as np 7 | import random 8 | import torch 9 | import darts.cnn.utils as utils 10 | import logging 11 | import torch.nn as nn 12 | import darts.cnn.genotypes 13 | import torch.utils 14 | import torchvision.datasets as dset 15 | import torch.backends.cudnn as cudnn 16 | from collections import namedtuple 17 | 18 | from darts.cnn.model import NetworkCIFAR as Network 19 | 20 | class Train: 21 | 22 | def __init__(self): 23 | 24 | self.data='./data' 25 | self.batch_size= 96 26 | self.learning_rate= 0.025 27 | self.momentum= 0.9 28 | self.weight_decay = 3e-4 29 | self.load_weights = 0 30 | self.report_freq = 500 31 | self.gpu = 0 32 | self.epochs = 50 33 | self.init_channels = 36 34 | self.layers = 20 35 | self.auxiliary = True 36 | self.auxiliary_weight = 0.4 37 | self.cutout = True 38 | self.cutout_length = 16 39 | self.drop_path_prob = 0.2 40 | self.save = 'EXP' 41 | self.seed = 0 42 | self.grad_clip = 5 43 | self.train_portion = 0.9 44 | self.validation_set = True 45 | self.CIFAR_CLASSES = 10 46 | 47 | def main(self, counter, seed, arch, epochs=50, gpu=0, load_weights=False, train_portion=0.9, save='model_search'): 48 | 49 | # Set up save file and logging 50 | self.save = save 51 | self.save = '{}'.format(self.save) 52 | utils.create_exp_dir(self.save, scripts_to_save=glob.glob('*.py')) 53 | log_format = '%(asctime)s %(message)s' 54 | logging.basicConfig(stream=sys.stdout, level=logging.INFO, 55 | format=log_format, datefmt='%m/%d %I:%M:%S %p') 56 | fh = logging.FileHandler(os.path.join(self.save, 'log-seed{}.txt'.format(seed))) 57 | fh.setFormatter(logging.Formatter(log_format)) 58 | logging.getLogger().addHandler(fh) 59 | 60 | 61 | self.arch = arch 62 | self.epochs = epochs 63 | self.load_weights = load_weights 64 | self.gpu = gpu 65 | self.train_portion = train_portion 66 | if self.train_portion == 1: 67 | self.validation_set = False 68 | self.seed = seed 69 | 70 | #logging.info('Train class params') 71 | #logging.info('arch: {}, epochs: {}, gpu: {}, load_weights: {}, train_portion: {}' 72 | # .format(arch, epochs, gpu, load_weights, train_portion)) 73 | 74 | # cpu-gpu switch 75 | if not torch.cuda.is_available(): 76 | #logging.info('no gpu device available') 77 | torch.manual_seed(self.seed) 78 | device = torch.device('cpu') 79 | 80 | else: 81 | torch.cuda.manual_seed_all(self.seed) 82 | random.seed(self.seed) 83 | torch.manual_seed(self.seed) 84 | device = torch.device(self.gpu) 85 | cudnn.benchmark = False 86 | cudnn.enabled=True 87 | cudnn.deterministic=True 88 | #logging.info('gpu device = %d' % self.gpu) 89 | 90 | Genotype = namedtuple('Genotype', 'normal normal_concat reduce reduce_concat') 91 | genotype = eval(self.convert_to_genotype(counter, arch)) 92 | model = Network(self.init_channels, self.CIFAR_CLASSES, self.layers, self.auxiliary, genotype) 93 | model = model.to(device) 94 | 95 | logging.info("param size = %fMB", utils.count_parameters_in_MB(model)) 96 | print("param size = {:.4f}MB".format(utils.count_parameters_in_MB(model))) 97 | total_params = sum(x.data.nelement() for x in model.parameters()) 98 | logging.info('Model total parameters: {}'.format(total_params)) 99 | print('Model total parameters: {}'.format(total_params)) 100 | 101 | criterion = nn.CrossEntropyLoss() 102 | criterion = criterion.to(device) 103 | optimizer = torch.optim.SGD( 104 | model.parameters(), 105 | self.learning_rate, 106 | momentum=self.momentum, 107 | weight_decay=self.weight_decay 108 | ) 109 | 110 | train_transform, test_transform = utils._data_transforms_cifar10(self.cutout, self.cutout_length) 111 | train_data = dset.CIFAR10(root=self.data, train=True, download=True, transform=train_transform) 112 | test_data = dset.CIFAR10(root=self.data, train=False, download=True, transform=test_transform) 113 | 114 | num_train = len(train_data) 115 | indices = list(range(num_train)) 116 | if self.validation_set: 117 | split = int(np.floor(self.train_portion * num_train)) 118 | else: 119 | split = num_train 120 | 121 | train_queue = torch.utils.data.DataLoader( 122 | train_data, batch_size=self.batch_size, 123 | sampler=torch.utils.data.sampler.SubsetRandomSampler(indices[:split]), 124 | pin_memory=True, num_workers=4) 125 | 126 | if self.validation_set: 127 | valid_queue = torch.utils.data.DataLoader( 128 | train_data, batch_size=self.batch_size, 129 | sampler=torch.utils.data.sampler.SubsetRandomSampler(indices[split:num_train]), 130 | pin_memory=True, num_workers=4) 131 | 132 | test_queue = torch.utils.data.DataLoader( 133 | test_data, batch_size=self.batch_size, shuffle=False, pin_memory=True, num_workers=4) 134 | 135 | if self.load_weights: 136 | logging.info('loading saved weights') 137 | ml = 'cuda:{}'.format(self.gpu) if torch.cuda.is_available() else 'cpu' 138 | model.load_state_dict(torch.load('weights.pt', map_location = ml)) 139 | logging.info('loaded saved weights') 140 | 141 | scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, float(self.epochs)) 142 | 143 | valid_accs = [] 144 | test_accs = [] 145 | 146 | for epoch in range(self.epochs): 147 | logging.info('epoch %d lr %e', epoch, scheduler.get_lr()[0]) 148 | print('epoch {} lr {}'.format(epoch, scheduler.get_lr()[0])) 149 | model.drop_path_prob = self.drop_path_prob * epoch / self.epochs 150 | 151 | train_acc, train_obj = self.train(train_queue, model, criterion, optimizer) 152 | 153 | if self.validation_set: 154 | valid_acc, valid_obj = self.infer(valid_queue, model, criterion) 155 | else: 156 | valid_acc, valid_obj = 0, 0 157 | 158 | test_acc, test_obj = self.infer(test_queue, model, criterion, test_data=True) 159 | logging.info('train_acc: {:.4f}, valid_acc: {:.4f}, test_acc: {:.4f}'.format(train_acc, valid_acc, test_acc)) 160 | print('train_acc: {:.4f}, valid_acc: {:.4f}, test_acc: {:.4f}'.format(train_acc, valid_acc, test_acc)) 161 | 162 | #utils.save(model, os.path.join(self.save, 'weights-seed-{}.pt'.format(seed))) 163 | 164 | if epoch in list(range(max(0, epochs - 5), epochs)): 165 | valid_accs.append((epoch, valid_acc)) 166 | test_accs.append((epoch, test_acc)) 167 | 168 | scheduler.step() 169 | 170 | return valid_accs, test_accs 171 | 172 | 173 | def convert_to_genotype(self, counter, arch): 174 | Genotype = namedtuple('Genotype', 'normal normal_concat reduce reduce_concat') 175 | geno = [] 176 | for item in arch: 177 | geno.append((item[1], int(item[0]))) 178 | geno = Genotype(normal=geno, normal_concat=[2,3,4,5], reduce=geno, reduce_concat=[2,3,4,5]) 179 | logging.info('counter: {}, genotypes: {}'.format(counter, str(geno))) 180 | print('counter: {}, genotypes: {}'.format(counter, str(geno))) 181 | return str(geno) 182 | 183 | 184 | def train(self, train_queue, model, criterion, optimizer): 185 | objs = utils.AvgrageMeter() 186 | top1 = utils.AvgrageMeter() 187 | top5 = utils.AvgrageMeter() 188 | model.train() 189 | 190 | for step, (input, target) in enumerate(train_queue): 191 | device = torch.device('cuda:{}'.format(self.gpu) if torch.cuda.is_available() else 'cpu') 192 | input = input.to(device) 193 | target = target.to(device) 194 | 195 | optimizer.zero_grad() 196 | logits, logits_aux = model(input) 197 | loss = criterion(logits, target) 198 | if self.auxiliary: 199 | loss_aux = criterion(logits_aux, target) 200 | loss += self.auxiliary_weight*loss_aux 201 | loss.backward() 202 | nn.utils.clip_grad_norm_(model.parameters(), self.grad_clip) 203 | optimizer.step() 204 | 205 | prec1, prec5 = utils.accuracy(logits, target, topk=(1, 5)) 206 | n = input.size(0) 207 | 208 | objs.update(loss.item(), n) 209 | top1.update(prec1.item(), n) 210 | top5.update(prec5.item(), n) 211 | 212 | 213 | return top1.avg, objs.avg 214 | 215 | 216 | def infer(self, valid_queue, model, criterion, test_data=False): 217 | objs = utils.AvgrageMeter() 218 | top1 = utils.AvgrageMeter() 219 | top5 = utils.AvgrageMeter() 220 | model.eval() 221 | device = torch.device('cuda:{}'.format(self.gpu) if torch.cuda.is_available() else 'cpu') 222 | 223 | for step, (input, target) in enumerate(valid_queue): 224 | with torch.no_grad(): 225 | input = input.to(device) 226 | target = target.to(device) 227 | 228 | logits, _ = model(input) 229 | loss = criterion(logits, target) 230 | 231 | prec1, prec5 = utils.accuracy(logits, target, topk=(1, 5)) 232 | n = input.size(0) 233 | 234 | objs.update(loss.item(), n) 235 | top1.update(prec1.item(), n) 236 | top5.update(prec5.item(), n) 237 | 238 | #if step % self.report_freq == 0: 239 | # if not test_data: 240 | # logging.info('valid %03d %e %f %f', step, objs.avg, top1.avg, top5.avg) 241 | # else: 242 | # logging.info('test %03d %e %f %f', step, objs.avg, top1.avg, top5.avg) 243 | 244 | return top1.avg, objs.avg 245 | 246 | 247 | -------------------------------------------------------------------------------- /darts/cnn/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import torch 4 | import shutil 5 | import torchvision.transforms as transforms 6 | 7 | 8 | class AvgrageMeter(object): 9 | 10 | def __init__(self): 11 | self.reset() 12 | 13 | def reset(self): 14 | self.avg = 0 15 | self.sum = 0 16 | self.cnt = 0 17 | 18 | def update(self, val, n=1): 19 | self.sum += val * n 20 | self.cnt += n 21 | self.avg = self.sum / self.cnt 22 | 23 | 24 | def accuracy(output, target, topk=(1,)): 25 | maxk = max(topk) 26 | batch_size = target.size(0) 27 | 28 | _, pred = output.topk(maxk, 1, True, True) 29 | pred = pred.t() 30 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 31 | 32 | res = [] 33 | for k in topk: 34 | correct_k = correct[:k].view(-1).float().sum(0) 35 | res.append(correct_k.mul_(100.0/batch_size)) 36 | return res 37 | 38 | 39 | class Cutout(object): 40 | def __init__(self, length): 41 | self.length = length 42 | 43 | def __call__(self, img): 44 | h, w = img.size(1), img.size(2) 45 | mask = np.ones((h, w), np.float32) 46 | y = np.random.randint(h) 47 | x = np.random.randint(w) 48 | 49 | y1 = np.clip(y - self.length // 2, 0, h) 50 | y2 = np.clip(y + self.length // 2, 0, h) 51 | x1 = np.clip(x - self.length // 2, 0, w) 52 | x2 = np.clip(x + self.length // 2, 0, w) 53 | 54 | mask[y1: y2, x1: x2] = 0. 55 | mask = torch.from_numpy(mask) 56 | mask = mask.expand_as(img) 57 | img *= mask 58 | return img 59 | 60 | 61 | def _data_transforms_cifar10(cutout, cutout_length): 62 | CIFAR_MEAN = [0.49139968, 0.48215827, 0.44653124] 63 | CIFAR_STD = [0.24703233, 0.24348505, 0.26158768] 64 | 65 | train_transform = transforms.Compose([ 66 | transforms.RandomCrop(32, padding=4), 67 | transforms.RandomHorizontalFlip(), 68 | transforms.ToTensor(), 69 | transforms.Normalize(CIFAR_MEAN, CIFAR_STD), 70 | ]) 71 | if cutout: 72 | train_transform.transforms.append(Cutout(cutout_length)) 73 | 74 | valid_transform = transforms.Compose([ 75 | transforms.ToTensor(), 76 | transforms.Normalize(CIFAR_MEAN, CIFAR_STD), 77 | ]) 78 | return train_transform, valid_transform 79 | 80 | 81 | def count_parameters_in_MB(model): 82 | return np.sum(np.prod(v.size()) for name, v in model.named_parameters() if "auxiliary" not in name)/1e6 83 | 84 | 85 | def save_checkpoint(state, is_best, save): 86 | filename = os.path.join(save, 'checkpoint.pth.tar') 87 | torch.save(state, filename) 88 | if is_best: 89 | best_filename = os.path.join(save, 'model_best.pth.tar') 90 | shutil.copyfile(filename, best_filename) 91 | 92 | 93 | def save(model, model_path): 94 | torch.save(model.state_dict(), model_path) 95 | 96 | 97 | def load(model, model_path, gpu_id): 98 | ml = 'cuda:{}'.format(gpu_id) if torch.cuda.is_available() else 'cpu' 99 | model.load_state_dict(torch.load(model_path, map_location = ml), strict=False) 100 | 101 | 102 | 103 | def drop_path(x, drop_prob): 104 | if drop_prob > 0.: 105 | keep_prob = 1.-drop_prob 106 | mask = torch.cuda.FloatTensor(x.size(0), 1, 1, 1).bernoulli_(keep_prob) 107 | x.div_(keep_prob) 108 | x.mul_(mask) 109 | return x 110 | 111 | 112 | def create_exp_dir(path, scripts_to_save=None): 113 | if not os.path.exists(path): 114 | os.mkdir(path) 115 | print('Experiment dir : {}'.format(path)) 116 | 117 | if scripts_to_save is not None: 118 | os.mkdir(os.path.join(path, 'scripts')) 119 | for script in scripts_to_save: 120 | dst_file = os.path.join(path, 'scripts', os.path.basename(script)) 121 | shutil.copyfile(script, dst_file) 122 | 123 | -------------------------------------------------------------------------------- /darts/cnn/visualize.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import genotypes 3 | from graphviz import Digraph 4 | 5 | 6 | def plot(genotype, filename): 7 | g = Digraph( 8 | format='pdf', 9 | edge_attr=dict(fontsize='20', fontname="times"), 10 | node_attr=dict(style='filled', shape='rect', align='center', fontsize='20', height='0.5', width='0.5', penwidth='2', fontname="times"), 11 | engine='dot') 12 | g.body.extend(['rankdir=LR']) 13 | 14 | g.node("c_{k-2}", fillcolor='darkseagreen2') 15 | g.node("c_{k-1}", fillcolor='darkseagreen2') 16 | assert len(genotype) % 2 == 0 17 | steps = len(genotype) // 2 18 | 19 | for i in range(steps): 20 | g.node(str(i), fillcolor='lightblue') 21 | 22 | for i in range(steps): 23 | for k in [2*i, 2*i + 1]: 24 | op, j = genotype[k] 25 | if j == 0: 26 | u = "c_{k-2}" 27 | elif j == 1: 28 | u = "c_{k-1}" 29 | else: 30 | u = str(j-2) 31 | v = str(i) 32 | g.edge(u, v, label=op, fillcolor="gray") 33 | 34 | g.node("c_{k}", fillcolor='palegoldenrod') 35 | for i in range(steps): 36 | g.edge(str(i), "c_{k}", fillcolor="gray") 37 | 38 | g.render(filename, view=True) 39 | 40 | 41 | if __name__ == '__main__': 42 | if len(sys.argv) != 2: 43 | print("usage:\n python {} ARCH_NAME".format(sys.argv[0])) 44 | sys.exit(1) 45 | 46 | genotype_name = sys.argv[1] 47 | try: 48 | genotype = eval('genotypes.{}'.format(genotype_name)) 49 | except AttributeError: 50 | print("{} is not specified in genotypes.py".format(genotype_name)) 51 | sys.exit(1) 52 | 53 | plot(genotype.normal, "normal & reduction") 54 | 55 | -------------------------------------------------------------------------------- /docs/arch.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AIoT-MLSys-Lab/arch2vec/ea01b0cf1295305596ee3c05fa1b6eb14e303512/docs/arch.png -------------------------------------------------------------------------------- /gin/models/graphcnn.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | import sys 6 | sys.path.append("models/") 7 | from gin.models.mlp import MLP 8 | 9 | class GraphCNN(nn.Module): 10 | def __init__(self, num_layers, num_mlp_layers, input_dim, hidden_dim, output_dim, final_dropout, learn_eps, graph_pooling_type, neighbor_pooling_type, device): 11 | ''' 12 | num_layers: number of layers in the neural networks (INCLUDING the input layer) 13 | num_mlp_layers: number of layers in mlps (EXCLUDING the input layer) 14 | input_dim: dimensionality of input features 15 | hidden_dim: dimensionality of hidden units at ALL layers 16 | output_dim: number of classes for prediction 17 | final_dropout: dropout ratio on the final linear layer 18 | learn_eps: If True, learn epsilon to distinguish center nodes from neighboring nodes. If False, aggregate neighbors and center nodes altogether. 19 | neighbor_pooling_type: how to aggregate neighbors (mean, average, or max) 20 | graph_pooling_type: how to aggregate entire nodes in a graph (mean, average) 21 | device: which device to use 22 | ''' 23 | 24 | super(GraphCNN, self).__init__() 25 | 26 | self.final_dropout = final_dropout 27 | self.device = device 28 | self.num_layers = num_layers 29 | self.graph_pooling_type = graph_pooling_type 30 | self.neighbor_pooling_type = neighbor_pooling_type 31 | self.learn_eps = learn_eps 32 | self.eps = nn.Parameter(torch.zeros(self.num_layers-1)) 33 | 34 | ###List of MLPs 35 | self.mlps = torch.nn.ModuleList() 36 | 37 | ###List of batchnorms applied to the output of MLP (input of the final prediction linear layer) 38 | self.batch_norms = torch.nn.ModuleList() 39 | 40 | for layer in range(self.num_layers-1): 41 | if layer == 0: 42 | self.mlps.append(MLP(num_mlp_layers, input_dim, hidden_dim, hidden_dim)) 43 | else: 44 | self.mlps.append(MLP(num_mlp_layers, hidden_dim, hidden_dim, hidden_dim)) 45 | 46 | self.batch_norms.append(nn.BatchNorm1d(hidden_dim)) 47 | 48 | #Linear function that maps the hidden representation at dofferemt layers into a prediction score 49 | self.linears_prediction = torch.nn.ModuleList() 50 | for layer in range(num_layers): 51 | if layer == 0: 52 | self.linears_prediction.append(nn.Linear(input_dim, output_dim)) 53 | else: 54 | self.linears_prediction.append(nn.Linear(hidden_dim, output_dim)) 55 | 56 | 57 | def __preprocess_neighbors_maxpool(self, batch_graph): 58 | ###create padded_neighbor_list in concatenated graph 59 | 60 | #compute the maximum number of neighbors within the graphs in the current minibatch 61 | max_deg = max([graph.max_neighbor for graph in batch_graph]) 62 | 63 | padded_neighbor_list = [] 64 | start_idx = [0] 65 | 66 | 67 | for i, graph in enumerate(batch_graph): 68 | start_idx.append(start_idx[i] + len(graph.g)) 69 | padded_neighbors = [] 70 | for j in range(len(graph.neighbors)): 71 | #add off-set values to the neighbor indices 72 | pad = [n + start_idx[i] for n in graph.neighbors[j]] 73 | #padding, dummy data is assumed to be stored in -1 74 | pad.extend([-1]*(max_deg - len(pad))) 75 | 76 | #Add center nodes in the maxpooling if learn_eps is False, i.e., aggregate center nodes and neighbor nodes altogether. 77 | if not self.learn_eps: 78 | pad.append(j + start_idx[i]) 79 | 80 | padded_neighbors.append(pad) 81 | padded_neighbor_list.extend(padded_neighbors) 82 | 83 | return torch.LongTensor(padded_neighbor_list) 84 | 85 | 86 | def __preprocess_neighbors_sumavepool(self, batch_graph): 87 | ###create block diagonal sparse matrix 88 | 89 | edge_mat_list = [] 90 | start_idx = [0] 91 | for i, graph in enumerate(batch_graph): 92 | start_idx.append(start_idx[i] + len(graph.g)) 93 | edge_mat_list.append(graph.edge_mat + start_idx[i]) 94 | Adj_block_idx = torch.cat(edge_mat_list, 1) 95 | Adj_block_elem = torch.ones(Adj_block_idx.shape[1]) 96 | 97 | #Add self-loops in the adjacency matrix if learn_eps is False, i.e., aggregate center nodes and neighbor nodes altogether. 98 | 99 | if not self.learn_eps: 100 | num_node = start_idx[-1] 101 | self_loop_edge = torch.LongTensor([range(num_node), range(num_node)]) 102 | elem = torch.ones(num_node) 103 | Adj_block_idx = torch.cat([Adj_block_idx, self_loop_edge], 1) 104 | Adj_block_elem = torch.cat([Adj_block_elem, elem], 0) 105 | 106 | Adj_block = torch.sparse.FloatTensor(Adj_block_idx, Adj_block_elem, torch.Size([start_idx[-1],start_idx[-1]])) 107 | 108 | return Adj_block.to(self.device) 109 | 110 | 111 | def __preprocess_graphpool(self, batch_graph): 112 | ###create sum or average pooling sparse matrix over entire nodes in each graph (num graphs x num nodes) 113 | 114 | start_idx = [0] 115 | 116 | #compute the padded neighbor list 117 | for i, graph in enumerate(batch_graph): 118 | start_idx.append(start_idx[i] + len(graph.g)) 119 | 120 | idx = [] 121 | elem = [] 122 | for i, graph in enumerate(batch_graph): 123 | ###average pooling 124 | if self.graph_pooling_type == "average": 125 | elem.extend([1./len(graph.g)]*len(graph.g)) 126 | 127 | else: 128 | ###sum pooling 129 | elem.extend([1]*len(graph.g)) 130 | 131 | idx.extend([[i, j] for j in range(start_idx[i], start_idx[i+1], 1)]) 132 | elem = torch.FloatTensor(elem) 133 | idx = torch.LongTensor(idx).transpose(0,1) 134 | graph_pool = torch.sparse.FloatTensor(idx, elem, torch.Size([len(batch_graph), start_idx[-1]])) 135 | 136 | return graph_pool.to(self.device) 137 | 138 | def maxpool(self, h, padded_neighbor_list): 139 | ###Element-wise minimum will never affect max-pooling 140 | 141 | dummy = torch.min(h, dim = 0)[0] 142 | h_with_dummy = torch.cat([h, dummy.reshape((1, -1)).to(self.device)]) 143 | pooled_rep = torch.max(h_with_dummy[padded_neighbor_list], dim = 1)[0] 144 | return pooled_rep 145 | 146 | 147 | def next_layer_eps(self, h, layer, padded_neighbor_list = None, Adj_block = None): 148 | ###pooling neighboring nodes and center nodes separately by epsilon reweighting. 149 | 150 | if self.neighbor_pooling_type == "max": 151 | ##If max pooling 152 | pooled = self.maxpool(h, padded_neighbor_list) 153 | else: 154 | #If sum or average pooling 155 | pooled = torch.spmm(Adj_block, h) 156 | if self.neighbor_pooling_type == "average": 157 | #If average pooling 158 | degree = torch.spmm(Adj_block, torch.ones((Adj_block.shape[0], 1)).to(self.device)) 159 | pooled = pooled/degree 160 | 161 | #Reweights the center node representation when aggregating it with its neighbors 162 | pooled = pooled + (1 + self.eps[layer])*h 163 | pooled_rep = self.mlps[layer](pooled) 164 | h = self.batch_norms[layer](pooled_rep) 165 | 166 | #non-linearity 167 | h = F.relu(h) 168 | return h 169 | 170 | 171 | def next_layer(self, h, layer, padded_neighbor_list = None, Adj_block = None): 172 | ###pooling neighboring nodes and center nodes altogether 173 | 174 | if self.neighbor_pooling_type == "max": 175 | ##If max pooling 176 | pooled = self.maxpool(h, padded_neighbor_list) 177 | else: 178 | #If sum or average pooling 179 | pooled = torch.spmm(Adj_block, h) 180 | if self.neighbor_pooling_type == "average": 181 | #If average pooling 182 | degree = torch.spmm(Adj_block, torch.ones((Adj_block.shape[0], 1)).to(self.device)) 183 | pooled = pooled/degree 184 | 185 | #representation of neighboring and center nodes 186 | pooled_rep = self.mlps[layer](pooled) 187 | 188 | h = self.batch_norms[layer](pooled_rep) 189 | 190 | #non-linearity 191 | h = F.relu(h) 192 | return h 193 | 194 | 195 | def forward(self, batch_graph): 196 | X_concat = torch.cat([graph.node_features for graph in batch_graph], 0).to(self.device) 197 | graph_pool = self.__preprocess_graphpool(batch_graph) 198 | 199 | if self.neighbor_pooling_type == "max": 200 | padded_neighbor_list = self.__preprocess_neighbors_maxpool(batch_graph) 201 | else: 202 | Adj_block = self.__preprocess_neighbors_sumavepool(batch_graph) 203 | 204 | #list of hidden representation at each layer (including input) 205 | hidden_rep = [X_concat] 206 | h = X_concat 207 | 208 | for layer in range(self.num_layers-1): 209 | if self.neighbor_pooling_type == "max" and self.learn_eps: 210 | h = self.next_layer_eps(h, layer, padded_neighbor_list = padded_neighbor_list) 211 | elif not self.neighbor_pooling_type == "max" and self.learn_eps: 212 | h = self.next_layer_eps(h, layer, Adj_block = Adj_block) 213 | elif self.neighbor_pooling_type == "max" and not self.learn_eps: 214 | h = self.next_layer(h, layer, padded_neighbor_list = padded_neighbor_list) 215 | elif not self.neighbor_pooling_type == "max" and not self.learn_eps: 216 | h = self.next_layer(h, layer, Adj_block = Adj_block) 217 | 218 | hidden_rep.append(h) 219 | 220 | score_over_layer = 0 221 | 222 | #perform pooling over all nodes in each graph in every layer 223 | for layer, h in enumerate(hidden_rep): 224 | pooled_h = torch.spmm(graph_pool, h) 225 | score_over_layer += F.dropout(self.linears_prediction[layer](pooled_h), self.final_dropout, training = self.training) 226 | 227 | return score_over_layer 228 | -------------------------------------------------------------------------------- /gin/models/mlp.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | ###MLP with lienar output 6 | class MLP(nn.Module): 7 | def __init__(self, num_layers, input_dim, hidden_dim, output_dim): 8 | ''' 9 | num_layers: number of layers in the neural networks (EXCLUDING the input layer). If num_layers=1, this reduces to linear model. 10 | input_dim: dimensionality of input features 11 | hidden_dim: dimensionality of hidden units at ALL layers 12 | output_dim: number of classes for prediction 13 | device: which device to use 14 | ''' 15 | 16 | super(MLP, self).__init__() 17 | 18 | self.linear_or_not = True #default is linear model 19 | self.num_layers = num_layers 20 | 21 | if num_layers < 1: 22 | raise ValueError("number of layers should be positive!") 23 | elif num_layers == 1: 24 | #Linear model 25 | self.linear = nn.Linear(input_dim, output_dim) 26 | else: 27 | #Multi-layer model 28 | self.linear_or_not = False 29 | self.linears = torch.nn.ModuleList() 30 | self.batch_norms = torch.nn.ModuleList() 31 | 32 | self.linears.append(nn.Linear(input_dim, hidden_dim)) 33 | for layer in range(num_layers - 2): 34 | self.linears.append(nn.Linear(hidden_dim, hidden_dim)) 35 | self.linears.append(nn.Linear(hidden_dim, output_dim)) 36 | 37 | for layer in range(num_layers - 1): 38 | self.batch_norms.append(nn.BatchNorm1d((hidden_dim))) 39 | 40 | def forward(self, x): 41 | if self.linear_or_not: 42 | #If linear model 43 | return self.linear(x) 44 | else: 45 | #If MLP 46 | h = x 47 | for layer in range(self.num_layers - 1): 48 | h = F.relu(self.batch_norms[layer](self.linears[layer](h))) 49 | return self.linears[self.num_layers - 1](h) -------------------------------------------------------------------------------- /models/configs.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | configs = [{'GAE': # 0 6 | {'activation_ops':torch.sigmoid}, 7 | 'loss': 8 | {'loss_ops':F.mse_loss, 'loss_adj':F.mse_loss}, 9 | 'prep': 10 | {'method':3, 'lbd':0.5} 11 | }, 12 | {'GAE': # 1 13 | {'activation_ops':torch.softmax}, 14 | 'loss': 15 | {'loss_ops':nn.BCELoss(), 'loss_adj':nn.BCELoss()}, 16 | 'prep': 17 | {'method':3, 'lbd':0.5} 18 | }, 19 | {'GAE': # 2 20 | {'activation_ops': torch.softmax}, 21 | 'loss': 22 | {'loss_ops': F.mse_loss, 'loss_adj': nn.BCELoss()}, 23 | 'prep': 24 | {'method':3, 'lbd':0.5} 25 | }, 26 | {'GAE':# 3 27 | {'activation_ops':torch.sigmoid}, 28 | 'loss': 29 | {'loss_ops':F.mse_loss, 'loss_adj':F.mse_loss}, 30 | 'prep': 31 | {'method':4, 'lbd':1.0} 32 | }, 33 | {'GAE': # 4 34 | {'activation_adj': torch.sigmoid, 'activation_ops': torch.softmax, 'adj_hidden_dim': 128, 'ops_hidden_dim': 128}, 35 | 'loss': 36 | {'loss_ops': nn.BCELoss(), 'loss_adj': nn.BCELoss()}, 37 | 'prep': 38 | {'method': 4, 'lbd': 1.0} 39 | }, 40 | ] 41 | -------------------------------------------------------------------------------- /models/layers.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import torch.nn.functional as F 4 | import torch.nn as nn 5 | from torch.nn.parameter import Parameter 6 | 7 | class GraphConvolution(nn.Module): 8 | def __init__(self, in_features, out_features, dropout=0., bias=True): 9 | super(GraphConvolution, self).__init__() 10 | self.in_features = in_features 11 | self.out_features = out_features 12 | self.weight = Parameter(torch.Tensor(out_features, in_features)) 13 | if bias: 14 | self.bias = Parameter(torch.Tensor(out_features)) 15 | else: 16 | self.register_parameter('bias', None) 17 | self.reset_parameters() 18 | self.dropout = dropout 19 | 20 | def reset_parameters(self): 21 | stdv = 1. / math.sqrt(self.weight.size(1)) 22 | torch.nn.init.kaiming_uniform_(self.weight) 23 | if self.bias is not None: 24 | self.bias.data.uniform_(-stdv, stdv) 25 | 26 | def forward(self, ops, adj): 27 | ops = F.dropout(ops, self.dropout, self.training) 28 | support = F.linear(ops, self.weight) 29 | output = F.relu(torch.matmul(adj, support)) 30 | 31 | if self.bias is not None: 32 | return output + self.bias 33 | else: 34 | return output 35 | 36 | def __repr__(self): 37 | return self.__class__.__name__ + '(' + str(self.in_features) + '->' + str(self.out_features) + ')' 38 | -------------------------------------------------------------------------------- /models/pretraining_darts.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | sys.path.insert(0, os.getcwd()) 4 | import torch 5 | import torch.nn as nn 6 | import numpy as np 7 | import argparse 8 | from nasbench.lib import graph_util 9 | from torch import optim 10 | from models.model import Model, VAEReconstructed_Loss 11 | from utils.utils import load_json, save_checkpoint_vae, preprocessing, one_hot_darts, to_ops_darts 12 | from utils.utils import get_val_acc_vae, is_valid_darts 13 | from models.configs import configs 14 | 15 | 16 | def process(geno): 17 | for i, item in enumerate(geno): 18 | geno[i] = tuple(geno[i]) 19 | return geno 20 | 21 | def _build_dataset(dataset): 22 | print(""" loading dataset """) 23 | X_adj = [] 24 | X_ops = [] 25 | for k, v in dataset.items(): 26 | adj = v[0] 27 | ops = v[1] 28 | X_adj.append(torch.Tensor(adj)) 29 | X_ops.append(torch.Tensor(one_hot_darts(ops))) 30 | 31 | X_adj = torch.stack(X_adj) 32 | X_ops = torch.stack(X_ops) 33 | 34 | X_adj_train, X_adj_val = X_adj[:int(X_adj.shape[0]*0.9)], X_adj[int(X_adj.shape[0]*0.9):] 35 | X_ops_train, X_ops_val = X_ops[:int(X_ops.shape[0]*0.9)], X_ops[int(X_ops.shape[0]*0.9):] 36 | indices = torch.randperm(X_adj_train.shape[0]) 37 | indices_val = torch.randperm(X_adj_val.shape[0]) 38 | X_adj = X_adj_train[indices] 39 | X_ops = X_ops_train[indices] 40 | X_adj_val = X_adj_val[indices_val] 41 | X_ops_val = X_ops_val[indices_val] 42 | 43 | return X_adj, X_ops, indices, X_adj_val, X_ops_val, indices_val 44 | 45 | 46 | def pretraining_gae(dataset, cfg): 47 | """ implementation of VGAE pretraining on DARTS Search Space """ 48 | X_adj, X_ops, indices, X_adj_val, X_ops_val, indices_val = _build_dataset(dataset) 49 | print('train set size: {}, validation set size: {}'.format(indices.shape[0], indices_val.shape[0])) 50 | model = Model(input_dim=args.input_dim, hidden_dim=args.hidden_dim, latent_dim=args.dim, 51 | num_hops=args.hops, num_mlp_layers=args.mlps, dropout=args.dropout, **cfg['GAE']).cuda() 52 | optimizer = optim.Adam(model.parameters(), lr=1e-3, betas=(0.9, 0.999), eps=1e-08) 53 | epochs = args.epochs 54 | bs = args.bs 55 | loss_total = [] 56 | best_graph_acc = 0 57 | for epoch in range(0, epochs): 58 | chunks = X_adj.shape[0] // bs 59 | if X_adj.shape[0] % bs > 0: 60 | chunks += 1 61 | X_adj_split = torch.split(X_adj, bs, dim=0) 62 | X_ops_split = torch.split(X_ops, bs, dim=0) 63 | indices_split = torch.split(indices, bs, dim=0) 64 | loss_epoch = [] 65 | Z = [] 66 | for i, (adj, ops, ind) in enumerate(zip(X_adj_split, X_ops_split, indices_split)): 67 | optimizer.zero_grad() 68 | adj, ops = adj.cuda(), ops.cuda() 69 | # preprocessing 70 | adj, ops, prep_reverse = preprocessing(adj, ops, **cfg['prep']) 71 | # forward 72 | ops_recon, adj_recon, mu, logvar = model(ops, adj) 73 | Z.append(mu) 74 | adj_recon, ops_recon = prep_reverse(adj_recon, ops_recon) 75 | adj, ops = prep_reverse(adj, ops) 76 | loss = VAEReconstructed_Loss(**cfg['loss'])((ops_recon, adj_recon), (ops, adj), mu, logvar) 77 | loss.backward() 78 | nn.utils.clip_grad_norm_(model.parameters(), 5) 79 | optimizer.step() 80 | loss_epoch.append(loss.item()) 81 | if i % 500 == 0: 82 | print('epoch {}: batch {} / {}: loss: {:.5f}'.format(epoch, i, chunks, loss.item())) 83 | Z = torch.cat(Z, dim=0) 84 | z_mean, z_std = Z.mean(0), Z.std(0) 85 | validity_counter = 0 86 | buckets = {} 87 | model.eval() 88 | for _ in range(args.latent_points): 89 | z = torch.randn(11, args.dim).cuda() 90 | z = z * z_std + z_mean 91 | op, ad = model.decoder(z.unsqueeze(0)) 92 | op = op.squeeze(0).cpu() 93 | ad = ad.squeeze(0).cpu() 94 | max_idx = torch.argmax(op, dim=-1) 95 | one_hot = torch.zeros_like(op) 96 | for i in range(one_hot.shape[0]): 97 | one_hot[i][max_idx[i]] = 1 98 | op_decode = to_ops_darts(max_idx) 99 | ad_decode = (ad>0.5).int().triu(1).numpy() 100 | ad_decode = np.ndarray.tolist(ad_decode) 101 | if is_valid_darts(ad_decode, op_decode): 102 | validity_counter += 1 103 | fingerprint = graph_util.hash_module(np.array(ad_decode), one_hot.numpy().tolist()) 104 | if fingerprint not in buckets: 105 | buckets[fingerprint] = (ad_decode, one_hot.numpy().astype('int8').tolist()) 106 | validity = validity_counter / args.latent_points 107 | print('Ratio of valid decodings from the prior: {:.4f}'.format(validity)) 108 | print('Ratio of unique decodings from the prior: {:.4f}'.format(len(buckets) / (validity_counter+1e-8))) 109 | 110 | acc_ops_val, mean_corr_adj_val, mean_fal_pos_adj_val, acc_adj_val = get_val_acc_vae(model,cfg,X_adj_val, X_ops_val,indices_val) 111 | print('validation set: acc_ops:{0:.2f}, mean_corr_adj:{1:.2f}, mean_fal_pos_adj:{2:.2f}, acc_adj:{3:.2f}'.format( 112 | acc_ops_val, mean_corr_adj_val, mean_fal_pos_adj_val, acc_adj_val)) 113 | 114 | #print("reconstructed adj matrix:", adj_recon[1]) 115 | #print("original adj matrix:", adj[1]) 116 | #print("reconstructed ops matrix:", ops_recon[1]) 117 | #print("original ops matrix:", ops[1]) 118 | 119 | print('epoch {}: average loss {:.5f}'.format(epoch, sum(loss_epoch)/len(loss_epoch))) 120 | loss_total.append(sum(loss_epoch) / len(loss_epoch)) 121 | print('loss for epochs: \n', loss_total) 122 | save_checkpoint_vae(model, optimizer, epoch, sum(loss_epoch) / len(loss_epoch), args.dim, args.name, args.dropout, args.seed) 123 | 124 | 125 | print('loss for epochs: ', loss_total) 126 | 127 | 128 | if __name__ == '__main__': 129 | parser = argparse.ArgumentParser(description='Pretraining') 130 | parser.add_argument("--seed", type=int, default=3, help="random seed") 131 | parser.add_argument('--data', type=str, default='data/data_darts_counter600000.json', 132 | help='Data file (default: data.json') 133 | parser.add_argument('--name', type=str, default='darts') 134 | parser.add_argument('--cfg', type=int, default=4, 135 | help='configuration (default: 4)') 136 | parser.add_argument('--bs', type=int, default=32, 137 | help='batch size (default: 32)') 138 | parser.add_argument('--epochs', type=int, default=10, 139 | help='training epochs (default: 10)') 140 | parser.add_argument('--dropout', type=float, default=0.3, 141 | help='decoder implicit regularization (default: 0.3)') 142 | parser.add_argument('--normalize', action='store_true', default=True, 143 | help='use input normalization') 144 | parser.add_argument('--input_dim', type=int, default=11) 145 | parser.add_argument('--hidden_dim', type=int, default=128) 146 | parser.add_argument('--dim', type=int, default=16, 147 | help='feature dimension (default: 16)') 148 | parser.add_argument('--hops', type=int, default=5) 149 | parser.add_argument('--mlps', type=int, default=2) 150 | parser.add_argument('--latent_points', type=int, default=10000, 151 | help='latent points for validaty check (default: 10000)') 152 | args = parser.parse_args() 153 | cfg = configs[args.cfg] 154 | dataset = load_json(args.data) 155 | print('using {}'.format(args.data)) 156 | print('feat dim {}'.format(args.dim)) 157 | pretraining_gae(dataset, cfg) 158 | -------------------------------------------------------------------------------- /models/pretraining_darts.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | python models/pretraining_darts.py --dim 16 --cfg 4 --bs 32 --epochs 10 --hidden_dim 128 --dim 16 --data data/data_darts_counter600000.json --name darts 3 | -------------------------------------------------------------------------------- /models/pretraining_nasbench101.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | sys.path.insert(0, os.getcwd()) 4 | import torch 5 | import torch.nn as nn 6 | import numpy as np 7 | from torch import optim 8 | from models.model import Model, VAEReconstructed_Loss 9 | from utils.utils import load_json, save_checkpoint_vae, preprocessing 10 | from utils.utils import get_val_acc_vae 11 | from models.configs import configs 12 | import argparse 13 | from nasbench import api 14 | from nasbench.lib import graph_util 15 | 16 | def transform_operations(max_idx): 17 | transform_dict = {0:'input', 1:'conv1x1-bn-relu', 2:'conv3x3-bn-relu', 3:'maxpool3x3', 4:'output'} 18 | ops = [] 19 | for idx in max_idx: 20 | ops.append(transform_dict[idx.item()]) 21 | return ops 22 | 23 | def _build_dataset(dataset, list): 24 | indices = np.random.permutation(list) 25 | X_adj = [] 26 | X_ops = [] 27 | for ind in indices: 28 | X_adj.append(torch.Tensor(dataset[str(ind)]['module_adjacency'])) 29 | X_ops.append(torch.Tensor(dataset[str(ind)]['module_operations'])) 30 | X_adj = torch.stack(X_adj) 31 | X_ops = torch.stack(X_ops) 32 | return X_adj, X_ops, torch.Tensor(indices) 33 | 34 | 35 | def pretraining_model(dataset, cfg, args): 36 | nasbench = api.NASBench('data/nasbench_only108.tfrecord') 37 | train_ind_list, val_ind_list = range(int(len(dataset)*0.9)), range(int(len(dataset)*0.9), len(dataset)) 38 | X_adj_train, X_ops_train, indices_train = _build_dataset(dataset, train_ind_list) 39 | X_adj_val, X_ops_val, indices_val = _build_dataset(dataset, val_ind_list) 40 | model = Model(input_dim=args.input_dim, hidden_dim=args.hidden_dim, latent_dim=args.dim, 41 | num_hops=args.hops, num_mlp_layers=args.mlps, dropout=args.dropout, **cfg['GAE']).cuda() 42 | optimizer = optim.Adam(model.parameters(), lr=1e-3, betas=(0.9, 0.999), eps=1e-08) 43 | epochs = args.epochs 44 | bs = args.bs 45 | loss_total = [] 46 | for epoch in range(0, epochs): 47 | chunks = len(train_ind_list) // bs 48 | if len(train_ind_list) % bs > 0: 49 | chunks += 1 50 | X_adj_split = torch.split(X_adj_train, bs, dim=0) 51 | X_ops_split = torch.split(X_ops_train, bs, dim=0) 52 | indices_split = torch.split(indices_train, bs, dim=0) 53 | loss_epoch = [] 54 | Z = [] 55 | for i, (adj, ops, ind) in enumerate(zip(X_adj_split, X_ops_split, indices_split)): 56 | optimizer.zero_grad() 57 | adj, ops = adj.cuda(), ops.cuda() 58 | # preprocessing 59 | adj, ops, prep_reverse = preprocessing(adj, ops, **cfg['prep']) 60 | # forward 61 | ops_recon, adj_recon, mu, logvar = model(ops, adj.to(torch.long)) 62 | Z.append(mu) 63 | adj_recon, ops_recon = prep_reverse(adj_recon, ops_recon) 64 | adj, ops = prep_reverse(adj, ops) 65 | loss = VAEReconstructed_Loss(**cfg['loss'])((ops_recon, adj_recon), (ops, adj), mu, logvar) 66 | loss.backward() 67 | nn.utils.clip_grad_norm_(model.parameters(), 5) 68 | optimizer.step() 69 | loss_epoch.append(loss.item()) 70 | if i%1000==0: 71 | print('epoch {}: batch {} / {}: loss: {:.5f}'.format(epoch, i, chunks, loss.item())) 72 | Z = torch.cat(Z, dim=0) 73 | z_mean, z_std = Z.mean(0), Z.std(0) 74 | validity_counter = 0 75 | buckets = {} 76 | model.eval() 77 | for _ in range(args.latent_points): 78 | z = torch.randn(7, args.dim).cuda() 79 | z = z * z_std + z_mean 80 | op, ad = model.decoder(z.unsqueeze(0)) 81 | op = op.squeeze(0).cpu() 82 | ad = ad.squeeze(0).cpu() 83 | max_idx = torch.argmax(op, dim=-1) 84 | one_hot = torch.zeros_like(op) 85 | for i in range(one_hot.shape[0]): 86 | one_hot[i][max_idx[i]] = 1 87 | op_decode = transform_operations(max_idx) 88 | ad_decode = (ad>0.5).int().triu(1).numpy() 89 | ad_decode = np.ndarray.tolist(ad_decode) 90 | spec = api.ModelSpec(matrix=ad_decode, ops=op_decode) 91 | if nasbench.is_valid(spec): 92 | validity_counter += 1 93 | fingerprint = graph_util.hash_module(np.array(ad_decode), one_hot.numpy().tolist()) 94 | if fingerprint not in buckets: 95 | buckets[fingerprint] = (ad_decode, one_hot.numpy().astype('int8').tolist()) 96 | validity = validity_counter / args.latent_points 97 | print('Ratio of valid decodings from the prior: {:.4f}'.format(validity)) 98 | print('Ratio of unique decodings from the prior: {:.4f}'.format(len(buckets) / (validity_counter+1e-8))) 99 | acc_ops_val, mean_corr_adj_val, mean_fal_pos_adj_val, acc_adj_val = get_val_acc_vae(model, cfg, X_adj_val, X_ops_val, indices_val) 100 | print('validation set: acc_ops:{0:.4f}, mean_corr_adj:{1:.4f}, mean_fal_pos_adj:{2:.4f}, acc_adj:{3:.4f}'.format( 101 | acc_ops_val, mean_corr_adj_val, mean_fal_pos_adj_val, acc_adj_val)) 102 | print('epoch {}: average loss {:.5f}'.format(epoch, sum(loss_epoch)/len(loss_epoch))) 103 | loss_total.append(sum(loss_epoch) / len(loss_epoch)) 104 | save_checkpoint_vae(model, optimizer, epoch, sum(loss_epoch) / len(loss_epoch), args.dim, args.name, args.dropout, args.seed) 105 | print('loss for epochs: \n', loss_total) 106 | 107 | 108 | 109 | if __name__ == '__main__': 110 | parser = argparse.ArgumentParser(description='Pretraining') 111 | parser.add_argument("--seed", type=int, default=1, help="random seed") 112 | parser.add_argument('--data', type=str, default='data/data.json', 113 | help='Data file (default: data.json') 114 | parser.add_argument('--name', type=str, default='nasbench-101', 115 | help='nasbench-101/nasbench-201/darts') 116 | parser.add_argument('--cfg', type=int, default=4, 117 | help='configuration (default: 4)') 118 | parser.add_argument('--bs', type=int, default=32, 119 | help='batch size (default: 32)') 120 | parser.add_argument('--epochs', type=int, default=8, 121 | help='training epochs (default: 8)') 122 | parser.add_argument('--dropout', type=float, default=0.3, 123 | help='decoder implicit regularization (default: 0.3)') 124 | parser.add_argument('--normalize', action='store_true', default=True, 125 | help='use input normalization') 126 | parser.add_argument('--input_dim', type=int, default=5) 127 | parser.add_argument('--hidden_dim', type=int, default=128) 128 | parser.add_argument('--dim', type=int, default=16, 129 | help='feature dimension (default: 16)') 130 | parser.add_argument('--hops', type=int, default=5) 131 | parser.add_argument('--mlps', type=int, default=2) 132 | parser.add_argument('--latent_points', type=int, default=10000, 133 | help='latent points for validaty check (default: 10000)') 134 | args = parser.parse_args() 135 | np.random.seed(args.seed) 136 | torch.manual_seed(args.seed) 137 | torch.cuda.manual_seed_all(args.seed) 138 | cfg = configs[args.cfg] 139 | dataset = load_json(args.data) 140 | print('using {}'.format(args.data)) 141 | print('feat dim {}'.format(args.dim)) 142 | pretraining_model(dataset, cfg, args) 143 | -------------------------------------------------------------------------------- /models/pretraining_nasbench101.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | python models/pretraining_nasbench101.py --dim 16 --cfg 4 --bs 32 --epochs 8 --seed 1 --name nasbench101 3 | -------------------------------------------------------------------------------- /models/pretraining_nasbench201.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | sys.path.insert(0, os.getcwd()) 4 | import torch 5 | import torch.nn as nn 6 | import numpy as np 7 | from torch import optim 8 | from models.model import Model, VAEReconstructed_Loss 9 | from utils.utils import load_json, save_checkpoint_vae, preprocessing 10 | from utils.utils import get_val_acc_vae, to_ops_nasbench201, is_valid_nasbench201 11 | from models.configs import configs 12 | from nasbench.lib import graph_util 13 | import argparse 14 | 15 | 16 | def _build_dataset(dataset, list): 17 | indices = np.random.permutation(list) 18 | X_adj = [] 19 | X_ops = [] 20 | for ind in indices: 21 | X_adj.append(torch.Tensor(dataset[str(ind)]['module_adjacency'])) 22 | X_ops.append(torch.Tensor(dataset[str(ind)]['module_operations'])) 23 | X_adj = torch.stack(X_adj) 24 | X_ops = torch.stack(X_ops) 25 | return X_adj, X_ops, torch.Tensor(indices) 26 | 27 | 28 | def pretraining_gae(dataset, cfg): 29 | """ 30 | implementation of model pretraining. 31 | :param dataset: nas-bench-201 32 | :param ind_list: a set structure of indices 33 | :return: the number of samples to achieve global optimum 34 | """ 35 | train_ind_list, val_ind_list = range(int(len(dataset)*0.9)), range(int(len(dataset)*0.9), len(dataset)) 36 | X_adj_train, X_ops_train, indices_train = _build_dataset(dataset, train_ind_list) 37 | X_adj_val, X_ops_val, indices_val = _build_dataset(dataset, val_ind_list) 38 | model = Model(input_dim=args.input_dim, hidden_dim=args.hidden_dim, latent_dim=args.latent_dim, 39 | num_hops=args.hops, num_mlp_layers=args.mlps, dropout=args.dropout, **cfg['GAE']).cuda() 40 | optimizer = optim.Adam(model.parameters(), lr=1e-3, betas=(0.9, 0.999), eps=1e-08) 41 | epochs = args.epochs 42 | bs = args.bs 43 | loss_total = [] 44 | for epoch in range(0, epochs): 45 | chunks = len(X_adj_train) // bs 46 | if len(X_adj_train) % bs > 0: 47 | chunks += 1 48 | X_adj_split = torch.split(X_adj_train, bs, dim=0) 49 | X_ops_split = torch.split(X_ops_train, bs, dim=0) 50 | indices_split = torch.split(indices_train, bs, dim=0) 51 | loss_epoch = [] 52 | Z = [] 53 | for i, (adj, ops, ind) in enumerate(zip(X_adj_split, X_ops_split, indices_split)): 54 | optimizer.zero_grad() 55 | adj, ops = adj.cuda(), ops.cuda() 56 | # preprocessing 57 | adj, ops, prep_reverse = preprocessing(adj, ops, **cfg['prep']) 58 | # forward 59 | ops_recon, adj_recon, mu, logvar = model(ops, adj) 60 | Z.append(mu) 61 | adj_recon, ops_recon = prep_reverse(adj_recon, ops_recon) 62 | adj, ops = prep_reverse(adj, ops) 63 | loss = VAEReconstructed_Loss(**cfg['loss'])((ops_recon, adj_recon), (ops, adj), mu, logvar) 64 | loss.backward() 65 | nn.utils.clip_grad_norm_(model.parameters(), 5) 66 | optimizer.step() 67 | loss_epoch.append(loss.item()) 68 | if i%100==0: 69 | print('epoch {}: batch {} / {}: loss: {:.5f}'.format(epoch, i, chunks, loss.item())) 70 | Z = torch.cat(Z, dim=0) 71 | z_mean, z_std = Z.mean(0), Z.std(0) 72 | validity_counter = 0 73 | buckets = {} 74 | model.eval() 75 | for _ in range(args.latent_points): 76 | z = torch.randn(8, args.latent_dim).cuda() 77 | z = z * z_std + z_mean 78 | op, ad = model.decoder(z.unsqueeze(0)) 79 | op = op.squeeze(0).cpu() 80 | ad = ad.squeeze(0).cpu() 81 | max_idx = torch.argmax(op, dim=-1) 82 | one_hot = torch.zeros_like(op) 83 | for i in range(one_hot.shape[0]): 84 | one_hot[i][max_idx[i]] = 1 85 | op_decode = to_ops_nasbench201(max_idx) 86 | ad_decode = (ad>0.5).int().triu(1).numpy() 87 | ad_decode = np.ndarray.tolist(ad_decode) 88 | if is_valid_nasbench201(ad_decode, op_decode): 89 | validity_counter += 1 90 | fingerprint = graph_util.hash_module(np.array(ad_decode), one_hot.numpy().tolist()) 91 | if fingerprint not in buckets: 92 | buckets[fingerprint] = (ad_decode, one_hot.numpy().astype('int8').tolist()) 93 | validity = validity_counter / args.latent_points 94 | print('Ratio of valid decodings from the prior: {:.4f}'.format(validity)) 95 | print('Ratio of unique decodings from the prior: {:.4f}'.format(len(buckets) / (validity_counter+1e-8))) 96 | 97 | acc_ops_val, mean_corr_adj_val, mean_fal_pos_adj_val, acc_adj_val = get_val_acc_vae(model, cfg, X_adj_val, X_ops_val, indices_val) 98 | print('validation set: acc_ops:{0:.2f}, mean_corr_adj:{1:.2f}, mean_fal_pos_adj:{2:.2f}, acc_adj:{3:.2f}'.format( 99 | acc_ops_val, mean_corr_adj_val, mean_fal_pos_adj_val, acc_adj_val)) 100 | print('epoch {}: average loss {:.5f}'.format(epoch, sum(loss_epoch)/len(loss_epoch))) 101 | print("reconstructed adj matrix:", adj_recon[1]) 102 | print("original adj matrix:", adj[1]) 103 | print("reconstructed ops matrix:", ops_recon[1]) 104 | print("original ops matrix:", ops[1]) 105 | loss_total.append(sum(loss_epoch) / len(loss_epoch)) 106 | save_checkpoint_vae(model, optimizer, epoch, sum(loss_epoch) / len(loss_epoch), args.latent_dim, args.name, args.dropout, args.seed) 107 | 108 | 109 | print('loss for epochs: ', loss_total) 110 | 111 | 112 | if __name__ == '__main__': 113 | parser = argparse.ArgumentParser(description='Pretraining') 114 | parser.add_argument("--seed", type=int, default=3, help="random seed") 115 | parser.add_argument('--data', type=str, default='data/cifar10_valid_converged.json') 116 | parser.add_argument('--cfg', type=int, default=4) 117 | parser.add_argument('--bs', type=int, default=32) 118 | parser.add_argument('--epochs', type=int, default=10) 119 | parser.add_argument('--input_dim', type=int, default=7) 120 | parser.add_argument('--hidden_dim', type=int, default=128) 121 | parser.add_argument('--latent_dim', type=int, default=16) 122 | parser.add_argument('--dropout', type=float, default=0.3) 123 | parser.add_argument('--hops', type=int, default=5) 124 | parser.add_argument('--mlps', type=int, default=2) 125 | parser.add_argument('--latent_points', type=int, default=10000) 126 | parser.add_argument('--name', type=str, default='nasbench201', help='the prefix for the saved check point') 127 | args = parser.parse_args() 128 | 129 | #reproducbility is good 130 | np.random.seed(args.seed) 131 | torch.manual_seed(args.seed) 132 | torch.cuda.manual_seed_all(args.seed) 133 | 134 | cfg = configs[args.cfg] 135 | dataset = load_json(args.data) 136 | print('using {}'.format(args.data)) 137 | print('feat dim {}'.format(args.latent_dim)) 138 | 139 | pretraining_gae(dataset, cfg) 140 | -------------------------------------------------------------------------------- /models/pretraining_nasbench201.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | python models/pretraining_nasbench201.py 3 | -------------------------------------------------------------------------------- /plot_scripts/draw_darts.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | sys.path.insert(0, os.getcwd()) 4 | import darts.cnn.genotypes 5 | from graphviz import Digraph 6 | 7 | 8 | def plot(genotype, filename): 9 | g = Digraph( 10 | format='png', 11 | edge_attr=dict(fontsize='20', fontname="times"), 12 | node_attr=dict(style='filled', shape='rect', align='center', fontsize='20', height='0.5', width='0.5', penwidth='2', fontname="times"), 13 | engine='dot') 14 | g.body.extend(['rankdir=UD']) 15 | 16 | g.node("c_{k-2}", fillcolor='darkseagreen2') 17 | g.node("c_{k-1}", fillcolor='darkseagreen2') 18 | assert len(genotype) % 2 == 0 19 | steps = len(genotype) // 2 20 | 21 | for i in range(steps): 22 | g.node(str(i), fillcolor='lightblue') 23 | 24 | for i in range(steps): 25 | for k in [2*i, 2*i + 1]: 26 | j, op = genotype[k] 27 | j = int(j) 28 | if j == 0: 29 | u = "c_{k-2}" 30 | elif j == 1: 31 | u = "c_{k-1}" 32 | else: 33 | u = str(j-2) 34 | v = str(i) 35 | g.edge(u, v, label=op, fillcolor="gray") 36 | 37 | g.node("c_{k}", fillcolor='palegoldenrod') 38 | for i in range(steps): 39 | g.edge(str(i), "c_{k}", fillcolor="gray") 40 | 41 | g.render(filename, view=False) 42 | 43 | 44 | if __name__ == '__main__': 45 | if len(sys.argv) != 2: 46 | print("usage:\n python {} ARCH_NAME".format(sys.argv[0])) 47 | sys.exit(1) 48 | 49 | genotype_name = sys.argv[1] 50 | try: 51 | genotype = eval('genotypes.{}'.format(genotype_name)) 52 | except AttributeError: 53 | print("{} is not specified in genotypes.py".format(genotype_name)) 54 | sys.exit(1) 55 | 56 | plot(genotype.normal, "normal") 57 | plot(genotype.reduce, "reduction") 58 | 59 | -------------------------------------------------------------------------------- /plot_scripts/drawfig4.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | python plot_scripts/visdensity.py \ 4 | --emb_path pretrained/dim-16/arch2vec-model-nasbench101.pt \ 5 | --supervised_emb_path pretrained/dim-16/supervised_dngo_embedding_nasbench101.npy \ 6 | --output_path density/nas101 -------------------------------------------------------------------------------- /plot_scripts/drawfig5-darts.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | python plot_scripts/visgraph.py \ 4 | --data_type darts \ 5 | --data_path data/data_darts_counter600000.json \ 6 | --emb_path pretrained/dim-16/arch2vec-darts.pt \ 7 | --output_path graphvisualization 8 | -------------------------------------------------------------------------------- /plot_scripts/drawfig5-nas101.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | python plot_scripts/visgraph.py \ 4 | --data_type nasbench101 \ 5 | --data_path data/data.json \ 6 | --emb_path pretrained/dim-16/arch2vec-model-nasbench101.pt \ 7 | --supervised_emb_path pretrained/dim-16/supervised_dngo_embedding_nasbench101.npy \ 8 | --output_path graphvisualization 9 | -------------------------------------------------------------------------------- /plot_scripts/drawfig5-nas201.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | python plot_scripts/visgraph.py \ 4 | --data_type nasbench201 \ 5 | --data_path data/cifar10_valid_converged.json \ 6 | --emb_path pretrained/dim-16/cifar10_valid_converged-arch2vec.pt \ 7 | --supervised_emb_path pretrained/dim-16/supervised_dngo_embedding_cifar10_nasbench201.npy \ 8 | --output_path graphvisualization -------------------------------------------------------------------------------- /plot_scripts/nas201.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AIoT-MLSys-Lab/arch2vec/ea01b0cf1295305596ee3c05fa1b6eb14e303512/plot_scripts/nas201.jpg -------------------------------------------------------------------------------- /plot_scripts/pearson_plot_fig2.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import matplotlib.pyplot as plt 4 | from sklearn.metrics import mean_squared_error 5 | from math import sqrt 6 | from scipy import stats 7 | from copy import copy 8 | from mpl_toolkits.axes_grid1 import make_axes_locatable 9 | import matplotlib as mpl 10 | import os 11 | 12 | result_path = 'saved_logs/predict_accuracy' 13 | seed = [1, 10] 14 | acc_th = 0.8 15 | 16 | for s in seed: 17 | ## unsupervised 18 | un_pred_acc = np.load(os.path.join(result_path, 'dngo_unsupervised', 'pred_acc_seed{}.npy'.format(s))) 19 | un_test_acc = np.load(os.path.join(result_path, 'dngo_unsupervised', 'test_acc_seed{}.npy'.format(s))) 20 | idx0 = np.logical_and(un_test_acc > acc_th, un_pred_acc > acc_th) # np.logical_and(un_pred_acc > th, un_test_acc > th) 21 | 22 | ## supervised 23 | sup_pred_acc = np.load(os.path.join(result_path, 'dngo_supervised', 'pred_acc_seed{}.npy'.format(s))) 24 | sup_test_acc = np.load(os.path.join(result_path, 'dngo_supervised', 'test_acc_seed{}.npy'.format(s))) 25 | idx1 = np.logical_and(sup_test_acc > acc_th, sup_pred_acc > acc_th) # np.logical_and(sup_pred_acc > th, sup_test_acc > th) 26 | 27 | bins = np.linspace(0.8, 1, 301) 28 | 29 | fig, (ax0, ax1) = plt.subplots(ncols=2, figsize=(6, 3), sharey=True) 30 | 31 | ax0.plot([0.8, 1], [0.8, 1], 'yellowgreen', linewidth=2) 32 | ax1.plot([0.8, 1], [0.8, 1], 'yellowgreen', linewidth=2) 33 | 34 | H, xedges, yedges = np.histogram2d(un_test_acc[idx0], un_pred_acc[idx0], bins=bins) 35 | H = H.T 36 | Hm = np.ma.masked_where(H < 1, H) 37 | X, Y = np.meshgrid(xedges, yedges) 38 | palette = copy(plt.cm.viridis) 39 | palette.set_bad('w', 1.0) 40 | ax0.pcolormesh(X, Y, Hm, cmap=palette) 41 | 42 | H, xedges, yedges = np.histogram2d(sup_test_acc[idx1], un_pred_acc[idx1], bins=bins) 43 | H = H.T 44 | Hm = np.ma.masked_where(H < 1, H) 45 | X, Y = np.meshgrid(xedges, yedges) 46 | palette = copy(plt.cm.viridis) 47 | palette.set_bad('w', 1.0) 48 | ax1.pcolormesh(X, Y, Hm, cmap=palette) 49 | 50 | ax0.set_xlabel('Test Accuracy') 51 | ax0.set_ylabel('Predicted Accuracy') 52 | ax1.set_xlabel('Test Accuracy') 53 | 54 | ax0.set_xlim(0.8, 0.95) 55 | ax0.set_ylim(0.8, 0.95) 56 | ax1.set_xlim(0.8, 0.95) 57 | ax1.set_ylim(0.8, 0.95) 58 | 59 | ax0.set_yticks(ticks=[0.8, 0.85, 0.90, 0.95]) 60 | ax0.set_xticks(ticks=[0.8, 0.85, 0.9]) 61 | ax1.set_xticks(ticks=[0.8, 0.85, 0.9, 0.95]) 62 | 63 | ax0.set_aspect('equal', 'box') 64 | ax1.set_aspect('equal', 'box') 65 | 66 | plt.subplots_adjust(wspace=0.05, top=0.9, bottom=0.1) 67 | plt.show() 68 | plt.savefig('compare_seed{}.png'.format(s), bbox_inches='tight') 69 | plt.close(fig=fig) 70 | 71 | 72 | -------------------------------------------------------------------------------- /plot_scripts/plot_cdf.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | import matplotlib as mpl 4 | import matplotlib.pyplot as plt 5 | from matplotlib.lines import Line2D 6 | 7 | def fix_hist_step_vertical_line_at_end(ax): 8 | axpolygons = [poly for poly in ax.get_children() if isinstance(poly, mpl.patches.Polygon)] 9 | for poly in axpolygons: 10 | poly.set_xy(poly.get_xy()[:-1]) 11 | 12 | def plot_cdf_comparison(cmap=plt.get_cmap("tab10")): 13 | fig = plt.figure() 14 | ax = fig.add_subplot(1, 1, 1) 15 | final_test_regret_rd_nas101 = [] 16 | final_test_regret_re_nas101 = [] 17 | final_test_regret_rl_nas101 = [] 18 | final_test_regret_bohb_nas101 = [] 19 | final_test_regret_rl_supervised = [] 20 | final_test_regret_bo_supervised = [] 21 | final_test_regret_rl_arch2vec = [] 22 | final_test_regret_bo_arch2vec = [] 23 | 24 | for i in range(1, 501): 25 | f_name = 'saved_logs/discrete/random_search/run_{}_nas_cifar10a_{}.json'.format(i, 20000) 26 | if not os.path.exists(f_name): 27 | continue 28 | f = open(f_name) 29 | data = json.load(f) 30 | for ind, t in enumerate(data['runtime']): 31 | if t > 1e6: 32 | final_test_regret_rd_nas101.append(data['regret_test'][ind]) 33 | break 34 | f.close() 35 | 36 | for i in range(1, 501): 37 | f_name = 'saved_logs/discrete/regularized_evolution/run_{}_nas_cifar10a_{}.json'.format(i, 3500) 38 | if not os.path.exists(f_name): 39 | continue 40 | f = open(f_name) 41 | data = json.load(f) 42 | for ind, t in enumerate(data['runtime']): 43 | if t > 1e6: 44 | final_test_regret_re_nas101.append(data['regret_test'][ind]) 45 | break 46 | f.close() 47 | 48 | for i in range(1, 501): 49 | f_name = 'saved_logs/discrete/rl/run_{}_nas_cifar10a_{}.json'.format(i, 3670) 50 | if not os.path.exists(f_name): 51 | continue 52 | f = open(f_name) 53 | data = json.load(f) 54 | for ind, t in enumerate(data['runtime']): 55 | if t > 1e6: 56 | final_test_regret_rl_nas101.append(data['regret_test'][ind]) 57 | break 58 | f.close() 59 | 60 | for i in range(1, 501): 61 | f_name = 'saved_logs/discrete/bohb/run_{}_nas_cifar10a_{}.json'.format(i, 1000) 62 | if not os.path.exists(f_name): 63 | continue 64 | f = open(f_name) 65 | data = json.load(f) 66 | for ind, t in enumerate(data['runtime']): 67 | if t > 1e6: 68 | final_test_regret_bohb_nas101.append(data['regret_test'][ind]) 69 | break 70 | f.close() 71 | 72 | for i in range(1, 501): 73 | f_name = 'saved_logs/rl/dim16/nasbench101_supervised_search_logs/run_{}_supervised_rl.json'.format(i) 74 | if not os.path.exists(f_name): 75 | continue 76 | f = open(f_name) 77 | data = json.load(f) 78 | for ind, t in enumerate(data['runtime']): 79 | if t > 1e6: 80 | final_test_regret_rl_supervised.append(data['regret_test'][ind]) 81 | break 82 | f.close() 83 | 84 | for i in range(1, 501): 85 | f_name = 'saved_logs/bo/dim16/nasbench101_supervised_search_logs/run_{}_supervised_bo.json'.format(i) 86 | if not os.path.exists(f_name): 87 | continue 88 | f = open(f_name) 89 | data = json.load(f) 90 | for ind, t in enumerate(data['runtime']): 91 | if t > 1e6: 92 | final_test_regret_bo_supervised.append(data['regret_test'][ind]) 93 | break 94 | f.close() 95 | 96 | for i in range(1, 501): 97 | f_name = 'saved_logs/rl/dim16/nasbench101_search_logs/run_{}_arch2vec-model-vae-nasbench-101.json'.format(i) 98 | if not os.path.exists(f_name): 99 | continue 100 | f = open(f_name) 101 | data = json.load(f) 102 | for ind, t in enumerate(data['runtime']): 103 | if t > 1e6: 104 | final_test_regret_rl_arch2vec.append(data['regret_test'][ind]) 105 | break 106 | f.close() 107 | 108 | for i in range(1, 501): 109 | f_name = 'saved_logs/bo/dim16/nasbench101_search_logs/run_{}_arch2vec-model-vae-nasbench-101.json'.format(i) 110 | if not os.path.exists(f_name): 111 | continue 112 | f = open(f_name) 113 | data = json.load(f) 114 | for ind, t in enumerate(data['runtime']): 115 | if t > 1e6: 116 | final_test_regret_bo_arch2vec.append(data['regret_test'][ind]) 117 | break 118 | f.close() 119 | 120 | 121 | plt_name_rd_nas101 = '{}: {}'.format('Discrete', 'Random Search') 122 | plt_name_re_nas101 = '{}: {}'.format('Discrete', 'Regularized Evolution') 123 | plt_name_rl_nas101 = '{}: {}'.format('Discrete', 'REINFORCE') 124 | plt_name_bohb_nas101 = '{}: {}'.format('Discrete', 'BOHB') 125 | plt_name_rl_supervised = '{}: {}'.format('Supervised', 'REINFORCE') 126 | plt_name_bo_supervised = '{}: {}'.format('Supervised', 'Bayesian Optimization') 127 | plt_name_rl_arch2vec = '{}: {}'.format('arch2vec', 'REINFORCE') 128 | plt_name_bo_arch2vec = '{}: {}'.format('arch2vec', 'Bayesian Optimization') 129 | 130 | plt.hist(final_test_regret_rd_nas101, bins=10, range=[8e-4, 1.2e-2], normed=True, cumulative=True, histtype='step', linestyle='--', color=cmap(1), lw=2, label=plt_name_rd_nas101) 131 | plt.hist(final_test_regret_re_nas101, bins=10, range=[8e-4, 1.2e-2], normed=True, cumulative=True, histtype='step', linestyle='--', lw=2.0, color=cmap(4), label=plt_name_re_nas101) 132 | plt.hist(final_test_regret_rl_nas101, bins=10, range=[8e-4, 1.2e-2], normed=True, cumulative=True, histtype='step', linestyle='--', lw=2.0, color=cmap(6), label=plt_name_rl_nas101) 133 | plt.hist(final_test_regret_bohb_nas101, bins=10, range=[8e-4, 1.2e-2], normed=True, cumulative=True, histtype='step', linestyle='--', lw=2.0, color=cmap(5), label=plt_name_bohb_nas101) 134 | plt.hist(final_test_regret_rl_supervised, bins=10, range=[8e-4, 1.2e-2], normed=True, cumulative=True, histtype='step', linestyle='-.', lw=2.0, color=cmap(7), label=plt_name_rl_supervised) 135 | plt.hist(final_test_regret_bo_supervised, bins=10, range=[8e-4, 1.2e-2], normed=True, cumulative=True, histtype='step', linestyle='-.', lw=2.0, color=cmap(9), label=plt_name_bo_supervised) 136 | plt.hist(final_test_regret_rl_arch2vec, bins=10, range=[8e-4, 1.2e-2], normed=True, cumulative=True, histtype='step', linestyle='-.', lw=2.0, color=cmap(0), label=plt_name_rl_arch2vec) 137 | plt.hist(final_test_regret_bo_arch2vec, bins=10, range=[8e-4, 1.2e-2], normed=True, cumulative=True, histtype='step', linestyle='-.', lw=2.0, color=cmap(3), label=plt_name_bo_arch2vec) 138 | fix_hist_step_vertical_line_at_end(ax) 139 | 140 | 141 | ax.set_xscale('log') 142 | ax.set_xlabel('final test regret', fontsize=12) 143 | ax.set_ylabel('CDF', fontsize=12) 144 | handles, labels = ax.get_legend_handles_labels() 145 | new_handles = [Line2D([], [], c=h.get_edgecolor()) for h in handles] 146 | ax.legend(prop={"size":8}, handles=new_handles, labels=labels, loc='upper left') 147 | 148 | 149 | plt.show() 150 | 151 | if __name__ == '__main__': 152 | plot_cdf_comparison() 153 | -------------------------------------------------------------------------------- /plot_scripts/plot_dngo_search_arch2vec.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | import matplotlib.pyplot as plt 4 | from collections import defaultdict 5 | 6 | def plot_over_time_dngo_search_arch2vec(name, cmap=plt.get_cmap("tab10")): 7 | length = [] 8 | for i in range(1, 501): 9 | f_name = 'saved_logs/bo/dim16/run_{}_{}-model-nasbench101.json'.format(i, name) 10 | if not os.path.exists(f_name): 11 | continue 12 | f = open(f_name) 13 | data = json.load(f) 14 | length.append(len(data['runtime'])) 15 | f.close() 16 | 17 | data_avg = defaultdict(list) 18 | test_regret_avg = defaultdict(list) 19 | valid_regret_avg = defaultdict(list) 20 | 21 | fig = plt.figure() 22 | ax = fig.add_subplot(1, 2, 1) 23 | ax_test = fig.add_subplot(1, 2, 2) 24 | for i in range(1, 501): 25 | f_name = 'saved_logs/bo/dim16/run_{}_{}-model-nasbench101.json'.format(i, name) 26 | if not os.path.exists(f_name): 27 | continue 28 | f = open(f_name) 29 | data = json.load(f) 30 | for idx in range(min(length)): 31 | data_avg[idx].append(data['runtime'][idx]) 32 | valid_regret_avg[idx].append(data['regret_validation'][idx]) 33 | test_regret_avg[idx].append(data['regret_test'][idx]) 34 | f.close() 35 | 36 | time_plot = [] 37 | valid_plot = [] 38 | test_plot = [] 39 | for idx in range(min(length)): 40 | if sum(data_avg[idx]) / len(data_avg[idx]) > 1e6: 41 | continue 42 | time_plot.append(sum(data_avg[idx]) / len(data_avg[idx])) 43 | valid_plot.append(sum(valid_regret_avg[idx]) / len(valid_regret_avg[idx])) 44 | test_plot.append(sum(test_regret_avg[idx]) / len(test_regret_avg[idx])) 45 | 46 | ax.plot(time_plot, valid_plot, color=cmap(6), lw=2, label='{}: {}'.format('arch2vec', 'BO')) 47 | ax_test.plot(time_plot, test_plot, '--', color=cmap(6), lw=2, label='{}: {}'.format('arch2vec', 'BO')) 48 | ax.set_xscale('log') 49 | ax.set_yscale('log') 50 | ax.set_xlabel('estimated wall-clock time [s]') 51 | ax.set_ylabel('validation regret') 52 | ax.legend() 53 | ax_test.set_xscale('log') 54 | ax_test.set_yscale('log') 55 | ax_test.set_xlabel('estimated wall-clock time [s]') 56 | ax_test.set_ylabel('test regret') 57 | ax_test.legend() 58 | 59 | save_data = {'time_plot': time_plot, 'valid_plot': valid_plot, 'test_plot': test_plot} 60 | with open('results/{}-{}-nasbench-101.json'.format('BO', name), 'w') as f_w: 61 | json.dump(save_data, f_w) 62 | 63 | plt.show() 64 | 65 | if __name__ == '__main__': 66 | name = 'arch2vec' 67 | plot_over_time_dngo_search_arch2vec(name) 68 | 69 | -------------------------------------------------------------------------------- /plot_scripts/plot_nasbench101_comparison.py: -------------------------------------------------------------------------------- 1 | import json 2 | import matplotlib.pyplot as plt 3 | 4 | def plot_over_time_comparison(cmap=plt.get_cmap("tab10")): 5 | fig = plt.figure() 6 | ax_test = fig.add_subplot(1, 1, 1) 7 | 8 | f_random_search = open('results/Random-Search-Encoding-A.json') 9 | f_regularized_evolution = open('results/Regularized-Evolution-Encoding-A.json') 10 | f_reinforce_search = open('results/Reinforce-Search-Encoding-A.json') 11 | f_bohb_search = open('results/BOHB-Search-Encoding-A.json') 12 | f_reinforce_search_arch2vec = open('results/RL-arch2vec-model-nasbench-101.json') 13 | f_bo_search_arch2vec = open('results/BO-arch2vec-model-nasbench-101.json') 14 | f_reinforce_search_supervised = open('results/RL-supervised-nasbench-101.json') 15 | f_bo_search_supervised = open('results/BO-supervised-nasbench-101.json') 16 | result_random_search = json.load(f_random_search) 17 | result_regularized_evolution = json.load(f_regularized_evolution) 18 | result_reinforce_search = json.load(f_reinforce_search) 19 | result_bohb_search = json.load(f_bohb_search) 20 | results_reinforce_search_arch2vec = json.load(f_reinforce_search_arch2vec) 21 | results_bo_search_arch2vec = json.load(f_bo_search_arch2vec) 22 | results_reinforce_search_supervised = json.load(f_reinforce_search_supervised) 23 | results_bo_search_supervised = json.load(f_bo_search_supervised) 24 | f_random_search.close() 25 | f_regularized_evolution.close() 26 | f_reinforce_search.close() 27 | f_bohb_search.close() 28 | f_reinforce_search_arch2vec.close() 29 | f_bo_search_arch2vec.close() 30 | f_reinforce_search_supervised.close() 31 | f_bo_search_supervised.close() 32 | 33 | ax_test.plot(result_random_search['time_plot'], result_random_search['test_plot'], linestyle='-.', marker='^', markevery=1e3, color=cmap(1), lw=2, markersize=4, label='{}: {}'.format('Discrete', 'Random Search')) 34 | ax_test.plot(result_regularized_evolution['time_plot'], result_regularized_evolution['test_plot'], linestyle='-.', marker='s', markevery=1e3, color=cmap(4), lw=2, markersize=4, label='{}: {}'.format('Discrete', 'Regularized Evolution')) 35 | ax_test.plot(result_reinforce_search['time_plot'], result_reinforce_search['test_plot'], linestyle='-.', marker='.', markevery=1e3, color=cmap(6), lw=2, markersize=4, label='{}: {}'.format('Discrete', 'REINFORCE')) 36 | ax_test.plot(result_bohb_search['time_plot'], result_bohb_search['test_plot'] , linestyle='-.', marker='*', markevery=1e3, color=cmap(5), lw=2, markersize=4, label='{}: {}'.format('Discrete', 'BOHB')) 37 | ax_test.plot(results_reinforce_search_supervised['time_plot'], results_reinforce_search_supervised['test_plot'], linestyle='--', marker='.', markevery=1e3, color=cmap(7), lw=2, markersize=4, label='{}: {}'.format('Supervised', 'REINFORCE')) 38 | ax_test.plot(results_bo_search_supervised['time_plot'], results_bo_search_supervised['test_plot'], linestyle='--', marker='v', markevery=1e3, color=cmap(9), lw=2, markersize=4, label='{}: {}'.format('Supervised', 'Bayesian Optimization')) 39 | ax_test.plot(results_reinforce_search_arch2vec['time_plot'], results_reinforce_search_arch2vec['test_plot'], linestyle='-.', marker='.', markevery=1e3, color=cmap(0), lw=2, markersize=4, label='{}: {}'.format('arch2vec', 'REINFORCE')) 40 | ax_test.plot(results_bo_search_arch2vec['time_plot'], results_bo_search_arch2vec['test_plot'], linestyle='-.', marker='v', markevery=1e3, color=cmap(3), lw=2, markersize=4, label='{}: {}'.format('arch2vec', 'Bayesian Optimization')) 41 | 42 | ax_test.set_xscale('log') 43 | ax_test.set_yscale('log') 44 | ax_test.set_xlabel('estimated wall-clock time [s]', fontsize=12) 45 | ax_test.set_ylabel('test regret', fontsize=12) 46 | ax_test.legend(prop={"size":10}) 47 | 48 | plt.show() 49 | 50 | if __name__ == '__main__': 51 | plot_over_time_comparison() 52 | -------------------------------------------------------------------------------- /plot_scripts/plot_reinforce_search_arch2vec.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | import matplotlib.pyplot as plt 4 | from collections import defaultdict 5 | 6 | def plot_over_time_reinforce_search_arch2vec(name, cmap=plt.get_cmap("tab10")): 7 | length = [] 8 | for i in range(1, 501): 9 | f_name = 'saved_logs/rl/dim16/run_{}_{}-model-nasbench101.json'.format(i, name) 10 | if not os.path.exists(f_name): 11 | continue 12 | f = open(f_name) 13 | data = json.load(f) 14 | length.append(len(data['runtime'])) 15 | f.close() 16 | 17 | data_avg = defaultdict(list) 18 | test_regret_avg = defaultdict(list) 19 | valid_regret_avg = defaultdict(list) 20 | 21 | fig = plt.figure() 22 | ax = fig.add_subplot(1, 2, 1) 23 | ax_test = fig.add_subplot(1, 2, 2) 24 | for i in range(1, 501): 25 | f_name = 'saved_logs/rl/dim16/run_{}_{}-model-nasbench101.json'.format(i, name) 26 | if not os.path.exists(f_name): 27 | continue 28 | f = open(f_name) 29 | data = json.load(f) 30 | for idx in range(min(length)): 31 | data_avg[idx].append(data['runtime'][idx]) 32 | valid_regret_avg[idx].append(data['regret_validation'][idx]) 33 | test_regret_avg[idx].append(data['regret_test'][idx]) 34 | f.close() 35 | 36 | time_plot = [] 37 | valid_plot = [] 38 | test_plot = [] 39 | for idx in range(min(length)): 40 | if sum(data_avg[idx]) / len(data_avg[idx]) > 1e6: 41 | continue 42 | time_plot.append(sum(data_avg[idx]) / len(data_avg[idx])) 43 | valid_plot.append(sum(valid_regret_avg[idx]) / len(valid_regret_avg[idx])) 44 | test_plot.append(sum(test_regret_avg[idx]) / len(test_regret_avg[idx])) 45 | 46 | ax.plot(time_plot, valid_plot, color=cmap(6), lw=2, label='{}: {}'.format('arch2vec', 'RL')) 47 | ax_test.plot(time_plot, test_plot, '--', color=cmap(6), lw=2, label='{}: {}'.format('arch2vec', 'RL')) 48 | ax.set_xscale('log') 49 | ax.set_yscale('log') 50 | ax.set_xlabel('estimated wall-clock time [s]') 51 | ax.set_ylabel('validation regret') 52 | ax.legend() 53 | ax_test.set_xscale('log') 54 | ax_test.set_yscale('log') 55 | ax_test.set_xlabel('estimated wall-clock time [s]') 56 | ax_test.set_ylabel('test regret') 57 | ax_test.legend() 58 | 59 | save_data = {'time_plot': time_plot, 'valid_plot': valid_plot, 'test_plot': test_plot} 60 | with open('results/{}-{}-nasbench-101.json'.format('RL', name), 'w') as f_w: 61 | json.dump(save_data, f_w) 62 | 63 | plt.show() 64 | 65 | if __name__ == '__main__': 66 | name = 'arch2vec' 67 | plot_over_time_reinforce_search_arch2vec(name) 68 | 69 | -------------------------------------------------------------------------------- /plot_scripts/summarize_nasbench201.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | import numpy as np 4 | #from prettytable import PrettyTable 5 | 6 | 7 | #t = PrettyTable(['Method', 'CIFAR-10 val', 'CIFAR-10 test', 'CIFAR-100 val', 'CIFAR-100 test', 'ImageNet-16-120 val', 'ImageNet-16-120 test']) 8 | #t = PrettyTable(['Method', 'CIFAR-10 val', 'CIFAR-10 test']) 9 | 10 | def get_summary(dataset, file_name, data_dir, val_test, N_runs): 11 | val_acc = [] 12 | test_acc = [] 13 | for k in range(1, N_runs+1): 14 | file_name_ = file_name.format(dataset, k) 15 | file_path = os.path.join(data_dir, file_name_) 16 | if os.path.isfile(file_path): 17 | with open(file_path, 'r') as f: 18 | acc_dict = json.load(f) 19 | val_acc.append(acc_dict[val_test[0]]) # using average instead of individual 20 | test_acc.append(acc_dict[val_test[1]]) 21 | val_acc = np.array(val_acc) 22 | test_acc = np.array(test_acc) 23 | 24 | return val_acc.mean(), val_acc.std(), test_acc.mean(), test_acc.std() 25 | 26 | 27 | # RL (ours) 28 | row = ['arch2vec-RL'] 29 | data_dir = 'saved_logs/rl/dim16/' 30 | datasets = {'cifar10_valid_converged':500, 'cifar100':500, 'ImageNet16_120':500} 31 | file_name = 'nasbench201_{}_run_{}_full.json' 32 | val_test = ['val_acc_avg', 'test_acc_avg'] 33 | for i, (dataset, N_runs) in enumerate(datasets.items()): 34 | val_mean, val_std, test_mean, test_std = get_summary(dataset, file_name, data_dir, val_test, N_runs) 35 | row.append('{:.2f}+-{:.2f}'.format(val_mean, val_std)) 36 | row.append('{:.2f}+-{:.2f}'.format(test_mean, test_std)) 37 | print(row) 38 | 39 | 40 | 41 | ## BO (ours) 42 | row = ['arch2vec-BO'] 43 | data_dir = 'saved_logs/bo/dim16/' 44 | datasets = {'cifar10_valid_converged':500, 'cifar100':500, 'ImageNet16_120':500} 45 | file_name = 'nasbench201_{}_run_{}_full.json' 46 | val_test = ['val_acc_avg', 'test_acc_avg'] 47 | for i, (dataset, N_runs) in enumerate(datasets.items()): 48 | val_mean, val_std, test_mean, test_std = get_summary(dataset, file_name, data_dir, val_test, N_runs) 49 | row.append('{:.2f}+-{:.2f}'.format(val_mean, val_std)) 50 | row.append('{:.2f}+-{:.2f}'.format(test_mean, test_std)) 51 | 52 | 53 | print(row) 54 | 55 | 56 | 57 | 58 | -------------------------------------------------------------------------------- /plot_scripts/try_networkx.py: -------------------------------------------------------------------------------- 1 | import networkx as nx 2 | import numpy as np 3 | import matplotlib.pyplot as plt 4 | 5 | def node_match(n1, n2): 6 | if n1['op'] == n2['op']: 7 | return True 8 | else: 9 | return False 10 | 11 | def edge_match(e1, e2): 12 | return True 13 | 14 | def gen_graph(adj, ops): 15 | G = nx.DiGraph() 16 | for k, op in enumerate(ops): 17 | G.add_node(k, op=op) 18 | assert adj.shape[0] == adj.shape[1] == len(ops) 19 | for row in range(len(ops)): 20 | for col in range(row + 1, len(ops)): 21 | if adj[row, col] > 0: 22 | G.add_edge(row, col) 23 | return G 24 | 25 | def preprocess_adj_op(adj, op): 26 | def counting_trailing_false(l): 27 | count = 0 28 | for TF in l[-1::-1]: 29 | if TF: 30 | break 31 | else: 32 | count += 1 33 | return count 34 | 35 | def transform_op(op): 36 | idx2op = {0:'input', 1:'conv1x1-bn-relu', 2:'conv3x3-bn-relu', 3:'maxpool3x3', 4:'output'} 37 | return [idx2op[idx] for idx in op.argmax(axis=1)] 38 | 39 | adj = np.array(adj).astype(int) 40 | op = np.array(op).astype(int) 41 | 42 | assert op.shape[0] == adj.shape[0] == adj.shape[1] 43 | # find all zero columns 44 | adj_zero_col = counting_trailing_false(adj.any(axis=0)) 45 | # find all zero rows 46 | adj_zero_row = counting_trailing_false(adj.any(axis=1)) 47 | # find all zero rows 48 | op_zero_row = counting_trailing_false(op.any(axis=1)) 49 | assert adj_zero_col == op_zero_row == adj_zero_row - 1, 'Inconsistant result {}={}={}'.format(adj_zero_col, op_zero_row, adj_zero_row - 1) 50 | N = op.shape[0] - adj_zero_col 51 | adj = adj[:N, :N] 52 | op = op[:N] 53 | 54 | return adj, transform_op(op) 55 | 56 | 57 | 58 | if __name__ == '__main__': 59 | 60 | adj1 = np.array([[0, 1, 1, 1, 0], 61 | [0, 0, 1, 0, 0], 62 | [0, 0, 0, 0, 1], 63 | [0, 0, 0, 0, 1], 64 | [0, 0, 0, 0, 0]]) 65 | op1 = ['in', 'conv1x1', 'conv3x3', 'mp3x3', 'out'] 66 | 67 | adj2 = np.array([[0, 1, 1, 1, 0], 68 | [0, 0, 0, 1, 0], 69 | [0, 0, 0, 0, 1], 70 | [0, 0, 0, 0, 1], 71 | [0, 0, 0, 0, 0]]) 72 | op2 = ['in', 'conv1x1', 'mp3x3', 'conv3x3', 'out'] 73 | 74 | 75 | adj3 = np.array([[0, 1, 1, 1, 0, 0], 76 | [0, 0, 1, 0, 0, 0], 77 | [0, 0, 0, 0, 1, 0], 78 | [0, 0, 0, 0, 1, 0], 79 | [0, 0, 0, 0, 0, 1], 80 | [0, 0, 0, 0, 0, 0]]) 81 | op3 = ['in', 'conv1x1', 'conv3x3', 'mp3x3', 'out','out2'] 82 | 83 | adj4 = np.array([[0, 1, 1, 1, 0, 0], 84 | [0, 0, 1, 0, 0, 0], 85 | [0, 0, 0, 0, 1, 0], 86 | [0, 0, 0, 0, 1, 0], 87 | [0, 0, 0, 0, 0, 0], 88 | [0, 0, 0, 0, 0, 0]]) 89 | op4 = np.array([[1, 0, 0, 0, 0], 90 | [0, 1, 0, 0, 0], 91 | [0, 0, 1, 0, 0], 92 | [0, 0, 0, 1, 0], 93 | [0, 0, 0, 0, 1], 94 | [0, 0, 0, 0, 0]]) 95 | adj4, op4 = preprocess_adj_op(adj4, op4) 96 | 97 | 98 | 99 | G1 = gen_graph(adj1, op1) 100 | G2 = gen_graph(adj2, op2) 101 | G3 = gen_graph(adj3, op3) 102 | G4 = gen_graph(adj4, op4) 103 | 104 | 105 | plt.subplot(141) 106 | nx.draw(G1, with_labels=True, font_weight='bold') 107 | plt.subplot(142) 108 | nx.draw(G2, with_labels=True, font_weight='bold') 109 | plt.subplot(143) 110 | nx.draw(G3, with_labels=True, font_weight='bold') 111 | plt.subplot(144) 112 | nx.draw(G4, with_labels=True, font_weight='bold') 113 | 114 | nx.graph_edit_distance(G1,G2, node_match=node_match, edge_match=edge_match) 115 | nx.graph_edit_distance(G2,G3, node_match=node_match, edge_match=edge_match) -------------------------------------------------------------------------------- /preprocessing/gen_isomorphism_graphs.py: -------------------------------------------------------------------------------- 1 | import json 2 | import torch 3 | import numpy as np 4 | import logging 5 | import sys 6 | import os 7 | sys.path.insert(0, os.getcwd()) 8 | from darts.cnn.genotypes import Genotype 9 | from darts.cnn.model import NetworkImageNet as Network 10 | from thop import profile 11 | 12 | def process(geno): 13 | for i, item in enumerate(geno): 14 | geno[i] = tuple(geno[i]) 15 | return geno 16 | 17 | def transform_operations(ops): 18 | transform_dict = {'c_k-2': 0, 'c_k-1': 1, 'none': 2, 'max_pool_3x3': 3, 'avg_pool_3x3': 4, 'skip_connect': 5, 19 | 'sep_conv_3x3': 6, 'sep_conv_5x5': 7, 'dil_conv_3x3': 8, 'dil_conv_5x5': 9, 'output': 10} 20 | 21 | ops_array = np.zeros([11, 11], dtype='int8') 22 | for row, op in enumerate(ops): 23 | ops_array[row, op] = 1 24 | return ops_array 25 | 26 | def sample_arch(): 27 | num_ops = len(OPS) 28 | normal = [] 29 | normal_name = [] 30 | for i in range(NUM_VERTICES): 31 | ops = np.random.choice(range(num_ops), NUM_VERTICES) 32 | nodes_in_normal = np.random.choice(range(i+2), 2, replace=False) 33 | normal.extend([(nodes_in_normal[0], ops[0]), (nodes_in_normal[1], ops[1])]) 34 | normal_name.extend([(str(nodes_in_normal[0]), OPS[ops[0]]), (str(nodes_in_normal[1]), OPS[ops[1]])]) 35 | 36 | return (normal), (normal_name) 37 | 38 | 39 | def build_mat_encoding(normal, normal_name, counter): 40 | adj = torch.zeros(11, 11) 41 | ops = torch.zeros(11, 11) 42 | block_0 = (normal[0], normal[1]) 43 | prev_b0_n1, prev_b0_n2 = block_0[0][0], block_0[1][0] 44 | prev_b0_o1, prev_b0_o2 = block_0[0][1], block_0[1][1] 45 | 46 | block_1 = (normal[2], normal[3]) 47 | prev_b1_n1, prev_b1_n2 = block_1[0][0], block_1[1][0] 48 | prev_b1_o1, prev_b1_o2 = block_1[0][1], block_1[1][1] 49 | 50 | block_2 = (normal[4], normal[5]) 51 | prev_b2_n1, prev_b2_n2 = block_2[0][0], block_2[1][0] 52 | prev_b2_o1, prev_b2_o2 = block_2[0][1], block_2[1][1] 53 | 54 | block_3 = (normal[6], normal[7]) 55 | prev_b3_n1, prev_b3_n2 = block_3[0][0], block_3[1][0] 56 | prev_b3_o1, prev_b3_o2 = block_3[0][1], block_3[1][1] 57 | 58 | adj[2][-1] = 1 59 | adj[3][-1] = 1 60 | adj[4][-1] = 1 61 | adj[5][-1] = 1 62 | adj[6][-1] = 1 63 | adj[7][-1] = 1 64 | adj[8][-1] = 1 65 | adj[9][-1] = 1 66 | 67 | # B0 68 | adj[prev_b0_n1][2] = 1 69 | adj[prev_b0_n2][3] = 1 70 | 71 | # B1 72 | if prev_b1_n1 == 2: 73 | adj[2][4] = 1 74 | adj[3][4] = 1 75 | else: 76 | adj[prev_b1_n1][4] = 1 77 | 78 | if prev_b1_n2 == 2: 79 | adj[2][5] = 1 80 | adj[3][5] = 1 81 | else: 82 | adj[prev_b1_n2][5] = 1 83 | 84 | # B2 85 | if prev_b2_n1 == 2: 86 | adj[2][6] = 1 87 | adj[3][6] = 1 88 | elif prev_b2_n1 == 3: 89 | adj[4][6] = 1 90 | adj[5][6] = 1 91 | else: 92 | adj[prev_b2_n1][6] = 1 93 | 94 | if prev_b2_n2 == 2: 95 | adj[2][7] = 1 96 | adj[3][7] = 1 97 | elif prev_b2_n2 == 3: 98 | adj[4][7] = 1 99 | adj[5][7] = 1 100 | else: 101 | adj[prev_b2_n2][7] = 1 102 | 103 | # B3 104 | if prev_b3_n1 == 2: 105 | adj[2][8] = 1 106 | adj[3][8] = 1 107 | elif prev_b3_n1 == 3: 108 | adj[4][8] = 1 109 | adj[5][8] = 1 110 | elif prev_b3_n1 == 4: 111 | adj[6][8] = 1 112 | adj[7][8] = 1 113 | else: 114 | adj[prev_b3_n1][8] = 1 115 | 116 | if prev_b3_n2 == 2: 117 | adj[2][9] = 1 118 | adj[3][9] = 1 119 | elif prev_b3_n2 == 3: 120 | adj[4][9] = 1 121 | adj[5][9] = 1 122 | elif prev_b3_n2 == 4: 123 | adj[6][9] = 1 124 | adj[7][9] = 1 125 | else: 126 | adj[prev_b3_n2][9] = 1 127 | 128 | ops[0][0] = 1 129 | ops[1][1] = 1 130 | ops[-1][-1] = 1 131 | ops[2][prev_b0_o1+2] = 1 132 | ops[3][prev_b0_o2+2] = 1 133 | ops[4][prev_b1_o1+2] = 1 134 | ops[5][prev_b1_o2+2] = 1 135 | ops[6][prev_b2_o1+2] = 1 136 | ops[7][prev_b2_o2+2] = 1 137 | ops[8][prev_b3_o1+2] = 1 138 | ops[9][prev_b3_o2+2] = 1 139 | 140 | #print("adj encoding: \n{} \n".format(adj.int())) 141 | #print("ops encoding: \n{} \n".format(ops.int())) 142 | 143 | label = torch.argmax(ops, dim=1) 144 | 145 | fingerprint = graph_util.hash_module(adj.int().numpy(), label.int().numpy().tolist()) 146 | if fingerprint not in buckets: 147 | normal_cell = [(item[1], int(item[0])) for item in normal_name] 148 | reduce_cell = normal_cell.copy() 149 | genotype = Genotype(normal=normal_cell, normal_concat=[2, 3, 4, 5], reduce=reduce_cell, reduce_concat=[2, 3, 4, 5]) 150 | model = Network(48, 1000, 14, False, genotype).cuda() 151 | input = torch.randn(1, 3, 224, 224).cuda() 152 | macs, params = profile(model, inputs=(input, )) 153 | if macs < 6e8: 154 | counter += 1 155 | print("counter: {}, flops: {}, params: {}".format(counter, macs, params)) 156 | buckets[fingerprint] = (adj.numpy().astype('int8').tolist(), label.numpy().astype('int8').tolist(), (normal_name)) 157 | 158 | if counter > 0 and counter % 1e5 == 0: 159 | with open('data/data_darts_counter{}.json'.format(counter), 'w') as f: 160 | json.dump(buckets, f) 161 | 162 | return counter 163 | 164 | if __name__ == '__main__': 165 | from nasbench.lib import graph_util 166 | OPS = ['none', 167 | 'max_pool_3x3', 168 | 'avg_pool_3x3', 169 | 'skip_connect', 170 | 'sep_conv_3x3', 171 | 'sep_conv_5x5', 172 | 'dil_conv_3x3', 173 | 'dil_conv_5x5' 174 | ] 175 | NUM_VERTICES = 4 176 | INPUT_1 = 'c_k-2' 177 | INPUT_2 = 'c_k-1' 178 | logging.basicConfig(filename='darts_preparation.log') 179 | 180 | buckets = {} 181 | counter = 0 182 | while counter <= 6e5: 183 | normal, normal_name = sample_arch() 184 | counter = build_mat_encoding(normal, normal_name, counter) 185 | -------------------------------------------------------------------------------- /preprocessing/gen_json.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | from nasbench import api 6 | from random import randint 7 | import json 8 | import numpy as np 9 | from collections import OrderedDict 10 | 11 | # Replace this string with the path to the downloaded nasbench.tfrecord before 12 | # executing. 13 | NASBENCH_TFRECORD = 'data/nasbench_only108.tfrecord' 14 | 15 | INPUT = 'input' 16 | OUTPUT = 'output' 17 | CONV1X1 = 'conv1x1-bn-relu' 18 | CONV3X3 = 'conv3x3-bn-relu' 19 | MAXPOOL3X3 = 'maxpool3x3' 20 | 21 | def gen_data_point(nasbench): 22 | 23 | i = 0 24 | epoch = 108 25 | 26 | padding = [0, 0, 0, 0, 0, 0, 0] 27 | best_val_acc = 0 28 | best_test_acc = 0 29 | 30 | for unique_hash in nasbench.hash_iterator(): 31 | fixed_metrics, computed_metrics = nasbench.get_metrics_from_hash(unique_hash) 32 | print('\nIterating over {} / {} unique models in the dataset.'.format(i, 423623)) 33 | test_acc_avg = 0.0 34 | val_acc_avg = 0.0 35 | training_time = 0.0 36 | for repeat_index in range(len(computed_metrics[epoch])): 37 | assert len(computed_metrics[epoch])==3, 'len(computed_metrics[epoch]) should be 3' 38 | data_point = computed_metrics[epoch][repeat_index] 39 | val_acc_avg += data_point['final_validation_accuracy'] 40 | test_acc_avg += data_point['final_test_accuracy'] 41 | training_time += data_point['final_training_time'] 42 | val_acc_avg = val_acc_avg/3.0 43 | test_acc_avg = test_acc_avg/3.0 44 | training_time_avg = training_time/3.0 45 | ops_array = transform_operations(fixed_metrics['module_operations']) 46 | adj_array = fixed_metrics['module_adjacency'].tolist() 47 | model_spec = api.ModelSpec(fixed_metrics['module_adjacency'], fixed_metrics['module_operations']) 48 | data = nasbench.query(model_spec, epochs=108) 49 | print('api training time: {}'.format(data['training_time'])) 50 | print('real training time: {}'.format(training_time_avg)) 51 | 52 | # pad zero to adjacent matrix that has nodes less than 7 53 | if len(adj_array) <= 6: 54 | for row in range(len(adj_array)): 55 | for _ in range(7-len(adj_array)): 56 | adj_array[row].append(0) 57 | for _ in range(7-len(adj_array)): 58 | adj_array.append(padding) 59 | 60 | if val_acc_avg > best_val_acc: 61 | best_val_acc = val_acc_avg 62 | 63 | if test_acc_avg > best_test_acc: 64 | best_test_acc = test_acc_avg 65 | 66 | print('best val. acc: {:.4f}, best test acc {:.4f}'.format(best_val_acc, best_test_acc)) 67 | 68 | yield {i: # unique_hash 69 | {'test_accuracy': test_acc_avg, 70 | 'validation_accuracy': val_acc_avg, 71 | 'module_adjacency':adj_array, 72 | 'module_operations': ops_array.tolist(), 73 | 'training_time': training_time_avg}} 74 | 75 | i += 1 76 | 77 | def transform_operations(ops): 78 | transform_dict = {'input':0, 'conv1x1-bn-relu':1, 'conv3x3-bn-relu':2, 'maxpool3x3':3, 'output':4} 79 | ops_array = np.zeros([7,5], dtype='int8') 80 | for row, op in enumerate(ops): 81 | col = transform_dict[op] 82 | ops_array[row, col] = 1 83 | return ops_array 84 | 85 | 86 | def gen_json_file(): 87 | nasbench = api.NASBench(NASBENCH_TFRECORD) 88 | nas_gen = gen_data_point(nasbench) 89 | data_dict = OrderedDict() 90 | for data_point in nas_gen: 91 | data_dict.update(data_point) 92 | with open('data/data.json', 'w') as outfile: 93 | json.dump(data_dict, outfile) 94 | 95 | 96 | 97 | 98 | if __name__ == '__main__': 99 | gen_json_file() 100 | -------------------------------------------------------------------------------- /preprocessing/nasbench201_json.py: -------------------------------------------------------------------------------- 1 | """API source: https://github.com/D-X-Y/NAS-Bench-201/blob/v1.1/nas_201_api/api.py""" 2 | from api import NASBench201API as API 3 | import numpy as np 4 | import json 5 | from collections import OrderedDict 6 | 7 | nas_bench = API('data/NAS-Bench-201-v1_0-e61699.pth') 8 | 9 | 10 | 11 | # num = len(api) 12 | # for i, arch_str in enumerate(api): 13 | # print ('{:5d}/{:5d} : {:}'.format(i, len(api), arch_str)) 14 | # 15 | # info = api.query_meta_info_by_index(1) # This is an instance of `ArchResults` 16 | # res_metrics = info.get_metrics('cifar10', 'train') # This is a dict with metric names as keys 17 | # cost_metrics = info.get_comput_costs('cifar100') # This is a dict with metric names as keys, e.g., flops, params, latency 18 | # 19 | # api.show(1) 20 | # api.show(2) 21 | 22 | def info2mat(arch_index): 23 | #info.all_results 24 | 25 | info = nas_bench.query_meta_info_by_index(arch_index) 26 | ops = {'input':0, 'nor_conv_1x1':1, 'nor_conv_3x3':2, 'avg_pool_3x3':3, 'skip_connect':4, 'none':5, 'output':6} 27 | adj_mat = np.array([[0, 1, 1, 0, 1, 0, 0, 0], 28 | [0, 0, 0, 1, 0, 1 ,0 ,0], 29 | [0, 0, 0, 0, 0, 0, 1, 0], 30 | [0, 0, 0, 0, 0, 0, 1, 0], 31 | [0, 0, 0, 0, 0, 0, 0, 1], 32 | [0, 0, 0, 0, 0, 0, 0, 1], 33 | [0, 0, 0, 0, 0, 0, 0, 1], 34 | [0, 0, 0, 0, 0, 0, 0, 0]]) 35 | 36 | nodes = ['input'] 37 | steps = info.arch_str.split('+') 38 | steps_coding = ['0', '0', '1', '0', '1', '2'] 39 | cont = 0 40 | for step in steps: 41 | step = step.strip('|').split('|') 42 | for node in step: 43 | n, idx = node.split('~') 44 | assert idx == steps_coding[cont] 45 | cont += 1 46 | nodes.append(n) 47 | nodes.append('output') 48 | 49 | node_mat =np.zeros([8, len(ops)]).astype(int) 50 | ops_idx = [ops[k] for k in nodes] 51 | node_mat[[0,1,2,3,4,5,6,7],ops_idx] = 1 52 | 53 | # For cifar10-valid with converged 54 | valid_acc, val_acc_avg, time_cost, test_acc, test_acc_avg = train_and_eval(arch_index, nepoch=None, dataname='cifar10-valid', use_converged_LR=True) 55 | cifar10_valid_converged = { 'test_accuracy': test_acc, 56 | 'test_accuracy_avg': test_acc_avg, 57 | 'validation_accuracy':valid_acc, 58 | 'validation_accuracy_avg': val_acc_avg, 59 | 'module_adjacency':adj_mat.tolist(), 60 | 'module_operations': node_mat.tolist(), 61 | 'training_time': time_cost} 62 | 63 | 64 | # For cifar100 65 | valid_acc, val_acc_avg, time_cost, test_acc, test_acc_avg = train_and_eval(arch_index, nepoch=199, dataname='cifar100', use_converged_LR=False) 66 | cifar100 = {'test_accuracy': test_acc, 67 | 'test_accuracy_avg': test_acc_avg, 68 | 'validation_accuracy': valid_acc, 69 | 'validation_accuracy_avg': val_acc_avg, 70 | 'module_adjacency': adj_mat.tolist(), 71 | 'module_operations': node_mat.tolist(), 72 | 'training_time': time_cost} 73 | 74 | # For ImageNet16-120 75 | valid_acc, val_acc_avg, time_cost, test_acc, test_acc_avg = train_and_eval(arch_index, nepoch=199, dataname='ImageNet16-120', use_converged_LR=False) 76 | ImageNet16_120 = {'test_accuracy': test_acc, 77 | 'test_accuracy_avg': test_acc_avg, 78 | 'validation_accuracy': valid_acc, 79 | 'validation_accuracy_avg': val_acc_avg, 80 | 'module_adjacency': adj_mat.tolist(), 81 | 'module_operations': node_mat.tolist(), 82 | 'training_time': time_cost} 83 | 84 | 85 | return {'cifar10_valid_converged': cifar10_valid_converged, 86 | 'cifar100':cifar100, 87 | 'ImageNet16_120': ImageNet16_120 } 88 | 89 | def train_and_eval(arch_index, nepoch=None, dataname=None, use_converged_LR=True): 90 | assert dataname !='cifar10', 'Do not allow cifar10 dataset' 91 | if use_converged_LR and dataname=='cifar10-valid': 92 | assert nepoch == None, 'When using use_converged_LR=True, please set nepoch=None, use 12-converged-epoch by default.' 93 | 94 | 95 | info = nas_bench.get_more_info(arch_index, dataname, None, True) 96 | valid_acc, time_cost = info['valid-accuracy'], info['train-all-time'] + info['valid-per-time'] 97 | valid_acc_avg = nas_bench.get_more_info(arch_index, 'cifar10-valid', None, False, False)['valid-accuracy'] 98 | test_acc = nas_bench.get_more_info(arch_index, 'cifar10', None, False, True)['test-accuracy'] 99 | test_acc_avg = nas_bench.get_more_info(arch_index, 'cifar10', None, False, False)['test-accuracy'] 100 | 101 | elif not use_converged_LR: 102 | 103 | assert isinstance(nepoch, int), 'nepoch should be int' 104 | xoinfo = nas_bench.get_more_info(arch_index, 'cifar10-valid', None, True) 105 | xocost = nas_bench.get_cost_info(arch_index, 'cifar10-valid', False) 106 | info = nas_bench.get_more_info(arch_index, dataname, nepoch, False, True) 107 | cost = nas_bench.get_cost_info(arch_index, dataname, False) 108 | # The following codes are used to estimate the time cost. 109 | # When we build NAS-Bench-201, architectures are trained on different machines and we can not use that time record. 110 | # When we create checkpoints for converged_LR, we run all experiments on 1080Ti, and thus the time for each architecture can be fairly compared. 111 | nums = {'ImageNet16-120-train': 151700, 'ImageNet16-120-valid': 3000, 112 | 'cifar10-valid-train' : 25000, 'cifar10-valid-valid' : 25000, 113 | 'cifar100-train' : 50000, 'cifar100-valid' : 5000} 114 | estimated_train_cost = xoinfo['train-per-time'] / nums['cifar10-valid-train'] * nums['{:}-train'.format(dataname)] / xocost['latency'] * cost['latency'] * nepoch 115 | estimated_valid_cost = xoinfo['valid-per-time'] / nums['cifar10-valid-valid'] * nums['{:}-valid'.format(dataname)] / xocost['latency'] * cost['latency'] 116 | try: 117 | valid_acc, time_cost = info['valid-accuracy'], estimated_train_cost + estimated_valid_cost 118 | except: 119 | valid_acc, time_cost = info['est-valid-accuracy'], estimated_train_cost + estimated_valid_cost 120 | test_acc = info['test-accuracy'] 121 | test_acc_avg = nas_bench.get_more_info(arch_index, dataname, None, False, False)['test-accuracy'] 122 | valid_acc_avg = nas_bench.get_more_info(arch_index, dataname, None, False, False)['valid-accuracy'] 123 | else: 124 | # train a model from scratch. 125 | raise ValueError('NOT IMPLEMENT YET') 126 | return valid_acc, valid_acc_avg, time_cost, test_acc, test_acc_avg 127 | 128 | 129 | def enumerate_dataset(dataset): 130 | for k in range(len(nas_bench)): 131 | print('{}: {}/{}'.format(dataset, k,len(nas_bench))) 132 | res = info2mat(k) 133 | yield {k:res[dataset]} 134 | 135 | def gen_json_file(dataset): 136 | data_dict = OrderedDict() 137 | enum_dataset = enumerate_dataset(dataset) 138 | for data_point in enum_dataset: 139 | data_dict.update(data_point) 140 | with open('data/{}.json'.format(dataset), 'w') as outfile: 141 | json.dump(data_dict, outfile) 142 | 143 | if __name__=='__main__': 144 | 145 | for dataset in ['cifar10_valid_converged', 'cifar100', 'ImageNet16_120']: 146 | gen_json_file(dataset) 147 | -------------------------------------------------------------------------------- /pybnn/__init__.py: -------------------------------------------------------------------------------- 1 | from pybnn.dngo import DNGO 2 | from pybnn.bayesian_linear_regression import BayesianLinearRegression 3 | from pybnn.base_model import BaseModel 4 | -------------------------------------------------------------------------------- /pybnn/base_model.py: -------------------------------------------------------------------------------- 1 | import abc 2 | import numpy as np 3 | 4 | 5 | class BaseModel(object): 6 | __metaclass__ = abc.ABCMeta 7 | 8 | def __init__(self): 9 | """ 10 | Abstract base class for all models 11 | """ 12 | self.X = None 13 | self.y = None 14 | 15 | @abc.abstractmethod 16 | def train(self, X, y): 17 | """ 18 | Trains the model on the provided data. 19 | 20 | Parameters 21 | ---------- 22 | X: np.ndarray (N, D) 23 | Input data points. The dimensionality of X is (N, D), 24 | with N as the number of points and D is the number of input dimensions. 25 | y: np.ndarray (N,) 26 | The corresponding target values of the input data points. 27 | """ 28 | pass 29 | 30 | def update(self, X, y): 31 | """ 32 | Update the model with the new additional data. Override this function if your 33 | model allows to do something smarter than simple retraining 34 | 35 | Parameters 36 | ---------- 37 | X: np.ndarray (N, D) 38 | Input data points. The dimensionality of X is (N, D), 39 | with N as the number of points and D is the number of input dimensions. 40 | y: np.ndarray (N,) 41 | The corresponding target values of the input data points. 42 | """ 43 | X = np.append(self.X, X, axis=0) 44 | y = np.append(self.y, y, axis=0) 45 | self.train(X, y) 46 | 47 | @abc.abstractmethod 48 | def predict(self, X_test): 49 | """ 50 | Predicts for a given set of test data points the mean and variance of its target values 51 | 52 | Parameters 53 | ---------- 54 | X_test: np.ndarray (N, D) 55 | N Test data points with input dimensions D 56 | 57 | Returns 58 | ---------- 59 | mean: ndarray (N,) 60 | Predictive mean of the test data points 61 | var: ndarray (N,) 62 | Predictive variance of the test data points 63 | """ 64 | pass 65 | 66 | def _check_shapes_train(func): 67 | def func_wrapper(self, X, y, *args, **kwargs): 68 | assert X.shape[0] == y.shape[0] 69 | assert len(X.shape) == 2 70 | assert len(y.shape) == 1 71 | return func(self, X, y, *args, **kwargs) 72 | return func_wrapper 73 | 74 | def _check_shapes_predict(func): 75 | def func_wrapper(self, X, *args, **kwargs): 76 | assert len(X.shape) == 2 77 | return func(self, X, *args, **kwargs) 78 | 79 | return func_wrapper 80 | 81 | def get_json_data(self): 82 | """ 83 | Json getter function' 84 | 85 | Returns 86 | ---------- 87 | dictionary 88 | """ 89 | json_data = {'X': self.X if self.X is None else self.X.tolist(), 90 | 'y': self.y if self.y is None else self.y.tolist(), 91 | 'hyperparameters': ""} 92 | return json_data 93 | 94 | def get_incumbent(self): 95 | """ 96 | Returns the best observed point and its function value 97 | 98 | Returns 99 | ---------- 100 | incumbent: ndarray (D,) 101 | current incumbent 102 | incumbent_value: ndarray (N,) 103 | the observed value of the incumbent 104 | """ 105 | best_idx = np.argmin(self.y) 106 | return self.X[best_idx], self.y[best_idx] 107 | -------------------------------------------------------------------------------- /pybnn/bayesian_linear_regression.py: -------------------------------------------------------------------------------- 1 | import emcee 2 | import logging 3 | import numpy as np 4 | 5 | from scipy import optimize 6 | from scipy import stats 7 | 8 | from pybnn.base_model import BaseModel 9 | 10 | 11 | def linear_basis_func(x): 12 | return np.append(x, np.ones([x.shape[0], 1]), axis=1) 13 | 14 | 15 | def quadratic_basis_func(x): 16 | x = np.append(x ** 2, x, axis=1) 17 | return np.append(x, np.ones([x.shape[0], 1]), axis=1) 18 | 19 | 20 | logger = logging.getLogger(__name__) 21 | 22 | 23 | class Prior(object): 24 | 25 | def __init__(self, rng=None): 26 | if rng is None: 27 | self.rng = np.random.RandomState(np.random.randint(0, 10000)) 28 | else: 29 | self.rng = rng 30 | 31 | def lnprob(self, theta): 32 | """ 33 | Compute the log probability for theta = [log alpha, log beta] 34 | :param theta: 35 | :return: log p(theta) 36 | """ 37 | lp = 0 38 | lp += stats.norm.pdf(theta[0], loc=0, scale=1) # log alpha 39 | lp += stats.norm.pdf(theta[1], loc=0, scale=1) # log sigma^2 40 | 41 | return lp 42 | 43 | def sample_from_prior(self, n_samples): 44 | p0 = np.zeros([n_samples, 2]) 45 | 46 | # Log alpha 47 | p0[:, 0] = self.rng.normal(loc=0, 48 | scale=1, 49 | size=n_samples) 50 | 51 | # Log sigma^2 52 | p0[:, 1] = self.rng.normal(loc=-3, 53 | scale=1, 54 | size=n_samples) 55 | return p0 56 | 57 | 58 | class BayesianLinearRegression(BaseModel): 59 | 60 | def __init__(self, alpha=1, beta=1000, basis_func=linear_basis_func, 61 | prior=None, do_mcmc=True, n_hypers=20, chain_length=2000, 62 | burnin_steps=2000, rng=None): 63 | """ 64 | Implementation of Bayesian linear regression. See chapter 3.3 of the book 65 | "Pattern Recognition and Machine Learning" by Bishop for more details. 66 | 67 | Parameters 68 | ---------- 69 | alpha: float 70 | Specifies the variance of the prior for the weights w 71 | beta : float 72 | Defines the inverse of the noise, i.e. beta = 1 / sigma^2 73 | basis_func : function 74 | Function handle to transfer the input with via basis functions 75 | (see the code above for an example) 76 | prior: Prior object 77 | Prior for alpha and beta. If set to None the default prior is used 78 | do_mcmc: bool 79 | If set to true different values for alpha and beta are sampled via MCMC from the marginal log likelihood 80 | Otherwise the marginal log likelihood is optimized with scipy fmin function 81 | n_hypers : int 82 | Number of samples for alpha and beta 83 | chain_length : int 84 | The chain length of the MCMC sampler 85 | burnin_steps: int 86 | The number of burnin steps before the sampling procedure starts 87 | rng: np.random.RandomState 88 | Random number generator 89 | """ 90 | 91 | if rng is None: 92 | self.rng = np.random.RandomState(np.random.randint(0, 10000)) 93 | else: 94 | self.rng = rng 95 | 96 | self.X = None 97 | self.y = None 98 | self.alpha = alpha 99 | self.beta = beta 100 | self.basis_func = basis_func 101 | if prior is None: 102 | self.prior = Prior(rng=self.rng) 103 | else: 104 | self.prior = prior 105 | self.do_mcmc = do_mcmc 106 | self.n_hypers = n_hypers 107 | self.chain_length = chain_length 108 | self.burned = False 109 | self.burnin_steps = burnin_steps 110 | self.models = None 111 | 112 | def marginal_log_likelihood(self, theta): 113 | """ 114 | Log likelihood of the data marginalised over the weights w. See chapter 3.5 of 115 | the book by Bishop of an derivation. 116 | 117 | Parameters 118 | ---------- 119 | theta: np.array(2,) 120 | The hyperparameter alpha and beta on a log scale 121 | 122 | Returns 123 | ------- 124 | float 125 | lnlikelihood + prior 126 | """ 127 | 128 | # Theta is on a log scale 129 | alpha = np.exp(theta[0]) 130 | beta = 1 / np.exp(theta[1]) 131 | 132 | D = self.X_transformed.shape[1] 133 | N = self.X_transformed.shape[0] 134 | 135 | A = beta * np.dot(self.X_transformed.T, self.X_transformed) 136 | A += np.eye(self.X_transformed.shape[1]) * alpha 137 | try: 138 | A_inv = np.linalg.inv(A) 139 | except np.linalg.linalg.LinAlgError: 140 | A_inv = np.linalg.inv(A + np.random.rand(A.shape[0], A.shape[1]) * 1e-8) 141 | 142 | 143 | m = beta * np.dot(A_inv, self.X_transformed.T) 144 | m = np.dot(m, self.y) 145 | 146 | mll = D / 2 * np.log(alpha) 147 | mll += N / 2 * np.log(beta) 148 | mll -= N / 2 * np.log(2 * np.pi) 149 | mll -= beta / 2. * np.linalg.norm(self.y - np.dot(self.X_transformed, m), 2) 150 | mll -= alpha / 2. * np.dot(m.T, m) 151 | mll -= 0.5 * np.log(np.linalg.det(A)) 152 | 153 | if self.prior is not None: 154 | mll += self.prior.lnprob(theta) 155 | 156 | return mll 157 | 158 | def negative_mll(self, theta): 159 | """ 160 | Returns the negative marginal log likelihood (for optimizing it with scipy). 161 | 162 | Parameters 163 | ---------- 164 | theta: np.array(2,) 165 | The hyperparameter alpha and beta on a log scale 166 | 167 | Returns 168 | ------- 169 | float 170 | negative lnlikelihood + prior 171 | """ 172 | return -self.marginal_log_likelihood(theta) 173 | 174 | @BaseModel._check_shapes_train 175 | def train(self, X, y, do_optimize=True): 176 | """ 177 | First optimized the hyperparameters if do_optimize is True and then computes 178 | the posterior distribution of the weights. See chapter 3.3 of the book by Bishop 179 | for more details. 180 | 181 | Parameters 182 | ---------- 183 | X: np.ndarray (N, D) 184 | Input data points. The dimensionality of X is (N, D), 185 | with N as the number of points and D is the number of features. 186 | y: np.ndarray (N,) 187 | The corresponding target values. 188 | do_optimize: boolean 189 | If set to true the hyperparameters are optimized otherwise 190 | the default hyperparameters are used. 191 | """ 192 | 193 | self.X = X 194 | 195 | if self.basis_func is not None: 196 | self.X_transformed = self.basis_func(X) 197 | else: 198 | self.X_transformed = self.X 199 | 200 | self.y = y 201 | 202 | if do_optimize: 203 | if self.do_mcmc: 204 | sampler = emcee.EnsembleSampler(self.n_hypers, 2, 205 | self.marginal_log_likelihood) 206 | 207 | # Do a burn-in in the first iteration 208 | if not self.burned: 209 | # Initialize the walkers by sampling from the prior 210 | self.p0 = self.prior.sample_from_prior(self.n_hypers) 211 | 212 | # Run MCMC sampling 213 | self.p0, _, _ = sampler.run_mcmc(self.p0, 214 | self.burnin_steps, 215 | rstate0=self.rng) 216 | 217 | self.burned = True 218 | 219 | # Start sampling 220 | pos, _, _ = sampler.run_mcmc(self.p0, 221 | self.chain_length, 222 | rstate0=self.rng) 223 | 224 | # Save the current position, it will be the start point in 225 | # the next iteration 226 | self.p0 = pos 227 | 228 | # Take the last samples from each walker 229 | self.hypers = np.exp(sampler.chain[:, -1]) 230 | else: 231 | # Optimize hyperparameters of the Bayesian linear regression 232 | res = optimize.fmin(self.negative_mll, self.rng.rand(2)) 233 | self.hypers = [[np.exp(res[0]), np.exp(res[1])]] 234 | 235 | else: 236 | self.hypers = [[self.alpha, self.beta]] 237 | 238 | self.models = [] 239 | for sample in self.hypers: 240 | alpha = sample[0] 241 | beta = sample[1] 242 | 243 | logger.debug("Alpha=%f ; Beta=%f" % (alpha, beta)) 244 | 245 | S_inv = beta * np.dot(self.X_transformed.T, self.X_transformed) 246 | S_inv += np.eye(self.X_transformed.shape[1]) * alpha 247 | try: 248 | S = np.linalg.inv(S_inv) 249 | except np.linalg.linalg.LinAlgError: 250 | S = np.linalg.inv(S_inv + np.random.rand(S_inv.shape[0], S_inv.shape[1]) * 1e-8) 251 | 252 | m = beta * np.dot(np.dot(S, self.X_transformed.T), self.y) 253 | 254 | self.models.append((m, S)) 255 | 256 | @BaseModel._check_shapes_predict 257 | def predict(self, X_test): 258 | r""" 259 | Returns the predictive mean and variance of the objective function at 260 | the given test points. 261 | 262 | Parameters 263 | ---------- 264 | X_test: np.ndarray (N, D) 265 | N input test points 266 | 267 | Returns 268 | ---------- 269 | np.array(N,) 270 | predictive mean 271 | np.array(N,) 272 | predictive variance 273 | 274 | """ 275 | if self.basis_func is not None: 276 | X_transformed = self.basis_func(X_test) 277 | else: 278 | X_transformed = X_test 279 | 280 | # Marginalise predictions over hyperparameters 281 | mu = np.zeros([len(self.hypers), X_transformed.shape[0]]) 282 | var = np.zeros([len(self.hypers), X_transformed.shape[0]]) 283 | 284 | for i, h in enumerate(self.hypers): 285 | mu[i] = np.dot(self.models[i][0].T, X_transformed.T) 286 | var[i] = 1. / h[1] + np.diag(np.dot(np.dot(X_transformed, self.models[i][1]), X_transformed.T)) 287 | 288 | m = mu.mean(axis=0) 289 | v = var.mean(axis=0) 290 | # Clip negative variances and set them to the smallest 291 | # positive float value 292 | if v.shape[0] == 1: 293 | v = np.clip(v, np.finfo(v.dtype).eps, np.inf) 294 | else: 295 | v = np.clip(v, np.finfo(v.dtype).eps, np.inf) 296 | v[np.where((v < np.finfo(v.dtype).eps) & (v > -np.finfo(v.dtype).eps))] = 0 297 | 298 | return m, v 299 | -------------------------------------------------------------------------------- /pybnn/util/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AIoT-MLSys-Lab/arch2vec/ea01b0cf1295305596ee3c05fa1b6eb14e303512/pybnn/util/__init__.py -------------------------------------------------------------------------------- /pybnn/util/normalization.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | def zero_one_normalization(X, lower=None, upper=None): 5 | 6 | if lower is None: 7 | lower = np.min(X, axis=0) 8 | if upper is None: 9 | upper = np.max(X, axis=0) 10 | 11 | X_normalized = np.true_divide((X - lower), (upper - lower)) 12 | 13 | return X_normalized, lower, upper 14 | 15 | 16 | def zero_one_denormalization(X_normalized, lower, upper): 17 | return lower + (upper - lower) * X_normalized 18 | 19 | 20 | def zero_mean_unit_var_normalization(X, mean=None, std=None): 21 | if mean is None: 22 | mean = np.mean(X, axis=0) 23 | if std is None: 24 | std = np.std(X, axis=0) 25 | 26 | X_normalized = (X - mean) / std 27 | 28 | return X_normalized, mean, std 29 | 30 | 31 | def zero_mean_unit_var_denormalization(X_normalized, mean, std): 32 | return X_normalized * std + mean 33 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch == 1.4.0 2 | torchvision == 0.5.0 3 | tensorflow == 1.15.0 4 | emcee == 3.0.2 5 | tqdm == 4.31.1 6 | networkx == 2.2 7 | graphviz == 0.14.2 8 | thop == 0.0.31.post2004101309 9 | texttable == 1.6.3 10 | python-igraph == 0.8.3 11 | 12 | -------------------------------------------------------------------------------- /run_scripts/extract_arch2vec.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | python search_methods/reinforce.py --dim 16 --model_path model-nasbench101.pt 3 | 4 | -------------------------------------------------------------------------------- /run_scripts/extract_arch2vec_darts.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | python search_methods/reinforce_darts.py --dim 16 3 | -------------------------------------------------------------------------------- /run_scripts/extract_arch2vec_nasbench201.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | python search_methods/reinforce_search_NB201_8x8.py --dataset_name cifar10_valid_converged --latent_dim 16 --model_path model-nasbench201.pt 4 | 5 | python search_methods/reinforce_search_NB201_8x8.py --dataset_name cifar100 --latent_dim 16 --model_path model-nasbench201.pt 6 | 7 | python search_methods/reinforce_search_NB201_8x8.py --dataset_name ImageNet16_120 --latent_dim 16 --model_path model-nasbench201.pt 8 | 9 | -------------------------------------------------------------------------------- /run_scripts/run_bo_arch2vec_darts.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | python search_methods/dngo_darts.py --max_budgets 100 --inner_epochs 50 --objective 0.95 --train_portion 0.9 --dim 16 --seed 3 --output_path saved_logs/bo --init_size 16 --batch_size 5 --logging_path darts-bo 3 | -------------------------------------------------------------------------------- /run_scripts/run_bo_arch2vec_nasbench201_ImageNet.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | for i in {16,} 3 | do 4 | for s in {1..500} 5 | do 6 | python search_methods/dngo_search_NB201_8x8.py --dim $i --seed $s --output_path saved_logs/bo --init_size 16 --batch_size 1 \ 7 | --dataset_name ImageNet16_120 --MAX_BUDGET 1400000 8 | done 9 | done 10 | -------------------------------------------------------------------------------- /run_scripts/run_bo_arch2vec_nasbench201_cifar100.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | for i in {16,} 3 | do 4 | for s in {1..500} 5 | do 6 | python search_methods/dngo_search_NB201_8x8.py --dim $i --seed $s --output_path saved_logs/bo --init_size 16 --batch_size 1 \ 7 | --dataset_name cifar100 --MAX_BUDGET 500000 8 | done 9 | done 10 | -------------------------------------------------------------------------------- /run_scripts/run_bo_arch2vec_nasbench201_cifar10_valid.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | for i in {16,} 3 | do 4 | for s in {1..500} 5 | do 6 | python search_methods/dngo_search_NB201_8x8.py --dim $i --seed $s --output_path saved_logs/bo --init_size 16 --batch_size 1 \ 7 | --dataset_name cifar10_valid_converged --MAX_BUDGET 12000 8 | done 9 | done 10 | -------------------------------------------------------------------------------- /run_scripts/run_dngo_arch2vec.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | for s in {1..500} 3 | do 4 | python search_methods/dngo.py --dim 16 --seed $s --output_path saved_logs/bo --emb_path arch2vec-model-nasbench101.pt --init_size 16 --topk 5 5 | done 6 | -------------------------------------------------------------------------------- /run_scripts/run_dngo_supervised.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | for s in {1..500} 4 | do 5 | python search_methods/supervised_dngo.py --dim 16 --seed $s --init_size 16 --topk 5 --output_path saved_logs/bo 6 | done 7 | -------------------------------------------------------------------------------- /run_scripts/run_reinforce_arch2vec.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | #python search_methods/reinforce.py --dim 16 --seed $s --bs 16 --output_path saved_logs/rl --saved_arch2vec --emb_path arch2vec-nasbench101.pt 4 | for s in {1..500} 5 | do 6 | python search_methods/reinforce.py --dim 16 --seed $s --bs 16 --output_path saved_logs/rl --saved_arch2vec --emb_path arch2vec-model-nasbench101.pt 7 | done 8 | -------------------------------------------------------------------------------- /run_scripts/run_reinforce_arch2vec_darts.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | python search_methods/reinforce_darts.py --max_budgets 100 --inner_epochs 50 --objective 0.95 --train_portion 0.9 --dim 16 --seed 3 --bs 16 --output_path saved_logs/rl --saved_arch2vec --logging_path darts-rl 3 | 4 | -------------------------------------------------------------------------------- /run_scripts/run_reinforce_arch2vec_nasbench201_ImageNet.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | for i in {16,} 3 | do 4 | for s in {1..500} 5 | do 6 | python search_methods/reinforce_search_NB201_8x8.py --latent_dim $i --seed $s --bs 16 --gamma 0.4 --baseline 0.4 \ 7 | --output_path saved_logs/rl --saved_arch2vec \ 8 | --dataset_name ImageNet16_120 --MAX_BUDGET 1400000 --model_path model-nasbench201.pt 9 | done 10 | done 11 | -------------------------------------------------------------------------------- /run_scripts/run_reinforce_arch2vec_nasbench201_cifar100.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | for i in {16,} 3 | do 4 | for s in {1..500} 5 | do 6 | python search_methods/reinforce_search_NB201_8x8.py --latent_dim $i --seed $s --bs 16 --MAX_BUDGET 500000 --baseline 0.4 --gamma 0.4 --saved_arch2vec \ 7 | --dataset_name cifar100 --output_path saved_logs/rl --model_path model-nasbench201.pt 8 | done 9 | done 10 | -------------------------------------------------------------------------------- /run_scripts/run_reinforce_arch2vec_nasbench201_cifar10_valid.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | for i in {16,} 3 | do 4 | for s in {1..500} 5 | do 6 | python search_methods/reinforce_search_NB201_8x8.py --latent_dim $i --seed $s --bs 16 --gamma 0.4 --baseline 0.4 \ 7 | --output_path saved_logs/rl --saved_arch2vec \ 8 | --dataset_name cifar10_valid_converged --MAX_BUDGET 12000 --model_path model-nasbench201.pt 9 | done 10 | done 11 | -------------------------------------------------------------------------------- /run_scripts/run_reinforce_supervised.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | for s in {1..500} 4 | do 5 | python search_methods/supervised_reinforce.py --dim 16 --seed $s --bs 16 --output_path saved_logs/rl 6 | done 7 | -------------------------------------------------------------------------------- /search_methods/dngo.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | sys.path.insert(0, os.getcwd()) 4 | from pybnn.dngo import DNGO 5 | import argparse 6 | import json 7 | import torch 8 | import numpy as np 9 | from collections import defaultdict 10 | from torch.distributions import Normal 11 | 12 | 13 | def load_arch2vec(embedding_path): 14 | embedding = torch.load(embedding_path) 15 | print('load arch2vec from {}'.format(embedding_path)) 16 | ind_list = range(len(embedding)) 17 | features = [embedding[ind]['feature'] for ind in ind_list] 18 | valid_labels = [embedding[ind]['valid_accuracy'] for ind in ind_list] 19 | test_labels = [embedding[ind]['test_accuracy'] for ind in ind_list] 20 | training_time = [embedding[ind]['time'] for ind in ind_list] 21 | features = torch.stack(features, dim=0) 22 | test_labels = torch.Tensor(test_labels) 23 | valid_labels = torch.Tensor(valid_labels) 24 | training_time = torch.Tensor(training_time) 25 | print('loading finished. pretrained embeddings shape {}'.format(features.shape)) 26 | return features, valid_labels, test_labels, training_time 27 | 28 | 29 | def get_init_samples(features, valid_labels, test_labels, training_time, visited): 30 | np.random.seed(args.seed) 31 | init_inds = np.random.permutation(list(range(features.shape[0])))[:args.init_size] 32 | init_inds = torch.Tensor(init_inds).long() 33 | init_feat_samples = features[init_inds] 34 | init_valid_label_samples = valid_labels[init_inds] 35 | init_test_label_samples = test_labels[init_inds] 36 | init_time_samples = training_time[init_inds] 37 | for idx in init_inds: 38 | visited[idx] = True 39 | return init_feat_samples, init_valid_label_samples, init_test_label_samples, init_time_samples, visited 40 | 41 | 42 | def propose_location(ei, features, valid_labels, test_labels, training_time, visited): 43 | k = args.topk 44 | print('remaining length of indices set:', len(features) - len(visited)) 45 | indices = torch.argsort(ei)[-k:] 46 | ind_dedup = [] 47 | for idx in indices: 48 | if idx not in visited: 49 | visited[idx] = True 50 | ind_dedup.append(idx) 51 | ind_dedup = torch.Tensor(ind_dedup).long() 52 | proposed_x, proposed_y_valid, proposed_y_test, proposed_time = features[ind_dedup], valid_labels[ind_dedup], test_labels[ind_dedup], training_time[ind_dedup] 53 | return proposed_x, proposed_y_valid, proposed_y_test, proposed_time, visited 54 | 55 | 56 | def expected_improvement_search(): 57 | """ implementation of arch2vec-DNGO """ 58 | BEST_TEST_ACC = 0.943175752957662 59 | BEST_VALID_ACC = 0.9505542318026224 60 | CURR_BEST_VALID = 0. 61 | CURR_BEST_TEST = 0. 62 | MAX_BUDGET = 1.5e6 63 | window_size = 200 64 | counter = 0 65 | rt = 0. 66 | visited = {} 67 | best_trace = defaultdict(list) 68 | features, valid_labels, test_labels, training_time = load_arch2vec(os.path.join('pretrained/dim-{}'.format(args.dim), args.emb_path)) 69 | features, valid_labels, test_labels, training_time = features.cpu().detach(), valid_labels.cpu().detach(), test_labels.cpu().detach(), training_time.cpu().detach() 70 | feat_samples, valid_label_samples, test_label_samples, time_samples, visited = get_init_samples(features, valid_labels, test_labels, training_time, visited) 71 | 72 | for feat, acc_valid, acc_test, t in zip(feat_samples, valid_label_samples, test_label_samples, time_samples): 73 | counter += 1 74 | rt += t.item() 75 | if acc_valid > CURR_BEST_VALID: 76 | CURR_BEST_VALID = acc_valid 77 | CURR_BEST_TEST = acc_test 78 | best_trace['regret_validation'].append(float(BEST_VALID_ACC - CURR_BEST_VALID)) 79 | best_trace['regret_test'].append(float(BEST_TEST_ACC - CURR_BEST_TEST)) 80 | best_trace['time'].append(rt) 81 | best_trace['counter'].append(counter) 82 | 83 | while rt < MAX_BUDGET: 84 | print("feat_samples:", feat_samples.shape) 85 | print("valid label_samples:", valid_label_samples.shape) 86 | print("test label samples:", test_label_samples.shape) 87 | print("current best validation: {}".format(CURR_BEST_VALID)) 88 | print("current best test: {}".format(CURR_BEST_TEST)) 89 | print("rt: {}".format(rt)) 90 | print(feat_samples.shape) 91 | print(valid_label_samples.shape) 92 | model = DNGO(num_epochs=100, n_units=128, do_mcmc=False, normalize_output=False, rng=args.seed) 93 | model.train(X=feat_samples.numpy(), y=valid_label_samples.view(-1).numpy(), do_optimize=True) 94 | print(model.network) 95 | m = [] 96 | v = [] 97 | chunks = int(features.shape[0] / window_size) 98 | if features.shape[0] % window_size > 0: 99 | chunks += 1 100 | features_split = torch.split(features, window_size, dim=0) 101 | for i in range(chunks): 102 | m_split, v_split = model.predict(features_split[i].numpy()) 103 | m.extend(list(m_split)) 104 | v.extend(list(v_split)) 105 | mean = torch.Tensor(m) 106 | sigma = torch.Tensor(v) 107 | u = (mean - torch.Tensor([0.95]).expand_as(mean)) / sigma 108 | normal = Normal(torch.zeros_like(u), torch.ones_like(u)) 109 | ucdf = normal.cdf(u) 110 | updf = torch.exp(normal.log_prob(u)) 111 | ei = sigma * (updf + u * ucdf) 112 | feat_next, label_next_valid, label_next_test, time_next, visited = propose_location(ei, features, valid_labels, test_labels, training_time, visited) 113 | 114 | # add proposed networks to the pool 115 | for feat, acc_valid, acc_test, t in zip(feat_next, label_next_valid, label_next_test, time_next): 116 | if acc_valid > CURR_BEST_VALID: 117 | CURR_BEST_VALID = acc_valid 118 | CURR_BEST_TEST = acc_test 119 | feat_samples = torch.cat((feat_samples, feat.view(1, -1)), dim=0) 120 | valid_label_samples = torch.cat((valid_label_samples.view(-1, 1), acc_valid.view(1, 1)), dim=0) 121 | test_label_samples = torch.cat((test_label_samples.view(-1, 1), acc_test.view(1, 1)), dim=0) 122 | counter += 1 123 | rt += t.item() 124 | best_trace['regret_validation'].append(float(BEST_VALID_ACC - CURR_BEST_VALID)) 125 | best_trace['regret_test'].append(float(BEST_TEST_ACC - CURR_BEST_TEST)) 126 | best_trace['time'].append(rt) 127 | best_trace['counter'].append(counter) 128 | if rt >= MAX_BUDGET: 129 | break 130 | 131 | res = dict() 132 | res['regret_validation'] = best_trace['regret_validation'] 133 | res['regret_test'] = best_trace['regret_test'] 134 | res['runtime'] = best_trace['time'] 135 | res['counter'] = best_trace['counter'] 136 | save_path = os.path.join(args.output_path, 'dim{}'.format(args.dim)) 137 | if not os.path.exists(save_path): 138 | os.mkdir(save_path) 139 | print('save to {}'.format(save_path)) 140 | if args.emb_path.endswith('.pt'): 141 | s = args.emb_path[:-3] 142 | fh = open(os.path.join(save_path, 'run_{}_{}.json'.format(args.seed, s)),'w') 143 | json.dump(res, fh) 144 | fh.close() 145 | 146 | 147 | if __name__ == '__main__': 148 | parser = argparse.ArgumentParser(description="arch2vec-DNGO") 149 | parser.add_argument("--seed", type=int, default=1, help="random seed") 150 | parser.add_argument('--cfg', type=int, default=4, help='configuration (default: 4)') 151 | parser.add_argument('--dim', type=int, default=16, help='feature dimension') 152 | parser.add_argument('--init_size', type=int, default=16, help='init samples') 153 | parser.add_argument('--topk', type=int, default=5, help='acquisition samples') 154 | parser.add_argument('--output_path', type=str, default='bo', help='bo') 155 | parser.add_argument('--emb_path', type=str, default='arch2vec.pt') 156 | args = parser.parse_args() 157 | np.random.seed(args.seed) 158 | torch.manual_seed(args.seed) 159 | torch.cuda.manual_seed_all(args.seed) 160 | torch.set_num_threads(2) 161 | expected_improvement_search() 162 | -------------------------------------------------------------------------------- /search_methods/dngo_darts.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | sys.path.insert(0, os.getcwd()) 4 | from pybnn.dngo import DNGO 5 | import random 6 | import argparse 7 | import json 8 | import torch 9 | import numpy as np 10 | from collections import defaultdict 11 | from torch.distributions import Normal 12 | from darts.cnn.train_search import Train 13 | 14 | def load_arch2vec(embedding_path): 15 | embedding = torch.load(embedding_path) 16 | print('load arch2vec from {}'.format(embedding_path)) 17 | ind_list = range(len(embedding)) 18 | features = [embedding[ind]['feature'] for ind in ind_list] 19 | genotype = [embedding[ind]['genotype'] for ind in ind_list] 20 | features = torch.stack(features, dim=0) 21 | print('loading finished. pretrained embeddings shape {}'.format(features.shape)) 22 | return features, genotype 23 | 24 | 25 | def query(counter, seed, genotype, epochs): 26 | trainer = Train() 27 | rewards, rewards_test = trainer.main(counter, seed, genotype, epochs=epochs, train_portion=args.train_portion, save=args.logging_path) 28 | val_sum = 0 29 | for epoch, val_acc in rewards: 30 | val_sum += val_acc 31 | val_avg = val_sum / len(rewards) 32 | return val_avg / 100., rewards_test[-1][-1] / 100. 33 | 34 | def get_init_samples(features, genotype, visited): 35 | count = 0 36 | np.random.seed(args.seed) 37 | init_inds = np.random.permutation(list(range(features.shape[0])))[:args.init_size] 38 | init_inds = torch.Tensor(init_inds).long() 39 | print('init index: {}'.format(init_inds)) 40 | init_feat_samples = features[init_inds] 41 | init_geno_samples = [genotype[i.item()] for i in init_inds] 42 | init_valid_label_samples = [] 43 | init_test_label_samples = [] 44 | 45 | for geno in init_geno_samples: 46 | val_acc, test_acc = query(count, args.seed, geno, args.inner_epochs) 47 | init_valid_label_samples.append(val_acc) 48 | init_test_label_samples.append(test_acc) 49 | count += 1 50 | 51 | init_valid_label_samples = torch.Tensor(init_valid_label_samples) 52 | init_test_label_samples = torch.Tensor(init_test_label_samples) 53 | for idx in init_inds: 54 | visited[idx.item()] = True 55 | return init_feat_samples, init_geno_samples, init_valid_label_samples, init_test_label_samples, visited 56 | 57 | 58 | def propose_location(ei, features, genotype, visited, counter): 59 | count = counter 60 | k = args.batch_size 61 | c = 0 62 | print('remaining length of indices set:', len(features) - len(visited)) 63 | indices = torch.argsort(ei) 64 | ind_dedup = [] 65 | # remove random sampled indices at each step 66 | for idx in reversed(indices): 67 | if c == k: 68 | break 69 | if idx.item() not in visited: 70 | visited[idx.item()] = True 71 | ind_dedup.append(idx.item()) 72 | c += 1 73 | ind_dedup = torch.Tensor(ind_dedup).long() 74 | print('proposed index: {}'.format(ind_dedup)) 75 | proposed_x = features[ind_dedup] 76 | proposed_geno = [genotype[i.item()] for i in ind_dedup] 77 | proposed_val_acc = [] 78 | proposed_test_acc = [] 79 | for geno in proposed_geno: 80 | val_acc, test_acc = query(count, args.seed, geno, args.inner_epochs) 81 | proposed_val_acc.append(val_acc) 82 | proposed_test_acc.append(test_acc) 83 | count += 1 84 | 85 | return proposed_x, proposed_geno, torch.Tensor(proposed_val_acc), torch.Tensor(proposed_test_acc), visited 86 | 87 | 88 | def expected_improvement_search(features, genotype): 89 | """ implementation of arch2vec-DNGO on DARTS Search Space """ 90 | CURR_BEST_VALID = 0. 91 | CURR_BEST_TEST = 0. 92 | CURR_BEST_GENOTYPE = None 93 | MAX_BUDGET = args.max_budgets 94 | window_size = 200 95 | counter = 0 96 | visited = {} 97 | best_trace = defaultdict(list) 98 | 99 | features, genotype = features.cpu().detach(), genotype 100 | feat_samples, geno_samples, valid_label_samples, test_label_samples, visited = get_init_samples(features, genotype, visited) 101 | 102 | for feat, geno, acc_valid, acc_test in zip(feat_samples, geno_samples, valid_label_samples, test_label_samples): 103 | counter += 1 104 | if acc_valid > CURR_BEST_VALID: 105 | CURR_BEST_VALID = acc_valid 106 | CURR_BEST_TEST = acc_test 107 | CURR_BEST_GENOTYPE = geno 108 | best_trace['validation_acc'].append(float(CURR_BEST_VALID)) 109 | best_trace['test_acc'].append(float(CURR_BEST_TEST)) 110 | best_trace['genotype'].append(CURR_BEST_GENOTYPE) 111 | best_trace['counter'].append(counter) 112 | 113 | while counter < MAX_BUDGET: 114 | print("feat_samples:", feat_samples.shape) 115 | print("length of genotypes:", len(geno_samples)) 116 | print("valid label_samples:", valid_label_samples.shape) 117 | print("test label samples:", test_label_samples.shape) 118 | print("current best validation: {}".format(CURR_BEST_VALID)) 119 | print("current best test: {}".format(CURR_BEST_TEST)) 120 | print("counter: {}".format(counter)) 121 | print(feat_samples.shape) 122 | print(valid_label_samples.shape) 123 | model = DNGO(num_epochs=100, n_units=128, do_mcmc=False, normalize_output=False) 124 | model.train(X=feat_samples.numpy(), y=valid_label_samples.view(-1).numpy(), do_optimize=True) 125 | print(model.network) 126 | m = [] 127 | v = [] 128 | chunks = int(features.shape[0] / window_size) 129 | if features.shape[0] % window_size > 0: 130 | chunks += 1 131 | features_split = torch.split(features, window_size, dim=0) 132 | for i in range(chunks): 133 | m_split, v_split = model.predict(features_split[i].numpy()) 134 | m.extend(list(m_split)) 135 | v.extend(list(v_split)) 136 | mean = torch.Tensor(m) 137 | sigma = torch.Tensor(v) 138 | u = (mean - torch.Tensor([args.objective]).expand_as(mean)) / sigma 139 | normal = Normal(torch.zeros_like(u), torch.ones_like(u)) 140 | ucdf = normal.cdf(u) 141 | updf = torch.exp(normal.log_prob(u)) 142 | ei = sigma * (updf + u * ucdf) 143 | feat_next, geno_next, label_next_valid, label_next_test, visited = propose_location(ei, features, genotype, visited, counter) 144 | 145 | # add proposed networks to the pool 146 | for feat, geno, acc_valid, acc_test in zip(feat_next, geno_next, label_next_valid, label_next_test): 147 | feat_samples = torch.cat((feat_samples, feat.view(1, -1)), dim=0) 148 | geno_samples.append(geno) 149 | valid_label_samples = torch.cat((valid_label_samples.view(-1, 1), acc_valid.view(1, 1)), dim=0) 150 | test_label_samples = torch.cat((test_label_samples.view(-1, 1), acc_test.view(1, 1)), dim=0) 151 | counter += 1 152 | if acc_valid.item() > CURR_BEST_VALID: 153 | CURR_BEST_VALID = acc_valid.item() 154 | CURR_BEST_TEST = acc_test.item() 155 | CURR_BEST_GENOTYPE = geno 156 | 157 | best_trace['validation_acc'].append(float(CURR_BEST_VALID)) 158 | best_trace['test_acc'].append(float(CURR_BEST_TEST)) 159 | best_trace['genotype'].append(CURR_BEST_GENOTYPE) 160 | best_trace['counter'].append(counter) 161 | 162 | if counter >= MAX_BUDGET: 163 | break 164 | 165 | res = dict() 166 | res['validation_acc'] = best_trace['validation_acc'] 167 | res['test_acc'] = best_trace['test_acc'] 168 | res['genotype'] = best_trace['genotype'] 169 | res['counter'] = best_trace['counter'] 170 | save_path = os.path.join(args.output_path, 'dim{}'.format(args.dim)) 171 | if not os.path.exists(save_path): 172 | os.mkdir(save_path) 173 | print('save to {}'.format(save_path)) 174 | fh = open(os.path.join(save_path, 'run_{}_arch2vec_model_darts.json'.format(args.seed)), 'w') 175 | json.dump(res, fh) 176 | fh.close() 177 | 178 | 179 | if __name__ == '__main__': 180 | parser = argparse.ArgumentParser(description="arch2vec-DNGO") 181 | parser.add_argument("--seed", type=int, default=3, help="random seed") 182 | parser.add_argument('--cfg', type=int, default=4, help='configuration (default: 4)') 183 | parser.add_argument('--dim', type=int, default=16, help='feature dimension') 184 | parser.add_argument('--objective', type=float, default=0.95, help='ei objective') 185 | parser.add_argument('--init_size', type=int, default=16, help='init samples') 186 | parser.add_argument('--batch_size', type=int, default=5, help='acquisition samples') 187 | parser.add_argument('--inner_epochs', type=int, default=50, help='inner loop epochs') 188 | parser.add_argument('--train_portion', type=float, default=0.9, help='inner loop train/val split') 189 | parser.add_argument('--max_budgets', type=int, default=100, help='max number of trials') 190 | parser.add_argument('--output_path', type=str, default='saved_logs/bo', help='bo') 191 | parser.add_argument('--logging_path', type=str, default='', help='search logging path') 192 | args = parser.parse_args() 193 | torch.manual_seed(args.seed) 194 | embedding_path = 'pretrained/dim-{}/arch2vec-darts.pt'.format(args.dim) 195 | if not os.path.exists(embedding_path): 196 | exit() 197 | features, genotype = load_arch2vec(embedding_path) 198 | expected_improvement_search(features, genotype) 199 | -------------------------------------------------------------------------------- /search_methods/dngo_search_NB201_8x8.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | sys.path.insert(0, os.getcwd()) 4 | from pybnn.dngo import DNGO 5 | import random 6 | import argparse 7 | import json 8 | import torch 9 | import numpy as np 10 | from collections import defaultdict 11 | from torch.distributions import Normal 12 | import time 13 | 14 | 15 | def load_arch2vec(embedding_path): 16 | embedding = torch.load(embedding_path) 17 | print('load pretrained arch2vec from {}'.format(embedding_path)) 18 | random.seed(args.seed) 19 | random.shuffle(embedding) 20 | features = [embedding[ind]['feature'] for ind in range(len(embedding))] 21 | valid_labels = [embedding[ind]['valid_accuracy']/100.0 for ind in range(len(embedding))] 22 | test_labels = [embedding[ind]['test_accuracy']/100.0 for ind in range(len(embedding))] 23 | training_time = [embedding[ind]['time'] for ind in range(len(embedding))] 24 | other_info = [embedding[ind]['other_info'] for ind in range(len(embedding))] 25 | features = torch.stack(features, dim=0) 26 | valid_labels = torch.Tensor(valid_labels) 27 | test_labels = torch.Tensor(test_labels) 28 | training_time = torch.Tensor(training_time) 29 | print('loading finished. pretrained embeddings shape {}, and valid labels shape {}, and test labels shape {}'.format(features.shape, valid_labels.shape, test_labels.shape)) 30 | return features, valid_labels, test_labels, training_time, other_info 31 | 32 | 33 | def get_init_samples(features, valid_labels, test_labels, training_time, other_info, visited): 34 | np.random.seed(args.seed) 35 | init_inds = np.random.permutation(list(range(features.shape[0])))[:args.init_size] 36 | init_inds = torch.Tensor(init_inds).long() 37 | init_feat_samples = features[init_inds] 38 | init_valid_label_samples = valid_labels[init_inds] 39 | init_test_label_samples = test_labels[init_inds] 40 | init_time_samples = training_time[init_inds] 41 | print('='*20, init_inds) 42 | init_other_info_samples = [other_info[k] for k in init_inds] 43 | for idx in init_inds: 44 | visited[idx] = True 45 | return init_feat_samples, init_valid_label_samples, init_test_label_samples, init_time_samples, init_other_info_samples, visited 46 | 47 | 48 | def propose_location(ei, features, valid_labels, test_labels, training_time, other_info, visited): 49 | k = args.batch_size 50 | print('remaining length of indices set:', len(features) - len(visited)) 51 | indices = torch.argsort(ei)[-k:] 52 | ind_dedup = [] 53 | # remove random sampled indices at each step 54 | for idx in indices: 55 | if idx not in visited: 56 | visited[idx] = True 57 | ind_dedup.append(idx) 58 | ind_dedup = torch.Tensor(ind_dedup).long() 59 | proposed_x, proposed_y_valid, proposed_y_test, proposed_time, propose_info = features[ind_dedup], valid_labels[ind_dedup], test_labels[ind_dedup], training_time[ind_dedup], [other_info[k] for k in ind_dedup] 60 | return proposed_x, proposed_y_valid, proposed_y_test, proposed_time, propose_info, visited 61 | 62 | 63 | def expected_improvement_search(features, valid_labels, test_labels, training_time, other_info): 64 | """ implementation of expected improvement search given arch2vec. 65 | :param data_path: the pretrained arch2vec path. 66 | :return: features, labels 67 | """ 68 | CURR_BEST_VALID = 0. 69 | CURR_BEST_TEST = 0. 70 | CURR_BEST_INFO = None 71 | MAX_BUDGET = args.MAX_BUDGET 72 | window_size = 200 73 | counter = 0 74 | rt = 0. 75 | visited = {} 76 | best_trace = defaultdict(list) 77 | 78 | features, valid_labels, test_labels, training_time = features.cpu().detach(), valid_labels.cpu().detach(), test_labels.cpu().detach(), training_time.cpu().detach() 79 | feat_samples, valid_label_samples, test_label_samples, time_samples, other_info_sampled, visited = get_init_samples(features, valid_labels, test_labels, training_time, other_info, visited) 80 | 81 | t_start = time.time() 82 | for feat, acc_valid, acc_test, t, o_info in zip(feat_samples, valid_label_samples, test_label_samples, time_samples, other_info_sampled): 83 | counter += 1 84 | rt += t.item() 85 | if acc_valid > CURR_BEST_VALID: 86 | CURR_BEST_VALID = acc_valid 87 | CURR_BEST_TEST = acc_test 88 | CURR_BEST_INFO = o_info 89 | best_trace['validation'].append(float(CURR_BEST_VALID)) 90 | best_trace['test'].append(float(CURR_BEST_TEST)) 91 | best_trace['time'].append(time.time() - t_start) 92 | best_trace['counter'].append(counter) 93 | 94 | while rt < MAX_BUDGET: 95 | print("feat_samples:", feat_samples.shape) 96 | print("valid label_samples:", valid_label_samples.shape) 97 | print("test label samples:", test_label_samples.shape) 98 | print("current best validation: {}".format(CURR_BEST_VALID)) 99 | print("current best test: {}".format(CURR_BEST_TEST)) 100 | print("rt: {}".format(rt)) 101 | print(feat_samples.shape) 102 | print(valid_label_samples.shape) 103 | model = DNGO(num_epochs=100, n_units=128, do_mcmc=False, normalize_output=False) 104 | model.train(X=feat_samples.numpy(), y=valid_label_samples.view(-1).numpy(), do_optimize=True) 105 | print(model.network) 106 | m = [] 107 | v = [] 108 | chunks = int(features.shape[0] / window_size) 109 | if features.shape[0] % window_size > 0: 110 | chunks += 1 111 | features_split = torch.split(features, window_size, dim=0) 112 | for i in range(chunks): 113 | m_split, v_split = model.predict(features_split[i].numpy()) 114 | m.extend(list(m_split)) 115 | v.extend(list(v_split)) 116 | mean = torch.Tensor(m) 117 | sigma = torch.Tensor(v) 118 | u = (mean - torch.Tensor([1.0]).expand_as(mean)) / sigma 119 | normal = Normal(torch.zeros_like(u), torch.ones_like(u)) 120 | ucdf = normal.cdf(u) 121 | updf = torch.exp(normal.log_prob(u)) 122 | ei = sigma * (updf + u * ucdf) 123 | feat_next, label_next_valid, label_next_test, time_next, info_next, visited = propose_location(ei, features, valid_labels, test_labels, training_time, other_info, visited) 124 | 125 | # add proposed networks to selected networks 126 | for feat, acc_valid, acc_test, t, o_info in zip(feat_next, label_next_valid, label_next_test, time_next, info_next): 127 | feat_samples = torch.cat((feat_samples, feat.view(1, -1)), dim=0) 128 | valid_label_samples = torch.cat((valid_label_samples.view(-1, 1), acc_valid.view(1, 1)), dim=0) 129 | test_label_samples = torch.cat((test_label_samples.view(-1, 1), acc_test.view(1, 1)), dim=0) 130 | counter += 1 131 | rt += t.item() 132 | if acc_valid > CURR_BEST_VALID: 133 | CURR_BEST_VALID = acc_valid 134 | CURR_BEST_TEST = acc_test 135 | CURR_BEST_INFO = o_info 136 | 137 | best_trace['acc_validation'].append(float( CURR_BEST_VALID)) 138 | best_trace['acc_test'].append(float(CURR_BEST_TEST)) 139 | best_trace['search_time'].append(time.time() - t_start) # The actual searching time 140 | best_trace['counter'].append(counter) 141 | 142 | if rt >= MAX_BUDGET: 143 | break 144 | 145 | res = dict() 146 | res['regret_validation'] = best_trace['regret_validation'] 147 | res['regret_test'] = best_trace['regret_test'] 148 | res['runtime'] = best_trace['time'] 149 | res['counter'] = best_trace['counter'] 150 | save_path = os.path.join(args.output_path, 'dim{}'.format(args.dim)) 151 | if not os.path.exists(save_path): 152 | os.mkdir(save_path) 153 | print('save to {}'.format(save_path)) 154 | print('Current Best Valid {}, Test {}'.format(CURR_BEST_VALID, CURR_BEST_TEST)) 155 | data_dict = {'val_acc': float(CURR_BEST_VALID), 'test_acc': float(CURR_BEST_TEST), 156 | 'val_acc_avg': float(CURR_BEST_INFO['valid_accuracy_avg']), 157 | 'test_acc_avg': float(CURR_BEST_INFO['test_accuracy_avg'])} 158 | save_dir = os.path.join(save_path, 'nasbench201_{}_run_{}_full.json'.format(args.dataset_name, args.seed)) 159 | with open(save_dir, 'w') as f: 160 | json.dump(data_dict, f) 161 | 162 | 163 | if __name__ == '__main__': 164 | parser = argparse.ArgumentParser(description="DNGO search for NB201") 165 | parser.add_argument("--gamma", type=float, default=0, help="discount factor (default 0.99)") 166 | parser.add_argument("--seed", type=int, default=1, help="random seed") 167 | parser.add_argument('--cfg', type=int, default=4, help='configuration (default: 4)') 168 | parser.add_argument('--dim', type=int, default=16, help='feature dimension') 169 | parser.add_argument('--init_size', type=int, default=16, help='init samples') 170 | parser.add_argument('--batch_size', type=int, default=1, help='acquisition samples') 171 | parser.add_argument('--output_path', type=str, default='saved_logs/bo', help='rl/gd/predictor/bo (default: bo)') 172 | parser.add_argument('--saved_arch2vec', action="store_true", default=True) 173 | 174 | parser.add_argument('--dataset_name', type=str, default='ImageNet16_120', 175 | help='Select from | cifar100 | ImageNet16_120 | cifar10_valid | cifar10_valid_converged') 176 | parser.add_argument('--MAX_BUDGET', type=float, default=1200000, help='The budget in seconds') 177 | 178 | args = parser.parse_args() 179 | #reproducbility is good 180 | np.random.seed(args.seed) 181 | torch.manual_seed(args.seed) 182 | torch.cuda.manual_seed_all(args.seed) 183 | embedding_path = 'pretrained/dim-{}/{}-arch2vec.pt'.format(args.dim, args.dataset_name) 184 | if not os.path.exists(embedding_path): 185 | exit() 186 | features, valid_labels, test_labels, training_time, other_info = load_arch2vec(embedding_path) 187 | expected_improvement_search(features, valid_labels, test_labels, training_time, other_info) 188 | -------------------------------------------------------------------------------- /search_methods/reinforce.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | sys.path.insert(0, os.getcwd()) 4 | import numpy as np 5 | import argparse 6 | import json 7 | import random 8 | import torch 9 | import torch.nn as nn 10 | import torch.nn.functional as F 11 | import torch.optim as optim 12 | from models.pretraining_nasbench101 import configs 13 | from utils.utils import load_json, preprocessing 14 | from models.model import Model 15 | from torch.distributions import MultivariateNormal 16 | 17 | class Env(object): 18 | def __init__(self, name, seed, emb_path, model_path, cfg, data_path=None, save=False): 19 | self.name = name 20 | self.model_path = model_path 21 | self.emb_path = emb_path 22 | self.seed = seed 23 | self.dir_name = 'pretrained/dim-{}'.format(args.dim) 24 | self.visited = {} 25 | self.features = [] 26 | self.embedding = {} 27 | self._reset(data_path, save) 28 | 29 | def _reset(self, data_path, save): 30 | if not save: 31 | print("extract arch2vec from {}".format(os.path.join(self.dir_name, self.model_path))) 32 | if not os.path.exists(os.path.join(self.dir_name, self.model_path)): 33 | exit() 34 | dataset = load_json(data_path) 35 | self.model = Model(input_dim=5, hidden_dim=128, latent_dim=16, num_hops=5, num_mlp_layers=2, dropout=0, **cfg['GAE']).cuda() 36 | self.model.load_state_dict(torch.load(os.path.join(self.dir_name, self.model_path).format(args.dim))['model_state']) 37 | self.model.eval() 38 | with torch.no_grad(): 39 | print("length of the dataset: {}".format(len(dataset))) 40 | self.f_path = os.path.join(self.dir_name, 'arch2vec-{}'.format(self.model_path)) 41 | if os.path.exists(self.f_path): 42 | print('{} is already saved'.format(self.f_path)) 43 | exit() 44 | print('save to {}'.format(self.f_path)) 45 | for ind in range(len(dataset)): 46 | adj = torch.Tensor(dataset[str(ind)]['module_adjacency']).unsqueeze(0).cuda() 47 | ops = torch.Tensor(dataset[str(ind)]['module_operations']).unsqueeze(0).cuda() 48 | adj, ops, prep_reverse = preprocessing(adj, ops, **cfg['prep']) 49 | test_acc = dataset[str(ind)]['test_accuracy'] 50 | valid_acc = dataset[str(ind)]['validation_accuracy'] 51 | time = dataset[str(ind)]['training_time'] 52 | x,_ = self.model._encoder(ops, adj) 53 | self.embedding[ind] = {'feature': x.squeeze(0).mean(dim=0).cpu(), 'valid_accuracy': float(valid_acc), 'test_accuracy': float(test_acc), 'time': float(time)} 54 | torch.save(self.embedding, self.f_path) 55 | print("finish arch2vec extraction") 56 | exit() 57 | else: 58 | self.f_path = os.path.join(self.dir_name, self.emb_path) 59 | print("load arch2vec from: {}".format(self.f_path)) 60 | self.embedding = torch.load(self.f_path) 61 | for ind in range(len(self.embedding)): 62 | self.features.append(self.embedding[ind]['feature']) 63 | self.features = torch.stack(self.features, dim=0) 64 | print('loading finished. pretrained embeddings shape: {}'.format(self.features.shape)) 65 | 66 | def get_init_state(self): 67 | """ 68 | :return: 1 x dim 69 | """ 70 | random.seed(args.seed) 71 | rand_indices = random.randint(0, self.features.shape[0]) 72 | self.visited[rand_indices] = True 73 | return self.features[rand_indices], self.embedding[rand_indices]['valid_accuracy'],\ 74 | self.embedding[rand_indices]['test_accuracy'], self.embedding[rand_indices]['time'] 75 | 76 | def step(self, action): 77 | """ 78 | action: 1 x dim 79 | self.features. N x dim 80 | """ 81 | dist = torch.norm(self.features - action.cpu(), dim=1) 82 | knn = (-1 * dist).topk(dist.shape[0]) 83 | min_dist, min_idx = knn.values, knn.indices 84 | count = 0 85 | while True: 86 | if len(self.visited) == dist.shape[0]: 87 | print("cannot find in the dataset") 88 | exit() 89 | if min_idx[count].item() not in self.visited: 90 | self.visited[min_idx[count].item()] = True 91 | break 92 | count += 1 93 | 94 | return self.features[min_idx[count].item()], self.embedding[min_idx[count].item()]['valid_accuracy'], \ 95 | self.embedding[min_idx[count].item()]['test_accuracy'], self.embedding[min_idx[count].item()]['time'] 96 | 97 | 98 | class Policy(nn.Module): 99 | def __init__(self, hidden_dim1, hidden_dim2): 100 | super(Policy, self).__init__() 101 | self.fc1 = nn.Linear(hidden_dim1, hidden_dim2) 102 | self.fc2 = nn.Linear(hidden_dim2, hidden_dim1) 103 | self.saved_log_probs = [] 104 | self.rewards = [] 105 | 106 | def forward(self, input): 107 | x = F.relu(self.fc1(input)) 108 | out = self.fc2(x) 109 | return out 110 | 111 | class Policy_LSTM(nn.Module): 112 | def __init__(self, hidden_dim1, hidden_dim2): 113 | super(Policy_LSTM, self).__init__() 114 | self.lstm = torch.nn.LSTMCell(input_size=hidden_dim1, hidden_size=hidden_dim2) 115 | self.fc = nn.Linear(hidden_dim2, hidden_dim1) 116 | self.saved_log_probs = [] 117 | self.rewards = [] 118 | self.hx = None 119 | self.cx = None 120 | 121 | def forward(self, input): 122 | if self.hx is None and self.cx is None: 123 | self.hx, self.cx = self.lstm(input) 124 | else: 125 | self.hx, self.cx = self.lstm(input, (self.hx, self.cx)) 126 | mean = self.fc(self.hx) 127 | return mean 128 | 129 | def select_action(state, policy): 130 | """ 131 | MVN based action selection. 132 | :param state: 1 x dim 133 | :param policy: policy network 134 | :return: action: 1 x dim 135 | """ 136 | mean = policy(state.view(1, state.shape[0])) 137 | mvn = MultivariateNormal(mean, torch.eye(state.shape[0]).cuda()) 138 | action = mvn.sample() 139 | policy.saved_log_probs.append(torch.mean(mvn.log_prob(action))) 140 | return action 141 | 142 | 143 | def finish_episode(policy, optimizer): 144 | R = 0 145 | policy_loss = [] 146 | returns = [] 147 | for r in policy.rewards: 148 | R = r + 0.8 * R 149 | returns.append(R) 150 | returns = torch.Tensor(policy.rewards) 151 | returns = returns - 0.95 152 | for log_prob, R in zip(policy.saved_log_probs, returns): 153 | policy_loss.append(-log_prob * R) 154 | 155 | optimizer.zero_grad() 156 | policy_loss = torch.mean(torch.stack(policy_loss, dim=0)) 157 | print("average reward: {}, policy loss: {}".format(sum(policy.rewards)/len(policy.rewards), policy_loss.item())) 158 | policy_loss.backward() 159 | optimizer.step() 160 | del policy.rewards[:] 161 | del policy.saved_log_probs[:] 162 | policy.hx = None 163 | policy.cx = None 164 | 165 | 166 | def reinforce_search(env, args): 167 | """ implementation of arch2vec-REINFORCE """ 168 | policy = Policy_LSTM(args.dim, 128).cuda() 169 | optimizer = optim.Adam(policy.parameters(), lr=1e-2) 170 | counter = 0 171 | BEST_VALID_ACC = 0.9505542318026224 172 | BEST_TEST_ACC = 0.943175752957662 173 | MAX_BUDGET = 1.5e6 174 | rt = 0 175 | state, _, _, time = env.get_init_state() 176 | CURR_BEST_VALID = 0 177 | CURR_BEST_TEST = 0 178 | test_trace = [] 179 | valid_trace = [] 180 | time_trace = [] 181 | while rt < MAX_BUDGET: 182 | for c in range(args.bs): 183 | state = state.cuda() 184 | action = select_action(state, policy) 185 | state, reward, reward_test, time = env.step(action) 186 | policy.rewards.append(reward) 187 | counter += 1 188 | rt += time 189 | print('counter: {}, validation reward: {}, test reward: {}, time: {}'.format(counter, reward, reward_test, rt)) 190 | 191 | if reward > CURR_BEST_VALID: 192 | CURR_BEST_VALID = reward 193 | CURR_BEST_TEST = reward_test 194 | 195 | valid_trace.append(float(BEST_VALID_ACC - CURR_BEST_VALID)) 196 | test_trace.append(float(BEST_TEST_ACC - CURR_BEST_TEST)) 197 | time_trace.append(rt) 198 | 199 | if rt >= MAX_BUDGET: 200 | break 201 | 202 | finish_episode(policy, optimizer) 203 | 204 | res = dict() 205 | res['regret_validation'] = valid_trace 206 | res['regret_test'] = test_trace 207 | res['runtime'] = time_trace 208 | save_path = os.path.join(args.output_path, 'dim{}'.format(args.dim)) 209 | if not os.path.exists(save_path): 210 | os.mkdir(save_path) 211 | print('save to {}'.format(save_path)) 212 | if args.emb_path.endswith('.pt'): 213 | s = args.emb_path[:-3] 214 | fh = open(os.path.join(save_path, 'run_{}_{}.json'.format(args.seed, s)),'w') 215 | json.dump(res, fh) 216 | fh.close() 217 | 218 | 219 | 220 | if __name__ == '__main__': 221 | parser = argparse.ArgumentParser(description="arch2vec-REINFORCE") 222 | parser.add_argument("--gamma", type=float, default=0, help="discount factor (default 0.99)") 223 | parser.add_argument("--seed", type=int, default=1, help="random seed") 224 | parser.add_argument('--cfg', type=int, default=4, help='configuration (default: 4)') 225 | parser.add_argument('--bs', type=int, default=16, help='batch size') 226 | parser.add_argument('--dim', type=int, default=7, help='feature dimension') 227 | parser.add_argument('--output_path', type=str, default='rl', help='rl/bo') 228 | parser.add_argument('--emb_path', type=str, default='arch2vec.pt') 229 | parser.add_argument('--model_path', type=str, default='model-nasbench-101.pt') 230 | parser.add_argument('--saved_arch2vec', action="store_true", default=False) 231 | args = parser.parse_args() 232 | cfg = configs[args.cfg] 233 | env = Env('REINFORCE', args.seed, args.emb_path, args.model_path, cfg, data_path='data/data.json', save=args.saved_arch2vec) 234 | np.random.seed(args.seed) 235 | torch.manual_seed(args.seed) 236 | torch.cuda.manual_seed_all(args.seed) 237 | torch.set_num_threads(2) 238 | reinforce_search(env, args) 239 | -------------------------------------------------------------------------------- /search_methods/supervised_dngo.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import os 4 | import sys 5 | sys.path.insert(0, os.getcwd()) 6 | from pybnn.dngo_supervised import DNGO 7 | import json 8 | import argparse 9 | from collections import defaultdict 10 | from torch.distributions import Normal 11 | 12 | def extract_data(dataset): 13 | with open(dataset) as f: 14 | data = json.load(f) 15 | X_adj = [torch.Tensor(data[str(ind)]['module_adjacency']) for ind in range(len(data))] 16 | X_ops = [torch.Tensor(data[str(ind)]['module_operations']) for ind in range(len(data))] 17 | Y = [data[str(ind)]['validation_accuracy'] for ind in range(len(data))] 18 | Y_test = [data[str(ind)]['test_accuracy'] for ind in range(len(data))] 19 | training_time = [data[str(ind)]['training_time'] for ind in range(len(data))] 20 | X_adj = torch.stack(X_adj, dim=0) 21 | X_ops = torch.stack(X_ops, dim=0) 22 | Y = torch.Tensor(Y) 23 | Y_test = torch.Tensor(Y_test) 24 | training_time = torch.Tensor(training_time) 25 | rand_ind = torch.randperm(X_ops.shape[0]) 26 | X_adj = X_adj[rand_ind] 27 | X_ops = X_ops[rand_ind] 28 | Y = Y[rand_ind] 29 | Y_test = Y_test[rand_ind] 30 | training_time = training_time[rand_ind] 31 | print('loading finished. input adj shape {}, input ops shape {} and valid labels shape {}, and test labels shape {}'.format(X_adj.shape, X_ops.shape, Y.shape, Y_test.shape)) 32 | return X_adj, X_ops, Y, Y_test, training_time 33 | 34 | def get_init_samples(X_adj, X_ops, Y, Y_test, training_time, visited): 35 | np.random.seed(args.seed) 36 | init_inds = np.random.permutation(list(range(X_ops.shape[0])))[:args.init_size] 37 | init_inds = torch.Tensor(init_inds).long() 38 | init_x_adj_samples = X_adj[init_inds] 39 | init_x_ops_samples = X_ops[init_inds] 40 | init_valid_label_samples = Y[init_inds] 41 | init_test_label_samples = Y_test[init_inds] 42 | init_time_samples = training_time[init_inds] 43 | for idx in init_inds: 44 | visited[idx.item()] = True 45 | return init_x_adj_samples, init_x_ops_samples, init_valid_label_samples, init_test_label_samples, init_time_samples, visited 46 | 47 | 48 | def propose_location(ei, X_adj, X_ops, valid_labels, test_labels, training_time, visited): 49 | k = args.topk 50 | count = 0 51 | print('remaining length of indices set:', len(X_adj) - len(visited)) 52 | indices = torch.argsort(ei) 53 | ind_dedup = [] 54 | # remove random sampled indices at each step 55 | for idx in reversed(indices): 56 | if count == k: 57 | break 58 | if idx.item() not in visited: 59 | visited[idx.item()] = True 60 | ind_dedup.append(idx.item()) 61 | count += 1 62 | ind_dedup = torch.Tensor(ind_dedup).long() 63 | proposed_x_adj, proposed_x_ops, proposed_y_valid, proposed_y_test, proposed_time = X_adj[ind_dedup], X_ops[ind_dedup], valid_labels[ind_dedup], test_labels[ind_dedup], training_time[ind_dedup] 64 | return proposed_x_adj, proposed_x_ops, proposed_y_valid, proposed_y_test, proposed_time, visited 65 | 66 | 67 | def supervised_encoding_search(X_adj, X_ops, Y, Y_test, training_time): 68 | """implementation of supervised learning based BO search""" 69 | BEST_TEST_ACC = 0.943175752957662 70 | BEST_VALID_ACC = 0.9505542318026224 71 | CURR_BEST_VALID = 0. 72 | CURR_BEST_TEST = 0. 73 | MAX_BUDGET = 1.5e6 74 | counter = 0 75 | rt = 0. 76 | best_trace = defaultdict(list) 77 | window_size = 512 78 | visited = {} 79 | X_adj_sample, X_ops_sample, Y_sample, Y_sample_test, time_sample, visited = get_init_samples(X_adj, X_ops, Y, Y_test, training_time, visited) 80 | 81 | for x_adj, x_ops, acc_valid, acc_test, t in zip(X_adj_sample, X_ops_sample, Y_sample, Y_sample_test, time_sample): 82 | counter += 1 83 | rt += t.item() 84 | if acc_valid > CURR_BEST_VALID: 85 | CURR_BEST_VALID = acc_valid 86 | CURR_BEST_TEST = acc_test 87 | best_trace['regret_validation'].append(float(BEST_VALID_ACC - CURR_BEST_VALID)) 88 | best_trace['regret_test'].append(float(BEST_TEST_ACC - CURR_BEST_TEST)) 89 | best_trace['time'].append(rt) 90 | best_trace['counter'].append(counter) 91 | 92 | while rt < MAX_BUDGET: 93 | print("data adjacent matrix samples:", X_adj_sample.shape) 94 | print("data operations matrix samples:", X_ops_sample.shape) 95 | print("valid label_samples:", Y_sample.shape) 96 | print("test label samples:", Y_sample_test.shape) 97 | print("current best validation: {}".format(CURR_BEST_VALID)) 98 | print("current best test: {}".format(CURR_BEST_TEST)) 99 | print("rt: {}".format(rt)) 100 | model = DNGO(num_epochs=100, input_dim=5, hidden_dim=128, latent_dim=args.dim, num_hops=5, num_mlp_layers=2, do_mcmc=False, normalize_output=False) 101 | model.train(X_adj_sample.numpy(), X_ops_sample.numpy(), Y_sample.view(-1).numpy(), do_optimize=True) 102 | m = [] 103 | v = [] 104 | chunks = int(X_adj.shape[0] / window_size) 105 | if X_adj.shape[0] % window_size > 0: 106 | chunks += 1 107 | X_adj_split = torch.split(X_adj, window_size, dim=0) 108 | X_ops_split = torch.split(X_ops, window_size, dim=0) 109 | for i in range(chunks): 110 | inputs_adj = X_adj_split[i] 111 | inputs_ops = X_ops_split[i] 112 | m_split, v_split = model.predict(inputs_ops.numpy(), inputs_adj.numpy()) 113 | m.extend(list(m_split)) 114 | v.extend(list(v_split)) 115 | mean = torch.Tensor(m) 116 | sigma = torch.Tensor(v) 117 | u = mean - torch.Tensor([0.95]).expand_as(mean) / sigma 118 | normal = Normal(torch.zeros_like(u), torch.ones_like(u)) 119 | ucdf = normal.cdf(u) 120 | updf = torch.exp(normal.log_prob(u)) 121 | ei = sigma * (updf + u * ucdf) 122 | 123 | X_adj_next, X_ops_next, label_next_valid, label_next_test, time_next, visited = propose_location(ei, X_adj, X_ops, Y, Y_test, training_time, visited) 124 | 125 | # add proposed networks to selected networks 126 | for x_adj, x_ops, acc_valid, acc_test, t in zip(X_adj_next, X_ops_next, label_next_valid, label_next_test, time_next): 127 | X_adj_sample = torch.cat((X_adj_sample, x_adj.view(1, 7, 7)), dim=0) 128 | X_ops_sample = torch.cat((X_ops_sample, x_ops.view(1, 7, 5)), dim=0) 129 | Y_sample = torch.cat((Y_sample.view(-1, 1), acc_valid.view(1, 1)), dim=0) 130 | Y_sample_test = torch.cat((Y_sample_test.view(-1, 1), acc_test.view(1, 1)), dim=0) 131 | counter += 1 132 | rt += t.item() 133 | if acc_valid > CURR_BEST_VALID: 134 | CURR_BEST_VALID = acc_valid 135 | CURR_BEST_TEST = acc_test 136 | 137 | best_trace['regret_validation'].append(float(BEST_VALID_ACC - CURR_BEST_VALID)) 138 | best_trace['regret_test'].append(float(BEST_TEST_ACC - CURR_BEST_TEST)) 139 | best_trace['time'].append(rt) 140 | best_trace['counter'].append(counter) 141 | 142 | if rt >= MAX_BUDGET: 143 | break 144 | 145 | res = dict() 146 | res['regret_validation'] = best_trace['regret_validation'] 147 | res['regret_test'] = best_trace['regret_test'] 148 | res['runtime'] = best_trace['time'] 149 | res['counter'] = best_trace['counter'] 150 | save_path = os.path.join(args.output_path, 'dim{}'.format(args.dim)) 151 | if not os.path.exists(save_path): 152 | os.mkdir(save_path) 153 | print('save to {}'.format(save_path)) 154 | fh = open(os.path.join(save_path, 'run_{}_{}.json'.format(args.seed, args.benchmark)), 'w') 155 | json.dump(res, fh) 156 | fh.close() 157 | 158 | 159 | 160 | if __name__ == '__main__': 161 | parser = argparse.ArgumentParser(description="Supervised DNGO search") 162 | parser.add_argument("--seed", type=int, default=1, help="random seed") 163 | parser.add_argument('--dim', type=int, default=16, help='feature dimension') 164 | parser.add_argument('--init_size', type=int, default=16, help='init samples') 165 | parser.add_argument('--topk', type=int, default=5, help='acquisition samples') 166 | parser.add_argument('--benchmark', type=str, default='supervised_dngo') 167 | parser.add_argument('--output_path', type=str, default='saved_logs/bo', help='rl/bo (default: bo)') 168 | args = parser.parse_args() 169 | torch.manual_seed(args.seed) 170 | data_path = 'data/data.json' 171 | X_adj, X_ops, Y, Y_test, training_time = extract_data(data_path) 172 | supervised_encoding_search(X_adj, X_ops, Y, Y_test, training_time) 173 | --------------------------------------------------------------------------------