├── .gitignore
├── Contributors.md
├── LICENSE.txt
├── README.md
├── assets
├── architecture.png
└── param_time_acc.png
├── main.py
├── nasbench_pytorch
├── __init__.py
├── datasets
│ ├── __init__.py
│ └── cifar10.py
├── model
│ ├── __init__.py
│ ├── base_ops.py
│ ├── graph_util.py
│ ├── model.py
│ └── model_spec.py
└── trainer.py
├── setup.cfg
└── setup.py
/.gitignore:
--------------------------------------------------------------------------------
1 |
2 | # Created by https://www.toptal.com/developers/gitignore/api/pycharm
3 | # Edit at https://www.toptal.com/developers/gitignore?templates=pycharm
4 |
5 | ### PyCharm ###
6 | # Covers JetBrains IDEs: IntelliJ, RubyMine, PhpStorm, AppCode, PyCharm, CLion, Android Studio, WebStorm and Rider
7 | # Reference: https://intellij-support.jetbrains.com/hc/en-us/articles/206544839
8 | data/*
9 | .idea/*
10 | checkpoint/*
11 | dist/*
12 | build/*
13 |
14 | *.ipynb
15 |
16 | **/__pycache__/**
17 |
18 | *.egg-info/*
19 |
20 | # User-specific stuff
21 | .idea/**/workspace.xml
22 | .idea/**/tasks.xml
23 | .idea/**/usage.statistics.xml
24 | .idea/**/dictionaries
25 | .idea/**/shelf
26 |
27 | # Generated files
28 | .idea/**/contentModel.xml
29 |
30 | # Sensitive or high-churn files
31 | .idea/**/dataSources/
32 | .idea/**/dataSources.ids
33 | .idea/**/dataSources.local.xml
34 | .idea/**/sqlDataSources.xml
35 | .idea/**/dynamic.xml
36 | .idea/**/uiDesigner.xml
37 | .idea/**/dbnavigator.xml
38 |
39 | # Gradle
40 | .idea/**/gradle.xml
41 | .idea/**/libraries
42 |
43 | # Gradle and Maven with auto-import
44 | # When using Gradle or Maven with auto-import, you should exclude module files,
45 | # since they will be recreated, and may cause churn. Uncomment if using
46 | # auto-import.
47 | # .idea/artifacts
48 | # .idea/compiler.xml
49 | # .idea/jarRepositories.xml
50 | # .idea/modules.xml
51 | # .idea/*.iml
52 | # .idea/modules
53 | # *.iml
54 | # *.ipr
55 |
56 | # CMake
57 | cmake-build-*/
58 |
59 | # Mongo Explorer plugin
60 | .idea/**/mongoSettings.xml
61 |
62 | # File-based project format
63 | *.iws
64 |
65 | # IntelliJ
66 | out/
67 |
68 | # mpeltonen/sbt-idea plugin
69 | .idea_modules/
70 |
71 | # JIRA plugin
72 | atlassian-ide-plugin.xml
73 |
74 | # Cursive Clojure plugin
75 | .idea/replstate.xml
76 |
77 | # Crashlytics plugin (for Android Studio and IntelliJ)
78 | com_crashlytics_export_strings.xml
79 | crashlytics.properties
80 | crashlytics-build.properties
81 | fabric.properties
82 |
83 | # Editor-based Rest Client
84 | .idea/httpRequests
85 |
86 | # Android studio 3.1+ serialized cache file
87 | .idea/caches/build_file_checksums.ser
88 |
89 | ### PyCharm Patch ###
90 | # Comment Reason: https://github.com/joeblau/gitignore.io/issues/186#issuecomment-215987721
91 |
92 | # *.iml
93 | # modules.xml
94 | # .idea/misc.xml
95 | # *.ipr
96 |
97 | # Sonarlint plugin
98 | # https://plugins.jetbrains.com/plugin/7973-sonarlint
99 | .idea/**/sonarlint/
100 |
101 | # SonarQube Plugin
102 | # https://plugins.jetbrains.com/plugin/7238-sonarqube-community-plugin
103 | .idea/**/sonarIssues.xml
104 |
105 | # Markdown Navigator plugin
106 | # https://plugins.jetbrains.com/plugin/7896-markdown-navigator-enhanced
107 | .idea/**/markdown-navigator.xml
108 | .idea/**/markdown-navigator-enh.xml
109 | .idea/**/markdown-navigator/
110 |
111 | # Cache file creation bug
112 | # See https://youtrack.jetbrains.com/issue/JBR-2257
113 | .idea/$CACHE_FILE$
114 |
115 | # CodeStream plugin
116 | # https://plugins.jetbrains.com/plugin/12206-codestream
117 | .idea/codestream.xml
118 |
119 | # End of https://www.toptal.com/developers/gitignore/api/pycharm
120 |
--------------------------------------------------------------------------------
/Contributors.md:
--------------------------------------------------------------------------------
1 | # Contributors
2 | - [@romulus0914](https://github.com/romulus0914) (Romulus Hong)
3 | - Author of the code - NAS-Bench-101 implementation in PyTorch
4 | - [@gabikadlecova](https://github.com/gabikadlecova)
5 | - Maintainer of the repository
6 | - Package structure, reproducibility
7 | ---------
8 | - [@abhash-er](https://github.com/abhash-er/) (Abhash Jha)
9 | - Modified the model code so that cast to double is possible
10 | - [@longerHost](https://github.com/longerHost)
11 | - Reproducibility of the original NAS-Bench-101
12 | - Comparison of training results and API results
13 |
--------------------------------------------------------------------------------
/LICENSE.txt:
--------------------------------------------------------------------------------
1 | Apache License
2 | Version 2.0, January 2004
3 | http://www.apache.org/licenses/
4 |
5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6 |
7 | 1. Definitions.
8 |
9 | "License" shall mean the terms and conditions for use, reproduction,
10 | and distribution as defined by Sections 1 through 9 of this document.
11 |
12 | "Licensor" shall mean the copyright owner or entity authorized by
13 | the copyright owner that is granting the License.
14 |
15 | "Legal Entity" shall mean the union of the acting entity and all
16 | other entities that control, are controlled by, or are under common
17 | control with that entity. For the purposes of this definition,
18 | "control" means (i) the power, direct or indirect, to cause the
19 | direction or management of such entity, whether by contract or
20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
21 | outstanding shares, or (iii) beneficial ownership of such entity.
22 |
23 | "You" (or "Your") shall mean an individual or Legal Entity
24 | exercising permissions granted by this License.
25 |
26 | "Source" form shall mean the preferred form for making modifications,
27 | including but not limited to software source code, documentation
28 | source, and configuration files.
29 |
30 | "Object" form shall mean any form resulting from mechanical
31 | transformation or translation of a Source form, including but
32 | not limited to compiled object code, generated documentation,
33 | and conversions to other media types.
34 |
35 | "Work" shall mean the work of authorship, whether in Source or
36 | Object form, made available under the License, as indicated by a
37 | copyright notice that is included in or attached to the work
38 | (an example is provided in the Appendix below).
39 |
40 | "Derivative Works" shall mean any work, whether in Source or Object
41 | form, that is based on (or derived from) the Work and for which the
42 | editorial revisions, annotations, elaborations, or other modifications
43 | represent, as a whole, an original work of authorship. For the purposes
44 | of this License, Derivative Works shall not include works that remain
45 | separable from, or merely link (or bind by name) to the interfaces of,
46 | the Work and Derivative Works thereof.
47 |
48 | "Contribution" shall mean any work of authorship, including
49 | the original version of the Work and any modifications or additions
50 | to that Work or Derivative Works thereof, that is intentionally
51 | submitted to Licensor for inclusion in the Work by the copyright owner
52 | or by an individual or Legal Entity authorized to submit on behalf of
53 | the copyright owner. For the purposes of this definition, "submitted"
54 | means any form of electronic, verbal, or written communication sent
55 | to the Licensor or its representatives, including but not limited to
56 | communication on electronic mailing lists, source code control systems,
57 | and issue tracking systems that are managed by, or on behalf of, the
58 | Licensor for the purpose of discussing and improving the Work, but
59 | excluding communication that is conspicuously marked or otherwise
60 | designated in writing by the copyright owner as "Not a Contribution."
61 |
62 | "Contributor" shall mean Licensor and any individual or Legal Entity
63 | on behalf of whom a Contribution has been received by Licensor and
64 | subsequently incorporated within the Work.
65 |
66 | 2. Grant of Copyright License. Subject to the terms and conditions of
67 | this License, each Contributor hereby grants to You a perpetual,
68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69 | copyright license to reproduce, prepare Derivative Works of,
70 | publicly display, publicly perform, sublicense, and distribute the
71 | Work and such Derivative Works in Source or Object form.
72 |
73 | 3. Grant of Patent License. Subject to the terms and conditions of
74 | this License, each Contributor hereby grants to You a perpetual,
75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76 | (except as stated in this section) patent license to make, have made,
77 | use, offer to sell, sell, import, and otherwise transfer the Work,
78 | where such license applies only to those patent claims licensable
79 | by such Contributor that are necessarily infringed by their
80 | Contribution(s) alone or by combination of their Contribution(s)
81 | with the Work to which such Contribution(s) was submitted. If You
82 | institute patent litigation against any entity (including a
83 | cross-claim or counterclaim in a lawsuit) alleging that the Work
84 | or a Contribution incorporated within the Work constitutes direct
85 | or contributory patent infringement, then any patent licenses
86 | granted to You under this License for that Work shall terminate
87 | as of the date such litigation is filed.
88 |
89 | 4. Redistribution. You may reproduce and distribute copies of the
90 | Work or Derivative Works thereof in any medium, with or without
91 | modifications, and in Source or Object form, provided that You
92 | meet the following conditions:
93 |
94 | (a) You must give any other recipients of the Work or
95 | Derivative Works a copy of this License; and
96 |
97 | (b) You must cause any modified files to carry prominent notices
98 | stating that You changed the files; and
99 |
100 | (c) You must retain, in the Source form of any Derivative Works
101 | that You distribute, all copyright, patent, trademark, and
102 | attribution notices from the Source form of the Work,
103 | excluding those notices that do not pertain to any part of
104 | the Derivative Works; and
105 |
106 | (d) If the Work includes a "NOTICE" text file as part of its
107 | distribution, then any Derivative Works that You distribute must
108 | include a readable copy of the attribution notices contained
109 | within such NOTICE file, excluding those notices that do not
110 | pertain to any part of the Derivative Works, in at least one
111 | of the following places: within a NOTICE text file distributed
112 | as part of the Derivative Works; within the Source form or
113 | documentation, if provided along with the Derivative Works; or,
114 | within a display generated by the Derivative Works, if and
115 | wherever such third-party notices normally appear. The contents
116 | of the NOTICE file are for informational purposes only and
117 | do not modify the License. You may add Your own attribution
118 | notices within Derivative Works that You distribute, alongside
119 | or as an addendum to the NOTICE text from the Work, provided
120 | that such additional attribution notices cannot be construed
121 | as modifying the License.
122 |
123 | You may add Your own copyright statement to Your modifications and
124 | may provide additional or different license terms and conditions
125 | for use, reproduction, or distribution of Your modifications, or
126 | for any such Derivative Works as a whole, provided Your use,
127 | reproduction, and distribution of the Work otherwise complies with
128 | the conditions stated in this License.
129 |
130 | 5. Submission of Contributions. Unless You explicitly state otherwise,
131 | any Contribution intentionally submitted for inclusion in the Work
132 | by You to the Licensor shall be under the terms and conditions of
133 | this License, without any additional terms or conditions.
134 | Notwithstanding the above, nothing herein shall supersede or modify
135 | the terms of any separate license agreement you may have executed
136 | with Licensor regarding such Contributions.
137 |
138 | 6. Trademarks. This License does not grant permission to use the trade
139 | names, trademarks, service marks, or product names of the Licensor,
140 | except as required for reasonable and customary use in describing the
141 | origin of the Work and reproducing the content of the NOTICE file.
142 |
143 | 7. Disclaimer of Warranty. Unless required by applicable law or
144 | agreed to in writing, Licensor provides the Work (and each
145 | Contributor provides its Contributions) on an "AS IS" BASIS,
146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147 | implied, including, without limitation, any warranties or conditions
148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149 | PARTICULAR PURPOSE. You are solely responsible for determining the
150 | appropriateness of using or redistributing the Work and assume any
151 | risks associated with Your exercise of permissions under this License.
152 |
153 | 8. Limitation of Liability. In no event and under no legal theory,
154 | whether in tort (including negligence), contract, or otherwise,
155 | unless required by applicable law (such as deliberate and grossly
156 | negligent acts) or agreed to in writing, shall any Contributor be
157 | liable to You for damages, including any direct, indirect, special,
158 | incidental, or consequential damages of any character arising as a
159 | result of this License or out of the use or inability to use the
160 | Work (including but not limited to damages for loss of goodwill,
161 | work stoppage, computer failure or malfunction, or any and all
162 | other commercial damages or losses), even if such Contributor
163 | has been advised of the possibility of such damages.
164 |
165 | 9. Accepting Warranty or Additional Liability. While redistributing
166 | the Work or Derivative Works thereof, You may choose to offer,
167 | and charge a fee for, acceptance of support, warranty, indemnity,
168 | or other liability obligations and/or rights consistent with this
169 | License. However, in accepting such obligations, You may act only
170 | on Your own behalf and on Your sole responsibility, not on behalf
171 | of any other Contributor, and only if You agree to indemnify,
172 | defend, and hold each Contributor harmless for any liability
173 | incurred by, or claims asserted against, such Contributor by reason
174 | of your accepting any such warranty or additional liability.
175 |
176 | END OF TERMS AND CONDITIONS
177 |
178 | APPENDIX: How to apply the Apache License to your work.
179 |
180 | To apply the Apache License to your work, attach the following
181 | boilerplate notice, with the fields enclosed by brackets "[]"
182 | replaced with your own identifying information. (Don't include
183 | the brackets!) The text should be enclosed in the appropriate
184 | comment syntax for the file format. We also recommend that a
185 | file or class name and description of purpose be included on the
186 | same "printed page" as the copyright notice for easier
187 | identification within third-party archives.
188 |
189 | Copyright [yyyy] [name of copyright owner]
190 |
191 | Licensed under the Apache License, Version 2.0 (the "License");
192 | you may not use this file except in compliance with the License.
193 | You may obtain a copy of the License at
194 |
195 | http://www.apache.org/licenses/LICENSE-2.0
196 |
197 | Unless required by applicable law or agreed to in writing, software
198 | distributed under the License is distributed on an "AS IS" BASIS,
199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200 | See the License for the specific language governing permissions and
201 | limitations under the License.
202 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # NASBench-PyTorch
2 | NASBench-PyTorch is a PyTorch implementation of the search space
3 | [NAS-Bench-101](https://github.com/google-research/nasbench) including the training of the networks[**](#note). The original
4 | implementation is written in TensorFlow, and this projects contains
5 | some files from the original repository (in the directory
6 | `nasbench_pytorch/model/`).
7 |
8 | **Important:** if you want to reproduce the original results, please refer to the
9 | [Reproducibility](#repro) section.
10 |
11 | # Overview
12 | A PyTorch implementation of *training* of NAS-Bench-101 dataset: [NAS-Bench-101: Towards Reproducible Neural Architecture Search](https://arxiv.org/abs/1902.09635).
13 | The dataset contains 423,624 unique neural networks exhaustively generated and evaluated from a fixed graph-based search space.
14 |
15 | # Usage
16 | You need to have PyTorch installed.
17 |
18 | You can install the package by running `pip install nasbench_pytorch`. The second possibility is to install from source code:
19 |
20 | 1. Clone this repo
21 | ```
22 | git clone https://github.com/romulus0914/NASBench-PyTorch
23 | cd NASBench-PyTorch
24 | ```
25 |
26 | 2. Install the project
27 | ```
28 | pip install -e .
29 | ```
30 |
31 | The file `main.py` contains an example training of a network. To see
32 | the different parameters, run:
33 |
34 | ```
35 | python main.py --help
36 | ```
37 |
38 | ### Train a network by hash
39 | To train a network whose architecture is queried from NAS-Bench-101
40 | using its unique hash, install the original [nasbench](https://github.com/google-research/nasbench)
41 | repository. Follow the instructions in the README, note that you
42 | need to install TensorFlow. If you need TensorFlow 2.x, install
43 | [this fork](https://github.com/gabrielasuchopar/nasbench) of the
44 | repository instead.
45 |
46 | Then, you can get the PyTorch architecture of a network like this:
47 |
48 | ```python
49 | from nasbench_pytorch.model import Network as NBNetwork
50 | from nasbench import api
51 |
52 |
53 | nasbench_path = '$path_to_downloaded_nasbench'
54 | nb = api.NASBench(nasbench_path)
55 |
56 | net_hash = '$some_hash' # you can get hashes using nasbench.hash_iterator()
57 | m = nb.get_metrics_from_hash(net_hash)
58 | ops = m[0]['module_operations']
59 | adjacency = m[0]['module_adjacency']
60 |
61 | net = NBNetwork((adjacency, ops))
62 | ```
63 |
64 | Then, you can train it just like the example network in `main.py`.
65 |
66 | # Architecture
67 | Example architecture (picture from the original repository)
68 | 
69 |
70 | # Reproducibility
71 | The code should closely match the TensorFlow version (including the hyperparameters), but there are some differences:
72 | - RMSProp implementation in TensorFlow and PyTorch is **different**
73 | - For more information refer to [here](https://github.com/pytorch/pytorch/issues/32545) and [here](https://github.com/pytorch/pytorch/issues/23796).
74 | - Optionally, you can install pytorch-image-models where a [TensorFlow-like RMSProp](https://github.com/rwightman/pytorch-image-models/blob/main/timm/optim/rmsprop_tf.py#L5) is implemented
75 | - `pip install timm`
76 | - Then, pass `--optimizer rmsprop_tf` to `main.py` to use it
77 |
78 |
79 | - You can turn gradient clipping off by setting `--grad_clip_off True`
80 |
81 |
82 | - The original training was on TPUs, this code enables only GPU and CPU training
83 | - Input data augmentation methods are the same, but due to randomness they are not applied in the same manner
84 | - Cause: Batches and images cannot be shuffled as in the original TPU training, and the augmentation seed is also different
85 | - Results may still differ due to TensorFlow/PyTorch implementation differences
86 |
87 | Refer to this [issue](https://github.com/romulus0914/NASBench-PyTorch/issues/6) for more information and for comparison with API results.
88 |
89 | # Disclaimer
90 | Modified from [NASBench: A Neural Architecture Search Dataset and Benchmark](https://github.com/google-research/nasbench).
91 | *graph_util.py* and *model_spec.py* are directly copied from the original repo. Original license can be found [here](https://github.com/google-research/nasbench/blob/master/LICENSE).
92 |
93 |
94 | **Please note that this repo is only used to train one possible architecture in the search space, not to generate all possible graphs and train them.
95 |
--------------------------------------------------------------------------------
/assets/architecture.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/romulus0914/NASBench-PyTorch/37bc332bc6802e0375d8cca839d54f47524487fb/assets/architecture.png
--------------------------------------------------------------------------------
/assets/param_time_acc.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/romulus0914/NASBench-PyTorch/37bc332bc6802e0375d8cca839d54f47524487fb/assets/param_time_acc.png
--------------------------------------------------------------------------------
/main.py:
--------------------------------------------------------------------------------
1 | import os
2 | import argparse
3 |
4 | import torch
5 | import torch.nn as nn
6 | import torch.optim as optim
7 |
8 | from nasbench_pytorch.datasets.cifar10 import prepare_dataset
9 | from nasbench_pytorch.model import Network
10 | from nasbench_pytorch.model import ModelSpec
11 | from nasbench_pytorch.trainer import train, test
12 |
13 | matrix = [[0, 1, 1, 1, 0, 1, 0],
14 | [0, 0, 0, 0, 0, 0, 1],
15 | [0, 0, 0, 0, 0, 0, 1],
16 | [0, 0, 0, 0, 1, 0, 0],
17 | [0, 0, 0, 0, 0, 0, 1],
18 | [0, 0, 0, 0, 0, 0, 1],
19 | [0, 0, 0, 0, 0, 0, 0]]
20 |
21 | operations = ['input', 'conv1x1-bn-relu', 'conv3x3-bn-relu', 'conv3x3-bn-relu', 'conv3x3-bn-relu',
22 | 'maxpool3x3', 'output']
23 |
24 |
25 | def save_checkpoint(net, postfix='cifar10'):
26 | print('--- Saving Checkpoint ---')
27 |
28 | if not os.path.isdir('checkpoint'):
29 | os.mkdir('checkpoint')
30 |
31 | torch.save(net.state_dict(), './checkpoint/ckpt_' + postfix + '.pt')
32 |
33 | def reload_checkpoint(path, device=None):
34 | print('--- Reloading Checkpoint ---')
35 |
36 | assert os.path.isdir('checkpoint'), '[Error] No checkpoint directory found!'
37 | return torch.load(path, map_location=device)
38 |
39 |
40 | if __name__ == '__main__':
41 | parser = argparse.ArgumentParser(description='NASBench')
42 | parser.add_argument('--random_state', default=1, type=int, help='Random seed.')
43 | parser.add_argument('--data_root', default='./data/', type=str, help='Path where cifar will be downloaded.')
44 | parser.add_argument('--in_channels', default=3, type=int, help='Number of input channels.')
45 | parser.add_argument('--stem_out_channels', default=128, type=int, help='output channels of stem convolution')
46 | parser.add_argument('--num_stacks', default=3, type=int, help='#stacks of modules')
47 | parser.add_argument('--num_modules_per_stack', default=3, type=int, help='#modules per stack')
48 | parser.add_argument('--batch_size', default=256, type=int, help='batch size')
49 | parser.add_argument('--test_batch_size', default=256, type=int, help='test set batch size')
50 | parser.add_argument('--epochs', default=108, type=int, help='#epochs of training')
51 | parser.add_argument('--validation_size', default=10000, type=int, help="Size of the validation set to split off.")
52 | parser.add_argument('--num_workers', default=0, type=int, help="Number of parallel workers for the train dataset.")
53 | parser.add_argument('--learning_rate', default=0.2, type=float, help='base learning rate')
54 | parser.add_argument('--optimizer', default='rmsprop', type=str, help='Optimizer (sgd, rmsprop or rmsprop_tf)')
55 | parser.add_argument('--rmsprop_eps', default=1.0, type=float, help='RMSProp eps parameter.')
56 | parser.add_argument('--momentum', default=0.9, type=float, help='momentum')
57 | parser.add_argument('--weight_decay', default=1e-4, type=float, help='L2 regularization weight')
58 | parser.add_argument('--grad_clip', default=5, type=float, help='gradient clipping')
59 | parser.add_argument('--grad_clip_off', default=False, type=bool, help='If True, turn off gradient clipping.')
60 | parser.add_argument('--batch_norm_momentum', default=0.997, type=float, help='Batch normalization momentum')
61 | parser.add_argument('--batch_norm_eps', default=1e-5, type=float, help='Batch normalization epsilon')
62 | parser.add_argument('--load_checkpoint', default='', type=str, help='Reload model from checkpoint')
63 | parser.add_argument('--num_labels', default=10, type=int, help='#classes')
64 | parser.add_argument('--device', default='cuda', type=str, help='Device for network training.')
65 | parser.add_argument('--print_freq', default=100, type=int, help='Batch print frequency.')
66 | parser.add_argument('--tf_like', default=False, type=bool,
67 | help='If true, use same weight initialization as in the tensorflow version.')
68 |
69 | args = parser.parse_args()
70 |
71 | # cifar10 dataset
72 | dataset = prepare_dataset(args.batch_size, test_batch_size=args.test_batch_size, root=args.data_root,
73 | validation_size=args.validation_size, random_state=args.random_state,
74 | set_global_seed=True, num_workers=args.num_workers)
75 |
76 | train_loader, test_loader, test_size = dataset['train'], dataset['test'], dataset['test_size']
77 | valid_loader = dataset['validation'] if args.validation_size > 0 else None
78 |
79 | # model
80 | spec = ModelSpec(matrix, operations)
81 | net = Network(spec, num_labels=args.num_labels, in_channels=args.in_channels,
82 | stem_out_channels=args.stem_out_channels, num_stacks=args.num_stacks,
83 | num_modules_per_stack=args.num_modules_per_stack,
84 | momentum=args.batch_norm_momentum, eps=args.batch_norm_eps, tf_like=args.tf_like)
85 |
86 | if args.load_checkpoint != '':
87 | net.load_state_dict(reload_checkpoint(args.load_checkpoint))
88 | net.to(args.device)
89 |
90 | criterion = nn.CrossEntropyLoss()
91 |
92 | if args.optimizer.lower() == 'sgd':
93 | optimizer = optim.SGD
94 | optimizer_kwargs = {}
95 | elif args.optimizer.lower() == 'rmsprop':
96 | optimizer = optim.RMSprop
97 | optimizer_kwargs = {'eps': args.rmsprop_eps}
98 | elif args.optimizer.lower() == 'rmsprop_tf':
99 | from timm.optim import RMSpropTF
100 | optimizer = RMSpropTF
101 | optimizer_kwargs = {'eps': args.rmsprop_eps}
102 | else:
103 | raise ValueError(f"Invalid optimizer {args.optimizer}, possible: SGD, RMSProp")
104 |
105 | optimizer = optimizer(net.parameters(), lr=args.learning_rate, momentum=args.momentum,
106 | weight_decay=args.weight_decay, **optimizer_kwargs)
107 | scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, args.epochs)
108 |
109 | result = train(net, train_loader, loss=criterion, optimizer=optimizer, scheduler=scheduler,
110 | grad_clip=args.grad_clip if not args.grad_clip_off else None,
111 | num_epochs=args.epochs, num_validation=args.validation_size, validation_loader=valid_loader,
112 | device=args.device, print_frequency=args.print_freq)
113 |
114 | last_epoch = {k: v[-1] for k, v in result.items() if len(v) > 0}
115 | print(f"Final train metrics: {last_epoch}")
116 |
117 | result = test(net, test_loader, loss=criterion, num_tests=test_size, device=args.device)
118 | print(f"\nFinal test metrics: {result}")
119 |
120 | save_checkpoint(net)
121 |
--------------------------------------------------------------------------------
/nasbench_pytorch/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/romulus0914/NASBench-PyTorch/37bc332bc6802e0375d8cca839d54f47524487fb/nasbench_pytorch/__init__.py
--------------------------------------------------------------------------------
/nasbench_pytorch/datasets/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/romulus0914/NASBench-PyTorch/37bc332bc6802e0375d8cca839d54f47524487fb/nasbench_pytorch/datasets/__init__.py
--------------------------------------------------------------------------------
/nasbench_pytorch/datasets/cifar10.py:
--------------------------------------------------------------------------------
1 | """
2 | Specific transforms and constants have been extracted from
3 | https://github.com/google-research/nasbench/blob/master/nasbench/lib/cifar.py
4 | """
5 | import random
6 | from functools import partial
7 | import numpy as np
8 | import torch
9 | import torchvision
10 | import torchvision.transforms as transforms
11 |
12 | from torch.utils.data.sampler import SubsetRandomSampler
13 |
14 |
15 | def train_valid_split(dataset_size, valid_size, random_state=None):
16 | random = np.random.RandomState(seed=random_state) if random_state is not None else np.random
17 | valid_inds = random.choice(dataset_size, size=valid_size, replace=False)
18 |
19 | train_inds = np.delete(np.arange(dataset_size), valid_inds)
20 |
21 | return SubsetRandomSampler(train_inds), SubsetRandomSampler(valid_inds)
22 |
23 |
24 | def seed_worker(seed, worker_id):
25 | seed = seed if seed is not None else 0
26 | worker_seed = seed + worker_id
27 | np.random.seed(worker_seed)
28 | random.seed(worker_seed)
29 |
30 |
31 | def prepare_dataset(batch_size, test_batch_size=256, root='./data/', use_validation=True, split_from_end=True,
32 | validation_size=10000, random_state=None, set_global_seed=False, no_valid_transform=True,
33 | num_workers=0, num_val_workers=0, num_test_workers=0):
34 | """
35 | Download the CIFAR-10 dataset and prepare train and test DataLoaders (optionally also validation loader).
36 |
37 | Args:
38 | batch_size: Batch size for the train (and validation) loader.
39 | test_batch_size: Batch size for the test loader.
40 | root: Directory path to download the CIFAR-10 dataset to.
41 | use_validation: If False, don't split off the validation set.
42 | split_from_end: If True, split off `validation_size` images from the end, if False, choose images randomly.
43 | validation_size: Size of the validation dataset to split off the train set.
44 |
45 | random_state: Seed for the random functions (generators from numpy and random)
46 | set_global_seed: If True, call np.random.seed(random_state) and random.seed(random_state). Useful when
47 | using 0 workers (because otherwise RandomCrop will return different results every call), but affects
48 | the seed in the whole program.
49 |
50 | no_valid_transform: If True, don't use RandomCrop and RandomFlip for the validation set.
51 | num_workers: Number of workers for the train loader.
52 | num_val_workers: Number of workers for the validation loader.
53 | num_test_workers: Number of workers for the test loader.
54 |
55 | Returns:
56 | if validation_size > 0:
57 | train loader, train size, validation loader, validation size, test loader, test size
58 | otherwise:
59 | train loader, train size, test loader, test size
60 |
61 | The sizes are dataset sizes, not the number of batches.
62 |
63 | """
64 |
65 | if set_global_seed:
66 | seed_worker(random_state, 0)
67 |
68 | if random_state is not None:
69 | worker_fn = partial(seed_worker, random_state)
70 | else:
71 | worker_fn = None
72 |
73 | print('\n--- Preparing CIFAR10 Data ---')
74 |
75 | train_transform = transforms.Compose([
76 | transforms.RandomCrop(32, padding=4),
77 | transforms.RandomHorizontalFlip(),
78 | transforms.ToTensor(),
79 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
80 | ])
81 |
82 | test_transform = transforms.Compose([
83 | transforms.ToTensor(),
84 | transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
85 | ])
86 |
87 | train_set = torchvision.datasets.CIFAR10(root=root, train=True, download=True, transform=train_transform)
88 | valid_set = torchvision.datasets.CIFAR10(root=root, train=True, download=True, transform=test_transform)
89 | valid_set = valid_set if no_valid_transform else train_set
90 | train_size = len(train_set)
91 |
92 | if use_validation:
93 | if split_from_end:
94 | # get last n images
95 | indices = np.arange(len(train_set))
96 | train_set = torch.utils.data.Subset(train_set, indices[:-validation_size])
97 | valid_set = torch.utils.data.Subset(valid_set, indices[-validation_size:])
98 | train_sampler, valid_sampler = None, None
99 | else:
100 | # split off random validation set
101 | train_sampler, valid_sampler = train_valid_split(train_size, validation_size, random_state=random_state)
102 |
103 | # shuffle is True if split_from_end otherwise False
104 | train_loader = torch.utils.data.DataLoader(train_set, batch_size=batch_size, shuffle=split_from_end,
105 | sampler=train_sampler, num_workers=num_workers,
106 | worker_init_fn=worker_fn)
107 | valid_loader = torch.utils.data.DataLoader(valid_set, batch_size=batch_size, shuffle=False,
108 | sampler=valid_sampler, num_workers=num_val_workers,
109 | worker_init_fn=worker_fn)
110 | else:
111 | train_loader = torch.utils.data.DataLoader(train_set, batch_size=batch_size, shuffle=True,
112 | num_workers=num_workers, worker_init_fn=worker_fn)
113 | valid_loader = None
114 |
115 | test_set = torchvision.datasets.CIFAR10(root=root, train=False, download=True, transform=test_transform)
116 | test_loader = torch.utils.data.DataLoader(test_set, batch_size=test_batch_size, shuffle=False,
117 | num_workers=num_test_workers, worker_init_fn=worker_fn)
118 | test_size = len(test_set)
119 |
120 | print('--- CIFAR10 Data Prepared ---\n')
121 |
122 | data = {
123 | 'train': train_loader,
124 | 'train_size': train_size,
125 | 'test': test_loader,
126 | 'test_size': test_size
127 | }
128 |
129 | if validation_size > 0:
130 | data['train_size'] = train_size - validation_size
131 | data['validation'] = valid_loader
132 | data['validation_size'] = validation_size
133 |
134 | return data
135 |
--------------------------------------------------------------------------------
/nasbench_pytorch/model/__init__.py:
--------------------------------------------------------------------------------
1 | from .model import Network
2 | from .model_spec import ModelSpec
3 |
4 | __all__ = [Network, ModelSpec]
5 |
--------------------------------------------------------------------------------
/nasbench_pytorch/model/base_ops.py:
--------------------------------------------------------------------------------
1 | """Base operations used by the modules in this search space."""
2 |
3 | from __future__ import absolute_import
4 | from __future__ import division
5 | from __future__ import print_function
6 |
7 | import torch
8 | import torch.nn as nn
9 | import torch.nn.functional as F
10 |
11 | class ConvBnRelu(nn.Module):
12 | def __init__(self, in_channels, out_channels, kernel_size=1, stride=1, padding=0, momentum=0.1, eps=1e-5):
13 | super(ConvBnRelu, self).__init__()
14 |
15 | self.conv_bn_relu = nn.Sequential(
16 | nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, bias=False),
17 | nn.BatchNorm2d(out_channels, eps=eps, momentum=momentum),
18 | nn.ReLU()
19 | )
20 |
21 | def forward(self, x):
22 | return self.conv_bn_relu(x)
23 |
24 | class Conv3x3BnRelu(nn.Module):
25 | """3x3 convolution with batch norm and ReLU activation."""
26 | def __init__(self, in_channels, out_channels, **kwargs):
27 | super(Conv3x3BnRelu, self).__init__()
28 |
29 | self.conv3x3 = ConvBnRelu(in_channels, out_channels, 3, 1, 1, **kwargs)
30 |
31 | def forward(self, x):
32 | x = self.conv3x3(x)
33 | return x
34 |
35 | class Conv1x1BnRelu(nn.Module):
36 | """1x1 convolution with batch norm and ReLU activation."""
37 | def __init__(self, in_channels, out_channels, **kwargs):
38 | super(Conv1x1BnRelu, self).__init__()
39 |
40 | self.conv1x1 = ConvBnRelu(in_channels, out_channels, 1, 1, 0, **kwargs)
41 |
42 | def forward(self, x):
43 | x = self.conv1x1(x)
44 | return x
45 |
46 | class MaxPool3x3(nn.Module):
47 | """3x3 max pool with no subsampling."""
48 | def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, padding=1):
49 | super(MaxPool3x3, self).__init__()
50 |
51 | self.maxpool = nn.MaxPool2d(kernel_size, stride, padding)
52 |
53 | def forward(self, x):
54 | x = self.maxpool(x)
55 | return x
56 |
57 | # Commas should not be used in op names
58 | OP_MAP = {
59 | 'conv3x3-bn-relu': Conv3x3BnRelu,
60 | 'conv1x1-bn-relu': Conv1x1BnRelu,
61 | 'maxpool3x3': MaxPool3x3
62 | }
63 |
--------------------------------------------------------------------------------
/nasbench_pytorch/model/graph_util.py:
--------------------------------------------------------------------------------
1 | # Copyright 2019 The Google Research Authors.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Utility functions used by generate_graph.py."""
16 | from __future__ import absolute_import
17 | from __future__ import division
18 | from __future__ import print_function
19 |
20 | import hashlib
21 | import itertools
22 |
23 | import numpy as np
24 |
25 |
26 | def gen_is_edge_fn(bits):
27 | """Generate a boolean function for the edge connectivity.
28 |
29 | Given a bitstring FEDCBA and a 4x4 matrix, the generated matrix is
30 | [[0, A, B, D],
31 | [0, 0, C, E],
32 | [0, 0, 0, F],
33 | [0, 0, 0, 0]]
34 |
35 | Note that this function is agnostic to the actual matrix dimension due to
36 | order in which elements are filled out (column-major, starting from least
37 | significant bit). For example, the same FEDCBA bitstring (0-padded) on a 5x5
38 | matrix is
39 | [[0, A, B, D, 0],
40 | [0, 0, C, E, 0],
41 | [0, 0, 0, F, 0],
42 | [0, 0, 0, 0, 0],
43 | [0, 0, 0, 0, 0]]
44 |
45 | Args:
46 | bits: integer which will be interpreted as a bit mask.
47 |
48 | Returns:
49 | vectorized function that returns True when an edge is present.
50 | """
51 | def is_edge(x, y):
52 | """Is there an edge from x to y (0-indexed)?"""
53 | if x >= y:
54 | return 0
55 | # Map x, y to index into bit string
56 | index = x + (y * (y - 1) // 2)
57 | return (bits >> index) % 2 == 1
58 |
59 | return np.vectorize(is_edge)
60 |
61 |
62 | def is_full_dag(matrix):
63 | """Full DAG == all vertices on a path from vert 0 to (V-1).
64 |
65 | i.e. no disconnected or "hanging" vertices.
66 |
67 | It is sufficient to check for:
68 | 1) no rows of 0 except for row V-1 (only output vertex has no out-edges)
69 | 2) no cols of 0 except for col 0 (only input vertex has no in-edges)
70 |
71 | Args:
72 | matrix: V x V upper-triangular adjacency matrix
73 |
74 | Returns:
75 | True if the there are no dangling vertices.
76 | """
77 | shape = np.shape(matrix)
78 |
79 | rows = matrix[:shape[0]-1, :] == 0
80 | rows = np.all(rows, axis=1) # Any row with all 0 will be True
81 | rows_bad = np.any(rows)
82 |
83 | cols = matrix[:, 1:] == 0
84 | cols = np.all(cols, axis=0) # Any col with all 0 will be True
85 | cols_bad = np.any(cols)
86 |
87 | return (not rows_bad) and (not cols_bad)
88 |
89 |
90 | def num_edges(matrix):
91 | """Computes number of edges in adjacency matrix."""
92 | return np.sum(matrix)
93 |
94 |
95 | def hash_module(matrix, labeling):
96 | """Computes a graph-invariance MD5 hash of the matrix and label pair.
97 |
98 | Args:
99 | matrix: np.ndarray square upper-triangular adjacency matrix.
100 | labeling: list of int labels of length equal to both dimensions of
101 | matrix.
102 |
103 | Returns:
104 | MD5 hash of the matrix and labeling.
105 | """
106 | vertices = np.shape(matrix)[0]
107 | in_edges = np.sum(matrix, axis=0).tolist()
108 | out_edges = np.sum(matrix, axis=1).tolist()
109 |
110 | assert len(in_edges) == len(out_edges) == len(labeling)
111 | hashes = list(zip(out_edges, in_edges, labeling))
112 | hashes = [hashlib.md5(str(h).encode('utf-8')).hexdigest() for h in hashes]
113 | # Computing this up to the diameter is probably sufficient but since the
114 | # operation is fast, it is okay to repeat more times.
115 | for _ in range(vertices):
116 | new_hashes = []
117 | for v in range(vertices):
118 | in_neighbors = [hashes[w] for w in range(vertices) if matrix[w, v]]
119 | out_neighbors = [hashes[w] for w in range(vertices) if matrix[v, w]]
120 | new_hashes.append(hashlib.md5(
121 | (''.join(sorted(in_neighbors)) + '|' +
122 | ''.join(sorted(out_neighbors)) + '|' +
123 | hashes[v]).encode('utf-8')).hexdigest())
124 | hashes = new_hashes
125 | fingerprint = hashlib.md5(str(sorted(hashes)).encode('utf-8')).hexdigest()
126 |
127 | return fingerprint
128 |
129 |
130 | def permute_graph(graph, label, permutation):
131 | """Permutes the graph and labels based on permutation.
132 |
133 | Args:
134 | graph: np.ndarray adjacency matrix.
135 | label: list of labels of same length as graph dimensions.
136 | permutation: a permutation list of ints of same length as graph dimensions.
137 |
138 | Returns:
139 | np.ndarray where vertex permutation[v] is vertex v from the original graph
140 | """
141 | # vertex permutation[v] in new graph is vertex v in the old graph
142 | forward_perm = zip(permutation, list(range(len(permutation))))
143 | inverse_perm = [x[1] for x in sorted(forward_perm)]
144 | edge_fn = lambda x, y: graph[inverse_perm[x], inverse_perm[y]] == 1
145 | new_matrix = np.fromfunction(np.vectorize(edge_fn),
146 | (len(label), len(label)),
147 | dtype=np.int8)
148 | new_label = [label[inverse_perm[i]] for i in range(len(label))]
149 | return new_matrix, new_label
150 |
151 |
152 | def is_isomorphic(graph1, graph2):
153 | """Exhaustively checks if 2 graphs are isomorphic."""
154 | matrix1, label1 = np.array(graph1[0]), graph1[1]
155 | matrix2, label2 = np.array(graph2[0]), graph2[1]
156 | assert np.shape(matrix1) == np.shape(matrix2)
157 | assert len(label1) == len(label2)
158 |
159 | vertices = np.shape(matrix1)[0]
160 | # Note: input and output in our constrained graphs always map to themselves
161 | # but this script does not enforce that.
162 | for perm in itertools.permutations(range(0, vertices)):
163 | pmatrix1, plabel1 = permute_graph(matrix1, label1, perm)
164 | if np.array_equal(pmatrix1, matrix2) and plabel1 == label2:
165 | return True
166 |
167 | return False
168 |
--------------------------------------------------------------------------------
/nasbench_pytorch/model/model.py:
--------------------------------------------------------------------------------
1 | """Builds the Pytorch computational graph.
2 |
3 | Tensors flowing into a single vertex are added together for all vertices
4 | except the output, which is concatenated instead. Tensors flowing out of input
5 | are always added.
6 |
7 | If interior edge channels don't match, drop the extra channels (channels are
8 | guaranteed non-decreasing). Tensors flowing out of the input as always
9 | projected instead.
10 | """
11 |
12 | from __future__ import absolute_import
13 | from __future__ import division
14 | from __future__ import print_function
15 |
16 | import numpy as np
17 | import math
18 |
19 | from nasbench_pytorch.model.base_ops import *
20 | from nasbench_pytorch.model.model_spec import ModelSpec
21 |
22 | import torch
23 | import torch.nn as nn
24 | from torch.nn.init import _calculate_fan_in_and_fan_out
25 |
26 |
27 | class Network(nn.Module):
28 | def __init__(self, spec, num_labels=10, in_channels=3, stem_out_channels=128, num_stacks=3, num_modules_per_stack=3,
29 | momentum=0.997, eps=1e-5, tf_like=False):
30 | """
31 |
32 | Args:
33 | spec: ModelSpec from nasbench, or a tuple (adjacency matrix, ops)
34 | num_labels: Number of output labels.
35 | in_channels: Number of input image channels.
36 | stem_out_channels: Number of output stem channels. Other hidden channels are computed and depend on this
37 | number.
38 |
39 | num_stacks: Number of stacks, in every stacks the cells have the same number of channels.
40 | num_modules_per_stack: Number of cells per stack.
41 | """
42 | super(Network, self).__init__()
43 |
44 | if isinstance(spec, tuple):
45 | spec = ModelSpec(spec[0], spec[1])
46 |
47 | self.cell_indices = set()
48 |
49 | self.tf_like = tf_like
50 | self.layers = nn.ModuleList([])
51 |
52 | # initial stem convolution
53 | out_channels = stem_out_channels
54 | stem_conv = ConvBnRelu(in_channels, out_channels, 3, 1, 1, momentum=momentum, eps=eps)
55 | self.layers.append(stem_conv)
56 |
57 | # stacked cells
58 | in_channels = out_channels
59 | for stack_num in range(num_stacks):
60 | # downsample after every but the last cell
61 | if stack_num > 0:
62 | downsample = nn.MaxPool2d(kernel_size=2, stride=2)
63 | self.layers.append(downsample)
64 |
65 | out_channels *= 2
66 |
67 | for module_num in range(num_modules_per_stack):
68 | cell = Cell(spec, in_channels, out_channels, momentum=momentum, eps=eps)
69 | self.layers.append(cell)
70 | in_channels = out_channels
71 |
72 | self.cell_indices.add(len(self.layers) - 1)
73 |
74 | self.classifier = nn.Linear(out_channels, num_labels)
75 |
76 | self._initialize_weights()
77 |
78 | def forward(self, x):
79 | for _, layer in enumerate(self.layers):
80 | x = layer(x)
81 | out = torch.mean(x, (2, 3))
82 | out = self.classifier(out)
83 |
84 | return out
85 |
86 | def _initialize_weights(self):
87 | for m in self.modules():
88 | if isinstance(m, nn.Conv2d):
89 | if self.tf_like:
90 | fan_in, _ = _calculate_fan_in_and_fan_out(m.weight)
91 | torch.nn.init.normal_(m.weight, mean=0, std=1.0 / torch.sqrt(torch.tensor(fan_in)))
92 | else:
93 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
94 | m.weight.data.normal_(0, math.sqrt(2.0 / n))
95 |
96 | if m.bias is not None:
97 | m.bias.data.zero_()
98 |
99 | elif isinstance(m, nn.BatchNorm2d):
100 | m.weight.data.fill_(1)
101 | m.bias.data.zero_()
102 | elif isinstance(m, nn.Linear):
103 | if self.tf_like:
104 | torch.nn.init.xavier_uniform_(m.weight)
105 | else:
106 | m.weight.data.normal_(0, 0.01)
107 | m.bias.data.zero_()
108 |
109 |
110 | class Cell(nn.Module):
111 | """
112 | Builds the model using the adjacency matrix and op labels specified. Channels
113 | control the module output channel count but the interior channels are
114 | determined via equally splitting the channel count whenever there is a
115 | concatenation of Tensors.
116 | """
117 | def __init__(self, spec, in_channels, out_channels, momentum=0.1, eps=1e-5):
118 | super(Cell, self).__init__()
119 |
120 | self.dev_param = nn.Parameter(torch.empty(0))
121 |
122 | self.matrix = spec.matrix
123 | self.num_vertices = np.shape(self.matrix)[0]
124 |
125 | # vertex_channels[i] = number of output channels of vertex i
126 | self.vertex_channels = compute_vertex_channels(in_channels, out_channels, self.matrix)
127 | #self.vertex_channels = [in_channels] + [out_channels] * (self.num_vertices - 1)
128 |
129 | # operation for each node
130 | self.vertex_op = nn.ModuleList([Placeholder()])
131 | for t in range(1, self.num_vertices-1):
132 | op = OP_MAP[spec.ops[t]](self.vertex_channels[t], self.vertex_channels[t])
133 | self.vertex_op.append(op)
134 |
135 | # operation for input on each vertex
136 | self.input_op = nn.ModuleList([Placeholder()])
137 | for t in range(1, self.num_vertices):
138 | if self.matrix[0, t]:
139 | self.input_op.append(projection(in_channels, self.vertex_channels[t], momentum=momentum, eps=eps))
140 | else:
141 | self.input_op.append(Placeholder())
142 |
143 | self.last_inop : projection = self.input_op[self.num_vertices - 1]
144 |
145 | def forward(self, x):
146 | tensors = [x]
147 |
148 | out_concat = []
149 | # range(1, self.num_vertices - 1),
150 | for t, (inmod, outmod) in enumerate(zip(self.input_op, self.vertex_op)):
151 | if 0 < t < (self.num_vertices - 1):
152 |
153 | fan_in = []
154 | for src in range(1, t):
155 | if self.matrix[src, t]:
156 | fan_in.append(truncate(tensors[src], torch.tensor(self.vertex_channels[t])))
157 |
158 | if self.matrix[0, t]:
159 | l = inmod(x)
160 | fan_in.append(l)
161 |
162 | # perform operation on node
163 | vertex_input = torch.zeros_like(fan_in[0]).to(self.dev_param.device)
164 | for val in fan_in:
165 | vertex_input += val
166 |
167 | vertex_output = outmod(vertex_input)
168 |
169 | tensors.append(vertex_output)
170 | if self.matrix[t, self.num_vertices-1]:
171 | out_concat.append(tensors[t])
172 |
173 | if not out_concat:
174 | assert self.matrix[0, self.num_vertices-1]
175 | outputs = self.last_inop(tensors[0])
176 | else:
177 | if len(out_concat) == 1:
178 | outputs = out_concat[0]
179 | else:
180 | outputs = torch.cat(out_concat, 1)
181 |
182 | if self.matrix[0, self.num_vertices-1]:
183 | outputs = outputs + self.last_inop(tensors[0])
184 |
185 | return outputs
186 |
187 |
188 | def projection(in_channels, out_channels, momentum=0.1, eps=1e-5):
189 | """1x1 projection (as in ResNet) followed by batch normalization and ReLU."""
190 | return ConvBnRelu(in_channels, out_channels, 1, momentum=momentum, eps=eps)
191 |
192 |
193 | def truncate(inputs, channels):
194 | """Slice the inputs to channels if necessary."""
195 | input_channels = inputs.size()[1]
196 | if input_channels < channels:
197 | raise ValueError('input channel < output channels for truncate')
198 | elif input_channels == channels:
199 | return inputs # No truncation necessary
200 | else:
201 | # Truncation should only be necessary when channel division leads to
202 | # vertices with +1 channels. The input vertex should always be projected to
203 | # the minimum channel count.
204 | assert input_channels - channels == 1
205 | return inputs[:, :channels, :, :]
206 |
207 |
208 | def compute_vertex_channels(in_channels, out_channels, matrix):
209 | """Computes the number of channels at every vertex.
210 |
211 | Given the input channels and output channels, this calculates the number of
212 | channels at each interior vertex. Interior vertices have the same number of
213 | channels as the max of the channels of the vertices it feeds into. The output
214 | channels are divided amongst the vertices that are directly connected to it.
215 | When the division is not even, some vertices may receive an extra channel to
216 | compensate.
217 |
218 | Code from https://github.com/google-research/nasbench/
219 |
220 | Returns:
221 | list of channel counts, in order of the vertices.
222 | """
223 | if isinstance(matrix, torch.Tensor):
224 | matrix = matrix.numpy()
225 |
226 | num_vertices = np.shape(matrix)[0]
227 |
228 | vertex_channels = [0] * num_vertices
229 | vertex_channels[0] = in_channels
230 | vertex_channels[num_vertices - 1] = out_channels
231 |
232 | if num_vertices == 2:
233 | # Edge case where module only has input and output vertices
234 | return vertex_channels
235 |
236 | # Compute the in-degree ignoring input, axis 0 is the src vertex and axis 1 is
237 | # the dst vertex. Summing over 0 gives the in-degree count of each vertex.
238 | in_degree = np.sum(matrix[1:], axis=0)
239 | interior_channels = out_channels // in_degree[num_vertices - 1]
240 | correction = out_channels % in_degree[num_vertices - 1] # Remainder to add
241 |
242 | # Set channels of vertices that flow directly to output
243 | for v in range(1, num_vertices - 1):
244 | if matrix[v, num_vertices - 1]:
245 | vertex_channels[v] = interior_channels
246 | if correction:
247 | vertex_channels[v] += 1
248 | correction -= 1
249 |
250 | # Set channels for all other vertices to the max of the out edges, going
251 | # backwards. (num_vertices - 2) index skipped because it only connects to
252 | # output.
253 | for v in range(num_vertices - 3, 0, -1):
254 | if not matrix[v, num_vertices - 1]:
255 | for dst in range(v + 1, num_vertices - 1):
256 | if matrix[v, dst]:
257 | vertex_channels[v] = max(vertex_channels[v], vertex_channels[dst])
258 | assert vertex_channels[v] > 0
259 |
260 | # Sanity check, verify that channels never increase and final channels add up.
261 | final_fan_in = 0
262 | for v in range(1, num_vertices - 1):
263 | if matrix[v, num_vertices - 1]:
264 | final_fan_in += vertex_channels[v]
265 | for dst in range(v + 1, num_vertices - 1):
266 | if matrix[v, dst]:
267 | assert vertex_channels[v] >= vertex_channels[dst]
268 | assert final_fan_in == out_channels or num_vertices == 2
269 | # num_vertices == 2 means only input/output nodes, so 0 fan-in
270 |
271 | return [int(v) for v in vertex_channels]
272 |
273 |
274 | class Placeholder(torch.nn.Module):
275 | def __init__(self):
276 | super().__init__()
277 | self.a = torch.nn.Parameter(torch.randn(()))
278 |
279 | def forward(self, x):
280 | return x
281 |
--------------------------------------------------------------------------------
/nasbench_pytorch/model/model_spec.py:
--------------------------------------------------------------------------------
1 |
2 | # Copyright 2019 The Google Research Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Model specification for module connectivity individuals.
17 |
18 | This module handles pruning the unused parts of the computation graph but should
19 | avoid creating any TensorFlow models (this is done inside model_builder.py).
20 | """
21 |
22 | from __future__ import absolute_import
23 | from __future__ import division
24 | from __future__ import print_function
25 |
26 | import copy
27 | import numpy as np
28 | import torch
29 |
30 | from nasbench_pytorch.model import graph_util
31 |
32 | # Graphviz is optional and only required for visualization.
33 | try:
34 | import graphviz # pylint: disable=g-import-not-at-top
35 | except ImportError:
36 | pass
37 |
38 |
39 | class ModelSpec(object):
40 | """Model specification given adjacency matrix and labeling."""
41 |
42 | def __init__(self, matrix, ops, data_format='channels_last'):
43 | """Initialize the module spec.
44 |
45 | Args:
46 | matrix: ndarray or nested list with shape [V, V] for the adjacency matrix.
47 | ops: V-length list of labels for the base ops used. The first and last
48 | elements are ignored because they are the input and output vertices
49 | which have no operations. The elements are retained to keep consistent
50 | indexing.
51 | data_format: channels_last or channels_first.
52 |
53 | Raises:
54 | ValueError: invalid matrix or ops
55 | """
56 |
57 | if not isinstance(matrix, np.ndarray):
58 | matrix = np.array(matrix)
59 | shape = np.shape(matrix)
60 | if len(shape) != 2 or shape[0] != shape[1]:
61 | raise ValueError('matrix must be square')
62 | if shape[0] != len(ops):
63 | raise ValueError('length of ops must match matrix dimensions')
64 | if not is_upper_triangular(matrix):
65 | raise ValueError('matrix must be upper triangular')
66 |
67 | # Both the original and pruned matrices are deep copies of the matrix and
68 | # ops so any changes to those after initialization are not recognized by the
69 | # spec.
70 | self.original_matrix = copy.deepcopy(matrix)
71 | self.original_ops = copy.deepcopy(ops)
72 |
73 | self.matrix = copy.deepcopy(matrix)
74 | self.ops = copy.deepcopy(ops)
75 | self.valid_spec = True
76 | self._prune()
77 |
78 | self.matrix = torch.tensor(self.matrix)
79 |
80 | self.data_format = data_format
81 |
82 | def _prune(self):
83 | """Prune the extraneous parts of the graph.
84 |
85 | General procedure:
86 | 1) Remove parts of graph not connected to input.
87 | 2) Remove parts of graph not connected to output.
88 | 3) Reorder the vertices so that they are consecutive after steps 1 and 2.
89 |
90 | These 3 steps can be combined by deleting the rows and columns of the
91 | vertices that are not reachable from both the input and output (in reverse).
92 | """
93 | num_vertices = np.shape(self.original_matrix)[0]
94 |
95 | # DFS forward from input
96 | visited_from_input = set([0])
97 | frontier = [0]
98 | while frontier:
99 | top = frontier.pop()
100 | for v in range(top + 1, num_vertices):
101 | if self.original_matrix[top, v] and v not in visited_from_input:
102 | visited_from_input.add(v)
103 | frontier.append(v)
104 |
105 | # DFS backward from output
106 | visited_from_output = set([num_vertices - 1])
107 | frontier = [num_vertices - 1]
108 | while frontier:
109 | top = frontier.pop()
110 | for v in range(0, top):
111 | if self.original_matrix[v, top] and v not in visited_from_output:
112 | visited_from_output.add(v)
113 | frontier.append(v)
114 |
115 | # Any vertex that isn't connected to both input and output is extraneous to
116 | # the computation graph.
117 | extraneous = set(range(num_vertices)).difference(
118 | visited_from_input.intersection(visited_from_output))
119 |
120 | # If the non-extraneous graph is less than 2 vertices, the input is not
121 | # connected to the output and the spec is invalid.
122 | if len(extraneous) > num_vertices - 2:
123 | self.matrix = None
124 | self.ops = None
125 | self.valid_spec = False
126 | return
127 |
128 | self.matrix = np.delete(self.matrix, list(extraneous), axis=0)
129 | self.matrix = np.delete(self.matrix, list(extraneous), axis=1)
130 | for index in sorted(extraneous, reverse=True):
131 | del self.ops[index]
132 |
133 | def hash_spec(self, canonical_ops):
134 | """Computes the isomorphism-invariant graph hash of this spec.
135 |
136 | Args:
137 | canonical_ops: list of operations in the canonical ordering which they
138 | were assigned (i.e. the order provided in the config['available_ops']).
139 |
140 | Returns:
141 | MD5 hash of this spec which can be used to query the dataset.
142 | """
143 | # Invert the operations back to integer label indices used in graph gen.
144 | labeling = [-1] + [canonical_ops.index(op) for op in self.ops[1:-1]] + [-2]
145 | return graph_util.hash_module(self.matrix, labeling)
146 |
147 | def visualize(self):
148 | """Creates a dot graph. Can be visualized in colab directly."""
149 | num_vertices = np.shape(self.matrix)[0]
150 | g = graphviz.Digraph()
151 | g.node(str(0), 'input')
152 | for v in range(1, num_vertices - 1):
153 | g.node(str(v), self.ops[v])
154 | g.node(str(num_vertices - 1), 'output')
155 |
156 | for src in range(num_vertices - 1):
157 | for dst in range(src + 1, num_vertices):
158 | if self.matrix[src, dst]:
159 | g.edge(str(src), str(dst))
160 |
161 | return g
162 |
163 |
164 | def is_upper_triangular(matrix):
165 | """True if matrix is 0 on diagonal and below."""
166 | for src in range(np.shape(matrix)[0]):
167 | for dst in range(0, src + 1):
168 | if matrix[src, dst] != 0:
169 | return False
170 |
171 | return True
172 |
--------------------------------------------------------------------------------
/nasbench_pytorch/trainer.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 | import torch
4 | from torch import nn
5 |
6 |
7 | def train(net, train_loader, loss=None, optimizer=None, scheduler=None, grad_clip=5, num_epochs=10,
8 | num_validation=None, validation_loader=None, device=None, print_frequency=200,
9 | checkpoint_every_k=None, checkpoint_func=None):
10 | """
11 | Train a network from the NAS-bench-101 search space on a dataset (`train_loader`).
12 |
13 | Args:
14 | net: Network to train.
15 | train_loader: Train data loader.
16 | loss: Loss, default is CrossEntropyLoss.
17 | optimizer: Optimizer, default is SGD, possible: 'sgd', 'rmsprop', 'adam', or an optimizer object.
18 | scheduler: Default is CosineAnnealingLR.
19 | grad_clip: Gradient clipping parameter.
20 | num_epochs: Number of training epochs.
21 | num_validation: Number of validation examples (for print purposes).
22 | validation_loader: Optional validation set.
23 | device: Device to train on, default is cpu.
24 | print_frequency: How often to print info about batches.
25 | checkpoint_every_k: Every k epochs, save a checkpoint.
26 | checkpoint_func: Custom function to save the checkpoint, signature: func(net, metric_dict, epoch num)
27 |
28 | Returns:
29 | Final train (and validation) metrics.
30 | """
31 |
32 | net = net.to(device)
33 |
34 | # defaults
35 | if loss is None:
36 | loss = nn.CrossEntropyLoss()
37 |
38 | if optimizer is not None and not isinstance(optimizer, str):
39 | pass
40 | elif optimizer is None or optimizer.lower() == 'rmsprop':
41 | optimizer = torch.optim.RMSprop(net.parameters(), lr=0.2, momentum=0.9, weight_decay=1e-4, eps=1.0)
42 | elif optimizer.lower() == 'sgd':
43 | optimizer = torch.optim.SGD(net.parameters(), lr=0.025, momentum=0.9, weight_decay=1e-4)
44 | elif optimizer.lower() == 'adam':
45 | optimizer = torch.optim.Adam(net.parameters())
46 |
47 | if scheduler is None:
48 | scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, num_epochs)
49 |
50 | # training
51 |
52 | n_batches = len(train_loader)
53 |
54 | metric_dict = {'train_loss': [], 'train_accuracy': [], 'val_loss': [], 'val_accuracy': []}
55 | for epoch in range(num_epochs):
56 | # checkpoint using a user defined function
57 | if checkpoint_every_k is not None and (epoch + 1) % checkpoint_every_k == 0:
58 | checkpoint_func(net, metric_dict, epoch + 1)
59 |
60 | net.train()
61 |
62 | train_loss = torch.tensor(0.0)
63 | correct = torch.tensor(0)
64 | total = 0
65 |
66 | batch_idx = 0
67 | for batch_idx, (inputs, targets) in enumerate(train_loader):
68 | inputs, targets = inputs.to(device), targets.to(device)
69 |
70 | # forward
71 | outputs = net(inputs)
72 |
73 | # back-propagation
74 | optimizer.zero_grad()
75 | curr_loss = loss(outputs, targets)
76 | curr_loss.backward()
77 | if grad_clip is not None:
78 | nn.utils.clip_grad_norm_(net.parameters(), grad_clip)
79 | optimizer.step()
80 |
81 | # metrics
82 | train_loss += curr_loss.detach().cpu()
83 | _, predict = torch.max(outputs.data, 1)
84 | total += targets.size(0)
85 | correct += predict.eq(targets.data).sum().detach().cpu()
86 |
87 | if (batch_idx % print_frequency) == 0:
88 | print(f'Epoch={epoch}/{num_epochs} Batch={batch_idx + 1}/{n_batches} | '
89 | f'Loss={train_loss / (batch_idx + 1):.3f}, '
90 | f'Acc={correct / total:.3f}({correct}/{total})')
91 |
92 | last_loss = train_loss / (batch_idx + 1)
93 | acc = correct / total
94 |
95 | # save metrics
96 | metric_dict['train_loss'].append(last_loss.item())
97 | metric_dict['train_accuracy'].append(acc.item())
98 |
99 | if validation_loader is not None:
100 | test_metrics = test(net, validation_loader, loss, num_tests=num_validation, device=device)
101 | metric_dict['val_loss'].append(test_metrics['test_loss'])
102 | metric_dict['val_accuracy'].append(test_metrics['test_accuracy'])
103 |
104 | print('--------------------')
105 | scheduler.step()
106 |
107 | return metric_dict
108 |
109 |
110 | def test(net, test_loader, loss=None, num_tests=None, device=None):
111 | """
112 | Evaluate the network on a test set.
113 |
114 | Args:
115 | net: Network for testing.
116 | test_loader: Test dataset.
117 | loss: Loss function, default is CrossEntropyLoss.
118 | num_tests: Number of test examples (for print purposes).
119 | device: Device to use.
120 |
121 | Returns:
122 | Test metrics.
123 | """
124 | net = net.to(device)
125 | net.eval()
126 |
127 | if loss is None:
128 | loss = nn.CrossEntropyLoss()
129 |
130 | test_loss = 0
131 | correct = 0
132 | n_tests = 0
133 |
134 | with torch.no_grad():
135 | for batch_idx, (inputs, targets) in enumerate(test_loader):
136 | inputs, targets = inputs.to(device), targets.to(device)
137 |
138 | outputs = net(inputs)
139 |
140 | curr_loss = loss(outputs, targets)
141 | test_loss += curr_loss.detach()
142 | _, predict = torch.max(outputs.data, 1)
143 | correct += predict.eq(targets.data).sum().detach()
144 |
145 | if num_tests is None:
146 | n_tests += len(targets)
147 |
148 | if num_tests is None:
149 | num_tests = n_tests
150 |
151 | print(f'Testing: Loss={(test_loss / len(test_loader)):.3f}, Acc={(correct / num_tests):.3f}'
152 | f'({correct}/{num_tests})')
153 |
154 | last_loss = test_loss / len(test_loader) if len(test_loader) > 0 else np.inf
155 | acc = correct / num_tests
156 |
157 | return {'test_loss': last_loss.item(), 'test_accuracy': acc.item()}
158 |
--------------------------------------------------------------------------------
/setup.cfg:
--------------------------------------------------------------------------------
1 | [metadata]
2 | long_description = file: README.md
3 | long_description_content_type = text/markdown
4 | license_files=LICENSE.txt
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | import setuptools
2 |
3 | setuptools.setup(
4 | name='nasbench_pytorch',
5 | version='1.3.1',
6 | license='Apache License 2.0',
7 | author='Romulus Hong, Gabriela Kadlecová',
8 | packages=setuptools.find_packages()
9 | )
10 |
--------------------------------------------------------------------------------