├── LICENSE
├── README.md
├── darts
└── cnn
│ ├── architect.py
│ ├── eval-EXP
│ └── log.txt
│ ├── genotypes.py
│ ├── model.py
│ ├── operations.py
│ ├── test.py
│ ├── test_imagenet.py
│ ├── train.py
│ ├── train_imagenet.py
│ ├── train_search.py
│ ├── utils.py
│ └── visualize.py
├── docs
└── arch.png
├── gin
└── models
│ ├── graphcnn.py
│ └── mlp.py
├── models
├── configs.py
├── layers.py
├── model.py
├── pretraining_darts.py
├── pretraining_darts.sh
├── pretraining_nasbench101.py
├── pretraining_nasbench101.sh
├── pretraining_nasbench201.py
└── pretraining_nasbench201.sh
├── plot_scripts
├── distance_comparison_fig3.py
├── draw_darts.py
├── drawfig4.sh
├── drawfig5-darts.sh
├── drawfig5-nas101.sh
├── drawfig5-nas201.sh
├── nas201.jpg
├── pearson_plot_fig2.py
├── plot_cdf.py
├── plot_dngo_search_arch2vec.py
├── plot_nasbench101_comparison.py
├── plot_reinforce_search_arch2vec.py
├── summarize_nasbench201.py
├── try_networkx.py
├── visdensity.py
└── visgraph.py
├── preprocessing
├── api.py
├── gen_isomorphism_graphs.py
├── gen_json.py
└── nasbench201_json.py
├── pybnn
├── __init__.py
├── base_model.py
├── bayesian_linear_regression.py
├── dngo.py
├── dngo_supervised.py
└── util
│ ├── __init__.py
│ └── normalization.py
├── requirements.txt
├── results
├── BO-arch2vec-model-nasbench-101.json
├── BO-supervised-nasbench-101.json
├── BOHB-Search-Encoding-A.json
├── RL-arch2vec-model-nasbench-101.json
├── RL-supervised-nasbench-101.json
├── Random-Search-Encoding-A.json
├── Regularized-Evolution-Encoding-A.json
└── Reinforce-Search-Encoding-A.json
├── run_scripts
├── extract_arch2vec.sh
├── extract_arch2vec_darts.sh
├── extract_arch2vec_nasbench201.sh
├── run_bo_arch2vec_darts.sh
├── run_bo_arch2vec_nasbench201_ImageNet.sh
├── run_bo_arch2vec_nasbench201_cifar100.sh
├── run_bo_arch2vec_nasbench201_cifar10_valid.sh
├── run_dngo_arch2vec.sh
├── run_dngo_supervised.sh
├── run_reinforce_arch2vec.sh
├── run_reinforce_arch2vec_darts.sh
├── run_reinforce_arch2vec_nasbench201_ImageNet.sh
├── run_reinforce_arch2vec_nasbench201_cifar100.sh
├── run_reinforce_arch2vec_nasbench201_cifar10_valid.sh
└── run_reinforce_supervised.sh
├── search_methods
├── dngo.py
├── dngo_darts.py
├── dngo_search_NB201_8x8.py
├── reinforce.py
├── reinforce_darts.py
├── reinforce_search_NB201_8x8.py
├── supervised_dngo.py
└── supervised_reinforce.py
└── utils
└── utils.py
/README.md:
--------------------------------------------------------------------------------
1 | # Does Unsupervised Architecture Representation Learning Help Neural Architecture Search?
2 | Code for paper:
3 | > [Does Unsupervised Architecture Representation Learning Help Neural Architecture Search?](https://arxiv.org/abs/2006.06936)\
4 | > Shen Yan, Yu Zheng, Wei Ao, Xiao Zeng, Mi Zhang.\
5 | > _NeurIPS 2020_.
6 |
7 |
8 |
9 | Top: The supervision signal for representation learning comes from the accuracies of architectures selected by the search strategies. Bottom (ours): Disentangling architecture representation learning and architecture search through unsupervised pre-training.
10 |
11 |
12 | The repository is built upon [pytorch_geometric](https://github.com/rusty1s/pytorch_geometric), [pybnn](https://github.com/automl/pybnn), [nas_benchmarks](https://github.com/automl/nas_benchmarks), [bananas](https://github.com/naszilla/bananas).
13 |
14 | ## 1. Requirements
15 | - NVIDIA GPU, Linux, Python3
16 | ```bash
17 | pip install -r requirements.txt
18 | ```
19 |
20 | ## 2. Experiments on NAS-Bench-101
21 | ### Dataset preparation on NAS-Bench-101
22 |
23 | Install [nasbench](https://github.com/google-research/nasbench) and download [nasbench_only108.tfrecord](https://storage.googleapis.com/nasbench/nasbench_only108.tfrecord) under `./data` folder.
24 |
25 | ```bash
26 | python preprocessing/gen_json.py
27 | ```
28 |
29 | Data will be saved in `./data/data.json`.
30 |
31 | ### Pretraining
32 | ```bash
33 | bash models/pretraining_nasbench101.sh
34 | ```
35 |
36 | The pretrained model will be saved in `./pretrained/dim-16/`.
37 |
38 | ### arch2vec extraction
39 | ```bash
40 | bash run_scripts/extract_arch2vec.sh
41 | ```
42 |
43 | The extracted arch2vec will be saved in `./pretrained/dim-16/`.
44 |
45 | Alternatively, you can download the pretrained [arch2vec](https://drive.google.com/file/d/16GnqqrN46PJWl8QnES83WY3W58NUhgCr/view?usp=sharing) on NAS-Bench-101.
46 |
47 |
48 | ### Run experiments of RL search on NAS-Bench-101
49 | ```bash
50 | bash run_scripts/run_reinforce_supervised.sh
51 | bash run_scripts/run_reinforce_arch2vec.sh
52 | ```
53 |
54 | Search results will be saved in `./saved_logs/rl/dim16`
55 |
56 | Generate json file:
57 | ```bash
58 | python plot_scripts/plot_reinforce_search_arch2vec.py
59 | ```
60 |
61 |
62 | ### Run experiments of BO search on NAS-Bench-101
63 | ```bash
64 | bash run_scripts/run_dngo_supervised.sh
65 | bash run_scripts/run_dngo_arch2vec.sh
66 | ```
67 |
68 | Search results will be saved in `./saved_logs/bo/dim16`.
69 |
70 | Generate json file:
71 | ```bash
72 | python plot_scripts/plot_dngo_search_arch2vec.py
73 | ```
74 |
75 | ### Plot NAS comparison curve on NAS-Bench-101:
76 | ```bash
77 | python plot_scipts/plot_nasbench101_comparison.py
78 | ```
79 |
80 | ### Plot CDF comparison curve on NAS-Bench-101:
81 | Download the search results from [search_logs](https://drive.google.com/drive/u/1/folders/1FKZghhBX0-gVNcQpzYjMShOH7mdkfwC1).
82 | ```bash
83 | python plot_scripts/plot_cdf.py
84 | ```
85 |
86 |
87 | ## 3. Experiments on NAS-Bench-201
88 |
89 | ### Dataset preparation
90 | Download the [NAS-Bench-201-v1_0-e61699.pth](https://drive.google.com/file/d/1SKW0Cu0u8-gb18zDpaAGi0f74UdXeGKs/view) under `./data` folder.
91 | ```bash
92 | python preprocessing/nasbench201_json.py
93 | ```
94 | Data corresponding to the three datasets in NAS-Bench-201 will be saved in folder `./data/` as `cifar10_valid_converged.json`, `cifar100.json`, `ImageNet16_120.json`.
95 |
96 | ### Pretraining
97 | ```bash
98 | bash models/pretraining_nasbench201.sh
99 | ```
100 | The pretrained model will be saved in `./pretrained/dim-16/`.
101 |
102 | Note that the pretrained model is shared across the 3 datasets in NAS-Bench-201.
103 |
104 | ### arch2vec extraction
105 | ```bash
106 | bash run_scripts/extract_arch2vec_nasbench201.sh
107 | ```
108 | The extracted arch2vec will be saved in `./pretrained/dim-16/` as `cifar10_valid_converged-arch2vec.pt`, `cifar100-arch2vec.pt` and `ImageNet16_120-arch2vec.pt`.
109 |
110 | Alternatively, you can download the pretrained [arch2vec](https://drive.google.com/drive/u/1/folders/16AIs4GfGNgeaHriTAICLCxIBdYE223id) on NAS-Bench-201.
111 |
112 | ### Run experiments of RL search on NAS-Bench-201
113 | ```bash
114 | CIFAR-10: ./run_scripts/run_reinforce_arch2vec_nasbench201_cifar10_valid.sh
115 | CIFAR-100: ./run_scripts/run_reinforce_arch2vec_nasbench201_cifar100.sh
116 | ImageNet-16-120: ./run_scripts/run_reinforce_arch2vec_nasbench201_ImageNet.sh
117 | ```
118 |
119 |
120 | ### Run experiments of BO search on NAS-Bench-201
121 | ```bash
122 | CIFAR-10: ./run_scripts/run_bo_arch2vec_nasbench201_cifar10_valid.sh
123 | CIFAR-100: ./run_scripts/run_bo_arch2vec_nasbench201_cifar100.sh
124 | ImageNet-16-120: ./run_scripts/run_bo_arch2vec_nasbench201_ImageNet.sh
125 | ```
126 |
127 |
128 | ### Summarize search result on NAS-Bench-201
129 | ```bash
130 | python ./plot_scripts/summarize_nasbench201.py
131 | ```
132 | The corresponding table will be printed to the console.
133 |
134 |
135 | ## 4. Experiments on DARTS Search Space
136 | CIFAR-10 can be automatically downloaded by torchvision, ImageNet needs to be manually downloaded (preferably to a SSD) from http://image-net.org/download.
137 |
138 | ### Random sampling 600,000 isomorphic graphs in DARTS space
139 | ```bash
140 | python preprocessing/gen_isomorphism_graphs.py
141 | ```
142 | Data will be saved in `./data/data_darts_counter600000.json`.
143 |
144 | Alternatively, you can download the extracted [data_darts_counter600000.json](https://drive.google.com/file/d/1xboQV_NtsSDyOPM4H7RxtDNL-2WXo3Wr/view?usp=sharing).
145 |
146 | ### Pretraining
147 | ```bash
148 | bash models/pretraining_darts.sh
149 | ```
150 | The pretrained model is saved in `./pretrained/dim-16/`.
151 |
152 | ### arch2vec extraction
153 | ```bash
154 | bash run_scripts/extract_arch2vec_darts.sh
155 | ```
156 | The extracted arch2vec will be saved in `./pretrained/dim-16/arch2vec-darts.pt`.
157 |
158 | Alternatively, you can download the pretrained [arch2vec](https://drive.google.com/file/d/1bDZCD-XDzded6SRjDUpRV6xTINpwTNcm/view?usp=sharing) on DARTS search space.
159 |
160 | ### Run experiments of RL search on DARTS search space
161 | ```bash
162 | bash run_scripts/run_reinforce_arch2vec_darts.sh
163 | ```
164 | logs will be saved in `./darts-rl/`.
165 |
166 | Final search result will be saved in `./saved_logs/rl/dim16`.
167 |
168 | ### Run experiments of BO search on DARTS search space
169 | ```bash
170 | bash run_scripts/run_bo_arch2vec_darts.sh
171 | ```
172 | logs will be saved in `./darts-bo/` .
173 |
174 | Final search result will be saved in `./saved_logs/bo/dim16`.
175 |
176 | ### Evaluate the learned cell on DARTS Search Space on CIFAR-10
177 | ```bash
178 | python darts/cnn/train.py --auxiliary --cutout --arch arch2vec_rl --seed 1
179 | python darts/cnn/train.py --auxiliary --cutout --arch arch2vec_bo --seed 1
180 | ```
181 | - Expected results (RL): 2.60\% test error with 3.3M model params.
182 | - Expected results (BO): 2.48\% test error with 3.6M model params.
183 |
184 |
185 | ### Transfer learning on ImageNet
186 | ```bash
187 | python darts/cnn/train_imagenet.py --arch arch2vec_rl --seed 1
188 | python darts/cnn/train_imagenet.py --arch arch2vec_bo --seed 1
189 | ```
190 | - Expected results (RL): 25.8\% test error with 4.8M model params and 533M mult-adds.
191 | - Expected results (RL): 25.5\% test error with 5.2M model params and 580M mult-adds.
192 |
193 |
194 | ### Visualize the learned cell
195 | ```bash
196 | python darts/cnn/visualize.py arch2vec_rl
197 | python darts/cnn/visualize.py arch2vec_bo
198 | ```
199 |
200 | ## 5. Analyzing the results
201 | ### Visualize a sequence of decoded cells from the latent space
202 | Download pretrained supervised embeddings of [nasbench101](https://drive.google.com/file/d/19-1gpMdXftXoH7G5929peoOnS1xKf5wN/view?usp=sharing) and [nasbench201](https://drive.google.com/file/d/1_Pw8MDp6ZrlI6EJ0kS3MVEz3HOSJMnIV/view?usp=sharing).
203 | ```bash
204 | bash plot_scripts/drawfig5-nas101.sh # visualization on nasbench-101
205 | bash plot_scripts/drawfig5-nas201.sh # visualization on nasbench-201
206 | bash plot_scripts/drawfig5-darts.sh # visualization on darts
207 | ```
208 | The plots will be saved in `./graphvisualization`.
209 |
210 | ### Plot distribution of L2 distance by edit distance
211 | Install [nas_benchmarks](https://github.com/automl/nas_benchmarks) and download [nasbench_full.tfrecord](https://storage.googleapis.com/nasbench/nasbench_full.tfrecord) under the same directory.
212 | ```bash
213 | python plot_scripts/distance_comparison_fig3.py
214 | ```
215 |
216 | ### Latent space 2D visualization
217 | ```bash
218 | bash plot_scripts/drawfig4.sh
219 | ```
220 | the plots will be saved in `./density`.
221 |
222 | ### Predictive performance comparison
223 | Download [predicted_accuracy](https://drive.google.com/drive/u/1/folders/1mNlg5s3FQ8PEcgTDSnAuM6qa8ECDTzhh) under `saved_logs/`.
224 | ```bash
225 | python plot_scripts/pearson_plot_fig2.py
226 | ```
227 |
228 |
229 |
230 |
231 |
232 | # Citation
233 | If you find this useful for your work, please consider citing:
234 | ```
235 | @InProceedings{yan2020arch,
236 | title = {Does Unsupervised Architecture Representation Learning Help Neural Architecture Search?},
237 | author = {Yan, Shen and Zheng, Yu and Ao, Wei and Zeng, Xiao and Zhang, Mi},
238 | booktitle = {NeurIPS},
239 | year = {2020}
240 | }
241 | ```
242 |
243 |
244 |
245 |
246 |
--------------------------------------------------------------------------------
/darts/cnn/architect.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 | import torch.nn as nn
4 | from torch.autograd import Variable
5 |
6 |
7 | def _concat(xs):
8 | return torch.cat([x.view(-1) for x in xs])
9 |
10 |
11 | class Architect(object):
12 |
13 | def __init__(self, model, args):
14 | self.network_momentum = args.momentum
15 | self.network_weight_decay = args.weight_decay
16 | self.model = model
17 | self.optimizer = torch.optim.Adam(self.model.arch_parameters(),
18 | lr=args.arch_learning_rate, betas=(0.5, 0.999), weight_decay=args.arch_weight_decay)
19 |
20 | def _compute_unrolled_model(self, input, target, eta, network_optimizer):
21 | loss = self.model._loss(input, target)
22 | theta = _concat(self.model.parameters()).data
23 | try:
24 | moment = _concat(network_optimizer.state[v]['momentum_buffer'] for v in self.model.parameters()).mul_(self.network_momentum)
25 | except:
26 | moment = torch.zeros_like(theta)
27 | dtheta = _concat(torch.autograd.grad(loss, self.model.parameters())).data + self.network_weight_decay*theta
28 | unrolled_model = self._construct_model_from_theta(theta.sub(eta, moment+dtheta))
29 | return unrolled_model
30 |
31 | def step(self, input_train, target_train, input_valid, target_valid, eta, network_optimizer, unrolled):
32 | self.optimizer.zero_grad()
33 | if unrolled:
34 | self._backward_step_unrolled(input_train, target_train, input_valid, target_valid, eta, network_optimizer)
35 | else:
36 | self._backward_step(input_valid, target_valid)
37 | self.optimizer.step()
38 |
39 | def _backward_step(self, input_valid, target_valid):
40 | loss = self.model._loss(input_valid, target_valid)
41 | loss.backward()
42 |
43 | def _backward_step_unrolled(self, input_train, target_train, input_valid, target_valid, eta, network_optimizer):
44 | unrolled_model = self._compute_unrolled_model(input_train, target_train, eta, network_optimizer)
45 | unrolled_loss = unrolled_model._loss(input_valid, target_valid)
46 |
47 | unrolled_loss.backward()
48 | dalpha = [v.grad for v in unrolled_model.arch_parameters()]
49 | vector = [v.grad.data for v in unrolled_model.parameters()]
50 | implicit_grads = self._hessian_vector_product(vector, input_train, target_train)
51 |
52 | for g, ig in zip(dalpha, implicit_grads):
53 | g.data.sub_(eta, ig.data)
54 |
55 | for v, g in zip(self.model.arch_parameters(), dalpha):
56 | if v.grad is None:
57 | v.grad = Variable(g.data)
58 | else:
59 | v.grad.data.copy_(g.data)
60 |
61 | def _construct_model_from_theta(self, theta):
62 | model_new = self.model.new()
63 | model_dict = self.model.state_dict()
64 |
65 | params, offset = {}, 0
66 | for k, v in self.model.named_parameters():
67 | v_length = np.prod(v.size())
68 | params[k] = theta[offset: offset+v_length].view(v.size())
69 | offset += v_length
70 |
71 | assert offset == len(theta)
72 | model_dict.update(params)
73 | model_new.load_state_dict(model_dict)
74 | return model_new.cuda()
75 |
76 | def _hessian_vector_product(self, vector, input, target, r=1e-2):
77 | R = r / _concat(vector).norm()
78 | for p, v in zip(self.model.parameters(), vector):
79 | p.data.add_(R, v)
80 | loss = self.model._loss(input, target)
81 | grads_p = torch.autograd.grad(loss, self.model.arch_parameters())
82 |
83 | for p, v in zip(self.model.parameters(), vector):
84 | p.data.sub_(2*R, v)
85 | loss = self.model._loss(input, target)
86 | grads_n = torch.autograd.grad(loss, self.model.arch_parameters())
87 |
88 | for p, v in zip(self.model.parameters(), vector):
89 | p.data.add_(R, v)
90 |
91 | return [(x-y).div_(2*R) for x, y in zip(grads_p, grads_n)]
92 |
93 |
--------------------------------------------------------------------------------
/darts/cnn/genotypes.py:
--------------------------------------------------------------------------------
1 | from collections import namedtuple
2 |
3 | Genotype = namedtuple('Genotype', 'normal normal_concat reduce reduce_concat')
4 |
5 | PRIMITIVES = [
6 | 'none',
7 | 'max_pool_3x3',
8 | 'avg_pool_3x3',
9 | 'skip_connect',
10 | 'sep_conv_3x3',
11 | 'sep_conv_5x5',
12 | 'dil_conv_3x3',
13 | 'dil_conv_5x5'
14 | ]
15 |
16 | NASNet = Genotype(
17 | normal = [
18 | ('sep_conv_5x5', 1),
19 | ('sep_conv_3x3', 0),
20 | ('sep_conv_5x5', 0),
21 | ('sep_conv_3x3', 0),
22 | ('avg_pool_3x3', 1),
23 | ('skip_connect', 0),
24 | ('avg_pool_3x3', 0),
25 | ('avg_pool_3x3', 0),
26 | ('sep_conv_3x3', 1),
27 | ('skip_connect', 1),
28 | ],
29 | normal_concat = [2, 3, 4, 5, 6],
30 | reduce = [
31 | ('sep_conv_5x5', 1),
32 | ('sep_conv_7x7', 0),
33 | ('max_pool_3x3', 1),
34 | ('sep_conv_7x7', 0),
35 | ('avg_pool_3x3', 1),
36 | ('sep_conv_5x5', 0),
37 | ('skip_connect', 3),
38 | ('avg_pool_3x3', 2),
39 | ('sep_conv_3x3', 2),
40 | ('max_pool_3x3', 1),
41 | ],
42 | reduce_concat = [4, 5, 6],
43 | )
44 |
45 | AmoebaNet = Genotype(
46 | normal = [
47 | ('avg_pool_3x3', 0),
48 | ('max_pool_3x3', 1),
49 | ('sep_conv_3x3', 0),
50 | ('sep_conv_5x5', 2),
51 | ('sep_conv_3x3', 0),
52 | ('avg_pool_3x3', 3),
53 | ('sep_conv_3x3', 1),
54 | ('skip_connect', 1),
55 | ('skip_connect', 0),
56 | ('avg_pool_3x3', 1),
57 | ],
58 | normal_concat = [4, 5, 6],
59 | reduce = [
60 | ('avg_pool_3x3', 0),
61 | ('sep_conv_3x3', 1),
62 | ('max_pool_3x3', 0),
63 | ('sep_conv_7x7', 2),
64 | ('sep_conv_7x7', 0),
65 | ('avg_pool_3x3', 1),
66 | ('max_pool_3x3', 0),
67 | ('max_pool_3x3', 1),
68 | ('conv_7x1_1x7', 0),
69 | ('sep_conv_3x3', 5),
70 | ],
71 | reduce_concat = [3, 4, 6]
72 | )
73 |
74 |
75 | DARTS = Genotype(normal=[('sep_conv_3x3', 0), ('sep_conv_3x3', 1), ('sep_conv_3x3', 0), ('sep_conv_3x3', 1), ('sep_conv_3x3', 1), ('skip_connect', 0), ('skip_connect', 0), ('dil_conv_3x3', 2)], normal_concat=[2, 3, 4, 5], reduce=[('max_pool_3x3', 0), ('max_pool_3x3', 1), ('skip_connect', 2), ('max_pool_3x3', 1), ('max_pool_3x3', 0), ('skip_connect', 2), ('skip_connect', 2), ('max_pool_3x3', 1)], reduce_concat=[2, 3, 4, 5])
76 |
77 | BANANAS = Genotype(normal=[('sep_conv_3x3', 1), ('skip_connect', 0), ('sep_conv_5x5', 2), ('sep_conv_5x5', 0), ('skip_connect', 0), ('sep_conv_5x5', 2), ('sep_conv_3x3', 0), ('skip_connect', 2)], normal_concat=[2, 3, 4, 5], reduce=[('sep_conv_3x3', 1), ('max_pool_3x3', 0), ('max_pool_3x3', 0), ('none', 1), ('dil_conv_3x3', 2), ('sep_conv_5x5', 3), ('sep_conv_5x5', 4), ('sep_conv_3x3', 1)], reduce_concat=[2, 3, 4, 5])
78 |
79 | arch2vec_bo = Genotype(normal=[('sep_conv_5x5', 1), ('max_pool_3x3', 0), ('skip_connect', 0), ('dil_conv_3x3', 1), ('sep_conv_5x5', 1), ('sep_conv_3x3', 0), ('dil_conv_5x5', 2), ('sep_conv_3x3', 0)], normal_concat=[2, 3, 4, 5], reduce=[('sep_conv_5x5', 1), ('max_pool_3x3', 0), ('skip_connect', 0), ('dil_conv_3x3', 1), ('sep_conv_5x5', 1), ('sep_conv_3x3', 0), ('dil_conv_5x5', 2), ('sep_conv_3x3', 0)], reduce_concat=[2, 3, 4, 5])
80 |
81 | arch2vec_rl = Genotype(normal=[('sep_conv_3x3', 0), ('dil_conv_3x3', 1), ('max_pool_3x3', 0), ('dil_conv_3x3', 1), ('skip_connect', 0), ('sep_conv_3x3', 1), ('dil_conv_5x5', 1), ('sep_conv_3x3', 0)], normal_concat=[2, 3, 4, 5], reduce=[('sep_conv_3x3', 0), ('dil_conv_3x3', 1), ('max_pool_3x3', 0), ('dil_conv_3x3', 1), ('skip_connect', 0), ('sep_conv_3x3', 1), ('dil_conv_5x5', 1), ('sep_conv_3x3', 0)], reduce_concat=[2, 3, 4, 5])
82 |
--------------------------------------------------------------------------------
/darts/cnn/model.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 | sys.path.insert(0, os.getcwd())
4 | from darts.cnn.operations import *
5 | from darts.cnn.utils import drop_path
6 |
7 |
8 | class Cell(nn.Module):
9 |
10 | def __init__(self, genotype, C_prev_prev, C_prev, C, reduction, reduction_prev):
11 | super(Cell, self).__init__()
12 | #print(C_prev_prev, C_prev, C)
13 |
14 | if reduction_prev:
15 | self.preprocess0 = FactorizedReduce(C_prev_prev, C)
16 | else:
17 | self.preprocess0 = ReLUConvBN(C_prev_prev, C, 1, 1, 0)
18 | self.preprocess1 = ReLUConvBN(C_prev, C, 1, 1, 0)
19 |
20 | if reduction:
21 | op_names, indices = zip(*genotype.reduce)
22 | concat = genotype.reduce_concat
23 | else:
24 | op_names, indices = zip(*genotype.normal)
25 | concat = genotype.normal_concat
26 | self._compile(C, op_names, indices, concat, reduction)
27 |
28 | def _compile(self, C, op_names, indices, concat, reduction):
29 | assert len(op_names) == len(indices)
30 | self._steps = len(op_names) // 2
31 | self._concat = concat
32 | self.multiplier = len(concat)
33 |
34 | self._ops = nn.ModuleList()
35 | for name, index in zip(op_names, indices):
36 | stride = 2 if reduction and index < 2 else 1
37 | op = OPS[name](C, stride, True)
38 | self._ops += [op]
39 | self._indices = indices
40 |
41 | def forward(self, s0, s1, drop_prob):
42 | s0 = self.preprocess0(s0)
43 | s1 = self.preprocess1(s1)
44 |
45 | states = [s0, s1]
46 | for i in range(self._steps):
47 | h1 = states[self._indices[2*i]]
48 | h2 = states[self._indices[2*i+1]]
49 | op1 = self._ops[2*i]
50 | op2 = self._ops[2*i+1]
51 | h1 = op1(h1)
52 | h2 = op2(h2)
53 | if self.training and drop_prob > 0.:
54 | if not isinstance(op1, Identity):
55 | h1 = drop_path(h1, drop_prob)
56 | if not isinstance(op2, Identity):
57 | h2 = drop_path(h2, drop_prob)
58 | s = h1 + h2
59 | states += [s]
60 | return torch.cat([states[i] for i in self._concat], dim=1)
61 |
62 |
63 | class AuxiliaryHeadCIFAR(nn.Module):
64 |
65 | def __init__(self, C, num_classes):
66 | """assuming input size 8x8"""
67 | super(AuxiliaryHeadCIFAR, self).__init__()
68 | self.features = nn.Sequential(
69 | nn.ReLU(inplace=True),
70 | nn.AvgPool2d(5, stride=3, padding=0, count_include_pad=False), # image size = 2 x 2
71 | nn.Conv2d(C, 128, 1, bias=False),
72 | nn.BatchNorm2d(128),
73 | nn.ReLU(inplace=True),
74 | nn.Conv2d(128, 768, 2, bias=False),
75 | nn.BatchNorm2d(768),
76 | nn.ReLU(inplace=True)
77 | )
78 | self.classifier = nn.Linear(768, num_classes)
79 |
80 | def forward(self, x):
81 | x = self.features(x)
82 | x = self.classifier(x.view(x.size(0),-1))
83 | return x
84 |
85 |
86 | class AuxiliaryHeadImageNet(nn.Module):
87 |
88 | def __init__(self, C, num_classes):
89 | """assuming input size 14x14"""
90 | super(AuxiliaryHeadImageNet, self).__init__()
91 | self.features = nn.Sequential(
92 | nn.ReLU(inplace=True),
93 | nn.AvgPool2d(5, stride=2, padding=0, count_include_pad=False),
94 | nn.Conv2d(C, 128, 1, bias=False),
95 | nn.BatchNorm2d(128),
96 | nn.ReLU(inplace=True),
97 | nn.Conv2d(128, 768, 2, bias=False),
98 | # NOTE: This batchnorm was omitted in my earlier implementation due to a typo.
99 | # Commenting it out for consistency with the experiments in the paper.
100 | # nn.BatchNorm2d(768),
101 | nn.ReLU(inplace=True)
102 | )
103 | self.classifier = nn.Linear(768, num_classes)
104 |
105 | def forward(self, x):
106 | x = self.features(x)
107 | x = self.classifier(x.view(x.size(0),-1))
108 | return x
109 |
110 |
111 | class NetworkCIFAR(nn.Module):
112 |
113 | def __init__(self, C, num_classes, layers, auxiliary, genotype):
114 | super(NetworkCIFAR, self).__init__()
115 | self._layers = layers
116 | self._auxiliary = auxiliary
117 |
118 | stem_multiplier = 3
119 | C_curr = stem_multiplier*C
120 | self.stem = nn.Sequential(
121 | nn.Conv2d(3, C_curr, 3, padding=1, bias=False),
122 | nn.BatchNorm2d(C_curr)
123 | )
124 |
125 | C_prev_prev, C_prev, C_curr = C_curr, C_curr, C
126 | self.cells = nn.ModuleList()
127 | reduction_prev = False
128 | for i in range(layers):
129 | if i in [layers//3, 2*layers//3]:
130 | C_curr *= 2
131 | reduction = True
132 | else:
133 | reduction = False
134 | cell = Cell(genotype, C_prev_prev, C_prev, C_curr, reduction, reduction_prev)
135 | reduction_prev = reduction
136 | self.cells += [cell]
137 | C_prev_prev, C_prev = C_prev, cell.multiplier*C_curr
138 | if i == 2*layers//3:
139 | C_to_auxiliary = C_prev
140 |
141 | if auxiliary:
142 | self.auxiliary_head = AuxiliaryHeadCIFAR(C_to_auxiliary, num_classes)
143 | self.global_pooling = nn.AdaptiveAvgPool2d(1)
144 | self.classifier = nn.Linear(C_prev, num_classes)
145 |
146 | def forward(self, input):
147 | logits_aux = None
148 | s0 = s1 = self.stem(input)
149 | for i, cell in enumerate(self.cells):
150 | s0, s1 = s1, cell(s0, s1, self.drop_path_prob)
151 | if i == 2*self._layers//3:
152 | if self._auxiliary and self.training:
153 | logits_aux = self.auxiliary_head(s1)
154 | out = self.global_pooling(s1)
155 | logits = self.classifier(out.view(out.size(0),-1))
156 | return logits, logits_aux
157 |
158 |
159 | class NetworkImageNet(nn.Module):
160 |
161 | def __init__(self, C, num_classes, layers, auxiliary, genotype):
162 | super(NetworkImageNet, self).__init__()
163 | self._layers = layers
164 | self._auxiliary = auxiliary
165 |
166 | self.stem0 = nn.Sequential(
167 | nn.Conv2d(3, C // 2, kernel_size=3, stride=2, padding=1, bias=False),
168 | nn.BatchNorm2d(C // 2),
169 | nn.ReLU(inplace=True),
170 | nn.Conv2d(C // 2, C, 3, stride=2, padding=1, bias=False),
171 | nn.BatchNorm2d(C),
172 | )
173 |
174 | self.stem1 = nn.Sequential(
175 | nn.ReLU(inplace=True),
176 | nn.Conv2d(C, C, 3, stride=2, padding=1, bias=False),
177 | nn.BatchNorm2d(C),
178 | )
179 |
180 | C_prev_prev, C_prev, C_curr = C, C, C
181 |
182 | self.cells = nn.ModuleList()
183 | reduction_prev = True
184 | for i in range(layers):
185 | if i in [layers // 3, 2 * layers // 3]:
186 | C_curr *= 2
187 | reduction = True
188 | else:
189 | reduction = False
190 | cell = Cell(genotype, C_prev_prev, C_prev, C_curr, reduction, reduction_prev)
191 | reduction_prev = reduction
192 | self.cells += [cell]
193 | C_prev_prev, C_prev = C_prev, cell.multiplier * C_curr
194 | if i == 2 * layers // 3:
195 | C_to_auxiliary = C_prev
196 |
197 | if auxiliary:
198 | self.auxiliary_head = AuxiliaryHeadImageNet(C_to_auxiliary, num_classes)
199 | self.global_pooling = nn.AvgPool2d(7)
200 | self.classifier = nn.Linear(C_prev, num_classes)
201 | self.drop_path_prob = 0
202 |
203 | def forward(self, input):
204 | logits_aux = None
205 | s0 = self.stem0(input)
206 | s1 = self.stem1(s0)
207 | for i, cell in enumerate(self.cells):
208 | s0, s1 = s1, cell(s0, s1, self.drop_path_prob)
209 | if i == 2 * self._layers // 3:
210 | if self._auxiliary and self.training:
211 | logits_aux = self.auxiliary_head(s1)
212 | out = self.global_pooling(s1)
213 | logits = self.classifier(out.view(out.size(0), -1))
214 | return logits, logits_aux
215 |
216 |
--------------------------------------------------------------------------------
/darts/cnn/operations.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 | OPS = {
5 | 'none' : lambda C, stride, affine: Zero(stride),
6 | 'avg_pool_3x3' : lambda C, stride, affine: nn.AvgPool2d(3, stride=stride, padding=1, count_include_pad=False),
7 | 'max_pool_3x3' : lambda C, stride, affine: nn.MaxPool2d(3, stride=stride, padding=1),
8 | 'skip_connect' : lambda C, stride, affine: Identity() if stride == 1 else FactorizedReduce(C, C, affine=affine),
9 | 'sep_conv_3x3' : lambda C, stride, affine: SepConv(C, C, 3, stride, 1, affine=affine),
10 | 'sep_conv_5x5' : lambda C, stride, affine: SepConv(C, C, 5, stride, 2, affine=affine),
11 | 'sep_conv_7x7' : lambda C, stride, affine: SepConv(C, C, 7, stride, 3, affine=affine),
12 | 'dil_conv_3x3' : lambda C, stride, affine: DilConv(C, C, 3, stride, 2, 2, affine=affine),
13 | 'dil_conv_5x5' : lambda C, stride, affine: DilConv(C, C, 5, stride, 4, 2, affine=affine),
14 | 'conv_7x1_1x7' : lambda C, stride, affine: nn.Sequential(
15 | nn.ReLU(inplace=False),
16 | nn.Conv2d(C, C, (1,7), stride=(1, stride), padding=(0, 3), bias=False),
17 | nn.Conv2d(C, C, (7,1), stride=(stride, 1), padding=(3, 0), bias=False),
18 | nn.BatchNorm2d(C, affine=affine)
19 | ),
20 | }
21 |
22 | class ReLUConvBN(nn.Module):
23 |
24 | def __init__(self, C_in, C_out, kernel_size, stride, padding, affine=True):
25 | super(ReLUConvBN, self).__init__()
26 | self.op = nn.Sequential(
27 | nn.ReLU(inplace=False),
28 | nn.Conv2d(C_in, C_out, kernel_size, stride=stride, padding=padding, bias=False),
29 | nn.BatchNorm2d(C_out, affine=affine)
30 | )
31 |
32 | def forward(self, x):
33 | return self.op(x)
34 |
35 | class DilConv(nn.Module):
36 |
37 | def __init__(self, C_in, C_out, kernel_size, stride, padding, dilation, affine=True):
38 | super(DilConv, self).__init__()
39 | self.op = nn.Sequential(
40 | nn.ReLU(inplace=False),
41 | nn.Conv2d(C_in, C_in, kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation, groups=C_in, bias=False),
42 | nn.Conv2d(C_in, C_out, kernel_size=1, padding=0, bias=False),
43 | nn.BatchNorm2d(C_out, affine=affine),
44 | )
45 |
46 | def forward(self, x):
47 | return self.op(x)
48 |
49 |
50 | class SepConv(nn.Module):
51 |
52 | def __init__(self, C_in, C_out, kernel_size, stride, padding, affine=True):
53 | super(SepConv, self).__init__()
54 | self.op = nn.Sequential(
55 | nn.ReLU(inplace=False),
56 | nn.Conv2d(C_in, C_in, kernel_size=kernel_size, stride=stride, padding=padding, groups=C_in, bias=False),
57 | nn.Conv2d(C_in, C_in, kernel_size=1, padding=0, bias=False),
58 | nn.BatchNorm2d(C_in, affine=affine),
59 | nn.ReLU(inplace=False),
60 | nn.Conv2d(C_in, C_in, kernel_size=kernel_size, stride=1, padding=padding, groups=C_in, bias=False),
61 | nn.Conv2d(C_in, C_out, kernel_size=1, padding=0, bias=False),
62 | nn.BatchNorm2d(C_out, affine=affine),
63 | )
64 |
65 | def forward(self, x):
66 | return self.op(x)
67 |
68 |
69 | class Identity(nn.Module):
70 |
71 | def __init__(self):
72 | super(Identity, self).__init__()
73 |
74 | def forward(self, x):
75 | return x
76 |
77 |
78 | class Zero(nn.Module):
79 |
80 | def __init__(self, stride):
81 | super(Zero, self).__init__()
82 | self.stride = stride
83 |
84 | def forward(self, x):
85 | if self.stride == 1:
86 | return x.mul(0.)
87 | return x[:,:,::self.stride,::self.stride].mul(0.)
88 |
89 |
90 | class FactorizedReduce(nn.Module):
91 |
92 | def __init__(self, C_in, C_out, affine=True):
93 | super(FactorizedReduce, self).__init__()
94 | assert C_out % 2 == 0
95 | self.relu = nn.ReLU(inplace=False)
96 | self.conv_1 = nn.Conv2d(C_in, C_out // 2, 1, stride=2, padding=0, bias=False)
97 | self.conv_2 = nn.Conv2d(C_in, C_out // 2, 1, stride=2, padding=0, bias=False)
98 | self.bn = nn.BatchNorm2d(C_out, affine=affine)
99 |
100 | def forward(self, x):
101 | x = self.relu(x)
102 | out = torch.cat([self.conv_1(x), self.conv_2(x[:,:,1:,1:])], dim=1)
103 | out = self.bn(out)
104 | return out
105 |
106 |
--------------------------------------------------------------------------------
/darts/cnn/test.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 | sys.path.insert(0, os.getcwd())
4 | import glob
5 | import numpy as np
6 | import torch
7 | import utils
8 | import logging
9 | import argparse
10 | import torch.nn as nn
11 | import darts.cnn.genotypes as genotypes
12 | import torch.utils
13 | import torchvision.datasets as dset
14 | import torch.backends.cudnn as cudnn
15 |
16 | from darts.cnn.model import NetworkCIFAR as Network
17 |
18 |
19 | parser = argparse.ArgumentParser("cifar")
20 | parser.add_argument('--data', type=str, default='../data', help='location of the data corpus')
21 | parser.add_argument('--batch_size', type=int, default=32, help='batch size')
22 | parser.add_argument('--report_freq', type=float, default=50, help='report frequency')
23 | parser.add_argument('--gpu', type=int, default=0, help='gpu device id')
24 | parser.add_argument('--init_channels', type=int, default=36, help='num of init channels')
25 | parser.add_argument('--layers', type=int, default=20, help='total number of layers')
26 | parser.add_argument('--model_path', type=str, default='EXP/model.pt', help='path of pretrained model')
27 | parser.add_argument('--auxiliary', action='store_true', default=False, help='use auxiliary tower')
28 | parser.add_argument('--cutout', action='store_true', default=False, help='use cutout')
29 | parser.add_argument('--cutout_length', type=int, default=16, help='cutout length')
30 | parser.add_argument('--drop_path_prob', type=float, default=0.2, help='drop path probability')
31 | parser.add_argument('--seed', type=int, default=0, help='random seed')
32 | parser.add_argument('--arch', type=str, default='BANANAS', help='which architecture to use')
33 | args = parser.parse_args()
34 |
35 | log_format = '%(asctime)s %(message)s'
36 | logging.basicConfig(stream=sys.stdout, level=logging.INFO,
37 | format=log_format, datefmt='%m/%d %I:%M:%S %p')
38 |
39 | CIFAR_CLASSES = 10
40 |
41 |
42 | def main():
43 |
44 | np.random.seed(args.seed)
45 |
46 | if torch.cuda.is_available():
47 | device = torch.device('cuda:{}'.format(args.gpu))
48 | cudnn.benchmark = True
49 | torch.manual_seed(args.seed)
50 | cudnn.enabled = True
51 | cudnn.deterministic = True
52 | torch.cuda.manual_seed(args.seed)
53 | logging.info('gpu device = %d' % args.gpu)
54 | else:
55 | device = torch.device('cpu')
56 | logging.info('No gpu device available')
57 | torch.manual_seed(args.seed)
58 |
59 | logging.info("args = %s", args)
60 | genotype = eval("genotypes.%s" % args.arch)
61 | model = Network(args.init_channels, CIFAR_CLASSES, args.layers, args.auxiliary, genotype)
62 | model = model.to(device)
63 | utils.load(model, args.model_path, args.gpu)
64 |
65 | logging.info("param size = %fMB", utils.count_parameters_in_MB(model))
66 |
67 | criterion = nn.CrossEntropyLoss()
68 | criterion = criterion.cuda()
69 |
70 | _, test_transform = utils._data_transforms_cifar10(args.cutout, args.cutout_length)
71 | test_data = dset.CIFAR10(root=args.data, train=False, download=True, transform=test_transform)
72 |
73 | test_queue = torch.utils.data.DataLoader(
74 | test_data, batch_size=args.batch_size, shuffle=False, pin_memory=True, num_workers=2)
75 |
76 | model.drop_path_prob = args.drop_path_prob
77 | test_acc, test_obj = infer(test_queue, model, criterion, args.gpu)
78 | logging.info('test_acc %f', test_acc)
79 |
80 |
81 | def infer(test_queue, model, criterion, gpu_id=0):
82 | objs = utils.AvgrageMeter()
83 | top1 = utils.AvgrageMeter()
84 | top5 = utils.AvgrageMeter()
85 | model.eval()
86 |
87 | device = torch.device('cuda:{}'.format(gpu_id) if torch.cuda.is_available() \
88 | else 'cpu')
89 |
90 | for step, (input, target) in enumerate(test_queue):
91 | input = input.to(device)
92 | target = target.to(device)
93 |
94 | logits, _ = model(input)
95 | loss = criterion(logits, target)
96 |
97 | prec1, prec5 = utils.accuracy(logits, target, topk=(1, 5))
98 | n = input.size(0)
99 | objs.update(loss.item(), n)
100 | top1.update(prec1.item(), n)
101 | top5.update(prec5.item(), n)
102 |
103 | if step % args.report_freq == 0:
104 | logging.info('test %03d %e %f %f', step, objs.avg, top1.avg, top5.avg)
105 |
106 | return top1.avg, objs.avg
107 |
108 |
109 | if __name__ == '__main__':
110 | main()
111 |
112 |
--------------------------------------------------------------------------------
/darts/cnn/test_imagenet.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 | import numpy as np
4 | import torch
5 | import utils
6 | import glob
7 | import random
8 | import logging
9 | import argparse
10 | import torch.nn as nn
11 | import genotypes
12 | import torch.utils
13 | import torchvision.datasets as dset
14 | import torchvision.transforms as transforms
15 | import torch.backends.cudnn as cudnn
16 |
17 | from model import NetworkImageNet as Network
18 |
19 |
20 | parser = argparse.ArgumentParser("imagenet")
21 | parser.add_argument('--data', type=str, default='../data/imagenet/', help='location of the data corpus')
22 | parser.add_argument('--batch_size', type=int, default=128, help='batch size')
23 | parser.add_argument('--report_freq', type=float, default=100, help='report frequency')
24 | parser.add_argument('--gpu', type=int, default=0, help='gpu device id')
25 | parser.add_argument('--init_channels', type=int, default=48, help='num of init channels')
26 | parser.add_argument('--layers', type=int, default=14, help='total number of layers')
27 | parser.add_argument('--model_path', type=str, default='EXP/model.pt', help='path of pretrained model')
28 | parser.add_argument('--auxiliary', action='store_true', default=False, help='use auxiliary tower')
29 | parser.add_argument('--drop_path_prob', type=float, default=0, help='drop path probability')
30 | parser.add_argument('--seed', type=int, default=0, help='random seed')
31 | parser.add_argument('--arch', type=str, default='DARTS', help='which architecture to use')
32 | args = parser.parse_args()
33 |
34 | log_format = '%(asctime)s %(message)s'
35 | logging.basicConfig(stream=sys.stdout, level=logging.INFO,
36 | format=log_format, datefmt='%m/%d %I:%M:%S %p')
37 |
38 | CLASSES = 1000
39 |
40 |
41 | def main():
42 | if not torch.cuda.is_available():
43 | logging.info('no gpu device available')
44 | sys.exit(1)
45 |
46 | np.random.seed(args.seed)
47 | torch.cuda.set_device(args.gpu)
48 | cudnn.benchmark = True
49 | torch.manual_seed(args.seed)
50 | cudnn.enabled=True
51 | torch.cuda.manual_seed(args.seed)
52 | logging.info('gpu device = %d' % args.gpu)
53 | logging.info("args = %s", args)
54 |
55 | genotype = eval("genotypes.%s" % args.arch)
56 | model = Network(args.init_channels, CLASSES, args.layers, args.auxiliary, genotype)
57 | model = model.cuda()
58 | model.load_state_dict(torch.load(args.model_path)['state_dict'])
59 |
60 | logging.info("param size = %fMB", utils.count_parameters_in_MB(model))
61 |
62 | criterion = nn.CrossEntropyLoss()
63 | criterion = criterion.cuda()
64 |
65 | validdir = os.path.join(args.data, 'val')
66 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
67 | valid_data = dset.ImageFolder(
68 | validdir,
69 | transforms.Compose([
70 | transforms.Resize(256),
71 | transforms.CenterCrop(224),
72 | transforms.ToTensor(),
73 | normalize,
74 | ]))
75 |
76 | valid_queue = torch.utils.data.DataLoader(
77 | valid_data, batch_size=args.batch_size, shuffle=False, pin_memory=True, num_workers=4)
78 |
79 | model.drop_path_prob = args.drop_path_prob
80 | valid_acc_top1, valid_acc_top5, valid_obj = infer(valid_queue, model, criterion)
81 | logging.info('valid_acc_top1 %f', valid_acc_top1)
82 | logging.info('valid_acc_top5 %f', valid_acc_top5)
83 |
84 |
85 | def infer(valid_queue, model, criterion):
86 | objs = utils.AvgrageMeter()
87 | top1 = utils.AvgrageMeter()
88 | top5 = utils.AvgrageMeter()
89 | model.eval()
90 |
91 | for step, (input, target) in enumerate(valid_queue):
92 | input = input.cuda()
93 | target = target.cuda()
94 |
95 | logits, _ = model(input)
96 | loss = criterion(logits, target)
97 |
98 | prec1, prec5 = utils.accuracy(logits, target, topk=(1, 5))
99 | n = input.size(0)
100 | objs.update(loss.data[0], n)
101 | top1.update(prec1.data[0], n)
102 | top5.update(prec5.data[0], n)
103 |
104 | if step % args.report_freq == 0:
105 | logging.info('valid %03d %e %f %f', step, objs.avg, top1.avg, top5.avg)
106 |
107 | return top1.avg, top5.avg, objs.avg
108 |
109 |
110 | if __name__ == '__main__':
111 | main()
112 |
--------------------------------------------------------------------------------
/darts/cnn/train.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 | sys.path.insert(0, os.getcwd())
4 | import time
5 | import glob
6 | import numpy as np
7 | import random
8 | import torch
9 | import darts.cnn.utils as utils
10 | import logging
11 | import argparse
12 | import torch.nn as nn
13 | import darts.cnn.genotypes as genotypes
14 | import torch.utils
15 | import torchvision.datasets as dset
16 | import torch.backends.cudnn as cudnn
17 |
18 | from darts.cnn.model import NetworkCIFAR as Network
19 |
20 |
21 | parser = argparse.ArgumentParser("cifar")
22 | parser.add_argument('--data', type=str, default='./data', help='location of the data corpus')
23 | parser.add_argument('--batch_size', type=int, default=96, help='batch size')
24 | parser.add_argument('--learning_rate', type=float, default=0.025, help='init learning rate')
25 | parser.add_argument('--momentum', type=float, default=0.9, help='momentum')
26 | parser.add_argument('--weight_decay', type=float, default=3e-4, help='weight decay')
27 | parser.add_argument('--report_freq', type=float, default=500, help='report frequency')
28 | parser.add_argument('--gpu', type=int, default=0, help='gpu device id')
29 | parser.add_argument('--epochs', type=int, default=600, help='num of training epochs')
30 | parser.add_argument('--init_channels', type=int, default=36, help='num of init channels')
31 | parser.add_argument('--layers', type=int, default=20, help='total number of layers')
32 | parser.add_argument('--model_path', type=str, default='saved_models', help='path to save the model')
33 | parser.add_argument('--auxiliary', action='store_true', default=False, help='use auxiliary tower')
34 | parser.add_argument('--auxiliary_weight', type=float, default=0.4, help='weight for auxiliary loss')
35 | parser.add_argument('--cutout', action='store_true', default=False, help='use cutout')
36 | parser.add_argument('--cutout_length', type=int, default=16, help='cutout length')
37 | parser.add_argument('--drop_path_prob', type=float, default=0.2, help='drop path probability')
38 | parser.add_argument('--save', type=str, default='EXP', help='experiment name')
39 | parser.add_argument('--seed', type=int, default=3, help='random seed')
40 | parser.add_argument('--arch', type=str, default='DARTS', help='which architecture to use')
41 | parser.add_argument('--grad_clip', type=float, default=5, help='gradient clipping')
42 | args = parser.parse_args()
43 |
44 | args.save = 'eval-{}-{}'.format(args.save, time.strftime("%Y%m%d-%H%M%S"))
45 | utils.create_exp_dir(args.save, scripts_to_save=glob.glob('*.py'))
46 |
47 | log_format = '%(asctime)s %(message)s'
48 | logging.basicConfig(stream=sys.stdout, level=logging.INFO,
49 | format=log_format, datefmt='%m/%d %I:%M:%S %p')
50 | fh = logging.FileHandler(os.path.join(args.save, 'log.txt'))
51 | fh.setFormatter(logging.Formatter(log_format))
52 | logging.getLogger().addHandler(fh)
53 |
54 | CIFAR_CLASSES = 10
55 |
56 |
57 | def main():
58 |
59 | np.random.seed(args.seed)
60 | random.seed(args.seed)
61 |
62 | if torch.cuda.is_available():
63 | device = torch.device('cuda:{}'.format(args.gpu))
64 | cudnn.benchmark = False
65 | torch.manual_seed(args.seed)
66 | cudnn.enabled = True
67 | cudnn.deterministic = True
68 | torch.cuda.manual_seed(args.seed)
69 | logging.info('gpu device = %d' % args.gpu)
70 | else:
71 | device = torch.device('cpu')
72 | logging.info('No gpu device available')
73 | torch.manual_seed(args.seed)
74 |
75 | logging.info("args = %s", args)
76 | genotype = eval("genotypes.%s" % args.arch)
77 | model = Network(args.init_channels, CIFAR_CLASSES, args.layers, args.auxiliary, genotype)
78 | model = model.to(device)
79 |
80 | logging.info("param size = %fMB", utils.count_parameters_in_MB(model))
81 | total_params = sum(x.data.nelement() for x in model.parameters())
82 | logging.info('Model total parameters: {}'.format(total_params))
83 |
84 | criterion = nn.CrossEntropyLoss()
85 | criterion = criterion.cuda()
86 | optimizer = torch.optim.SGD(
87 | model.parameters(),
88 | args.learning_rate,
89 | momentum=args.momentum,
90 | weight_decay=args.weight_decay
91 | )
92 |
93 | train_transform, valid_transform = utils._data_transforms_cifar10(args.cutout, args.cutout_length)
94 | train_data = dset.CIFAR10(root=args.data, train=True, download=True, transform=train_transform)
95 | valid_data = dset.CIFAR10(root=args.data, train=False, download=True, transform=valid_transform)
96 |
97 | train_queue = torch.utils.data.DataLoader(
98 | train_data, batch_size=args.batch_size, shuffle=True, pin_memory=True, num_workers=4)
99 |
100 | valid_queue = torch.utils.data.DataLoader(
101 | valid_data, batch_size=args.batch_size*4, shuffle=False, pin_memory=True, num_workers=4)
102 |
103 | scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, float(args.epochs))
104 |
105 | for epoch in range(args.epochs):
106 |
107 | logging.info('epoch %d lr %e', epoch, scheduler.get_lr()[0])
108 | model.drop_path_prob = args.drop_path_prob * epoch / args.epochs
109 |
110 | train_acc, train_obj = train(train_queue, model, criterion, optimizer, args.gpu)
111 | logging.info('train_acc %f', train_acc)
112 |
113 | valid_acc, valid_obj = infer(valid_queue, model, criterion, args.gpu)
114 | logging.info('valid_acc %f', valid_acc)
115 |
116 | scheduler.step()
117 |
118 | utils.save(model, os.path.join(args.save, 'weights.pt'))
119 |
120 |
121 | def train(train_queue, model, criterion, optimizer, gpu_id=0):
122 | objs = utils.AvgrageMeter()
123 | top1 = utils.AvgrageMeter()
124 | top5 = utils.AvgrageMeter()
125 | model.train()
126 |
127 | device = torch.device('cuda:{}'.format(gpu_id) if torch.cuda.is_available() \
128 | else 'cpu')
129 |
130 |
131 | for step, (input, target) in enumerate(train_queue):
132 | input = input.to(device)
133 | target = target.to(device)
134 |
135 | optimizer.zero_grad()
136 | logits, logits_aux = model(input)
137 | loss = criterion(logits, target)
138 | if args.auxiliary:
139 | loss_aux = criterion(logits_aux, target)
140 | loss += args.auxiliary_weight*loss_aux
141 | loss.backward()
142 | nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip)
143 | optimizer.step()
144 |
145 | prec1, prec5 = utils.accuracy(logits, target, topk=(1, 5))
146 | n = input.size(0)
147 | objs.update(loss.item(), n)
148 | top1.update(prec1.item(), n)
149 | top5.update(prec5.item(), n)
150 |
151 | if step % args.report_freq == 0:
152 | logging.info('train %03d %e %f %f', step, objs.avg, top1.avg, top5.avg)
153 |
154 | return top1.avg, objs.avg
155 |
156 |
157 | def infer(valid_queue, model, criterion, gpu_id=0):
158 | with torch.no_grad():
159 | objs = utils.AvgrageMeter()
160 | top1 = utils.AvgrageMeter()
161 | top5 = utils.AvgrageMeter()
162 | model.eval()
163 |
164 | device = torch.device('cuda:{}'.format(gpu_id) if torch.cuda.is_available() \
165 | else 'cpu')
166 |
167 | for step, (input, target) in enumerate(valid_queue):
168 | input = input.to(device)
169 | target = target.to(device)
170 |
171 | logits, _ = model(input)
172 | loss = criterion(logits, target)
173 |
174 | prec1, prec5 = utils.accuracy(logits, target, topk=(1, 5))
175 | n = input.size(0)
176 | objs.update(loss.item(), n)
177 | top1.update(prec1.item(), n)
178 | top5.update(prec5.item(), n)
179 |
180 | if step % args.report_freq == 0:
181 | logging.info('valid %03d %e %f %f', step, objs.avg, top1.avg, top5.avg)
182 |
183 | return top1.avg, objs.avg
184 |
185 |
186 | if __name__ == '__main__':
187 | main()
188 |
189 |
--------------------------------------------------------------------------------
/darts/cnn/train_imagenet.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 | sys.path.insert(0, os.getcwd())
4 | import numpy as np
5 | import time
6 | import torch
7 | import darts.cnn.utils as utils
8 | import glob
9 | import random
10 | import logging
11 | import argparse
12 | import torch.nn as nn
13 | import darts.cnn.genotypes as genotypes
14 | import torch.utils
15 | import torchvision.datasets as dset
16 | import torchvision.transforms as transforms
17 | import torch.backends.cudnn as cudnn
18 | from darts.cnn.model import NetworkImageNet as Network
19 | from thop import profile
20 |
21 |
22 | parser = argparse.ArgumentParser("imagenet")
23 | parser.add_argument('--data', type=str, default='data/imagenet/', help='location of the data corpus')
24 | parser.add_argument('--batch_size', type=int, default=128, help='batch size')
25 | parser.add_argument('--learning_rate', type=float, default=0.1, help='init learning rate')
26 | parser.add_argument('--momentum', type=float, default=0.9, help='momentum')
27 | parser.add_argument('--weight_decay', type=float, default=3e-5, help='weight decay')
28 | parser.add_argument('--report_freq', type=float, default=100, help='report frequency')
29 | parser.add_argument('--gpu', type=int, default=0, help='gpu device id')
30 | parser.add_argument('--epochs', type=int, default=250, help='num of training epochs')
31 | parser.add_argument('--init_channels', type=int, default=48, help='num of init channels')
32 | parser.add_argument('--layers', type=int, default=14, help='total number of layers')
33 | parser.add_argument('--auxiliary', action='store_true', default=False, help='use auxiliary tower')
34 | parser.add_argument('--auxiliary_weight', type=float, default=0.4, help='weight for auxiliary loss')
35 | parser.add_argument('--drop_path_prob', type=float, default=0, help='drop path probability')
36 | parser.add_argument('--save', type=str, default='EXP', help='experiment name')
37 | parser.add_argument('--seed', type=int, default=0, help='random seed')
38 | parser.add_argument('--arch', type=str, default='DARTS', help='which architecture to use')
39 | parser.add_argument('--grad_clip', type=float, default=5., help='gradient clipping')
40 | parser.add_argument('--label_smooth', type=float, default=0.1, help='label smoothing')
41 | parser.add_argument('--gamma', type=float, default=0.97, help='learning rate decay')
42 | parser.add_argument('--decay_period', type=int, default=1, help='epochs between two learning rate decays')
43 | parser.add_argument('--parallel', action='store_true', default=False, help='data parallelism')
44 | args = parser.parse_args()
45 |
46 | args.save = 'eval-{}-{}'.format(args.save, time.strftime("%Y%m%d-%H%M%S"))
47 | utils.create_exp_dir(args.save, scripts_to_save=glob.glob('*.py'))
48 |
49 | log_format = '%(asctime)s %(message)s'
50 | logging.basicConfig(stream=sys.stdout, level=logging.INFO,
51 | format=log_format, datefmt='%m/%d %I:%M:%S %p')
52 | fh = logging.FileHandler(os.path.join(args.save, 'log.txt'))
53 | fh.setFormatter(logging.Formatter(log_format))
54 | logging.getLogger().addHandler(fh)
55 |
56 | CLASSES = 1000
57 |
58 |
59 | class CrossEntropyLabelSmooth(nn.Module):
60 |
61 | def __init__(self, num_classes, epsilon):
62 | super(CrossEntropyLabelSmooth, self).__init__()
63 | self.num_classes = num_classes
64 | self.epsilon = epsilon
65 | self.logsoftmax = nn.LogSoftmax(dim=1)
66 |
67 | def forward(self, inputs, targets):
68 | log_probs = self.logsoftmax(inputs)
69 | targets = torch.zeros_like(log_probs).scatter_(1, targets.unsqueeze(1), 1)
70 | targets = (1 - self.epsilon) * targets + self.epsilon / self.num_classes
71 | loss = (-targets * log_probs).mean(0).sum()
72 | return loss
73 |
74 |
75 | def main():
76 | if not torch.cuda.is_available():
77 | logging.info('no gpu device available')
78 | sys.exit(1)
79 |
80 | np.random.seed(args.seed)
81 | torch.cuda.set_device(args.gpu)
82 | cudnn.benchmark = True
83 | torch.manual_seed(args.seed)
84 | cudnn.enabled=True
85 | torch.cuda.manual_seed(args.seed)
86 | logging.info('gpu device = %d' % args.gpu)
87 | logging.info("args = %s", args)
88 |
89 | genotype = eval("genotypes.%s" % args.arch)
90 | model = Network(args.init_channels, CLASSES, args.layers, args.auxiliary, genotype)
91 |
92 | if args.parallel:
93 | model = nn.DataParallel(model).cuda()
94 | else:
95 | model = model.cuda()
96 |
97 | #input = torch.randn(1,3,224,224).cuda()
98 | #macs, params = profile(model, inputs=(input,))
99 | #print('flops: {}, params: {}'.format(macs, params)) #arch2vec_bo: 580M, 5.18M; arch2vec_rl: 533M, 4.82M
100 | #print("param size = %fMB", utils.count_parameters_in_MB(model))
101 | #exit()
102 |
103 | logging.info("param size = %fMB", utils.count_parameters_in_MB(model))
104 |
105 | criterion = nn.CrossEntropyLoss()
106 | criterion = criterion.cuda()
107 | criterion_smooth = CrossEntropyLabelSmooth(CLASSES, args.label_smooth)
108 | criterion_smooth = criterion_smooth.cuda()
109 |
110 | optimizer = torch.optim.SGD(
111 | model.parameters(),
112 | args.learning_rate,
113 | momentum=args.momentum,
114 | weight_decay=args.weight_decay
115 | )
116 |
117 | traindir = os.path.join(args.data, 'train')
118 | validdir = os.path.join(args.data, 'val')
119 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
120 | train_data = dset.ImageFolder(
121 | traindir,
122 | transforms.Compose([
123 | transforms.RandomResizedCrop(224),
124 | transforms.RandomHorizontalFlip(),
125 | transforms.ColorJitter(
126 | brightness=0.4,
127 | contrast=0.4,
128 | saturation=0.4,
129 | hue=0.2),
130 | transforms.ToTensor(),
131 | normalize,
132 | ]))
133 | valid_data = dset.ImageFolder(
134 | validdir,
135 | transforms.Compose([
136 | transforms.Resize(256),
137 | transforms.CenterCrop(224),
138 | transforms.ToTensor(),
139 | normalize,
140 | ]))
141 |
142 | train_queue = torch.utils.data.DataLoader(
143 | train_data, batch_size=args.batch_size, shuffle=True, pin_memory=True, num_workers=4)
144 |
145 | valid_queue = torch.utils.data.DataLoader(
146 | valid_data, batch_size=args.batch_size, shuffle=False, pin_memory=True, num_workers=4)
147 |
148 | scheduler = torch.optim.lr_scheduler.StepLR(optimizer, args.decay_period, gamma=args.gamma)
149 |
150 | best_acc_top1 = 0
151 | for epoch in range(args.epochs):
152 | logging.info('epoch %d lr %e', epoch, scheduler.get_lr()[0])
153 | model.drop_path_prob = args.drop_path_prob * epoch / args.epochs
154 |
155 | train_acc, train_obj = train(train_queue, model, criterion_smooth, optimizer)
156 | logging.info('train_acc %f', train_acc)
157 |
158 | valid_acc_top1, valid_acc_top5, valid_obj = infer(valid_queue, model, criterion)
159 | logging.info('valid_acc_top1 %f', valid_acc_top1)
160 | logging.info('valid_acc_top5 %f', valid_acc_top5)
161 |
162 | is_best = False
163 | if valid_acc_top1 > best_acc_top1:
164 | best_acc_top1 = valid_acc_top1
165 | is_best = True
166 |
167 | utils.save_checkpoint({
168 | 'epoch': epoch + 1,
169 | 'state_dict': model.state_dict(),
170 | 'best_acc_top1': best_acc_top1,
171 | 'optimizer' : optimizer.state_dict(),
172 | }, is_best, args.save)
173 |
174 | scheduler.step()
175 |
176 |
177 | def train(train_queue, model, criterion, optimizer):
178 | objs = utils.AvgrageMeter()
179 | top1 = utils.AvgrageMeter()
180 | top5 = utils.AvgrageMeter()
181 | model.train()
182 |
183 | for step, (input, target) in enumerate(train_queue):
184 | target = target.cuda()
185 | input = input.cuda()
186 |
187 | optimizer.zero_grad()
188 | logits, logits_aux = model(input)
189 | loss = criterion(logits, target)
190 | if args.auxiliary:
191 | loss_aux = criterion(logits_aux, target)
192 | loss += args.auxiliary_weight*loss_aux
193 |
194 | loss.backward()
195 | nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip)
196 | optimizer.step()
197 |
198 | prec1, prec5 = utils.accuracy(logits, target, topk=(1, 5))
199 | n = input.size(0)
200 | objs.update(loss.item(), n)
201 | top1.update(prec1.item(), n)
202 | top5.update(prec5.item(), n)
203 |
204 | if step % args.report_freq == 0:
205 | logging.info('train %03d %e %f %f', step, objs.avg, top1.avg, top5.avg)
206 |
207 | return top1.avg, objs.avg
208 |
209 |
210 | def infer(valid_queue, model, criterion):
211 | objs = utils.AvgrageMeter()
212 | top1 = utils.AvgrageMeter()
213 | top5 = utils.AvgrageMeter()
214 | model.eval()
215 |
216 | for step, (input, target) in enumerate(valid_queue):
217 | with torch.no_grad():
218 | input = input.cuda()
219 | target = target.cuda()
220 |
221 | logits, _ = model(input)
222 | loss = criterion(logits, target)
223 |
224 | prec1, prec5 = utils.accuracy(logits, target, topk=(1, 5))
225 | n = input.size(0)
226 | objs.update(loss.item(), n)
227 | top1.update(prec1.item(), n)
228 | top5.update(prec5.item(), n)
229 |
230 | if step % args.report_freq == 0:
231 | logging.info('valid %03d %e %f %f', step, objs.avg, top1.avg, top5.avg)
232 |
233 | return top1.avg, top5.avg, objs.avg
234 |
235 |
236 | if __name__ == '__main__':
237 | main()
238 |
--------------------------------------------------------------------------------
/darts/cnn/train_search.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 | sys.path.insert(0, os.getcwd())
4 | import time
5 | import glob
6 | import numpy as np
7 | import random
8 | import torch
9 | import darts.cnn.utils as utils
10 | import logging
11 | import torch.nn as nn
12 | import darts.cnn.genotypes
13 | import torch.utils
14 | import torchvision.datasets as dset
15 | import torch.backends.cudnn as cudnn
16 | from collections import namedtuple
17 |
18 | from darts.cnn.model import NetworkCIFAR as Network
19 |
20 | class Train:
21 |
22 | def __init__(self):
23 |
24 | self.data='./data'
25 | self.batch_size= 96
26 | self.learning_rate= 0.025
27 | self.momentum= 0.9
28 | self.weight_decay = 3e-4
29 | self.load_weights = 0
30 | self.report_freq = 500
31 | self.gpu = 0
32 | self.epochs = 50
33 | self.init_channels = 36
34 | self.layers = 20
35 | self.auxiliary = True
36 | self.auxiliary_weight = 0.4
37 | self.cutout = True
38 | self.cutout_length = 16
39 | self.drop_path_prob = 0.2
40 | self.save = 'EXP'
41 | self.seed = 0
42 | self.grad_clip = 5
43 | self.train_portion = 0.9
44 | self.validation_set = True
45 | self.CIFAR_CLASSES = 10
46 |
47 | def main(self, counter, seed, arch, epochs=50, gpu=0, load_weights=False, train_portion=0.9, save='model_search'):
48 |
49 | # Set up save file and logging
50 | self.save = save
51 | self.save = '{}'.format(self.save)
52 | utils.create_exp_dir(self.save, scripts_to_save=glob.glob('*.py'))
53 | log_format = '%(asctime)s %(message)s'
54 | logging.basicConfig(stream=sys.stdout, level=logging.INFO,
55 | format=log_format, datefmt='%m/%d %I:%M:%S %p')
56 | fh = logging.FileHandler(os.path.join(self.save, 'log-seed{}.txt'.format(seed)))
57 | fh.setFormatter(logging.Formatter(log_format))
58 | logging.getLogger().addHandler(fh)
59 |
60 |
61 | self.arch = arch
62 | self.epochs = epochs
63 | self.load_weights = load_weights
64 | self.gpu = gpu
65 | self.train_portion = train_portion
66 | if self.train_portion == 1:
67 | self.validation_set = False
68 | self.seed = seed
69 |
70 | #logging.info('Train class params')
71 | #logging.info('arch: {}, epochs: {}, gpu: {}, load_weights: {}, train_portion: {}'
72 | # .format(arch, epochs, gpu, load_weights, train_portion))
73 |
74 | # cpu-gpu switch
75 | if not torch.cuda.is_available():
76 | #logging.info('no gpu device available')
77 | torch.manual_seed(self.seed)
78 | device = torch.device('cpu')
79 |
80 | else:
81 | torch.cuda.manual_seed_all(self.seed)
82 | random.seed(self.seed)
83 | torch.manual_seed(self.seed)
84 | device = torch.device(self.gpu)
85 | cudnn.benchmark = False
86 | cudnn.enabled=True
87 | cudnn.deterministic=True
88 | #logging.info('gpu device = %d' % self.gpu)
89 |
90 | Genotype = namedtuple('Genotype', 'normal normal_concat reduce reduce_concat')
91 | genotype = eval(self.convert_to_genotype(counter, arch))
92 | model = Network(self.init_channels, self.CIFAR_CLASSES, self.layers, self.auxiliary, genotype)
93 | model = model.to(device)
94 |
95 | logging.info("param size = %fMB", utils.count_parameters_in_MB(model))
96 | print("param size = {:.4f}MB".format(utils.count_parameters_in_MB(model)))
97 | total_params = sum(x.data.nelement() for x in model.parameters())
98 | logging.info('Model total parameters: {}'.format(total_params))
99 | print('Model total parameters: {}'.format(total_params))
100 |
101 | criterion = nn.CrossEntropyLoss()
102 | criterion = criterion.to(device)
103 | optimizer = torch.optim.SGD(
104 | model.parameters(),
105 | self.learning_rate,
106 | momentum=self.momentum,
107 | weight_decay=self.weight_decay
108 | )
109 |
110 | train_transform, test_transform = utils._data_transforms_cifar10(self.cutout, self.cutout_length)
111 | train_data = dset.CIFAR10(root=self.data, train=True, download=True, transform=train_transform)
112 | test_data = dset.CIFAR10(root=self.data, train=False, download=True, transform=test_transform)
113 |
114 | num_train = len(train_data)
115 | indices = list(range(num_train))
116 | if self.validation_set:
117 | split = int(np.floor(self.train_portion * num_train))
118 | else:
119 | split = num_train
120 |
121 | train_queue = torch.utils.data.DataLoader(
122 | train_data, batch_size=self.batch_size,
123 | sampler=torch.utils.data.sampler.SubsetRandomSampler(indices[:split]),
124 | pin_memory=True, num_workers=4)
125 |
126 | if self.validation_set:
127 | valid_queue = torch.utils.data.DataLoader(
128 | train_data, batch_size=self.batch_size,
129 | sampler=torch.utils.data.sampler.SubsetRandomSampler(indices[split:num_train]),
130 | pin_memory=True, num_workers=4)
131 |
132 | test_queue = torch.utils.data.DataLoader(
133 | test_data, batch_size=self.batch_size, shuffle=False, pin_memory=True, num_workers=4)
134 |
135 | if self.load_weights:
136 | logging.info('loading saved weights')
137 | ml = 'cuda:{}'.format(self.gpu) if torch.cuda.is_available() else 'cpu'
138 | model.load_state_dict(torch.load('weights.pt', map_location = ml))
139 | logging.info('loaded saved weights')
140 |
141 | scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, float(self.epochs))
142 |
143 | valid_accs = []
144 | test_accs = []
145 |
146 | for epoch in range(self.epochs):
147 | logging.info('epoch %d lr %e', epoch, scheduler.get_lr()[0])
148 | print('epoch {} lr {}'.format(epoch, scheduler.get_lr()[0]))
149 | model.drop_path_prob = self.drop_path_prob * epoch / self.epochs
150 |
151 | train_acc, train_obj = self.train(train_queue, model, criterion, optimizer)
152 |
153 | if self.validation_set:
154 | valid_acc, valid_obj = self.infer(valid_queue, model, criterion)
155 | else:
156 | valid_acc, valid_obj = 0, 0
157 |
158 | test_acc, test_obj = self.infer(test_queue, model, criterion, test_data=True)
159 | logging.info('train_acc: {:.4f}, valid_acc: {:.4f}, test_acc: {:.4f}'.format(train_acc, valid_acc, test_acc))
160 | print('train_acc: {:.4f}, valid_acc: {:.4f}, test_acc: {:.4f}'.format(train_acc, valid_acc, test_acc))
161 |
162 | #utils.save(model, os.path.join(self.save, 'weights-seed-{}.pt'.format(seed)))
163 |
164 | if epoch in list(range(max(0, epochs - 5), epochs)):
165 | valid_accs.append((epoch, valid_acc))
166 | test_accs.append((epoch, test_acc))
167 |
168 | scheduler.step()
169 |
170 | return valid_accs, test_accs
171 |
172 |
173 | def convert_to_genotype(self, counter, arch):
174 | Genotype = namedtuple('Genotype', 'normal normal_concat reduce reduce_concat')
175 | geno = []
176 | for item in arch:
177 | geno.append((item[1], int(item[0])))
178 | geno = Genotype(normal=geno, normal_concat=[2,3,4,5], reduce=geno, reduce_concat=[2,3,4,5])
179 | logging.info('counter: {}, genotypes: {}'.format(counter, str(geno)))
180 | print('counter: {}, genotypes: {}'.format(counter, str(geno)))
181 | return str(geno)
182 |
183 |
184 | def train(self, train_queue, model, criterion, optimizer):
185 | objs = utils.AvgrageMeter()
186 | top1 = utils.AvgrageMeter()
187 | top5 = utils.AvgrageMeter()
188 | model.train()
189 |
190 | for step, (input, target) in enumerate(train_queue):
191 | device = torch.device('cuda:{}'.format(self.gpu) if torch.cuda.is_available() else 'cpu')
192 | input = input.to(device)
193 | target = target.to(device)
194 |
195 | optimizer.zero_grad()
196 | logits, logits_aux = model(input)
197 | loss = criterion(logits, target)
198 | if self.auxiliary:
199 | loss_aux = criterion(logits_aux, target)
200 | loss += self.auxiliary_weight*loss_aux
201 | loss.backward()
202 | nn.utils.clip_grad_norm_(model.parameters(), self.grad_clip)
203 | optimizer.step()
204 |
205 | prec1, prec5 = utils.accuracy(logits, target, topk=(1, 5))
206 | n = input.size(0)
207 |
208 | objs.update(loss.item(), n)
209 | top1.update(prec1.item(), n)
210 | top5.update(prec5.item(), n)
211 |
212 |
213 | return top1.avg, objs.avg
214 |
215 |
216 | def infer(self, valid_queue, model, criterion, test_data=False):
217 | objs = utils.AvgrageMeter()
218 | top1 = utils.AvgrageMeter()
219 | top5 = utils.AvgrageMeter()
220 | model.eval()
221 | device = torch.device('cuda:{}'.format(self.gpu) if torch.cuda.is_available() else 'cpu')
222 |
223 | for step, (input, target) in enumerate(valid_queue):
224 | with torch.no_grad():
225 | input = input.to(device)
226 | target = target.to(device)
227 |
228 | logits, _ = model(input)
229 | loss = criterion(logits, target)
230 |
231 | prec1, prec5 = utils.accuracy(logits, target, topk=(1, 5))
232 | n = input.size(0)
233 |
234 | objs.update(loss.item(), n)
235 | top1.update(prec1.item(), n)
236 | top5.update(prec5.item(), n)
237 |
238 | #if step % self.report_freq == 0:
239 | # if not test_data:
240 | # logging.info('valid %03d %e %f %f', step, objs.avg, top1.avg, top5.avg)
241 | # else:
242 | # logging.info('test %03d %e %f %f', step, objs.avg, top1.avg, top5.avg)
243 |
244 | return top1.avg, objs.avg
245 |
246 |
247 |
--------------------------------------------------------------------------------
/darts/cnn/utils.py:
--------------------------------------------------------------------------------
1 | import os
2 | import numpy as np
3 | import torch
4 | import shutil
5 | import torchvision.transforms as transforms
6 |
7 |
8 | class AvgrageMeter(object):
9 |
10 | def __init__(self):
11 | self.reset()
12 |
13 | def reset(self):
14 | self.avg = 0
15 | self.sum = 0
16 | self.cnt = 0
17 |
18 | def update(self, val, n=1):
19 | self.sum += val * n
20 | self.cnt += n
21 | self.avg = self.sum / self.cnt
22 |
23 |
24 | def accuracy(output, target, topk=(1,)):
25 | maxk = max(topk)
26 | batch_size = target.size(0)
27 |
28 | _, pred = output.topk(maxk, 1, True, True)
29 | pred = pred.t()
30 | correct = pred.eq(target.view(1, -1).expand_as(pred))
31 |
32 | res = []
33 | for k in topk:
34 | correct_k = correct[:k].view(-1).float().sum(0)
35 | res.append(correct_k.mul_(100.0/batch_size))
36 | return res
37 |
38 |
39 | class Cutout(object):
40 | def __init__(self, length):
41 | self.length = length
42 |
43 | def __call__(self, img):
44 | h, w = img.size(1), img.size(2)
45 | mask = np.ones((h, w), np.float32)
46 | y = np.random.randint(h)
47 | x = np.random.randint(w)
48 |
49 | y1 = np.clip(y - self.length // 2, 0, h)
50 | y2 = np.clip(y + self.length // 2, 0, h)
51 | x1 = np.clip(x - self.length // 2, 0, w)
52 | x2 = np.clip(x + self.length // 2, 0, w)
53 |
54 | mask[y1: y2, x1: x2] = 0.
55 | mask = torch.from_numpy(mask)
56 | mask = mask.expand_as(img)
57 | img *= mask
58 | return img
59 |
60 |
61 | def _data_transforms_cifar10(cutout, cutout_length):
62 | CIFAR_MEAN = [0.49139968, 0.48215827, 0.44653124]
63 | CIFAR_STD = [0.24703233, 0.24348505, 0.26158768]
64 |
65 | train_transform = transforms.Compose([
66 | transforms.RandomCrop(32, padding=4),
67 | transforms.RandomHorizontalFlip(),
68 | transforms.ToTensor(),
69 | transforms.Normalize(CIFAR_MEAN, CIFAR_STD),
70 | ])
71 | if cutout:
72 | train_transform.transforms.append(Cutout(cutout_length))
73 |
74 | valid_transform = transforms.Compose([
75 | transforms.ToTensor(),
76 | transforms.Normalize(CIFAR_MEAN, CIFAR_STD),
77 | ])
78 | return train_transform, valid_transform
79 |
80 |
81 | def count_parameters_in_MB(model):
82 | return np.sum(np.prod(v.size()) for name, v in model.named_parameters() if "auxiliary" not in name)/1e6
83 |
84 |
85 | def save_checkpoint(state, is_best, save):
86 | filename = os.path.join(save, 'checkpoint.pth.tar')
87 | torch.save(state, filename)
88 | if is_best:
89 | best_filename = os.path.join(save, 'model_best.pth.tar')
90 | shutil.copyfile(filename, best_filename)
91 |
92 |
93 | def save(model, model_path):
94 | torch.save(model.state_dict(), model_path)
95 |
96 |
97 | def load(model, model_path, gpu_id):
98 | ml = 'cuda:{}'.format(gpu_id) if torch.cuda.is_available() else 'cpu'
99 | model.load_state_dict(torch.load(model_path, map_location = ml), strict=False)
100 |
101 |
102 |
103 | def drop_path(x, drop_prob):
104 | if drop_prob > 0.:
105 | keep_prob = 1.-drop_prob
106 | mask = torch.cuda.FloatTensor(x.size(0), 1, 1, 1).bernoulli_(keep_prob)
107 | x.div_(keep_prob)
108 | x.mul_(mask)
109 | return x
110 |
111 |
112 | def create_exp_dir(path, scripts_to_save=None):
113 | if not os.path.exists(path):
114 | os.mkdir(path)
115 | print('Experiment dir : {}'.format(path))
116 |
117 | if scripts_to_save is not None:
118 | os.mkdir(os.path.join(path, 'scripts'))
119 | for script in scripts_to_save:
120 | dst_file = os.path.join(path, 'scripts', os.path.basename(script))
121 | shutil.copyfile(script, dst_file)
122 |
123 |
--------------------------------------------------------------------------------
/darts/cnn/visualize.py:
--------------------------------------------------------------------------------
1 | import sys
2 | import genotypes
3 | from graphviz import Digraph
4 |
5 |
6 | def plot(genotype, filename):
7 | g = Digraph(
8 | format='pdf',
9 | edge_attr=dict(fontsize='20', fontname="times"),
10 | node_attr=dict(style='filled', shape='rect', align='center', fontsize='20', height='0.5', width='0.5', penwidth='2', fontname="times"),
11 | engine='dot')
12 | g.body.extend(['rankdir=LR'])
13 |
14 | g.node("c_{k-2}", fillcolor='darkseagreen2')
15 | g.node("c_{k-1}", fillcolor='darkseagreen2')
16 | assert len(genotype) % 2 == 0
17 | steps = len(genotype) // 2
18 |
19 | for i in range(steps):
20 | g.node(str(i), fillcolor='lightblue')
21 |
22 | for i in range(steps):
23 | for k in [2*i, 2*i + 1]:
24 | op, j = genotype[k]
25 | if j == 0:
26 | u = "c_{k-2}"
27 | elif j == 1:
28 | u = "c_{k-1}"
29 | else:
30 | u = str(j-2)
31 | v = str(i)
32 | g.edge(u, v, label=op, fillcolor="gray")
33 |
34 | g.node("c_{k}", fillcolor='palegoldenrod')
35 | for i in range(steps):
36 | g.edge(str(i), "c_{k}", fillcolor="gray")
37 |
38 | g.render(filename, view=True)
39 |
40 |
41 | if __name__ == '__main__':
42 | if len(sys.argv) != 2:
43 | print("usage:\n python {} ARCH_NAME".format(sys.argv[0]))
44 | sys.exit(1)
45 |
46 | genotype_name = sys.argv[1]
47 | try:
48 | genotype = eval('genotypes.{}'.format(genotype_name))
49 | except AttributeError:
50 | print("{} is not specified in genotypes.py".format(genotype_name))
51 | sys.exit(1)
52 |
53 | plot(genotype.normal, "normal & reduction")
54 |
55 |
--------------------------------------------------------------------------------
/docs/arch.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AIoT-MLSys-Lab/arch2vec/ea01b0cf1295305596ee3c05fa1b6eb14e303512/docs/arch.png
--------------------------------------------------------------------------------
/gin/models/graphcnn.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 | import sys
6 | sys.path.append("models/")
7 | from gin.models.mlp import MLP
8 |
9 | class GraphCNN(nn.Module):
10 | def __init__(self, num_layers, num_mlp_layers, input_dim, hidden_dim, output_dim, final_dropout, learn_eps, graph_pooling_type, neighbor_pooling_type, device):
11 | '''
12 | num_layers: number of layers in the neural networks (INCLUDING the input layer)
13 | num_mlp_layers: number of layers in mlps (EXCLUDING the input layer)
14 | input_dim: dimensionality of input features
15 | hidden_dim: dimensionality of hidden units at ALL layers
16 | output_dim: number of classes for prediction
17 | final_dropout: dropout ratio on the final linear layer
18 | learn_eps: If True, learn epsilon to distinguish center nodes from neighboring nodes. If False, aggregate neighbors and center nodes altogether.
19 | neighbor_pooling_type: how to aggregate neighbors (mean, average, or max)
20 | graph_pooling_type: how to aggregate entire nodes in a graph (mean, average)
21 | device: which device to use
22 | '''
23 |
24 | super(GraphCNN, self).__init__()
25 |
26 | self.final_dropout = final_dropout
27 | self.device = device
28 | self.num_layers = num_layers
29 | self.graph_pooling_type = graph_pooling_type
30 | self.neighbor_pooling_type = neighbor_pooling_type
31 | self.learn_eps = learn_eps
32 | self.eps = nn.Parameter(torch.zeros(self.num_layers-1))
33 |
34 | ###List of MLPs
35 | self.mlps = torch.nn.ModuleList()
36 |
37 | ###List of batchnorms applied to the output of MLP (input of the final prediction linear layer)
38 | self.batch_norms = torch.nn.ModuleList()
39 |
40 | for layer in range(self.num_layers-1):
41 | if layer == 0:
42 | self.mlps.append(MLP(num_mlp_layers, input_dim, hidden_dim, hidden_dim))
43 | else:
44 | self.mlps.append(MLP(num_mlp_layers, hidden_dim, hidden_dim, hidden_dim))
45 |
46 | self.batch_norms.append(nn.BatchNorm1d(hidden_dim))
47 |
48 | #Linear function that maps the hidden representation at dofferemt layers into a prediction score
49 | self.linears_prediction = torch.nn.ModuleList()
50 | for layer in range(num_layers):
51 | if layer == 0:
52 | self.linears_prediction.append(nn.Linear(input_dim, output_dim))
53 | else:
54 | self.linears_prediction.append(nn.Linear(hidden_dim, output_dim))
55 |
56 |
57 | def __preprocess_neighbors_maxpool(self, batch_graph):
58 | ###create padded_neighbor_list in concatenated graph
59 |
60 | #compute the maximum number of neighbors within the graphs in the current minibatch
61 | max_deg = max([graph.max_neighbor for graph in batch_graph])
62 |
63 | padded_neighbor_list = []
64 | start_idx = [0]
65 |
66 |
67 | for i, graph in enumerate(batch_graph):
68 | start_idx.append(start_idx[i] + len(graph.g))
69 | padded_neighbors = []
70 | for j in range(len(graph.neighbors)):
71 | #add off-set values to the neighbor indices
72 | pad = [n + start_idx[i] for n in graph.neighbors[j]]
73 | #padding, dummy data is assumed to be stored in -1
74 | pad.extend([-1]*(max_deg - len(pad)))
75 |
76 | #Add center nodes in the maxpooling if learn_eps is False, i.e., aggregate center nodes and neighbor nodes altogether.
77 | if not self.learn_eps:
78 | pad.append(j + start_idx[i])
79 |
80 | padded_neighbors.append(pad)
81 | padded_neighbor_list.extend(padded_neighbors)
82 |
83 | return torch.LongTensor(padded_neighbor_list)
84 |
85 |
86 | def __preprocess_neighbors_sumavepool(self, batch_graph):
87 | ###create block diagonal sparse matrix
88 |
89 | edge_mat_list = []
90 | start_idx = [0]
91 | for i, graph in enumerate(batch_graph):
92 | start_idx.append(start_idx[i] + len(graph.g))
93 | edge_mat_list.append(graph.edge_mat + start_idx[i])
94 | Adj_block_idx = torch.cat(edge_mat_list, 1)
95 | Adj_block_elem = torch.ones(Adj_block_idx.shape[1])
96 |
97 | #Add self-loops in the adjacency matrix if learn_eps is False, i.e., aggregate center nodes and neighbor nodes altogether.
98 |
99 | if not self.learn_eps:
100 | num_node = start_idx[-1]
101 | self_loop_edge = torch.LongTensor([range(num_node), range(num_node)])
102 | elem = torch.ones(num_node)
103 | Adj_block_idx = torch.cat([Adj_block_idx, self_loop_edge], 1)
104 | Adj_block_elem = torch.cat([Adj_block_elem, elem], 0)
105 |
106 | Adj_block = torch.sparse.FloatTensor(Adj_block_idx, Adj_block_elem, torch.Size([start_idx[-1],start_idx[-1]]))
107 |
108 | return Adj_block.to(self.device)
109 |
110 |
111 | def __preprocess_graphpool(self, batch_graph):
112 | ###create sum or average pooling sparse matrix over entire nodes in each graph (num graphs x num nodes)
113 |
114 | start_idx = [0]
115 |
116 | #compute the padded neighbor list
117 | for i, graph in enumerate(batch_graph):
118 | start_idx.append(start_idx[i] + len(graph.g))
119 |
120 | idx = []
121 | elem = []
122 | for i, graph in enumerate(batch_graph):
123 | ###average pooling
124 | if self.graph_pooling_type == "average":
125 | elem.extend([1./len(graph.g)]*len(graph.g))
126 |
127 | else:
128 | ###sum pooling
129 | elem.extend([1]*len(graph.g))
130 |
131 | idx.extend([[i, j] for j in range(start_idx[i], start_idx[i+1], 1)])
132 | elem = torch.FloatTensor(elem)
133 | idx = torch.LongTensor(idx).transpose(0,1)
134 | graph_pool = torch.sparse.FloatTensor(idx, elem, torch.Size([len(batch_graph), start_idx[-1]]))
135 |
136 | return graph_pool.to(self.device)
137 |
138 | def maxpool(self, h, padded_neighbor_list):
139 | ###Element-wise minimum will never affect max-pooling
140 |
141 | dummy = torch.min(h, dim = 0)[0]
142 | h_with_dummy = torch.cat([h, dummy.reshape((1, -1)).to(self.device)])
143 | pooled_rep = torch.max(h_with_dummy[padded_neighbor_list], dim = 1)[0]
144 | return pooled_rep
145 |
146 |
147 | def next_layer_eps(self, h, layer, padded_neighbor_list = None, Adj_block = None):
148 | ###pooling neighboring nodes and center nodes separately by epsilon reweighting.
149 |
150 | if self.neighbor_pooling_type == "max":
151 | ##If max pooling
152 | pooled = self.maxpool(h, padded_neighbor_list)
153 | else:
154 | #If sum or average pooling
155 | pooled = torch.spmm(Adj_block, h)
156 | if self.neighbor_pooling_type == "average":
157 | #If average pooling
158 | degree = torch.spmm(Adj_block, torch.ones((Adj_block.shape[0], 1)).to(self.device))
159 | pooled = pooled/degree
160 |
161 | #Reweights the center node representation when aggregating it with its neighbors
162 | pooled = pooled + (1 + self.eps[layer])*h
163 | pooled_rep = self.mlps[layer](pooled)
164 | h = self.batch_norms[layer](pooled_rep)
165 |
166 | #non-linearity
167 | h = F.relu(h)
168 | return h
169 |
170 |
171 | def next_layer(self, h, layer, padded_neighbor_list = None, Adj_block = None):
172 | ###pooling neighboring nodes and center nodes altogether
173 |
174 | if self.neighbor_pooling_type == "max":
175 | ##If max pooling
176 | pooled = self.maxpool(h, padded_neighbor_list)
177 | else:
178 | #If sum or average pooling
179 | pooled = torch.spmm(Adj_block, h)
180 | if self.neighbor_pooling_type == "average":
181 | #If average pooling
182 | degree = torch.spmm(Adj_block, torch.ones((Adj_block.shape[0], 1)).to(self.device))
183 | pooled = pooled/degree
184 |
185 | #representation of neighboring and center nodes
186 | pooled_rep = self.mlps[layer](pooled)
187 |
188 | h = self.batch_norms[layer](pooled_rep)
189 |
190 | #non-linearity
191 | h = F.relu(h)
192 | return h
193 |
194 |
195 | def forward(self, batch_graph):
196 | X_concat = torch.cat([graph.node_features for graph in batch_graph], 0).to(self.device)
197 | graph_pool = self.__preprocess_graphpool(batch_graph)
198 |
199 | if self.neighbor_pooling_type == "max":
200 | padded_neighbor_list = self.__preprocess_neighbors_maxpool(batch_graph)
201 | else:
202 | Adj_block = self.__preprocess_neighbors_sumavepool(batch_graph)
203 |
204 | #list of hidden representation at each layer (including input)
205 | hidden_rep = [X_concat]
206 | h = X_concat
207 |
208 | for layer in range(self.num_layers-1):
209 | if self.neighbor_pooling_type == "max" and self.learn_eps:
210 | h = self.next_layer_eps(h, layer, padded_neighbor_list = padded_neighbor_list)
211 | elif not self.neighbor_pooling_type == "max" and self.learn_eps:
212 | h = self.next_layer_eps(h, layer, Adj_block = Adj_block)
213 | elif self.neighbor_pooling_type == "max" and not self.learn_eps:
214 | h = self.next_layer(h, layer, padded_neighbor_list = padded_neighbor_list)
215 | elif not self.neighbor_pooling_type == "max" and not self.learn_eps:
216 | h = self.next_layer(h, layer, Adj_block = Adj_block)
217 |
218 | hidden_rep.append(h)
219 |
220 | score_over_layer = 0
221 |
222 | #perform pooling over all nodes in each graph in every layer
223 | for layer, h in enumerate(hidden_rep):
224 | pooled_h = torch.spmm(graph_pool, h)
225 | score_over_layer += F.dropout(self.linears_prediction[layer](pooled_h), self.final_dropout, training = self.training)
226 |
227 | return score_over_layer
228 |
--------------------------------------------------------------------------------
/gin/models/mlp.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 | ###MLP with lienar output
6 | class MLP(nn.Module):
7 | def __init__(self, num_layers, input_dim, hidden_dim, output_dim):
8 | '''
9 | num_layers: number of layers in the neural networks (EXCLUDING the input layer). If num_layers=1, this reduces to linear model.
10 | input_dim: dimensionality of input features
11 | hidden_dim: dimensionality of hidden units at ALL layers
12 | output_dim: number of classes for prediction
13 | device: which device to use
14 | '''
15 |
16 | super(MLP, self).__init__()
17 |
18 | self.linear_or_not = True #default is linear model
19 | self.num_layers = num_layers
20 |
21 | if num_layers < 1:
22 | raise ValueError("number of layers should be positive!")
23 | elif num_layers == 1:
24 | #Linear model
25 | self.linear = nn.Linear(input_dim, output_dim)
26 | else:
27 | #Multi-layer model
28 | self.linear_or_not = False
29 | self.linears = torch.nn.ModuleList()
30 | self.batch_norms = torch.nn.ModuleList()
31 |
32 | self.linears.append(nn.Linear(input_dim, hidden_dim))
33 | for layer in range(num_layers - 2):
34 | self.linears.append(nn.Linear(hidden_dim, hidden_dim))
35 | self.linears.append(nn.Linear(hidden_dim, output_dim))
36 |
37 | for layer in range(num_layers - 1):
38 | self.batch_norms.append(nn.BatchNorm1d((hidden_dim)))
39 |
40 | def forward(self, x):
41 | if self.linear_or_not:
42 | #If linear model
43 | return self.linear(x)
44 | else:
45 | #If MLP
46 | h = x
47 | for layer in range(self.num_layers - 1):
48 | h = F.relu(self.batch_norms[layer](self.linears[layer](h)))
49 | return self.linears[self.num_layers - 1](h)
--------------------------------------------------------------------------------
/models/configs.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 | configs = [{'GAE': # 0
6 | {'activation_ops':torch.sigmoid},
7 | 'loss':
8 | {'loss_ops':F.mse_loss, 'loss_adj':F.mse_loss},
9 | 'prep':
10 | {'method':3, 'lbd':0.5}
11 | },
12 | {'GAE': # 1
13 | {'activation_ops':torch.softmax},
14 | 'loss':
15 | {'loss_ops':nn.BCELoss(), 'loss_adj':nn.BCELoss()},
16 | 'prep':
17 | {'method':3, 'lbd':0.5}
18 | },
19 | {'GAE': # 2
20 | {'activation_ops': torch.softmax},
21 | 'loss':
22 | {'loss_ops': F.mse_loss, 'loss_adj': nn.BCELoss()},
23 | 'prep':
24 | {'method':3, 'lbd':0.5}
25 | },
26 | {'GAE':# 3
27 | {'activation_ops':torch.sigmoid},
28 | 'loss':
29 | {'loss_ops':F.mse_loss, 'loss_adj':F.mse_loss},
30 | 'prep':
31 | {'method':4, 'lbd':1.0}
32 | },
33 | {'GAE': # 4
34 | {'activation_adj': torch.sigmoid, 'activation_ops': torch.softmax, 'adj_hidden_dim': 128, 'ops_hidden_dim': 128},
35 | 'loss':
36 | {'loss_ops': nn.BCELoss(), 'loss_adj': nn.BCELoss()},
37 | 'prep':
38 | {'method': 4, 'lbd': 1.0}
39 | },
40 | ]
41 |
--------------------------------------------------------------------------------
/models/layers.py:
--------------------------------------------------------------------------------
1 | import math
2 | import torch
3 | import torch.nn.functional as F
4 | import torch.nn as nn
5 | from torch.nn.parameter import Parameter
6 |
7 | class GraphConvolution(nn.Module):
8 | def __init__(self, in_features, out_features, dropout=0., bias=True):
9 | super(GraphConvolution, self).__init__()
10 | self.in_features = in_features
11 | self.out_features = out_features
12 | self.weight = Parameter(torch.Tensor(out_features, in_features))
13 | if bias:
14 | self.bias = Parameter(torch.Tensor(out_features))
15 | else:
16 | self.register_parameter('bias', None)
17 | self.reset_parameters()
18 | self.dropout = dropout
19 |
20 | def reset_parameters(self):
21 | stdv = 1. / math.sqrt(self.weight.size(1))
22 | torch.nn.init.kaiming_uniform_(self.weight)
23 | if self.bias is not None:
24 | self.bias.data.uniform_(-stdv, stdv)
25 |
26 | def forward(self, ops, adj):
27 | ops = F.dropout(ops, self.dropout, self.training)
28 | support = F.linear(ops, self.weight)
29 | output = F.relu(torch.matmul(adj, support))
30 |
31 | if self.bias is not None:
32 | return output + self.bias
33 | else:
34 | return output
35 |
36 | def __repr__(self):
37 | return self.__class__.__name__ + '(' + str(self.in_features) + '->' + str(self.out_features) + ')'
38 |
--------------------------------------------------------------------------------
/models/pretraining_darts.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 | sys.path.insert(0, os.getcwd())
4 | import torch
5 | import torch.nn as nn
6 | import numpy as np
7 | import argparse
8 | from nasbench.lib import graph_util
9 | from torch import optim
10 | from models.model import Model, VAEReconstructed_Loss
11 | from utils.utils import load_json, save_checkpoint_vae, preprocessing, one_hot_darts, to_ops_darts
12 | from utils.utils import get_val_acc_vae, is_valid_darts
13 | from models.configs import configs
14 |
15 |
16 | def process(geno):
17 | for i, item in enumerate(geno):
18 | geno[i] = tuple(geno[i])
19 | return geno
20 |
21 | def _build_dataset(dataset):
22 | print(""" loading dataset """)
23 | X_adj = []
24 | X_ops = []
25 | for k, v in dataset.items():
26 | adj = v[0]
27 | ops = v[1]
28 | X_adj.append(torch.Tensor(adj))
29 | X_ops.append(torch.Tensor(one_hot_darts(ops)))
30 |
31 | X_adj = torch.stack(X_adj)
32 | X_ops = torch.stack(X_ops)
33 |
34 | X_adj_train, X_adj_val = X_adj[:int(X_adj.shape[0]*0.9)], X_adj[int(X_adj.shape[0]*0.9):]
35 | X_ops_train, X_ops_val = X_ops[:int(X_ops.shape[0]*0.9)], X_ops[int(X_ops.shape[0]*0.9):]
36 | indices = torch.randperm(X_adj_train.shape[0])
37 | indices_val = torch.randperm(X_adj_val.shape[0])
38 | X_adj = X_adj_train[indices]
39 | X_ops = X_ops_train[indices]
40 | X_adj_val = X_adj_val[indices_val]
41 | X_ops_val = X_ops_val[indices_val]
42 |
43 | return X_adj, X_ops, indices, X_adj_val, X_ops_val, indices_val
44 |
45 |
46 | def pretraining_gae(dataset, cfg):
47 | """ implementation of VGAE pretraining on DARTS Search Space """
48 | X_adj, X_ops, indices, X_adj_val, X_ops_val, indices_val = _build_dataset(dataset)
49 | print('train set size: {}, validation set size: {}'.format(indices.shape[0], indices_val.shape[0]))
50 | model = Model(input_dim=args.input_dim, hidden_dim=args.hidden_dim, latent_dim=args.dim,
51 | num_hops=args.hops, num_mlp_layers=args.mlps, dropout=args.dropout, **cfg['GAE']).cuda()
52 | optimizer = optim.Adam(model.parameters(), lr=1e-3, betas=(0.9, 0.999), eps=1e-08)
53 | epochs = args.epochs
54 | bs = args.bs
55 | loss_total = []
56 | best_graph_acc = 0
57 | for epoch in range(0, epochs):
58 | chunks = X_adj.shape[0] // bs
59 | if X_adj.shape[0] % bs > 0:
60 | chunks += 1
61 | X_adj_split = torch.split(X_adj, bs, dim=0)
62 | X_ops_split = torch.split(X_ops, bs, dim=0)
63 | indices_split = torch.split(indices, bs, dim=0)
64 | loss_epoch = []
65 | Z = []
66 | for i, (adj, ops, ind) in enumerate(zip(X_adj_split, X_ops_split, indices_split)):
67 | optimizer.zero_grad()
68 | adj, ops = adj.cuda(), ops.cuda()
69 | # preprocessing
70 | adj, ops, prep_reverse = preprocessing(adj, ops, **cfg['prep'])
71 | # forward
72 | ops_recon, adj_recon, mu, logvar = model(ops, adj)
73 | Z.append(mu)
74 | adj_recon, ops_recon = prep_reverse(adj_recon, ops_recon)
75 | adj, ops = prep_reverse(adj, ops)
76 | loss = VAEReconstructed_Loss(**cfg['loss'])((ops_recon, adj_recon), (ops, adj), mu, logvar)
77 | loss.backward()
78 | nn.utils.clip_grad_norm_(model.parameters(), 5)
79 | optimizer.step()
80 | loss_epoch.append(loss.item())
81 | if i % 500 == 0:
82 | print('epoch {}: batch {} / {}: loss: {:.5f}'.format(epoch, i, chunks, loss.item()))
83 | Z = torch.cat(Z, dim=0)
84 | z_mean, z_std = Z.mean(0), Z.std(0)
85 | validity_counter = 0
86 | buckets = {}
87 | model.eval()
88 | for _ in range(args.latent_points):
89 | z = torch.randn(11, args.dim).cuda()
90 | z = z * z_std + z_mean
91 | op, ad = model.decoder(z.unsqueeze(0))
92 | op = op.squeeze(0).cpu()
93 | ad = ad.squeeze(0).cpu()
94 | max_idx = torch.argmax(op, dim=-1)
95 | one_hot = torch.zeros_like(op)
96 | for i in range(one_hot.shape[0]):
97 | one_hot[i][max_idx[i]] = 1
98 | op_decode = to_ops_darts(max_idx)
99 | ad_decode = (ad>0.5).int().triu(1).numpy()
100 | ad_decode = np.ndarray.tolist(ad_decode)
101 | if is_valid_darts(ad_decode, op_decode):
102 | validity_counter += 1
103 | fingerprint = graph_util.hash_module(np.array(ad_decode), one_hot.numpy().tolist())
104 | if fingerprint not in buckets:
105 | buckets[fingerprint] = (ad_decode, one_hot.numpy().astype('int8').tolist())
106 | validity = validity_counter / args.latent_points
107 | print('Ratio of valid decodings from the prior: {:.4f}'.format(validity))
108 | print('Ratio of unique decodings from the prior: {:.4f}'.format(len(buckets) / (validity_counter+1e-8)))
109 |
110 | acc_ops_val, mean_corr_adj_val, mean_fal_pos_adj_val, acc_adj_val = get_val_acc_vae(model,cfg,X_adj_val, X_ops_val,indices_val)
111 | print('validation set: acc_ops:{0:.2f}, mean_corr_adj:{1:.2f}, mean_fal_pos_adj:{2:.2f}, acc_adj:{3:.2f}'.format(
112 | acc_ops_val, mean_corr_adj_val, mean_fal_pos_adj_val, acc_adj_val))
113 |
114 | #print("reconstructed adj matrix:", adj_recon[1])
115 | #print("original adj matrix:", adj[1])
116 | #print("reconstructed ops matrix:", ops_recon[1])
117 | #print("original ops matrix:", ops[1])
118 |
119 | print('epoch {}: average loss {:.5f}'.format(epoch, sum(loss_epoch)/len(loss_epoch)))
120 | loss_total.append(sum(loss_epoch) / len(loss_epoch))
121 | print('loss for epochs: \n', loss_total)
122 | save_checkpoint_vae(model, optimizer, epoch, sum(loss_epoch) / len(loss_epoch), args.dim, args.name, args.dropout, args.seed)
123 |
124 |
125 | print('loss for epochs: ', loss_total)
126 |
127 |
128 | if __name__ == '__main__':
129 | parser = argparse.ArgumentParser(description='Pretraining')
130 | parser.add_argument("--seed", type=int, default=3, help="random seed")
131 | parser.add_argument('--data', type=str, default='data/data_darts_counter600000.json',
132 | help='Data file (default: data.json')
133 | parser.add_argument('--name', type=str, default='darts')
134 | parser.add_argument('--cfg', type=int, default=4,
135 | help='configuration (default: 4)')
136 | parser.add_argument('--bs', type=int, default=32,
137 | help='batch size (default: 32)')
138 | parser.add_argument('--epochs', type=int, default=10,
139 | help='training epochs (default: 10)')
140 | parser.add_argument('--dropout', type=float, default=0.3,
141 | help='decoder implicit regularization (default: 0.3)')
142 | parser.add_argument('--normalize', action='store_true', default=True,
143 | help='use input normalization')
144 | parser.add_argument('--input_dim', type=int, default=11)
145 | parser.add_argument('--hidden_dim', type=int, default=128)
146 | parser.add_argument('--dim', type=int, default=16,
147 | help='feature dimension (default: 16)')
148 | parser.add_argument('--hops', type=int, default=5)
149 | parser.add_argument('--mlps', type=int, default=2)
150 | parser.add_argument('--latent_points', type=int, default=10000,
151 | help='latent points for validaty check (default: 10000)')
152 | args = parser.parse_args()
153 | cfg = configs[args.cfg]
154 | dataset = load_json(args.data)
155 | print('using {}'.format(args.data))
156 | print('feat dim {}'.format(args.dim))
157 | pretraining_gae(dataset, cfg)
158 |
--------------------------------------------------------------------------------
/models/pretraining_darts.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 | python models/pretraining_darts.py --dim 16 --cfg 4 --bs 32 --epochs 10 --hidden_dim 128 --dim 16 --data data/data_darts_counter600000.json --name darts
3 |
--------------------------------------------------------------------------------
/models/pretraining_nasbench101.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 | sys.path.insert(0, os.getcwd())
4 | import torch
5 | import torch.nn as nn
6 | import numpy as np
7 | from torch import optim
8 | from models.model import Model, VAEReconstructed_Loss
9 | from utils.utils import load_json, save_checkpoint_vae, preprocessing
10 | from utils.utils import get_val_acc_vae
11 | from models.configs import configs
12 | import argparse
13 | from nasbench import api
14 | from nasbench.lib import graph_util
15 |
16 | def transform_operations(max_idx):
17 | transform_dict = {0:'input', 1:'conv1x1-bn-relu', 2:'conv3x3-bn-relu', 3:'maxpool3x3', 4:'output'}
18 | ops = []
19 | for idx in max_idx:
20 | ops.append(transform_dict[idx.item()])
21 | return ops
22 |
23 | def _build_dataset(dataset, list):
24 | indices = np.random.permutation(list)
25 | X_adj = []
26 | X_ops = []
27 | for ind in indices:
28 | X_adj.append(torch.Tensor(dataset[str(ind)]['module_adjacency']))
29 | X_ops.append(torch.Tensor(dataset[str(ind)]['module_operations']))
30 | X_adj = torch.stack(X_adj)
31 | X_ops = torch.stack(X_ops)
32 | return X_adj, X_ops, torch.Tensor(indices)
33 |
34 |
35 | def pretraining_model(dataset, cfg, args):
36 | nasbench = api.NASBench('data/nasbench_only108.tfrecord')
37 | train_ind_list, val_ind_list = range(int(len(dataset)*0.9)), range(int(len(dataset)*0.9), len(dataset))
38 | X_adj_train, X_ops_train, indices_train = _build_dataset(dataset, train_ind_list)
39 | X_adj_val, X_ops_val, indices_val = _build_dataset(dataset, val_ind_list)
40 | model = Model(input_dim=args.input_dim, hidden_dim=args.hidden_dim, latent_dim=args.dim,
41 | num_hops=args.hops, num_mlp_layers=args.mlps, dropout=args.dropout, **cfg['GAE']).cuda()
42 | optimizer = optim.Adam(model.parameters(), lr=1e-3, betas=(0.9, 0.999), eps=1e-08)
43 | epochs = args.epochs
44 | bs = args.bs
45 | loss_total = []
46 | for epoch in range(0, epochs):
47 | chunks = len(train_ind_list) // bs
48 | if len(train_ind_list) % bs > 0:
49 | chunks += 1
50 | X_adj_split = torch.split(X_adj_train, bs, dim=0)
51 | X_ops_split = torch.split(X_ops_train, bs, dim=0)
52 | indices_split = torch.split(indices_train, bs, dim=0)
53 | loss_epoch = []
54 | Z = []
55 | for i, (adj, ops, ind) in enumerate(zip(X_adj_split, X_ops_split, indices_split)):
56 | optimizer.zero_grad()
57 | adj, ops = adj.cuda(), ops.cuda()
58 | # preprocessing
59 | adj, ops, prep_reverse = preprocessing(adj, ops, **cfg['prep'])
60 | # forward
61 | ops_recon, adj_recon, mu, logvar = model(ops, adj.to(torch.long))
62 | Z.append(mu)
63 | adj_recon, ops_recon = prep_reverse(adj_recon, ops_recon)
64 | adj, ops = prep_reverse(adj, ops)
65 | loss = VAEReconstructed_Loss(**cfg['loss'])((ops_recon, adj_recon), (ops, adj), mu, logvar)
66 | loss.backward()
67 | nn.utils.clip_grad_norm_(model.parameters(), 5)
68 | optimizer.step()
69 | loss_epoch.append(loss.item())
70 | if i%1000==0:
71 | print('epoch {}: batch {} / {}: loss: {:.5f}'.format(epoch, i, chunks, loss.item()))
72 | Z = torch.cat(Z, dim=0)
73 | z_mean, z_std = Z.mean(0), Z.std(0)
74 | validity_counter = 0
75 | buckets = {}
76 | model.eval()
77 | for _ in range(args.latent_points):
78 | z = torch.randn(7, args.dim).cuda()
79 | z = z * z_std + z_mean
80 | op, ad = model.decoder(z.unsqueeze(0))
81 | op = op.squeeze(0).cpu()
82 | ad = ad.squeeze(0).cpu()
83 | max_idx = torch.argmax(op, dim=-1)
84 | one_hot = torch.zeros_like(op)
85 | for i in range(one_hot.shape[0]):
86 | one_hot[i][max_idx[i]] = 1
87 | op_decode = transform_operations(max_idx)
88 | ad_decode = (ad>0.5).int().triu(1).numpy()
89 | ad_decode = np.ndarray.tolist(ad_decode)
90 | spec = api.ModelSpec(matrix=ad_decode, ops=op_decode)
91 | if nasbench.is_valid(spec):
92 | validity_counter += 1
93 | fingerprint = graph_util.hash_module(np.array(ad_decode), one_hot.numpy().tolist())
94 | if fingerprint not in buckets:
95 | buckets[fingerprint] = (ad_decode, one_hot.numpy().astype('int8').tolist())
96 | validity = validity_counter / args.latent_points
97 | print('Ratio of valid decodings from the prior: {:.4f}'.format(validity))
98 | print('Ratio of unique decodings from the prior: {:.4f}'.format(len(buckets) / (validity_counter+1e-8)))
99 | acc_ops_val, mean_corr_adj_val, mean_fal_pos_adj_val, acc_adj_val = get_val_acc_vae(model, cfg, X_adj_val, X_ops_val, indices_val)
100 | print('validation set: acc_ops:{0:.4f}, mean_corr_adj:{1:.4f}, mean_fal_pos_adj:{2:.4f}, acc_adj:{3:.4f}'.format(
101 | acc_ops_val, mean_corr_adj_val, mean_fal_pos_adj_val, acc_adj_val))
102 | print('epoch {}: average loss {:.5f}'.format(epoch, sum(loss_epoch)/len(loss_epoch)))
103 | loss_total.append(sum(loss_epoch) / len(loss_epoch))
104 | save_checkpoint_vae(model, optimizer, epoch, sum(loss_epoch) / len(loss_epoch), args.dim, args.name, args.dropout, args.seed)
105 | print('loss for epochs: \n', loss_total)
106 |
107 |
108 |
109 | if __name__ == '__main__':
110 | parser = argparse.ArgumentParser(description='Pretraining')
111 | parser.add_argument("--seed", type=int, default=1, help="random seed")
112 | parser.add_argument('--data', type=str, default='data/data.json',
113 | help='Data file (default: data.json')
114 | parser.add_argument('--name', type=str, default='nasbench-101',
115 | help='nasbench-101/nasbench-201/darts')
116 | parser.add_argument('--cfg', type=int, default=4,
117 | help='configuration (default: 4)')
118 | parser.add_argument('--bs', type=int, default=32,
119 | help='batch size (default: 32)')
120 | parser.add_argument('--epochs', type=int, default=8,
121 | help='training epochs (default: 8)')
122 | parser.add_argument('--dropout', type=float, default=0.3,
123 | help='decoder implicit regularization (default: 0.3)')
124 | parser.add_argument('--normalize', action='store_true', default=True,
125 | help='use input normalization')
126 | parser.add_argument('--input_dim', type=int, default=5)
127 | parser.add_argument('--hidden_dim', type=int, default=128)
128 | parser.add_argument('--dim', type=int, default=16,
129 | help='feature dimension (default: 16)')
130 | parser.add_argument('--hops', type=int, default=5)
131 | parser.add_argument('--mlps', type=int, default=2)
132 | parser.add_argument('--latent_points', type=int, default=10000,
133 | help='latent points for validaty check (default: 10000)')
134 | args = parser.parse_args()
135 | np.random.seed(args.seed)
136 | torch.manual_seed(args.seed)
137 | torch.cuda.manual_seed_all(args.seed)
138 | cfg = configs[args.cfg]
139 | dataset = load_json(args.data)
140 | print('using {}'.format(args.data))
141 | print('feat dim {}'.format(args.dim))
142 | pretraining_model(dataset, cfg, args)
143 |
--------------------------------------------------------------------------------
/models/pretraining_nasbench101.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 | python models/pretraining_nasbench101.py --dim 16 --cfg 4 --bs 32 --epochs 8 --seed 1 --name nasbench101
3 |
--------------------------------------------------------------------------------
/models/pretraining_nasbench201.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 | sys.path.insert(0, os.getcwd())
4 | import torch
5 | import torch.nn as nn
6 | import numpy as np
7 | from torch import optim
8 | from models.model import Model, VAEReconstructed_Loss
9 | from utils.utils import load_json, save_checkpoint_vae, preprocessing
10 | from utils.utils import get_val_acc_vae, to_ops_nasbench201, is_valid_nasbench201
11 | from models.configs import configs
12 | from nasbench.lib import graph_util
13 | import argparse
14 |
15 |
16 | def _build_dataset(dataset, list):
17 | indices = np.random.permutation(list)
18 | X_adj = []
19 | X_ops = []
20 | for ind in indices:
21 | X_adj.append(torch.Tensor(dataset[str(ind)]['module_adjacency']))
22 | X_ops.append(torch.Tensor(dataset[str(ind)]['module_operations']))
23 | X_adj = torch.stack(X_adj)
24 | X_ops = torch.stack(X_ops)
25 | return X_adj, X_ops, torch.Tensor(indices)
26 |
27 |
28 | def pretraining_gae(dataset, cfg):
29 | """
30 | implementation of model pretraining.
31 | :param dataset: nas-bench-201
32 | :param ind_list: a set structure of indices
33 | :return: the number of samples to achieve global optimum
34 | """
35 | train_ind_list, val_ind_list = range(int(len(dataset)*0.9)), range(int(len(dataset)*0.9), len(dataset))
36 | X_adj_train, X_ops_train, indices_train = _build_dataset(dataset, train_ind_list)
37 | X_adj_val, X_ops_val, indices_val = _build_dataset(dataset, val_ind_list)
38 | model = Model(input_dim=args.input_dim, hidden_dim=args.hidden_dim, latent_dim=args.latent_dim,
39 | num_hops=args.hops, num_mlp_layers=args.mlps, dropout=args.dropout, **cfg['GAE']).cuda()
40 | optimizer = optim.Adam(model.parameters(), lr=1e-3, betas=(0.9, 0.999), eps=1e-08)
41 | epochs = args.epochs
42 | bs = args.bs
43 | loss_total = []
44 | for epoch in range(0, epochs):
45 | chunks = len(X_adj_train) // bs
46 | if len(X_adj_train) % bs > 0:
47 | chunks += 1
48 | X_adj_split = torch.split(X_adj_train, bs, dim=0)
49 | X_ops_split = torch.split(X_ops_train, bs, dim=0)
50 | indices_split = torch.split(indices_train, bs, dim=0)
51 | loss_epoch = []
52 | Z = []
53 | for i, (adj, ops, ind) in enumerate(zip(X_adj_split, X_ops_split, indices_split)):
54 | optimizer.zero_grad()
55 | adj, ops = adj.cuda(), ops.cuda()
56 | # preprocessing
57 | adj, ops, prep_reverse = preprocessing(adj, ops, **cfg['prep'])
58 | # forward
59 | ops_recon, adj_recon, mu, logvar = model(ops, adj)
60 | Z.append(mu)
61 | adj_recon, ops_recon = prep_reverse(adj_recon, ops_recon)
62 | adj, ops = prep_reverse(adj, ops)
63 | loss = VAEReconstructed_Loss(**cfg['loss'])((ops_recon, adj_recon), (ops, adj), mu, logvar)
64 | loss.backward()
65 | nn.utils.clip_grad_norm_(model.parameters(), 5)
66 | optimizer.step()
67 | loss_epoch.append(loss.item())
68 | if i%100==0:
69 | print('epoch {}: batch {} / {}: loss: {:.5f}'.format(epoch, i, chunks, loss.item()))
70 | Z = torch.cat(Z, dim=0)
71 | z_mean, z_std = Z.mean(0), Z.std(0)
72 | validity_counter = 0
73 | buckets = {}
74 | model.eval()
75 | for _ in range(args.latent_points):
76 | z = torch.randn(8, args.latent_dim).cuda()
77 | z = z * z_std + z_mean
78 | op, ad = model.decoder(z.unsqueeze(0))
79 | op = op.squeeze(0).cpu()
80 | ad = ad.squeeze(0).cpu()
81 | max_idx = torch.argmax(op, dim=-1)
82 | one_hot = torch.zeros_like(op)
83 | for i in range(one_hot.shape[0]):
84 | one_hot[i][max_idx[i]] = 1
85 | op_decode = to_ops_nasbench201(max_idx)
86 | ad_decode = (ad>0.5).int().triu(1).numpy()
87 | ad_decode = np.ndarray.tolist(ad_decode)
88 | if is_valid_nasbench201(ad_decode, op_decode):
89 | validity_counter += 1
90 | fingerprint = graph_util.hash_module(np.array(ad_decode), one_hot.numpy().tolist())
91 | if fingerprint not in buckets:
92 | buckets[fingerprint] = (ad_decode, one_hot.numpy().astype('int8').tolist())
93 | validity = validity_counter / args.latent_points
94 | print('Ratio of valid decodings from the prior: {:.4f}'.format(validity))
95 | print('Ratio of unique decodings from the prior: {:.4f}'.format(len(buckets) / (validity_counter+1e-8)))
96 |
97 | acc_ops_val, mean_corr_adj_val, mean_fal_pos_adj_val, acc_adj_val = get_val_acc_vae(model, cfg, X_adj_val, X_ops_val, indices_val)
98 | print('validation set: acc_ops:{0:.2f}, mean_corr_adj:{1:.2f}, mean_fal_pos_adj:{2:.2f}, acc_adj:{3:.2f}'.format(
99 | acc_ops_val, mean_corr_adj_val, mean_fal_pos_adj_val, acc_adj_val))
100 | print('epoch {}: average loss {:.5f}'.format(epoch, sum(loss_epoch)/len(loss_epoch)))
101 | print("reconstructed adj matrix:", adj_recon[1])
102 | print("original adj matrix:", adj[1])
103 | print("reconstructed ops matrix:", ops_recon[1])
104 | print("original ops matrix:", ops[1])
105 | loss_total.append(sum(loss_epoch) / len(loss_epoch))
106 | save_checkpoint_vae(model, optimizer, epoch, sum(loss_epoch) / len(loss_epoch), args.latent_dim, args.name, args.dropout, args.seed)
107 |
108 |
109 | print('loss for epochs: ', loss_total)
110 |
111 |
112 | if __name__ == '__main__':
113 | parser = argparse.ArgumentParser(description='Pretraining')
114 | parser.add_argument("--seed", type=int, default=3, help="random seed")
115 | parser.add_argument('--data', type=str, default='data/cifar10_valid_converged.json')
116 | parser.add_argument('--cfg', type=int, default=4)
117 | parser.add_argument('--bs', type=int, default=32)
118 | parser.add_argument('--epochs', type=int, default=10)
119 | parser.add_argument('--input_dim', type=int, default=7)
120 | parser.add_argument('--hidden_dim', type=int, default=128)
121 | parser.add_argument('--latent_dim', type=int, default=16)
122 | parser.add_argument('--dropout', type=float, default=0.3)
123 | parser.add_argument('--hops', type=int, default=5)
124 | parser.add_argument('--mlps', type=int, default=2)
125 | parser.add_argument('--latent_points', type=int, default=10000)
126 | parser.add_argument('--name', type=str, default='nasbench201', help='the prefix for the saved check point')
127 | args = parser.parse_args()
128 |
129 | #reproducbility is good
130 | np.random.seed(args.seed)
131 | torch.manual_seed(args.seed)
132 | torch.cuda.manual_seed_all(args.seed)
133 |
134 | cfg = configs[args.cfg]
135 | dataset = load_json(args.data)
136 | print('using {}'.format(args.data))
137 | print('feat dim {}'.format(args.latent_dim))
138 |
139 | pretraining_gae(dataset, cfg)
140 |
--------------------------------------------------------------------------------
/models/pretraining_nasbench201.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 | python models/pretraining_nasbench201.py
3 |
--------------------------------------------------------------------------------
/plot_scripts/draw_darts.py:
--------------------------------------------------------------------------------
1 | import sys
2 | import os
3 | sys.path.insert(0, os.getcwd())
4 | import darts.cnn.genotypes
5 | from graphviz import Digraph
6 |
7 |
8 | def plot(genotype, filename):
9 | g = Digraph(
10 | format='png',
11 | edge_attr=dict(fontsize='20', fontname="times"),
12 | node_attr=dict(style='filled', shape='rect', align='center', fontsize='20', height='0.5', width='0.5', penwidth='2', fontname="times"),
13 | engine='dot')
14 | g.body.extend(['rankdir=UD'])
15 |
16 | g.node("c_{k-2}", fillcolor='darkseagreen2')
17 | g.node("c_{k-1}", fillcolor='darkseagreen2')
18 | assert len(genotype) % 2 == 0
19 | steps = len(genotype) // 2
20 |
21 | for i in range(steps):
22 | g.node(str(i), fillcolor='lightblue')
23 |
24 | for i in range(steps):
25 | for k in [2*i, 2*i + 1]:
26 | j, op = genotype[k]
27 | j = int(j)
28 | if j == 0:
29 | u = "c_{k-2}"
30 | elif j == 1:
31 | u = "c_{k-1}"
32 | else:
33 | u = str(j-2)
34 | v = str(i)
35 | g.edge(u, v, label=op, fillcolor="gray")
36 |
37 | g.node("c_{k}", fillcolor='palegoldenrod')
38 | for i in range(steps):
39 | g.edge(str(i), "c_{k}", fillcolor="gray")
40 |
41 | g.render(filename, view=False)
42 |
43 |
44 | if __name__ == '__main__':
45 | if len(sys.argv) != 2:
46 | print("usage:\n python {} ARCH_NAME".format(sys.argv[0]))
47 | sys.exit(1)
48 |
49 | genotype_name = sys.argv[1]
50 | try:
51 | genotype = eval('genotypes.{}'.format(genotype_name))
52 | except AttributeError:
53 | print("{} is not specified in genotypes.py".format(genotype_name))
54 | sys.exit(1)
55 |
56 | plot(genotype.normal, "normal")
57 | plot(genotype.reduce, "reduction")
58 |
59 |
--------------------------------------------------------------------------------
/plot_scripts/drawfig4.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 |
3 | python plot_scripts/visdensity.py \
4 | --emb_path pretrained/dim-16/arch2vec-model-nasbench101.pt \
5 | --supervised_emb_path pretrained/dim-16/supervised_dngo_embedding_nasbench101.npy \
6 | --output_path density/nas101
--------------------------------------------------------------------------------
/plot_scripts/drawfig5-darts.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 |
3 | python plot_scripts/visgraph.py \
4 | --data_type darts \
5 | --data_path data/data_darts_counter600000.json \
6 | --emb_path pretrained/dim-16/arch2vec-darts.pt \
7 | --output_path graphvisualization
8 |
--------------------------------------------------------------------------------
/plot_scripts/drawfig5-nas101.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 |
3 | python plot_scripts/visgraph.py \
4 | --data_type nasbench101 \
5 | --data_path data/data.json \
6 | --emb_path pretrained/dim-16/arch2vec-model-nasbench101.pt \
7 | --supervised_emb_path pretrained/dim-16/supervised_dngo_embedding_nasbench101.npy \
8 | --output_path graphvisualization
9 |
--------------------------------------------------------------------------------
/plot_scripts/drawfig5-nas201.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 |
3 | python plot_scripts/visgraph.py \
4 | --data_type nasbench201 \
5 | --data_path data/cifar10_valid_converged.json \
6 | --emb_path pretrained/dim-16/cifar10_valid_converged-arch2vec.pt \
7 | --supervised_emb_path pretrained/dim-16/supervised_dngo_embedding_cifar10_nasbench201.npy \
8 | --output_path graphvisualization
--------------------------------------------------------------------------------
/plot_scripts/nas201.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AIoT-MLSys-Lab/arch2vec/ea01b0cf1295305596ee3c05fa1b6eb14e303512/plot_scripts/nas201.jpg
--------------------------------------------------------------------------------
/plot_scripts/pearson_plot_fig2.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 | import matplotlib.pyplot as plt
4 | from sklearn.metrics import mean_squared_error
5 | from math import sqrt
6 | from scipy import stats
7 | from copy import copy
8 | from mpl_toolkits.axes_grid1 import make_axes_locatable
9 | import matplotlib as mpl
10 | import os
11 |
12 | result_path = 'saved_logs/predict_accuracy'
13 | seed = [1, 10]
14 | acc_th = 0.8
15 |
16 | for s in seed:
17 | ## unsupervised
18 | un_pred_acc = np.load(os.path.join(result_path, 'dngo_unsupervised', 'pred_acc_seed{}.npy'.format(s)))
19 | un_test_acc = np.load(os.path.join(result_path, 'dngo_unsupervised', 'test_acc_seed{}.npy'.format(s)))
20 | idx0 = np.logical_and(un_test_acc > acc_th, un_pred_acc > acc_th) # np.logical_and(un_pred_acc > th, un_test_acc > th)
21 |
22 | ## supervised
23 | sup_pred_acc = np.load(os.path.join(result_path, 'dngo_supervised', 'pred_acc_seed{}.npy'.format(s)))
24 | sup_test_acc = np.load(os.path.join(result_path, 'dngo_supervised', 'test_acc_seed{}.npy'.format(s)))
25 | idx1 = np.logical_and(sup_test_acc > acc_th, sup_pred_acc > acc_th) # np.logical_and(sup_pred_acc > th, sup_test_acc > th)
26 |
27 | bins = np.linspace(0.8, 1, 301)
28 |
29 | fig, (ax0, ax1) = plt.subplots(ncols=2, figsize=(6, 3), sharey=True)
30 |
31 | ax0.plot([0.8, 1], [0.8, 1], 'yellowgreen', linewidth=2)
32 | ax1.plot([0.8, 1], [0.8, 1], 'yellowgreen', linewidth=2)
33 |
34 | H, xedges, yedges = np.histogram2d(un_test_acc[idx0], un_pred_acc[idx0], bins=bins)
35 | H = H.T
36 | Hm = np.ma.masked_where(H < 1, H)
37 | X, Y = np.meshgrid(xedges, yedges)
38 | palette = copy(plt.cm.viridis)
39 | palette.set_bad('w', 1.0)
40 | ax0.pcolormesh(X, Y, Hm, cmap=palette)
41 |
42 | H, xedges, yedges = np.histogram2d(sup_test_acc[idx1], un_pred_acc[idx1], bins=bins)
43 | H = H.T
44 | Hm = np.ma.masked_where(H < 1, H)
45 | X, Y = np.meshgrid(xedges, yedges)
46 | palette = copy(plt.cm.viridis)
47 | palette.set_bad('w', 1.0)
48 | ax1.pcolormesh(X, Y, Hm, cmap=palette)
49 |
50 | ax0.set_xlabel('Test Accuracy')
51 | ax0.set_ylabel('Predicted Accuracy')
52 | ax1.set_xlabel('Test Accuracy')
53 |
54 | ax0.set_xlim(0.8, 0.95)
55 | ax0.set_ylim(0.8, 0.95)
56 | ax1.set_xlim(0.8, 0.95)
57 | ax1.set_ylim(0.8, 0.95)
58 |
59 | ax0.set_yticks(ticks=[0.8, 0.85, 0.90, 0.95])
60 | ax0.set_xticks(ticks=[0.8, 0.85, 0.9])
61 | ax1.set_xticks(ticks=[0.8, 0.85, 0.9, 0.95])
62 |
63 | ax0.set_aspect('equal', 'box')
64 | ax1.set_aspect('equal', 'box')
65 |
66 | plt.subplots_adjust(wspace=0.05, top=0.9, bottom=0.1)
67 | plt.show()
68 | plt.savefig('compare_seed{}.png'.format(s), bbox_inches='tight')
69 | plt.close(fig=fig)
70 |
71 |
72 |
--------------------------------------------------------------------------------
/plot_scripts/plot_cdf.py:
--------------------------------------------------------------------------------
1 | import json
2 | import os
3 | import matplotlib as mpl
4 | import matplotlib.pyplot as plt
5 | from matplotlib.lines import Line2D
6 |
7 | def fix_hist_step_vertical_line_at_end(ax):
8 | axpolygons = [poly for poly in ax.get_children() if isinstance(poly, mpl.patches.Polygon)]
9 | for poly in axpolygons:
10 | poly.set_xy(poly.get_xy()[:-1])
11 |
12 | def plot_cdf_comparison(cmap=plt.get_cmap("tab10")):
13 | fig = plt.figure()
14 | ax = fig.add_subplot(1, 1, 1)
15 | final_test_regret_rd_nas101 = []
16 | final_test_regret_re_nas101 = []
17 | final_test_regret_rl_nas101 = []
18 | final_test_regret_bohb_nas101 = []
19 | final_test_regret_rl_supervised = []
20 | final_test_regret_bo_supervised = []
21 | final_test_regret_rl_arch2vec = []
22 | final_test_regret_bo_arch2vec = []
23 |
24 | for i in range(1, 501):
25 | f_name = 'saved_logs/discrete/random_search/run_{}_nas_cifar10a_{}.json'.format(i, 20000)
26 | if not os.path.exists(f_name):
27 | continue
28 | f = open(f_name)
29 | data = json.load(f)
30 | for ind, t in enumerate(data['runtime']):
31 | if t > 1e6:
32 | final_test_regret_rd_nas101.append(data['regret_test'][ind])
33 | break
34 | f.close()
35 |
36 | for i in range(1, 501):
37 | f_name = 'saved_logs/discrete/regularized_evolution/run_{}_nas_cifar10a_{}.json'.format(i, 3500)
38 | if not os.path.exists(f_name):
39 | continue
40 | f = open(f_name)
41 | data = json.load(f)
42 | for ind, t in enumerate(data['runtime']):
43 | if t > 1e6:
44 | final_test_regret_re_nas101.append(data['regret_test'][ind])
45 | break
46 | f.close()
47 |
48 | for i in range(1, 501):
49 | f_name = 'saved_logs/discrete/rl/run_{}_nas_cifar10a_{}.json'.format(i, 3670)
50 | if not os.path.exists(f_name):
51 | continue
52 | f = open(f_name)
53 | data = json.load(f)
54 | for ind, t in enumerate(data['runtime']):
55 | if t > 1e6:
56 | final_test_regret_rl_nas101.append(data['regret_test'][ind])
57 | break
58 | f.close()
59 |
60 | for i in range(1, 501):
61 | f_name = 'saved_logs/discrete/bohb/run_{}_nas_cifar10a_{}.json'.format(i, 1000)
62 | if not os.path.exists(f_name):
63 | continue
64 | f = open(f_name)
65 | data = json.load(f)
66 | for ind, t in enumerate(data['runtime']):
67 | if t > 1e6:
68 | final_test_regret_bohb_nas101.append(data['regret_test'][ind])
69 | break
70 | f.close()
71 |
72 | for i in range(1, 501):
73 | f_name = 'saved_logs/rl/dim16/nasbench101_supervised_search_logs/run_{}_supervised_rl.json'.format(i)
74 | if not os.path.exists(f_name):
75 | continue
76 | f = open(f_name)
77 | data = json.load(f)
78 | for ind, t in enumerate(data['runtime']):
79 | if t > 1e6:
80 | final_test_regret_rl_supervised.append(data['regret_test'][ind])
81 | break
82 | f.close()
83 |
84 | for i in range(1, 501):
85 | f_name = 'saved_logs/bo/dim16/nasbench101_supervised_search_logs/run_{}_supervised_bo.json'.format(i)
86 | if not os.path.exists(f_name):
87 | continue
88 | f = open(f_name)
89 | data = json.load(f)
90 | for ind, t in enumerate(data['runtime']):
91 | if t > 1e6:
92 | final_test_regret_bo_supervised.append(data['regret_test'][ind])
93 | break
94 | f.close()
95 |
96 | for i in range(1, 501):
97 | f_name = 'saved_logs/rl/dim16/nasbench101_search_logs/run_{}_arch2vec-model-vae-nasbench-101.json'.format(i)
98 | if not os.path.exists(f_name):
99 | continue
100 | f = open(f_name)
101 | data = json.load(f)
102 | for ind, t in enumerate(data['runtime']):
103 | if t > 1e6:
104 | final_test_regret_rl_arch2vec.append(data['regret_test'][ind])
105 | break
106 | f.close()
107 |
108 | for i in range(1, 501):
109 | f_name = 'saved_logs/bo/dim16/nasbench101_search_logs/run_{}_arch2vec-model-vae-nasbench-101.json'.format(i)
110 | if not os.path.exists(f_name):
111 | continue
112 | f = open(f_name)
113 | data = json.load(f)
114 | for ind, t in enumerate(data['runtime']):
115 | if t > 1e6:
116 | final_test_regret_bo_arch2vec.append(data['regret_test'][ind])
117 | break
118 | f.close()
119 |
120 |
121 | plt_name_rd_nas101 = '{}: {}'.format('Discrete', 'Random Search')
122 | plt_name_re_nas101 = '{}: {}'.format('Discrete', 'Regularized Evolution')
123 | plt_name_rl_nas101 = '{}: {}'.format('Discrete', 'REINFORCE')
124 | plt_name_bohb_nas101 = '{}: {}'.format('Discrete', 'BOHB')
125 | plt_name_rl_supervised = '{}: {}'.format('Supervised', 'REINFORCE')
126 | plt_name_bo_supervised = '{}: {}'.format('Supervised', 'Bayesian Optimization')
127 | plt_name_rl_arch2vec = '{}: {}'.format('arch2vec', 'REINFORCE')
128 | plt_name_bo_arch2vec = '{}: {}'.format('arch2vec', 'Bayesian Optimization')
129 |
130 | plt.hist(final_test_regret_rd_nas101, bins=10, range=[8e-4, 1.2e-2], normed=True, cumulative=True, histtype='step', linestyle='--', color=cmap(1), lw=2, label=plt_name_rd_nas101)
131 | plt.hist(final_test_regret_re_nas101, bins=10, range=[8e-4, 1.2e-2], normed=True, cumulative=True, histtype='step', linestyle='--', lw=2.0, color=cmap(4), label=plt_name_re_nas101)
132 | plt.hist(final_test_regret_rl_nas101, bins=10, range=[8e-4, 1.2e-2], normed=True, cumulative=True, histtype='step', linestyle='--', lw=2.0, color=cmap(6), label=plt_name_rl_nas101)
133 | plt.hist(final_test_regret_bohb_nas101, bins=10, range=[8e-4, 1.2e-2], normed=True, cumulative=True, histtype='step', linestyle='--', lw=2.0, color=cmap(5), label=plt_name_bohb_nas101)
134 | plt.hist(final_test_regret_rl_supervised, bins=10, range=[8e-4, 1.2e-2], normed=True, cumulative=True, histtype='step', linestyle='-.', lw=2.0, color=cmap(7), label=plt_name_rl_supervised)
135 | plt.hist(final_test_regret_bo_supervised, bins=10, range=[8e-4, 1.2e-2], normed=True, cumulative=True, histtype='step', linestyle='-.', lw=2.0, color=cmap(9), label=plt_name_bo_supervised)
136 | plt.hist(final_test_regret_rl_arch2vec, bins=10, range=[8e-4, 1.2e-2], normed=True, cumulative=True, histtype='step', linestyle='-.', lw=2.0, color=cmap(0), label=plt_name_rl_arch2vec)
137 | plt.hist(final_test_regret_bo_arch2vec, bins=10, range=[8e-4, 1.2e-2], normed=True, cumulative=True, histtype='step', linestyle='-.', lw=2.0, color=cmap(3), label=plt_name_bo_arch2vec)
138 | fix_hist_step_vertical_line_at_end(ax)
139 |
140 |
141 | ax.set_xscale('log')
142 | ax.set_xlabel('final test regret', fontsize=12)
143 | ax.set_ylabel('CDF', fontsize=12)
144 | handles, labels = ax.get_legend_handles_labels()
145 | new_handles = [Line2D([], [], c=h.get_edgecolor()) for h in handles]
146 | ax.legend(prop={"size":8}, handles=new_handles, labels=labels, loc='upper left')
147 |
148 |
149 | plt.show()
150 |
151 | if __name__ == '__main__':
152 | plot_cdf_comparison()
153 |
--------------------------------------------------------------------------------
/plot_scripts/plot_dngo_search_arch2vec.py:
--------------------------------------------------------------------------------
1 | import json
2 | import os
3 | import matplotlib.pyplot as plt
4 | from collections import defaultdict
5 |
6 | def plot_over_time_dngo_search_arch2vec(name, cmap=plt.get_cmap("tab10")):
7 | length = []
8 | for i in range(1, 501):
9 | f_name = 'saved_logs/bo/dim16/run_{}_{}-model-nasbench101.json'.format(i, name)
10 | if not os.path.exists(f_name):
11 | continue
12 | f = open(f_name)
13 | data = json.load(f)
14 | length.append(len(data['runtime']))
15 | f.close()
16 |
17 | data_avg = defaultdict(list)
18 | test_regret_avg = defaultdict(list)
19 | valid_regret_avg = defaultdict(list)
20 |
21 | fig = plt.figure()
22 | ax = fig.add_subplot(1, 2, 1)
23 | ax_test = fig.add_subplot(1, 2, 2)
24 | for i in range(1, 501):
25 | f_name = 'saved_logs/bo/dim16/run_{}_{}-model-nasbench101.json'.format(i, name)
26 | if not os.path.exists(f_name):
27 | continue
28 | f = open(f_name)
29 | data = json.load(f)
30 | for idx in range(min(length)):
31 | data_avg[idx].append(data['runtime'][idx])
32 | valid_regret_avg[idx].append(data['regret_validation'][idx])
33 | test_regret_avg[idx].append(data['regret_test'][idx])
34 | f.close()
35 |
36 | time_plot = []
37 | valid_plot = []
38 | test_plot = []
39 | for idx in range(min(length)):
40 | if sum(data_avg[idx]) / len(data_avg[idx]) > 1e6:
41 | continue
42 | time_plot.append(sum(data_avg[idx]) / len(data_avg[idx]))
43 | valid_plot.append(sum(valid_regret_avg[idx]) / len(valid_regret_avg[idx]))
44 | test_plot.append(sum(test_regret_avg[idx]) / len(test_regret_avg[idx]))
45 |
46 | ax.plot(time_plot, valid_plot, color=cmap(6), lw=2, label='{}: {}'.format('arch2vec', 'BO'))
47 | ax_test.plot(time_plot, test_plot, '--', color=cmap(6), lw=2, label='{}: {}'.format('arch2vec', 'BO'))
48 | ax.set_xscale('log')
49 | ax.set_yscale('log')
50 | ax.set_xlabel('estimated wall-clock time [s]')
51 | ax.set_ylabel('validation regret')
52 | ax.legend()
53 | ax_test.set_xscale('log')
54 | ax_test.set_yscale('log')
55 | ax_test.set_xlabel('estimated wall-clock time [s]')
56 | ax_test.set_ylabel('test regret')
57 | ax_test.legend()
58 |
59 | save_data = {'time_plot': time_plot, 'valid_plot': valid_plot, 'test_plot': test_plot}
60 | with open('results/{}-{}-nasbench-101.json'.format('BO', name), 'w') as f_w:
61 | json.dump(save_data, f_w)
62 |
63 | plt.show()
64 |
65 | if __name__ == '__main__':
66 | name = 'arch2vec'
67 | plot_over_time_dngo_search_arch2vec(name)
68 |
69 |
--------------------------------------------------------------------------------
/plot_scripts/plot_nasbench101_comparison.py:
--------------------------------------------------------------------------------
1 | import json
2 | import matplotlib.pyplot as plt
3 |
4 | def plot_over_time_comparison(cmap=plt.get_cmap("tab10")):
5 | fig = plt.figure()
6 | ax_test = fig.add_subplot(1, 1, 1)
7 |
8 | f_random_search = open('results/Random-Search-Encoding-A.json')
9 | f_regularized_evolution = open('results/Regularized-Evolution-Encoding-A.json')
10 | f_reinforce_search = open('results/Reinforce-Search-Encoding-A.json')
11 | f_bohb_search = open('results/BOHB-Search-Encoding-A.json')
12 | f_reinforce_search_arch2vec = open('results/RL-arch2vec-model-nasbench-101.json')
13 | f_bo_search_arch2vec = open('results/BO-arch2vec-model-nasbench-101.json')
14 | f_reinforce_search_supervised = open('results/RL-supervised-nasbench-101.json')
15 | f_bo_search_supervised = open('results/BO-supervised-nasbench-101.json')
16 | result_random_search = json.load(f_random_search)
17 | result_regularized_evolution = json.load(f_regularized_evolution)
18 | result_reinforce_search = json.load(f_reinforce_search)
19 | result_bohb_search = json.load(f_bohb_search)
20 | results_reinforce_search_arch2vec = json.load(f_reinforce_search_arch2vec)
21 | results_bo_search_arch2vec = json.load(f_bo_search_arch2vec)
22 | results_reinforce_search_supervised = json.load(f_reinforce_search_supervised)
23 | results_bo_search_supervised = json.load(f_bo_search_supervised)
24 | f_random_search.close()
25 | f_regularized_evolution.close()
26 | f_reinforce_search.close()
27 | f_bohb_search.close()
28 | f_reinforce_search_arch2vec.close()
29 | f_bo_search_arch2vec.close()
30 | f_reinforce_search_supervised.close()
31 | f_bo_search_supervised.close()
32 |
33 | ax_test.plot(result_random_search['time_plot'], result_random_search['test_plot'], linestyle='-.', marker='^', markevery=1e3, color=cmap(1), lw=2, markersize=4, label='{}: {}'.format('Discrete', 'Random Search'))
34 | ax_test.plot(result_regularized_evolution['time_plot'], result_regularized_evolution['test_plot'], linestyle='-.', marker='s', markevery=1e3, color=cmap(4), lw=2, markersize=4, label='{}: {}'.format('Discrete', 'Regularized Evolution'))
35 | ax_test.plot(result_reinforce_search['time_plot'], result_reinforce_search['test_plot'], linestyle='-.', marker='.', markevery=1e3, color=cmap(6), lw=2, markersize=4, label='{}: {}'.format('Discrete', 'REINFORCE'))
36 | ax_test.plot(result_bohb_search['time_plot'], result_bohb_search['test_plot'] , linestyle='-.', marker='*', markevery=1e3, color=cmap(5), lw=2, markersize=4, label='{}: {}'.format('Discrete', 'BOHB'))
37 | ax_test.plot(results_reinforce_search_supervised['time_plot'], results_reinforce_search_supervised['test_plot'], linestyle='--', marker='.', markevery=1e3, color=cmap(7), lw=2, markersize=4, label='{}: {}'.format('Supervised', 'REINFORCE'))
38 | ax_test.plot(results_bo_search_supervised['time_plot'], results_bo_search_supervised['test_plot'], linestyle='--', marker='v', markevery=1e3, color=cmap(9), lw=2, markersize=4, label='{}: {}'.format('Supervised', 'Bayesian Optimization'))
39 | ax_test.plot(results_reinforce_search_arch2vec['time_plot'], results_reinforce_search_arch2vec['test_plot'], linestyle='-.', marker='.', markevery=1e3, color=cmap(0), lw=2, markersize=4, label='{}: {}'.format('arch2vec', 'REINFORCE'))
40 | ax_test.plot(results_bo_search_arch2vec['time_plot'], results_bo_search_arch2vec['test_plot'], linestyle='-.', marker='v', markevery=1e3, color=cmap(3), lw=2, markersize=4, label='{}: {}'.format('arch2vec', 'Bayesian Optimization'))
41 |
42 | ax_test.set_xscale('log')
43 | ax_test.set_yscale('log')
44 | ax_test.set_xlabel('estimated wall-clock time [s]', fontsize=12)
45 | ax_test.set_ylabel('test regret', fontsize=12)
46 | ax_test.legend(prop={"size":10})
47 |
48 | plt.show()
49 |
50 | if __name__ == '__main__':
51 | plot_over_time_comparison()
52 |
--------------------------------------------------------------------------------
/plot_scripts/plot_reinforce_search_arch2vec.py:
--------------------------------------------------------------------------------
1 | import json
2 | import os
3 | import matplotlib.pyplot as plt
4 | from collections import defaultdict
5 |
6 | def plot_over_time_reinforce_search_arch2vec(name, cmap=plt.get_cmap("tab10")):
7 | length = []
8 | for i in range(1, 501):
9 | f_name = 'saved_logs/rl/dim16/run_{}_{}-model-nasbench101.json'.format(i, name)
10 | if not os.path.exists(f_name):
11 | continue
12 | f = open(f_name)
13 | data = json.load(f)
14 | length.append(len(data['runtime']))
15 | f.close()
16 |
17 | data_avg = defaultdict(list)
18 | test_regret_avg = defaultdict(list)
19 | valid_regret_avg = defaultdict(list)
20 |
21 | fig = plt.figure()
22 | ax = fig.add_subplot(1, 2, 1)
23 | ax_test = fig.add_subplot(1, 2, 2)
24 | for i in range(1, 501):
25 | f_name = 'saved_logs/rl/dim16/run_{}_{}-model-nasbench101.json'.format(i, name)
26 | if not os.path.exists(f_name):
27 | continue
28 | f = open(f_name)
29 | data = json.load(f)
30 | for idx in range(min(length)):
31 | data_avg[idx].append(data['runtime'][idx])
32 | valid_regret_avg[idx].append(data['regret_validation'][idx])
33 | test_regret_avg[idx].append(data['regret_test'][idx])
34 | f.close()
35 |
36 | time_plot = []
37 | valid_plot = []
38 | test_plot = []
39 | for idx in range(min(length)):
40 | if sum(data_avg[idx]) / len(data_avg[idx]) > 1e6:
41 | continue
42 | time_plot.append(sum(data_avg[idx]) / len(data_avg[idx]))
43 | valid_plot.append(sum(valid_regret_avg[idx]) / len(valid_regret_avg[idx]))
44 | test_plot.append(sum(test_regret_avg[idx]) / len(test_regret_avg[idx]))
45 |
46 | ax.plot(time_plot, valid_plot, color=cmap(6), lw=2, label='{}: {}'.format('arch2vec', 'RL'))
47 | ax_test.plot(time_plot, test_plot, '--', color=cmap(6), lw=2, label='{}: {}'.format('arch2vec', 'RL'))
48 | ax.set_xscale('log')
49 | ax.set_yscale('log')
50 | ax.set_xlabel('estimated wall-clock time [s]')
51 | ax.set_ylabel('validation regret')
52 | ax.legend()
53 | ax_test.set_xscale('log')
54 | ax_test.set_yscale('log')
55 | ax_test.set_xlabel('estimated wall-clock time [s]')
56 | ax_test.set_ylabel('test regret')
57 | ax_test.legend()
58 |
59 | save_data = {'time_plot': time_plot, 'valid_plot': valid_plot, 'test_plot': test_plot}
60 | with open('results/{}-{}-nasbench-101.json'.format('RL', name), 'w') as f_w:
61 | json.dump(save_data, f_w)
62 |
63 | plt.show()
64 |
65 | if __name__ == '__main__':
66 | name = 'arch2vec'
67 | plot_over_time_reinforce_search_arch2vec(name)
68 |
69 |
--------------------------------------------------------------------------------
/plot_scripts/summarize_nasbench201.py:
--------------------------------------------------------------------------------
1 | import json
2 | import os
3 | import numpy as np
4 | #from prettytable import PrettyTable
5 |
6 |
7 | #t = PrettyTable(['Method', 'CIFAR-10 val', 'CIFAR-10 test', 'CIFAR-100 val', 'CIFAR-100 test', 'ImageNet-16-120 val', 'ImageNet-16-120 test'])
8 | #t = PrettyTable(['Method', 'CIFAR-10 val', 'CIFAR-10 test'])
9 |
10 | def get_summary(dataset, file_name, data_dir, val_test, N_runs):
11 | val_acc = []
12 | test_acc = []
13 | for k in range(1, N_runs+1):
14 | file_name_ = file_name.format(dataset, k)
15 | file_path = os.path.join(data_dir, file_name_)
16 | if os.path.isfile(file_path):
17 | with open(file_path, 'r') as f:
18 | acc_dict = json.load(f)
19 | val_acc.append(acc_dict[val_test[0]]) # using average instead of individual
20 | test_acc.append(acc_dict[val_test[1]])
21 | val_acc = np.array(val_acc)
22 | test_acc = np.array(test_acc)
23 |
24 | return val_acc.mean(), val_acc.std(), test_acc.mean(), test_acc.std()
25 |
26 |
27 | # RL (ours)
28 | row = ['arch2vec-RL']
29 | data_dir = 'saved_logs/rl/dim16/'
30 | datasets = {'cifar10_valid_converged':500, 'cifar100':500, 'ImageNet16_120':500}
31 | file_name = 'nasbench201_{}_run_{}_full.json'
32 | val_test = ['val_acc_avg', 'test_acc_avg']
33 | for i, (dataset, N_runs) in enumerate(datasets.items()):
34 | val_mean, val_std, test_mean, test_std = get_summary(dataset, file_name, data_dir, val_test, N_runs)
35 | row.append('{:.2f}+-{:.2f}'.format(val_mean, val_std))
36 | row.append('{:.2f}+-{:.2f}'.format(test_mean, test_std))
37 | print(row)
38 |
39 |
40 |
41 | ## BO (ours)
42 | row = ['arch2vec-BO']
43 | data_dir = 'saved_logs/bo/dim16/'
44 | datasets = {'cifar10_valid_converged':500, 'cifar100':500, 'ImageNet16_120':500}
45 | file_name = 'nasbench201_{}_run_{}_full.json'
46 | val_test = ['val_acc_avg', 'test_acc_avg']
47 | for i, (dataset, N_runs) in enumerate(datasets.items()):
48 | val_mean, val_std, test_mean, test_std = get_summary(dataset, file_name, data_dir, val_test, N_runs)
49 | row.append('{:.2f}+-{:.2f}'.format(val_mean, val_std))
50 | row.append('{:.2f}+-{:.2f}'.format(test_mean, test_std))
51 |
52 |
53 | print(row)
54 |
55 |
56 |
57 |
58 |
--------------------------------------------------------------------------------
/plot_scripts/try_networkx.py:
--------------------------------------------------------------------------------
1 | import networkx as nx
2 | import numpy as np
3 | import matplotlib.pyplot as plt
4 |
5 | def node_match(n1, n2):
6 | if n1['op'] == n2['op']:
7 | return True
8 | else:
9 | return False
10 |
11 | def edge_match(e1, e2):
12 | return True
13 |
14 | def gen_graph(adj, ops):
15 | G = nx.DiGraph()
16 | for k, op in enumerate(ops):
17 | G.add_node(k, op=op)
18 | assert adj.shape[0] == adj.shape[1] == len(ops)
19 | for row in range(len(ops)):
20 | for col in range(row + 1, len(ops)):
21 | if adj[row, col] > 0:
22 | G.add_edge(row, col)
23 | return G
24 |
25 | def preprocess_adj_op(adj, op):
26 | def counting_trailing_false(l):
27 | count = 0
28 | for TF in l[-1::-1]:
29 | if TF:
30 | break
31 | else:
32 | count += 1
33 | return count
34 |
35 | def transform_op(op):
36 | idx2op = {0:'input', 1:'conv1x1-bn-relu', 2:'conv3x3-bn-relu', 3:'maxpool3x3', 4:'output'}
37 | return [idx2op[idx] for idx in op.argmax(axis=1)]
38 |
39 | adj = np.array(adj).astype(int)
40 | op = np.array(op).astype(int)
41 |
42 | assert op.shape[0] == adj.shape[0] == adj.shape[1]
43 | # find all zero columns
44 | adj_zero_col = counting_trailing_false(adj.any(axis=0))
45 | # find all zero rows
46 | adj_zero_row = counting_trailing_false(adj.any(axis=1))
47 | # find all zero rows
48 | op_zero_row = counting_trailing_false(op.any(axis=1))
49 | assert adj_zero_col == op_zero_row == adj_zero_row - 1, 'Inconsistant result {}={}={}'.format(adj_zero_col, op_zero_row, adj_zero_row - 1)
50 | N = op.shape[0] - adj_zero_col
51 | adj = adj[:N, :N]
52 | op = op[:N]
53 |
54 | return adj, transform_op(op)
55 |
56 |
57 |
58 | if __name__ == '__main__':
59 |
60 | adj1 = np.array([[0, 1, 1, 1, 0],
61 | [0, 0, 1, 0, 0],
62 | [0, 0, 0, 0, 1],
63 | [0, 0, 0, 0, 1],
64 | [0, 0, 0, 0, 0]])
65 | op1 = ['in', 'conv1x1', 'conv3x3', 'mp3x3', 'out']
66 |
67 | adj2 = np.array([[0, 1, 1, 1, 0],
68 | [0, 0, 0, 1, 0],
69 | [0, 0, 0, 0, 1],
70 | [0, 0, 0, 0, 1],
71 | [0, 0, 0, 0, 0]])
72 | op2 = ['in', 'conv1x1', 'mp3x3', 'conv3x3', 'out']
73 |
74 |
75 | adj3 = np.array([[0, 1, 1, 1, 0, 0],
76 | [0, 0, 1, 0, 0, 0],
77 | [0, 0, 0, 0, 1, 0],
78 | [0, 0, 0, 0, 1, 0],
79 | [0, 0, 0, 0, 0, 1],
80 | [0, 0, 0, 0, 0, 0]])
81 | op3 = ['in', 'conv1x1', 'conv3x3', 'mp3x3', 'out','out2']
82 |
83 | adj4 = np.array([[0, 1, 1, 1, 0, 0],
84 | [0, 0, 1, 0, 0, 0],
85 | [0, 0, 0, 0, 1, 0],
86 | [0, 0, 0, 0, 1, 0],
87 | [0, 0, 0, 0, 0, 0],
88 | [0, 0, 0, 0, 0, 0]])
89 | op4 = np.array([[1, 0, 0, 0, 0],
90 | [0, 1, 0, 0, 0],
91 | [0, 0, 1, 0, 0],
92 | [0, 0, 0, 1, 0],
93 | [0, 0, 0, 0, 1],
94 | [0, 0, 0, 0, 0]])
95 | adj4, op4 = preprocess_adj_op(adj4, op4)
96 |
97 |
98 |
99 | G1 = gen_graph(adj1, op1)
100 | G2 = gen_graph(adj2, op2)
101 | G3 = gen_graph(adj3, op3)
102 | G4 = gen_graph(adj4, op4)
103 |
104 |
105 | plt.subplot(141)
106 | nx.draw(G1, with_labels=True, font_weight='bold')
107 | plt.subplot(142)
108 | nx.draw(G2, with_labels=True, font_weight='bold')
109 | plt.subplot(143)
110 | nx.draw(G3, with_labels=True, font_weight='bold')
111 | plt.subplot(144)
112 | nx.draw(G4, with_labels=True, font_weight='bold')
113 |
114 | nx.graph_edit_distance(G1,G2, node_match=node_match, edge_match=edge_match)
115 | nx.graph_edit_distance(G2,G3, node_match=node_match, edge_match=edge_match)
--------------------------------------------------------------------------------
/preprocessing/gen_isomorphism_graphs.py:
--------------------------------------------------------------------------------
1 | import json
2 | import torch
3 | import numpy as np
4 | import logging
5 | import sys
6 | import os
7 | sys.path.insert(0, os.getcwd())
8 | from darts.cnn.genotypes import Genotype
9 | from darts.cnn.model import NetworkImageNet as Network
10 | from thop import profile
11 |
12 | def process(geno):
13 | for i, item in enumerate(geno):
14 | geno[i] = tuple(geno[i])
15 | return geno
16 |
17 | def transform_operations(ops):
18 | transform_dict = {'c_k-2': 0, 'c_k-1': 1, 'none': 2, 'max_pool_3x3': 3, 'avg_pool_3x3': 4, 'skip_connect': 5,
19 | 'sep_conv_3x3': 6, 'sep_conv_5x5': 7, 'dil_conv_3x3': 8, 'dil_conv_5x5': 9, 'output': 10}
20 |
21 | ops_array = np.zeros([11, 11], dtype='int8')
22 | for row, op in enumerate(ops):
23 | ops_array[row, op] = 1
24 | return ops_array
25 |
26 | def sample_arch():
27 | num_ops = len(OPS)
28 | normal = []
29 | normal_name = []
30 | for i in range(NUM_VERTICES):
31 | ops = np.random.choice(range(num_ops), NUM_VERTICES)
32 | nodes_in_normal = np.random.choice(range(i+2), 2, replace=False)
33 | normal.extend([(nodes_in_normal[0], ops[0]), (nodes_in_normal[1], ops[1])])
34 | normal_name.extend([(str(nodes_in_normal[0]), OPS[ops[0]]), (str(nodes_in_normal[1]), OPS[ops[1]])])
35 |
36 | return (normal), (normal_name)
37 |
38 |
39 | def build_mat_encoding(normal, normal_name, counter):
40 | adj = torch.zeros(11, 11)
41 | ops = torch.zeros(11, 11)
42 | block_0 = (normal[0], normal[1])
43 | prev_b0_n1, prev_b0_n2 = block_0[0][0], block_0[1][0]
44 | prev_b0_o1, prev_b0_o2 = block_0[0][1], block_0[1][1]
45 |
46 | block_1 = (normal[2], normal[3])
47 | prev_b1_n1, prev_b1_n2 = block_1[0][0], block_1[1][0]
48 | prev_b1_o1, prev_b1_o2 = block_1[0][1], block_1[1][1]
49 |
50 | block_2 = (normal[4], normal[5])
51 | prev_b2_n1, prev_b2_n2 = block_2[0][0], block_2[1][0]
52 | prev_b2_o1, prev_b2_o2 = block_2[0][1], block_2[1][1]
53 |
54 | block_3 = (normal[6], normal[7])
55 | prev_b3_n1, prev_b3_n2 = block_3[0][0], block_3[1][0]
56 | prev_b3_o1, prev_b3_o2 = block_3[0][1], block_3[1][1]
57 |
58 | adj[2][-1] = 1
59 | adj[3][-1] = 1
60 | adj[4][-1] = 1
61 | adj[5][-1] = 1
62 | adj[6][-1] = 1
63 | adj[7][-1] = 1
64 | adj[8][-1] = 1
65 | adj[9][-1] = 1
66 |
67 | # B0
68 | adj[prev_b0_n1][2] = 1
69 | adj[prev_b0_n2][3] = 1
70 |
71 | # B1
72 | if prev_b1_n1 == 2:
73 | adj[2][4] = 1
74 | adj[3][4] = 1
75 | else:
76 | adj[prev_b1_n1][4] = 1
77 |
78 | if prev_b1_n2 == 2:
79 | adj[2][5] = 1
80 | adj[3][5] = 1
81 | else:
82 | adj[prev_b1_n2][5] = 1
83 |
84 | # B2
85 | if prev_b2_n1 == 2:
86 | adj[2][6] = 1
87 | adj[3][6] = 1
88 | elif prev_b2_n1 == 3:
89 | adj[4][6] = 1
90 | adj[5][6] = 1
91 | else:
92 | adj[prev_b2_n1][6] = 1
93 |
94 | if prev_b2_n2 == 2:
95 | adj[2][7] = 1
96 | adj[3][7] = 1
97 | elif prev_b2_n2 == 3:
98 | adj[4][7] = 1
99 | adj[5][7] = 1
100 | else:
101 | adj[prev_b2_n2][7] = 1
102 |
103 | # B3
104 | if prev_b3_n1 == 2:
105 | adj[2][8] = 1
106 | adj[3][8] = 1
107 | elif prev_b3_n1 == 3:
108 | adj[4][8] = 1
109 | adj[5][8] = 1
110 | elif prev_b3_n1 == 4:
111 | adj[6][8] = 1
112 | adj[7][8] = 1
113 | else:
114 | adj[prev_b3_n1][8] = 1
115 |
116 | if prev_b3_n2 == 2:
117 | adj[2][9] = 1
118 | adj[3][9] = 1
119 | elif prev_b3_n2 == 3:
120 | adj[4][9] = 1
121 | adj[5][9] = 1
122 | elif prev_b3_n2 == 4:
123 | adj[6][9] = 1
124 | adj[7][9] = 1
125 | else:
126 | adj[prev_b3_n2][9] = 1
127 |
128 | ops[0][0] = 1
129 | ops[1][1] = 1
130 | ops[-1][-1] = 1
131 | ops[2][prev_b0_o1+2] = 1
132 | ops[3][prev_b0_o2+2] = 1
133 | ops[4][prev_b1_o1+2] = 1
134 | ops[5][prev_b1_o2+2] = 1
135 | ops[6][prev_b2_o1+2] = 1
136 | ops[7][prev_b2_o2+2] = 1
137 | ops[8][prev_b3_o1+2] = 1
138 | ops[9][prev_b3_o2+2] = 1
139 |
140 | #print("adj encoding: \n{} \n".format(adj.int()))
141 | #print("ops encoding: \n{} \n".format(ops.int()))
142 |
143 | label = torch.argmax(ops, dim=1)
144 |
145 | fingerprint = graph_util.hash_module(adj.int().numpy(), label.int().numpy().tolist())
146 | if fingerprint not in buckets:
147 | normal_cell = [(item[1], int(item[0])) for item in normal_name]
148 | reduce_cell = normal_cell.copy()
149 | genotype = Genotype(normal=normal_cell, normal_concat=[2, 3, 4, 5], reduce=reduce_cell, reduce_concat=[2, 3, 4, 5])
150 | model = Network(48, 1000, 14, False, genotype).cuda()
151 | input = torch.randn(1, 3, 224, 224).cuda()
152 | macs, params = profile(model, inputs=(input, ))
153 | if macs < 6e8:
154 | counter += 1
155 | print("counter: {}, flops: {}, params: {}".format(counter, macs, params))
156 | buckets[fingerprint] = (adj.numpy().astype('int8').tolist(), label.numpy().astype('int8').tolist(), (normal_name))
157 |
158 | if counter > 0 and counter % 1e5 == 0:
159 | with open('data/data_darts_counter{}.json'.format(counter), 'w') as f:
160 | json.dump(buckets, f)
161 |
162 | return counter
163 |
164 | if __name__ == '__main__':
165 | from nasbench.lib import graph_util
166 | OPS = ['none',
167 | 'max_pool_3x3',
168 | 'avg_pool_3x3',
169 | 'skip_connect',
170 | 'sep_conv_3x3',
171 | 'sep_conv_5x5',
172 | 'dil_conv_3x3',
173 | 'dil_conv_5x5'
174 | ]
175 | NUM_VERTICES = 4
176 | INPUT_1 = 'c_k-2'
177 | INPUT_2 = 'c_k-1'
178 | logging.basicConfig(filename='darts_preparation.log')
179 |
180 | buckets = {}
181 | counter = 0
182 | while counter <= 6e5:
183 | normal, normal_name = sample_arch()
184 | counter = build_mat_encoding(normal, normal_name, counter)
185 |
--------------------------------------------------------------------------------
/preprocessing/gen_json.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 | from __future__ import division
3 | from __future__ import print_function
4 |
5 | from nasbench import api
6 | from random import randint
7 | import json
8 | import numpy as np
9 | from collections import OrderedDict
10 |
11 | # Replace this string with the path to the downloaded nasbench.tfrecord before
12 | # executing.
13 | NASBENCH_TFRECORD = 'data/nasbench_only108.tfrecord'
14 |
15 | INPUT = 'input'
16 | OUTPUT = 'output'
17 | CONV1X1 = 'conv1x1-bn-relu'
18 | CONV3X3 = 'conv3x3-bn-relu'
19 | MAXPOOL3X3 = 'maxpool3x3'
20 |
21 | def gen_data_point(nasbench):
22 |
23 | i = 0
24 | epoch = 108
25 |
26 | padding = [0, 0, 0, 0, 0, 0, 0]
27 | best_val_acc = 0
28 | best_test_acc = 0
29 |
30 | for unique_hash in nasbench.hash_iterator():
31 | fixed_metrics, computed_metrics = nasbench.get_metrics_from_hash(unique_hash)
32 | print('\nIterating over {} / {} unique models in the dataset.'.format(i, 423623))
33 | test_acc_avg = 0.0
34 | val_acc_avg = 0.0
35 | training_time = 0.0
36 | for repeat_index in range(len(computed_metrics[epoch])):
37 | assert len(computed_metrics[epoch])==3, 'len(computed_metrics[epoch]) should be 3'
38 | data_point = computed_metrics[epoch][repeat_index]
39 | val_acc_avg += data_point['final_validation_accuracy']
40 | test_acc_avg += data_point['final_test_accuracy']
41 | training_time += data_point['final_training_time']
42 | val_acc_avg = val_acc_avg/3.0
43 | test_acc_avg = test_acc_avg/3.0
44 | training_time_avg = training_time/3.0
45 | ops_array = transform_operations(fixed_metrics['module_operations'])
46 | adj_array = fixed_metrics['module_adjacency'].tolist()
47 | model_spec = api.ModelSpec(fixed_metrics['module_adjacency'], fixed_metrics['module_operations'])
48 | data = nasbench.query(model_spec, epochs=108)
49 | print('api training time: {}'.format(data['training_time']))
50 | print('real training time: {}'.format(training_time_avg))
51 |
52 | # pad zero to adjacent matrix that has nodes less than 7
53 | if len(adj_array) <= 6:
54 | for row in range(len(adj_array)):
55 | for _ in range(7-len(adj_array)):
56 | adj_array[row].append(0)
57 | for _ in range(7-len(adj_array)):
58 | adj_array.append(padding)
59 |
60 | if val_acc_avg > best_val_acc:
61 | best_val_acc = val_acc_avg
62 |
63 | if test_acc_avg > best_test_acc:
64 | best_test_acc = test_acc_avg
65 |
66 | print('best val. acc: {:.4f}, best test acc {:.4f}'.format(best_val_acc, best_test_acc))
67 |
68 | yield {i: # unique_hash
69 | {'test_accuracy': test_acc_avg,
70 | 'validation_accuracy': val_acc_avg,
71 | 'module_adjacency':adj_array,
72 | 'module_operations': ops_array.tolist(),
73 | 'training_time': training_time_avg}}
74 |
75 | i += 1
76 |
77 | def transform_operations(ops):
78 | transform_dict = {'input':0, 'conv1x1-bn-relu':1, 'conv3x3-bn-relu':2, 'maxpool3x3':3, 'output':4}
79 | ops_array = np.zeros([7,5], dtype='int8')
80 | for row, op in enumerate(ops):
81 | col = transform_dict[op]
82 | ops_array[row, col] = 1
83 | return ops_array
84 |
85 |
86 | def gen_json_file():
87 | nasbench = api.NASBench(NASBENCH_TFRECORD)
88 | nas_gen = gen_data_point(nasbench)
89 | data_dict = OrderedDict()
90 | for data_point in nas_gen:
91 | data_dict.update(data_point)
92 | with open('data/data.json', 'w') as outfile:
93 | json.dump(data_dict, outfile)
94 |
95 |
96 |
97 |
98 | if __name__ == '__main__':
99 | gen_json_file()
100 |
--------------------------------------------------------------------------------
/preprocessing/nasbench201_json.py:
--------------------------------------------------------------------------------
1 | """API source: https://github.com/D-X-Y/NAS-Bench-201/blob/v1.1/nas_201_api/api.py"""
2 | from api import NASBench201API as API
3 | import numpy as np
4 | import json
5 | from collections import OrderedDict
6 |
7 | nas_bench = API('data/NAS-Bench-201-v1_0-e61699.pth')
8 |
9 |
10 |
11 | # num = len(api)
12 | # for i, arch_str in enumerate(api):
13 | # print ('{:5d}/{:5d} : {:}'.format(i, len(api), arch_str))
14 | #
15 | # info = api.query_meta_info_by_index(1) # This is an instance of `ArchResults`
16 | # res_metrics = info.get_metrics('cifar10', 'train') # This is a dict with metric names as keys
17 | # cost_metrics = info.get_comput_costs('cifar100') # This is a dict with metric names as keys, e.g., flops, params, latency
18 | #
19 | # api.show(1)
20 | # api.show(2)
21 |
22 | def info2mat(arch_index):
23 | #info.all_results
24 |
25 | info = nas_bench.query_meta_info_by_index(arch_index)
26 | ops = {'input':0, 'nor_conv_1x1':1, 'nor_conv_3x3':2, 'avg_pool_3x3':3, 'skip_connect':4, 'none':5, 'output':6}
27 | adj_mat = np.array([[0, 1, 1, 0, 1, 0, 0, 0],
28 | [0, 0, 0, 1, 0, 1 ,0 ,0],
29 | [0, 0, 0, 0, 0, 0, 1, 0],
30 | [0, 0, 0, 0, 0, 0, 1, 0],
31 | [0, 0, 0, 0, 0, 0, 0, 1],
32 | [0, 0, 0, 0, 0, 0, 0, 1],
33 | [0, 0, 0, 0, 0, 0, 0, 1],
34 | [0, 0, 0, 0, 0, 0, 0, 0]])
35 |
36 | nodes = ['input']
37 | steps = info.arch_str.split('+')
38 | steps_coding = ['0', '0', '1', '0', '1', '2']
39 | cont = 0
40 | for step in steps:
41 | step = step.strip('|').split('|')
42 | for node in step:
43 | n, idx = node.split('~')
44 | assert idx == steps_coding[cont]
45 | cont += 1
46 | nodes.append(n)
47 | nodes.append('output')
48 |
49 | node_mat =np.zeros([8, len(ops)]).astype(int)
50 | ops_idx = [ops[k] for k in nodes]
51 | node_mat[[0,1,2,3,4,5,6,7],ops_idx] = 1
52 |
53 | # For cifar10-valid with converged
54 | valid_acc, val_acc_avg, time_cost, test_acc, test_acc_avg = train_and_eval(arch_index, nepoch=None, dataname='cifar10-valid', use_converged_LR=True)
55 | cifar10_valid_converged = { 'test_accuracy': test_acc,
56 | 'test_accuracy_avg': test_acc_avg,
57 | 'validation_accuracy':valid_acc,
58 | 'validation_accuracy_avg': val_acc_avg,
59 | 'module_adjacency':adj_mat.tolist(),
60 | 'module_operations': node_mat.tolist(),
61 | 'training_time': time_cost}
62 |
63 |
64 | # For cifar100
65 | valid_acc, val_acc_avg, time_cost, test_acc, test_acc_avg = train_and_eval(arch_index, nepoch=199, dataname='cifar100', use_converged_LR=False)
66 | cifar100 = {'test_accuracy': test_acc,
67 | 'test_accuracy_avg': test_acc_avg,
68 | 'validation_accuracy': valid_acc,
69 | 'validation_accuracy_avg': val_acc_avg,
70 | 'module_adjacency': adj_mat.tolist(),
71 | 'module_operations': node_mat.tolist(),
72 | 'training_time': time_cost}
73 |
74 | # For ImageNet16-120
75 | valid_acc, val_acc_avg, time_cost, test_acc, test_acc_avg = train_and_eval(arch_index, nepoch=199, dataname='ImageNet16-120', use_converged_LR=False)
76 | ImageNet16_120 = {'test_accuracy': test_acc,
77 | 'test_accuracy_avg': test_acc_avg,
78 | 'validation_accuracy': valid_acc,
79 | 'validation_accuracy_avg': val_acc_avg,
80 | 'module_adjacency': adj_mat.tolist(),
81 | 'module_operations': node_mat.tolist(),
82 | 'training_time': time_cost}
83 |
84 |
85 | return {'cifar10_valid_converged': cifar10_valid_converged,
86 | 'cifar100':cifar100,
87 | 'ImageNet16_120': ImageNet16_120 }
88 |
89 | def train_and_eval(arch_index, nepoch=None, dataname=None, use_converged_LR=True):
90 | assert dataname !='cifar10', 'Do not allow cifar10 dataset'
91 | if use_converged_LR and dataname=='cifar10-valid':
92 | assert nepoch == None, 'When using use_converged_LR=True, please set nepoch=None, use 12-converged-epoch by default.'
93 |
94 |
95 | info = nas_bench.get_more_info(arch_index, dataname, None, True)
96 | valid_acc, time_cost = info['valid-accuracy'], info['train-all-time'] + info['valid-per-time']
97 | valid_acc_avg = nas_bench.get_more_info(arch_index, 'cifar10-valid', None, False, False)['valid-accuracy']
98 | test_acc = nas_bench.get_more_info(arch_index, 'cifar10', None, False, True)['test-accuracy']
99 | test_acc_avg = nas_bench.get_more_info(arch_index, 'cifar10', None, False, False)['test-accuracy']
100 |
101 | elif not use_converged_LR:
102 |
103 | assert isinstance(nepoch, int), 'nepoch should be int'
104 | xoinfo = nas_bench.get_more_info(arch_index, 'cifar10-valid', None, True)
105 | xocost = nas_bench.get_cost_info(arch_index, 'cifar10-valid', False)
106 | info = nas_bench.get_more_info(arch_index, dataname, nepoch, False, True)
107 | cost = nas_bench.get_cost_info(arch_index, dataname, False)
108 | # The following codes are used to estimate the time cost.
109 | # When we build NAS-Bench-201, architectures are trained on different machines and we can not use that time record.
110 | # When we create checkpoints for converged_LR, we run all experiments on 1080Ti, and thus the time for each architecture can be fairly compared.
111 | nums = {'ImageNet16-120-train': 151700, 'ImageNet16-120-valid': 3000,
112 | 'cifar10-valid-train' : 25000, 'cifar10-valid-valid' : 25000,
113 | 'cifar100-train' : 50000, 'cifar100-valid' : 5000}
114 | estimated_train_cost = xoinfo['train-per-time'] / nums['cifar10-valid-train'] * nums['{:}-train'.format(dataname)] / xocost['latency'] * cost['latency'] * nepoch
115 | estimated_valid_cost = xoinfo['valid-per-time'] / nums['cifar10-valid-valid'] * nums['{:}-valid'.format(dataname)] / xocost['latency'] * cost['latency']
116 | try:
117 | valid_acc, time_cost = info['valid-accuracy'], estimated_train_cost + estimated_valid_cost
118 | except:
119 | valid_acc, time_cost = info['est-valid-accuracy'], estimated_train_cost + estimated_valid_cost
120 | test_acc = info['test-accuracy']
121 | test_acc_avg = nas_bench.get_more_info(arch_index, dataname, None, False, False)['test-accuracy']
122 | valid_acc_avg = nas_bench.get_more_info(arch_index, dataname, None, False, False)['valid-accuracy']
123 | else:
124 | # train a model from scratch.
125 | raise ValueError('NOT IMPLEMENT YET')
126 | return valid_acc, valid_acc_avg, time_cost, test_acc, test_acc_avg
127 |
128 |
129 | def enumerate_dataset(dataset):
130 | for k in range(len(nas_bench)):
131 | print('{}: {}/{}'.format(dataset, k,len(nas_bench)))
132 | res = info2mat(k)
133 | yield {k:res[dataset]}
134 |
135 | def gen_json_file(dataset):
136 | data_dict = OrderedDict()
137 | enum_dataset = enumerate_dataset(dataset)
138 | for data_point in enum_dataset:
139 | data_dict.update(data_point)
140 | with open('data/{}.json'.format(dataset), 'w') as outfile:
141 | json.dump(data_dict, outfile)
142 |
143 | if __name__=='__main__':
144 |
145 | for dataset in ['cifar10_valid_converged', 'cifar100', 'ImageNet16_120']:
146 | gen_json_file(dataset)
147 |
--------------------------------------------------------------------------------
/pybnn/__init__.py:
--------------------------------------------------------------------------------
1 | from pybnn.dngo import DNGO
2 | from pybnn.bayesian_linear_regression import BayesianLinearRegression
3 | from pybnn.base_model import BaseModel
4 |
--------------------------------------------------------------------------------
/pybnn/base_model.py:
--------------------------------------------------------------------------------
1 | import abc
2 | import numpy as np
3 |
4 |
5 | class BaseModel(object):
6 | __metaclass__ = abc.ABCMeta
7 |
8 | def __init__(self):
9 | """
10 | Abstract base class for all models
11 | """
12 | self.X = None
13 | self.y = None
14 |
15 | @abc.abstractmethod
16 | def train(self, X, y):
17 | """
18 | Trains the model on the provided data.
19 |
20 | Parameters
21 | ----------
22 | X: np.ndarray (N, D)
23 | Input data points. The dimensionality of X is (N, D),
24 | with N as the number of points and D is the number of input dimensions.
25 | y: np.ndarray (N,)
26 | The corresponding target values of the input data points.
27 | """
28 | pass
29 |
30 | def update(self, X, y):
31 | """
32 | Update the model with the new additional data. Override this function if your
33 | model allows to do something smarter than simple retraining
34 |
35 | Parameters
36 | ----------
37 | X: np.ndarray (N, D)
38 | Input data points. The dimensionality of X is (N, D),
39 | with N as the number of points and D is the number of input dimensions.
40 | y: np.ndarray (N,)
41 | The corresponding target values of the input data points.
42 | """
43 | X = np.append(self.X, X, axis=0)
44 | y = np.append(self.y, y, axis=0)
45 | self.train(X, y)
46 |
47 | @abc.abstractmethod
48 | def predict(self, X_test):
49 | """
50 | Predicts for a given set of test data points the mean and variance of its target values
51 |
52 | Parameters
53 | ----------
54 | X_test: np.ndarray (N, D)
55 | N Test data points with input dimensions D
56 |
57 | Returns
58 | ----------
59 | mean: ndarray (N,)
60 | Predictive mean of the test data points
61 | var: ndarray (N,)
62 | Predictive variance of the test data points
63 | """
64 | pass
65 |
66 | def _check_shapes_train(func):
67 | def func_wrapper(self, X, y, *args, **kwargs):
68 | assert X.shape[0] == y.shape[0]
69 | assert len(X.shape) == 2
70 | assert len(y.shape) == 1
71 | return func(self, X, y, *args, **kwargs)
72 | return func_wrapper
73 |
74 | def _check_shapes_predict(func):
75 | def func_wrapper(self, X, *args, **kwargs):
76 | assert len(X.shape) == 2
77 | return func(self, X, *args, **kwargs)
78 |
79 | return func_wrapper
80 |
81 | def get_json_data(self):
82 | """
83 | Json getter function'
84 |
85 | Returns
86 | ----------
87 | dictionary
88 | """
89 | json_data = {'X': self.X if self.X is None else self.X.tolist(),
90 | 'y': self.y if self.y is None else self.y.tolist(),
91 | 'hyperparameters': ""}
92 | return json_data
93 |
94 | def get_incumbent(self):
95 | """
96 | Returns the best observed point and its function value
97 |
98 | Returns
99 | ----------
100 | incumbent: ndarray (D,)
101 | current incumbent
102 | incumbent_value: ndarray (N,)
103 | the observed value of the incumbent
104 | """
105 | best_idx = np.argmin(self.y)
106 | return self.X[best_idx], self.y[best_idx]
107 |
--------------------------------------------------------------------------------
/pybnn/bayesian_linear_regression.py:
--------------------------------------------------------------------------------
1 | import emcee
2 | import logging
3 | import numpy as np
4 |
5 | from scipy import optimize
6 | from scipy import stats
7 |
8 | from pybnn.base_model import BaseModel
9 |
10 |
11 | def linear_basis_func(x):
12 | return np.append(x, np.ones([x.shape[0], 1]), axis=1)
13 |
14 |
15 | def quadratic_basis_func(x):
16 | x = np.append(x ** 2, x, axis=1)
17 | return np.append(x, np.ones([x.shape[0], 1]), axis=1)
18 |
19 |
20 | logger = logging.getLogger(__name__)
21 |
22 |
23 | class Prior(object):
24 |
25 | def __init__(self, rng=None):
26 | if rng is None:
27 | self.rng = np.random.RandomState(np.random.randint(0, 10000))
28 | else:
29 | self.rng = rng
30 |
31 | def lnprob(self, theta):
32 | """
33 | Compute the log probability for theta = [log alpha, log beta]
34 | :param theta:
35 | :return: log p(theta)
36 | """
37 | lp = 0
38 | lp += stats.norm.pdf(theta[0], loc=0, scale=1) # log alpha
39 | lp += stats.norm.pdf(theta[1], loc=0, scale=1) # log sigma^2
40 |
41 | return lp
42 |
43 | def sample_from_prior(self, n_samples):
44 | p0 = np.zeros([n_samples, 2])
45 |
46 | # Log alpha
47 | p0[:, 0] = self.rng.normal(loc=0,
48 | scale=1,
49 | size=n_samples)
50 |
51 | # Log sigma^2
52 | p0[:, 1] = self.rng.normal(loc=-3,
53 | scale=1,
54 | size=n_samples)
55 | return p0
56 |
57 |
58 | class BayesianLinearRegression(BaseModel):
59 |
60 | def __init__(self, alpha=1, beta=1000, basis_func=linear_basis_func,
61 | prior=None, do_mcmc=True, n_hypers=20, chain_length=2000,
62 | burnin_steps=2000, rng=None):
63 | """
64 | Implementation of Bayesian linear regression. See chapter 3.3 of the book
65 | "Pattern Recognition and Machine Learning" by Bishop for more details.
66 |
67 | Parameters
68 | ----------
69 | alpha: float
70 | Specifies the variance of the prior for the weights w
71 | beta : float
72 | Defines the inverse of the noise, i.e. beta = 1 / sigma^2
73 | basis_func : function
74 | Function handle to transfer the input with via basis functions
75 | (see the code above for an example)
76 | prior: Prior object
77 | Prior for alpha and beta. If set to None the default prior is used
78 | do_mcmc: bool
79 | If set to true different values for alpha and beta are sampled via MCMC from the marginal log likelihood
80 | Otherwise the marginal log likelihood is optimized with scipy fmin function
81 | n_hypers : int
82 | Number of samples for alpha and beta
83 | chain_length : int
84 | The chain length of the MCMC sampler
85 | burnin_steps: int
86 | The number of burnin steps before the sampling procedure starts
87 | rng: np.random.RandomState
88 | Random number generator
89 | """
90 |
91 | if rng is None:
92 | self.rng = np.random.RandomState(np.random.randint(0, 10000))
93 | else:
94 | self.rng = rng
95 |
96 | self.X = None
97 | self.y = None
98 | self.alpha = alpha
99 | self.beta = beta
100 | self.basis_func = basis_func
101 | if prior is None:
102 | self.prior = Prior(rng=self.rng)
103 | else:
104 | self.prior = prior
105 | self.do_mcmc = do_mcmc
106 | self.n_hypers = n_hypers
107 | self.chain_length = chain_length
108 | self.burned = False
109 | self.burnin_steps = burnin_steps
110 | self.models = None
111 |
112 | def marginal_log_likelihood(self, theta):
113 | """
114 | Log likelihood of the data marginalised over the weights w. See chapter 3.5 of
115 | the book by Bishop of an derivation.
116 |
117 | Parameters
118 | ----------
119 | theta: np.array(2,)
120 | The hyperparameter alpha and beta on a log scale
121 |
122 | Returns
123 | -------
124 | float
125 | lnlikelihood + prior
126 | """
127 |
128 | # Theta is on a log scale
129 | alpha = np.exp(theta[0])
130 | beta = 1 / np.exp(theta[1])
131 |
132 | D = self.X_transformed.shape[1]
133 | N = self.X_transformed.shape[0]
134 |
135 | A = beta * np.dot(self.X_transformed.T, self.X_transformed)
136 | A += np.eye(self.X_transformed.shape[1]) * alpha
137 | try:
138 | A_inv = np.linalg.inv(A)
139 | except np.linalg.linalg.LinAlgError:
140 | A_inv = np.linalg.inv(A + np.random.rand(A.shape[0], A.shape[1]) * 1e-8)
141 |
142 |
143 | m = beta * np.dot(A_inv, self.X_transformed.T)
144 | m = np.dot(m, self.y)
145 |
146 | mll = D / 2 * np.log(alpha)
147 | mll += N / 2 * np.log(beta)
148 | mll -= N / 2 * np.log(2 * np.pi)
149 | mll -= beta / 2. * np.linalg.norm(self.y - np.dot(self.X_transformed, m), 2)
150 | mll -= alpha / 2. * np.dot(m.T, m)
151 | mll -= 0.5 * np.log(np.linalg.det(A))
152 |
153 | if self.prior is not None:
154 | mll += self.prior.lnprob(theta)
155 |
156 | return mll
157 |
158 | def negative_mll(self, theta):
159 | """
160 | Returns the negative marginal log likelihood (for optimizing it with scipy).
161 |
162 | Parameters
163 | ----------
164 | theta: np.array(2,)
165 | The hyperparameter alpha and beta on a log scale
166 |
167 | Returns
168 | -------
169 | float
170 | negative lnlikelihood + prior
171 | """
172 | return -self.marginal_log_likelihood(theta)
173 |
174 | @BaseModel._check_shapes_train
175 | def train(self, X, y, do_optimize=True):
176 | """
177 | First optimized the hyperparameters if do_optimize is True and then computes
178 | the posterior distribution of the weights. See chapter 3.3 of the book by Bishop
179 | for more details.
180 |
181 | Parameters
182 | ----------
183 | X: np.ndarray (N, D)
184 | Input data points. The dimensionality of X is (N, D),
185 | with N as the number of points and D is the number of features.
186 | y: np.ndarray (N,)
187 | The corresponding target values.
188 | do_optimize: boolean
189 | If set to true the hyperparameters are optimized otherwise
190 | the default hyperparameters are used.
191 | """
192 |
193 | self.X = X
194 |
195 | if self.basis_func is not None:
196 | self.X_transformed = self.basis_func(X)
197 | else:
198 | self.X_transformed = self.X
199 |
200 | self.y = y
201 |
202 | if do_optimize:
203 | if self.do_mcmc:
204 | sampler = emcee.EnsembleSampler(self.n_hypers, 2,
205 | self.marginal_log_likelihood)
206 |
207 | # Do a burn-in in the first iteration
208 | if not self.burned:
209 | # Initialize the walkers by sampling from the prior
210 | self.p0 = self.prior.sample_from_prior(self.n_hypers)
211 |
212 | # Run MCMC sampling
213 | self.p0, _, _ = sampler.run_mcmc(self.p0,
214 | self.burnin_steps,
215 | rstate0=self.rng)
216 |
217 | self.burned = True
218 |
219 | # Start sampling
220 | pos, _, _ = sampler.run_mcmc(self.p0,
221 | self.chain_length,
222 | rstate0=self.rng)
223 |
224 | # Save the current position, it will be the start point in
225 | # the next iteration
226 | self.p0 = pos
227 |
228 | # Take the last samples from each walker
229 | self.hypers = np.exp(sampler.chain[:, -1])
230 | else:
231 | # Optimize hyperparameters of the Bayesian linear regression
232 | res = optimize.fmin(self.negative_mll, self.rng.rand(2))
233 | self.hypers = [[np.exp(res[0]), np.exp(res[1])]]
234 |
235 | else:
236 | self.hypers = [[self.alpha, self.beta]]
237 |
238 | self.models = []
239 | for sample in self.hypers:
240 | alpha = sample[0]
241 | beta = sample[1]
242 |
243 | logger.debug("Alpha=%f ; Beta=%f" % (alpha, beta))
244 |
245 | S_inv = beta * np.dot(self.X_transformed.T, self.X_transformed)
246 | S_inv += np.eye(self.X_transformed.shape[1]) * alpha
247 | try:
248 | S = np.linalg.inv(S_inv)
249 | except np.linalg.linalg.LinAlgError:
250 | S = np.linalg.inv(S_inv + np.random.rand(S_inv.shape[0], S_inv.shape[1]) * 1e-8)
251 |
252 | m = beta * np.dot(np.dot(S, self.X_transformed.T), self.y)
253 |
254 | self.models.append((m, S))
255 |
256 | @BaseModel._check_shapes_predict
257 | def predict(self, X_test):
258 | r"""
259 | Returns the predictive mean and variance of the objective function at
260 | the given test points.
261 |
262 | Parameters
263 | ----------
264 | X_test: np.ndarray (N, D)
265 | N input test points
266 |
267 | Returns
268 | ----------
269 | np.array(N,)
270 | predictive mean
271 | np.array(N,)
272 | predictive variance
273 |
274 | """
275 | if self.basis_func is not None:
276 | X_transformed = self.basis_func(X_test)
277 | else:
278 | X_transformed = X_test
279 |
280 | # Marginalise predictions over hyperparameters
281 | mu = np.zeros([len(self.hypers), X_transformed.shape[0]])
282 | var = np.zeros([len(self.hypers), X_transformed.shape[0]])
283 |
284 | for i, h in enumerate(self.hypers):
285 | mu[i] = np.dot(self.models[i][0].T, X_transformed.T)
286 | var[i] = 1. / h[1] + np.diag(np.dot(np.dot(X_transformed, self.models[i][1]), X_transformed.T))
287 |
288 | m = mu.mean(axis=0)
289 | v = var.mean(axis=0)
290 | # Clip negative variances and set them to the smallest
291 | # positive float value
292 | if v.shape[0] == 1:
293 | v = np.clip(v, np.finfo(v.dtype).eps, np.inf)
294 | else:
295 | v = np.clip(v, np.finfo(v.dtype).eps, np.inf)
296 | v[np.where((v < np.finfo(v.dtype).eps) & (v > -np.finfo(v.dtype).eps))] = 0
297 |
298 | return m, v
299 |
--------------------------------------------------------------------------------
/pybnn/util/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/AIoT-MLSys-Lab/arch2vec/ea01b0cf1295305596ee3c05fa1b6eb14e303512/pybnn/util/__init__.py
--------------------------------------------------------------------------------
/pybnn/util/normalization.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 |
4 | def zero_one_normalization(X, lower=None, upper=None):
5 |
6 | if lower is None:
7 | lower = np.min(X, axis=0)
8 | if upper is None:
9 | upper = np.max(X, axis=0)
10 |
11 | X_normalized = np.true_divide((X - lower), (upper - lower))
12 |
13 | return X_normalized, lower, upper
14 |
15 |
16 | def zero_one_denormalization(X_normalized, lower, upper):
17 | return lower + (upper - lower) * X_normalized
18 |
19 |
20 | def zero_mean_unit_var_normalization(X, mean=None, std=None):
21 | if mean is None:
22 | mean = np.mean(X, axis=0)
23 | if std is None:
24 | std = np.std(X, axis=0)
25 |
26 | X_normalized = (X - mean) / std
27 |
28 | return X_normalized, mean, std
29 |
30 |
31 | def zero_mean_unit_var_denormalization(X_normalized, mean, std):
32 | return X_normalized * std + mean
33 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | torch == 1.4.0
2 | torchvision == 0.5.0
3 | tensorflow == 1.15.0
4 | emcee == 3.0.2
5 | tqdm == 4.31.1
6 | networkx == 2.2
7 | graphviz == 0.14.2
8 | thop == 0.0.31.post2004101309
9 | texttable == 1.6.3
10 | python-igraph == 0.8.3
11 |
12 |
--------------------------------------------------------------------------------
/run_scripts/extract_arch2vec.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 | python search_methods/reinforce.py --dim 16 --model_path model-nasbench101.pt
3 |
4 |
--------------------------------------------------------------------------------
/run_scripts/extract_arch2vec_darts.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 | python search_methods/reinforce_darts.py --dim 16
3 |
--------------------------------------------------------------------------------
/run_scripts/extract_arch2vec_nasbench201.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 |
3 | python search_methods/reinforce_search_NB201_8x8.py --dataset_name cifar10_valid_converged --latent_dim 16 --model_path model-nasbench201.pt
4 |
5 | python search_methods/reinforce_search_NB201_8x8.py --dataset_name cifar100 --latent_dim 16 --model_path model-nasbench201.pt
6 |
7 | python search_methods/reinforce_search_NB201_8x8.py --dataset_name ImageNet16_120 --latent_dim 16 --model_path model-nasbench201.pt
8 |
9 |
--------------------------------------------------------------------------------
/run_scripts/run_bo_arch2vec_darts.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 | python search_methods/dngo_darts.py --max_budgets 100 --inner_epochs 50 --objective 0.95 --train_portion 0.9 --dim 16 --seed 3 --output_path saved_logs/bo --init_size 16 --batch_size 5 --logging_path darts-bo
3 |
--------------------------------------------------------------------------------
/run_scripts/run_bo_arch2vec_nasbench201_ImageNet.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 | for i in {16,}
3 | do
4 | for s in {1..500}
5 | do
6 | python search_methods/dngo_search_NB201_8x8.py --dim $i --seed $s --output_path saved_logs/bo --init_size 16 --batch_size 1 \
7 | --dataset_name ImageNet16_120 --MAX_BUDGET 1400000
8 | done
9 | done
10 |
--------------------------------------------------------------------------------
/run_scripts/run_bo_arch2vec_nasbench201_cifar100.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 | for i in {16,}
3 | do
4 | for s in {1..500}
5 | do
6 | python search_methods/dngo_search_NB201_8x8.py --dim $i --seed $s --output_path saved_logs/bo --init_size 16 --batch_size 1 \
7 | --dataset_name cifar100 --MAX_BUDGET 500000
8 | done
9 | done
10 |
--------------------------------------------------------------------------------
/run_scripts/run_bo_arch2vec_nasbench201_cifar10_valid.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 | for i in {16,}
3 | do
4 | for s in {1..500}
5 | do
6 | python search_methods/dngo_search_NB201_8x8.py --dim $i --seed $s --output_path saved_logs/bo --init_size 16 --batch_size 1 \
7 | --dataset_name cifar10_valid_converged --MAX_BUDGET 12000
8 | done
9 | done
10 |
--------------------------------------------------------------------------------
/run_scripts/run_dngo_arch2vec.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 | for s in {1..500}
3 | do
4 | python search_methods/dngo.py --dim 16 --seed $s --output_path saved_logs/bo --emb_path arch2vec-model-nasbench101.pt --init_size 16 --topk 5
5 | done
6 |
--------------------------------------------------------------------------------
/run_scripts/run_dngo_supervised.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 |
3 | for s in {1..500}
4 | do
5 | python search_methods/supervised_dngo.py --dim 16 --seed $s --init_size 16 --topk 5 --output_path saved_logs/bo
6 | done
7 |
--------------------------------------------------------------------------------
/run_scripts/run_reinforce_arch2vec.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 |
3 | #python search_methods/reinforce.py --dim 16 --seed $s --bs 16 --output_path saved_logs/rl --saved_arch2vec --emb_path arch2vec-nasbench101.pt
4 | for s in {1..500}
5 | do
6 | python search_methods/reinforce.py --dim 16 --seed $s --bs 16 --output_path saved_logs/rl --saved_arch2vec --emb_path arch2vec-model-nasbench101.pt
7 | done
8 |
--------------------------------------------------------------------------------
/run_scripts/run_reinforce_arch2vec_darts.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 | python search_methods/reinforce_darts.py --max_budgets 100 --inner_epochs 50 --objective 0.95 --train_portion 0.9 --dim 16 --seed 3 --bs 16 --output_path saved_logs/rl --saved_arch2vec --logging_path darts-rl
3 |
4 |
--------------------------------------------------------------------------------
/run_scripts/run_reinforce_arch2vec_nasbench201_ImageNet.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 | for i in {16,}
3 | do
4 | for s in {1..500}
5 | do
6 | python search_methods/reinforce_search_NB201_8x8.py --latent_dim $i --seed $s --bs 16 --gamma 0.4 --baseline 0.4 \
7 | --output_path saved_logs/rl --saved_arch2vec \
8 | --dataset_name ImageNet16_120 --MAX_BUDGET 1400000 --model_path model-nasbench201.pt
9 | done
10 | done
11 |
--------------------------------------------------------------------------------
/run_scripts/run_reinforce_arch2vec_nasbench201_cifar100.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 | for i in {16,}
3 | do
4 | for s in {1..500}
5 | do
6 | python search_methods/reinforce_search_NB201_8x8.py --latent_dim $i --seed $s --bs 16 --MAX_BUDGET 500000 --baseline 0.4 --gamma 0.4 --saved_arch2vec \
7 | --dataset_name cifar100 --output_path saved_logs/rl --model_path model-nasbench201.pt
8 | done
9 | done
10 |
--------------------------------------------------------------------------------
/run_scripts/run_reinforce_arch2vec_nasbench201_cifar10_valid.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 | for i in {16,}
3 | do
4 | for s in {1..500}
5 | do
6 | python search_methods/reinforce_search_NB201_8x8.py --latent_dim $i --seed $s --bs 16 --gamma 0.4 --baseline 0.4 \
7 | --output_path saved_logs/rl --saved_arch2vec \
8 | --dataset_name cifar10_valid_converged --MAX_BUDGET 12000 --model_path model-nasbench201.pt
9 | done
10 | done
11 |
--------------------------------------------------------------------------------
/run_scripts/run_reinforce_supervised.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 |
3 | for s in {1..500}
4 | do
5 | python search_methods/supervised_reinforce.py --dim 16 --seed $s --bs 16 --output_path saved_logs/rl
6 | done
7 |
--------------------------------------------------------------------------------
/search_methods/dngo.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 | sys.path.insert(0, os.getcwd())
4 | from pybnn.dngo import DNGO
5 | import argparse
6 | import json
7 | import torch
8 | import numpy as np
9 | from collections import defaultdict
10 | from torch.distributions import Normal
11 |
12 |
13 | def load_arch2vec(embedding_path):
14 | embedding = torch.load(embedding_path)
15 | print('load arch2vec from {}'.format(embedding_path))
16 | ind_list = range(len(embedding))
17 | features = [embedding[ind]['feature'] for ind in ind_list]
18 | valid_labels = [embedding[ind]['valid_accuracy'] for ind in ind_list]
19 | test_labels = [embedding[ind]['test_accuracy'] for ind in ind_list]
20 | training_time = [embedding[ind]['time'] for ind in ind_list]
21 | features = torch.stack(features, dim=0)
22 | test_labels = torch.Tensor(test_labels)
23 | valid_labels = torch.Tensor(valid_labels)
24 | training_time = torch.Tensor(training_time)
25 | print('loading finished. pretrained embeddings shape {}'.format(features.shape))
26 | return features, valid_labels, test_labels, training_time
27 |
28 |
29 | def get_init_samples(features, valid_labels, test_labels, training_time, visited):
30 | np.random.seed(args.seed)
31 | init_inds = np.random.permutation(list(range(features.shape[0])))[:args.init_size]
32 | init_inds = torch.Tensor(init_inds).long()
33 | init_feat_samples = features[init_inds]
34 | init_valid_label_samples = valid_labels[init_inds]
35 | init_test_label_samples = test_labels[init_inds]
36 | init_time_samples = training_time[init_inds]
37 | for idx in init_inds:
38 | visited[idx] = True
39 | return init_feat_samples, init_valid_label_samples, init_test_label_samples, init_time_samples, visited
40 |
41 |
42 | def propose_location(ei, features, valid_labels, test_labels, training_time, visited):
43 | k = args.topk
44 | print('remaining length of indices set:', len(features) - len(visited))
45 | indices = torch.argsort(ei)[-k:]
46 | ind_dedup = []
47 | for idx in indices:
48 | if idx not in visited:
49 | visited[idx] = True
50 | ind_dedup.append(idx)
51 | ind_dedup = torch.Tensor(ind_dedup).long()
52 | proposed_x, proposed_y_valid, proposed_y_test, proposed_time = features[ind_dedup], valid_labels[ind_dedup], test_labels[ind_dedup], training_time[ind_dedup]
53 | return proposed_x, proposed_y_valid, proposed_y_test, proposed_time, visited
54 |
55 |
56 | def expected_improvement_search():
57 | """ implementation of arch2vec-DNGO """
58 | BEST_TEST_ACC = 0.943175752957662
59 | BEST_VALID_ACC = 0.9505542318026224
60 | CURR_BEST_VALID = 0.
61 | CURR_BEST_TEST = 0.
62 | MAX_BUDGET = 1.5e6
63 | window_size = 200
64 | counter = 0
65 | rt = 0.
66 | visited = {}
67 | best_trace = defaultdict(list)
68 | features, valid_labels, test_labels, training_time = load_arch2vec(os.path.join('pretrained/dim-{}'.format(args.dim), args.emb_path))
69 | features, valid_labels, test_labels, training_time = features.cpu().detach(), valid_labels.cpu().detach(), test_labels.cpu().detach(), training_time.cpu().detach()
70 | feat_samples, valid_label_samples, test_label_samples, time_samples, visited = get_init_samples(features, valid_labels, test_labels, training_time, visited)
71 |
72 | for feat, acc_valid, acc_test, t in zip(feat_samples, valid_label_samples, test_label_samples, time_samples):
73 | counter += 1
74 | rt += t.item()
75 | if acc_valid > CURR_BEST_VALID:
76 | CURR_BEST_VALID = acc_valid
77 | CURR_BEST_TEST = acc_test
78 | best_trace['regret_validation'].append(float(BEST_VALID_ACC - CURR_BEST_VALID))
79 | best_trace['regret_test'].append(float(BEST_TEST_ACC - CURR_BEST_TEST))
80 | best_trace['time'].append(rt)
81 | best_trace['counter'].append(counter)
82 |
83 | while rt < MAX_BUDGET:
84 | print("feat_samples:", feat_samples.shape)
85 | print("valid label_samples:", valid_label_samples.shape)
86 | print("test label samples:", test_label_samples.shape)
87 | print("current best validation: {}".format(CURR_BEST_VALID))
88 | print("current best test: {}".format(CURR_BEST_TEST))
89 | print("rt: {}".format(rt))
90 | print(feat_samples.shape)
91 | print(valid_label_samples.shape)
92 | model = DNGO(num_epochs=100, n_units=128, do_mcmc=False, normalize_output=False, rng=args.seed)
93 | model.train(X=feat_samples.numpy(), y=valid_label_samples.view(-1).numpy(), do_optimize=True)
94 | print(model.network)
95 | m = []
96 | v = []
97 | chunks = int(features.shape[0] / window_size)
98 | if features.shape[0] % window_size > 0:
99 | chunks += 1
100 | features_split = torch.split(features, window_size, dim=0)
101 | for i in range(chunks):
102 | m_split, v_split = model.predict(features_split[i].numpy())
103 | m.extend(list(m_split))
104 | v.extend(list(v_split))
105 | mean = torch.Tensor(m)
106 | sigma = torch.Tensor(v)
107 | u = (mean - torch.Tensor([0.95]).expand_as(mean)) / sigma
108 | normal = Normal(torch.zeros_like(u), torch.ones_like(u))
109 | ucdf = normal.cdf(u)
110 | updf = torch.exp(normal.log_prob(u))
111 | ei = sigma * (updf + u * ucdf)
112 | feat_next, label_next_valid, label_next_test, time_next, visited = propose_location(ei, features, valid_labels, test_labels, training_time, visited)
113 |
114 | # add proposed networks to the pool
115 | for feat, acc_valid, acc_test, t in zip(feat_next, label_next_valid, label_next_test, time_next):
116 | if acc_valid > CURR_BEST_VALID:
117 | CURR_BEST_VALID = acc_valid
118 | CURR_BEST_TEST = acc_test
119 | feat_samples = torch.cat((feat_samples, feat.view(1, -1)), dim=0)
120 | valid_label_samples = torch.cat((valid_label_samples.view(-1, 1), acc_valid.view(1, 1)), dim=0)
121 | test_label_samples = torch.cat((test_label_samples.view(-1, 1), acc_test.view(1, 1)), dim=0)
122 | counter += 1
123 | rt += t.item()
124 | best_trace['regret_validation'].append(float(BEST_VALID_ACC - CURR_BEST_VALID))
125 | best_trace['regret_test'].append(float(BEST_TEST_ACC - CURR_BEST_TEST))
126 | best_trace['time'].append(rt)
127 | best_trace['counter'].append(counter)
128 | if rt >= MAX_BUDGET:
129 | break
130 |
131 | res = dict()
132 | res['regret_validation'] = best_trace['regret_validation']
133 | res['regret_test'] = best_trace['regret_test']
134 | res['runtime'] = best_trace['time']
135 | res['counter'] = best_trace['counter']
136 | save_path = os.path.join(args.output_path, 'dim{}'.format(args.dim))
137 | if not os.path.exists(save_path):
138 | os.mkdir(save_path)
139 | print('save to {}'.format(save_path))
140 | if args.emb_path.endswith('.pt'):
141 | s = args.emb_path[:-3]
142 | fh = open(os.path.join(save_path, 'run_{}_{}.json'.format(args.seed, s)),'w')
143 | json.dump(res, fh)
144 | fh.close()
145 |
146 |
147 | if __name__ == '__main__':
148 | parser = argparse.ArgumentParser(description="arch2vec-DNGO")
149 | parser.add_argument("--seed", type=int, default=1, help="random seed")
150 | parser.add_argument('--cfg', type=int, default=4, help='configuration (default: 4)')
151 | parser.add_argument('--dim', type=int, default=16, help='feature dimension')
152 | parser.add_argument('--init_size', type=int, default=16, help='init samples')
153 | parser.add_argument('--topk', type=int, default=5, help='acquisition samples')
154 | parser.add_argument('--output_path', type=str, default='bo', help='bo')
155 | parser.add_argument('--emb_path', type=str, default='arch2vec.pt')
156 | args = parser.parse_args()
157 | np.random.seed(args.seed)
158 | torch.manual_seed(args.seed)
159 | torch.cuda.manual_seed_all(args.seed)
160 | torch.set_num_threads(2)
161 | expected_improvement_search()
162 |
--------------------------------------------------------------------------------
/search_methods/dngo_darts.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 | sys.path.insert(0, os.getcwd())
4 | from pybnn.dngo import DNGO
5 | import random
6 | import argparse
7 | import json
8 | import torch
9 | import numpy as np
10 | from collections import defaultdict
11 | from torch.distributions import Normal
12 | from darts.cnn.train_search import Train
13 |
14 | def load_arch2vec(embedding_path):
15 | embedding = torch.load(embedding_path)
16 | print('load arch2vec from {}'.format(embedding_path))
17 | ind_list = range(len(embedding))
18 | features = [embedding[ind]['feature'] for ind in ind_list]
19 | genotype = [embedding[ind]['genotype'] for ind in ind_list]
20 | features = torch.stack(features, dim=0)
21 | print('loading finished. pretrained embeddings shape {}'.format(features.shape))
22 | return features, genotype
23 |
24 |
25 | def query(counter, seed, genotype, epochs):
26 | trainer = Train()
27 | rewards, rewards_test = trainer.main(counter, seed, genotype, epochs=epochs, train_portion=args.train_portion, save=args.logging_path)
28 | val_sum = 0
29 | for epoch, val_acc in rewards:
30 | val_sum += val_acc
31 | val_avg = val_sum / len(rewards)
32 | return val_avg / 100., rewards_test[-1][-1] / 100.
33 |
34 | def get_init_samples(features, genotype, visited):
35 | count = 0
36 | np.random.seed(args.seed)
37 | init_inds = np.random.permutation(list(range(features.shape[0])))[:args.init_size]
38 | init_inds = torch.Tensor(init_inds).long()
39 | print('init index: {}'.format(init_inds))
40 | init_feat_samples = features[init_inds]
41 | init_geno_samples = [genotype[i.item()] for i in init_inds]
42 | init_valid_label_samples = []
43 | init_test_label_samples = []
44 |
45 | for geno in init_geno_samples:
46 | val_acc, test_acc = query(count, args.seed, geno, args.inner_epochs)
47 | init_valid_label_samples.append(val_acc)
48 | init_test_label_samples.append(test_acc)
49 | count += 1
50 |
51 | init_valid_label_samples = torch.Tensor(init_valid_label_samples)
52 | init_test_label_samples = torch.Tensor(init_test_label_samples)
53 | for idx in init_inds:
54 | visited[idx.item()] = True
55 | return init_feat_samples, init_geno_samples, init_valid_label_samples, init_test_label_samples, visited
56 |
57 |
58 | def propose_location(ei, features, genotype, visited, counter):
59 | count = counter
60 | k = args.batch_size
61 | c = 0
62 | print('remaining length of indices set:', len(features) - len(visited))
63 | indices = torch.argsort(ei)
64 | ind_dedup = []
65 | # remove random sampled indices at each step
66 | for idx in reversed(indices):
67 | if c == k:
68 | break
69 | if idx.item() not in visited:
70 | visited[idx.item()] = True
71 | ind_dedup.append(idx.item())
72 | c += 1
73 | ind_dedup = torch.Tensor(ind_dedup).long()
74 | print('proposed index: {}'.format(ind_dedup))
75 | proposed_x = features[ind_dedup]
76 | proposed_geno = [genotype[i.item()] for i in ind_dedup]
77 | proposed_val_acc = []
78 | proposed_test_acc = []
79 | for geno in proposed_geno:
80 | val_acc, test_acc = query(count, args.seed, geno, args.inner_epochs)
81 | proposed_val_acc.append(val_acc)
82 | proposed_test_acc.append(test_acc)
83 | count += 1
84 |
85 | return proposed_x, proposed_geno, torch.Tensor(proposed_val_acc), torch.Tensor(proposed_test_acc), visited
86 |
87 |
88 | def expected_improvement_search(features, genotype):
89 | """ implementation of arch2vec-DNGO on DARTS Search Space """
90 | CURR_BEST_VALID = 0.
91 | CURR_BEST_TEST = 0.
92 | CURR_BEST_GENOTYPE = None
93 | MAX_BUDGET = args.max_budgets
94 | window_size = 200
95 | counter = 0
96 | visited = {}
97 | best_trace = defaultdict(list)
98 |
99 | features, genotype = features.cpu().detach(), genotype
100 | feat_samples, geno_samples, valid_label_samples, test_label_samples, visited = get_init_samples(features, genotype, visited)
101 |
102 | for feat, geno, acc_valid, acc_test in zip(feat_samples, geno_samples, valid_label_samples, test_label_samples):
103 | counter += 1
104 | if acc_valid > CURR_BEST_VALID:
105 | CURR_BEST_VALID = acc_valid
106 | CURR_BEST_TEST = acc_test
107 | CURR_BEST_GENOTYPE = geno
108 | best_trace['validation_acc'].append(float(CURR_BEST_VALID))
109 | best_trace['test_acc'].append(float(CURR_BEST_TEST))
110 | best_trace['genotype'].append(CURR_BEST_GENOTYPE)
111 | best_trace['counter'].append(counter)
112 |
113 | while counter < MAX_BUDGET:
114 | print("feat_samples:", feat_samples.shape)
115 | print("length of genotypes:", len(geno_samples))
116 | print("valid label_samples:", valid_label_samples.shape)
117 | print("test label samples:", test_label_samples.shape)
118 | print("current best validation: {}".format(CURR_BEST_VALID))
119 | print("current best test: {}".format(CURR_BEST_TEST))
120 | print("counter: {}".format(counter))
121 | print(feat_samples.shape)
122 | print(valid_label_samples.shape)
123 | model = DNGO(num_epochs=100, n_units=128, do_mcmc=False, normalize_output=False)
124 | model.train(X=feat_samples.numpy(), y=valid_label_samples.view(-1).numpy(), do_optimize=True)
125 | print(model.network)
126 | m = []
127 | v = []
128 | chunks = int(features.shape[0] / window_size)
129 | if features.shape[0] % window_size > 0:
130 | chunks += 1
131 | features_split = torch.split(features, window_size, dim=0)
132 | for i in range(chunks):
133 | m_split, v_split = model.predict(features_split[i].numpy())
134 | m.extend(list(m_split))
135 | v.extend(list(v_split))
136 | mean = torch.Tensor(m)
137 | sigma = torch.Tensor(v)
138 | u = (mean - torch.Tensor([args.objective]).expand_as(mean)) / sigma
139 | normal = Normal(torch.zeros_like(u), torch.ones_like(u))
140 | ucdf = normal.cdf(u)
141 | updf = torch.exp(normal.log_prob(u))
142 | ei = sigma * (updf + u * ucdf)
143 | feat_next, geno_next, label_next_valid, label_next_test, visited = propose_location(ei, features, genotype, visited, counter)
144 |
145 | # add proposed networks to the pool
146 | for feat, geno, acc_valid, acc_test in zip(feat_next, geno_next, label_next_valid, label_next_test):
147 | feat_samples = torch.cat((feat_samples, feat.view(1, -1)), dim=0)
148 | geno_samples.append(geno)
149 | valid_label_samples = torch.cat((valid_label_samples.view(-1, 1), acc_valid.view(1, 1)), dim=0)
150 | test_label_samples = torch.cat((test_label_samples.view(-1, 1), acc_test.view(1, 1)), dim=0)
151 | counter += 1
152 | if acc_valid.item() > CURR_BEST_VALID:
153 | CURR_BEST_VALID = acc_valid.item()
154 | CURR_BEST_TEST = acc_test.item()
155 | CURR_BEST_GENOTYPE = geno
156 |
157 | best_trace['validation_acc'].append(float(CURR_BEST_VALID))
158 | best_trace['test_acc'].append(float(CURR_BEST_TEST))
159 | best_trace['genotype'].append(CURR_BEST_GENOTYPE)
160 | best_trace['counter'].append(counter)
161 |
162 | if counter >= MAX_BUDGET:
163 | break
164 |
165 | res = dict()
166 | res['validation_acc'] = best_trace['validation_acc']
167 | res['test_acc'] = best_trace['test_acc']
168 | res['genotype'] = best_trace['genotype']
169 | res['counter'] = best_trace['counter']
170 | save_path = os.path.join(args.output_path, 'dim{}'.format(args.dim))
171 | if not os.path.exists(save_path):
172 | os.mkdir(save_path)
173 | print('save to {}'.format(save_path))
174 | fh = open(os.path.join(save_path, 'run_{}_arch2vec_model_darts.json'.format(args.seed)), 'w')
175 | json.dump(res, fh)
176 | fh.close()
177 |
178 |
179 | if __name__ == '__main__':
180 | parser = argparse.ArgumentParser(description="arch2vec-DNGO")
181 | parser.add_argument("--seed", type=int, default=3, help="random seed")
182 | parser.add_argument('--cfg', type=int, default=4, help='configuration (default: 4)')
183 | parser.add_argument('--dim', type=int, default=16, help='feature dimension')
184 | parser.add_argument('--objective', type=float, default=0.95, help='ei objective')
185 | parser.add_argument('--init_size', type=int, default=16, help='init samples')
186 | parser.add_argument('--batch_size', type=int, default=5, help='acquisition samples')
187 | parser.add_argument('--inner_epochs', type=int, default=50, help='inner loop epochs')
188 | parser.add_argument('--train_portion', type=float, default=0.9, help='inner loop train/val split')
189 | parser.add_argument('--max_budgets', type=int, default=100, help='max number of trials')
190 | parser.add_argument('--output_path', type=str, default='saved_logs/bo', help='bo')
191 | parser.add_argument('--logging_path', type=str, default='', help='search logging path')
192 | args = parser.parse_args()
193 | torch.manual_seed(args.seed)
194 | embedding_path = 'pretrained/dim-{}/arch2vec-darts.pt'.format(args.dim)
195 | if not os.path.exists(embedding_path):
196 | exit()
197 | features, genotype = load_arch2vec(embedding_path)
198 | expected_improvement_search(features, genotype)
199 |
--------------------------------------------------------------------------------
/search_methods/dngo_search_NB201_8x8.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 | sys.path.insert(0, os.getcwd())
4 | from pybnn.dngo import DNGO
5 | import random
6 | import argparse
7 | import json
8 | import torch
9 | import numpy as np
10 | from collections import defaultdict
11 | from torch.distributions import Normal
12 | import time
13 |
14 |
15 | def load_arch2vec(embedding_path):
16 | embedding = torch.load(embedding_path)
17 | print('load pretrained arch2vec from {}'.format(embedding_path))
18 | random.seed(args.seed)
19 | random.shuffle(embedding)
20 | features = [embedding[ind]['feature'] for ind in range(len(embedding))]
21 | valid_labels = [embedding[ind]['valid_accuracy']/100.0 for ind in range(len(embedding))]
22 | test_labels = [embedding[ind]['test_accuracy']/100.0 for ind in range(len(embedding))]
23 | training_time = [embedding[ind]['time'] for ind in range(len(embedding))]
24 | other_info = [embedding[ind]['other_info'] for ind in range(len(embedding))]
25 | features = torch.stack(features, dim=0)
26 | valid_labels = torch.Tensor(valid_labels)
27 | test_labels = torch.Tensor(test_labels)
28 | training_time = torch.Tensor(training_time)
29 | print('loading finished. pretrained embeddings shape {}, and valid labels shape {}, and test labels shape {}'.format(features.shape, valid_labels.shape, test_labels.shape))
30 | return features, valid_labels, test_labels, training_time, other_info
31 |
32 |
33 | def get_init_samples(features, valid_labels, test_labels, training_time, other_info, visited):
34 | np.random.seed(args.seed)
35 | init_inds = np.random.permutation(list(range(features.shape[0])))[:args.init_size]
36 | init_inds = torch.Tensor(init_inds).long()
37 | init_feat_samples = features[init_inds]
38 | init_valid_label_samples = valid_labels[init_inds]
39 | init_test_label_samples = test_labels[init_inds]
40 | init_time_samples = training_time[init_inds]
41 | print('='*20, init_inds)
42 | init_other_info_samples = [other_info[k] for k in init_inds]
43 | for idx in init_inds:
44 | visited[idx] = True
45 | return init_feat_samples, init_valid_label_samples, init_test_label_samples, init_time_samples, init_other_info_samples, visited
46 |
47 |
48 | def propose_location(ei, features, valid_labels, test_labels, training_time, other_info, visited):
49 | k = args.batch_size
50 | print('remaining length of indices set:', len(features) - len(visited))
51 | indices = torch.argsort(ei)[-k:]
52 | ind_dedup = []
53 | # remove random sampled indices at each step
54 | for idx in indices:
55 | if idx not in visited:
56 | visited[idx] = True
57 | ind_dedup.append(idx)
58 | ind_dedup = torch.Tensor(ind_dedup).long()
59 | proposed_x, proposed_y_valid, proposed_y_test, proposed_time, propose_info = features[ind_dedup], valid_labels[ind_dedup], test_labels[ind_dedup], training_time[ind_dedup], [other_info[k] for k in ind_dedup]
60 | return proposed_x, proposed_y_valid, proposed_y_test, proposed_time, propose_info, visited
61 |
62 |
63 | def expected_improvement_search(features, valid_labels, test_labels, training_time, other_info):
64 | """ implementation of expected improvement search given arch2vec.
65 | :param data_path: the pretrained arch2vec path.
66 | :return: features, labels
67 | """
68 | CURR_BEST_VALID = 0.
69 | CURR_BEST_TEST = 0.
70 | CURR_BEST_INFO = None
71 | MAX_BUDGET = args.MAX_BUDGET
72 | window_size = 200
73 | counter = 0
74 | rt = 0.
75 | visited = {}
76 | best_trace = defaultdict(list)
77 |
78 | features, valid_labels, test_labels, training_time = features.cpu().detach(), valid_labels.cpu().detach(), test_labels.cpu().detach(), training_time.cpu().detach()
79 | feat_samples, valid_label_samples, test_label_samples, time_samples, other_info_sampled, visited = get_init_samples(features, valid_labels, test_labels, training_time, other_info, visited)
80 |
81 | t_start = time.time()
82 | for feat, acc_valid, acc_test, t, o_info in zip(feat_samples, valid_label_samples, test_label_samples, time_samples, other_info_sampled):
83 | counter += 1
84 | rt += t.item()
85 | if acc_valid > CURR_BEST_VALID:
86 | CURR_BEST_VALID = acc_valid
87 | CURR_BEST_TEST = acc_test
88 | CURR_BEST_INFO = o_info
89 | best_trace['validation'].append(float(CURR_BEST_VALID))
90 | best_trace['test'].append(float(CURR_BEST_TEST))
91 | best_trace['time'].append(time.time() - t_start)
92 | best_trace['counter'].append(counter)
93 |
94 | while rt < MAX_BUDGET:
95 | print("feat_samples:", feat_samples.shape)
96 | print("valid label_samples:", valid_label_samples.shape)
97 | print("test label samples:", test_label_samples.shape)
98 | print("current best validation: {}".format(CURR_BEST_VALID))
99 | print("current best test: {}".format(CURR_BEST_TEST))
100 | print("rt: {}".format(rt))
101 | print(feat_samples.shape)
102 | print(valid_label_samples.shape)
103 | model = DNGO(num_epochs=100, n_units=128, do_mcmc=False, normalize_output=False)
104 | model.train(X=feat_samples.numpy(), y=valid_label_samples.view(-1).numpy(), do_optimize=True)
105 | print(model.network)
106 | m = []
107 | v = []
108 | chunks = int(features.shape[0] / window_size)
109 | if features.shape[0] % window_size > 0:
110 | chunks += 1
111 | features_split = torch.split(features, window_size, dim=0)
112 | for i in range(chunks):
113 | m_split, v_split = model.predict(features_split[i].numpy())
114 | m.extend(list(m_split))
115 | v.extend(list(v_split))
116 | mean = torch.Tensor(m)
117 | sigma = torch.Tensor(v)
118 | u = (mean - torch.Tensor([1.0]).expand_as(mean)) / sigma
119 | normal = Normal(torch.zeros_like(u), torch.ones_like(u))
120 | ucdf = normal.cdf(u)
121 | updf = torch.exp(normal.log_prob(u))
122 | ei = sigma * (updf + u * ucdf)
123 | feat_next, label_next_valid, label_next_test, time_next, info_next, visited = propose_location(ei, features, valid_labels, test_labels, training_time, other_info, visited)
124 |
125 | # add proposed networks to selected networks
126 | for feat, acc_valid, acc_test, t, o_info in zip(feat_next, label_next_valid, label_next_test, time_next, info_next):
127 | feat_samples = torch.cat((feat_samples, feat.view(1, -1)), dim=0)
128 | valid_label_samples = torch.cat((valid_label_samples.view(-1, 1), acc_valid.view(1, 1)), dim=0)
129 | test_label_samples = torch.cat((test_label_samples.view(-1, 1), acc_test.view(1, 1)), dim=0)
130 | counter += 1
131 | rt += t.item()
132 | if acc_valid > CURR_BEST_VALID:
133 | CURR_BEST_VALID = acc_valid
134 | CURR_BEST_TEST = acc_test
135 | CURR_BEST_INFO = o_info
136 |
137 | best_trace['acc_validation'].append(float( CURR_BEST_VALID))
138 | best_trace['acc_test'].append(float(CURR_BEST_TEST))
139 | best_trace['search_time'].append(time.time() - t_start) # The actual searching time
140 | best_trace['counter'].append(counter)
141 |
142 | if rt >= MAX_BUDGET:
143 | break
144 |
145 | res = dict()
146 | res['regret_validation'] = best_trace['regret_validation']
147 | res['regret_test'] = best_trace['regret_test']
148 | res['runtime'] = best_trace['time']
149 | res['counter'] = best_trace['counter']
150 | save_path = os.path.join(args.output_path, 'dim{}'.format(args.dim))
151 | if not os.path.exists(save_path):
152 | os.mkdir(save_path)
153 | print('save to {}'.format(save_path))
154 | print('Current Best Valid {}, Test {}'.format(CURR_BEST_VALID, CURR_BEST_TEST))
155 | data_dict = {'val_acc': float(CURR_BEST_VALID), 'test_acc': float(CURR_BEST_TEST),
156 | 'val_acc_avg': float(CURR_BEST_INFO['valid_accuracy_avg']),
157 | 'test_acc_avg': float(CURR_BEST_INFO['test_accuracy_avg'])}
158 | save_dir = os.path.join(save_path, 'nasbench201_{}_run_{}_full.json'.format(args.dataset_name, args.seed))
159 | with open(save_dir, 'w') as f:
160 | json.dump(data_dict, f)
161 |
162 |
163 | if __name__ == '__main__':
164 | parser = argparse.ArgumentParser(description="DNGO search for NB201")
165 | parser.add_argument("--gamma", type=float, default=0, help="discount factor (default 0.99)")
166 | parser.add_argument("--seed", type=int, default=1, help="random seed")
167 | parser.add_argument('--cfg', type=int, default=4, help='configuration (default: 4)')
168 | parser.add_argument('--dim', type=int, default=16, help='feature dimension')
169 | parser.add_argument('--init_size', type=int, default=16, help='init samples')
170 | parser.add_argument('--batch_size', type=int, default=1, help='acquisition samples')
171 | parser.add_argument('--output_path', type=str, default='saved_logs/bo', help='rl/gd/predictor/bo (default: bo)')
172 | parser.add_argument('--saved_arch2vec', action="store_true", default=True)
173 |
174 | parser.add_argument('--dataset_name', type=str, default='ImageNet16_120',
175 | help='Select from | cifar100 | ImageNet16_120 | cifar10_valid | cifar10_valid_converged')
176 | parser.add_argument('--MAX_BUDGET', type=float, default=1200000, help='The budget in seconds')
177 |
178 | args = parser.parse_args()
179 | #reproducbility is good
180 | np.random.seed(args.seed)
181 | torch.manual_seed(args.seed)
182 | torch.cuda.manual_seed_all(args.seed)
183 | embedding_path = 'pretrained/dim-{}/{}-arch2vec.pt'.format(args.dim, args.dataset_name)
184 | if not os.path.exists(embedding_path):
185 | exit()
186 | features, valid_labels, test_labels, training_time, other_info = load_arch2vec(embedding_path)
187 | expected_improvement_search(features, valid_labels, test_labels, training_time, other_info)
188 |
--------------------------------------------------------------------------------
/search_methods/reinforce.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 | sys.path.insert(0, os.getcwd())
4 | import numpy as np
5 | import argparse
6 | import json
7 | import random
8 | import torch
9 | import torch.nn as nn
10 | import torch.nn.functional as F
11 | import torch.optim as optim
12 | from models.pretraining_nasbench101 import configs
13 | from utils.utils import load_json, preprocessing
14 | from models.model import Model
15 | from torch.distributions import MultivariateNormal
16 |
17 | class Env(object):
18 | def __init__(self, name, seed, emb_path, model_path, cfg, data_path=None, save=False):
19 | self.name = name
20 | self.model_path = model_path
21 | self.emb_path = emb_path
22 | self.seed = seed
23 | self.dir_name = 'pretrained/dim-{}'.format(args.dim)
24 | self.visited = {}
25 | self.features = []
26 | self.embedding = {}
27 | self._reset(data_path, save)
28 |
29 | def _reset(self, data_path, save):
30 | if not save:
31 | print("extract arch2vec from {}".format(os.path.join(self.dir_name, self.model_path)))
32 | if not os.path.exists(os.path.join(self.dir_name, self.model_path)):
33 | exit()
34 | dataset = load_json(data_path)
35 | self.model = Model(input_dim=5, hidden_dim=128, latent_dim=16, num_hops=5, num_mlp_layers=2, dropout=0, **cfg['GAE']).cuda()
36 | self.model.load_state_dict(torch.load(os.path.join(self.dir_name, self.model_path).format(args.dim))['model_state'])
37 | self.model.eval()
38 | with torch.no_grad():
39 | print("length of the dataset: {}".format(len(dataset)))
40 | self.f_path = os.path.join(self.dir_name, 'arch2vec-{}'.format(self.model_path))
41 | if os.path.exists(self.f_path):
42 | print('{} is already saved'.format(self.f_path))
43 | exit()
44 | print('save to {}'.format(self.f_path))
45 | for ind in range(len(dataset)):
46 | adj = torch.Tensor(dataset[str(ind)]['module_adjacency']).unsqueeze(0).cuda()
47 | ops = torch.Tensor(dataset[str(ind)]['module_operations']).unsqueeze(0).cuda()
48 | adj, ops, prep_reverse = preprocessing(adj, ops, **cfg['prep'])
49 | test_acc = dataset[str(ind)]['test_accuracy']
50 | valid_acc = dataset[str(ind)]['validation_accuracy']
51 | time = dataset[str(ind)]['training_time']
52 | x,_ = self.model._encoder(ops, adj)
53 | self.embedding[ind] = {'feature': x.squeeze(0).mean(dim=0).cpu(), 'valid_accuracy': float(valid_acc), 'test_accuracy': float(test_acc), 'time': float(time)}
54 | torch.save(self.embedding, self.f_path)
55 | print("finish arch2vec extraction")
56 | exit()
57 | else:
58 | self.f_path = os.path.join(self.dir_name, self.emb_path)
59 | print("load arch2vec from: {}".format(self.f_path))
60 | self.embedding = torch.load(self.f_path)
61 | for ind in range(len(self.embedding)):
62 | self.features.append(self.embedding[ind]['feature'])
63 | self.features = torch.stack(self.features, dim=0)
64 | print('loading finished. pretrained embeddings shape: {}'.format(self.features.shape))
65 |
66 | def get_init_state(self):
67 | """
68 | :return: 1 x dim
69 | """
70 | random.seed(args.seed)
71 | rand_indices = random.randint(0, self.features.shape[0])
72 | self.visited[rand_indices] = True
73 | return self.features[rand_indices], self.embedding[rand_indices]['valid_accuracy'],\
74 | self.embedding[rand_indices]['test_accuracy'], self.embedding[rand_indices]['time']
75 |
76 | def step(self, action):
77 | """
78 | action: 1 x dim
79 | self.features. N x dim
80 | """
81 | dist = torch.norm(self.features - action.cpu(), dim=1)
82 | knn = (-1 * dist).topk(dist.shape[0])
83 | min_dist, min_idx = knn.values, knn.indices
84 | count = 0
85 | while True:
86 | if len(self.visited) == dist.shape[0]:
87 | print("cannot find in the dataset")
88 | exit()
89 | if min_idx[count].item() not in self.visited:
90 | self.visited[min_idx[count].item()] = True
91 | break
92 | count += 1
93 |
94 | return self.features[min_idx[count].item()], self.embedding[min_idx[count].item()]['valid_accuracy'], \
95 | self.embedding[min_idx[count].item()]['test_accuracy'], self.embedding[min_idx[count].item()]['time']
96 |
97 |
98 | class Policy(nn.Module):
99 | def __init__(self, hidden_dim1, hidden_dim2):
100 | super(Policy, self).__init__()
101 | self.fc1 = nn.Linear(hidden_dim1, hidden_dim2)
102 | self.fc2 = nn.Linear(hidden_dim2, hidden_dim1)
103 | self.saved_log_probs = []
104 | self.rewards = []
105 |
106 | def forward(self, input):
107 | x = F.relu(self.fc1(input))
108 | out = self.fc2(x)
109 | return out
110 |
111 | class Policy_LSTM(nn.Module):
112 | def __init__(self, hidden_dim1, hidden_dim2):
113 | super(Policy_LSTM, self).__init__()
114 | self.lstm = torch.nn.LSTMCell(input_size=hidden_dim1, hidden_size=hidden_dim2)
115 | self.fc = nn.Linear(hidden_dim2, hidden_dim1)
116 | self.saved_log_probs = []
117 | self.rewards = []
118 | self.hx = None
119 | self.cx = None
120 |
121 | def forward(self, input):
122 | if self.hx is None and self.cx is None:
123 | self.hx, self.cx = self.lstm(input)
124 | else:
125 | self.hx, self.cx = self.lstm(input, (self.hx, self.cx))
126 | mean = self.fc(self.hx)
127 | return mean
128 |
129 | def select_action(state, policy):
130 | """
131 | MVN based action selection.
132 | :param state: 1 x dim
133 | :param policy: policy network
134 | :return: action: 1 x dim
135 | """
136 | mean = policy(state.view(1, state.shape[0]))
137 | mvn = MultivariateNormal(mean, torch.eye(state.shape[0]).cuda())
138 | action = mvn.sample()
139 | policy.saved_log_probs.append(torch.mean(mvn.log_prob(action)))
140 | return action
141 |
142 |
143 | def finish_episode(policy, optimizer):
144 | R = 0
145 | policy_loss = []
146 | returns = []
147 | for r in policy.rewards:
148 | R = r + 0.8 * R
149 | returns.append(R)
150 | returns = torch.Tensor(policy.rewards)
151 | returns = returns - 0.95
152 | for log_prob, R in zip(policy.saved_log_probs, returns):
153 | policy_loss.append(-log_prob * R)
154 |
155 | optimizer.zero_grad()
156 | policy_loss = torch.mean(torch.stack(policy_loss, dim=0))
157 | print("average reward: {}, policy loss: {}".format(sum(policy.rewards)/len(policy.rewards), policy_loss.item()))
158 | policy_loss.backward()
159 | optimizer.step()
160 | del policy.rewards[:]
161 | del policy.saved_log_probs[:]
162 | policy.hx = None
163 | policy.cx = None
164 |
165 |
166 | def reinforce_search(env, args):
167 | """ implementation of arch2vec-REINFORCE """
168 | policy = Policy_LSTM(args.dim, 128).cuda()
169 | optimizer = optim.Adam(policy.parameters(), lr=1e-2)
170 | counter = 0
171 | BEST_VALID_ACC = 0.9505542318026224
172 | BEST_TEST_ACC = 0.943175752957662
173 | MAX_BUDGET = 1.5e6
174 | rt = 0
175 | state, _, _, time = env.get_init_state()
176 | CURR_BEST_VALID = 0
177 | CURR_BEST_TEST = 0
178 | test_trace = []
179 | valid_trace = []
180 | time_trace = []
181 | while rt < MAX_BUDGET:
182 | for c in range(args.bs):
183 | state = state.cuda()
184 | action = select_action(state, policy)
185 | state, reward, reward_test, time = env.step(action)
186 | policy.rewards.append(reward)
187 | counter += 1
188 | rt += time
189 | print('counter: {}, validation reward: {}, test reward: {}, time: {}'.format(counter, reward, reward_test, rt))
190 |
191 | if reward > CURR_BEST_VALID:
192 | CURR_BEST_VALID = reward
193 | CURR_BEST_TEST = reward_test
194 |
195 | valid_trace.append(float(BEST_VALID_ACC - CURR_BEST_VALID))
196 | test_trace.append(float(BEST_TEST_ACC - CURR_BEST_TEST))
197 | time_trace.append(rt)
198 |
199 | if rt >= MAX_BUDGET:
200 | break
201 |
202 | finish_episode(policy, optimizer)
203 |
204 | res = dict()
205 | res['regret_validation'] = valid_trace
206 | res['regret_test'] = test_trace
207 | res['runtime'] = time_trace
208 | save_path = os.path.join(args.output_path, 'dim{}'.format(args.dim))
209 | if not os.path.exists(save_path):
210 | os.mkdir(save_path)
211 | print('save to {}'.format(save_path))
212 | if args.emb_path.endswith('.pt'):
213 | s = args.emb_path[:-3]
214 | fh = open(os.path.join(save_path, 'run_{}_{}.json'.format(args.seed, s)),'w')
215 | json.dump(res, fh)
216 | fh.close()
217 |
218 |
219 |
220 | if __name__ == '__main__':
221 | parser = argparse.ArgumentParser(description="arch2vec-REINFORCE")
222 | parser.add_argument("--gamma", type=float, default=0, help="discount factor (default 0.99)")
223 | parser.add_argument("--seed", type=int, default=1, help="random seed")
224 | parser.add_argument('--cfg', type=int, default=4, help='configuration (default: 4)')
225 | parser.add_argument('--bs', type=int, default=16, help='batch size')
226 | parser.add_argument('--dim', type=int, default=7, help='feature dimension')
227 | parser.add_argument('--output_path', type=str, default='rl', help='rl/bo')
228 | parser.add_argument('--emb_path', type=str, default='arch2vec.pt')
229 | parser.add_argument('--model_path', type=str, default='model-nasbench-101.pt')
230 | parser.add_argument('--saved_arch2vec', action="store_true", default=False)
231 | args = parser.parse_args()
232 | cfg = configs[args.cfg]
233 | env = Env('REINFORCE', args.seed, args.emb_path, args.model_path, cfg, data_path='data/data.json', save=args.saved_arch2vec)
234 | np.random.seed(args.seed)
235 | torch.manual_seed(args.seed)
236 | torch.cuda.manual_seed_all(args.seed)
237 | torch.set_num_threads(2)
238 | reinforce_search(env, args)
239 |
--------------------------------------------------------------------------------
/search_methods/supervised_dngo.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 | import os
4 | import sys
5 | sys.path.insert(0, os.getcwd())
6 | from pybnn.dngo_supervised import DNGO
7 | import json
8 | import argparse
9 | from collections import defaultdict
10 | from torch.distributions import Normal
11 |
12 | def extract_data(dataset):
13 | with open(dataset) as f:
14 | data = json.load(f)
15 | X_adj = [torch.Tensor(data[str(ind)]['module_adjacency']) for ind in range(len(data))]
16 | X_ops = [torch.Tensor(data[str(ind)]['module_operations']) for ind in range(len(data))]
17 | Y = [data[str(ind)]['validation_accuracy'] for ind in range(len(data))]
18 | Y_test = [data[str(ind)]['test_accuracy'] for ind in range(len(data))]
19 | training_time = [data[str(ind)]['training_time'] for ind in range(len(data))]
20 | X_adj = torch.stack(X_adj, dim=0)
21 | X_ops = torch.stack(X_ops, dim=0)
22 | Y = torch.Tensor(Y)
23 | Y_test = torch.Tensor(Y_test)
24 | training_time = torch.Tensor(training_time)
25 | rand_ind = torch.randperm(X_ops.shape[0])
26 | X_adj = X_adj[rand_ind]
27 | X_ops = X_ops[rand_ind]
28 | Y = Y[rand_ind]
29 | Y_test = Y_test[rand_ind]
30 | training_time = training_time[rand_ind]
31 | print('loading finished. input adj shape {}, input ops shape {} and valid labels shape {}, and test labels shape {}'.format(X_adj.shape, X_ops.shape, Y.shape, Y_test.shape))
32 | return X_adj, X_ops, Y, Y_test, training_time
33 |
34 | def get_init_samples(X_adj, X_ops, Y, Y_test, training_time, visited):
35 | np.random.seed(args.seed)
36 | init_inds = np.random.permutation(list(range(X_ops.shape[0])))[:args.init_size]
37 | init_inds = torch.Tensor(init_inds).long()
38 | init_x_adj_samples = X_adj[init_inds]
39 | init_x_ops_samples = X_ops[init_inds]
40 | init_valid_label_samples = Y[init_inds]
41 | init_test_label_samples = Y_test[init_inds]
42 | init_time_samples = training_time[init_inds]
43 | for idx in init_inds:
44 | visited[idx.item()] = True
45 | return init_x_adj_samples, init_x_ops_samples, init_valid_label_samples, init_test_label_samples, init_time_samples, visited
46 |
47 |
48 | def propose_location(ei, X_adj, X_ops, valid_labels, test_labels, training_time, visited):
49 | k = args.topk
50 | count = 0
51 | print('remaining length of indices set:', len(X_adj) - len(visited))
52 | indices = torch.argsort(ei)
53 | ind_dedup = []
54 | # remove random sampled indices at each step
55 | for idx in reversed(indices):
56 | if count == k:
57 | break
58 | if idx.item() not in visited:
59 | visited[idx.item()] = True
60 | ind_dedup.append(idx.item())
61 | count += 1
62 | ind_dedup = torch.Tensor(ind_dedup).long()
63 | proposed_x_adj, proposed_x_ops, proposed_y_valid, proposed_y_test, proposed_time = X_adj[ind_dedup], X_ops[ind_dedup], valid_labels[ind_dedup], test_labels[ind_dedup], training_time[ind_dedup]
64 | return proposed_x_adj, proposed_x_ops, proposed_y_valid, proposed_y_test, proposed_time, visited
65 |
66 |
67 | def supervised_encoding_search(X_adj, X_ops, Y, Y_test, training_time):
68 | """implementation of supervised learning based BO search"""
69 | BEST_TEST_ACC = 0.943175752957662
70 | BEST_VALID_ACC = 0.9505542318026224
71 | CURR_BEST_VALID = 0.
72 | CURR_BEST_TEST = 0.
73 | MAX_BUDGET = 1.5e6
74 | counter = 0
75 | rt = 0.
76 | best_trace = defaultdict(list)
77 | window_size = 512
78 | visited = {}
79 | X_adj_sample, X_ops_sample, Y_sample, Y_sample_test, time_sample, visited = get_init_samples(X_adj, X_ops, Y, Y_test, training_time, visited)
80 |
81 | for x_adj, x_ops, acc_valid, acc_test, t in zip(X_adj_sample, X_ops_sample, Y_sample, Y_sample_test, time_sample):
82 | counter += 1
83 | rt += t.item()
84 | if acc_valid > CURR_BEST_VALID:
85 | CURR_BEST_VALID = acc_valid
86 | CURR_BEST_TEST = acc_test
87 | best_trace['regret_validation'].append(float(BEST_VALID_ACC - CURR_BEST_VALID))
88 | best_trace['regret_test'].append(float(BEST_TEST_ACC - CURR_BEST_TEST))
89 | best_trace['time'].append(rt)
90 | best_trace['counter'].append(counter)
91 |
92 | while rt < MAX_BUDGET:
93 | print("data adjacent matrix samples:", X_adj_sample.shape)
94 | print("data operations matrix samples:", X_ops_sample.shape)
95 | print("valid label_samples:", Y_sample.shape)
96 | print("test label samples:", Y_sample_test.shape)
97 | print("current best validation: {}".format(CURR_BEST_VALID))
98 | print("current best test: {}".format(CURR_BEST_TEST))
99 | print("rt: {}".format(rt))
100 | model = DNGO(num_epochs=100, input_dim=5, hidden_dim=128, latent_dim=args.dim, num_hops=5, num_mlp_layers=2, do_mcmc=False, normalize_output=False)
101 | model.train(X_adj_sample.numpy(), X_ops_sample.numpy(), Y_sample.view(-1).numpy(), do_optimize=True)
102 | m = []
103 | v = []
104 | chunks = int(X_adj.shape[0] / window_size)
105 | if X_adj.shape[0] % window_size > 0:
106 | chunks += 1
107 | X_adj_split = torch.split(X_adj, window_size, dim=0)
108 | X_ops_split = torch.split(X_ops, window_size, dim=0)
109 | for i in range(chunks):
110 | inputs_adj = X_adj_split[i]
111 | inputs_ops = X_ops_split[i]
112 | m_split, v_split = model.predict(inputs_ops.numpy(), inputs_adj.numpy())
113 | m.extend(list(m_split))
114 | v.extend(list(v_split))
115 | mean = torch.Tensor(m)
116 | sigma = torch.Tensor(v)
117 | u = mean - torch.Tensor([0.95]).expand_as(mean) / sigma
118 | normal = Normal(torch.zeros_like(u), torch.ones_like(u))
119 | ucdf = normal.cdf(u)
120 | updf = torch.exp(normal.log_prob(u))
121 | ei = sigma * (updf + u * ucdf)
122 |
123 | X_adj_next, X_ops_next, label_next_valid, label_next_test, time_next, visited = propose_location(ei, X_adj, X_ops, Y, Y_test, training_time, visited)
124 |
125 | # add proposed networks to selected networks
126 | for x_adj, x_ops, acc_valid, acc_test, t in zip(X_adj_next, X_ops_next, label_next_valid, label_next_test, time_next):
127 | X_adj_sample = torch.cat((X_adj_sample, x_adj.view(1, 7, 7)), dim=0)
128 | X_ops_sample = torch.cat((X_ops_sample, x_ops.view(1, 7, 5)), dim=0)
129 | Y_sample = torch.cat((Y_sample.view(-1, 1), acc_valid.view(1, 1)), dim=0)
130 | Y_sample_test = torch.cat((Y_sample_test.view(-1, 1), acc_test.view(1, 1)), dim=0)
131 | counter += 1
132 | rt += t.item()
133 | if acc_valid > CURR_BEST_VALID:
134 | CURR_BEST_VALID = acc_valid
135 | CURR_BEST_TEST = acc_test
136 |
137 | best_trace['regret_validation'].append(float(BEST_VALID_ACC - CURR_BEST_VALID))
138 | best_trace['regret_test'].append(float(BEST_TEST_ACC - CURR_BEST_TEST))
139 | best_trace['time'].append(rt)
140 | best_trace['counter'].append(counter)
141 |
142 | if rt >= MAX_BUDGET:
143 | break
144 |
145 | res = dict()
146 | res['regret_validation'] = best_trace['regret_validation']
147 | res['regret_test'] = best_trace['regret_test']
148 | res['runtime'] = best_trace['time']
149 | res['counter'] = best_trace['counter']
150 | save_path = os.path.join(args.output_path, 'dim{}'.format(args.dim))
151 | if not os.path.exists(save_path):
152 | os.mkdir(save_path)
153 | print('save to {}'.format(save_path))
154 | fh = open(os.path.join(save_path, 'run_{}_{}.json'.format(args.seed, args.benchmark)), 'w')
155 | json.dump(res, fh)
156 | fh.close()
157 |
158 |
159 |
160 | if __name__ == '__main__':
161 | parser = argparse.ArgumentParser(description="Supervised DNGO search")
162 | parser.add_argument("--seed", type=int, default=1, help="random seed")
163 | parser.add_argument('--dim', type=int, default=16, help='feature dimension')
164 | parser.add_argument('--init_size', type=int, default=16, help='init samples')
165 | parser.add_argument('--topk', type=int, default=5, help='acquisition samples')
166 | parser.add_argument('--benchmark', type=str, default='supervised_dngo')
167 | parser.add_argument('--output_path', type=str, default='saved_logs/bo', help='rl/bo (default: bo)')
168 | args = parser.parse_args()
169 | torch.manual_seed(args.seed)
170 | data_path = 'data/data.json'
171 | X_adj, X_ops, Y, Y_test, training_time = extract_data(data_path)
172 | supervised_encoding_search(X_adj, X_ops, Y, Y_test, training_time)
173 |
--------------------------------------------------------------------------------