├── .gitignore
├── LICENSE
├── README.md
├── assets
├── ENAS_cnn.png
├── ENAS_rnn.png
├── arial.ttf
├── best_rnn_epoch27.png
├── cnn.png
├── cnn_cell.png
├── ptb.gif
├── rnn.png
└── wikitext.gif
├── config.py
├── dag.json
├── data
├── __init__.py
├── image.py
├── ptb
│ ├── test.txt
│ ├── train.txt
│ └── valid.txt
├── text.py
└── wikitext
│ ├── README
│ ├── test.txt
│ ├── train.txt
│ └── valid.txt
├── generate_gif.py
├── main.py
├── models
├── __init__.py
├── controller.py
├── shared_base.py
├── shared_cnn.py
└── shared_rnn.py
├── requirements.txt
├── run.sh
├── tensorboard.py
├── trainer.py
└── utils.py
/.gitignore:
--------------------------------------------------------------------------------
1 | # Data
2 | *.png
3 | *.gif
4 | *.tar.gz
5 | data/cifar-10-batches-py
6 |
7 | # ipython checkpoints
8 | .ipynb_checkpoints
9 |
10 | # Log
11 | logs
12 |
13 | # ETC
14 | .vscode
15 |
16 | # Created by https://www.gitignore.io/api/python,vim
17 |
18 | ### Python ###
19 | # Byte-compiled / optimized / DLL files
20 | __pycache__/
21 | *.py[cod]
22 | *$py.class
23 |
24 | # C extensions
25 | *.so
26 |
27 | # Distribution / packaging
28 | .Python
29 | env/
30 | build/
31 | develop-eggs/
32 | dist/
33 | downloads/
34 | eggs/
35 | .eggs/
36 | lib/
37 | lib64/
38 | parts/
39 | sdist/
40 | var/
41 | wheels/
42 | *.egg-info/
43 | .installed.cfg
44 | *.egg
45 |
46 | # PyInstaller
47 | # Usually these files are written by a python script from a template
48 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
49 | *.manifest
50 | *.spec
51 |
52 | # Installer logs
53 | pip-log.txt
54 | pip-delete-this-directory.txt
55 |
56 | # Unit test / coverage reports
57 | htmlcov/
58 | .tox/
59 | .coverage
60 | .coverage.*
61 | .cache
62 | nosetests.xml
63 | coverage.xml
64 | *,cover
65 | .hypothesis/
66 |
67 | # Translations
68 | *.mo
69 | *.pot
70 |
71 | # Django stuff:
72 | *.log
73 | local_settings.py
74 |
75 | # Flask stuff:
76 | instance/
77 | .webassets-cache
78 |
79 | # Scrapy stuff:
80 | .scrapy
81 |
82 | # Sphinx documentation
83 | docs/_build/
84 |
85 | # PyBuilder
86 | target/
87 |
88 | # Jupyter Notebook
89 | .ipynb_checkpoints
90 |
91 | # pyenv
92 | .python-version
93 |
94 | # celery beat schedule file
95 | celerybeat-schedule
96 |
97 | # dotenv
98 | .env
99 |
100 | # virtualenv
101 | .venv/
102 | venv/
103 | ENV/
104 |
105 | # Spyder project settings
106 | .spyderproject
107 |
108 | # Rope project settings
109 | .ropeproject
110 |
111 |
112 | ### Vim ###
113 | # swap
114 | [._]*.s[a-v][a-z]
115 | [._]*.sw[a-p]
116 | [._]s[a-v][a-z]
117 | [._]sw[a-p]
118 | # session
119 | Session.vim
120 | # temporary
121 | .netrwhist
122 | *~
123 | # auto-generated tag files
124 | tags
125 |
126 | # End of https://www.gitignore.io/api/python,vim
127 | main.sh
128 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | Apache License
2 | Version 2.0, January 2004
3 | http://www.apache.org/licenses/
4 |
5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6 |
7 | 1. Definitions.
8 |
9 | "License" shall mean the terms and conditions for use, reproduction,
10 | and distribution as defined by Sections 1 through 9 of this document.
11 |
12 | "Licensor" shall mean the copyright owner or entity authorized by
13 | the copyright owner that is granting the License.
14 |
15 | "Legal Entity" shall mean the union of the acting entity and all
16 | other entities that control, are controlled by, or are under common
17 | control with that entity. For the purposes of this definition,
18 | "control" means (i) the power, direct or indirect, to cause the
19 | direction or management of such entity, whether by contract or
20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
21 | outstanding shares, or (iii) beneficial ownership of such entity.
22 |
23 | "You" (or "Your") shall mean an individual or Legal Entity
24 | exercising permissions granted by this License.
25 |
26 | "Source" form shall mean the preferred form for making modifications,
27 | including but not limited to software source code, documentation
28 | source, and configuration files.
29 |
30 | "Object" form shall mean any form resulting from mechanical
31 | transformation or translation of a Source form, including but
32 | not limited to compiled object code, generated documentation,
33 | and conversions to other media types.
34 |
35 | "Work" shall mean the work of authorship, whether in Source or
36 | Object form, made available under the License, as indicated by a
37 | copyright notice that is included in or attached to the work
38 | (an example is provided in the Appendix below).
39 |
40 | "Derivative Works" shall mean any work, whether in Source or Object
41 | form, that is based on (or derived from) the Work and for which the
42 | editorial revisions, annotations, elaborations, or other modifications
43 | represent, as a whole, an original work of authorship. For the purposes
44 | of this License, Derivative Works shall not include works that remain
45 | separable from, or merely link (or bind by name) to the interfaces of,
46 | the Work and Derivative Works thereof.
47 |
48 | "Contribution" shall mean any work of authorship, including
49 | the original version of the Work and any modifications or additions
50 | to that Work or Derivative Works thereof, that is intentionally
51 | submitted to Licensor for inclusion in the Work by the copyright owner
52 | or by an individual or Legal Entity authorized to submit on behalf of
53 | the copyright owner. For the purposes of this definition, "submitted"
54 | means any form of electronic, verbal, or written communication sent
55 | to the Licensor or its representatives, including but not limited to
56 | communication on electronic mailing lists, source code control systems,
57 | and issue tracking systems that are managed by, or on behalf of, the
58 | Licensor for the purpose of discussing and improving the Work, but
59 | excluding communication that is conspicuously marked or otherwise
60 | designated in writing by the copyright owner as "Not a Contribution."
61 |
62 | "Contributor" shall mean Licensor and any individual or Legal Entity
63 | on behalf of whom a Contribution has been received by Licensor and
64 | subsequently incorporated within the Work.
65 |
66 | 2. Grant of Copyright License. Subject to the terms and conditions of
67 | this License, each Contributor hereby grants to You a perpetual,
68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69 | copyright license to reproduce, prepare Derivative Works of,
70 | publicly display, publicly perform, sublicense, and distribute the
71 | Work and such Derivative Works in Source or Object form.
72 |
73 | 3. Grant of Patent License. Subject to the terms and conditions of
74 | this License, each Contributor hereby grants to You a perpetual,
75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76 | (except as stated in this section) patent license to make, have made,
77 | use, offer to sell, sell, import, and otherwise transfer the Work,
78 | where such license applies only to those patent claims licensable
79 | by such Contributor that are necessarily infringed by their
80 | Contribution(s) alone or by combination of their Contribution(s)
81 | with the Work to which such Contribution(s) was submitted. If You
82 | institute patent litigation against any entity (including a
83 | cross-claim or counterclaim in a lawsuit) alleging that the Work
84 | or a Contribution incorporated within the Work constitutes direct
85 | or contributory patent infringement, then any patent licenses
86 | granted to You under this License for that Work shall terminate
87 | as of the date such litigation is filed.
88 |
89 | 4. Redistribution. You may reproduce and distribute copies of the
90 | Work or Derivative Works thereof in any medium, with or without
91 | modifications, and in Source or Object form, provided that You
92 | meet the following conditions:
93 |
94 | (a) You must give any other recipients of the Work or
95 | Derivative Works a copy of this License; and
96 |
97 | (b) You must cause any modified files to carry prominent notices
98 | stating that You changed the files; and
99 |
100 | (c) You must retain, in the Source form of any Derivative Works
101 | that You distribute, all copyright, patent, trademark, and
102 | attribution notices from the Source form of the Work,
103 | excluding those notices that do not pertain to any part of
104 | the Derivative Works; and
105 |
106 | (d) If the Work includes a "NOTICE" text file as part of its
107 | distribution, then any Derivative Works that You distribute must
108 | include a readable copy of the attribution notices contained
109 | within such NOTICE file, excluding those notices that do not
110 | pertain to any part of the Derivative Works, in at least one
111 | of the following places: within a NOTICE text file distributed
112 | as part of the Derivative Works; within the Source form or
113 | documentation, if provided along with the Derivative Works; or,
114 | within a display generated by the Derivative Works, if and
115 | wherever such third-party notices normally appear. The contents
116 | of the NOTICE file are for informational purposes only and
117 | do not modify the License. You may add Your own attribution
118 | notices within Derivative Works that You distribute, alongside
119 | or as an addendum to the NOTICE text from the Work, provided
120 | that such additional attribution notices cannot be construed
121 | as modifying the License.
122 |
123 | You may add Your own copyright statement to Your modifications and
124 | may provide additional or different license terms and conditions
125 | for use, reproduction, or distribution of Your modifications, or
126 | for any such Derivative Works as a whole, provided Your use,
127 | reproduction, and distribution of the Work otherwise complies with
128 | the conditions stated in this License.
129 |
130 | 5. Submission of Contributions. Unless You explicitly state otherwise,
131 | any Contribution intentionally submitted for inclusion in the Work
132 | by You to the Licensor shall be under the terms and conditions of
133 | this License, without any additional terms or conditions.
134 | Notwithstanding the above, nothing herein shall supersede or modify
135 | the terms of any separate license agreement you may have executed
136 | with Licensor regarding such Contributions.
137 |
138 | 6. Trademarks. This License does not grant permission to use the trade
139 | names, trademarks, service marks, or product names of the Licensor,
140 | except as required for reasonable and customary use in describing the
141 | origin of the Work and reproducing the content of the NOTICE file.
142 |
143 | 7. Disclaimer of Warranty. Unless required by applicable law or
144 | agreed to in writing, Licensor provides the Work (and each
145 | Contributor provides its Contributions) on an "AS IS" BASIS,
146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147 | implied, including, without limitation, any warranties or conditions
148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149 | PARTICULAR PURPOSE. You are solely responsible for determining the
150 | appropriateness of using or redistributing the Work and assume any
151 | risks associated with Your exercise of permissions under this License.
152 |
153 | 8. Limitation of Liability. In no event and under no legal theory,
154 | whether in tort (including negligence), contract, or otherwise,
155 | unless required by applicable law (such as deliberate and grossly
156 | negligent acts) or agreed to in writing, shall any Contributor be
157 | liable to You for damages, including any direct, indirect, special,
158 | incidental, or consequential damages of any character arising as a
159 | result of this License or out of the use or inability to use the
160 | Work (including but not limited to damages for loss of goodwill,
161 | work stoppage, computer failure or malfunction, or any and all
162 | other commercial damages or losses), even if such Contributor
163 | has been advised of the possibility of such damages.
164 |
165 | 9. Accepting Warranty or Additional Liability. While redistributing
166 | the Work or Derivative Works thereof, You may choose to offer,
167 | and charge a fee for, acceptance of support, warranty, indemnity,
168 | or other liability obligations and/or rights consistent with this
169 | License. However, in accepting such obligations, You may act only
170 | on Your own behalf and on Your sole responsibility, not on behalf
171 | of any other Contributor, and only if You agree to indemnify,
172 | defend, and hold each Contributor harmless for any liability
173 | incurred by, or claims asserted against, such Contributor by reason
174 | of your accepting any such warranty or additional liability.
175 |
176 | END OF TERMS AND CONDITIONS
177 |
178 | APPENDIX: How to apply the Apache License to your work.
179 |
180 | To apply the Apache License to your work, attach the following
181 | boilerplate notice, with the fields enclosed by brackets "[]"
182 | replaced with your own identifying information. (Don't include
183 | the brackets!) The text should be enclosed in the appropriate
184 | comment syntax for the file format. We also recommend that a
185 | file or class name and description of purpose be included on the
186 | same "printed page" as the copyright notice for easier
187 | identification within third-party archives.
188 |
189 | Copyright [yyyy] [name of copyright owner]
190 |
191 | Licensed under the Apache License, Version 2.0 (the "License");
192 | you may not use this file except in compliance with the License.
193 | You may obtain a copy of the License at
194 |
195 | http://www.apache.org/licenses/LICENSE-2.0
196 |
197 | Unless required by applicable law or agreed to in writing, software
198 | distributed under the License is distributed on an "AS IS" BASIS,
199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200 | See the License for the specific language governing permissions and
201 | limitations under the License.
202 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Efficient Neural Architecture Search (ENAS) in PyTorch
2 |
3 | PyTorch implementation of [Efficient Neural Architecture Search via Parameters Sharing](https://arxiv.org/abs/1802.03268).
4 |
5 |

6 |
7 | **ENAS** reduce the computational requirement (GPU-hours) of [Neural Architecture Search](https://arxiv.org/abs/1611.01578) (**NAS**) by 1000x via parameter sharing between models that are subgraphs within a large computational graph. SOTA on `Penn Treebank` language modeling.
8 |
9 | **\*\*[Caveat] Use official code from the authors: [link](https://github.com/melodyguan/enas)\*\***
10 |
11 |
12 | ## Prerequisites
13 |
14 | - Python 3.6+
15 | - [PyTorch==0.3.1](https://pytorch.org/get-started/previous-versions/)
16 | - tqdm, scipy, imageio, graphviz, tensorboardX
17 |
18 | ## Usage
19 |
20 | Install prerequisites with:
21 |
22 | conda install graphviz
23 | pip install -r requirements.txt
24 |
25 | To train **ENAS** to discover a recurrent cell for RNN:
26 |
27 | python main.py --network_type rnn --dataset ptb --controller_optim adam --controller_lr 0.00035 \
28 | --shared_optim sgd --shared_lr 20.0 --entropy_coeff 0.0001
29 |
30 | python main.py --network_type rnn --dataset wikitext
31 |
32 | To train **ENAS** to discover CNN architecture (in progress):
33 |
34 | python main.py --network_type cnn --dataset cifar --controller_optim momentum --controller_lr_cosine=True \
35 | --controller_lr_max 0.05 --controller_lr_min 0.0001 --entropy_coeff 0.1
36 |
37 | or you can use your own dataset by placing images like:
38 |
39 | data
40 | ├── YOUR_TEXT_DATASET
41 | │ ├── test.txt
42 | │ ├── train.txt
43 | │ └── valid.txt
44 | ├── YOUR_IMAGE_DATASET
45 | │ ├── test
46 | │ │ ├── xxx.jpg (name doesn't matter)
47 | │ │ ├── yyy.jpg (name doesn't matter)
48 | │ │ └── ...
49 | │ ├── train
50 | │ │ ├── xxx.jpg
51 | │ │ └── ...
52 | │ └── valid
53 | │ ├── xxx.jpg
54 | │ └── ...
55 | ├── image.py
56 | └── text.py
57 |
58 | To generate `gif` image of generated samples:
59 |
60 | python generate_gif.py --model_name=ptb_2018-02-15_11-20-02 --output=sample.gif
61 |
62 | More configurations can be found [here](config.py).
63 |
64 |
65 | ## Results
66 |
67 | Efficient Neural Architecture Search (**ENAS**) is composed of two sets of learnable parameters, controller LSTM *θ* and the shared parameters *ω*. These two parameters are alternatively trained and only trained controller is used to derive novel architectures.
68 |
69 | ### 1. Discovering Recurrent Cells
70 |
71 | 
72 |
73 | Controller LSTM decide 1) what activation function to use and 2) which previous node to connect.
74 |
75 | The RNN cell **ENAS** discovered for `Penn Treebank` and `WikiText-2` dataset:
76 |
77 |
78 |
79 | Best discovered ENAS cell for `Penn Treebank` at epoch 27:
80 |
81 |
82 |
83 | You can see the details of training (e.g. `reward`, `entropy`, `loss`) with:
84 |
85 | tensorboard --logdir=logs --port=6006
86 |
87 |
88 | ### 2. Discovering Convolutional Neural Networks
89 |
90 | 
91 |
92 | Controller LSTM samples 1) what computation operation to use and 2) which previous node to connect.
93 |
94 | The CNN network **ENAS** discovered for `CIFAR-10` dataset:
95 |
96 | (in progress)
97 |
98 |
99 | ### 3. Designing Convolutional Cells
100 |
101 | (in progress)
102 |
103 |
104 | ## Reference
105 |
106 | - [Neural Architecture Search with Reinforcement Learning](https://arxiv.org/abs/1611.01578)
107 | - [Neural Optimizer Search with Reinforcement Learning](https://arxiv.org/abs/1709.07417)
108 |
109 |
110 | ## Author
111 |
112 | Taehoon Kim / [@carpedm20](http://carpedm20.github.io/)
113 |
--------------------------------------------------------------------------------
/assets/ENAS_cnn.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/carpedm20/ENAS-pytorch/0468b8c4ddcf540c9ed6f80c27289792ff9118c9/assets/ENAS_cnn.png
--------------------------------------------------------------------------------
/assets/ENAS_rnn.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/carpedm20/ENAS-pytorch/0468b8c4ddcf540c9ed6f80c27289792ff9118c9/assets/ENAS_rnn.png
--------------------------------------------------------------------------------
/assets/arial.ttf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/carpedm20/ENAS-pytorch/0468b8c4ddcf540c9ed6f80c27289792ff9118c9/assets/arial.ttf
--------------------------------------------------------------------------------
/assets/best_rnn_epoch27.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/carpedm20/ENAS-pytorch/0468b8c4ddcf540c9ed6f80c27289792ff9118c9/assets/best_rnn_epoch27.png
--------------------------------------------------------------------------------
/assets/cnn.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/carpedm20/ENAS-pytorch/0468b8c4ddcf540c9ed6f80c27289792ff9118c9/assets/cnn.png
--------------------------------------------------------------------------------
/assets/cnn_cell.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/carpedm20/ENAS-pytorch/0468b8c4ddcf540c9ed6f80c27289792ff9118c9/assets/cnn_cell.png
--------------------------------------------------------------------------------
/assets/ptb.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/carpedm20/ENAS-pytorch/0468b8c4ddcf540c9ed6f80c27289792ff9118c9/assets/ptb.gif
--------------------------------------------------------------------------------
/assets/rnn.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/carpedm20/ENAS-pytorch/0468b8c4ddcf540c9ed6f80c27289792ff9118c9/assets/rnn.png
--------------------------------------------------------------------------------
/assets/wikitext.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/carpedm20/ENAS-pytorch/0468b8c4ddcf540c9ed6f80c27289792ff9118c9/assets/wikitext.gif
--------------------------------------------------------------------------------
/config.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | from utils import get_logger
3 |
4 | logger = get_logger()
5 |
6 |
7 | arg_lists = []
8 | parser = argparse.ArgumentParser()
9 |
10 | def str2bool(v):
11 | return v.lower() in ('true')
12 |
13 | def add_argument_group(name):
14 | arg = parser.add_argument_group(name)
15 | arg_lists.append(arg)
16 | return arg
17 |
18 | # Network
19 | net_arg = add_argument_group('Network')
20 | net_arg.add_argument('--network_type', type=str, choices=['rnn', 'cnn'], default='rnn')
21 |
22 | # Controller
23 | net_arg.add_argument('--num_blocks', type=int, default=12)
24 | net_arg.add_argument('--tie_weights', type=str2bool, default=True)
25 | net_arg.add_argument('--controller_hid', type=int, default=100)
26 |
27 | # Shared parameters for PTB
28 | # NOTE(brendan): See Merity config for wdrop
29 | # https://github.com/salesforce/awd-lstm-lm.
30 | net_arg.add_argument('--shared_wdrop', type=float, default=0.5)
31 | net_arg.add_argument('--shared_dropout', type=float, default=0.4) # TODO
32 | net_arg.add_argument('--shared_dropoute', type=float, default=0.1) # TODO
33 | net_arg.add_argument('--shared_dropouti', type=float, default=0.65) # TODO
34 | net_arg.add_argument('--shared_embed', type=int, default=1000) # TODO: 200, 500, 1000
35 | net_arg.add_argument('--shared_hid', type=int, default=1000)
36 | net_arg.add_argument('--shared_rnn_max_length', type=int, default=35)
37 | net_arg.add_argument('--shared_rnn_activations', type=eval,
38 | default="['tanh', 'ReLU', 'identity', 'sigmoid']")
39 | net_arg.add_argument('--shared_cnn_types', type=eval,
40 | default="['3x3', '5x5', 'sep 3x3', 'sep 5x5', 'max 3x3', 'max 5x5']")
41 |
42 | # PTB regularizations
43 | net_arg.add_argument('--activation_regularization',
44 | type=str2bool,
45 | default=False)
46 | net_arg.add_argument('--activation_regularization_amount',
47 | type=float,
48 | default=2.0)
49 | net_arg.add_argument('--temporal_activation_regularization',
50 | type=str2bool,
51 | default=False)
52 | net_arg.add_argument('--temporal_activation_regularization_amount',
53 | type=float,
54 | default=1.0)
55 | net_arg.add_argument('--norm_stabilizer_regularization',
56 | type=str2bool,
57 | default=False)
58 | net_arg.add_argument('--norm_stabilizer_regularization_amount',
59 | type=float,
60 | default=1.0)
61 | net_arg.add_argument('--norm_stabilizer_fixed_point', type=float, default=5.0)
62 |
63 | # Shared parameters for CIFAR
64 | net_arg.add_argument('--cnn_hid', type=int, default=64)
65 |
66 |
67 | # Data
68 | data_arg = add_argument_group('Data')
69 | data_arg.add_argument('--dataset', type=str, default='ptb')
70 |
71 |
72 | # Training / test parameters
73 | learn_arg = add_argument_group('Learning')
74 | learn_arg.add_argument('--mode', type=str, default='train',
75 | choices=['train', 'derive', 'test', 'single'],
76 | help='train: Training ENAS, derive: Deriving Architectures,\
77 | single: training one dag')
78 | learn_arg.add_argument('--batch_size', type=int, default=64)
79 | learn_arg.add_argument('--test_batch_size', type=int, default=1)
80 | learn_arg.add_argument('--max_epoch', type=int, default=150)
81 | learn_arg.add_argument('--entropy_mode', type=str, default='reward', choices=['reward', 'regularizer'])
82 |
83 |
84 | # Controller
85 | learn_arg.add_argument('--ppl_square', type=str2bool, default=False)
86 | # NOTE(brendan): (Zoph and Le, 2017) page 8 states that c is a constant,
87 | # usually set at 80.
88 | learn_arg.add_argument('--reward_c', type=int, default=80,
89 | help="WE DON'T KNOW WHAT THIS VALUE SHOULD BE") # TODO
90 | # NOTE(brendan): irrelevant for actor critic.
91 | learn_arg.add_argument('--ema_baseline_decay', type=float, default=0.95) # TODO: very important
92 | learn_arg.add_argument('--discount', type=float, default=1.0) # TODO
93 | learn_arg.add_argument('--controller_max_step', type=int, default=2000,
94 | help='step for controller parameters')
95 | learn_arg.add_argument('--controller_optim', type=str, default='adam')
96 | learn_arg.add_argument('--controller_lr', type=float, default=3.5e-4,
97 | help="will be ignored if --controller_lr_cosine=True")
98 | learn_arg.add_argument('--controller_lr_cosine', type=str2bool, default=False)
99 | learn_arg.add_argument('--controller_lr_max', type=float, default=0.05,
100 | help="lr max for cosine schedule")
101 | learn_arg.add_argument('--controller_lr_min', type=float, default=0.001,
102 | help="lr min for cosine schedule")
103 | learn_arg.add_argument('--controller_grad_clip', type=float, default=0)
104 | learn_arg.add_argument('--tanh_c', type=float, default=2.5)
105 | learn_arg.add_argument('--softmax_temperature', type=float, default=5.0)
106 | learn_arg.add_argument('--entropy_coeff', type=float, default=1e-4)
107 |
108 | # Shared parameters
109 | learn_arg.add_argument('--shared_initial_step', type=int, default=0)
110 | learn_arg.add_argument('--shared_max_step', type=int, default=400,
111 | help='step for shared parameters')
112 | # NOTE(brendan): Should be 10 for CNN architectures.
113 | learn_arg.add_argument('--shared_num_sample', type=int, default=1,
114 | help='# of Monte Carlo samples')
115 | learn_arg.add_argument('--shared_optim', type=str, default='sgd')
116 | learn_arg.add_argument('--shared_lr', type=float, default=20.0)
117 | learn_arg.add_argument('--shared_decay', type=float, default=0.96)
118 | learn_arg.add_argument('--shared_decay_after', type=float, default=15)
119 | learn_arg.add_argument('--shared_l2_reg', type=float, default=1e-7)
120 | learn_arg.add_argument('--shared_grad_clip', type=float, default=0.25)
121 |
122 | # Deriving Architectures
123 | learn_arg.add_argument('--derive_num_sample', type=int, default=100)
124 |
125 |
126 | # Misc
127 | misc_arg = add_argument_group('Misc')
128 | misc_arg.add_argument('--load_path', type=str, default='')
129 | misc_arg.add_argument('--log_step', type=int, default=50)
130 | misc_arg.add_argument('--save_epoch', type=int, default=4)
131 | misc_arg.add_argument('--max_save_num', type=int, default=4)
132 | misc_arg.add_argument('--log_level', type=str, default='INFO', choices=['INFO', 'DEBUG', 'WARN'])
133 | misc_arg.add_argument('--log_dir', type=str, default='logs')
134 | misc_arg.add_argument('--data_dir', type=str, default='data')
135 | misc_arg.add_argument('--num_gpu', type=int, default=1)
136 | misc_arg.add_argument('--random_seed', type=int, default=12345)
137 | misc_arg.add_argument('--use_tensorboard', type=str2bool, default=True)
138 | misc_arg.add_argument('--dag_path', type=str, default='')
139 |
140 | def get_args():
141 | """Parses all of the arguments above, which mostly correspond to the
142 | hyperparameters mentioned in the paper.
143 | """
144 | args, unparsed = parser.parse_known_args()
145 | if args.num_gpu > 0:
146 | setattr(args, 'cuda', True)
147 | else:
148 | setattr(args, 'cuda', False)
149 | if len(unparsed) > 1:
150 | logger.info(f"Unparsed args: {unparsed}")
151 | return args, unparsed
152 |
--------------------------------------------------------------------------------
/dag.json:
--------------------------------------------------------------------------------
1 | {
2 | "-1": [
3 | [
4 | 0,
5 | "tanh"
6 | ]
7 | ],
8 | "-2": [
9 | [
10 | 0,
11 | "tanh"
12 | ]
13 | ],
14 | "0": [
15 | [
16 | 1,
17 | "tanh"
18 | ]
19 | ],
20 | "1": [
21 | [
22 | 2,
23 | "ReLU"
24 | ],
25 | [
26 | 3,
27 | "tanh"
28 | ]
29 | ],
30 | "2": [
31 | [
32 | 4,
33 | "tanh"
34 | ],
35 | [
36 | 5,
37 | "tanh"
38 | ],
39 | [
40 | 6,
41 | "ReLU"
42 | ]
43 | ],
44 | "4": [
45 | [
46 | 7,
47 | "ReLU"
48 | ]
49 | ],
50 | "7": [
51 | [
52 | 8,
53 | "ReLU"
54 | ]
55 | ],
56 | "8": [
57 | [
58 | 9,
59 | "ReLU"
60 | ],
61 | [
62 | 10,
63 | "ReLU"
64 | ],
65 | [
66 | 11,
67 | "ReLU"
68 | ]
69 | ],
70 | "3": [
71 | [
72 | 12,
73 | "avg"
74 | ]
75 | ],
76 | "5": [
77 | [
78 | 12,
79 | "avg"
80 | ]
81 | ],
82 | "6": [
83 | [
84 | 12,
85 | "avg"
86 | ]
87 | ],
88 | "9": [
89 | [
90 | 12,
91 | "avg"
92 | ]
93 | ],
94 | "10": [
95 | [
96 | 12,
97 | "avg"
98 | ]
99 | ],
100 | "11": [
101 | [
102 | 12,
103 | "avg"
104 | ]
105 | ],
106 | "12": [
107 | [
108 | 13,
109 | "h[t]"
110 | ]
111 | ]
112 | }
--------------------------------------------------------------------------------
/data/__init__.py:
--------------------------------------------------------------------------------
1 | import data.text
2 | import data.image
3 |
--------------------------------------------------------------------------------
/data/image.py:
--------------------------------------------------------------------------------
1 | import torch as t
2 | import torchvision.datasets as datasets
3 | import torchvision.transforms as transforms
4 |
5 |
6 | class Image(object):
7 | def __init__(self, args):
8 | if args.dataset == 'cifar10':
9 | Dataset = datasets.CIFAR10
10 |
11 | mean = [0.49139968, 0.48215827, 0.44653124]
12 | std = [0.24703233, 0.24348505, 0.26158768]
13 |
14 | normalize = transforms.Normalize(mean, std)
15 |
16 | transform = transforms.Compose([
17 | transforms.RandomCrop(32, padding=4),
18 | transforms.RandomHorizontalFlip(),
19 | transforms.ToTensor(),
20 | normalize,
21 | ])
22 | elif args.dataset == 'MNIST':
23 | Dataset = datasets.MNIST
24 | else:
25 | raise NotImplementedError(f'Unknown dataset: {args.dataset}')
26 |
27 | self.train = t.utils.data.DataLoader(
28 | Dataset(root='./data', train=True, transform=transform, download=True),
29 | batch_size=args.batch_size, shuffle=True,
30 | num_workers=args.num_workers, pin_memory=True)
31 |
32 | self.valid = t.utils.data.DataLoader(
33 | Dataset(root='./data', train=False, transform=transforms.Compose([
34 | transforms.ToTensor(),
35 | normalize,
36 | ])),
37 | batch_size=args.batch_size, shuffle=False,
38 | num_workers=args.num_workers, pin_memory=True)
39 |
40 | self.test = self.valid
41 |
--------------------------------------------------------------------------------
/data/text.py:
--------------------------------------------------------------------------------
1 | # Code from https://github.com/salesforce/awd-lstm-lm
2 | import os
3 | import torch as t
4 |
5 | import collections
6 |
7 |
8 | class Dictionary(object):
9 | def __init__(self):
10 | self.word2idx = {}
11 | self.idx2word = []
12 | self.counter = collections.Counter()
13 | self.total = 0
14 |
15 | def add_word(self, word):
16 | if word not in self.word2idx:
17 | self.idx2word.append(word)
18 | self.word2idx[word] = len(self.idx2word) - 1
19 |
20 | token_id = self.word2idx[word]
21 | self.counter[token_id] += 1
22 | self.total += 1
23 |
24 | return token_id
25 |
26 | def __len__(self):
27 | return len(self.idx2word)
28 |
29 |
30 | class Corpus(object):
31 | def __init__(self, path):
32 | self.dictionary = Dictionary()
33 | self.train = self.tokenize(os.path.join(path, 'train.txt'))
34 | self.valid = self.tokenize(os.path.join(path, 'valid.txt'))
35 | self.test = self.tokenize(os.path.join(path, 'test.txt'))
36 | self.num_tokens = len(self.dictionary)
37 |
38 | def tokenize(self, path):
39 | """Tokenizes a text file."""
40 | assert os.path.exists(path)
41 | # Add words to the dictionary
42 | with open(path, 'r') as f:
43 | tokens = 0
44 | for line in f:
45 | words = line.split() + ['']
46 | tokens += len(words)
47 | for word in words:
48 | self.dictionary.add_word(word)
49 |
50 | # Tokenize file content
51 | with open(path, 'r') as f:
52 | ids = t.LongTensor(tokens)
53 | token = 0
54 | for line in f:
55 | words = line.split() + ['']
56 | for word in words:
57 | ids[token] = self.dictionary.word2idx[word]
58 | token += 1
59 |
60 | return ids
61 |
--------------------------------------------------------------------------------
/data/wikitext/README:
--------------------------------------------------------------------------------
1 | This is raw data from the wikitext-2 dataset.
2 |
3 | See https://www.salesforce.com/products/einstein/ai-research/the-wikitext-dependency-language-modeling-dataset/
4 |
--------------------------------------------------------------------------------
/generate_gif.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 |
3 | import argparse
4 | from glob import glob
5 |
6 | from utils import make_gif
7 |
8 | parser = argparse.ArgumentParser()
9 | parser.add_argument("--model_name", type=str)
10 | parser.add_argument("--max_frame", type=int, default=50)
11 | parser.add_argument("--output", type=str, default="sampe.gif")
12 | parser.add_argument("--title", type=str, default="")
13 |
14 | if __name__ == "__main__":
15 | args = parser.parse_args()
16 |
17 | paths = glob(f"./logs/{args.model_name}/networks/*.png")
18 | make_gif(paths, args.output,
19 | max_frame=args.max_frame,
20 | prefix=f"{args.title}\n" if args.title else "")
21 |
--------------------------------------------------------------------------------
/main.py:
--------------------------------------------------------------------------------
1 | """Entry point."""
2 | import os
3 |
4 | import torch
5 |
6 | import data
7 | import config
8 | import utils
9 | import trainer
10 |
11 | logger = utils.get_logger()
12 |
13 |
14 | def main(args): # pylint:disable=redefined-outer-name
15 | """main: Entry point."""
16 | utils.prepare_dirs(args)
17 |
18 | torch.manual_seed(args.random_seed)
19 |
20 | if args.num_gpu > 0:
21 | torch.cuda.manual_seed(args.random_seed)
22 |
23 | if args.network_type == 'rnn':
24 | dataset = data.text.Corpus(args.data_path)
25 | elif args.dataset == 'cifar':
26 | dataset = data.image.Image(args.data_path)
27 | else:
28 | raise NotImplementedError(f"{args.dataset} is not supported")
29 |
30 | trnr = trainer.Trainer(args, dataset)
31 |
32 | if args.mode == 'train':
33 | utils.save_args(args)
34 | trnr.train()
35 | elif args.mode == 'derive':
36 | assert args.load_path != "", ("`--load_path` should be given in "
37 | "`derive` mode")
38 | trnr.derive()
39 | elif args.mode == 'test':
40 | if not args.load_path:
41 | raise Exception("[!] You should specify `load_path` to load a "
42 | "pretrained model")
43 | trnr.test()
44 | elif args.mode == 'single':
45 | if not args.dag_path:
46 | raise Exception("[!] You should specify `dag_path` to load a dag")
47 | utils.save_args(args)
48 | trnr.train(single=True)
49 | else:
50 | raise Exception(f"[!] Mode not found: {args.mode}")
51 |
52 | if __name__ == "__main__":
53 | args, unparsed = config.get_args()
54 | main(args)
55 |
--------------------------------------------------------------------------------
/models/__init__.py:
--------------------------------------------------------------------------------
1 | from models.shared_rnn import RNN
2 | from models.shared_cnn import CNN
3 | from models.controller import Controller
4 |
--------------------------------------------------------------------------------
/models/controller.py:
--------------------------------------------------------------------------------
1 | """A module with NAS controller-related code."""
2 | import collections
3 | import os
4 |
5 | import torch
6 | import torch.nn.functional as F
7 |
8 | import utils
9 | from utils import Node
10 |
11 |
12 | def _construct_dags(prev_nodes, activations, func_names, num_blocks):
13 | """Constructs a set of DAGs based on the actions, i.e., previous nodes and
14 | activation functions, sampled from the controller/policy pi.
15 |
16 | Args:
17 | prev_nodes: Previous node actions from the policy.
18 | activations: Activations sampled from the policy.
19 | func_names: Mapping from activation function names to functions.
20 | num_blocks: Number of blocks in the target RNN cell.
21 |
22 | Returns:
23 | A list of DAGs defined by the inputs.
24 |
25 | RNN cell DAGs are represented in the following way:
26 |
27 | 1. Each element (node) in a DAG is a list of `Node`s.
28 |
29 | 2. The `Node`s in the list dag[i] correspond to the subsequent nodes
30 | that take the output from node i as their own input.
31 |
32 | 3. dag[-1] is the node that takes input from x^{(t)} and h^{(t - 1)}.
33 | dag[-1] always feeds dag[0].
34 | dag[-1] acts as if `w_xc`, `w_hc`, `w_xh` and `w_hh` are its
35 | weights.
36 |
37 | 4. dag[N - 1] is the node that produces the hidden state passed to
38 | the next timestep. dag[N - 1] is also always a leaf node, and therefore
39 | is always averaged with the other leaf nodes and fed to the output
40 | decoder.
41 | """
42 | dags = []
43 | for nodes, func_ids in zip(prev_nodes, activations):
44 | dag = collections.defaultdict(list)
45 |
46 | # add first node
47 | dag[-1] = [Node(0, func_names[func_ids[0]])]
48 | dag[-2] = [Node(0, func_names[func_ids[0]])]
49 |
50 | # add following nodes
51 | for jdx, (idx, func_id) in enumerate(zip(nodes, func_ids[1:])):
52 | dag[utils.to_item(idx)].append(Node(jdx + 1, func_names[func_id]))
53 |
54 | leaf_nodes = set(range(num_blocks)) - dag.keys()
55 |
56 | # merge with avg
57 | for idx in leaf_nodes:
58 | dag[idx] = [Node(num_blocks, 'avg')]
59 |
60 | # TODO(brendan): This is actually y^{(t)}. h^{(t)} is node N - 1 in
61 | # the graph, where N Is the number of nodes. I.e., h^{(t)} takes
62 | # only one other node as its input.
63 | # last h[t] node
64 | last_node = Node(num_blocks + 1, 'h[t]')
65 | dag[num_blocks] = [last_node]
66 | dags.append(dag)
67 |
68 | return dags
69 |
70 |
71 | class Controller(torch.nn.Module):
72 | """Based on
73 | https://github.com/pytorch/examples/blob/master/word_language_model/model.py
74 |
75 | TODO(brendan): RL controllers do not necessarily have much to do with
76 | language models.
77 |
78 | Base the controller RNN on the GRU from:
79 | https://github.com/ikostrikov/pytorch-a2c-ppo-acktr/blob/master/model.py
80 | """
81 | def __init__(self, args):
82 | torch.nn.Module.__init__(self)
83 | self.args = args
84 |
85 | if self.args.network_type == 'rnn':
86 | # NOTE(brendan): `num_tokens` here is just the activation function
87 | # for every even step,
88 | self.num_tokens = [len(args.shared_rnn_activations)]
89 | for idx in range(self.args.num_blocks):
90 | self.num_tokens += [idx + 1,
91 | len(args.shared_rnn_activations)]
92 | self.func_names = args.shared_rnn_activations
93 | elif self.args.network_type == 'cnn':
94 | self.num_tokens = [len(args.shared_cnn_types),
95 | self.args.num_blocks]
96 | self.func_names = args.shared_cnn_types
97 |
98 | num_total_tokens = sum(self.num_tokens)
99 |
100 | self.encoder = torch.nn.Embedding(num_total_tokens,
101 | args.controller_hid)
102 | self.lstm = torch.nn.LSTMCell(args.controller_hid, args.controller_hid)
103 |
104 | # TODO(brendan): Perhaps these weights in the decoder should be
105 | # shared? At least for the activation functions, which all have the
106 | # same size.
107 | self.decoders = []
108 | for idx, size in enumerate(self.num_tokens):
109 | decoder = torch.nn.Linear(args.controller_hid, size)
110 | self.decoders.append(decoder)
111 |
112 | self._decoders = torch.nn.ModuleList(self.decoders)
113 |
114 | self.reset_parameters()
115 | self.static_init_hidden = utils.keydefaultdict(self.init_hidden)
116 |
117 | def _get_default_hidden(key):
118 | return utils.get_variable(
119 | torch.zeros(key, self.args.controller_hid),
120 | self.args.cuda,
121 | requires_grad=False)
122 |
123 | self.static_inputs = utils.keydefaultdict(_get_default_hidden)
124 |
125 | def reset_parameters(self):
126 | init_range = 0.1
127 | for param in self.parameters():
128 | param.data.uniform_(-init_range, init_range)
129 | for decoder in self.decoders:
130 | decoder.bias.data.fill_(0)
131 |
132 | def forward(self, # pylint:disable=arguments-differ
133 | inputs,
134 | hidden,
135 | block_idx,
136 | is_embed):
137 | if not is_embed:
138 | embed = self.encoder(inputs)
139 | else:
140 | embed = inputs
141 |
142 | hx, cx = self.lstm(embed, hidden)
143 | logits = self.decoders[block_idx](hx)
144 |
145 | logits /= self.args.softmax_temperature
146 |
147 | # exploration
148 | if self.args.mode == 'train':
149 | logits = (self.args.tanh_c*F.tanh(logits))
150 |
151 | return logits, (hx, cx)
152 |
153 | def sample(self, batch_size=1, with_details=False, save_dir=None):
154 | """Samples a set of `args.num_blocks` many computational nodes from the
155 | controller, where each node is made up of an activation function, and
156 | each node except the last also includes a previous node.
157 | """
158 | if batch_size < 1:
159 | raise Exception(f'Wrong batch_size: {batch_size} < 1')
160 |
161 | # [B, L, H]
162 | inputs = self.static_inputs[batch_size]
163 | hidden = self.static_init_hidden[batch_size]
164 |
165 | activations = []
166 | entropies = []
167 | log_probs = []
168 | prev_nodes = []
169 | # NOTE(brendan): The RNN controller alternately outputs an activation,
170 | # followed by a previous node, for each block except the last one,
171 | # which only gets an activation function. The last node is the output
172 | # node, and its previous node is the average of all leaf nodes.
173 | for block_idx in range(2*(self.args.num_blocks - 1) + 1):
174 | logits, hidden = self.forward(inputs,
175 | hidden,
176 | block_idx,
177 | is_embed=(block_idx == 0))
178 |
179 | probs = F.softmax(logits, dim=-1)
180 | log_prob = F.log_softmax(logits, dim=-1)
181 | # TODO(brendan): .mean() for entropy?
182 | entropy = -(log_prob * probs).sum(1, keepdim=False)
183 |
184 | action = probs.multinomial(num_samples=1).data
185 | selected_log_prob = log_prob.gather(
186 | 1, utils.get_variable(action, requires_grad=False))
187 |
188 | # TODO(brendan): why the [:, 0] here? Should it be .squeeze(), or
189 | # .view()? Same below with `action`.
190 | entropies.append(entropy)
191 | log_probs.append(selected_log_prob[:, 0])
192 |
193 | # 0: function, 1: previous node
194 | mode = block_idx % 2
195 | inputs = utils.get_variable(
196 | action[:, 0] + sum(self.num_tokens[:mode]),
197 | requires_grad=False)
198 |
199 | if mode == 0:
200 | activations.append(action[:, 0])
201 | elif mode == 1:
202 | prev_nodes.append(action[:, 0])
203 |
204 | prev_nodes = torch.stack(prev_nodes).transpose(0, 1)
205 | activations = torch.stack(activations).transpose(0, 1)
206 |
207 | dags = _construct_dags(prev_nodes,
208 | activations,
209 | self.func_names,
210 | self.args.num_blocks)
211 |
212 | if save_dir is not None:
213 | for idx, dag in enumerate(dags):
214 | utils.draw_network(dag,
215 | os.path.join(save_dir, f'graph{idx}.png'))
216 |
217 | if with_details:
218 | return dags, torch.cat(log_probs), torch.cat(entropies)
219 |
220 | return dags
221 |
222 | def init_hidden(self, batch_size):
223 | zeros = torch.zeros(batch_size, self.args.controller_hid)
224 | return (utils.get_variable(zeros, self.args.cuda, requires_grad=False),
225 | utils.get_variable(zeros.clone(), self.args.cuda, requires_grad=False))
226 |
--------------------------------------------------------------------------------
/models/shared_base.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 |
4 |
5 | def size(p):
6 | return np.prod(p.size())
7 |
8 | class SharedModel(torch.nn.Module):
9 | def __init__(self):
10 | torch.nn.Module.__init__(self)
11 |
12 | @property
13 | def num_parameters(self):
14 | return sum([size(param) for param in self.parameters()])
15 |
16 | def get_f(self, name):
17 | raise NotImplementedError()
18 |
19 | def get_num_cell_parameters(self, dag):
20 | raise NotImplementedError()
21 |
22 | def reset_parameters(self):
23 | raise NotImplementedError()
24 |
--------------------------------------------------------------------------------
/models/shared_cnn.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from collections import defaultdict, deque
3 |
4 | import torch as t
5 | from torch import nn
6 | import torch.nn.functional as F
7 | from torch.autograd import Variable
8 |
9 | from models.shared_base import *
10 | from utils import get_logger, get_variable, keydefaultdict
11 |
12 | logger = get_logger()
13 |
14 |
15 | def conv3x3(in_planes, out_planes, stride=1):
16 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
17 | padding=1, bias=False)
18 |
19 | def conv5x5(in_planes, out_planes, stride=1):
20 | return nn.Conv2d(in_planes, out_planes, kernel_size=5, stride=stride,
21 | padding=1, bias=False)
22 |
23 | def conv(kernel, planes):
24 | if kernel == 3:
25 | _conv = conv3x3
26 | elif kernel == 5:
27 | _conv = conv5x5
28 | else:
29 | raise NotImplemented(f"Unkown kernel size: {kernel}")
30 |
31 | return nn.Sequential(
32 | nn.ReLU(inplace=True),
33 | _conv(planes, planes),
34 | nn.BatchNorm2d(planes),
35 | )
36 |
37 |
38 | class CNN(SharedModel):
39 | def __init__(self, args, images):
40 | super(CNN, self).__init__()
41 |
42 | self.args = args
43 | self.images = images
44 |
45 | self.w_c, self.w_h = defaultdict(dict), defaultdict(dict)
46 | self.reset_parameters()
47 |
48 | self.conv = defaultdict(dict)
49 | for idx in range(args.num_blocks):
50 | for jdx in range(idx+1, args.num_blocks):
51 | self.conv[idx][jdx] = conv()
52 |
53 | raise NotImplemented("In progress...")
54 |
55 | def forward(self, inputs, dag):
56 | pass
57 |
58 | def get_f(self, name):
59 | name = name.lower()
60 | return f
61 |
62 | def get_num_cell_parameters(self, dag):
63 | pass
64 |
65 | def reset_parameters(self):
66 | pass
67 |
--------------------------------------------------------------------------------
/models/shared_rnn.py:
--------------------------------------------------------------------------------
1 | """Module containing the shared RNN model."""
2 | import numpy as np
3 | import collections
4 |
5 | import torch
6 | from torch import nn
7 | import torch.nn.functional as F
8 | from torch.autograd import Variable
9 |
10 | import models.shared_base
11 | import utils
12 |
13 |
14 | logger = utils.get_logger()
15 |
16 |
17 | def _get_dropped_weights(w_raw, dropout_p, is_training):
18 | """Drops out weights to implement DropConnect.
19 |
20 | Args:
21 | w_raw: Full, pre-dropout, weights to be dropped out.
22 | dropout_p: Proportion of weights to drop out.
23 | is_training: True iff _shared_ model is training.
24 |
25 | Returns:
26 | The dropped weights.
27 |
28 | TODO(brendan): Why does torch.nn.functional.dropout() return:
29 | 1. `torch.autograd.Variable()` on the training loop
30 | 2. `torch.nn.Parameter()` on the controller or eval loop, when
31 | training = False...
32 |
33 | Even though the call to `_setweights` in the Smerity repo's
34 | `weight_drop.py` does not have this behaviour, and `F.dropout` always
35 | returns `torch.autograd.Variable` there, even when `training=False`?
36 |
37 | The above TODO is the reason for the hacky check for `torch.nn.Parameter`.
38 | """
39 | dropped_w = F.dropout(w_raw, p=dropout_p, training=is_training)
40 |
41 | if isinstance(dropped_w, torch.nn.Parameter):
42 | dropped_w = dropped_w.clone()
43 |
44 | return dropped_w
45 |
46 |
47 | def isnan(tensor):
48 | return np.isnan(tensor.cpu().data.numpy()).sum() > 0
49 |
50 |
51 | class EmbeddingDropout(torch.nn.Embedding):
52 | """Class for dropping out embeddings by zero'ing out parameters in the
53 | embedding matrix.
54 |
55 | This is equivalent to dropping out particular words, e.g., in the sentence
56 | 'the quick brown fox jumps over the lazy dog', dropping out 'the' would
57 | lead to the sentence '### quick brown fox jumps over ### lazy dog' (in the
58 | embedding vector space).
59 |
60 | See 'A Theoretically Grounded Application of Dropout in Recurrent Neural
61 | Networks', (Gal and Ghahramani, 2016).
62 | """
63 | def __init__(self,
64 | num_embeddings,
65 | embedding_dim,
66 | max_norm=None,
67 | norm_type=2,
68 | scale_grad_by_freq=False,
69 | sparse=False,
70 | dropout=0.1,
71 | scale=None):
72 | """Embedding constructor.
73 |
74 | Args:
75 | dropout: Dropout probability.
76 | scale: Used to scale parameters of embedding weight matrix that are
77 | not dropped out. Note that this is _in addition_ to the
78 | `1/(1 - dropout)` scaling.
79 |
80 | See `torch.nn.Embedding` for remaining arguments.
81 | """
82 | torch.nn.Embedding.__init__(self,
83 | num_embeddings=num_embeddings,
84 | embedding_dim=embedding_dim,
85 | max_norm=max_norm,
86 | norm_type=norm_type,
87 | scale_grad_by_freq=scale_grad_by_freq,
88 | sparse=sparse)
89 | self.dropout = dropout
90 | assert (dropout >= 0.0) and (dropout < 1.0), ('Dropout must be >= 0.0 '
91 | 'and < 1.0')
92 | self.scale = scale
93 |
94 | def forward(self, inputs): # pylint:disable=arguments-differ
95 | """Embeds `inputs` with the dropped out embedding weight matrix."""
96 | if self.training:
97 | dropout = self.dropout
98 | else:
99 | dropout = 0
100 |
101 | if dropout:
102 | mask = self.weight.data.new(self.weight.size(0), 1)
103 | mask.bernoulli_(1 - dropout)
104 | mask = mask.expand_as(self.weight)
105 | mask = mask / (1 - dropout)
106 | masked_weight = self.weight * Variable(mask)
107 | else:
108 | masked_weight = self.weight
109 | if self.scale and self.scale != 1:
110 | masked_weight = masked_weight * self.scale
111 |
112 | return F.embedding(inputs,
113 | masked_weight,
114 | max_norm=self.max_norm,
115 | norm_type=self.norm_type,
116 | scale_grad_by_freq=self.scale_grad_by_freq,
117 | sparse=self.sparse)
118 |
119 |
120 | class LockedDropout(nn.Module):
121 | # code from https://github.com/salesforce/awd-lstm-lm/blob/master/locked_dropout.py
122 | def __init__(self):
123 | super().__init__()
124 |
125 | def forward(self, x, dropout=0.5):
126 | if not self.training or not dropout:
127 | return x
128 | m = x.data.new(1, x.size(1), x.size(2)).bernoulli_(1 - dropout)
129 | mask = Variable(m, requires_grad=False) / (1 - dropout)
130 | mask = mask.expand_as(x)
131 | return mask * x
132 |
133 |
134 | class RNN(models.shared_base.SharedModel):
135 | """Shared RNN model."""
136 | def __init__(self, args, corpus):
137 | models.shared_base.SharedModel.__init__(self)
138 |
139 | self.args = args
140 | self.corpus = corpus
141 |
142 | self.decoder = nn.Linear(args.shared_hid, corpus.num_tokens)
143 | self.encoder = EmbeddingDropout(corpus.num_tokens,
144 | args.shared_embed,
145 | dropout=args.shared_dropoute)
146 | self.lockdrop = LockedDropout()
147 |
148 | if self.args.tie_weights:
149 | self.decoder.weight = self.encoder.weight
150 |
151 | # NOTE(brendan): Since W^{x, c} and W^{h, c} are always summed, there
152 | # is no point duplicating their bias offset parameter. Likewise for
153 | # W^{x, h} and W^{h, h}.
154 | self.w_xc = nn.Linear(args.shared_embed, args.shared_hid)
155 | self.w_xh = nn.Linear(args.shared_embed, args.shared_hid)
156 |
157 | # The raw weights are stored here because the hidden-to-hidden weights
158 | # are weight dropped on the forward pass.
159 | self.w_hc_raw = torch.nn.Parameter(
160 | torch.Tensor(args.shared_hid, args.shared_hid))
161 | self.w_hh_raw = torch.nn.Parameter(
162 | torch.Tensor(args.shared_hid, args.shared_hid))
163 | self.w_hc = None
164 | self.w_hh = None
165 |
166 | self.w_h = collections.defaultdict(dict)
167 | self.w_c = collections.defaultdict(dict)
168 |
169 | for idx in range(args.num_blocks):
170 | for jdx in range(idx + 1, args.num_blocks):
171 | self.w_h[idx][jdx] = nn.Linear(args.shared_hid,
172 | args.shared_hid,
173 | bias=False)
174 | self.w_c[idx][jdx] = nn.Linear(args.shared_hid,
175 | args.shared_hid,
176 | bias=False)
177 |
178 | self._w_h = nn.ModuleList([self.w_h[idx][jdx]
179 | for idx in self.w_h
180 | for jdx in self.w_h[idx]])
181 | self._w_c = nn.ModuleList([self.w_c[idx][jdx]
182 | for idx in self.w_c
183 | for jdx in self.w_c[idx]])
184 |
185 | if args.mode == 'train':
186 | self.batch_norm = nn.BatchNorm1d(args.shared_hid)
187 | else:
188 | self.batch_norm = None
189 |
190 | self.reset_parameters()
191 | self.static_init_hidden = utils.keydefaultdict(self.init_hidden)
192 |
193 | logger.info(f'# of parameters: {format(self.num_parameters, ",d")}')
194 |
195 | def forward(self, # pylint:disable=arguments-differ
196 | inputs,
197 | dag,
198 | hidden=None,
199 | is_train=True):
200 | time_steps = inputs.size(0)
201 | batch_size = inputs.size(1)
202 |
203 | is_train = is_train and self.args.mode in ['train']
204 |
205 | self.w_hh = _get_dropped_weights(self.w_hh_raw,
206 | self.args.shared_wdrop,
207 | self.training)
208 | self.w_hc = _get_dropped_weights(self.w_hc_raw,
209 | self.args.shared_wdrop,
210 | self.training)
211 |
212 | if hidden is None:
213 | hidden = self.static_init_hidden[batch_size]
214 |
215 | embed = self.encoder(inputs)
216 |
217 | if self.args.shared_dropouti > 0:
218 | embed = self.lockdrop(embed,
219 | self.args.shared_dropouti if is_train else 0)
220 |
221 | # TODO(brendan): The norm of hidden states are clipped here because
222 | # otherwise ENAS is especially prone to exploding activations on the
223 | # forward pass. This could probably be fixed in a more elegant way, but
224 | # it might be exposing a weakness in the ENAS algorithm as currently
225 | # proposed.
226 | #
227 | # For more details, see
228 | # https://github.com/carpedm20/ENAS-pytorch/issues/6
229 | clipped_num = 0
230 | max_clipped_norm = 0
231 | h1tohT = []
232 | logits = []
233 | for step in range(time_steps):
234 | x_t = embed[step]
235 | logit, hidden = self.cell(x_t, hidden, dag)
236 |
237 | hidden_norms = hidden.norm(dim=-1)
238 | max_norm = 25.0
239 | if hidden_norms.data.max() > max_norm:
240 | # TODO(brendan): Just directly use the torch slice operations
241 | # in PyTorch v0.4.
242 | #
243 | # This workaround for PyTorch v0.3.1 does everything in numpy,
244 | # because the PyTorch slicing and slice assignment is too
245 | # flaky.
246 | hidden_norms = hidden_norms.data.cpu().numpy()
247 |
248 | clipped_num += 1
249 | if hidden_norms.max() > max_clipped_norm:
250 | max_clipped_norm = hidden_norms.max()
251 |
252 | clip_select = hidden_norms > max_norm
253 | clip_norms = hidden_norms[clip_select]
254 |
255 | mask = np.ones(hidden.size())
256 | normalizer = max_norm/clip_norms
257 | normalizer = normalizer[:, np.newaxis]
258 |
259 | mask[clip_select] = normalizer
260 | hidden *= torch.autograd.Variable(
261 | torch.FloatTensor(mask).cuda(), requires_grad=False)
262 |
263 | logits.append(logit)
264 | h1tohT.append(hidden)
265 |
266 | if clipped_num > 0:
267 | logger.info(f'clipped {clipped_num} hidden states in one forward '
268 | f'pass. '
269 | f'max clipped hidden state norm: {max_clipped_norm}')
270 |
271 | h1tohT = torch.stack(h1tohT)
272 | output = torch.stack(logits)
273 | raw_output = output
274 | if self.args.shared_dropout > 0:
275 | output = self.lockdrop(output,
276 | self.args.shared_dropout if is_train else 0)
277 |
278 | dropped_output = output
279 |
280 | decoded = self.decoder(
281 | output.view(output.size(0)*output.size(1), output.size(2)))
282 | decoded = decoded.view(output.size(0), output.size(1), decoded.size(1))
283 |
284 | extra_out = {'dropped': dropped_output,
285 | 'hiddens': h1tohT,
286 | 'raw': raw_output}
287 | return decoded, hidden, extra_out
288 |
289 | def cell(self, x, h_prev, dag):
290 | """Computes a single pass through the discovered RNN cell."""
291 | c = {}
292 | h = {}
293 | f = {}
294 |
295 | f[0] = self.get_f(dag[-1][0].name)
296 | c[0] = F.sigmoid(self.w_xc(x) + F.linear(h_prev, self.w_hc, None))
297 | h[0] = (c[0]*f[0](self.w_xh(x) + F.linear(h_prev, self.w_hh, None)) +
298 | (1 - c[0])*h_prev)
299 |
300 | leaf_node_ids = []
301 | q = collections.deque()
302 | q.append(0)
303 |
304 | # NOTE(brendan): Computes connections from the parent nodes `node_id`
305 | # to their child nodes `next_id` recursively, skipping leaf nodes. A
306 | # leaf node is a node whose id == `self.args.num_blocks`.
307 | #
308 | # Connections between parent i and child j should be computed as
309 | # h_j = c_j*f_{ij}{(W^h_{ij}*h_i)} + (1 - c_j)*h_i,
310 | # where c_j = \sigmoid{(W^c_{ij}*h_i)}
311 | #
312 | # See Training details from Section 3.1 of the paper.
313 | #
314 | # The following algorithm does a breadth-first (since `q.popleft()` is
315 | # used) search over the nodes and computes all the hidden states.
316 | while True:
317 | if len(q) == 0:
318 | break
319 |
320 | node_id = q.popleft()
321 | nodes = dag[node_id]
322 |
323 | for next_node in nodes:
324 | next_id = next_node.id
325 | if next_id == self.args.num_blocks:
326 | leaf_node_ids.append(node_id)
327 | assert len(nodes) == 1, ('parent of leaf node should have '
328 | 'only one child')
329 | continue
330 |
331 | w_h = self.w_h[node_id][next_id]
332 | w_c = self.w_c[node_id][next_id]
333 |
334 | f[next_id] = self.get_f(next_node.name)
335 | c[next_id] = F.sigmoid(w_c(h[node_id]))
336 | h[next_id] = (c[next_id]*f[next_id](w_h(h[node_id])) +
337 | (1 - c[next_id])*h[node_id])
338 |
339 | q.append(next_id)
340 |
341 | # TODO(brendan): Instead of averaging loose ends, perhaps there should
342 | # be a set of separate unshared weights for each "loose" connection
343 | # between each node in a cell and the output.
344 | #
345 | # As it stands, all weights W^h_{ij} are doing double duty by
346 | # connecting both from i to j, as well as from i to the output.
347 |
348 | # average all the loose ends
349 | leaf_nodes = [h[node_id] for node_id in leaf_node_ids]
350 | output = torch.mean(torch.stack(leaf_nodes, 2), -1)
351 |
352 | # stabilizing the Updates of omega
353 | if self.batch_norm is not None:
354 | output = self.batch_norm(output)
355 |
356 | return output, h[self.args.num_blocks - 1]
357 |
358 | def init_hidden(self, batch_size):
359 | zeros = torch.zeros(batch_size, self.args.shared_hid)
360 | return utils.get_variable(zeros, self.args.cuda, requires_grad=False)
361 |
362 | def get_f(self, name):
363 | name = name.lower()
364 | if name == 'relu':
365 | f = F.relu
366 | elif name == 'tanh':
367 | f = F.tanh
368 | elif name == 'identity':
369 | f = lambda x: x
370 | elif name == 'sigmoid':
371 | f = F.sigmoid
372 | return f
373 |
374 | def get_num_cell_parameters(self, dag):
375 | num = 0
376 |
377 | num += models.shared_base.size(self.w_xc)
378 | num += models.shared_base.size(self.w_xh)
379 |
380 | q = collections.deque()
381 | q.append(0)
382 |
383 | while True:
384 | if len(q) == 0:
385 | break
386 |
387 | node_id = q.popleft()
388 | nodes = dag[node_id]
389 |
390 | for next_node in nodes:
391 | next_id = next_node.id
392 | if next_id == self.args.num_blocks:
393 | assert len(nodes) == 1, 'parent of leaf node should have only one child'
394 | continue
395 |
396 | w_h = self.w_h[node_id][next_id]
397 | w_c = self.w_c[node_id][next_id]
398 |
399 | num += models.shared_base.size(w_h)
400 | num += models.shared_base.size(w_c)
401 |
402 | q.append(next_id)
403 |
404 | logger.debug(f'# of cell parameters: '
405 | f'{format(self.num_parameters, ",d")}')
406 | return num
407 |
408 | def reset_parameters(self):
409 | init_range = 0.025 if self.args.mode == 'train' else 0.04
410 | for param in self.parameters():
411 | param.data.uniform_(-init_range, init_range)
412 | self.decoder.bias.data.fill_(0)
413 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | ipdb
2 | tqdm
3 | scipy==1.2.1
4 | imageio
5 | pygraphviz
6 | torch
7 | torchvision
8 | tensorboardX
9 | Pillow
10 | opencv-contrib-python
--------------------------------------------------------------------------------
/run.sh:
--------------------------------------------------------------------------------
1 | #!/bin/sh
2 | # Entropy decrease fast but maintaining compared to other --ema=0.95
3 | #python main.py --ema_baseline_decay=0.9 --shared_initial_step=150 --reward_c=800 &
4 | #sleep 2
5 |
6 | # BAD # live until 5k entropy was high so shared_loss was unstable
7 | #python main.py --ema_baseline_decay=0.9 --shared_initial_step=150 --reward_c=80 &
8 | #sleep 2
9 |
10 | # BAD # Entropy decrease fast and increase continuously
11 | #python main.py --ema_baseline_decay=0.95 --shared_initial_step=150 --reward_c=800 &
12 | #sleep 2
13 |
14 | # BAD explode but alive longer then upper reward_c=800 one
15 | # Entropy decrease fast and increase continuously
16 | #python main.py --ema_baseline_decay=0.95 --shared_initial_step=150 --reward_c=80 &
17 | #sleep 2
18 |
19 | #python main.py --ema_baseline_decay=0.9 --shared_initial_step=150 --reward_c=80 --ppl_square=True &
20 | sleep 2
21 |
22 | #python main.py --ema_baseline_decay=0.92 --shared_initial_step=150 --reward_c=80 --ppl_square=True &
23 | sleep 2
24 |
25 | python main.py --ema_baseline_decay=0.95 --shared_initial_step=150 --reward_c=80 --ppl_square=True &
26 | sleep 2
27 |
--------------------------------------------------------------------------------
/tensorboard.py:
--------------------------------------------------------------------------------
1 | import PIL
2 | import scipy.misc
3 | from io import BytesIO
4 | import tensorboardX as tb
5 | from tensorboardX.summary import Summary
6 |
7 |
8 | class TensorBoard(object):
9 | def __init__(self, model_dir):
10 | self.summary_writer = tb.FileWriter(model_dir)
11 |
12 | def image_summary(self, tag, value, step):
13 | for idx, img in enumerate(value):
14 | summary = Summary()
15 | bio = BytesIO()
16 |
17 | if type(img) == str:
18 | img = PIL.Image.open(img)
19 | elif type(img) == PIL.Image.Image:
20 | pass
21 | else:
22 | img = scipy.misc.toimage(img)
23 |
24 | img.save(bio, format="png")
25 | image_summary = Summary.Image(encoded_image_string=bio.getvalue())
26 | summary.value.add(tag=f"{tag}/{idx}", image=image_summary)
27 | self.summary_writer.add_summary(summary, global_step=step)
28 |
29 | def scalar_summary(self, tag, value, step):
30 | summary= Summary(value=[Summary.Value(tag=tag, simple_value=value)])
31 | self.summary_writer.add_summary(summary, global_step=step)
32 |
--------------------------------------------------------------------------------
/trainer.py:
--------------------------------------------------------------------------------
1 | """The module for training ENAS."""
2 | import contextlib
3 | import glob
4 | import math
5 | import os
6 |
7 | import numpy as np
8 | import scipy.signal
9 | from tensorboard import TensorBoard
10 | import torch
11 | from torch import nn
12 | import torch.nn.parallel
13 | from torch.autograd import Variable
14 |
15 | import models
16 | import utils
17 |
18 |
19 | logger = utils.get_logger()
20 |
21 |
22 | def _apply_penalties(extra_out, args):
23 | """Based on `args`, optionally adds regularization penalty terms for
24 | activation regularization, temporal activation regularization and/or hidden
25 | state norm stabilization.
26 |
27 | Args:
28 | extra_out[*]:
29 | dropped: Post-dropout activations.
30 | hiddens: All hidden states for a batch of sequences.
31 | raw: Pre-dropout activations.
32 |
33 | Returns:
34 | The penalty term associated with all of the enabled regularizations.
35 |
36 | See:
37 | Regularizing and Optimizing LSTM Language Models (Merity et al., 2017)
38 | Regularizing RNNs by Stabilizing Activations (Krueger & Memsevic, 2016)
39 | """
40 | penalty = 0
41 |
42 | # Activation regularization.
43 | if args.activation_regularization:
44 | penalty += (args.activation_regularization_amount *
45 | extra_out['dropped'].pow(2).mean())
46 |
47 | # Temporal activation regularization (slowness)
48 | if args.temporal_activation_regularization:
49 | raw = extra_out['raw']
50 | penalty += (args.temporal_activation_regularization_amount *
51 | (raw[1:] - raw[:-1]).pow(2).mean())
52 |
53 | # Norm stabilizer regularization
54 | if args.norm_stabilizer_regularization:
55 | penalty += (args.norm_stabilizer_regularization_amount *
56 | (extra_out['hiddens'].norm(dim=-1) -
57 | args.norm_stabilizer_fixed_point).pow(2).mean())
58 |
59 | return penalty
60 |
61 |
62 | def discount(x, amount):
63 | return scipy.signal.lfilter([1], [1, -amount], x[::-1], axis=0)[::-1]
64 |
65 |
66 | def _get_optimizer(name):
67 | if name.lower() == 'sgd':
68 | optim = torch.optim.SGD
69 | elif name.lower() == 'adam':
70 | optim = torch.optim.Adam
71 |
72 | return optim
73 |
74 |
75 | def _get_no_grad_ctx_mgr():
76 | """Returns a the `torch.no_grad` context manager for PyTorch version >=
77 | 0.4, or a no-op context manager otherwise.
78 | """
79 | if float(torch.__version__[0:3]) >= 0.4:
80 | return torch.no_grad()
81 |
82 | return contextlib.suppress()
83 |
84 |
85 | def _check_abs_max_grad(abs_max_grad, model):
86 | """Checks `model` for a new largest gradient for this epoch, in order to
87 | track gradient explosions.
88 | """
89 | finite_grads = [p.grad.data
90 | for p in model.parameters()
91 | if p.grad is not None]
92 |
93 | new_max_grad = max([grad.max() for grad in finite_grads])
94 | new_min_grad = min([grad.min() for grad in finite_grads])
95 |
96 | new_abs_max_grad = max(new_max_grad, abs(new_min_grad))
97 | if new_abs_max_grad > abs_max_grad:
98 | logger.info(f'abs max grad {abs_max_grad}')
99 | return new_abs_max_grad
100 |
101 | return abs_max_grad
102 |
103 |
104 | class Trainer(object):
105 | """A class to wrap training code."""
106 | def __init__(self, args, dataset):
107 | """Constructor for training algorithm.
108 |
109 | Args:
110 | args: From command line, picked up by `argparse`.
111 | dataset: Currently only `data.text.Corpus` is supported.
112 |
113 | Initializes:
114 | - Data: train, val and test.
115 | - Model: shared and controller.
116 | - Inference: optimizers for shared and controller parameters.
117 | - Criticism: cross-entropy loss for training the shared model.
118 | """
119 | self.args = args
120 | self.controller_step = 0
121 | self.cuda = args.cuda
122 | self.dataset = dataset
123 | self.epoch = 0
124 | self.shared_step = 0
125 | self.start_epoch = 0
126 |
127 | logger.info('regularizing:')
128 | for regularizer in [('activation regularization',
129 | self.args.activation_regularization),
130 | ('temporal activation regularization',
131 | self.args.temporal_activation_regularization),
132 | ('norm stabilizer regularization',
133 | self.args.norm_stabilizer_regularization)]:
134 | if regularizer[1]:
135 | logger.info(f'{regularizer[0]}')
136 |
137 | self.train_data = utils.batchify(dataset.train,
138 | args.batch_size,
139 | self.cuda)
140 | # NOTE(brendan): The validation set data is batchified twice
141 | # separately: once for computing rewards during the Train Controller
142 | # phase (valid_data, batch size == 64), and once for evaluating ppl
143 | # over the entire validation set (eval_data, batch size == 1)
144 | self.valid_data = utils.batchify(dataset.valid,
145 | args.batch_size,
146 | self.cuda)
147 | self.eval_data = utils.batchify(dataset.valid,
148 | args.test_batch_size,
149 | self.cuda)
150 | self.test_data = utils.batchify(dataset.test,
151 | args.test_batch_size,
152 | self.cuda)
153 |
154 | self.max_length = self.args.shared_rnn_max_length
155 |
156 | if args.use_tensorboard:
157 | self.tb = TensorBoard(args.model_dir)
158 | else:
159 | self.tb = None
160 | self.build_model()
161 |
162 | if self.args.load_path:
163 | self.load_model()
164 |
165 | shared_optimizer = _get_optimizer(self.args.shared_optim)
166 | controller_optimizer = _get_optimizer(self.args.controller_optim)
167 |
168 | self.shared_optim = shared_optimizer(
169 | self.shared.parameters(),
170 | lr=self.shared_lr,
171 | weight_decay=self.args.shared_l2_reg)
172 |
173 | self.controller_optim = controller_optimizer(
174 | self.controller.parameters(),
175 | lr=self.args.controller_lr)
176 |
177 | self.ce = nn.CrossEntropyLoss()
178 |
179 | def build_model(self):
180 | """Creates and initializes the shared and controller models."""
181 | if self.args.network_type == 'rnn':
182 | self.shared = models.RNN(self.args, self.dataset)
183 | elif self.args.network_type == 'cnn':
184 | self.shared = models.CNN(self.args, self.dataset)
185 | else:
186 | raise NotImplementedError(f'Network type '
187 | f'`{self.args.network_type}` is not '
188 | f'defined')
189 | self.controller = models.Controller(self.args)
190 |
191 | if self.args.num_gpu == 1:
192 | self.shared.cuda()
193 | self.controller.cuda()
194 | elif self.args.num_gpu > 1:
195 | raise NotImplementedError('`num_gpu > 1` is in progress')
196 |
197 | def train(self, single=False):
198 | """Cycles through alternately training the shared parameters and the
199 | controller, as described in Section 2.2, Training ENAS and Deriving
200 | Architectures, of the paper.
201 |
202 | From the paper (for Penn Treebank):
203 |
204 | - In the first phase, shared parameters omega are trained for 400
205 | steps, each on a minibatch of 64 examples.
206 |
207 | - In the second phase, the controller's parameters are trained for 2000
208 | steps.
209 |
210 | Args:
211 | single (bool): If True it won't train the controller and use the
212 | same dag instead of derive().
213 | """
214 | dag = utils.load_dag(self.args) if single else None
215 |
216 | if self.args.shared_initial_step > 0:
217 | self.train_shared(self.args.shared_initial_step)
218 | self.train_controller()
219 |
220 | for self.epoch in range(self.start_epoch, self.args.max_epoch):
221 | # 1. Training the shared parameters omega of the child models
222 | self.train_shared(dag=dag)
223 |
224 | # 2. Training the controller parameters theta
225 | if not single:
226 | self.train_controller()
227 |
228 | if self.epoch % self.args.save_epoch == 0:
229 | with _get_no_grad_ctx_mgr():
230 | best_dag = dag if dag else self.derive()
231 | self.evaluate(self.eval_data,
232 | best_dag,
233 | 'val_best',
234 | max_num=self.args.batch_size*100)
235 | self.save_model()
236 |
237 | if self.epoch >= self.args.shared_decay_after:
238 | utils.update_lr(self.shared_optim, self.shared_lr)
239 |
240 | def get_loss(self, inputs, targets, hidden, dags):
241 | """Computes the loss for the same batch for M models.
242 |
243 | This amounts to an estimate of the loss, which is turned into an
244 | estimate for the gradients of the shared model.
245 | """
246 | if not isinstance(dags, list):
247 | dags = [dags]
248 |
249 | loss = 0
250 | for dag in dags:
251 | output, hidden, extra_out = self.shared(inputs, dag, hidden=hidden)
252 | output_flat = output.view(-1, self.dataset.num_tokens)
253 | sample_loss = (self.ce(output_flat, targets) /
254 | self.args.shared_num_sample)
255 | loss += sample_loss
256 |
257 | assert len(dags) == 1, 'there are multiple `hidden` for multple `dags`'
258 | return loss, hidden, extra_out
259 |
260 | def train_shared(self, max_step=None, dag=None):
261 | """Train the language model for 400 steps of minibatches of 64
262 | examples.
263 |
264 | Args:
265 | max_step: Used to run extra training steps as a warm-up.
266 | dag: If not None, is used instead of calling sample().
267 |
268 | BPTT is truncated at 35 timesteps.
269 |
270 | For each weight update, gradients are estimated by sampling M models
271 | from the fixed controller policy, and averaging their gradients
272 | computed on a batch of training data.
273 | """
274 | model = self.shared
275 | model.train()
276 | self.controller.eval()
277 |
278 | hidden = self.shared.init_hidden(self.args.batch_size)
279 |
280 | if max_step is None:
281 | max_step = self.args.shared_max_step
282 | else:
283 | max_step = min(self.args.shared_max_step, max_step)
284 |
285 | abs_max_grad = 0
286 | abs_max_hidden_norm = 0
287 | step = 0
288 | raw_total_loss = 0
289 | total_loss = 0
290 | train_idx = 0
291 | # TODO(brendan): Why - 1 - 1?
292 | while train_idx < self.train_data.size(0) - 1 - 1:
293 | if step > max_step:
294 | break
295 |
296 | dags = dag if dag else self.controller.sample(
297 | self.args.shared_num_sample)
298 | inputs, targets = self.get_batch(self.train_data,
299 | train_idx,
300 | self.max_length)
301 |
302 | loss, hidden, extra_out = self.get_loss(inputs,
303 | targets,
304 | hidden,
305 | dags)
306 | hidden.detach_()
307 | raw_total_loss += loss.data
308 |
309 | loss += _apply_penalties(extra_out, self.args)
310 |
311 | # update
312 | self.shared_optim.zero_grad()
313 | loss.backward()
314 |
315 | h1tohT = extra_out['hiddens']
316 | new_abs_max_hidden_norm = utils.to_item(
317 | h1tohT.norm(dim=-1).data.max())
318 | if new_abs_max_hidden_norm > abs_max_hidden_norm:
319 | abs_max_hidden_norm = new_abs_max_hidden_norm
320 | logger.info(f'max hidden {abs_max_hidden_norm}')
321 | abs_max_grad = _check_abs_max_grad(abs_max_grad, model)
322 | torch.nn.utils.clip_grad_norm(model.parameters(),
323 | self.args.shared_grad_clip)
324 | self.shared_optim.step()
325 |
326 | total_loss += loss.data
327 |
328 | if ((step % self.args.log_step) == 0) and (step > 0):
329 | self._summarize_shared_train(total_loss, raw_total_loss)
330 | raw_total_loss = 0
331 | total_loss = 0
332 |
333 | step += 1
334 | self.shared_step += 1
335 | train_idx += self.max_length
336 |
337 | def get_reward(self, dag, entropies, hidden, valid_idx=0):
338 | """Computes the perplexity of a single sampled model on a minibatch of
339 | validation data.
340 | """
341 | if not isinstance(entropies, np.ndarray):
342 | entropies = entropies.data.cpu().numpy()
343 |
344 | inputs, targets = self.get_batch(self.valid_data,
345 | valid_idx,
346 | self.max_length,
347 | volatile=True)
348 | valid_loss, hidden, _ = self.get_loss(inputs, targets, hidden, dag)
349 | valid_loss = utils.to_item(valid_loss.data)
350 |
351 | valid_ppl = math.exp(valid_loss)
352 |
353 | # TODO: we don't know reward_c
354 | if self.args.ppl_square:
355 | # TODO: but we do know reward_c=80 in the previous paper
356 | R = self.args.reward_c / valid_ppl ** 2
357 | else:
358 | R = self.args.reward_c / valid_ppl
359 |
360 | if self.args.entropy_mode == 'reward':
361 | rewards = R + self.args.entropy_coeff * entropies
362 | elif self.args.entropy_mode == 'regularizer':
363 | rewards = R * np.ones_like(entropies)
364 | else:
365 | raise NotImplementedError(f'Unkown entropy mode: {self.args.entropy_mode}')
366 |
367 | return rewards, hidden
368 |
369 | def train_controller(self):
370 | """Fixes the shared parameters and updates the controller parameters.
371 |
372 | The controller is updated with a score function gradient estimator
373 | (i.e., REINFORCE), with the reward being c/valid_ppl, where valid_ppl
374 | is computed on a minibatch of validation data.
375 |
376 | A moving average baseline is used.
377 |
378 | The controller is trained for 2000 steps per epoch (i.e.,
379 | first (Train Shared) phase -> second (Train Controller) phase).
380 | """
381 | model = self.controller
382 | model.train()
383 | # TODO(brendan): Why can't we call shared.eval() here? Leads to loss
384 | # being uniformly zero for the controller.
385 | # self.shared.eval()
386 |
387 | avg_reward_base = None
388 | baseline = None
389 | adv_history = []
390 | entropy_history = []
391 | reward_history = []
392 |
393 | hidden = self.shared.init_hidden(self.args.batch_size)
394 | total_loss = 0
395 | valid_idx = 0
396 | for step in range(self.args.controller_max_step):
397 | # sample models
398 | dags, log_probs, entropies = self.controller.sample(
399 | with_details=True)
400 |
401 | # calculate reward
402 | np_entropies = entropies.data.cpu().numpy()
403 | # NOTE(brendan): No gradients should be backpropagated to the
404 | # shared model during controller training, obviously.
405 | with _get_no_grad_ctx_mgr():
406 | rewards, hidden = self.get_reward(dags,
407 | np_entropies,
408 | hidden,
409 | valid_idx)
410 |
411 | # discount
412 | if 1 > self.args.discount > 0:
413 | rewards = discount(rewards, self.args.discount)
414 |
415 | reward_history.extend(rewards)
416 | entropy_history.extend(np_entropies)
417 |
418 | # moving average baseline
419 | if baseline is None:
420 | baseline = rewards
421 | else:
422 | decay = self.args.ema_baseline_decay
423 | baseline = decay * baseline + (1 - decay) * rewards
424 |
425 | adv = rewards - baseline
426 | adv_history.extend(adv)
427 |
428 | # policy loss
429 | loss = -log_probs*utils.get_variable(adv,
430 | self.cuda,
431 | requires_grad=False)
432 | if self.args.entropy_mode == 'regularizer':
433 | loss -= self.args.entropy_coeff * entropies
434 |
435 | loss = loss.sum() # or loss.mean()
436 |
437 | # update
438 | self.controller_optim.zero_grad()
439 | loss.backward()
440 |
441 | if self.args.controller_grad_clip > 0:
442 | torch.nn.utils.clip_grad_norm(model.parameters(),
443 | self.args.controller_grad_clip)
444 | self.controller_optim.step()
445 |
446 | total_loss += utils.to_item(loss.data)
447 |
448 | if ((step % self.args.log_step) == 0) and (step > 0):
449 | self._summarize_controller_train(total_loss,
450 | adv_history,
451 | entropy_history,
452 | reward_history,
453 | avg_reward_base,
454 | dags)
455 |
456 | reward_history, adv_history, entropy_history = [], [], []
457 | total_loss = 0
458 |
459 | self.controller_step += 1
460 |
461 | prev_valid_idx = valid_idx
462 | valid_idx = ((valid_idx + self.max_length) %
463 | (self.valid_data.size(0) - 1))
464 | # NOTE(brendan): Whenever we wrap around to the beginning of the
465 | # validation data, we reset the hidden states.
466 | if prev_valid_idx > valid_idx:
467 | hidden = self.shared.init_hidden(self.args.batch_size)
468 |
469 | def evaluate(self, source, dag, name, batch_size=1, max_num=None):
470 | """Evaluate on the validation set.
471 |
472 | NOTE(brendan): We should not be using the test set to develop the
473 | algorithm (basic machine learning good practices).
474 | """
475 | self.shared.eval()
476 | self.controller.eval()
477 |
478 | data = source[:max_num*self.max_length]
479 |
480 | total_loss = 0
481 | hidden = self.shared.init_hidden(batch_size)
482 |
483 | pbar = range(0, data.size(0) - 1, self.max_length)
484 | for count, idx in enumerate(pbar):
485 | inputs, targets = self.get_batch(data, idx, volatile=True)
486 | output, hidden, _ = self.shared(inputs,
487 | dag,
488 | hidden=hidden,
489 | is_train=False)
490 | output_flat = output.view(-1, self.dataset.num_tokens)
491 | total_loss += len(inputs) * self.ce(output_flat, targets).data
492 | hidden.detach_()
493 | ppl = math.exp(utils.to_item(total_loss) / (count + 1) / self.max_length)
494 |
495 | val_loss = utils.to_item(total_loss) / len(data)
496 | ppl = math.exp(val_loss)
497 |
498 | self.tb.scalar_summary(f'eval/{name}_loss', val_loss, self.epoch)
499 | self.tb.scalar_summary(f'eval/{name}_ppl', ppl, self.epoch)
500 | logger.info(f'eval | loss: {val_loss:8.2f} | ppl: {ppl:8.2f}')
501 |
502 | def derive(self, sample_num=None, valid_idx=0):
503 | """TODO(brendan): We are always deriving based on the very first batch
504 | of validation data? This seems wrong...
505 | """
506 | hidden = self.shared.init_hidden(self.args.batch_size)
507 |
508 | if sample_num is None:
509 | sample_num = self.args.derive_num_sample
510 |
511 | dags, _, entropies = self.controller.sample(sample_num,
512 | with_details=True)
513 |
514 | max_R = 0
515 | best_dag = None
516 | for dag in dags:
517 | R, _ = self.get_reward(dag, entropies, hidden, valid_idx)
518 | if R.max() > max_R:
519 | max_R = R.max()
520 | best_dag = dag
521 |
522 | logger.info(f'derive | max_R: {max_R:8.6f}')
523 | fname = (f'{self.epoch:03d}-{self.controller_step:06d}-'
524 | f'{max_R:6.4f}-best.png')
525 | path = os.path.join(self.args.model_dir, 'networks', fname)
526 | utils.draw_network(best_dag, path)
527 | self.tb.image_summary('derive/best', [path], self.epoch)
528 |
529 | return best_dag
530 |
531 | @property
532 | def shared_lr(self):
533 | degree = max(self.epoch - self.args.shared_decay_after + 1, 0)
534 | return self.args.shared_lr * (self.args.shared_decay ** degree)
535 |
536 | @property
537 | def controller_lr(self):
538 | return self.args.controller_lr
539 |
540 | def get_batch(self, source, idx, length=None, volatile=False):
541 | # code from
542 | # https://github.com/pytorch/examples/blob/master/word_language_model/main.py
543 | length = min(length if length else self.max_length,
544 | len(source) - 1 - idx)
545 | data = Variable(source[idx:idx + length], volatile=volatile)
546 | target = Variable(source[idx + 1:idx + 1 + length].view(-1),
547 | volatile=volatile)
548 | return data, target
549 |
550 | @property
551 | def shared_path(self):
552 | return f'{self.args.model_dir}/shared_epoch{self.epoch}_step{self.shared_step}.pth'
553 |
554 | @property
555 | def controller_path(self):
556 | return f'{self.args.model_dir}/controller_epoch{self.epoch}_step{self.controller_step}.pth'
557 |
558 | def get_saved_models_info(self):
559 | paths = glob.glob(os.path.join(self.args.model_dir, '*.pth'))
560 | paths.sort()
561 |
562 | def get_numbers(items, delimiter, idx, replace_word, must_contain=''):
563 | return list(set([int(
564 | name.split(delimiter)[idx].replace(replace_word, ''))
565 | for name in basenames if must_contain in name]))
566 |
567 | basenames = [os.path.basename(path.rsplit('.', 1)[0]) for path in paths]
568 | epochs = get_numbers(basenames, '_', 1, 'epoch')
569 | shared_steps = get_numbers(basenames, '_', 2, 'step', 'shared')
570 | controller_steps = get_numbers(basenames, '_', 2, 'step', 'controller')
571 |
572 | epochs.sort()
573 | shared_steps.sort()
574 | controller_steps.sort()
575 |
576 | return epochs, shared_steps, controller_steps
577 |
578 | def save_model(self):
579 | torch.save(self.shared.state_dict(), self.shared_path)
580 | logger.info(f'[*] SAVED: {self.shared_path}')
581 |
582 | torch.save(self.controller.state_dict(), self.controller_path)
583 | logger.info(f'[*] SAVED: {self.controller_path}')
584 |
585 | epochs, shared_steps, controller_steps = self.get_saved_models_info()
586 |
587 | for epoch in epochs[:-self.args.max_save_num]:
588 | paths = glob.glob(
589 | os.path.join(self.args.model_dir, f'*_epoch{epoch}_*.pth'))
590 |
591 | for path in paths:
592 | utils.remove_file(path)
593 |
594 | def load_model(self):
595 | epochs, shared_steps, controller_steps = self.get_saved_models_info()
596 |
597 | if len(epochs) == 0:
598 | logger.info(f'[!] No checkpoint found in {self.args.model_dir}...')
599 | return
600 |
601 | self.epoch = self.start_epoch = max(epochs)
602 | self.shared_step = max(shared_steps)
603 | self.controller_step = max(controller_steps)
604 |
605 | if self.args.num_gpu == 0:
606 | map_location = lambda storage, loc: storage
607 | else:
608 | map_location = None
609 |
610 | self.shared.load_state_dict(
611 | torch.load(self.shared_path, map_location=map_location))
612 | logger.info(f'[*] LOADED: {self.shared_path}')
613 |
614 | self.controller.load_state_dict(
615 | torch.load(self.controller_path, map_location=map_location))
616 | logger.info(f'[*] LOADED: {self.controller_path}')
617 |
618 | def _summarize_controller_train(self,
619 | total_loss,
620 | adv_history,
621 | entropy_history,
622 | reward_history,
623 | avg_reward_base,
624 | dags):
625 | """Logs the controller's progress for this training epoch."""
626 | cur_loss = total_loss / self.args.log_step
627 |
628 | avg_adv = np.mean(adv_history)
629 | avg_entropy = np.mean(entropy_history)
630 | avg_reward = np.mean(reward_history)
631 |
632 | if avg_reward_base is None:
633 | avg_reward_base = avg_reward
634 |
635 | logger.info(
636 | f'| epoch {self.epoch:3d} | lr {self.controller_lr:.5f} '
637 | f'| R {avg_reward:.5f} | entropy {avg_entropy:.4f} '
638 | f'| loss {cur_loss:.5f}')
639 |
640 | # Tensorboard
641 | if self.tb is not None:
642 | self.tb.scalar_summary('controller/loss',
643 | cur_loss,
644 | self.controller_step)
645 | self.tb.scalar_summary('controller/reward',
646 | avg_reward,
647 | self.controller_step)
648 | self.tb.scalar_summary('controller/reward-B_per_epoch',
649 | avg_reward - avg_reward_base,
650 | self.controller_step)
651 | self.tb.scalar_summary('controller/entropy',
652 | avg_entropy,
653 | self.controller_step)
654 | self.tb.scalar_summary('controller/adv',
655 | avg_adv,
656 | self.controller_step)
657 |
658 | paths = []
659 | for dag in dags:
660 | fname = (f'{self.epoch:03d}-{self.controller_step:06d}-'
661 | f'{avg_reward:6.4f}.png')
662 | path = os.path.join(self.args.model_dir, 'networks', fname)
663 | utils.draw_network(dag, path)
664 | paths.append(path)
665 |
666 | self.tb.image_summary('controller/sample',
667 | paths,
668 | self.controller_step)
669 |
670 | def _summarize_shared_train(self, total_loss, raw_total_loss):
671 | """Logs a set of training steps."""
672 | cur_loss = utils.to_item(total_loss) / self.args.log_step
673 | # NOTE(brendan): The raw loss, without adding in the activation
674 | # regularization terms, should be used to compute ppl.
675 | cur_raw_loss = utils.to_item(raw_total_loss) / self.args.log_step
676 | ppl = math.exp(cur_raw_loss)
677 |
678 | logger.info(f'| epoch {self.epoch:3d} '
679 | f'| lr {self.shared_lr:4.2f} '
680 | f'| raw loss {cur_raw_loss:.2f} '
681 | f'| loss {cur_loss:.2f} '
682 | f'| ppl {ppl:8.2f}')
683 |
684 | # Tensorboard
685 | if self.tb is not None:
686 | self.tb.scalar_summary('shared/loss',
687 | cur_loss,
688 | self.shared_step)
689 | self.tb.scalar_summary('shared/perplexity',
690 | ppl,
691 | self.shared_step)
692 |
--------------------------------------------------------------------------------
/utils.py:
--------------------------------------------------------------------------------
1 | from __future__ import print_function
2 |
3 | from collections import defaultdict
4 | import collections
5 | from datetime import datetime
6 | import os
7 | import json
8 | import logging
9 |
10 | import numpy as np
11 | import pygraphviz as pgv
12 |
13 | import torch
14 | from torch.autograd import Variable
15 |
16 | from PIL import Image
17 | from PIL import ImageFont
18 | from PIL import ImageDraw
19 |
20 |
21 | try:
22 | import scipy.misc
23 | imread = scipy.misc.imread
24 | imresize = scipy.misc.imresize
25 | imsave = imwrite = scipy.misc.imsave
26 | except:
27 | import cv2
28 | imread = cv2.imread
29 | imresize = cv2.imresize
30 | imsave = imwrite = cv2.imwrite
31 |
32 |
33 | ##########################
34 | # Network visualization
35 | ##########################
36 |
37 | def add_node(graph, node_id, label, shape='box', style='filled'):
38 | if label.startswith('x'):
39 | color = 'white'
40 | elif label.startswith('h'):
41 | color = 'skyblue'
42 | elif label == 'tanh':
43 | color = 'yellow'
44 | elif label == 'ReLU':
45 | color = 'pink'
46 | elif label == 'identity':
47 | color = 'orange'
48 | elif label == 'sigmoid':
49 | color = 'greenyellow'
50 | elif label == 'avg':
51 | color = 'seagreen3'
52 | else:
53 | color = 'white'
54 |
55 | if not any(label.startswith(word) for word in ['x', 'avg', 'h']):
56 | label = f"{label}\n({node_id})"
57 |
58 | graph.add_node(
59 | node_id, label=label, color='black', fillcolor=color,
60 | shape=shape, style=style,
61 | )
62 |
63 | def draw_network(dag, path):
64 | makedirs(os.path.dirname(path))
65 | graph = pgv.AGraph(directed=True, strict=True,
66 | fontname='Helvetica', arrowtype='open') # not work?
67 |
68 | checked_ids = [-2, -1, 0]
69 |
70 | if -1 in dag:
71 | add_node(graph, -1, 'x[t]')
72 | if -2 in dag:
73 | add_node(graph, -2, 'h[t-1]')
74 |
75 | add_node(graph, 0, dag[-1][0].name)
76 |
77 | for idx in dag:
78 | for node in dag[idx]:
79 | if node.id not in checked_ids:
80 | add_node(graph, node.id, node.name)
81 | checked_ids.append(node.id)
82 | graph.add_edge(idx, node.id)
83 |
84 | graph.layout(prog='dot')
85 | graph.draw(path)
86 |
87 | def make_gif(paths, gif_path, max_frame=50, prefix=""):
88 | import imageio
89 |
90 | paths.sort()
91 |
92 | skip_frame = len(paths) // max_frame
93 | paths = paths[::skip_frame]
94 |
95 | images = [imageio.imread(path) for path in paths]
96 | max_h, max_w, max_c = np.max(
97 | np.array([image.shape for image in images]), 0)
98 |
99 | for idx, image in enumerate(images):
100 | h, w, c = image.shape
101 | blank = np.ones([max_h, max_w, max_c], dtype=np.uint8) * 255
102 |
103 | pivot_h, pivot_w = (max_h-h)//2, (max_w-w)//2
104 | blank[pivot_h:pivot_h+h,pivot_w:pivot_w+w,:c] = image
105 |
106 | images[idx] = blank
107 |
108 | try:
109 | images = [Image.fromarray(image) for image in images]
110 | draws = [ImageDraw.Draw(image) for image in images]
111 | font = ImageFont.truetype("assets/arial.ttf", 30)
112 |
113 | steps = [int(os.path.basename(path).rsplit('.', 1)[0].split('-')[1]) for path in paths]
114 | for step, draw in zip(steps, draws):
115 | draw.text((max_h//20, max_h//20),
116 | f"{prefix}step: {format(step, ',d')}", (0, 0, 0), font=font)
117 | except IndexError:
118 | pass
119 |
120 | imageio.mimsave(gif_path, [np.array(img) for img in images], duration=0.5)
121 |
122 |
123 | ##########################
124 | # Torch
125 | ##########################
126 |
127 | def detach(h):
128 | if type(h) == Variable:
129 | return Variable(h.data)
130 | else:
131 | return tuple(detach(v) for v in h)
132 |
133 | def get_variable(inputs, cuda=False, **kwargs):
134 | if type(inputs) in [list, np.ndarray]:
135 | inputs = torch.Tensor(inputs)
136 | if cuda:
137 | out = Variable(inputs.cuda(), **kwargs)
138 | else:
139 | out = Variable(inputs, **kwargs)
140 | return out
141 |
142 | def update_lr(optimizer, lr):
143 | for param_group in optimizer.param_groups:
144 | param_group['lr'] = lr
145 |
146 | def batchify(data, bsz, use_cuda):
147 | # code from https://github.com/pytorch/examples/blob/master/word_language_model/main.py
148 | nbatch = data.size(0) // bsz
149 | data = data.narrow(0, 0, nbatch * bsz)
150 | data = data.view(bsz, -1).t().contiguous()
151 | if use_cuda:
152 | data = data.cuda()
153 | return data
154 |
155 |
156 | ##########################
157 | # ETC
158 | ##########################
159 |
160 | Node = collections.namedtuple('Node', ['id', 'name'])
161 |
162 |
163 | class keydefaultdict(defaultdict):
164 | def __missing__(self, key):
165 | if self.default_factory is None:
166 | raise KeyError(key)
167 | else:
168 | ret = self[key] = self.default_factory(key)
169 | return ret
170 |
171 |
172 | def to_item(x):
173 | """Converts x, possibly scalar and possibly tensor, to a Python scalar."""
174 | if isinstance(x, (float, int)):
175 | return x
176 |
177 | if float(torch.__version__[0:3]) < 0.4:
178 | assert (x.dim() == 1) and (len(x) == 1)
179 | return x[0]
180 |
181 | return x.item()
182 |
183 |
184 | def get_logger(name=__file__, level=logging.INFO):
185 | logger = logging.getLogger(name)
186 |
187 | if getattr(logger, '_init_done__', None):
188 | logger.setLevel(level)
189 | return logger
190 |
191 | logger._init_done__ = True
192 | logger.propagate = False
193 | logger.setLevel(level)
194 |
195 | formatter = logging.Formatter("%(asctime)s:%(levelname)s::%(message)s")
196 | handler = logging.StreamHandler()
197 | handler.setFormatter(formatter)
198 | handler.setLevel(0)
199 |
200 | del logger.handlers[:]
201 | logger.addHandler(handler)
202 |
203 | return logger
204 |
205 |
206 | logger = get_logger()
207 |
208 |
209 | def prepare_dirs(args):
210 | """Sets the directories for the model, and creates those directories.
211 |
212 | Args:
213 | args: Parsed from `argparse` in the `config` module.
214 | """
215 | if args.load_path:
216 | if args.load_path.startswith(args.log_dir):
217 | args.model_dir = args.load_path
218 | else:
219 | if args.load_path.startswith(args.dataset):
220 | args.model_name = args.load_path
221 | else:
222 | args.model_name = "{}_{}".format(args.dataset, args.load_path)
223 | else:
224 | args.model_name = "{}_{}".format(args.dataset, get_time())
225 |
226 | if not hasattr(args, 'model_dir'):
227 | args.model_dir = os.path.join(args.log_dir, args.model_name)
228 | args.data_path = os.path.join(args.data_dir, args.dataset)
229 |
230 | for path in [args.log_dir, args.data_dir, args.model_dir]:
231 | if not os.path.exists(path):
232 | makedirs(path)
233 |
234 | def get_time():
235 | return datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
236 |
237 | def save_args(args):
238 | param_path = os.path.join(args.model_dir, "params.json")
239 |
240 | logger.info("[*] MODEL dir: %s" % args.model_dir)
241 | logger.info("[*] PARAM path: %s" % param_path)
242 |
243 | with open(param_path, 'w') as fp:
244 | json.dump(args.__dict__, fp, indent=4, sort_keys=True)
245 |
246 | def save_dag(args, dag, name):
247 | save_path = os.path.join(args.model_dir, name)
248 | logger.info("[*] Save dag : {}".format(save_path))
249 | json.dump(dag, open(save_path, 'w'))
250 |
251 | def load_dag(args):
252 | load_path = os.path.join(args.dag_path)
253 | logger.info("[*] Load dag : {}".format(load_path))
254 | with open(load_path) as f:
255 | dag = json.load(f)
256 | dag = {int(k): [Node(el[0], el[1]) for el in v] for k, v in dag.items()}
257 | save_dag(args, dag, "dag.json")
258 | draw_network(dag, os.path.join(args.model_dir, "dag.png"))
259 | return dag
260 |
261 | def makedirs(path):
262 | if not os.path.exists(path):
263 | logger.info("[*] Make directories : {}".format(path))
264 | os.makedirs(path)
265 |
266 | def remove_file(path):
267 | if os.path.exists(path):
268 | logger.info("[*] Removed: {}".format(path))
269 | os.remove(path)
270 |
271 | def backup_file(path):
272 | root, ext = os.path.splitext(path)
273 | new_path = "{}.backup_{}{}".format(root, get_time(), ext)
274 |
275 | os.rename(path, new_path)
276 | logger.info("[*] {} has backup: {}".format(path, new_path))
277 |
--------------------------------------------------------------------------------