├── .gitignore ├── Contributors.md ├── LICENSE.txt ├── README.md ├── assets ├── architecture.png └── param_time_acc.png ├── main.py ├── nasbench_pytorch ├── __init__.py ├── datasets │ ├── __init__.py │ └── cifar10.py ├── model │ ├── __init__.py │ ├── base_ops.py │ ├── graph_util.py │ ├── model.py │ └── model_spec.py └── trainer.py ├── setup.cfg └── setup.py /.gitignore: -------------------------------------------------------------------------------- 1 | 2 | # Created by https://www.toptal.com/developers/gitignore/api/pycharm 3 | # Edit at https://www.toptal.com/developers/gitignore?templates=pycharm 4 | 5 | ### PyCharm ### 6 | # Covers JetBrains IDEs: IntelliJ, RubyMine, PhpStorm, AppCode, PyCharm, CLion, Android Studio, WebStorm and Rider 7 | # Reference: https://intellij-support.jetbrains.com/hc/en-us/articles/206544839 8 | data/* 9 | .idea/* 10 | checkpoint/* 11 | dist/* 12 | build/* 13 | 14 | *.ipynb 15 | 16 | **/__pycache__/** 17 | 18 | *.egg-info/* 19 | 20 | # User-specific stuff 21 | .idea/**/workspace.xml 22 | .idea/**/tasks.xml 23 | .idea/**/usage.statistics.xml 24 | .idea/**/dictionaries 25 | .idea/**/shelf 26 | 27 | # Generated files 28 | .idea/**/contentModel.xml 29 | 30 | # Sensitive or high-churn files 31 | .idea/**/dataSources/ 32 | .idea/**/dataSources.ids 33 | .idea/**/dataSources.local.xml 34 | .idea/**/sqlDataSources.xml 35 | .idea/**/dynamic.xml 36 | .idea/**/uiDesigner.xml 37 | .idea/**/dbnavigator.xml 38 | 39 | # Gradle 40 | .idea/**/gradle.xml 41 | .idea/**/libraries 42 | 43 | # Gradle and Maven with auto-import 44 | # When using Gradle or Maven with auto-import, you should exclude module files, 45 | # since they will be recreated, and may cause churn. Uncomment if using 46 | # auto-import. 47 | # .idea/artifacts 48 | # .idea/compiler.xml 49 | # .idea/jarRepositories.xml 50 | # .idea/modules.xml 51 | # .idea/*.iml 52 | # .idea/modules 53 | # *.iml 54 | # *.ipr 55 | 56 | # CMake 57 | cmake-build-*/ 58 | 59 | # Mongo Explorer plugin 60 | .idea/**/mongoSettings.xml 61 | 62 | # File-based project format 63 | *.iws 64 | 65 | # IntelliJ 66 | out/ 67 | 68 | # mpeltonen/sbt-idea plugin 69 | .idea_modules/ 70 | 71 | # JIRA plugin 72 | atlassian-ide-plugin.xml 73 | 74 | # Cursive Clojure plugin 75 | .idea/replstate.xml 76 | 77 | # Crashlytics plugin (for Android Studio and IntelliJ) 78 | com_crashlytics_export_strings.xml 79 | crashlytics.properties 80 | crashlytics-build.properties 81 | fabric.properties 82 | 83 | # Editor-based Rest Client 84 | .idea/httpRequests 85 | 86 | # Android studio 3.1+ serialized cache file 87 | .idea/caches/build_file_checksums.ser 88 | 89 | ### PyCharm Patch ### 90 | # Comment Reason: https://github.com/joeblau/gitignore.io/issues/186#issuecomment-215987721 91 | 92 | # *.iml 93 | # modules.xml 94 | # .idea/misc.xml 95 | # *.ipr 96 | 97 | # Sonarlint plugin 98 | # https://plugins.jetbrains.com/plugin/7973-sonarlint 99 | .idea/**/sonarlint/ 100 | 101 | # SonarQube Plugin 102 | # https://plugins.jetbrains.com/plugin/7238-sonarqube-community-plugin 103 | .idea/**/sonarIssues.xml 104 | 105 | # Markdown Navigator plugin 106 | # https://plugins.jetbrains.com/plugin/7896-markdown-navigator-enhanced 107 | .idea/**/markdown-navigator.xml 108 | .idea/**/markdown-navigator-enh.xml 109 | .idea/**/markdown-navigator/ 110 | 111 | # Cache file creation bug 112 | # See https://youtrack.jetbrains.com/issue/JBR-2257 113 | .idea/$CACHE_FILE$ 114 | 115 | # CodeStream plugin 116 | # https://plugins.jetbrains.com/plugin/12206-codestream 117 | .idea/codestream.xml 118 | 119 | # End of https://www.toptal.com/developers/gitignore/api/pycharm 120 | -------------------------------------------------------------------------------- /Contributors.md: -------------------------------------------------------------------------------- 1 | # Contributors 2 | - [@romulus0914](https://github.com/romulus0914) (Romulus Hong) 3 | - Author of the code - NAS-Bench-101 implementation in PyTorch 4 | - [@gabikadlecova](https://github.com/gabikadlecova) 5 | - Maintainer of the repository 6 | - Package structure, reproducibility 7 | --------- 8 | - [@abhash-er](https://github.com/abhash-er/) (Abhash Jha) 9 | - Modified the model code so that cast to double is possible 10 | - [@longerHost](https://github.com/longerHost) 11 | - Reproducibility of the original NAS-Bench-101 12 | - Comparison of training results and API results 13 | -------------------------------------------------------------------------------- /LICENSE.txt: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # NASBench-PyTorch 2 | NASBench-PyTorch is a PyTorch implementation of the search space 3 | [NAS-Bench-101](https://github.com/google-research/nasbench) including the training of the networks[**](#note). The original 4 | implementation is written in TensorFlow, and this projects contains 5 | some files from the original repository (in the directory 6 | `nasbench_pytorch/model/`). 7 | 8 | **Important:** if you want to reproduce the original results, please refer to the 9 | [Reproducibility](#repro) section. 10 | 11 | # Overview 12 | A PyTorch implementation of *training* of NAS-Bench-101 dataset: [NAS-Bench-101: Towards Reproducible Neural Architecture Search](https://arxiv.org/abs/1902.09635). 13 | The dataset contains 423,624 unique neural networks exhaustively generated and evaluated from a fixed graph-based search space. 14 | 15 | # Usage 16 | You need to have PyTorch installed. 17 | 18 | You can install the package by running `pip install nasbench_pytorch`. The second possibility is to install from source code: 19 | 20 | 1. Clone this repo 21 | ``` 22 | git clone https://github.com/romulus0914/NASBench-PyTorch 23 | cd NASBench-PyTorch 24 | ``` 25 | 26 | 2. Install the project 27 | ``` 28 | pip install -e . 29 | ``` 30 | 31 | The file `main.py` contains an example training of a network. To see 32 | the different parameters, run: 33 | 34 | ``` 35 | python main.py --help 36 | ``` 37 | 38 | ### Train a network by hash 39 | To train a network whose architecture is queried from NAS-Bench-101 40 | using its unique hash, install the original [nasbench](https://github.com/google-research/nasbench) 41 | repository. Follow the instructions in the README, note that you 42 | need to install TensorFlow. If you need TensorFlow 2.x, install 43 | [this fork](https://github.com/gabrielasuchopar/nasbench) of the 44 | repository instead. 45 | 46 | Then, you can get the PyTorch architecture of a network like this: 47 | 48 | ```python 49 | from nasbench_pytorch.model import Network as NBNetwork 50 | from nasbench import api 51 | 52 | 53 | nasbench_path = '$path_to_downloaded_nasbench' 54 | nb = api.NASBench(nasbench_path) 55 | 56 | net_hash = '$some_hash' # you can get hashes using nasbench.hash_iterator() 57 | m = nb.get_metrics_from_hash(net_hash) 58 | ops = m[0]['module_operations'] 59 | adjacency = m[0]['module_adjacency'] 60 | 61 | net = NBNetwork((adjacency, ops)) 62 | ``` 63 | 64 | Then, you can train it just like the example network in `main.py`. 65 | 66 | # Architecture 67 | Example architecture (picture from the original repository) 68 | ![archtecture](./assets/architecture.png) 69 | 70 | # Reproducibility 71 | The code should closely match the TensorFlow version (including the hyperparameters), but there are some differences: 72 | - RMSProp implementation in TensorFlow and PyTorch is **different** 73 | - For more information refer to [here](https://github.com/pytorch/pytorch/issues/32545) and [here](https://github.com/pytorch/pytorch/issues/23796). 74 | - Optionally, you can install pytorch-image-models where a [TensorFlow-like RMSProp](https://github.com/rwightman/pytorch-image-models/blob/main/timm/optim/rmsprop_tf.py#L5) is implemented 75 | - `pip install timm` 76 | - Then, pass `--optimizer rmsprop_tf` to `main.py` to use it 77 | 78 | 79 | - You can turn gradient clipping off by setting `--grad_clip_off True` 80 | 81 | 82 | - The original training was on TPUs, this code enables only GPU and CPU training 83 | - Input data augmentation methods are the same, but due to randomness they are not applied in the same manner 84 | - Cause: Batches and images cannot be shuffled as in the original TPU training, and the augmentation seed is also different 85 | - Results may still differ due to TensorFlow/PyTorch implementation differences 86 | 87 | Refer to this [issue](https://github.com/romulus0914/NASBench-PyTorch/issues/6) for more information and for comparison with API results. 88 | 89 | # Disclaimer 90 | Modified from [NASBench: A Neural Architecture Search Dataset and Benchmark](https://github.com/google-research/nasbench). 91 | *graph_util.py* and *model_spec.py* are directly copied from the original repo. Original license can be found [here](https://github.com/google-research/nasbench/blob/master/LICENSE). 92 | 93 | 94 | **Please note that this repo is only used to train one possible architecture in the search space, not to generate all possible graphs and train them. 95 | -------------------------------------------------------------------------------- /assets/architecture.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/romulus0914/NASBench-PyTorch/37bc332bc6802e0375d8cca839d54f47524487fb/assets/architecture.png -------------------------------------------------------------------------------- /assets/param_time_acc.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/romulus0914/NASBench-PyTorch/37bc332bc6802e0375d8cca839d54f47524487fb/assets/param_time_acc.png -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | 4 | import torch 5 | import torch.nn as nn 6 | import torch.optim as optim 7 | 8 | from nasbench_pytorch.datasets.cifar10 import prepare_dataset 9 | from nasbench_pytorch.model import Network 10 | from nasbench_pytorch.model import ModelSpec 11 | from nasbench_pytorch.trainer import train, test 12 | 13 | matrix = [[0, 1, 1, 1, 0, 1, 0], 14 | [0, 0, 0, 0, 0, 0, 1], 15 | [0, 0, 0, 0, 0, 0, 1], 16 | [0, 0, 0, 0, 1, 0, 0], 17 | [0, 0, 0, 0, 0, 0, 1], 18 | [0, 0, 0, 0, 0, 0, 1], 19 | [0, 0, 0, 0, 0, 0, 0]] 20 | 21 | operations = ['input', 'conv1x1-bn-relu', 'conv3x3-bn-relu', 'conv3x3-bn-relu', 'conv3x3-bn-relu', 22 | 'maxpool3x3', 'output'] 23 | 24 | 25 | def save_checkpoint(net, postfix='cifar10'): 26 | print('--- Saving Checkpoint ---') 27 | 28 | if not os.path.isdir('checkpoint'): 29 | os.mkdir('checkpoint') 30 | 31 | torch.save(net.state_dict(), './checkpoint/ckpt_' + postfix + '.pt') 32 | 33 | def reload_checkpoint(path, device=None): 34 | print('--- Reloading Checkpoint ---') 35 | 36 | assert os.path.isdir('checkpoint'), '[Error] No checkpoint directory found!' 37 | return torch.load(path, map_location=device) 38 | 39 | 40 | if __name__ == '__main__': 41 | parser = argparse.ArgumentParser(description='NASBench') 42 | parser.add_argument('--random_state', default=1, type=int, help='Random seed.') 43 | parser.add_argument('--data_root', default='./data/', type=str, help='Path where cifar will be downloaded.') 44 | parser.add_argument('--in_channels', default=3, type=int, help='Number of input channels.') 45 | parser.add_argument('--stem_out_channels', default=128, type=int, help='output channels of stem convolution') 46 | parser.add_argument('--num_stacks', default=3, type=int, help='#stacks of modules') 47 | parser.add_argument('--num_modules_per_stack', default=3, type=int, help='#modules per stack') 48 | parser.add_argument('--batch_size', default=256, type=int, help='batch size') 49 | parser.add_argument('--test_batch_size', default=256, type=int, help='test set batch size') 50 | parser.add_argument('--epochs', default=108, type=int, help='#epochs of training') 51 | parser.add_argument('--validation_size', default=10000, type=int, help="Size of the validation set to split off.") 52 | parser.add_argument('--num_workers', default=0, type=int, help="Number of parallel workers for the train dataset.") 53 | parser.add_argument('--learning_rate', default=0.2, type=float, help='base learning rate') 54 | parser.add_argument('--optimizer', default='rmsprop', type=str, help='Optimizer (sgd, rmsprop or rmsprop_tf)') 55 | parser.add_argument('--rmsprop_eps', default=1.0, type=float, help='RMSProp eps parameter.') 56 | parser.add_argument('--momentum', default=0.9, type=float, help='momentum') 57 | parser.add_argument('--weight_decay', default=1e-4, type=float, help='L2 regularization weight') 58 | parser.add_argument('--grad_clip', default=5, type=float, help='gradient clipping') 59 | parser.add_argument('--grad_clip_off', default=False, type=bool, help='If True, turn off gradient clipping.') 60 | parser.add_argument('--batch_norm_momentum', default=0.997, type=float, help='Batch normalization momentum') 61 | parser.add_argument('--batch_norm_eps', default=1e-5, type=float, help='Batch normalization epsilon') 62 | parser.add_argument('--load_checkpoint', default='', type=str, help='Reload model from checkpoint') 63 | parser.add_argument('--num_labels', default=10, type=int, help='#classes') 64 | parser.add_argument('--device', default='cuda', type=str, help='Device for network training.') 65 | parser.add_argument('--print_freq', default=100, type=int, help='Batch print frequency.') 66 | parser.add_argument('--tf_like', default=False, type=bool, 67 | help='If true, use same weight initialization as in the tensorflow version.') 68 | 69 | args = parser.parse_args() 70 | 71 | # cifar10 dataset 72 | dataset = prepare_dataset(args.batch_size, test_batch_size=args.test_batch_size, root=args.data_root, 73 | validation_size=args.validation_size, random_state=args.random_state, 74 | set_global_seed=True, num_workers=args.num_workers) 75 | 76 | train_loader, test_loader, test_size = dataset['train'], dataset['test'], dataset['test_size'] 77 | valid_loader = dataset['validation'] if args.validation_size > 0 else None 78 | 79 | # model 80 | spec = ModelSpec(matrix, operations) 81 | net = Network(spec, num_labels=args.num_labels, in_channels=args.in_channels, 82 | stem_out_channels=args.stem_out_channels, num_stacks=args.num_stacks, 83 | num_modules_per_stack=args.num_modules_per_stack, 84 | momentum=args.batch_norm_momentum, eps=args.batch_norm_eps, tf_like=args.tf_like) 85 | 86 | if args.load_checkpoint != '': 87 | net.load_state_dict(reload_checkpoint(args.load_checkpoint)) 88 | net.to(args.device) 89 | 90 | criterion = nn.CrossEntropyLoss() 91 | 92 | if args.optimizer.lower() == 'sgd': 93 | optimizer = optim.SGD 94 | optimizer_kwargs = {} 95 | elif args.optimizer.lower() == 'rmsprop': 96 | optimizer = optim.RMSprop 97 | optimizer_kwargs = {'eps': args.rmsprop_eps} 98 | elif args.optimizer.lower() == 'rmsprop_tf': 99 | from timm.optim import RMSpropTF 100 | optimizer = RMSpropTF 101 | optimizer_kwargs = {'eps': args.rmsprop_eps} 102 | else: 103 | raise ValueError(f"Invalid optimizer {args.optimizer}, possible: SGD, RMSProp") 104 | 105 | optimizer = optimizer(net.parameters(), lr=args.learning_rate, momentum=args.momentum, 106 | weight_decay=args.weight_decay, **optimizer_kwargs) 107 | scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, args.epochs) 108 | 109 | result = train(net, train_loader, loss=criterion, optimizer=optimizer, scheduler=scheduler, 110 | grad_clip=args.grad_clip if not args.grad_clip_off else None, 111 | num_epochs=args.epochs, num_validation=args.validation_size, validation_loader=valid_loader, 112 | device=args.device, print_frequency=args.print_freq) 113 | 114 | last_epoch = {k: v[-1] for k, v in result.items() if len(v) > 0} 115 | print(f"Final train metrics: {last_epoch}") 116 | 117 | result = test(net, test_loader, loss=criterion, num_tests=test_size, device=args.device) 118 | print(f"\nFinal test metrics: {result}") 119 | 120 | save_checkpoint(net) 121 | -------------------------------------------------------------------------------- /nasbench_pytorch/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/romulus0914/NASBench-PyTorch/37bc332bc6802e0375d8cca839d54f47524487fb/nasbench_pytorch/__init__.py -------------------------------------------------------------------------------- /nasbench_pytorch/datasets/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/romulus0914/NASBench-PyTorch/37bc332bc6802e0375d8cca839d54f47524487fb/nasbench_pytorch/datasets/__init__.py -------------------------------------------------------------------------------- /nasbench_pytorch/datasets/cifar10.py: -------------------------------------------------------------------------------- 1 | """ 2 | Specific transforms and constants have been extracted from 3 | https://github.com/google-research/nasbench/blob/master/nasbench/lib/cifar.py 4 | """ 5 | import random 6 | from functools import partial 7 | import numpy as np 8 | import torch 9 | import torchvision 10 | import torchvision.transforms as transforms 11 | 12 | from torch.utils.data.sampler import SubsetRandomSampler 13 | 14 | 15 | def train_valid_split(dataset_size, valid_size, random_state=None): 16 | random = np.random.RandomState(seed=random_state) if random_state is not None else np.random 17 | valid_inds = random.choice(dataset_size, size=valid_size, replace=False) 18 | 19 | train_inds = np.delete(np.arange(dataset_size), valid_inds) 20 | 21 | return SubsetRandomSampler(train_inds), SubsetRandomSampler(valid_inds) 22 | 23 | 24 | def seed_worker(seed, worker_id): 25 | seed = seed if seed is not None else 0 26 | worker_seed = seed + worker_id 27 | np.random.seed(worker_seed) 28 | random.seed(worker_seed) 29 | 30 | 31 | def prepare_dataset(batch_size, test_batch_size=256, root='./data/', use_validation=True, split_from_end=True, 32 | validation_size=10000, random_state=None, set_global_seed=False, no_valid_transform=True, 33 | num_workers=0, num_val_workers=0, num_test_workers=0): 34 | """ 35 | Download the CIFAR-10 dataset and prepare train and test DataLoaders (optionally also validation loader). 36 | 37 | Args: 38 | batch_size: Batch size for the train (and validation) loader. 39 | test_batch_size: Batch size for the test loader. 40 | root: Directory path to download the CIFAR-10 dataset to. 41 | use_validation: If False, don't split off the validation set. 42 | split_from_end: If True, split off `validation_size` images from the end, if False, choose images randomly. 43 | validation_size: Size of the validation dataset to split off the train set. 44 | 45 | random_state: Seed for the random functions (generators from numpy and random) 46 | set_global_seed: If True, call np.random.seed(random_state) and random.seed(random_state). Useful when 47 | using 0 workers (because otherwise RandomCrop will return different results every call), but affects 48 | the seed in the whole program. 49 | 50 | no_valid_transform: If True, don't use RandomCrop and RandomFlip for the validation set. 51 | num_workers: Number of workers for the train loader. 52 | num_val_workers: Number of workers for the validation loader. 53 | num_test_workers: Number of workers for the test loader. 54 | 55 | Returns: 56 | if validation_size > 0: 57 | train loader, train size, validation loader, validation size, test loader, test size 58 | otherwise: 59 | train loader, train size, test loader, test size 60 | 61 | The sizes are dataset sizes, not the number of batches. 62 | 63 | """ 64 | 65 | if set_global_seed: 66 | seed_worker(random_state, 0) 67 | 68 | if random_state is not None: 69 | worker_fn = partial(seed_worker, random_state) 70 | else: 71 | worker_fn = None 72 | 73 | print('\n--- Preparing CIFAR10 Data ---') 74 | 75 | train_transform = transforms.Compose([ 76 | transforms.RandomCrop(32, padding=4), 77 | transforms.RandomHorizontalFlip(), 78 | transforms.ToTensor(), 79 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), 80 | ]) 81 | 82 | test_transform = transforms.Compose([ 83 | transforms.ToTensor(), 84 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), 85 | ]) 86 | 87 | train_set = torchvision.datasets.CIFAR10(root=root, train=True, download=True, transform=train_transform) 88 | valid_set = torchvision.datasets.CIFAR10(root=root, train=True, download=True, transform=test_transform) 89 | valid_set = valid_set if no_valid_transform else train_set 90 | train_size = len(train_set) 91 | 92 | if use_validation: 93 | if split_from_end: 94 | # get last n images 95 | indices = np.arange(len(train_set)) 96 | train_set = torch.utils.data.Subset(train_set, indices[:-validation_size]) 97 | valid_set = torch.utils.data.Subset(valid_set, indices[-validation_size:]) 98 | train_sampler, valid_sampler = None, None 99 | else: 100 | # split off random validation set 101 | train_sampler, valid_sampler = train_valid_split(train_size, validation_size, random_state=random_state) 102 | 103 | # shuffle is True if split_from_end otherwise False 104 | train_loader = torch.utils.data.DataLoader(train_set, batch_size=batch_size, shuffle=split_from_end, 105 | sampler=train_sampler, num_workers=num_workers, 106 | worker_init_fn=worker_fn) 107 | valid_loader = torch.utils.data.DataLoader(valid_set, batch_size=batch_size, shuffle=False, 108 | sampler=valid_sampler, num_workers=num_val_workers, 109 | worker_init_fn=worker_fn) 110 | else: 111 | train_loader = torch.utils.data.DataLoader(train_set, batch_size=batch_size, shuffle=True, 112 | num_workers=num_workers, worker_init_fn=worker_fn) 113 | valid_loader = None 114 | 115 | test_set = torchvision.datasets.CIFAR10(root=root, train=False, download=True, transform=test_transform) 116 | test_loader = torch.utils.data.DataLoader(test_set, batch_size=test_batch_size, shuffle=False, 117 | num_workers=num_test_workers, worker_init_fn=worker_fn) 118 | test_size = len(test_set) 119 | 120 | print('--- CIFAR10 Data Prepared ---\n') 121 | 122 | data = { 123 | 'train': train_loader, 124 | 'train_size': train_size, 125 | 'test': test_loader, 126 | 'test_size': test_size 127 | } 128 | 129 | if validation_size > 0: 130 | data['train_size'] = train_size - validation_size 131 | data['validation'] = valid_loader 132 | data['validation_size'] = validation_size 133 | 134 | return data 135 | -------------------------------------------------------------------------------- /nasbench_pytorch/model/__init__.py: -------------------------------------------------------------------------------- 1 | from .model import Network 2 | from .model_spec import ModelSpec 3 | 4 | __all__ = [Network, ModelSpec] 5 | -------------------------------------------------------------------------------- /nasbench_pytorch/model/base_ops.py: -------------------------------------------------------------------------------- 1 | """Base operations used by the modules in this search space.""" 2 | 3 | from __future__ import absolute_import 4 | from __future__ import division 5 | from __future__ import print_function 6 | 7 | import torch 8 | import torch.nn as nn 9 | import torch.nn.functional as F 10 | 11 | class ConvBnRelu(nn.Module): 12 | def __init__(self, in_channels, out_channels, kernel_size=1, stride=1, padding=0, momentum=0.1, eps=1e-5): 13 | super(ConvBnRelu, self).__init__() 14 | 15 | self.conv_bn_relu = nn.Sequential( 16 | nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, bias=False), 17 | nn.BatchNorm2d(out_channels, eps=eps, momentum=momentum), 18 | nn.ReLU() 19 | ) 20 | 21 | def forward(self, x): 22 | return self.conv_bn_relu(x) 23 | 24 | class Conv3x3BnRelu(nn.Module): 25 | """3x3 convolution with batch norm and ReLU activation.""" 26 | def __init__(self, in_channels, out_channels, **kwargs): 27 | super(Conv3x3BnRelu, self).__init__() 28 | 29 | self.conv3x3 = ConvBnRelu(in_channels, out_channels, 3, 1, 1, **kwargs) 30 | 31 | def forward(self, x): 32 | x = self.conv3x3(x) 33 | return x 34 | 35 | class Conv1x1BnRelu(nn.Module): 36 | """1x1 convolution with batch norm and ReLU activation.""" 37 | def __init__(self, in_channels, out_channels, **kwargs): 38 | super(Conv1x1BnRelu, self).__init__() 39 | 40 | self.conv1x1 = ConvBnRelu(in_channels, out_channels, 1, 1, 0, **kwargs) 41 | 42 | def forward(self, x): 43 | x = self.conv1x1(x) 44 | return x 45 | 46 | class MaxPool3x3(nn.Module): 47 | """3x3 max pool with no subsampling.""" 48 | def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, padding=1): 49 | super(MaxPool3x3, self).__init__() 50 | 51 | self.maxpool = nn.MaxPool2d(kernel_size, stride, padding) 52 | 53 | def forward(self, x): 54 | x = self.maxpool(x) 55 | return x 56 | 57 | # Commas should not be used in op names 58 | OP_MAP = { 59 | 'conv3x3-bn-relu': Conv3x3BnRelu, 60 | 'conv1x1-bn-relu': Conv1x1BnRelu, 61 | 'maxpool3x3': MaxPool3x3 62 | } 63 | -------------------------------------------------------------------------------- /nasbench_pytorch/model/graph_util.py: -------------------------------------------------------------------------------- 1 | # Copyright 2019 The Google Research Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Utility functions used by generate_graph.py.""" 16 | from __future__ import absolute_import 17 | from __future__ import division 18 | from __future__ import print_function 19 | 20 | import hashlib 21 | import itertools 22 | 23 | import numpy as np 24 | 25 | 26 | def gen_is_edge_fn(bits): 27 | """Generate a boolean function for the edge connectivity. 28 | 29 | Given a bitstring FEDCBA and a 4x4 matrix, the generated matrix is 30 | [[0, A, B, D], 31 | [0, 0, C, E], 32 | [0, 0, 0, F], 33 | [0, 0, 0, 0]] 34 | 35 | Note that this function is agnostic to the actual matrix dimension due to 36 | order in which elements are filled out (column-major, starting from least 37 | significant bit). For example, the same FEDCBA bitstring (0-padded) on a 5x5 38 | matrix is 39 | [[0, A, B, D, 0], 40 | [0, 0, C, E, 0], 41 | [0, 0, 0, F, 0], 42 | [0, 0, 0, 0, 0], 43 | [0, 0, 0, 0, 0]] 44 | 45 | Args: 46 | bits: integer which will be interpreted as a bit mask. 47 | 48 | Returns: 49 | vectorized function that returns True when an edge is present. 50 | """ 51 | def is_edge(x, y): 52 | """Is there an edge from x to y (0-indexed)?""" 53 | if x >= y: 54 | return 0 55 | # Map x, y to index into bit string 56 | index = x + (y * (y - 1) // 2) 57 | return (bits >> index) % 2 == 1 58 | 59 | return np.vectorize(is_edge) 60 | 61 | 62 | def is_full_dag(matrix): 63 | """Full DAG == all vertices on a path from vert 0 to (V-1). 64 | 65 | i.e. no disconnected or "hanging" vertices. 66 | 67 | It is sufficient to check for: 68 | 1) no rows of 0 except for row V-1 (only output vertex has no out-edges) 69 | 2) no cols of 0 except for col 0 (only input vertex has no in-edges) 70 | 71 | Args: 72 | matrix: V x V upper-triangular adjacency matrix 73 | 74 | Returns: 75 | True if the there are no dangling vertices. 76 | """ 77 | shape = np.shape(matrix) 78 | 79 | rows = matrix[:shape[0]-1, :] == 0 80 | rows = np.all(rows, axis=1) # Any row with all 0 will be True 81 | rows_bad = np.any(rows) 82 | 83 | cols = matrix[:, 1:] == 0 84 | cols = np.all(cols, axis=0) # Any col with all 0 will be True 85 | cols_bad = np.any(cols) 86 | 87 | return (not rows_bad) and (not cols_bad) 88 | 89 | 90 | def num_edges(matrix): 91 | """Computes number of edges in adjacency matrix.""" 92 | return np.sum(matrix) 93 | 94 | 95 | def hash_module(matrix, labeling): 96 | """Computes a graph-invariance MD5 hash of the matrix and label pair. 97 | 98 | Args: 99 | matrix: np.ndarray square upper-triangular adjacency matrix. 100 | labeling: list of int labels of length equal to both dimensions of 101 | matrix. 102 | 103 | Returns: 104 | MD5 hash of the matrix and labeling. 105 | """ 106 | vertices = np.shape(matrix)[0] 107 | in_edges = np.sum(matrix, axis=0).tolist() 108 | out_edges = np.sum(matrix, axis=1).tolist() 109 | 110 | assert len(in_edges) == len(out_edges) == len(labeling) 111 | hashes = list(zip(out_edges, in_edges, labeling)) 112 | hashes = [hashlib.md5(str(h).encode('utf-8')).hexdigest() for h in hashes] 113 | # Computing this up to the diameter is probably sufficient but since the 114 | # operation is fast, it is okay to repeat more times. 115 | for _ in range(vertices): 116 | new_hashes = [] 117 | for v in range(vertices): 118 | in_neighbors = [hashes[w] for w in range(vertices) if matrix[w, v]] 119 | out_neighbors = [hashes[w] for w in range(vertices) if matrix[v, w]] 120 | new_hashes.append(hashlib.md5( 121 | (''.join(sorted(in_neighbors)) + '|' + 122 | ''.join(sorted(out_neighbors)) + '|' + 123 | hashes[v]).encode('utf-8')).hexdigest()) 124 | hashes = new_hashes 125 | fingerprint = hashlib.md5(str(sorted(hashes)).encode('utf-8')).hexdigest() 126 | 127 | return fingerprint 128 | 129 | 130 | def permute_graph(graph, label, permutation): 131 | """Permutes the graph and labels based on permutation. 132 | 133 | Args: 134 | graph: np.ndarray adjacency matrix. 135 | label: list of labels of same length as graph dimensions. 136 | permutation: a permutation list of ints of same length as graph dimensions. 137 | 138 | Returns: 139 | np.ndarray where vertex permutation[v] is vertex v from the original graph 140 | """ 141 | # vertex permutation[v] in new graph is vertex v in the old graph 142 | forward_perm = zip(permutation, list(range(len(permutation)))) 143 | inverse_perm = [x[1] for x in sorted(forward_perm)] 144 | edge_fn = lambda x, y: graph[inverse_perm[x], inverse_perm[y]] == 1 145 | new_matrix = np.fromfunction(np.vectorize(edge_fn), 146 | (len(label), len(label)), 147 | dtype=np.int8) 148 | new_label = [label[inverse_perm[i]] for i in range(len(label))] 149 | return new_matrix, new_label 150 | 151 | 152 | def is_isomorphic(graph1, graph2): 153 | """Exhaustively checks if 2 graphs are isomorphic.""" 154 | matrix1, label1 = np.array(graph1[0]), graph1[1] 155 | matrix2, label2 = np.array(graph2[0]), graph2[1] 156 | assert np.shape(matrix1) == np.shape(matrix2) 157 | assert len(label1) == len(label2) 158 | 159 | vertices = np.shape(matrix1)[0] 160 | # Note: input and output in our constrained graphs always map to themselves 161 | # but this script does not enforce that. 162 | for perm in itertools.permutations(range(0, vertices)): 163 | pmatrix1, plabel1 = permute_graph(matrix1, label1, perm) 164 | if np.array_equal(pmatrix1, matrix2) and plabel1 == label2: 165 | return True 166 | 167 | return False 168 | -------------------------------------------------------------------------------- /nasbench_pytorch/model/model.py: -------------------------------------------------------------------------------- 1 | """Builds the Pytorch computational graph. 2 | 3 | Tensors flowing into a single vertex are added together for all vertices 4 | except the output, which is concatenated instead. Tensors flowing out of input 5 | are always added. 6 | 7 | If interior edge channels don't match, drop the extra channels (channels are 8 | guaranteed non-decreasing). Tensors flowing out of the input as always 9 | projected instead. 10 | """ 11 | 12 | from __future__ import absolute_import 13 | from __future__ import division 14 | from __future__ import print_function 15 | 16 | import numpy as np 17 | import math 18 | 19 | from nasbench_pytorch.model.base_ops import * 20 | from nasbench_pytorch.model.model_spec import ModelSpec 21 | 22 | import torch 23 | import torch.nn as nn 24 | from torch.nn.init import _calculate_fan_in_and_fan_out 25 | 26 | 27 | class Network(nn.Module): 28 | def __init__(self, spec, num_labels=10, in_channels=3, stem_out_channels=128, num_stacks=3, num_modules_per_stack=3, 29 | momentum=0.997, eps=1e-5, tf_like=False): 30 | """ 31 | 32 | Args: 33 | spec: ModelSpec from nasbench, or a tuple (adjacency matrix, ops) 34 | num_labels: Number of output labels. 35 | in_channels: Number of input image channels. 36 | stem_out_channels: Number of output stem channels. Other hidden channels are computed and depend on this 37 | number. 38 | 39 | num_stacks: Number of stacks, in every stacks the cells have the same number of channels. 40 | num_modules_per_stack: Number of cells per stack. 41 | """ 42 | super(Network, self).__init__() 43 | 44 | if isinstance(spec, tuple): 45 | spec = ModelSpec(spec[0], spec[1]) 46 | 47 | self.cell_indices = set() 48 | 49 | self.tf_like = tf_like 50 | self.layers = nn.ModuleList([]) 51 | 52 | # initial stem convolution 53 | out_channels = stem_out_channels 54 | stem_conv = ConvBnRelu(in_channels, out_channels, 3, 1, 1, momentum=momentum, eps=eps) 55 | self.layers.append(stem_conv) 56 | 57 | # stacked cells 58 | in_channels = out_channels 59 | for stack_num in range(num_stacks): 60 | # downsample after every but the last cell 61 | if stack_num > 0: 62 | downsample = nn.MaxPool2d(kernel_size=2, stride=2) 63 | self.layers.append(downsample) 64 | 65 | out_channels *= 2 66 | 67 | for module_num in range(num_modules_per_stack): 68 | cell = Cell(spec, in_channels, out_channels, momentum=momentum, eps=eps) 69 | self.layers.append(cell) 70 | in_channels = out_channels 71 | 72 | self.cell_indices.add(len(self.layers) - 1) 73 | 74 | self.classifier = nn.Linear(out_channels, num_labels) 75 | 76 | self._initialize_weights() 77 | 78 | def forward(self, x): 79 | for _, layer in enumerate(self.layers): 80 | x = layer(x) 81 | out = torch.mean(x, (2, 3)) 82 | out = self.classifier(out) 83 | 84 | return out 85 | 86 | def _initialize_weights(self): 87 | for m in self.modules(): 88 | if isinstance(m, nn.Conv2d): 89 | if self.tf_like: 90 | fan_in, _ = _calculate_fan_in_and_fan_out(m.weight) 91 | torch.nn.init.normal_(m.weight, mean=0, std=1.0 / torch.sqrt(torch.tensor(fan_in))) 92 | else: 93 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 94 | m.weight.data.normal_(0, math.sqrt(2.0 / n)) 95 | 96 | if m.bias is not None: 97 | m.bias.data.zero_() 98 | 99 | elif isinstance(m, nn.BatchNorm2d): 100 | m.weight.data.fill_(1) 101 | m.bias.data.zero_() 102 | elif isinstance(m, nn.Linear): 103 | if self.tf_like: 104 | torch.nn.init.xavier_uniform_(m.weight) 105 | else: 106 | m.weight.data.normal_(0, 0.01) 107 | m.bias.data.zero_() 108 | 109 | 110 | class Cell(nn.Module): 111 | """ 112 | Builds the model using the adjacency matrix and op labels specified. Channels 113 | control the module output channel count but the interior channels are 114 | determined via equally splitting the channel count whenever there is a 115 | concatenation of Tensors. 116 | """ 117 | def __init__(self, spec, in_channels, out_channels, momentum=0.1, eps=1e-5): 118 | super(Cell, self).__init__() 119 | 120 | self.dev_param = nn.Parameter(torch.empty(0)) 121 | 122 | self.matrix = spec.matrix 123 | self.num_vertices = np.shape(self.matrix)[0] 124 | 125 | # vertex_channels[i] = number of output channels of vertex i 126 | self.vertex_channels = compute_vertex_channels(in_channels, out_channels, self.matrix) 127 | #self.vertex_channels = [in_channels] + [out_channels] * (self.num_vertices - 1) 128 | 129 | # operation for each node 130 | self.vertex_op = nn.ModuleList([Placeholder()]) 131 | for t in range(1, self.num_vertices-1): 132 | op = OP_MAP[spec.ops[t]](self.vertex_channels[t], self.vertex_channels[t]) 133 | self.vertex_op.append(op) 134 | 135 | # operation for input on each vertex 136 | self.input_op = nn.ModuleList([Placeholder()]) 137 | for t in range(1, self.num_vertices): 138 | if self.matrix[0, t]: 139 | self.input_op.append(projection(in_channels, self.vertex_channels[t], momentum=momentum, eps=eps)) 140 | else: 141 | self.input_op.append(Placeholder()) 142 | 143 | self.last_inop : projection = self.input_op[self.num_vertices - 1] 144 | 145 | def forward(self, x): 146 | tensors = [x] 147 | 148 | out_concat = [] 149 | # range(1, self.num_vertices - 1), 150 | for t, (inmod, outmod) in enumerate(zip(self.input_op, self.vertex_op)): 151 | if 0 < t < (self.num_vertices - 1): 152 | 153 | fan_in = [] 154 | for src in range(1, t): 155 | if self.matrix[src, t]: 156 | fan_in.append(truncate(tensors[src], torch.tensor(self.vertex_channels[t]))) 157 | 158 | if self.matrix[0, t]: 159 | l = inmod(x) 160 | fan_in.append(l) 161 | 162 | # perform operation on node 163 | vertex_input = torch.zeros_like(fan_in[0]).to(self.dev_param.device) 164 | for val in fan_in: 165 | vertex_input += val 166 | 167 | vertex_output = outmod(vertex_input) 168 | 169 | tensors.append(vertex_output) 170 | if self.matrix[t, self.num_vertices-1]: 171 | out_concat.append(tensors[t]) 172 | 173 | if not out_concat: 174 | assert self.matrix[0, self.num_vertices-1] 175 | outputs = self.last_inop(tensors[0]) 176 | else: 177 | if len(out_concat) == 1: 178 | outputs = out_concat[0] 179 | else: 180 | outputs = torch.cat(out_concat, 1) 181 | 182 | if self.matrix[0, self.num_vertices-1]: 183 | outputs = outputs + self.last_inop(tensors[0]) 184 | 185 | return outputs 186 | 187 | 188 | def projection(in_channels, out_channels, momentum=0.1, eps=1e-5): 189 | """1x1 projection (as in ResNet) followed by batch normalization and ReLU.""" 190 | return ConvBnRelu(in_channels, out_channels, 1, momentum=momentum, eps=eps) 191 | 192 | 193 | def truncate(inputs, channels): 194 | """Slice the inputs to channels if necessary.""" 195 | input_channels = inputs.size()[1] 196 | if input_channels < channels: 197 | raise ValueError('input channel < output channels for truncate') 198 | elif input_channels == channels: 199 | return inputs # No truncation necessary 200 | else: 201 | # Truncation should only be necessary when channel division leads to 202 | # vertices with +1 channels. The input vertex should always be projected to 203 | # the minimum channel count. 204 | assert input_channels - channels == 1 205 | return inputs[:, :channels, :, :] 206 | 207 | 208 | def compute_vertex_channels(in_channels, out_channels, matrix): 209 | """Computes the number of channels at every vertex. 210 | 211 | Given the input channels and output channels, this calculates the number of 212 | channels at each interior vertex. Interior vertices have the same number of 213 | channels as the max of the channels of the vertices it feeds into. The output 214 | channels are divided amongst the vertices that are directly connected to it. 215 | When the division is not even, some vertices may receive an extra channel to 216 | compensate. 217 | 218 | Code from https://github.com/google-research/nasbench/ 219 | 220 | Returns: 221 | list of channel counts, in order of the vertices. 222 | """ 223 | if isinstance(matrix, torch.Tensor): 224 | matrix = matrix.numpy() 225 | 226 | num_vertices = np.shape(matrix)[0] 227 | 228 | vertex_channels = [0] * num_vertices 229 | vertex_channels[0] = in_channels 230 | vertex_channels[num_vertices - 1] = out_channels 231 | 232 | if num_vertices == 2: 233 | # Edge case where module only has input and output vertices 234 | return vertex_channels 235 | 236 | # Compute the in-degree ignoring input, axis 0 is the src vertex and axis 1 is 237 | # the dst vertex. Summing over 0 gives the in-degree count of each vertex. 238 | in_degree = np.sum(matrix[1:], axis=0) 239 | interior_channels = out_channels // in_degree[num_vertices - 1] 240 | correction = out_channels % in_degree[num_vertices - 1] # Remainder to add 241 | 242 | # Set channels of vertices that flow directly to output 243 | for v in range(1, num_vertices - 1): 244 | if matrix[v, num_vertices - 1]: 245 | vertex_channels[v] = interior_channels 246 | if correction: 247 | vertex_channels[v] += 1 248 | correction -= 1 249 | 250 | # Set channels for all other vertices to the max of the out edges, going 251 | # backwards. (num_vertices - 2) index skipped because it only connects to 252 | # output. 253 | for v in range(num_vertices - 3, 0, -1): 254 | if not matrix[v, num_vertices - 1]: 255 | for dst in range(v + 1, num_vertices - 1): 256 | if matrix[v, dst]: 257 | vertex_channels[v] = max(vertex_channels[v], vertex_channels[dst]) 258 | assert vertex_channels[v] > 0 259 | 260 | # Sanity check, verify that channels never increase and final channels add up. 261 | final_fan_in = 0 262 | for v in range(1, num_vertices - 1): 263 | if matrix[v, num_vertices - 1]: 264 | final_fan_in += vertex_channels[v] 265 | for dst in range(v + 1, num_vertices - 1): 266 | if matrix[v, dst]: 267 | assert vertex_channels[v] >= vertex_channels[dst] 268 | assert final_fan_in == out_channels or num_vertices == 2 269 | # num_vertices == 2 means only input/output nodes, so 0 fan-in 270 | 271 | return [int(v) for v in vertex_channels] 272 | 273 | 274 | class Placeholder(torch.nn.Module): 275 | def __init__(self): 276 | super().__init__() 277 | self.a = torch.nn.Parameter(torch.randn(())) 278 | 279 | def forward(self, x): 280 | return x 281 | -------------------------------------------------------------------------------- /nasbench_pytorch/model/model_spec.py: -------------------------------------------------------------------------------- 1 | 2 | # Copyright 2019 The Google Research Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Model specification for module connectivity individuals. 17 | 18 | This module handles pruning the unused parts of the computation graph but should 19 | avoid creating any TensorFlow models (this is done inside model_builder.py). 20 | """ 21 | 22 | from __future__ import absolute_import 23 | from __future__ import division 24 | from __future__ import print_function 25 | 26 | import copy 27 | import numpy as np 28 | import torch 29 | 30 | from nasbench_pytorch.model import graph_util 31 | 32 | # Graphviz is optional and only required for visualization. 33 | try: 34 | import graphviz # pylint: disable=g-import-not-at-top 35 | except ImportError: 36 | pass 37 | 38 | 39 | class ModelSpec(object): 40 | """Model specification given adjacency matrix and labeling.""" 41 | 42 | def __init__(self, matrix, ops, data_format='channels_last'): 43 | """Initialize the module spec. 44 | 45 | Args: 46 | matrix: ndarray or nested list with shape [V, V] for the adjacency matrix. 47 | ops: V-length list of labels for the base ops used. The first and last 48 | elements are ignored because they are the input and output vertices 49 | which have no operations. The elements are retained to keep consistent 50 | indexing. 51 | data_format: channels_last or channels_first. 52 | 53 | Raises: 54 | ValueError: invalid matrix or ops 55 | """ 56 | 57 | if not isinstance(matrix, np.ndarray): 58 | matrix = np.array(matrix) 59 | shape = np.shape(matrix) 60 | if len(shape) != 2 or shape[0] != shape[1]: 61 | raise ValueError('matrix must be square') 62 | if shape[0] != len(ops): 63 | raise ValueError('length of ops must match matrix dimensions') 64 | if not is_upper_triangular(matrix): 65 | raise ValueError('matrix must be upper triangular') 66 | 67 | # Both the original and pruned matrices are deep copies of the matrix and 68 | # ops so any changes to those after initialization are not recognized by the 69 | # spec. 70 | self.original_matrix = copy.deepcopy(matrix) 71 | self.original_ops = copy.deepcopy(ops) 72 | 73 | self.matrix = copy.deepcopy(matrix) 74 | self.ops = copy.deepcopy(ops) 75 | self.valid_spec = True 76 | self._prune() 77 | 78 | self.matrix = torch.tensor(self.matrix) 79 | 80 | self.data_format = data_format 81 | 82 | def _prune(self): 83 | """Prune the extraneous parts of the graph. 84 | 85 | General procedure: 86 | 1) Remove parts of graph not connected to input. 87 | 2) Remove parts of graph not connected to output. 88 | 3) Reorder the vertices so that they are consecutive after steps 1 and 2. 89 | 90 | These 3 steps can be combined by deleting the rows and columns of the 91 | vertices that are not reachable from both the input and output (in reverse). 92 | """ 93 | num_vertices = np.shape(self.original_matrix)[0] 94 | 95 | # DFS forward from input 96 | visited_from_input = set([0]) 97 | frontier = [0] 98 | while frontier: 99 | top = frontier.pop() 100 | for v in range(top + 1, num_vertices): 101 | if self.original_matrix[top, v] and v not in visited_from_input: 102 | visited_from_input.add(v) 103 | frontier.append(v) 104 | 105 | # DFS backward from output 106 | visited_from_output = set([num_vertices - 1]) 107 | frontier = [num_vertices - 1] 108 | while frontier: 109 | top = frontier.pop() 110 | for v in range(0, top): 111 | if self.original_matrix[v, top] and v not in visited_from_output: 112 | visited_from_output.add(v) 113 | frontier.append(v) 114 | 115 | # Any vertex that isn't connected to both input and output is extraneous to 116 | # the computation graph. 117 | extraneous = set(range(num_vertices)).difference( 118 | visited_from_input.intersection(visited_from_output)) 119 | 120 | # If the non-extraneous graph is less than 2 vertices, the input is not 121 | # connected to the output and the spec is invalid. 122 | if len(extraneous) > num_vertices - 2: 123 | self.matrix = None 124 | self.ops = None 125 | self.valid_spec = False 126 | return 127 | 128 | self.matrix = np.delete(self.matrix, list(extraneous), axis=0) 129 | self.matrix = np.delete(self.matrix, list(extraneous), axis=1) 130 | for index in sorted(extraneous, reverse=True): 131 | del self.ops[index] 132 | 133 | def hash_spec(self, canonical_ops): 134 | """Computes the isomorphism-invariant graph hash of this spec. 135 | 136 | Args: 137 | canonical_ops: list of operations in the canonical ordering which they 138 | were assigned (i.e. the order provided in the config['available_ops']). 139 | 140 | Returns: 141 | MD5 hash of this spec which can be used to query the dataset. 142 | """ 143 | # Invert the operations back to integer label indices used in graph gen. 144 | labeling = [-1] + [canonical_ops.index(op) for op in self.ops[1:-1]] + [-2] 145 | return graph_util.hash_module(self.matrix, labeling) 146 | 147 | def visualize(self): 148 | """Creates a dot graph. Can be visualized in colab directly.""" 149 | num_vertices = np.shape(self.matrix)[0] 150 | g = graphviz.Digraph() 151 | g.node(str(0), 'input') 152 | for v in range(1, num_vertices - 1): 153 | g.node(str(v), self.ops[v]) 154 | g.node(str(num_vertices - 1), 'output') 155 | 156 | for src in range(num_vertices - 1): 157 | for dst in range(src + 1, num_vertices): 158 | if self.matrix[src, dst]: 159 | g.edge(str(src), str(dst)) 160 | 161 | return g 162 | 163 | 164 | def is_upper_triangular(matrix): 165 | """True if matrix is 0 on diagonal and below.""" 166 | for src in range(np.shape(matrix)[0]): 167 | for dst in range(0, src + 1): 168 | if matrix[src, dst] != 0: 169 | return False 170 | 171 | return True 172 | -------------------------------------------------------------------------------- /nasbench_pytorch/trainer.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | import torch 4 | from torch import nn 5 | 6 | 7 | def train(net, train_loader, loss=None, optimizer=None, scheduler=None, grad_clip=5, num_epochs=10, 8 | num_validation=None, validation_loader=None, device=None, print_frequency=200, 9 | checkpoint_every_k=None, checkpoint_func=None): 10 | """ 11 | Train a network from the NAS-bench-101 search space on a dataset (`train_loader`). 12 | 13 | Args: 14 | net: Network to train. 15 | train_loader: Train data loader. 16 | loss: Loss, default is CrossEntropyLoss. 17 | optimizer: Optimizer, default is SGD, possible: 'sgd', 'rmsprop', 'adam', or an optimizer object. 18 | scheduler: Default is CosineAnnealingLR. 19 | grad_clip: Gradient clipping parameter. 20 | num_epochs: Number of training epochs. 21 | num_validation: Number of validation examples (for print purposes). 22 | validation_loader: Optional validation set. 23 | device: Device to train on, default is cpu. 24 | print_frequency: How often to print info about batches. 25 | checkpoint_every_k: Every k epochs, save a checkpoint. 26 | checkpoint_func: Custom function to save the checkpoint, signature: func(net, metric_dict, epoch num) 27 | 28 | Returns: 29 | Final train (and validation) metrics. 30 | """ 31 | 32 | net = net.to(device) 33 | 34 | # defaults 35 | if loss is None: 36 | loss = nn.CrossEntropyLoss() 37 | 38 | if optimizer is not None and not isinstance(optimizer, str): 39 | pass 40 | elif optimizer is None or optimizer.lower() == 'rmsprop': 41 | optimizer = torch.optim.RMSprop(net.parameters(), lr=0.2, momentum=0.9, weight_decay=1e-4, eps=1.0) 42 | elif optimizer.lower() == 'sgd': 43 | optimizer = torch.optim.SGD(net.parameters(), lr=0.025, momentum=0.9, weight_decay=1e-4) 44 | elif optimizer.lower() == 'adam': 45 | optimizer = torch.optim.Adam(net.parameters()) 46 | 47 | if scheduler is None: 48 | scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, num_epochs) 49 | 50 | # training 51 | 52 | n_batches = len(train_loader) 53 | 54 | metric_dict = {'train_loss': [], 'train_accuracy': [], 'val_loss': [], 'val_accuracy': []} 55 | for epoch in range(num_epochs): 56 | # checkpoint using a user defined function 57 | if checkpoint_every_k is not None and (epoch + 1) % checkpoint_every_k == 0: 58 | checkpoint_func(net, metric_dict, epoch + 1) 59 | 60 | net.train() 61 | 62 | train_loss = torch.tensor(0.0) 63 | correct = torch.tensor(0) 64 | total = 0 65 | 66 | batch_idx = 0 67 | for batch_idx, (inputs, targets) in enumerate(train_loader): 68 | inputs, targets = inputs.to(device), targets.to(device) 69 | 70 | # forward 71 | outputs = net(inputs) 72 | 73 | # back-propagation 74 | optimizer.zero_grad() 75 | curr_loss = loss(outputs, targets) 76 | curr_loss.backward() 77 | if grad_clip is not None: 78 | nn.utils.clip_grad_norm_(net.parameters(), grad_clip) 79 | optimizer.step() 80 | 81 | # metrics 82 | train_loss += curr_loss.detach().cpu() 83 | _, predict = torch.max(outputs.data, 1) 84 | total += targets.size(0) 85 | correct += predict.eq(targets.data).sum().detach().cpu() 86 | 87 | if (batch_idx % print_frequency) == 0: 88 | print(f'Epoch={epoch}/{num_epochs} Batch={batch_idx + 1}/{n_batches} | ' 89 | f'Loss={train_loss / (batch_idx + 1):.3f}, ' 90 | f'Acc={correct / total:.3f}({correct}/{total})') 91 | 92 | last_loss = train_loss / (batch_idx + 1) 93 | acc = correct / total 94 | 95 | # save metrics 96 | metric_dict['train_loss'].append(last_loss.item()) 97 | metric_dict['train_accuracy'].append(acc.item()) 98 | 99 | if validation_loader is not None: 100 | test_metrics = test(net, validation_loader, loss, num_tests=num_validation, device=device) 101 | metric_dict['val_loss'].append(test_metrics['test_loss']) 102 | metric_dict['val_accuracy'].append(test_metrics['test_accuracy']) 103 | 104 | print('--------------------') 105 | scheduler.step() 106 | 107 | return metric_dict 108 | 109 | 110 | def test(net, test_loader, loss=None, num_tests=None, device=None): 111 | """ 112 | Evaluate the network on a test set. 113 | 114 | Args: 115 | net: Network for testing. 116 | test_loader: Test dataset. 117 | loss: Loss function, default is CrossEntropyLoss. 118 | num_tests: Number of test examples (for print purposes). 119 | device: Device to use. 120 | 121 | Returns: 122 | Test metrics. 123 | """ 124 | net = net.to(device) 125 | net.eval() 126 | 127 | if loss is None: 128 | loss = nn.CrossEntropyLoss() 129 | 130 | test_loss = 0 131 | correct = 0 132 | n_tests = 0 133 | 134 | with torch.no_grad(): 135 | for batch_idx, (inputs, targets) in enumerate(test_loader): 136 | inputs, targets = inputs.to(device), targets.to(device) 137 | 138 | outputs = net(inputs) 139 | 140 | curr_loss = loss(outputs, targets) 141 | test_loss += curr_loss.detach() 142 | _, predict = torch.max(outputs.data, 1) 143 | correct += predict.eq(targets.data).sum().detach() 144 | 145 | if num_tests is None: 146 | n_tests += len(targets) 147 | 148 | if num_tests is None: 149 | num_tests = n_tests 150 | 151 | print(f'Testing: Loss={(test_loss / len(test_loader)):.3f}, Acc={(correct / num_tests):.3f}' 152 | f'({correct}/{num_tests})') 153 | 154 | last_loss = test_loss / len(test_loader) if len(test_loader) > 0 else np.inf 155 | acc = correct / num_tests 156 | 157 | return {'test_loss': last_loss.item(), 'test_accuracy': acc.item()} 158 | -------------------------------------------------------------------------------- /setup.cfg: -------------------------------------------------------------------------------- 1 | [metadata] 2 | long_description = file: README.md 3 | long_description_content_type = text/markdown 4 | license_files=LICENSE.txt -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | import setuptools 2 | 3 | setuptools.setup( 4 | name='nasbench_pytorch', 5 | version='1.3.1', 6 | license='Apache License 2.0', 7 | author='Romulus Hong, Gabriela Kadlecová', 8 | packages=setuptools.find_packages() 9 | ) 10 | --------------------------------------------------------------------------------