├── .gitignore ├── LICENSE ├── README.md ├── assets ├── ENAS_cnn.png ├── ENAS_rnn.png ├── arial.ttf ├── best_rnn_epoch27.png ├── cnn.png ├── cnn_cell.png ├── ptb.gif ├── rnn.png └── wikitext.gif ├── config.py ├── dag.json ├── data ├── __init__.py ├── image.py ├── ptb │ ├── test.txt │ ├── train.txt │ └── valid.txt ├── text.py └── wikitext │ ├── README │ ├── test.txt │ ├── train.txt │ └── valid.txt ├── generate_gif.py ├── main.py ├── models ├── __init__.py ├── controller.py ├── shared_base.py ├── shared_cnn.py └── shared_rnn.py ├── requirements.txt ├── run.sh ├── tensorboard.py ├── trainer.py └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Data 2 | *.png 3 | *.gif 4 | *.tar.gz 5 | data/cifar-10-batches-py 6 | 7 | # ipython checkpoints 8 | .ipynb_checkpoints 9 | 10 | # Log 11 | logs 12 | 13 | # ETC 14 | .vscode 15 | 16 | # Created by https://www.gitignore.io/api/python,vim 17 | 18 | ### Python ### 19 | # Byte-compiled / optimized / DLL files 20 | __pycache__/ 21 | *.py[cod] 22 | *$py.class 23 | 24 | # C extensions 25 | *.so 26 | 27 | # Distribution / packaging 28 | .Python 29 | env/ 30 | build/ 31 | develop-eggs/ 32 | dist/ 33 | downloads/ 34 | eggs/ 35 | .eggs/ 36 | lib/ 37 | lib64/ 38 | parts/ 39 | sdist/ 40 | var/ 41 | wheels/ 42 | *.egg-info/ 43 | .installed.cfg 44 | *.egg 45 | 46 | # PyInstaller 47 | # Usually these files are written by a python script from a template 48 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 49 | *.manifest 50 | *.spec 51 | 52 | # Installer logs 53 | pip-log.txt 54 | pip-delete-this-directory.txt 55 | 56 | # Unit test / coverage reports 57 | htmlcov/ 58 | .tox/ 59 | .coverage 60 | .coverage.* 61 | .cache 62 | nosetests.xml 63 | coverage.xml 64 | *,cover 65 | .hypothesis/ 66 | 67 | # Translations 68 | *.mo 69 | *.pot 70 | 71 | # Django stuff: 72 | *.log 73 | local_settings.py 74 | 75 | # Flask stuff: 76 | instance/ 77 | .webassets-cache 78 | 79 | # Scrapy stuff: 80 | .scrapy 81 | 82 | # Sphinx documentation 83 | docs/_build/ 84 | 85 | # PyBuilder 86 | target/ 87 | 88 | # Jupyter Notebook 89 | .ipynb_checkpoints 90 | 91 | # pyenv 92 | .python-version 93 | 94 | # celery beat schedule file 95 | celerybeat-schedule 96 | 97 | # dotenv 98 | .env 99 | 100 | # virtualenv 101 | .venv/ 102 | venv/ 103 | ENV/ 104 | 105 | # Spyder project settings 106 | .spyderproject 107 | 108 | # Rope project settings 109 | .ropeproject 110 | 111 | 112 | ### Vim ### 113 | # swap 114 | [._]*.s[a-v][a-z] 115 | [._]*.sw[a-p] 116 | [._]s[a-v][a-z] 117 | [._]sw[a-p] 118 | # session 119 | Session.vim 120 | # temporary 121 | .netrwhist 122 | *~ 123 | # auto-generated tag files 124 | tags 125 | 126 | # End of https://www.gitignore.io/api/python,vim 127 | main.sh 128 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Efficient Neural Architecture Search (ENAS) in PyTorch 2 | 3 | PyTorch implementation of [Efficient Neural Architecture Search via Parameters Sharing](https://arxiv.org/abs/1802.03268). 4 | 5 |

ENAS_rnn

6 | 7 | **ENAS** reduce the computational requirement (GPU-hours) of [Neural Architecture Search](https://arxiv.org/abs/1611.01578) (**NAS**) by 1000x via parameter sharing between models that are subgraphs within a large computational graph. SOTA on `Penn Treebank` language modeling. 8 | 9 | **\*\*[Caveat] Use official code from the authors: [link](https://github.com/melodyguan/enas)\*\*** 10 | 11 | 12 | ## Prerequisites 13 | 14 | - Python 3.6+ 15 | - [PyTorch==0.3.1](https://pytorch.org/get-started/previous-versions/) 16 | - tqdm, scipy, imageio, graphviz, tensorboardX 17 | 18 | ## Usage 19 | 20 | Install prerequisites with: 21 | 22 | conda install graphviz 23 | pip install -r requirements.txt 24 | 25 | To train **ENAS** to discover a recurrent cell for RNN: 26 | 27 | python main.py --network_type rnn --dataset ptb --controller_optim adam --controller_lr 0.00035 \ 28 | --shared_optim sgd --shared_lr 20.0 --entropy_coeff 0.0001 29 | 30 | python main.py --network_type rnn --dataset wikitext 31 | 32 | To train **ENAS** to discover CNN architecture (in progress): 33 | 34 | python main.py --network_type cnn --dataset cifar --controller_optim momentum --controller_lr_cosine=True \ 35 | --controller_lr_max 0.05 --controller_lr_min 0.0001 --entropy_coeff 0.1 36 | 37 | or you can use your own dataset by placing images like: 38 | 39 | data 40 | ├── YOUR_TEXT_DATASET 41 | │ ├── test.txt 42 | │ ├── train.txt 43 | │ └── valid.txt 44 | ├── YOUR_IMAGE_DATASET 45 | │ ├── test 46 | │ │ ├── xxx.jpg (name doesn't matter) 47 | │ │ ├── yyy.jpg (name doesn't matter) 48 | │ │ └── ... 49 | │ ├── train 50 | │ │ ├── xxx.jpg 51 | │ │ └── ... 52 | │ └── valid 53 | │ ├── xxx.jpg 54 | │ └── ... 55 | ├── image.py 56 | └── text.py 57 | 58 | To generate `gif` image of generated samples: 59 | 60 | python generate_gif.py --model_name=ptb_2018-02-15_11-20-02 --output=sample.gif 61 | 62 | More configurations can be found [here](config.py). 63 | 64 | 65 | ## Results 66 | 67 | Efficient Neural Architecture Search (**ENAS**) is composed of two sets of learnable parameters, controller LSTM *θ* and the shared parameters *ω*. These two parameters are alternatively trained and only trained controller is used to derive novel architectures. 68 | 69 | ### 1. Discovering Recurrent Cells 70 | 71 | ![rnn](./assets/rnn.png) 72 | 73 | Controller LSTM decide 1) what activation function to use and 2) which previous node to connect. 74 | 75 | The RNN cell **ENAS** discovered for `Penn Treebank` and `WikiText-2` dataset: 76 | 77 | ptb wikitext 78 | 79 | Best discovered ENAS cell for `Penn Treebank` at epoch 27: 80 | 81 | ptb 82 | 83 | You can see the details of training (e.g. `reward`, `entropy`, `loss`) with: 84 | 85 | tensorboard --logdir=logs --port=6006 86 | 87 | 88 | ### 2. Discovering Convolutional Neural Networks 89 | 90 | ![cnn](./assets/cnn.png) 91 | 92 | Controller LSTM samples 1) what computation operation to use and 2) which previous node to connect. 93 | 94 | The CNN network **ENAS** discovered for `CIFAR-10` dataset: 95 | 96 | (in progress) 97 | 98 | 99 | ### 3. Designing Convolutional Cells 100 | 101 | (in progress) 102 | 103 | 104 | ## Reference 105 | 106 | - [Neural Architecture Search with Reinforcement Learning](https://arxiv.org/abs/1611.01578) 107 | - [Neural Optimizer Search with Reinforcement Learning](https://arxiv.org/abs/1709.07417) 108 | 109 | 110 | ## Author 111 | 112 | Taehoon Kim / [@carpedm20](http://carpedm20.github.io/) 113 | -------------------------------------------------------------------------------- /assets/ENAS_cnn.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/carpedm20/ENAS-pytorch/0468b8c4ddcf540c9ed6f80c27289792ff9118c9/assets/ENAS_cnn.png -------------------------------------------------------------------------------- /assets/ENAS_rnn.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/carpedm20/ENAS-pytorch/0468b8c4ddcf540c9ed6f80c27289792ff9118c9/assets/ENAS_rnn.png -------------------------------------------------------------------------------- /assets/arial.ttf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/carpedm20/ENAS-pytorch/0468b8c4ddcf540c9ed6f80c27289792ff9118c9/assets/arial.ttf -------------------------------------------------------------------------------- /assets/best_rnn_epoch27.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/carpedm20/ENAS-pytorch/0468b8c4ddcf540c9ed6f80c27289792ff9118c9/assets/best_rnn_epoch27.png -------------------------------------------------------------------------------- /assets/cnn.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/carpedm20/ENAS-pytorch/0468b8c4ddcf540c9ed6f80c27289792ff9118c9/assets/cnn.png -------------------------------------------------------------------------------- /assets/cnn_cell.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/carpedm20/ENAS-pytorch/0468b8c4ddcf540c9ed6f80c27289792ff9118c9/assets/cnn_cell.png -------------------------------------------------------------------------------- /assets/ptb.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/carpedm20/ENAS-pytorch/0468b8c4ddcf540c9ed6f80c27289792ff9118c9/assets/ptb.gif -------------------------------------------------------------------------------- /assets/rnn.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/carpedm20/ENAS-pytorch/0468b8c4ddcf540c9ed6f80c27289792ff9118c9/assets/rnn.png -------------------------------------------------------------------------------- /assets/wikitext.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/carpedm20/ENAS-pytorch/0468b8c4ddcf540c9ed6f80c27289792ff9118c9/assets/wikitext.gif -------------------------------------------------------------------------------- /config.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from utils import get_logger 3 | 4 | logger = get_logger() 5 | 6 | 7 | arg_lists = [] 8 | parser = argparse.ArgumentParser() 9 | 10 | def str2bool(v): 11 | return v.lower() in ('true') 12 | 13 | def add_argument_group(name): 14 | arg = parser.add_argument_group(name) 15 | arg_lists.append(arg) 16 | return arg 17 | 18 | # Network 19 | net_arg = add_argument_group('Network') 20 | net_arg.add_argument('--network_type', type=str, choices=['rnn', 'cnn'], default='rnn') 21 | 22 | # Controller 23 | net_arg.add_argument('--num_blocks', type=int, default=12) 24 | net_arg.add_argument('--tie_weights', type=str2bool, default=True) 25 | net_arg.add_argument('--controller_hid', type=int, default=100) 26 | 27 | # Shared parameters for PTB 28 | # NOTE(brendan): See Merity config for wdrop 29 | # https://github.com/salesforce/awd-lstm-lm. 30 | net_arg.add_argument('--shared_wdrop', type=float, default=0.5) 31 | net_arg.add_argument('--shared_dropout', type=float, default=0.4) # TODO 32 | net_arg.add_argument('--shared_dropoute', type=float, default=0.1) # TODO 33 | net_arg.add_argument('--shared_dropouti', type=float, default=0.65) # TODO 34 | net_arg.add_argument('--shared_embed', type=int, default=1000) # TODO: 200, 500, 1000 35 | net_arg.add_argument('--shared_hid', type=int, default=1000) 36 | net_arg.add_argument('--shared_rnn_max_length', type=int, default=35) 37 | net_arg.add_argument('--shared_rnn_activations', type=eval, 38 | default="['tanh', 'ReLU', 'identity', 'sigmoid']") 39 | net_arg.add_argument('--shared_cnn_types', type=eval, 40 | default="['3x3', '5x5', 'sep 3x3', 'sep 5x5', 'max 3x3', 'max 5x5']") 41 | 42 | # PTB regularizations 43 | net_arg.add_argument('--activation_regularization', 44 | type=str2bool, 45 | default=False) 46 | net_arg.add_argument('--activation_regularization_amount', 47 | type=float, 48 | default=2.0) 49 | net_arg.add_argument('--temporal_activation_regularization', 50 | type=str2bool, 51 | default=False) 52 | net_arg.add_argument('--temporal_activation_regularization_amount', 53 | type=float, 54 | default=1.0) 55 | net_arg.add_argument('--norm_stabilizer_regularization', 56 | type=str2bool, 57 | default=False) 58 | net_arg.add_argument('--norm_stabilizer_regularization_amount', 59 | type=float, 60 | default=1.0) 61 | net_arg.add_argument('--norm_stabilizer_fixed_point', type=float, default=5.0) 62 | 63 | # Shared parameters for CIFAR 64 | net_arg.add_argument('--cnn_hid', type=int, default=64) 65 | 66 | 67 | # Data 68 | data_arg = add_argument_group('Data') 69 | data_arg.add_argument('--dataset', type=str, default='ptb') 70 | 71 | 72 | # Training / test parameters 73 | learn_arg = add_argument_group('Learning') 74 | learn_arg.add_argument('--mode', type=str, default='train', 75 | choices=['train', 'derive', 'test', 'single'], 76 | help='train: Training ENAS, derive: Deriving Architectures,\ 77 | single: training one dag') 78 | learn_arg.add_argument('--batch_size', type=int, default=64) 79 | learn_arg.add_argument('--test_batch_size', type=int, default=1) 80 | learn_arg.add_argument('--max_epoch', type=int, default=150) 81 | learn_arg.add_argument('--entropy_mode', type=str, default='reward', choices=['reward', 'regularizer']) 82 | 83 | 84 | # Controller 85 | learn_arg.add_argument('--ppl_square', type=str2bool, default=False) 86 | # NOTE(brendan): (Zoph and Le, 2017) page 8 states that c is a constant, 87 | # usually set at 80. 88 | learn_arg.add_argument('--reward_c', type=int, default=80, 89 | help="WE DON'T KNOW WHAT THIS VALUE SHOULD BE") # TODO 90 | # NOTE(brendan): irrelevant for actor critic. 91 | learn_arg.add_argument('--ema_baseline_decay', type=float, default=0.95) # TODO: very important 92 | learn_arg.add_argument('--discount', type=float, default=1.0) # TODO 93 | learn_arg.add_argument('--controller_max_step', type=int, default=2000, 94 | help='step for controller parameters') 95 | learn_arg.add_argument('--controller_optim', type=str, default='adam') 96 | learn_arg.add_argument('--controller_lr', type=float, default=3.5e-4, 97 | help="will be ignored if --controller_lr_cosine=True") 98 | learn_arg.add_argument('--controller_lr_cosine', type=str2bool, default=False) 99 | learn_arg.add_argument('--controller_lr_max', type=float, default=0.05, 100 | help="lr max for cosine schedule") 101 | learn_arg.add_argument('--controller_lr_min', type=float, default=0.001, 102 | help="lr min for cosine schedule") 103 | learn_arg.add_argument('--controller_grad_clip', type=float, default=0) 104 | learn_arg.add_argument('--tanh_c', type=float, default=2.5) 105 | learn_arg.add_argument('--softmax_temperature', type=float, default=5.0) 106 | learn_arg.add_argument('--entropy_coeff', type=float, default=1e-4) 107 | 108 | # Shared parameters 109 | learn_arg.add_argument('--shared_initial_step', type=int, default=0) 110 | learn_arg.add_argument('--shared_max_step', type=int, default=400, 111 | help='step for shared parameters') 112 | # NOTE(brendan): Should be 10 for CNN architectures. 113 | learn_arg.add_argument('--shared_num_sample', type=int, default=1, 114 | help='# of Monte Carlo samples') 115 | learn_arg.add_argument('--shared_optim', type=str, default='sgd') 116 | learn_arg.add_argument('--shared_lr', type=float, default=20.0) 117 | learn_arg.add_argument('--shared_decay', type=float, default=0.96) 118 | learn_arg.add_argument('--shared_decay_after', type=float, default=15) 119 | learn_arg.add_argument('--shared_l2_reg', type=float, default=1e-7) 120 | learn_arg.add_argument('--shared_grad_clip', type=float, default=0.25) 121 | 122 | # Deriving Architectures 123 | learn_arg.add_argument('--derive_num_sample', type=int, default=100) 124 | 125 | 126 | # Misc 127 | misc_arg = add_argument_group('Misc') 128 | misc_arg.add_argument('--load_path', type=str, default='') 129 | misc_arg.add_argument('--log_step', type=int, default=50) 130 | misc_arg.add_argument('--save_epoch', type=int, default=4) 131 | misc_arg.add_argument('--max_save_num', type=int, default=4) 132 | misc_arg.add_argument('--log_level', type=str, default='INFO', choices=['INFO', 'DEBUG', 'WARN']) 133 | misc_arg.add_argument('--log_dir', type=str, default='logs') 134 | misc_arg.add_argument('--data_dir', type=str, default='data') 135 | misc_arg.add_argument('--num_gpu', type=int, default=1) 136 | misc_arg.add_argument('--random_seed', type=int, default=12345) 137 | misc_arg.add_argument('--use_tensorboard', type=str2bool, default=True) 138 | misc_arg.add_argument('--dag_path', type=str, default='') 139 | 140 | def get_args(): 141 | """Parses all of the arguments above, which mostly correspond to the 142 | hyperparameters mentioned in the paper. 143 | """ 144 | args, unparsed = parser.parse_known_args() 145 | if args.num_gpu > 0: 146 | setattr(args, 'cuda', True) 147 | else: 148 | setattr(args, 'cuda', False) 149 | if len(unparsed) > 1: 150 | logger.info(f"Unparsed args: {unparsed}") 151 | return args, unparsed 152 | -------------------------------------------------------------------------------- /dag.json: -------------------------------------------------------------------------------- 1 | { 2 | "-1": [ 3 | [ 4 | 0, 5 | "tanh" 6 | ] 7 | ], 8 | "-2": [ 9 | [ 10 | 0, 11 | "tanh" 12 | ] 13 | ], 14 | "0": [ 15 | [ 16 | 1, 17 | "tanh" 18 | ] 19 | ], 20 | "1": [ 21 | [ 22 | 2, 23 | "ReLU" 24 | ], 25 | [ 26 | 3, 27 | "tanh" 28 | ] 29 | ], 30 | "2": [ 31 | [ 32 | 4, 33 | "tanh" 34 | ], 35 | [ 36 | 5, 37 | "tanh" 38 | ], 39 | [ 40 | 6, 41 | "ReLU" 42 | ] 43 | ], 44 | "4": [ 45 | [ 46 | 7, 47 | "ReLU" 48 | ] 49 | ], 50 | "7": [ 51 | [ 52 | 8, 53 | "ReLU" 54 | ] 55 | ], 56 | "8": [ 57 | [ 58 | 9, 59 | "ReLU" 60 | ], 61 | [ 62 | 10, 63 | "ReLU" 64 | ], 65 | [ 66 | 11, 67 | "ReLU" 68 | ] 69 | ], 70 | "3": [ 71 | [ 72 | 12, 73 | "avg" 74 | ] 75 | ], 76 | "5": [ 77 | [ 78 | 12, 79 | "avg" 80 | ] 81 | ], 82 | "6": [ 83 | [ 84 | 12, 85 | "avg" 86 | ] 87 | ], 88 | "9": [ 89 | [ 90 | 12, 91 | "avg" 92 | ] 93 | ], 94 | "10": [ 95 | [ 96 | 12, 97 | "avg" 98 | ] 99 | ], 100 | "11": [ 101 | [ 102 | 12, 103 | "avg" 104 | ] 105 | ], 106 | "12": [ 107 | [ 108 | 13, 109 | "h[t]" 110 | ] 111 | ] 112 | } -------------------------------------------------------------------------------- /data/__init__.py: -------------------------------------------------------------------------------- 1 | import data.text 2 | import data.image 3 | -------------------------------------------------------------------------------- /data/image.py: -------------------------------------------------------------------------------- 1 | import torch as t 2 | import torchvision.datasets as datasets 3 | import torchvision.transforms as transforms 4 | 5 | 6 | class Image(object): 7 | def __init__(self, args): 8 | if args.dataset == 'cifar10': 9 | Dataset = datasets.CIFAR10 10 | 11 | mean = [0.49139968, 0.48215827, 0.44653124] 12 | std = [0.24703233, 0.24348505, 0.26158768] 13 | 14 | normalize = transforms.Normalize(mean, std) 15 | 16 | transform = transforms.Compose([ 17 | transforms.RandomCrop(32, padding=4), 18 | transforms.RandomHorizontalFlip(), 19 | transforms.ToTensor(), 20 | normalize, 21 | ]) 22 | elif args.dataset == 'MNIST': 23 | Dataset = datasets.MNIST 24 | else: 25 | raise NotImplementedError(f'Unknown dataset: {args.dataset}') 26 | 27 | self.train = t.utils.data.DataLoader( 28 | Dataset(root='./data', train=True, transform=transform, download=True), 29 | batch_size=args.batch_size, shuffle=True, 30 | num_workers=args.num_workers, pin_memory=True) 31 | 32 | self.valid = t.utils.data.DataLoader( 33 | Dataset(root='./data', train=False, transform=transforms.Compose([ 34 | transforms.ToTensor(), 35 | normalize, 36 | ])), 37 | batch_size=args.batch_size, shuffle=False, 38 | num_workers=args.num_workers, pin_memory=True) 39 | 40 | self.test = self.valid 41 | -------------------------------------------------------------------------------- /data/text.py: -------------------------------------------------------------------------------- 1 | # Code from https://github.com/salesforce/awd-lstm-lm 2 | import os 3 | import torch as t 4 | 5 | import collections 6 | 7 | 8 | class Dictionary(object): 9 | def __init__(self): 10 | self.word2idx = {} 11 | self.idx2word = [] 12 | self.counter = collections.Counter() 13 | self.total = 0 14 | 15 | def add_word(self, word): 16 | if word not in self.word2idx: 17 | self.idx2word.append(word) 18 | self.word2idx[word] = len(self.idx2word) - 1 19 | 20 | token_id = self.word2idx[word] 21 | self.counter[token_id] += 1 22 | self.total += 1 23 | 24 | return token_id 25 | 26 | def __len__(self): 27 | return len(self.idx2word) 28 | 29 | 30 | class Corpus(object): 31 | def __init__(self, path): 32 | self.dictionary = Dictionary() 33 | self.train = self.tokenize(os.path.join(path, 'train.txt')) 34 | self.valid = self.tokenize(os.path.join(path, 'valid.txt')) 35 | self.test = self.tokenize(os.path.join(path, 'test.txt')) 36 | self.num_tokens = len(self.dictionary) 37 | 38 | def tokenize(self, path): 39 | """Tokenizes a text file.""" 40 | assert os.path.exists(path) 41 | # Add words to the dictionary 42 | with open(path, 'r') as f: 43 | tokens = 0 44 | for line in f: 45 | words = line.split() + [''] 46 | tokens += len(words) 47 | for word in words: 48 | self.dictionary.add_word(word) 49 | 50 | # Tokenize file content 51 | with open(path, 'r') as f: 52 | ids = t.LongTensor(tokens) 53 | token = 0 54 | for line in f: 55 | words = line.split() + [''] 56 | for word in words: 57 | ids[token] = self.dictionary.word2idx[word] 58 | token += 1 59 | 60 | return ids 61 | -------------------------------------------------------------------------------- /data/wikitext/README: -------------------------------------------------------------------------------- 1 | This is raw data from the wikitext-2 dataset. 2 | 3 | See https://www.salesforce.com/products/einstein/ai-research/the-wikitext-dependency-language-modeling-dataset/ 4 | -------------------------------------------------------------------------------- /generate_gif.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | import argparse 4 | from glob import glob 5 | 6 | from utils import make_gif 7 | 8 | parser = argparse.ArgumentParser() 9 | parser.add_argument("--model_name", type=str) 10 | parser.add_argument("--max_frame", type=int, default=50) 11 | parser.add_argument("--output", type=str, default="sampe.gif") 12 | parser.add_argument("--title", type=str, default="") 13 | 14 | if __name__ == "__main__": 15 | args = parser.parse_args() 16 | 17 | paths = glob(f"./logs/{args.model_name}/networks/*.png") 18 | make_gif(paths, args.output, 19 | max_frame=args.max_frame, 20 | prefix=f"{args.title}\n" if args.title else "") 21 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | """Entry point.""" 2 | import os 3 | 4 | import torch 5 | 6 | import data 7 | import config 8 | import utils 9 | import trainer 10 | 11 | logger = utils.get_logger() 12 | 13 | 14 | def main(args): # pylint:disable=redefined-outer-name 15 | """main: Entry point.""" 16 | utils.prepare_dirs(args) 17 | 18 | torch.manual_seed(args.random_seed) 19 | 20 | if args.num_gpu > 0: 21 | torch.cuda.manual_seed(args.random_seed) 22 | 23 | if args.network_type == 'rnn': 24 | dataset = data.text.Corpus(args.data_path) 25 | elif args.dataset == 'cifar': 26 | dataset = data.image.Image(args.data_path) 27 | else: 28 | raise NotImplementedError(f"{args.dataset} is not supported") 29 | 30 | trnr = trainer.Trainer(args, dataset) 31 | 32 | if args.mode == 'train': 33 | utils.save_args(args) 34 | trnr.train() 35 | elif args.mode == 'derive': 36 | assert args.load_path != "", ("`--load_path` should be given in " 37 | "`derive` mode") 38 | trnr.derive() 39 | elif args.mode == 'test': 40 | if not args.load_path: 41 | raise Exception("[!] You should specify `load_path` to load a " 42 | "pretrained model") 43 | trnr.test() 44 | elif args.mode == 'single': 45 | if not args.dag_path: 46 | raise Exception("[!] You should specify `dag_path` to load a dag") 47 | utils.save_args(args) 48 | trnr.train(single=True) 49 | else: 50 | raise Exception(f"[!] Mode not found: {args.mode}") 51 | 52 | if __name__ == "__main__": 53 | args, unparsed = config.get_args() 54 | main(args) 55 | -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | from models.shared_rnn import RNN 2 | from models.shared_cnn import CNN 3 | from models.controller import Controller 4 | -------------------------------------------------------------------------------- /models/controller.py: -------------------------------------------------------------------------------- 1 | """A module with NAS controller-related code.""" 2 | import collections 3 | import os 4 | 5 | import torch 6 | import torch.nn.functional as F 7 | 8 | import utils 9 | from utils import Node 10 | 11 | 12 | def _construct_dags(prev_nodes, activations, func_names, num_blocks): 13 | """Constructs a set of DAGs based on the actions, i.e., previous nodes and 14 | activation functions, sampled from the controller/policy pi. 15 | 16 | Args: 17 | prev_nodes: Previous node actions from the policy. 18 | activations: Activations sampled from the policy. 19 | func_names: Mapping from activation function names to functions. 20 | num_blocks: Number of blocks in the target RNN cell. 21 | 22 | Returns: 23 | A list of DAGs defined by the inputs. 24 | 25 | RNN cell DAGs are represented in the following way: 26 | 27 | 1. Each element (node) in a DAG is a list of `Node`s. 28 | 29 | 2. The `Node`s in the list dag[i] correspond to the subsequent nodes 30 | that take the output from node i as their own input. 31 | 32 | 3. dag[-1] is the node that takes input from x^{(t)} and h^{(t - 1)}. 33 | dag[-1] always feeds dag[0]. 34 | dag[-1] acts as if `w_xc`, `w_hc`, `w_xh` and `w_hh` are its 35 | weights. 36 | 37 | 4. dag[N - 1] is the node that produces the hidden state passed to 38 | the next timestep. dag[N - 1] is also always a leaf node, and therefore 39 | is always averaged with the other leaf nodes and fed to the output 40 | decoder. 41 | """ 42 | dags = [] 43 | for nodes, func_ids in zip(prev_nodes, activations): 44 | dag = collections.defaultdict(list) 45 | 46 | # add first node 47 | dag[-1] = [Node(0, func_names[func_ids[0]])] 48 | dag[-2] = [Node(0, func_names[func_ids[0]])] 49 | 50 | # add following nodes 51 | for jdx, (idx, func_id) in enumerate(zip(nodes, func_ids[1:])): 52 | dag[utils.to_item(idx)].append(Node(jdx + 1, func_names[func_id])) 53 | 54 | leaf_nodes = set(range(num_blocks)) - dag.keys() 55 | 56 | # merge with avg 57 | for idx in leaf_nodes: 58 | dag[idx] = [Node(num_blocks, 'avg')] 59 | 60 | # TODO(brendan): This is actually y^{(t)}. h^{(t)} is node N - 1 in 61 | # the graph, where N Is the number of nodes. I.e., h^{(t)} takes 62 | # only one other node as its input. 63 | # last h[t] node 64 | last_node = Node(num_blocks + 1, 'h[t]') 65 | dag[num_blocks] = [last_node] 66 | dags.append(dag) 67 | 68 | return dags 69 | 70 | 71 | class Controller(torch.nn.Module): 72 | """Based on 73 | https://github.com/pytorch/examples/blob/master/word_language_model/model.py 74 | 75 | TODO(brendan): RL controllers do not necessarily have much to do with 76 | language models. 77 | 78 | Base the controller RNN on the GRU from: 79 | https://github.com/ikostrikov/pytorch-a2c-ppo-acktr/blob/master/model.py 80 | """ 81 | def __init__(self, args): 82 | torch.nn.Module.__init__(self) 83 | self.args = args 84 | 85 | if self.args.network_type == 'rnn': 86 | # NOTE(brendan): `num_tokens` here is just the activation function 87 | # for every even step, 88 | self.num_tokens = [len(args.shared_rnn_activations)] 89 | for idx in range(self.args.num_blocks): 90 | self.num_tokens += [idx + 1, 91 | len(args.shared_rnn_activations)] 92 | self.func_names = args.shared_rnn_activations 93 | elif self.args.network_type == 'cnn': 94 | self.num_tokens = [len(args.shared_cnn_types), 95 | self.args.num_blocks] 96 | self.func_names = args.shared_cnn_types 97 | 98 | num_total_tokens = sum(self.num_tokens) 99 | 100 | self.encoder = torch.nn.Embedding(num_total_tokens, 101 | args.controller_hid) 102 | self.lstm = torch.nn.LSTMCell(args.controller_hid, args.controller_hid) 103 | 104 | # TODO(brendan): Perhaps these weights in the decoder should be 105 | # shared? At least for the activation functions, which all have the 106 | # same size. 107 | self.decoders = [] 108 | for idx, size in enumerate(self.num_tokens): 109 | decoder = torch.nn.Linear(args.controller_hid, size) 110 | self.decoders.append(decoder) 111 | 112 | self._decoders = torch.nn.ModuleList(self.decoders) 113 | 114 | self.reset_parameters() 115 | self.static_init_hidden = utils.keydefaultdict(self.init_hidden) 116 | 117 | def _get_default_hidden(key): 118 | return utils.get_variable( 119 | torch.zeros(key, self.args.controller_hid), 120 | self.args.cuda, 121 | requires_grad=False) 122 | 123 | self.static_inputs = utils.keydefaultdict(_get_default_hidden) 124 | 125 | def reset_parameters(self): 126 | init_range = 0.1 127 | for param in self.parameters(): 128 | param.data.uniform_(-init_range, init_range) 129 | for decoder in self.decoders: 130 | decoder.bias.data.fill_(0) 131 | 132 | def forward(self, # pylint:disable=arguments-differ 133 | inputs, 134 | hidden, 135 | block_idx, 136 | is_embed): 137 | if not is_embed: 138 | embed = self.encoder(inputs) 139 | else: 140 | embed = inputs 141 | 142 | hx, cx = self.lstm(embed, hidden) 143 | logits = self.decoders[block_idx](hx) 144 | 145 | logits /= self.args.softmax_temperature 146 | 147 | # exploration 148 | if self.args.mode == 'train': 149 | logits = (self.args.tanh_c*F.tanh(logits)) 150 | 151 | return logits, (hx, cx) 152 | 153 | def sample(self, batch_size=1, with_details=False, save_dir=None): 154 | """Samples a set of `args.num_blocks` many computational nodes from the 155 | controller, where each node is made up of an activation function, and 156 | each node except the last also includes a previous node. 157 | """ 158 | if batch_size < 1: 159 | raise Exception(f'Wrong batch_size: {batch_size} < 1') 160 | 161 | # [B, L, H] 162 | inputs = self.static_inputs[batch_size] 163 | hidden = self.static_init_hidden[batch_size] 164 | 165 | activations = [] 166 | entropies = [] 167 | log_probs = [] 168 | prev_nodes = [] 169 | # NOTE(brendan): The RNN controller alternately outputs an activation, 170 | # followed by a previous node, for each block except the last one, 171 | # which only gets an activation function. The last node is the output 172 | # node, and its previous node is the average of all leaf nodes. 173 | for block_idx in range(2*(self.args.num_blocks - 1) + 1): 174 | logits, hidden = self.forward(inputs, 175 | hidden, 176 | block_idx, 177 | is_embed=(block_idx == 0)) 178 | 179 | probs = F.softmax(logits, dim=-1) 180 | log_prob = F.log_softmax(logits, dim=-1) 181 | # TODO(brendan): .mean() for entropy? 182 | entropy = -(log_prob * probs).sum(1, keepdim=False) 183 | 184 | action = probs.multinomial(num_samples=1).data 185 | selected_log_prob = log_prob.gather( 186 | 1, utils.get_variable(action, requires_grad=False)) 187 | 188 | # TODO(brendan): why the [:, 0] here? Should it be .squeeze(), or 189 | # .view()? Same below with `action`. 190 | entropies.append(entropy) 191 | log_probs.append(selected_log_prob[:, 0]) 192 | 193 | # 0: function, 1: previous node 194 | mode = block_idx % 2 195 | inputs = utils.get_variable( 196 | action[:, 0] + sum(self.num_tokens[:mode]), 197 | requires_grad=False) 198 | 199 | if mode == 0: 200 | activations.append(action[:, 0]) 201 | elif mode == 1: 202 | prev_nodes.append(action[:, 0]) 203 | 204 | prev_nodes = torch.stack(prev_nodes).transpose(0, 1) 205 | activations = torch.stack(activations).transpose(0, 1) 206 | 207 | dags = _construct_dags(prev_nodes, 208 | activations, 209 | self.func_names, 210 | self.args.num_blocks) 211 | 212 | if save_dir is not None: 213 | for idx, dag in enumerate(dags): 214 | utils.draw_network(dag, 215 | os.path.join(save_dir, f'graph{idx}.png')) 216 | 217 | if with_details: 218 | return dags, torch.cat(log_probs), torch.cat(entropies) 219 | 220 | return dags 221 | 222 | def init_hidden(self, batch_size): 223 | zeros = torch.zeros(batch_size, self.args.controller_hid) 224 | return (utils.get_variable(zeros, self.args.cuda, requires_grad=False), 225 | utils.get_variable(zeros.clone(), self.args.cuda, requires_grad=False)) 226 | -------------------------------------------------------------------------------- /models/shared_base.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | 4 | 5 | def size(p): 6 | return np.prod(p.size()) 7 | 8 | class SharedModel(torch.nn.Module): 9 | def __init__(self): 10 | torch.nn.Module.__init__(self) 11 | 12 | @property 13 | def num_parameters(self): 14 | return sum([size(param) for param in self.parameters()]) 15 | 16 | def get_f(self, name): 17 | raise NotImplementedError() 18 | 19 | def get_num_cell_parameters(self, dag): 20 | raise NotImplementedError() 21 | 22 | def reset_parameters(self): 23 | raise NotImplementedError() 24 | -------------------------------------------------------------------------------- /models/shared_cnn.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from collections import defaultdict, deque 3 | 4 | import torch as t 5 | from torch import nn 6 | import torch.nn.functional as F 7 | from torch.autograd import Variable 8 | 9 | from models.shared_base import * 10 | from utils import get_logger, get_variable, keydefaultdict 11 | 12 | logger = get_logger() 13 | 14 | 15 | def conv3x3(in_planes, out_planes, stride=1): 16 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 17 | padding=1, bias=False) 18 | 19 | def conv5x5(in_planes, out_planes, stride=1): 20 | return nn.Conv2d(in_planes, out_planes, kernel_size=5, stride=stride, 21 | padding=1, bias=False) 22 | 23 | def conv(kernel, planes): 24 | if kernel == 3: 25 | _conv = conv3x3 26 | elif kernel == 5: 27 | _conv = conv5x5 28 | else: 29 | raise NotImplemented(f"Unkown kernel size: {kernel}") 30 | 31 | return nn.Sequential( 32 | nn.ReLU(inplace=True), 33 | _conv(planes, planes), 34 | nn.BatchNorm2d(planes), 35 | ) 36 | 37 | 38 | class CNN(SharedModel): 39 | def __init__(self, args, images): 40 | super(CNN, self).__init__() 41 | 42 | self.args = args 43 | self.images = images 44 | 45 | self.w_c, self.w_h = defaultdict(dict), defaultdict(dict) 46 | self.reset_parameters() 47 | 48 | self.conv = defaultdict(dict) 49 | for idx in range(args.num_blocks): 50 | for jdx in range(idx+1, args.num_blocks): 51 | self.conv[idx][jdx] = conv() 52 | 53 | raise NotImplemented("In progress...") 54 | 55 | def forward(self, inputs, dag): 56 | pass 57 | 58 | def get_f(self, name): 59 | name = name.lower() 60 | return f 61 | 62 | def get_num_cell_parameters(self, dag): 63 | pass 64 | 65 | def reset_parameters(self): 66 | pass 67 | -------------------------------------------------------------------------------- /models/shared_rnn.py: -------------------------------------------------------------------------------- 1 | """Module containing the shared RNN model.""" 2 | import numpy as np 3 | import collections 4 | 5 | import torch 6 | from torch import nn 7 | import torch.nn.functional as F 8 | from torch.autograd import Variable 9 | 10 | import models.shared_base 11 | import utils 12 | 13 | 14 | logger = utils.get_logger() 15 | 16 | 17 | def _get_dropped_weights(w_raw, dropout_p, is_training): 18 | """Drops out weights to implement DropConnect. 19 | 20 | Args: 21 | w_raw: Full, pre-dropout, weights to be dropped out. 22 | dropout_p: Proportion of weights to drop out. 23 | is_training: True iff _shared_ model is training. 24 | 25 | Returns: 26 | The dropped weights. 27 | 28 | TODO(brendan): Why does torch.nn.functional.dropout() return: 29 | 1. `torch.autograd.Variable()` on the training loop 30 | 2. `torch.nn.Parameter()` on the controller or eval loop, when 31 | training = False... 32 | 33 | Even though the call to `_setweights` in the Smerity repo's 34 | `weight_drop.py` does not have this behaviour, and `F.dropout` always 35 | returns `torch.autograd.Variable` there, even when `training=False`? 36 | 37 | The above TODO is the reason for the hacky check for `torch.nn.Parameter`. 38 | """ 39 | dropped_w = F.dropout(w_raw, p=dropout_p, training=is_training) 40 | 41 | if isinstance(dropped_w, torch.nn.Parameter): 42 | dropped_w = dropped_w.clone() 43 | 44 | return dropped_w 45 | 46 | 47 | def isnan(tensor): 48 | return np.isnan(tensor.cpu().data.numpy()).sum() > 0 49 | 50 | 51 | class EmbeddingDropout(torch.nn.Embedding): 52 | """Class for dropping out embeddings by zero'ing out parameters in the 53 | embedding matrix. 54 | 55 | This is equivalent to dropping out particular words, e.g., in the sentence 56 | 'the quick brown fox jumps over the lazy dog', dropping out 'the' would 57 | lead to the sentence '### quick brown fox jumps over ### lazy dog' (in the 58 | embedding vector space). 59 | 60 | See 'A Theoretically Grounded Application of Dropout in Recurrent Neural 61 | Networks', (Gal and Ghahramani, 2016). 62 | """ 63 | def __init__(self, 64 | num_embeddings, 65 | embedding_dim, 66 | max_norm=None, 67 | norm_type=2, 68 | scale_grad_by_freq=False, 69 | sparse=False, 70 | dropout=0.1, 71 | scale=None): 72 | """Embedding constructor. 73 | 74 | Args: 75 | dropout: Dropout probability. 76 | scale: Used to scale parameters of embedding weight matrix that are 77 | not dropped out. Note that this is _in addition_ to the 78 | `1/(1 - dropout)` scaling. 79 | 80 | See `torch.nn.Embedding` for remaining arguments. 81 | """ 82 | torch.nn.Embedding.__init__(self, 83 | num_embeddings=num_embeddings, 84 | embedding_dim=embedding_dim, 85 | max_norm=max_norm, 86 | norm_type=norm_type, 87 | scale_grad_by_freq=scale_grad_by_freq, 88 | sparse=sparse) 89 | self.dropout = dropout 90 | assert (dropout >= 0.0) and (dropout < 1.0), ('Dropout must be >= 0.0 ' 91 | 'and < 1.0') 92 | self.scale = scale 93 | 94 | def forward(self, inputs): # pylint:disable=arguments-differ 95 | """Embeds `inputs` with the dropped out embedding weight matrix.""" 96 | if self.training: 97 | dropout = self.dropout 98 | else: 99 | dropout = 0 100 | 101 | if dropout: 102 | mask = self.weight.data.new(self.weight.size(0), 1) 103 | mask.bernoulli_(1 - dropout) 104 | mask = mask.expand_as(self.weight) 105 | mask = mask / (1 - dropout) 106 | masked_weight = self.weight * Variable(mask) 107 | else: 108 | masked_weight = self.weight 109 | if self.scale and self.scale != 1: 110 | masked_weight = masked_weight * self.scale 111 | 112 | return F.embedding(inputs, 113 | masked_weight, 114 | max_norm=self.max_norm, 115 | norm_type=self.norm_type, 116 | scale_grad_by_freq=self.scale_grad_by_freq, 117 | sparse=self.sparse) 118 | 119 | 120 | class LockedDropout(nn.Module): 121 | # code from https://github.com/salesforce/awd-lstm-lm/blob/master/locked_dropout.py 122 | def __init__(self): 123 | super().__init__() 124 | 125 | def forward(self, x, dropout=0.5): 126 | if not self.training or not dropout: 127 | return x 128 | m = x.data.new(1, x.size(1), x.size(2)).bernoulli_(1 - dropout) 129 | mask = Variable(m, requires_grad=False) / (1 - dropout) 130 | mask = mask.expand_as(x) 131 | return mask * x 132 | 133 | 134 | class RNN(models.shared_base.SharedModel): 135 | """Shared RNN model.""" 136 | def __init__(self, args, corpus): 137 | models.shared_base.SharedModel.__init__(self) 138 | 139 | self.args = args 140 | self.corpus = corpus 141 | 142 | self.decoder = nn.Linear(args.shared_hid, corpus.num_tokens) 143 | self.encoder = EmbeddingDropout(corpus.num_tokens, 144 | args.shared_embed, 145 | dropout=args.shared_dropoute) 146 | self.lockdrop = LockedDropout() 147 | 148 | if self.args.tie_weights: 149 | self.decoder.weight = self.encoder.weight 150 | 151 | # NOTE(brendan): Since W^{x, c} and W^{h, c} are always summed, there 152 | # is no point duplicating their bias offset parameter. Likewise for 153 | # W^{x, h} and W^{h, h}. 154 | self.w_xc = nn.Linear(args.shared_embed, args.shared_hid) 155 | self.w_xh = nn.Linear(args.shared_embed, args.shared_hid) 156 | 157 | # The raw weights are stored here because the hidden-to-hidden weights 158 | # are weight dropped on the forward pass. 159 | self.w_hc_raw = torch.nn.Parameter( 160 | torch.Tensor(args.shared_hid, args.shared_hid)) 161 | self.w_hh_raw = torch.nn.Parameter( 162 | torch.Tensor(args.shared_hid, args.shared_hid)) 163 | self.w_hc = None 164 | self.w_hh = None 165 | 166 | self.w_h = collections.defaultdict(dict) 167 | self.w_c = collections.defaultdict(dict) 168 | 169 | for idx in range(args.num_blocks): 170 | for jdx in range(idx + 1, args.num_blocks): 171 | self.w_h[idx][jdx] = nn.Linear(args.shared_hid, 172 | args.shared_hid, 173 | bias=False) 174 | self.w_c[idx][jdx] = nn.Linear(args.shared_hid, 175 | args.shared_hid, 176 | bias=False) 177 | 178 | self._w_h = nn.ModuleList([self.w_h[idx][jdx] 179 | for idx in self.w_h 180 | for jdx in self.w_h[idx]]) 181 | self._w_c = nn.ModuleList([self.w_c[idx][jdx] 182 | for idx in self.w_c 183 | for jdx in self.w_c[idx]]) 184 | 185 | if args.mode == 'train': 186 | self.batch_norm = nn.BatchNorm1d(args.shared_hid) 187 | else: 188 | self.batch_norm = None 189 | 190 | self.reset_parameters() 191 | self.static_init_hidden = utils.keydefaultdict(self.init_hidden) 192 | 193 | logger.info(f'# of parameters: {format(self.num_parameters, ",d")}') 194 | 195 | def forward(self, # pylint:disable=arguments-differ 196 | inputs, 197 | dag, 198 | hidden=None, 199 | is_train=True): 200 | time_steps = inputs.size(0) 201 | batch_size = inputs.size(1) 202 | 203 | is_train = is_train and self.args.mode in ['train'] 204 | 205 | self.w_hh = _get_dropped_weights(self.w_hh_raw, 206 | self.args.shared_wdrop, 207 | self.training) 208 | self.w_hc = _get_dropped_weights(self.w_hc_raw, 209 | self.args.shared_wdrop, 210 | self.training) 211 | 212 | if hidden is None: 213 | hidden = self.static_init_hidden[batch_size] 214 | 215 | embed = self.encoder(inputs) 216 | 217 | if self.args.shared_dropouti > 0: 218 | embed = self.lockdrop(embed, 219 | self.args.shared_dropouti if is_train else 0) 220 | 221 | # TODO(brendan): The norm of hidden states are clipped here because 222 | # otherwise ENAS is especially prone to exploding activations on the 223 | # forward pass. This could probably be fixed in a more elegant way, but 224 | # it might be exposing a weakness in the ENAS algorithm as currently 225 | # proposed. 226 | # 227 | # For more details, see 228 | # https://github.com/carpedm20/ENAS-pytorch/issues/6 229 | clipped_num = 0 230 | max_clipped_norm = 0 231 | h1tohT = [] 232 | logits = [] 233 | for step in range(time_steps): 234 | x_t = embed[step] 235 | logit, hidden = self.cell(x_t, hidden, dag) 236 | 237 | hidden_norms = hidden.norm(dim=-1) 238 | max_norm = 25.0 239 | if hidden_norms.data.max() > max_norm: 240 | # TODO(brendan): Just directly use the torch slice operations 241 | # in PyTorch v0.4. 242 | # 243 | # This workaround for PyTorch v0.3.1 does everything in numpy, 244 | # because the PyTorch slicing and slice assignment is too 245 | # flaky. 246 | hidden_norms = hidden_norms.data.cpu().numpy() 247 | 248 | clipped_num += 1 249 | if hidden_norms.max() > max_clipped_norm: 250 | max_clipped_norm = hidden_norms.max() 251 | 252 | clip_select = hidden_norms > max_norm 253 | clip_norms = hidden_norms[clip_select] 254 | 255 | mask = np.ones(hidden.size()) 256 | normalizer = max_norm/clip_norms 257 | normalizer = normalizer[:, np.newaxis] 258 | 259 | mask[clip_select] = normalizer 260 | hidden *= torch.autograd.Variable( 261 | torch.FloatTensor(mask).cuda(), requires_grad=False) 262 | 263 | logits.append(logit) 264 | h1tohT.append(hidden) 265 | 266 | if clipped_num > 0: 267 | logger.info(f'clipped {clipped_num} hidden states in one forward ' 268 | f'pass. ' 269 | f'max clipped hidden state norm: {max_clipped_norm}') 270 | 271 | h1tohT = torch.stack(h1tohT) 272 | output = torch.stack(logits) 273 | raw_output = output 274 | if self.args.shared_dropout > 0: 275 | output = self.lockdrop(output, 276 | self.args.shared_dropout if is_train else 0) 277 | 278 | dropped_output = output 279 | 280 | decoded = self.decoder( 281 | output.view(output.size(0)*output.size(1), output.size(2))) 282 | decoded = decoded.view(output.size(0), output.size(1), decoded.size(1)) 283 | 284 | extra_out = {'dropped': dropped_output, 285 | 'hiddens': h1tohT, 286 | 'raw': raw_output} 287 | return decoded, hidden, extra_out 288 | 289 | def cell(self, x, h_prev, dag): 290 | """Computes a single pass through the discovered RNN cell.""" 291 | c = {} 292 | h = {} 293 | f = {} 294 | 295 | f[0] = self.get_f(dag[-1][0].name) 296 | c[0] = F.sigmoid(self.w_xc(x) + F.linear(h_prev, self.w_hc, None)) 297 | h[0] = (c[0]*f[0](self.w_xh(x) + F.linear(h_prev, self.w_hh, None)) + 298 | (1 - c[0])*h_prev) 299 | 300 | leaf_node_ids = [] 301 | q = collections.deque() 302 | q.append(0) 303 | 304 | # NOTE(brendan): Computes connections from the parent nodes `node_id` 305 | # to their child nodes `next_id` recursively, skipping leaf nodes. A 306 | # leaf node is a node whose id == `self.args.num_blocks`. 307 | # 308 | # Connections between parent i and child j should be computed as 309 | # h_j = c_j*f_{ij}{(W^h_{ij}*h_i)} + (1 - c_j)*h_i, 310 | # where c_j = \sigmoid{(W^c_{ij}*h_i)} 311 | # 312 | # See Training details from Section 3.1 of the paper. 313 | # 314 | # The following algorithm does a breadth-first (since `q.popleft()` is 315 | # used) search over the nodes and computes all the hidden states. 316 | while True: 317 | if len(q) == 0: 318 | break 319 | 320 | node_id = q.popleft() 321 | nodes = dag[node_id] 322 | 323 | for next_node in nodes: 324 | next_id = next_node.id 325 | if next_id == self.args.num_blocks: 326 | leaf_node_ids.append(node_id) 327 | assert len(nodes) == 1, ('parent of leaf node should have ' 328 | 'only one child') 329 | continue 330 | 331 | w_h = self.w_h[node_id][next_id] 332 | w_c = self.w_c[node_id][next_id] 333 | 334 | f[next_id] = self.get_f(next_node.name) 335 | c[next_id] = F.sigmoid(w_c(h[node_id])) 336 | h[next_id] = (c[next_id]*f[next_id](w_h(h[node_id])) + 337 | (1 - c[next_id])*h[node_id]) 338 | 339 | q.append(next_id) 340 | 341 | # TODO(brendan): Instead of averaging loose ends, perhaps there should 342 | # be a set of separate unshared weights for each "loose" connection 343 | # between each node in a cell and the output. 344 | # 345 | # As it stands, all weights W^h_{ij} are doing double duty by 346 | # connecting both from i to j, as well as from i to the output. 347 | 348 | # average all the loose ends 349 | leaf_nodes = [h[node_id] for node_id in leaf_node_ids] 350 | output = torch.mean(torch.stack(leaf_nodes, 2), -1) 351 | 352 | # stabilizing the Updates of omega 353 | if self.batch_norm is not None: 354 | output = self.batch_norm(output) 355 | 356 | return output, h[self.args.num_blocks - 1] 357 | 358 | def init_hidden(self, batch_size): 359 | zeros = torch.zeros(batch_size, self.args.shared_hid) 360 | return utils.get_variable(zeros, self.args.cuda, requires_grad=False) 361 | 362 | def get_f(self, name): 363 | name = name.lower() 364 | if name == 'relu': 365 | f = F.relu 366 | elif name == 'tanh': 367 | f = F.tanh 368 | elif name == 'identity': 369 | f = lambda x: x 370 | elif name == 'sigmoid': 371 | f = F.sigmoid 372 | return f 373 | 374 | def get_num_cell_parameters(self, dag): 375 | num = 0 376 | 377 | num += models.shared_base.size(self.w_xc) 378 | num += models.shared_base.size(self.w_xh) 379 | 380 | q = collections.deque() 381 | q.append(0) 382 | 383 | while True: 384 | if len(q) == 0: 385 | break 386 | 387 | node_id = q.popleft() 388 | nodes = dag[node_id] 389 | 390 | for next_node in nodes: 391 | next_id = next_node.id 392 | if next_id == self.args.num_blocks: 393 | assert len(nodes) == 1, 'parent of leaf node should have only one child' 394 | continue 395 | 396 | w_h = self.w_h[node_id][next_id] 397 | w_c = self.w_c[node_id][next_id] 398 | 399 | num += models.shared_base.size(w_h) 400 | num += models.shared_base.size(w_c) 401 | 402 | q.append(next_id) 403 | 404 | logger.debug(f'# of cell parameters: ' 405 | f'{format(self.num_parameters, ",d")}') 406 | return num 407 | 408 | def reset_parameters(self): 409 | init_range = 0.025 if self.args.mode == 'train' else 0.04 410 | for param in self.parameters(): 411 | param.data.uniform_(-init_range, init_range) 412 | self.decoder.bias.data.fill_(0) 413 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | ipdb 2 | tqdm 3 | scipy==1.2.1 4 | imageio 5 | pygraphviz 6 | torch 7 | torchvision 8 | tensorboardX 9 | Pillow 10 | opencv-contrib-python -------------------------------------------------------------------------------- /run.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | # Entropy decrease fast but maintaining compared to other --ema=0.95 3 | #python main.py --ema_baseline_decay=0.9 --shared_initial_step=150 --reward_c=800 & 4 | #sleep 2 5 | 6 | # BAD # live until 5k entropy was high so shared_loss was unstable 7 | #python main.py --ema_baseline_decay=0.9 --shared_initial_step=150 --reward_c=80 & 8 | #sleep 2 9 | 10 | # BAD # Entropy decrease fast and increase continuously 11 | #python main.py --ema_baseline_decay=0.95 --shared_initial_step=150 --reward_c=800 & 12 | #sleep 2 13 | 14 | # BAD explode but alive longer then upper reward_c=800 one 15 | # Entropy decrease fast and increase continuously 16 | #python main.py --ema_baseline_decay=0.95 --shared_initial_step=150 --reward_c=80 & 17 | #sleep 2 18 | 19 | #python main.py --ema_baseline_decay=0.9 --shared_initial_step=150 --reward_c=80 --ppl_square=True & 20 | sleep 2 21 | 22 | #python main.py --ema_baseline_decay=0.92 --shared_initial_step=150 --reward_c=80 --ppl_square=True & 23 | sleep 2 24 | 25 | python main.py --ema_baseline_decay=0.95 --shared_initial_step=150 --reward_c=80 --ppl_square=True & 26 | sleep 2 27 | -------------------------------------------------------------------------------- /tensorboard.py: -------------------------------------------------------------------------------- 1 | import PIL 2 | import scipy.misc 3 | from io import BytesIO 4 | import tensorboardX as tb 5 | from tensorboardX.summary import Summary 6 | 7 | 8 | class TensorBoard(object): 9 | def __init__(self, model_dir): 10 | self.summary_writer = tb.FileWriter(model_dir) 11 | 12 | def image_summary(self, tag, value, step): 13 | for idx, img in enumerate(value): 14 | summary = Summary() 15 | bio = BytesIO() 16 | 17 | if type(img) == str: 18 | img = PIL.Image.open(img) 19 | elif type(img) == PIL.Image.Image: 20 | pass 21 | else: 22 | img = scipy.misc.toimage(img) 23 | 24 | img.save(bio, format="png") 25 | image_summary = Summary.Image(encoded_image_string=bio.getvalue()) 26 | summary.value.add(tag=f"{tag}/{idx}", image=image_summary) 27 | self.summary_writer.add_summary(summary, global_step=step) 28 | 29 | def scalar_summary(self, tag, value, step): 30 | summary= Summary(value=[Summary.Value(tag=tag, simple_value=value)]) 31 | self.summary_writer.add_summary(summary, global_step=step) 32 | -------------------------------------------------------------------------------- /trainer.py: -------------------------------------------------------------------------------- 1 | """The module for training ENAS.""" 2 | import contextlib 3 | import glob 4 | import math 5 | import os 6 | 7 | import numpy as np 8 | import scipy.signal 9 | from tensorboard import TensorBoard 10 | import torch 11 | from torch import nn 12 | import torch.nn.parallel 13 | from torch.autograd import Variable 14 | 15 | import models 16 | import utils 17 | 18 | 19 | logger = utils.get_logger() 20 | 21 | 22 | def _apply_penalties(extra_out, args): 23 | """Based on `args`, optionally adds regularization penalty terms for 24 | activation regularization, temporal activation regularization and/or hidden 25 | state norm stabilization. 26 | 27 | Args: 28 | extra_out[*]: 29 | dropped: Post-dropout activations. 30 | hiddens: All hidden states for a batch of sequences. 31 | raw: Pre-dropout activations. 32 | 33 | Returns: 34 | The penalty term associated with all of the enabled regularizations. 35 | 36 | See: 37 | Regularizing and Optimizing LSTM Language Models (Merity et al., 2017) 38 | Regularizing RNNs by Stabilizing Activations (Krueger & Memsevic, 2016) 39 | """ 40 | penalty = 0 41 | 42 | # Activation regularization. 43 | if args.activation_regularization: 44 | penalty += (args.activation_regularization_amount * 45 | extra_out['dropped'].pow(2).mean()) 46 | 47 | # Temporal activation regularization (slowness) 48 | if args.temporal_activation_regularization: 49 | raw = extra_out['raw'] 50 | penalty += (args.temporal_activation_regularization_amount * 51 | (raw[1:] - raw[:-1]).pow(2).mean()) 52 | 53 | # Norm stabilizer regularization 54 | if args.norm_stabilizer_regularization: 55 | penalty += (args.norm_stabilizer_regularization_amount * 56 | (extra_out['hiddens'].norm(dim=-1) - 57 | args.norm_stabilizer_fixed_point).pow(2).mean()) 58 | 59 | return penalty 60 | 61 | 62 | def discount(x, amount): 63 | return scipy.signal.lfilter([1], [1, -amount], x[::-1], axis=0)[::-1] 64 | 65 | 66 | def _get_optimizer(name): 67 | if name.lower() == 'sgd': 68 | optim = torch.optim.SGD 69 | elif name.lower() == 'adam': 70 | optim = torch.optim.Adam 71 | 72 | return optim 73 | 74 | 75 | def _get_no_grad_ctx_mgr(): 76 | """Returns a the `torch.no_grad` context manager for PyTorch version >= 77 | 0.4, or a no-op context manager otherwise. 78 | """ 79 | if float(torch.__version__[0:3]) >= 0.4: 80 | return torch.no_grad() 81 | 82 | return contextlib.suppress() 83 | 84 | 85 | def _check_abs_max_grad(abs_max_grad, model): 86 | """Checks `model` for a new largest gradient for this epoch, in order to 87 | track gradient explosions. 88 | """ 89 | finite_grads = [p.grad.data 90 | for p in model.parameters() 91 | if p.grad is not None] 92 | 93 | new_max_grad = max([grad.max() for grad in finite_grads]) 94 | new_min_grad = min([grad.min() for grad in finite_grads]) 95 | 96 | new_abs_max_grad = max(new_max_grad, abs(new_min_grad)) 97 | if new_abs_max_grad > abs_max_grad: 98 | logger.info(f'abs max grad {abs_max_grad}') 99 | return new_abs_max_grad 100 | 101 | return abs_max_grad 102 | 103 | 104 | class Trainer(object): 105 | """A class to wrap training code.""" 106 | def __init__(self, args, dataset): 107 | """Constructor for training algorithm. 108 | 109 | Args: 110 | args: From command line, picked up by `argparse`. 111 | dataset: Currently only `data.text.Corpus` is supported. 112 | 113 | Initializes: 114 | - Data: train, val and test. 115 | - Model: shared and controller. 116 | - Inference: optimizers for shared and controller parameters. 117 | - Criticism: cross-entropy loss for training the shared model. 118 | """ 119 | self.args = args 120 | self.controller_step = 0 121 | self.cuda = args.cuda 122 | self.dataset = dataset 123 | self.epoch = 0 124 | self.shared_step = 0 125 | self.start_epoch = 0 126 | 127 | logger.info('regularizing:') 128 | for regularizer in [('activation regularization', 129 | self.args.activation_regularization), 130 | ('temporal activation regularization', 131 | self.args.temporal_activation_regularization), 132 | ('norm stabilizer regularization', 133 | self.args.norm_stabilizer_regularization)]: 134 | if regularizer[1]: 135 | logger.info(f'{regularizer[0]}') 136 | 137 | self.train_data = utils.batchify(dataset.train, 138 | args.batch_size, 139 | self.cuda) 140 | # NOTE(brendan): The validation set data is batchified twice 141 | # separately: once for computing rewards during the Train Controller 142 | # phase (valid_data, batch size == 64), and once for evaluating ppl 143 | # over the entire validation set (eval_data, batch size == 1) 144 | self.valid_data = utils.batchify(dataset.valid, 145 | args.batch_size, 146 | self.cuda) 147 | self.eval_data = utils.batchify(dataset.valid, 148 | args.test_batch_size, 149 | self.cuda) 150 | self.test_data = utils.batchify(dataset.test, 151 | args.test_batch_size, 152 | self.cuda) 153 | 154 | self.max_length = self.args.shared_rnn_max_length 155 | 156 | if args.use_tensorboard: 157 | self.tb = TensorBoard(args.model_dir) 158 | else: 159 | self.tb = None 160 | self.build_model() 161 | 162 | if self.args.load_path: 163 | self.load_model() 164 | 165 | shared_optimizer = _get_optimizer(self.args.shared_optim) 166 | controller_optimizer = _get_optimizer(self.args.controller_optim) 167 | 168 | self.shared_optim = shared_optimizer( 169 | self.shared.parameters(), 170 | lr=self.shared_lr, 171 | weight_decay=self.args.shared_l2_reg) 172 | 173 | self.controller_optim = controller_optimizer( 174 | self.controller.parameters(), 175 | lr=self.args.controller_lr) 176 | 177 | self.ce = nn.CrossEntropyLoss() 178 | 179 | def build_model(self): 180 | """Creates and initializes the shared and controller models.""" 181 | if self.args.network_type == 'rnn': 182 | self.shared = models.RNN(self.args, self.dataset) 183 | elif self.args.network_type == 'cnn': 184 | self.shared = models.CNN(self.args, self.dataset) 185 | else: 186 | raise NotImplementedError(f'Network type ' 187 | f'`{self.args.network_type}` is not ' 188 | f'defined') 189 | self.controller = models.Controller(self.args) 190 | 191 | if self.args.num_gpu == 1: 192 | self.shared.cuda() 193 | self.controller.cuda() 194 | elif self.args.num_gpu > 1: 195 | raise NotImplementedError('`num_gpu > 1` is in progress') 196 | 197 | def train(self, single=False): 198 | """Cycles through alternately training the shared parameters and the 199 | controller, as described in Section 2.2, Training ENAS and Deriving 200 | Architectures, of the paper. 201 | 202 | From the paper (for Penn Treebank): 203 | 204 | - In the first phase, shared parameters omega are trained for 400 205 | steps, each on a minibatch of 64 examples. 206 | 207 | - In the second phase, the controller's parameters are trained for 2000 208 | steps. 209 | 210 | Args: 211 | single (bool): If True it won't train the controller and use the 212 | same dag instead of derive(). 213 | """ 214 | dag = utils.load_dag(self.args) if single else None 215 | 216 | if self.args.shared_initial_step > 0: 217 | self.train_shared(self.args.shared_initial_step) 218 | self.train_controller() 219 | 220 | for self.epoch in range(self.start_epoch, self.args.max_epoch): 221 | # 1. Training the shared parameters omega of the child models 222 | self.train_shared(dag=dag) 223 | 224 | # 2. Training the controller parameters theta 225 | if not single: 226 | self.train_controller() 227 | 228 | if self.epoch % self.args.save_epoch == 0: 229 | with _get_no_grad_ctx_mgr(): 230 | best_dag = dag if dag else self.derive() 231 | self.evaluate(self.eval_data, 232 | best_dag, 233 | 'val_best', 234 | max_num=self.args.batch_size*100) 235 | self.save_model() 236 | 237 | if self.epoch >= self.args.shared_decay_after: 238 | utils.update_lr(self.shared_optim, self.shared_lr) 239 | 240 | def get_loss(self, inputs, targets, hidden, dags): 241 | """Computes the loss for the same batch for M models. 242 | 243 | This amounts to an estimate of the loss, which is turned into an 244 | estimate for the gradients of the shared model. 245 | """ 246 | if not isinstance(dags, list): 247 | dags = [dags] 248 | 249 | loss = 0 250 | for dag in dags: 251 | output, hidden, extra_out = self.shared(inputs, dag, hidden=hidden) 252 | output_flat = output.view(-1, self.dataset.num_tokens) 253 | sample_loss = (self.ce(output_flat, targets) / 254 | self.args.shared_num_sample) 255 | loss += sample_loss 256 | 257 | assert len(dags) == 1, 'there are multiple `hidden` for multple `dags`' 258 | return loss, hidden, extra_out 259 | 260 | def train_shared(self, max_step=None, dag=None): 261 | """Train the language model for 400 steps of minibatches of 64 262 | examples. 263 | 264 | Args: 265 | max_step: Used to run extra training steps as a warm-up. 266 | dag: If not None, is used instead of calling sample(). 267 | 268 | BPTT is truncated at 35 timesteps. 269 | 270 | For each weight update, gradients are estimated by sampling M models 271 | from the fixed controller policy, and averaging their gradients 272 | computed on a batch of training data. 273 | """ 274 | model = self.shared 275 | model.train() 276 | self.controller.eval() 277 | 278 | hidden = self.shared.init_hidden(self.args.batch_size) 279 | 280 | if max_step is None: 281 | max_step = self.args.shared_max_step 282 | else: 283 | max_step = min(self.args.shared_max_step, max_step) 284 | 285 | abs_max_grad = 0 286 | abs_max_hidden_norm = 0 287 | step = 0 288 | raw_total_loss = 0 289 | total_loss = 0 290 | train_idx = 0 291 | # TODO(brendan): Why - 1 - 1? 292 | while train_idx < self.train_data.size(0) - 1 - 1: 293 | if step > max_step: 294 | break 295 | 296 | dags = dag if dag else self.controller.sample( 297 | self.args.shared_num_sample) 298 | inputs, targets = self.get_batch(self.train_data, 299 | train_idx, 300 | self.max_length) 301 | 302 | loss, hidden, extra_out = self.get_loss(inputs, 303 | targets, 304 | hidden, 305 | dags) 306 | hidden.detach_() 307 | raw_total_loss += loss.data 308 | 309 | loss += _apply_penalties(extra_out, self.args) 310 | 311 | # update 312 | self.shared_optim.zero_grad() 313 | loss.backward() 314 | 315 | h1tohT = extra_out['hiddens'] 316 | new_abs_max_hidden_norm = utils.to_item( 317 | h1tohT.norm(dim=-1).data.max()) 318 | if new_abs_max_hidden_norm > abs_max_hidden_norm: 319 | abs_max_hidden_norm = new_abs_max_hidden_norm 320 | logger.info(f'max hidden {abs_max_hidden_norm}') 321 | abs_max_grad = _check_abs_max_grad(abs_max_grad, model) 322 | torch.nn.utils.clip_grad_norm(model.parameters(), 323 | self.args.shared_grad_clip) 324 | self.shared_optim.step() 325 | 326 | total_loss += loss.data 327 | 328 | if ((step % self.args.log_step) == 0) and (step > 0): 329 | self._summarize_shared_train(total_loss, raw_total_loss) 330 | raw_total_loss = 0 331 | total_loss = 0 332 | 333 | step += 1 334 | self.shared_step += 1 335 | train_idx += self.max_length 336 | 337 | def get_reward(self, dag, entropies, hidden, valid_idx=0): 338 | """Computes the perplexity of a single sampled model on a minibatch of 339 | validation data. 340 | """ 341 | if not isinstance(entropies, np.ndarray): 342 | entropies = entropies.data.cpu().numpy() 343 | 344 | inputs, targets = self.get_batch(self.valid_data, 345 | valid_idx, 346 | self.max_length, 347 | volatile=True) 348 | valid_loss, hidden, _ = self.get_loss(inputs, targets, hidden, dag) 349 | valid_loss = utils.to_item(valid_loss.data) 350 | 351 | valid_ppl = math.exp(valid_loss) 352 | 353 | # TODO: we don't know reward_c 354 | if self.args.ppl_square: 355 | # TODO: but we do know reward_c=80 in the previous paper 356 | R = self.args.reward_c / valid_ppl ** 2 357 | else: 358 | R = self.args.reward_c / valid_ppl 359 | 360 | if self.args.entropy_mode == 'reward': 361 | rewards = R + self.args.entropy_coeff * entropies 362 | elif self.args.entropy_mode == 'regularizer': 363 | rewards = R * np.ones_like(entropies) 364 | else: 365 | raise NotImplementedError(f'Unkown entropy mode: {self.args.entropy_mode}') 366 | 367 | return rewards, hidden 368 | 369 | def train_controller(self): 370 | """Fixes the shared parameters and updates the controller parameters. 371 | 372 | The controller is updated with a score function gradient estimator 373 | (i.e., REINFORCE), with the reward being c/valid_ppl, where valid_ppl 374 | is computed on a minibatch of validation data. 375 | 376 | A moving average baseline is used. 377 | 378 | The controller is trained for 2000 steps per epoch (i.e., 379 | first (Train Shared) phase -> second (Train Controller) phase). 380 | """ 381 | model = self.controller 382 | model.train() 383 | # TODO(brendan): Why can't we call shared.eval() here? Leads to loss 384 | # being uniformly zero for the controller. 385 | # self.shared.eval() 386 | 387 | avg_reward_base = None 388 | baseline = None 389 | adv_history = [] 390 | entropy_history = [] 391 | reward_history = [] 392 | 393 | hidden = self.shared.init_hidden(self.args.batch_size) 394 | total_loss = 0 395 | valid_idx = 0 396 | for step in range(self.args.controller_max_step): 397 | # sample models 398 | dags, log_probs, entropies = self.controller.sample( 399 | with_details=True) 400 | 401 | # calculate reward 402 | np_entropies = entropies.data.cpu().numpy() 403 | # NOTE(brendan): No gradients should be backpropagated to the 404 | # shared model during controller training, obviously. 405 | with _get_no_grad_ctx_mgr(): 406 | rewards, hidden = self.get_reward(dags, 407 | np_entropies, 408 | hidden, 409 | valid_idx) 410 | 411 | # discount 412 | if 1 > self.args.discount > 0: 413 | rewards = discount(rewards, self.args.discount) 414 | 415 | reward_history.extend(rewards) 416 | entropy_history.extend(np_entropies) 417 | 418 | # moving average baseline 419 | if baseline is None: 420 | baseline = rewards 421 | else: 422 | decay = self.args.ema_baseline_decay 423 | baseline = decay * baseline + (1 - decay) * rewards 424 | 425 | adv = rewards - baseline 426 | adv_history.extend(adv) 427 | 428 | # policy loss 429 | loss = -log_probs*utils.get_variable(adv, 430 | self.cuda, 431 | requires_grad=False) 432 | if self.args.entropy_mode == 'regularizer': 433 | loss -= self.args.entropy_coeff * entropies 434 | 435 | loss = loss.sum() # or loss.mean() 436 | 437 | # update 438 | self.controller_optim.zero_grad() 439 | loss.backward() 440 | 441 | if self.args.controller_grad_clip > 0: 442 | torch.nn.utils.clip_grad_norm(model.parameters(), 443 | self.args.controller_grad_clip) 444 | self.controller_optim.step() 445 | 446 | total_loss += utils.to_item(loss.data) 447 | 448 | if ((step % self.args.log_step) == 0) and (step > 0): 449 | self._summarize_controller_train(total_loss, 450 | adv_history, 451 | entropy_history, 452 | reward_history, 453 | avg_reward_base, 454 | dags) 455 | 456 | reward_history, adv_history, entropy_history = [], [], [] 457 | total_loss = 0 458 | 459 | self.controller_step += 1 460 | 461 | prev_valid_idx = valid_idx 462 | valid_idx = ((valid_idx + self.max_length) % 463 | (self.valid_data.size(0) - 1)) 464 | # NOTE(brendan): Whenever we wrap around to the beginning of the 465 | # validation data, we reset the hidden states. 466 | if prev_valid_idx > valid_idx: 467 | hidden = self.shared.init_hidden(self.args.batch_size) 468 | 469 | def evaluate(self, source, dag, name, batch_size=1, max_num=None): 470 | """Evaluate on the validation set. 471 | 472 | NOTE(brendan): We should not be using the test set to develop the 473 | algorithm (basic machine learning good practices). 474 | """ 475 | self.shared.eval() 476 | self.controller.eval() 477 | 478 | data = source[:max_num*self.max_length] 479 | 480 | total_loss = 0 481 | hidden = self.shared.init_hidden(batch_size) 482 | 483 | pbar = range(0, data.size(0) - 1, self.max_length) 484 | for count, idx in enumerate(pbar): 485 | inputs, targets = self.get_batch(data, idx, volatile=True) 486 | output, hidden, _ = self.shared(inputs, 487 | dag, 488 | hidden=hidden, 489 | is_train=False) 490 | output_flat = output.view(-1, self.dataset.num_tokens) 491 | total_loss += len(inputs) * self.ce(output_flat, targets).data 492 | hidden.detach_() 493 | ppl = math.exp(utils.to_item(total_loss) / (count + 1) / self.max_length) 494 | 495 | val_loss = utils.to_item(total_loss) / len(data) 496 | ppl = math.exp(val_loss) 497 | 498 | self.tb.scalar_summary(f'eval/{name}_loss', val_loss, self.epoch) 499 | self.tb.scalar_summary(f'eval/{name}_ppl', ppl, self.epoch) 500 | logger.info(f'eval | loss: {val_loss:8.2f} | ppl: {ppl:8.2f}') 501 | 502 | def derive(self, sample_num=None, valid_idx=0): 503 | """TODO(brendan): We are always deriving based on the very first batch 504 | of validation data? This seems wrong... 505 | """ 506 | hidden = self.shared.init_hidden(self.args.batch_size) 507 | 508 | if sample_num is None: 509 | sample_num = self.args.derive_num_sample 510 | 511 | dags, _, entropies = self.controller.sample(sample_num, 512 | with_details=True) 513 | 514 | max_R = 0 515 | best_dag = None 516 | for dag in dags: 517 | R, _ = self.get_reward(dag, entropies, hidden, valid_idx) 518 | if R.max() > max_R: 519 | max_R = R.max() 520 | best_dag = dag 521 | 522 | logger.info(f'derive | max_R: {max_R:8.6f}') 523 | fname = (f'{self.epoch:03d}-{self.controller_step:06d}-' 524 | f'{max_R:6.4f}-best.png') 525 | path = os.path.join(self.args.model_dir, 'networks', fname) 526 | utils.draw_network(best_dag, path) 527 | self.tb.image_summary('derive/best', [path], self.epoch) 528 | 529 | return best_dag 530 | 531 | @property 532 | def shared_lr(self): 533 | degree = max(self.epoch - self.args.shared_decay_after + 1, 0) 534 | return self.args.shared_lr * (self.args.shared_decay ** degree) 535 | 536 | @property 537 | def controller_lr(self): 538 | return self.args.controller_lr 539 | 540 | def get_batch(self, source, idx, length=None, volatile=False): 541 | # code from 542 | # https://github.com/pytorch/examples/blob/master/word_language_model/main.py 543 | length = min(length if length else self.max_length, 544 | len(source) - 1 - idx) 545 | data = Variable(source[idx:idx + length], volatile=volatile) 546 | target = Variable(source[idx + 1:idx + 1 + length].view(-1), 547 | volatile=volatile) 548 | return data, target 549 | 550 | @property 551 | def shared_path(self): 552 | return f'{self.args.model_dir}/shared_epoch{self.epoch}_step{self.shared_step}.pth' 553 | 554 | @property 555 | def controller_path(self): 556 | return f'{self.args.model_dir}/controller_epoch{self.epoch}_step{self.controller_step}.pth' 557 | 558 | def get_saved_models_info(self): 559 | paths = glob.glob(os.path.join(self.args.model_dir, '*.pth')) 560 | paths.sort() 561 | 562 | def get_numbers(items, delimiter, idx, replace_word, must_contain=''): 563 | return list(set([int( 564 | name.split(delimiter)[idx].replace(replace_word, '')) 565 | for name in basenames if must_contain in name])) 566 | 567 | basenames = [os.path.basename(path.rsplit('.', 1)[0]) for path in paths] 568 | epochs = get_numbers(basenames, '_', 1, 'epoch') 569 | shared_steps = get_numbers(basenames, '_', 2, 'step', 'shared') 570 | controller_steps = get_numbers(basenames, '_', 2, 'step', 'controller') 571 | 572 | epochs.sort() 573 | shared_steps.sort() 574 | controller_steps.sort() 575 | 576 | return epochs, shared_steps, controller_steps 577 | 578 | def save_model(self): 579 | torch.save(self.shared.state_dict(), self.shared_path) 580 | logger.info(f'[*] SAVED: {self.shared_path}') 581 | 582 | torch.save(self.controller.state_dict(), self.controller_path) 583 | logger.info(f'[*] SAVED: {self.controller_path}') 584 | 585 | epochs, shared_steps, controller_steps = self.get_saved_models_info() 586 | 587 | for epoch in epochs[:-self.args.max_save_num]: 588 | paths = glob.glob( 589 | os.path.join(self.args.model_dir, f'*_epoch{epoch}_*.pth')) 590 | 591 | for path in paths: 592 | utils.remove_file(path) 593 | 594 | def load_model(self): 595 | epochs, shared_steps, controller_steps = self.get_saved_models_info() 596 | 597 | if len(epochs) == 0: 598 | logger.info(f'[!] No checkpoint found in {self.args.model_dir}...') 599 | return 600 | 601 | self.epoch = self.start_epoch = max(epochs) 602 | self.shared_step = max(shared_steps) 603 | self.controller_step = max(controller_steps) 604 | 605 | if self.args.num_gpu == 0: 606 | map_location = lambda storage, loc: storage 607 | else: 608 | map_location = None 609 | 610 | self.shared.load_state_dict( 611 | torch.load(self.shared_path, map_location=map_location)) 612 | logger.info(f'[*] LOADED: {self.shared_path}') 613 | 614 | self.controller.load_state_dict( 615 | torch.load(self.controller_path, map_location=map_location)) 616 | logger.info(f'[*] LOADED: {self.controller_path}') 617 | 618 | def _summarize_controller_train(self, 619 | total_loss, 620 | adv_history, 621 | entropy_history, 622 | reward_history, 623 | avg_reward_base, 624 | dags): 625 | """Logs the controller's progress for this training epoch.""" 626 | cur_loss = total_loss / self.args.log_step 627 | 628 | avg_adv = np.mean(adv_history) 629 | avg_entropy = np.mean(entropy_history) 630 | avg_reward = np.mean(reward_history) 631 | 632 | if avg_reward_base is None: 633 | avg_reward_base = avg_reward 634 | 635 | logger.info( 636 | f'| epoch {self.epoch:3d} | lr {self.controller_lr:.5f} ' 637 | f'| R {avg_reward:.5f} | entropy {avg_entropy:.4f} ' 638 | f'| loss {cur_loss:.5f}') 639 | 640 | # Tensorboard 641 | if self.tb is not None: 642 | self.tb.scalar_summary('controller/loss', 643 | cur_loss, 644 | self.controller_step) 645 | self.tb.scalar_summary('controller/reward', 646 | avg_reward, 647 | self.controller_step) 648 | self.tb.scalar_summary('controller/reward-B_per_epoch', 649 | avg_reward - avg_reward_base, 650 | self.controller_step) 651 | self.tb.scalar_summary('controller/entropy', 652 | avg_entropy, 653 | self.controller_step) 654 | self.tb.scalar_summary('controller/adv', 655 | avg_adv, 656 | self.controller_step) 657 | 658 | paths = [] 659 | for dag in dags: 660 | fname = (f'{self.epoch:03d}-{self.controller_step:06d}-' 661 | f'{avg_reward:6.4f}.png') 662 | path = os.path.join(self.args.model_dir, 'networks', fname) 663 | utils.draw_network(dag, path) 664 | paths.append(path) 665 | 666 | self.tb.image_summary('controller/sample', 667 | paths, 668 | self.controller_step) 669 | 670 | def _summarize_shared_train(self, total_loss, raw_total_loss): 671 | """Logs a set of training steps.""" 672 | cur_loss = utils.to_item(total_loss) / self.args.log_step 673 | # NOTE(brendan): The raw loss, without adding in the activation 674 | # regularization terms, should be used to compute ppl. 675 | cur_raw_loss = utils.to_item(raw_total_loss) / self.args.log_step 676 | ppl = math.exp(cur_raw_loss) 677 | 678 | logger.info(f'| epoch {self.epoch:3d} ' 679 | f'| lr {self.shared_lr:4.2f} ' 680 | f'| raw loss {cur_raw_loss:.2f} ' 681 | f'| loss {cur_loss:.2f} ' 682 | f'| ppl {ppl:8.2f}') 683 | 684 | # Tensorboard 685 | if self.tb is not None: 686 | self.tb.scalar_summary('shared/loss', 687 | cur_loss, 688 | self.shared_step) 689 | self.tb.scalar_summary('shared/perplexity', 690 | ppl, 691 | self.shared_step) 692 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | 3 | from collections import defaultdict 4 | import collections 5 | from datetime import datetime 6 | import os 7 | import json 8 | import logging 9 | 10 | import numpy as np 11 | import pygraphviz as pgv 12 | 13 | import torch 14 | from torch.autograd import Variable 15 | 16 | from PIL import Image 17 | from PIL import ImageFont 18 | from PIL import ImageDraw 19 | 20 | 21 | try: 22 | import scipy.misc 23 | imread = scipy.misc.imread 24 | imresize = scipy.misc.imresize 25 | imsave = imwrite = scipy.misc.imsave 26 | except: 27 | import cv2 28 | imread = cv2.imread 29 | imresize = cv2.imresize 30 | imsave = imwrite = cv2.imwrite 31 | 32 | 33 | ########################## 34 | # Network visualization 35 | ########################## 36 | 37 | def add_node(graph, node_id, label, shape='box', style='filled'): 38 | if label.startswith('x'): 39 | color = 'white' 40 | elif label.startswith('h'): 41 | color = 'skyblue' 42 | elif label == 'tanh': 43 | color = 'yellow' 44 | elif label == 'ReLU': 45 | color = 'pink' 46 | elif label == 'identity': 47 | color = 'orange' 48 | elif label == 'sigmoid': 49 | color = 'greenyellow' 50 | elif label == 'avg': 51 | color = 'seagreen3' 52 | else: 53 | color = 'white' 54 | 55 | if not any(label.startswith(word) for word in ['x', 'avg', 'h']): 56 | label = f"{label}\n({node_id})" 57 | 58 | graph.add_node( 59 | node_id, label=label, color='black', fillcolor=color, 60 | shape=shape, style=style, 61 | ) 62 | 63 | def draw_network(dag, path): 64 | makedirs(os.path.dirname(path)) 65 | graph = pgv.AGraph(directed=True, strict=True, 66 | fontname='Helvetica', arrowtype='open') # not work? 67 | 68 | checked_ids = [-2, -1, 0] 69 | 70 | if -1 in dag: 71 | add_node(graph, -1, 'x[t]') 72 | if -2 in dag: 73 | add_node(graph, -2, 'h[t-1]') 74 | 75 | add_node(graph, 0, dag[-1][0].name) 76 | 77 | for idx in dag: 78 | for node in dag[idx]: 79 | if node.id not in checked_ids: 80 | add_node(graph, node.id, node.name) 81 | checked_ids.append(node.id) 82 | graph.add_edge(idx, node.id) 83 | 84 | graph.layout(prog='dot') 85 | graph.draw(path) 86 | 87 | def make_gif(paths, gif_path, max_frame=50, prefix=""): 88 | import imageio 89 | 90 | paths.sort() 91 | 92 | skip_frame = len(paths) // max_frame 93 | paths = paths[::skip_frame] 94 | 95 | images = [imageio.imread(path) for path in paths] 96 | max_h, max_w, max_c = np.max( 97 | np.array([image.shape for image in images]), 0) 98 | 99 | for idx, image in enumerate(images): 100 | h, w, c = image.shape 101 | blank = np.ones([max_h, max_w, max_c], dtype=np.uint8) * 255 102 | 103 | pivot_h, pivot_w = (max_h-h)//2, (max_w-w)//2 104 | blank[pivot_h:pivot_h+h,pivot_w:pivot_w+w,:c] = image 105 | 106 | images[idx] = blank 107 | 108 | try: 109 | images = [Image.fromarray(image) for image in images] 110 | draws = [ImageDraw.Draw(image) for image in images] 111 | font = ImageFont.truetype("assets/arial.ttf", 30) 112 | 113 | steps = [int(os.path.basename(path).rsplit('.', 1)[0].split('-')[1]) for path in paths] 114 | for step, draw in zip(steps, draws): 115 | draw.text((max_h//20, max_h//20), 116 | f"{prefix}step: {format(step, ',d')}", (0, 0, 0), font=font) 117 | except IndexError: 118 | pass 119 | 120 | imageio.mimsave(gif_path, [np.array(img) for img in images], duration=0.5) 121 | 122 | 123 | ########################## 124 | # Torch 125 | ########################## 126 | 127 | def detach(h): 128 | if type(h) == Variable: 129 | return Variable(h.data) 130 | else: 131 | return tuple(detach(v) for v in h) 132 | 133 | def get_variable(inputs, cuda=False, **kwargs): 134 | if type(inputs) in [list, np.ndarray]: 135 | inputs = torch.Tensor(inputs) 136 | if cuda: 137 | out = Variable(inputs.cuda(), **kwargs) 138 | else: 139 | out = Variable(inputs, **kwargs) 140 | return out 141 | 142 | def update_lr(optimizer, lr): 143 | for param_group in optimizer.param_groups: 144 | param_group['lr'] = lr 145 | 146 | def batchify(data, bsz, use_cuda): 147 | # code from https://github.com/pytorch/examples/blob/master/word_language_model/main.py 148 | nbatch = data.size(0) // bsz 149 | data = data.narrow(0, 0, nbatch * bsz) 150 | data = data.view(bsz, -1).t().contiguous() 151 | if use_cuda: 152 | data = data.cuda() 153 | return data 154 | 155 | 156 | ########################## 157 | # ETC 158 | ########################## 159 | 160 | Node = collections.namedtuple('Node', ['id', 'name']) 161 | 162 | 163 | class keydefaultdict(defaultdict): 164 | def __missing__(self, key): 165 | if self.default_factory is None: 166 | raise KeyError(key) 167 | else: 168 | ret = self[key] = self.default_factory(key) 169 | return ret 170 | 171 | 172 | def to_item(x): 173 | """Converts x, possibly scalar and possibly tensor, to a Python scalar.""" 174 | if isinstance(x, (float, int)): 175 | return x 176 | 177 | if float(torch.__version__[0:3]) < 0.4: 178 | assert (x.dim() == 1) and (len(x) == 1) 179 | return x[0] 180 | 181 | return x.item() 182 | 183 | 184 | def get_logger(name=__file__, level=logging.INFO): 185 | logger = logging.getLogger(name) 186 | 187 | if getattr(logger, '_init_done__', None): 188 | logger.setLevel(level) 189 | return logger 190 | 191 | logger._init_done__ = True 192 | logger.propagate = False 193 | logger.setLevel(level) 194 | 195 | formatter = logging.Formatter("%(asctime)s:%(levelname)s::%(message)s") 196 | handler = logging.StreamHandler() 197 | handler.setFormatter(formatter) 198 | handler.setLevel(0) 199 | 200 | del logger.handlers[:] 201 | logger.addHandler(handler) 202 | 203 | return logger 204 | 205 | 206 | logger = get_logger() 207 | 208 | 209 | def prepare_dirs(args): 210 | """Sets the directories for the model, and creates those directories. 211 | 212 | Args: 213 | args: Parsed from `argparse` in the `config` module. 214 | """ 215 | if args.load_path: 216 | if args.load_path.startswith(args.log_dir): 217 | args.model_dir = args.load_path 218 | else: 219 | if args.load_path.startswith(args.dataset): 220 | args.model_name = args.load_path 221 | else: 222 | args.model_name = "{}_{}".format(args.dataset, args.load_path) 223 | else: 224 | args.model_name = "{}_{}".format(args.dataset, get_time()) 225 | 226 | if not hasattr(args, 'model_dir'): 227 | args.model_dir = os.path.join(args.log_dir, args.model_name) 228 | args.data_path = os.path.join(args.data_dir, args.dataset) 229 | 230 | for path in [args.log_dir, args.data_dir, args.model_dir]: 231 | if not os.path.exists(path): 232 | makedirs(path) 233 | 234 | def get_time(): 235 | return datetime.now().strftime("%Y-%m-%d_%H-%M-%S") 236 | 237 | def save_args(args): 238 | param_path = os.path.join(args.model_dir, "params.json") 239 | 240 | logger.info("[*] MODEL dir: %s" % args.model_dir) 241 | logger.info("[*] PARAM path: %s" % param_path) 242 | 243 | with open(param_path, 'w') as fp: 244 | json.dump(args.__dict__, fp, indent=4, sort_keys=True) 245 | 246 | def save_dag(args, dag, name): 247 | save_path = os.path.join(args.model_dir, name) 248 | logger.info("[*] Save dag : {}".format(save_path)) 249 | json.dump(dag, open(save_path, 'w')) 250 | 251 | def load_dag(args): 252 | load_path = os.path.join(args.dag_path) 253 | logger.info("[*] Load dag : {}".format(load_path)) 254 | with open(load_path) as f: 255 | dag = json.load(f) 256 | dag = {int(k): [Node(el[0], el[1]) for el in v] for k, v in dag.items()} 257 | save_dag(args, dag, "dag.json") 258 | draw_network(dag, os.path.join(args.model_dir, "dag.png")) 259 | return dag 260 | 261 | def makedirs(path): 262 | if not os.path.exists(path): 263 | logger.info("[*] Make directories : {}".format(path)) 264 | os.makedirs(path) 265 | 266 | def remove_file(path): 267 | if os.path.exists(path): 268 | logger.info("[*] Removed: {}".format(path)) 269 | os.remove(path) 270 | 271 | def backup_file(path): 272 | root, ext = os.path.splitext(path) 273 | new_path = "{}.backup_{}{}".format(root, get_time(), ext) 274 | 275 | os.rename(path, new_path) 276 | logger.info("[*] {} has backup: {}".format(path, new_path)) 277 | --------------------------------------------------------------------------------