├── .gitignore ├── images ├── architecture.png └── param_time_acc.png ├── nasbench ├── __init__.py ├── lib │ ├── __init__.py │ ├── model_metrics.proto │ ├── config.py │ ├── graph_util.py │ ├── base_ops.py │ ├── model_spec.py │ ├── cifar.py │ ├── model_metrics_pb2.py │ ├── training_time.py │ ├── evaluate.py │ └── model_builder.py ├── scripts │ ├── __init__.py │ ├── augment_model.py │ ├── generate_cifar10_tfrecords.py │ ├── generate_graphs.py │ └── run_evaluation.py ├── tests │ ├── model_builder_test.py │ ├── model_spec_test.py │ ├── run_evaluation_test.py │ └── graph_util_test.py └── api.py ├── setup.py ├── CONTRIBUTING.md ├── example.py ├── README.md └── LICENSE /.gitignore: -------------------------------------------------------------------------------- 1 | venv/ 2 | nasbench.egg-info/ 3 | *.pyc 4 | -------------------------------------------------------------------------------- /images/architecture.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google-research/nasbench/HEAD/images/architecture.png -------------------------------------------------------------------------------- /images/param_time_acc.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google-research/nasbench/HEAD/images/param_time_acc.png -------------------------------------------------------------------------------- /nasbench/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 The Google Research Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | -------------------------------------------------------------------------------- /nasbench/lib/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 The Google Research Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | -------------------------------------------------------------------------------- /nasbench/scripts/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 The Google Research Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 The Google Research Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Setup the nasbench library. 16 | 17 | This file is automatically run as part of `pip install -e .` 18 | """ 19 | 20 | import setuptools 21 | 22 | setuptools.setup( 23 | name='nasbench', 24 | version='1.0', 25 | packages=setuptools.find_packages(), 26 | install_requires=[ 27 | 'tensorflow>=1.12.0', 28 | ] 29 | ) 30 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # How to Contribute 2 | 3 | # Pull Requests 4 | 5 | Please send in fixes or feature additions through Pull Requests. 6 | 7 | ## Contributor License Agreement 8 | 9 | Contributions to this project must be accompanied by a Contributor License 10 | Agreement. You (or your employer) retain the copyright to your contribution, 11 | this simply gives us permission to use and redistribute your contributions as 12 | part of the project. Head over to to see 13 | your current agreements on file or to sign a new one. 14 | 15 | You generally only need to submit a CLA once, so if you've already submitted one 16 | (even if it was for a different project), you probably don't need to do it 17 | again. 18 | 19 | ## Code reviews 20 | 21 | All submissions, including submissions by project members, require review. We 22 | use GitHub pull requests for this purpose. Consult 23 | [GitHub Help](https://help.github.com/articles/about-pull-requests/) for more 24 | information on using pull requests. 25 | -------------------------------------------------------------------------------- /nasbench/lib/model_metrics.proto: -------------------------------------------------------------------------------- 1 | // Metrics stored per evaluation of each ModelSpec. 2 | // NOTE: this file is for reference only, changes to this file will not affect 3 | // the code unless you compile the proto using protoc, which can be installed 4 | // from https://github.com/protocolbuffers/protobuf/releases. 5 | syntax = "proto2"; 6 | 7 | package nasbench; 8 | 9 | message ModelMetrics { 10 | // Metrics that are evaluated at each checkpoint. Each ModelMetrics will 11 | // contain multiple EvaluationData messages evaluated at various points during 12 | // training, including the initialization before any steps are taken. 13 | repeated EvaluationData evaluation_data = 1; 14 | 15 | // Other fixed metrics (does not change over training) go here. 16 | 17 | // Parameter count of all trainable variables. 18 | optional int32 trainable_parameters = 2; 19 | 20 | // Total time for all training and evaluation (mostly used for diagnostic 21 | // purposes). 22 | optional double total_time = 3; 23 | } 24 | 25 | message EvaluationData { 26 | // Current epoch at the time of this evaluation. 27 | optional double current_epoch = 1; 28 | 29 | // Training time in seconds up to this point. Does not include evaluation 30 | // time. 31 | optional double training_time = 2; 32 | 33 | // Accuracy on a fixed 10,000 images from the train set. 34 | optional double train_accuracy = 3; 35 | 36 | // Accuracy on a held-out validation set of 10,000 images. 37 | optional double validation_accuracy = 4; 38 | 39 | // Accuracy on the test set of 10,000 images. 40 | optional double test_accuracy = 5; 41 | 42 | // Location of checkpoint file. Note: checkpoint_path will look like 43 | // /path/to/model_dir/model.ckpt-1234 but the actual checkpoint files may have 44 | // an extra ".data", ".index", ".meta" suffix. For purposes of loading a 45 | // checkpoint file in TensorFlow, the path without the suffix is sufficient. 46 | // This field may be left blank because the checkpoint can be programmatically 47 | // generated from the model specifications. 48 | optional string checkpoint_path = 6; 49 | 50 | // Additional sample metrics like gradient norms and covariance are too large 51 | // to store in file, so they need to be queried along with the checkpoints 52 | // from GCS directly. 53 | } 54 | 55 | 56 | -------------------------------------------------------------------------------- /example.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 The Google Research Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Runnable example, as shown in the README.md.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | from absl import app 22 | from nasbench import api 23 | 24 | # Replace this string with the path to the downloaded nasbench.tfrecord before 25 | # executing. 26 | NASBENCH_TFRECORD = '/path/to/nasbench.tfrecord' 27 | 28 | INPUT = 'input' 29 | OUTPUT = 'output' 30 | CONV1X1 = 'conv1x1-bn-relu' 31 | CONV3X3 = 'conv3x3-bn-relu' 32 | MAXPOOL3X3 = 'maxpool3x3' 33 | 34 | 35 | def main(argv): 36 | del argv # Unused 37 | 38 | # Load the data from file (this will take some time) 39 | nasbench = api.NASBench(NASBENCH_TFRECORD) 40 | 41 | # Create an Inception-like module (5x5 convolution replaced with two 3x3 42 | # convolutions). 43 | model_spec = api.ModelSpec( 44 | # Adjacency matrix of the module 45 | matrix=[[0, 1, 1, 1, 0, 1, 0], # input layer 46 | [0, 0, 0, 0, 0, 0, 1], # 1x1 conv 47 | [0, 0, 0, 0, 0, 0, 1], # 3x3 conv 48 | [0, 0, 0, 0, 1, 0, 0], # 5x5 conv (replaced by two 3x3's) 49 | [0, 0, 0, 0, 0, 0, 1], # 5x5 conv (replaced by two 3x3's) 50 | [0, 0, 0, 0, 0, 0, 1], # 3x3 max-pool 51 | [0, 0, 0, 0, 0, 0, 0]], # output layer 52 | # Operations at the vertices of the module, matches order of matrix 53 | ops=[INPUT, CONV1X1, CONV3X3, CONV3X3, CONV3X3, MAXPOOL3X3, OUTPUT]) 54 | 55 | # Query this model from dataset, returns a dictionary containing the metrics 56 | # associated with this model. 57 | print('Querying an Inception-like model.') 58 | data = nasbench.query(model_spec) 59 | print(data) 60 | print(nasbench.get_budget_counters()) # prints (total time, total epochs) 61 | 62 | # Get all metrics (all epoch lengths, all repeats) associated with this 63 | # model_spec. This should be used for dataset analysis and NOT for 64 | # benchmarking algorithms (does not increment budget counters). 65 | print('\nGetting all metrics for the same Inception-like model.') 66 | fixed_metrics, computed_metrics = nasbench.get_metrics_from_spec(model_spec) 67 | print(fixed_metrics) 68 | for epochs in nasbench.valid_epochs: 69 | for repeat_index in range(len(computed_metrics[epochs])): 70 | data_point = computed_metrics[epochs][repeat_index] 71 | print('Epochs trained %d, repeat number: %d' % (epochs, repeat_index + 1)) 72 | print(data_point) 73 | 74 | # Iterate through unique models in the dataset. Models are unqiuely identified 75 | # by a hash. 76 | print('\nIterating over unique models in the dataset.') 77 | for unique_hash in nasbench.hash_iterator(): 78 | fixed_metrics, computed_metrics = nasbench.get_metrics_from_hash( 79 | unique_hash) 80 | print(fixed_metrics) 81 | 82 | # For demo purposes, break here instead of iterating through whole set. 83 | break 84 | 85 | 86 | # If you are passing command line flags to modify the default config values, you 87 | # must use app.run(main) 88 | if __name__ == '__main__': 89 | app.run(main) 90 | -------------------------------------------------------------------------------- /nasbench/lib/config.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 The Google Research Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Configuration flags.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | from absl import flags 22 | 23 | FLAGS = flags.FLAGS 24 | 25 | # Data flags (only required for generating the dataset) 26 | flags.DEFINE_list( 27 | 'train_data_files', [], 28 | 'Training data files in TFRecord format. Multiple files can be passed in a' 29 | ' comma-separated list. The first file in the list will be used for' 30 | ' computing the training error.') 31 | flags.DEFINE_string( 32 | 'valid_data_file', '', 'Validation data in TFRecord format.') 33 | flags.DEFINE_string( 34 | 'test_data_file', '', 'Testing data in TFRecord format.') 35 | flags.DEFINE_string( 36 | 'sample_data_file', '', 'Sampled batch data in TFRecord format.') 37 | flags.DEFINE_string( 38 | 'data_format', 'channels_last', 39 | 'Data format, one of [channels_last, channels_first] for NHWC and NCHW' 40 | ' tensor formats respectively.') 41 | flags.DEFINE_integer( 42 | 'num_labels', 10, 'Number of input class labels.') 43 | 44 | # Search space parameters. 45 | flags.DEFINE_integer( 46 | 'module_vertices', 7, 47 | 'Number of vertices in module matrix, including input and output.') 48 | flags.DEFINE_integer( 49 | 'max_edges', 9, 50 | 'Maximum number of edges in the module matrix.') 51 | flags.DEFINE_list( 52 | 'available_ops', ['conv3x3-bn-relu', 'conv1x1-bn-relu', 'maxpool3x3'], 53 | 'Available op labels, see base_ops.py for full list of ops.') 54 | 55 | # Model hyperparameters. The default values are exactly what is used during the 56 | # exhaustive evaluation of all models. 57 | flags.DEFINE_integer( 58 | 'stem_filter_size', 128, 'Filter size after stem convolutions.') 59 | flags.DEFINE_integer( 60 | 'num_stacks', 3, 'Number of stacks of modules.') 61 | flags.DEFINE_integer( 62 | 'num_modules_per_stack', 3, 'Number of modules per stack.') 63 | flags.DEFINE_integer( 64 | 'batch_size', 256, 'Training batch size.') 65 | flags.DEFINE_integer( 66 | 'train_epochs', 108, 67 | 'Maximum training epochs. If --train_seconds is reached first, training' 68 | ' may not reach --train_epochs.') 69 | flags.DEFINE_float( 70 | 'train_seconds', 4.0 * 60 * 60, 71 | 'Maximum training seconds. If --train_epochs is reached first, training' 72 | ' may not reach --train_seconds. Used as safeguard against stalled jobs.' 73 | ' If train_seconds is 0.0, no time limit will be used.') 74 | flags.DEFINE_float( 75 | 'learning_rate', 0.1, 76 | 'Base learning rate. Linearly scaled by --tpu_num_shards.') 77 | flags.DEFINE_string( 78 | 'lr_decay_method', 'COSINE_BY_STEP', 79 | '[COSINE_BY_TIME, COSINE_BY_STEP, STEPWISE], see model_builder.py for full' 80 | ' list of decay methods.') 81 | flags.DEFINE_float( 82 | 'momentum', 0.9, 'Momentum.') 83 | flags.DEFINE_float( 84 | 'weight_decay', 1e-4, 'L2 regularization weight.') 85 | flags.DEFINE_integer( 86 | 'max_attempts', 5, 87 | 'Maximum number of times to try training and evaluating an individual' 88 | ' before aborting.') 89 | flags.DEFINE_list( 90 | 'intermediate_evaluations', ['0.5'], 91 | 'Intermediate evaluations relative to --train_epochs. For example, to' 92 | ' evaluate the model at 1/4, 1/2, 3/4 of the total epochs, use [0.25, 0.5,' 93 | ' 0.75]. An evaluation is always done at the start and end of training.') 94 | flags.DEFINE_integer( 95 | 'num_repeats', 3, 96 | 'Number of repeats evaluated for each model in the space.') 97 | 98 | # TPU flags 99 | flags.DEFINE_bool( 100 | 'use_tpu', True, 'Use TPUs for train and evaluation.') 101 | flags.DEFINE_integer( 102 | 'tpu_iterations_per_loop', 100, 'Iterations per loop of TPU execution.') 103 | flags.DEFINE_integer( 104 | 'tpu_num_shards', 2, 105 | 'Number of TPU shards, a single TPU chip has 2 shards.') 106 | 107 | 108 | def build_config(): 109 | """Build config from flags defined in this module.""" 110 | config = { 111 | flag.name: flag.value 112 | for flag in FLAGS.flags_by_module_dict()[__name__] 113 | } 114 | 115 | return config 116 | -------------------------------------------------------------------------------- /nasbench/scripts/augment_model.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 The Google Research Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Augments one model with longer training and evaluates on test set.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | from nasbench.lib import config as _config 22 | from nasbench.lib import evaluate 23 | from nasbench.lib import model_spec 24 | import numpy as np 25 | import tensorflow as tf # Used for app, flags, logging 26 | 27 | tf.flags.DEFINE_string('model_dir', '', 'model directory') 28 | FLAGS = tf.flags.FLAGS 29 | 30 | 31 | def create_resnet20_spec(config): 32 | """Construct a ResNet-20-like spec. 33 | 34 | The main difference is that there is an extra projection layer before the 35 | conv3x3 whereas the original ResNet doesn't have this. This increases the 36 | parameter count of this version slightly. 37 | 38 | Args: 39 | config: config dict created by config.py. 40 | 41 | Returns: 42 | ModelSpec object. 43 | """ 44 | spec = model_spec.ModelSpec( 45 | np.array([[0, 1, 0, 1], 46 | [0, 0, 1, 0], 47 | [0, 0, 0, 1], 48 | [0, 0, 0, 0]]), 49 | ['input', 'conv3x3-bn-relu', 'conv3x3-bn-relu', 'output']) 50 | config['num_stacks'] = 3 51 | config['num_modules_per_stack'] = 3 52 | config['stem_filter_size'] = 16 53 | return spec 54 | 55 | 56 | def create_resnet50_spec(config): 57 | """Construct a ResNet-50-like spec. 58 | 59 | The main difference is that there is an extra projection layer before the 60 | conv1x1 whereas the original ResNet doesn't have this. This increases the 61 | parameter count of this version slightly. 62 | 63 | Args: 64 | config: config dict created by config.py. 65 | 66 | Returns: 67 | ModelSpec object. 68 | """ 69 | spec = model_spec.ModelSpec( 70 | np.array([[0, 1, 1], 71 | [0, 0, 1], 72 | [0, 0, 0]]), 73 | ['input', 'bottleneck3x3', 'output']) 74 | config['num_stacks'] = 3 75 | config['num_modules_per_stack'] = 6 76 | config['stem_filter_size'] = 128 77 | return spec 78 | 79 | 80 | def create_inception_resnet_spec(config): 81 | """Construct an Inception-ResNet like spec. 82 | 83 | This spec is very similar to the InceptionV2 module with an added 84 | residual connection except that there is an extra projection in front of the 85 | max pool. The overall network filter counts and module counts do not match 86 | the actual source model. 87 | 88 | Args: 89 | config: config dict created by config.py. 90 | 91 | Returns: 92 | ModelSpec object. 93 | """ 94 | spec = model_spec.ModelSpec( 95 | np.array([[0, 1, 1, 1, 0, 1, 1], 96 | [0, 0, 0, 0, 0, 0, 1], 97 | [0, 0, 0, 0, 0, 0, 1], 98 | [0, 0, 0, 0, 1, 0, 0], 99 | [0, 0, 0, 0, 0, 0, 1], 100 | [0, 0, 0, 0, 0, 0, 1], 101 | [0, 0, 0, 0, 0, 0, 0]]), 102 | ['input', 'conv1x1-bn-relu', 'conv3x3-bn-relu', 'conv3x3-bn-relu', 103 | 'conv3x3-bn-relu', 'maxpool3x3', 'output']) 104 | config['num_stacks'] = 3 105 | config['num_modules_per_stack'] = 3 106 | config['stem_filter_size'] = 128 107 | return spec 108 | 109 | 110 | def create_best_nasbench_spec(config): 111 | """Construct the best spec in the NASBench dataset w.r.t. mean test accuracy. 112 | 113 | Args: 114 | config: config dict created by config.py. 115 | 116 | Returns: 117 | ModelSpec object. 118 | """ 119 | spec = model_spec.ModelSpec( 120 | np.array([[0, 1, 1, 0, 0, 1, 1], 121 | [0, 0, 0, 0, 0, 1, 0], 122 | [0, 0, 0, 1, 0, 0, 0], 123 | [0, 0, 0, 0, 1, 0, 0], 124 | [0, 0, 0, 0, 0, 1, 0], 125 | [0, 0, 0, 0, 0, 0, 1], 126 | [0, 0, 0, 0, 0, 0, 0]]), 127 | ['input', 'conv1x1-bn-relu', 'conv3x3-bn-relu', 'maxpool3x3', 128 | 'conv3x3-bn-relu', 'conv3x3-bn-relu', 'output']) 129 | config['num_stacks'] = 3 130 | config['num_modules_per_stack'] = 3 131 | config['stem_filter_size'] = 128 132 | return spec 133 | 134 | 135 | def main(_): 136 | config = _config.build_config() 137 | 138 | # The default settings in config are exactly what was used to generate the 139 | # dataset of models. However, given more epochs and a different learning rate 140 | # schedule, it is possible to get higher accuracy. 141 | config['train_epochs'] = 200 142 | config['lr_decay_method'] = 'STEPWISE' 143 | config['train_seconds'] = -1 # Disable training time limit 144 | spec = create_best_nasbench_spec(config) 145 | 146 | data = evaluate.augment_and_evaluate(spec, config, FLAGS.model_dir) 147 | tf.logging.info(data) 148 | 149 | 150 | if __name__ == '__main__': 151 | tf.app.run(main) 152 | -------------------------------------------------------------------------------- /nasbench/tests/model_builder_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 The Google Research Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Tests for lib/model_builder.py.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | from nasbench.lib import model_builder 22 | import numpy as np 23 | import tensorflow as tf 24 | 25 | 26 | class ModelBuilderTest(tf.test.TestCase): 27 | 28 | def test_compute_vertex_channels_linear(self): 29 | """Tests modules with no branching.""" 30 | matrix1 = np.array([[0, 1, 0, 0], 31 | [0, 0, 1, 0], 32 | [0, 0, 0, 1], 33 | [0, 0, 0, 0]]) 34 | vc1 = model_builder.compute_vertex_channels(8, 8, matrix1) 35 | assert vc1 == [8, 8, 8, 8] 36 | 37 | vc2 = model_builder.compute_vertex_channels(8, 16, matrix1) 38 | assert vc2 == [8, 16, 16, 16] 39 | 40 | vc3 = model_builder.compute_vertex_channels(16, 8, matrix1) 41 | assert vc3 == [16, 8, 8, 8] 42 | 43 | matrix2 = np.array([[0, 1], 44 | [0, 0]]) 45 | vc4 = model_builder.compute_vertex_channels(1, 1, matrix2) 46 | assert vc4 == [1, 1] 47 | 48 | vc5 = model_builder.compute_vertex_channels(1, 5, matrix2) 49 | assert vc5 == [1, 5] 50 | 51 | vc5 = model_builder.compute_vertex_channels(5, 1, matrix2) 52 | assert vc5 == [5, 1] 53 | 54 | def test_compute_vertex_channels_no_output_branch(self): 55 | """Tests modules that branch but not at the output vertex.""" 56 | matrix1 = np.array([[0, 1, 1, 0, 0], 57 | [0, 0, 0, 1, 0], 58 | [0, 0, 0, 1, 0], 59 | [0, 0, 0, 0, 1], 60 | [0, 0, 0, 0, 0]]) 61 | vc1 = model_builder.compute_vertex_channels(8, 8, matrix1) 62 | assert vc1 == [8, 8, 8, 8, 8] 63 | 64 | vc2 = model_builder.compute_vertex_channels(8, 16, matrix1) 65 | assert vc2 == [8, 16, 16, 16, 16] 66 | 67 | vc3 = model_builder.compute_vertex_channels(16, 8, matrix1) 68 | assert vc3 == [16, 8, 8, 8, 8] 69 | 70 | def test_compute_vertex_channels_output_branching(self): 71 | """Tests modules that branch at output.""" 72 | matrix1 = np.array([[0, 1, 1, 0], 73 | [0, 0, 0, 1], 74 | [0, 0, 0, 1], 75 | [0, 0, 0, 0]]) 76 | vc1 = model_builder.compute_vertex_channels(8, 8, matrix1) 77 | assert vc1 == [8, 4, 4, 8] 78 | 79 | vc2 = model_builder.compute_vertex_channels(8, 16, matrix1) 80 | assert vc2 == [8, 8, 8, 16] 81 | 82 | vc3 = model_builder.compute_vertex_channels(16, 8, matrix1) 83 | assert vc3 == [16, 4, 4, 8] 84 | 85 | vc4 = model_builder.compute_vertex_channels(8, 15, matrix1) 86 | assert vc4 == [8, 8, 7, 15] 87 | 88 | matrix2 = np.array([[0, 1, 1, 1, 0], 89 | [0, 0, 0, 0, 1], 90 | [0, 0, 0, 0, 1], 91 | [0, 0, 0, 0, 1], 92 | [0, 0, 0, 0, 0]]) 93 | vc5 = model_builder.compute_vertex_channels(8, 8, matrix2) 94 | assert vc5 == [8, 3, 3, 2, 8] 95 | 96 | vc6 = model_builder.compute_vertex_channels(8, 15, matrix2) 97 | assert vc6 == [8, 5, 5, 5, 15] 98 | 99 | def test_compute_vertex_channels_max(self): 100 | """Tests modules where some vertices take the max channels of neighbors.""" 101 | matrix1 = np.array([[0, 1, 0, 0, 0], 102 | [0, 0, 1, 1, 0], 103 | [0, 0, 0, 0, 1], 104 | [0, 0, 0, 0, 1], 105 | [0, 0, 0, 0, 0]]) 106 | vc1 = model_builder.compute_vertex_channels(8, 8, matrix1) 107 | assert vc1 == [8, 4, 4, 4, 8] 108 | 109 | vc2 = model_builder.compute_vertex_channels(8, 9, matrix1) 110 | assert vc2 == [8, 5, 5, 4, 9] 111 | 112 | matrix2 = np.array([[0, 1, 0, 1, 0], 113 | [0, 0, 1, 0, 1], 114 | [0, 0, 0, 1, 0], 115 | [0, 0, 0, 0, 1], 116 | [0, 0, 0, 0, 0]]) 117 | 118 | vc3 = model_builder.compute_vertex_channels(8, 8, matrix2) 119 | assert vc3 == [8, 4, 4, 4, 8] 120 | 121 | vc4 = model_builder.compute_vertex_channels(8, 15, matrix2) 122 | assert vc4 == [8, 8, 7, 7, 15] 123 | 124 | def test_covariance_matrix_against_numpy(self): 125 | """Tests that the TF implementation of covariance matrix matchs np.cov.""" 126 | 127 | # Randomized test 100 times 128 | for _ in range(100): 129 | batch = np.random.randint(50, 150) 130 | features = np.random.randint(500, 1500) 131 | matrix = np.random.random((batch, features)) 132 | 133 | tf_matrix = tf.constant(matrix, dtype=tf.float32) 134 | tf_cov_tensor = model_builder._covariance_matrix(tf_matrix) 135 | 136 | with tf.Session() as sess: 137 | tf_cov = sess.run(tf_cov_tensor) 138 | 139 | np_cov = np.cov(matrix) 140 | np.testing.assert_array_almost_equal(tf_cov, np_cov) 141 | 142 | 143 | if __name__ == '__main__': 144 | tf.test.main() 145 | -------------------------------------------------------------------------------- /nasbench/scripts/generate_cifar10_tfrecords.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 The Google Research Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Read CIFAR-10 data from pickled numpy arrays and writes TFRecords. 16 | 17 | Generates tf.train.Example protos and writes them to TFRecord files from the 18 | python version of the CIFAR-10 dataset downloaded from 19 | https://www.cs.toronto.edu/~kriz/cifar.html. 20 | 21 | Based on script from 22 | https://github.com/tensorflow/models/blob/master/tutorials/image/cifar10_estimator/generate_cifar10_tfrecords.py 23 | 24 | To run: 25 | python generate_cifar10_tfrecords.py --data_dir=/tmp/cifar-tfrecord 26 | """ 27 | 28 | from __future__ import absolute_import 29 | from __future__ import division 30 | from __future__ import print_function 31 | 32 | import argparse 33 | import os 34 | import sys 35 | 36 | import tarfile 37 | from six.moves import cPickle as pickle 38 | import tensorflow as tf 39 | 40 | CIFAR_FILENAME = 'cifar-10-python.tar.gz' 41 | CIFAR_DOWNLOAD_URL = 'https://www.cs.toronto.edu/~kriz/' + CIFAR_FILENAME 42 | CIFAR_LOCAL_FOLDER = 'cifar-10-batches-py' 43 | 44 | 45 | def download_and_extract(data_dir): 46 | # download CIFAR-10 if not already downloaded. 47 | tf.contrib.learn.datasets.base.maybe_download(CIFAR_FILENAME, data_dir, 48 | CIFAR_DOWNLOAD_URL) 49 | tarfile.open(os.path.join(data_dir, CIFAR_FILENAME), 50 | 'r:gz').extractall(data_dir) 51 | 52 | 53 | def _int64_feature(value): 54 | return tf.train.Feature(int64_list=tf.train.Int64List(value=[value])) 55 | 56 | 57 | def _bytes_feature(value): 58 | return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value])) 59 | 60 | 61 | def _get_file_names(): 62 | """Returns the file names expected to exist in the input_dir.""" 63 | file_names = {} 64 | for i in range(1, 5): 65 | file_names['train_%d' % i] = 'data_batch_%d' % i 66 | file_names['validation'] = 'data_batch_5' 67 | file_names['test'] = 'test_batch' 68 | return file_names 69 | 70 | 71 | def read_pickle_from_file(filename): 72 | with tf.gfile.Open(filename, 'rb') as f: 73 | if sys.version_info >= (3, 0): 74 | data_dict = pickle.load(f, encoding='bytes') 75 | else: 76 | data_dict = pickle.load(f) 77 | return data_dict 78 | 79 | 80 | def convert_to_tfrecord(input_file, output_file): 81 | """Converts a file to TFRecords.""" 82 | print('Generating %s' % output_file) 83 | with tf.python_io.TFRecordWriter(output_file) as record_writer: 84 | data_dict = read_pickle_from_file(input_file) 85 | data = data_dict[b'data'] 86 | labels = data_dict[b'labels'] 87 | num_entries_in_batch = len(labels) 88 | print('Converting %d images' % num_entries_in_batch) 89 | for i in range(num_entries_in_batch): 90 | example = tf.train.Example(features=tf.train.Features( 91 | feature={ 92 | 'image': _bytes_feature(data[i].tobytes()), 93 | 'label': _int64_feature(labels[i]) 94 | })) 95 | record_writer.write(example.SerializeToString()) 96 | 97 | 98 | def main(data_dir): 99 | print('Download from {} and extract.'.format(CIFAR_DOWNLOAD_URL)) 100 | download_and_extract(data_dir) 101 | file_names = _get_file_names() 102 | input_dir = os.path.join(data_dir, CIFAR_LOCAL_FOLDER) 103 | for mode, f in file_names.items(): 104 | input_file = os.path.join(input_dir, f) 105 | output_file = os.path.join(data_dir, mode + '.tfrecords') 106 | try: 107 | os.remove(output_file) 108 | except OSError: 109 | pass 110 | # Convert to tf.train.Example and write the to TFRecords. 111 | convert_to_tfrecord(input_file, output_file) 112 | 113 | # Save fixed batch of 100 examples (first 10 of each class sampled at the 114 | # front of the validation set). Ordered by label, i.e. 10 "airplane" images 115 | # followed by 10 "automobile" images... 116 | images = [[] for _ in range(10)] 117 | num_images = 0 118 | input_file = os.path.join(input_dir, file_names['validation']) 119 | data_dict = read_pickle_from_file(input_file) 120 | data = data_dict[b'data'] 121 | labels = data_dict[b'labels'] 122 | for i in range(len(labels)): 123 | label = labels[i] 124 | if len(images[label]) < 10: 125 | images[label].append( 126 | tf.train.Example(features=tf.train.Features( 127 | feature={ 128 | 'image': _bytes_feature(data[i].tobytes()), 129 | 'label': _int64_feature(label) 130 | }))) 131 | num_images += 1 132 | if num_images == 100: 133 | break 134 | 135 | output_file = os.path.join(data_dir, 'sample.tfrecords') 136 | print('Generating %s' % output_file) 137 | with tf.python_io.TFRecordWriter(output_file) as record_writer: 138 | for label_images in images: 139 | for example in label_images: 140 | record_writer.write(example.SerializeToString()) 141 | print('Done!') 142 | 143 | 144 | if __name__ == '__main__': 145 | parser = argparse.ArgumentParser() 146 | parser.add_argument( 147 | '--data_dir', 148 | type=str, 149 | default='', 150 | help='Directory to download and extract CIFAR-10 to.') 151 | 152 | args = parser.parse_args() 153 | main(args.data_dir) 154 | -------------------------------------------------------------------------------- /nasbench/lib/graph_util.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 The Google Research Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Utility functions used by generate_graph.py.""" 16 | from __future__ import absolute_import 17 | from __future__ import division 18 | from __future__ import print_function 19 | 20 | import hashlib 21 | import itertools 22 | 23 | import numpy as np 24 | 25 | 26 | def gen_is_edge_fn(bits): 27 | """Generate a boolean function for the edge connectivity. 28 | 29 | Given a bitstring FEDCBA and a 4x4 matrix, the generated matrix is 30 | [[0, A, B, D], 31 | [0, 0, C, E], 32 | [0, 0, 0, F], 33 | [0, 0, 0, 0]] 34 | 35 | Note that this function is agnostic to the actual matrix dimension due to 36 | order in which elements are filled out (column-major, starting from least 37 | significant bit). For example, the same FEDCBA bitstring (0-padded) on a 5x5 38 | matrix is 39 | [[0, A, B, D, 0], 40 | [0, 0, C, E, 0], 41 | [0, 0, 0, F, 0], 42 | [0, 0, 0, 0, 0], 43 | [0, 0, 0, 0, 0]] 44 | 45 | Args: 46 | bits: integer which will be interpreted as a bit mask. 47 | 48 | Returns: 49 | vectorized function that returns True when an edge is present. 50 | """ 51 | def is_edge(x, y): 52 | """Is there an edge from x to y (0-indexed)?""" 53 | if x >= y: 54 | return 0 55 | # Map x, y to index into bit string 56 | index = x + (y * (y - 1) // 2) 57 | return (bits >> index) % 2 == 1 58 | 59 | return np.vectorize(is_edge) 60 | 61 | 62 | def is_full_dag(matrix): 63 | """Full DAG == all vertices on a path from vert 0 to (V-1). 64 | 65 | i.e. no disconnected or "hanging" vertices. 66 | 67 | It is sufficient to check for: 68 | 1) no rows of 0 except for row V-1 (only output vertex has no out-edges) 69 | 2) no cols of 0 except for col 0 (only input vertex has no in-edges) 70 | 71 | Args: 72 | matrix: V x V upper-triangular adjacency matrix 73 | 74 | Returns: 75 | True if the there are no dangling vertices. 76 | """ 77 | shape = np.shape(matrix) 78 | 79 | rows = matrix[:shape[0]-1, :] == 0 80 | rows = np.all(rows, axis=1) # Any row with all 0 will be True 81 | rows_bad = np.any(rows) 82 | 83 | cols = matrix[:, 1:] == 0 84 | cols = np.all(cols, axis=0) # Any col with all 0 will be True 85 | cols_bad = np.any(cols) 86 | 87 | return (not rows_bad) and (not cols_bad) 88 | 89 | 90 | def num_edges(matrix): 91 | """Computes number of edges in adjacency matrix.""" 92 | return np.sum(matrix) 93 | 94 | 95 | def hash_module(matrix, labeling): 96 | """Computes a graph-invariance MD5 hash of the matrix and label pair. 97 | 98 | Args: 99 | matrix: np.ndarray square upper-triangular adjacency matrix. 100 | labeling: list of int labels of length equal to both dimensions of 101 | matrix. 102 | 103 | Returns: 104 | MD5 hash of the matrix and labeling. 105 | """ 106 | vertices = np.shape(matrix)[0] 107 | in_edges = np.sum(matrix, axis=0).tolist() 108 | out_edges = np.sum(matrix, axis=1).tolist() 109 | 110 | assert len(in_edges) == len(out_edges) == len(labeling) 111 | hashes = list(zip(out_edges, in_edges, labeling)) 112 | hashes = [hashlib.md5(str(h).encode('utf-8')).hexdigest() for h in hashes] 113 | # Computing this up to the diameter is probably sufficient but since the 114 | # operation is fast, it is okay to repeat more times. 115 | for _ in range(vertices): 116 | new_hashes = [] 117 | for v in range(vertices): 118 | in_neighbors = [hashes[w] for w in range(vertices) if matrix[w, v]] 119 | out_neighbors = [hashes[w] for w in range(vertices) if matrix[v, w]] 120 | new_hashes.append(hashlib.md5( 121 | (''.join(sorted(in_neighbors)) + '|' + 122 | ''.join(sorted(out_neighbors)) + '|' + 123 | hashes[v]).encode('utf-8')).hexdigest()) 124 | hashes = new_hashes 125 | fingerprint = hashlib.md5(str(sorted(hashes)).encode('utf-8')).hexdigest() 126 | 127 | return fingerprint 128 | 129 | 130 | def permute_graph(graph, label, permutation): 131 | """Permutes the graph and labels based on permutation. 132 | 133 | Args: 134 | graph: np.ndarray adjacency matrix. 135 | label: list of labels of same length as graph dimensions. 136 | permutation: a permutation list of ints of same length as graph dimensions. 137 | 138 | Returns: 139 | np.ndarray where vertex permutation[v] is vertex v from the original graph 140 | """ 141 | # vertex permutation[v] in new graph is vertex v in the old graph 142 | forward_perm = zip(permutation, list(range(len(permutation)))) 143 | inverse_perm = [x[1] for x in sorted(forward_perm)] 144 | edge_fn = lambda x, y: graph[inverse_perm[x], inverse_perm[y]] == 1 145 | new_matrix = np.fromfunction(np.vectorize(edge_fn), 146 | (len(label), len(label)), 147 | dtype=np.int8) 148 | new_label = [label[inverse_perm[i]] for i in range(len(label))] 149 | return new_matrix, new_label 150 | 151 | 152 | def is_isomorphic(graph1, graph2): 153 | """Exhaustively checks if 2 graphs are isomorphic.""" 154 | matrix1, label1 = np.array(graph1[0]), graph1[1] 155 | matrix2, label2 = np.array(graph2[0]), graph2[1] 156 | assert np.shape(matrix1) == np.shape(matrix2) 157 | assert len(label1) == len(label2) 158 | 159 | vertices = np.shape(matrix1)[0] 160 | # Note: input and output in our constrained graphs always map to themselves 161 | # but this script does not enforce that. 162 | for perm in itertools.permutations(range(0, vertices)): 163 | pmatrix1, plabel1 = permute_graph(matrix1, label1, perm) 164 | if np.array_equal(pmatrix1, matrix2) and plabel1 == label2: 165 | return True 166 | 167 | return False 168 | -------------------------------------------------------------------------------- /nasbench/lib/base_ops.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 The Google Research Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Base operations used by the modules in this search space.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import abc 22 | 23 | import tensorflow as tf 24 | 25 | # Currently, only channels_last is well supported. 26 | VALID_DATA_FORMATS = frozenset(['channels_last', 'channels_first']) 27 | MIN_FILTERS = 8 28 | BN_MOMENTUM = 0.997 29 | BN_EPSILON = 1e-5 30 | 31 | 32 | def conv_bn_relu(inputs, conv_size, conv_filters, is_training, data_format): 33 | """Convolution followed by batch norm and ReLU.""" 34 | if data_format == 'channels_last': 35 | axis = 3 36 | elif data_format == 'channels_first': 37 | axis = 1 38 | else: 39 | raise ValueError('invalid data_format') 40 | 41 | net = tf.layers.conv2d( 42 | inputs=inputs, 43 | filters=conv_filters, 44 | kernel_size=conv_size, 45 | strides=(1, 1), 46 | use_bias=False, 47 | kernel_initializer=tf.variance_scaling_initializer(), 48 | padding='same', 49 | data_format=data_format) 50 | 51 | net = tf.layers.batch_normalization( 52 | inputs=net, 53 | axis=axis, 54 | momentum=BN_MOMENTUM, 55 | epsilon=BN_EPSILON, 56 | training=is_training) 57 | 58 | net = tf.nn.relu(net) 59 | 60 | return net 61 | 62 | 63 | class BaseOp(object): 64 | """Abstract base operation class.""" 65 | __metaclass__ = abc.ABCMeta 66 | 67 | def __init__(self, is_training, data_format='channels_last'): 68 | self.is_training = is_training 69 | if data_format.lower() not in VALID_DATA_FORMATS: 70 | raise ValueError('invalid data_format') 71 | self.data_format = data_format.lower() 72 | 73 | @abc.abstractmethod 74 | def build(self, inputs, channels): 75 | """Builds the operation with input tensors and returns an output tensor. 76 | 77 | Args: 78 | inputs: a 4-D Tensor. 79 | channels: int number of output channels of operation. The operation may 80 | choose to ignore this parameter. 81 | 82 | Returns: 83 | a 4-D Tensor with the same data format. 84 | """ 85 | pass 86 | 87 | 88 | class Identity(BaseOp): 89 | """Identity operation (ignores channels).""" 90 | 91 | def build(self, inputs, channels): 92 | del channels # Unused 93 | return tf.identity(inputs, name='identity') 94 | 95 | 96 | class Conv3x3BnRelu(BaseOp): 97 | """3x3 convolution with batch norm and ReLU activation.""" 98 | 99 | def build(self, inputs, channels): 100 | with tf.variable_scope('Conv3x3-BN-ReLU'): 101 | net = conv_bn_relu( 102 | inputs, 3, channels, self.is_training, self.data_format) 103 | 104 | return net 105 | 106 | 107 | class Conv1x1BnRelu(BaseOp): 108 | """1x1 convolution with batch norm and ReLU activation.""" 109 | 110 | def build(self, inputs, channels): 111 | with tf.variable_scope('Conv1x1-BN-ReLU'): 112 | net = conv_bn_relu( 113 | inputs, 1, channels, self.is_training, self.data_format) 114 | 115 | return net 116 | 117 | 118 | class MaxPool3x3(BaseOp): 119 | """3x3 max pool with no subsampling.""" 120 | 121 | def build(self, inputs, channels): 122 | del channels # Unused 123 | with tf.variable_scope('MaxPool3x3'): 124 | net = tf.layers.max_pooling2d( 125 | inputs=inputs, 126 | pool_size=(3, 3), 127 | strides=(1, 1), 128 | padding='same', 129 | data_format=self.data_format) 130 | 131 | return net 132 | 133 | 134 | class BottleneckConv3x3(BaseOp): 135 | """[1x1(/4)]+3x3+[1x1(*4)] conv. Uses BN + ReLU post-activation.""" 136 | # TODO(chrisying): verify this block can reproduce results of ResNet-50. 137 | 138 | def build(self, inputs, channels): 139 | with tf.variable_scope('BottleneckConv3x3'): 140 | net = conv_bn_relu( 141 | inputs, 1, channels // 4, self.is_training, self.data_format) 142 | net = conv_bn_relu( 143 | net, 3, channels // 4, self.is_training, self.data_format) 144 | net = conv_bn_relu( 145 | net, 1, channels, self.is_training, self.data_format) 146 | 147 | return net 148 | 149 | 150 | class BottleneckConv5x5(BaseOp): 151 | """[1x1(/4)]+5x5+[1x1(*4)] conv. Uses BN + ReLU post-activation.""" 152 | 153 | def build(self, inputs, channels): 154 | with tf.variable_scope('BottleneckConv5x5'): 155 | net = conv_bn_relu( 156 | inputs, 1, channels // 4, self.is_training, self.data_format) 157 | net = conv_bn_relu( 158 | net, 5, channels // 4, self.is_training, self.data_format) 159 | net = conv_bn_relu( 160 | net, 1, channels, self.is_training, self.data_format) 161 | 162 | return net 163 | 164 | 165 | class MaxPool3x3Conv1x1(BaseOp): 166 | """3x3 max pool with no subsampling followed by 1x1 for rescaling.""" 167 | 168 | def build(self, inputs, channels): 169 | with tf.variable_scope('MaxPool3x3-Conv1x1'): 170 | net = tf.layers.max_pooling2d( 171 | inputs=inputs, 172 | pool_size=(3, 3), 173 | strides=(1, 1), 174 | padding='same', 175 | data_format=self.data_format) 176 | 177 | net = conv_bn_relu(net, 1, channels, self.is_training, self.data_format) 178 | 179 | return net 180 | 181 | 182 | # Commas should not be used in op names 183 | OP_MAP = { 184 | 'identity': Identity, 185 | 'conv3x3-bn-relu': Conv3x3BnRelu, 186 | 'conv1x1-bn-relu': Conv1x1BnRelu, 187 | 'maxpool3x3': MaxPool3x3, 188 | 'bottleneck3x3': BottleneckConv3x3, 189 | 'bottleneck5x5': BottleneckConv5x5, 190 | 'maxpool3x3-conv1x1': MaxPool3x3Conv1x1, 191 | } 192 | -------------------------------------------------------------------------------- /nasbench/lib/model_spec.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 The Google Research Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Model specification for module connectivity individuals. 16 | 17 | This module handles pruning the unused parts of the computation graph but should 18 | avoid creating any TensorFlow models (this is done inside model_builder.py). 19 | """ 20 | 21 | from __future__ import absolute_import 22 | from __future__ import division 23 | from __future__ import print_function 24 | 25 | import copy 26 | 27 | from nasbench.lib import graph_util 28 | import numpy as np 29 | 30 | # Graphviz is optional and only required for visualization. 31 | try: 32 | import graphviz # pylint: disable=g-import-not-at-top 33 | except ImportError: 34 | pass 35 | 36 | 37 | class ModelSpec(object): 38 | """Model specification given adjacency matrix and labeling.""" 39 | 40 | def __init__(self, matrix, ops, data_format='channels_last'): 41 | """Initialize the module spec. 42 | 43 | Args: 44 | matrix: ndarray or nested list with shape [V, V] for the adjacency matrix. 45 | ops: V-length list of labels for the base ops used. The first and last 46 | elements are ignored because they are the input and output vertices 47 | which have no operations. The elements are retained to keep consistent 48 | indexing. 49 | data_format: channels_last or channels_first. 50 | 51 | Raises: 52 | ValueError: invalid matrix or ops 53 | """ 54 | if not isinstance(matrix, np.ndarray): 55 | matrix = np.array(matrix) 56 | shape = np.shape(matrix) 57 | if len(shape) != 2 or shape[0] != shape[1]: 58 | raise ValueError('matrix must be square') 59 | if shape[0] != len(ops): 60 | raise ValueError('length of ops must match matrix dimensions') 61 | if not is_upper_triangular(matrix): 62 | raise ValueError('matrix must be upper triangular') 63 | 64 | # Both the original and pruned matrices are deep copies of the matrix and 65 | # ops so any changes to those after initialization are not recognized by the 66 | # spec. 67 | self.original_matrix = copy.deepcopy(matrix) 68 | self.original_ops = copy.deepcopy(ops) 69 | 70 | self.matrix = copy.deepcopy(matrix) 71 | self.ops = copy.deepcopy(ops) 72 | self.valid_spec = True 73 | self._prune() 74 | 75 | self.data_format = data_format 76 | 77 | def _prune(self): 78 | """Prune the extraneous parts of the graph. 79 | 80 | General procedure: 81 | 1) Remove parts of graph not connected to input. 82 | 2) Remove parts of graph not connected to output. 83 | 3) Reorder the vertices so that they are consecutive after steps 1 and 2. 84 | 85 | These 3 steps can be combined by deleting the rows and columns of the 86 | vertices that are not reachable from both the input and output (in reverse). 87 | """ 88 | num_vertices = np.shape(self.original_matrix)[0] 89 | 90 | # DFS forward from input 91 | visited_from_input = set([0]) 92 | frontier = [0] 93 | while frontier: 94 | top = frontier.pop() 95 | for v in range(top + 1, num_vertices): 96 | if self.original_matrix[top, v] and v not in visited_from_input: 97 | visited_from_input.add(v) 98 | frontier.append(v) 99 | 100 | # DFS backward from output 101 | visited_from_output = set([num_vertices - 1]) 102 | frontier = [num_vertices - 1] 103 | while frontier: 104 | top = frontier.pop() 105 | for v in range(0, top): 106 | if self.original_matrix[v, top] and v not in visited_from_output: 107 | visited_from_output.add(v) 108 | frontier.append(v) 109 | 110 | # Any vertex that isn't connected to both input and output is extraneous to 111 | # the computation graph. 112 | extraneous = set(range(num_vertices)).difference( 113 | visited_from_input.intersection(visited_from_output)) 114 | 115 | # If the non-extraneous graph is less than 2 vertices, the input is not 116 | # connected to the output and the spec is invalid. 117 | if len(extraneous) > num_vertices - 2: 118 | self.matrix = None 119 | self.ops = None 120 | self.valid_spec = False 121 | return 122 | 123 | self.matrix = np.delete(self.matrix, list(extraneous), axis=0) 124 | self.matrix = np.delete(self.matrix, list(extraneous), axis=1) 125 | for index in sorted(extraneous, reverse=True): 126 | del self.ops[index] 127 | 128 | def hash_spec(self, canonical_ops): 129 | """Computes the isomorphism-invariant graph hash of this spec. 130 | 131 | Args: 132 | canonical_ops: list of operations in the canonical ordering which they 133 | were assigned (i.e. the order provided in the config['available_ops']). 134 | 135 | Returns: 136 | MD5 hash of this spec which can be used to query the dataset. 137 | """ 138 | # Invert the operations back to integer label indices used in graph gen. 139 | labeling = [-1] + [canonical_ops.index(op) for op in self.ops[1:-1]] + [-2] 140 | return graph_util.hash_module(self.matrix, labeling) 141 | 142 | def visualize(self): 143 | """Creates a dot graph. Can be visualized in colab directly.""" 144 | num_vertices = np.shape(self.matrix)[0] 145 | g = graphviz.Digraph() 146 | g.node(str(0), 'input') 147 | for v in range(1, num_vertices - 1): 148 | g.node(str(v), self.ops[v]) 149 | g.node(str(num_vertices - 1), 'output') 150 | 151 | for src in range(num_vertices - 1): 152 | for dst in range(src + 1, num_vertices): 153 | if self.matrix[src, dst]: 154 | g.edge(str(src), str(dst)) 155 | 156 | return g 157 | 158 | 159 | def is_upper_triangular(matrix): 160 | """True if matrix is 0 on diagonal and below.""" 161 | for src in range(np.shape(matrix)[0]): 162 | for dst in range(0, src + 1): 163 | if matrix[src, dst] != 0: 164 | return False 165 | 166 | return True 167 | -------------------------------------------------------------------------------- /nasbench/tests/model_spec_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 The Google Research Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Tests for lib/model_spec.py.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | from nasbench.lib import model_spec 22 | import numpy as np 23 | import tensorflow as tf # Used only for tf.test 24 | 25 | 26 | class ModelSpecTest(tf.test.TestCase): 27 | 28 | def test_prune_noop(self): 29 | """Tests graphs which require no pruning.""" 30 | model1 = model_spec.ModelSpec( 31 | np.array([[0, 1, 0], 32 | [0, 0, 1], 33 | [0, 0, 0]]), 34 | [0, 0, 0]) 35 | assert model1.valid_spec 36 | assert np.array_equal(model1.original_matrix, model1.matrix) 37 | assert model1.original_ops == model1.original_ops 38 | 39 | model2 = model_spec.ModelSpec( 40 | np.array([[0, 1, 1], 41 | [0, 0, 1], 42 | [0, 0, 0]]), 43 | [0, 0, 0]) 44 | assert model2.valid_spec 45 | assert np.array_equal(model2.original_matrix, model2.matrix) 46 | assert model2.original_ops == model2.ops 47 | 48 | model3 = model_spec.ModelSpec( 49 | np.array([[0, 1, 1, 0], 50 | [0, 0, 0, 1], 51 | [0, 0, 0, 1], 52 | [0, 0, 0, 0]]), 53 | [0, 0, 0, 0]) 54 | assert model3.valid_spec 55 | assert np.array_equal(model3.original_matrix, model3.matrix) 56 | assert model3.original_ops == model3.ops 57 | 58 | def test_prune_islands(self): 59 | """Tests isolated components are pruned.""" 60 | model1 = model_spec.ModelSpec( 61 | np.array([[0, 1, 0, 0], 62 | [0, 0, 0, 1], 63 | [0, 0, 0, 0], 64 | [0, 0, 0, 0]]), 65 | [1, 2, 3, 4]) 66 | assert model1.valid_spec 67 | assert np.array_equal(model1.matrix, 68 | np.array([[0, 1, 0], 69 | [0, 0, 1], 70 | [0, 0, 0]])) 71 | assert model1.ops == [1, 2, 4] 72 | 73 | model2 = model_spec.ModelSpec( 74 | np.array([[0, 1, 0, 0, 0], 75 | [0, 0, 0, 0, 1], 76 | [0, 0, 0, 1, 0], 77 | [0, 0, 0, 0, 0], 78 | [0, 0, 0, 0, 0]]), 79 | [1, 2, 3, 4, 5]) 80 | assert model2.valid_spec 81 | assert np.array_equal(model2.matrix, 82 | np.array([[0, 1, 0], 83 | [0, 0, 1], 84 | [0, 0, 0]])) 85 | assert model2.ops == [1, 2, 5] 86 | 87 | def test_prune_dangling(self): 88 | """Tests dangling vertices are pruned.""" 89 | model1 = model_spec.ModelSpec( 90 | np.array([[0, 1, 1, 0], 91 | [0, 0, 0, 0], 92 | [0, 0, 0, 1], 93 | [0, 0, 0, 0]]), 94 | [1, 2, 3, 4]) 95 | assert model1.valid_spec 96 | assert np.array_equal(model1.matrix, 97 | np.array([[0, 1, 0], 98 | [0, 0, 1], 99 | [0, 0, 0]])) 100 | assert model1.ops == [1, 3, 4] 101 | 102 | model2 = model_spec.ModelSpec( 103 | np.array([[0, 0, 1, 0], 104 | [0, 0, 0, 1], 105 | [0, 0, 0, 1], 106 | [0, 0, 0, 0]]), 107 | [1, 2, 3, 4]) 108 | assert model2.valid_spec 109 | assert np.array_equal(model2.matrix, 110 | np.array([[0, 1, 0], 111 | [0, 0, 1], 112 | [0, 0, 0]])) 113 | assert model2.ops == [1, 3, 4] 114 | 115 | def test_prune_disconnected(self): 116 | """Tests graphs where with no input to output path are marked invalid.""" 117 | model1 = model_spec.ModelSpec( 118 | np.array([[0, 0], 119 | [0, 0]]), 120 | [0, 0]) 121 | assert not model1.valid_spec 122 | 123 | model2 = model_spec.ModelSpec( 124 | np.array([[0, 1, 0, 0], 125 | [0, 0, 0, 0], 126 | [0, 0, 0, 1], 127 | [0, 0, 0, 0]]), 128 | [1, 2, 3, 4]) 129 | assert not model2.valid_spec 130 | 131 | model3 = model_spec.ModelSpec( 132 | np.array([[0, 0, 0, 0], 133 | [0, 0, 1, 0], 134 | [0, 0, 0, 0], 135 | [0, 0, 0, 0]]), 136 | [1, 2, 3, 4]) 137 | assert not model3.valid_spec 138 | 139 | def test_is_upper_triangular(self): 140 | """Tests is_uppper_triangular correct for square graphs.""" 141 | m0 = np.array([[0, 0, 0, 0], 142 | [0, 0, 0, 0], 143 | [0, 0, 0, 0], 144 | [0, 0, 0, 0]]) 145 | assert model_spec.is_upper_triangular(m0) 146 | 147 | m1 = np.array([[0, 1, 1, 1], 148 | [0, 0, 1, 1], 149 | [0, 0, 0, 1], 150 | [0, 0, 0, 0]]) 151 | assert model_spec.is_upper_triangular(m1) 152 | 153 | m2 = np.array([[0, 1, 1, 1], 154 | [0, 0, 1, 1], 155 | [1, 0, 0, 1], 156 | [0, 0, 0, 0]]) 157 | assert not model_spec.is_upper_triangular(m2) 158 | 159 | m3 = np.array([[0, 0, 0, 0], 160 | [0, 0, 0, 0], 161 | [1, 0, 0, 0], 162 | [0, 0, 0, 0]]) 163 | assert not model_spec.is_upper_triangular(m3) 164 | 165 | m4 = np.array([[1, 0, 0, 0], 166 | [1, 1, 0, 0], 167 | [1, 1, 1, 0], 168 | [1, 1, 1, 1]]) 169 | assert not model_spec.is_upper_triangular(m4) 170 | 171 | m5 = np.array([[0]]) 172 | assert model_spec.is_upper_triangular(m5) 173 | 174 | m6 = np.array([[1]]) 175 | assert not model_spec.is_upper_triangular(m6) 176 | 177 | 178 | if __name__ == '__main__': 179 | tf.test.main() 180 | -------------------------------------------------------------------------------- /nasbench/lib/cifar.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 The Google Research Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """CIFAR-10 data pipeline with preprocessing. 16 | 17 | The data is generated via generate_cifar10_tfrecords.py. 18 | """ 19 | 20 | 21 | from __future__ import absolute_import 22 | from __future__ import division 23 | from __future__ import print_function 24 | 25 | import functools 26 | import tensorflow as tf 27 | 28 | WIDTH = 32 29 | HEIGHT = 32 30 | RGB_MEAN = [125.31, 122.95, 113.87] 31 | RGB_STD = [62.99, 62.09, 66.70] 32 | 33 | 34 | class CIFARInput(object): 35 | """Wrapper class for input_fn passed to TPUEstimator.""" 36 | 37 | def __init__(self, mode, config): 38 | """Initializes a CIFARInput object. 39 | 40 | Args: 41 | mode: one of [train, valid, test, augment, sample] 42 | config: config dict built from config.py 43 | 44 | Raises: 45 | ValueError: invalid mode or data files 46 | """ 47 | self.mode = mode 48 | self.config = config 49 | if mode == 'train': # Training set (no validation & test) 50 | self.data_files = config['train_data_files'] 51 | elif mode == 'train_eval': # For computing train error 52 | self.data_files = [config['train_data_files'][0]] 53 | elif mode == 'valid': # For computing validation error 54 | self.data_files = [config['valid_data_file']] 55 | elif mode == 'test': # For computing the test error 56 | self.data_files = [config['test_data_file']] 57 | elif mode == 'augment': # Training set (includes validation, no test) 58 | self.data_files = (config['train_data_files'] + 59 | [config['valid_data_file']]) 60 | elif mode == 'sample': # Fixed batch of 100 samples from validation 61 | self.data_files = [config['sample_data_file']] 62 | else: 63 | raise ValueError('invalid mode') 64 | 65 | if not self.data_files: 66 | raise ValueError('no data files provided') 67 | 68 | @property 69 | def num_images(self): 70 | """Number of images in the dataset (depends on the mode).""" 71 | if self.mode == 'train': 72 | return 40000 73 | elif self.mode == 'train_eval': 74 | return 10000 75 | elif self.mode == 'valid': 76 | return 10000 77 | elif self.mode == 'test': 78 | return 10000 79 | elif self.mode == 'augment': 80 | return 50000 81 | elif self.mode == 'sample': 82 | return 100 83 | 84 | def input_fn(self, params): 85 | """Returns a CIFAR tf.data.Dataset object. 86 | 87 | Args: 88 | params: parameter dict pass by Estimator. 89 | 90 | Returns: 91 | tf.data.Dataset object 92 | """ 93 | batch_size = params['batch_size'] 94 | is_training = (self.mode == 'train' or self.mode == 'augment') 95 | 96 | dataset = tf.data.TFRecordDataset(self.data_files) 97 | dataset = dataset.prefetch(buffer_size=batch_size) 98 | 99 | # Repeat dataset for training modes 100 | if is_training: 101 | # Shuffle buffer with whole dataset to ensure full randomness per epoch 102 | dataset = dataset.cache().apply( 103 | tf.contrib.data.shuffle_and_repeat( 104 | buffer_size=self.num_images)) 105 | 106 | # This is a hack to allow computing metrics on a fixed batch on TPU. Because 107 | # TPU shards the batch acrosss cores, we replicate the fixed batch so that 108 | # each core contains the whole batch. 109 | if self.mode == 'sample': 110 | dataset = dataset.repeat() 111 | 112 | # Parse, preprocess, and batch images 113 | parser_fn = functools.partial(_parser, is_training) 114 | dataset = dataset.apply( 115 | tf.contrib.data.map_and_batch( 116 | parser_fn, 117 | batch_size=batch_size, 118 | num_parallel_batches=self.config['tpu_num_shards'], 119 | drop_remainder=True)) 120 | 121 | # Assign static batch size dimension 122 | dataset = dataset.map(functools.partial(_set_batch_dimension, batch_size)) 123 | 124 | # Prefetch to overlap in-feed with training 125 | dataset = dataset.prefetch(tf.contrib.data.AUTOTUNE) 126 | 127 | return dataset 128 | 129 | 130 | def _preprocess(image): 131 | """Perform standard CIFAR preprocessing. 132 | 133 | Pads the image then performs a random crop. 134 | Then, image is flipped horizontally randomly. 135 | 136 | Args: 137 | image: image Tensor with shape [height, width, 3] 138 | 139 | Returns: 140 | preprocessed image with the same dimensions. 141 | """ 142 | # Pad 4 pixels on all sides with 0 143 | image = tf.image.resize_image_with_crop_or_pad( 144 | image, HEIGHT + 8, WIDTH + 8) 145 | 146 | # Random crop 147 | image = tf.random_crop(image, [HEIGHT, WIDTH, 3], seed=0) 148 | 149 | # Random flip 150 | image = tf.image.random_flip_left_right(image, seed=0) 151 | 152 | return image 153 | 154 | 155 | def _parser(use_preprocessing, serialized_example): 156 | """Parses a single tf.Example into image and label tensors.""" 157 | features = tf.parse_single_example( 158 | serialized_example, 159 | features={ 160 | 'image': tf.FixedLenFeature([], tf.string), 161 | 'label': tf.FixedLenFeature([], tf.int64), 162 | }) 163 | image = tf.decode_raw(features['image'], tf.uint8) 164 | image.set_shape([3 * HEIGHT * WIDTH]) 165 | image = tf.reshape(image, [3, HEIGHT, WIDTH]) 166 | # TODO(chrisying): handle NCHW format 167 | image = tf.transpose(image, [1, 2, 0]) 168 | image = tf.cast(image, tf.float32) 169 | if use_preprocessing: 170 | image = _preprocess(image) 171 | image -= tf.constant(RGB_MEAN, shape=[1, 1, 3]) 172 | image /= tf.constant(RGB_STD, shape=[1, 1, 3]) 173 | label = tf.cast(features['label'], tf.int32) 174 | return image, label 175 | 176 | 177 | def _set_batch_dimension(batch_size, images, labels): 178 | images.set_shape(images.get_shape().merge_with( 179 | tf.TensorShape([batch_size, None, None, None]))) 180 | labels.set_shape(labels.get_shape().merge_with( 181 | tf.TensorShape([batch_size]))) 182 | 183 | return images, labels 184 | -------------------------------------------------------------------------------- /nasbench/scripts/generate_graphs.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 The Google Research Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Generate all graphs up to structure and label isomorphism. 16 | 17 | The goal is to generate all unique computational graphs up to some number of 18 | vertices and edges. Computational graphs can be represented by directed acyclic 19 | graphs with all components connected along some path from a specially-labeled 20 | input to output. The pseudocode for generating these is: 21 | 22 | for V in [2, ..., MAX_VERTICES]: # V includes input and output vertices 23 | generate all bitmasks of length V*(V-1)/2 # num upper triangular entries 24 | for each bitmask: 25 | convert bitmask to adjacency matrix 26 | if adjacency matrix has disconnected vertices from input/output: 27 | discard and continue to next matrix 28 | generate all labelings of ops to vertices 29 | for each labeling: 30 | compute graph hash from matrix and labels 31 | if graph hash has not been seen before: 32 | output graph (adjacency matrix + labeling) 33 | 34 | This script uses a modification on Weisfeiler-Lehman color refinement 35 | (https://ist.ac.at/mfcs13/slides/gi.pdf) for graph hashing, which is very 36 | loosely similar to the hashing approach described in 37 | https://arxiv.org/pdf/1606.00001.pdf. The general idea is to assign each vertex 38 | a hash based on the in-degree, out-degree, and operation label then iteratively 39 | hash each vertex with the hashes of its neighbors. 40 | 41 | In more detail, the iterative update involves repeating the following steps a 42 | number of times greater than or equal to the diameter of the graph: 43 | 1) For each vertex, sort the hashes of the in-neighbors. 44 | 2) For each vertex, sort the hashes of the out-neighbors. 45 | 3) For each vertex, concatenate the sorted hashes from (1), (2) and the vertex 46 | operation label. 47 | 4) For each vertex, compute the MD5 hash of the concatenated values in (3). 48 | 5) Assign the newly computed hashes to each vertex. 49 | 50 | Finally, sort the hashes of all the vertices and concat and hash one more time 51 | to obtain the final graph hash. This hash is a graph invariant as all operations 52 | are invariant under isomorphism, thus we expect no false negatives (isomorphic 53 | graphs hashed to different values). 54 | 55 | We have empirically verified that, for graphs up to 7 vertices, 9 edges, 3 ops, 56 | this algorithm does not cause "false positives" (graphs that hash to the same 57 | value but are non-isomorphic). For such graphs, this algorithm yields 423,624 58 | unique computation graphs, which is roughly 1/3rd of the total number of 59 | connected DAGs before de-duping using this hash algorithm. 60 | """ 61 | from __future__ import absolute_import 62 | from __future__ import division 63 | from __future__ import print_function 64 | 65 | import itertools 66 | import json 67 | import sys 68 | 69 | from absl import app 70 | from absl import flags 71 | from absl import logging 72 | 73 | from nasbench.lib import graph_util 74 | import numpy as np 75 | import tensorflow as tf # For gfile 76 | 77 | flags.DEFINE_string('output_file', '/tmp/generated_graphs.json', 78 | 'Output file name.') 79 | flags.DEFINE_integer('max_vertices', 7, 80 | 'Maximum number of vertices including input/output.') 81 | flags.DEFINE_integer('num_ops', 3, 'Number of operation labels.') 82 | flags.DEFINE_integer('max_edges', 9, 'Maximum number of edges.') 83 | flags.DEFINE_boolean('verify_isomorphism', True, 84 | 'Exhaustively verifies that each detected isomorphism' 85 | ' is truly an isomorphism. This operation is very' 86 | ' expensive.') 87 | FLAGS = flags.FLAGS 88 | 89 | 90 | def main(_): 91 | total_graphs = 0 # Total number of graphs (including isomorphisms) 92 | # hash --> (matrix, label) for the canonical graph associated with each hash 93 | buckets = {} 94 | 95 | logging.info('Using %d vertices, %d op labels, max %d edges', 96 | FLAGS.max_vertices, FLAGS.num_ops, FLAGS.max_edges) 97 | for vertices in range(2, FLAGS.max_vertices+1): 98 | for bits in range(2 ** (vertices * (vertices-1) // 2)): 99 | # Construct adj matrix from bit string 100 | matrix = np.fromfunction(graph_util.gen_is_edge_fn(bits), 101 | (vertices, vertices), 102 | dtype=np.int8) 103 | 104 | # Discard any graphs which can be pruned or exceed constraints 105 | if (not graph_util.is_full_dag(matrix) or 106 | graph_util.num_edges(matrix) > FLAGS.max_edges): 107 | continue 108 | 109 | # Iterate through all possible labelings 110 | for labeling in itertools.product(*[range(FLAGS.num_ops) 111 | for _ in range(vertices-2)]): 112 | total_graphs += 1 113 | labeling = [-1] + list(labeling) + [-2] 114 | fingerprint = graph_util.hash_module(matrix, labeling) 115 | 116 | if fingerprint not in buckets: 117 | buckets[fingerprint] = (matrix.tolist(), labeling) 118 | 119 | # This catches the "false positive" case of two models which are not 120 | # isomorphic hashing to the same bucket. 121 | elif FLAGS.verify_isomorphism: 122 | canonical_graph = buckets[fingerprint] 123 | if not graph_util.is_isomorphic( 124 | (matrix.tolist(), labeling), canonical_graph): 125 | logging.fatal('Matrix:\n%s\nLabel: %s\nis not isomorphic to' 126 | ' canonical matrix:\n%s\nLabel: %s', 127 | str(matrix), str(labeling), 128 | str(canonical_graph[0]), 129 | str(canonical_graph[1])) 130 | sys.exit() 131 | 132 | logging.info('Up to %d vertices: %d graphs (%d without hashing)', 133 | vertices, len(buckets), total_graphs) 134 | 135 | with tf.gfile.Open(FLAGS.output_file, 'w') as f: 136 | json.dump(buckets, f, sort_keys=True) 137 | 138 | 139 | if __name__ == '__main__': 140 | app.run(main) 141 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # NASBench: A Neural Architecture Search Dataset and Benchmark 2 | 3 | This repository contains the code used for generating and interacting with the 4 | NASBench dataset. The dataset contains **423,624 unique neural networks** 5 | exhaustively generated and evaluated from a fixed graph-based search space. 6 | 7 | Each network is trained and evaluated multiple times on CIFAR-10 at various 8 | training budgets and we present the metrics in a queriable API. The current 9 | release contains over **5 million** trained and evaluated models. 10 | 11 | Our paper can be found at: 12 | 13 | [NAS-Bench-101: Towards Reproducible Neural Architecture 14 | Search](https://arxiv.org/abs/1902.09635) 15 | 16 | If you use this dataset, please cite: 17 | 18 | ``` 19 | @InProceedings{pmlr-v97-ying19a, 20 | title = {{NAS}-Bench-101: Towards Reproducible Neural Architecture Search}, 21 | author = {Ying, Chris and Klein, Aaron and Christiansen, Eric and Real, Esteban and Murphy, Kevin and Hutter, Frank}, 22 | booktitle = {Proceedings of the 36th International Conference on Machine Learning}, 23 | pages = {7105--7114}, 24 | year = {2019}, 25 | editor = {Chaudhuri, Kamalika and Salakhutdinov, Ruslan}, 26 | volume = {97}, 27 | series = {Proceedings of Machine Learning Research}, 28 | address = {Long Beach, California, USA}, 29 | month = {09--15 Jun}, 30 | publisher = {PMLR}, 31 | url = {http://proceedings.mlr.press/v97/ying19a.html}, 32 | ``` 33 | 34 | ## Dataset overview 35 | 36 | NASBench is a tabular dataset which maps convolutional neural network 37 | architectures to their trained and evaluated performance on CIFAR-10. 38 | Specifically, all networks share the same network "skeleton", which can be seen 39 | in Figure (a) below. What changes between different models is the "module", which is a 40 | collection of neural network operations linked in an arbitrary graph-like 41 | structure. 42 | 43 | Modules are represented by directed acyclic graphs with up to 9 vertices and 7 44 | edges. The valid operations at each vertex are "3x3 convolution", "1x1 45 | convolution", and "3x3 max-pooling". Figure (b) below shows an Inception-like 46 | cell within the dataset. Figure (c) shows a high-level overview of how the 47 | interior filter counts of each module are computed. 48 | 49 | 50 | 51 | There are exactly 423,624 computationally unique modules within this search 52 | space and each one has been trained for 4, 12, 36, and 108 epochs three times 53 | each (423K * 3 * 4 = ~5M total trained models). We report the following metrics: 54 | 55 | * training accuracy 56 | * validation accuracy 57 | * testing accuracy 58 | * number of parameters 59 | * training time 60 | 61 | The scatterplot below shows a comparison of number of parameters, training time, 62 | and mean validation accuracy of models trained for 108 epochs in the dataset. 63 | 64 | 65 | 66 | See our paper for more detailed information about the design of this search 67 | space, further implementation details, and more in-depth analysis. 68 | 69 | ## Colab 70 | 71 | You can directly use this dataset from Google Colaboratory without needing to 72 | install anything on your local machine. Click "Open in Colab" below: 73 | 74 | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/google-research/nasbench/blob/master/NASBench.ipynb) 75 | 76 | ## Setup 77 | 78 | 1. Clone this repo. 79 | 80 | ``` 81 | git clone https://github.com/google-research/nasbench 82 | cd nasbench 83 | ``` 84 | 85 | 2. (optional) Create a virtualenv for this library. 86 | 87 | ``` 88 | virtualenv venv 89 | source venv/bin/activate 90 | ``` 91 | 92 | 3. Install the project along with dependencies. 93 | 94 | ``` 95 | pip install -e . 96 | ``` 97 | 98 | **Note:** the only required dependency is TensorFlow. The above instructions 99 | will install the CPU version of TensorFlow to the virtualenv. For other install 100 | options, see https://www.tensorflow.org/install/. 101 | 102 | ## Download the dataset 103 | 104 | The full dataset (which includes all 5M data points at all 4 epoch lengths): 105 | 106 | https://storage.googleapis.com/nasbench/nasbench_full.tfrecord 107 | 108 | Size: ~1.95 GB, SHA256: `3d64db8180fb1b0207212f9032205064312b6907a3bbc81eabea10db2f5c7e9c` 109 | 110 | --- 111 | 112 | Subset of the dataset with only models trained at 108 epochs: 113 | 114 | https://storage.googleapis.com/nasbench/nasbench_only108.tfrecord 115 | 116 | Size: ~499 MB, SHA256: `4c39c3936e36a85269881d659e44e61a245babcb72cb374eacacf75d0e5f4fd1` 117 | 118 | 119 | ## Using the dataset 120 | 121 | Example usage (see `example.py` for a full runnable example): 122 | 123 | ```python 124 | # Load the data from file (this will take some time) 125 | nasbench = api.NASBench('/path/to/nasbench.tfrecord') 126 | 127 | # Create an Inception-like module (5x5 convolution replaced with two 3x3 128 | # convolutions). 129 | model_spec = api.ModelSpec( 130 | # Adjacency matrix of the module 131 | matrix=[[0, 1, 1, 1, 0, 1, 0], # input layer 132 | [0, 0, 0, 0, 0, 0, 1], # 1x1 conv 133 | [0, 0, 0, 0, 0, 0, 1], # 3x3 conv 134 | [0, 0, 0, 0, 1, 0, 0], # 5x5 conv (replaced by two 3x3's) 135 | [0, 0, 0, 0, 0, 0, 1], # 5x5 conv (replaced by two 3x3's) 136 | [0, 0, 0, 0, 0, 0, 1], # 3x3 max-pool 137 | [0, 0, 0, 0, 0, 0, 0]], # output layer 138 | # Operations at the vertices of the module, matches order of matrix 139 | ops=[INPUT, CONV1X1, CONV3X3, CONV3X3, CONV3X3, MAXPOOL3X3, OUTPUT]) 140 | 141 | # Query this model from dataset, returns a dictionary containing the metrics 142 | # associated with this model. 143 | data = nasbench.query(model_spec) 144 | ``` 145 | 146 | See `nasbench/api.py` for more information, including the constraints on valid 147 | module matrices and operations. 148 | 149 | **Note**: it is not required to use `nasbench/api.py` to work with this dataset, 150 | you can see how to parse the dataset files from the initializer inside 151 | `nasbench/api.py` and then interact the data however you'd like. 152 | 153 | ## How the dataset was generated 154 | 155 | The dataset generation code is provided for reference, but the dataset has 156 | already been fully generated. 157 | 158 | The list of unique computation graphs evaluated in this dataset was generated 159 | via `nasbench/scripts/generate_graphs.py`. Each of these graphs was evaluated 160 | multiple times via `nasbench/scripts/run_evaluation.py`. 161 | 162 | ## How to run the unit tests 163 | 164 | Unit tests are included for some of the algorithmically complex parts of the 165 | code. The tests can be run directly via Python. Example: 166 | 167 | ``` 168 | python nasbench/tests/model_builder_test.py 169 | ``` 170 | 171 | ## Disclaimer 172 | 173 | This is not an official Google product. 174 | -------------------------------------------------------------------------------- /nasbench/lib/model_metrics_pb2.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 The Google Research Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | # pylint: skip-file 16 | # Generated by the protocol buffer compiler. DO NOT EDIT! 17 | # source: model_metrics.proto 18 | 19 | import sys 20 | _b=sys.version_info[0]<3 and (lambda x:x) or (lambda x:x.encode('latin1')) 21 | from google.protobuf import descriptor as _descriptor 22 | from google.protobuf import message as _message 23 | from google.protobuf import reflection as _reflection 24 | from google.protobuf import symbol_database as _symbol_database 25 | # @@protoc_insertion_point(imports) 26 | 27 | _sym_db = _symbol_database.Default() 28 | 29 | 30 | 31 | 32 | DESCRIPTOR = _descriptor.FileDescriptor( 33 | name='model_metrics.proto', 34 | package='nasbench', 35 | syntax='proto2', 36 | serialized_options=None, 37 | serialized_pb=_b('\n\x13model_metrics.proto\x12\x08nasbench\"s\n\x0cModelMetrics\x12\x31\n\x0f\x65valuation_data\x18\x01 \x03(\x0b\x32\x18.nasbench.EvaluationData\x12\x1c\n\x14trainable_parameters\x18\x02 \x01(\x05\x12\x12\n\ntotal_time\x18\x03 \x01(\x01\"\xa3\x01\n\x0e\x45valuationData\x12\x15\n\rcurrent_epoch\x18\x01 \x01(\x01\x12\x15\n\rtraining_time\x18\x02 \x01(\x01\x12\x16\n\x0etrain_accuracy\x18\x03 \x01(\x01\x12\x1b\n\x13validation_accuracy\x18\x04 \x01(\x01\x12\x15\n\rtest_accuracy\x18\x05 \x01(\x01\x12\x17\n\x0f\x63heckpoint_path\x18\x06 \x01(\t') 38 | ) 39 | 40 | 41 | 42 | 43 | _MODELMETRICS = _descriptor.Descriptor( 44 | name='ModelMetrics', 45 | full_name='nasbench.ModelMetrics', 46 | filename=None, 47 | file=DESCRIPTOR, 48 | containing_type=None, 49 | fields=[ 50 | _descriptor.FieldDescriptor( 51 | name='evaluation_data', full_name='nasbench.ModelMetrics.evaluation_data', index=0, 52 | number=1, type=11, cpp_type=10, label=3, 53 | has_default_value=False, default_value=[], 54 | message_type=None, enum_type=None, containing_type=None, 55 | is_extension=False, extension_scope=None, 56 | serialized_options=None, file=DESCRIPTOR), 57 | _descriptor.FieldDescriptor( 58 | name='trainable_parameters', full_name='nasbench.ModelMetrics.trainable_parameters', index=1, 59 | number=2, type=5, cpp_type=1, label=1, 60 | has_default_value=False, default_value=0, 61 | message_type=None, enum_type=None, containing_type=None, 62 | is_extension=False, extension_scope=None, 63 | serialized_options=None, file=DESCRIPTOR), 64 | _descriptor.FieldDescriptor( 65 | name='total_time', full_name='nasbench.ModelMetrics.total_time', index=2, 66 | number=3, type=1, cpp_type=5, label=1, 67 | has_default_value=False, default_value=float(0), 68 | message_type=None, enum_type=None, containing_type=None, 69 | is_extension=False, extension_scope=None, 70 | serialized_options=None, file=DESCRIPTOR), 71 | ], 72 | extensions=[ 73 | ], 74 | nested_types=[], 75 | enum_types=[ 76 | ], 77 | serialized_options=None, 78 | is_extendable=False, 79 | syntax='proto2', 80 | extension_ranges=[], 81 | oneofs=[ 82 | ], 83 | serialized_start=33, 84 | serialized_end=148, 85 | ) 86 | 87 | 88 | _EVALUATIONDATA = _descriptor.Descriptor( 89 | name='EvaluationData', 90 | full_name='nasbench.EvaluationData', 91 | filename=None, 92 | file=DESCRIPTOR, 93 | containing_type=None, 94 | fields=[ 95 | _descriptor.FieldDescriptor( 96 | name='current_epoch', full_name='nasbench.EvaluationData.current_epoch', index=0, 97 | number=1, type=1, cpp_type=5, label=1, 98 | has_default_value=False, default_value=float(0), 99 | message_type=None, enum_type=None, containing_type=None, 100 | is_extension=False, extension_scope=None, 101 | serialized_options=None, file=DESCRIPTOR), 102 | _descriptor.FieldDescriptor( 103 | name='training_time', full_name='nasbench.EvaluationData.training_time', index=1, 104 | number=2, type=1, cpp_type=5, label=1, 105 | has_default_value=False, default_value=float(0), 106 | message_type=None, enum_type=None, containing_type=None, 107 | is_extension=False, extension_scope=None, 108 | serialized_options=None, file=DESCRIPTOR), 109 | _descriptor.FieldDescriptor( 110 | name='train_accuracy', full_name='nasbench.EvaluationData.train_accuracy', index=2, 111 | number=3, type=1, cpp_type=5, label=1, 112 | has_default_value=False, default_value=float(0), 113 | message_type=None, enum_type=None, containing_type=None, 114 | is_extension=False, extension_scope=None, 115 | serialized_options=None, file=DESCRIPTOR), 116 | _descriptor.FieldDescriptor( 117 | name='validation_accuracy', full_name='nasbench.EvaluationData.validation_accuracy', index=3, 118 | number=4, type=1, cpp_type=5, label=1, 119 | has_default_value=False, default_value=float(0), 120 | message_type=None, enum_type=None, containing_type=None, 121 | is_extension=False, extension_scope=None, 122 | serialized_options=None, file=DESCRIPTOR), 123 | _descriptor.FieldDescriptor( 124 | name='test_accuracy', full_name='nasbench.EvaluationData.test_accuracy', index=4, 125 | number=5, type=1, cpp_type=5, label=1, 126 | has_default_value=False, default_value=float(0), 127 | message_type=None, enum_type=None, containing_type=None, 128 | is_extension=False, extension_scope=None, 129 | serialized_options=None, file=DESCRIPTOR), 130 | _descriptor.FieldDescriptor( 131 | name='checkpoint_path', full_name='nasbench.EvaluationData.checkpoint_path', index=5, 132 | number=6, type=9, cpp_type=9, label=1, 133 | has_default_value=False, default_value=_b("").decode('utf-8'), 134 | message_type=None, enum_type=None, containing_type=None, 135 | is_extension=False, extension_scope=None, 136 | serialized_options=None, file=DESCRIPTOR), 137 | ], 138 | extensions=[ 139 | ], 140 | nested_types=[], 141 | enum_types=[ 142 | ], 143 | serialized_options=None, 144 | is_extendable=False, 145 | syntax='proto2', 146 | extension_ranges=[], 147 | oneofs=[ 148 | ], 149 | serialized_start=151, 150 | serialized_end=314, 151 | ) 152 | 153 | _MODELMETRICS.fields_by_name['evaluation_data'].message_type = _EVALUATIONDATA 154 | DESCRIPTOR.message_types_by_name['ModelMetrics'] = _MODELMETRICS 155 | DESCRIPTOR.message_types_by_name['EvaluationData'] = _EVALUATIONDATA 156 | _sym_db.RegisterFileDescriptor(DESCRIPTOR) 157 | 158 | ModelMetrics = _reflection.GeneratedProtocolMessageType('ModelMetrics', (_message.Message,), dict( 159 | DESCRIPTOR = _MODELMETRICS, 160 | __module__ = 'model_metrics_pb2' 161 | # @@protoc_insertion_point(class_scope:nasbench.ModelMetrics) 162 | )) 163 | _sym_db.RegisterMessage(ModelMetrics) 164 | 165 | EvaluationData = _reflection.GeneratedProtocolMessageType('EvaluationData', (_message.Message,), dict( 166 | DESCRIPTOR = _EVALUATIONDATA, 167 | __module__ = 'model_metrics_pb2' 168 | # @@protoc_insertion_point(class_scope:nasbench.EvaluationData) 169 | )) 170 | _sym_db.RegisterMessage(EvaluationData) 171 | 172 | 173 | # @@protoc_insertion_point(module_scope) 174 | -------------------------------------------------------------------------------- /nasbench/lib/training_time.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 The Google Research Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Tools to measure and limit the training time of a TF model.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import collections 22 | import tensorflow as tf 23 | 24 | # Name of scope where to put all timing-related ops and variables. 25 | _SCOPE_NAME = 'timing' 26 | 27 | # Variables names: 28 | _START_VAR = 'start_timestamp' 29 | _STEPS_VAR = 'steps' 30 | _PREV_VAR = 'previous_time' 31 | _TOTAL_VAR = 'total_time' 32 | 33 | # The name of the TF variable that will hold the total training time so far. 34 | # Get with estimator.get_variable_value(TOTAL_TIME_NAME) after 35 | # running estimator.train(). Note that this time includes the time spent in 36 | # previous calls to train() as well. 37 | TOTAL_TIME_NAME = '%s/%s' % (_SCOPE_NAME, _TOTAL_VAR) 38 | 39 | # We have a fixed temporal precision of one millisecond. 40 | # We used fixed precision to represent seconds since the epoch, as a tf.int64, 41 | # because tf.float32 lacks precision for large values and tf.float64 is not 42 | # supported on TPU. 43 | _INTERNAL_TIME_PRECISION = 1000 44 | 45 | 46 | def _seconds_to_internal_time(seconds): 47 | """Converts seconds to fixed-precision time.""" 48 | return tf.to_int64(tf.round(seconds * _INTERNAL_TIME_PRECISION)) 49 | 50 | 51 | def _internal_time_to_seconds(internal_time): 52 | """Converts fixed-precision time to seconds.""" 53 | return tf.to_float(internal_time / _INTERNAL_TIME_PRECISION) 54 | 55 | 56 | Timing = collections.namedtuple( # pylint: disable=g-bad-name 57 | 'Timing', 58 | [ 59 | # A SessionRunHook instance that must be passed to estimator.train() 60 | # through its `hooks` arg. 61 | 'train_hook', 62 | 63 | # A CheckpointSaverListener instance. This must be passed to 64 | # estimator.train() through its `saving_listeners` arg if and only if 65 | # checkpoints are being saved. 66 | 'saving_listener', 67 | ]) 68 | 69 | 70 | def limit(max_train_secs=None): 71 | """Provides hooks and ops to measure/limit the training time of a model. 72 | 73 | This is done by direct measurement of the time spent on training steps. It 74 | excludes time spent saving checkpoints or due to pre-emptions. 75 | 76 | Args: 77 | max_train_secs: the desired training time limit. It is possible that this 78 | may be exceeded by the time it takes to run 1 step. If None, training will 79 | not be limited by time but timing variables will still be created. 80 | 81 | Returns: 82 | A Timing named tuple. 83 | """ 84 | train_hook = _TimingRunHook(max_train_secs) 85 | saving_listener = _TimingSaverListener() 86 | return Timing(train_hook=train_hook, saving_listener=saving_listener) 87 | 88 | 89 | def get_total_time(): 90 | """Returns the timing/total_time variable, regardless of current scope. 91 | 92 | You may need to call force_create_timing_vars() first, or else there is a risk 93 | that you may try to retrieve a variable that doesn't yet exist. 94 | 95 | Returns: 96 | A TF Variable. 97 | 98 | Raises: 99 | RuntimeError: if the variable has not been created yet. 100 | """ 101 | timing_vars = _get_or_create_timing_vars() 102 | return timing_vars.total_time 103 | 104 | 105 | _TimingVars = collections.namedtuple( # pylint: disable=g-bad-name 106 | '_TimingVars', 107 | [ 108 | # TF variable to be used to store the timestamp (in seconds) of the 109 | # first training step after the last checkpoint save (or the first 110 | # training step ever if no save has happened yet). -1 means no steps 111 | # have been run since the last checkpoint save. 112 | 'start_timestamp', 113 | 114 | # TF variable to be used to store the number of steps since the last 115 | # checkpoint save (or the beginning of training if no save has happened 116 | # yet). 117 | 'steps', 118 | 119 | # TF variable to be used to store the training time up to the last 120 | # checkpoint saved. 121 | 'previous_time', 122 | 123 | # TF variable to be used to accumulate the total training time up 124 | # to the last step run. This time will not include gaps resulting from 125 | # checkpoint saving or pre-emptions. 126 | 'total_time', 127 | ]) 128 | 129 | 130 | class _TimingRunHook(tf.train.SessionRunHook): 131 | """Hook to stop the training after a certain amount of time.""" 132 | 133 | def __init__(self, max_train_secs=None): 134 | """Initializes the instance. 135 | 136 | Args: 137 | max_train_secs: the maximum number of seconds to train for. If None, 138 | training will not be limited by time. 139 | """ 140 | self._max_train_secs = max_train_secs 141 | 142 | def begin(self): 143 | with tf.name_scope(_SCOPE_NAME): 144 | # See _get_or_create_timing_vars for the definitions of these variables. 145 | timing_vars = _get_or_create_timing_vars() 146 | 147 | # An op to produce a tensor with the latest timestamp. 148 | self._end_op = _seconds_to_internal_time(tf.timestamp(name='end')) 149 | 150 | # An op to update the timing_vars.start_timestamp variable. 151 | self._start_op = tf.cond( 152 | pred=tf.equal(timing_vars.steps, 0), 153 | true_fn=lambda: timing_vars.start_timestamp.assign(self._end_op), 154 | false_fn=lambda: timing_vars.start_timestamp) 155 | 156 | # An op to update the step. 157 | with tf.control_dependencies([self._start_op]): 158 | self._step_op = timing_vars.steps.assign_add(1) 159 | 160 | # An op to compute the timing_vars.total_time variable. 161 | self._total_op = timing_vars.total_time.assign( 162 | timing_vars.previous_time + 163 | _internal_time_to_seconds(self._end_op - self._start_op)) 164 | 165 | def before_run(self, run_context): 166 | return tf.train.SessionRunArgs([self._total_op, self._step_op]) 167 | 168 | def after_run(self, run_context, run_values): 169 | total_time, _ = run_values.results 170 | if self._max_train_secs and total_time > self._max_train_secs: 171 | run_context.request_stop() 172 | 173 | 174 | class _TimingSaverListener(tf.train.CheckpointSaverListener): 175 | """Saving listener to store the train time up to the last checkpoint save.""" 176 | 177 | def begin(self): 178 | with tf.name_scope(_SCOPE_NAME): 179 | timing_vars = _get_or_create_timing_vars() 180 | 181 | # An op to update the timing_vars.previous_time variable. 182 | self._prev_op = timing_vars.previous_time.assign(timing_vars.total_time) 183 | 184 | # Marks that timing_vars.start_timestamp should be reset in the next step. 185 | self._reset_steps_op = timing_vars.steps.assign(0) 186 | 187 | def before_save(self, session, global_step_value): 188 | session.run(self._prev_op) 189 | 190 | def after_save(self, session, global_step_value): 191 | session.run(self._reset_steps_op) 192 | 193 | 194 | def _get_or_create_timing_vars(): 195 | """Creates variables used to measure training time. 196 | 197 | Returns: 198 | A _TimingVars named tuple. 199 | """ 200 | # We always create the timing variables at root_scope / _SCOPE_NAME, 201 | # regardless of the scope from where this is called. 202 | root_scope = tf.get_variable_scope() 203 | with tf.variable_scope(root_scope, reuse=tf.AUTO_REUSE): 204 | with tf.variable_scope(_SCOPE_NAME, reuse=tf.AUTO_REUSE): 205 | start_timestamp = tf.get_variable( 206 | _START_VAR, 207 | shape=[], 208 | dtype=tf.int64, 209 | initializer=tf.constant_initializer(-1), 210 | trainable=False) 211 | steps = tf.get_variable( 212 | _STEPS_VAR, 213 | shape=[], 214 | dtype=tf.int64, 215 | initializer=tf.constant_initializer(0), 216 | trainable=False) 217 | previous_time = tf.get_variable( 218 | _PREV_VAR, 219 | shape=[], 220 | dtype=tf.float32, 221 | initializer=tf.constant_initializer(0.0), 222 | trainable=False) 223 | total_time = tf.get_variable( 224 | _TOTAL_VAR, 225 | shape=[], 226 | dtype=tf.float32, 227 | initializer=tf.constant_initializer(0.0), 228 | trainable=False) 229 | return _TimingVars( 230 | start_timestamp=start_timestamp, 231 | steps=steps, 232 | previous_time=previous_time, 233 | total_time=total_time) 234 | -------------------------------------------------------------------------------- /nasbench/scripts/run_evaluation.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 The Google Research Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Script for training a large number of networks across multiple workers. 16 | 17 | Every process running this script is assigned a monotonically increasing 18 | worker_id starting at 0 to the total_workers (exclusive). Each full 19 | training-and-evaluation of a module counts as a single "work unit" with repeated 20 | runs being different units. Work units are assigned a monotonically increasing 21 | index and each worker only computes the work units that have the same index 22 | modulo total_workers as the worker_id. 23 | 24 | For example, for 3 models, each with 3 repeats, and 4 total workers: 25 | Model Number || 1 2 3 1 2 3 1 2 3 26 | Repeat Number || 1 1 1 2 2 2 3 3 3 27 | Work Unit Index || 0 1 2 3 4 5 6 7 8 28 | Assigned Worker || 0 1 2 3 0 1 2 3 1 29 | 30 | i.e. worker_id 0 will compute [model1-repeat1, model2-repeat2, model3-repeat3], 31 | worker_id 1 will compute [model2-repeat1, model3-repeat2], etc... 32 | 33 | --worker_id_offset is provided to allow launching workers in multiple flocks and 34 | is added to the --worker_id flag which is assumed to start at 0 for each new 35 | flock. --total_workers should be the total number of workers across all flocks. 36 | 37 | For basic failure recovery, each worker stores a text file with the current work 38 | unit index it is computing. Upon restarting, workers will resume at the 39 | beginning of the work unit index inside the recovery file if it exists. 40 | """ 41 | 42 | from __future__ import absolute_import 43 | from __future__ import division 44 | from __future__ import print_function 45 | 46 | import json 47 | import os 48 | import re 49 | 50 | from absl import app 51 | from absl import flags 52 | from nasbench.lib import config as _config 53 | from nasbench.lib import evaluate 54 | from nasbench.lib import model_metrics_pb2 55 | from nasbench.lib import model_spec 56 | import numpy as np 57 | import tensorflow as tf 58 | 59 | 60 | flags.DEFINE_string('models_file', '', 61 | 'JSON file containing models.') 62 | flags.DEFINE_string('remainders_file', '', 63 | 'JSON file containing list of remainders as tuples of' 64 | ' (module hash, repeat num). If provided, only the runs in' 65 | ' the list will be evaluated, otherwise, all models inside' 66 | ' models_file will be evaluated.') 67 | flags.DEFINE_string('model_id_regex', '^', 68 | 'Regex of models to train. Model IDs are MD5 hashes' 69 | ' which match ([a-f0-9]{32}).') 70 | flags.DEFINE_string('output_dir', '', 'Base output directory.') 71 | flags.DEFINE_integer('worker_id', 0, 72 | 'Worker ID within this flock, starting at 0.') 73 | flags.DEFINE_integer('worker_id_offset', 0, 74 | 'Worker ID offset added.') 75 | flags.DEFINE_integer('total_workers', 1, 76 | 'Total number of workers, across all flocks.') 77 | FLAGS = flags.FLAGS 78 | 79 | CHECKPOINT_PREFIX = 'model.ckpt' 80 | RESULTS_FILE = 'results.json' 81 | # Checkpoint 1 is a side-effect of pre-initializing the model weights and can be 82 | # deleted during the clean-up step. 83 | CHECKPOINT_1_PREFIX = 'model.ckpt-1.' 84 | 85 | 86 | class NumpyEncoder(json.JSONEncoder): 87 | """Converts numpy objects to JSON-serializable format.""" 88 | 89 | def default(self, obj): 90 | if isinstance(obj, np.ndarray): 91 | # Matrices converted to nested lists 92 | return obj.tolist() 93 | elif isinstance(obj, np.generic): 94 | # Scalars converted to closest Python type 95 | return np.asscalar(obj) 96 | return json.JSONEncoder.default(self, obj) 97 | 98 | 99 | class Evaluator(object): 100 | """Manages evaluating a subset of the total models.""" 101 | 102 | def __init__(self, 103 | models_file, 104 | output_dir, 105 | worker_id=0, 106 | total_workers=1, 107 | model_id_regex='^'): 108 | self.config = _config.build_config() 109 | with tf.gfile.Open(models_file) as f: 110 | self.models = json.load(f) 111 | 112 | self.remainders = None 113 | self.ordered_keys = None 114 | 115 | if FLAGS.remainders_file: 116 | # Run only the modules and repeat numbers specified 117 | with tf.gfile.Open(FLAGS.remainders_file) as f: 118 | self.remainders = json.load(f) 119 | self.remainders = sorted(self.remainders) 120 | self.num_models = len(self.remainders) 121 | self.total_work_units = self.num_models 122 | else: 123 | # Filter keys to only those that fit the regex and order them so all 124 | # workers see a canonical ordering. 125 | regex = re.compile(model_id_regex) 126 | evaluated_keys = [key for key in self.models.keys() if regex.match(key)] 127 | self.ordered_keys = sorted(evaluated_keys) 128 | self.num_models = len(self.ordered_keys) 129 | self.total_work_units = self.num_models * self.config['num_repeats'] 130 | 131 | self.total_workers = total_workers 132 | 133 | # If the worker is recovering from a restart, figure out where to restart 134 | worker_recovery_dir = os.path.join(output_dir, '_recovery') 135 | tf.gfile.MakeDirs(worker_recovery_dir) # Silently succeeds if exists 136 | self.recovery_file = os.path.join(worker_recovery_dir, str(worker_id)) 137 | if tf.gfile.Exists(self.recovery_file): 138 | with tf.gfile.Open(self.recovery_file) as f: 139 | self.current_index = int(f.read()) 140 | else: 141 | self.current_index = worker_id 142 | with tf.gfile.Open(self.recovery_file, 'w') as f: 143 | f.write(str(self.current_index)) 144 | 145 | assert self.current_index % self.total_workers == worker_id 146 | self.output_dir = output_dir 147 | 148 | def run_evaluation(self): 149 | """Runs the worker evaluation loop.""" 150 | while self.current_index < self.total_work_units: 151 | # Perform the expensive evaluation of the model at the current index 152 | self._evaluate_work_unit(self.current_index) 153 | 154 | self.current_index += self.total_workers 155 | with tf.gfile.Open(self.recovery_file, 'w') as f: 156 | f.write(str(self.current_index)) 157 | 158 | def _evaluate_work_unit(self, index): 159 | """Runs the evaluation of the model at the specified index. 160 | 161 | The index records the current index of the work unit being evaluated. Each 162 | worker will only compute the work units with index modulo total_workers 163 | equal to the worker_id. 164 | 165 | Args: 166 | index: int index into total work units. 167 | """ 168 | if self.remainders: 169 | assert self.ordered_keys is None 170 | model_id = self.remainders[index][0] 171 | model_repeat = self.remainders[index][1] 172 | else: 173 | model_id = self.ordered_keys[index % self.num_models] 174 | model_repeat = index // self.num_models + 1 175 | 176 | matrix, labels = self.models[model_id] 177 | matrix = np.array(matrix) 178 | 179 | # Re-label to config['available_ops'] 180 | labels = (['input'] + 181 | [self.config['available_ops'][lab] for lab in labels[1:-1]] + 182 | ['output']) 183 | spec = model_spec.ModelSpec(matrix, labels) 184 | assert spec.valid_spec 185 | assert np.sum(spec.matrix) <= self.config['max_edges'] 186 | 187 | # Split the directory into 16^2 roughly equal subdirectories 188 | model_dir = os.path.join(self.output_dir, 189 | model_id[:2], 190 | model_id, 191 | 'repeat_%d' % model_repeat) 192 | try: 193 | meta = evaluate.train_and_evaluate(spec, self.config, model_dir) 194 | except evaluate.AbortError: 195 | # After hitting the retry limit, the job will continue to the next work 196 | # unit. These failed jobs may need to be re-run at a later point. 197 | return 198 | 199 | # Write data to model_dir 200 | output_file = os.path.join(model_dir, RESULTS_FILE) 201 | with tf.gfile.Open(output_file, 'w') as f: 202 | json.dump(meta, f, cls=NumpyEncoder) 203 | 204 | # Delete some files to reclaim space 205 | self._clean_model_dir(model_dir) 206 | 207 | def _clean_model_dir(self, model_dir): 208 | """Cleans the output model directory to reclaim disk space.""" 209 | saved_prefixes = [CHECKPOINT_PREFIX, RESULTS_FILE] 210 | all_files = tf.gfile.ListDirectory(model_dir) 211 | files_to_keep = set() 212 | for filename in all_files: 213 | for prefix in saved_prefixes: 214 | if (filename.startswith(prefix) and 215 | not filename.startswith(CHECKPOINT_1_PREFIX)): 216 | files_to_keep.add(filename) 217 | 218 | for filename in all_files: 219 | if filename not in files_to_keep: 220 | full_filename = os.path.join(model_dir, filename) 221 | if tf.gfile.IsDirectory(full_filename): 222 | tf.gfile.DeleteRecursively(full_filename) 223 | else: 224 | tf.gfile.Remove(full_filename) 225 | 226 | 227 | def main(args): 228 | del args # Unused 229 | worker_id = FLAGS.worker_id + FLAGS.worker_id_offset 230 | evaluator = Evaluator( 231 | models_file=FLAGS.models_file, 232 | output_dir=FLAGS.output_dir, 233 | worker_id=worker_id, 234 | total_workers=FLAGS.total_workers, 235 | model_id_regex=FLAGS.model_id_regex) 236 | evaluator.run_evaluation() 237 | 238 | 239 | if __name__ == '__main__': 240 | app.run(main) 241 | -------------------------------------------------------------------------------- /nasbench/tests/run_evaluation_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 The Google Research Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Unit tests for scripts/run_evaluation.py.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import json 22 | import os 23 | import tempfile 24 | 25 | from absl.testing import flagsaver 26 | from nasbench.scripts import run_evaluation 27 | import tensorflow as tf 28 | 29 | 30 | 31 | class RunEvaluationTest(tf.test.TestCase): 32 | 33 | def setUp(self): 34 | """Set up files and directories that are expected by run_evaluation.""" 35 | # Create temp directory for output files 36 | self.output_dir = tempfile.mkdtemp() 37 | self.models_file = os.path.join(self.output_dir, 'models_file.json') 38 | 39 | self.toy_data = { 40 | 'abc': ([[0, 1, 1], [0, 0, 1], [0, 0, 0]], [-1, 0, -2]), 41 | 'abd': ([[0, 1, 0], [0, 0, 1], [0, 0, 0]], [-1, 0, -2]), 42 | 'abe': ([[0, 0, 1], [0, 0, 0], [0, 0, 0]], [-1, 0, -2]), 43 | } 44 | 45 | with tf.gfile.Open(self.models_file, 'w') as f: 46 | json.dump(self.toy_data, f) 47 | 48 | # Create files & directories which are normally created by 49 | # evaluate.train_and_evaluate but have been mocked out. 50 | for model_id in self.toy_data: 51 | eval_dir = os.path.join(self.output_dir, 'ab', model_id, 'repeat_1') 52 | tf.gfile.MakeDirs(eval_dir) 53 | run_evaluation.FLAGS.train_data_files = 'unused' 54 | run_evaluation.FLAGS.valid_data_file = 'unused' 55 | run_evaluation.FLAGS.test_data_file = 'unused' 56 | run_evaluation.FLAGS.num_repeats = 1 57 | 58 | @tf.test.mock.patch.object(run_evaluation, 'evaluate') 59 | def test_evaluate_single_worker(self, mock_eval): 60 | """Tests single worker code path.""" 61 | mock_eval.train_and_evaluate.return_value = 'unused_output' 62 | evaluator = run_evaluation.Evaluator( 63 | self.models_file, self.output_dir) 64 | evaluator.run_evaluation() 65 | 66 | expected_dir = os.path.join(self.output_dir, 'ab') 67 | mock_eval.train_and_evaluate.assert_has_calls([ 68 | tf.test.mock.call(tf.test.mock.ANY, tf.test.mock.ANY, 69 | os.path.join(expected_dir, 'abc', 'repeat_1')), 70 | tf.test.mock.call(tf.test.mock.ANY, tf.test.mock.ANY, 71 | os.path.join(expected_dir, 'abd', 'repeat_1')), 72 | tf.test.mock.call(tf.test.mock.ANY, tf.test.mock.ANY, 73 | os.path.join(expected_dir, 'abe', 'repeat_1'))]) 74 | 75 | for model_id in self.toy_data: 76 | self.assertTrue(tf.gfile.Exists( 77 | os.path.join(expected_dir, model_id, 'repeat_1', 'results.json'))) 78 | 79 | @tf.test.mock.patch.object(run_evaluation, 'evaluate') 80 | def test_evaluate_multi_worker_0(self, mock_eval): 81 | """Tests multi worker code path for worker 0.""" 82 | mock_eval.train_and_evaluate.return_value = 'unused_output' 83 | evaluator = run_evaluation.Evaluator( 84 | self.models_file, self.output_dir, worker_id=0, total_workers=2) 85 | evaluator.run_evaluation() 86 | 87 | expected_dir = os.path.join(self.output_dir, 'ab') 88 | mock_eval.train_and_evaluate.assert_has_calls([ 89 | tf.test.mock.call(tf.test.mock.ANY, tf.test.mock.ANY, 90 | os.path.join(expected_dir, 'abc', 'repeat_1')), 91 | tf.test.mock.call(tf.test.mock.ANY, tf.test.mock.ANY, 92 | os.path.join(expected_dir, 'abe', 'repeat_1'))]) 93 | 94 | for model_id in ['abc', 'abe']: 95 | self.assertTrue(tf.gfile.Exists( 96 | os.path.join(expected_dir, model_id, 'repeat_1', 'results.json'))) 97 | 98 | @tf.test.mock.patch.object(run_evaluation, 'evaluate') 99 | def test_evaluate_multi_worker_1(self, mock_eval): 100 | """Tests multi worker code path for worker 1.""" 101 | mock_eval.train_and_evaluate.return_value = 'unused_output' 102 | evaluator = run_evaluation.Evaluator( 103 | self.models_file, self.output_dir, worker_id=1, total_workers=2) 104 | evaluator.run_evaluation() 105 | 106 | expected_dir = os.path.join(self.output_dir, 'ab') 107 | mock_eval.train_and_evaluate.assert_has_calls([ 108 | tf.test.mock.call(tf.test.mock.ANY, tf.test.mock.ANY, 109 | os.path.join(expected_dir, 'abd', 'repeat_1'))]) 110 | 111 | self.assertTrue(tf.gfile.Exists( 112 | os.path.join(expected_dir, 'abd', 'repeat_1', 'results.json'))) 113 | 114 | @tf.test.mock.patch.object(run_evaluation, 'evaluate') 115 | def test_evaluate_regex(self, mock_eval): 116 | """Tests regex filters models.""" 117 | mock_eval.train_and_evaluate.return_value = 'unused_output' 118 | evaluator = run_evaluation.Evaluator( 119 | self.models_file, self.output_dir, model_id_regex='^ab(d|e)') 120 | evaluator.run_evaluation() 121 | 122 | expected_dir = os.path.join(self.output_dir, 'ab') 123 | mock_eval.train_and_evaluate.assert_has_calls([ 124 | tf.test.mock.call(tf.test.mock.ANY, tf.test.mock.ANY, 125 | os.path.join(expected_dir, 'abd', 'repeat_1')), 126 | tf.test.mock.call(tf.test.mock.ANY, tf.test.mock.ANY, 127 | os.path.join(expected_dir, 'abe', 'repeat_1'))]) 128 | 129 | for model_id in ['abd', 'abe']: 130 | self.assertTrue(tf.gfile.Exists( 131 | os.path.join(expected_dir, model_id, 'repeat_1', 'results.json'))) 132 | 133 | @tf.test.mock.patch.object(run_evaluation, 'evaluate') 134 | def test_evaluate_repeat(self, mock_eval): 135 | """Tests evaluate with repeats.""" 136 | mock_eval.train_and_evaluate.return_value = 'unused_output' 137 | 138 | # Create extra directories not created in setUp for repeat_2 139 | for model_id in self.toy_data: 140 | eval_dir = os.path.join(self.output_dir, 'ab', model_id, 'repeat_2') 141 | tf.gfile.MakeDirs(eval_dir) 142 | 143 | with flagsaver.flagsaver(num_repeats=2): 144 | evaluator = run_evaluation.Evaluator( 145 | self.models_file, self.output_dir) 146 | evaluator.run_evaluation() 147 | 148 | expected_dir = os.path.join(self.output_dir, 'ab') 149 | mock_eval.train_and_evaluate.assert_has_calls([ 150 | tf.test.mock.call(tf.test.mock.ANY, tf.test.mock.ANY, 151 | os.path.join(expected_dir, 'abc', 'repeat_1')), 152 | tf.test.mock.call(tf.test.mock.ANY, tf.test.mock.ANY, 153 | os.path.join(expected_dir, 'abd', 'repeat_1')), 154 | tf.test.mock.call(tf.test.mock.ANY, tf.test.mock.ANY, 155 | os.path.join(expected_dir, 'abe', 'repeat_1')), 156 | tf.test.mock.call(tf.test.mock.ANY, tf.test.mock.ANY, 157 | os.path.join(expected_dir, 'abc', 'repeat_2')), 158 | tf.test.mock.call(tf.test.mock.ANY, tf.test.mock.ANY, 159 | os.path.join(expected_dir, 'abd', 'repeat_2')), 160 | tf.test.mock.call(tf.test.mock.ANY, tf.test.mock.ANY, 161 | os.path.join(expected_dir, 'abe', 'repeat_2'))]) 162 | 163 | for model_id in self.toy_data: 164 | for repeat in range(2): 165 | self.assertTrue(tf.gfile.Exists( 166 | os.path.join(expected_dir, model_id, 167 | 'repeat_%d' % (repeat + 1), 'results.json'))) 168 | 169 | def test_clean_model_dir(self): 170 | """Tests clean-up of model directory keeps only intended files.""" 171 | model_dir = os.path.join(self.output_dir, 'ab', 'abcde', 'repeat_1') 172 | tf.gfile.MakeDirs(model_dir) 173 | 174 | # Write files which will be preserved 175 | preserved_files = ['model.ckpt-0.index', 176 | 'model.ckpt-100.index', 177 | 'results.json'] 178 | for filename in preserved_files: 179 | with tf.gfile.Open(os.path.join(model_dir, filename), 'w') as f: 180 | f.write('unused') 181 | 182 | # Write files which will be deleted 183 | for filename in ['checkpoint', 184 | 'events.out.tfevents']: 185 | with tf.gfile.Open(os.path.join(model_dir, filename), 'w') as f: 186 | f.write('unused') 187 | 188 | # Create subdirectory which will be deleted 189 | eval_dir = os.path.join(model_dir, 'eval_dir') 190 | tf.gfile.MakeDirs(eval_dir) 191 | with tf.gfile.Open(os.path.join(eval_dir, 'events.out.tfevents'), 'w') as f: 192 | f.write('unused') 193 | 194 | evaluator = run_evaluation.Evaluator(self.models_file, self.output_dir) 195 | evaluator._clean_model_dir(model_dir) 196 | 197 | # Check only intended files are preserved 198 | remaining_files = tf.gfile.ListDirectory(model_dir) 199 | self.assertItemsEqual(remaining_files, preserved_files) 200 | 201 | @tf.test.mock.patch.object(run_evaluation, 'evaluate') 202 | def test_recovery_file(self, mock_eval): 203 | """Tests that evaluation recovers from restart.""" 204 | mock_eval.train_and_evaluate.return_value = 'unused_output' 205 | 206 | # Write recovery file 207 | recovery_dir = os.path.join(self.output_dir, '_recovery') 208 | tf.gfile.MakeDirs(recovery_dir) 209 | with tf.gfile.Open(os.path.join(recovery_dir, '0'), 'w') as f: 210 | f.write('2') # Resume at 3rd entry 211 | 212 | evaluator = run_evaluation.Evaluator( 213 | self.models_file, self.output_dir) 214 | evaluator.run_evaluation() 215 | 216 | expected_dir = os.path.join(self.output_dir, 'ab') 217 | mock_eval.train_and_evaluate.assert_has_calls([ 218 | tf.test.mock.call(tf.test.mock.ANY, tf.test.mock.ANY, 219 | os.path.join(expected_dir, 'abe', 'repeat_1'))]) 220 | 221 | # Check that only 'abe' was evaluated, 'abc' and 'abe' are skipped due to 222 | # recovery. 223 | call_args = mock_eval.train_and_evaluate.call_args_list 224 | self.assertEqual(len(call_args), 1) 225 | 226 | # Check that recovery file is updated after run 227 | with tf.gfile.Open(evaluator.recovery_file) as f: 228 | new_idx = int(f.read()) 229 | self.assertEqual(new_idx, 3) 230 | 231 | 232 | if __name__ == '__main__': 233 | tf.test.main() 234 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | 2 | Apache License 3 | Version 2.0, January 2004 4 | http://www.apache.org/licenses/ 5 | 6 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 7 | 8 | 1. Definitions. 9 | 10 | "License" shall mean the terms and conditions for use, reproduction, 11 | and distribution as defined by Sections 1 through 9 of this document. 12 | 13 | "Licensor" shall mean the copyright owner or entity authorized by 14 | the copyright owner that is granting the License. 15 | 16 | "Legal Entity" shall mean the union of the acting entity and all 17 | other entities that control, are controlled by, or are under common 18 | control with that entity. For the purposes of this definition, 19 | "control" means (i) the power, direct or indirect, to cause the 20 | direction or management of such entity, whether by contract or 21 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 22 | outstanding shares, or (iii) beneficial ownership of such entity. 23 | 24 | "You" (or "Your") shall mean an individual or Legal Entity 25 | exercising permissions granted by this License. 26 | 27 | "Source" form shall mean the preferred form for making modifications, 28 | including but not limited to software source code, documentation 29 | source, and configuration files. 30 | 31 | "Object" form shall mean any form resulting from mechanical 32 | transformation or translation of a Source form, including but 33 | not limited to compiled object code, generated documentation, 34 | and conversions to other media types. 35 | 36 | "Work" shall mean the work of authorship, whether in Source or 37 | Object form, made available under the License, as indicated by a 38 | copyright notice that is included in or attached to the work 39 | (an example is provided in the Appendix below). 40 | 41 | "Derivative Works" shall mean any work, whether in Source or Object 42 | form, that is based on (or derived from) the Work and for which the 43 | editorial revisions, annotations, elaborations, or other modifications 44 | represent, as a whole, an original work of authorship. For the purposes 45 | of this License, Derivative Works shall not include works that remain 46 | separable from, or merely link (or bind by name) to the interfaces of, 47 | the Work and Derivative Works thereof. 48 | 49 | "Contribution" shall mean any work of authorship, including 50 | the original version of the Work and any modifications or additions 51 | to that Work or Derivative Works thereof, that is intentionally 52 | submitted to Licensor for inclusion in the Work by the copyright owner 53 | or by an individual or Legal Entity authorized to submit on behalf of 54 | the copyright owner. For the purposes of this definition, "submitted" 55 | means any form of electronic, verbal, or written communication sent 56 | to the Licensor or its representatives, including but not limited to 57 | communication on electronic mailing lists, source code control systems, 58 | and issue tracking systems that are managed by, or on behalf of, the 59 | Licensor for the purpose of discussing and improving the Work, but 60 | excluding communication that is conspicuously marked or otherwise 61 | designated in writing by the copyright owner as "Not a Contribution." 62 | 63 | "Contributor" shall mean Licensor and any individual or Legal Entity 64 | on behalf of whom a Contribution has been received by Licensor and 65 | subsequently incorporated within the Work. 66 | 67 | 2. Grant of Copyright License. Subject to the terms and conditions of 68 | this License, each Contributor hereby grants to You a perpetual, 69 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 70 | copyright license to reproduce, prepare Derivative Works of, 71 | publicly display, publicly perform, sublicense, and distribute the 72 | Work and such Derivative Works in Source or Object form. 73 | 74 | 3. Grant of Patent License. Subject to the terms and conditions of 75 | this License, each Contributor hereby grants to You a perpetual, 76 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 77 | (except as stated in this section) patent license to make, have made, 78 | use, offer to sell, sell, import, and otherwise transfer the Work, 79 | where such license applies only to those patent claims licensable 80 | by such Contributor that are necessarily infringed by their 81 | Contribution(s) alone or by combination of their Contribution(s) 82 | with the Work to which such Contribution(s) was submitted. If You 83 | institute patent litigation against any entity (including a 84 | cross-claim or counterclaim in a lawsuit) alleging that the Work 85 | or a Contribution incorporated within the Work constitutes direct 86 | or contributory patent infringement, then any patent licenses 87 | granted to You under this License for that Work shall terminate 88 | as of the date such litigation is filed. 89 | 90 | 4. Redistribution. You may reproduce and distribute copies of the 91 | Work or Derivative Works thereof in any medium, with or without 92 | modifications, and in Source or Object form, provided that You 93 | meet the following conditions: 94 | 95 | (a) You must give any other recipients of the Work or 96 | Derivative Works a copy of this License; and 97 | 98 | (b) You must cause any modified files to carry prominent notices 99 | stating that You changed the files; and 100 | 101 | (c) You must retain, in the Source form of any Derivative Works 102 | that You distribute, all copyright, patent, trademark, and 103 | attribution notices from the Source form of the Work, 104 | excluding those notices that do not pertain to any part of 105 | the Derivative Works; and 106 | 107 | (d) If the Work includes a "NOTICE" text file as part of its 108 | distribution, then any Derivative Works that You distribute must 109 | include a readable copy of the attribution notices contained 110 | within such NOTICE file, excluding those notices that do not 111 | pertain to any part of the Derivative Works, in at least one 112 | of the following places: within a NOTICE text file distributed 113 | as part of the Derivative Works; within the Source form or 114 | documentation, if provided along with the Derivative Works; or, 115 | within a display generated by the Derivative Works, if and 116 | wherever such third-party notices normally appear. The contents 117 | of the NOTICE file are for informational purposes only and 118 | do not modify the License. You may add Your own attribution 119 | notices within Derivative Works that You distribute, alongside 120 | or as an addendum to the NOTICE text from the Work, provided 121 | that such additional attribution notices cannot be construed 122 | as modifying the License. 123 | 124 | You may add Your own copyright statement to Your modifications and 125 | may provide additional or different license terms and conditions 126 | for use, reproduction, or distribution of Your modifications, or 127 | for any such Derivative Works as a whole, provided Your use, 128 | reproduction, and distribution of the Work otherwise complies with 129 | the conditions stated in this License. 130 | 131 | 5. Submission of Contributions. Unless You explicitly state otherwise, 132 | any Contribution intentionally submitted for inclusion in the Work 133 | by You to the Licensor shall be under the terms and conditions of 134 | this License, without any additional terms or conditions. 135 | Notwithstanding the above, nothing herein shall supersede or modify 136 | the terms of any separate license agreement you may have executed 137 | with Licensor regarding such Contributions. 138 | 139 | 6. Trademarks. This License does not grant permission to use the trade 140 | names, trademarks, service marks, or product names of the Licensor, 141 | except as required for reasonable and customary use in describing the 142 | origin of the Work and reproducing the content of the NOTICE file. 143 | 144 | 7. Disclaimer of Warranty. Unless required by applicable law or 145 | agreed to in writing, Licensor provides the Work (and each 146 | Contributor provides its Contributions) on an "AS IS" BASIS, 147 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 148 | implied, including, without limitation, any warranties or conditions 149 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 150 | PARTICULAR PURPOSE. You are solely responsible for determining the 151 | appropriateness of using or redistributing the Work and assume any 152 | risks associated with Your exercise of permissions under this License. 153 | 154 | 8. Limitation of Liability. In no event and under no legal theory, 155 | whether in tort (including negligence), contract, or otherwise, 156 | unless required by applicable law (such as deliberate and grossly 157 | negligent acts) or agreed to in writing, shall any Contributor be 158 | liable to You for damages, including any direct, indirect, special, 159 | incidental, or consequential damages of any character arising as a 160 | result of this License or out of the use or inability to use the 161 | Work (including but not limited to damages for loss of goodwill, 162 | work stoppage, computer failure or malfunction, or any and all 163 | other commercial damages or losses), even if such Contributor 164 | has been advised of the possibility of such damages. 165 | 166 | 9. Accepting Warranty or Additional Liability. While redistributing 167 | the Work or Derivative Works thereof, You may choose to offer, 168 | and charge a fee for, acceptance of support, warranty, indemnity, 169 | or other liability obligations and/or rights consistent with this 170 | License. However, in accepting such obligations, You may act only 171 | on Your own behalf and on Your sole responsibility, not on behalf 172 | of any other Contributor, and only if You agree to indemnify, 173 | defend, and hold each Contributor harmless for any liability 174 | incurred by, or claims asserted against, such Contributor by reason 175 | of your accepting any such warranty or additional liability. 176 | 177 | END OF TERMS AND CONDITIONS 178 | 179 | APPENDIX: How to apply the Apache License to your work. 180 | 181 | To apply the Apache License to your work, attach the following 182 | boilerplate notice, with the fields enclosed by brackets "[]" 183 | replaced with your own identifying information. (Don't include 184 | the brackets!) The text should be enclosed in the appropriate 185 | comment syntax for the file format. We also recommend that a 186 | file or class name and description of purpose be included on the 187 | same "printed page" as the copyright notice for easier 188 | identification within third-party archives. 189 | 190 | Copyright [yyyy] [name of copyright owner] 191 | 192 | Licensed under the Apache License, Version 2.0 (the "License"); 193 | you may not use this file except in compliance with the License. 194 | You may obtain a copy of the License at 195 | 196 | http://www.apache.org/licenses/LICENSE-2.0 197 | 198 | Unless required by applicable law or agreed to in writing, software 199 | distributed under the License is distributed on an "AS IS" BASIS, 200 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 201 | See the License for the specific language governing permissions and 202 | limitations under the License. 203 | -------------------------------------------------------------------------------- /nasbench/lib/evaluate.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 The Google Research Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Performs training and evaluation of the proposed model spec on TPU.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import time 22 | 23 | from nasbench.lib import cifar 24 | from nasbench.lib import model_builder 25 | from nasbench.lib import training_time 26 | import numpy as np 27 | import tensorflow as tf 28 | 29 | VALID_EXCEPTIONS = ( 30 | tf.train.NanLossDuringTrainingError, # NaN loss 31 | tf.errors.ResourceExhaustedError, # OOM 32 | tf.errors.InvalidArgumentError, # NaN gradient 33 | tf.errors.DeadlineExceededError, # Timed out 34 | ) 35 | 36 | 37 | class AbortError(Exception): 38 | """Signals that evaluation failed for a valid reason.""" 39 | pass 40 | 41 | 42 | def train_and_evaluate(spec, config, model_dir): 43 | """Train and evaluate the proposed model. 44 | 45 | This method trains and evaluates the model for the creation of the benchmark 46 | dataset. The default values from the config.py are exactly the values used. 47 | 48 | Args: 49 | spec: ModelSpec object. 50 | config: config dict generated from config.py. 51 | model_dir: directory to store the checkpoint files. 52 | 53 | Returns: 54 | dict containing the evaluation metadata. 55 | """ 56 | return _train_and_evaluate_impl(spec, config, model_dir) 57 | 58 | 59 | def augment_and_evaluate(spec, config, model_dir, epochs_per_eval=5): 60 | """Trains the model on the full training set and evaluates on test set. 61 | 62 | "Augment" specifically refers to training the same spec in a larger network on 63 | the full training set. Typically this involves increasing the epoch count, 64 | number of modules/stacks, and changing the LR schedule. These changes should 65 | be made to the config dict before calling this method. 66 | 67 | Note: this method was not used for generating the NAS Benchmark dataset. See 68 | train_and_evaluate instead. 69 | 70 | Args: 71 | spec: ModelSpec object. 72 | config: config dict generated from config.py. 73 | model_dir: directory to store the checkpoint files. 74 | epochs_per_eval: number of epochs per evaluation run. Evaluation is always 75 | run at the very start and end. 76 | 77 | Returns: 78 | dict containing the evaluation metadata. 79 | """ 80 | return _augment_and_evaluate_impl(spec, config, model_dir, epochs_per_eval) 81 | 82 | 83 | def _train_and_evaluate_impl(spec, config, model_dir): 84 | """Train and evaluate implementation, see train_and_evaluate docstring.""" 85 | evaluator = _TrainAndEvaluator(spec, config, model_dir) 86 | return evaluator.run() 87 | 88 | 89 | class _TrainAndEvaluator(object): 90 | """Runs the training and evaluation.""" 91 | 92 | def __init__(self, spec, config, model_dir): 93 | """Initialize evaluator. See train_and_evaluate docstring.""" 94 | self.input_train = cifar.CIFARInput('train', config) 95 | self.input_train_eval = cifar.CIFARInput('train_eval', config) 96 | self.input_valid = cifar.CIFARInput('valid', config) 97 | self.input_test = cifar.CIFARInput('test', config) 98 | self.input_sample = cifar.CIFARInput('sample', config) 99 | self.estimator = _create_estimator(spec, config, model_dir, 100 | self.input_train.num_images, 101 | self.input_sample.num_images) 102 | 103 | self.spec = spec 104 | self.config = config 105 | self.model_dir = model_dir 106 | 107 | def run(self): 108 | """Runs training and evaluation.""" 109 | attempts = 0 110 | while True: 111 | # Delete everything in the model dir at the start of each attempt 112 | try: 113 | tf.gfile.DeleteRecursively(self.model_dir) 114 | except tf.errors.NotFoundError: 115 | pass 116 | tf.gfile.MakeDirs(self.model_dir) 117 | 118 | try: 119 | # Train 120 | if self.config['train_seconds'] > 0.0: 121 | timing = training_time.limit(self.config['train_seconds']) 122 | else: 123 | timing = training_time.limit(None) 124 | 125 | evaluations = map(float, self.config['intermediate_evaluations']) 126 | if not evaluations or evaluations[-1] != 1.0: 127 | evaluations.append(1.0) 128 | assert evaluations == sorted(evaluations) 129 | 130 | evaluation_results = [] 131 | start_time = time.time() 132 | 133 | # Train for 1 step with 0 LR to initialize the weights, then evaluate 134 | # once at the start for completeness, accuracies expected to be around 135 | # random selection. Note that batch norm moving averages change during 136 | # the step but the trainable weights do not. 137 | self.estimator.train( 138 | input_fn=self.input_train.input_fn, 139 | max_steps=1, 140 | hooks=[timing.train_hook], 141 | saving_listeners=[timing.saving_listener]) 142 | evaluation_results.append(self._evaluate_all(0.0, 0)) 143 | 144 | for next_evaluation in evaluations: 145 | epoch = next_evaluation * self.config['train_epochs'] 146 | train_steps = int(epoch * self.input_train.num_images / 147 | self.config['batch_size']) 148 | self.estimator.train( 149 | input_fn=self.input_train.input_fn, 150 | max_steps=train_steps, 151 | hooks=[timing.train_hook], 152 | saving_listeners=[timing.saving_listener]) 153 | 154 | evaluation_results.append(self._evaluate_all(epoch, train_steps)) 155 | 156 | all_time = time.time() - start_time 157 | break # Break from retry loop on success 158 | except VALID_EXCEPTIONS as e: # pylint: disable=catching-non-exception 159 | attempts += 1 160 | tf.logging.warning(str(e)) 161 | if attempts >= self.config['max_attempts']: 162 | raise AbortError(str(e)) 163 | 164 | metadata = { 165 | 'trainable_params': _get_param_count(self.model_dir), 166 | 'total_time': all_time, # includes eval and other metric time 167 | 'evaluation_results': evaluation_results, 168 | } 169 | 170 | return metadata 171 | 172 | def _evaluate_all(self, epochs, steps): 173 | """Runs all the evaluations.""" 174 | train_accuracy = _evaluate(self.estimator, self.input_train_eval, 175 | self.config, name='train') 176 | valid_accuracy = _evaluate(self.estimator, self.input_valid, 177 | self.config, name='valid') 178 | test_accuracy = _evaluate(self.estimator, self.input_test, 179 | self.config, name='test') 180 | train_time = self.estimator.get_variable_value( 181 | training_time.TOTAL_TIME_NAME) 182 | 183 | now = time.time() 184 | sample_metrics = self._compute_sample_metrics() 185 | predict_time = time.time() - now 186 | 187 | return { 188 | 'epochs': epochs, 189 | 'training_time': train_time, 190 | 'training_steps': steps, 191 | 'train_accuracy': train_accuracy, 192 | 'validation_accuracy': valid_accuracy, 193 | 'test_accuracy': test_accuracy, 194 | 'sample_metrics': sample_metrics, 195 | 'predict_time': predict_time, 196 | } 197 | 198 | def _compute_sample_metrics(self): 199 | """Computes the metrics on a fixed batch.""" 200 | sample_metrics = self.estimator.predict( 201 | input_fn=self.input_sample.input_fn, yield_single_examples=False).next() 202 | 203 | # Fix the extra batch dimension added by PREDICT 204 | for metric in sample_metrics: 205 | if metric in ['logits', 'input_grad_norm']: 206 | # Batch-shaped tensors take first batch 207 | sample_metrics[metric] = ( 208 | sample_metrics[metric][:self.input_sample.num_images, Ellipsis]) 209 | else: 210 | # Other tensors remove batch dimension 211 | sample_metrics[metric] = sample_metrics[metric][0, Ellipsis] 212 | 213 | return sample_metrics 214 | 215 | 216 | def _augment_and_evaluate_impl(spec, config, model_dir, epochs_per_eval=5): 217 | """Augment and evaluate implementation, see augment_and_evaluate docstring.""" 218 | input_augment, input_test = [ 219 | cifar.CIFARInput(m, config) 220 | for m in ['augment', 'test']] 221 | estimator = _create_estimator(spec, config, model_dir, 222 | input_augment.num_images) 223 | 224 | if config['train_seconds'] > 0.0: 225 | timing = training_time.limit(config['train_seconds']) 226 | else: 227 | timing = training_time.limit(None) 228 | 229 | steps_per_epoch = input_augment.num_images / config['batch_size'] # float 230 | ckpt = tf.train.latest_checkpoint(model_dir) 231 | if not ckpt: 232 | current_step = 0 233 | else: 234 | current_step = int(ckpt.split('-')[-1]) 235 | max_steps = int(config['train_epochs'] * steps_per_epoch) 236 | 237 | while current_step < max_steps: 238 | next_step = current_step + int(epochs_per_eval * steps_per_epoch) 239 | next_step = min(next_step, max_steps) 240 | estimator.train( 241 | input_fn=input_augment.input_fn, 242 | max_steps=next_step, 243 | hooks=[timing.train_hook], 244 | saving_listeners=[timing.saving_listener]) 245 | current_step = next_step 246 | 247 | test_accuracy = _evaluate(estimator, input_test, config) 248 | 249 | metadata = { 250 | 'trainable_params': _get_param_count(model_dir), 251 | 'test_accuracy': test_accuracy, 252 | } 253 | 254 | return metadata 255 | 256 | 257 | def _create_estimator(spec, config, model_dir, 258 | num_train_images, num_sample_images=None): 259 | """Creates the TPUEstimator object.""" 260 | # Estimator will save a checkpoint at the end of every train() call. Disable 261 | # automatic checkpoints by setting the time interval between checkpoints to 262 | # a very large value. 263 | run_config = tf.contrib.tpu.RunConfig( 264 | model_dir=model_dir, 265 | keep_checkpoint_max=3, # Keeps ckpt at start, halfway, and end 266 | save_checkpoints_secs=2**30, 267 | tpu_config=tf.contrib.tpu.TPUConfig( 268 | iterations_per_loop=config['tpu_iterations_per_loop'], 269 | num_shards=config['tpu_num_shards'])) 270 | 271 | # This is a hack to allow PREDICT on a fixed batch on TPU. By replicating the 272 | # batch by the number of shards, this ensures each TPU core operates on the 273 | # entire fixed batch. 274 | if num_sample_images and config['use_tpu']: 275 | num_sample_images *= config['tpu_num_shards'] 276 | 277 | estimator = tf.contrib.tpu.TPUEstimator( 278 | use_tpu=config['use_tpu'], 279 | model_fn=model_builder.build_model_fn( 280 | spec, config, num_train_images), 281 | config=run_config, 282 | train_batch_size=config['batch_size'], 283 | eval_batch_size=config['batch_size'], 284 | predict_batch_size=num_sample_images) 285 | 286 | return estimator 287 | 288 | 289 | def _evaluate(estimator, input_data, config, name=None): 290 | """Evaluate the estimator on the input data.""" 291 | steps = input_data.num_images // config['batch_size'] 292 | results = estimator.evaluate( 293 | input_fn=input_data.input_fn, 294 | steps=steps, 295 | name=name) 296 | return results['accuracy'] 297 | 298 | 299 | def _get_param_count(model_dir): 300 | """Get trainable param count from the model directory.""" 301 | tf.reset_default_graph() 302 | checkpoint = tf.train.get_checkpoint_state(model_dir) 303 | with tf.Session() as sess: 304 | saver = tf.train.import_meta_graph( 305 | checkpoint.model_checkpoint_path + '.meta') 306 | saver.restore(sess, checkpoint.model_checkpoint_path) 307 | params = np.sum([np.prod(v.get_shape().as_list()) 308 | for v in tf.trainable_variables()]) 309 | 310 | return params 311 | 312 | -------------------------------------------------------------------------------- /nasbench/api.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 The Google Research Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """User interface for the NAS Benchmark dataset. 16 | 17 | Before using this API, download the data files from the links in the README. 18 | 19 | Usage: 20 | # Load the data from file (this will take some time) 21 | nasbench = api.NASBench('/path/to/nasbench.tfrecord') 22 | 23 | # Create an Inception-like module (5x5 convolution replaced with two 3x3 24 | # convolutions). 25 | model_spec = api.ModelSpec( 26 | # Adjacency matrix of the module 27 | matrix=[[0, 1, 1, 1, 0, 1, 0], # input layer 28 | [0, 0, 0, 0, 0, 0, 1], # 1x1 conv 29 | [0, 0, 0, 0, 0, 0, 1], # 3x3 conv 30 | [0, 0, 0, 0, 1, 0, 0], # 5x5 conv (replaced by two 3x3's) 31 | [0, 0, 0, 0, 0, 0, 1], # 5x5 conv (replaced by two 3x3's) 32 | [0, 0, 0, 0, 0, 0, 1], # 3x3 max-pool 33 | [0, 0, 0, 0, 0, 0, 0]], # output layer 34 | # Operations at the vertices of the module, matches order of matrix 35 | ops=[INPUT, CONV1X1, CONV3X3, CONV3X3, CONV3X3, MAXPOOL3X3, OUTPUT]) 36 | 37 | 38 | # Query this model from dataset 39 | data = nasbench.query(model_spec) 40 | 41 | Adjacency matrices are expected to be upper-triangular 0-1 matrices within the 42 | defined search space (7 vertices, 9 edges, 3 allowed ops). The first and last 43 | operations must be 'input' and 'output'. The other operations should be from 44 | config['available_ops']. Currently, the available operations are: 45 | CONV3X3 = "conv3x3-bn-relu" 46 | CONV1X1 = "conv1x1-bn-relu" 47 | MAXPOOL3X3 = "maxpool3x3" 48 | 49 | When querying a spec, the spec will first be automatically pruned (removing 50 | unused vertices and edges along with ops). If the pruned spec is still out of 51 | the search space, an OutOfDomainError will be raised, otherwise the data is 52 | returned. 53 | 54 | The returned data object is a dictionary with the following keys: 55 | - module_adjacency: numpy array for the adjacency matrix 56 | - module_operations: list of operation labels 57 | - trainable_parameters: number of trainable parameters in the model 58 | - training_time: the total training time in seconds up to this point 59 | - train_accuracy: training accuracy 60 | - validation_accuracy: validation_accuracy 61 | - test_accuracy: testing accuracy 62 | 63 | Instead of querying the dataset for a single run of a model, it is also possible 64 | to retrieve all metrics for a given spec, using: 65 | 66 | fixed_stats, computed_stats = nasbench.get_metrics_from_spec(model_spec) 67 | 68 | The fixed_stats is a dictionary with the keys: 69 | - module_adjacency 70 | - module_operations 71 | - trainable_parameters 72 | 73 | The computed_stats is a dictionary from epoch count to a list of metric 74 | dicts. For example, computed_stats[108][0] contains the metrics for the first 75 | repeat of the provided model trained to 108 epochs. The available keys are: 76 | - halfway_training_time 77 | - halfway_train_accuracy 78 | - halfway_validation_accuracy 79 | - halfway_test_accuracy 80 | - final_training_time 81 | - final_train_accuracy 82 | - final_validation_accuracy 83 | - final_test_accuracy 84 | """ 85 | 86 | from __future__ import absolute_import 87 | from __future__ import division 88 | from __future__ import print_function 89 | 90 | import base64 91 | import copy 92 | import json 93 | import os 94 | import random 95 | import time 96 | 97 | from nasbench.lib import config 98 | from nasbench.lib import evaluate 99 | from nasbench.lib import model_metrics_pb2 100 | from nasbench.lib import model_spec as _model_spec 101 | import numpy as np 102 | import tensorflow as tf 103 | 104 | # Bring ModelSpec to top-level for convenience. See lib/model_spec.py. 105 | ModelSpec = _model_spec.ModelSpec 106 | 107 | 108 | class OutOfDomainError(Exception): 109 | """Indicates that the requested graph is outside of the search domain.""" 110 | 111 | 112 | class NASBench(object): 113 | """User-facing API for accessing the NASBench dataset.""" 114 | 115 | def __init__(self, dataset_file, seed=None): 116 | """Initialize dataset, this should only be done once per experiment. 117 | 118 | Args: 119 | dataset_file: path to .tfrecord file containing the dataset. 120 | seed: random seed used for sampling queried models. Two NASBench objects 121 | created with the same seed will return the same data points when queried 122 | with the same models in the same order. By default, the seed is randomly 123 | generated. 124 | """ 125 | self.config = config.build_config() 126 | random.seed(seed) 127 | 128 | print('Loading dataset from file... This may take a few minutes...') 129 | start = time.time() 130 | 131 | # Stores the fixed statistics that are independent of evaluation (i.e., 132 | # adjacency matrix, operations, and number of parameters). 133 | # hash --> metric name --> scalar 134 | self.fixed_statistics = {} 135 | 136 | # Stores the statistics that are computed via training and evaluating the 137 | # model on CIFAR-10. Statistics are computed for multiple repeats of each 138 | # model at each max epoch length. 139 | # hash --> epochs --> repeat index --> metric name --> scalar 140 | self.computed_statistics = {} 141 | 142 | # Valid queriable epoch lengths. {4, 12, 36, 108} for the full dataset or 143 | # {108} for the smaller dataset with only the 108 epochs. 144 | self.valid_epochs = set() 145 | 146 | for serialized_row in tf.python_io.tf_record_iterator(dataset_file): 147 | # Parse the data from the data file. 148 | module_hash, epochs, raw_adjacency, raw_operations, raw_metrics = ( 149 | json.loads(serialized_row.decode('utf-8'))) 150 | 151 | dim = int(np.sqrt(len(raw_adjacency))) 152 | adjacency = np.array([int(e) for e in list(raw_adjacency)], dtype=np.int8) 153 | adjacency = np.reshape(adjacency, (dim, dim)) 154 | operations = raw_operations.split(',') 155 | metrics = model_metrics_pb2.ModelMetrics.FromString( 156 | base64.b64decode(raw_metrics)) 157 | 158 | if module_hash not in self.fixed_statistics: 159 | # First time seeing this module, initialize fixed statistics. 160 | new_entry = {} 161 | new_entry['module_adjacency'] = adjacency 162 | new_entry['module_operations'] = operations 163 | new_entry['trainable_parameters'] = metrics.trainable_parameters 164 | self.fixed_statistics[module_hash] = new_entry 165 | self.computed_statistics[module_hash] = {} 166 | 167 | self.valid_epochs.add(epochs) 168 | 169 | if epochs not in self.computed_statistics[module_hash]: 170 | self.computed_statistics[module_hash][epochs] = [] 171 | 172 | # Each data_point consists of the metrics recorded from a single 173 | # train-and-evaluation of a model at a specific epoch length. 174 | data_point = {} 175 | 176 | # Note: metrics.evaluation_data[0] contains the computed metrics at the 177 | # start of training (step 0) but this is unused by this API. 178 | 179 | # Evaluation statistics at the half-way point of training 180 | half_evaluation = metrics.evaluation_data[1] 181 | data_point['halfway_training_time'] = half_evaluation.training_time 182 | data_point['halfway_train_accuracy'] = half_evaluation.train_accuracy 183 | data_point['halfway_validation_accuracy'] = ( 184 | half_evaluation.validation_accuracy) 185 | data_point['halfway_test_accuracy'] = half_evaluation.test_accuracy 186 | 187 | # Evaluation statistics at the end of training 188 | final_evaluation = metrics.evaluation_data[2] 189 | data_point['final_training_time'] = final_evaluation.training_time 190 | data_point['final_train_accuracy'] = final_evaluation.train_accuracy 191 | data_point['final_validation_accuracy'] = ( 192 | final_evaluation.validation_accuracy) 193 | data_point['final_test_accuracy'] = final_evaluation.test_accuracy 194 | 195 | self.computed_statistics[module_hash][epochs].append(data_point) 196 | 197 | elapsed = time.time() - start 198 | print('Loaded dataset in %d seconds' % elapsed) 199 | 200 | self.history = {} 201 | self.training_time_spent = 0.0 202 | self.total_epochs_spent = 0 203 | 204 | def query(self, model_spec, epochs=108, stop_halfway=False): 205 | """Fetch one of the evaluations for this model spec. 206 | 207 | Each call will sample one of the config['num_repeats'] evaluations of the 208 | model. This means that repeated queries of the same model (or isomorphic 209 | models) may return identical metrics. 210 | 211 | This function will increment the budget counters for benchmarking purposes. 212 | See self.training_time_spent, and self.total_epochs_spent. 213 | 214 | This function also allows querying the evaluation metrics at the halfway 215 | point of training using stop_halfway. Using this option will increment the 216 | budget counters only up to the halfway point. 217 | 218 | Args: 219 | model_spec: ModelSpec object. 220 | epochs: number of epochs trained. Must be one of the evaluated number of 221 | epochs, [4, 12, 36, 108] for the full dataset. 222 | stop_halfway: if True, returned dict will only contain the training time 223 | and accuracies at the halfway point of training (num_epochs/2). 224 | Otherwise, returns the time and accuracies at the end of training 225 | (num_epochs). 226 | 227 | Returns: 228 | dict containing the evaluated data for this object. 229 | 230 | Raises: 231 | OutOfDomainError: if model_spec or num_epochs is outside the search space. 232 | """ 233 | if epochs not in self.valid_epochs: 234 | raise OutOfDomainError('invalid number of epochs, must be one of %s' 235 | % self.valid_epochs) 236 | 237 | fixed_stat, computed_stat = self.get_metrics_from_spec(model_spec) 238 | sampled_index = random.randint(0, self.config['num_repeats'] - 1) 239 | computed_stat = computed_stat[epochs][sampled_index] 240 | 241 | data = {} 242 | data['module_adjacency'] = fixed_stat['module_adjacency'] 243 | data['module_operations'] = fixed_stat['module_operations'] 244 | data['trainable_parameters'] = fixed_stat['trainable_parameters'] 245 | 246 | if stop_halfway: 247 | data['training_time'] = computed_stat['halfway_training_time'] 248 | data['train_accuracy'] = computed_stat['halfway_train_accuracy'] 249 | data['validation_accuracy'] = computed_stat['halfway_validation_accuracy'] 250 | data['test_accuracy'] = computed_stat['halfway_test_accuracy'] 251 | else: 252 | data['training_time'] = computed_stat['final_training_time'] 253 | data['train_accuracy'] = computed_stat['final_train_accuracy'] 254 | data['validation_accuracy'] = computed_stat['final_validation_accuracy'] 255 | data['test_accuracy'] = computed_stat['final_test_accuracy'] 256 | 257 | self.training_time_spent += data['training_time'] 258 | if stop_halfway: 259 | self.total_epochs_spent += epochs // 2 260 | else: 261 | self.total_epochs_spent += epochs 262 | 263 | return data 264 | 265 | def is_valid(self, model_spec): 266 | """Checks the validity of the model_spec. 267 | 268 | For the purposes of benchmarking, this does not increment the budget 269 | counters. 270 | 271 | Args: 272 | model_spec: ModelSpec object. 273 | 274 | Returns: 275 | True if model is within space. 276 | """ 277 | try: 278 | self._check_spec(model_spec) 279 | except OutOfDomainError: 280 | return False 281 | 282 | return True 283 | 284 | def get_budget_counters(self): 285 | """Returns the time and budget counters.""" 286 | return self.training_time_spent, self.total_epochs_spent 287 | 288 | def reset_budget_counters(self): 289 | """Reset the time and epoch budget counters.""" 290 | self.training_time_spent = 0.0 291 | self.total_epochs_spent = 0 292 | 293 | def evaluate(self, model_spec, model_dir): 294 | """Trains and evaluates a model spec from scratch (does not query dataset). 295 | 296 | This function runs the same procedure that was used to generate each 297 | evaluation in the dataset. Because we are not querying the generated 298 | dataset of trained models, there are no limitations on number of vertices, 299 | edges, operations, or epochs. Note that the results will not exactly match 300 | the dataset due to randomness. By default, this uses TPUs for evaluation but 301 | CPU/GPU can be used by setting --use_tpu=false (GPU will require installing 302 | tensorflow-gpu). 303 | 304 | Args: 305 | model_spec: ModelSpec object. 306 | model_dir: directory to store the checkpoints, summaries, and logs. 307 | 308 | Returns: 309 | dict contained the evaluated data for this object, same structure as 310 | returned by query(). 311 | """ 312 | # Metadata contains additional metrics that aren't reported normally. 313 | # However, these are stored in the JSON file at the model_dir. 314 | metadata = evaluate.train_and_evaluate(model_spec, self.config, model_dir) 315 | metadata_file = os.path.join(model_dir, 'metadata.json') 316 | with tf.gfile.Open(metadata_file, 'w') as f: 317 | json.dump(metadata, f, cls=_NumpyEncoder) 318 | 319 | data_point = {} 320 | data_point['module_adjacency'] = model_spec.matrix 321 | data_point['module_operations'] = model_spec.ops 322 | data_point['trainable_parameters'] = metadata['trainable_params'] 323 | 324 | final_evaluation = metadata['evaluation_results'][-1] 325 | data_point['training_time'] = final_evaluation['training_time'] 326 | data_point['train_accuracy'] = final_evaluation['train_accuracy'] 327 | data_point['validation_accuracy'] = final_evaluation['validation_accuracy'] 328 | data_point['test_accuracy'] = final_evaluation['test_accuracy'] 329 | 330 | return data_point 331 | 332 | def hash_iterator(self): 333 | """Returns iterator over all unique model hashes.""" 334 | return self.fixed_statistics.keys() 335 | 336 | def get_metrics_from_hash(self, module_hash): 337 | """Returns the metrics for all epochs and all repeats of a hash. 338 | 339 | This method is for dataset analysis and should not be used for benchmarking. 340 | As such, it does not increment any of the budget counters. 341 | 342 | Args: 343 | module_hash: MD5 hash, i.e., the values yielded by hash_iterator(). 344 | 345 | Returns: 346 | fixed stats and computed stats of the model spec provided. 347 | """ 348 | fixed_stat = copy.deepcopy(self.fixed_statistics[module_hash]) 349 | computed_stat = copy.deepcopy(self.computed_statistics[module_hash]) 350 | return fixed_stat, computed_stat 351 | 352 | def get_metrics_from_spec(self, model_spec): 353 | """Returns the metrics for all epochs and all repeats of a model. 354 | 355 | This method is for dataset analysis and should not be used for benchmarking. 356 | As such, it does not increment any of the budget counters. 357 | 358 | Args: 359 | model_spec: ModelSpec object. 360 | 361 | Returns: 362 | fixed stats and computed stats of the model spec provided. 363 | """ 364 | self._check_spec(model_spec) 365 | module_hash = self._hash_spec(model_spec) 366 | return self.get_metrics_from_hash(module_hash) 367 | 368 | def _check_spec(self, model_spec): 369 | """Checks that the model spec is within the dataset.""" 370 | if not model_spec.valid_spec: 371 | raise OutOfDomainError('invalid spec, provided graph is disconnected.') 372 | 373 | num_vertices = len(model_spec.ops) 374 | num_edges = np.sum(model_spec.matrix) 375 | 376 | if num_vertices > self.config['module_vertices']: 377 | raise OutOfDomainError('too many vertices, got %d (max vertices = %d)' 378 | % (num_vertices, config['module_vertices'])) 379 | 380 | if num_edges > self.config['max_edges']: 381 | raise OutOfDomainError('too many edges, got %d (max edges = %d)' 382 | % (num_edges, self.config['max_edges'])) 383 | 384 | if model_spec.ops[0] != 'input': 385 | raise OutOfDomainError('first operation should be \'input\'') 386 | if model_spec.ops[-1] != 'output': 387 | raise OutOfDomainError('last operation should be \'output\'') 388 | for op in model_spec.ops[1:-1]: 389 | if op not in self.config['available_ops']: 390 | raise OutOfDomainError('unsupported op %s (available ops = %s)' 391 | % (op, self.config['available_ops'])) 392 | 393 | def _hash_spec(self, model_spec): 394 | """Returns the MD5 hash for a provided model_spec.""" 395 | return model_spec.hash_spec(self.config['available_ops']) 396 | 397 | 398 | class _NumpyEncoder(json.JSONEncoder): 399 | """Converts numpy objects to JSON-serializable format.""" 400 | 401 | def default(self, obj): 402 | if isinstance(obj, np.ndarray): 403 | # Matrices converted to nested lists 404 | return obj.tolist() 405 | elif isinstance(obj, np.generic): 406 | # Scalars converted to closest Python type 407 | return np.asscalar(obj) 408 | return json.JSONEncoder.default(self, obj) 409 | -------------------------------------------------------------------------------- /nasbench/tests/graph_util_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 The Google Research Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Tests for lib/graph_util.py.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import random 22 | from nasbench.lib import graph_util 23 | import numpy as np 24 | import tensorflow as tf # Used for tf.test 25 | 26 | 27 | class GraphUtilTest(tf.test.TestCase): 28 | 29 | def test_gen_is_edge(self): 30 | """Tests gen_is_edge generates correct graphs.""" 31 | fn = graph_util.gen_is_edge_fn(0) # '000' 32 | arr = np.fromfunction(fn, (3, 3), dtype=np.int8) 33 | self.assertTrue(np.array_equal(arr, 34 | np.array([[0, 0, 0], 35 | [0, 0, 0], 36 | [0, 0, 0]]))) 37 | 38 | fn = graph_util.gen_is_edge_fn(3) # '011' 39 | arr = np.fromfunction(fn, (3, 3), dtype=np.int8) 40 | self.assertTrue(np.array_equal(arr, 41 | np.array([[0, 1, 1], 42 | [0, 0, 0], 43 | [0, 0, 0]]))) 44 | 45 | fn = graph_util.gen_is_edge_fn(5) # '101' 46 | arr = np.fromfunction(fn, (3, 3), dtype=np.int8) 47 | self.assertTrue(np.array_equal(arr, 48 | np.array([[0, 1, 0], 49 | [0, 0, 1], 50 | [0, 0, 0]]))) 51 | 52 | fn = graph_util.gen_is_edge_fn(7) # '111' 53 | arr = np.fromfunction(fn, (3, 3), dtype=np.int8) 54 | self.assertTrue(np.array_equal(arr, 55 | np.array([[0, 1, 1], 56 | [0, 0, 1], 57 | [0, 0, 0]]))) 58 | 59 | fn = graph_util.gen_is_edge_fn(7) # '111' 60 | arr = np.fromfunction(fn, (4, 4), dtype=np.int8) 61 | self.assertTrue(np.array_equal(arr, 62 | np.array([[0, 1, 1, 0], 63 | [0, 0, 1, 0], 64 | [0, 0, 0, 0], 65 | [0, 0, 0, 0]]))) 66 | 67 | fn = graph_util.gen_is_edge_fn(18) # '010010' 68 | arr = np.fromfunction(fn, (4, 4), dtype=np.int8) 69 | self.assertTrue(np.array_equal(arr, 70 | np.array([[0, 0, 1, 0], 71 | [0, 0, 0, 1], 72 | [0, 0, 0, 0], 73 | [0, 0, 0, 0]]))) 74 | 75 | fn = graph_util.gen_is_edge_fn(35) # '100011' 76 | arr = np.fromfunction(fn, (4, 4), dtype=np.int8) 77 | self.assertTrue(np.array_equal(arr, 78 | np.array([[0, 1, 1, 0], 79 | [0, 0, 0, 0], 80 | [0, 0, 0, 1], 81 | [0, 0, 0, 0]]))) 82 | 83 | def test_is_full_dag(self): 84 | """Tests is_full_dag classifies DAGs.""" 85 | self.assertTrue(graph_util.is_full_dag(np.array( 86 | [[0, 1, 0], 87 | [0, 0, 1], 88 | [0, 0, 0]]))) 89 | 90 | self.assertTrue(graph_util.is_full_dag(np.array( 91 | [[0, 1, 1], 92 | [0, 0, 1], 93 | [0, 0, 0]]))) 94 | 95 | self.assertTrue(graph_util.is_full_dag(np.array( 96 | [[0, 1, 1, 0], 97 | [0, 0, 0, 1], 98 | [0, 0, 0, 1], 99 | [0, 0, 0, 0]]))) 100 | 101 | # vertex 1 not connected to input 102 | self.assertFalse(graph_util.is_full_dag(np.array( 103 | [[0, 0, 1], 104 | [0, 0, 1], 105 | [0, 0, 0]]))) 106 | 107 | # vertex 1 not connected to output 108 | self.assertFalse(graph_util.is_full_dag(np.array( 109 | [[0, 1, 1], 110 | [0, 0, 0], 111 | [0, 0, 0]]))) 112 | 113 | # 1, 3 are connected to each other but disconnected from main path 114 | self.assertFalse(graph_util.is_full_dag(np.array( 115 | [[0, 0, 1, 0, 0], 116 | [0, 0, 0, 1, 0], 117 | [0, 0, 0, 0, 1], 118 | [0, 0, 0, 0, 0], 119 | [0, 0, 0, 0, 0]]))) 120 | 121 | # no path from input to output 122 | self.assertFalse(graph_util.is_full_dag(np.array( 123 | [[0, 0, 1, 0], 124 | [0, 0, 0, 1], 125 | [0, 0, 0, 0], 126 | [0, 0, 0, 0]]))) 127 | 128 | # completely disconnected vertex 129 | self.assertFalse(graph_util.is_full_dag(np.array( 130 | [[0, 1, 0, 0], 131 | [0, 0, 0, 1], 132 | [0, 0, 0, 0], 133 | [0, 0, 0, 0]]))) 134 | 135 | def test_hash_module(self): 136 | # Diamond graph with label permutation 137 | matrix1 = np.array( 138 | [[0, 1, 1, 0,], 139 | [0, 0, 0, 1], 140 | [0, 0, 0, 1], 141 | [0, 0, 0, 0]]) 142 | label1 = [-1, 1, 2, -2] 143 | label2 = [-1, 2, 1, -2] 144 | 145 | hash1 = graph_util.hash_module(matrix1, label1) 146 | hash2 = graph_util.hash_module(matrix1, label2) 147 | self.assertEqual(hash1, hash2) 148 | 149 | # Simple graph with edge permutation 150 | matrix1 = np.array( 151 | [[0, 1, 1, 0, 0], 152 | [0, 0, 0, 0, 1], 153 | [0, 0, 0, 1, 0], 154 | [0, 0, 0, 0, 1], 155 | [0, 0, 0, 0, 0]]) 156 | label1 = [-1, 1, 2, 3, -2] 157 | 158 | matrix2 = np.array( 159 | [[0, 1, 0, 1, 0], 160 | [0, 0, 1, 0, 0], 161 | [0, 0, 0, 0, 1], 162 | [0, 0, 0, 0, 1], 163 | [0, 0, 0, 0, 0]]) 164 | label2 = [-1, 2, 3, 1, -2] 165 | 166 | matrix3 = np.array( 167 | [[0, 1, 1, 0, 0], 168 | [0, 0, 0, 1, 0], 169 | [0, 0, 0, 0, 1], 170 | [0, 0, 0, 0, 1], 171 | [0, 0, 0, 0, 0]]) 172 | label3 = [-1, 2, 1, 3, -2] 173 | 174 | hash1 = graph_util.hash_module(matrix1, label1) 175 | hash2 = graph_util.hash_module(matrix2, label2) 176 | hash3 = graph_util.hash_module(matrix3, label3) 177 | self.assertEqual(hash1, hash2) 178 | self.assertEqual(hash2, hash3) 179 | 180 | hash4 = graph_util.hash_module(matrix1, label2) 181 | self.assertNotEqual(hash4, hash1) 182 | 183 | hash5 = graph_util.hash_module(matrix1, label3) 184 | self.assertNotEqual(hash5, hash1) 185 | 186 | # Connected non-isomorphic regular graphs on 6 interior vertices (8 total) 187 | matrix1 = np.array( 188 | [[0, 1, 0, 0, 0, 0, 0, 0], 189 | [0, 0, 1, 1, 0, 0, 1, 0], 190 | [0, 0, 0, 0, 1, 1, 0, 0], 191 | [0, 0, 0, 0, 1, 1, 0, 0], 192 | [0, 0, 0, 0, 0, 0, 1, 0], 193 | [0, 0, 0, 0, 0, 0, 1, 0], 194 | [0, 0, 0, 0, 0, 0, 0, 1], 195 | [0, 0, 0, 0, 0, 0, 0, 0]]) 196 | matrix2 = np.array( 197 | [[0, 1, 0, 0, 0, 0, 0, 0], 198 | [0, 0, 1, 1, 0, 1, 0, 0], 199 | [0, 0, 0, 0, 1, 0, 1, 0], 200 | [0, 0, 0, 0, 1, 1, 0, 0], 201 | [0, 0, 0, 0, 0, 0, 1, 0], 202 | [0, 0, 0, 0, 0, 0, 1, 0], 203 | [0, 0, 0, 0, 0, 0, 0, 1], 204 | [0, 0, 0, 0, 0, 0, 0, 0]]) 205 | label1 = [-1, 1, 1, 1, 1, 1, 1, -2] 206 | 207 | hash1 = graph_util.hash_module(matrix1, label1) 208 | hash2 = graph_util.hash_module(matrix2, label1) 209 | self.assertNotEqual(hash1, hash2) 210 | 211 | # Non-isomorphic tricky case (breaks if you don't include self) 212 | hash1 = graph_util.hash_module( 213 | np.array([[0, 1, 0, 0, 0], 214 | [0, 0, 1, 0, 0], 215 | [0, 0, 0, 1, 0], 216 | [0, 0, 0, 0, 1], 217 | [0, 0, 0, 0, 0]]), 218 | [-1, 1, 0, 0, -2]) 219 | 220 | hash2 = graph_util.hash_module( 221 | np.array([[0, 1, 0, 0, 0], 222 | [0, 0, 1, 0, 0], 223 | [0, 0, 0, 1, 0], 224 | [0, 0, 0, 0, 1], 225 | [0, 0, 0, 0, 0]]), 226 | [-1, 0, 0, 1, -2]) 227 | self.assertNotEqual(hash1, hash2) 228 | 229 | # Non-isomorphic tricky case (breaks if you don't use directed edges) 230 | hash1 = graph_util.hash_module( 231 | np.array([[0, 1, 0, 1], 232 | [0, 0, 1, 0], 233 | [0, 0, 0, 1], 234 | [0, 0, 0, 0]]), 235 | [-1, 1, 0, -2]) 236 | 237 | hash2 = graph_util.hash_module( 238 | np.array([[0, 1, 0, 1], 239 | [0, 0, 1, 0], 240 | [0, 0, 0, 1], 241 | [0, 0, 0, 0]]), 242 | [-1, 0, 1, -2]) 243 | self.assertNotEqual(hash1, hash2) 244 | 245 | # Non-isomorphic tricky case (breaks if you only use out-neighbors and self) 246 | hash1 = graph_util.hash_module(np.array([[0, 1, 1, 1, 1, 0, 0], 247 | [0, 0, 1, 0, 0, 0, 0], 248 | [0, 0, 0, 0, 0, 0, 1], 249 | [0, 0, 0, 0, 0, 1, 0], 250 | [0, 0, 0, 0, 0, 1, 0], 251 | [0, 0, 0, 0, 0, 0, 1], 252 | [0, 0, 0, 0, 0, 0, 0]]), 253 | [-1, 1, 0, 0, 0, 0, -2]) 254 | hash2 = graph_util.hash_module(np.array([[0, 1, 1, 1, 1, 0, 0], 255 | [0, 0, 1, 0, 0, 0, 0], 256 | [0, 0, 0, 0, 0, 0, 1], 257 | [0, 0, 0, 0, 0, 1, 0], 258 | [0, 0, 0, 0, 0, 1, 0], 259 | [0, 0, 0, 0, 0, 0, 1], 260 | [0, 0, 0, 0, 0, 0, 0]]), 261 | [-1, 0, 0, 0, 1, 0, -2]) 262 | self.assertNotEqual(hash1, hash2) 263 | 264 | def test_permute_graph(self): 265 | # Does not have to be DAG 266 | matrix = np.array([[1, 1, 0], 267 | [0, 0, 1], 268 | [1, 0, 1]]) 269 | labels = ['a', 'b', 'c'] 270 | 271 | p1, l1 = graph_util.permute_graph(matrix, labels, [2, 0, 1]) 272 | self.assertTrue(np.array_equal(p1, 273 | np.array([[0, 1, 0], 274 | [0, 1, 1], 275 | [1, 0, 1]]))) 276 | self.assertEqual(l1, ['b', 'c', 'a']) 277 | 278 | p1, l1 = graph_util.permute_graph(matrix, labels, [0, 2, 1]) 279 | self.assertTrue(np.array_equal(p1, 280 | np.array([[1, 0, 1], 281 | [1, 1, 0], 282 | [0, 1, 0]]))) 283 | self.assertEqual(l1, ['a', 'c', 'b']) 284 | 285 | def test_is_isomorphic(self): 286 | # Reuse some tests from hash_module 287 | matrix1 = np.array( 288 | [[0, 1, 1, 0,], 289 | [0, 0, 0, 1], 290 | [0, 0, 0, 1], 291 | [0, 0, 0, 0]]) 292 | label1 = [-1, 1, 2, -2] 293 | label2 = [-1, 2, 1, -2] 294 | 295 | self.assertTrue(graph_util.is_isomorphic((matrix1, label1), 296 | (matrix1, label2))) 297 | 298 | # Simple graph with edge permutation 299 | matrix1 = np.array( 300 | [[0, 1, 1, 0, 0], 301 | [0, 0, 0, 0, 1], 302 | [0, 0, 0, 1, 0], 303 | [0, 0, 0, 0, 1], 304 | [0, 0, 0, 0, 0]]) 305 | label1 = [-1, 1, 2, 3, -2] 306 | 307 | matrix2 = np.array( 308 | [[0, 1, 0, 1, 0], 309 | [0, 0, 1, 0, 0], 310 | [0, 0, 0, 0, 1], 311 | [0, 0, 0, 0, 1], 312 | [0, 0, 0, 0, 0]]) 313 | label2 = [-1, 2, 3, 1, -2] 314 | 315 | matrix3 = np.array( 316 | [[0, 1, 1, 0, 0], 317 | [0, 0, 0, 1, 0], 318 | [0, 0, 0, 0, 1], 319 | [0, 0, 0, 0, 1], 320 | [0, 0, 0, 0, 0]]) 321 | label3 = [-1, 2, 1, 3, -2] 322 | 323 | self.assertTrue(graph_util.is_isomorphic((matrix1, label1), 324 | (matrix2, label2))) 325 | self.assertTrue(graph_util.is_isomorphic((matrix1, label1), 326 | (matrix3, label3))) 327 | self.assertFalse(graph_util.is_isomorphic((matrix1, label1), 328 | (matrix2, label1))) 329 | 330 | # Connected non-isomorphic regular graphs on 6 interior vertices (8 total) 331 | matrix1 = np.array( 332 | [[0, 1, 0, 0, 0, 0, 0, 0], 333 | [0, 0, 1, 1, 0, 0, 1, 0], 334 | [0, 0, 0, 0, 1, 1, 0, 0], 335 | [0, 0, 0, 0, 1, 1, 0, 0], 336 | [0, 0, 0, 0, 0, 0, 1, 0], 337 | [0, 0, 0, 0, 0, 0, 1, 0], 338 | [0, 0, 0, 0, 0, 0, 0, 1], 339 | [0, 0, 0, 0, 0, 0, 0, 0]]) 340 | matrix2 = np.array( 341 | [[0, 1, 0, 0, 0, 0, 0, 0], 342 | [0, 0, 1, 1, 0, 1, 0, 0], 343 | [0, 0, 0, 0, 1, 0, 1, 0], 344 | [0, 0, 0, 0, 1, 1, 0, 0], 345 | [0, 0, 0, 0, 0, 0, 1, 0], 346 | [0, 0, 0, 0, 0, 0, 1, 0], 347 | [0, 0, 0, 0, 0, 0, 0, 1], 348 | [0, 0, 0, 0, 0, 0, 0, 0]]) 349 | label1 = [-1, 1, 1, 1, 1, 1, 1, -2] 350 | 351 | self.assertFalse(graph_util.is_isomorphic((matrix1, label1), 352 | (matrix2, label1))) 353 | 354 | # Connected isomorphic regular graphs on 8 total vertices (bipartite) 355 | matrix1 = np.array( 356 | [[0, 0, 0, 0, 1, 1, 1, 0], 357 | [0, 0, 0, 0, 1, 1, 0, 1], 358 | [0, 0, 0, 0, 1, 0, 1, 1], 359 | [0, 0, 0, 0, 0, 1, 1, 1], 360 | [1, 1, 1, 0, 0, 0, 0, 0], 361 | [1, 1, 0, 1, 0, 0, 0, 0], 362 | [1, 0, 1, 1, 0, 0, 0, 0], 363 | [0, 1, 1, 1, 0, 0, 0, 0]]) 364 | matrix2 = np.array( 365 | [[0, 1, 0, 1, 1, 0, 0, 0], 366 | [1, 0, 1, 0, 0, 1, 0, 0], 367 | [0, 1, 0, 1, 0, 0, 1, 0], 368 | [1, 0, 1, 0, 0, 0, 0, 1], 369 | [1, 0, 0, 0, 0, 1, 0, 1], 370 | [0, 1, 0, 0, 1, 0, 1, 0], 371 | [0, 0, 1, 0, 0, 1, 0, 1], 372 | [0, 0, 0, 1, 1, 0, 1, 0]]) 373 | label1 = [1, 1, 1, 1, 1, 1, 1, 1] 374 | 375 | # Sanity check: manual permutation 376 | perm = [0, 5, 7, 2, 4, 1, 3, 6] 377 | pm1, pl1 = graph_util.permute_graph(matrix1, label1, perm) 378 | self.assertTrue(np.array_equal(matrix2, pm1)) 379 | self.assertEqual(pl1, label1) 380 | 381 | self.assertTrue(graph_util.is_isomorphic((matrix1, label1), 382 | (matrix2, label1))) 383 | 384 | label2 = [1, 1, 1, 1, 2, 2, 2, 2] 385 | label3 = [1, 2, 1, 2, 2, 1, 2, 1] 386 | 387 | self.assertTrue(graph_util.is_isomorphic((matrix1, label2), 388 | (matrix2, label3))) 389 | 390 | def test_random_isomorphism_hashing(self): 391 | # Tests that hash_module always provides the same hash for randomly 392 | # generated isomorphic graphs. 393 | for _ in range(1000): 394 | # Generate random graph. Note: the algorithm works (i.e. same hash == 395 | # isomorphic graphs) for all directed graphs with coloring and does not 396 | # require the graph to be a DAG. 397 | size = random.randint(3, 20) 398 | matrix = np.random.randint(0, 2, [size, size]) 399 | labels = [random.randint(0, 10) for _ in range(size)] 400 | 401 | # Generate permutation of matrix and labels. 402 | perm = np.random.permutation(size).tolist() 403 | pmatrix, plabels = graph_util.permute_graph(matrix, labels, perm) 404 | 405 | # Hashes should be identical. 406 | hash1 = graph_util.hash_module(matrix, labels) 407 | hash2 = graph_util.hash_module(pmatrix, plabels) 408 | self.assertEqual(hash1, hash2) 409 | 410 | def test_counterexample_bipartite(self): 411 | # This is a counter example that shows that the hashing algorithm is not 412 | # perfectly identifiable (i.e. there are non-isomorphic graphs with the same 413 | # hash). If this tests fails, it means the algorithm must have been changed 414 | # in some way that allows it to identify these graphs as non-isomoprhic. 415 | matrix1 = np.array( 416 | [[0, 1, 1, 1, 1, 0, 0, 0, 0, 0], 417 | [0, 0, 0, 0, 0, 1, 1, 0, 0, 0], 418 | [0, 0, 0, 0, 0, 1, 1, 0, 0, 0], 419 | [0, 0, 0, 0, 0, 0, 0, 1, 1, 0], 420 | [0, 0, 0, 0, 0, 0, 0, 1, 1, 0], 421 | [0, 0, 0, 0, 0, 0, 0, 0, 0, 1], 422 | [0, 0, 0, 0, 0, 0, 0, 0, 0, 1], 423 | [0, 0, 0, 0, 0, 0, 0, 0, 0, 1], 424 | [0, 0, 0, 0, 0, 0, 0, 0, 0, 1], 425 | [0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]) 426 | 427 | matrix2 = np.array( 428 | [[0, 1, 1, 1, 1, 0, 0, 0, 0, 0], 429 | [0, 0, 0, 0, 0, 1, 1, 0, 0, 0], 430 | [0, 0, 0, 0, 0, 0, 1, 1, 0, 0], 431 | [0, 0, 0, 0, 0, 0, 0, 1, 1, 0], 432 | [0, 0, 0, 0, 0, 1, 0, 0, 1, 0], 433 | [0, 0, 0, 0, 0, 0, 0, 0, 0, 1], 434 | [0, 0, 0, 0, 0, 0, 0, 0, 0, 1], 435 | [0, 0, 0, 0, 0, 0, 0, 0, 0, 1], 436 | [0, 0, 0, 0, 0, 0, 0, 0, 0, 1], 437 | [0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]) 438 | 439 | labels = [-1, 1, 1, 1, 1, 2, 2, 2, 2, -2] 440 | 441 | # This takes far too long to run so commenting it out. The graphs are 442 | # non-isomorphic fairly obviously from visual inspection. 443 | # self.assertFalse(graph_util.is_isomorphic((matrix1, labels), 444 | # (matrix2, labels))) 445 | self.assertEqual(graph_util.hash_module(matrix1, labels), 446 | graph_util.hash_module(matrix2, labels)) 447 | 448 | 449 | if __name__ == '__main__': 450 | tf.test.main() 451 | -------------------------------------------------------------------------------- /nasbench/lib/model_builder.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 The Google Research Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Builds the TensorFlow computational graph. 16 | 17 | Tensors flowing into a single vertex are added together for all vertices 18 | except the output, which is concatenated instead. Tensors flowing out of input 19 | are always added. 20 | 21 | If interior edge channels don't match, drop the extra channels (channels are 22 | guaranteed non-decreasing). Tensors flowing out of the input as always 23 | projected instead. 24 | """ 25 | 26 | from __future__ import absolute_import 27 | from __future__ import division 28 | from __future__ import print_function 29 | 30 | from nasbench.lib import base_ops 31 | from nasbench.lib import training_time 32 | import numpy as np 33 | import tensorflow as tf 34 | 35 | 36 | def build_model_fn(spec, config, num_train_images): 37 | """Returns a model function for Estimator.""" 38 | if config['data_format'] == 'channels_last': 39 | channel_axis = 3 40 | elif config['data_format'] == 'channels_first': 41 | # Currently this is not well supported 42 | channel_axis = 1 43 | else: 44 | raise ValueError('invalid data_format') 45 | 46 | def model_fn(features, labels, mode, params): 47 | """Builds the model from the input features.""" 48 | del params # Unused 49 | is_training = (mode == tf.estimator.ModeKeys.TRAIN) 50 | 51 | # Store auxiliary activations increasing in depth of network. First 52 | # activation occurs immediately after the stem and the others immediately 53 | # follow each stack. 54 | aux_activations = [] 55 | 56 | # Initial stem convolution 57 | with tf.variable_scope('stem'): 58 | net = base_ops.conv_bn_relu( 59 | features, 3, config['stem_filter_size'], 60 | is_training, config['data_format']) 61 | aux_activations.append(net) 62 | 63 | for stack_num in range(config['num_stacks']): 64 | channels = net.get_shape()[channel_axis].value 65 | 66 | # Downsample at start (except first) 67 | if stack_num > 0: 68 | net = tf.layers.max_pooling2d( 69 | inputs=net, 70 | pool_size=(2, 2), 71 | strides=(2, 2), 72 | padding='same', 73 | data_format=config['data_format']) 74 | 75 | # Double output channels each time we downsample 76 | channels *= 2 77 | 78 | with tf.variable_scope('stack{}'.format(stack_num)): 79 | for module_num in range(config['num_modules_per_stack']): 80 | with tf.variable_scope('module{}'.format(module_num)): 81 | net = build_module( 82 | spec, 83 | inputs=net, 84 | channels=channels, 85 | is_training=is_training) 86 | aux_activations.append(net) 87 | 88 | # Global average pool 89 | if config['data_format'] == 'channels_last': 90 | net = tf.reduce_mean(net, [1, 2]) 91 | elif config['data_format'] == 'channels_first': 92 | net = tf.reduce_mean(net, [2, 3]) 93 | else: 94 | raise ValueError('invalid data_format') 95 | 96 | # Fully-connected layer to labels 97 | logits = tf.layers.dense( 98 | inputs=net, 99 | units=config['num_labels']) 100 | 101 | if mode == tf.estimator.ModeKeys.PREDICT and not config['use_tpu']: 102 | # It is a known limitation of Estimator that the labels 103 | # are not passed during PREDICT mode when running on CPU/GPU 104 | # (https://github.com/tensorflow/tensorflow/issues/17824), thus we cannot 105 | # compute the loss or anything dependent on it (i.e., the gradients). 106 | loss = tf.constant(0.0) 107 | else: 108 | loss = tf.losses.softmax_cross_entropy( 109 | onehot_labels=tf.one_hot(labels, config['num_labels']), 110 | logits=logits) 111 | 112 | loss += config['weight_decay'] * tf.add_n( 113 | [tf.nn.l2_loss(v) for v in tf.trainable_variables()]) 114 | 115 | # Use inference mode to compute some useful metrics on a fixed sample 116 | # Due to the batch being sharded on TPU, these metrics should be run on CPU 117 | # only to ensure that the metrics are computed on the whole batch. We add a 118 | # leading dimension because PREDICT expects batch-shaped tensors. 119 | if mode == tf.estimator.ModeKeys.PREDICT: 120 | parameter_norms = { 121 | 'param:' + tensor.name: 122 | tf.expand_dims(tf.norm(tensor, ord=2), 0) 123 | for tensor in tf.trainable_variables() 124 | } 125 | 126 | # Compute gradients of all parameters and the input simultaneously 127 | all_params_names = [] 128 | all_params_tensors = [] 129 | for tensor in tf.trainable_variables(): 130 | all_params_names.append('param_grad_norm:' + tensor.name) 131 | all_params_tensors.append(tensor) 132 | all_params_names.append('input_grad_norm') 133 | all_params_tensors.append(features) 134 | 135 | grads = tf.gradients(loss, all_params_tensors) 136 | 137 | param_gradient_norms = {} 138 | for name, grad in zip(all_params_names, grads)[:-1]: 139 | if grad is not None: 140 | param_gradient_norms[name] = ( 141 | tf.expand_dims(tf.norm(grad, ord=2), 0)) 142 | else: 143 | param_gradient_norms[name] = ( 144 | tf.expand_dims(tf.constant(0.0), 0)) 145 | 146 | if grads[-1] is not None: 147 | input_grad_norm = tf.sqrt(tf.reduce_sum( 148 | tf.square(grads[-1]), axis=[1, 2, 3])) 149 | else: 150 | input_grad_norm = tf.expand_dims(tf.constant(0.0), 0) 151 | 152 | covariance_matrices = { 153 | 'cov_matrix_%d' % i: 154 | tf.expand_dims(_covariance_matrix(aux), 0) 155 | for i, aux in enumerate(aux_activations) 156 | } 157 | 158 | predictions = { 159 | 'logits': logits, 160 | 'loss': tf.expand_dims(loss, 0), 161 | 'input_grad_norm': input_grad_norm, 162 | } 163 | predictions.update(parameter_norms) 164 | predictions.update(param_gradient_norms) 165 | predictions.update(covariance_matrices) 166 | 167 | return tf.contrib.tpu.TPUEstimatorSpec(mode=mode, predictions=predictions) 168 | 169 | if mode == tf.estimator.ModeKeys.TRAIN: 170 | global_step = tf.train.get_or_create_global_step() 171 | base_lr = config['learning_rate'] 172 | if config['use_tpu']: 173 | base_lr *= config['tpu_num_shards'] 174 | 175 | if config['lr_decay_method'] == 'COSINE_BY_STEP': 176 | total_steps = int(config['train_epochs'] * num_train_images / 177 | config['batch_size']) 178 | progress_fraction = tf.cast(global_step, tf.float32) / total_steps 179 | learning_rate = (0.5 * base_lr * 180 | (1 + tf.cos(np.pi * progress_fraction))) 181 | 182 | elif config['lr_decay_method'] == 'COSINE_BY_TIME': 183 | # Requires training_time.limit hooks to be added to Estimator 184 | elapsed_time = tf.cast(training_time.get_total_time(), dtype=tf.float32) 185 | progress_fraction = elapsed_time / config['train_seconds'] 186 | learning_rate = (0.5 * base_lr * 187 | (1 + tf.cos(np.pi * progress_fraction))) 188 | 189 | elif config['lr_decay_method'] == 'STEPWISE': 190 | # divide LR by 10 at 1/2, 2/3, and 5/6 of total epochs 191 | total_steps = (config['train_epochs'] * num_train_images / 192 | config['batch_size']) 193 | boundaries = [int(0.5 * total_steps), 194 | int(0.667 * total_steps), 195 | int(0.833 * total_steps)] 196 | values = [1.0 * base_lr, 197 | 0.1 * base_lr, 198 | 0.01 * base_lr, 199 | 0.0001 * base_lr] 200 | learning_rate = tf.train.piecewise_constant( 201 | global_step, boundaries, values) 202 | 203 | else: 204 | raise ValueError('invalid lr_decay_method') 205 | 206 | # Set LR to 0 for step 0 to initialize the weights without training 207 | learning_rate = tf.where(tf.equal(global_step, 0), 0.0, learning_rate) 208 | 209 | optimizer = tf.train.RMSPropOptimizer( 210 | learning_rate=learning_rate, 211 | momentum=config['momentum'], 212 | epsilon=1.0) 213 | if config['use_tpu']: 214 | optimizer = tf.contrib.tpu.CrossShardOptimizer(optimizer) 215 | 216 | # Update ops required for batch norm moving variables 217 | update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) 218 | with tf.control_dependencies(update_ops): 219 | train_op = optimizer.minimize(loss, global_step) 220 | 221 | return tf.contrib.tpu.TPUEstimatorSpec( 222 | mode=mode, 223 | loss=loss, 224 | train_op=train_op) 225 | 226 | elif mode == tf.estimator.ModeKeys.EVAL: 227 | def metric_fn(labels, logits): 228 | predictions = tf.argmax(logits, axis=1) 229 | accuracy = tf.metrics.accuracy(labels, predictions) 230 | 231 | return {'accuracy': accuracy} 232 | 233 | eval_metrics = (metric_fn, [labels, logits]) 234 | 235 | return tf.contrib.tpu.TPUEstimatorSpec( 236 | mode=mode, 237 | loss=loss, 238 | eval_metrics=eval_metrics) 239 | 240 | return model_fn 241 | 242 | 243 | def build_module(spec, inputs, channels, is_training): 244 | """Build a custom module using a proposed model spec. 245 | 246 | Builds the model using the adjacency matrix and op labels specified. Channels 247 | controls the module output channel count but the interior channels are 248 | determined via equally splitting the channel count whenever there is a 249 | concatenation of Tensors. 250 | 251 | Args: 252 | spec: ModelSpec object. 253 | inputs: input Tensors to this module. 254 | channels: output channel count. 255 | is_training: bool for whether this model is training. 256 | 257 | Returns: 258 | output Tensor from built module. 259 | 260 | Raises: 261 | ValueError: invalid spec 262 | """ 263 | num_vertices = np.shape(spec.matrix)[0] 264 | 265 | if spec.data_format == 'channels_last': 266 | channel_axis = 3 267 | elif spec.data_format == 'channels_first': 268 | channel_axis = 1 269 | else: 270 | raise ValueError('invalid data_format') 271 | 272 | input_channels = inputs.get_shape()[channel_axis].value 273 | # vertex_channels[i] = number of output channels of vertex i 274 | vertex_channels = compute_vertex_channels( 275 | input_channels, channels, spec.matrix) 276 | 277 | # Construct tensors from input forward 278 | tensors = [tf.identity(inputs, name='input')] 279 | 280 | final_concat_in = [] 281 | for t in range(1, num_vertices - 1): 282 | with tf.variable_scope('vertex_{}'.format(t)): 283 | # Create interior connections, truncating if necessary 284 | add_in = [truncate(tensors[src], vertex_channels[t], spec.data_format) 285 | for src in range(1, t) if spec.matrix[src, t]] 286 | 287 | # Create add connection from projected input 288 | if spec.matrix[0, t]: 289 | add_in.append(projection( 290 | tensors[0], 291 | vertex_channels[t], 292 | is_training, 293 | spec.data_format)) 294 | 295 | if len(add_in) == 1: 296 | vertex_input = add_in[0] 297 | else: 298 | vertex_input = tf.add_n(add_in) 299 | 300 | # Perform op at vertex t 301 | op = base_ops.OP_MAP[spec.ops[t]]( 302 | is_training=is_training, 303 | data_format=spec.data_format) 304 | vertex_value = op.build(vertex_input, vertex_channels[t]) 305 | 306 | tensors.append(vertex_value) 307 | if spec.matrix[t, num_vertices - 1]: 308 | final_concat_in.append(tensors[t]) 309 | 310 | # Construct final output tensor by concating all fan-in and adding input. 311 | if not final_concat_in: 312 | # No interior vertices, input directly connected to output 313 | assert spec.matrix[0, num_vertices - 1] 314 | with tf.variable_scope('output'): 315 | outputs = projection( 316 | tensors[0], 317 | channels, 318 | is_training, 319 | spec.data_format) 320 | 321 | else: 322 | if len(final_concat_in) == 1: 323 | outputs = final_concat_in[0] 324 | else: 325 | outputs = tf.concat(final_concat_in, channel_axis) 326 | 327 | if spec.matrix[0, num_vertices - 1]: 328 | outputs += projection( 329 | tensors[0], 330 | channels, 331 | is_training, 332 | spec.data_format) 333 | 334 | outputs = tf.identity(outputs, name='output') 335 | return outputs 336 | 337 | 338 | def projection(inputs, channels, is_training, data_format): 339 | """1x1 projection (as in ResNet) followed by batch normalization and ReLU.""" 340 | with tf.variable_scope('projection'): 341 | net = base_ops.conv_bn_relu(inputs, 1, channels, is_training, data_format) 342 | 343 | return net 344 | 345 | 346 | def truncate(inputs, channels, data_format): 347 | """Slice the inputs to channels if necessary.""" 348 | if data_format == 'channels_last': 349 | input_channels = inputs.get_shape()[3].value 350 | else: 351 | assert data_format == 'channels_first' 352 | input_channels = inputs.get_shape()[1].value 353 | 354 | if input_channels < channels: 355 | raise ValueError('input channel < output channels for truncate') 356 | elif input_channels == channels: 357 | return inputs # No truncation necessary 358 | else: 359 | # Truncation should only be necessary when channel division leads to 360 | # vertices with +1 channels. The input vertex should always be projected to 361 | # the minimum channel count. 362 | assert input_channels - channels == 1 363 | if data_format == 'channels_last': 364 | return tf.slice(inputs, [0, 0, 0, 0], [-1, -1, -1, channels]) 365 | else: 366 | return tf.slice(inputs, [0, 0, 0, 0], [-1, channels, -1, -1]) 367 | 368 | 369 | def compute_vertex_channels(input_channels, output_channels, matrix): 370 | """Computes the number of channels at every vertex. 371 | 372 | Given the input channels and output channels, this calculates the number of 373 | channels at each interior vertex. Interior vertices have the same number of 374 | channels as the max of the channels of the vertices it feeds into. The output 375 | channels are divided amongst the vertices that are directly connected to it. 376 | When the division is not even, some vertices may receive an extra channel to 377 | compensate. 378 | 379 | Args: 380 | input_channels: input channel count. 381 | output_channels: output channel count. 382 | matrix: adjacency matrix for the module (pruned by model_spec). 383 | 384 | Returns: 385 | list of channel counts, in order of the vertices. 386 | """ 387 | num_vertices = np.shape(matrix)[0] 388 | 389 | vertex_channels = [0] * num_vertices 390 | vertex_channels[0] = input_channels 391 | vertex_channels[num_vertices - 1] = output_channels 392 | 393 | if num_vertices == 2: 394 | # Edge case where module only has input and output vertices 395 | return vertex_channels 396 | 397 | # Compute the in-degree ignoring input, axis 0 is the src vertex and axis 1 is 398 | # the dst vertex. Summing over 0 gives the in-degree count of each vertex. 399 | in_degree = np.sum(matrix[1:], axis=0) 400 | interior_channels = output_channels // in_degree[num_vertices - 1] 401 | correction = output_channels % in_degree[num_vertices - 1] # Remainder to add 402 | 403 | # Set channels of vertices that flow directly to output 404 | for v in range(1, num_vertices - 1): 405 | if matrix[v, num_vertices - 1]: 406 | vertex_channels[v] = interior_channels 407 | if correction: 408 | vertex_channels[v] += 1 409 | correction -= 1 410 | 411 | # Set channels for all other vertices to the max of the out edges, going 412 | # backwards. (num_vertices - 2) index skipped because it only connects to 413 | # output. 414 | for v in range(num_vertices - 3, 0, -1): 415 | if not matrix[v, num_vertices - 1]: 416 | for dst in range(v + 1, num_vertices - 1): 417 | if matrix[v, dst]: 418 | vertex_channels[v] = max(vertex_channels[v], vertex_channels[dst]) 419 | assert vertex_channels[v] > 0 420 | 421 | tf.logging.info('vertex_channels: %s', str(vertex_channels)) 422 | 423 | # Sanity check, verify that channels never increase and final channels add up. 424 | final_fan_in = 0 425 | for v in range(1, num_vertices - 1): 426 | if matrix[v, num_vertices - 1]: 427 | final_fan_in += vertex_channels[v] 428 | for dst in range(v + 1, num_vertices - 1): 429 | if matrix[v, dst]: 430 | assert vertex_channels[v] >= vertex_channels[dst] 431 | assert final_fan_in == output_channels or num_vertices == 2 432 | # num_vertices == 2 means only input/output nodes, so 0 fan-in 433 | 434 | return vertex_channels 435 | 436 | 437 | def _covariance_matrix(activations): 438 | """Computes the unbiased covariance matrix of the samples within the batch. 439 | 440 | Computes the sample covariance between the samples in the batch. Specifically, 441 | 442 | C(i,j) = (x_i - mean(x_i)) dot (x_j - mean(x_j)) / (N - 1) 443 | 444 | Matches the default behavior of np.cov(). 445 | 446 | Args: 447 | activations: tensor activations with batch dimension first. 448 | 449 | Returns: 450 | [batch, batch] shape tensor for the covariance matrix. 451 | """ 452 | batch_size = activations.get_shape()[0].value 453 | flattened = tf.reshape(activations, [batch_size, -1]) 454 | means = tf.reduce_mean(flattened, axis=1, keepdims=True) 455 | 456 | centered = flattened - means 457 | squared = tf.matmul(centered, tf.transpose(centered)) 458 | cov = squared / (tf.cast(tf.shape(flattened)[1], tf.float32) - 1) 459 | 460 | return cov 461 | 462 | --------------------------------------------------------------------------------