├── .gitignore
├── .gitmodules
├── CONTRIBUTING.md
├── LICENSE
├── README.md
├── _assets
└── framework.png
├── difflogic
├── __init__.py
├── cli.py
├── dataset
│ ├── __init__.py
│ ├── graph
│ │ ├── __init__.py
│ │ ├── dataset.py
│ │ └── family.py
│ └── utils.py
├── envs
│ ├── __init__.py
│ ├── algorithmic
│ │ ├── __init__.py
│ │ ├── quickaccess.py
│ │ └── sort_envs.py
│ ├── blocksworld
│ │ ├── __init__.py
│ │ ├── block.py
│ │ ├── envs.py
│ │ ├── quickaccess.py
│ │ └── represent.py
│ ├── graph
│ │ ├── __init__.py
│ │ ├── graph.py
│ │ ├── graph_env.py
│ │ └── quickaccess.py
│ └── utils.py
├── nn
│ ├── __init__.py
│ ├── baselines
│ │ ├── __init__.py
│ │ ├── lstm.py
│ │ └── memory_net.py
│ ├── neural_logic
│ │ ├── __init__.py
│ │ ├── layer.py
│ │ └── modules
│ │ │ ├── __init__.py
│ │ │ ├── _utils.py
│ │ │ ├── dimension.py
│ │ │ ├── input_transform.py
│ │ │ └── neural_logic.py
│ └── rl
│ │ ├── __init__.py
│ │ └── reinforce.py
├── thutils.py
├── tqdm_utils.py
└── train
│ ├── __init__.py
│ └── train.py
├── models
├── blocksworld.pth
├── path.pth
└── sort.pth
├── requirements.txt
├── scripts
├── blocksworld
│ ├── README.md
│ └── learn_policy.py
└── graph
│ ├── README.md
│ ├── learn_graph_tasks.py
│ └── learn_policy.py
├── third_party
└── js_lib
│ ├── d3.v4.js
│ └── jquery-3.3.1.js
└── vis
├── README.md
├── blocksworld.html
├── blocksworld.json
├── path.html
├── path.json
├── sort.html
└── sort.json
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | env/
12 | build/
13 | develop-eggs/
14 | dist/
15 | downloads/
16 | eggs/
17 | .eggs/
18 | lib/
19 | lib64/
20 | parts/
21 | sdist/
22 | var/
23 | wheels/
24 | *.egg-info/
25 | .installed.cfg
26 | *.egg
27 |
28 | # PyInstaller
29 | # Usually these files are written by a python script from a template
30 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
31 | *.manifest
32 | *.spec
33 |
34 | # Installer logs
35 | pip-log.txt
36 | pip-delete-this-directory.txt
37 |
38 | # Unit test / coverage reports
39 | htmlcov/
40 | .tox/
41 | .coverage
42 | .coverage.*
43 | .cache
44 | nosetests.xml
45 | coverage.xml
46 | *.cover
47 | .hypothesis/
48 |
49 | # Translations
50 | *.mo
51 | *.pot
52 |
53 | # Django stuff:
54 | *.log
55 | local_settings.py
56 |
57 | # Flask stuff:
58 | instance/
59 | .webassets-cache
60 |
61 | # Scrapy stuff:
62 | .scrapy
63 |
64 | # Sphinx documentation
65 | docs/_build/
66 |
67 | # PyBuilder
68 | target/
69 |
70 | # Jupyter Notebook
71 | .ipynb_checkpoints
72 |
73 | # pyenv
74 | .python-version
75 |
76 | # celery beat schedule file
77 | celerybeat-schedule
78 |
79 | # SageMath parsed files
80 | *.sage.py
81 |
82 | # dotenv
83 | .env
84 |
85 | # virtualenv
86 | .venv
87 | venv/
88 | ENV/
89 |
90 | # Spyder project settings
91 | .spyderproject
92 | .spyproject
93 |
94 | # Rope project settings
95 | .ropeproject
96 |
97 | # mkdocs documentation
98 | /site
99 |
100 | # mypy
101 | .mypy_cache/
102 |
103 | # sublime worksapce
104 | *.sublime-workspace
105 |
--------------------------------------------------------------------------------
/.gitmodules:
--------------------------------------------------------------------------------
1 | [submodule "third_party/Jacinle"]
2 | path = third_party/Jacinle
3 | url = https://github.com/vacancy/Jacinle
4 |
--------------------------------------------------------------------------------
/CONTRIBUTING.md:
--------------------------------------------------------------------------------
1 | # How to Contribute
2 |
3 | We'd love to accept your patches and contributions to this project. There are
4 | just a few small guidelines you need to follow.
5 |
6 | ## Contributor License Agreement
7 |
8 | Contributions to this project must be accompanied by a Contributor License
9 | Agreement. You (or your employer) retain the copyright to your contribution;
10 | this simply gives us permission to use and redistribute your contributions as
11 | part of the project. Head over to to see
12 | your current agreements on file or to sign a new one.
13 |
14 | You generally only need to submit a CLA once, so if you've already submitted one
15 | (even if it was for a different project), you probably don't need to do it
16 | again.
17 |
18 | ## Code reviews
19 |
20 | All submissions, including submissions by project members, require review. We
21 | use GitHub pull requests for this purpose. Consult
22 | [GitHub Help](https://help.github.com/articles/about-pull-requests/) for more
23 | information on using pull requests.
24 |
25 | ## Community Guidelines
26 |
27 | This project follows [Google's Open Source Community
28 | Guidelines](https://opensource.google.com/conduct/).
29 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 |
2 | Apache License
3 | Version 2.0, January 2004
4 | http://www.apache.org/licenses/
5 |
6 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
7 |
8 | 1. Definitions.
9 |
10 | "License" shall mean the terms and conditions for use, reproduction,
11 | and distribution as defined by Sections 1 through 9 of this document.
12 |
13 | "Licensor" shall mean the copyright owner or entity authorized by
14 | the copyright owner that is granting the License.
15 |
16 | "Legal Entity" shall mean the union of the acting entity and all
17 | other entities that control, are controlled by, or are under common
18 | control with that entity. For the purposes of this definition,
19 | "control" means (i) the power, direct or indirect, to cause the
20 | direction or management of such entity, whether by contract or
21 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
22 | outstanding shares, or (iii) beneficial ownership of such entity.
23 |
24 | "You" (or "Your") shall mean an individual or Legal Entity
25 | exercising permissions granted by this License.
26 |
27 | "Source" form shall mean the preferred form for making modifications,
28 | including but not limited to software source code, documentation
29 | source, and configuration files.
30 |
31 | "Object" form shall mean any form resulting from mechanical
32 | transformation or translation of a Source form, including but
33 | not limited to compiled object code, generated documentation,
34 | and conversions to other media types.
35 |
36 | "Work" shall mean the work of authorship, whether in Source or
37 | Object form, made available under the License, as indicated by a
38 | copyright notice that is included in or attached to the work
39 | (an example is provided in the Appendix below).
40 |
41 | "Derivative Works" shall mean any work, whether in Source or Object
42 | form, that is based on (or derived from) the Work and for which the
43 | editorial revisions, annotations, elaborations, or other modifications
44 | represent, as a whole, an original work of authorship. For the purposes
45 | of this License, Derivative Works shall not include works that remain
46 | separable from, or merely link (or bind by name) to the interfaces of,
47 | the Work and Derivative Works thereof.
48 |
49 | "Contribution" shall mean any work of authorship, including
50 | the original version of the Work and any modifications or additions
51 | to that Work or Derivative Works thereof, that is intentionally
52 | submitted to Licensor for inclusion in the Work by the copyright owner
53 | or by an individual or Legal Entity authorized to submit on behalf of
54 | the copyright owner. For the purposes of this definition, "submitted"
55 | means any form of electronic, verbal, or written communication sent
56 | to the Licensor or its representatives, including but not limited to
57 | communication on electronic mailing lists, source code control systems,
58 | and issue tracking systems that are managed by, or on behalf of, the
59 | Licensor for the purpose of discussing and improving the Work, but
60 | excluding communication that is conspicuously marked or otherwise
61 | designated in writing by the copyright owner as "Not a Contribution."
62 |
63 | "Contributor" shall mean Licensor and any individual or Legal Entity
64 | on behalf of whom a Contribution has been received by Licensor and
65 | subsequently incorporated within the Work.
66 |
67 | 2. Grant of Copyright License. Subject to the terms and conditions of
68 | this License, each Contributor hereby grants to You a perpetual,
69 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
70 | copyright license to reproduce, prepare Derivative Works of,
71 | publicly display, publicly perform, sublicense, and distribute the
72 | Work and such Derivative Works in Source or Object form.
73 |
74 | 3. Grant of Patent License. Subject to the terms and conditions of
75 | this License, each Contributor hereby grants to You a perpetual,
76 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
77 | (except as stated in this section) patent license to make, have made,
78 | use, offer to sell, sell, import, and otherwise transfer the Work,
79 | where such license applies only to those patent claims licensable
80 | by such Contributor that are necessarily infringed by their
81 | Contribution(s) alone or by combination of their Contribution(s)
82 | with the Work to which such Contribution(s) was submitted. If You
83 | institute patent litigation against any entity (including a
84 | cross-claim or counterclaim in a lawsuit) alleging that the Work
85 | or a Contribution incorporated within the Work constitutes direct
86 | or contributory patent infringement, then any patent licenses
87 | granted to You under this License for that Work shall terminate
88 | as of the date such litigation is filed.
89 |
90 | 4. Redistribution. You may reproduce and distribute copies of the
91 | Work or Derivative Works thereof in any medium, with or without
92 | modifications, and in Source or Object form, provided that You
93 | meet the following conditions:
94 |
95 | (a) You must give any other recipients of the Work or
96 | Derivative Works a copy of this License; and
97 |
98 | (b) You must cause any modified files to carry prominent notices
99 | stating that You changed the files; and
100 |
101 | (c) You must retain, in the Source form of any Derivative Works
102 | that You distribute, all copyright, patent, trademark, and
103 | attribution notices from the Source form of the Work,
104 | excluding those notices that do not pertain to any part of
105 | the Derivative Works; and
106 |
107 | (d) If the Work includes a "NOTICE" text file as part of its
108 | distribution, then any Derivative Works that You distribute must
109 | include a readable copy of the attribution notices contained
110 | within such NOTICE file, excluding those notices that do not
111 | pertain to any part of the Derivative Works, in at least one
112 | of the following places: within a NOTICE text file distributed
113 | as part of the Derivative Works; within the Source form or
114 | documentation, if provided along with the Derivative Works; or,
115 | within a display generated by the Derivative Works, if and
116 | wherever such third-party notices normally appear. The contents
117 | of the NOTICE file are for informational purposes only and
118 | do not modify the License. You may add Your own attribution
119 | notices within Derivative Works that You distribute, alongside
120 | or as an addendum to the NOTICE text from the Work, provided
121 | that such additional attribution notices cannot be construed
122 | as modifying the License.
123 |
124 | You may add Your own copyright statement to Your modifications and
125 | may provide additional or different license terms and conditions
126 | for use, reproduction, or distribution of Your modifications, or
127 | for any such Derivative Works as a whole, provided Your use,
128 | reproduction, and distribution of the Work otherwise complies with
129 | the conditions stated in this License.
130 |
131 | 5. Submission of Contributions. Unless You explicitly state otherwise,
132 | any Contribution intentionally submitted for inclusion in the Work
133 | by You to the Licensor shall be under the terms and conditions of
134 | this License, without any additional terms or conditions.
135 | Notwithstanding the above, nothing herein shall supersede or modify
136 | the terms of any separate license agreement you may have executed
137 | with Licensor regarding such Contributions.
138 |
139 | 6. Trademarks. This License does not grant permission to use the trade
140 | names, trademarks, service marks, or product names of the Licensor,
141 | except as required for reasonable and customary use in describing the
142 | origin of the Work and reproducing the content of the NOTICE file.
143 |
144 | 7. Disclaimer of Warranty. Unless required by applicable law or
145 | agreed to in writing, Licensor provides the Work (and each
146 | Contributor provides its Contributions) on an "AS IS" BASIS,
147 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
148 | implied, including, without limitation, any warranties or conditions
149 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
150 | PARTICULAR PURPOSE. You are solely responsible for determining the
151 | appropriateness of using or redistributing the Work and assume any
152 | risks associated with Your exercise of permissions under this License.
153 |
154 | 8. Limitation of Liability. In no event and under no legal theory,
155 | whether in tort (including negligence), contract, or otherwise,
156 | unless required by applicable law (such as deliberate and grossly
157 | negligent acts) or agreed to in writing, shall any Contributor be
158 | liable to You for damages, including any direct, indirect, special,
159 | incidental, or consequential damages of any character arising as a
160 | result of this License or out of the use or inability to use the
161 | Work (including but not limited to damages for loss of goodwill,
162 | work stoppage, computer failure or malfunction, or any and all
163 | other commercial damages or losses), even if such Contributor
164 | has been advised of the possibility of such damages.
165 |
166 | 9. Accepting Warranty or Additional Liability. While redistributing
167 | the Work or Derivative Works thereof, You may choose to offer,
168 | and charge a fee for, acceptance of support, warranty, indemnity,
169 | or other liability obligations and/or rights consistent with this
170 | License. However, in accepting such obligations, You may act only
171 | on Your own behalf and on Your sole responsibility, not on behalf
172 | of any other Contributor, and only if You agree to indemnify,
173 | defend, and hold each Contributor harmless for any liability
174 | incurred by, or claims asserted against, such Contributor by reason
175 | of your accepting any such warranty or additional liability.
176 |
177 | END OF TERMS AND CONDITIONS
178 |
179 | APPENDIX: How to apply the Apache License to your work.
180 |
181 | To apply the Apache License to your work, attach the following
182 | boilerplate notice, with the fields enclosed by brackets "[]"
183 | replaced with your own identifying information. (Don't include
184 | the brackets!) The text should be enclosed in the appropriate
185 | comment syntax for the file format. We also recommend that a
186 | file or class name and description of purpose be included on the
187 | same "printed page" as the copyright notice for easier
188 | identification within third-party archives.
189 |
190 | Copyright [yyyy] [name of copyright owner]
191 |
192 | Licensed under the Apache License, Version 2.0 (the "License");
193 | you may not use this file except in compliance with the License.
194 | You may obtain a copy of the License at
195 |
196 | http://www.apache.org/licenses/LICENSE-2.0
197 |
198 | Unless required by applicable law or agreed to in writing, software
199 | distributed under the License is distributed on an "AS IS" BASIS,
200 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
201 | See the License for the specific language governing permissions and
202 | limitations under the License.
203 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Neural Logic Machines
2 | PyTorch implementation for the Neural Logic Machines (NLM). **Please note that this is not an officially supported Google product.**
3 |
4 |
5 |
6 |
7 |
8 | Neural Logic Machine (NLM) is a neural-symbolic architecture for both inductive learning and logic reasoning. NLMs use tensors to represent logic predicates. This is done by grounding the predicate as
9 | True or False over a fixed set of objects. Based on the tensor representation, rules are implemented
10 | as neural operators that can be applied over the premise tensors and generate conclusion tensors.
11 |
12 | **[Neural Logic Machines](https://arxiv.org/pdf/1904.11694.pdf)**
13 |
14 | [Honghua Dong](http://dhh1995.github.io)\*,
15 | [Jiayuan Mao](http://jiayuanm.com)\*,
16 | [Tian Lin](https://www.linkedin.com/in/tianl),
17 | [Chong Wang](https://chongw.github.io/),
18 | [Lihong Li](https://lihongli.github.io/), and
19 | [Denny Zhou](https://dennyzhou.github.io/)
20 |
21 | (\*: indicates equal contribution.)
22 |
23 | In International Conference on Learning Representations (ICLR) 2019
24 |
25 | [[Paper]](https://arxiv.org/pdf/1904.11694.pdf)
26 | [[Project Page]](https://sites.google.com/view/neural-logic-machines)
27 |
28 | ```
29 | @inproceedings{
30 | dong2018neural,
31 | title = {Neural Logic Machines},
32 | author = {Honghua Dong and Jiayuan Mao and Tian Lin and Chong Wang and Lihong Li and Denny Zhou},
33 | booktitle = {International Conference on Learning Representations},
34 | year = {2019},
35 | url = {https://openreview.net/forum?id=B1xY-hRctX},
36 | }
37 | ```
38 |
39 | ## Prerequisites
40 | * Python 3
41 | * PyTorch 0.4.0
42 | * [Jacinle](https://github.com/vacancy/Jacinle). We use the version [ed90c3a](https://github.com/vacancy/Jacinle/tree/ed90c3a70a133eb9c6c2f4ea2cc3d907de7ffd57) for this repo.
43 | * Other required python packages specified by `requirements.txt`. See the Installation.
44 |
45 | ## Installation
46 |
47 | Clone this repository:
48 |
49 | ```
50 | git clone https://github.com/google/neural-logic-machines --recursive
51 | ```
52 |
53 | Install [Jacinle](https://github.com/vacancy/Jacinle) included as a submodule. You need to add the bin path to your global `PATH` environment variable:
54 |
55 | ```
56 | export PATH=/third_party/Jacinle/bin:$PATH
57 | ```
58 |
59 | Create a conda environment for NLM, and install the requirements. This includes the required python packages
60 | from both Jacinle and NLM. Most of the required packages have been included in the built-in `anaconda` package:
61 |
62 | ```
63 | conda create -n nlm anaconda
64 | conda install pytorch torchvision -c pytorch
65 | ```
66 |
67 | ## Usage
68 |
69 | This repo contains 10 graph-related reasoning tasks (using supervised learning)
70 | and 3 decision-making tasks (using reinforcement learning).
71 |
72 | We also provide pre-trained models for 3 decision-making tasks in [models](models) directory,
73 |
74 | Taking the [Blocks World](scripts/blocksworld) task as an example.
75 |
76 | ``` shell
77 | # To train the model:
78 | $ jac-run scripts/blocksworld/learn_policy.py --task final
79 | # To test the model:
80 | $ jac-run scripts/blocksworld/learn_policy.py --task final --test-only --load models/blocksworld.pth
81 | # add [--test-epoch-size T] to control the number of testing cases.
82 | # E.g. use T=20 for a quick testing, usually take ~2min on CPUs.
83 | # Sample output of testing for number=10 and number=50:
84 | > Evaluation:
85 | length = 12.500000
86 | number = 10.000000
87 | score = 0.885000
88 | succ = 1.000000
89 | > Evaluation:
90 | length = 85.800000
91 | number = 50.000000
92 | score = 0.152000
93 | succ = 1.000000
94 | ```
95 |
96 | Please refer to the [graph](scripts/graph) directory for training/inference details of other tasks.
97 |
98 | ### Useful Command-line options
99 | - `jac-crun GPU_ID FILE --use-gpu GPU_ID` instead of `jac-run FILE` to enable using gpu with id `GPU_ID`.
100 | - `--model {nlm, memnet}`[default: `nlm`]: choose `memnet` to use (Memory Networks)[https://arxiv.org/abs/1503.08895] as baseline.
101 | - `--runs N`: take `N` runs.
102 | - `--dump-dir DUMP_DIR`: place to dump logs/summaries/checkpoints/plays.
103 | - `--dump-play`: dump plays for visualization in json format, can be visualized by our [html visualizer](vis). (not applied to [graph tasks](scripts/graph/learn_graph_tasks.py))
104 | - `--test-number-begin B --test-number-step S --step-number-end E`: \
105 | defines the range of the sizes of the test instances.
106 | - `--test-epoch-size SIZE`: number of test instances.
107 |
108 | For a complete command-line options see `jac-run FILE -h` (e.g. `jac-run scripts/blocksworld/learn_policy.py -h`).
109 |
--------------------------------------------------------------------------------
/_assets/framework.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/google/neural-logic-machines/3f8a8966c54d13d2658c77c03793a9a98a283e22/_assets/framework.png
--------------------------------------------------------------------------------
/difflogic/__init__.py:
--------------------------------------------------------------------------------
1 | #! /usr/bin/env python3
2 | #
3 | # Copyright 2018 Google LLC
4 | #
5 | # Licensed under the Apache License, Version 2.0 (the "License");
6 | # you may not use this file except in compliance with the License.
7 | # You may obtain a copy of the License at
8 | #
9 | # https://www.apache.org/licenses/LICENSE-2.0
10 | #
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 | # limitations under the License.
16 |
--------------------------------------------------------------------------------
/difflogic/cli.py:
--------------------------------------------------------------------------------
1 | #! /usr/bin/env python3
2 | #
3 | # Copyright 2018 Google LLC
4 | #
5 | # Licensed under the Apache License, Version 2.0 (the "License");
6 | # you may not use this file except in compliance with the License.
7 | # You may obtain a copy of the License at
8 | #
9 | # https://www.apache.org/licenses/LICENSE-2.0
10 | #
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 | # limitations under the License.
16 | """Command line interface, print or format args as string."""
17 |
18 | from jacinle.utils.printing import kvformat
19 | from jacinle.utils.printing import kvprint
20 |
21 |
22 | def print_args(args):
23 | kvprint(args.__dict__)
24 |
25 |
26 | def format_args(args):
27 | return kvformat(args.__dict__)
28 |
--------------------------------------------------------------------------------
/difflogic/dataset/__init__.py:
--------------------------------------------------------------------------------
1 | #! /usr/bin/env python3
2 | #
3 | # Copyright 2018 Google LLC
4 | #
5 | # Licensed under the Apache License, Version 2.0 (the "License");
6 | # you may not use this file except in compliance with the License.
7 | # You may obtain a copy of the License at
8 | #
9 | # https://www.apache.org/licenses/LICENSE-2.0
10 | #
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 | # limitations under the License.
16 |
17 |
--------------------------------------------------------------------------------
/difflogic/dataset/graph/__init__.py:
--------------------------------------------------------------------------------
1 | #! /usr/bin/env python3
2 | #
3 | # Copyright 2018 Google LLC
4 | #
5 | # Licensed under the Apache License, Version 2.0 (the "License");
6 | # you may not use this file except in compliance with the License.
7 | # You may obtain a copy of the License at
8 | #
9 | # https://www.apache.org/licenses/LICENSE-2.0
10 | #
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 | # limitations under the License.
16 |
17 | from .dataset import *
18 |
--------------------------------------------------------------------------------
/difflogic/dataset/graph/dataset.py:
--------------------------------------------------------------------------------
1 | #! /usr/bin/env python3
2 | #
3 | # Copyright 2018 Google LLC
4 | #
5 | # Licensed under the Apache License, Version 2.0 (the "License");
6 | # you may not use this file except in compliance with the License.
7 | # You may obtain a copy of the License at
8 | #
9 | # https://www.apache.org/licenses/LICENSE-2.0
10 | #
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 | # limitations under the License.
16 | """Implement datasets classes for graph and family tree tasks."""
17 |
18 | import numpy as np
19 |
20 | from torch.utils.data.dataset import Dataset
21 | from torchvision import datasets
22 |
23 | import jacinle.random as random
24 |
25 | from .family import randomly_generate_family
26 | from ...envs.graph import get_random_graph_generator
27 |
28 | __all__ = [
29 | 'GraphOutDegreeDataset', 'GraphConnectivityDataset', 'GraphAdjacentDataset',
30 | 'FamilyTreeDataset'
31 | ]
32 |
33 |
34 | class GraphDatasetBase(Dataset):
35 | """Base dataset class for graphs.
36 |
37 | Args:
38 | epoch_size: The number of batches for each epoch.
39 | nmin: The minimal number of nodes in the graph.
40 | pmin: The lower bound of the parameter p of the graph generator.
41 | nmax: The maximal number of nodes in the graph,
42 | the same as $nmin in default.
43 | pmax: The upper bound of the parameter p of the graph generator,
44 | the same as $pmin in default.
45 | directed: Generator directed graph if directed=True.
46 | gen_method: Controlling the graph generation method.
47 | If gen_method='dnc', use the similar way as in DNC paper.
48 | Else using Erdos-Renyi algorithm (each edge exists with prob).
49 | """
50 |
51 | def __init__(self,
52 | epoch_size,
53 | nmin,
54 | pmin,
55 | nmax=None,
56 | pmax=None,
57 | directed=False,
58 | gen_method='dnc'):
59 | self._epoch_size = epoch_size
60 | self._nmin = nmin
61 | self._nmax = nmin if nmax is None else nmax
62 | assert self._nmin <= self._nmax
63 | self._pmin = pmin
64 | self._pmax = pmin if pmax is None else pmax
65 | assert self._pmin <= self._pmax
66 | self._directed = directed
67 | self._gen_method = gen_method
68 |
69 | def _gen_graph(self, item):
70 | n = self._nmin + item % (self._nmax - self._nmin + 1)
71 | p = self._pmin + random.rand() * (self._pmax - self._pmin)
72 | gen = get_random_graph_generator(self._gen_method)
73 | return gen(n, p, directed=self._directed)
74 |
75 | def __len__(self):
76 | return self._epoch_size
77 |
78 |
79 | class GraphOutDegreeDataset(GraphDatasetBase):
80 | """The dataset for out-degree task in graphs."""
81 |
82 | def __init__(self,
83 | degree,
84 | epoch_size,
85 | nmin,
86 | pmin,
87 | nmax=None,
88 | pmax=None,
89 | directed=False,
90 | gen_method='dnc'):
91 | super().__init__(epoch_size, nmin, pmin, nmax, pmax, directed, gen_method)
92 | self._degree = degree
93 |
94 | def __getitem__(self, item):
95 | graph = self._gen_graph(item)
96 | # The goal is to predict whether out-degree(x) == self._degree for all x.
97 | return dict(
98 | n=graph.nr_nodes,
99 | relations=np.expand_dims(graph.get_edges(), axis=-1),
100 | target=(graph.get_out_degree() == self._degree).astype('float'),
101 | )
102 |
103 |
104 | class GraphConnectivityDataset(GraphDatasetBase):
105 | """The dataset for connectivity task in graphs."""
106 |
107 | def __init__(self,
108 | dist_limit,
109 | epoch_size,
110 | nmin,
111 | pmin,
112 | nmax=None,
113 | pmax=None,
114 | directed=False,
115 | gen_method='dnc'):
116 | super().__init__(epoch_size, nmin, pmin, nmax, pmax, directed, gen_method)
117 | self._dist_limit = dist_limit
118 |
119 | def __getitem__(self, item):
120 | graph = self._gen_graph(item)
121 | # The goal is to predict whether (x, y) are connected within a limited steps
122 | # I.e. dist(x, y) <= self._dist_limit for all x, y.
123 | return dict(
124 | n=graph.nr_nodes,
125 | relations=np.expand_dims(graph.get_edges(), axis=-1),
126 | target=graph.get_connectivity(self._dist_limit, exclude_self=True),
127 | )
128 |
129 |
130 | class GraphAdjacentDataset(GraphDatasetBase):
131 | """The dataset for adjacent task in graphs."""
132 |
133 | def __init__(self,
134 | nr_colors,
135 | epoch_size,
136 | nmin,
137 | pmin,
138 | nmax=None,
139 | pmax=None,
140 | directed=False,
141 | gen_method='dnc',
142 | is_train=True,
143 | is_mnist_colors=False,
144 | mnist_dir='../data'):
145 |
146 | super().__init__(epoch_size, nmin, pmin, nmax, pmax, directed, gen_method)
147 | self._nr_colors = nr_colors
148 | self._is_mnist_colors = is_mnist_colors
149 | # When taking MNIST digits as inputs, fetch MNIST dataset.
150 | if self._is_mnist_colors:
151 | assert nr_colors == 10
152 | self.mnist = datasets.MNIST(
153 | mnist_dir, train=is_train, download=True, transform=None)
154 |
155 | def __getitem__(self, item):
156 | graph = self._gen_graph(item)
157 | n = graph.nr_nodes
158 | if self._is_mnist_colors:
159 | m = self.mnist.__len__()
160 | digits = []
161 | colors = []
162 | for i in range(n):
163 | x = random.randint(m)
164 | digit, color = self.mnist.__getitem__(x)
165 | digits.append(np.array(digit)[np.newaxis])
166 | colors.append(color)
167 | digits, colors = np.array(digits), np.array(colors)
168 | else:
169 | colors = random.randint(self._nr_colors, size=n)
170 | states = np.zeros((n, self._nr_colors))
171 | adjacent = np.zeros((n, self._nr_colors))
172 | # The goal is to predict whether there is a node with desired color
173 | # as adjacent node for each node x.
174 | for i in range(n):
175 | states[i, colors[i]] = 1
176 | adjacent[i, colors[i]] = 1
177 | for j in range(n):
178 | if graph.has_edge(i, j):
179 | adjacent[i, colors[j]] = 1
180 | if self._is_mnist_colors:
181 | states = digits
182 | return dict(
183 | n=n,
184 | relations=np.expand_dims(graph.get_edges(), axis=-1),
185 | states=states,
186 | colors=colors,
187 | target=adjacent,
188 | )
189 |
190 |
191 | class FamilyTreeDataset(Dataset):
192 | """The dataset for family tree tasks."""
193 |
194 | def __init__(self,
195 | task,
196 | epoch_size,
197 | nmin,
198 | nmax=None,
199 | p_marriage=0.8,
200 | balance_sample=False):
201 | super().__init__()
202 | self._task = task
203 | self._epoch_size = epoch_size
204 | self._nmin = nmin
205 | self._nmax = nmin if nmax is None else nmax
206 | assert self._nmin <= self._nmax
207 | self._p_marriage = p_marriage
208 | self._balance_sample = balance_sample
209 | self._data = []
210 |
211 | def _gen_family(self, item):
212 | n = self._nmin + item % (self._nmax - self._nmin + 1)
213 | return randomly_generate_family(n, self._p_marriage)
214 |
215 | def __getitem__(self, item):
216 | while len(self._data) == 0:
217 | family = self._gen_family(item)
218 | relations = family.relations[:, :, 2:]
219 | if self._task == 'has-father':
220 | target = family.has_father()
221 | elif self._task == 'has-daughter':
222 | target = family.has_daughter()
223 | elif self._task == 'has-sister':
224 | target = family.has_sister()
225 | elif self._task == 'parents':
226 | target = family.get_parents()
227 | elif self._task == 'grandparents':
228 | target = family.get_grandparents()
229 | elif self._task == 'uncle':
230 | target = family.get_uncle()
231 | elif self._task == 'maternal-great-uncle':
232 | target = family.get_maternal_great_uncle()
233 | else:
234 | assert False, '{} is not supported.'.format(self._task)
235 |
236 | if not self._balance_sample:
237 | return dict(n=family.nr_people, relations=relations, target=target)
238 |
239 | # In balance_sample case, the data format is different. Not used.
240 | def get_positions(x):
241 | return list(np.vstack(np.where(x)).T)
242 |
243 | def append_data(pos, target):
244 | states = np.zeros((family.nr_people, 2))
245 | states[pos[0], 0] = states[pos[1], 1] = 1
246 | self._data.append(dict(n=family.nr_people,
247 | relations=relations,
248 | states=states,
249 | target=target))
250 |
251 | positive = get_positions(target == 1)
252 | if len(positive) == 0:
253 | continue
254 | negative = get_positions(target == 0)
255 | np.random.shuffle(negative)
256 | negative = negative[:len(positive)]
257 | for i in positive:
258 | append_data(i, 1)
259 | for i in negative:
260 | append_data(i, 0)
261 |
262 | return self._data.pop()
263 |
264 | def __len__(self):
265 | return self._epoch_size
266 |
--------------------------------------------------------------------------------
/difflogic/dataset/graph/family.py:
--------------------------------------------------------------------------------
1 | #! /usr/bin/env python3
2 | #
3 | # Copyright 2018 Google LLC
4 | #
5 | # Licensed under the Apache License, Version 2.0 (the "License");
6 | # you may not use this file except in compliance with the License.
7 | # You may obtain a copy of the License at
8 | #
9 | # https://www.apache.org/licenses/LICENSE-2.0
10 | #
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 | # limitations under the License.
16 | """Implement family tree generator and family tree class."""
17 |
18 | import jacinle.random as random
19 | import numpy as np
20 |
21 | __all__ = ['Family', 'randomly_generate_family']
22 |
23 |
24 | class Family(object):
25 | """Family tree class to support queries about relations between family members.
26 |
27 | Args:
28 | nr_people: The number of people in the family tree.
29 | relations: The relations between family members. The relations should be an
30 | matrix of shape [nr_people, nr_people, 6]. The relations in the order
31 | are: husband, wife, father, mother, son, daughter.
32 | """
33 |
34 | def __init__(self, nr_people, relations):
35 | self._n = nr_people
36 | self._relations = relations
37 |
38 | def mul(self, x, y):
39 | return np.clip(np.matmul(x, y), 0, 1)
40 |
41 | @property
42 | def nr_people(self):
43 | return self._n
44 |
45 | @property
46 | def relations(self):
47 | return self._relations
48 |
49 | @property
50 | def father(self):
51 | return self._relations[:, :, 2]
52 |
53 | @property
54 | def mother(self):
55 | return self._relations[:, :, 3]
56 |
57 | @property
58 | def son(self):
59 | return self._relations[:, :, 4]
60 |
61 | @property
62 | def daughter(self):
63 | return self._relations[:, :, 5]
64 |
65 | def has_father(self):
66 | return self.father.max(axis=1)
67 |
68 | def has_daughter(self):
69 | return self.daughter.max(axis=1)
70 |
71 | def has_sister(self):
72 | daughter_cnt = self.daughter.sum(axis=1)
73 | is_daughter = np.clip(self.daughter.sum(axis=0), 0, 1)
74 | return ((np.matmul(self.father, daughter_cnt) - is_daughter) >
75 | 0).astype('float')
76 | # The wrong implementation: count herself as sister.
77 | # return self.mul(self.father, self.daughter).max(axis=1)
78 |
79 | def get_parents(self):
80 | return np.clip(self.father + self.mother, 0, 1)
81 |
82 | def get_grandfather(self):
83 | return self.mul(self.get_parents(), self.father)
84 |
85 | def get_grandmother(self):
86 | return self.mul(self.get_parents(), self.mother)
87 |
88 | def get_grandparents(self):
89 | parents = self.get_parents()
90 | return self.mul(parents, parents)
91 |
92 | def get_uncle(self):
93 | return np.clip(self.mul(self.get_grandparents(), self.son) - self.father, 0, 1)
94 | # The wrong Implementation: not exclude father.
95 | # return self.mul(self.get_grandparents(), self.son)
96 |
97 | def get_maternal_great_uncle(self):
98 | return self.mul(self.mul(self.get_grandmother(), self.mother), self.son)
99 |
100 |
101 | def randomly_generate_family(n, p_marriage=0.8, verbose=False):
102 | """Randomly generate family trees.
103 |
104 | Mimic the process of families growing using a timeline. Each time a new person
105 | is created, randomly sample the gender and parents (could be none, indicating
106 | not included in the family tree) of the person. Also maintain lists of singles
107 | of each gender. With probability $p_marrige, randomly pick two from each list
108 | to be married. Finally randomly permute the order of people.
109 |
110 | Args:
111 | n: The number of people in the family tree.
112 | p_marriage: The probability of marriage happens each time.
113 | verbose: print the marriage and child born process if verbose=True.
114 | Returns:
115 | A family tree instance of $n people.
116 | """
117 | assert n > 0
118 | ids = list(random.permutation(n))
119 |
120 | single_m = []
121 | single_w = []
122 | couples = [None]
123 | # The relations are: husband, wife, father, mother, son, daughter
124 | rel = np.zeros((n, n, 6))
125 | fathers = [None for i in range(n)]
126 | mothers = [None for i in range(n)]
127 |
128 | def add_couple(man, woman):
129 | """Add a couple relation among (man, woman)."""
130 | couples.append((man, woman))
131 | rel[woman, man, 0] = 1 # husband
132 | rel[man, woman, 1] = 1 # wife
133 | if verbose:
134 | print('couple', man, woman)
135 |
136 | def add_child(parents, child, gender):
137 | """Add a child relation between parents and the child according to gender."""
138 | father, mother = parents
139 | fathers[child] = father
140 | mothers[child] = mother
141 | rel[child, father, 2] = 1 # father
142 | rel[child, mother, 3] = 1 # mother
143 | if gender == 0: # son
144 | rel[father, child, 4] = 1
145 | rel[mother, child, 4] = 1
146 | else: # daughter
147 | rel[father, child, 5] = 1
148 | rel[mother, child, 5] = 1
149 | if verbose:
150 | print('child', father, mother, child, gender)
151 |
152 | def check_relations(man, woman):
153 | """Disable marriage between cousins."""
154 | if fathers[man] is None or fathers[woman] is None:
155 | return True
156 | if fathers[man] == fathers[woman]:
157 | return False
158 |
159 | def same_parent(x, y):
160 | return fathers[x] is not None and fathers[y] is not None and fathers[
161 | x] == fathers[y]
162 |
163 | for x in [fathers[man], mothers[man]]:
164 | for y in [fathers[woman], mothers[woman]]:
165 | if same_parent(man, y) or same_parent(woman, x) or same_parent(x, y):
166 | return False
167 | return True
168 |
169 | while ids:
170 | x = ids.pop()
171 | gender = random.randint(2)
172 | parents = random.choice(couples)
173 | if gender == 0:
174 | single_m.append(x)
175 | else:
176 | single_w.append(x)
177 | if parents is not None:
178 | add_child(parents, x, gender)
179 |
180 | if random.rand() < p_marriage and len(single_m) > 0 and len(single_w) > 0:
181 | mi = random.randint(len(single_m))
182 | wi = random.randint(len(single_w))
183 | man = single_m[mi]
184 | woman = single_w[wi]
185 | if check_relations(man, woman):
186 | add_couple(man, woman)
187 | del single_m[mi]
188 | del single_w[wi]
189 |
190 | return Family(n, rel)
191 |
192 |
--------------------------------------------------------------------------------
/difflogic/dataset/utils.py:
--------------------------------------------------------------------------------
1 | #! /usr/bin/env python3
2 | #
3 | # Copyright 2018 Google LLC
4 | #
5 | # Licensed under the Apache License, Version 2.0 (the "License");
6 | # you may not use this file except in compliance with the License.
7 | # You may obtain a copy of the License at
8 | #
9 | # https://www.apache.org/licenses/LICENSE-2.0
10 | #
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 | # limitations under the License.
16 | """Utility functions for customer datasets."""
17 |
18 | import collections
19 | import copy
20 | import jacinle.random as random
21 | import numpy as np
22 |
23 | __all__ = [
24 | 'ValidActionDataset',
25 | 'RandomlyIterDataset',
26 | ]
27 |
28 |
29 | class ValidActionDataset(object):
30 | """Collect data and sample batches of whether actions are valid or not.
31 |
32 | The data are collected into different bins according to the number of objects
33 | in the cases. At most $maxn bins are maintained. The capacity of each bin
34 | should be set, the newly added ones exceeding the capacity would replace the
35 | earliest ones in the bin (implemented using deque).
36 |
37 | Args:
38 | capacity: upper bound of the number of training instances in each bin.
39 | maxn: The maximum number of objects in the collected cases.
40 | """
41 |
42 | def __init__(self, capacity=5000, maxn=30):
43 | super().__init__()
44 | self.largest_n = 0
45 | self.maxn = maxn
46 | # Better to use defaultdict to replace array
47 | self.data = [[collections.deque(maxlen=capacity)
48 | for i in range(2)]
49 | for j in range(maxn + 1)]
50 |
51 | def append(self, n, state, action, valid):
52 | """add a new data point of n objects, given $state, the $action is $valid."""
53 | assert n <= self.maxn
54 | valid = int(valid)
55 | self.data[n][valid].append((state, action))
56 | self.largest_n = max(self.largest_n, n)
57 |
58 | def _sample(self, data, num, label):
59 | """Sample a batch of size $num from the data, with already determined label."""
60 | # assert num <= len(data)
61 | states, actions = [], []
62 | for _ in range(num):
63 | ind = random.randint(len(data))
64 | state, action = data[ind]
65 | states.append(state)
66 | actions.append([action])
67 | return np.array(states), np.array(actions), np.ones((num,)) * label
68 |
69 | def sample_batch(self, batch_size, n=None):
70 | """Sample a batch of data for $n objects."""
71 | # use the data from the bin with largest number of objects in default.
72 | if n is None:
73 | n = self.largest_n
74 | data = self.data[n]
75 | # The pos/neg ones are not strict equal if batch_size % 2 != 0.
76 | # Should add warning.
77 | num = batch_size // 2
78 | # if no negative ones, using all positive ones.
79 | c = 1 - int(len(data[0]) > 0)
80 | states1, actions1, labels1 = self._sample(data[c], num, c)
81 | # if no positive ones, using all negative ones.
82 | c = int(len(data[1]) > 0)
83 | states2, actions2, labels2 = self._sample(data[c], batch_size - num, c)
84 | return (np.vstack([states1, states2]),
85 | np.vstack([actions1, actions2]).squeeze(axis=-1),
86 | np.concatenate([labels1, labels2], axis=0))
87 |
88 |
89 | class RandomlyIterDataset(object):
90 | """Collect data and iterate the dataset in random order."""
91 |
92 | def __init__(self):
93 | super().__init__()
94 | self.data = []
95 | self.ind = 0
96 |
97 | def append(self, data):
98 | self.data.append(data)
99 |
100 | @property
101 | def size(self):
102 | return len(self.data)
103 |
104 | def reset(self):
105 | self.ind = 0
106 |
107 | def get(self):
108 | """iterate the dataset with random order."""
109 | # Shuffle before the iteration starts.
110 | # The iteration should better be separated with collection.
111 | if self.ind == 0:
112 | random.shuffle(self.data)
113 | ret = self.data[self.ind]
114 | self.ind += 1
115 | if self.ind == self.size:
116 | self.ind = 0
117 | return copy.deepcopy(ret)
118 |
--------------------------------------------------------------------------------
/difflogic/envs/__init__.py:
--------------------------------------------------------------------------------
1 | #! /usr/bin/env python3
2 | #
3 | # Copyright 2018 Google LLC
4 | #
5 | # Licensed under the Apache License, Version 2.0 (the "License");
6 | # you may not use this file except in compliance with the License.
7 | # You may obtain a copy of the License at
8 | #
9 | # https://www.apache.org/licenses/LICENSE-2.0
10 | #
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 | # limitations under the License.
16 |
--------------------------------------------------------------------------------
/difflogic/envs/algorithmic/__init__.py:
--------------------------------------------------------------------------------
1 | #! /usr/bin/env python3
2 | #
3 | # Copyright 2018 Google LLC
4 | #
5 | # Licensed under the Apache License, Version 2.0 (the "License");
6 | # you may not use this file except in compliance with the License.
7 | # You may obtain a copy of the License at
8 | #
9 | # https://www.apache.org/licenses/LICENSE-2.0
10 | #
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 | # limitations under the License.
16 |
17 | from .sort_envs import *
18 | from .quickaccess import *
19 |
--------------------------------------------------------------------------------
/difflogic/envs/algorithmic/quickaccess.py:
--------------------------------------------------------------------------------
1 | #! /usr/bin/env python3
2 | #
3 | # Copyright 2018 Google LLC
4 | #
5 | # Licensed under the Apache License, Version 2.0 (the "License");
6 | # you may not use this file except in compliance with the License.
7 | # You may obtain a copy of the License at
8 | #
9 | # https://www.apache.org/licenses/LICENSE-2.0
10 | #
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 | # limitations under the License.
16 | """Quick access for algorithmic environments."""
17 |
18 | from jaclearn.rl.proxy import LimitLengthProxy
19 |
20 | from .sort_envs import ListSortingEnv
21 | from ..utils import get_action_mapping_sorting
22 | from ..utils import MapActionProxy
23 |
24 | __all__ = ['get_sort_env', 'make']
25 |
26 |
27 | def get_sort_env(n, exclude_self=True):
28 | env_cls = ListSortingEnv
29 | p = env_cls(n)
30 | p = LimitLengthProxy(p, n * 2)
31 | mapping = get_action_mapping_sorting(n, exclude_self=exclude_self)
32 | p = MapActionProxy(p, mapping)
33 | return p
34 |
35 |
36 | def make(task, *args, **kwargs):
37 | if task == 'sort':
38 | return get_sort_env(*args, **kwargs)
39 | else:
40 | raise ValueError('Unknown task: {}.'.format(task))
41 |
--------------------------------------------------------------------------------
/difflogic/envs/algorithmic/sort_envs.py:
--------------------------------------------------------------------------------
1 | #! /usr/bin/env python3
2 | #
3 | # Copyright 2018 Google LLC
4 | #
5 | # Licensed under the Apache License, Version 2.0 (the "License");
6 | # you may not use this file except in compliance with the License.
7 | # You may obtain a copy of the License at
8 | #
9 | # https://www.apache.org/licenses/LICENSE-2.0
10 | #
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 | # limitations under the License.
16 | """The environment class for sorting tasks."""
17 |
18 | import numpy as np
19 |
20 | import jacinle.random as random
21 | from jacinle.utils.meta import notnone_property
22 | from jaclearn.rl.env import SimpleRLEnvBase
23 |
24 |
25 | class ListSortingEnv(SimpleRLEnvBase):
26 | """Environment for sorting a random permutation.
27 |
28 | Args:
29 | nr_numbers: The number of numbers in the array.
30 | """
31 |
32 | def __init__(self, nr_numbers):
33 | super().__init__()
34 | self._nr_numbers = nr_numbers
35 | self._array = None
36 |
37 | @notnone_property
38 | def array(self):
39 | return self._array
40 |
41 | @property
42 | def nr_numbers(self):
43 | return self._nr_numbers
44 |
45 | def get_state(self):
46 | """Compute the state given the array."""
47 | x, y = np.meshgrid(self.array, self.array)
48 | number_relations = np.stack([x < y, x == y, x > y], axis=-1).astype('float')
49 | index = np.array(list(range(self._nr_numbers)))
50 | x, y = np.meshgrid(index, index)
51 | position_relations = np.stack([x < y, x == y, x > y],
52 | axis=-1).astype('float')
53 | return np.concatenate([number_relations, position_relations], axis=-1)
54 |
55 | def _calculate_optimal(self):
56 | """Calculate the optimal number of steps for sorting the array."""
57 | a = self._array
58 | b = [0 for i in range(len(a))]
59 | cnt = 0
60 | for i, x in enumerate(a):
61 | if b[i] == 0:
62 | j = x
63 | b[i] = 1
64 | while b[j] == 0:
65 | b[j] = 1
66 | j = a[j]
67 | assert i == j
68 | cnt += 1
69 | return len(a) - cnt
70 |
71 | def _restart(self):
72 | """Restart: Generate a random permutation."""
73 | self._array = random.permutation(self._nr_numbers)
74 | self._set_current_state(self.get_state())
75 | self.optimal = self._calculate_optimal()
76 |
77 | def _action(self, action):
78 | """action is a tuple (i, j), perform this action leads to the swap."""
79 | a = self._array
80 | i, j = action
81 | x, y = a[i], a[j]
82 | a[i], a[j] = y, x
83 | self._set_current_state(self.get_state())
84 | for i in range(self._nr_numbers):
85 | if a[i] != i:
86 | return 0, False
87 | return 1, True
88 |
--------------------------------------------------------------------------------
/difflogic/envs/blocksworld/__init__.py:
--------------------------------------------------------------------------------
1 | #! /usr/bin/env python3
2 | #
3 | # Copyright 2018 Google LLC
4 | #
5 | # Licensed under the Apache License, Version 2.0 (the "License");
6 | # you may not use this file except in compliance with the License.
7 | # You may obtain a copy of the License at
8 | #
9 | # https://www.apache.org/licenses/LICENSE-2.0
10 | #
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 | # limitations under the License.
16 |
17 | from .block import *
18 | from .represent import *
19 | from .envs import *
20 | from .quickaccess import *
21 |
--------------------------------------------------------------------------------
/difflogic/envs/blocksworld/block.py:
--------------------------------------------------------------------------------
1 | #! /usr/bin/env python3
2 | #
3 | # Copyright 2018 Google LLC
4 | #
5 | # Licensed under the Apache License, Version 2.0 (the "License");
6 | # you may not use this file except in compliance with the License.
7 | # You may obtain a copy of the License at
8 | #
9 | # https://www.apache.org/licenses/LICENSE-2.0
10 | #
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 | # limitations under the License.
16 | """Implement random blocksworld generator and BlocksWorld class."""
17 |
18 | import jacinle.random as random
19 |
20 | __all__ = ['Block', 'BlocksWorld', 'randomly_generate_world']
21 |
22 |
23 | class Block(object):
24 | """A single block in blocksworld, using tree-like storage method."""
25 |
26 | def __init__(self, index, father=None):
27 | self.index = index
28 | self.father = father
29 | self.children = []
30 |
31 | @property
32 | def is_ground(self):
33 | return self.father is None
34 |
35 | @property
36 | def placeable(self):
37 | if self.is_ground:
38 | return True
39 | return len(self.children) == 0
40 |
41 | @property
42 | def moveable(self):
43 | if self.is_ground:
44 | return False
45 | return len(self.children) == 0
46 |
47 | def remove_from_father(self):
48 | assert self in self.father.children
49 | self.father.children.remove(self)
50 | self.father = None
51 |
52 | def add_to(self, other):
53 | self.father = other
54 | other.children.append(self)
55 |
56 |
57 | class BlockStorage(object):
58 | """The storage of blocks with an order.
59 |
60 | Args:
61 | blocks: The list of instances of Block class.
62 | random_order: set the blocks in a desired order, or unchanged in default.
63 | """
64 |
65 | def __init__(self, blocks, random_order=None):
66 | super().__init__()
67 | self._blocks = blocks
68 | self.set_random_order(random_order)
69 |
70 | def __getitem__(self, item):
71 | if self._random_order is None:
72 | return self._blocks[item]
73 | return self._blocks[self._random_order[item]]
74 |
75 | def __len__(self):
76 | return len(self._blocks)
77 |
78 | @property
79 | def raw(self):
80 | return self._blocks.copy()
81 |
82 | @property
83 | def random_order(self):
84 | return self._random_order
85 |
86 | def set_random_order(self, random_order):
87 | if random_order is None:
88 | self._random_order = None
89 | self._inv_random_order = None
90 | return
91 |
92 | self._random_order = random_order
93 | self._inv_random_order = sorted(
94 | range(len(random_order)), key=lambda x: random_order[x])
95 |
96 | def index(self, i):
97 | if self._random_order is None:
98 | return i
99 | return self._random_order[i]
100 |
101 | def inv_index(self, i):
102 | if self._random_order is None:
103 | return i
104 | return self._inv_random_order[i]
105 |
106 | def permute(self, array):
107 | if self._random_order is None:
108 | return array
109 | return [array[self._random_order[i]] for i in range(len(self._blocks))]
110 |
111 |
112 | class BlocksWorld(object):
113 | """The blocks world class implement queries and movements."""
114 |
115 | def __init__(self, blocks, random_order=None):
116 | super().__init__()
117 | self.blocks = BlockStorage(blocks, random_order)
118 |
119 | @property
120 | def size(self):
121 | return len(self.blocks)
122 |
123 | def move(self, x, y):
124 | if x != y and self.moveable(x, y):
125 | self.blocks[x].remove_from_father()
126 | self.blocks[x].add_to(self.blocks[y])
127 |
128 | def moveable(self, x, y):
129 | return self.blocks[x].moveable and self.blocks[y].placeable
130 |
131 |
132 | def randomly_generate_world(nr_blocks, random_order=False, one_stack=False):
133 | """Randomly generate a blocks world case.
134 |
135 | Similar to classical random tree generation, incrementally add new blocks.
136 | for each new block, randomly sample a valid father and stack on its father.
137 |
138 | Args:
139 | nr_blocks: The number of blocks in the world.
140 | random_order: Randomly permute the indexes of the blocks if set True.
141 | Or set as a provided order. Leave the raw order unchanged in default.
142 | one_stack: A special case where only one stack of blocks. If True, for each
143 | new node, set its father as the last node.
144 |
145 | Returns:
146 | A BlocksWorld instance which is randomly generated.
147 | """
148 | blocks = [Block(0, None)]
149 | leafs = [blocks[0]]
150 |
151 | for i in range(1, nr_blocks + 1):
152 | other = random.choice_list(leafs)
153 | this = Block(i)
154 | this.add_to(other)
155 | if not other.placeable or one_stack:
156 | leafs.remove(other)
157 | blocks.append(this)
158 | leafs.append(this)
159 |
160 | order = None
161 | if random_order:
162 | order = random.permutation(len(blocks))
163 |
164 | return BlocksWorld(blocks, random_order=order)
165 |
--------------------------------------------------------------------------------
/difflogic/envs/blocksworld/envs.py:
--------------------------------------------------------------------------------
1 | #! /usr/bin/env python3
2 | #
3 | # Copyright 2018 Google LLC
4 | #
5 | # Licensed under the Apache License, Version 2.0 (the "License");
6 | # you may not use this file except in compliance with the License.
7 | # You may obtain a copy of the License at
8 | #
9 | # https://www.apache.org/licenses/LICENSE-2.0
10 | #
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 | # limitations under the License.
16 | """The environment class for blocks world tasks."""
17 |
18 | import numpy as np
19 |
20 | import jacinle.random as random
21 | from jaclearn.rl.env import SimpleRLEnvBase
22 |
23 | from .block import randomly_generate_world
24 | from .represent import get_coordinates
25 | from .represent import decorate
26 |
27 | __all__ = ['FinalBlocksWorldEnv']
28 |
29 |
30 | class BlocksWorldEnv(SimpleRLEnvBase):
31 | """The Base BlocksWorld environment.
32 |
33 | Args:
34 | nr_blocks: The number of blocks.
35 | random_order: randomly permute the indexes of the blocks. This
36 | option prevents the models from memorizing the configurations.
37 | decorate: if True, the coordinates in the states will also include the
38 | world index (default: 0) and the block index (starting from 0).
39 | prob_unchange: The probability that an action is not effective.
40 | prob_fall: The probability that an action will make the object currently
41 | moving fall on the ground.
42 | """
43 |
44 | def __init__(self,
45 | nr_blocks,
46 | random_order=False,
47 | decorate=False,
48 | prob_unchange=0.0,
49 | prob_fall=0.0):
50 | super().__init__()
51 | self.nr_blocks = nr_blocks
52 | self.nr_objects = nr_blocks + 1
53 | self.random_order = random_order
54 | self.decorate = decorate
55 | self.prob_unchange = prob_unchange
56 | self.prob_fall = prob_fall
57 |
58 | def _restart(self):
59 | self.world = randomly_generate_world(
60 | self.nr_blocks, random_order=self.random_order)
61 | self._set_current_state(self._get_decorated_states())
62 | self.is_over = False
63 | self.cached_result = self._get_result()
64 |
65 | def _get_decorated_states(self, world_id=0):
66 | state = get_coordinates(self.world)
67 | if self.decorate:
68 | state = decorate(state, self.nr_objects, world_id)
69 | return state
70 |
71 |
72 | class FinalBlocksWorldEnv(BlocksWorldEnv):
73 | """The BlocksWorld environment for the final task."""
74 |
75 | def __init__(self,
76 | nr_blocks,
77 | random_order=False,
78 | shape_only=False,
79 | fix_ground=False,
80 | prob_unchange=0.0,
81 | prob_fall=0.0):
82 | super().__init__(nr_blocks, random_order, True, prob_unchange, prob_fall)
83 | self.shape_only = shape_only
84 | self.fix_ground = fix_ground
85 |
86 | def _restart(self):
87 | self.start_world = randomly_generate_world(
88 | self.nr_blocks, random_order=False)
89 | self.final_world = randomly_generate_world(
90 | self.nr_blocks, random_order=False)
91 | self.world = self.start_world
92 | if self.random_order:
93 | n = self.world.size
94 | # Ground is fixed as index 0 if fix_ground is True
95 | ground_ind = 0 if self.fix_ground else random.randint(n)
96 |
97 | def get_order():
98 | raw_order = random.permutation(n - 1)
99 | order = []
100 | for i in range(n - 1):
101 | if i == ground_ind:
102 | order.append(0)
103 | order.append(raw_order[i] + 1)
104 | if ground_ind == n - 1:
105 | order.append(0)
106 | return order
107 |
108 | self.start_world.blocks.set_random_order(get_order())
109 | self.final_world.blocks.set_random_order(get_order())
110 |
111 | self._prepare_worlds()
112 | self.start_state = decorate(
113 | self._get_coordinates(self.start_world), self.nr_objects, 0)
114 | self.final_state = decorate(
115 | self._get_coordinates(self.final_world), self.nr_objects, 1)
116 |
117 | self.is_over = False
118 | self.cached_result = self._get_result()
119 |
120 | def _prepare_worlds(self):
121 | pass
122 |
123 | def _action(self, action):
124 | assert self.start_world is not None, 'you need to call restart() first'
125 |
126 | if self.is_over:
127 | return 0, True
128 | r, is_over = self.cached_result
129 | if is_over:
130 | self.is_over = True
131 | return r, is_over
132 |
133 | x, y = action
134 | assert 0 <= x <= self.nr_blocks and 0 <= y <= self.nr_blocks
135 |
136 | p = random.rand()
137 | if p >= self.prob_unchange:
138 | if p < self.prob_unchange + self.prob_fall:
139 | y = self.start_world.blocks.inv_index(0) # fall to ground
140 | self.start_world.move(x, y)
141 | self.start_state = decorate(
142 | self._get_coordinates(self.start_world), self.nr_objects, 0)
143 | r, is_over = self._get_result()
144 | if is_over:
145 | self.is_over = True
146 | return r, is_over
147 |
148 | def _get_current_state(self):
149 | assert self.start_world is not None, 'Should call restart() first.'
150 | return np.vstack([self.start_state, self.final_state])
151 |
152 | def _get_result(self):
153 | sorted_start_state = self._get_coordinates(self.start_world, sort=True)
154 | sorted_final_state = self._get_coordinates(self.final_world, sort=True)
155 | if (sorted_start_state == sorted_final_state).all():
156 | return 1, True
157 | else:
158 | return 0, False
159 |
160 | def _get_coordinates(self, world, sort=False):
161 | # If shape_only=True, only the shape of the blocks need to be the same.
162 | # If shape_only=False, the index of the blocks should also match.
163 | coordinates = get_coordinates(world, absolute=not self.shape_only)
164 | if sort:
165 | if not self.shape_only:
166 | coordinates = decorate(coordinates, self.nr_objects, 0)
167 | coordinates = np.array(sorted(list(map(tuple, coordinates))))
168 | return coordinates
169 |
--------------------------------------------------------------------------------
/difflogic/envs/blocksworld/quickaccess.py:
--------------------------------------------------------------------------------
1 | #! /usr/bin/env python3
2 | #
3 | # Copyright 2018 Google LLC
4 | #
5 | # Licensed under the Apache License, Version 2.0 (the "License");
6 | # you may not use this file except in compliance with the License.
7 | # You may obtain a copy of the License at
8 | #
9 | # https://www.apache.org/licenses/LICENSE-2.0
10 | #
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 | # limitations under the License.
16 | """Quick access for blocksworld environments."""
17 |
18 | from jaclearn.rl.proxy import LimitLengthProxy
19 |
20 | from .envs import FinalBlocksWorldEnv
21 | from ..utils import get_action_mapping_blocksworld
22 | from ..utils import MapActionProxy
23 |
24 | __all__ = ['get_final_env', 'make']
25 |
26 |
27 | def get_final_env(nr_blocks,
28 | random_order=False,
29 | exclude_self=True,
30 | shape_only=False,
31 | fix_ground=False,
32 | limit_length=None):
33 | """Get the blocksworld environment for the final task."""
34 | p = FinalBlocksWorldEnv(
35 | nr_blocks,
36 | random_order=random_order,
37 | shape_only=shape_only,
38 | fix_ground=fix_ground)
39 | p = LimitLengthProxy(p, limit_length or nr_blocks * 4)
40 | mapping = get_action_mapping_blocksworld(nr_blocks, exclude_self=exclude_self)
41 | p = MapActionProxy(p, mapping)
42 | return p
43 |
44 |
45 | def make(task, *args, **kwargs):
46 | if task == 'final':
47 | return get_final_env(*args, **kwargs)
48 | else:
49 | raise ValueError('Unknown task: {}.'.format(task))
50 |
--------------------------------------------------------------------------------
/difflogic/envs/blocksworld/represent.py:
--------------------------------------------------------------------------------
1 | #! /usr/bin/env python3
2 | #
3 | # Copyright 2018 Google LLC
4 | #
5 | # Licensed under the Apache License, Version 2.0 (the "License");
6 | # you may not use this file except in compliance with the License.
7 | # You may obtain a copy of the License at
8 | #
9 | # https://www.apache.org/licenses/LICENSE-2.0
10 | #
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 | # limitations under the License.
16 | """Implement queries for different representations of blocksworld."""
17 |
18 | import numpy as np
19 |
20 | __all__ = [
21 | 'get_world_string', 'get_coordinates', 'get_is_ground', 'get_moveable',
22 | 'get_placeable', 'decorate'
23 | ]
24 |
25 |
26 | def get_world_string(world):
27 | """Format the blocks world instance as string to view."""
28 | index_mapping = {b.index: i for i, b in enumerate(world.blocks)}
29 | raw_blocks = world.blocks.raw
30 |
31 | result = ''
32 |
33 | def dfs(block, indent):
34 | nonlocal result
35 |
36 | result += '{}Block #{}: (IsGround={}, Moveable={}, Placeable={})\n'.format(
37 | ' ' * (indent * 2), index_mapping[block.index], block.is_ground,
38 | block.moveable, block.placeable)
39 | for c in block.children:
40 | dfs(c, indent + 1)
41 |
42 | dfs(raw_blocks[0], 0)
43 | return result
44 |
45 |
46 | def get_coordinates(world, absolute=False):
47 | """Get the coordinates of each block in the blocks world."""
48 | coordinates = [None for _ in range(world.size)]
49 | raw_blocks = world.blocks.raw
50 |
51 | def dfs(block):
52 | """Use depth-first-search to get the coordinate of each block."""
53 | if block.is_ground:
54 | coordinates[block.index] = (0, 0)
55 | for j, c in enumerate(block.children):
56 | # When using absolute coordinate, the block x directly placed on the
57 | # ground gets coordinate (x, 1).
58 | x = world.blocks.inv_index(c.index) if absolute else j
59 | coordinates[c.index] = (x, 1)
60 | dfs(c)
61 | else:
62 | coor = coordinates[block.index]
63 | assert coor is not None
64 | x, y = coor
65 | for c in block.children:
66 | coordinates[c.index] = (x, y + 1)
67 | dfs(c)
68 |
69 | dfs(raw_blocks[0])
70 | coordinates = world.blocks.permute(coordinates)
71 | return np.array(coordinates)
72 |
73 |
74 | def get_is_ground(world):
75 | return np.array([block.is_ground for block in world.blocks])
76 |
77 |
78 | def get_moveable(world):
79 | return np.array([block.moveable for block in world.blocks])
80 |
81 |
82 | def get_placeable(world):
83 | return np.array([block.placeable for block in world.blocks])
84 |
85 |
86 | def decorate(state, nr_objects, world_id=None):
87 | """Append world index and object index information to state."""
88 | info = []
89 | if world_id is not None:
90 | info.append(np.ones((nr_objects, 1)) * world_id)
91 | info.extend([np.array(range(nr_objects))[:, np.newaxis], state])
92 | return np.hstack(info)
93 |
--------------------------------------------------------------------------------
/difflogic/envs/graph/__init__.py:
--------------------------------------------------------------------------------
1 | #! /usr/bin/env python3
2 | #
3 | # Copyright 2018 Google LLC
4 | #
5 | # Licensed under the Apache License, Version 2.0 (the "License");
6 | # you may not use this file except in compliance with the License.
7 | # You may obtain a copy of the License at
8 | #
9 | # https://www.apache.org/licenses/LICENSE-2.0
10 | #
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 | # limitations under the License.
16 |
17 | from .graph import *
18 | from .graph_env import *
19 | from .quickaccess import *
20 |
--------------------------------------------------------------------------------
/difflogic/envs/graph/graph.py:
--------------------------------------------------------------------------------
1 | #! /usr/bin/env python3
2 | #
3 | # Copyright 2018 Google LLC
4 | #
5 | # Licensed under the Apache License, Version 2.0 (the "License");
6 | # you may not use this file except in compliance with the License.
7 | # You may obtain a copy of the License at
8 | #
9 | # https://www.apache.org/licenses/LICENSE-2.0
10 | #
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 | # limitations under the License.
16 | """Implement random graph generators and Graph class."""
17 |
18 | import copy
19 | import numpy as np
20 |
21 | import jacinle.random as random
22 |
23 | __all__ = ['Graph', 'randomly_generate_graph_er', 'randomly_generate_graph_dnc',
24 | 'get_random_graph_generator']
25 |
26 |
27 | class Graph(object):
28 | """Store a graph using adjacency matrix.
29 |
30 | Args:
31 | nr_nodes: The number of nodes in the graph.
32 | edges: The adjacency matrix of the graph.
33 | """
34 |
35 | def __init__(self, nr_nodes, edges, coordinates=None):
36 | edges = edges.astype('int32')
37 | assert edges.min() >= 0 and edges.max() <= 1
38 | self._nr_nodes = nr_nodes
39 | self._edges = edges
40 | self._coordinates = coordinates
41 | self._shortest = None
42 | self.extra_info = {}
43 |
44 | @property
45 | def nr_nodes(self):
46 | return self._nr_nodes
47 |
48 | def get_edges(self):
49 | return copy.copy(self._edges)
50 |
51 | def get_coordinates(self):
52 | return self._coordinates
53 |
54 | def get_relations(self):
55 | """Return edges and identity matrix."""
56 | return np.stack([self.get_edges(), np.eye(self.nr_nodes)], axis=-1)
57 |
58 | def has_edge(self, x, y):
59 | return self._edges[x, y] == 1
60 |
61 | def get_out_degree(self):
62 | """Return the out degree of each node."""
63 | return np.sum(self._edges, axis=1)
64 |
65 | def get_shortest(self):
66 | """Return the length of shortest path between nodes."""
67 | if self._shortest is not None:
68 | return self._shortest
69 |
70 | n = self.nr_nodes
71 | edges = self.get_edges()
72 |
73 | # n + 1 indicates unreachable.
74 | shortest = np.ones((n, n)) * (n + 1)
75 | shortest[np.where(edges == 1)] = 1
76 | # Make sure that shortest[x, x] = 0
77 | shortest -= shortest * np.eye(n)
78 | shortest = shortest.astype('int32')
79 |
80 | # Floyd Algorithm
81 | for k in range(n):
82 | for i in range(n):
83 | for j in range(n):
84 | if i != j:
85 | shortest[i, j] = min(shortest[i, j],
86 | shortest[i, k] + shortest[k, j])
87 | self._shortest = shortest
88 | return self._shortest
89 |
90 | def get_connectivity(self, k=None, exclude_self=True):
91 | """Calculate the k-connectivity.
92 |
93 | Args:
94 | k: The limited steps. unlimited if k=None or k<0.
95 | exclude_self: remove connectivity[x, x] if exclude_self=True.
96 | Returns:
97 | A numpy.ndarray representing the k-connectivity for each pair of nodes.
98 | """
99 | shortest = self.get_shortest()
100 | if k is None or k < 0:
101 | k = self.nr_nodes
102 | k = min(k, self.nr_nodes)
103 | conn = (shortest <= k).astype('int32')
104 | if exclude_self:
105 | n = self.nr_nodes
106 | inds = np.where(~np.eye(n, dtype=bool))
107 | conn = conn[inds]
108 | conn.resize(n, n - 1)
109 | return conn
110 |
111 |
112 | def randomly_generate_graph_er(n, p, directed=False):
113 | """Randomly generate a graph by sampling the existence of each edge.
114 |
115 | Each edge between nodes has the probability $p (directed) or
116 | 1 - (1-$p)^2 (undirected) to exist.
117 |
118 | Args:
119 | n: The number of nodes in the graph.
120 | p: the probability that a edge doesn't exist in directed graph.
121 | directed: Directed or Undirected graph. Default: False (undirected)
122 |
123 | Returns:
124 | A Graph class representing randomly generated graph.
125 | """
126 | edges = (random.rand(n, n) < p).astype('float')
127 | edges -= edges * np.eye(n)
128 | if not directed:
129 | edges = np.maximum(edges, edges.T)
130 | return Graph(n, edges)
131 |
132 |
133 | def randomly_generate_graph_dnc(n, p=None, directed=False):
134 | """Random graph generation method as in DNC.
135 |
136 | As described in Differentiable neural computers (DNC),
137 | (https://www.nature.com/articles/nature20101.epdf?author_access_token=ImTXBI8aWbYxYQ51Plys8NRgN0jAjWel9jnR3ZoTv0MggmpDmwljGswxVdeocYSurJ3hxupzWuRNeGvvXnoO8o4jTJcnAyhGuZzXJ1GEaD-Z7E6X_a9R-xqJ9TfJWBqz)
138 | Sample $n nodes in a unit square. Then sample out-degree (m) of each nodes,
139 | connect to $m nearest neighbors (Euclidean distance) in the unit square.
140 |
141 | Args:
142 | n: The number of nodes in the graph.
143 | p: Control the sampling of the out-degree.
144 | If p=None, the default range is [1, n // 3].
145 | If p is float, the range is [1, int(n * p)].
146 | If p is int, the range is [1, p].
147 | If p is tuple. the range is [p[0], p[1]].
148 | directed: Directed or Undirected graph. Default: False (undirected)
149 |
150 | Returns:
151 | A Graph class representing randomly generated graph.
152 | """
153 | edges = np.zeros((n, n), dtype='float')
154 | pos = random.rand(n, 2)
155 |
156 | def dist(x, y):
157 | return ((x - y)**2).mean()
158 |
159 | if isinstance(p, tuple):
160 | lower, upper = p
161 | else:
162 | lower = 1
163 | if p is None:
164 | upper = n // 3
165 | elif isinstance(p, int):
166 | upper = p
167 | elif isinstance(p, float):
168 | upper = int(n * p)
169 | else:
170 | assert False, 'Unknown argument type: {}'.format(type(p))
171 | upper = max(upper, 1)
172 | lower = max(lower, 1)
173 | upper = min(upper, n - 1)
174 |
175 | for i in range(n):
176 | d = []
177 | k = random.randint(upper - lower + 1) + lower
178 | for j in range(n):
179 | if i != j:
180 | d.append((dist(pos[i], pos[j]), j))
181 | d.sort()
182 | for j in range(k):
183 | edges[i, d[j][1]] = 1
184 | if not directed:
185 | edges = np.maximum(edges, edges.T)
186 | return Graph(n, edges, pos)
187 |
188 |
189 | def get_random_graph_generator(name):
190 | if name == 'dnc':
191 | return randomly_generate_graph_dnc
192 | return randomly_generate_graph_er
193 |
--------------------------------------------------------------------------------
/difflogic/envs/graph/graph_env.py:
--------------------------------------------------------------------------------
1 | #! /usr/bin/env python3
2 | #
3 | # Copyright 2018 Google LLC
4 | #
5 | # Licensed under the Apache License, Version 2.0 (the "License");
6 | # you may not use this file except in compliance with the License.
7 | # You may obtain a copy of the License at
8 | #
9 | # https://www.apache.org/licenses/LICENSE-2.0
10 | #
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 | # limitations under the License.
16 | """The environment class for graph tasks."""
17 |
18 | import numpy as np
19 |
20 | import jacinle.random as random
21 | from jacinle.utils.meta import notnone_property
22 | from jaclearn.rl.env import SimpleRLEnvBase
23 |
24 | from .graph import get_random_graph_generator
25 |
26 | __all__ = ['GraphEnvBase', 'PathGraphEnv']
27 |
28 |
29 | class GraphEnvBase(SimpleRLEnvBase):
30 | """The base class for Graph Environment.
31 |
32 | Args:
33 | nr_nodes: The number of nodes in the graph.
34 | pmin: The lower bound of the parameter controlling the graph generation.
35 | pmax: The upper bound of the parameter controlling the graph generation,
36 | the same as $pmin in default.
37 | directed: Generator directed graph if directed=True.
38 | gen_method: Controlling the graph generation method.
39 | If gen_method='dnc', use the similar way as in DNC paper.
40 | Else using Erdos-Renyi algorithm (each edge exists with prob).
41 | """
42 |
43 | def __init__(self,
44 | nr_nodes,
45 | pmin,
46 | pmax=None,
47 | directed=False,
48 | gen_method='dnc'):
49 | super().__init__()
50 | self._nr_nodes = nr_nodes
51 | self._pmin = pmin
52 | self._pmax = pmin if pmax is None else pmax
53 | self._directed = directed
54 | self._gen_method = gen_method
55 | self._graph = None
56 |
57 | @notnone_property
58 | def graph(self):
59 | return self._graph
60 |
61 | @property
62 | def nr_nodes(self):
63 | return self._nr_nodes
64 |
65 | def _restart(self):
66 | """Restart the environment."""
67 | self._gen_graph()
68 |
69 | def _gen_graph(self):
70 | """generate the graph by specified method."""
71 | n = self._nr_nodes
72 | p = self._pmin + random.rand() * (self._pmax - self._pmin)
73 | assert self._gen_method in ['edge', 'dnc']
74 | gen = get_random_graph_generator(self._gen_method)
75 | self._graph = gen(n, p, self._directed)
76 |
77 |
78 | class PathGraphEnv(GraphEnvBase):
79 | """Env for Finding a path from starting node to the destination."""
80 |
81 | def __init__(self,
82 | nr_nodes,
83 | dist_range,
84 | pmin,
85 | pmax=None,
86 | directed=False,
87 | gen_method='dnc'):
88 | super().__init__(nr_nodes, pmin, pmax, directed, gen_method)
89 | self._dist_range = dist_range
90 |
91 | @property
92 | def dist(self):
93 | return self._dist
94 |
95 | def _restart(self):
96 | super()._restart()
97 | self._dist = self._sample_dist()
98 | self._task = None
99 | while True:
100 | self._task = self._gen()
101 | if self._task is not None:
102 | break
103 | # Generate another graph if fail to find two nodes with desired distance.
104 | self._gen_graph()
105 | self._current = self._task[0]
106 | self._set_current_state(self._task)
107 | self._steps = 0
108 |
109 | def _sample_dist(self):
110 | """Sample the distance between the starting node and the destination."""
111 | lower, upper = self._dist_range
112 | upper = min(upper, self._nr_nodes - 1)
113 | return random.randint(upper - lower + 1) + lower
114 |
115 | def _gen(self):
116 | """Sample the starting node and the destination according to the distance."""
117 | dist_matrix = self._graph.get_shortest()
118 | st, ed = np.where(dist_matrix == self.dist)
119 | if len(st) == 0:
120 | return None
121 | ind = random.randint(len(st))
122 | return st[ind], ed[ind]
123 |
124 | def _action(self, target):
125 | """Move to the target node from current node if has_edge(current -> target)."""
126 | if self._current == self._task[1]:
127 | return 1, True
128 | if self._graph.has_edge(self._current, target):
129 | self._current = target
130 | self._set_current_state((self._current, self._task[1]))
131 | if self._current == self._task[1]:
132 | return 1, True
133 | self._steps += 1
134 | if self._steps >= self.dist:
135 | return 0, True
136 | return 0, False
137 |
--------------------------------------------------------------------------------
/difflogic/envs/graph/quickaccess.py:
--------------------------------------------------------------------------------
1 | #! /usr/bin/env python3
2 | #
3 | # Copyright 2018 Google LLC
4 | #
5 | # Licensed under the Apache License, Version 2.0 (the "License");
6 | # you may not use this file except in compliance with the License.
7 | # You may obtain a copy of the License at
8 | #
9 | # https://www.apache.org/licenses/LICENSE-2.0
10 | #
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 | # limitations under the License.
16 | """Quick access for graph environments."""
17 |
18 | from .graph_env import PathGraphEnv
19 |
20 | __all__ = ['get_path_env', 'make']
21 |
22 |
23 | def get_path_env(n, dist_range, pmin, pmax, directed=False, gen_method='dnc'):
24 | env_cls = PathGraphEnv
25 | p = env_cls(
26 | n, dist_range, pmin, pmax, directed=directed, gen_method=gen_method)
27 | return p
28 |
29 |
30 | def make(task, *args, **kwargs):
31 | if task == 'path':
32 | return get_path_env(*args, **kwargs)
33 | else:
34 | raise ValueError('Unknown task: {}.'.format(task))
35 |
--------------------------------------------------------------------------------
/difflogic/envs/utils.py:
--------------------------------------------------------------------------------
1 | #! /usr/bin/env python3
2 | #
3 | # Copyright 2018 Google LLC
4 | #
5 | # Licensed under the Apache License, Version 2.0 (the "License");
6 | # you may not use this file except in compliance with the License.
7 | # You may obtain a copy of the License at
8 | #
9 | # https://www.apache.org/licenses/LICENSE-2.0
10 | #
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 | # limitations under the License.
16 | """Utility functions for customer datasets."""
17 |
18 | from jaclearn.rl.env import ProxyRLEnvBase
19 | from jaclearn.rl.space import DiscreteActionSpace
20 |
21 | __all__ = ['MapActionProxy', 'get_action_mapping', 'get_action_mapping_graph',
22 | 'get_action_mapping_sorting', 'get_action_mapping_blocksworld']
23 |
24 |
25 | class MapActionProxy(ProxyRLEnvBase):
26 | """RL Env proxy to map actions using provided mapping function."""
27 |
28 | def __init__(self, other, mapping):
29 | super().__init__(other)
30 | self._mapping = mapping
31 |
32 | @property
33 | def mapping(self):
34 | return self._mapping
35 |
36 | def map_action(self, action):
37 | assert action < len(self._mapping)
38 | return self._mapping[action]
39 |
40 | def _get_action_space(self):
41 | return DiscreteActionSpace(len(self._mapping))
42 |
43 | def _action(self, action):
44 | return self.proxy.action(self.map_action(action))
45 |
46 |
47 | def get_action_mapping(n, exclude_self=True):
48 | """In a matrix view, this a mapping from 1d-index to 2d-coordinate."""
49 | mapping = [
50 | (i, j) for i in range(n) for j in range(n) if (i != j or not exclude_self)
51 | ]
52 | return mapping
53 |
54 | get_action_mapping_graph = get_action_mapping
55 | get_action_mapping_sorting = get_action_mapping
56 |
57 |
58 | def get_action_mapping_blocksworld(nr_blocks, exclude_self=True):
59 | return get_action_mapping_graph(nr_blocks + 1, exclude_self)
60 |
--------------------------------------------------------------------------------
/difflogic/nn/__init__.py:
--------------------------------------------------------------------------------
1 | #! /usr/bin/env python3
2 | #
3 | # Copyright 2018 Google LLC
4 | #
5 | # Licensed under the Apache License, Version 2.0 (the "License");
6 | # you may not use this file except in compliance with the License.
7 | # You may obtain a copy of the License at
8 | #
9 | # https://www.apache.org/licenses/LICENSE-2.0
10 | #
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 | # limitations under the License.
16 | # This file is part of Jacinle.
17 |
--------------------------------------------------------------------------------
/difflogic/nn/baselines/__init__.py:
--------------------------------------------------------------------------------
1 | #! /usr/bin/env python3
2 | #
3 | # Copyright 2018 Google LLC
4 | #
5 | # Licensed under the Apache License, Version 2.0 (the "License");
6 | # you may not use this file except in compliance with the License.
7 | # You may obtain a copy of the License at
8 | #
9 | # https://www.apache.org/licenses/LICENSE-2.0
10 | #
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 | # limitations under the License.
16 |
17 | from .lstm import *
18 | from .memory_net import *
19 |
--------------------------------------------------------------------------------
/difflogic/nn/baselines/lstm.py:
--------------------------------------------------------------------------------
1 | #! /usr/bin/env python3
2 | #
3 | # Copyright 2018 Google LLC
4 | #
5 | # Licensed under the Apache License, Version 2.0 (the "License");
6 | # you may not use this file except in compliance with the License.
7 | # You may obtain a copy of the License at
8 | #
9 | # https://www.apache.org/licenses/LICENSE-2.0
10 | #
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 | # limitations under the License.
16 | """The LSTM baseline."""
17 |
18 | import numpy as np
19 | import torch
20 | import torch.nn as nn
21 |
22 | import jacinle.random as random
23 | from difflogic.nn.neural_logic.modules._utils import meshgrid
24 | from jactorch.functional.shape import broadcast
25 |
26 | __all__ = ['LSTMBaseline']
27 |
28 |
29 | class LSTMBaseline(nn.Module):
30 | """LSTM baseline model."""
31 | def __init__(self,
32 | input_dim,
33 | feature_dim,
34 | num_layers=2,
35 | hidden_size=512,
36 | code_length=8):
37 | super().__init__()
38 | current_dim = input_dim + code_length * 2
39 | self.feature_dim = feature_dim
40 | assert feature_dim == 1 or feature_dim == 2, ('only support attributes or '
41 | 'relations')
42 | self.num_layers = num_layers
43 | self.hidden_size = hidden_size
44 | self.code_length = code_length
45 | self.lstm = nn.LSTM(
46 | current_dim,
47 | hidden_size,
48 | num_layers,
49 | batch_first=True,
50 | bidirectional=True)
51 |
52 | def forward(self, relations, attributes=None):
53 | batch_size, nr = relations.size()[:2]
54 | assert nr == relations.size(2)
55 |
56 | id_shape = list(relations.size()[:-1])
57 | ids = [
58 | random.permutation(2**self.code_length - 1)[:nr] + 1
59 | for i in range(batch_size)
60 | ]
61 | ids = np.vstack(ids)
62 | binary_ids = self.binarize_code(ids)
63 | zeros = torch.tensor(
64 | np.zeros(binary_ids.shape),
65 | dtype=relations.dtype,
66 | device=relations.device)
67 | binary_ids = torch.tensor(
68 | binary_ids, dtype=relations.dtype, device=relations.device)
69 | binary_ids2 = torch.cat(meshgrid(binary_ids, dim=1), dim=-1)
70 |
71 | if attributes is None:
72 | rels = [binary_ids2, relations]
73 | else:
74 | padding = torch.zeros(
75 | *binary_ids2.size()[:-1],
76 | attributes.size(-1),
77 | dtype=relations.dtype,
78 | device=relations.device)
79 | rels = [binary_ids2, padding, relations]
80 | rels = torch.cat(rels, dim=-1)
81 | input_seq = rels.view(batch_size, -1, rels.size(-1))
82 | if attributes is not None:
83 | assert nr == attributes.size(1)
84 | padding = torch.zeros(
85 | *binary_ids.size()[:-1],
86 | relations.size(-1),
87 | dtype=relations.dtype,
88 | device=relations.device)
89 | attributes = torch.cat([binary_ids, zeros, attributes, padding], dim=-1)
90 | input_seq = torch.cat([input_seq, attributes], dim=1)
91 |
92 | h0 = torch.zeros(
93 | self.num_layers * 2,
94 | batch_size,
95 | self.hidden_size,
96 | dtype=relations.dtype,
97 | device=relations.device)
98 | c0 = torch.zeros(
99 | self.num_layers * 2,
100 | batch_size,
101 | self.hidden_size,
102 | dtype=relations.dtype,
103 | device=relations.device)
104 | out, _ = self.lstm(input_seq, (h0, c0))
105 | out = out[:, -1]
106 |
107 | if self.feature_dim == 1:
108 | expanded_feature = broadcast(out.unsqueeze(dim=1), 1, nr)
109 | return torch.cat([binary_ids, expanded_feature], dim=-1)
110 | else:
111 | expanded_feature = broadcast(out.unsqueeze(dim=1), 1, nr)
112 | expanded_feature = broadcast(expanded_feature.unsqueeze(dim=1), 1, nr)
113 | return torch.cat([binary_ids2, expanded_feature], dim=-1)
114 |
115 | def binarize_code(self, x):
116 | m = self.code_length
117 | code = np.zeros((x.shape + (m,)))
118 | for i in range(m)[::-1]:
119 | code[:, :, i] = (x >= 2**i).astype('float')
120 | x = x - code[:, :, i] * 2**i
121 | return code
122 |
123 | def get_output_dim(self):
124 | return self.hidden_size * 2 + self.code_length * self.feature_dim
125 |
--------------------------------------------------------------------------------
/difflogic/nn/baselines/memory_net.py:
--------------------------------------------------------------------------------
1 | #! /usr/bin/env python3
2 | #
3 | # Copyright 2018 Google LLC
4 | #
5 | # Licensed under the Apache License, Version 2.0 (the "License");
6 | # you may not use this file except in compliance with the License.
7 | # You may obtain a copy of the License at
8 | #
9 | # https://www.apache.org/licenses/LICENSE-2.0
10 | #
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 | # limitations under the License.
16 |
17 | import functools
18 | import numpy as np
19 | import torch
20 | import torch.nn as nn
21 | import torch.nn.functional as F
22 |
23 | import jacinle.random as random
24 | from difflogic.nn.neural_logic.modules._utils import meshgrid
25 | from jactorch.quickstart.models import MLPModel
26 |
27 | __all__ = ['MemoryNet']
28 |
29 |
30 | class MemoryNet(nn.Module):
31 |
32 | def __init__(self, input_dim, feature_dim, queries, hidden_dim, key_dim,
33 | value_dim, id_dim):
34 | super().__init__()
35 | self.feature_dim = feature_dim
36 | assert feature_dim == 1 or feature_dim == 2, \
37 | 'only support attributes or relations'
38 | self.queries = queries
39 | self.hidden_dim = hidden_dim
40 | self.key_dim = key_dim
41 | self.value_dim = value_dim
42 | self.id_dim = id_dim
43 |
44 | current_dim = id_dim * 2 + input_dim
45 | self.key_embed = MLPModel(current_dim, key_dim, [])
46 | self.value_embed = MLPModel(current_dim, value_dim, [])
47 | self.query_embed = MLPModel(id_dim * 2, key_dim, [])
48 | self.to_query = MLPModel(hidden_dim, key_dim, [])
49 | self.lstm_cell = nn.LSTMCell(value_dim, hidden_dim)
50 |
51 | def forward(self, relations, attributes=None):
52 | batch_size, nr = relations.size()[:2]
53 | assert nr == relations.size(2)
54 | create_zeros = functools.partial(
55 | torch.zeros, dtype=relations.dtype, device=relations.device)
56 |
57 | id_shape = list(relations.size()[:-1])
58 | ids = [random.permutation(2 ** self.id_dim - 1)[:nr] + 1 \
59 | for i in range(batch_size)]
60 | ids = np.vstack(ids)
61 | binary_ids = self.binarize_code(ids)
62 | zeros = create_zeros(binary_ids.shape)
63 | binary_ids = torch.tensor(
64 | binary_ids, dtype=relations.dtype, device=relations.device)
65 | binary_ids2 = torch.cat(meshgrid(binary_ids, dim=1), dim=-1)
66 | padded_binary_ids = torch.cat([binary_ids, zeros], dim=-1)
67 |
68 | def embed(embed, x):
69 | input_size = x.size()[:-1]
70 | input_channel = x.size(-1)
71 | f = x.view(-1, input_channel)
72 | f = embed(f)
73 | return f.view(*input_size, -1)
74 |
75 | if attributes is None:
76 | rels = [binary_ids2, relations]
77 | else:
78 | padding = create_zeros(*binary_ids2.size()[:-1], attributes.size(-1))
79 | rels = [binary_ids2, padding, relations]
80 | rels = torch.cat(rels, dim=-1)
81 | memory = rels.view(batch_size, -1, rels.size(-1))
82 | if attributes is not None:
83 | assert nr == attributes.size(1)
84 | padding = create_zeros(*padded_binary_ids.size()[:-1], relations.size(-1))
85 | attributes = torch.cat([padded_binary_ids, attributes, padding], dim=-1)
86 | memory = torch.cat([memory, attributes], dim=1)
87 | keys = embed(self.key_embed, memory).transpose(1, 2)
88 | values = embed(self.value_embed, memory)
89 |
90 | query = padded_binary_ids if self.feature_dim == 1 else binary_ids2
91 | nr_items = nr**self.feature_dim
92 | query = embed(self.query_embed, query).view(batch_size, nr_items, -1)
93 |
94 | h0 = create_zeros(batch_size * nr_items, self.hidden_dim)
95 | c0 = create_zeros(batch_size * nr_items, self.hidden_dim)
96 | for i in range(self.queries):
97 | attention = F.softmax(torch.bmm(query, keys), dim=-1)
98 | value = torch.bmm(attention, values)
99 | value = value.view(-1, value.size(-1))
100 |
101 | h0, c0 = self.lstm_cell(value, (h0, c0))
102 | query = self.to_query(h0).view(batch_size, nr_items, self.key_dim)
103 |
104 | if self.feature_dim == 1:
105 | out = h0.view(batch_size, nr, self.hidden_dim)
106 | else:
107 | out = h0.view(batch_size, nr, nr, self.hidden_dim)
108 | return out
109 |
110 | def binarize_code(self, x):
111 | m = self.id_dim
112 | code = np.zeros((x.shape + (m,)))
113 | for i in range(m)[::-1]:
114 | code[:, :, i] = (x >= 2**i).astype('float')
115 | x = x - code[:, :, i] * 2**i
116 | return code
117 |
118 | def get_output_dim(self):
119 | return self.hidden_dim
120 |
121 | __hyperparams__ = ('queries', 'hidden_dim', 'key_dim', 'value_dim', 'id_dim')
122 |
123 | __hyperparam_defaults__ = {
124 | 'queries': 4,
125 | 'hidden_dim': 64,
126 | 'key_dim': 16,
127 | 'value_dim': 32,
128 | 'id_dim': 8,
129 | }
130 |
131 | @classmethod
132 | def make_memnet_parser(cls, parser, defaults, prefix=None):
133 | for k, v in cls.__hyperparam_defaults__.items():
134 | defaults.setdefault(k, v)
135 |
136 | if prefix is None:
137 | prefix = '--'
138 | else:
139 | prefix = '--' + str(prefix) + '-'
140 |
141 | parser.add_argument(
142 | prefix + 'queries',
143 | type=int,
144 | default=defaults['queries'],
145 | metavar='N',
146 | help='number of queries')
147 | parser.add_argument(
148 | prefix + 'hidden-dim',
149 | type=int,
150 | default=defaults['hidden_dim'],
151 | metavar='N',
152 | help='hidden dimension of LSTM cell')
153 | parser.add_argument(
154 | prefix + 'key-dim',
155 | type=int,
156 | default=defaults['key_dim'],
157 | metavar='N',
158 | help='dimension of key vector')
159 | parser.add_argument(
160 | prefix + 'value-dim',
161 | type=int,
162 | default=defaults['value_dim'],
163 | metavar='N',
164 | help='dimension of value vector')
165 | parser.add_argument(
166 | prefix + 'id-dim',
167 | type=int,
168 | default=defaults['id_dim'],
169 | metavar='N',
170 | help='dimension of id vector')
171 |
172 | @classmethod
173 | def from_args(cls, input_dim, feature_dim, args, prefix=None, **kwargs):
174 | if prefix is None:
175 | prefix = ''
176 | else:
177 | prefix = str(prefix) + '_'
178 |
179 | init_params = {k: getattr(args, prefix + k) for k in cls.__hyperparams__}
180 | init_params.update(kwargs)
181 |
182 | return cls(input_dim, feature_dim, **init_params)
183 |
--------------------------------------------------------------------------------
/difflogic/nn/neural_logic/__init__.py:
--------------------------------------------------------------------------------
1 | #! /usr/bin/env python3
2 | #
3 | # Copyright 2018 Google LLC
4 | #
5 | # Licensed under the Apache License, Version 2.0 (the "License");
6 | # you may not use this file except in compliance with the License.
7 | # You may obtain a copy of the License at
8 | #
9 | # https://www.apache.org/licenses/LICENSE-2.0
10 | #
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 | # limitations under the License.
16 |
17 | from .modules.input_transform import *
18 | from .modules.neural_logic import *
19 |
20 | from .layer import *
21 |
--------------------------------------------------------------------------------
/difflogic/nn/neural_logic/layer.py:
--------------------------------------------------------------------------------
1 | #! /usr/bin/env python3
2 | #
3 | # Copyright 2018 Google LLC
4 | #
5 | # Licensed under the Apache License, Version 2.0 (the "License");
6 | # you may not use this file except in compliance with the License.
7 | # You may obtain a copy of the License at
8 | #
9 | # https://www.apache.org/licenses/LICENSE-2.0
10 | #
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 | # limitations under the License.
16 | """Implement Neural Logic Layers and Machines."""
17 |
18 | import torch
19 | import torch.nn as nn
20 |
21 | from jacinle.logging import get_logger
22 |
23 | from .modules.dimension import Expander, Reducer, Permutation
24 | from .modules.neural_logic import LogicInference
25 |
26 | __all__ = ['LogicLayer', 'LogicMachine']
27 |
28 | logger = get_logger(__file__)
29 |
30 |
31 | def _get_tuple_n(x, n, tp):
32 | """Get a length-n list of type tp."""
33 | assert tp is not list
34 | if isinstance(x, tp):
35 | x = [x,] * n
36 | assert len(x) == n, 'Parameters should be {} or list of N elements.'.format(
37 | tp)
38 | for i in x:
39 | assert isinstance(i, tp), 'Elements of list should be {}.'.format(tp)
40 | return x
41 |
42 |
43 | class LogicLayer(nn.Module):
44 | """Logic Layers do one-step differentiable logic deduction.
45 |
46 | The predicates grouped by their number of variables. The inter group deduction
47 | is done by expansion/reduction, the intra group deduction is done by logic
48 | model.
49 |
50 | Args:
51 | breadth: The breadth of the logic layer.
52 | input_dims: the number of input channels of each input group, should consist
53 | with the inputs. use dims=0 and input=None to indicate no input
54 | of that group.
55 | output_dims: the number of output channels of each group, could
56 | use a single value.
57 | logic_hidden_dim: The hidden dim of the logic model.
58 | exclude_self: Not allow multiple occurrence of same variable when
59 | being True.
60 | residual: Use residual connections when being True.
61 | """
62 |
63 | def __init__(
64 | self,
65 | breadth,
66 | input_dims,
67 | output_dims,
68 | logic_hidden_dim,
69 | exclude_self=True,
70 | residual=False,
71 | ):
72 | super().__init__()
73 | assert breadth > 0, 'Does not support breadth <= 0.'
74 | if breadth > 3:
75 | logger.warn(
76 | 'Using LogicLayer with breadth > 3 may cause speed and memory issue.')
77 |
78 | self.max_order = breadth
79 | self.residual = residual
80 |
81 | input_dims = _get_tuple_n(input_dims, self.max_order + 1, int)
82 | output_dims = _get_tuple_n(output_dims, self.max_order + 1, int)
83 |
84 | self.logic, self.dim_perms, self.dim_expanders, self.dim_reducers = [
85 | nn.ModuleList() for _ in range(4)
86 | ]
87 | for i in range(self.max_order + 1):
88 | # collect current_dim from group i-1, i and i+1.
89 | current_dim = input_dims[i]
90 | if i > 0:
91 | expander = Expander(i - 1)
92 | self.dim_expanders.append(expander)
93 | current_dim += expander.get_output_dim(input_dims[i - 1])
94 | else:
95 | self.dim_expanders.append(None)
96 |
97 | if i + 1 < self.max_order + 1:
98 | reducer = Reducer(i + 1, exclude_self)
99 | self.dim_reducers.append(reducer)
100 | current_dim += reducer.get_output_dim(input_dims[i + 1])
101 | else:
102 | self.dim_reducers.append(None)
103 |
104 | if current_dim == 0:
105 | self.dim_perms.append(None)
106 | self.logic.append(None)
107 | output_dims[i] = 0
108 | else:
109 | perm = Permutation(i)
110 | self.dim_perms.append(perm)
111 | current_dim = perm.get_output_dim(current_dim)
112 | self.logic.append(
113 | LogicInference(current_dim, output_dims[i], logic_hidden_dim))
114 |
115 | self.input_dims = input_dims
116 | self.output_dims = output_dims
117 |
118 | if self.residual:
119 | for i in range(len(input_dims)):
120 | self.output_dims[i] += input_dims[i]
121 |
122 | def forward(self, inputs):
123 | assert len(inputs) == self.max_order + 1
124 | outputs = []
125 | for i in range(self.max_order + 1):
126 | # collect input f from group i-1, i and i+1.
127 | f = []
128 | if i > 0 and self.input_dims[i - 1] > 0:
129 | n = inputs[i].size(1) if i == 1 else None
130 | f.append(self.dim_expanders[i](inputs[i - 1], n))
131 | if i < len(inputs) and self.input_dims[i] > 0:
132 | f.append(inputs[i])
133 | if i + 1 < len(inputs) and self.input_dims[i + 1] > 0:
134 | f.append(self.dim_reducers[i](inputs[i + 1]))
135 | if len(f) == 0:
136 | output = None
137 | else:
138 | f = torch.cat(f, dim=-1)
139 | f = self.dim_perms[i](f)
140 | output = self.logic[i](f)
141 | if self.residual and self.input_dims[i] > 0:
142 | output = torch.cat([inputs[i], output], dim=-1)
143 | outputs.append(output)
144 | return outputs
145 |
146 | __hyperparams__ = (
147 | 'breadth',
148 | 'input_dims',
149 | 'output_dims',
150 | 'logic_hidden_dim',
151 | 'exclude_self',
152 | 'residual',
153 | )
154 |
155 | __hyperparam_defaults__ = {
156 | 'exclude_self': True,
157 | 'residual': False,
158 | }
159 |
160 | @classmethod
161 | def make_nlm_parser(cls, parser, defaults, prefix=None):
162 | for k, v in cls.__hyperparam_defaults__.items():
163 | defaults.setdefault(k, v)
164 |
165 | if prefix is None:
166 | prefix = '--'
167 | else:
168 | prefix = '--' + str(prefix) + '-'
169 |
170 | parser.add_argument(
171 | prefix + 'breadth',
172 | type='int',
173 | default=defaults['breadth'],
174 | metavar='N',
175 | help='breadth of the logic layer')
176 | parser.add_argument(
177 | prefix + 'logic-hidden-dim',
178 | type=int,
179 | nargs='+',
180 | default=defaults['logic_hidden_dim'],
181 | metavar='N',
182 | help='hidden dim of the logic model')
183 | parser.add_argument(
184 | prefix + 'exclude-self',
185 | type='bool',
186 | default=defaults['exclude_self'],
187 | metavar='B',
188 | help='not allow multiple occurrence of same variable')
189 | parser.add_argument(
190 | prefix + 'residual',
191 | type='bool',
192 | default=defaults['residual'],
193 | metavar='B',
194 | help='use residual connections')
195 |
196 | @classmethod
197 | def from_args(cls, input_dims, output_dims, args, prefix=None, **kwargs):
198 | if prefix is None:
199 | prefix = ''
200 | else:
201 | prefix = str(prefix) + '_'
202 |
203 | setattr(args, prefix + 'input_dims', input_dims)
204 | setattr(args, prefix + 'output_dims', output_dims)
205 | init_params = {k: getattr(args, prefix + k) for k in cls.__hyperparams__}
206 | init_params.update(kwargs)
207 |
208 | return cls(**init_params)
209 |
210 |
211 | class LogicMachine(nn.Module):
212 | """Neural Logic Machine consists of multiple logic layers."""
213 |
214 | def __init__(
215 | self,
216 | depth,
217 | breadth,
218 | input_dims,
219 | output_dims,
220 | logic_hidden_dim,
221 | exclude_self=True,
222 | residual=False,
223 | io_residual=False,
224 | recursion=False,
225 | connections=None,
226 | ):
227 | super().__init__()
228 | self.depth = depth
229 | self.breadth = breadth
230 | self.residual = residual
231 | self.io_residual = io_residual
232 | self.recursion = recursion
233 | self.connections = connections
234 |
235 | assert not (self.residual and self.io_residual), \
236 | 'Only one type of residual connection is allowed at the same time.'
237 |
238 | # element-wise addition for vector
239 | def add_(x, y):
240 | for i in range(len(y)):
241 | x[i] += y[i]
242 | return x
243 |
244 | self.layers = nn.ModuleList()
245 | current_dims = input_dims
246 | total_output_dims = [0 for _ in range(self.breadth + 1)
247 | ] # for IO residual only
248 | for i in range(depth):
249 | # IO residual is unused.
250 | if i > 0 and io_residual:
251 | add_(current_dims, input_dims)
252 | # Not support output_dims as list or list[list] yet.
253 | layer = LogicLayer(breadth, current_dims, output_dims, logic_hidden_dim,
254 | exclude_self, residual)
255 | current_dims = layer.output_dims
256 | current_dims = self._mask(current_dims, i, 0)
257 | if io_residual:
258 | add_(total_output_dims, current_dims)
259 | self.layers.append(layer)
260 |
261 | if io_residual:
262 | self.output_dims = total_output_dims
263 | else:
264 | self.output_dims = current_dims
265 |
266 | # Mask out the specific group-entry in layer i, specified by self.connections.
267 | # For debug usage.
268 | def _mask(self, a, i, masked_value):
269 | if self.connections is not None:
270 | assert i < len(self.connections)
271 | mask = self.connections[i]
272 | if mask is not None:
273 | assert len(mask) == len(a)
274 | a = [x if y else masked_value for x, y in zip(a, mask)]
275 | return a
276 |
277 | def forward(self, inputs, depth=None):
278 | outputs = [None for _ in range(self.breadth + 1)]
279 | f = inputs
280 |
281 | # depth: the actual depth used for inference
282 | if depth is None:
283 | depth = self.depth
284 | if not self.recursion:
285 | depth = min(depth, self.depth)
286 |
287 | def merge(x, y):
288 | if x is None:
289 | return y
290 | if y is None:
291 | return x
292 | return torch.cat([x, y], dim=-1)
293 |
294 | layer = None
295 | last_layer = None
296 | for i in range(depth):
297 | if i > 0 and self.io_residual:
298 | for j, inp in enumerate(inputs):
299 | f[j] = merge(f[j], inp)
300 | # To enable recursion, use scroll variables layer/last_layer
301 | # For weight sharing of period 2, i.e. 0,1,2,1,2,1,2,...
302 | if self.recursion and i >= 3:
303 | assert not self.residual
304 | layer, last_layer = last_layer, layer
305 | else:
306 | last_layer = layer
307 | layer = self.layers[i]
308 |
309 | f = layer(f)
310 | f = self._mask(f, i, None)
311 | if self.io_residual:
312 | for j, out in enumerate(f):
313 | outputs[j] = merge(outputs[j], out)
314 | if not self.io_residual:
315 | outputs = f
316 | return outputs
317 |
318 | __hyperparams__ = (
319 | 'depth',
320 | 'breadth',
321 | 'input_dims',
322 | 'output_dims',
323 | 'logic_hidden_dim',
324 | 'exclude_self',
325 | 'io_residual',
326 | 'residual',
327 | 'recursion',
328 | )
329 |
330 | __hyperparam_defaults__ = {
331 | 'exclude_self': True,
332 | 'io_residual': False,
333 | 'residual': False,
334 | 'recursion': False,
335 | }
336 |
337 | @classmethod
338 | def make_nlm_parser(cls, parser, defaults, prefix=None):
339 | for k, v in cls.__hyperparam_defaults__.items():
340 | defaults.setdefault(k, v)
341 |
342 | if prefix is None:
343 | prefix = '--'
344 | else:
345 | prefix = '--' + str(prefix) + '-'
346 |
347 | parser.add_argument(
348 | prefix + 'depth',
349 | type=int,
350 | default=defaults['depth'],
351 | metavar='N',
352 | help='depth of the logic machine')
353 | parser.add_argument(
354 | prefix + 'breadth',
355 | type=int,
356 | default=defaults['breadth'],
357 | metavar='N',
358 | help='breadth of the logic machine')
359 | parser.add_argument(
360 | prefix + 'logic-hidden-dim',
361 | type=int,
362 | nargs='+',
363 | default=defaults['logic_hidden_dim'],
364 | metavar='N',
365 | help='hidden dim of the logic model')
366 | parser.add_argument(
367 | prefix + 'exclude-self',
368 | type='bool',
369 | default=defaults['exclude_self'],
370 | metavar='B',
371 | help='not allow multiple occurrence of same variable')
372 | parser.add_argument(
373 | prefix + 'io-residual',
374 | type='bool',
375 | default=defaults['io_residual'],
376 | metavar='B',
377 | help='use input/output-only residual connections')
378 | parser.add_argument(
379 | prefix + 'residual',
380 | type='bool',
381 | default=defaults['residual'],
382 | metavar='B',
383 | help='use residual connections')
384 | parser.add_argument(
385 | prefix + 'recursion',
386 | type='bool',
387 | default=defaults['recursion'],
388 | metavar='B',
389 | help='use recursion weight sharing')
390 |
391 | @classmethod
392 | def from_args(cls, input_dims, output_dims, args, prefix=None, **kwargs):
393 | if prefix is None:
394 | prefix = ''
395 | else:
396 | prefix = str(prefix) + '_'
397 |
398 | setattr(args, prefix + 'input_dims', input_dims)
399 | setattr(args, prefix + 'output_dims', output_dims)
400 | init_params = {k: getattr(args, prefix + k) for k in cls.__hyperparams__}
401 | init_params.update(kwargs)
402 |
403 | return cls(**init_params)
404 |
--------------------------------------------------------------------------------
/difflogic/nn/neural_logic/modules/__init__.py:
--------------------------------------------------------------------------------
1 | #! /usr/bin/env python3
2 | #
3 | # Copyright 2018 Google LLC
4 | #
5 | # Licensed under the Apache License, Version 2.0 (the "License");
6 | # you may not use this file except in compliance with the License.
7 | # You may obtain a copy of the License at
8 | #
9 | # https://www.apache.org/licenses/LICENSE-2.0
10 | #
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 | # limitations under the License.
16 | # This file is part of DifferentiableLogic-PyTorch.
17 |
--------------------------------------------------------------------------------
/difflogic/nn/neural_logic/modules/_utils.py:
--------------------------------------------------------------------------------
1 | #! /usr/bin/env python3
2 | #
3 | # Copyright 2018 Google LLC
4 | #
5 | # Licensed under the Apache License, Version 2.0 (the "License");
6 | # you may not use this file except in compliance with the License.
7 | # You may obtain a copy of the License at
8 | #
9 | # https://www.apache.org/licenses/LICENSE-2.0
10 | #
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 | # limitations under the License.
16 | """Utility functions for tensor masking."""
17 |
18 | import torch
19 |
20 | import torch.autograd as ag
21 | from jactorch.functional import meshgrid, meshgrid_exclude_self
22 |
23 | __all__ = ['meshgrid', 'meshgrid_exclude_self', 'exclude_mask', 'mask_value']
24 |
25 |
26 | def exclude_mask(inputs, cnt=2, dim=1):
27 | """Produce an exclusive mask.
28 |
29 | Specifically, for cnt=2, given an array a[i, j] of n * n, it produces
30 | a mask with size n * n where only a[i, j] = 1 if and only if (i != j).
31 |
32 | Args:
33 | inputs: The tensor to be masked.
34 | cnt: The operation is performed over [dim, dim + cnt) axes.
35 | dim: The starting dimension for the exclusive mask.
36 |
37 | Returns:
38 | A mask that make sure the coordinates are mutually exclusive.
39 | """
40 | assert cnt > 0
41 | if dim < 0:
42 | dim += inputs.dim()
43 | n = inputs.size(dim)
44 | for i in range(1, cnt):
45 | assert n == inputs.size(dim + i)
46 |
47 | rng = torch.arange(0, n, dtype=torch.long, device=inputs.device)
48 | q = []
49 | for i in range(cnt):
50 | p = rng
51 | for j in range(cnt):
52 | if i != j:
53 | p = p.unsqueeze(j)
54 | p = p.expand((n,) * cnt)
55 | q.append(p)
56 | mask = q[0] == q[0]
57 | # Mutually Exclusive
58 | for i in range(cnt):
59 | for j in range(cnt):
60 | if i != j:
61 | mask *= q[i] != q[j]
62 | for i in range(dim):
63 | mask.unsqueeze_(0)
64 | for j in range(inputs.dim() - dim - cnt):
65 | mask.unsqueeze_(-1)
66 |
67 | return mask.expand(inputs.size()).float()
68 |
69 |
70 | def mask_value(inputs, mask, value):
71 | assert inputs.size() == mask.size()
72 | return inputs * mask + value * (1 - mask)
73 |
--------------------------------------------------------------------------------
/difflogic/nn/neural_logic/modules/dimension.py:
--------------------------------------------------------------------------------
1 | #! /usr/bin/env python3
2 | #
3 | # Copyright 2018 Google LLC
4 | #
5 | # Licensed under the Apache License, Version 2.0 (the "License");
6 | # you may not use this file except in compliance with the License.
7 | # You may obtain a copy of the License at
8 | #
9 | # https://www.apache.org/licenses/LICENSE-2.0
10 | #
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 | # limitations under the License.
16 | """The dimension related operations."""
17 |
18 | import itertools
19 |
20 | import torch
21 | import torch.nn as nn
22 |
23 | from jactorch.functional import broadcast
24 |
25 | from ._utils import exclude_mask, mask_value
26 |
27 | __all__ = ['Expander', 'Reducer', 'Permutation']
28 |
29 |
30 | class Expander(nn.Module):
31 | """Capture a free variable into predicates, implemented by broadcast."""
32 |
33 | def __init__(self, dim):
34 | super().__init__()
35 | self.dim = dim
36 |
37 | def forward(self, inputs, n=None):
38 | if self.dim == 0:
39 | assert n is not None
40 | elif n is None:
41 | n = inputs.size(self.dim)
42 | dim = self.dim + 1
43 | return broadcast(inputs.unsqueeze(dim), dim, n)
44 |
45 | def get_output_dim(self, input_dim):
46 | return input_dim
47 |
48 |
49 | class Reducer(nn.Module):
50 | """Reduce out a variable via quantifiers (exists/forall), implemented by max/min-pooling."""
51 |
52 | def __init__(self, dim, exclude_self=True, exists=True):
53 | super().__init__()
54 | self.dim = dim
55 | self.exclude_self = exclude_self
56 | self.exists = exists
57 |
58 | def forward(self, inputs):
59 | shape = inputs.size()
60 | inp0, inp1 = inputs, inputs
61 | if self.exclude_self:
62 | mask = exclude_mask(inputs, cnt=self.dim, dim=-1 - self.dim)
63 | inp0 = mask_value(inputs, mask, 0.0)
64 | inp1 = mask_value(inputs, mask, 1.0)
65 |
66 | if self.exists:
67 | shape = shape[:-2] + (shape[-1] * 2,)
68 | exists = torch.max(inp0, dim=-2)[0]
69 | forall = torch.min(inp1, dim=-2)[0]
70 | return torch.stack((exists, forall), dim=-1).view(shape)
71 |
72 | shape = shape[:-2] + (shape[-1],)
73 | return torch.max(inp0, dim=-2)[0].view(shape)
74 |
75 | def get_output_dim(self, input_dim):
76 | if self.exists:
77 | return input_dim * 2
78 | return input_dim
79 |
80 |
81 | class Permutation(nn.Module):
82 | """Create r! new predicates by permuting the axies for r-arity predicates."""
83 |
84 | def __init__(self, dim):
85 | super().__init__()
86 | self.dim = dim
87 |
88 | def forward(self, inputs):
89 | if self.dim <= 1:
90 | return inputs
91 | nr_dims = len(inputs.size())
92 | # Assume the last dim is channel.
93 | index = tuple(range(nr_dims - 1))
94 | start_dim = nr_dims - 1 - self.dim
95 | assert start_dim > 0
96 | res = []
97 | for i in itertools.permutations(index[start_dim:]):
98 | p = index[:start_dim] + i + (nr_dims - 1,)
99 | res.append(inputs.permute(p))
100 | return torch.cat(res, dim=-1)
101 |
102 | def get_output_dim(self, input_dim):
103 | mul = 1
104 | for i in range(self.dim):
105 | mul *= i + 1
106 | return input_dim * mul
107 |
--------------------------------------------------------------------------------
/difflogic/nn/neural_logic/modules/input_transform.py:
--------------------------------------------------------------------------------
1 | #! /usr/bin/env python3
2 | #
3 | # Copyright 2018 Google LLC
4 | #
5 | # Licensed under the Apache License, Version 2.0 (the "License");
6 | # you may not use this file except in compliance with the License.
7 | # You may obtain a copy of the License at
8 | #
9 | # https://www.apache.org/licenses/LICENSE-2.0
10 | #
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 | # limitations under the License.
16 | """Implement transformation for input tensors."""
17 |
18 | import torch
19 | import torch.nn as nn
20 |
21 | from jacinle.utils.enum import JacEnum
22 |
23 | from ._utils import meshgrid, meshgrid_exclude_self
24 |
25 | __all__ = ['InputTransformMethod', 'InputTransform']
26 |
27 |
28 | class InputTransformMethod(JacEnum):
29 | CONCAT = 'concat'
30 | DIFF = 'diff'
31 | CMP = 'cmp'
32 |
33 |
34 | class InputTransform(nn.Module):
35 | """Transform the unary predicates to binary predicates by operations."""
36 |
37 | def __init__(self, method, exclude_self=True):
38 | super().__init__()
39 | self.method = InputTransformMethod.from_string(method)
40 | self.exclude_self = exclude_self
41 |
42 | def forward(self, inputs):
43 | assert inputs.dim() == 3
44 |
45 | x, y = meshgrid(inputs, dim=1)
46 |
47 | if self.method is InputTransformMethod.CONCAT:
48 | combined = torch.cat((x, y), dim=3)
49 | elif self.method is InputTransformMethod.DIFF:
50 | combined = x - y
51 | elif self.method is InputTransformMethod.CMP:
52 | combined = torch.cat([x < y, x == y, x > y], dim=3)
53 | else:
54 | raise ValueError('Unknown input transform method: {}.'.format(
55 | self.method))
56 |
57 | if self.exclude_self:
58 | combined = meshgrid_exclude_self(combined, dim=1)
59 | return combined.float()
60 |
61 | def get_output_dim(self, input_dim):
62 | if self.method is InputTransformMethod.CONCAT:
63 | return input_dim * 2
64 | elif self.method is InputTransformMethod.DIFF:
65 | return input_dim
66 | elif self.method is InputTransformMethod.CMP:
67 | return input_dim * 3
68 | else:
69 | raise ValueError('Unknown input transform method: {}.'.format(
70 | self.method))
71 |
72 | def __repr__(self):
73 | return '{name}({method}, exclude_self={exclude_self})'.format(
74 | name=self.__class__.__name__, **self.__dict__)
75 |
--------------------------------------------------------------------------------
/difflogic/nn/neural_logic/modules/neural_logic.py:
--------------------------------------------------------------------------------
1 | #! /usr/bin/env python3
2 | #
3 | # Copyright 2018 Google LLC
4 | #
5 | # Licensed under the Apache License, Version 2.0 (the "License");
6 | # you may not use this file except in compliance with the License.
7 | # You may obtain a copy of the License at
8 | #
9 | # https://www.apache.org/licenses/LICENSE-2.0
10 | #
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 | # limitations under the License.
16 | """MLP-based implementation for logic and logits inference."""
17 |
18 | import torch.nn as nn
19 |
20 | from jactorch.quickstart.models import MLPModel
21 |
22 | __all__ = ['LogicInference', 'LogitsInference']
23 |
24 |
25 | class InferenceBase(nn.Module):
26 | """MLP model with shared parameters among other axies except the channel axis."""
27 |
28 | def __init__(self, input_dim, output_dim, hidden_dim):
29 | super().__init__()
30 | self.input_dim = input_dim
31 | self.output_dim = output_dim
32 | self.hidden_dim = hidden_dim
33 | self.layer = nn.Sequential(MLPModel(input_dim, output_dim, hidden_dim))
34 |
35 | def forward(self, inputs):
36 | input_size = inputs.size()[:-1]
37 | input_channel = inputs.size(-1)
38 |
39 | f = inputs.view(-1, input_channel)
40 | f = self.layer(f)
41 | f = f.view(*input_size, -1)
42 | return f
43 |
44 | def get_output_dim(self, input_dim):
45 | return self.output_dim
46 |
47 |
48 | class LogicInference(InferenceBase):
49 | """MLP layer with sigmoid activation."""
50 |
51 | def __init__(self, input_dim, output_dim, hidden_dim):
52 | super().__init__(input_dim, output_dim, hidden_dim)
53 | self.layer.add_module(str(len(self.layer)), nn.Sigmoid())
54 |
55 |
56 | class LogitsInference(InferenceBase):
57 | pass
58 |
--------------------------------------------------------------------------------
/difflogic/nn/rl/__init__.py:
--------------------------------------------------------------------------------
1 | #! /usr/bin/env python3
2 | #
3 | # Copyright 2018 Google LLC
4 | #
5 | # Licensed under the Apache License, Version 2.0 (the "License");
6 | # you may not use this file except in compliance with the License.
7 | # You may obtain a copy of the License at
8 | #
9 | # https://www.apache.org/licenses/LICENSE-2.0
10 | #
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 | # limitations under the License.
16 |
17 | from .reinforce import *
18 |
--------------------------------------------------------------------------------
/difflogic/nn/rl/reinforce.py:
--------------------------------------------------------------------------------
1 | #! /usr/bin/env python3
2 | #
3 | # Copyright 2018 Google LLC
4 | #
5 | # Licensed under the Apache License, Version 2.0 (the "License");
6 | # you may not use this file except in compliance with the License.
7 | # You may obtain a copy of the License at
8 | #
9 | # https://www.apache.org/licenses/LICENSE-2.0
10 | #
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 | # limitations under the License.
16 | """Implement REINFORCE loss."""
17 |
18 | import torch.nn as nn
19 |
20 | __all__ = ['REINFORCELoss']
21 |
22 |
23 | class REINFORCELoss(nn.Module):
24 | """Implement the loss function for REINFORCE algorithm."""
25 |
26 | def __init__(self, entropy_beta=None):
27 | super().__init__()
28 | self.nll = nn.NLLLoss(reduce=False)
29 | self.entropy_beta = entropy_beta
30 |
31 | def forward(self, policy, action, discount_reward, entropy_beta=None):
32 | monitors = dict()
33 | entropy = -(policy * policy.log()).sum(dim=1).mean()
34 | nll = self.nll(policy, action)
35 | loss = (nll * discount_reward).mean()
36 | if entropy_beta is None:
37 | entropy_beta = self.entropy_beta
38 | if entropy_beta is not None:
39 | monitors['reinforce_loss'] = loss
40 | monitors['entropy_loss'] = -entropy * entropy_beta
41 | loss -= entropy * entropy_beta
42 | monitors['entropy'] = entropy
43 | return loss, monitors
44 |
--------------------------------------------------------------------------------
/difflogic/thutils.py:
--------------------------------------------------------------------------------
1 | #! /usr/bin/env python3
2 | #
3 | # Copyright 2018 Google LLC
4 | #
5 | # Licensed under the Apache License, Version 2.0 (the "License");
6 | # you may not use this file except in compliance with the License.
7 | # You may obtain a copy of the License at
8 | #
9 | # https://www.apache.org/licenses/LICENSE-2.0
10 | #
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 | # limitations under the License.
16 | """Utility functions for PyTorch."""
17 |
18 | import torch
19 | import torch.nn.functional as F
20 |
21 | from jactorch.utils.meta import as_float
22 | from jactorch.utils.meta import as_tensor
23 |
24 | __all__ = [
25 | 'binary_accuracy', 'rms', 'monitor_saturation', 'monitor_paramrms',
26 | 'monitor_gradrms'
27 | ]
28 |
29 |
30 | def binary_accuracy(label, raw_pred, eps=1e-20, return_float=True):
31 | """get accuracy for binary classification problem."""
32 | pred = as_tensor(raw_pred).squeeze(-1)
33 | pred = (pred > 0.5).float()
34 | label = as_tensor(label).float()
35 | # The $acc is micro accuracy = the correct ones / total
36 | acc = label.eq(pred).float()
37 |
38 | # The $balanced_accuracy is macro accuracy, with class-wide balance.
39 | nr_total = torch.ones(
40 | label.size(), dtype=label.dtype, device=label.device).sum(dim=-1)
41 | nr_pos = label.sum(dim=-1)
42 | nr_neg = nr_total - nr_pos
43 | pos_cnt = (acc * label).sum(dim=-1)
44 | neg_cnt = acc.sum(dim=-1) - pos_cnt
45 | balanced_acc = ((pos_cnt + eps) / (nr_pos + eps) + (neg_cnt + eps) /
46 | (nr_neg + eps)) / 2.0
47 |
48 | # $sat means the saturation rate of the predication,
49 | # measure how close the predections are to 0 or 1.
50 | sat = 1 - (raw_pred - pred).abs()
51 | if return_float:
52 | acc = as_float(acc.mean())
53 | balanced_acc = as_float(balanced_acc.mean())
54 | sat_mean = as_float(sat.mean())
55 | sat_min = as_float(sat.min())
56 | else:
57 | sat_mean = sat.mean(dim=-1)
58 | sat_min = sat.min(dim=-1)[0]
59 |
60 | return {
61 | 'accuracy': acc,
62 | 'balanced_accuracy': balanced_acc,
63 | 'satuation/mean': sat_mean,
64 | 'satuation/min': sat_min,
65 | }
66 |
67 |
68 | def rms(p):
69 | """Root mean square function."""
70 | return as_float((as_tensor(p)**2).mean()**0.5)
71 |
72 |
73 | def monitor_saturation(model):
74 | """Monitor the saturation rate."""
75 | monitors = {}
76 | for name, p in model.named_parameters():
77 | p = F.sigmoid(p)
78 | sat = 1 - (p - (p > 0.5).float()).abs()
79 | monitors['sat/' + name] = sat
80 | return monitors
81 |
82 |
83 | def monitor_paramrms(model):
84 | """Monitor the rms of the parameters."""
85 | monitors = {}
86 | for name, p in model.named_parameters():
87 | monitors['paramrms/' + name] = rms(p)
88 | return monitors
89 |
90 |
91 | def monitor_gradrms(model):
92 | """Monitor the rms of the gradients of the parameters."""
93 | monitors = {}
94 | for name, p in model.named_parameters():
95 | if p.grad is not None:
96 | monitors['gradrms/' + name] = rms(p.grad) / max(rms(p), 1e-8)
97 | return monitors
98 |
--------------------------------------------------------------------------------
/difflogic/tqdm_utils.py:
--------------------------------------------------------------------------------
1 | #! /usr/bin/env python3
2 | #
3 | # Copyright 2018 Google LLC
4 | #
5 | # Licensed under the Apache License, Version 2.0 (the "License");
6 | # you may not use this file except in compliance with the License.
7 | # You may obtain a copy of the License at
8 | #
9 | # https://www.apache.org/licenses/LICENSE-2.0
10 | #
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 | # limitations under the License.
16 | """The utility functions for tqdm."""
17 |
18 | from jacinle.utils.tqdm import tqdm_pbar
19 |
20 | __all__ = ['tqdm_for']
21 |
22 |
23 | def tqdm_for(total, func):
24 | """wrapper of the for function with message showing on the progress bar."""
25 | # Not support break cases for now.
26 | with tqdm_pbar(total=total) as pbar:
27 | for i in range(total):
28 | message = func(i)
29 | pbar.set_description(message)
30 | pbar.update()
31 |
--------------------------------------------------------------------------------
/difflogic/train/__init__.py:
--------------------------------------------------------------------------------
1 | #! /usr/bin/env python3
2 | #
3 | # Copyright 2018 Google LLC
4 | #
5 | # Licensed under the Apache License, Version 2.0 (the "License");
6 | # you may not use this file except in compliance with the License.
7 | # You may obtain a copy of the License at
8 | #
9 | # https://www.apache.org/licenses/LICENSE-2.0
10 | #
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 | # limitations under the License.
16 |
17 | from .train import *
18 |
--------------------------------------------------------------------------------
/models/blocksworld.pth:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/google/neural-logic-machines/3f8a8966c54d13d2658c77c03793a9a98a283e22/models/blocksworld.pth
--------------------------------------------------------------------------------
/models/path.pth:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/google/neural-logic-machines/3f8a8966c54d13d2658c77c03793a9a98a283e22/models/path.pth
--------------------------------------------------------------------------------
/models/sort.pth:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/google/neural-logic-machines/3f8a8966c54d13d2658c77c03793a9a98a283e22/models/sort.pth
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | numpy
2 | tqdm
3 |
--------------------------------------------------------------------------------
/scripts/blocksworld/README.md:
--------------------------------------------------------------------------------
1 | # Blocks World
2 |
3 | ## Command
4 | Please run these commands from the root directory of the project.
5 |
6 | ``` shell
7 | # Train
8 | $ jac-run scripts/blocksworld/learn_policy.py --task final
9 | # Test
10 | $ jac-run scripts/blocksworld/learn_policy.py --task final --test-only --load CHECKPOINT
11 | # add [--test-epoch-size T] to control the number of testing cases.
12 | ```
13 |
--------------------------------------------------------------------------------
/scripts/graph/README.md:
--------------------------------------------------------------------------------
1 | # Graph Tasks, Shortest Path and Sorting
2 |
3 | A set of graph-related reasoning tasks ans sorting task.
4 |
5 | ## Graph tasks
6 |
7 | ### Family tree tasks
8 | ``` shell
9 | # Train: add --train-number 20 --test-number-begin 20 --test-number-step 20 --test-number-end 100
10 | $ jac-run scripts/graph/learn_graph_tasks.py --task has-father
11 | $ jac-run scripts/graph/learn_graph_tasks.py --task has-sister
12 | $ jac-run scripts/graph/learn_graph_tasks.py --task grandparents --epochs 100 --early-stop 1e-7
13 | $ jac-run scripts/graph/learn_graph_tasks.py --task uncle --epochs 200 --early-stop 1e-8
14 | $ jac-run scripts/graph/learn_graph_tasks.py --task maternal-great-uncle --epochs 20 --epoch-size 2500 --early-stop 1e-8
15 | # Test
16 | $ jac-run scripts/graph/learn_graph_tasks.py --task TASK --test-only --load $CHECKPOINT
17 | ```
18 | We use `loss < thresh` as the criteria for qualifying models.
19 |
20 | ### General graph tasks
21 | ``` shell
22 | # Train
23 | # AdjacentToRed
24 | $ jac-run scripts/graph/learn_graph_tasks.py --task adjacent --gen-graph-colors 4
25 | # 4-Connectivity
26 | $ jac-run scripts/graph/learn_graph_tasks.py --task connectivity
27 | # 6-Connectivity
28 | $ jac-run scripts/graph/learn_graph_tasks.py --task connectivity --connectivity-dist-limit 6 --early-stop 1e-6 \
29 | --nlm-depth 8 --nlm-residual True --gen-graph-pmin 0.1 --gen-graph-pmax 0.3 --gen-graph-method dnc
30 | # 1-Outdegree
31 | $ jac-run scripts/graph/learn_graph_tasks.py --task outdegree --outdegree-n 1
32 | # 2-Outdegree
33 | $ jac-run scripts/graph/learn_graph_tasks.py --task outdegree --outdegree-n 2 \
34 | --nlm-depth 5 --nlm-breadth 4 --nlm-residual True
35 |
36 | # Test
37 | $ jac-run scripts/graph/learn_graph_tasks.py --task TASK --test-only --load $CHECKPOINT \
38 | --nlm-depth DEPTH --nlm-breadth BREADTH
39 | ```
40 |
41 | ### MNIST Input
42 | We modified the `AdjacentToRed` task and replace the indicator of colors with MNIST digits. The NLM is integrated with LeNet and optimized jointly.
43 |
44 | ``` shell
45 | $ jac-run scripts/graph/learn_graph_tasks.py --task adjacent-mnist \
46 | --nlm-depth 2 --nlm-breadth 2 --nlm-attributes 16 --nlm-residual True --gen-graph-colors 10
47 | ```
48 |
49 | ## Shortest path
50 |
51 | Below provides a proper set of parameters for this task.
52 | ``` shell
53 | # Train
54 | $ jac-run scripts/graph/learn_policy.py --task path
55 | # Test
56 | $ jac-run scripts/graph/learn_policy.py --task path --test-only --load $CHECKPOINT
57 | ```
58 | For all available arguments see `jac-run scripts/graph/learn_policy.py -h`.
59 |
60 | ## Sorting
61 |
62 | Below provides a proper set of parameters for this task.
63 | ``` shell
64 | # Train
65 | $ jac-run scripts/graph/learn_policy.py --task sort --nlm-depth 3 --nlm-breadth 2 \
66 | --curriculum-graduate 10 --entropy-beta 0.01 \
67 | --mining-epoch-size 200 --mining-dataset-size 20 --mining-interval 2
68 | # Test
69 | $ jac-run scripts/graph/learn_policy.py --task sort --test-only \
70 | --nlm-depth 3 --nlm-breadth 2 --load $CHECKPOINT \
71 | ```
72 |
--------------------------------------------------------------------------------
/scripts/graph/learn_graph_tasks.py:
--------------------------------------------------------------------------------
1 | #! /usr/bin/env python3
2 | #
3 | # Copyright 2018 Google LLC
4 | #
5 | # Licensed under the Apache License, Version 2.0 (the "License");
6 | # you may not use this file except in compliance with the License.
7 | # You may obtain a copy of the License at
8 | #
9 | # https://www.apache.org/licenses/LICENSE-2.0
10 | #
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 | # limitations under the License.
16 | """The script for family tree or general graphs experiments."""
17 |
18 | import copy
19 | import collections
20 | import functools
21 | import os
22 | import json
23 |
24 | import numpy as np
25 | import torch
26 | import torch.nn as nn
27 | import torch.nn.functional as F
28 |
29 | import jacinle.random as random
30 | import jacinle.io as io
31 | import jactorch.nn as jacnn
32 |
33 | from difflogic.cli import format_args
34 | from difflogic.dataset.graph import GraphOutDegreeDataset, \
35 | GraphConnectivityDataset, GraphAdjacentDataset, FamilyTreeDataset
36 | from difflogic.nn.baselines import MemoryNet
37 | from difflogic.nn.neural_logic import LogicMachine, LogicInference, LogitsInference
38 | from difflogic.nn.neural_logic.modules._utils import meshgrid_exclude_self
39 | from difflogic.nn.rl.reinforce import REINFORCELoss
40 | from difflogic.thutils import binary_accuracy
41 | from difflogic.train import TrainerBase
42 |
43 | from jacinle.cli.argument import JacArgumentParser
44 | from jacinle.logging import get_logger, set_output_file
45 | from jacinle.utils.container import GView
46 | from jacinle.utils.meter import GroupMeters
47 | from jactorch.data.dataloader import JacDataLoader
48 | from jactorch.optim.accum_grad import AccumGrad
49 | from jactorch.optim.quickaccess import get_optimizer
50 | from jactorch.train.env import TrainerEnv
51 | from jactorch.utils.meta import as_cuda, as_numpy, as_tensor
52 |
53 | TASKS = [
54 | 'outdegree', 'connectivity', 'adjacent', 'adjacent-mnist', 'has-father',
55 | 'has-sister', 'grandparents', 'uncle', 'maternal-great-uncle'
56 | ]
57 |
58 | parser = JacArgumentParser()
59 |
60 | parser.add_argument(
61 | '--model',
62 | default='nlm',
63 | choices=['nlm', 'memnet'],
64 | help='model choices, nlm: Neural Logic Machine, memnet: Memory Networks')
65 |
66 | # NLM parameters, works when model is 'nlm'
67 | nlm_group = parser.add_argument_group('Neural Logic Machines')
68 | LogicMachine.make_nlm_parser(
69 | nlm_group, {
70 | 'depth': 4,
71 | 'breadth': 3,
72 | 'exclude_self': True,
73 | 'logic_hidden_dim': []
74 | },
75 | prefix='nlm')
76 | nlm_group.add_argument(
77 | '--nlm-attributes',
78 | type=int,
79 | default=8,
80 | metavar='N',
81 | help='number of output attributes in each group of each layer of the LogicMachine'
82 | )
83 |
84 | # MemNN parameters, works when model is 'memnet'
85 | memnet_group = parser.add_argument_group('Memory Networks')
86 | MemoryNet.make_memnet_parser(memnet_group, {}, prefix='memnet')
87 |
88 | # task related
89 | task_group = parser.add_argument_group('Task')
90 | task_group.add_argument(
91 | '--task', required=True, choices=TASKS, help='tasks choices')
92 | task_group.add_argument(
93 | '--train-number',
94 | type=int,
95 | default=10,
96 | metavar='N',
97 | help='size of training instances')
98 | task_group.add_argument(
99 | '--adjacent-pred-colors', type=int, default=4, metavar='N')
100 | task_group.add_argument('--outdegree-n', type=int, default=2, metavar='N')
101 | task_group.add_argument(
102 | '--connectivity-dist-limit', type=int, default=4, metavar='N')
103 |
104 | data_gen_group = parser.add_argument_group('Data Generation')
105 | data_gen_group.add_argument(
106 | '--gen-graph-method',
107 | default='edge',
108 | choices=['dnc', 'edge'],
109 | help='method use to generate random graph')
110 | data_gen_group.add_argument(
111 | '--gen-graph-pmin',
112 | type=float,
113 | default=0.0,
114 | metavar='F',
115 | help='control parameter p reflecting the graph sparsity')
116 | data_gen_group.add_argument(
117 | '--gen-graph-pmax',
118 | type=float,
119 | default=0.3,
120 | metavar='F',
121 | help='control parameter p reflecting the graph sparsity')
122 | data_gen_group.add_argument(
123 | '--gen-graph-colors',
124 | type=int,
125 | default=4,
126 | metavar='N',
127 | help='number of colors in adjacent task')
128 | data_gen_group.add_argument(
129 | '--gen-directed', action='store_true', help='directed graph')
130 |
131 | train_group = parser.add_argument_group('Train')
132 | train_group.add_argument(
133 | '--seed',
134 | type=int,
135 | default=None,
136 | metavar='SEED',
137 | help='seed of jacinle.random')
138 | train_group.add_argument(
139 | '--use-gpu', action='store_true', help='use GPU or not')
140 | train_group.add_argument(
141 | '--optimizer',
142 | default='AdamW',
143 | choices=['SGD', 'Adam', 'AdamW'],
144 | help='optimizer choices')
145 | train_group.add_argument(
146 | '--lr',
147 | type=float,
148 | default=0.005,
149 | metavar='F',
150 | help='initial learning rate')
151 | train_group.add_argument(
152 | '--lr-decay',
153 | type=float,
154 | default=1.0,
155 | metavar='F',
156 | help='exponential decay of learning rate per lesson')
157 | train_group.add_argument(
158 | '--accum-grad',
159 | type=int,
160 | default=1,
161 | metavar='N',
162 | help='accumulated gradient for batches (default: 1)')
163 | train_group.add_argument(
164 | '--ohem-size',
165 | type=int,
166 | default=0,
167 | metavar='N',
168 | help='size of online hard negative mining')
169 | train_group.add_argument(
170 | '--batch-size',
171 | type=int,
172 | default=4,
173 | metavar='N',
174 | help='batch size for training')
175 | train_group.add_argument(
176 | '--test-batch-size',
177 | type=int,
178 | default=4,
179 | metavar='N',
180 | help='batch size for testing')
181 | train_group.add_argument(
182 | '--early-stop-loss-thresh',
183 | type=float,
184 | default=1e-5,
185 | metavar='F',
186 | help='threshold of loss for early stop')
187 |
188 | # Note that nr_examples_per_epoch = epoch_size * batch_size
189 | TrainerBase.make_trainer_parser(
190 | parser, {
191 | 'epochs': 50,
192 | 'epoch_size': 250,
193 | 'test_epoch_size': 250,
194 | 'test_number_begin': 10,
195 | 'test_number_step': 10,
196 | 'test_number_end': 50,
197 | })
198 |
199 | io_group = parser.add_argument_group('Input/Output')
200 | io_group.add_argument(
201 | '--dump-dir', type=str, default=None, metavar='DIR', help='dump dir')
202 | io_group.add_argument(
203 | '--load-checkpoint',
204 | type=str,
205 | default=None,
206 | metavar='FILE',
207 | help='load parameters from checkpoint')
208 |
209 | schedule_group = parser.add_argument_group('Schedule')
210 | schedule_group.add_argument(
211 | '--runs', type=int, default=1, metavar='N', help='number of runs')
212 | schedule_group.add_argument(
213 | '--save-interval',
214 | type=int,
215 | default=10,
216 | metavar='N',
217 | help='the interval(number of epochs) to save checkpoint')
218 | schedule_group.add_argument(
219 | '--test-interval',
220 | type=int,
221 | default=None,
222 | metavar='N',
223 | help='the interval(number of epochs) to do test')
224 | schedule_group.add_argument(
225 | '--test-only', action='store_true', help='test-only mode')
226 |
227 | logger = get_logger(__file__)
228 |
229 | args = parser.parse_args()
230 |
231 | args.use_gpu = args.use_gpu and torch.cuda.is_available()
232 |
233 | if args.dump_dir is not None:
234 | io.mkdir(args.dump_dir)
235 | args.log_file = os.path.join(args.dump_dir, 'log.log')
236 | set_output_file(args.log_file)
237 | else:
238 | args.checkpoints_dir = None
239 | args.summary_file = None
240 |
241 | if args.seed is not None:
242 | import jacinle.random as random
243 | random.reset_global_seed(args.seed)
244 |
245 | args.task_is_outdegree = args.task in ['outdegree']
246 | args.task_is_connectivity = args.task in ['connectivity']
247 | args.task_is_adjacent = args.task in ['adjacent', 'adjacent-mnist']
248 | args.task_is_family_tree = args.task in [
249 | 'has-father', 'has-sister', 'grandparents', 'uncle', 'maternal-great-uncle'
250 | ]
251 | args.task_is_mnist_input = args.task in ['adjacent-mnist']
252 | args.task_is_1d_output = args.task in [
253 | 'outdegree', 'adjacent', 'adjacent-mnist', 'has-father', 'has-sister'
254 | ]
255 |
256 |
257 | class LeNet(nn.Module):
258 |
259 | def __init__(self):
260 | super().__init__()
261 | self.conv1 = jacnn.Conv2dLayer(
262 | 1, 10, kernel_size=5, batch_norm=True, activation='relu')
263 | self.conv2 = jacnn.Conv2dLayer(
264 | 10,
265 | 20,
266 | kernel_size=5,
267 | batch_norm=True,
268 | dropout=False,
269 | activation='relu')
270 | self.fc1 = nn.Linear(320, 50)
271 | self.fc2 = nn.Linear(50, 10)
272 |
273 | def forward(self, x):
274 | x = F.max_pool2d(self.conv1(x), 2)
275 | x = F.max_pool2d(self.conv2(x), 2)
276 | x = x.view(-1, 320)
277 | x = F.relu(self.fc1(x))
278 | x = self.fc2(x)
279 | return x
280 |
281 |
282 | class Model(nn.Module):
283 | """The model for family tree or general graphs path tasks."""
284 |
285 | def __init__(self):
286 | super().__init__()
287 |
288 | # inputs
289 | input_dim = 4 if args.task_is_family_tree else 1
290 | self.feature_axis = 1 if args.task_is_1d_output else 2
291 |
292 | # features
293 | if args.model == 'nlm':
294 | input_dims = [0 for _ in range(args.nlm_breadth + 1)]
295 | if args.task_is_adjacent:
296 | input_dims[1] = args.gen_graph_colors
297 | if args.task_is_mnist_input:
298 | self.lenet = LeNet()
299 | input_dims[2] = input_dim
300 |
301 | self.features = LogicMachine.from_args(
302 | input_dims, args.nlm_attributes, args, prefix='nlm')
303 | output_dim = self.features.output_dims[self.feature_axis]
304 |
305 | elif args.model == 'memnet':
306 | if args.task_is_adjacent:
307 | input_dim += args.gen_graph_colors
308 | self.feature = MemoryNet.from_args(
309 | input_dim, self.feature_axis, args, prefix='memnet')
310 | output_dim = self.feature.get_output_dim()
311 |
312 | # target
313 | target_dim = args.adjacent_pred_colors if args.task_is_adjacent else 1
314 | self.pred = LogicInference(output_dim, target_dim, [])
315 |
316 | # losses
317 | if args.ohem_size > 0:
318 | from jactorch.nn.losses import BinaryCrossEntropyLossWithProbs as BCELoss
319 | self.loss = BCELoss(average='none')
320 | else:
321 | self.loss = nn.BCELoss()
322 |
323 | def forward(self, feed_dict):
324 | feed_dict = GView(feed_dict)
325 |
326 | # properties
327 | if args.task_is_adjacent:
328 | states = feed_dict.states.float()
329 | else:
330 | states = None
331 |
332 | # relations
333 | relations = feed_dict.relations.float()
334 | batch_size, nr = relations.size()[:2]
335 |
336 | if args.model == 'nlm':
337 | if args.task_is_adjacent and args.task_is_mnist_input:
338 | states_shape = states.size()
339 | states = states.view((-1,) + states_shape[2:])
340 | states = self.lenet(states)
341 | states = states.view(states_shape[:2] + (-1,))
342 | states = F.sigmoid(states)
343 |
344 | inp = [None for _ in range(args.nlm_breadth + 1)]
345 | inp[1] = states
346 | inp[2] = relations
347 |
348 | depth = None
349 | if args.nlm_recursion:
350 | depth = 1
351 | while 2**depth + 1 < nr:
352 | depth += 1
353 | depth = depth * 2 + 1
354 | feature = self.features(inp, depth=depth)[self.feature_axis]
355 | elif args.model == 'memnet':
356 | feature = self.feature(relations, states)
357 | if args.task_is_adjacent and args.task_is_mnist_input:
358 | raise NotImplementedError()
359 |
360 | pred = self.pred(feature)
361 | if not args.task_is_adjacent:
362 | pred = pred.squeeze(-1)
363 | if args.task_is_connectivity:
364 | pred = meshgrid_exclude_self(pred) # exclude self-cycle
365 |
366 | if self.training:
367 | monitors = dict()
368 | target = feed_dict.target.float()
369 |
370 | if args.task_is_adjacent:
371 | target = target[:, :, :args.adjacent_pred_colors]
372 |
373 | monitors.update(binary_accuracy(target, pred, return_float=False))
374 |
375 | loss = self.loss(pred, target)
376 | # ohem loss is unused.
377 | if args.ohem_size > 0:
378 | loss = loss.view(-1).topk(args.ohem_size)[0].mean()
379 | return loss, monitors, dict(pred=pred)
380 | else:
381 | return dict(pred=pred)
382 |
383 |
384 | def make_dataset(n, epoch_size, is_train):
385 | pmin, pmax = args.gen_graph_pmin, args.gen_graph_pmax
386 | if args.task_is_outdegree:
387 | return GraphOutDegreeDataset(
388 | args.outdegree_n,
389 | epoch_size,
390 | n,
391 | pmin=pmin,
392 | pmax=pmax,
393 | directed=args.gen_directed,
394 | gen_method=args.gen_graph_method)
395 | elif args.task_is_connectivity:
396 | nmin, nmax = n, n
397 | if is_train and args.nlm_recursion:
398 | nmin = 2
399 | return GraphConnectivityDataset(
400 | args.connectivity_dist_limit,
401 | epoch_size,
402 | nmin,
403 | pmin,
404 | nmax,
405 | pmax,
406 | directed=args.gen_directed,
407 | gen_method=args.gen_graph_method)
408 | elif args.task_is_adjacent:
409 | return GraphAdjacentDataset(
410 | args.gen_graph_colors,
411 | epoch_size,
412 | n,
413 | pmin=pmin,
414 | pmax=pmax,
415 | directed=args.gen_directed,
416 | gen_method=args.gen_graph_method,
417 | is_train=is_train,
418 | is_mnist_colors=args.task_is_mnist_input)
419 | else:
420 | return FamilyTreeDataset(args.task, epoch_size, n, p_marriage=1.0)
421 |
422 |
423 | class MyTrainer(TrainerBase):
424 | def save_checkpoint(self, name):
425 | if args.checkpoints_dir is not None:
426 | checkpoint_file = os.path.join(args.checkpoints_dir,
427 | 'checkpoint_{}.pth'.format(name))
428 | super().save_checkpoint(checkpoint_file)
429 |
430 | def _dump_meters(self, meters, mode):
431 | if args.summary_file is not None:
432 | meters_kv = meters._canonize_values('avg')
433 | meters_kv['mode'] = mode
434 | meters_kv['epoch'] = self.current_epoch
435 | with open(args.summary_file, 'a') as f:
436 | f.write(io.dumps_json(meters_kv))
437 | f.write('\n')
438 |
439 | data_iterator = {}
440 |
441 | def _prepare_dataset(self, epoch_size, mode):
442 | assert mode in ['train', 'test']
443 | if mode == 'train':
444 | batch_size = args.batch_size
445 | number = args.train_number
446 | else:
447 | batch_size = args.test_batch_size
448 | number = self.test_number
449 |
450 | # The actual number of instances in an epoch is epoch_size * batch_size.
451 | dataset = make_dataset(number, epoch_size * batch_size, mode == 'train')
452 | dataloader = JacDataLoader(
453 | dataset,
454 | shuffle=True,
455 | batch_size=batch_size,
456 | num_workers=min(epoch_size, 4))
457 | self.data_iterator[mode] = dataloader.__iter__()
458 |
459 | def _get_data(self, index, meters, mode):
460 | feed_dict = self.data_iterator[mode].next()
461 | meters.update(number=feed_dict['n'].data.numpy().mean())
462 | if args.use_gpu:
463 | feed_dict = as_cuda(feed_dict)
464 | return feed_dict
465 |
466 | def _get_result(self, index, meters, mode):
467 | feed_dict = self._get_data(index, meters, mode)
468 | output_dict = self.model(feed_dict)
469 |
470 | target = feed_dict['target']
471 | if args.task_is_adjacent:
472 | target = target[:, :, :args.adjacent_pred_colors]
473 | result = binary_accuracy(target, output_dict['pred'])
474 | succ = result['accuracy'] == 1.0
475 |
476 | meters.update(succ=succ)
477 | meters.update(result, n=target.size(0))
478 | message = '> {} iter={iter}, accuracy={accuracy:.4f}, \
479 | balance_acc={balanced_accuracy:.4f}'.format(
480 | mode, iter=index, **meters.val)
481 | return message, dict(succ=succ, feed_dict=feed_dict)
482 |
483 | def _get_train_data(self, index, meters):
484 | return self._get_data(index, meters, mode='train')
485 |
486 | def _train_epoch(self, epoch_size):
487 | meters = super()._train_epoch(epoch_size)
488 |
489 | i = self.current_epoch
490 | if args.save_interval is not None and i % args.save_interval == 0:
491 | self.save_checkpoint(str(i))
492 | if args.test_interval is not None and i % args.test_interval == 0:
493 | self.test()
494 | return meters
495 |
496 | def _early_stop(self, meters):
497 | return meters.avg['loss'] < args.early_stop_loss_thresh
498 |
499 |
500 | def main(run_id):
501 | if args.dump_dir is not None:
502 | if args.runs > 1:
503 | args.current_dump_dir = os.path.join(args.dump_dir,
504 | 'run_{}'.format(run_id))
505 | io.mkdir(args.current_dump_dir)
506 | else:
507 | args.current_dump_dir = args.dump_dir
508 |
509 | args.summary_file = os.path.join(args.current_dump_dir, 'summary.json')
510 | args.checkpoints_dir = os.path.join(args.current_dump_dir, 'checkpoints')
511 | io.mkdir(args.checkpoints_dir)
512 |
513 | logger.info(format_args(args))
514 |
515 | model = Model()
516 | if args.use_gpu:
517 | model.cuda()
518 | optimizer = get_optimizer(args.optimizer, model, args.lr)
519 | if args.accum_grad > 1:
520 | optimizer = AccumGrad(optimizer, args.accum_grad)
521 | trainer = MyTrainer.from_args(model, optimizer, args)
522 |
523 | if args.load_checkpoint is not None:
524 | trainer.load_checkpoint(args.load_checkpoint)
525 |
526 | if args.test_only:
527 | return None, trainer.test()
528 |
529 | final_meters = trainer.train()
530 | trainer.save_checkpoint('last')
531 |
532 | return trainer.early_stopped, trainer.test()
533 |
534 |
535 | if __name__ == '__main__':
536 | stats = []
537 | nr_graduated = 0
538 |
539 | for i in range(args.runs):
540 | graduated, test_meters = main(i)
541 | logger.info('run {}'.format(i + 1))
542 |
543 | if test_meters is not None:
544 | for j, meters in enumerate(test_meters):
545 | if len(stats) <= j:
546 | stats.append(GroupMeters())
547 | stats[j].update(
548 | number=meters.avg['number'], test_acc=meters.avg['accuracy'])
549 |
550 | for meters in stats:
551 | logger.info('number {}, test_acc {}'.format(meters.avg['number'],
552 | meters.avg['test_acc']))
553 |
554 | if not args.test_only:
555 | nr_graduated += int(graduated)
556 | logger.info('graduate_ratio {}'.format(nr_graduated / (i + 1)))
557 | if graduated:
558 | for j, meters in enumerate(test_meters):
559 | stats[j].update(grad_test_acc=meters.avg['accuracy'])
560 | if nr_graduated > 0:
561 | for meters in stats:
562 | logger.info('number {}, grad_test_acc {}'.format(
563 | meters.avg['number'], meters.avg['grad_test_acc']))
564 |
--------------------------------------------------------------------------------
/scripts/graph/learn_policy.py:
--------------------------------------------------------------------------------
1 | #! /usr/bin/env python3
2 | #
3 | # Copyright 2018 Google LLC
4 | #
5 | # Licensed under the Apache License, Version 2.0 (the "License");
6 | # you may not use this file except in compliance with the License.
7 | # You may obtain a copy of the License at
8 | #
9 | # https://www.apache.org/licenses/LICENSE-2.0
10 | #
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 | # limitations under the License.
16 | """The script for sorting or shortest path experiments."""
17 |
18 | import collections
19 | import copy
20 | import functools
21 | import json
22 | import os
23 |
24 | import numpy as np
25 | import torch
26 | import torch.nn as nn
27 | import torch.nn.functional as F
28 |
29 | import jacinle.random as random
30 | import jacinle.io as io
31 |
32 | from difflogic.cli import format_args
33 | from difflogic.nn.baselines import MemoryNet
34 | from difflogic.nn.neural_logic import LogicMachine, LogitsInference
35 | from difflogic.nn.neural_logic.modules._utils import meshgrid_exclude_self
36 | from difflogic.nn.rl.reinforce import REINFORCELoss
37 | from difflogic.train import MiningTrainerBase
38 |
39 | from jacinle.cli.argument import JacArgumentParser
40 | from jacinle.logging import get_logger, set_output_file
41 | from jacinle.utils.container import GView
42 | from jacinle.utils.meter import GroupMeters
43 | from jactorch.optim.accum_grad import AccumGrad
44 | from jactorch.optim.quickaccess import get_optimizer
45 | from jactorch.utils.meta import as_cuda, as_numpy, as_tensor
46 |
47 | parser = JacArgumentParser()
48 |
49 | parser.add_argument(
50 | '--model',
51 | default='nlm',
52 | choices=['nlm', 'memnet'],
53 | help='model choices, nlm: Neural Logic Machine, memnet: Memory Networks')
54 |
55 | # NLM parameters, works when model is 'nlm'.
56 | nlm_group = parser.add_argument_group('Neural Logic Machines')
57 | LogicMachine.make_nlm_parser(
58 | nlm_group, {
59 | 'depth': 5,
60 | 'breadth': 3,
61 | 'residual': True,
62 | 'exclude_self': True,
63 | 'logic_hidden_dim': []
64 | },
65 | prefix='nlm')
66 | nlm_group.add_argument(
67 | '--nlm-attributes',
68 | type=int,
69 | default=8,
70 | metavar='N',
71 | help='number of output attributes in each group of each layer of the LogicMachine'
72 | )
73 |
74 | # MemNN parameters, works when model is 'memnet'.
75 | memnet_group = parser.add_argument_group('Memory Networks')
76 | MemoryNet.make_memnet_parser(memnet_group, {}, prefix='memnet')
77 |
78 | parser.add_argument(
79 | '--task', required=True, choices=['sort', 'path'], help='tasks choices')
80 |
81 | data_gen_group = parser.add_argument_group('Data Generation')
82 | data_gen_group.add_argument(
83 | '--gen-method',
84 | default='dnc',
85 | choices=['dnc', 'edge'],
86 | help='method use to generate random graph')
87 | data_gen_group.add_argument(
88 | '--gen-graph-pmin',
89 | type=float,
90 | default=0.3,
91 | metavar='F',
92 | help='control parameter p reflecting the graph sparsity')
93 | data_gen_group.add_argument(
94 | '--gen-graph-pmax',
95 | type=float,
96 | default=0.3,
97 | metavar='F',
98 | help='control parameter p reflecting the graph sparsity')
99 | data_gen_group.add_argument(
100 | '--gen-max-len',
101 | type=int,
102 | default=5,
103 | metavar='N',
104 | help='maximum length of shortest path during training')
105 | data_gen_group.add_argument(
106 | '--gen-test-len',
107 | type=int,
108 | default=4,
109 | metavar='N',
110 | help='length of shortest path during testing')
111 | data_gen_group.add_argument(
112 | '--gen-directed', action='store_true', help='directed graph')
113 |
114 | MiningTrainerBase.make_trainer_parser(
115 | parser, {
116 | 'epochs': 400,
117 | 'epoch_size': 100,
118 | 'test_epoch_size': 1000,
119 | 'test_number_begin': 10,
120 | 'test_number_step': 10,
121 | 'test_number_end': 50,
122 | 'curriculum_start': 3,
123 | 'curriculum_step': 1,
124 | 'curriculum_graduate': 12,
125 | 'curriculum_thresh_relax': 0.005,
126 | 'sample_array_capacity': 3,
127 | 'enable_mining': True,
128 | 'mining_interval': 6,
129 | 'mining_epoch_size': 3000,
130 | 'mining_dataset_size': 300,
131 | 'inherit_neg_data': True,
132 | 'prob_pos_data': 0.5
133 | })
134 |
135 | train_group = parser.add_argument_group('Train')
136 | train_group.add_argument('--seed', type=int, default=None, metavar='SEED')
137 | train_group.add_argument(
138 | '--use-gpu', action='store_true', help='use GPU or not')
139 | train_group.add_argument(
140 | '--optimizer',
141 | default='AdamW',
142 | choices=['SGD', 'Adam', 'AdamW'],
143 | help='optimizer choices')
144 | train_group.add_argument(
145 | '--lr',
146 | type=float,
147 | default=0.005,
148 | metavar='F',
149 | help='initial learning rate')
150 | train_group.add_argument(
151 | '--lr-decay',
152 | type=float,
153 | default=0.9,
154 | metavar='F',
155 | help='exponential decay of learning rate per lesson')
156 | train_group.add_argument(
157 | '--accum-grad',
158 | type=int,
159 | default=1,
160 | metavar='N',
161 | help='accumulated gradient (default: 1)')
162 | train_group.add_argument(
163 | '--candidate-relax',
164 | type=int,
165 | default=0,
166 | metavar='N',
167 | help='number of thresh relaxation for candidate')
168 |
169 | rl_group = parser.add_argument_group('Reinforcement Learning')
170 | rl_group.add_argument(
171 | '--gamma',
172 | type=float,
173 | default=0.99,
174 | metavar='F',
175 | help='discount factor for accumulated reward function in reinforcement learning'
176 | )
177 | rl_group.add_argument(
178 | '--penalty',
179 | type=float,
180 | default=-0.01,
181 | metavar='F',
182 | help='a small penalty each step')
183 | rl_group.add_argument(
184 | '--entropy-beta',
185 | type=float,
186 | default=0.1,
187 | metavar='F',
188 | help='entropy loss scaling factor')
189 | rl_group.add_argument(
190 | '--entropy-beta-decay',
191 | type=float,
192 | default=0.8,
193 | metavar='F',
194 | help='entropy beta exponential decay factor')
195 |
196 | io_group = parser.add_argument_group('Input/Output')
197 | io_group.add_argument(
198 | '--dump-dir', default=None, metavar='DIR', help='dump dir')
199 | io_group.add_argument(
200 | '--dump-play',
201 | action='store_true',
202 | help='dump the trajectory of the plays for visualization')
203 | io_group.add_argument(
204 | '--dump-fail-only', action='store_true', help='dump failure cases only')
205 | io_group.add_argument(
206 | '--load-checkpoint',
207 | default=None,
208 | metavar='FILE',
209 | help='load parameters from checkpoint')
210 |
211 | schedule_group = parser.add_argument_group('Schedule')
212 | schedule_group.add_argument(
213 | '--runs', type=int, default=1, metavar='N', help='number of runs')
214 | schedule_group.add_argument(
215 | '--early-drop-epochs',
216 | type=int,
217 | default=40,
218 | metavar='N',
219 | help='epochs could spend for each lesson, early drop')
220 | schedule_group.add_argument(
221 | '--save-interval',
222 | type=int,
223 | default=10,
224 | metavar='N',
225 | help='the interval(number of epochs) to save checkpoint')
226 | schedule_group.add_argument(
227 | '--test-interval',
228 | type=int,
229 | default=None,
230 | metavar='N',
231 | help='the interval(number of epochs) to do test')
232 | schedule_group.add_argument(
233 | '--test-only', action='store_true', help='test-only mode')
234 | schedule_group.add_argument(
235 | '--test-not-graduated',
236 | action='store_true',
237 | help='test not graduated models also')
238 |
239 | args = parser.parse_args()
240 |
241 | args.use_gpu = args.use_gpu and torch.cuda.is_available()
242 | args.dump_play = args.dump_play and (args.dump_dir is not None)
243 |
244 | if args.dump_dir is not None:
245 | io.mkdir(args.dump_dir)
246 | args.log_file = os.path.join(args.dump_dir, 'log.log')
247 | set_output_file(args.log_file)
248 | else:
249 | args.checkpoints_dir = None
250 | args.summary_file = None
251 |
252 | if args.seed is not None:
253 | import jacinle.random as random
254 | random.reset_global_seed(args.seed)
255 |
256 | args.is_path_task = args.task in ['path']
257 | args.is_sort_task = args.task in ['sort']
258 | if args.is_path_task:
259 | from difflogic.envs.graph import make as make_env
260 | make_env = functools.partial(
261 | make_env,
262 | pmin=args.gen_graph_pmin,
263 | pmax=args.gen_graph_pmax,
264 | directed=args.gen_directed,
265 | gen_method=args.gen_method)
266 | elif args.is_sort_task:
267 | from difflogic.envs.algorithmic import make as make_env
268 |
269 | logger = get_logger(__file__)
270 |
271 |
272 | class Model(nn.Module):
273 | """The model for sorting or shortest path tasks."""
274 |
275 | def __init__(self):
276 | super().__init__()
277 |
278 | self.feature_axis = 1 if args.is_path_task else 2
279 | if args.model == 'memnet':
280 | current_dim = 4 if args.is_path_task else 6
281 | self.feature = MemoryNet.from_args(
282 | current_dim, self.feature_axis, args, prefix='memnet')
283 | current_dim = self.feature.get_output_dim()
284 | else:
285 | input_dims = [0 for i in range(args.nlm_breadth + 1)]
286 | if args.is_path_task:
287 | input_dims[1] = 2
288 | input_dims[2] = 2
289 | elif args.is_sort_task:
290 | input_dims[2] = 6
291 |
292 | self.features = LogicMachine.from_args(
293 | input_dims, args.nlm_attributes, args, prefix='nlm')
294 | if args.is_path_task:
295 | current_dim = self.features.output_dims[1]
296 | elif args.task == 'sort':
297 | current_dim = self.features.output_dims[2]
298 |
299 | self.pred = LogitsInference(current_dim, 1, [])
300 | self.loss = REINFORCELoss()
301 | self.pred_loss = nn.BCELoss()
302 |
303 | def forward(self, feed_dict):
304 | feed_dict = GView(feed_dict)
305 | states = None
306 | if args.is_path_task:
307 | states = feed_dict.states.float()
308 | relations = feed_dict.relations.float()
309 | elif args.is_sort_task:
310 | relations = feed_dict.states.float()
311 |
312 | def get_features(states, relations, depth=None):
313 | inp = [None for i in range(args.nlm_breadth + 1)]
314 | inp[1] = states
315 | inp[2] = relations
316 | features = self.features(inp, depth=depth)
317 | return features
318 |
319 | if args.model == 'memnet':
320 | f = self.feature(relations, states)
321 | else:
322 | f = get_features(states, relations)[self.feature_axis]
323 | if self.feature_axis == 2: #sorting task
324 | f = meshgrid_exclude_self(f)
325 |
326 | logits = self.pred(f).squeeze(dim=-1).view(relations.size(0), -1)
327 | # Set minimal value to avoid loss to be nan.
328 | policy = F.softmax(logits, dim=-1).clamp(min=1e-20)
329 |
330 | if self.training:
331 | loss, monitors = self.loss(policy, feed_dict.actions,
332 | feed_dict.discount_rewards,
333 | feed_dict.entropy_beta)
334 | return loss, monitors, dict()
335 | else:
336 | return dict(policy=policy, logits=logits)
337 |
338 |
339 | def make_data(traj, gamma):
340 | Q = 0
341 | discount_rewards = []
342 | for reward in traj['rewards'][::-1]:
343 | Q = Q * gamma + reward
344 | discount_rewards.append(Q)
345 | discount_rewards.reverse()
346 |
347 | traj['states'] = as_tensor(np.array(traj['states']))
348 | if args.is_path_task:
349 | traj['relations'] = as_tensor(np.array(traj['relations']))
350 | traj['actions'] = as_tensor(np.array(traj['actions']))
351 | traj['discount_rewards'] = as_tensor(np.array(discount_rewards)).float()
352 | return traj
353 |
354 |
355 | def run_episode(env,
356 | model,
357 | number,
358 | play_name='',
359 | dump=False,
360 | eval_only=False,
361 | use_argmax=False,
362 | need_restart=False,
363 | entropy_beta=0.0):
364 | """Run one episode using the model with $number nodes/numbers."""
365 | is_over = False
366 | traj = collections.defaultdict(list)
367 | score = 0
368 | moves = []
369 | # If dump_play=True, store the states and actions in a json file
370 | # for visualization.
371 | dump_play = args.dump_play and dump
372 |
373 | if need_restart:
374 | env.restart()
375 |
376 | if args.is_path_task:
377 | optimal = env.unwrapped.dist
378 | relation = env.unwrapped.graph.get_edges()
379 | relation = np.stack([relation, relation.T], axis=-1)
380 | st, ed = env.current_state
381 | nodes_trajectory = [int(st)]
382 | destination = int(ed)
383 | policies = []
384 | elif args.is_sort_task:
385 | optimal = env.unwrapped.optimal
386 | array = [str(i) for i in env.unwrapped.array]
387 |
388 | while not is_over:
389 | if args.is_path_task:
390 | st, ed = env.current_state
391 | state = np.zeros((relation.shape[0], 2))
392 | state[st, 0] = 1
393 | state[ed, 1] = 1
394 | feed_dict = dict(states=np.array([state]), relations=np.array([relation]))
395 | elif args.is_sort_task:
396 | state = env.current_state
397 | feed_dict = dict(states=np.array([state]))
398 | feed_dict['entropy_beta'] = as_tensor(entropy_beta).float()
399 | feed_dict = as_tensor(feed_dict)
400 | if args.use_gpu:
401 | feed_dict = as_cuda(feed_dict)
402 |
403 | with torch.set_grad_enabled(not eval_only):
404 | output_dict = model(feed_dict)
405 |
406 | policy = output_dict['policy']
407 | p = as_numpy(policy.data[0])
408 | action = p.argmax() if use_argmax else random.choice(len(p), p=p)
409 | reward, is_over = env.action(action)
410 |
411 | # collect moves information
412 | if dump_play:
413 | if args.is_path_task:
414 | moves.append(int(action))
415 | nodes_trajectory.append(int(env.current_state[0]))
416 | logits = as_numpy(output_dict['logits'].data[0])
417 | tops = np.argsort(p)[-10:][::-1]
418 | tops = list(
419 | map(lambda x: (int(x), float(p[x]), float(logits[x])), tops))
420 | policies.append(tops)
421 | if args.is_sort_task:
422 | # Need to ensure that env.utils.MapActionProxy is the outermost class.
423 | mapped_x, mapped_y = env.mapping[action]
424 | moves.append([mapped_x, mapped_y])
425 |
426 | # For now, assume reward=1 only when succeed, otherwise reward=0.
427 | # Manipulate the reward and get success information according to reward.
428 | if reward == 0 and args.penalty is not None:
429 | reward = args.penalty
430 | succ = 1 if is_over and reward > 0.99 else 0
431 |
432 | score += reward
433 | traj['states'].append(state)
434 | if args.is_path_task:
435 | traj['relations'].append(relation)
436 | traj['rewards'].append(reward)
437 | traj['actions'].append(action)
438 |
439 | # dump json file storing information of playing
440 | if dump_play and not (args.dump_fail_only and succ):
441 | if args.is_path_task:
442 | num = env.unwrapped.nr_nodes
443 | graph = relation[:, :, 0].tolist()
444 | coordinates = env.unwrapped.graph.get_coordinates().tolist()
445 | json_str = json.dumps(
446 | dict(
447 | graph=graph,
448 | coordinates=coordinates,
449 | policies=policies,
450 | destination=destination,
451 | current=nodes_trajectory,
452 | moves=moves))
453 | if args.is_sort_task:
454 | num = env.unwrapped.nr_numbers
455 | json_str = json.dumps(dict(array=array, moves=moves))
456 | dump_file = os.path.join(args.current_dump_dir,
457 | '{}_size{}.json'.format(play_name, num))
458 | with open(dump_file, 'w') as f:
459 | f.write(json_str)
460 |
461 | length = len(traj['rewards'])
462 | return succ, score, traj, length, optimal
463 |
464 |
465 | class MyTrainer(MiningTrainerBase):
466 | def save_checkpoint(self, name):
467 | if args.checkpoints_dir is not None:
468 | checkpoint_file = os.path.join(args.checkpoints_dir,
469 | 'checkpoint_{}.pth'.format(name))
470 | super().save_checkpoint(checkpoint_file)
471 |
472 | def _dump_meters(self, meters, mode):
473 | if args.summary_file is not None:
474 | meters_kv = meters._canonize_values('avg')
475 | meters_kv['mode'] = mode
476 | meters_kv['epoch'] = self.current_epoch
477 | with open(args.summary_file, 'a') as f:
478 | f.write(io.dumps_json(meters_kv))
479 | f.write('\n')
480 |
481 | def _prepare_dataset(self, epoch_size, mode):
482 | pass
483 |
484 | def _get_player(self, number, mode):
485 | if args.is_path_task:
486 | test_len = args.gen_test_len
487 | dist_range = (test_len, test_len) if mode == 'test' \
488 | else (1, args.gen_max_len)
489 | player = make_env(args.task, number, dist_range=dist_range)
490 | else:
491 | player = make_env(args.task, number)
492 | player.restart()
493 | return player
494 |
495 | def _get_result_given_player(self, index, meters, number, player, mode):
496 | assert mode in ['train', 'test', 'mining', 'inherit']
497 | params = dict(
498 | eval_only=True,
499 | number=number,
500 | play_name='{}_epoch{}_episode{}'.format(mode, self.current_epoch,
501 | index))
502 | backup = None
503 | if mode == 'train':
504 | params['eval_only'] = False
505 | params['entropy_beta'] = self.entropy_beta
506 | meters.update(lr=self.lr, entropy_beta=self.entropy_beta)
507 | elif mode == 'test':
508 | params['dump'] = True
509 | params['use_argmax'] = True
510 | else:
511 | backup = copy.deepcopy(player)
512 | params['use_argmax'] = self.is_candidate
513 | succ, score, traj, length, optimal = \
514 | run_episode(player, self.model, **params)
515 | meters.update(
516 | number=number, succ=succ, score=score, length=length, optimal=optimal)
517 |
518 | if mode == 'train':
519 | feed_dict = make_data(traj, args.gamma)
520 | feed_dict['entropy_beta'] = as_tensor(self.entropy_beta).float()
521 |
522 | if args.use_gpu:
523 | feed_dict = as_cuda(feed_dict)
524 | return feed_dict
525 | else:
526 | message = '> {} iter={iter}, number={number}, succ={succ}, \
527 | score={score:.4f}, length={length}, optimal={optimal}'.format(
528 | mode, iter=index, **meters.val)
529 | return message, dict(succ=succ, number=number, backup=backup)
530 |
531 | def _extract_info(self, extra):
532 | return extra['succ'], extra['number'], extra['backup']
533 |
534 | def _get_accuracy(self, meters):
535 | return meters.avg['succ']
536 |
537 | def _get_threshold(self):
538 | candidate_relax = 0 if self.is_candidate else args.candidate_relax
539 | return super()._get_threshold() - \
540 | self.curriculum_thresh_relax * candidate_relax
541 |
542 | def _upgrade_lesson(self):
543 | super()._upgrade_lesson()
544 | # Adjust lr & entropy_beta w.r.t different lesson progressively.
545 | self.lr *= args.lr_decay
546 | self.entropy_beta *= args.entropy_beta_decay
547 | self.set_learning_rate(self.lr)
548 |
549 | def _train_epoch(self, epoch_size):
550 | meters = super()._train_epoch(epoch_size)
551 |
552 | i = self.current_epoch
553 | if args.save_interval is not None and i % args.save_interval == 0:
554 | self.save_checkpoint(str(i))
555 | if args.test_interval is not None and i % args.test_interval == 0:
556 | self.test()
557 |
558 | return meters
559 |
560 | def _early_stop(self, meters):
561 | t = args.early_drop_epochs
562 | if t is not None and self.current_epoch > t * (self.nr_upgrades + 1):
563 | return True
564 | return super()._early_stop(meters)
565 |
566 | def train(self):
567 | self.lr = args.lr
568 | self.entropy_beta = args.entropy_beta
569 | return super().train()
570 |
571 |
572 | def main(run_id):
573 | if args.dump_dir is not None:
574 | if args.runs > 1:
575 | args.current_dump_dir = os.path.join(args.dump_dir,
576 | 'run_{}'.format(run_id))
577 | io.mkdir(args.current_dump_dir)
578 | else:
579 | args.current_dump_dir = args.dump_dir
580 | args.checkpoints_dir = os.path.join(args.current_dump_dir, 'checkpoints')
581 | io.mkdir(args.checkpoints_dir)
582 | args.summary_file = os.path.join(args.current_dump_dir, 'summary.json')
583 |
584 | logger.info(format_args(args))
585 |
586 | model = Model()
587 | if args.use_gpu:
588 | model.cuda()
589 | optimizer = get_optimizer(args.optimizer, model, args.lr)
590 | if args.accum_grad > 1:
591 | optimizer = AccumGrad(optimizer, args.accum_grad)
592 |
593 | trainer = MyTrainer.from_args(model, optimizer, args)
594 |
595 | if args.load_checkpoint is not None:
596 | trainer.load_checkpoint(args.load_checkpoint)
597 |
598 | if args.test_only:
599 | trainer.current_epoch = 0
600 | return None, trainer.test()
601 |
602 | graduated = trainer.train()
603 | trainer.save_checkpoint('last')
604 | test_meters = trainer.test() if graduated or args.test_not_graduated else None
605 | return graduated, test_meters
606 |
607 |
608 | if __name__ == '__main__':
609 | stats = []
610 | nr_graduated = 0
611 |
612 | for i in range(args.runs):
613 | graduated, test_meters = main(i)
614 | logger.info('run {}'.format(i + 1))
615 |
616 | if test_meters is not None:
617 | for j, meters in enumerate(test_meters):
618 | if len(stats) <= j:
619 | stats.append(GroupMeters())
620 | stats[j].update(
621 | number=meters.avg['number'], test_succ=meters.avg['succ'])
622 |
623 | for meters in stats:
624 | logger.info('number {}, test_succ {}'.format(meters.avg['number'],
625 | meters.avg['test_succ']))
626 |
627 | if not args.test_only:
628 | nr_graduated += int(graduated)
629 | logger.info('graduate_ratio {}'.format(nr_graduated / (i + 1)))
630 | if graduated:
631 | for j, meters in enumerate(test_meters):
632 | stats[j].update(grad_test_succ=meters.avg['succ'])
633 | if nr_graduated > 0:
634 | for meters in stats:
635 | logger.info('number {}, grad_test_succ {}'.format(
636 | meters.avg['number'], meters.avg['grad_test_succ']))
637 |
--------------------------------------------------------------------------------
/vis/README.md:
--------------------------------------------------------------------------------
1 | # Neural Logic Machine Visualization
2 |
3 | This directory is the website visualizer for sorting, shortest path and blocks world task.
4 |
5 | It includes an example for each task, which generated by the training/testing code.
6 |
7 | ```
8 | $ cd ROOT_DIR_OF_PEOJECT
9 | $ python2 -m SimpleHTTPServer -p PORT (default 8000)
10 | Open localhost:8000 in browser, and open the corresponding html file
11 | click "run" button.
12 | ```
13 |
14 | To visualize another play, edit the value of "DATA_URL" in the html file.
15 |
16 | htmls: [sort, path, blocksworld]
17 |
--------------------------------------------------------------------------------
/vis/blocksworld.html:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
106 |
107 |
108 |
109 |
110 |
111 |