├── .gitignore
├── images
├── architecture.png
└── param_time_acc.png
├── nasbench
├── __init__.py
├── lib
│ ├── __init__.py
│ ├── model_metrics.proto
│ ├── config.py
│ ├── graph_util.py
│ ├── base_ops.py
│ ├── model_spec.py
│ ├── cifar.py
│ ├── model_metrics_pb2.py
│ ├── training_time.py
│ ├── evaluate.py
│ └── model_builder.py
├── scripts
│ ├── __init__.py
│ ├── augment_model.py
│ ├── generate_cifar10_tfrecords.py
│ ├── generate_graphs.py
│ └── run_evaluation.py
├── tests
│ ├── model_builder_test.py
│ ├── model_spec_test.py
│ ├── run_evaluation_test.py
│ └── graph_util_test.py
└── api.py
├── setup.py
├── CONTRIBUTING.md
├── example.py
├── README.md
└── LICENSE
/.gitignore:
--------------------------------------------------------------------------------
1 | venv/
2 | nasbench.egg-info/
3 | *.pyc
4 |
--------------------------------------------------------------------------------
/images/architecture.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/google-research/nasbench/HEAD/images/architecture.png
--------------------------------------------------------------------------------
/images/param_time_acc.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/google-research/nasbench/HEAD/images/param_time_acc.png
--------------------------------------------------------------------------------
/nasbench/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2019 The Google Research Authors.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 |
--------------------------------------------------------------------------------
/nasbench/lib/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2019 The Google Research Authors.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 |
--------------------------------------------------------------------------------
/nasbench/scripts/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2019 The Google Research Authors.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 |
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | # Copyright 2019 The Google Research Authors.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Setup the nasbench library.
16 |
17 | This file is automatically run as part of `pip install -e .`
18 | """
19 |
20 | import setuptools
21 |
22 | setuptools.setup(
23 | name='nasbench',
24 | version='1.0',
25 | packages=setuptools.find_packages(),
26 | install_requires=[
27 | 'tensorflow>=1.12.0',
28 | ]
29 | )
30 |
--------------------------------------------------------------------------------
/CONTRIBUTING.md:
--------------------------------------------------------------------------------
1 | # How to Contribute
2 |
3 | # Pull Requests
4 |
5 | Please send in fixes or feature additions through Pull Requests.
6 |
7 | ## Contributor License Agreement
8 |
9 | Contributions to this project must be accompanied by a Contributor License
10 | Agreement. You (or your employer) retain the copyright to your contribution,
11 | this simply gives us permission to use and redistribute your contributions as
12 | part of the project. Head over to to see
13 | your current agreements on file or to sign a new one.
14 |
15 | You generally only need to submit a CLA once, so if you've already submitted one
16 | (even if it was for a different project), you probably don't need to do it
17 | again.
18 |
19 | ## Code reviews
20 |
21 | All submissions, including submissions by project members, require review. We
22 | use GitHub pull requests for this purpose. Consult
23 | [GitHub Help](https://help.github.com/articles/about-pull-requests/) for more
24 | information on using pull requests.
25 |
--------------------------------------------------------------------------------
/nasbench/lib/model_metrics.proto:
--------------------------------------------------------------------------------
1 | // Metrics stored per evaluation of each ModelSpec.
2 | // NOTE: this file is for reference only, changes to this file will not affect
3 | // the code unless you compile the proto using protoc, which can be installed
4 | // from https://github.com/protocolbuffers/protobuf/releases.
5 | syntax = "proto2";
6 |
7 | package nasbench;
8 |
9 | message ModelMetrics {
10 | // Metrics that are evaluated at each checkpoint. Each ModelMetrics will
11 | // contain multiple EvaluationData messages evaluated at various points during
12 | // training, including the initialization before any steps are taken.
13 | repeated EvaluationData evaluation_data = 1;
14 |
15 | // Other fixed metrics (does not change over training) go here.
16 |
17 | // Parameter count of all trainable variables.
18 | optional int32 trainable_parameters = 2;
19 |
20 | // Total time for all training and evaluation (mostly used for diagnostic
21 | // purposes).
22 | optional double total_time = 3;
23 | }
24 |
25 | message EvaluationData {
26 | // Current epoch at the time of this evaluation.
27 | optional double current_epoch = 1;
28 |
29 | // Training time in seconds up to this point. Does not include evaluation
30 | // time.
31 | optional double training_time = 2;
32 |
33 | // Accuracy on a fixed 10,000 images from the train set.
34 | optional double train_accuracy = 3;
35 |
36 | // Accuracy on a held-out validation set of 10,000 images.
37 | optional double validation_accuracy = 4;
38 |
39 | // Accuracy on the test set of 10,000 images.
40 | optional double test_accuracy = 5;
41 |
42 | // Location of checkpoint file. Note: checkpoint_path will look like
43 | // /path/to/model_dir/model.ckpt-1234 but the actual checkpoint files may have
44 | // an extra ".data", ".index", ".meta" suffix. For purposes of loading a
45 | // checkpoint file in TensorFlow, the path without the suffix is sufficient.
46 | // This field may be left blank because the checkpoint can be programmatically
47 | // generated from the model specifications.
48 | optional string checkpoint_path = 6;
49 |
50 | // Additional sample metrics like gradient norms and covariance are too large
51 | // to store in file, so they need to be queried along with the checkpoints
52 | // from GCS directly.
53 | }
54 |
55 |
56 |
--------------------------------------------------------------------------------
/example.py:
--------------------------------------------------------------------------------
1 | # Copyright 2019 The Google Research Authors.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Runnable example, as shown in the README.md."""
16 |
17 | from __future__ import absolute_import
18 | from __future__ import division
19 | from __future__ import print_function
20 |
21 | from absl import app
22 | from nasbench import api
23 |
24 | # Replace this string with the path to the downloaded nasbench.tfrecord before
25 | # executing.
26 | NASBENCH_TFRECORD = '/path/to/nasbench.tfrecord'
27 |
28 | INPUT = 'input'
29 | OUTPUT = 'output'
30 | CONV1X1 = 'conv1x1-bn-relu'
31 | CONV3X3 = 'conv3x3-bn-relu'
32 | MAXPOOL3X3 = 'maxpool3x3'
33 |
34 |
35 | def main(argv):
36 | del argv # Unused
37 |
38 | # Load the data from file (this will take some time)
39 | nasbench = api.NASBench(NASBENCH_TFRECORD)
40 |
41 | # Create an Inception-like module (5x5 convolution replaced with two 3x3
42 | # convolutions).
43 | model_spec = api.ModelSpec(
44 | # Adjacency matrix of the module
45 | matrix=[[0, 1, 1, 1, 0, 1, 0], # input layer
46 | [0, 0, 0, 0, 0, 0, 1], # 1x1 conv
47 | [0, 0, 0, 0, 0, 0, 1], # 3x3 conv
48 | [0, 0, 0, 0, 1, 0, 0], # 5x5 conv (replaced by two 3x3's)
49 | [0, 0, 0, 0, 0, 0, 1], # 5x5 conv (replaced by two 3x3's)
50 | [0, 0, 0, 0, 0, 0, 1], # 3x3 max-pool
51 | [0, 0, 0, 0, 0, 0, 0]], # output layer
52 | # Operations at the vertices of the module, matches order of matrix
53 | ops=[INPUT, CONV1X1, CONV3X3, CONV3X3, CONV3X3, MAXPOOL3X3, OUTPUT])
54 |
55 | # Query this model from dataset, returns a dictionary containing the metrics
56 | # associated with this model.
57 | print('Querying an Inception-like model.')
58 | data = nasbench.query(model_spec)
59 | print(data)
60 | print(nasbench.get_budget_counters()) # prints (total time, total epochs)
61 |
62 | # Get all metrics (all epoch lengths, all repeats) associated with this
63 | # model_spec. This should be used for dataset analysis and NOT for
64 | # benchmarking algorithms (does not increment budget counters).
65 | print('\nGetting all metrics for the same Inception-like model.')
66 | fixed_metrics, computed_metrics = nasbench.get_metrics_from_spec(model_spec)
67 | print(fixed_metrics)
68 | for epochs in nasbench.valid_epochs:
69 | for repeat_index in range(len(computed_metrics[epochs])):
70 | data_point = computed_metrics[epochs][repeat_index]
71 | print('Epochs trained %d, repeat number: %d' % (epochs, repeat_index + 1))
72 | print(data_point)
73 |
74 | # Iterate through unique models in the dataset. Models are unqiuely identified
75 | # by a hash.
76 | print('\nIterating over unique models in the dataset.')
77 | for unique_hash in nasbench.hash_iterator():
78 | fixed_metrics, computed_metrics = nasbench.get_metrics_from_hash(
79 | unique_hash)
80 | print(fixed_metrics)
81 |
82 | # For demo purposes, break here instead of iterating through whole set.
83 | break
84 |
85 |
86 | # If you are passing command line flags to modify the default config values, you
87 | # must use app.run(main)
88 | if __name__ == '__main__':
89 | app.run(main)
90 |
--------------------------------------------------------------------------------
/nasbench/lib/config.py:
--------------------------------------------------------------------------------
1 | # Copyright 2019 The Google Research Authors.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Configuration flags."""
16 |
17 | from __future__ import absolute_import
18 | from __future__ import division
19 | from __future__ import print_function
20 |
21 | from absl import flags
22 |
23 | FLAGS = flags.FLAGS
24 |
25 | # Data flags (only required for generating the dataset)
26 | flags.DEFINE_list(
27 | 'train_data_files', [],
28 | 'Training data files in TFRecord format. Multiple files can be passed in a'
29 | ' comma-separated list. The first file in the list will be used for'
30 | ' computing the training error.')
31 | flags.DEFINE_string(
32 | 'valid_data_file', '', 'Validation data in TFRecord format.')
33 | flags.DEFINE_string(
34 | 'test_data_file', '', 'Testing data in TFRecord format.')
35 | flags.DEFINE_string(
36 | 'sample_data_file', '', 'Sampled batch data in TFRecord format.')
37 | flags.DEFINE_string(
38 | 'data_format', 'channels_last',
39 | 'Data format, one of [channels_last, channels_first] for NHWC and NCHW'
40 | ' tensor formats respectively.')
41 | flags.DEFINE_integer(
42 | 'num_labels', 10, 'Number of input class labels.')
43 |
44 | # Search space parameters.
45 | flags.DEFINE_integer(
46 | 'module_vertices', 7,
47 | 'Number of vertices in module matrix, including input and output.')
48 | flags.DEFINE_integer(
49 | 'max_edges', 9,
50 | 'Maximum number of edges in the module matrix.')
51 | flags.DEFINE_list(
52 | 'available_ops', ['conv3x3-bn-relu', 'conv1x1-bn-relu', 'maxpool3x3'],
53 | 'Available op labels, see base_ops.py for full list of ops.')
54 |
55 | # Model hyperparameters. The default values are exactly what is used during the
56 | # exhaustive evaluation of all models.
57 | flags.DEFINE_integer(
58 | 'stem_filter_size', 128, 'Filter size after stem convolutions.')
59 | flags.DEFINE_integer(
60 | 'num_stacks', 3, 'Number of stacks of modules.')
61 | flags.DEFINE_integer(
62 | 'num_modules_per_stack', 3, 'Number of modules per stack.')
63 | flags.DEFINE_integer(
64 | 'batch_size', 256, 'Training batch size.')
65 | flags.DEFINE_integer(
66 | 'train_epochs', 108,
67 | 'Maximum training epochs. If --train_seconds is reached first, training'
68 | ' may not reach --train_epochs.')
69 | flags.DEFINE_float(
70 | 'train_seconds', 4.0 * 60 * 60,
71 | 'Maximum training seconds. If --train_epochs is reached first, training'
72 | ' may not reach --train_seconds. Used as safeguard against stalled jobs.'
73 | ' If train_seconds is 0.0, no time limit will be used.')
74 | flags.DEFINE_float(
75 | 'learning_rate', 0.1,
76 | 'Base learning rate. Linearly scaled by --tpu_num_shards.')
77 | flags.DEFINE_string(
78 | 'lr_decay_method', 'COSINE_BY_STEP',
79 | '[COSINE_BY_TIME, COSINE_BY_STEP, STEPWISE], see model_builder.py for full'
80 | ' list of decay methods.')
81 | flags.DEFINE_float(
82 | 'momentum', 0.9, 'Momentum.')
83 | flags.DEFINE_float(
84 | 'weight_decay', 1e-4, 'L2 regularization weight.')
85 | flags.DEFINE_integer(
86 | 'max_attempts', 5,
87 | 'Maximum number of times to try training and evaluating an individual'
88 | ' before aborting.')
89 | flags.DEFINE_list(
90 | 'intermediate_evaluations', ['0.5'],
91 | 'Intermediate evaluations relative to --train_epochs. For example, to'
92 | ' evaluate the model at 1/4, 1/2, 3/4 of the total epochs, use [0.25, 0.5,'
93 | ' 0.75]. An evaluation is always done at the start and end of training.')
94 | flags.DEFINE_integer(
95 | 'num_repeats', 3,
96 | 'Number of repeats evaluated for each model in the space.')
97 |
98 | # TPU flags
99 | flags.DEFINE_bool(
100 | 'use_tpu', True, 'Use TPUs for train and evaluation.')
101 | flags.DEFINE_integer(
102 | 'tpu_iterations_per_loop', 100, 'Iterations per loop of TPU execution.')
103 | flags.DEFINE_integer(
104 | 'tpu_num_shards', 2,
105 | 'Number of TPU shards, a single TPU chip has 2 shards.')
106 |
107 |
108 | def build_config():
109 | """Build config from flags defined in this module."""
110 | config = {
111 | flag.name: flag.value
112 | for flag in FLAGS.flags_by_module_dict()[__name__]
113 | }
114 |
115 | return config
116 |
--------------------------------------------------------------------------------
/nasbench/scripts/augment_model.py:
--------------------------------------------------------------------------------
1 | # Copyright 2019 The Google Research Authors.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Augments one model with longer training and evaluates on test set."""
16 |
17 | from __future__ import absolute_import
18 | from __future__ import division
19 | from __future__ import print_function
20 |
21 | from nasbench.lib import config as _config
22 | from nasbench.lib import evaluate
23 | from nasbench.lib import model_spec
24 | import numpy as np
25 | import tensorflow as tf # Used for app, flags, logging
26 |
27 | tf.flags.DEFINE_string('model_dir', '', 'model directory')
28 | FLAGS = tf.flags.FLAGS
29 |
30 |
31 | def create_resnet20_spec(config):
32 | """Construct a ResNet-20-like spec.
33 |
34 | The main difference is that there is an extra projection layer before the
35 | conv3x3 whereas the original ResNet doesn't have this. This increases the
36 | parameter count of this version slightly.
37 |
38 | Args:
39 | config: config dict created by config.py.
40 |
41 | Returns:
42 | ModelSpec object.
43 | """
44 | spec = model_spec.ModelSpec(
45 | np.array([[0, 1, 0, 1],
46 | [0, 0, 1, 0],
47 | [0, 0, 0, 1],
48 | [0, 0, 0, 0]]),
49 | ['input', 'conv3x3-bn-relu', 'conv3x3-bn-relu', 'output'])
50 | config['num_stacks'] = 3
51 | config['num_modules_per_stack'] = 3
52 | config['stem_filter_size'] = 16
53 | return spec
54 |
55 |
56 | def create_resnet50_spec(config):
57 | """Construct a ResNet-50-like spec.
58 |
59 | The main difference is that there is an extra projection layer before the
60 | conv1x1 whereas the original ResNet doesn't have this. This increases the
61 | parameter count of this version slightly.
62 |
63 | Args:
64 | config: config dict created by config.py.
65 |
66 | Returns:
67 | ModelSpec object.
68 | """
69 | spec = model_spec.ModelSpec(
70 | np.array([[0, 1, 1],
71 | [0, 0, 1],
72 | [0, 0, 0]]),
73 | ['input', 'bottleneck3x3', 'output'])
74 | config['num_stacks'] = 3
75 | config['num_modules_per_stack'] = 6
76 | config['stem_filter_size'] = 128
77 | return spec
78 |
79 |
80 | def create_inception_resnet_spec(config):
81 | """Construct an Inception-ResNet like spec.
82 |
83 | This spec is very similar to the InceptionV2 module with an added
84 | residual connection except that there is an extra projection in front of the
85 | max pool. The overall network filter counts and module counts do not match
86 | the actual source model.
87 |
88 | Args:
89 | config: config dict created by config.py.
90 |
91 | Returns:
92 | ModelSpec object.
93 | """
94 | spec = model_spec.ModelSpec(
95 | np.array([[0, 1, 1, 1, 0, 1, 1],
96 | [0, 0, 0, 0, 0, 0, 1],
97 | [0, 0, 0, 0, 0, 0, 1],
98 | [0, 0, 0, 0, 1, 0, 0],
99 | [0, 0, 0, 0, 0, 0, 1],
100 | [0, 0, 0, 0, 0, 0, 1],
101 | [0, 0, 0, 0, 0, 0, 0]]),
102 | ['input', 'conv1x1-bn-relu', 'conv3x3-bn-relu', 'conv3x3-bn-relu',
103 | 'conv3x3-bn-relu', 'maxpool3x3', 'output'])
104 | config['num_stacks'] = 3
105 | config['num_modules_per_stack'] = 3
106 | config['stem_filter_size'] = 128
107 | return spec
108 |
109 |
110 | def create_best_nasbench_spec(config):
111 | """Construct the best spec in the NASBench dataset w.r.t. mean test accuracy.
112 |
113 | Args:
114 | config: config dict created by config.py.
115 |
116 | Returns:
117 | ModelSpec object.
118 | """
119 | spec = model_spec.ModelSpec(
120 | np.array([[0, 1, 1, 0, 0, 1, 1],
121 | [0, 0, 0, 0, 0, 1, 0],
122 | [0, 0, 0, 1, 0, 0, 0],
123 | [0, 0, 0, 0, 1, 0, 0],
124 | [0, 0, 0, 0, 0, 1, 0],
125 | [0, 0, 0, 0, 0, 0, 1],
126 | [0, 0, 0, 0, 0, 0, 0]]),
127 | ['input', 'conv1x1-bn-relu', 'conv3x3-bn-relu', 'maxpool3x3',
128 | 'conv3x3-bn-relu', 'conv3x3-bn-relu', 'output'])
129 | config['num_stacks'] = 3
130 | config['num_modules_per_stack'] = 3
131 | config['stem_filter_size'] = 128
132 | return spec
133 |
134 |
135 | def main(_):
136 | config = _config.build_config()
137 |
138 | # The default settings in config are exactly what was used to generate the
139 | # dataset of models. However, given more epochs and a different learning rate
140 | # schedule, it is possible to get higher accuracy.
141 | config['train_epochs'] = 200
142 | config['lr_decay_method'] = 'STEPWISE'
143 | config['train_seconds'] = -1 # Disable training time limit
144 | spec = create_best_nasbench_spec(config)
145 |
146 | data = evaluate.augment_and_evaluate(spec, config, FLAGS.model_dir)
147 | tf.logging.info(data)
148 |
149 |
150 | if __name__ == '__main__':
151 | tf.app.run(main)
152 |
--------------------------------------------------------------------------------
/nasbench/tests/model_builder_test.py:
--------------------------------------------------------------------------------
1 | # Copyright 2019 The Google Research Authors.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Tests for lib/model_builder.py."""
16 |
17 | from __future__ import absolute_import
18 | from __future__ import division
19 | from __future__ import print_function
20 |
21 | from nasbench.lib import model_builder
22 | import numpy as np
23 | import tensorflow as tf
24 |
25 |
26 | class ModelBuilderTest(tf.test.TestCase):
27 |
28 | def test_compute_vertex_channels_linear(self):
29 | """Tests modules with no branching."""
30 | matrix1 = np.array([[0, 1, 0, 0],
31 | [0, 0, 1, 0],
32 | [0, 0, 0, 1],
33 | [0, 0, 0, 0]])
34 | vc1 = model_builder.compute_vertex_channels(8, 8, matrix1)
35 | assert vc1 == [8, 8, 8, 8]
36 |
37 | vc2 = model_builder.compute_vertex_channels(8, 16, matrix1)
38 | assert vc2 == [8, 16, 16, 16]
39 |
40 | vc3 = model_builder.compute_vertex_channels(16, 8, matrix1)
41 | assert vc3 == [16, 8, 8, 8]
42 |
43 | matrix2 = np.array([[0, 1],
44 | [0, 0]])
45 | vc4 = model_builder.compute_vertex_channels(1, 1, matrix2)
46 | assert vc4 == [1, 1]
47 |
48 | vc5 = model_builder.compute_vertex_channels(1, 5, matrix2)
49 | assert vc5 == [1, 5]
50 |
51 | vc5 = model_builder.compute_vertex_channels(5, 1, matrix2)
52 | assert vc5 == [5, 1]
53 |
54 | def test_compute_vertex_channels_no_output_branch(self):
55 | """Tests modules that branch but not at the output vertex."""
56 | matrix1 = np.array([[0, 1, 1, 0, 0],
57 | [0, 0, 0, 1, 0],
58 | [0, 0, 0, 1, 0],
59 | [0, 0, 0, 0, 1],
60 | [0, 0, 0, 0, 0]])
61 | vc1 = model_builder.compute_vertex_channels(8, 8, matrix1)
62 | assert vc1 == [8, 8, 8, 8, 8]
63 |
64 | vc2 = model_builder.compute_vertex_channels(8, 16, matrix1)
65 | assert vc2 == [8, 16, 16, 16, 16]
66 |
67 | vc3 = model_builder.compute_vertex_channels(16, 8, matrix1)
68 | assert vc3 == [16, 8, 8, 8, 8]
69 |
70 | def test_compute_vertex_channels_output_branching(self):
71 | """Tests modules that branch at output."""
72 | matrix1 = np.array([[0, 1, 1, 0],
73 | [0, 0, 0, 1],
74 | [0, 0, 0, 1],
75 | [0, 0, 0, 0]])
76 | vc1 = model_builder.compute_vertex_channels(8, 8, matrix1)
77 | assert vc1 == [8, 4, 4, 8]
78 |
79 | vc2 = model_builder.compute_vertex_channels(8, 16, matrix1)
80 | assert vc2 == [8, 8, 8, 16]
81 |
82 | vc3 = model_builder.compute_vertex_channels(16, 8, matrix1)
83 | assert vc3 == [16, 4, 4, 8]
84 |
85 | vc4 = model_builder.compute_vertex_channels(8, 15, matrix1)
86 | assert vc4 == [8, 8, 7, 15]
87 |
88 | matrix2 = np.array([[0, 1, 1, 1, 0],
89 | [0, 0, 0, 0, 1],
90 | [0, 0, 0, 0, 1],
91 | [0, 0, 0, 0, 1],
92 | [0, 0, 0, 0, 0]])
93 | vc5 = model_builder.compute_vertex_channels(8, 8, matrix2)
94 | assert vc5 == [8, 3, 3, 2, 8]
95 |
96 | vc6 = model_builder.compute_vertex_channels(8, 15, matrix2)
97 | assert vc6 == [8, 5, 5, 5, 15]
98 |
99 | def test_compute_vertex_channels_max(self):
100 | """Tests modules where some vertices take the max channels of neighbors."""
101 | matrix1 = np.array([[0, 1, 0, 0, 0],
102 | [0, 0, 1, 1, 0],
103 | [0, 0, 0, 0, 1],
104 | [0, 0, 0, 0, 1],
105 | [0, 0, 0, 0, 0]])
106 | vc1 = model_builder.compute_vertex_channels(8, 8, matrix1)
107 | assert vc1 == [8, 4, 4, 4, 8]
108 |
109 | vc2 = model_builder.compute_vertex_channels(8, 9, matrix1)
110 | assert vc2 == [8, 5, 5, 4, 9]
111 |
112 | matrix2 = np.array([[0, 1, 0, 1, 0],
113 | [0, 0, 1, 0, 1],
114 | [0, 0, 0, 1, 0],
115 | [0, 0, 0, 0, 1],
116 | [0, 0, 0, 0, 0]])
117 |
118 | vc3 = model_builder.compute_vertex_channels(8, 8, matrix2)
119 | assert vc3 == [8, 4, 4, 4, 8]
120 |
121 | vc4 = model_builder.compute_vertex_channels(8, 15, matrix2)
122 | assert vc4 == [8, 8, 7, 7, 15]
123 |
124 | def test_covariance_matrix_against_numpy(self):
125 | """Tests that the TF implementation of covariance matrix matchs np.cov."""
126 |
127 | # Randomized test 100 times
128 | for _ in range(100):
129 | batch = np.random.randint(50, 150)
130 | features = np.random.randint(500, 1500)
131 | matrix = np.random.random((batch, features))
132 |
133 | tf_matrix = tf.constant(matrix, dtype=tf.float32)
134 | tf_cov_tensor = model_builder._covariance_matrix(tf_matrix)
135 |
136 | with tf.Session() as sess:
137 | tf_cov = sess.run(tf_cov_tensor)
138 |
139 | np_cov = np.cov(matrix)
140 | np.testing.assert_array_almost_equal(tf_cov, np_cov)
141 |
142 |
143 | if __name__ == '__main__':
144 | tf.test.main()
145 |
--------------------------------------------------------------------------------
/nasbench/scripts/generate_cifar10_tfrecords.py:
--------------------------------------------------------------------------------
1 | # Copyright 2019 The Google Research Authors.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Read CIFAR-10 data from pickled numpy arrays and writes TFRecords.
16 |
17 | Generates tf.train.Example protos and writes them to TFRecord files from the
18 | python version of the CIFAR-10 dataset downloaded from
19 | https://www.cs.toronto.edu/~kriz/cifar.html.
20 |
21 | Based on script from
22 | https://github.com/tensorflow/models/blob/master/tutorials/image/cifar10_estimator/generate_cifar10_tfrecords.py
23 |
24 | To run:
25 | python generate_cifar10_tfrecords.py --data_dir=/tmp/cifar-tfrecord
26 | """
27 |
28 | from __future__ import absolute_import
29 | from __future__ import division
30 | from __future__ import print_function
31 |
32 | import argparse
33 | import os
34 | import sys
35 |
36 | import tarfile
37 | from six.moves import cPickle as pickle
38 | import tensorflow as tf
39 |
40 | CIFAR_FILENAME = 'cifar-10-python.tar.gz'
41 | CIFAR_DOWNLOAD_URL = 'https://www.cs.toronto.edu/~kriz/' + CIFAR_FILENAME
42 | CIFAR_LOCAL_FOLDER = 'cifar-10-batches-py'
43 |
44 |
45 | def download_and_extract(data_dir):
46 | # download CIFAR-10 if not already downloaded.
47 | tf.contrib.learn.datasets.base.maybe_download(CIFAR_FILENAME, data_dir,
48 | CIFAR_DOWNLOAD_URL)
49 | tarfile.open(os.path.join(data_dir, CIFAR_FILENAME),
50 | 'r:gz').extractall(data_dir)
51 |
52 |
53 | def _int64_feature(value):
54 | return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))
55 |
56 |
57 | def _bytes_feature(value):
58 | return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))
59 |
60 |
61 | def _get_file_names():
62 | """Returns the file names expected to exist in the input_dir."""
63 | file_names = {}
64 | for i in range(1, 5):
65 | file_names['train_%d' % i] = 'data_batch_%d' % i
66 | file_names['validation'] = 'data_batch_5'
67 | file_names['test'] = 'test_batch'
68 | return file_names
69 |
70 |
71 | def read_pickle_from_file(filename):
72 | with tf.gfile.Open(filename, 'rb') as f:
73 | if sys.version_info >= (3, 0):
74 | data_dict = pickle.load(f, encoding='bytes')
75 | else:
76 | data_dict = pickle.load(f)
77 | return data_dict
78 |
79 |
80 | def convert_to_tfrecord(input_file, output_file):
81 | """Converts a file to TFRecords."""
82 | print('Generating %s' % output_file)
83 | with tf.python_io.TFRecordWriter(output_file) as record_writer:
84 | data_dict = read_pickle_from_file(input_file)
85 | data = data_dict[b'data']
86 | labels = data_dict[b'labels']
87 | num_entries_in_batch = len(labels)
88 | print('Converting %d images' % num_entries_in_batch)
89 | for i in range(num_entries_in_batch):
90 | example = tf.train.Example(features=tf.train.Features(
91 | feature={
92 | 'image': _bytes_feature(data[i].tobytes()),
93 | 'label': _int64_feature(labels[i])
94 | }))
95 | record_writer.write(example.SerializeToString())
96 |
97 |
98 | def main(data_dir):
99 | print('Download from {} and extract.'.format(CIFAR_DOWNLOAD_URL))
100 | download_and_extract(data_dir)
101 | file_names = _get_file_names()
102 | input_dir = os.path.join(data_dir, CIFAR_LOCAL_FOLDER)
103 | for mode, f in file_names.items():
104 | input_file = os.path.join(input_dir, f)
105 | output_file = os.path.join(data_dir, mode + '.tfrecords')
106 | try:
107 | os.remove(output_file)
108 | except OSError:
109 | pass
110 | # Convert to tf.train.Example and write the to TFRecords.
111 | convert_to_tfrecord(input_file, output_file)
112 |
113 | # Save fixed batch of 100 examples (first 10 of each class sampled at the
114 | # front of the validation set). Ordered by label, i.e. 10 "airplane" images
115 | # followed by 10 "automobile" images...
116 | images = [[] for _ in range(10)]
117 | num_images = 0
118 | input_file = os.path.join(input_dir, file_names['validation'])
119 | data_dict = read_pickle_from_file(input_file)
120 | data = data_dict[b'data']
121 | labels = data_dict[b'labels']
122 | for i in range(len(labels)):
123 | label = labels[i]
124 | if len(images[label]) < 10:
125 | images[label].append(
126 | tf.train.Example(features=tf.train.Features(
127 | feature={
128 | 'image': _bytes_feature(data[i].tobytes()),
129 | 'label': _int64_feature(label)
130 | })))
131 | num_images += 1
132 | if num_images == 100:
133 | break
134 |
135 | output_file = os.path.join(data_dir, 'sample.tfrecords')
136 | print('Generating %s' % output_file)
137 | with tf.python_io.TFRecordWriter(output_file) as record_writer:
138 | for label_images in images:
139 | for example in label_images:
140 | record_writer.write(example.SerializeToString())
141 | print('Done!')
142 |
143 |
144 | if __name__ == '__main__':
145 | parser = argparse.ArgumentParser()
146 | parser.add_argument(
147 | '--data_dir',
148 | type=str,
149 | default='',
150 | help='Directory to download and extract CIFAR-10 to.')
151 |
152 | args = parser.parse_args()
153 | main(args.data_dir)
154 |
--------------------------------------------------------------------------------
/nasbench/lib/graph_util.py:
--------------------------------------------------------------------------------
1 | # Copyright 2019 The Google Research Authors.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Utility functions used by generate_graph.py."""
16 | from __future__ import absolute_import
17 | from __future__ import division
18 | from __future__ import print_function
19 |
20 | import hashlib
21 | import itertools
22 |
23 | import numpy as np
24 |
25 |
26 | def gen_is_edge_fn(bits):
27 | """Generate a boolean function for the edge connectivity.
28 |
29 | Given a bitstring FEDCBA and a 4x4 matrix, the generated matrix is
30 | [[0, A, B, D],
31 | [0, 0, C, E],
32 | [0, 0, 0, F],
33 | [0, 0, 0, 0]]
34 |
35 | Note that this function is agnostic to the actual matrix dimension due to
36 | order in which elements are filled out (column-major, starting from least
37 | significant bit). For example, the same FEDCBA bitstring (0-padded) on a 5x5
38 | matrix is
39 | [[0, A, B, D, 0],
40 | [0, 0, C, E, 0],
41 | [0, 0, 0, F, 0],
42 | [0, 0, 0, 0, 0],
43 | [0, 0, 0, 0, 0]]
44 |
45 | Args:
46 | bits: integer which will be interpreted as a bit mask.
47 |
48 | Returns:
49 | vectorized function that returns True when an edge is present.
50 | """
51 | def is_edge(x, y):
52 | """Is there an edge from x to y (0-indexed)?"""
53 | if x >= y:
54 | return 0
55 | # Map x, y to index into bit string
56 | index = x + (y * (y - 1) // 2)
57 | return (bits >> index) % 2 == 1
58 |
59 | return np.vectorize(is_edge)
60 |
61 |
62 | def is_full_dag(matrix):
63 | """Full DAG == all vertices on a path from vert 0 to (V-1).
64 |
65 | i.e. no disconnected or "hanging" vertices.
66 |
67 | It is sufficient to check for:
68 | 1) no rows of 0 except for row V-1 (only output vertex has no out-edges)
69 | 2) no cols of 0 except for col 0 (only input vertex has no in-edges)
70 |
71 | Args:
72 | matrix: V x V upper-triangular adjacency matrix
73 |
74 | Returns:
75 | True if the there are no dangling vertices.
76 | """
77 | shape = np.shape(matrix)
78 |
79 | rows = matrix[:shape[0]-1, :] == 0
80 | rows = np.all(rows, axis=1) # Any row with all 0 will be True
81 | rows_bad = np.any(rows)
82 |
83 | cols = matrix[:, 1:] == 0
84 | cols = np.all(cols, axis=0) # Any col with all 0 will be True
85 | cols_bad = np.any(cols)
86 |
87 | return (not rows_bad) and (not cols_bad)
88 |
89 |
90 | def num_edges(matrix):
91 | """Computes number of edges in adjacency matrix."""
92 | return np.sum(matrix)
93 |
94 |
95 | def hash_module(matrix, labeling):
96 | """Computes a graph-invariance MD5 hash of the matrix and label pair.
97 |
98 | Args:
99 | matrix: np.ndarray square upper-triangular adjacency matrix.
100 | labeling: list of int labels of length equal to both dimensions of
101 | matrix.
102 |
103 | Returns:
104 | MD5 hash of the matrix and labeling.
105 | """
106 | vertices = np.shape(matrix)[0]
107 | in_edges = np.sum(matrix, axis=0).tolist()
108 | out_edges = np.sum(matrix, axis=1).tolist()
109 |
110 | assert len(in_edges) == len(out_edges) == len(labeling)
111 | hashes = list(zip(out_edges, in_edges, labeling))
112 | hashes = [hashlib.md5(str(h).encode('utf-8')).hexdigest() for h in hashes]
113 | # Computing this up to the diameter is probably sufficient but since the
114 | # operation is fast, it is okay to repeat more times.
115 | for _ in range(vertices):
116 | new_hashes = []
117 | for v in range(vertices):
118 | in_neighbors = [hashes[w] for w in range(vertices) if matrix[w, v]]
119 | out_neighbors = [hashes[w] for w in range(vertices) if matrix[v, w]]
120 | new_hashes.append(hashlib.md5(
121 | (''.join(sorted(in_neighbors)) + '|' +
122 | ''.join(sorted(out_neighbors)) + '|' +
123 | hashes[v]).encode('utf-8')).hexdigest())
124 | hashes = new_hashes
125 | fingerprint = hashlib.md5(str(sorted(hashes)).encode('utf-8')).hexdigest()
126 |
127 | return fingerprint
128 |
129 |
130 | def permute_graph(graph, label, permutation):
131 | """Permutes the graph and labels based on permutation.
132 |
133 | Args:
134 | graph: np.ndarray adjacency matrix.
135 | label: list of labels of same length as graph dimensions.
136 | permutation: a permutation list of ints of same length as graph dimensions.
137 |
138 | Returns:
139 | np.ndarray where vertex permutation[v] is vertex v from the original graph
140 | """
141 | # vertex permutation[v] in new graph is vertex v in the old graph
142 | forward_perm = zip(permutation, list(range(len(permutation))))
143 | inverse_perm = [x[1] for x in sorted(forward_perm)]
144 | edge_fn = lambda x, y: graph[inverse_perm[x], inverse_perm[y]] == 1
145 | new_matrix = np.fromfunction(np.vectorize(edge_fn),
146 | (len(label), len(label)),
147 | dtype=np.int8)
148 | new_label = [label[inverse_perm[i]] for i in range(len(label))]
149 | return new_matrix, new_label
150 |
151 |
152 | def is_isomorphic(graph1, graph2):
153 | """Exhaustively checks if 2 graphs are isomorphic."""
154 | matrix1, label1 = np.array(graph1[0]), graph1[1]
155 | matrix2, label2 = np.array(graph2[0]), graph2[1]
156 | assert np.shape(matrix1) == np.shape(matrix2)
157 | assert len(label1) == len(label2)
158 |
159 | vertices = np.shape(matrix1)[0]
160 | # Note: input and output in our constrained graphs always map to themselves
161 | # but this script does not enforce that.
162 | for perm in itertools.permutations(range(0, vertices)):
163 | pmatrix1, plabel1 = permute_graph(matrix1, label1, perm)
164 | if np.array_equal(pmatrix1, matrix2) and plabel1 == label2:
165 | return True
166 |
167 | return False
168 |
--------------------------------------------------------------------------------
/nasbench/lib/base_ops.py:
--------------------------------------------------------------------------------
1 | # Copyright 2019 The Google Research Authors.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Base operations used by the modules in this search space."""
16 |
17 | from __future__ import absolute_import
18 | from __future__ import division
19 | from __future__ import print_function
20 |
21 | import abc
22 |
23 | import tensorflow as tf
24 |
25 | # Currently, only channels_last is well supported.
26 | VALID_DATA_FORMATS = frozenset(['channels_last', 'channels_first'])
27 | MIN_FILTERS = 8
28 | BN_MOMENTUM = 0.997
29 | BN_EPSILON = 1e-5
30 |
31 |
32 | def conv_bn_relu(inputs, conv_size, conv_filters, is_training, data_format):
33 | """Convolution followed by batch norm and ReLU."""
34 | if data_format == 'channels_last':
35 | axis = 3
36 | elif data_format == 'channels_first':
37 | axis = 1
38 | else:
39 | raise ValueError('invalid data_format')
40 |
41 | net = tf.layers.conv2d(
42 | inputs=inputs,
43 | filters=conv_filters,
44 | kernel_size=conv_size,
45 | strides=(1, 1),
46 | use_bias=False,
47 | kernel_initializer=tf.variance_scaling_initializer(),
48 | padding='same',
49 | data_format=data_format)
50 |
51 | net = tf.layers.batch_normalization(
52 | inputs=net,
53 | axis=axis,
54 | momentum=BN_MOMENTUM,
55 | epsilon=BN_EPSILON,
56 | training=is_training)
57 |
58 | net = tf.nn.relu(net)
59 |
60 | return net
61 |
62 |
63 | class BaseOp(object):
64 | """Abstract base operation class."""
65 | __metaclass__ = abc.ABCMeta
66 |
67 | def __init__(self, is_training, data_format='channels_last'):
68 | self.is_training = is_training
69 | if data_format.lower() not in VALID_DATA_FORMATS:
70 | raise ValueError('invalid data_format')
71 | self.data_format = data_format.lower()
72 |
73 | @abc.abstractmethod
74 | def build(self, inputs, channels):
75 | """Builds the operation with input tensors and returns an output tensor.
76 |
77 | Args:
78 | inputs: a 4-D Tensor.
79 | channels: int number of output channels of operation. The operation may
80 | choose to ignore this parameter.
81 |
82 | Returns:
83 | a 4-D Tensor with the same data format.
84 | """
85 | pass
86 |
87 |
88 | class Identity(BaseOp):
89 | """Identity operation (ignores channels)."""
90 |
91 | def build(self, inputs, channels):
92 | del channels # Unused
93 | return tf.identity(inputs, name='identity')
94 |
95 |
96 | class Conv3x3BnRelu(BaseOp):
97 | """3x3 convolution with batch norm and ReLU activation."""
98 |
99 | def build(self, inputs, channels):
100 | with tf.variable_scope('Conv3x3-BN-ReLU'):
101 | net = conv_bn_relu(
102 | inputs, 3, channels, self.is_training, self.data_format)
103 |
104 | return net
105 |
106 |
107 | class Conv1x1BnRelu(BaseOp):
108 | """1x1 convolution with batch norm and ReLU activation."""
109 |
110 | def build(self, inputs, channels):
111 | with tf.variable_scope('Conv1x1-BN-ReLU'):
112 | net = conv_bn_relu(
113 | inputs, 1, channels, self.is_training, self.data_format)
114 |
115 | return net
116 |
117 |
118 | class MaxPool3x3(BaseOp):
119 | """3x3 max pool with no subsampling."""
120 |
121 | def build(self, inputs, channels):
122 | del channels # Unused
123 | with tf.variable_scope('MaxPool3x3'):
124 | net = tf.layers.max_pooling2d(
125 | inputs=inputs,
126 | pool_size=(3, 3),
127 | strides=(1, 1),
128 | padding='same',
129 | data_format=self.data_format)
130 |
131 | return net
132 |
133 |
134 | class BottleneckConv3x3(BaseOp):
135 | """[1x1(/4)]+3x3+[1x1(*4)] conv. Uses BN + ReLU post-activation."""
136 | # TODO(chrisying): verify this block can reproduce results of ResNet-50.
137 |
138 | def build(self, inputs, channels):
139 | with tf.variable_scope('BottleneckConv3x3'):
140 | net = conv_bn_relu(
141 | inputs, 1, channels // 4, self.is_training, self.data_format)
142 | net = conv_bn_relu(
143 | net, 3, channels // 4, self.is_training, self.data_format)
144 | net = conv_bn_relu(
145 | net, 1, channels, self.is_training, self.data_format)
146 |
147 | return net
148 |
149 |
150 | class BottleneckConv5x5(BaseOp):
151 | """[1x1(/4)]+5x5+[1x1(*4)] conv. Uses BN + ReLU post-activation."""
152 |
153 | def build(self, inputs, channels):
154 | with tf.variable_scope('BottleneckConv5x5'):
155 | net = conv_bn_relu(
156 | inputs, 1, channels // 4, self.is_training, self.data_format)
157 | net = conv_bn_relu(
158 | net, 5, channels // 4, self.is_training, self.data_format)
159 | net = conv_bn_relu(
160 | net, 1, channels, self.is_training, self.data_format)
161 |
162 | return net
163 |
164 |
165 | class MaxPool3x3Conv1x1(BaseOp):
166 | """3x3 max pool with no subsampling followed by 1x1 for rescaling."""
167 |
168 | def build(self, inputs, channels):
169 | with tf.variable_scope('MaxPool3x3-Conv1x1'):
170 | net = tf.layers.max_pooling2d(
171 | inputs=inputs,
172 | pool_size=(3, 3),
173 | strides=(1, 1),
174 | padding='same',
175 | data_format=self.data_format)
176 |
177 | net = conv_bn_relu(net, 1, channels, self.is_training, self.data_format)
178 |
179 | return net
180 |
181 |
182 | # Commas should not be used in op names
183 | OP_MAP = {
184 | 'identity': Identity,
185 | 'conv3x3-bn-relu': Conv3x3BnRelu,
186 | 'conv1x1-bn-relu': Conv1x1BnRelu,
187 | 'maxpool3x3': MaxPool3x3,
188 | 'bottleneck3x3': BottleneckConv3x3,
189 | 'bottleneck5x5': BottleneckConv5x5,
190 | 'maxpool3x3-conv1x1': MaxPool3x3Conv1x1,
191 | }
192 |
--------------------------------------------------------------------------------
/nasbench/lib/model_spec.py:
--------------------------------------------------------------------------------
1 | # Copyright 2019 The Google Research Authors.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Model specification for module connectivity individuals.
16 |
17 | This module handles pruning the unused parts of the computation graph but should
18 | avoid creating any TensorFlow models (this is done inside model_builder.py).
19 | """
20 |
21 | from __future__ import absolute_import
22 | from __future__ import division
23 | from __future__ import print_function
24 |
25 | import copy
26 |
27 | from nasbench.lib import graph_util
28 | import numpy as np
29 |
30 | # Graphviz is optional and only required for visualization.
31 | try:
32 | import graphviz # pylint: disable=g-import-not-at-top
33 | except ImportError:
34 | pass
35 |
36 |
37 | class ModelSpec(object):
38 | """Model specification given adjacency matrix and labeling."""
39 |
40 | def __init__(self, matrix, ops, data_format='channels_last'):
41 | """Initialize the module spec.
42 |
43 | Args:
44 | matrix: ndarray or nested list with shape [V, V] for the adjacency matrix.
45 | ops: V-length list of labels for the base ops used. The first and last
46 | elements are ignored because they are the input and output vertices
47 | which have no operations. The elements are retained to keep consistent
48 | indexing.
49 | data_format: channels_last or channels_first.
50 |
51 | Raises:
52 | ValueError: invalid matrix or ops
53 | """
54 | if not isinstance(matrix, np.ndarray):
55 | matrix = np.array(matrix)
56 | shape = np.shape(matrix)
57 | if len(shape) != 2 or shape[0] != shape[1]:
58 | raise ValueError('matrix must be square')
59 | if shape[0] != len(ops):
60 | raise ValueError('length of ops must match matrix dimensions')
61 | if not is_upper_triangular(matrix):
62 | raise ValueError('matrix must be upper triangular')
63 |
64 | # Both the original and pruned matrices are deep copies of the matrix and
65 | # ops so any changes to those after initialization are not recognized by the
66 | # spec.
67 | self.original_matrix = copy.deepcopy(matrix)
68 | self.original_ops = copy.deepcopy(ops)
69 |
70 | self.matrix = copy.deepcopy(matrix)
71 | self.ops = copy.deepcopy(ops)
72 | self.valid_spec = True
73 | self._prune()
74 |
75 | self.data_format = data_format
76 |
77 | def _prune(self):
78 | """Prune the extraneous parts of the graph.
79 |
80 | General procedure:
81 | 1) Remove parts of graph not connected to input.
82 | 2) Remove parts of graph not connected to output.
83 | 3) Reorder the vertices so that they are consecutive after steps 1 and 2.
84 |
85 | These 3 steps can be combined by deleting the rows and columns of the
86 | vertices that are not reachable from both the input and output (in reverse).
87 | """
88 | num_vertices = np.shape(self.original_matrix)[0]
89 |
90 | # DFS forward from input
91 | visited_from_input = set([0])
92 | frontier = [0]
93 | while frontier:
94 | top = frontier.pop()
95 | for v in range(top + 1, num_vertices):
96 | if self.original_matrix[top, v] and v not in visited_from_input:
97 | visited_from_input.add(v)
98 | frontier.append(v)
99 |
100 | # DFS backward from output
101 | visited_from_output = set([num_vertices - 1])
102 | frontier = [num_vertices - 1]
103 | while frontier:
104 | top = frontier.pop()
105 | for v in range(0, top):
106 | if self.original_matrix[v, top] and v not in visited_from_output:
107 | visited_from_output.add(v)
108 | frontier.append(v)
109 |
110 | # Any vertex that isn't connected to both input and output is extraneous to
111 | # the computation graph.
112 | extraneous = set(range(num_vertices)).difference(
113 | visited_from_input.intersection(visited_from_output))
114 |
115 | # If the non-extraneous graph is less than 2 vertices, the input is not
116 | # connected to the output and the spec is invalid.
117 | if len(extraneous) > num_vertices - 2:
118 | self.matrix = None
119 | self.ops = None
120 | self.valid_spec = False
121 | return
122 |
123 | self.matrix = np.delete(self.matrix, list(extraneous), axis=0)
124 | self.matrix = np.delete(self.matrix, list(extraneous), axis=1)
125 | for index in sorted(extraneous, reverse=True):
126 | del self.ops[index]
127 |
128 | def hash_spec(self, canonical_ops):
129 | """Computes the isomorphism-invariant graph hash of this spec.
130 |
131 | Args:
132 | canonical_ops: list of operations in the canonical ordering which they
133 | were assigned (i.e. the order provided in the config['available_ops']).
134 |
135 | Returns:
136 | MD5 hash of this spec which can be used to query the dataset.
137 | """
138 | # Invert the operations back to integer label indices used in graph gen.
139 | labeling = [-1] + [canonical_ops.index(op) for op in self.ops[1:-1]] + [-2]
140 | return graph_util.hash_module(self.matrix, labeling)
141 |
142 | def visualize(self):
143 | """Creates a dot graph. Can be visualized in colab directly."""
144 | num_vertices = np.shape(self.matrix)[0]
145 | g = graphviz.Digraph()
146 | g.node(str(0), 'input')
147 | for v in range(1, num_vertices - 1):
148 | g.node(str(v), self.ops[v])
149 | g.node(str(num_vertices - 1), 'output')
150 |
151 | for src in range(num_vertices - 1):
152 | for dst in range(src + 1, num_vertices):
153 | if self.matrix[src, dst]:
154 | g.edge(str(src), str(dst))
155 |
156 | return g
157 |
158 |
159 | def is_upper_triangular(matrix):
160 | """True if matrix is 0 on diagonal and below."""
161 | for src in range(np.shape(matrix)[0]):
162 | for dst in range(0, src + 1):
163 | if matrix[src, dst] != 0:
164 | return False
165 |
166 | return True
167 |
--------------------------------------------------------------------------------
/nasbench/tests/model_spec_test.py:
--------------------------------------------------------------------------------
1 | # Copyright 2019 The Google Research Authors.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Tests for lib/model_spec.py."""
16 |
17 | from __future__ import absolute_import
18 | from __future__ import division
19 | from __future__ import print_function
20 |
21 | from nasbench.lib import model_spec
22 | import numpy as np
23 | import tensorflow as tf # Used only for tf.test
24 |
25 |
26 | class ModelSpecTest(tf.test.TestCase):
27 |
28 | def test_prune_noop(self):
29 | """Tests graphs which require no pruning."""
30 | model1 = model_spec.ModelSpec(
31 | np.array([[0, 1, 0],
32 | [0, 0, 1],
33 | [0, 0, 0]]),
34 | [0, 0, 0])
35 | assert model1.valid_spec
36 | assert np.array_equal(model1.original_matrix, model1.matrix)
37 | assert model1.original_ops == model1.original_ops
38 |
39 | model2 = model_spec.ModelSpec(
40 | np.array([[0, 1, 1],
41 | [0, 0, 1],
42 | [0, 0, 0]]),
43 | [0, 0, 0])
44 | assert model2.valid_spec
45 | assert np.array_equal(model2.original_matrix, model2.matrix)
46 | assert model2.original_ops == model2.ops
47 |
48 | model3 = model_spec.ModelSpec(
49 | np.array([[0, 1, 1, 0],
50 | [0, 0, 0, 1],
51 | [0, 0, 0, 1],
52 | [0, 0, 0, 0]]),
53 | [0, 0, 0, 0])
54 | assert model3.valid_spec
55 | assert np.array_equal(model3.original_matrix, model3.matrix)
56 | assert model3.original_ops == model3.ops
57 |
58 | def test_prune_islands(self):
59 | """Tests isolated components are pruned."""
60 | model1 = model_spec.ModelSpec(
61 | np.array([[0, 1, 0, 0],
62 | [0, 0, 0, 1],
63 | [0, 0, 0, 0],
64 | [0, 0, 0, 0]]),
65 | [1, 2, 3, 4])
66 | assert model1.valid_spec
67 | assert np.array_equal(model1.matrix,
68 | np.array([[0, 1, 0],
69 | [0, 0, 1],
70 | [0, 0, 0]]))
71 | assert model1.ops == [1, 2, 4]
72 |
73 | model2 = model_spec.ModelSpec(
74 | np.array([[0, 1, 0, 0, 0],
75 | [0, 0, 0, 0, 1],
76 | [0, 0, 0, 1, 0],
77 | [0, 0, 0, 0, 0],
78 | [0, 0, 0, 0, 0]]),
79 | [1, 2, 3, 4, 5])
80 | assert model2.valid_spec
81 | assert np.array_equal(model2.matrix,
82 | np.array([[0, 1, 0],
83 | [0, 0, 1],
84 | [0, 0, 0]]))
85 | assert model2.ops == [1, 2, 5]
86 |
87 | def test_prune_dangling(self):
88 | """Tests dangling vertices are pruned."""
89 | model1 = model_spec.ModelSpec(
90 | np.array([[0, 1, 1, 0],
91 | [0, 0, 0, 0],
92 | [0, 0, 0, 1],
93 | [0, 0, 0, 0]]),
94 | [1, 2, 3, 4])
95 | assert model1.valid_spec
96 | assert np.array_equal(model1.matrix,
97 | np.array([[0, 1, 0],
98 | [0, 0, 1],
99 | [0, 0, 0]]))
100 | assert model1.ops == [1, 3, 4]
101 |
102 | model2 = model_spec.ModelSpec(
103 | np.array([[0, 0, 1, 0],
104 | [0, 0, 0, 1],
105 | [0, 0, 0, 1],
106 | [0, 0, 0, 0]]),
107 | [1, 2, 3, 4])
108 | assert model2.valid_spec
109 | assert np.array_equal(model2.matrix,
110 | np.array([[0, 1, 0],
111 | [0, 0, 1],
112 | [0, 0, 0]]))
113 | assert model2.ops == [1, 3, 4]
114 |
115 | def test_prune_disconnected(self):
116 | """Tests graphs where with no input to output path are marked invalid."""
117 | model1 = model_spec.ModelSpec(
118 | np.array([[0, 0],
119 | [0, 0]]),
120 | [0, 0])
121 | assert not model1.valid_spec
122 |
123 | model2 = model_spec.ModelSpec(
124 | np.array([[0, 1, 0, 0],
125 | [0, 0, 0, 0],
126 | [0, 0, 0, 1],
127 | [0, 0, 0, 0]]),
128 | [1, 2, 3, 4])
129 | assert not model2.valid_spec
130 |
131 | model3 = model_spec.ModelSpec(
132 | np.array([[0, 0, 0, 0],
133 | [0, 0, 1, 0],
134 | [0, 0, 0, 0],
135 | [0, 0, 0, 0]]),
136 | [1, 2, 3, 4])
137 | assert not model3.valid_spec
138 |
139 | def test_is_upper_triangular(self):
140 | """Tests is_uppper_triangular correct for square graphs."""
141 | m0 = np.array([[0, 0, 0, 0],
142 | [0, 0, 0, 0],
143 | [0, 0, 0, 0],
144 | [0, 0, 0, 0]])
145 | assert model_spec.is_upper_triangular(m0)
146 |
147 | m1 = np.array([[0, 1, 1, 1],
148 | [0, 0, 1, 1],
149 | [0, 0, 0, 1],
150 | [0, 0, 0, 0]])
151 | assert model_spec.is_upper_triangular(m1)
152 |
153 | m2 = np.array([[0, 1, 1, 1],
154 | [0, 0, 1, 1],
155 | [1, 0, 0, 1],
156 | [0, 0, 0, 0]])
157 | assert not model_spec.is_upper_triangular(m2)
158 |
159 | m3 = np.array([[0, 0, 0, 0],
160 | [0, 0, 0, 0],
161 | [1, 0, 0, 0],
162 | [0, 0, 0, 0]])
163 | assert not model_spec.is_upper_triangular(m3)
164 |
165 | m4 = np.array([[1, 0, 0, 0],
166 | [1, 1, 0, 0],
167 | [1, 1, 1, 0],
168 | [1, 1, 1, 1]])
169 | assert not model_spec.is_upper_triangular(m4)
170 |
171 | m5 = np.array([[0]])
172 | assert model_spec.is_upper_triangular(m5)
173 |
174 | m6 = np.array([[1]])
175 | assert not model_spec.is_upper_triangular(m6)
176 |
177 |
178 | if __name__ == '__main__':
179 | tf.test.main()
180 |
--------------------------------------------------------------------------------
/nasbench/lib/cifar.py:
--------------------------------------------------------------------------------
1 | # Copyright 2019 The Google Research Authors.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """CIFAR-10 data pipeline with preprocessing.
16 |
17 | The data is generated via generate_cifar10_tfrecords.py.
18 | """
19 |
20 |
21 | from __future__ import absolute_import
22 | from __future__ import division
23 | from __future__ import print_function
24 |
25 | import functools
26 | import tensorflow as tf
27 |
28 | WIDTH = 32
29 | HEIGHT = 32
30 | RGB_MEAN = [125.31, 122.95, 113.87]
31 | RGB_STD = [62.99, 62.09, 66.70]
32 |
33 |
34 | class CIFARInput(object):
35 | """Wrapper class for input_fn passed to TPUEstimator."""
36 |
37 | def __init__(self, mode, config):
38 | """Initializes a CIFARInput object.
39 |
40 | Args:
41 | mode: one of [train, valid, test, augment, sample]
42 | config: config dict built from config.py
43 |
44 | Raises:
45 | ValueError: invalid mode or data files
46 | """
47 | self.mode = mode
48 | self.config = config
49 | if mode == 'train': # Training set (no validation & test)
50 | self.data_files = config['train_data_files']
51 | elif mode == 'train_eval': # For computing train error
52 | self.data_files = [config['train_data_files'][0]]
53 | elif mode == 'valid': # For computing validation error
54 | self.data_files = [config['valid_data_file']]
55 | elif mode == 'test': # For computing the test error
56 | self.data_files = [config['test_data_file']]
57 | elif mode == 'augment': # Training set (includes validation, no test)
58 | self.data_files = (config['train_data_files'] +
59 | [config['valid_data_file']])
60 | elif mode == 'sample': # Fixed batch of 100 samples from validation
61 | self.data_files = [config['sample_data_file']]
62 | else:
63 | raise ValueError('invalid mode')
64 |
65 | if not self.data_files:
66 | raise ValueError('no data files provided')
67 |
68 | @property
69 | def num_images(self):
70 | """Number of images in the dataset (depends on the mode)."""
71 | if self.mode == 'train':
72 | return 40000
73 | elif self.mode == 'train_eval':
74 | return 10000
75 | elif self.mode == 'valid':
76 | return 10000
77 | elif self.mode == 'test':
78 | return 10000
79 | elif self.mode == 'augment':
80 | return 50000
81 | elif self.mode == 'sample':
82 | return 100
83 |
84 | def input_fn(self, params):
85 | """Returns a CIFAR tf.data.Dataset object.
86 |
87 | Args:
88 | params: parameter dict pass by Estimator.
89 |
90 | Returns:
91 | tf.data.Dataset object
92 | """
93 | batch_size = params['batch_size']
94 | is_training = (self.mode == 'train' or self.mode == 'augment')
95 |
96 | dataset = tf.data.TFRecordDataset(self.data_files)
97 | dataset = dataset.prefetch(buffer_size=batch_size)
98 |
99 | # Repeat dataset for training modes
100 | if is_training:
101 | # Shuffle buffer with whole dataset to ensure full randomness per epoch
102 | dataset = dataset.cache().apply(
103 | tf.contrib.data.shuffle_and_repeat(
104 | buffer_size=self.num_images))
105 |
106 | # This is a hack to allow computing metrics on a fixed batch on TPU. Because
107 | # TPU shards the batch acrosss cores, we replicate the fixed batch so that
108 | # each core contains the whole batch.
109 | if self.mode == 'sample':
110 | dataset = dataset.repeat()
111 |
112 | # Parse, preprocess, and batch images
113 | parser_fn = functools.partial(_parser, is_training)
114 | dataset = dataset.apply(
115 | tf.contrib.data.map_and_batch(
116 | parser_fn,
117 | batch_size=batch_size,
118 | num_parallel_batches=self.config['tpu_num_shards'],
119 | drop_remainder=True))
120 |
121 | # Assign static batch size dimension
122 | dataset = dataset.map(functools.partial(_set_batch_dimension, batch_size))
123 |
124 | # Prefetch to overlap in-feed with training
125 | dataset = dataset.prefetch(tf.contrib.data.AUTOTUNE)
126 |
127 | return dataset
128 |
129 |
130 | def _preprocess(image):
131 | """Perform standard CIFAR preprocessing.
132 |
133 | Pads the image then performs a random crop.
134 | Then, image is flipped horizontally randomly.
135 |
136 | Args:
137 | image: image Tensor with shape [height, width, 3]
138 |
139 | Returns:
140 | preprocessed image with the same dimensions.
141 | """
142 | # Pad 4 pixels on all sides with 0
143 | image = tf.image.resize_image_with_crop_or_pad(
144 | image, HEIGHT + 8, WIDTH + 8)
145 |
146 | # Random crop
147 | image = tf.random_crop(image, [HEIGHT, WIDTH, 3], seed=0)
148 |
149 | # Random flip
150 | image = tf.image.random_flip_left_right(image, seed=0)
151 |
152 | return image
153 |
154 |
155 | def _parser(use_preprocessing, serialized_example):
156 | """Parses a single tf.Example into image and label tensors."""
157 | features = tf.parse_single_example(
158 | serialized_example,
159 | features={
160 | 'image': tf.FixedLenFeature([], tf.string),
161 | 'label': tf.FixedLenFeature([], tf.int64),
162 | })
163 | image = tf.decode_raw(features['image'], tf.uint8)
164 | image.set_shape([3 * HEIGHT * WIDTH])
165 | image = tf.reshape(image, [3, HEIGHT, WIDTH])
166 | # TODO(chrisying): handle NCHW format
167 | image = tf.transpose(image, [1, 2, 0])
168 | image = tf.cast(image, tf.float32)
169 | if use_preprocessing:
170 | image = _preprocess(image)
171 | image -= tf.constant(RGB_MEAN, shape=[1, 1, 3])
172 | image /= tf.constant(RGB_STD, shape=[1, 1, 3])
173 | label = tf.cast(features['label'], tf.int32)
174 | return image, label
175 |
176 |
177 | def _set_batch_dimension(batch_size, images, labels):
178 | images.set_shape(images.get_shape().merge_with(
179 | tf.TensorShape([batch_size, None, None, None])))
180 | labels.set_shape(labels.get_shape().merge_with(
181 | tf.TensorShape([batch_size])))
182 |
183 | return images, labels
184 |
--------------------------------------------------------------------------------
/nasbench/scripts/generate_graphs.py:
--------------------------------------------------------------------------------
1 | # Copyright 2019 The Google Research Authors.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Generate all graphs up to structure and label isomorphism.
16 |
17 | The goal is to generate all unique computational graphs up to some number of
18 | vertices and edges. Computational graphs can be represented by directed acyclic
19 | graphs with all components connected along some path from a specially-labeled
20 | input to output. The pseudocode for generating these is:
21 |
22 | for V in [2, ..., MAX_VERTICES]: # V includes input and output vertices
23 | generate all bitmasks of length V*(V-1)/2 # num upper triangular entries
24 | for each bitmask:
25 | convert bitmask to adjacency matrix
26 | if adjacency matrix has disconnected vertices from input/output:
27 | discard and continue to next matrix
28 | generate all labelings of ops to vertices
29 | for each labeling:
30 | compute graph hash from matrix and labels
31 | if graph hash has not been seen before:
32 | output graph (adjacency matrix + labeling)
33 |
34 | This script uses a modification on Weisfeiler-Lehman color refinement
35 | (https://ist.ac.at/mfcs13/slides/gi.pdf) for graph hashing, which is very
36 | loosely similar to the hashing approach described in
37 | https://arxiv.org/pdf/1606.00001.pdf. The general idea is to assign each vertex
38 | a hash based on the in-degree, out-degree, and operation label then iteratively
39 | hash each vertex with the hashes of its neighbors.
40 |
41 | In more detail, the iterative update involves repeating the following steps a
42 | number of times greater than or equal to the diameter of the graph:
43 | 1) For each vertex, sort the hashes of the in-neighbors.
44 | 2) For each vertex, sort the hashes of the out-neighbors.
45 | 3) For each vertex, concatenate the sorted hashes from (1), (2) and the vertex
46 | operation label.
47 | 4) For each vertex, compute the MD5 hash of the concatenated values in (3).
48 | 5) Assign the newly computed hashes to each vertex.
49 |
50 | Finally, sort the hashes of all the vertices and concat and hash one more time
51 | to obtain the final graph hash. This hash is a graph invariant as all operations
52 | are invariant under isomorphism, thus we expect no false negatives (isomorphic
53 | graphs hashed to different values).
54 |
55 | We have empirically verified that, for graphs up to 7 vertices, 9 edges, 3 ops,
56 | this algorithm does not cause "false positives" (graphs that hash to the same
57 | value but are non-isomorphic). For such graphs, this algorithm yields 423,624
58 | unique computation graphs, which is roughly 1/3rd of the total number of
59 | connected DAGs before de-duping using this hash algorithm.
60 | """
61 | from __future__ import absolute_import
62 | from __future__ import division
63 | from __future__ import print_function
64 |
65 | import itertools
66 | import json
67 | import sys
68 |
69 | from absl import app
70 | from absl import flags
71 | from absl import logging
72 |
73 | from nasbench.lib import graph_util
74 | import numpy as np
75 | import tensorflow as tf # For gfile
76 |
77 | flags.DEFINE_string('output_file', '/tmp/generated_graphs.json',
78 | 'Output file name.')
79 | flags.DEFINE_integer('max_vertices', 7,
80 | 'Maximum number of vertices including input/output.')
81 | flags.DEFINE_integer('num_ops', 3, 'Number of operation labels.')
82 | flags.DEFINE_integer('max_edges', 9, 'Maximum number of edges.')
83 | flags.DEFINE_boolean('verify_isomorphism', True,
84 | 'Exhaustively verifies that each detected isomorphism'
85 | ' is truly an isomorphism. This operation is very'
86 | ' expensive.')
87 | FLAGS = flags.FLAGS
88 |
89 |
90 | def main(_):
91 | total_graphs = 0 # Total number of graphs (including isomorphisms)
92 | # hash --> (matrix, label) for the canonical graph associated with each hash
93 | buckets = {}
94 |
95 | logging.info('Using %d vertices, %d op labels, max %d edges',
96 | FLAGS.max_vertices, FLAGS.num_ops, FLAGS.max_edges)
97 | for vertices in range(2, FLAGS.max_vertices+1):
98 | for bits in range(2 ** (vertices * (vertices-1) // 2)):
99 | # Construct adj matrix from bit string
100 | matrix = np.fromfunction(graph_util.gen_is_edge_fn(bits),
101 | (vertices, vertices),
102 | dtype=np.int8)
103 |
104 | # Discard any graphs which can be pruned or exceed constraints
105 | if (not graph_util.is_full_dag(matrix) or
106 | graph_util.num_edges(matrix) > FLAGS.max_edges):
107 | continue
108 |
109 | # Iterate through all possible labelings
110 | for labeling in itertools.product(*[range(FLAGS.num_ops)
111 | for _ in range(vertices-2)]):
112 | total_graphs += 1
113 | labeling = [-1] + list(labeling) + [-2]
114 | fingerprint = graph_util.hash_module(matrix, labeling)
115 |
116 | if fingerprint not in buckets:
117 | buckets[fingerprint] = (matrix.tolist(), labeling)
118 |
119 | # This catches the "false positive" case of two models which are not
120 | # isomorphic hashing to the same bucket.
121 | elif FLAGS.verify_isomorphism:
122 | canonical_graph = buckets[fingerprint]
123 | if not graph_util.is_isomorphic(
124 | (matrix.tolist(), labeling), canonical_graph):
125 | logging.fatal('Matrix:\n%s\nLabel: %s\nis not isomorphic to'
126 | ' canonical matrix:\n%s\nLabel: %s',
127 | str(matrix), str(labeling),
128 | str(canonical_graph[0]),
129 | str(canonical_graph[1]))
130 | sys.exit()
131 |
132 | logging.info('Up to %d vertices: %d graphs (%d without hashing)',
133 | vertices, len(buckets), total_graphs)
134 |
135 | with tf.gfile.Open(FLAGS.output_file, 'w') as f:
136 | json.dump(buckets, f, sort_keys=True)
137 |
138 |
139 | if __name__ == '__main__':
140 | app.run(main)
141 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # NASBench: A Neural Architecture Search Dataset and Benchmark
2 |
3 | This repository contains the code used for generating and interacting with the
4 | NASBench dataset. The dataset contains **423,624 unique neural networks**
5 | exhaustively generated and evaluated from a fixed graph-based search space.
6 |
7 | Each network is trained and evaluated multiple times on CIFAR-10 at various
8 | training budgets and we present the metrics in a queriable API. The current
9 | release contains over **5 million** trained and evaluated models.
10 |
11 | Our paper can be found at:
12 |
13 | [NAS-Bench-101: Towards Reproducible Neural Architecture
14 | Search](https://arxiv.org/abs/1902.09635)
15 |
16 | If you use this dataset, please cite:
17 |
18 | ```
19 | @InProceedings{pmlr-v97-ying19a,
20 | title = {{NAS}-Bench-101: Towards Reproducible Neural Architecture Search},
21 | author = {Ying, Chris and Klein, Aaron and Christiansen, Eric and Real, Esteban and Murphy, Kevin and Hutter, Frank},
22 | booktitle = {Proceedings of the 36th International Conference on Machine Learning},
23 | pages = {7105--7114},
24 | year = {2019},
25 | editor = {Chaudhuri, Kamalika and Salakhutdinov, Ruslan},
26 | volume = {97},
27 | series = {Proceedings of Machine Learning Research},
28 | address = {Long Beach, California, USA},
29 | month = {09--15 Jun},
30 | publisher = {PMLR},
31 | url = {http://proceedings.mlr.press/v97/ying19a.html},
32 | ```
33 |
34 | ## Dataset overview
35 |
36 | NASBench is a tabular dataset which maps convolutional neural network
37 | architectures to their trained and evaluated performance on CIFAR-10.
38 | Specifically, all networks share the same network "skeleton", which can be seen
39 | in Figure (a) below. What changes between different models is the "module", which is a
40 | collection of neural network operations linked in an arbitrary graph-like
41 | structure.
42 |
43 | Modules are represented by directed acyclic graphs with up to 9 vertices and 7
44 | edges. The valid operations at each vertex are "3x3 convolution", "1x1
45 | convolution", and "3x3 max-pooling". Figure (b) below shows an Inception-like
46 | cell within the dataset. Figure (c) shows a high-level overview of how the
47 | interior filter counts of each module are computed.
48 |
49 |
50 |
51 | There are exactly 423,624 computationally unique modules within this search
52 | space and each one has been trained for 4, 12, 36, and 108 epochs three times
53 | each (423K * 3 * 4 = ~5M total trained models). We report the following metrics:
54 |
55 | * training accuracy
56 | * validation accuracy
57 | * testing accuracy
58 | * number of parameters
59 | * training time
60 |
61 | The scatterplot below shows a comparison of number of parameters, training time,
62 | and mean validation accuracy of models trained for 108 epochs in the dataset.
63 |
64 |
65 |
66 | See our paper for more detailed information about the design of this search
67 | space, further implementation details, and more in-depth analysis.
68 |
69 | ## Colab
70 |
71 | You can directly use this dataset from Google Colaboratory without needing to
72 | install anything on your local machine. Click "Open in Colab" below:
73 |
74 | [](https://colab.research.google.com/github/google-research/nasbench/blob/master/NASBench.ipynb)
75 |
76 | ## Setup
77 |
78 | 1. Clone this repo.
79 |
80 | ```
81 | git clone https://github.com/google-research/nasbench
82 | cd nasbench
83 | ```
84 |
85 | 2. (optional) Create a virtualenv for this library.
86 |
87 | ```
88 | virtualenv venv
89 | source venv/bin/activate
90 | ```
91 |
92 | 3. Install the project along with dependencies.
93 |
94 | ```
95 | pip install -e .
96 | ```
97 |
98 | **Note:** the only required dependency is TensorFlow. The above instructions
99 | will install the CPU version of TensorFlow to the virtualenv. For other install
100 | options, see https://www.tensorflow.org/install/.
101 |
102 | ## Download the dataset
103 |
104 | The full dataset (which includes all 5M data points at all 4 epoch lengths):
105 |
106 | https://storage.googleapis.com/nasbench/nasbench_full.tfrecord
107 |
108 | Size: ~1.95 GB, SHA256: `3d64db8180fb1b0207212f9032205064312b6907a3bbc81eabea10db2f5c7e9c`
109 |
110 | ---
111 |
112 | Subset of the dataset with only models trained at 108 epochs:
113 |
114 | https://storage.googleapis.com/nasbench/nasbench_only108.tfrecord
115 |
116 | Size: ~499 MB, SHA256: `4c39c3936e36a85269881d659e44e61a245babcb72cb374eacacf75d0e5f4fd1`
117 |
118 |
119 | ## Using the dataset
120 |
121 | Example usage (see `example.py` for a full runnable example):
122 |
123 | ```python
124 | # Load the data from file (this will take some time)
125 | nasbench = api.NASBench('/path/to/nasbench.tfrecord')
126 |
127 | # Create an Inception-like module (5x5 convolution replaced with two 3x3
128 | # convolutions).
129 | model_spec = api.ModelSpec(
130 | # Adjacency matrix of the module
131 | matrix=[[0, 1, 1, 1, 0, 1, 0], # input layer
132 | [0, 0, 0, 0, 0, 0, 1], # 1x1 conv
133 | [0, 0, 0, 0, 0, 0, 1], # 3x3 conv
134 | [0, 0, 0, 0, 1, 0, 0], # 5x5 conv (replaced by two 3x3's)
135 | [0, 0, 0, 0, 0, 0, 1], # 5x5 conv (replaced by two 3x3's)
136 | [0, 0, 0, 0, 0, 0, 1], # 3x3 max-pool
137 | [0, 0, 0, 0, 0, 0, 0]], # output layer
138 | # Operations at the vertices of the module, matches order of matrix
139 | ops=[INPUT, CONV1X1, CONV3X3, CONV3X3, CONV3X3, MAXPOOL3X3, OUTPUT])
140 |
141 | # Query this model from dataset, returns a dictionary containing the metrics
142 | # associated with this model.
143 | data = nasbench.query(model_spec)
144 | ```
145 |
146 | See `nasbench/api.py` for more information, including the constraints on valid
147 | module matrices and operations.
148 |
149 | **Note**: it is not required to use `nasbench/api.py` to work with this dataset,
150 | you can see how to parse the dataset files from the initializer inside
151 | `nasbench/api.py` and then interact the data however you'd like.
152 |
153 | ## How the dataset was generated
154 |
155 | The dataset generation code is provided for reference, but the dataset has
156 | already been fully generated.
157 |
158 | The list of unique computation graphs evaluated in this dataset was generated
159 | via `nasbench/scripts/generate_graphs.py`. Each of these graphs was evaluated
160 | multiple times via `nasbench/scripts/run_evaluation.py`.
161 |
162 | ## How to run the unit tests
163 |
164 | Unit tests are included for some of the algorithmically complex parts of the
165 | code. The tests can be run directly via Python. Example:
166 |
167 | ```
168 | python nasbench/tests/model_builder_test.py
169 | ```
170 |
171 | ## Disclaimer
172 |
173 | This is not an official Google product.
174 |
--------------------------------------------------------------------------------
/nasbench/lib/model_metrics_pb2.py:
--------------------------------------------------------------------------------
1 | # Copyright 2019 The Google Research Authors.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | # pylint: skip-file
16 | # Generated by the protocol buffer compiler. DO NOT EDIT!
17 | # source: model_metrics.proto
18 |
19 | import sys
20 | _b=sys.version_info[0]<3 and (lambda x:x) or (lambda x:x.encode('latin1'))
21 | from google.protobuf import descriptor as _descriptor
22 | from google.protobuf import message as _message
23 | from google.protobuf import reflection as _reflection
24 | from google.protobuf import symbol_database as _symbol_database
25 | # @@protoc_insertion_point(imports)
26 |
27 | _sym_db = _symbol_database.Default()
28 |
29 |
30 |
31 |
32 | DESCRIPTOR = _descriptor.FileDescriptor(
33 | name='model_metrics.proto',
34 | package='nasbench',
35 | syntax='proto2',
36 | serialized_options=None,
37 | serialized_pb=_b('\n\x13model_metrics.proto\x12\x08nasbench\"s\n\x0cModelMetrics\x12\x31\n\x0f\x65valuation_data\x18\x01 \x03(\x0b\x32\x18.nasbench.EvaluationData\x12\x1c\n\x14trainable_parameters\x18\x02 \x01(\x05\x12\x12\n\ntotal_time\x18\x03 \x01(\x01\"\xa3\x01\n\x0e\x45valuationData\x12\x15\n\rcurrent_epoch\x18\x01 \x01(\x01\x12\x15\n\rtraining_time\x18\x02 \x01(\x01\x12\x16\n\x0etrain_accuracy\x18\x03 \x01(\x01\x12\x1b\n\x13validation_accuracy\x18\x04 \x01(\x01\x12\x15\n\rtest_accuracy\x18\x05 \x01(\x01\x12\x17\n\x0f\x63heckpoint_path\x18\x06 \x01(\t')
38 | )
39 |
40 |
41 |
42 |
43 | _MODELMETRICS = _descriptor.Descriptor(
44 | name='ModelMetrics',
45 | full_name='nasbench.ModelMetrics',
46 | filename=None,
47 | file=DESCRIPTOR,
48 | containing_type=None,
49 | fields=[
50 | _descriptor.FieldDescriptor(
51 | name='evaluation_data', full_name='nasbench.ModelMetrics.evaluation_data', index=0,
52 | number=1, type=11, cpp_type=10, label=3,
53 | has_default_value=False, default_value=[],
54 | message_type=None, enum_type=None, containing_type=None,
55 | is_extension=False, extension_scope=None,
56 | serialized_options=None, file=DESCRIPTOR),
57 | _descriptor.FieldDescriptor(
58 | name='trainable_parameters', full_name='nasbench.ModelMetrics.trainable_parameters', index=1,
59 | number=2, type=5, cpp_type=1, label=1,
60 | has_default_value=False, default_value=0,
61 | message_type=None, enum_type=None, containing_type=None,
62 | is_extension=False, extension_scope=None,
63 | serialized_options=None, file=DESCRIPTOR),
64 | _descriptor.FieldDescriptor(
65 | name='total_time', full_name='nasbench.ModelMetrics.total_time', index=2,
66 | number=3, type=1, cpp_type=5, label=1,
67 | has_default_value=False, default_value=float(0),
68 | message_type=None, enum_type=None, containing_type=None,
69 | is_extension=False, extension_scope=None,
70 | serialized_options=None, file=DESCRIPTOR),
71 | ],
72 | extensions=[
73 | ],
74 | nested_types=[],
75 | enum_types=[
76 | ],
77 | serialized_options=None,
78 | is_extendable=False,
79 | syntax='proto2',
80 | extension_ranges=[],
81 | oneofs=[
82 | ],
83 | serialized_start=33,
84 | serialized_end=148,
85 | )
86 |
87 |
88 | _EVALUATIONDATA = _descriptor.Descriptor(
89 | name='EvaluationData',
90 | full_name='nasbench.EvaluationData',
91 | filename=None,
92 | file=DESCRIPTOR,
93 | containing_type=None,
94 | fields=[
95 | _descriptor.FieldDescriptor(
96 | name='current_epoch', full_name='nasbench.EvaluationData.current_epoch', index=0,
97 | number=1, type=1, cpp_type=5, label=1,
98 | has_default_value=False, default_value=float(0),
99 | message_type=None, enum_type=None, containing_type=None,
100 | is_extension=False, extension_scope=None,
101 | serialized_options=None, file=DESCRIPTOR),
102 | _descriptor.FieldDescriptor(
103 | name='training_time', full_name='nasbench.EvaluationData.training_time', index=1,
104 | number=2, type=1, cpp_type=5, label=1,
105 | has_default_value=False, default_value=float(0),
106 | message_type=None, enum_type=None, containing_type=None,
107 | is_extension=False, extension_scope=None,
108 | serialized_options=None, file=DESCRIPTOR),
109 | _descriptor.FieldDescriptor(
110 | name='train_accuracy', full_name='nasbench.EvaluationData.train_accuracy', index=2,
111 | number=3, type=1, cpp_type=5, label=1,
112 | has_default_value=False, default_value=float(0),
113 | message_type=None, enum_type=None, containing_type=None,
114 | is_extension=False, extension_scope=None,
115 | serialized_options=None, file=DESCRIPTOR),
116 | _descriptor.FieldDescriptor(
117 | name='validation_accuracy', full_name='nasbench.EvaluationData.validation_accuracy', index=3,
118 | number=4, type=1, cpp_type=5, label=1,
119 | has_default_value=False, default_value=float(0),
120 | message_type=None, enum_type=None, containing_type=None,
121 | is_extension=False, extension_scope=None,
122 | serialized_options=None, file=DESCRIPTOR),
123 | _descriptor.FieldDescriptor(
124 | name='test_accuracy', full_name='nasbench.EvaluationData.test_accuracy', index=4,
125 | number=5, type=1, cpp_type=5, label=1,
126 | has_default_value=False, default_value=float(0),
127 | message_type=None, enum_type=None, containing_type=None,
128 | is_extension=False, extension_scope=None,
129 | serialized_options=None, file=DESCRIPTOR),
130 | _descriptor.FieldDescriptor(
131 | name='checkpoint_path', full_name='nasbench.EvaluationData.checkpoint_path', index=5,
132 | number=6, type=9, cpp_type=9, label=1,
133 | has_default_value=False, default_value=_b("").decode('utf-8'),
134 | message_type=None, enum_type=None, containing_type=None,
135 | is_extension=False, extension_scope=None,
136 | serialized_options=None, file=DESCRIPTOR),
137 | ],
138 | extensions=[
139 | ],
140 | nested_types=[],
141 | enum_types=[
142 | ],
143 | serialized_options=None,
144 | is_extendable=False,
145 | syntax='proto2',
146 | extension_ranges=[],
147 | oneofs=[
148 | ],
149 | serialized_start=151,
150 | serialized_end=314,
151 | )
152 |
153 | _MODELMETRICS.fields_by_name['evaluation_data'].message_type = _EVALUATIONDATA
154 | DESCRIPTOR.message_types_by_name['ModelMetrics'] = _MODELMETRICS
155 | DESCRIPTOR.message_types_by_name['EvaluationData'] = _EVALUATIONDATA
156 | _sym_db.RegisterFileDescriptor(DESCRIPTOR)
157 |
158 | ModelMetrics = _reflection.GeneratedProtocolMessageType('ModelMetrics', (_message.Message,), dict(
159 | DESCRIPTOR = _MODELMETRICS,
160 | __module__ = 'model_metrics_pb2'
161 | # @@protoc_insertion_point(class_scope:nasbench.ModelMetrics)
162 | ))
163 | _sym_db.RegisterMessage(ModelMetrics)
164 |
165 | EvaluationData = _reflection.GeneratedProtocolMessageType('EvaluationData', (_message.Message,), dict(
166 | DESCRIPTOR = _EVALUATIONDATA,
167 | __module__ = 'model_metrics_pb2'
168 | # @@protoc_insertion_point(class_scope:nasbench.EvaluationData)
169 | ))
170 | _sym_db.RegisterMessage(EvaluationData)
171 |
172 |
173 | # @@protoc_insertion_point(module_scope)
174 |
--------------------------------------------------------------------------------
/nasbench/lib/training_time.py:
--------------------------------------------------------------------------------
1 | # Copyright 2019 The Google Research Authors.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Tools to measure and limit the training time of a TF model."""
16 |
17 | from __future__ import absolute_import
18 | from __future__ import division
19 | from __future__ import print_function
20 |
21 | import collections
22 | import tensorflow as tf
23 |
24 | # Name of scope where to put all timing-related ops and variables.
25 | _SCOPE_NAME = 'timing'
26 |
27 | # Variables names:
28 | _START_VAR = 'start_timestamp'
29 | _STEPS_VAR = 'steps'
30 | _PREV_VAR = 'previous_time'
31 | _TOTAL_VAR = 'total_time'
32 |
33 | # The name of the TF variable that will hold the total training time so far.
34 | # Get with estimator.get_variable_value(TOTAL_TIME_NAME) after
35 | # running estimator.train(). Note that this time includes the time spent in
36 | # previous calls to train() as well.
37 | TOTAL_TIME_NAME = '%s/%s' % (_SCOPE_NAME, _TOTAL_VAR)
38 |
39 | # We have a fixed temporal precision of one millisecond.
40 | # We used fixed precision to represent seconds since the epoch, as a tf.int64,
41 | # because tf.float32 lacks precision for large values and tf.float64 is not
42 | # supported on TPU.
43 | _INTERNAL_TIME_PRECISION = 1000
44 |
45 |
46 | def _seconds_to_internal_time(seconds):
47 | """Converts seconds to fixed-precision time."""
48 | return tf.to_int64(tf.round(seconds * _INTERNAL_TIME_PRECISION))
49 |
50 |
51 | def _internal_time_to_seconds(internal_time):
52 | """Converts fixed-precision time to seconds."""
53 | return tf.to_float(internal_time / _INTERNAL_TIME_PRECISION)
54 |
55 |
56 | Timing = collections.namedtuple( # pylint: disable=g-bad-name
57 | 'Timing',
58 | [
59 | # A SessionRunHook instance that must be passed to estimator.train()
60 | # through its `hooks` arg.
61 | 'train_hook',
62 |
63 | # A CheckpointSaverListener instance. This must be passed to
64 | # estimator.train() through its `saving_listeners` arg if and only if
65 | # checkpoints are being saved.
66 | 'saving_listener',
67 | ])
68 |
69 |
70 | def limit(max_train_secs=None):
71 | """Provides hooks and ops to measure/limit the training time of a model.
72 |
73 | This is done by direct measurement of the time spent on training steps. It
74 | excludes time spent saving checkpoints or due to pre-emptions.
75 |
76 | Args:
77 | max_train_secs: the desired training time limit. It is possible that this
78 | may be exceeded by the time it takes to run 1 step. If None, training will
79 | not be limited by time but timing variables will still be created.
80 |
81 | Returns:
82 | A Timing named tuple.
83 | """
84 | train_hook = _TimingRunHook(max_train_secs)
85 | saving_listener = _TimingSaverListener()
86 | return Timing(train_hook=train_hook, saving_listener=saving_listener)
87 |
88 |
89 | def get_total_time():
90 | """Returns the timing/total_time variable, regardless of current scope.
91 |
92 | You may need to call force_create_timing_vars() first, or else there is a risk
93 | that you may try to retrieve a variable that doesn't yet exist.
94 |
95 | Returns:
96 | A TF Variable.
97 |
98 | Raises:
99 | RuntimeError: if the variable has not been created yet.
100 | """
101 | timing_vars = _get_or_create_timing_vars()
102 | return timing_vars.total_time
103 |
104 |
105 | _TimingVars = collections.namedtuple( # pylint: disable=g-bad-name
106 | '_TimingVars',
107 | [
108 | # TF variable to be used to store the timestamp (in seconds) of the
109 | # first training step after the last checkpoint save (or the first
110 | # training step ever if no save has happened yet). -1 means no steps
111 | # have been run since the last checkpoint save.
112 | 'start_timestamp',
113 |
114 | # TF variable to be used to store the number of steps since the last
115 | # checkpoint save (or the beginning of training if no save has happened
116 | # yet).
117 | 'steps',
118 |
119 | # TF variable to be used to store the training time up to the last
120 | # checkpoint saved.
121 | 'previous_time',
122 |
123 | # TF variable to be used to accumulate the total training time up
124 | # to the last step run. This time will not include gaps resulting from
125 | # checkpoint saving or pre-emptions.
126 | 'total_time',
127 | ])
128 |
129 |
130 | class _TimingRunHook(tf.train.SessionRunHook):
131 | """Hook to stop the training after a certain amount of time."""
132 |
133 | def __init__(self, max_train_secs=None):
134 | """Initializes the instance.
135 |
136 | Args:
137 | max_train_secs: the maximum number of seconds to train for. If None,
138 | training will not be limited by time.
139 | """
140 | self._max_train_secs = max_train_secs
141 |
142 | def begin(self):
143 | with tf.name_scope(_SCOPE_NAME):
144 | # See _get_or_create_timing_vars for the definitions of these variables.
145 | timing_vars = _get_or_create_timing_vars()
146 |
147 | # An op to produce a tensor with the latest timestamp.
148 | self._end_op = _seconds_to_internal_time(tf.timestamp(name='end'))
149 |
150 | # An op to update the timing_vars.start_timestamp variable.
151 | self._start_op = tf.cond(
152 | pred=tf.equal(timing_vars.steps, 0),
153 | true_fn=lambda: timing_vars.start_timestamp.assign(self._end_op),
154 | false_fn=lambda: timing_vars.start_timestamp)
155 |
156 | # An op to update the step.
157 | with tf.control_dependencies([self._start_op]):
158 | self._step_op = timing_vars.steps.assign_add(1)
159 |
160 | # An op to compute the timing_vars.total_time variable.
161 | self._total_op = timing_vars.total_time.assign(
162 | timing_vars.previous_time +
163 | _internal_time_to_seconds(self._end_op - self._start_op))
164 |
165 | def before_run(self, run_context):
166 | return tf.train.SessionRunArgs([self._total_op, self._step_op])
167 |
168 | def after_run(self, run_context, run_values):
169 | total_time, _ = run_values.results
170 | if self._max_train_secs and total_time > self._max_train_secs:
171 | run_context.request_stop()
172 |
173 |
174 | class _TimingSaverListener(tf.train.CheckpointSaverListener):
175 | """Saving listener to store the train time up to the last checkpoint save."""
176 |
177 | def begin(self):
178 | with tf.name_scope(_SCOPE_NAME):
179 | timing_vars = _get_or_create_timing_vars()
180 |
181 | # An op to update the timing_vars.previous_time variable.
182 | self._prev_op = timing_vars.previous_time.assign(timing_vars.total_time)
183 |
184 | # Marks that timing_vars.start_timestamp should be reset in the next step.
185 | self._reset_steps_op = timing_vars.steps.assign(0)
186 |
187 | def before_save(self, session, global_step_value):
188 | session.run(self._prev_op)
189 |
190 | def after_save(self, session, global_step_value):
191 | session.run(self._reset_steps_op)
192 |
193 |
194 | def _get_or_create_timing_vars():
195 | """Creates variables used to measure training time.
196 |
197 | Returns:
198 | A _TimingVars named tuple.
199 | """
200 | # We always create the timing variables at root_scope / _SCOPE_NAME,
201 | # regardless of the scope from where this is called.
202 | root_scope = tf.get_variable_scope()
203 | with tf.variable_scope(root_scope, reuse=tf.AUTO_REUSE):
204 | with tf.variable_scope(_SCOPE_NAME, reuse=tf.AUTO_REUSE):
205 | start_timestamp = tf.get_variable(
206 | _START_VAR,
207 | shape=[],
208 | dtype=tf.int64,
209 | initializer=tf.constant_initializer(-1),
210 | trainable=False)
211 | steps = tf.get_variable(
212 | _STEPS_VAR,
213 | shape=[],
214 | dtype=tf.int64,
215 | initializer=tf.constant_initializer(0),
216 | trainable=False)
217 | previous_time = tf.get_variable(
218 | _PREV_VAR,
219 | shape=[],
220 | dtype=tf.float32,
221 | initializer=tf.constant_initializer(0.0),
222 | trainable=False)
223 | total_time = tf.get_variable(
224 | _TOTAL_VAR,
225 | shape=[],
226 | dtype=tf.float32,
227 | initializer=tf.constant_initializer(0.0),
228 | trainable=False)
229 | return _TimingVars(
230 | start_timestamp=start_timestamp,
231 | steps=steps,
232 | previous_time=previous_time,
233 | total_time=total_time)
234 |
--------------------------------------------------------------------------------
/nasbench/scripts/run_evaluation.py:
--------------------------------------------------------------------------------
1 | # Copyright 2019 The Google Research Authors.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Script for training a large number of networks across multiple workers.
16 |
17 | Every process running this script is assigned a monotonically increasing
18 | worker_id starting at 0 to the total_workers (exclusive). Each full
19 | training-and-evaluation of a module counts as a single "work unit" with repeated
20 | runs being different units. Work units are assigned a monotonically increasing
21 | index and each worker only computes the work units that have the same index
22 | modulo total_workers as the worker_id.
23 |
24 | For example, for 3 models, each with 3 repeats, and 4 total workers:
25 | Model Number || 1 2 3 1 2 3 1 2 3
26 | Repeat Number || 1 1 1 2 2 2 3 3 3
27 | Work Unit Index || 0 1 2 3 4 5 6 7 8
28 | Assigned Worker || 0 1 2 3 0 1 2 3 1
29 |
30 | i.e. worker_id 0 will compute [model1-repeat1, model2-repeat2, model3-repeat3],
31 | worker_id 1 will compute [model2-repeat1, model3-repeat2], etc...
32 |
33 | --worker_id_offset is provided to allow launching workers in multiple flocks and
34 | is added to the --worker_id flag which is assumed to start at 0 for each new
35 | flock. --total_workers should be the total number of workers across all flocks.
36 |
37 | For basic failure recovery, each worker stores a text file with the current work
38 | unit index it is computing. Upon restarting, workers will resume at the
39 | beginning of the work unit index inside the recovery file if it exists.
40 | """
41 |
42 | from __future__ import absolute_import
43 | from __future__ import division
44 | from __future__ import print_function
45 |
46 | import json
47 | import os
48 | import re
49 |
50 | from absl import app
51 | from absl import flags
52 | from nasbench.lib import config as _config
53 | from nasbench.lib import evaluate
54 | from nasbench.lib import model_metrics_pb2
55 | from nasbench.lib import model_spec
56 | import numpy as np
57 | import tensorflow as tf
58 |
59 |
60 | flags.DEFINE_string('models_file', '',
61 | 'JSON file containing models.')
62 | flags.DEFINE_string('remainders_file', '',
63 | 'JSON file containing list of remainders as tuples of'
64 | ' (module hash, repeat num). If provided, only the runs in'
65 | ' the list will be evaluated, otherwise, all models inside'
66 | ' models_file will be evaluated.')
67 | flags.DEFINE_string('model_id_regex', '^',
68 | 'Regex of models to train. Model IDs are MD5 hashes'
69 | ' which match ([a-f0-9]{32}).')
70 | flags.DEFINE_string('output_dir', '', 'Base output directory.')
71 | flags.DEFINE_integer('worker_id', 0,
72 | 'Worker ID within this flock, starting at 0.')
73 | flags.DEFINE_integer('worker_id_offset', 0,
74 | 'Worker ID offset added.')
75 | flags.DEFINE_integer('total_workers', 1,
76 | 'Total number of workers, across all flocks.')
77 | FLAGS = flags.FLAGS
78 |
79 | CHECKPOINT_PREFIX = 'model.ckpt'
80 | RESULTS_FILE = 'results.json'
81 | # Checkpoint 1 is a side-effect of pre-initializing the model weights and can be
82 | # deleted during the clean-up step.
83 | CHECKPOINT_1_PREFIX = 'model.ckpt-1.'
84 |
85 |
86 | class NumpyEncoder(json.JSONEncoder):
87 | """Converts numpy objects to JSON-serializable format."""
88 |
89 | def default(self, obj):
90 | if isinstance(obj, np.ndarray):
91 | # Matrices converted to nested lists
92 | return obj.tolist()
93 | elif isinstance(obj, np.generic):
94 | # Scalars converted to closest Python type
95 | return np.asscalar(obj)
96 | return json.JSONEncoder.default(self, obj)
97 |
98 |
99 | class Evaluator(object):
100 | """Manages evaluating a subset of the total models."""
101 |
102 | def __init__(self,
103 | models_file,
104 | output_dir,
105 | worker_id=0,
106 | total_workers=1,
107 | model_id_regex='^'):
108 | self.config = _config.build_config()
109 | with tf.gfile.Open(models_file) as f:
110 | self.models = json.load(f)
111 |
112 | self.remainders = None
113 | self.ordered_keys = None
114 |
115 | if FLAGS.remainders_file:
116 | # Run only the modules and repeat numbers specified
117 | with tf.gfile.Open(FLAGS.remainders_file) as f:
118 | self.remainders = json.load(f)
119 | self.remainders = sorted(self.remainders)
120 | self.num_models = len(self.remainders)
121 | self.total_work_units = self.num_models
122 | else:
123 | # Filter keys to only those that fit the regex and order them so all
124 | # workers see a canonical ordering.
125 | regex = re.compile(model_id_regex)
126 | evaluated_keys = [key for key in self.models.keys() if regex.match(key)]
127 | self.ordered_keys = sorted(evaluated_keys)
128 | self.num_models = len(self.ordered_keys)
129 | self.total_work_units = self.num_models * self.config['num_repeats']
130 |
131 | self.total_workers = total_workers
132 |
133 | # If the worker is recovering from a restart, figure out where to restart
134 | worker_recovery_dir = os.path.join(output_dir, '_recovery')
135 | tf.gfile.MakeDirs(worker_recovery_dir) # Silently succeeds if exists
136 | self.recovery_file = os.path.join(worker_recovery_dir, str(worker_id))
137 | if tf.gfile.Exists(self.recovery_file):
138 | with tf.gfile.Open(self.recovery_file) as f:
139 | self.current_index = int(f.read())
140 | else:
141 | self.current_index = worker_id
142 | with tf.gfile.Open(self.recovery_file, 'w') as f:
143 | f.write(str(self.current_index))
144 |
145 | assert self.current_index % self.total_workers == worker_id
146 | self.output_dir = output_dir
147 |
148 | def run_evaluation(self):
149 | """Runs the worker evaluation loop."""
150 | while self.current_index < self.total_work_units:
151 | # Perform the expensive evaluation of the model at the current index
152 | self._evaluate_work_unit(self.current_index)
153 |
154 | self.current_index += self.total_workers
155 | with tf.gfile.Open(self.recovery_file, 'w') as f:
156 | f.write(str(self.current_index))
157 |
158 | def _evaluate_work_unit(self, index):
159 | """Runs the evaluation of the model at the specified index.
160 |
161 | The index records the current index of the work unit being evaluated. Each
162 | worker will only compute the work units with index modulo total_workers
163 | equal to the worker_id.
164 |
165 | Args:
166 | index: int index into total work units.
167 | """
168 | if self.remainders:
169 | assert self.ordered_keys is None
170 | model_id = self.remainders[index][0]
171 | model_repeat = self.remainders[index][1]
172 | else:
173 | model_id = self.ordered_keys[index % self.num_models]
174 | model_repeat = index // self.num_models + 1
175 |
176 | matrix, labels = self.models[model_id]
177 | matrix = np.array(matrix)
178 |
179 | # Re-label to config['available_ops']
180 | labels = (['input'] +
181 | [self.config['available_ops'][lab] for lab in labels[1:-1]] +
182 | ['output'])
183 | spec = model_spec.ModelSpec(matrix, labels)
184 | assert spec.valid_spec
185 | assert np.sum(spec.matrix) <= self.config['max_edges']
186 |
187 | # Split the directory into 16^2 roughly equal subdirectories
188 | model_dir = os.path.join(self.output_dir,
189 | model_id[:2],
190 | model_id,
191 | 'repeat_%d' % model_repeat)
192 | try:
193 | meta = evaluate.train_and_evaluate(spec, self.config, model_dir)
194 | except evaluate.AbortError:
195 | # After hitting the retry limit, the job will continue to the next work
196 | # unit. These failed jobs may need to be re-run at a later point.
197 | return
198 |
199 | # Write data to model_dir
200 | output_file = os.path.join(model_dir, RESULTS_FILE)
201 | with tf.gfile.Open(output_file, 'w') as f:
202 | json.dump(meta, f, cls=NumpyEncoder)
203 |
204 | # Delete some files to reclaim space
205 | self._clean_model_dir(model_dir)
206 |
207 | def _clean_model_dir(self, model_dir):
208 | """Cleans the output model directory to reclaim disk space."""
209 | saved_prefixes = [CHECKPOINT_PREFIX, RESULTS_FILE]
210 | all_files = tf.gfile.ListDirectory(model_dir)
211 | files_to_keep = set()
212 | for filename in all_files:
213 | for prefix in saved_prefixes:
214 | if (filename.startswith(prefix) and
215 | not filename.startswith(CHECKPOINT_1_PREFIX)):
216 | files_to_keep.add(filename)
217 |
218 | for filename in all_files:
219 | if filename not in files_to_keep:
220 | full_filename = os.path.join(model_dir, filename)
221 | if tf.gfile.IsDirectory(full_filename):
222 | tf.gfile.DeleteRecursively(full_filename)
223 | else:
224 | tf.gfile.Remove(full_filename)
225 |
226 |
227 | def main(args):
228 | del args # Unused
229 | worker_id = FLAGS.worker_id + FLAGS.worker_id_offset
230 | evaluator = Evaluator(
231 | models_file=FLAGS.models_file,
232 | output_dir=FLAGS.output_dir,
233 | worker_id=worker_id,
234 | total_workers=FLAGS.total_workers,
235 | model_id_regex=FLAGS.model_id_regex)
236 | evaluator.run_evaluation()
237 |
238 |
239 | if __name__ == '__main__':
240 | app.run(main)
241 |
--------------------------------------------------------------------------------
/nasbench/tests/run_evaluation_test.py:
--------------------------------------------------------------------------------
1 | # Copyright 2019 The Google Research Authors.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Unit tests for scripts/run_evaluation.py."""
16 |
17 | from __future__ import absolute_import
18 | from __future__ import division
19 | from __future__ import print_function
20 |
21 | import json
22 | import os
23 | import tempfile
24 |
25 | from absl.testing import flagsaver
26 | from nasbench.scripts import run_evaluation
27 | import tensorflow as tf
28 |
29 |
30 |
31 | class RunEvaluationTest(tf.test.TestCase):
32 |
33 | def setUp(self):
34 | """Set up files and directories that are expected by run_evaluation."""
35 | # Create temp directory for output files
36 | self.output_dir = tempfile.mkdtemp()
37 | self.models_file = os.path.join(self.output_dir, 'models_file.json')
38 |
39 | self.toy_data = {
40 | 'abc': ([[0, 1, 1], [0, 0, 1], [0, 0, 0]], [-1, 0, -2]),
41 | 'abd': ([[0, 1, 0], [0, 0, 1], [0, 0, 0]], [-1, 0, -2]),
42 | 'abe': ([[0, 0, 1], [0, 0, 0], [0, 0, 0]], [-1, 0, -2]),
43 | }
44 |
45 | with tf.gfile.Open(self.models_file, 'w') as f:
46 | json.dump(self.toy_data, f)
47 |
48 | # Create files & directories which are normally created by
49 | # evaluate.train_and_evaluate but have been mocked out.
50 | for model_id in self.toy_data:
51 | eval_dir = os.path.join(self.output_dir, 'ab', model_id, 'repeat_1')
52 | tf.gfile.MakeDirs(eval_dir)
53 | run_evaluation.FLAGS.train_data_files = 'unused'
54 | run_evaluation.FLAGS.valid_data_file = 'unused'
55 | run_evaluation.FLAGS.test_data_file = 'unused'
56 | run_evaluation.FLAGS.num_repeats = 1
57 |
58 | @tf.test.mock.patch.object(run_evaluation, 'evaluate')
59 | def test_evaluate_single_worker(self, mock_eval):
60 | """Tests single worker code path."""
61 | mock_eval.train_and_evaluate.return_value = 'unused_output'
62 | evaluator = run_evaluation.Evaluator(
63 | self.models_file, self.output_dir)
64 | evaluator.run_evaluation()
65 |
66 | expected_dir = os.path.join(self.output_dir, 'ab')
67 | mock_eval.train_and_evaluate.assert_has_calls([
68 | tf.test.mock.call(tf.test.mock.ANY, tf.test.mock.ANY,
69 | os.path.join(expected_dir, 'abc', 'repeat_1')),
70 | tf.test.mock.call(tf.test.mock.ANY, tf.test.mock.ANY,
71 | os.path.join(expected_dir, 'abd', 'repeat_1')),
72 | tf.test.mock.call(tf.test.mock.ANY, tf.test.mock.ANY,
73 | os.path.join(expected_dir, 'abe', 'repeat_1'))])
74 |
75 | for model_id in self.toy_data:
76 | self.assertTrue(tf.gfile.Exists(
77 | os.path.join(expected_dir, model_id, 'repeat_1', 'results.json')))
78 |
79 | @tf.test.mock.patch.object(run_evaluation, 'evaluate')
80 | def test_evaluate_multi_worker_0(self, mock_eval):
81 | """Tests multi worker code path for worker 0."""
82 | mock_eval.train_and_evaluate.return_value = 'unused_output'
83 | evaluator = run_evaluation.Evaluator(
84 | self.models_file, self.output_dir, worker_id=0, total_workers=2)
85 | evaluator.run_evaluation()
86 |
87 | expected_dir = os.path.join(self.output_dir, 'ab')
88 | mock_eval.train_and_evaluate.assert_has_calls([
89 | tf.test.mock.call(tf.test.mock.ANY, tf.test.mock.ANY,
90 | os.path.join(expected_dir, 'abc', 'repeat_1')),
91 | tf.test.mock.call(tf.test.mock.ANY, tf.test.mock.ANY,
92 | os.path.join(expected_dir, 'abe', 'repeat_1'))])
93 |
94 | for model_id in ['abc', 'abe']:
95 | self.assertTrue(tf.gfile.Exists(
96 | os.path.join(expected_dir, model_id, 'repeat_1', 'results.json')))
97 |
98 | @tf.test.mock.patch.object(run_evaluation, 'evaluate')
99 | def test_evaluate_multi_worker_1(self, mock_eval):
100 | """Tests multi worker code path for worker 1."""
101 | mock_eval.train_and_evaluate.return_value = 'unused_output'
102 | evaluator = run_evaluation.Evaluator(
103 | self.models_file, self.output_dir, worker_id=1, total_workers=2)
104 | evaluator.run_evaluation()
105 |
106 | expected_dir = os.path.join(self.output_dir, 'ab')
107 | mock_eval.train_and_evaluate.assert_has_calls([
108 | tf.test.mock.call(tf.test.mock.ANY, tf.test.mock.ANY,
109 | os.path.join(expected_dir, 'abd', 'repeat_1'))])
110 |
111 | self.assertTrue(tf.gfile.Exists(
112 | os.path.join(expected_dir, 'abd', 'repeat_1', 'results.json')))
113 |
114 | @tf.test.mock.patch.object(run_evaluation, 'evaluate')
115 | def test_evaluate_regex(self, mock_eval):
116 | """Tests regex filters models."""
117 | mock_eval.train_and_evaluate.return_value = 'unused_output'
118 | evaluator = run_evaluation.Evaluator(
119 | self.models_file, self.output_dir, model_id_regex='^ab(d|e)')
120 | evaluator.run_evaluation()
121 |
122 | expected_dir = os.path.join(self.output_dir, 'ab')
123 | mock_eval.train_and_evaluate.assert_has_calls([
124 | tf.test.mock.call(tf.test.mock.ANY, tf.test.mock.ANY,
125 | os.path.join(expected_dir, 'abd', 'repeat_1')),
126 | tf.test.mock.call(tf.test.mock.ANY, tf.test.mock.ANY,
127 | os.path.join(expected_dir, 'abe', 'repeat_1'))])
128 |
129 | for model_id in ['abd', 'abe']:
130 | self.assertTrue(tf.gfile.Exists(
131 | os.path.join(expected_dir, model_id, 'repeat_1', 'results.json')))
132 |
133 | @tf.test.mock.patch.object(run_evaluation, 'evaluate')
134 | def test_evaluate_repeat(self, mock_eval):
135 | """Tests evaluate with repeats."""
136 | mock_eval.train_and_evaluate.return_value = 'unused_output'
137 |
138 | # Create extra directories not created in setUp for repeat_2
139 | for model_id in self.toy_data:
140 | eval_dir = os.path.join(self.output_dir, 'ab', model_id, 'repeat_2')
141 | tf.gfile.MakeDirs(eval_dir)
142 |
143 | with flagsaver.flagsaver(num_repeats=2):
144 | evaluator = run_evaluation.Evaluator(
145 | self.models_file, self.output_dir)
146 | evaluator.run_evaluation()
147 |
148 | expected_dir = os.path.join(self.output_dir, 'ab')
149 | mock_eval.train_and_evaluate.assert_has_calls([
150 | tf.test.mock.call(tf.test.mock.ANY, tf.test.mock.ANY,
151 | os.path.join(expected_dir, 'abc', 'repeat_1')),
152 | tf.test.mock.call(tf.test.mock.ANY, tf.test.mock.ANY,
153 | os.path.join(expected_dir, 'abd', 'repeat_1')),
154 | tf.test.mock.call(tf.test.mock.ANY, tf.test.mock.ANY,
155 | os.path.join(expected_dir, 'abe', 'repeat_1')),
156 | tf.test.mock.call(tf.test.mock.ANY, tf.test.mock.ANY,
157 | os.path.join(expected_dir, 'abc', 'repeat_2')),
158 | tf.test.mock.call(tf.test.mock.ANY, tf.test.mock.ANY,
159 | os.path.join(expected_dir, 'abd', 'repeat_2')),
160 | tf.test.mock.call(tf.test.mock.ANY, tf.test.mock.ANY,
161 | os.path.join(expected_dir, 'abe', 'repeat_2'))])
162 |
163 | for model_id in self.toy_data:
164 | for repeat in range(2):
165 | self.assertTrue(tf.gfile.Exists(
166 | os.path.join(expected_dir, model_id,
167 | 'repeat_%d' % (repeat + 1), 'results.json')))
168 |
169 | def test_clean_model_dir(self):
170 | """Tests clean-up of model directory keeps only intended files."""
171 | model_dir = os.path.join(self.output_dir, 'ab', 'abcde', 'repeat_1')
172 | tf.gfile.MakeDirs(model_dir)
173 |
174 | # Write files which will be preserved
175 | preserved_files = ['model.ckpt-0.index',
176 | 'model.ckpt-100.index',
177 | 'results.json']
178 | for filename in preserved_files:
179 | with tf.gfile.Open(os.path.join(model_dir, filename), 'w') as f:
180 | f.write('unused')
181 |
182 | # Write files which will be deleted
183 | for filename in ['checkpoint',
184 | 'events.out.tfevents']:
185 | with tf.gfile.Open(os.path.join(model_dir, filename), 'w') as f:
186 | f.write('unused')
187 |
188 | # Create subdirectory which will be deleted
189 | eval_dir = os.path.join(model_dir, 'eval_dir')
190 | tf.gfile.MakeDirs(eval_dir)
191 | with tf.gfile.Open(os.path.join(eval_dir, 'events.out.tfevents'), 'w') as f:
192 | f.write('unused')
193 |
194 | evaluator = run_evaluation.Evaluator(self.models_file, self.output_dir)
195 | evaluator._clean_model_dir(model_dir)
196 |
197 | # Check only intended files are preserved
198 | remaining_files = tf.gfile.ListDirectory(model_dir)
199 | self.assertItemsEqual(remaining_files, preserved_files)
200 |
201 | @tf.test.mock.patch.object(run_evaluation, 'evaluate')
202 | def test_recovery_file(self, mock_eval):
203 | """Tests that evaluation recovers from restart."""
204 | mock_eval.train_and_evaluate.return_value = 'unused_output'
205 |
206 | # Write recovery file
207 | recovery_dir = os.path.join(self.output_dir, '_recovery')
208 | tf.gfile.MakeDirs(recovery_dir)
209 | with tf.gfile.Open(os.path.join(recovery_dir, '0'), 'w') as f:
210 | f.write('2') # Resume at 3rd entry
211 |
212 | evaluator = run_evaluation.Evaluator(
213 | self.models_file, self.output_dir)
214 | evaluator.run_evaluation()
215 |
216 | expected_dir = os.path.join(self.output_dir, 'ab')
217 | mock_eval.train_and_evaluate.assert_has_calls([
218 | tf.test.mock.call(tf.test.mock.ANY, tf.test.mock.ANY,
219 | os.path.join(expected_dir, 'abe', 'repeat_1'))])
220 |
221 | # Check that only 'abe' was evaluated, 'abc' and 'abe' are skipped due to
222 | # recovery.
223 | call_args = mock_eval.train_and_evaluate.call_args_list
224 | self.assertEqual(len(call_args), 1)
225 |
226 | # Check that recovery file is updated after run
227 | with tf.gfile.Open(evaluator.recovery_file) as f:
228 | new_idx = int(f.read())
229 | self.assertEqual(new_idx, 3)
230 |
231 |
232 | if __name__ == '__main__':
233 | tf.test.main()
234 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 |
2 | Apache License
3 | Version 2.0, January 2004
4 | http://www.apache.org/licenses/
5 |
6 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
7 |
8 | 1. Definitions.
9 |
10 | "License" shall mean the terms and conditions for use, reproduction,
11 | and distribution as defined by Sections 1 through 9 of this document.
12 |
13 | "Licensor" shall mean the copyright owner or entity authorized by
14 | the copyright owner that is granting the License.
15 |
16 | "Legal Entity" shall mean the union of the acting entity and all
17 | other entities that control, are controlled by, or are under common
18 | control with that entity. For the purposes of this definition,
19 | "control" means (i) the power, direct or indirect, to cause the
20 | direction or management of such entity, whether by contract or
21 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
22 | outstanding shares, or (iii) beneficial ownership of such entity.
23 |
24 | "You" (or "Your") shall mean an individual or Legal Entity
25 | exercising permissions granted by this License.
26 |
27 | "Source" form shall mean the preferred form for making modifications,
28 | including but not limited to software source code, documentation
29 | source, and configuration files.
30 |
31 | "Object" form shall mean any form resulting from mechanical
32 | transformation or translation of a Source form, including but
33 | not limited to compiled object code, generated documentation,
34 | and conversions to other media types.
35 |
36 | "Work" shall mean the work of authorship, whether in Source or
37 | Object form, made available under the License, as indicated by a
38 | copyright notice that is included in or attached to the work
39 | (an example is provided in the Appendix below).
40 |
41 | "Derivative Works" shall mean any work, whether in Source or Object
42 | form, that is based on (or derived from) the Work and for which the
43 | editorial revisions, annotations, elaborations, or other modifications
44 | represent, as a whole, an original work of authorship. For the purposes
45 | of this License, Derivative Works shall not include works that remain
46 | separable from, or merely link (or bind by name) to the interfaces of,
47 | the Work and Derivative Works thereof.
48 |
49 | "Contribution" shall mean any work of authorship, including
50 | the original version of the Work and any modifications or additions
51 | to that Work or Derivative Works thereof, that is intentionally
52 | submitted to Licensor for inclusion in the Work by the copyright owner
53 | or by an individual or Legal Entity authorized to submit on behalf of
54 | the copyright owner. For the purposes of this definition, "submitted"
55 | means any form of electronic, verbal, or written communication sent
56 | to the Licensor or its representatives, including but not limited to
57 | communication on electronic mailing lists, source code control systems,
58 | and issue tracking systems that are managed by, or on behalf of, the
59 | Licensor for the purpose of discussing and improving the Work, but
60 | excluding communication that is conspicuously marked or otherwise
61 | designated in writing by the copyright owner as "Not a Contribution."
62 |
63 | "Contributor" shall mean Licensor and any individual or Legal Entity
64 | on behalf of whom a Contribution has been received by Licensor and
65 | subsequently incorporated within the Work.
66 |
67 | 2. Grant of Copyright License. Subject to the terms and conditions of
68 | this License, each Contributor hereby grants to You a perpetual,
69 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
70 | copyright license to reproduce, prepare Derivative Works of,
71 | publicly display, publicly perform, sublicense, and distribute the
72 | Work and such Derivative Works in Source or Object form.
73 |
74 | 3. Grant of Patent License. Subject to the terms and conditions of
75 | this License, each Contributor hereby grants to You a perpetual,
76 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
77 | (except as stated in this section) patent license to make, have made,
78 | use, offer to sell, sell, import, and otherwise transfer the Work,
79 | where such license applies only to those patent claims licensable
80 | by such Contributor that are necessarily infringed by their
81 | Contribution(s) alone or by combination of their Contribution(s)
82 | with the Work to which such Contribution(s) was submitted. If You
83 | institute patent litigation against any entity (including a
84 | cross-claim or counterclaim in a lawsuit) alleging that the Work
85 | or a Contribution incorporated within the Work constitutes direct
86 | or contributory patent infringement, then any patent licenses
87 | granted to You under this License for that Work shall terminate
88 | as of the date such litigation is filed.
89 |
90 | 4. Redistribution. You may reproduce and distribute copies of the
91 | Work or Derivative Works thereof in any medium, with or without
92 | modifications, and in Source or Object form, provided that You
93 | meet the following conditions:
94 |
95 | (a) You must give any other recipients of the Work or
96 | Derivative Works a copy of this License; and
97 |
98 | (b) You must cause any modified files to carry prominent notices
99 | stating that You changed the files; and
100 |
101 | (c) You must retain, in the Source form of any Derivative Works
102 | that You distribute, all copyright, patent, trademark, and
103 | attribution notices from the Source form of the Work,
104 | excluding those notices that do not pertain to any part of
105 | the Derivative Works; and
106 |
107 | (d) If the Work includes a "NOTICE" text file as part of its
108 | distribution, then any Derivative Works that You distribute must
109 | include a readable copy of the attribution notices contained
110 | within such NOTICE file, excluding those notices that do not
111 | pertain to any part of the Derivative Works, in at least one
112 | of the following places: within a NOTICE text file distributed
113 | as part of the Derivative Works; within the Source form or
114 | documentation, if provided along with the Derivative Works; or,
115 | within a display generated by the Derivative Works, if and
116 | wherever such third-party notices normally appear. The contents
117 | of the NOTICE file are for informational purposes only and
118 | do not modify the License. You may add Your own attribution
119 | notices within Derivative Works that You distribute, alongside
120 | or as an addendum to the NOTICE text from the Work, provided
121 | that such additional attribution notices cannot be construed
122 | as modifying the License.
123 |
124 | You may add Your own copyright statement to Your modifications and
125 | may provide additional or different license terms and conditions
126 | for use, reproduction, or distribution of Your modifications, or
127 | for any such Derivative Works as a whole, provided Your use,
128 | reproduction, and distribution of the Work otherwise complies with
129 | the conditions stated in this License.
130 |
131 | 5. Submission of Contributions. Unless You explicitly state otherwise,
132 | any Contribution intentionally submitted for inclusion in the Work
133 | by You to the Licensor shall be under the terms and conditions of
134 | this License, without any additional terms or conditions.
135 | Notwithstanding the above, nothing herein shall supersede or modify
136 | the terms of any separate license agreement you may have executed
137 | with Licensor regarding such Contributions.
138 |
139 | 6. Trademarks. This License does not grant permission to use the trade
140 | names, trademarks, service marks, or product names of the Licensor,
141 | except as required for reasonable and customary use in describing the
142 | origin of the Work and reproducing the content of the NOTICE file.
143 |
144 | 7. Disclaimer of Warranty. Unless required by applicable law or
145 | agreed to in writing, Licensor provides the Work (and each
146 | Contributor provides its Contributions) on an "AS IS" BASIS,
147 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
148 | implied, including, without limitation, any warranties or conditions
149 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
150 | PARTICULAR PURPOSE. You are solely responsible for determining the
151 | appropriateness of using or redistributing the Work and assume any
152 | risks associated with Your exercise of permissions under this License.
153 |
154 | 8. Limitation of Liability. In no event and under no legal theory,
155 | whether in tort (including negligence), contract, or otherwise,
156 | unless required by applicable law (such as deliberate and grossly
157 | negligent acts) or agreed to in writing, shall any Contributor be
158 | liable to You for damages, including any direct, indirect, special,
159 | incidental, or consequential damages of any character arising as a
160 | result of this License or out of the use or inability to use the
161 | Work (including but not limited to damages for loss of goodwill,
162 | work stoppage, computer failure or malfunction, or any and all
163 | other commercial damages or losses), even if such Contributor
164 | has been advised of the possibility of such damages.
165 |
166 | 9. Accepting Warranty or Additional Liability. While redistributing
167 | the Work or Derivative Works thereof, You may choose to offer,
168 | and charge a fee for, acceptance of support, warranty, indemnity,
169 | or other liability obligations and/or rights consistent with this
170 | License. However, in accepting such obligations, You may act only
171 | on Your own behalf and on Your sole responsibility, not on behalf
172 | of any other Contributor, and only if You agree to indemnify,
173 | defend, and hold each Contributor harmless for any liability
174 | incurred by, or claims asserted against, such Contributor by reason
175 | of your accepting any such warranty or additional liability.
176 |
177 | END OF TERMS AND CONDITIONS
178 |
179 | APPENDIX: How to apply the Apache License to your work.
180 |
181 | To apply the Apache License to your work, attach the following
182 | boilerplate notice, with the fields enclosed by brackets "[]"
183 | replaced with your own identifying information. (Don't include
184 | the brackets!) The text should be enclosed in the appropriate
185 | comment syntax for the file format. We also recommend that a
186 | file or class name and description of purpose be included on the
187 | same "printed page" as the copyright notice for easier
188 | identification within third-party archives.
189 |
190 | Copyright [yyyy] [name of copyright owner]
191 |
192 | Licensed under the Apache License, Version 2.0 (the "License");
193 | you may not use this file except in compliance with the License.
194 | You may obtain a copy of the License at
195 |
196 | http://www.apache.org/licenses/LICENSE-2.0
197 |
198 | Unless required by applicable law or agreed to in writing, software
199 | distributed under the License is distributed on an "AS IS" BASIS,
200 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
201 | See the License for the specific language governing permissions and
202 | limitations under the License.
203 |
--------------------------------------------------------------------------------
/nasbench/lib/evaluate.py:
--------------------------------------------------------------------------------
1 | # Copyright 2019 The Google Research Authors.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Performs training and evaluation of the proposed model spec on TPU."""
16 |
17 | from __future__ import absolute_import
18 | from __future__ import division
19 | from __future__ import print_function
20 |
21 | import time
22 |
23 | from nasbench.lib import cifar
24 | from nasbench.lib import model_builder
25 | from nasbench.lib import training_time
26 | import numpy as np
27 | import tensorflow as tf
28 |
29 | VALID_EXCEPTIONS = (
30 | tf.train.NanLossDuringTrainingError, # NaN loss
31 | tf.errors.ResourceExhaustedError, # OOM
32 | tf.errors.InvalidArgumentError, # NaN gradient
33 | tf.errors.DeadlineExceededError, # Timed out
34 | )
35 |
36 |
37 | class AbortError(Exception):
38 | """Signals that evaluation failed for a valid reason."""
39 | pass
40 |
41 |
42 | def train_and_evaluate(spec, config, model_dir):
43 | """Train and evaluate the proposed model.
44 |
45 | This method trains and evaluates the model for the creation of the benchmark
46 | dataset. The default values from the config.py are exactly the values used.
47 |
48 | Args:
49 | spec: ModelSpec object.
50 | config: config dict generated from config.py.
51 | model_dir: directory to store the checkpoint files.
52 |
53 | Returns:
54 | dict containing the evaluation metadata.
55 | """
56 | return _train_and_evaluate_impl(spec, config, model_dir)
57 |
58 |
59 | def augment_and_evaluate(spec, config, model_dir, epochs_per_eval=5):
60 | """Trains the model on the full training set and evaluates on test set.
61 |
62 | "Augment" specifically refers to training the same spec in a larger network on
63 | the full training set. Typically this involves increasing the epoch count,
64 | number of modules/stacks, and changing the LR schedule. These changes should
65 | be made to the config dict before calling this method.
66 |
67 | Note: this method was not used for generating the NAS Benchmark dataset. See
68 | train_and_evaluate instead.
69 |
70 | Args:
71 | spec: ModelSpec object.
72 | config: config dict generated from config.py.
73 | model_dir: directory to store the checkpoint files.
74 | epochs_per_eval: number of epochs per evaluation run. Evaluation is always
75 | run at the very start and end.
76 |
77 | Returns:
78 | dict containing the evaluation metadata.
79 | """
80 | return _augment_and_evaluate_impl(spec, config, model_dir, epochs_per_eval)
81 |
82 |
83 | def _train_and_evaluate_impl(spec, config, model_dir):
84 | """Train and evaluate implementation, see train_and_evaluate docstring."""
85 | evaluator = _TrainAndEvaluator(spec, config, model_dir)
86 | return evaluator.run()
87 |
88 |
89 | class _TrainAndEvaluator(object):
90 | """Runs the training and evaluation."""
91 |
92 | def __init__(self, spec, config, model_dir):
93 | """Initialize evaluator. See train_and_evaluate docstring."""
94 | self.input_train = cifar.CIFARInput('train', config)
95 | self.input_train_eval = cifar.CIFARInput('train_eval', config)
96 | self.input_valid = cifar.CIFARInput('valid', config)
97 | self.input_test = cifar.CIFARInput('test', config)
98 | self.input_sample = cifar.CIFARInput('sample', config)
99 | self.estimator = _create_estimator(spec, config, model_dir,
100 | self.input_train.num_images,
101 | self.input_sample.num_images)
102 |
103 | self.spec = spec
104 | self.config = config
105 | self.model_dir = model_dir
106 |
107 | def run(self):
108 | """Runs training and evaluation."""
109 | attempts = 0
110 | while True:
111 | # Delete everything in the model dir at the start of each attempt
112 | try:
113 | tf.gfile.DeleteRecursively(self.model_dir)
114 | except tf.errors.NotFoundError:
115 | pass
116 | tf.gfile.MakeDirs(self.model_dir)
117 |
118 | try:
119 | # Train
120 | if self.config['train_seconds'] > 0.0:
121 | timing = training_time.limit(self.config['train_seconds'])
122 | else:
123 | timing = training_time.limit(None)
124 |
125 | evaluations = map(float, self.config['intermediate_evaluations'])
126 | if not evaluations or evaluations[-1] != 1.0:
127 | evaluations.append(1.0)
128 | assert evaluations == sorted(evaluations)
129 |
130 | evaluation_results = []
131 | start_time = time.time()
132 |
133 | # Train for 1 step with 0 LR to initialize the weights, then evaluate
134 | # once at the start for completeness, accuracies expected to be around
135 | # random selection. Note that batch norm moving averages change during
136 | # the step but the trainable weights do not.
137 | self.estimator.train(
138 | input_fn=self.input_train.input_fn,
139 | max_steps=1,
140 | hooks=[timing.train_hook],
141 | saving_listeners=[timing.saving_listener])
142 | evaluation_results.append(self._evaluate_all(0.0, 0))
143 |
144 | for next_evaluation in evaluations:
145 | epoch = next_evaluation * self.config['train_epochs']
146 | train_steps = int(epoch * self.input_train.num_images /
147 | self.config['batch_size'])
148 | self.estimator.train(
149 | input_fn=self.input_train.input_fn,
150 | max_steps=train_steps,
151 | hooks=[timing.train_hook],
152 | saving_listeners=[timing.saving_listener])
153 |
154 | evaluation_results.append(self._evaluate_all(epoch, train_steps))
155 |
156 | all_time = time.time() - start_time
157 | break # Break from retry loop on success
158 | except VALID_EXCEPTIONS as e: # pylint: disable=catching-non-exception
159 | attempts += 1
160 | tf.logging.warning(str(e))
161 | if attempts >= self.config['max_attempts']:
162 | raise AbortError(str(e))
163 |
164 | metadata = {
165 | 'trainable_params': _get_param_count(self.model_dir),
166 | 'total_time': all_time, # includes eval and other metric time
167 | 'evaluation_results': evaluation_results,
168 | }
169 |
170 | return metadata
171 |
172 | def _evaluate_all(self, epochs, steps):
173 | """Runs all the evaluations."""
174 | train_accuracy = _evaluate(self.estimator, self.input_train_eval,
175 | self.config, name='train')
176 | valid_accuracy = _evaluate(self.estimator, self.input_valid,
177 | self.config, name='valid')
178 | test_accuracy = _evaluate(self.estimator, self.input_test,
179 | self.config, name='test')
180 | train_time = self.estimator.get_variable_value(
181 | training_time.TOTAL_TIME_NAME)
182 |
183 | now = time.time()
184 | sample_metrics = self._compute_sample_metrics()
185 | predict_time = time.time() - now
186 |
187 | return {
188 | 'epochs': epochs,
189 | 'training_time': train_time,
190 | 'training_steps': steps,
191 | 'train_accuracy': train_accuracy,
192 | 'validation_accuracy': valid_accuracy,
193 | 'test_accuracy': test_accuracy,
194 | 'sample_metrics': sample_metrics,
195 | 'predict_time': predict_time,
196 | }
197 |
198 | def _compute_sample_metrics(self):
199 | """Computes the metrics on a fixed batch."""
200 | sample_metrics = self.estimator.predict(
201 | input_fn=self.input_sample.input_fn, yield_single_examples=False).next()
202 |
203 | # Fix the extra batch dimension added by PREDICT
204 | for metric in sample_metrics:
205 | if metric in ['logits', 'input_grad_norm']:
206 | # Batch-shaped tensors take first batch
207 | sample_metrics[metric] = (
208 | sample_metrics[metric][:self.input_sample.num_images, Ellipsis])
209 | else:
210 | # Other tensors remove batch dimension
211 | sample_metrics[metric] = sample_metrics[metric][0, Ellipsis]
212 |
213 | return sample_metrics
214 |
215 |
216 | def _augment_and_evaluate_impl(spec, config, model_dir, epochs_per_eval=5):
217 | """Augment and evaluate implementation, see augment_and_evaluate docstring."""
218 | input_augment, input_test = [
219 | cifar.CIFARInput(m, config)
220 | for m in ['augment', 'test']]
221 | estimator = _create_estimator(spec, config, model_dir,
222 | input_augment.num_images)
223 |
224 | if config['train_seconds'] > 0.0:
225 | timing = training_time.limit(config['train_seconds'])
226 | else:
227 | timing = training_time.limit(None)
228 |
229 | steps_per_epoch = input_augment.num_images / config['batch_size'] # float
230 | ckpt = tf.train.latest_checkpoint(model_dir)
231 | if not ckpt:
232 | current_step = 0
233 | else:
234 | current_step = int(ckpt.split('-')[-1])
235 | max_steps = int(config['train_epochs'] * steps_per_epoch)
236 |
237 | while current_step < max_steps:
238 | next_step = current_step + int(epochs_per_eval * steps_per_epoch)
239 | next_step = min(next_step, max_steps)
240 | estimator.train(
241 | input_fn=input_augment.input_fn,
242 | max_steps=next_step,
243 | hooks=[timing.train_hook],
244 | saving_listeners=[timing.saving_listener])
245 | current_step = next_step
246 |
247 | test_accuracy = _evaluate(estimator, input_test, config)
248 |
249 | metadata = {
250 | 'trainable_params': _get_param_count(model_dir),
251 | 'test_accuracy': test_accuracy,
252 | }
253 |
254 | return metadata
255 |
256 |
257 | def _create_estimator(spec, config, model_dir,
258 | num_train_images, num_sample_images=None):
259 | """Creates the TPUEstimator object."""
260 | # Estimator will save a checkpoint at the end of every train() call. Disable
261 | # automatic checkpoints by setting the time interval between checkpoints to
262 | # a very large value.
263 | run_config = tf.contrib.tpu.RunConfig(
264 | model_dir=model_dir,
265 | keep_checkpoint_max=3, # Keeps ckpt at start, halfway, and end
266 | save_checkpoints_secs=2**30,
267 | tpu_config=tf.contrib.tpu.TPUConfig(
268 | iterations_per_loop=config['tpu_iterations_per_loop'],
269 | num_shards=config['tpu_num_shards']))
270 |
271 | # This is a hack to allow PREDICT on a fixed batch on TPU. By replicating the
272 | # batch by the number of shards, this ensures each TPU core operates on the
273 | # entire fixed batch.
274 | if num_sample_images and config['use_tpu']:
275 | num_sample_images *= config['tpu_num_shards']
276 |
277 | estimator = tf.contrib.tpu.TPUEstimator(
278 | use_tpu=config['use_tpu'],
279 | model_fn=model_builder.build_model_fn(
280 | spec, config, num_train_images),
281 | config=run_config,
282 | train_batch_size=config['batch_size'],
283 | eval_batch_size=config['batch_size'],
284 | predict_batch_size=num_sample_images)
285 |
286 | return estimator
287 |
288 |
289 | def _evaluate(estimator, input_data, config, name=None):
290 | """Evaluate the estimator on the input data."""
291 | steps = input_data.num_images // config['batch_size']
292 | results = estimator.evaluate(
293 | input_fn=input_data.input_fn,
294 | steps=steps,
295 | name=name)
296 | return results['accuracy']
297 |
298 |
299 | def _get_param_count(model_dir):
300 | """Get trainable param count from the model directory."""
301 | tf.reset_default_graph()
302 | checkpoint = tf.train.get_checkpoint_state(model_dir)
303 | with tf.Session() as sess:
304 | saver = tf.train.import_meta_graph(
305 | checkpoint.model_checkpoint_path + '.meta')
306 | saver.restore(sess, checkpoint.model_checkpoint_path)
307 | params = np.sum([np.prod(v.get_shape().as_list())
308 | for v in tf.trainable_variables()])
309 |
310 | return params
311 |
312 |
--------------------------------------------------------------------------------
/nasbench/api.py:
--------------------------------------------------------------------------------
1 | # Copyright 2019 The Google Research Authors.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """User interface for the NAS Benchmark dataset.
16 |
17 | Before using this API, download the data files from the links in the README.
18 |
19 | Usage:
20 | # Load the data from file (this will take some time)
21 | nasbench = api.NASBench('/path/to/nasbench.tfrecord')
22 |
23 | # Create an Inception-like module (5x5 convolution replaced with two 3x3
24 | # convolutions).
25 | model_spec = api.ModelSpec(
26 | # Adjacency matrix of the module
27 | matrix=[[0, 1, 1, 1, 0, 1, 0], # input layer
28 | [0, 0, 0, 0, 0, 0, 1], # 1x1 conv
29 | [0, 0, 0, 0, 0, 0, 1], # 3x3 conv
30 | [0, 0, 0, 0, 1, 0, 0], # 5x5 conv (replaced by two 3x3's)
31 | [0, 0, 0, 0, 0, 0, 1], # 5x5 conv (replaced by two 3x3's)
32 | [0, 0, 0, 0, 0, 0, 1], # 3x3 max-pool
33 | [0, 0, 0, 0, 0, 0, 0]], # output layer
34 | # Operations at the vertices of the module, matches order of matrix
35 | ops=[INPUT, CONV1X1, CONV3X3, CONV3X3, CONV3X3, MAXPOOL3X3, OUTPUT])
36 |
37 |
38 | # Query this model from dataset
39 | data = nasbench.query(model_spec)
40 |
41 | Adjacency matrices are expected to be upper-triangular 0-1 matrices within the
42 | defined search space (7 vertices, 9 edges, 3 allowed ops). The first and last
43 | operations must be 'input' and 'output'. The other operations should be from
44 | config['available_ops']. Currently, the available operations are:
45 | CONV3X3 = "conv3x3-bn-relu"
46 | CONV1X1 = "conv1x1-bn-relu"
47 | MAXPOOL3X3 = "maxpool3x3"
48 |
49 | When querying a spec, the spec will first be automatically pruned (removing
50 | unused vertices and edges along with ops). If the pruned spec is still out of
51 | the search space, an OutOfDomainError will be raised, otherwise the data is
52 | returned.
53 |
54 | The returned data object is a dictionary with the following keys:
55 | - module_adjacency: numpy array for the adjacency matrix
56 | - module_operations: list of operation labels
57 | - trainable_parameters: number of trainable parameters in the model
58 | - training_time: the total training time in seconds up to this point
59 | - train_accuracy: training accuracy
60 | - validation_accuracy: validation_accuracy
61 | - test_accuracy: testing accuracy
62 |
63 | Instead of querying the dataset for a single run of a model, it is also possible
64 | to retrieve all metrics for a given spec, using:
65 |
66 | fixed_stats, computed_stats = nasbench.get_metrics_from_spec(model_spec)
67 |
68 | The fixed_stats is a dictionary with the keys:
69 | - module_adjacency
70 | - module_operations
71 | - trainable_parameters
72 |
73 | The computed_stats is a dictionary from epoch count to a list of metric
74 | dicts. For example, computed_stats[108][0] contains the metrics for the first
75 | repeat of the provided model trained to 108 epochs. The available keys are:
76 | - halfway_training_time
77 | - halfway_train_accuracy
78 | - halfway_validation_accuracy
79 | - halfway_test_accuracy
80 | - final_training_time
81 | - final_train_accuracy
82 | - final_validation_accuracy
83 | - final_test_accuracy
84 | """
85 |
86 | from __future__ import absolute_import
87 | from __future__ import division
88 | from __future__ import print_function
89 |
90 | import base64
91 | import copy
92 | import json
93 | import os
94 | import random
95 | import time
96 |
97 | from nasbench.lib import config
98 | from nasbench.lib import evaluate
99 | from nasbench.lib import model_metrics_pb2
100 | from nasbench.lib import model_spec as _model_spec
101 | import numpy as np
102 | import tensorflow as tf
103 |
104 | # Bring ModelSpec to top-level for convenience. See lib/model_spec.py.
105 | ModelSpec = _model_spec.ModelSpec
106 |
107 |
108 | class OutOfDomainError(Exception):
109 | """Indicates that the requested graph is outside of the search domain."""
110 |
111 |
112 | class NASBench(object):
113 | """User-facing API for accessing the NASBench dataset."""
114 |
115 | def __init__(self, dataset_file, seed=None):
116 | """Initialize dataset, this should only be done once per experiment.
117 |
118 | Args:
119 | dataset_file: path to .tfrecord file containing the dataset.
120 | seed: random seed used for sampling queried models. Two NASBench objects
121 | created with the same seed will return the same data points when queried
122 | with the same models in the same order. By default, the seed is randomly
123 | generated.
124 | """
125 | self.config = config.build_config()
126 | random.seed(seed)
127 |
128 | print('Loading dataset from file... This may take a few minutes...')
129 | start = time.time()
130 |
131 | # Stores the fixed statistics that are independent of evaluation (i.e.,
132 | # adjacency matrix, operations, and number of parameters).
133 | # hash --> metric name --> scalar
134 | self.fixed_statistics = {}
135 |
136 | # Stores the statistics that are computed via training and evaluating the
137 | # model on CIFAR-10. Statistics are computed for multiple repeats of each
138 | # model at each max epoch length.
139 | # hash --> epochs --> repeat index --> metric name --> scalar
140 | self.computed_statistics = {}
141 |
142 | # Valid queriable epoch lengths. {4, 12, 36, 108} for the full dataset or
143 | # {108} for the smaller dataset with only the 108 epochs.
144 | self.valid_epochs = set()
145 |
146 | for serialized_row in tf.python_io.tf_record_iterator(dataset_file):
147 | # Parse the data from the data file.
148 | module_hash, epochs, raw_adjacency, raw_operations, raw_metrics = (
149 | json.loads(serialized_row.decode('utf-8')))
150 |
151 | dim = int(np.sqrt(len(raw_adjacency)))
152 | adjacency = np.array([int(e) for e in list(raw_adjacency)], dtype=np.int8)
153 | adjacency = np.reshape(adjacency, (dim, dim))
154 | operations = raw_operations.split(',')
155 | metrics = model_metrics_pb2.ModelMetrics.FromString(
156 | base64.b64decode(raw_metrics))
157 |
158 | if module_hash not in self.fixed_statistics:
159 | # First time seeing this module, initialize fixed statistics.
160 | new_entry = {}
161 | new_entry['module_adjacency'] = adjacency
162 | new_entry['module_operations'] = operations
163 | new_entry['trainable_parameters'] = metrics.trainable_parameters
164 | self.fixed_statistics[module_hash] = new_entry
165 | self.computed_statistics[module_hash] = {}
166 |
167 | self.valid_epochs.add(epochs)
168 |
169 | if epochs not in self.computed_statistics[module_hash]:
170 | self.computed_statistics[module_hash][epochs] = []
171 |
172 | # Each data_point consists of the metrics recorded from a single
173 | # train-and-evaluation of a model at a specific epoch length.
174 | data_point = {}
175 |
176 | # Note: metrics.evaluation_data[0] contains the computed metrics at the
177 | # start of training (step 0) but this is unused by this API.
178 |
179 | # Evaluation statistics at the half-way point of training
180 | half_evaluation = metrics.evaluation_data[1]
181 | data_point['halfway_training_time'] = half_evaluation.training_time
182 | data_point['halfway_train_accuracy'] = half_evaluation.train_accuracy
183 | data_point['halfway_validation_accuracy'] = (
184 | half_evaluation.validation_accuracy)
185 | data_point['halfway_test_accuracy'] = half_evaluation.test_accuracy
186 |
187 | # Evaluation statistics at the end of training
188 | final_evaluation = metrics.evaluation_data[2]
189 | data_point['final_training_time'] = final_evaluation.training_time
190 | data_point['final_train_accuracy'] = final_evaluation.train_accuracy
191 | data_point['final_validation_accuracy'] = (
192 | final_evaluation.validation_accuracy)
193 | data_point['final_test_accuracy'] = final_evaluation.test_accuracy
194 |
195 | self.computed_statistics[module_hash][epochs].append(data_point)
196 |
197 | elapsed = time.time() - start
198 | print('Loaded dataset in %d seconds' % elapsed)
199 |
200 | self.history = {}
201 | self.training_time_spent = 0.0
202 | self.total_epochs_spent = 0
203 |
204 | def query(self, model_spec, epochs=108, stop_halfway=False):
205 | """Fetch one of the evaluations for this model spec.
206 |
207 | Each call will sample one of the config['num_repeats'] evaluations of the
208 | model. This means that repeated queries of the same model (or isomorphic
209 | models) may return identical metrics.
210 |
211 | This function will increment the budget counters for benchmarking purposes.
212 | See self.training_time_spent, and self.total_epochs_spent.
213 |
214 | This function also allows querying the evaluation metrics at the halfway
215 | point of training using stop_halfway. Using this option will increment the
216 | budget counters only up to the halfway point.
217 |
218 | Args:
219 | model_spec: ModelSpec object.
220 | epochs: number of epochs trained. Must be one of the evaluated number of
221 | epochs, [4, 12, 36, 108] for the full dataset.
222 | stop_halfway: if True, returned dict will only contain the training time
223 | and accuracies at the halfway point of training (num_epochs/2).
224 | Otherwise, returns the time and accuracies at the end of training
225 | (num_epochs).
226 |
227 | Returns:
228 | dict containing the evaluated data for this object.
229 |
230 | Raises:
231 | OutOfDomainError: if model_spec or num_epochs is outside the search space.
232 | """
233 | if epochs not in self.valid_epochs:
234 | raise OutOfDomainError('invalid number of epochs, must be one of %s'
235 | % self.valid_epochs)
236 |
237 | fixed_stat, computed_stat = self.get_metrics_from_spec(model_spec)
238 | sampled_index = random.randint(0, self.config['num_repeats'] - 1)
239 | computed_stat = computed_stat[epochs][sampled_index]
240 |
241 | data = {}
242 | data['module_adjacency'] = fixed_stat['module_adjacency']
243 | data['module_operations'] = fixed_stat['module_operations']
244 | data['trainable_parameters'] = fixed_stat['trainable_parameters']
245 |
246 | if stop_halfway:
247 | data['training_time'] = computed_stat['halfway_training_time']
248 | data['train_accuracy'] = computed_stat['halfway_train_accuracy']
249 | data['validation_accuracy'] = computed_stat['halfway_validation_accuracy']
250 | data['test_accuracy'] = computed_stat['halfway_test_accuracy']
251 | else:
252 | data['training_time'] = computed_stat['final_training_time']
253 | data['train_accuracy'] = computed_stat['final_train_accuracy']
254 | data['validation_accuracy'] = computed_stat['final_validation_accuracy']
255 | data['test_accuracy'] = computed_stat['final_test_accuracy']
256 |
257 | self.training_time_spent += data['training_time']
258 | if stop_halfway:
259 | self.total_epochs_spent += epochs // 2
260 | else:
261 | self.total_epochs_spent += epochs
262 |
263 | return data
264 |
265 | def is_valid(self, model_spec):
266 | """Checks the validity of the model_spec.
267 |
268 | For the purposes of benchmarking, this does not increment the budget
269 | counters.
270 |
271 | Args:
272 | model_spec: ModelSpec object.
273 |
274 | Returns:
275 | True if model is within space.
276 | """
277 | try:
278 | self._check_spec(model_spec)
279 | except OutOfDomainError:
280 | return False
281 |
282 | return True
283 |
284 | def get_budget_counters(self):
285 | """Returns the time and budget counters."""
286 | return self.training_time_spent, self.total_epochs_spent
287 |
288 | def reset_budget_counters(self):
289 | """Reset the time and epoch budget counters."""
290 | self.training_time_spent = 0.0
291 | self.total_epochs_spent = 0
292 |
293 | def evaluate(self, model_spec, model_dir):
294 | """Trains and evaluates a model spec from scratch (does not query dataset).
295 |
296 | This function runs the same procedure that was used to generate each
297 | evaluation in the dataset. Because we are not querying the generated
298 | dataset of trained models, there are no limitations on number of vertices,
299 | edges, operations, or epochs. Note that the results will not exactly match
300 | the dataset due to randomness. By default, this uses TPUs for evaluation but
301 | CPU/GPU can be used by setting --use_tpu=false (GPU will require installing
302 | tensorflow-gpu).
303 |
304 | Args:
305 | model_spec: ModelSpec object.
306 | model_dir: directory to store the checkpoints, summaries, and logs.
307 |
308 | Returns:
309 | dict contained the evaluated data for this object, same structure as
310 | returned by query().
311 | """
312 | # Metadata contains additional metrics that aren't reported normally.
313 | # However, these are stored in the JSON file at the model_dir.
314 | metadata = evaluate.train_and_evaluate(model_spec, self.config, model_dir)
315 | metadata_file = os.path.join(model_dir, 'metadata.json')
316 | with tf.gfile.Open(metadata_file, 'w') as f:
317 | json.dump(metadata, f, cls=_NumpyEncoder)
318 |
319 | data_point = {}
320 | data_point['module_adjacency'] = model_spec.matrix
321 | data_point['module_operations'] = model_spec.ops
322 | data_point['trainable_parameters'] = metadata['trainable_params']
323 |
324 | final_evaluation = metadata['evaluation_results'][-1]
325 | data_point['training_time'] = final_evaluation['training_time']
326 | data_point['train_accuracy'] = final_evaluation['train_accuracy']
327 | data_point['validation_accuracy'] = final_evaluation['validation_accuracy']
328 | data_point['test_accuracy'] = final_evaluation['test_accuracy']
329 |
330 | return data_point
331 |
332 | def hash_iterator(self):
333 | """Returns iterator over all unique model hashes."""
334 | return self.fixed_statistics.keys()
335 |
336 | def get_metrics_from_hash(self, module_hash):
337 | """Returns the metrics for all epochs and all repeats of a hash.
338 |
339 | This method is for dataset analysis and should not be used for benchmarking.
340 | As such, it does not increment any of the budget counters.
341 |
342 | Args:
343 | module_hash: MD5 hash, i.e., the values yielded by hash_iterator().
344 |
345 | Returns:
346 | fixed stats and computed stats of the model spec provided.
347 | """
348 | fixed_stat = copy.deepcopy(self.fixed_statistics[module_hash])
349 | computed_stat = copy.deepcopy(self.computed_statistics[module_hash])
350 | return fixed_stat, computed_stat
351 |
352 | def get_metrics_from_spec(self, model_spec):
353 | """Returns the metrics for all epochs and all repeats of a model.
354 |
355 | This method is for dataset analysis and should not be used for benchmarking.
356 | As such, it does not increment any of the budget counters.
357 |
358 | Args:
359 | model_spec: ModelSpec object.
360 |
361 | Returns:
362 | fixed stats and computed stats of the model spec provided.
363 | """
364 | self._check_spec(model_spec)
365 | module_hash = self._hash_spec(model_spec)
366 | return self.get_metrics_from_hash(module_hash)
367 |
368 | def _check_spec(self, model_spec):
369 | """Checks that the model spec is within the dataset."""
370 | if not model_spec.valid_spec:
371 | raise OutOfDomainError('invalid spec, provided graph is disconnected.')
372 |
373 | num_vertices = len(model_spec.ops)
374 | num_edges = np.sum(model_spec.matrix)
375 |
376 | if num_vertices > self.config['module_vertices']:
377 | raise OutOfDomainError('too many vertices, got %d (max vertices = %d)'
378 | % (num_vertices, config['module_vertices']))
379 |
380 | if num_edges > self.config['max_edges']:
381 | raise OutOfDomainError('too many edges, got %d (max edges = %d)'
382 | % (num_edges, self.config['max_edges']))
383 |
384 | if model_spec.ops[0] != 'input':
385 | raise OutOfDomainError('first operation should be \'input\'')
386 | if model_spec.ops[-1] != 'output':
387 | raise OutOfDomainError('last operation should be \'output\'')
388 | for op in model_spec.ops[1:-1]:
389 | if op not in self.config['available_ops']:
390 | raise OutOfDomainError('unsupported op %s (available ops = %s)'
391 | % (op, self.config['available_ops']))
392 |
393 | def _hash_spec(self, model_spec):
394 | """Returns the MD5 hash for a provided model_spec."""
395 | return model_spec.hash_spec(self.config['available_ops'])
396 |
397 |
398 | class _NumpyEncoder(json.JSONEncoder):
399 | """Converts numpy objects to JSON-serializable format."""
400 |
401 | def default(self, obj):
402 | if isinstance(obj, np.ndarray):
403 | # Matrices converted to nested lists
404 | return obj.tolist()
405 | elif isinstance(obj, np.generic):
406 | # Scalars converted to closest Python type
407 | return np.asscalar(obj)
408 | return json.JSONEncoder.default(self, obj)
409 |
--------------------------------------------------------------------------------
/nasbench/tests/graph_util_test.py:
--------------------------------------------------------------------------------
1 | # Copyright 2019 The Google Research Authors.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Tests for lib/graph_util.py."""
16 |
17 | from __future__ import absolute_import
18 | from __future__ import division
19 | from __future__ import print_function
20 |
21 | import random
22 | from nasbench.lib import graph_util
23 | import numpy as np
24 | import tensorflow as tf # Used for tf.test
25 |
26 |
27 | class GraphUtilTest(tf.test.TestCase):
28 |
29 | def test_gen_is_edge(self):
30 | """Tests gen_is_edge generates correct graphs."""
31 | fn = graph_util.gen_is_edge_fn(0) # '000'
32 | arr = np.fromfunction(fn, (3, 3), dtype=np.int8)
33 | self.assertTrue(np.array_equal(arr,
34 | np.array([[0, 0, 0],
35 | [0, 0, 0],
36 | [0, 0, 0]])))
37 |
38 | fn = graph_util.gen_is_edge_fn(3) # '011'
39 | arr = np.fromfunction(fn, (3, 3), dtype=np.int8)
40 | self.assertTrue(np.array_equal(arr,
41 | np.array([[0, 1, 1],
42 | [0, 0, 0],
43 | [0, 0, 0]])))
44 |
45 | fn = graph_util.gen_is_edge_fn(5) # '101'
46 | arr = np.fromfunction(fn, (3, 3), dtype=np.int8)
47 | self.assertTrue(np.array_equal(arr,
48 | np.array([[0, 1, 0],
49 | [0, 0, 1],
50 | [0, 0, 0]])))
51 |
52 | fn = graph_util.gen_is_edge_fn(7) # '111'
53 | arr = np.fromfunction(fn, (3, 3), dtype=np.int8)
54 | self.assertTrue(np.array_equal(arr,
55 | np.array([[0, 1, 1],
56 | [0, 0, 1],
57 | [0, 0, 0]])))
58 |
59 | fn = graph_util.gen_is_edge_fn(7) # '111'
60 | arr = np.fromfunction(fn, (4, 4), dtype=np.int8)
61 | self.assertTrue(np.array_equal(arr,
62 | np.array([[0, 1, 1, 0],
63 | [0, 0, 1, 0],
64 | [0, 0, 0, 0],
65 | [0, 0, 0, 0]])))
66 |
67 | fn = graph_util.gen_is_edge_fn(18) # '010010'
68 | arr = np.fromfunction(fn, (4, 4), dtype=np.int8)
69 | self.assertTrue(np.array_equal(arr,
70 | np.array([[0, 0, 1, 0],
71 | [0, 0, 0, 1],
72 | [0, 0, 0, 0],
73 | [0, 0, 0, 0]])))
74 |
75 | fn = graph_util.gen_is_edge_fn(35) # '100011'
76 | arr = np.fromfunction(fn, (4, 4), dtype=np.int8)
77 | self.assertTrue(np.array_equal(arr,
78 | np.array([[0, 1, 1, 0],
79 | [0, 0, 0, 0],
80 | [0, 0, 0, 1],
81 | [0, 0, 0, 0]])))
82 |
83 | def test_is_full_dag(self):
84 | """Tests is_full_dag classifies DAGs."""
85 | self.assertTrue(graph_util.is_full_dag(np.array(
86 | [[0, 1, 0],
87 | [0, 0, 1],
88 | [0, 0, 0]])))
89 |
90 | self.assertTrue(graph_util.is_full_dag(np.array(
91 | [[0, 1, 1],
92 | [0, 0, 1],
93 | [0, 0, 0]])))
94 |
95 | self.assertTrue(graph_util.is_full_dag(np.array(
96 | [[0, 1, 1, 0],
97 | [0, 0, 0, 1],
98 | [0, 0, 0, 1],
99 | [0, 0, 0, 0]])))
100 |
101 | # vertex 1 not connected to input
102 | self.assertFalse(graph_util.is_full_dag(np.array(
103 | [[0, 0, 1],
104 | [0, 0, 1],
105 | [0, 0, 0]])))
106 |
107 | # vertex 1 not connected to output
108 | self.assertFalse(graph_util.is_full_dag(np.array(
109 | [[0, 1, 1],
110 | [0, 0, 0],
111 | [0, 0, 0]])))
112 |
113 | # 1, 3 are connected to each other but disconnected from main path
114 | self.assertFalse(graph_util.is_full_dag(np.array(
115 | [[0, 0, 1, 0, 0],
116 | [0, 0, 0, 1, 0],
117 | [0, 0, 0, 0, 1],
118 | [0, 0, 0, 0, 0],
119 | [0, 0, 0, 0, 0]])))
120 |
121 | # no path from input to output
122 | self.assertFalse(graph_util.is_full_dag(np.array(
123 | [[0, 0, 1, 0],
124 | [0, 0, 0, 1],
125 | [0, 0, 0, 0],
126 | [0, 0, 0, 0]])))
127 |
128 | # completely disconnected vertex
129 | self.assertFalse(graph_util.is_full_dag(np.array(
130 | [[0, 1, 0, 0],
131 | [0, 0, 0, 1],
132 | [0, 0, 0, 0],
133 | [0, 0, 0, 0]])))
134 |
135 | def test_hash_module(self):
136 | # Diamond graph with label permutation
137 | matrix1 = np.array(
138 | [[0, 1, 1, 0,],
139 | [0, 0, 0, 1],
140 | [0, 0, 0, 1],
141 | [0, 0, 0, 0]])
142 | label1 = [-1, 1, 2, -2]
143 | label2 = [-1, 2, 1, -2]
144 |
145 | hash1 = graph_util.hash_module(matrix1, label1)
146 | hash2 = graph_util.hash_module(matrix1, label2)
147 | self.assertEqual(hash1, hash2)
148 |
149 | # Simple graph with edge permutation
150 | matrix1 = np.array(
151 | [[0, 1, 1, 0, 0],
152 | [0, 0, 0, 0, 1],
153 | [0, 0, 0, 1, 0],
154 | [0, 0, 0, 0, 1],
155 | [0, 0, 0, 0, 0]])
156 | label1 = [-1, 1, 2, 3, -2]
157 |
158 | matrix2 = np.array(
159 | [[0, 1, 0, 1, 0],
160 | [0, 0, 1, 0, 0],
161 | [0, 0, 0, 0, 1],
162 | [0, 0, 0, 0, 1],
163 | [0, 0, 0, 0, 0]])
164 | label2 = [-1, 2, 3, 1, -2]
165 |
166 | matrix3 = np.array(
167 | [[0, 1, 1, 0, 0],
168 | [0, 0, 0, 1, 0],
169 | [0, 0, 0, 0, 1],
170 | [0, 0, 0, 0, 1],
171 | [0, 0, 0, 0, 0]])
172 | label3 = [-1, 2, 1, 3, -2]
173 |
174 | hash1 = graph_util.hash_module(matrix1, label1)
175 | hash2 = graph_util.hash_module(matrix2, label2)
176 | hash3 = graph_util.hash_module(matrix3, label3)
177 | self.assertEqual(hash1, hash2)
178 | self.assertEqual(hash2, hash3)
179 |
180 | hash4 = graph_util.hash_module(matrix1, label2)
181 | self.assertNotEqual(hash4, hash1)
182 |
183 | hash5 = graph_util.hash_module(matrix1, label3)
184 | self.assertNotEqual(hash5, hash1)
185 |
186 | # Connected non-isomorphic regular graphs on 6 interior vertices (8 total)
187 | matrix1 = np.array(
188 | [[0, 1, 0, 0, 0, 0, 0, 0],
189 | [0, 0, 1, 1, 0, 0, 1, 0],
190 | [0, 0, 0, 0, 1, 1, 0, 0],
191 | [0, 0, 0, 0, 1, 1, 0, 0],
192 | [0, 0, 0, 0, 0, 0, 1, 0],
193 | [0, 0, 0, 0, 0, 0, 1, 0],
194 | [0, 0, 0, 0, 0, 0, 0, 1],
195 | [0, 0, 0, 0, 0, 0, 0, 0]])
196 | matrix2 = np.array(
197 | [[0, 1, 0, 0, 0, 0, 0, 0],
198 | [0, 0, 1, 1, 0, 1, 0, 0],
199 | [0, 0, 0, 0, 1, 0, 1, 0],
200 | [0, 0, 0, 0, 1, 1, 0, 0],
201 | [0, 0, 0, 0, 0, 0, 1, 0],
202 | [0, 0, 0, 0, 0, 0, 1, 0],
203 | [0, 0, 0, 0, 0, 0, 0, 1],
204 | [0, 0, 0, 0, 0, 0, 0, 0]])
205 | label1 = [-1, 1, 1, 1, 1, 1, 1, -2]
206 |
207 | hash1 = graph_util.hash_module(matrix1, label1)
208 | hash2 = graph_util.hash_module(matrix2, label1)
209 | self.assertNotEqual(hash1, hash2)
210 |
211 | # Non-isomorphic tricky case (breaks if you don't include self)
212 | hash1 = graph_util.hash_module(
213 | np.array([[0, 1, 0, 0, 0],
214 | [0, 0, 1, 0, 0],
215 | [0, 0, 0, 1, 0],
216 | [0, 0, 0, 0, 1],
217 | [0, 0, 0, 0, 0]]),
218 | [-1, 1, 0, 0, -2])
219 |
220 | hash2 = graph_util.hash_module(
221 | np.array([[0, 1, 0, 0, 0],
222 | [0, 0, 1, 0, 0],
223 | [0, 0, 0, 1, 0],
224 | [0, 0, 0, 0, 1],
225 | [0, 0, 0, 0, 0]]),
226 | [-1, 0, 0, 1, -2])
227 | self.assertNotEqual(hash1, hash2)
228 |
229 | # Non-isomorphic tricky case (breaks if you don't use directed edges)
230 | hash1 = graph_util.hash_module(
231 | np.array([[0, 1, 0, 1],
232 | [0, 0, 1, 0],
233 | [0, 0, 0, 1],
234 | [0, 0, 0, 0]]),
235 | [-1, 1, 0, -2])
236 |
237 | hash2 = graph_util.hash_module(
238 | np.array([[0, 1, 0, 1],
239 | [0, 0, 1, 0],
240 | [0, 0, 0, 1],
241 | [0, 0, 0, 0]]),
242 | [-1, 0, 1, -2])
243 | self.assertNotEqual(hash1, hash2)
244 |
245 | # Non-isomorphic tricky case (breaks if you only use out-neighbors and self)
246 | hash1 = graph_util.hash_module(np.array([[0, 1, 1, 1, 1, 0, 0],
247 | [0, 0, 1, 0, 0, 0, 0],
248 | [0, 0, 0, 0, 0, 0, 1],
249 | [0, 0, 0, 0, 0, 1, 0],
250 | [0, 0, 0, 0, 0, 1, 0],
251 | [0, 0, 0, 0, 0, 0, 1],
252 | [0, 0, 0, 0, 0, 0, 0]]),
253 | [-1, 1, 0, 0, 0, 0, -2])
254 | hash2 = graph_util.hash_module(np.array([[0, 1, 1, 1, 1, 0, 0],
255 | [0, 0, 1, 0, 0, 0, 0],
256 | [0, 0, 0, 0, 0, 0, 1],
257 | [0, 0, 0, 0, 0, 1, 0],
258 | [0, 0, 0, 0, 0, 1, 0],
259 | [0, 0, 0, 0, 0, 0, 1],
260 | [0, 0, 0, 0, 0, 0, 0]]),
261 | [-1, 0, 0, 0, 1, 0, -2])
262 | self.assertNotEqual(hash1, hash2)
263 |
264 | def test_permute_graph(self):
265 | # Does not have to be DAG
266 | matrix = np.array([[1, 1, 0],
267 | [0, 0, 1],
268 | [1, 0, 1]])
269 | labels = ['a', 'b', 'c']
270 |
271 | p1, l1 = graph_util.permute_graph(matrix, labels, [2, 0, 1])
272 | self.assertTrue(np.array_equal(p1,
273 | np.array([[0, 1, 0],
274 | [0, 1, 1],
275 | [1, 0, 1]])))
276 | self.assertEqual(l1, ['b', 'c', 'a'])
277 |
278 | p1, l1 = graph_util.permute_graph(matrix, labels, [0, 2, 1])
279 | self.assertTrue(np.array_equal(p1,
280 | np.array([[1, 0, 1],
281 | [1, 1, 0],
282 | [0, 1, 0]])))
283 | self.assertEqual(l1, ['a', 'c', 'b'])
284 |
285 | def test_is_isomorphic(self):
286 | # Reuse some tests from hash_module
287 | matrix1 = np.array(
288 | [[0, 1, 1, 0,],
289 | [0, 0, 0, 1],
290 | [0, 0, 0, 1],
291 | [0, 0, 0, 0]])
292 | label1 = [-1, 1, 2, -2]
293 | label2 = [-1, 2, 1, -2]
294 |
295 | self.assertTrue(graph_util.is_isomorphic((matrix1, label1),
296 | (matrix1, label2)))
297 |
298 | # Simple graph with edge permutation
299 | matrix1 = np.array(
300 | [[0, 1, 1, 0, 0],
301 | [0, 0, 0, 0, 1],
302 | [0, 0, 0, 1, 0],
303 | [0, 0, 0, 0, 1],
304 | [0, 0, 0, 0, 0]])
305 | label1 = [-1, 1, 2, 3, -2]
306 |
307 | matrix2 = np.array(
308 | [[0, 1, 0, 1, 0],
309 | [0, 0, 1, 0, 0],
310 | [0, 0, 0, 0, 1],
311 | [0, 0, 0, 0, 1],
312 | [0, 0, 0, 0, 0]])
313 | label2 = [-1, 2, 3, 1, -2]
314 |
315 | matrix3 = np.array(
316 | [[0, 1, 1, 0, 0],
317 | [0, 0, 0, 1, 0],
318 | [0, 0, 0, 0, 1],
319 | [0, 0, 0, 0, 1],
320 | [0, 0, 0, 0, 0]])
321 | label3 = [-1, 2, 1, 3, -2]
322 |
323 | self.assertTrue(graph_util.is_isomorphic((matrix1, label1),
324 | (matrix2, label2)))
325 | self.assertTrue(graph_util.is_isomorphic((matrix1, label1),
326 | (matrix3, label3)))
327 | self.assertFalse(graph_util.is_isomorphic((matrix1, label1),
328 | (matrix2, label1)))
329 |
330 | # Connected non-isomorphic regular graphs on 6 interior vertices (8 total)
331 | matrix1 = np.array(
332 | [[0, 1, 0, 0, 0, 0, 0, 0],
333 | [0, 0, 1, 1, 0, 0, 1, 0],
334 | [0, 0, 0, 0, 1, 1, 0, 0],
335 | [0, 0, 0, 0, 1, 1, 0, 0],
336 | [0, 0, 0, 0, 0, 0, 1, 0],
337 | [0, 0, 0, 0, 0, 0, 1, 0],
338 | [0, 0, 0, 0, 0, 0, 0, 1],
339 | [0, 0, 0, 0, 0, 0, 0, 0]])
340 | matrix2 = np.array(
341 | [[0, 1, 0, 0, 0, 0, 0, 0],
342 | [0, 0, 1, 1, 0, 1, 0, 0],
343 | [0, 0, 0, 0, 1, 0, 1, 0],
344 | [0, 0, 0, 0, 1, 1, 0, 0],
345 | [0, 0, 0, 0, 0, 0, 1, 0],
346 | [0, 0, 0, 0, 0, 0, 1, 0],
347 | [0, 0, 0, 0, 0, 0, 0, 1],
348 | [0, 0, 0, 0, 0, 0, 0, 0]])
349 | label1 = [-1, 1, 1, 1, 1, 1, 1, -2]
350 |
351 | self.assertFalse(graph_util.is_isomorphic((matrix1, label1),
352 | (matrix2, label1)))
353 |
354 | # Connected isomorphic regular graphs on 8 total vertices (bipartite)
355 | matrix1 = np.array(
356 | [[0, 0, 0, 0, 1, 1, 1, 0],
357 | [0, 0, 0, 0, 1, 1, 0, 1],
358 | [0, 0, 0, 0, 1, 0, 1, 1],
359 | [0, 0, 0, 0, 0, 1, 1, 1],
360 | [1, 1, 1, 0, 0, 0, 0, 0],
361 | [1, 1, 0, 1, 0, 0, 0, 0],
362 | [1, 0, 1, 1, 0, 0, 0, 0],
363 | [0, 1, 1, 1, 0, 0, 0, 0]])
364 | matrix2 = np.array(
365 | [[0, 1, 0, 1, 1, 0, 0, 0],
366 | [1, 0, 1, 0, 0, 1, 0, 0],
367 | [0, 1, 0, 1, 0, 0, 1, 0],
368 | [1, 0, 1, 0, 0, 0, 0, 1],
369 | [1, 0, 0, 0, 0, 1, 0, 1],
370 | [0, 1, 0, 0, 1, 0, 1, 0],
371 | [0, 0, 1, 0, 0, 1, 0, 1],
372 | [0, 0, 0, 1, 1, 0, 1, 0]])
373 | label1 = [1, 1, 1, 1, 1, 1, 1, 1]
374 |
375 | # Sanity check: manual permutation
376 | perm = [0, 5, 7, 2, 4, 1, 3, 6]
377 | pm1, pl1 = graph_util.permute_graph(matrix1, label1, perm)
378 | self.assertTrue(np.array_equal(matrix2, pm1))
379 | self.assertEqual(pl1, label1)
380 |
381 | self.assertTrue(graph_util.is_isomorphic((matrix1, label1),
382 | (matrix2, label1)))
383 |
384 | label2 = [1, 1, 1, 1, 2, 2, 2, 2]
385 | label3 = [1, 2, 1, 2, 2, 1, 2, 1]
386 |
387 | self.assertTrue(graph_util.is_isomorphic((matrix1, label2),
388 | (matrix2, label3)))
389 |
390 | def test_random_isomorphism_hashing(self):
391 | # Tests that hash_module always provides the same hash for randomly
392 | # generated isomorphic graphs.
393 | for _ in range(1000):
394 | # Generate random graph. Note: the algorithm works (i.e. same hash ==
395 | # isomorphic graphs) for all directed graphs with coloring and does not
396 | # require the graph to be a DAG.
397 | size = random.randint(3, 20)
398 | matrix = np.random.randint(0, 2, [size, size])
399 | labels = [random.randint(0, 10) for _ in range(size)]
400 |
401 | # Generate permutation of matrix and labels.
402 | perm = np.random.permutation(size).tolist()
403 | pmatrix, plabels = graph_util.permute_graph(matrix, labels, perm)
404 |
405 | # Hashes should be identical.
406 | hash1 = graph_util.hash_module(matrix, labels)
407 | hash2 = graph_util.hash_module(pmatrix, plabels)
408 | self.assertEqual(hash1, hash2)
409 |
410 | def test_counterexample_bipartite(self):
411 | # This is a counter example that shows that the hashing algorithm is not
412 | # perfectly identifiable (i.e. there are non-isomorphic graphs with the same
413 | # hash). If this tests fails, it means the algorithm must have been changed
414 | # in some way that allows it to identify these graphs as non-isomoprhic.
415 | matrix1 = np.array(
416 | [[0, 1, 1, 1, 1, 0, 0, 0, 0, 0],
417 | [0, 0, 0, 0, 0, 1, 1, 0, 0, 0],
418 | [0, 0, 0, 0, 0, 1, 1, 0, 0, 0],
419 | [0, 0, 0, 0, 0, 0, 0, 1, 1, 0],
420 | [0, 0, 0, 0, 0, 0, 0, 1, 1, 0],
421 | [0, 0, 0, 0, 0, 0, 0, 0, 0, 1],
422 | [0, 0, 0, 0, 0, 0, 0, 0, 0, 1],
423 | [0, 0, 0, 0, 0, 0, 0, 0, 0, 1],
424 | [0, 0, 0, 0, 0, 0, 0, 0, 0, 1],
425 | [0, 0, 0, 0, 0, 0, 0, 0, 0, 0]])
426 |
427 | matrix2 = np.array(
428 | [[0, 1, 1, 1, 1, 0, 0, 0, 0, 0],
429 | [0, 0, 0, 0, 0, 1, 1, 0, 0, 0],
430 | [0, 0, 0, 0, 0, 0, 1, 1, 0, 0],
431 | [0, 0, 0, 0, 0, 0, 0, 1, 1, 0],
432 | [0, 0, 0, 0, 0, 1, 0, 0, 1, 0],
433 | [0, 0, 0, 0, 0, 0, 0, 0, 0, 1],
434 | [0, 0, 0, 0, 0, 0, 0, 0, 0, 1],
435 | [0, 0, 0, 0, 0, 0, 0, 0, 0, 1],
436 | [0, 0, 0, 0, 0, 0, 0, 0, 0, 1],
437 | [0, 0, 0, 0, 0, 0, 0, 0, 0, 0]])
438 |
439 | labels = [-1, 1, 1, 1, 1, 2, 2, 2, 2, -2]
440 |
441 | # This takes far too long to run so commenting it out. The graphs are
442 | # non-isomorphic fairly obviously from visual inspection.
443 | # self.assertFalse(graph_util.is_isomorphic((matrix1, labels),
444 | # (matrix2, labels)))
445 | self.assertEqual(graph_util.hash_module(matrix1, labels),
446 | graph_util.hash_module(matrix2, labels))
447 |
448 |
449 | if __name__ == '__main__':
450 | tf.test.main()
451 |
--------------------------------------------------------------------------------
/nasbench/lib/model_builder.py:
--------------------------------------------------------------------------------
1 | # Copyright 2019 The Google Research Authors.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Builds the TensorFlow computational graph.
16 |
17 | Tensors flowing into a single vertex are added together for all vertices
18 | except the output, which is concatenated instead. Tensors flowing out of input
19 | are always added.
20 |
21 | If interior edge channels don't match, drop the extra channels (channels are
22 | guaranteed non-decreasing). Tensors flowing out of the input as always
23 | projected instead.
24 | """
25 |
26 | from __future__ import absolute_import
27 | from __future__ import division
28 | from __future__ import print_function
29 |
30 | from nasbench.lib import base_ops
31 | from nasbench.lib import training_time
32 | import numpy as np
33 | import tensorflow as tf
34 |
35 |
36 | def build_model_fn(spec, config, num_train_images):
37 | """Returns a model function for Estimator."""
38 | if config['data_format'] == 'channels_last':
39 | channel_axis = 3
40 | elif config['data_format'] == 'channels_first':
41 | # Currently this is not well supported
42 | channel_axis = 1
43 | else:
44 | raise ValueError('invalid data_format')
45 |
46 | def model_fn(features, labels, mode, params):
47 | """Builds the model from the input features."""
48 | del params # Unused
49 | is_training = (mode == tf.estimator.ModeKeys.TRAIN)
50 |
51 | # Store auxiliary activations increasing in depth of network. First
52 | # activation occurs immediately after the stem and the others immediately
53 | # follow each stack.
54 | aux_activations = []
55 |
56 | # Initial stem convolution
57 | with tf.variable_scope('stem'):
58 | net = base_ops.conv_bn_relu(
59 | features, 3, config['stem_filter_size'],
60 | is_training, config['data_format'])
61 | aux_activations.append(net)
62 |
63 | for stack_num in range(config['num_stacks']):
64 | channels = net.get_shape()[channel_axis].value
65 |
66 | # Downsample at start (except first)
67 | if stack_num > 0:
68 | net = tf.layers.max_pooling2d(
69 | inputs=net,
70 | pool_size=(2, 2),
71 | strides=(2, 2),
72 | padding='same',
73 | data_format=config['data_format'])
74 |
75 | # Double output channels each time we downsample
76 | channels *= 2
77 |
78 | with tf.variable_scope('stack{}'.format(stack_num)):
79 | for module_num in range(config['num_modules_per_stack']):
80 | with tf.variable_scope('module{}'.format(module_num)):
81 | net = build_module(
82 | spec,
83 | inputs=net,
84 | channels=channels,
85 | is_training=is_training)
86 | aux_activations.append(net)
87 |
88 | # Global average pool
89 | if config['data_format'] == 'channels_last':
90 | net = tf.reduce_mean(net, [1, 2])
91 | elif config['data_format'] == 'channels_first':
92 | net = tf.reduce_mean(net, [2, 3])
93 | else:
94 | raise ValueError('invalid data_format')
95 |
96 | # Fully-connected layer to labels
97 | logits = tf.layers.dense(
98 | inputs=net,
99 | units=config['num_labels'])
100 |
101 | if mode == tf.estimator.ModeKeys.PREDICT and not config['use_tpu']:
102 | # It is a known limitation of Estimator that the labels
103 | # are not passed during PREDICT mode when running on CPU/GPU
104 | # (https://github.com/tensorflow/tensorflow/issues/17824), thus we cannot
105 | # compute the loss or anything dependent on it (i.e., the gradients).
106 | loss = tf.constant(0.0)
107 | else:
108 | loss = tf.losses.softmax_cross_entropy(
109 | onehot_labels=tf.one_hot(labels, config['num_labels']),
110 | logits=logits)
111 |
112 | loss += config['weight_decay'] * tf.add_n(
113 | [tf.nn.l2_loss(v) for v in tf.trainable_variables()])
114 |
115 | # Use inference mode to compute some useful metrics on a fixed sample
116 | # Due to the batch being sharded on TPU, these metrics should be run on CPU
117 | # only to ensure that the metrics are computed on the whole batch. We add a
118 | # leading dimension because PREDICT expects batch-shaped tensors.
119 | if mode == tf.estimator.ModeKeys.PREDICT:
120 | parameter_norms = {
121 | 'param:' + tensor.name:
122 | tf.expand_dims(tf.norm(tensor, ord=2), 0)
123 | for tensor in tf.trainable_variables()
124 | }
125 |
126 | # Compute gradients of all parameters and the input simultaneously
127 | all_params_names = []
128 | all_params_tensors = []
129 | for tensor in tf.trainable_variables():
130 | all_params_names.append('param_grad_norm:' + tensor.name)
131 | all_params_tensors.append(tensor)
132 | all_params_names.append('input_grad_norm')
133 | all_params_tensors.append(features)
134 |
135 | grads = tf.gradients(loss, all_params_tensors)
136 |
137 | param_gradient_norms = {}
138 | for name, grad in zip(all_params_names, grads)[:-1]:
139 | if grad is not None:
140 | param_gradient_norms[name] = (
141 | tf.expand_dims(tf.norm(grad, ord=2), 0))
142 | else:
143 | param_gradient_norms[name] = (
144 | tf.expand_dims(tf.constant(0.0), 0))
145 |
146 | if grads[-1] is not None:
147 | input_grad_norm = tf.sqrt(tf.reduce_sum(
148 | tf.square(grads[-1]), axis=[1, 2, 3]))
149 | else:
150 | input_grad_norm = tf.expand_dims(tf.constant(0.0), 0)
151 |
152 | covariance_matrices = {
153 | 'cov_matrix_%d' % i:
154 | tf.expand_dims(_covariance_matrix(aux), 0)
155 | for i, aux in enumerate(aux_activations)
156 | }
157 |
158 | predictions = {
159 | 'logits': logits,
160 | 'loss': tf.expand_dims(loss, 0),
161 | 'input_grad_norm': input_grad_norm,
162 | }
163 | predictions.update(parameter_norms)
164 | predictions.update(param_gradient_norms)
165 | predictions.update(covariance_matrices)
166 |
167 | return tf.contrib.tpu.TPUEstimatorSpec(mode=mode, predictions=predictions)
168 |
169 | if mode == tf.estimator.ModeKeys.TRAIN:
170 | global_step = tf.train.get_or_create_global_step()
171 | base_lr = config['learning_rate']
172 | if config['use_tpu']:
173 | base_lr *= config['tpu_num_shards']
174 |
175 | if config['lr_decay_method'] == 'COSINE_BY_STEP':
176 | total_steps = int(config['train_epochs'] * num_train_images /
177 | config['batch_size'])
178 | progress_fraction = tf.cast(global_step, tf.float32) / total_steps
179 | learning_rate = (0.5 * base_lr *
180 | (1 + tf.cos(np.pi * progress_fraction)))
181 |
182 | elif config['lr_decay_method'] == 'COSINE_BY_TIME':
183 | # Requires training_time.limit hooks to be added to Estimator
184 | elapsed_time = tf.cast(training_time.get_total_time(), dtype=tf.float32)
185 | progress_fraction = elapsed_time / config['train_seconds']
186 | learning_rate = (0.5 * base_lr *
187 | (1 + tf.cos(np.pi * progress_fraction)))
188 |
189 | elif config['lr_decay_method'] == 'STEPWISE':
190 | # divide LR by 10 at 1/2, 2/3, and 5/6 of total epochs
191 | total_steps = (config['train_epochs'] * num_train_images /
192 | config['batch_size'])
193 | boundaries = [int(0.5 * total_steps),
194 | int(0.667 * total_steps),
195 | int(0.833 * total_steps)]
196 | values = [1.0 * base_lr,
197 | 0.1 * base_lr,
198 | 0.01 * base_lr,
199 | 0.0001 * base_lr]
200 | learning_rate = tf.train.piecewise_constant(
201 | global_step, boundaries, values)
202 |
203 | else:
204 | raise ValueError('invalid lr_decay_method')
205 |
206 | # Set LR to 0 for step 0 to initialize the weights without training
207 | learning_rate = tf.where(tf.equal(global_step, 0), 0.0, learning_rate)
208 |
209 | optimizer = tf.train.RMSPropOptimizer(
210 | learning_rate=learning_rate,
211 | momentum=config['momentum'],
212 | epsilon=1.0)
213 | if config['use_tpu']:
214 | optimizer = tf.contrib.tpu.CrossShardOptimizer(optimizer)
215 |
216 | # Update ops required for batch norm moving variables
217 | update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
218 | with tf.control_dependencies(update_ops):
219 | train_op = optimizer.minimize(loss, global_step)
220 |
221 | return tf.contrib.tpu.TPUEstimatorSpec(
222 | mode=mode,
223 | loss=loss,
224 | train_op=train_op)
225 |
226 | elif mode == tf.estimator.ModeKeys.EVAL:
227 | def metric_fn(labels, logits):
228 | predictions = tf.argmax(logits, axis=1)
229 | accuracy = tf.metrics.accuracy(labels, predictions)
230 |
231 | return {'accuracy': accuracy}
232 |
233 | eval_metrics = (metric_fn, [labels, logits])
234 |
235 | return tf.contrib.tpu.TPUEstimatorSpec(
236 | mode=mode,
237 | loss=loss,
238 | eval_metrics=eval_metrics)
239 |
240 | return model_fn
241 |
242 |
243 | def build_module(spec, inputs, channels, is_training):
244 | """Build a custom module using a proposed model spec.
245 |
246 | Builds the model using the adjacency matrix and op labels specified. Channels
247 | controls the module output channel count but the interior channels are
248 | determined via equally splitting the channel count whenever there is a
249 | concatenation of Tensors.
250 |
251 | Args:
252 | spec: ModelSpec object.
253 | inputs: input Tensors to this module.
254 | channels: output channel count.
255 | is_training: bool for whether this model is training.
256 |
257 | Returns:
258 | output Tensor from built module.
259 |
260 | Raises:
261 | ValueError: invalid spec
262 | """
263 | num_vertices = np.shape(spec.matrix)[0]
264 |
265 | if spec.data_format == 'channels_last':
266 | channel_axis = 3
267 | elif spec.data_format == 'channels_first':
268 | channel_axis = 1
269 | else:
270 | raise ValueError('invalid data_format')
271 |
272 | input_channels = inputs.get_shape()[channel_axis].value
273 | # vertex_channels[i] = number of output channels of vertex i
274 | vertex_channels = compute_vertex_channels(
275 | input_channels, channels, spec.matrix)
276 |
277 | # Construct tensors from input forward
278 | tensors = [tf.identity(inputs, name='input')]
279 |
280 | final_concat_in = []
281 | for t in range(1, num_vertices - 1):
282 | with tf.variable_scope('vertex_{}'.format(t)):
283 | # Create interior connections, truncating if necessary
284 | add_in = [truncate(tensors[src], vertex_channels[t], spec.data_format)
285 | for src in range(1, t) if spec.matrix[src, t]]
286 |
287 | # Create add connection from projected input
288 | if spec.matrix[0, t]:
289 | add_in.append(projection(
290 | tensors[0],
291 | vertex_channels[t],
292 | is_training,
293 | spec.data_format))
294 |
295 | if len(add_in) == 1:
296 | vertex_input = add_in[0]
297 | else:
298 | vertex_input = tf.add_n(add_in)
299 |
300 | # Perform op at vertex t
301 | op = base_ops.OP_MAP[spec.ops[t]](
302 | is_training=is_training,
303 | data_format=spec.data_format)
304 | vertex_value = op.build(vertex_input, vertex_channels[t])
305 |
306 | tensors.append(vertex_value)
307 | if spec.matrix[t, num_vertices - 1]:
308 | final_concat_in.append(tensors[t])
309 |
310 | # Construct final output tensor by concating all fan-in and adding input.
311 | if not final_concat_in:
312 | # No interior vertices, input directly connected to output
313 | assert spec.matrix[0, num_vertices - 1]
314 | with tf.variable_scope('output'):
315 | outputs = projection(
316 | tensors[0],
317 | channels,
318 | is_training,
319 | spec.data_format)
320 |
321 | else:
322 | if len(final_concat_in) == 1:
323 | outputs = final_concat_in[0]
324 | else:
325 | outputs = tf.concat(final_concat_in, channel_axis)
326 |
327 | if spec.matrix[0, num_vertices - 1]:
328 | outputs += projection(
329 | tensors[0],
330 | channels,
331 | is_training,
332 | spec.data_format)
333 |
334 | outputs = tf.identity(outputs, name='output')
335 | return outputs
336 |
337 |
338 | def projection(inputs, channels, is_training, data_format):
339 | """1x1 projection (as in ResNet) followed by batch normalization and ReLU."""
340 | with tf.variable_scope('projection'):
341 | net = base_ops.conv_bn_relu(inputs, 1, channels, is_training, data_format)
342 |
343 | return net
344 |
345 |
346 | def truncate(inputs, channels, data_format):
347 | """Slice the inputs to channels if necessary."""
348 | if data_format == 'channels_last':
349 | input_channels = inputs.get_shape()[3].value
350 | else:
351 | assert data_format == 'channels_first'
352 | input_channels = inputs.get_shape()[1].value
353 |
354 | if input_channels < channels:
355 | raise ValueError('input channel < output channels for truncate')
356 | elif input_channels == channels:
357 | return inputs # No truncation necessary
358 | else:
359 | # Truncation should only be necessary when channel division leads to
360 | # vertices with +1 channels. The input vertex should always be projected to
361 | # the minimum channel count.
362 | assert input_channels - channels == 1
363 | if data_format == 'channels_last':
364 | return tf.slice(inputs, [0, 0, 0, 0], [-1, -1, -1, channels])
365 | else:
366 | return tf.slice(inputs, [0, 0, 0, 0], [-1, channels, -1, -1])
367 |
368 |
369 | def compute_vertex_channels(input_channels, output_channels, matrix):
370 | """Computes the number of channels at every vertex.
371 |
372 | Given the input channels and output channels, this calculates the number of
373 | channels at each interior vertex. Interior vertices have the same number of
374 | channels as the max of the channels of the vertices it feeds into. The output
375 | channels are divided amongst the vertices that are directly connected to it.
376 | When the division is not even, some vertices may receive an extra channel to
377 | compensate.
378 |
379 | Args:
380 | input_channels: input channel count.
381 | output_channels: output channel count.
382 | matrix: adjacency matrix for the module (pruned by model_spec).
383 |
384 | Returns:
385 | list of channel counts, in order of the vertices.
386 | """
387 | num_vertices = np.shape(matrix)[0]
388 |
389 | vertex_channels = [0] * num_vertices
390 | vertex_channels[0] = input_channels
391 | vertex_channels[num_vertices - 1] = output_channels
392 |
393 | if num_vertices == 2:
394 | # Edge case where module only has input and output vertices
395 | return vertex_channels
396 |
397 | # Compute the in-degree ignoring input, axis 0 is the src vertex and axis 1 is
398 | # the dst vertex. Summing over 0 gives the in-degree count of each vertex.
399 | in_degree = np.sum(matrix[1:], axis=0)
400 | interior_channels = output_channels // in_degree[num_vertices - 1]
401 | correction = output_channels % in_degree[num_vertices - 1] # Remainder to add
402 |
403 | # Set channels of vertices that flow directly to output
404 | for v in range(1, num_vertices - 1):
405 | if matrix[v, num_vertices - 1]:
406 | vertex_channels[v] = interior_channels
407 | if correction:
408 | vertex_channels[v] += 1
409 | correction -= 1
410 |
411 | # Set channels for all other vertices to the max of the out edges, going
412 | # backwards. (num_vertices - 2) index skipped because it only connects to
413 | # output.
414 | for v in range(num_vertices - 3, 0, -1):
415 | if not matrix[v, num_vertices - 1]:
416 | for dst in range(v + 1, num_vertices - 1):
417 | if matrix[v, dst]:
418 | vertex_channels[v] = max(vertex_channels[v], vertex_channels[dst])
419 | assert vertex_channels[v] > 0
420 |
421 | tf.logging.info('vertex_channels: %s', str(vertex_channels))
422 |
423 | # Sanity check, verify that channels never increase and final channels add up.
424 | final_fan_in = 0
425 | for v in range(1, num_vertices - 1):
426 | if matrix[v, num_vertices - 1]:
427 | final_fan_in += vertex_channels[v]
428 | for dst in range(v + 1, num_vertices - 1):
429 | if matrix[v, dst]:
430 | assert vertex_channels[v] >= vertex_channels[dst]
431 | assert final_fan_in == output_channels or num_vertices == 2
432 | # num_vertices == 2 means only input/output nodes, so 0 fan-in
433 |
434 | return vertex_channels
435 |
436 |
437 | def _covariance_matrix(activations):
438 | """Computes the unbiased covariance matrix of the samples within the batch.
439 |
440 | Computes the sample covariance between the samples in the batch. Specifically,
441 |
442 | C(i,j) = (x_i - mean(x_i)) dot (x_j - mean(x_j)) / (N - 1)
443 |
444 | Matches the default behavior of np.cov().
445 |
446 | Args:
447 | activations: tensor activations with batch dimension first.
448 |
449 | Returns:
450 | [batch, batch] shape tensor for the covariance matrix.
451 | """
452 | batch_size = activations.get_shape()[0].value
453 | flattened = tf.reshape(activations, [batch_size, -1])
454 | means = tf.reduce_mean(flattened, axis=1, keepdims=True)
455 |
456 | centered = flattened - means
457 | squared = tf.matmul(centered, tf.transpose(centered))
458 | cov = squared / (tf.cast(tf.shape(flattened)[1], tf.float32) - 1)
459 |
460 | return cov
461 |
462 |
--------------------------------------------------------------------------------