├── .gitignore ├── .gitmodules ├── CONTRIBUTING.md ├── LICENSE ├── README.md ├── _assets └── framework.png ├── difflogic ├── __init__.py ├── cli.py ├── dataset │ ├── __init__.py │ ├── graph │ │ ├── __init__.py │ │ ├── dataset.py │ │ └── family.py │ └── utils.py ├── envs │ ├── __init__.py │ ├── algorithmic │ │ ├── __init__.py │ │ ├── quickaccess.py │ │ └── sort_envs.py │ ├── blocksworld │ │ ├── __init__.py │ │ ├── block.py │ │ ├── envs.py │ │ ├── quickaccess.py │ │ └── represent.py │ ├── graph │ │ ├── __init__.py │ │ ├── graph.py │ │ ├── graph_env.py │ │ └── quickaccess.py │ └── utils.py ├── nn │ ├── __init__.py │ ├── baselines │ │ ├── __init__.py │ │ ├── lstm.py │ │ └── memory_net.py │ ├── neural_logic │ │ ├── __init__.py │ │ ├── layer.py │ │ └── modules │ │ │ ├── __init__.py │ │ │ ├── _utils.py │ │ │ ├── dimension.py │ │ │ ├── input_transform.py │ │ │ └── neural_logic.py │ └── rl │ │ ├── __init__.py │ │ └── reinforce.py ├── thutils.py ├── tqdm_utils.py └── train │ ├── __init__.py │ └── train.py ├── models ├── blocksworld.pth ├── path.pth └── sort.pth ├── requirements.txt ├── scripts ├── blocksworld │ ├── README.md │ └── learn_policy.py └── graph │ ├── README.md │ ├── learn_graph_tasks.py │ └── learn_policy.py ├── third_party └── js_lib │ ├── d3.v4.js │ └── jquery-3.3.1.js └── vis ├── README.md ├── blocksworld.html ├── blocksworld.json ├── path.html ├── path.json ├── sort.html └── sort.json /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | env/ 12 | build/ 13 | develop-eggs/ 14 | dist/ 15 | downloads/ 16 | eggs/ 17 | .eggs/ 18 | lib/ 19 | lib64/ 20 | parts/ 21 | sdist/ 22 | var/ 23 | wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *.cover 47 | .hypothesis/ 48 | 49 | # Translations 50 | *.mo 51 | *.pot 52 | 53 | # Django stuff: 54 | *.log 55 | local_settings.py 56 | 57 | # Flask stuff: 58 | instance/ 59 | .webassets-cache 60 | 61 | # Scrapy stuff: 62 | .scrapy 63 | 64 | # Sphinx documentation 65 | docs/_build/ 66 | 67 | # PyBuilder 68 | target/ 69 | 70 | # Jupyter Notebook 71 | .ipynb_checkpoints 72 | 73 | # pyenv 74 | .python-version 75 | 76 | # celery beat schedule file 77 | celerybeat-schedule 78 | 79 | # SageMath parsed files 80 | *.sage.py 81 | 82 | # dotenv 83 | .env 84 | 85 | # virtualenv 86 | .venv 87 | venv/ 88 | ENV/ 89 | 90 | # Spyder project settings 91 | .spyderproject 92 | .spyproject 93 | 94 | # Rope project settings 95 | .ropeproject 96 | 97 | # mkdocs documentation 98 | /site 99 | 100 | # mypy 101 | .mypy_cache/ 102 | 103 | # sublime worksapce 104 | *.sublime-workspace 105 | -------------------------------------------------------------------------------- /.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "third_party/Jacinle"] 2 | path = third_party/Jacinle 3 | url = https://github.com/vacancy/Jacinle 4 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # How to Contribute 2 | 3 | We'd love to accept your patches and contributions to this project. There are 4 | just a few small guidelines you need to follow. 5 | 6 | ## Contributor License Agreement 7 | 8 | Contributions to this project must be accompanied by a Contributor License 9 | Agreement. You (or your employer) retain the copyright to your contribution; 10 | this simply gives us permission to use and redistribute your contributions as 11 | part of the project. Head over to to see 12 | your current agreements on file or to sign a new one. 13 | 14 | You generally only need to submit a CLA once, so if you've already submitted one 15 | (even if it was for a different project), you probably don't need to do it 16 | again. 17 | 18 | ## Code reviews 19 | 20 | All submissions, including submissions by project members, require review. We 21 | use GitHub pull requests for this purpose. Consult 22 | [GitHub Help](https://help.github.com/articles/about-pull-requests/) for more 23 | information on using pull requests. 24 | 25 | ## Community Guidelines 26 | 27 | This project follows [Google's Open Source Community 28 | Guidelines](https://opensource.google.com/conduct/). 29 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | 2 | Apache License 3 | Version 2.0, January 2004 4 | http://www.apache.org/licenses/ 5 | 6 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 7 | 8 | 1. Definitions. 9 | 10 | "License" shall mean the terms and conditions for use, reproduction, 11 | and distribution as defined by Sections 1 through 9 of this document. 12 | 13 | "Licensor" shall mean the copyright owner or entity authorized by 14 | the copyright owner that is granting the License. 15 | 16 | "Legal Entity" shall mean the union of the acting entity and all 17 | other entities that control, are controlled by, or are under common 18 | control with that entity. For the purposes of this definition, 19 | "control" means (i) the power, direct or indirect, to cause the 20 | direction or management of such entity, whether by contract or 21 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 22 | outstanding shares, or (iii) beneficial ownership of such entity. 23 | 24 | "You" (or "Your") shall mean an individual or Legal Entity 25 | exercising permissions granted by this License. 26 | 27 | "Source" form shall mean the preferred form for making modifications, 28 | including but not limited to software source code, documentation 29 | source, and configuration files. 30 | 31 | "Object" form shall mean any form resulting from mechanical 32 | transformation or translation of a Source form, including but 33 | not limited to compiled object code, generated documentation, 34 | and conversions to other media types. 35 | 36 | "Work" shall mean the work of authorship, whether in Source or 37 | Object form, made available under the License, as indicated by a 38 | copyright notice that is included in or attached to the work 39 | (an example is provided in the Appendix below). 40 | 41 | "Derivative Works" shall mean any work, whether in Source or Object 42 | form, that is based on (or derived from) the Work and for which the 43 | editorial revisions, annotations, elaborations, or other modifications 44 | represent, as a whole, an original work of authorship. For the purposes 45 | of this License, Derivative Works shall not include works that remain 46 | separable from, or merely link (or bind by name) to the interfaces of, 47 | the Work and Derivative Works thereof. 48 | 49 | "Contribution" shall mean any work of authorship, including 50 | the original version of the Work and any modifications or additions 51 | to that Work or Derivative Works thereof, that is intentionally 52 | submitted to Licensor for inclusion in the Work by the copyright owner 53 | or by an individual or Legal Entity authorized to submit on behalf of 54 | the copyright owner. For the purposes of this definition, "submitted" 55 | means any form of electronic, verbal, or written communication sent 56 | to the Licensor or its representatives, including but not limited to 57 | communication on electronic mailing lists, source code control systems, 58 | and issue tracking systems that are managed by, or on behalf of, the 59 | Licensor for the purpose of discussing and improving the Work, but 60 | excluding communication that is conspicuously marked or otherwise 61 | designated in writing by the copyright owner as "Not a Contribution." 62 | 63 | "Contributor" shall mean Licensor and any individual or Legal Entity 64 | on behalf of whom a Contribution has been received by Licensor and 65 | subsequently incorporated within the Work. 66 | 67 | 2. Grant of Copyright License. Subject to the terms and conditions of 68 | this License, each Contributor hereby grants to You a perpetual, 69 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 70 | copyright license to reproduce, prepare Derivative Works of, 71 | publicly display, publicly perform, sublicense, and distribute the 72 | Work and such Derivative Works in Source or Object form. 73 | 74 | 3. Grant of Patent License. Subject to the terms and conditions of 75 | this License, each Contributor hereby grants to You a perpetual, 76 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 77 | (except as stated in this section) patent license to make, have made, 78 | use, offer to sell, sell, import, and otherwise transfer the Work, 79 | where such license applies only to those patent claims licensable 80 | by such Contributor that are necessarily infringed by their 81 | Contribution(s) alone or by combination of their Contribution(s) 82 | with the Work to which such Contribution(s) was submitted. If You 83 | institute patent litigation against any entity (including a 84 | cross-claim or counterclaim in a lawsuit) alleging that the Work 85 | or a Contribution incorporated within the Work constitutes direct 86 | or contributory patent infringement, then any patent licenses 87 | granted to You under this License for that Work shall terminate 88 | as of the date such litigation is filed. 89 | 90 | 4. Redistribution. You may reproduce and distribute copies of the 91 | Work or Derivative Works thereof in any medium, with or without 92 | modifications, and in Source or Object form, provided that You 93 | meet the following conditions: 94 | 95 | (a) You must give any other recipients of the Work or 96 | Derivative Works a copy of this License; and 97 | 98 | (b) You must cause any modified files to carry prominent notices 99 | stating that You changed the files; and 100 | 101 | (c) You must retain, in the Source form of any Derivative Works 102 | that You distribute, all copyright, patent, trademark, and 103 | attribution notices from the Source form of the Work, 104 | excluding those notices that do not pertain to any part of 105 | the Derivative Works; and 106 | 107 | (d) If the Work includes a "NOTICE" text file as part of its 108 | distribution, then any Derivative Works that You distribute must 109 | include a readable copy of the attribution notices contained 110 | within such NOTICE file, excluding those notices that do not 111 | pertain to any part of the Derivative Works, in at least one 112 | of the following places: within a NOTICE text file distributed 113 | as part of the Derivative Works; within the Source form or 114 | documentation, if provided along with the Derivative Works; or, 115 | within a display generated by the Derivative Works, if and 116 | wherever such third-party notices normally appear. The contents 117 | of the NOTICE file are for informational purposes only and 118 | do not modify the License. You may add Your own attribution 119 | notices within Derivative Works that You distribute, alongside 120 | or as an addendum to the NOTICE text from the Work, provided 121 | that such additional attribution notices cannot be construed 122 | as modifying the License. 123 | 124 | You may add Your own copyright statement to Your modifications and 125 | may provide additional or different license terms and conditions 126 | for use, reproduction, or distribution of Your modifications, or 127 | for any such Derivative Works as a whole, provided Your use, 128 | reproduction, and distribution of the Work otherwise complies with 129 | the conditions stated in this License. 130 | 131 | 5. Submission of Contributions. Unless You explicitly state otherwise, 132 | any Contribution intentionally submitted for inclusion in the Work 133 | by You to the Licensor shall be under the terms and conditions of 134 | this License, without any additional terms or conditions. 135 | Notwithstanding the above, nothing herein shall supersede or modify 136 | the terms of any separate license agreement you may have executed 137 | with Licensor regarding such Contributions. 138 | 139 | 6. Trademarks. This License does not grant permission to use the trade 140 | names, trademarks, service marks, or product names of the Licensor, 141 | except as required for reasonable and customary use in describing the 142 | origin of the Work and reproducing the content of the NOTICE file. 143 | 144 | 7. Disclaimer of Warranty. Unless required by applicable law or 145 | agreed to in writing, Licensor provides the Work (and each 146 | Contributor provides its Contributions) on an "AS IS" BASIS, 147 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 148 | implied, including, without limitation, any warranties or conditions 149 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 150 | PARTICULAR PURPOSE. You are solely responsible for determining the 151 | appropriateness of using or redistributing the Work and assume any 152 | risks associated with Your exercise of permissions under this License. 153 | 154 | 8. Limitation of Liability. In no event and under no legal theory, 155 | whether in tort (including negligence), contract, or otherwise, 156 | unless required by applicable law (such as deliberate and grossly 157 | negligent acts) or agreed to in writing, shall any Contributor be 158 | liable to You for damages, including any direct, indirect, special, 159 | incidental, or consequential damages of any character arising as a 160 | result of this License or out of the use or inability to use the 161 | Work (including but not limited to damages for loss of goodwill, 162 | work stoppage, computer failure or malfunction, or any and all 163 | other commercial damages or losses), even if such Contributor 164 | has been advised of the possibility of such damages. 165 | 166 | 9. Accepting Warranty or Additional Liability. While redistributing 167 | the Work or Derivative Works thereof, You may choose to offer, 168 | and charge a fee for, acceptance of support, warranty, indemnity, 169 | or other liability obligations and/or rights consistent with this 170 | License. However, in accepting such obligations, You may act only 171 | on Your own behalf and on Your sole responsibility, not on behalf 172 | of any other Contributor, and only if You agree to indemnify, 173 | defend, and hold each Contributor harmless for any liability 174 | incurred by, or claims asserted against, such Contributor by reason 175 | of your accepting any such warranty or additional liability. 176 | 177 | END OF TERMS AND CONDITIONS 178 | 179 | APPENDIX: How to apply the Apache License to your work. 180 | 181 | To apply the Apache License to your work, attach the following 182 | boilerplate notice, with the fields enclosed by brackets "[]" 183 | replaced with your own identifying information. (Don't include 184 | the brackets!) The text should be enclosed in the appropriate 185 | comment syntax for the file format. We also recommend that a 186 | file or class name and description of purpose be included on the 187 | same "printed page" as the copyright notice for easier 188 | identification within third-party archives. 189 | 190 | Copyright [yyyy] [name of copyright owner] 191 | 192 | Licensed under the Apache License, Version 2.0 (the "License"); 193 | you may not use this file except in compliance with the License. 194 | You may obtain a copy of the License at 195 | 196 | http://www.apache.org/licenses/LICENSE-2.0 197 | 198 | Unless required by applicable law or agreed to in writing, software 199 | distributed under the License is distributed on an "AS IS" BASIS, 200 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 201 | See the License for the specific language governing permissions and 202 | limitations under the License. 203 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Neural Logic Machines 2 | PyTorch implementation for the Neural Logic Machines (NLM). **Please note that this is not an officially supported Google product.** 3 | 4 |
5 | 6 |
7 | 8 | Neural Logic Machine (NLM) is a neural-symbolic architecture for both inductive learning and logic reasoning. NLMs use tensors to represent logic predicates. This is done by grounding the predicate as 9 | True or False over a fixed set of objects. Based on the tensor representation, rules are implemented 10 | as neural operators that can be applied over the premise tensors and generate conclusion tensors. 11 | 12 | **[Neural Logic Machines](https://arxiv.org/pdf/1904.11694.pdf)** 13 |
14 | [Honghua Dong](http://dhh1995.github.io)\*, 15 | [Jiayuan Mao](http://jiayuanm.com)\*, 16 | [Tian Lin](https://www.linkedin.com/in/tianl), 17 | [Chong Wang](https://chongw.github.io/), 18 | [Lihong Li](https://lihongli.github.io/), and 19 | [Denny Zhou](https://dennyzhou.github.io/) 20 |
21 | (\*: indicates equal contribution.) 22 |
23 | In International Conference on Learning Representations (ICLR) 2019 24 |
25 | [[Paper]](https://arxiv.org/pdf/1904.11694.pdf) 26 | [[Project Page]](https://sites.google.com/view/neural-logic-machines) 27 | 28 | ``` 29 | @inproceedings{ 30 | dong2018neural, 31 | title = {Neural Logic Machines}, 32 | author = {Honghua Dong and Jiayuan Mao and Tian Lin and Chong Wang and Lihong Li and Denny Zhou}, 33 | booktitle = {International Conference on Learning Representations}, 34 | year = {2019}, 35 | url = {https://openreview.net/forum?id=B1xY-hRctX}, 36 | } 37 | ``` 38 | 39 | ## Prerequisites 40 | * Python 3 41 | * PyTorch 0.4.0 42 | * [Jacinle](https://github.com/vacancy/Jacinle). We use the version [ed90c3a](https://github.com/vacancy/Jacinle/tree/ed90c3a70a133eb9c6c2f4ea2cc3d907de7ffd57) for this repo. 43 | * Other required python packages specified by `requirements.txt`. See the Installation. 44 | 45 | ## Installation 46 | 47 | Clone this repository: 48 | 49 | ``` 50 | git clone https://github.com/google/neural-logic-machines --recursive 51 | ``` 52 | 53 | Install [Jacinle](https://github.com/vacancy/Jacinle) included as a submodule. You need to add the bin path to your global `PATH` environment variable: 54 | 55 | ``` 56 | export PATH=/third_party/Jacinle/bin:$PATH 57 | ``` 58 | 59 | Create a conda environment for NLM, and install the requirements. This includes the required python packages 60 | from both Jacinle and NLM. Most of the required packages have been included in the built-in `anaconda` package: 61 | 62 | ``` 63 | conda create -n nlm anaconda 64 | conda install pytorch torchvision -c pytorch 65 | ``` 66 | 67 | ## Usage 68 | 69 | This repo contains 10 graph-related reasoning tasks (using supervised learning) 70 | and 3 decision-making tasks (using reinforcement learning). 71 | 72 | We also provide pre-trained models for 3 decision-making tasks in [models](models) directory, 73 | 74 | Taking the [Blocks World](scripts/blocksworld) task as an example. 75 | 76 | ``` shell 77 | # To train the model: 78 | $ jac-run scripts/blocksworld/learn_policy.py --task final 79 | # To test the model: 80 | $ jac-run scripts/blocksworld/learn_policy.py --task final --test-only --load models/blocksworld.pth 81 | # add [--test-epoch-size T] to control the number of testing cases. 82 | # E.g. use T=20 for a quick testing, usually take ~2min on CPUs. 83 | # Sample output of testing for number=10 and number=50: 84 | > Evaluation: 85 | length = 12.500000 86 | number = 10.000000 87 | score = 0.885000 88 | succ = 1.000000 89 | > Evaluation: 90 | length = 85.800000 91 | number = 50.000000 92 | score = 0.152000 93 | succ = 1.000000 94 | ``` 95 | 96 | Please refer to the [graph](scripts/graph) directory for training/inference details of other tasks. 97 | 98 | ### Useful Command-line options 99 | - `jac-crun GPU_ID FILE --use-gpu GPU_ID` instead of `jac-run FILE` to enable using gpu with id `GPU_ID`. 100 | - `--model {nlm, memnet}`[default: `nlm`]: choose `memnet` to use (Memory Networks)[https://arxiv.org/abs/1503.08895] as baseline. 101 | - `--runs N`: take `N` runs. 102 | - `--dump-dir DUMP_DIR`: place to dump logs/summaries/checkpoints/plays. 103 | - `--dump-play`: dump plays for visualization in json format, can be visualized by our [html visualizer](vis). (not applied to [graph tasks](scripts/graph/learn_graph_tasks.py)) 104 | - `--test-number-begin B --test-number-step S --step-number-end E`: \ 105 | defines the range of the sizes of the test instances. 106 | - `--test-epoch-size SIZE`: number of test instances. 107 | 108 | For a complete command-line options see `jac-run FILE -h` (e.g. `jac-run scripts/blocksworld/learn_policy.py -h`). 109 | -------------------------------------------------------------------------------- /_assets/framework.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google/neural-logic-machines/3f8a8966c54d13d2658c77c03793a9a98a283e22/_assets/framework.png -------------------------------------------------------------------------------- /difflogic/__init__.py: -------------------------------------------------------------------------------- 1 | #! /usr/bin/env python3 2 | # 3 | # Copyright 2018 Google LLC 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # https://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | -------------------------------------------------------------------------------- /difflogic/cli.py: -------------------------------------------------------------------------------- 1 | #! /usr/bin/env python3 2 | # 3 | # Copyright 2018 Google LLC 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # https://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | """Command line interface, print or format args as string.""" 17 | 18 | from jacinle.utils.printing import kvformat 19 | from jacinle.utils.printing import kvprint 20 | 21 | 22 | def print_args(args): 23 | kvprint(args.__dict__) 24 | 25 | 26 | def format_args(args): 27 | return kvformat(args.__dict__) 28 | -------------------------------------------------------------------------------- /difflogic/dataset/__init__.py: -------------------------------------------------------------------------------- 1 | #! /usr/bin/env python3 2 | # 3 | # Copyright 2018 Google LLC 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # https://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | 17 | -------------------------------------------------------------------------------- /difflogic/dataset/graph/__init__.py: -------------------------------------------------------------------------------- 1 | #! /usr/bin/env python3 2 | # 3 | # Copyright 2018 Google LLC 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # https://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | 17 | from .dataset import * 18 | -------------------------------------------------------------------------------- /difflogic/dataset/graph/dataset.py: -------------------------------------------------------------------------------- 1 | #! /usr/bin/env python3 2 | # 3 | # Copyright 2018 Google LLC 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # https://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | """Implement datasets classes for graph and family tree tasks.""" 17 | 18 | import numpy as np 19 | 20 | from torch.utils.data.dataset import Dataset 21 | from torchvision import datasets 22 | 23 | import jacinle.random as random 24 | 25 | from .family import randomly_generate_family 26 | from ...envs.graph import get_random_graph_generator 27 | 28 | __all__ = [ 29 | 'GraphOutDegreeDataset', 'GraphConnectivityDataset', 'GraphAdjacentDataset', 30 | 'FamilyTreeDataset' 31 | ] 32 | 33 | 34 | class GraphDatasetBase(Dataset): 35 | """Base dataset class for graphs. 36 | 37 | Args: 38 | epoch_size: The number of batches for each epoch. 39 | nmin: The minimal number of nodes in the graph. 40 | pmin: The lower bound of the parameter p of the graph generator. 41 | nmax: The maximal number of nodes in the graph, 42 | the same as $nmin in default. 43 | pmax: The upper bound of the parameter p of the graph generator, 44 | the same as $pmin in default. 45 | directed: Generator directed graph if directed=True. 46 | gen_method: Controlling the graph generation method. 47 | If gen_method='dnc', use the similar way as in DNC paper. 48 | Else using Erdos-Renyi algorithm (each edge exists with prob). 49 | """ 50 | 51 | def __init__(self, 52 | epoch_size, 53 | nmin, 54 | pmin, 55 | nmax=None, 56 | pmax=None, 57 | directed=False, 58 | gen_method='dnc'): 59 | self._epoch_size = epoch_size 60 | self._nmin = nmin 61 | self._nmax = nmin if nmax is None else nmax 62 | assert self._nmin <= self._nmax 63 | self._pmin = pmin 64 | self._pmax = pmin if pmax is None else pmax 65 | assert self._pmin <= self._pmax 66 | self._directed = directed 67 | self._gen_method = gen_method 68 | 69 | def _gen_graph(self, item): 70 | n = self._nmin + item % (self._nmax - self._nmin + 1) 71 | p = self._pmin + random.rand() * (self._pmax - self._pmin) 72 | gen = get_random_graph_generator(self._gen_method) 73 | return gen(n, p, directed=self._directed) 74 | 75 | def __len__(self): 76 | return self._epoch_size 77 | 78 | 79 | class GraphOutDegreeDataset(GraphDatasetBase): 80 | """The dataset for out-degree task in graphs.""" 81 | 82 | def __init__(self, 83 | degree, 84 | epoch_size, 85 | nmin, 86 | pmin, 87 | nmax=None, 88 | pmax=None, 89 | directed=False, 90 | gen_method='dnc'): 91 | super().__init__(epoch_size, nmin, pmin, nmax, pmax, directed, gen_method) 92 | self._degree = degree 93 | 94 | def __getitem__(self, item): 95 | graph = self._gen_graph(item) 96 | # The goal is to predict whether out-degree(x) == self._degree for all x. 97 | return dict( 98 | n=graph.nr_nodes, 99 | relations=np.expand_dims(graph.get_edges(), axis=-1), 100 | target=(graph.get_out_degree() == self._degree).astype('float'), 101 | ) 102 | 103 | 104 | class GraphConnectivityDataset(GraphDatasetBase): 105 | """The dataset for connectivity task in graphs.""" 106 | 107 | def __init__(self, 108 | dist_limit, 109 | epoch_size, 110 | nmin, 111 | pmin, 112 | nmax=None, 113 | pmax=None, 114 | directed=False, 115 | gen_method='dnc'): 116 | super().__init__(epoch_size, nmin, pmin, nmax, pmax, directed, gen_method) 117 | self._dist_limit = dist_limit 118 | 119 | def __getitem__(self, item): 120 | graph = self._gen_graph(item) 121 | # The goal is to predict whether (x, y) are connected within a limited steps 122 | # I.e. dist(x, y) <= self._dist_limit for all x, y. 123 | return dict( 124 | n=graph.nr_nodes, 125 | relations=np.expand_dims(graph.get_edges(), axis=-1), 126 | target=graph.get_connectivity(self._dist_limit, exclude_self=True), 127 | ) 128 | 129 | 130 | class GraphAdjacentDataset(GraphDatasetBase): 131 | """The dataset for adjacent task in graphs.""" 132 | 133 | def __init__(self, 134 | nr_colors, 135 | epoch_size, 136 | nmin, 137 | pmin, 138 | nmax=None, 139 | pmax=None, 140 | directed=False, 141 | gen_method='dnc', 142 | is_train=True, 143 | is_mnist_colors=False, 144 | mnist_dir='../data'): 145 | 146 | super().__init__(epoch_size, nmin, pmin, nmax, pmax, directed, gen_method) 147 | self._nr_colors = nr_colors 148 | self._is_mnist_colors = is_mnist_colors 149 | # When taking MNIST digits as inputs, fetch MNIST dataset. 150 | if self._is_mnist_colors: 151 | assert nr_colors == 10 152 | self.mnist = datasets.MNIST( 153 | mnist_dir, train=is_train, download=True, transform=None) 154 | 155 | def __getitem__(self, item): 156 | graph = self._gen_graph(item) 157 | n = graph.nr_nodes 158 | if self._is_mnist_colors: 159 | m = self.mnist.__len__() 160 | digits = [] 161 | colors = [] 162 | for i in range(n): 163 | x = random.randint(m) 164 | digit, color = self.mnist.__getitem__(x) 165 | digits.append(np.array(digit)[np.newaxis]) 166 | colors.append(color) 167 | digits, colors = np.array(digits), np.array(colors) 168 | else: 169 | colors = random.randint(self._nr_colors, size=n) 170 | states = np.zeros((n, self._nr_colors)) 171 | adjacent = np.zeros((n, self._nr_colors)) 172 | # The goal is to predict whether there is a node with desired color 173 | # as adjacent node for each node x. 174 | for i in range(n): 175 | states[i, colors[i]] = 1 176 | adjacent[i, colors[i]] = 1 177 | for j in range(n): 178 | if graph.has_edge(i, j): 179 | adjacent[i, colors[j]] = 1 180 | if self._is_mnist_colors: 181 | states = digits 182 | return dict( 183 | n=n, 184 | relations=np.expand_dims(graph.get_edges(), axis=-1), 185 | states=states, 186 | colors=colors, 187 | target=adjacent, 188 | ) 189 | 190 | 191 | class FamilyTreeDataset(Dataset): 192 | """The dataset for family tree tasks.""" 193 | 194 | def __init__(self, 195 | task, 196 | epoch_size, 197 | nmin, 198 | nmax=None, 199 | p_marriage=0.8, 200 | balance_sample=False): 201 | super().__init__() 202 | self._task = task 203 | self._epoch_size = epoch_size 204 | self._nmin = nmin 205 | self._nmax = nmin if nmax is None else nmax 206 | assert self._nmin <= self._nmax 207 | self._p_marriage = p_marriage 208 | self._balance_sample = balance_sample 209 | self._data = [] 210 | 211 | def _gen_family(self, item): 212 | n = self._nmin + item % (self._nmax - self._nmin + 1) 213 | return randomly_generate_family(n, self._p_marriage) 214 | 215 | def __getitem__(self, item): 216 | while len(self._data) == 0: 217 | family = self._gen_family(item) 218 | relations = family.relations[:, :, 2:] 219 | if self._task == 'has-father': 220 | target = family.has_father() 221 | elif self._task == 'has-daughter': 222 | target = family.has_daughter() 223 | elif self._task == 'has-sister': 224 | target = family.has_sister() 225 | elif self._task == 'parents': 226 | target = family.get_parents() 227 | elif self._task == 'grandparents': 228 | target = family.get_grandparents() 229 | elif self._task == 'uncle': 230 | target = family.get_uncle() 231 | elif self._task == 'maternal-great-uncle': 232 | target = family.get_maternal_great_uncle() 233 | else: 234 | assert False, '{} is not supported.'.format(self._task) 235 | 236 | if not self._balance_sample: 237 | return dict(n=family.nr_people, relations=relations, target=target) 238 | 239 | # In balance_sample case, the data format is different. Not used. 240 | def get_positions(x): 241 | return list(np.vstack(np.where(x)).T) 242 | 243 | def append_data(pos, target): 244 | states = np.zeros((family.nr_people, 2)) 245 | states[pos[0], 0] = states[pos[1], 1] = 1 246 | self._data.append(dict(n=family.nr_people, 247 | relations=relations, 248 | states=states, 249 | target=target)) 250 | 251 | positive = get_positions(target == 1) 252 | if len(positive) == 0: 253 | continue 254 | negative = get_positions(target == 0) 255 | np.random.shuffle(negative) 256 | negative = negative[:len(positive)] 257 | for i in positive: 258 | append_data(i, 1) 259 | for i in negative: 260 | append_data(i, 0) 261 | 262 | return self._data.pop() 263 | 264 | def __len__(self): 265 | return self._epoch_size 266 | -------------------------------------------------------------------------------- /difflogic/dataset/graph/family.py: -------------------------------------------------------------------------------- 1 | #! /usr/bin/env python3 2 | # 3 | # Copyright 2018 Google LLC 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # https://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | """Implement family tree generator and family tree class.""" 17 | 18 | import jacinle.random as random 19 | import numpy as np 20 | 21 | __all__ = ['Family', 'randomly_generate_family'] 22 | 23 | 24 | class Family(object): 25 | """Family tree class to support queries about relations between family members. 26 | 27 | Args: 28 | nr_people: The number of people in the family tree. 29 | relations: The relations between family members. The relations should be an 30 | matrix of shape [nr_people, nr_people, 6]. The relations in the order 31 | are: husband, wife, father, mother, son, daughter. 32 | """ 33 | 34 | def __init__(self, nr_people, relations): 35 | self._n = nr_people 36 | self._relations = relations 37 | 38 | def mul(self, x, y): 39 | return np.clip(np.matmul(x, y), 0, 1) 40 | 41 | @property 42 | def nr_people(self): 43 | return self._n 44 | 45 | @property 46 | def relations(self): 47 | return self._relations 48 | 49 | @property 50 | def father(self): 51 | return self._relations[:, :, 2] 52 | 53 | @property 54 | def mother(self): 55 | return self._relations[:, :, 3] 56 | 57 | @property 58 | def son(self): 59 | return self._relations[:, :, 4] 60 | 61 | @property 62 | def daughter(self): 63 | return self._relations[:, :, 5] 64 | 65 | def has_father(self): 66 | return self.father.max(axis=1) 67 | 68 | def has_daughter(self): 69 | return self.daughter.max(axis=1) 70 | 71 | def has_sister(self): 72 | daughter_cnt = self.daughter.sum(axis=1) 73 | is_daughter = np.clip(self.daughter.sum(axis=0), 0, 1) 74 | return ((np.matmul(self.father, daughter_cnt) - is_daughter) > 75 | 0).astype('float') 76 | # The wrong implementation: count herself as sister. 77 | # return self.mul(self.father, self.daughter).max(axis=1) 78 | 79 | def get_parents(self): 80 | return np.clip(self.father + self.mother, 0, 1) 81 | 82 | def get_grandfather(self): 83 | return self.mul(self.get_parents(), self.father) 84 | 85 | def get_grandmother(self): 86 | return self.mul(self.get_parents(), self.mother) 87 | 88 | def get_grandparents(self): 89 | parents = self.get_parents() 90 | return self.mul(parents, parents) 91 | 92 | def get_uncle(self): 93 | return np.clip(self.mul(self.get_grandparents(), self.son) - self.father, 0, 1) 94 | # The wrong Implementation: not exclude father. 95 | # return self.mul(self.get_grandparents(), self.son) 96 | 97 | def get_maternal_great_uncle(self): 98 | return self.mul(self.mul(self.get_grandmother(), self.mother), self.son) 99 | 100 | 101 | def randomly_generate_family(n, p_marriage=0.8, verbose=False): 102 | """Randomly generate family trees. 103 | 104 | Mimic the process of families growing using a timeline. Each time a new person 105 | is created, randomly sample the gender and parents (could be none, indicating 106 | not included in the family tree) of the person. Also maintain lists of singles 107 | of each gender. With probability $p_marrige, randomly pick two from each list 108 | to be married. Finally randomly permute the order of people. 109 | 110 | Args: 111 | n: The number of people in the family tree. 112 | p_marriage: The probability of marriage happens each time. 113 | verbose: print the marriage and child born process if verbose=True. 114 | Returns: 115 | A family tree instance of $n people. 116 | """ 117 | assert n > 0 118 | ids = list(random.permutation(n)) 119 | 120 | single_m = [] 121 | single_w = [] 122 | couples = [None] 123 | # The relations are: husband, wife, father, mother, son, daughter 124 | rel = np.zeros((n, n, 6)) 125 | fathers = [None for i in range(n)] 126 | mothers = [None for i in range(n)] 127 | 128 | def add_couple(man, woman): 129 | """Add a couple relation among (man, woman).""" 130 | couples.append((man, woman)) 131 | rel[woman, man, 0] = 1 # husband 132 | rel[man, woman, 1] = 1 # wife 133 | if verbose: 134 | print('couple', man, woman) 135 | 136 | def add_child(parents, child, gender): 137 | """Add a child relation between parents and the child according to gender.""" 138 | father, mother = parents 139 | fathers[child] = father 140 | mothers[child] = mother 141 | rel[child, father, 2] = 1 # father 142 | rel[child, mother, 3] = 1 # mother 143 | if gender == 0: # son 144 | rel[father, child, 4] = 1 145 | rel[mother, child, 4] = 1 146 | else: # daughter 147 | rel[father, child, 5] = 1 148 | rel[mother, child, 5] = 1 149 | if verbose: 150 | print('child', father, mother, child, gender) 151 | 152 | def check_relations(man, woman): 153 | """Disable marriage between cousins.""" 154 | if fathers[man] is None or fathers[woman] is None: 155 | return True 156 | if fathers[man] == fathers[woman]: 157 | return False 158 | 159 | def same_parent(x, y): 160 | return fathers[x] is not None and fathers[y] is not None and fathers[ 161 | x] == fathers[y] 162 | 163 | for x in [fathers[man], mothers[man]]: 164 | for y in [fathers[woman], mothers[woman]]: 165 | if same_parent(man, y) or same_parent(woman, x) or same_parent(x, y): 166 | return False 167 | return True 168 | 169 | while ids: 170 | x = ids.pop() 171 | gender = random.randint(2) 172 | parents = random.choice(couples) 173 | if gender == 0: 174 | single_m.append(x) 175 | else: 176 | single_w.append(x) 177 | if parents is not None: 178 | add_child(parents, x, gender) 179 | 180 | if random.rand() < p_marriage and len(single_m) > 0 and len(single_w) > 0: 181 | mi = random.randint(len(single_m)) 182 | wi = random.randint(len(single_w)) 183 | man = single_m[mi] 184 | woman = single_w[wi] 185 | if check_relations(man, woman): 186 | add_couple(man, woman) 187 | del single_m[mi] 188 | del single_w[wi] 189 | 190 | return Family(n, rel) 191 | 192 | -------------------------------------------------------------------------------- /difflogic/dataset/utils.py: -------------------------------------------------------------------------------- 1 | #! /usr/bin/env python3 2 | # 3 | # Copyright 2018 Google LLC 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # https://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | """Utility functions for customer datasets.""" 17 | 18 | import collections 19 | import copy 20 | import jacinle.random as random 21 | import numpy as np 22 | 23 | __all__ = [ 24 | 'ValidActionDataset', 25 | 'RandomlyIterDataset', 26 | ] 27 | 28 | 29 | class ValidActionDataset(object): 30 | """Collect data and sample batches of whether actions are valid or not. 31 | 32 | The data are collected into different bins according to the number of objects 33 | in the cases. At most $maxn bins are maintained. The capacity of each bin 34 | should be set, the newly added ones exceeding the capacity would replace the 35 | earliest ones in the bin (implemented using deque). 36 | 37 | Args: 38 | capacity: upper bound of the number of training instances in each bin. 39 | maxn: The maximum number of objects in the collected cases. 40 | """ 41 | 42 | def __init__(self, capacity=5000, maxn=30): 43 | super().__init__() 44 | self.largest_n = 0 45 | self.maxn = maxn 46 | # Better to use defaultdict to replace array 47 | self.data = [[collections.deque(maxlen=capacity) 48 | for i in range(2)] 49 | for j in range(maxn + 1)] 50 | 51 | def append(self, n, state, action, valid): 52 | """add a new data point of n objects, given $state, the $action is $valid.""" 53 | assert n <= self.maxn 54 | valid = int(valid) 55 | self.data[n][valid].append((state, action)) 56 | self.largest_n = max(self.largest_n, n) 57 | 58 | def _sample(self, data, num, label): 59 | """Sample a batch of size $num from the data, with already determined label.""" 60 | # assert num <= len(data) 61 | states, actions = [], [] 62 | for _ in range(num): 63 | ind = random.randint(len(data)) 64 | state, action = data[ind] 65 | states.append(state) 66 | actions.append([action]) 67 | return np.array(states), np.array(actions), np.ones((num,)) * label 68 | 69 | def sample_batch(self, batch_size, n=None): 70 | """Sample a batch of data for $n objects.""" 71 | # use the data from the bin with largest number of objects in default. 72 | if n is None: 73 | n = self.largest_n 74 | data = self.data[n] 75 | # The pos/neg ones are not strict equal if batch_size % 2 != 0. 76 | # Should add warning. 77 | num = batch_size // 2 78 | # if no negative ones, using all positive ones. 79 | c = 1 - int(len(data[0]) > 0) 80 | states1, actions1, labels1 = self._sample(data[c], num, c) 81 | # if no positive ones, using all negative ones. 82 | c = int(len(data[1]) > 0) 83 | states2, actions2, labels2 = self._sample(data[c], batch_size - num, c) 84 | return (np.vstack([states1, states2]), 85 | np.vstack([actions1, actions2]).squeeze(axis=-1), 86 | np.concatenate([labels1, labels2], axis=0)) 87 | 88 | 89 | class RandomlyIterDataset(object): 90 | """Collect data and iterate the dataset in random order.""" 91 | 92 | def __init__(self): 93 | super().__init__() 94 | self.data = [] 95 | self.ind = 0 96 | 97 | def append(self, data): 98 | self.data.append(data) 99 | 100 | @property 101 | def size(self): 102 | return len(self.data) 103 | 104 | def reset(self): 105 | self.ind = 0 106 | 107 | def get(self): 108 | """iterate the dataset with random order.""" 109 | # Shuffle before the iteration starts. 110 | # The iteration should better be separated with collection. 111 | if self.ind == 0: 112 | random.shuffle(self.data) 113 | ret = self.data[self.ind] 114 | self.ind += 1 115 | if self.ind == self.size: 116 | self.ind = 0 117 | return copy.deepcopy(ret) 118 | -------------------------------------------------------------------------------- /difflogic/envs/__init__.py: -------------------------------------------------------------------------------- 1 | #! /usr/bin/env python3 2 | # 3 | # Copyright 2018 Google LLC 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # https://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | -------------------------------------------------------------------------------- /difflogic/envs/algorithmic/__init__.py: -------------------------------------------------------------------------------- 1 | #! /usr/bin/env python3 2 | # 3 | # Copyright 2018 Google LLC 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # https://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | 17 | from .sort_envs import * 18 | from .quickaccess import * 19 | -------------------------------------------------------------------------------- /difflogic/envs/algorithmic/quickaccess.py: -------------------------------------------------------------------------------- 1 | #! /usr/bin/env python3 2 | # 3 | # Copyright 2018 Google LLC 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # https://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | """Quick access for algorithmic environments.""" 17 | 18 | from jaclearn.rl.proxy import LimitLengthProxy 19 | 20 | from .sort_envs import ListSortingEnv 21 | from ..utils import get_action_mapping_sorting 22 | from ..utils import MapActionProxy 23 | 24 | __all__ = ['get_sort_env', 'make'] 25 | 26 | 27 | def get_sort_env(n, exclude_self=True): 28 | env_cls = ListSortingEnv 29 | p = env_cls(n) 30 | p = LimitLengthProxy(p, n * 2) 31 | mapping = get_action_mapping_sorting(n, exclude_self=exclude_self) 32 | p = MapActionProxy(p, mapping) 33 | return p 34 | 35 | 36 | def make(task, *args, **kwargs): 37 | if task == 'sort': 38 | return get_sort_env(*args, **kwargs) 39 | else: 40 | raise ValueError('Unknown task: {}.'.format(task)) 41 | -------------------------------------------------------------------------------- /difflogic/envs/algorithmic/sort_envs.py: -------------------------------------------------------------------------------- 1 | #! /usr/bin/env python3 2 | # 3 | # Copyright 2018 Google LLC 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # https://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | """The environment class for sorting tasks.""" 17 | 18 | import numpy as np 19 | 20 | import jacinle.random as random 21 | from jacinle.utils.meta import notnone_property 22 | from jaclearn.rl.env import SimpleRLEnvBase 23 | 24 | 25 | class ListSortingEnv(SimpleRLEnvBase): 26 | """Environment for sorting a random permutation. 27 | 28 | Args: 29 | nr_numbers: The number of numbers in the array. 30 | """ 31 | 32 | def __init__(self, nr_numbers): 33 | super().__init__() 34 | self._nr_numbers = nr_numbers 35 | self._array = None 36 | 37 | @notnone_property 38 | def array(self): 39 | return self._array 40 | 41 | @property 42 | def nr_numbers(self): 43 | return self._nr_numbers 44 | 45 | def get_state(self): 46 | """Compute the state given the array.""" 47 | x, y = np.meshgrid(self.array, self.array) 48 | number_relations = np.stack([x < y, x == y, x > y], axis=-1).astype('float') 49 | index = np.array(list(range(self._nr_numbers))) 50 | x, y = np.meshgrid(index, index) 51 | position_relations = np.stack([x < y, x == y, x > y], 52 | axis=-1).astype('float') 53 | return np.concatenate([number_relations, position_relations], axis=-1) 54 | 55 | def _calculate_optimal(self): 56 | """Calculate the optimal number of steps for sorting the array.""" 57 | a = self._array 58 | b = [0 for i in range(len(a))] 59 | cnt = 0 60 | for i, x in enumerate(a): 61 | if b[i] == 0: 62 | j = x 63 | b[i] = 1 64 | while b[j] == 0: 65 | b[j] = 1 66 | j = a[j] 67 | assert i == j 68 | cnt += 1 69 | return len(a) - cnt 70 | 71 | def _restart(self): 72 | """Restart: Generate a random permutation.""" 73 | self._array = random.permutation(self._nr_numbers) 74 | self._set_current_state(self.get_state()) 75 | self.optimal = self._calculate_optimal() 76 | 77 | def _action(self, action): 78 | """action is a tuple (i, j), perform this action leads to the swap.""" 79 | a = self._array 80 | i, j = action 81 | x, y = a[i], a[j] 82 | a[i], a[j] = y, x 83 | self._set_current_state(self.get_state()) 84 | for i in range(self._nr_numbers): 85 | if a[i] != i: 86 | return 0, False 87 | return 1, True 88 | -------------------------------------------------------------------------------- /difflogic/envs/blocksworld/__init__.py: -------------------------------------------------------------------------------- 1 | #! /usr/bin/env python3 2 | # 3 | # Copyright 2018 Google LLC 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # https://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | 17 | from .block import * 18 | from .represent import * 19 | from .envs import * 20 | from .quickaccess import * 21 | -------------------------------------------------------------------------------- /difflogic/envs/blocksworld/block.py: -------------------------------------------------------------------------------- 1 | #! /usr/bin/env python3 2 | # 3 | # Copyright 2018 Google LLC 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # https://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | """Implement random blocksworld generator and BlocksWorld class.""" 17 | 18 | import jacinle.random as random 19 | 20 | __all__ = ['Block', 'BlocksWorld', 'randomly_generate_world'] 21 | 22 | 23 | class Block(object): 24 | """A single block in blocksworld, using tree-like storage method.""" 25 | 26 | def __init__(self, index, father=None): 27 | self.index = index 28 | self.father = father 29 | self.children = [] 30 | 31 | @property 32 | def is_ground(self): 33 | return self.father is None 34 | 35 | @property 36 | def placeable(self): 37 | if self.is_ground: 38 | return True 39 | return len(self.children) == 0 40 | 41 | @property 42 | def moveable(self): 43 | if self.is_ground: 44 | return False 45 | return len(self.children) == 0 46 | 47 | def remove_from_father(self): 48 | assert self in self.father.children 49 | self.father.children.remove(self) 50 | self.father = None 51 | 52 | def add_to(self, other): 53 | self.father = other 54 | other.children.append(self) 55 | 56 | 57 | class BlockStorage(object): 58 | """The storage of blocks with an order. 59 | 60 | Args: 61 | blocks: The list of instances of Block class. 62 | random_order: set the blocks in a desired order, or unchanged in default. 63 | """ 64 | 65 | def __init__(self, blocks, random_order=None): 66 | super().__init__() 67 | self._blocks = blocks 68 | self.set_random_order(random_order) 69 | 70 | def __getitem__(self, item): 71 | if self._random_order is None: 72 | return self._blocks[item] 73 | return self._blocks[self._random_order[item]] 74 | 75 | def __len__(self): 76 | return len(self._blocks) 77 | 78 | @property 79 | def raw(self): 80 | return self._blocks.copy() 81 | 82 | @property 83 | def random_order(self): 84 | return self._random_order 85 | 86 | def set_random_order(self, random_order): 87 | if random_order is None: 88 | self._random_order = None 89 | self._inv_random_order = None 90 | return 91 | 92 | self._random_order = random_order 93 | self._inv_random_order = sorted( 94 | range(len(random_order)), key=lambda x: random_order[x]) 95 | 96 | def index(self, i): 97 | if self._random_order is None: 98 | return i 99 | return self._random_order[i] 100 | 101 | def inv_index(self, i): 102 | if self._random_order is None: 103 | return i 104 | return self._inv_random_order[i] 105 | 106 | def permute(self, array): 107 | if self._random_order is None: 108 | return array 109 | return [array[self._random_order[i]] for i in range(len(self._blocks))] 110 | 111 | 112 | class BlocksWorld(object): 113 | """The blocks world class implement queries and movements.""" 114 | 115 | def __init__(self, blocks, random_order=None): 116 | super().__init__() 117 | self.blocks = BlockStorage(blocks, random_order) 118 | 119 | @property 120 | def size(self): 121 | return len(self.blocks) 122 | 123 | def move(self, x, y): 124 | if x != y and self.moveable(x, y): 125 | self.blocks[x].remove_from_father() 126 | self.blocks[x].add_to(self.blocks[y]) 127 | 128 | def moveable(self, x, y): 129 | return self.blocks[x].moveable and self.blocks[y].placeable 130 | 131 | 132 | def randomly_generate_world(nr_blocks, random_order=False, one_stack=False): 133 | """Randomly generate a blocks world case. 134 | 135 | Similar to classical random tree generation, incrementally add new blocks. 136 | for each new block, randomly sample a valid father and stack on its father. 137 | 138 | Args: 139 | nr_blocks: The number of blocks in the world. 140 | random_order: Randomly permute the indexes of the blocks if set True. 141 | Or set as a provided order. Leave the raw order unchanged in default. 142 | one_stack: A special case where only one stack of blocks. If True, for each 143 | new node, set its father as the last node. 144 | 145 | Returns: 146 | A BlocksWorld instance which is randomly generated. 147 | """ 148 | blocks = [Block(0, None)] 149 | leafs = [blocks[0]] 150 | 151 | for i in range(1, nr_blocks + 1): 152 | other = random.choice_list(leafs) 153 | this = Block(i) 154 | this.add_to(other) 155 | if not other.placeable or one_stack: 156 | leafs.remove(other) 157 | blocks.append(this) 158 | leafs.append(this) 159 | 160 | order = None 161 | if random_order: 162 | order = random.permutation(len(blocks)) 163 | 164 | return BlocksWorld(blocks, random_order=order) 165 | -------------------------------------------------------------------------------- /difflogic/envs/blocksworld/envs.py: -------------------------------------------------------------------------------- 1 | #! /usr/bin/env python3 2 | # 3 | # Copyright 2018 Google LLC 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # https://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | """The environment class for blocks world tasks.""" 17 | 18 | import numpy as np 19 | 20 | import jacinle.random as random 21 | from jaclearn.rl.env import SimpleRLEnvBase 22 | 23 | from .block import randomly_generate_world 24 | from .represent import get_coordinates 25 | from .represent import decorate 26 | 27 | __all__ = ['FinalBlocksWorldEnv'] 28 | 29 | 30 | class BlocksWorldEnv(SimpleRLEnvBase): 31 | """The Base BlocksWorld environment. 32 | 33 | Args: 34 | nr_blocks: The number of blocks. 35 | random_order: randomly permute the indexes of the blocks. This 36 | option prevents the models from memorizing the configurations. 37 | decorate: if True, the coordinates in the states will also include the 38 | world index (default: 0) and the block index (starting from 0). 39 | prob_unchange: The probability that an action is not effective. 40 | prob_fall: The probability that an action will make the object currently 41 | moving fall on the ground. 42 | """ 43 | 44 | def __init__(self, 45 | nr_blocks, 46 | random_order=False, 47 | decorate=False, 48 | prob_unchange=0.0, 49 | prob_fall=0.0): 50 | super().__init__() 51 | self.nr_blocks = nr_blocks 52 | self.nr_objects = nr_blocks + 1 53 | self.random_order = random_order 54 | self.decorate = decorate 55 | self.prob_unchange = prob_unchange 56 | self.prob_fall = prob_fall 57 | 58 | def _restart(self): 59 | self.world = randomly_generate_world( 60 | self.nr_blocks, random_order=self.random_order) 61 | self._set_current_state(self._get_decorated_states()) 62 | self.is_over = False 63 | self.cached_result = self._get_result() 64 | 65 | def _get_decorated_states(self, world_id=0): 66 | state = get_coordinates(self.world) 67 | if self.decorate: 68 | state = decorate(state, self.nr_objects, world_id) 69 | return state 70 | 71 | 72 | class FinalBlocksWorldEnv(BlocksWorldEnv): 73 | """The BlocksWorld environment for the final task.""" 74 | 75 | def __init__(self, 76 | nr_blocks, 77 | random_order=False, 78 | shape_only=False, 79 | fix_ground=False, 80 | prob_unchange=0.0, 81 | prob_fall=0.0): 82 | super().__init__(nr_blocks, random_order, True, prob_unchange, prob_fall) 83 | self.shape_only = shape_only 84 | self.fix_ground = fix_ground 85 | 86 | def _restart(self): 87 | self.start_world = randomly_generate_world( 88 | self.nr_blocks, random_order=False) 89 | self.final_world = randomly_generate_world( 90 | self.nr_blocks, random_order=False) 91 | self.world = self.start_world 92 | if self.random_order: 93 | n = self.world.size 94 | # Ground is fixed as index 0 if fix_ground is True 95 | ground_ind = 0 if self.fix_ground else random.randint(n) 96 | 97 | def get_order(): 98 | raw_order = random.permutation(n - 1) 99 | order = [] 100 | for i in range(n - 1): 101 | if i == ground_ind: 102 | order.append(0) 103 | order.append(raw_order[i] + 1) 104 | if ground_ind == n - 1: 105 | order.append(0) 106 | return order 107 | 108 | self.start_world.blocks.set_random_order(get_order()) 109 | self.final_world.blocks.set_random_order(get_order()) 110 | 111 | self._prepare_worlds() 112 | self.start_state = decorate( 113 | self._get_coordinates(self.start_world), self.nr_objects, 0) 114 | self.final_state = decorate( 115 | self._get_coordinates(self.final_world), self.nr_objects, 1) 116 | 117 | self.is_over = False 118 | self.cached_result = self._get_result() 119 | 120 | def _prepare_worlds(self): 121 | pass 122 | 123 | def _action(self, action): 124 | assert self.start_world is not None, 'you need to call restart() first' 125 | 126 | if self.is_over: 127 | return 0, True 128 | r, is_over = self.cached_result 129 | if is_over: 130 | self.is_over = True 131 | return r, is_over 132 | 133 | x, y = action 134 | assert 0 <= x <= self.nr_blocks and 0 <= y <= self.nr_blocks 135 | 136 | p = random.rand() 137 | if p >= self.prob_unchange: 138 | if p < self.prob_unchange + self.prob_fall: 139 | y = self.start_world.blocks.inv_index(0) # fall to ground 140 | self.start_world.move(x, y) 141 | self.start_state = decorate( 142 | self._get_coordinates(self.start_world), self.nr_objects, 0) 143 | r, is_over = self._get_result() 144 | if is_over: 145 | self.is_over = True 146 | return r, is_over 147 | 148 | def _get_current_state(self): 149 | assert self.start_world is not None, 'Should call restart() first.' 150 | return np.vstack([self.start_state, self.final_state]) 151 | 152 | def _get_result(self): 153 | sorted_start_state = self._get_coordinates(self.start_world, sort=True) 154 | sorted_final_state = self._get_coordinates(self.final_world, sort=True) 155 | if (sorted_start_state == sorted_final_state).all(): 156 | return 1, True 157 | else: 158 | return 0, False 159 | 160 | def _get_coordinates(self, world, sort=False): 161 | # If shape_only=True, only the shape of the blocks need to be the same. 162 | # If shape_only=False, the index of the blocks should also match. 163 | coordinates = get_coordinates(world, absolute=not self.shape_only) 164 | if sort: 165 | if not self.shape_only: 166 | coordinates = decorate(coordinates, self.nr_objects, 0) 167 | coordinates = np.array(sorted(list(map(tuple, coordinates)))) 168 | return coordinates 169 | -------------------------------------------------------------------------------- /difflogic/envs/blocksworld/quickaccess.py: -------------------------------------------------------------------------------- 1 | #! /usr/bin/env python3 2 | # 3 | # Copyright 2018 Google LLC 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # https://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | """Quick access for blocksworld environments.""" 17 | 18 | from jaclearn.rl.proxy import LimitLengthProxy 19 | 20 | from .envs import FinalBlocksWorldEnv 21 | from ..utils import get_action_mapping_blocksworld 22 | from ..utils import MapActionProxy 23 | 24 | __all__ = ['get_final_env', 'make'] 25 | 26 | 27 | def get_final_env(nr_blocks, 28 | random_order=False, 29 | exclude_self=True, 30 | shape_only=False, 31 | fix_ground=False, 32 | limit_length=None): 33 | """Get the blocksworld environment for the final task.""" 34 | p = FinalBlocksWorldEnv( 35 | nr_blocks, 36 | random_order=random_order, 37 | shape_only=shape_only, 38 | fix_ground=fix_ground) 39 | p = LimitLengthProxy(p, limit_length or nr_blocks * 4) 40 | mapping = get_action_mapping_blocksworld(nr_blocks, exclude_self=exclude_self) 41 | p = MapActionProxy(p, mapping) 42 | return p 43 | 44 | 45 | def make(task, *args, **kwargs): 46 | if task == 'final': 47 | return get_final_env(*args, **kwargs) 48 | else: 49 | raise ValueError('Unknown task: {}.'.format(task)) 50 | -------------------------------------------------------------------------------- /difflogic/envs/blocksworld/represent.py: -------------------------------------------------------------------------------- 1 | #! /usr/bin/env python3 2 | # 3 | # Copyright 2018 Google LLC 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # https://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | """Implement queries for different representations of blocksworld.""" 17 | 18 | import numpy as np 19 | 20 | __all__ = [ 21 | 'get_world_string', 'get_coordinates', 'get_is_ground', 'get_moveable', 22 | 'get_placeable', 'decorate' 23 | ] 24 | 25 | 26 | def get_world_string(world): 27 | """Format the blocks world instance as string to view.""" 28 | index_mapping = {b.index: i for i, b in enumerate(world.blocks)} 29 | raw_blocks = world.blocks.raw 30 | 31 | result = '' 32 | 33 | def dfs(block, indent): 34 | nonlocal result 35 | 36 | result += '{}Block #{}: (IsGround={}, Moveable={}, Placeable={})\n'.format( 37 | ' ' * (indent * 2), index_mapping[block.index], block.is_ground, 38 | block.moveable, block.placeable) 39 | for c in block.children: 40 | dfs(c, indent + 1) 41 | 42 | dfs(raw_blocks[0], 0) 43 | return result 44 | 45 | 46 | def get_coordinates(world, absolute=False): 47 | """Get the coordinates of each block in the blocks world.""" 48 | coordinates = [None for _ in range(world.size)] 49 | raw_blocks = world.blocks.raw 50 | 51 | def dfs(block): 52 | """Use depth-first-search to get the coordinate of each block.""" 53 | if block.is_ground: 54 | coordinates[block.index] = (0, 0) 55 | for j, c in enumerate(block.children): 56 | # When using absolute coordinate, the block x directly placed on the 57 | # ground gets coordinate (x, 1). 58 | x = world.blocks.inv_index(c.index) if absolute else j 59 | coordinates[c.index] = (x, 1) 60 | dfs(c) 61 | else: 62 | coor = coordinates[block.index] 63 | assert coor is not None 64 | x, y = coor 65 | for c in block.children: 66 | coordinates[c.index] = (x, y + 1) 67 | dfs(c) 68 | 69 | dfs(raw_blocks[0]) 70 | coordinates = world.blocks.permute(coordinates) 71 | return np.array(coordinates) 72 | 73 | 74 | def get_is_ground(world): 75 | return np.array([block.is_ground for block in world.blocks]) 76 | 77 | 78 | def get_moveable(world): 79 | return np.array([block.moveable for block in world.blocks]) 80 | 81 | 82 | def get_placeable(world): 83 | return np.array([block.placeable for block in world.blocks]) 84 | 85 | 86 | def decorate(state, nr_objects, world_id=None): 87 | """Append world index and object index information to state.""" 88 | info = [] 89 | if world_id is not None: 90 | info.append(np.ones((nr_objects, 1)) * world_id) 91 | info.extend([np.array(range(nr_objects))[:, np.newaxis], state]) 92 | return np.hstack(info) 93 | -------------------------------------------------------------------------------- /difflogic/envs/graph/__init__.py: -------------------------------------------------------------------------------- 1 | #! /usr/bin/env python3 2 | # 3 | # Copyright 2018 Google LLC 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # https://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | 17 | from .graph import * 18 | from .graph_env import * 19 | from .quickaccess import * 20 | -------------------------------------------------------------------------------- /difflogic/envs/graph/graph.py: -------------------------------------------------------------------------------- 1 | #! /usr/bin/env python3 2 | # 3 | # Copyright 2018 Google LLC 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # https://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | """Implement random graph generators and Graph class.""" 17 | 18 | import copy 19 | import numpy as np 20 | 21 | import jacinle.random as random 22 | 23 | __all__ = ['Graph', 'randomly_generate_graph_er', 'randomly_generate_graph_dnc', 24 | 'get_random_graph_generator'] 25 | 26 | 27 | class Graph(object): 28 | """Store a graph using adjacency matrix. 29 | 30 | Args: 31 | nr_nodes: The number of nodes in the graph. 32 | edges: The adjacency matrix of the graph. 33 | """ 34 | 35 | def __init__(self, nr_nodes, edges, coordinates=None): 36 | edges = edges.astype('int32') 37 | assert edges.min() >= 0 and edges.max() <= 1 38 | self._nr_nodes = nr_nodes 39 | self._edges = edges 40 | self._coordinates = coordinates 41 | self._shortest = None 42 | self.extra_info = {} 43 | 44 | @property 45 | def nr_nodes(self): 46 | return self._nr_nodes 47 | 48 | def get_edges(self): 49 | return copy.copy(self._edges) 50 | 51 | def get_coordinates(self): 52 | return self._coordinates 53 | 54 | def get_relations(self): 55 | """Return edges and identity matrix.""" 56 | return np.stack([self.get_edges(), np.eye(self.nr_nodes)], axis=-1) 57 | 58 | def has_edge(self, x, y): 59 | return self._edges[x, y] == 1 60 | 61 | def get_out_degree(self): 62 | """Return the out degree of each node.""" 63 | return np.sum(self._edges, axis=1) 64 | 65 | def get_shortest(self): 66 | """Return the length of shortest path between nodes.""" 67 | if self._shortest is not None: 68 | return self._shortest 69 | 70 | n = self.nr_nodes 71 | edges = self.get_edges() 72 | 73 | # n + 1 indicates unreachable. 74 | shortest = np.ones((n, n)) * (n + 1) 75 | shortest[np.where(edges == 1)] = 1 76 | # Make sure that shortest[x, x] = 0 77 | shortest -= shortest * np.eye(n) 78 | shortest = shortest.astype('int32') 79 | 80 | # Floyd Algorithm 81 | for k in range(n): 82 | for i in range(n): 83 | for j in range(n): 84 | if i != j: 85 | shortest[i, j] = min(shortest[i, j], 86 | shortest[i, k] + shortest[k, j]) 87 | self._shortest = shortest 88 | return self._shortest 89 | 90 | def get_connectivity(self, k=None, exclude_self=True): 91 | """Calculate the k-connectivity. 92 | 93 | Args: 94 | k: The limited steps. unlimited if k=None or k<0. 95 | exclude_self: remove connectivity[x, x] if exclude_self=True. 96 | Returns: 97 | A numpy.ndarray representing the k-connectivity for each pair of nodes. 98 | """ 99 | shortest = self.get_shortest() 100 | if k is None or k < 0: 101 | k = self.nr_nodes 102 | k = min(k, self.nr_nodes) 103 | conn = (shortest <= k).astype('int32') 104 | if exclude_self: 105 | n = self.nr_nodes 106 | inds = np.where(~np.eye(n, dtype=bool)) 107 | conn = conn[inds] 108 | conn.resize(n, n - 1) 109 | return conn 110 | 111 | 112 | def randomly_generate_graph_er(n, p, directed=False): 113 | """Randomly generate a graph by sampling the existence of each edge. 114 | 115 | Each edge between nodes has the probability $p (directed) or 116 | 1 - (1-$p)^2 (undirected) to exist. 117 | 118 | Args: 119 | n: The number of nodes in the graph. 120 | p: the probability that a edge doesn't exist in directed graph. 121 | directed: Directed or Undirected graph. Default: False (undirected) 122 | 123 | Returns: 124 | A Graph class representing randomly generated graph. 125 | """ 126 | edges = (random.rand(n, n) < p).astype('float') 127 | edges -= edges * np.eye(n) 128 | if not directed: 129 | edges = np.maximum(edges, edges.T) 130 | return Graph(n, edges) 131 | 132 | 133 | def randomly_generate_graph_dnc(n, p=None, directed=False): 134 | """Random graph generation method as in DNC. 135 | 136 | As described in Differentiable neural computers (DNC), 137 | (https://www.nature.com/articles/nature20101.epdf?author_access_token=ImTXBI8aWbYxYQ51Plys8NRgN0jAjWel9jnR3ZoTv0MggmpDmwljGswxVdeocYSurJ3hxupzWuRNeGvvXnoO8o4jTJcnAyhGuZzXJ1GEaD-Z7E6X_a9R-xqJ9TfJWBqz) 138 | Sample $n nodes in a unit square. Then sample out-degree (m) of each nodes, 139 | connect to $m nearest neighbors (Euclidean distance) in the unit square. 140 | 141 | Args: 142 | n: The number of nodes in the graph. 143 | p: Control the sampling of the out-degree. 144 | If p=None, the default range is [1, n // 3]. 145 | If p is float, the range is [1, int(n * p)]. 146 | If p is int, the range is [1, p]. 147 | If p is tuple. the range is [p[0], p[1]]. 148 | directed: Directed or Undirected graph. Default: False (undirected) 149 | 150 | Returns: 151 | A Graph class representing randomly generated graph. 152 | """ 153 | edges = np.zeros((n, n), dtype='float') 154 | pos = random.rand(n, 2) 155 | 156 | def dist(x, y): 157 | return ((x - y)**2).mean() 158 | 159 | if isinstance(p, tuple): 160 | lower, upper = p 161 | else: 162 | lower = 1 163 | if p is None: 164 | upper = n // 3 165 | elif isinstance(p, int): 166 | upper = p 167 | elif isinstance(p, float): 168 | upper = int(n * p) 169 | else: 170 | assert False, 'Unknown argument type: {}'.format(type(p)) 171 | upper = max(upper, 1) 172 | lower = max(lower, 1) 173 | upper = min(upper, n - 1) 174 | 175 | for i in range(n): 176 | d = [] 177 | k = random.randint(upper - lower + 1) + lower 178 | for j in range(n): 179 | if i != j: 180 | d.append((dist(pos[i], pos[j]), j)) 181 | d.sort() 182 | for j in range(k): 183 | edges[i, d[j][1]] = 1 184 | if not directed: 185 | edges = np.maximum(edges, edges.T) 186 | return Graph(n, edges, pos) 187 | 188 | 189 | def get_random_graph_generator(name): 190 | if name == 'dnc': 191 | return randomly_generate_graph_dnc 192 | return randomly_generate_graph_er 193 | -------------------------------------------------------------------------------- /difflogic/envs/graph/graph_env.py: -------------------------------------------------------------------------------- 1 | #! /usr/bin/env python3 2 | # 3 | # Copyright 2018 Google LLC 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # https://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | """The environment class for graph tasks.""" 17 | 18 | import numpy as np 19 | 20 | import jacinle.random as random 21 | from jacinle.utils.meta import notnone_property 22 | from jaclearn.rl.env import SimpleRLEnvBase 23 | 24 | from .graph import get_random_graph_generator 25 | 26 | __all__ = ['GraphEnvBase', 'PathGraphEnv'] 27 | 28 | 29 | class GraphEnvBase(SimpleRLEnvBase): 30 | """The base class for Graph Environment. 31 | 32 | Args: 33 | nr_nodes: The number of nodes in the graph. 34 | pmin: The lower bound of the parameter controlling the graph generation. 35 | pmax: The upper bound of the parameter controlling the graph generation, 36 | the same as $pmin in default. 37 | directed: Generator directed graph if directed=True. 38 | gen_method: Controlling the graph generation method. 39 | If gen_method='dnc', use the similar way as in DNC paper. 40 | Else using Erdos-Renyi algorithm (each edge exists with prob). 41 | """ 42 | 43 | def __init__(self, 44 | nr_nodes, 45 | pmin, 46 | pmax=None, 47 | directed=False, 48 | gen_method='dnc'): 49 | super().__init__() 50 | self._nr_nodes = nr_nodes 51 | self._pmin = pmin 52 | self._pmax = pmin if pmax is None else pmax 53 | self._directed = directed 54 | self._gen_method = gen_method 55 | self._graph = None 56 | 57 | @notnone_property 58 | def graph(self): 59 | return self._graph 60 | 61 | @property 62 | def nr_nodes(self): 63 | return self._nr_nodes 64 | 65 | def _restart(self): 66 | """Restart the environment.""" 67 | self._gen_graph() 68 | 69 | def _gen_graph(self): 70 | """generate the graph by specified method.""" 71 | n = self._nr_nodes 72 | p = self._pmin + random.rand() * (self._pmax - self._pmin) 73 | assert self._gen_method in ['edge', 'dnc'] 74 | gen = get_random_graph_generator(self._gen_method) 75 | self._graph = gen(n, p, self._directed) 76 | 77 | 78 | class PathGraphEnv(GraphEnvBase): 79 | """Env for Finding a path from starting node to the destination.""" 80 | 81 | def __init__(self, 82 | nr_nodes, 83 | dist_range, 84 | pmin, 85 | pmax=None, 86 | directed=False, 87 | gen_method='dnc'): 88 | super().__init__(nr_nodes, pmin, pmax, directed, gen_method) 89 | self._dist_range = dist_range 90 | 91 | @property 92 | def dist(self): 93 | return self._dist 94 | 95 | def _restart(self): 96 | super()._restart() 97 | self._dist = self._sample_dist() 98 | self._task = None 99 | while True: 100 | self._task = self._gen() 101 | if self._task is not None: 102 | break 103 | # Generate another graph if fail to find two nodes with desired distance. 104 | self._gen_graph() 105 | self._current = self._task[0] 106 | self._set_current_state(self._task) 107 | self._steps = 0 108 | 109 | def _sample_dist(self): 110 | """Sample the distance between the starting node and the destination.""" 111 | lower, upper = self._dist_range 112 | upper = min(upper, self._nr_nodes - 1) 113 | return random.randint(upper - lower + 1) + lower 114 | 115 | def _gen(self): 116 | """Sample the starting node and the destination according to the distance.""" 117 | dist_matrix = self._graph.get_shortest() 118 | st, ed = np.where(dist_matrix == self.dist) 119 | if len(st) == 0: 120 | return None 121 | ind = random.randint(len(st)) 122 | return st[ind], ed[ind] 123 | 124 | def _action(self, target): 125 | """Move to the target node from current node if has_edge(current -> target).""" 126 | if self._current == self._task[1]: 127 | return 1, True 128 | if self._graph.has_edge(self._current, target): 129 | self._current = target 130 | self._set_current_state((self._current, self._task[1])) 131 | if self._current == self._task[1]: 132 | return 1, True 133 | self._steps += 1 134 | if self._steps >= self.dist: 135 | return 0, True 136 | return 0, False 137 | -------------------------------------------------------------------------------- /difflogic/envs/graph/quickaccess.py: -------------------------------------------------------------------------------- 1 | #! /usr/bin/env python3 2 | # 3 | # Copyright 2018 Google LLC 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # https://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | """Quick access for graph environments.""" 17 | 18 | from .graph_env import PathGraphEnv 19 | 20 | __all__ = ['get_path_env', 'make'] 21 | 22 | 23 | def get_path_env(n, dist_range, pmin, pmax, directed=False, gen_method='dnc'): 24 | env_cls = PathGraphEnv 25 | p = env_cls( 26 | n, dist_range, pmin, pmax, directed=directed, gen_method=gen_method) 27 | return p 28 | 29 | 30 | def make(task, *args, **kwargs): 31 | if task == 'path': 32 | return get_path_env(*args, **kwargs) 33 | else: 34 | raise ValueError('Unknown task: {}.'.format(task)) 35 | -------------------------------------------------------------------------------- /difflogic/envs/utils.py: -------------------------------------------------------------------------------- 1 | #! /usr/bin/env python3 2 | # 3 | # Copyright 2018 Google LLC 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # https://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | """Utility functions for customer datasets.""" 17 | 18 | from jaclearn.rl.env import ProxyRLEnvBase 19 | from jaclearn.rl.space import DiscreteActionSpace 20 | 21 | __all__ = ['MapActionProxy', 'get_action_mapping', 'get_action_mapping_graph', 22 | 'get_action_mapping_sorting', 'get_action_mapping_blocksworld'] 23 | 24 | 25 | class MapActionProxy(ProxyRLEnvBase): 26 | """RL Env proxy to map actions using provided mapping function.""" 27 | 28 | def __init__(self, other, mapping): 29 | super().__init__(other) 30 | self._mapping = mapping 31 | 32 | @property 33 | def mapping(self): 34 | return self._mapping 35 | 36 | def map_action(self, action): 37 | assert action < len(self._mapping) 38 | return self._mapping[action] 39 | 40 | def _get_action_space(self): 41 | return DiscreteActionSpace(len(self._mapping)) 42 | 43 | def _action(self, action): 44 | return self.proxy.action(self.map_action(action)) 45 | 46 | 47 | def get_action_mapping(n, exclude_self=True): 48 | """In a matrix view, this a mapping from 1d-index to 2d-coordinate.""" 49 | mapping = [ 50 | (i, j) for i in range(n) for j in range(n) if (i != j or not exclude_self) 51 | ] 52 | return mapping 53 | 54 | get_action_mapping_graph = get_action_mapping 55 | get_action_mapping_sorting = get_action_mapping 56 | 57 | 58 | def get_action_mapping_blocksworld(nr_blocks, exclude_self=True): 59 | return get_action_mapping_graph(nr_blocks + 1, exclude_self) 60 | -------------------------------------------------------------------------------- /difflogic/nn/__init__.py: -------------------------------------------------------------------------------- 1 | #! /usr/bin/env python3 2 | # 3 | # Copyright 2018 Google LLC 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # https://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | # This file is part of Jacinle. 17 | -------------------------------------------------------------------------------- /difflogic/nn/baselines/__init__.py: -------------------------------------------------------------------------------- 1 | #! /usr/bin/env python3 2 | # 3 | # Copyright 2018 Google LLC 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # https://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | 17 | from .lstm import * 18 | from .memory_net import * 19 | -------------------------------------------------------------------------------- /difflogic/nn/baselines/lstm.py: -------------------------------------------------------------------------------- 1 | #! /usr/bin/env python3 2 | # 3 | # Copyright 2018 Google LLC 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # https://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | """The LSTM baseline.""" 17 | 18 | import numpy as np 19 | import torch 20 | import torch.nn as nn 21 | 22 | import jacinle.random as random 23 | from difflogic.nn.neural_logic.modules._utils import meshgrid 24 | from jactorch.functional.shape import broadcast 25 | 26 | __all__ = ['LSTMBaseline'] 27 | 28 | 29 | class LSTMBaseline(nn.Module): 30 | """LSTM baseline model.""" 31 | def __init__(self, 32 | input_dim, 33 | feature_dim, 34 | num_layers=2, 35 | hidden_size=512, 36 | code_length=8): 37 | super().__init__() 38 | current_dim = input_dim + code_length * 2 39 | self.feature_dim = feature_dim 40 | assert feature_dim == 1 or feature_dim == 2, ('only support attributes or ' 41 | 'relations') 42 | self.num_layers = num_layers 43 | self.hidden_size = hidden_size 44 | self.code_length = code_length 45 | self.lstm = nn.LSTM( 46 | current_dim, 47 | hidden_size, 48 | num_layers, 49 | batch_first=True, 50 | bidirectional=True) 51 | 52 | def forward(self, relations, attributes=None): 53 | batch_size, nr = relations.size()[:2] 54 | assert nr == relations.size(2) 55 | 56 | id_shape = list(relations.size()[:-1]) 57 | ids = [ 58 | random.permutation(2**self.code_length - 1)[:nr] + 1 59 | for i in range(batch_size) 60 | ] 61 | ids = np.vstack(ids) 62 | binary_ids = self.binarize_code(ids) 63 | zeros = torch.tensor( 64 | np.zeros(binary_ids.shape), 65 | dtype=relations.dtype, 66 | device=relations.device) 67 | binary_ids = torch.tensor( 68 | binary_ids, dtype=relations.dtype, device=relations.device) 69 | binary_ids2 = torch.cat(meshgrid(binary_ids, dim=1), dim=-1) 70 | 71 | if attributes is None: 72 | rels = [binary_ids2, relations] 73 | else: 74 | padding = torch.zeros( 75 | *binary_ids2.size()[:-1], 76 | attributes.size(-1), 77 | dtype=relations.dtype, 78 | device=relations.device) 79 | rels = [binary_ids2, padding, relations] 80 | rels = torch.cat(rels, dim=-1) 81 | input_seq = rels.view(batch_size, -1, rels.size(-1)) 82 | if attributes is not None: 83 | assert nr == attributes.size(1) 84 | padding = torch.zeros( 85 | *binary_ids.size()[:-1], 86 | relations.size(-1), 87 | dtype=relations.dtype, 88 | device=relations.device) 89 | attributes = torch.cat([binary_ids, zeros, attributes, padding], dim=-1) 90 | input_seq = torch.cat([input_seq, attributes], dim=1) 91 | 92 | h0 = torch.zeros( 93 | self.num_layers * 2, 94 | batch_size, 95 | self.hidden_size, 96 | dtype=relations.dtype, 97 | device=relations.device) 98 | c0 = torch.zeros( 99 | self.num_layers * 2, 100 | batch_size, 101 | self.hidden_size, 102 | dtype=relations.dtype, 103 | device=relations.device) 104 | out, _ = self.lstm(input_seq, (h0, c0)) 105 | out = out[:, -1] 106 | 107 | if self.feature_dim == 1: 108 | expanded_feature = broadcast(out.unsqueeze(dim=1), 1, nr) 109 | return torch.cat([binary_ids, expanded_feature], dim=-1) 110 | else: 111 | expanded_feature = broadcast(out.unsqueeze(dim=1), 1, nr) 112 | expanded_feature = broadcast(expanded_feature.unsqueeze(dim=1), 1, nr) 113 | return torch.cat([binary_ids2, expanded_feature], dim=-1) 114 | 115 | def binarize_code(self, x): 116 | m = self.code_length 117 | code = np.zeros((x.shape + (m,))) 118 | for i in range(m)[::-1]: 119 | code[:, :, i] = (x >= 2**i).astype('float') 120 | x = x - code[:, :, i] * 2**i 121 | return code 122 | 123 | def get_output_dim(self): 124 | return self.hidden_size * 2 + self.code_length * self.feature_dim 125 | -------------------------------------------------------------------------------- /difflogic/nn/baselines/memory_net.py: -------------------------------------------------------------------------------- 1 | #! /usr/bin/env python3 2 | # 3 | # Copyright 2018 Google LLC 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # https://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | 17 | import functools 18 | import numpy as np 19 | import torch 20 | import torch.nn as nn 21 | import torch.nn.functional as F 22 | 23 | import jacinle.random as random 24 | from difflogic.nn.neural_logic.modules._utils import meshgrid 25 | from jactorch.quickstart.models import MLPModel 26 | 27 | __all__ = ['MemoryNet'] 28 | 29 | 30 | class MemoryNet(nn.Module): 31 | 32 | def __init__(self, input_dim, feature_dim, queries, hidden_dim, key_dim, 33 | value_dim, id_dim): 34 | super().__init__() 35 | self.feature_dim = feature_dim 36 | assert feature_dim == 1 or feature_dim == 2, \ 37 | 'only support attributes or relations' 38 | self.queries = queries 39 | self.hidden_dim = hidden_dim 40 | self.key_dim = key_dim 41 | self.value_dim = value_dim 42 | self.id_dim = id_dim 43 | 44 | current_dim = id_dim * 2 + input_dim 45 | self.key_embed = MLPModel(current_dim, key_dim, []) 46 | self.value_embed = MLPModel(current_dim, value_dim, []) 47 | self.query_embed = MLPModel(id_dim * 2, key_dim, []) 48 | self.to_query = MLPModel(hidden_dim, key_dim, []) 49 | self.lstm_cell = nn.LSTMCell(value_dim, hidden_dim) 50 | 51 | def forward(self, relations, attributes=None): 52 | batch_size, nr = relations.size()[:2] 53 | assert nr == relations.size(2) 54 | create_zeros = functools.partial( 55 | torch.zeros, dtype=relations.dtype, device=relations.device) 56 | 57 | id_shape = list(relations.size()[:-1]) 58 | ids = [random.permutation(2 ** self.id_dim - 1)[:nr] + 1 \ 59 | for i in range(batch_size)] 60 | ids = np.vstack(ids) 61 | binary_ids = self.binarize_code(ids) 62 | zeros = create_zeros(binary_ids.shape) 63 | binary_ids = torch.tensor( 64 | binary_ids, dtype=relations.dtype, device=relations.device) 65 | binary_ids2 = torch.cat(meshgrid(binary_ids, dim=1), dim=-1) 66 | padded_binary_ids = torch.cat([binary_ids, zeros], dim=-1) 67 | 68 | def embed(embed, x): 69 | input_size = x.size()[:-1] 70 | input_channel = x.size(-1) 71 | f = x.view(-1, input_channel) 72 | f = embed(f) 73 | return f.view(*input_size, -1) 74 | 75 | if attributes is None: 76 | rels = [binary_ids2, relations] 77 | else: 78 | padding = create_zeros(*binary_ids2.size()[:-1], attributes.size(-1)) 79 | rels = [binary_ids2, padding, relations] 80 | rels = torch.cat(rels, dim=-1) 81 | memory = rels.view(batch_size, -1, rels.size(-1)) 82 | if attributes is not None: 83 | assert nr == attributes.size(1) 84 | padding = create_zeros(*padded_binary_ids.size()[:-1], relations.size(-1)) 85 | attributes = torch.cat([padded_binary_ids, attributes, padding], dim=-1) 86 | memory = torch.cat([memory, attributes], dim=1) 87 | keys = embed(self.key_embed, memory).transpose(1, 2) 88 | values = embed(self.value_embed, memory) 89 | 90 | query = padded_binary_ids if self.feature_dim == 1 else binary_ids2 91 | nr_items = nr**self.feature_dim 92 | query = embed(self.query_embed, query).view(batch_size, nr_items, -1) 93 | 94 | h0 = create_zeros(batch_size * nr_items, self.hidden_dim) 95 | c0 = create_zeros(batch_size * nr_items, self.hidden_dim) 96 | for i in range(self.queries): 97 | attention = F.softmax(torch.bmm(query, keys), dim=-1) 98 | value = torch.bmm(attention, values) 99 | value = value.view(-1, value.size(-1)) 100 | 101 | h0, c0 = self.lstm_cell(value, (h0, c0)) 102 | query = self.to_query(h0).view(batch_size, nr_items, self.key_dim) 103 | 104 | if self.feature_dim == 1: 105 | out = h0.view(batch_size, nr, self.hidden_dim) 106 | else: 107 | out = h0.view(batch_size, nr, nr, self.hidden_dim) 108 | return out 109 | 110 | def binarize_code(self, x): 111 | m = self.id_dim 112 | code = np.zeros((x.shape + (m,))) 113 | for i in range(m)[::-1]: 114 | code[:, :, i] = (x >= 2**i).astype('float') 115 | x = x - code[:, :, i] * 2**i 116 | return code 117 | 118 | def get_output_dim(self): 119 | return self.hidden_dim 120 | 121 | __hyperparams__ = ('queries', 'hidden_dim', 'key_dim', 'value_dim', 'id_dim') 122 | 123 | __hyperparam_defaults__ = { 124 | 'queries': 4, 125 | 'hidden_dim': 64, 126 | 'key_dim': 16, 127 | 'value_dim': 32, 128 | 'id_dim': 8, 129 | } 130 | 131 | @classmethod 132 | def make_memnet_parser(cls, parser, defaults, prefix=None): 133 | for k, v in cls.__hyperparam_defaults__.items(): 134 | defaults.setdefault(k, v) 135 | 136 | if prefix is None: 137 | prefix = '--' 138 | else: 139 | prefix = '--' + str(prefix) + '-' 140 | 141 | parser.add_argument( 142 | prefix + 'queries', 143 | type=int, 144 | default=defaults['queries'], 145 | metavar='N', 146 | help='number of queries') 147 | parser.add_argument( 148 | prefix + 'hidden-dim', 149 | type=int, 150 | default=defaults['hidden_dim'], 151 | metavar='N', 152 | help='hidden dimension of LSTM cell') 153 | parser.add_argument( 154 | prefix + 'key-dim', 155 | type=int, 156 | default=defaults['key_dim'], 157 | metavar='N', 158 | help='dimension of key vector') 159 | parser.add_argument( 160 | prefix + 'value-dim', 161 | type=int, 162 | default=defaults['value_dim'], 163 | metavar='N', 164 | help='dimension of value vector') 165 | parser.add_argument( 166 | prefix + 'id-dim', 167 | type=int, 168 | default=defaults['id_dim'], 169 | metavar='N', 170 | help='dimension of id vector') 171 | 172 | @classmethod 173 | def from_args(cls, input_dim, feature_dim, args, prefix=None, **kwargs): 174 | if prefix is None: 175 | prefix = '' 176 | else: 177 | prefix = str(prefix) + '_' 178 | 179 | init_params = {k: getattr(args, prefix + k) for k in cls.__hyperparams__} 180 | init_params.update(kwargs) 181 | 182 | return cls(input_dim, feature_dim, **init_params) 183 | -------------------------------------------------------------------------------- /difflogic/nn/neural_logic/__init__.py: -------------------------------------------------------------------------------- 1 | #! /usr/bin/env python3 2 | # 3 | # Copyright 2018 Google LLC 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # https://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | 17 | from .modules.input_transform import * 18 | from .modules.neural_logic import * 19 | 20 | from .layer import * 21 | -------------------------------------------------------------------------------- /difflogic/nn/neural_logic/layer.py: -------------------------------------------------------------------------------- 1 | #! /usr/bin/env python3 2 | # 3 | # Copyright 2018 Google LLC 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # https://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | """Implement Neural Logic Layers and Machines.""" 17 | 18 | import torch 19 | import torch.nn as nn 20 | 21 | from jacinle.logging import get_logger 22 | 23 | from .modules.dimension import Expander, Reducer, Permutation 24 | from .modules.neural_logic import LogicInference 25 | 26 | __all__ = ['LogicLayer', 'LogicMachine'] 27 | 28 | logger = get_logger(__file__) 29 | 30 | 31 | def _get_tuple_n(x, n, tp): 32 | """Get a length-n list of type tp.""" 33 | assert tp is not list 34 | if isinstance(x, tp): 35 | x = [x,] * n 36 | assert len(x) == n, 'Parameters should be {} or list of N elements.'.format( 37 | tp) 38 | for i in x: 39 | assert isinstance(i, tp), 'Elements of list should be {}.'.format(tp) 40 | return x 41 | 42 | 43 | class LogicLayer(nn.Module): 44 | """Logic Layers do one-step differentiable logic deduction. 45 | 46 | The predicates grouped by their number of variables. The inter group deduction 47 | is done by expansion/reduction, the intra group deduction is done by logic 48 | model. 49 | 50 | Args: 51 | breadth: The breadth of the logic layer. 52 | input_dims: the number of input channels of each input group, should consist 53 | with the inputs. use dims=0 and input=None to indicate no input 54 | of that group. 55 | output_dims: the number of output channels of each group, could 56 | use a single value. 57 | logic_hidden_dim: The hidden dim of the logic model. 58 | exclude_self: Not allow multiple occurrence of same variable when 59 | being True. 60 | residual: Use residual connections when being True. 61 | """ 62 | 63 | def __init__( 64 | self, 65 | breadth, 66 | input_dims, 67 | output_dims, 68 | logic_hidden_dim, 69 | exclude_self=True, 70 | residual=False, 71 | ): 72 | super().__init__() 73 | assert breadth > 0, 'Does not support breadth <= 0.' 74 | if breadth > 3: 75 | logger.warn( 76 | 'Using LogicLayer with breadth > 3 may cause speed and memory issue.') 77 | 78 | self.max_order = breadth 79 | self.residual = residual 80 | 81 | input_dims = _get_tuple_n(input_dims, self.max_order + 1, int) 82 | output_dims = _get_tuple_n(output_dims, self.max_order + 1, int) 83 | 84 | self.logic, self.dim_perms, self.dim_expanders, self.dim_reducers = [ 85 | nn.ModuleList() for _ in range(4) 86 | ] 87 | for i in range(self.max_order + 1): 88 | # collect current_dim from group i-1, i and i+1. 89 | current_dim = input_dims[i] 90 | if i > 0: 91 | expander = Expander(i - 1) 92 | self.dim_expanders.append(expander) 93 | current_dim += expander.get_output_dim(input_dims[i - 1]) 94 | else: 95 | self.dim_expanders.append(None) 96 | 97 | if i + 1 < self.max_order + 1: 98 | reducer = Reducer(i + 1, exclude_self) 99 | self.dim_reducers.append(reducer) 100 | current_dim += reducer.get_output_dim(input_dims[i + 1]) 101 | else: 102 | self.dim_reducers.append(None) 103 | 104 | if current_dim == 0: 105 | self.dim_perms.append(None) 106 | self.logic.append(None) 107 | output_dims[i] = 0 108 | else: 109 | perm = Permutation(i) 110 | self.dim_perms.append(perm) 111 | current_dim = perm.get_output_dim(current_dim) 112 | self.logic.append( 113 | LogicInference(current_dim, output_dims[i], logic_hidden_dim)) 114 | 115 | self.input_dims = input_dims 116 | self.output_dims = output_dims 117 | 118 | if self.residual: 119 | for i in range(len(input_dims)): 120 | self.output_dims[i] += input_dims[i] 121 | 122 | def forward(self, inputs): 123 | assert len(inputs) == self.max_order + 1 124 | outputs = [] 125 | for i in range(self.max_order + 1): 126 | # collect input f from group i-1, i and i+1. 127 | f = [] 128 | if i > 0 and self.input_dims[i - 1] > 0: 129 | n = inputs[i].size(1) if i == 1 else None 130 | f.append(self.dim_expanders[i](inputs[i - 1], n)) 131 | if i < len(inputs) and self.input_dims[i] > 0: 132 | f.append(inputs[i]) 133 | if i + 1 < len(inputs) and self.input_dims[i + 1] > 0: 134 | f.append(self.dim_reducers[i](inputs[i + 1])) 135 | if len(f) == 0: 136 | output = None 137 | else: 138 | f = torch.cat(f, dim=-1) 139 | f = self.dim_perms[i](f) 140 | output = self.logic[i](f) 141 | if self.residual and self.input_dims[i] > 0: 142 | output = torch.cat([inputs[i], output], dim=-1) 143 | outputs.append(output) 144 | return outputs 145 | 146 | __hyperparams__ = ( 147 | 'breadth', 148 | 'input_dims', 149 | 'output_dims', 150 | 'logic_hidden_dim', 151 | 'exclude_self', 152 | 'residual', 153 | ) 154 | 155 | __hyperparam_defaults__ = { 156 | 'exclude_self': True, 157 | 'residual': False, 158 | } 159 | 160 | @classmethod 161 | def make_nlm_parser(cls, parser, defaults, prefix=None): 162 | for k, v in cls.__hyperparam_defaults__.items(): 163 | defaults.setdefault(k, v) 164 | 165 | if prefix is None: 166 | prefix = '--' 167 | else: 168 | prefix = '--' + str(prefix) + '-' 169 | 170 | parser.add_argument( 171 | prefix + 'breadth', 172 | type='int', 173 | default=defaults['breadth'], 174 | metavar='N', 175 | help='breadth of the logic layer') 176 | parser.add_argument( 177 | prefix + 'logic-hidden-dim', 178 | type=int, 179 | nargs='+', 180 | default=defaults['logic_hidden_dim'], 181 | metavar='N', 182 | help='hidden dim of the logic model') 183 | parser.add_argument( 184 | prefix + 'exclude-self', 185 | type='bool', 186 | default=defaults['exclude_self'], 187 | metavar='B', 188 | help='not allow multiple occurrence of same variable') 189 | parser.add_argument( 190 | prefix + 'residual', 191 | type='bool', 192 | default=defaults['residual'], 193 | metavar='B', 194 | help='use residual connections') 195 | 196 | @classmethod 197 | def from_args(cls, input_dims, output_dims, args, prefix=None, **kwargs): 198 | if prefix is None: 199 | prefix = '' 200 | else: 201 | prefix = str(prefix) + '_' 202 | 203 | setattr(args, prefix + 'input_dims', input_dims) 204 | setattr(args, prefix + 'output_dims', output_dims) 205 | init_params = {k: getattr(args, prefix + k) for k in cls.__hyperparams__} 206 | init_params.update(kwargs) 207 | 208 | return cls(**init_params) 209 | 210 | 211 | class LogicMachine(nn.Module): 212 | """Neural Logic Machine consists of multiple logic layers.""" 213 | 214 | def __init__( 215 | self, 216 | depth, 217 | breadth, 218 | input_dims, 219 | output_dims, 220 | logic_hidden_dim, 221 | exclude_self=True, 222 | residual=False, 223 | io_residual=False, 224 | recursion=False, 225 | connections=None, 226 | ): 227 | super().__init__() 228 | self.depth = depth 229 | self.breadth = breadth 230 | self.residual = residual 231 | self.io_residual = io_residual 232 | self.recursion = recursion 233 | self.connections = connections 234 | 235 | assert not (self.residual and self.io_residual), \ 236 | 'Only one type of residual connection is allowed at the same time.' 237 | 238 | # element-wise addition for vector 239 | def add_(x, y): 240 | for i in range(len(y)): 241 | x[i] += y[i] 242 | return x 243 | 244 | self.layers = nn.ModuleList() 245 | current_dims = input_dims 246 | total_output_dims = [0 for _ in range(self.breadth + 1) 247 | ] # for IO residual only 248 | for i in range(depth): 249 | # IO residual is unused. 250 | if i > 0 and io_residual: 251 | add_(current_dims, input_dims) 252 | # Not support output_dims as list or list[list] yet. 253 | layer = LogicLayer(breadth, current_dims, output_dims, logic_hidden_dim, 254 | exclude_self, residual) 255 | current_dims = layer.output_dims 256 | current_dims = self._mask(current_dims, i, 0) 257 | if io_residual: 258 | add_(total_output_dims, current_dims) 259 | self.layers.append(layer) 260 | 261 | if io_residual: 262 | self.output_dims = total_output_dims 263 | else: 264 | self.output_dims = current_dims 265 | 266 | # Mask out the specific group-entry in layer i, specified by self.connections. 267 | # For debug usage. 268 | def _mask(self, a, i, masked_value): 269 | if self.connections is not None: 270 | assert i < len(self.connections) 271 | mask = self.connections[i] 272 | if mask is not None: 273 | assert len(mask) == len(a) 274 | a = [x if y else masked_value for x, y in zip(a, mask)] 275 | return a 276 | 277 | def forward(self, inputs, depth=None): 278 | outputs = [None for _ in range(self.breadth + 1)] 279 | f = inputs 280 | 281 | # depth: the actual depth used for inference 282 | if depth is None: 283 | depth = self.depth 284 | if not self.recursion: 285 | depth = min(depth, self.depth) 286 | 287 | def merge(x, y): 288 | if x is None: 289 | return y 290 | if y is None: 291 | return x 292 | return torch.cat([x, y], dim=-1) 293 | 294 | layer = None 295 | last_layer = None 296 | for i in range(depth): 297 | if i > 0 and self.io_residual: 298 | for j, inp in enumerate(inputs): 299 | f[j] = merge(f[j], inp) 300 | # To enable recursion, use scroll variables layer/last_layer 301 | # For weight sharing of period 2, i.e. 0,1,2,1,2,1,2,... 302 | if self.recursion and i >= 3: 303 | assert not self.residual 304 | layer, last_layer = last_layer, layer 305 | else: 306 | last_layer = layer 307 | layer = self.layers[i] 308 | 309 | f = layer(f) 310 | f = self._mask(f, i, None) 311 | if self.io_residual: 312 | for j, out in enumerate(f): 313 | outputs[j] = merge(outputs[j], out) 314 | if not self.io_residual: 315 | outputs = f 316 | return outputs 317 | 318 | __hyperparams__ = ( 319 | 'depth', 320 | 'breadth', 321 | 'input_dims', 322 | 'output_dims', 323 | 'logic_hidden_dim', 324 | 'exclude_self', 325 | 'io_residual', 326 | 'residual', 327 | 'recursion', 328 | ) 329 | 330 | __hyperparam_defaults__ = { 331 | 'exclude_self': True, 332 | 'io_residual': False, 333 | 'residual': False, 334 | 'recursion': False, 335 | } 336 | 337 | @classmethod 338 | def make_nlm_parser(cls, parser, defaults, prefix=None): 339 | for k, v in cls.__hyperparam_defaults__.items(): 340 | defaults.setdefault(k, v) 341 | 342 | if prefix is None: 343 | prefix = '--' 344 | else: 345 | prefix = '--' + str(prefix) + '-' 346 | 347 | parser.add_argument( 348 | prefix + 'depth', 349 | type=int, 350 | default=defaults['depth'], 351 | metavar='N', 352 | help='depth of the logic machine') 353 | parser.add_argument( 354 | prefix + 'breadth', 355 | type=int, 356 | default=defaults['breadth'], 357 | metavar='N', 358 | help='breadth of the logic machine') 359 | parser.add_argument( 360 | prefix + 'logic-hidden-dim', 361 | type=int, 362 | nargs='+', 363 | default=defaults['logic_hidden_dim'], 364 | metavar='N', 365 | help='hidden dim of the logic model') 366 | parser.add_argument( 367 | prefix + 'exclude-self', 368 | type='bool', 369 | default=defaults['exclude_self'], 370 | metavar='B', 371 | help='not allow multiple occurrence of same variable') 372 | parser.add_argument( 373 | prefix + 'io-residual', 374 | type='bool', 375 | default=defaults['io_residual'], 376 | metavar='B', 377 | help='use input/output-only residual connections') 378 | parser.add_argument( 379 | prefix + 'residual', 380 | type='bool', 381 | default=defaults['residual'], 382 | metavar='B', 383 | help='use residual connections') 384 | parser.add_argument( 385 | prefix + 'recursion', 386 | type='bool', 387 | default=defaults['recursion'], 388 | metavar='B', 389 | help='use recursion weight sharing') 390 | 391 | @classmethod 392 | def from_args(cls, input_dims, output_dims, args, prefix=None, **kwargs): 393 | if prefix is None: 394 | prefix = '' 395 | else: 396 | prefix = str(prefix) + '_' 397 | 398 | setattr(args, prefix + 'input_dims', input_dims) 399 | setattr(args, prefix + 'output_dims', output_dims) 400 | init_params = {k: getattr(args, prefix + k) for k in cls.__hyperparams__} 401 | init_params.update(kwargs) 402 | 403 | return cls(**init_params) 404 | -------------------------------------------------------------------------------- /difflogic/nn/neural_logic/modules/__init__.py: -------------------------------------------------------------------------------- 1 | #! /usr/bin/env python3 2 | # 3 | # Copyright 2018 Google LLC 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # https://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | # This file is part of DifferentiableLogic-PyTorch. 17 | -------------------------------------------------------------------------------- /difflogic/nn/neural_logic/modules/_utils.py: -------------------------------------------------------------------------------- 1 | #! /usr/bin/env python3 2 | # 3 | # Copyright 2018 Google LLC 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # https://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | """Utility functions for tensor masking.""" 17 | 18 | import torch 19 | 20 | import torch.autograd as ag 21 | from jactorch.functional import meshgrid, meshgrid_exclude_self 22 | 23 | __all__ = ['meshgrid', 'meshgrid_exclude_self', 'exclude_mask', 'mask_value'] 24 | 25 | 26 | def exclude_mask(inputs, cnt=2, dim=1): 27 | """Produce an exclusive mask. 28 | 29 | Specifically, for cnt=2, given an array a[i, j] of n * n, it produces 30 | a mask with size n * n where only a[i, j] = 1 if and only if (i != j). 31 | 32 | Args: 33 | inputs: The tensor to be masked. 34 | cnt: The operation is performed over [dim, dim + cnt) axes. 35 | dim: The starting dimension for the exclusive mask. 36 | 37 | Returns: 38 | A mask that make sure the coordinates are mutually exclusive. 39 | """ 40 | assert cnt > 0 41 | if dim < 0: 42 | dim += inputs.dim() 43 | n = inputs.size(dim) 44 | for i in range(1, cnt): 45 | assert n == inputs.size(dim + i) 46 | 47 | rng = torch.arange(0, n, dtype=torch.long, device=inputs.device) 48 | q = [] 49 | for i in range(cnt): 50 | p = rng 51 | for j in range(cnt): 52 | if i != j: 53 | p = p.unsqueeze(j) 54 | p = p.expand((n,) * cnt) 55 | q.append(p) 56 | mask = q[0] == q[0] 57 | # Mutually Exclusive 58 | for i in range(cnt): 59 | for j in range(cnt): 60 | if i != j: 61 | mask *= q[i] != q[j] 62 | for i in range(dim): 63 | mask.unsqueeze_(0) 64 | for j in range(inputs.dim() - dim - cnt): 65 | mask.unsqueeze_(-1) 66 | 67 | return mask.expand(inputs.size()).float() 68 | 69 | 70 | def mask_value(inputs, mask, value): 71 | assert inputs.size() == mask.size() 72 | return inputs * mask + value * (1 - mask) 73 | -------------------------------------------------------------------------------- /difflogic/nn/neural_logic/modules/dimension.py: -------------------------------------------------------------------------------- 1 | #! /usr/bin/env python3 2 | # 3 | # Copyright 2018 Google LLC 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # https://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | """The dimension related operations.""" 17 | 18 | import itertools 19 | 20 | import torch 21 | import torch.nn as nn 22 | 23 | from jactorch.functional import broadcast 24 | 25 | from ._utils import exclude_mask, mask_value 26 | 27 | __all__ = ['Expander', 'Reducer', 'Permutation'] 28 | 29 | 30 | class Expander(nn.Module): 31 | """Capture a free variable into predicates, implemented by broadcast.""" 32 | 33 | def __init__(self, dim): 34 | super().__init__() 35 | self.dim = dim 36 | 37 | def forward(self, inputs, n=None): 38 | if self.dim == 0: 39 | assert n is not None 40 | elif n is None: 41 | n = inputs.size(self.dim) 42 | dim = self.dim + 1 43 | return broadcast(inputs.unsqueeze(dim), dim, n) 44 | 45 | def get_output_dim(self, input_dim): 46 | return input_dim 47 | 48 | 49 | class Reducer(nn.Module): 50 | """Reduce out a variable via quantifiers (exists/forall), implemented by max/min-pooling.""" 51 | 52 | def __init__(self, dim, exclude_self=True, exists=True): 53 | super().__init__() 54 | self.dim = dim 55 | self.exclude_self = exclude_self 56 | self.exists = exists 57 | 58 | def forward(self, inputs): 59 | shape = inputs.size() 60 | inp0, inp1 = inputs, inputs 61 | if self.exclude_self: 62 | mask = exclude_mask(inputs, cnt=self.dim, dim=-1 - self.dim) 63 | inp0 = mask_value(inputs, mask, 0.0) 64 | inp1 = mask_value(inputs, mask, 1.0) 65 | 66 | if self.exists: 67 | shape = shape[:-2] + (shape[-1] * 2,) 68 | exists = torch.max(inp0, dim=-2)[0] 69 | forall = torch.min(inp1, dim=-2)[0] 70 | return torch.stack((exists, forall), dim=-1).view(shape) 71 | 72 | shape = shape[:-2] + (shape[-1],) 73 | return torch.max(inp0, dim=-2)[0].view(shape) 74 | 75 | def get_output_dim(self, input_dim): 76 | if self.exists: 77 | return input_dim * 2 78 | return input_dim 79 | 80 | 81 | class Permutation(nn.Module): 82 | """Create r! new predicates by permuting the axies for r-arity predicates.""" 83 | 84 | def __init__(self, dim): 85 | super().__init__() 86 | self.dim = dim 87 | 88 | def forward(self, inputs): 89 | if self.dim <= 1: 90 | return inputs 91 | nr_dims = len(inputs.size()) 92 | # Assume the last dim is channel. 93 | index = tuple(range(nr_dims - 1)) 94 | start_dim = nr_dims - 1 - self.dim 95 | assert start_dim > 0 96 | res = [] 97 | for i in itertools.permutations(index[start_dim:]): 98 | p = index[:start_dim] + i + (nr_dims - 1,) 99 | res.append(inputs.permute(p)) 100 | return torch.cat(res, dim=-1) 101 | 102 | def get_output_dim(self, input_dim): 103 | mul = 1 104 | for i in range(self.dim): 105 | mul *= i + 1 106 | return input_dim * mul 107 | -------------------------------------------------------------------------------- /difflogic/nn/neural_logic/modules/input_transform.py: -------------------------------------------------------------------------------- 1 | #! /usr/bin/env python3 2 | # 3 | # Copyright 2018 Google LLC 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # https://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | """Implement transformation for input tensors.""" 17 | 18 | import torch 19 | import torch.nn as nn 20 | 21 | from jacinle.utils.enum import JacEnum 22 | 23 | from ._utils import meshgrid, meshgrid_exclude_self 24 | 25 | __all__ = ['InputTransformMethod', 'InputTransform'] 26 | 27 | 28 | class InputTransformMethod(JacEnum): 29 | CONCAT = 'concat' 30 | DIFF = 'diff' 31 | CMP = 'cmp' 32 | 33 | 34 | class InputTransform(nn.Module): 35 | """Transform the unary predicates to binary predicates by operations.""" 36 | 37 | def __init__(self, method, exclude_self=True): 38 | super().__init__() 39 | self.method = InputTransformMethod.from_string(method) 40 | self.exclude_self = exclude_self 41 | 42 | def forward(self, inputs): 43 | assert inputs.dim() == 3 44 | 45 | x, y = meshgrid(inputs, dim=1) 46 | 47 | if self.method is InputTransformMethod.CONCAT: 48 | combined = torch.cat((x, y), dim=3) 49 | elif self.method is InputTransformMethod.DIFF: 50 | combined = x - y 51 | elif self.method is InputTransformMethod.CMP: 52 | combined = torch.cat([x < y, x == y, x > y], dim=3) 53 | else: 54 | raise ValueError('Unknown input transform method: {}.'.format( 55 | self.method)) 56 | 57 | if self.exclude_self: 58 | combined = meshgrid_exclude_self(combined, dim=1) 59 | return combined.float() 60 | 61 | def get_output_dim(self, input_dim): 62 | if self.method is InputTransformMethod.CONCAT: 63 | return input_dim * 2 64 | elif self.method is InputTransformMethod.DIFF: 65 | return input_dim 66 | elif self.method is InputTransformMethod.CMP: 67 | return input_dim * 3 68 | else: 69 | raise ValueError('Unknown input transform method: {}.'.format( 70 | self.method)) 71 | 72 | def __repr__(self): 73 | return '{name}({method}, exclude_self={exclude_self})'.format( 74 | name=self.__class__.__name__, **self.__dict__) 75 | -------------------------------------------------------------------------------- /difflogic/nn/neural_logic/modules/neural_logic.py: -------------------------------------------------------------------------------- 1 | #! /usr/bin/env python3 2 | # 3 | # Copyright 2018 Google LLC 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # https://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | """MLP-based implementation for logic and logits inference.""" 17 | 18 | import torch.nn as nn 19 | 20 | from jactorch.quickstart.models import MLPModel 21 | 22 | __all__ = ['LogicInference', 'LogitsInference'] 23 | 24 | 25 | class InferenceBase(nn.Module): 26 | """MLP model with shared parameters among other axies except the channel axis.""" 27 | 28 | def __init__(self, input_dim, output_dim, hidden_dim): 29 | super().__init__() 30 | self.input_dim = input_dim 31 | self.output_dim = output_dim 32 | self.hidden_dim = hidden_dim 33 | self.layer = nn.Sequential(MLPModel(input_dim, output_dim, hidden_dim)) 34 | 35 | def forward(self, inputs): 36 | input_size = inputs.size()[:-1] 37 | input_channel = inputs.size(-1) 38 | 39 | f = inputs.view(-1, input_channel) 40 | f = self.layer(f) 41 | f = f.view(*input_size, -1) 42 | return f 43 | 44 | def get_output_dim(self, input_dim): 45 | return self.output_dim 46 | 47 | 48 | class LogicInference(InferenceBase): 49 | """MLP layer with sigmoid activation.""" 50 | 51 | def __init__(self, input_dim, output_dim, hidden_dim): 52 | super().__init__(input_dim, output_dim, hidden_dim) 53 | self.layer.add_module(str(len(self.layer)), nn.Sigmoid()) 54 | 55 | 56 | class LogitsInference(InferenceBase): 57 | pass 58 | -------------------------------------------------------------------------------- /difflogic/nn/rl/__init__.py: -------------------------------------------------------------------------------- 1 | #! /usr/bin/env python3 2 | # 3 | # Copyright 2018 Google LLC 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # https://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | 17 | from .reinforce import * 18 | -------------------------------------------------------------------------------- /difflogic/nn/rl/reinforce.py: -------------------------------------------------------------------------------- 1 | #! /usr/bin/env python3 2 | # 3 | # Copyright 2018 Google LLC 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # https://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | """Implement REINFORCE loss.""" 17 | 18 | import torch.nn as nn 19 | 20 | __all__ = ['REINFORCELoss'] 21 | 22 | 23 | class REINFORCELoss(nn.Module): 24 | """Implement the loss function for REINFORCE algorithm.""" 25 | 26 | def __init__(self, entropy_beta=None): 27 | super().__init__() 28 | self.nll = nn.NLLLoss(reduce=False) 29 | self.entropy_beta = entropy_beta 30 | 31 | def forward(self, policy, action, discount_reward, entropy_beta=None): 32 | monitors = dict() 33 | entropy = -(policy * policy.log()).sum(dim=1).mean() 34 | nll = self.nll(policy, action) 35 | loss = (nll * discount_reward).mean() 36 | if entropy_beta is None: 37 | entropy_beta = self.entropy_beta 38 | if entropy_beta is not None: 39 | monitors['reinforce_loss'] = loss 40 | monitors['entropy_loss'] = -entropy * entropy_beta 41 | loss -= entropy * entropy_beta 42 | monitors['entropy'] = entropy 43 | return loss, monitors 44 | -------------------------------------------------------------------------------- /difflogic/thutils.py: -------------------------------------------------------------------------------- 1 | #! /usr/bin/env python3 2 | # 3 | # Copyright 2018 Google LLC 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # https://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | """Utility functions for PyTorch.""" 17 | 18 | import torch 19 | import torch.nn.functional as F 20 | 21 | from jactorch.utils.meta import as_float 22 | from jactorch.utils.meta import as_tensor 23 | 24 | __all__ = [ 25 | 'binary_accuracy', 'rms', 'monitor_saturation', 'monitor_paramrms', 26 | 'monitor_gradrms' 27 | ] 28 | 29 | 30 | def binary_accuracy(label, raw_pred, eps=1e-20, return_float=True): 31 | """get accuracy for binary classification problem.""" 32 | pred = as_tensor(raw_pred).squeeze(-1) 33 | pred = (pred > 0.5).float() 34 | label = as_tensor(label).float() 35 | # The $acc is micro accuracy = the correct ones / total 36 | acc = label.eq(pred).float() 37 | 38 | # The $balanced_accuracy is macro accuracy, with class-wide balance. 39 | nr_total = torch.ones( 40 | label.size(), dtype=label.dtype, device=label.device).sum(dim=-1) 41 | nr_pos = label.sum(dim=-1) 42 | nr_neg = nr_total - nr_pos 43 | pos_cnt = (acc * label).sum(dim=-1) 44 | neg_cnt = acc.sum(dim=-1) - pos_cnt 45 | balanced_acc = ((pos_cnt + eps) / (nr_pos + eps) + (neg_cnt + eps) / 46 | (nr_neg + eps)) / 2.0 47 | 48 | # $sat means the saturation rate of the predication, 49 | # measure how close the predections are to 0 or 1. 50 | sat = 1 - (raw_pred - pred).abs() 51 | if return_float: 52 | acc = as_float(acc.mean()) 53 | balanced_acc = as_float(balanced_acc.mean()) 54 | sat_mean = as_float(sat.mean()) 55 | sat_min = as_float(sat.min()) 56 | else: 57 | sat_mean = sat.mean(dim=-1) 58 | sat_min = sat.min(dim=-1)[0] 59 | 60 | return { 61 | 'accuracy': acc, 62 | 'balanced_accuracy': balanced_acc, 63 | 'satuation/mean': sat_mean, 64 | 'satuation/min': sat_min, 65 | } 66 | 67 | 68 | def rms(p): 69 | """Root mean square function.""" 70 | return as_float((as_tensor(p)**2).mean()**0.5) 71 | 72 | 73 | def monitor_saturation(model): 74 | """Monitor the saturation rate.""" 75 | monitors = {} 76 | for name, p in model.named_parameters(): 77 | p = F.sigmoid(p) 78 | sat = 1 - (p - (p > 0.5).float()).abs() 79 | monitors['sat/' + name] = sat 80 | return monitors 81 | 82 | 83 | def monitor_paramrms(model): 84 | """Monitor the rms of the parameters.""" 85 | monitors = {} 86 | for name, p in model.named_parameters(): 87 | monitors['paramrms/' + name] = rms(p) 88 | return monitors 89 | 90 | 91 | def monitor_gradrms(model): 92 | """Monitor the rms of the gradients of the parameters.""" 93 | monitors = {} 94 | for name, p in model.named_parameters(): 95 | if p.grad is not None: 96 | monitors['gradrms/' + name] = rms(p.grad) / max(rms(p), 1e-8) 97 | return monitors 98 | -------------------------------------------------------------------------------- /difflogic/tqdm_utils.py: -------------------------------------------------------------------------------- 1 | #! /usr/bin/env python3 2 | # 3 | # Copyright 2018 Google LLC 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # https://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | """The utility functions for tqdm.""" 17 | 18 | from jacinle.utils.tqdm import tqdm_pbar 19 | 20 | __all__ = ['tqdm_for'] 21 | 22 | 23 | def tqdm_for(total, func): 24 | """wrapper of the for function with message showing on the progress bar.""" 25 | # Not support break cases for now. 26 | with tqdm_pbar(total=total) as pbar: 27 | for i in range(total): 28 | message = func(i) 29 | pbar.set_description(message) 30 | pbar.update() 31 | -------------------------------------------------------------------------------- /difflogic/train/__init__.py: -------------------------------------------------------------------------------- 1 | #! /usr/bin/env python3 2 | # 3 | # Copyright 2018 Google LLC 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # https://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | 17 | from .train import * 18 | -------------------------------------------------------------------------------- /models/blocksworld.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google/neural-logic-machines/3f8a8966c54d13d2658c77c03793a9a98a283e22/models/blocksworld.pth -------------------------------------------------------------------------------- /models/path.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google/neural-logic-machines/3f8a8966c54d13d2658c77c03793a9a98a283e22/models/path.pth -------------------------------------------------------------------------------- /models/sort.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google/neural-logic-machines/3f8a8966c54d13d2658c77c03793a9a98a283e22/models/sort.pth -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | numpy 2 | tqdm 3 | -------------------------------------------------------------------------------- /scripts/blocksworld/README.md: -------------------------------------------------------------------------------- 1 | # Blocks World 2 | 3 | ## Command 4 | Please run these commands from the root directory of the project. 5 | 6 | ``` shell 7 | # Train 8 | $ jac-run scripts/blocksworld/learn_policy.py --task final 9 | # Test 10 | $ jac-run scripts/blocksworld/learn_policy.py --task final --test-only --load CHECKPOINT 11 | # add [--test-epoch-size T] to control the number of testing cases. 12 | ``` 13 | -------------------------------------------------------------------------------- /scripts/graph/README.md: -------------------------------------------------------------------------------- 1 | # Graph Tasks, Shortest Path and Sorting 2 | 3 | A set of graph-related reasoning tasks ans sorting task. 4 | 5 | ## Graph tasks 6 | 7 | ### Family tree tasks 8 | ``` shell 9 | # Train: add --train-number 20 --test-number-begin 20 --test-number-step 20 --test-number-end 100 10 | $ jac-run scripts/graph/learn_graph_tasks.py --task has-father 11 | $ jac-run scripts/graph/learn_graph_tasks.py --task has-sister 12 | $ jac-run scripts/graph/learn_graph_tasks.py --task grandparents --epochs 100 --early-stop 1e-7 13 | $ jac-run scripts/graph/learn_graph_tasks.py --task uncle --epochs 200 --early-stop 1e-8 14 | $ jac-run scripts/graph/learn_graph_tasks.py --task maternal-great-uncle --epochs 20 --epoch-size 2500 --early-stop 1e-8 15 | # Test 16 | $ jac-run scripts/graph/learn_graph_tasks.py --task TASK --test-only --load $CHECKPOINT 17 | ``` 18 | We use `loss < thresh` as the criteria for qualifying models. 19 | 20 | ### General graph tasks 21 | ``` shell 22 | # Train 23 | # AdjacentToRed 24 | $ jac-run scripts/graph/learn_graph_tasks.py --task adjacent --gen-graph-colors 4 25 | # 4-Connectivity 26 | $ jac-run scripts/graph/learn_graph_tasks.py --task connectivity 27 | # 6-Connectivity 28 | $ jac-run scripts/graph/learn_graph_tasks.py --task connectivity --connectivity-dist-limit 6 --early-stop 1e-6 \ 29 | --nlm-depth 8 --nlm-residual True --gen-graph-pmin 0.1 --gen-graph-pmax 0.3 --gen-graph-method dnc 30 | # 1-Outdegree 31 | $ jac-run scripts/graph/learn_graph_tasks.py --task outdegree --outdegree-n 1 32 | # 2-Outdegree 33 | $ jac-run scripts/graph/learn_graph_tasks.py --task outdegree --outdegree-n 2 \ 34 | --nlm-depth 5 --nlm-breadth 4 --nlm-residual True 35 | 36 | # Test 37 | $ jac-run scripts/graph/learn_graph_tasks.py --task TASK --test-only --load $CHECKPOINT \ 38 | --nlm-depth DEPTH --nlm-breadth BREADTH 39 | ``` 40 | 41 | ### MNIST Input 42 | We modified the `AdjacentToRed` task and replace the indicator of colors with MNIST digits. The NLM is integrated with LeNet and optimized jointly. 43 | 44 | ``` shell 45 | $ jac-run scripts/graph/learn_graph_tasks.py --task adjacent-mnist \ 46 | --nlm-depth 2 --nlm-breadth 2 --nlm-attributes 16 --nlm-residual True --gen-graph-colors 10 47 | ``` 48 | 49 | ## Shortest path 50 | 51 | Below provides a proper set of parameters for this task. 52 | ``` shell 53 | # Train 54 | $ jac-run scripts/graph/learn_policy.py --task path 55 | # Test 56 | $ jac-run scripts/graph/learn_policy.py --task path --test-only --load $CHECKPOINT 57 | ``` 58 | For all available arguments see `jac-run scripts/graph/learn_policy.py -h`. 59 | 60 | ## Sorting 61 | 62 | Below provides a proper set of parameters for this task. 63 | ``` shell 64 | # Train 65 | $ jac-run scripts/graph/learn_policy.py --task sort --nlm-depth 3 --nlm-breadth 2 \ 66 | --curriculum-graduate 10 --entropy-beta 0.01 \ 67 | --mining-epoch-size 200 --mining-dataset-size 20 --mining-interval 2 68 | # Test 69 | $ jac-run scripts/graph/learn_policy.py --task sort --test-only \ 70 | --nlm-depth 3 --nlm-breadth 2 --load $CHECKPOINT \ 71 | ``` 72 | -------------------------------------------------------------------------------- /scripts/graph/learn_graph_tasks.py: -------------------------------------------------------------------------------- 1 | #! /usr/bin/env python3 2 | # 3 | # Copyright 2018 Google LLC 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # https://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | """The script for family tree or general graphs experiments.""" 17 | 18 | import copy 19 | import collections 20 | import functools 21 | import os 22 | import json 23 | 24 | import numpy as np 25 | import torch 26 | import torch.nn as nn 27 | import torch.nn.functional as F 28 | 29 | import jacinle.random as random 30 | import jacinle.io as io 31 | import jactorch.nn as jacnn 32 | 33 | from difflogic.cli import format_args 34 | from difflogic.dataset.graph import GraphOutDegreeDataset, \ 35 | GraphConnectivityDataset, GraphAdjacentDataset, FamilyTreeDataset 36 | from difflogic.nn.baselines import MemoryNet 37 | from difflogic.nn.neural_logic import LogicMachine, LogicInference, LogitsInference 38 | from difflogic.nn.neural_logic.modules._utils import meshgrid_exclude_self 39 | from difflogic.nn.rl.reinforce import REINFORCELoss 40 | from difflogic.thutils import binary_accuracy 41 | from difflogic.train import TrainerBase 42 | 43 | from jacinle.cli.argument import JacArgumentParser 44 | from jacinle.logging import get_logger, set_output_file 45 | from jacinle.utils.container import GView 46 | from jacinle.utils.meter import GroupMeters 47 | from jactorch.data.dataloader import JacDataLoader 48 | from jactorch.optim.accum_grad import AccumGrad 49 | from jactorch.optim.quickaccess import get_optimizer 50 | from jactorch.train.env import TrainerEnv 51 | from jactorch.utils.meta import as_cuda, as_numpy, as_tensor 52 | 53 | TASKS = [ 54 | 'outdegree', 'connectivity', 'adjacent', 'adjacent-mnist', 'has-father', 55 | 'has-sister', 'grandparents', 'uncle', 'maternal-great-uncle' 56 | ] 57 | 58 | parser = JacArgumentParser() 59 | 60 | parser.add_argument( 61 | '--model', 62 | default='nlm', 63 | choices=['nlm', 'memnet'], 64 | help='model choices, nlm: Neural Logic Machine, memnet: Memory Networks') 65 | 66 | # NLM parameters, works when model is 'nlm' 67 | nlm_group = parser.add_argument_group('Neural Logic Machines') 68 | LogicMachine.make_nlm_parser( 69 | nlm_group, { 70 | 'depth': 4, 71 | 'breadth': 3, 72 | 'exclude_self': True, 73 | 'logic_hidden_dim': [] 74 | }, 75 | prefix='nlm') 76 | nlm_group.add_argument( 77 | '--nlm-attributes', 78 | type=int, 79 | default=8, 80 | metavar='N', 81 | help='number of output attributes in each group of each layer of the LogicMachine' 82 | ) 83 | 84 | # MemNN parameters, works when model is 'memnet' 85 | memnet_group = parser.add_argument_group('Memory Networks') 86 | MemoryNet.make_memnet_parser(memnet_group, {}, prefix='memnet') 87 | 88 | # task related 89 | task_group = parser.add_argument_group('Task') 90 | task_group.add_argument( 91 | '--task', required=True, choices=TASKS, help='tasks choices') 92 | task_group.add_argument( 93 | '--train-number', 94 | type=int, 95 | default=10, 96 | metavar='N', 97 | help='size of training instances') 98 | task_group.add_argument( 99 | '--adjacent-pred-colors', type=int, default=4, metavar='N') 100 | task_group.add_argument('--outdegree-n', type=int, default=2, metavar='N') 101 | task_group.add_argument( 102 | '--connectivity-dist-limit', type=int, default=4, metavar='N') 103 | 104 | data_gen_group = parser.add_argument_group('Data Generation') 105 | data_gen_group.add_argument( 106 | '--gen-graph-method', 107 | default='edge', 108 | choices=['dnc', 'edge'], 109 | help='method use to generate random graph') 110 | data_gen_group.add_argument( 111 | '--gen-graph-pmin', 112 | type=float, 113 | default=0.0, 114 | metavar='F', 115 | help='control parameter p reflecting the graph sparsity') 116 | data_gen_group.add_argument( 117 | '--gen-graph-pmax', 118 | type=float, 119 | default=0.3, 120 | metavar='F', 121 | help='control parameter p reflecting the graph sparsity') 122 | data_gen_group.add_argument( 123 | '--gen-graph-colors', 124 | type=int, 125 | default=4, 126 | metavar='N', 127 | help='number of colors in adjacent task') 128 | data_gen_group.add_argument( 129 | '--gen-directed', action='store_true', help='directed graph') 130 | 131 | train_group = parser.add_argument_group('Train') 132 | train_group.add_argument( 133 | '--seed', 134 | type=int, 135 | default=None, 136 | metavar='SEED', 137 | help='seed of jacinle.random') 138 | train_group.add_argument( 139 | '--use-gpu', action='store_true', help='use GPU or not') 140 | train_group.add_argument( 141 | '--optimizer', 142 | default='AdamW', 143 | choices=['SGD', 'Adam', 'AdamW'], 144 | help='optimizer choices') 145 | train_group.add_argument( 146 | '--lr', 147 | type=float, 148 | default=0.005, 149 | metavar='F', 150 | help='initial learning rate') 151 | train_group.add_argument( 152 | '--lr-decay', 153 | type=float, 154 | default=1.0, 155 | metavar='F', 156 | help='exponential decay of learning rate per lesson') 157 | train_group.add_argument( 158 | '--accum-grad', 159 | type=int, 160 | default=1, 161 | metavar='N', 162 | help='accumulated gradient for batches (default: 1)') 163 | train_group.add_argument( 164 | '--ohem-size', 165 | type=int, 166 | default=0, 167 | metavar='N', 168 | help='size of online hard negative mining') 169 | train_group.add_argument( 170 | '--batch-size', 171 | type=int, 172 | default=4, 173 | metavar='N', 174 | help='batch size for training') 175 | train_group.add_argument( 176 | '--test-batch-size', 177 | type=int, 178 | default=4, 179 | metavar='N', 180 | help='batch size for testing') 181 | train_group.add_argument( 182 | '--early-stop-loss-thresh', 183 | type=float, 184 | default=1e-5, 185 | metavar='F', 186 | help='threshold of loss for early stop') 187 | 188 | # Note that nr_examples_per_epoch = epoch_size * batch_size 189 | TrainerBase.make_trainer_parser( 190 | parser, { 191 | 'epochs': 50, 192 | 'epoch_size': 250, 193 | 'test_epoch_size': 250, 194 | 'test_number_begin': 10, 195 | 'test_number_step': 10, 196 | 'test_number_end': 50, 197 | }) 198 | 199 | io_group = parser.add_argument_group('Input/Output') 200 | io_group.add_argument( 201 | '--dump-dir', type=str, default=None, metavar='DIR', help='dump dir') 202 | io_group.add_argument( 203 | '--load-checkpoint', 204 | type=str, 205 | default=None, 206 | metavar='FILE', 207 | help='load parameters from checkpoint') 208 | 209 | schedule_group = parser.add_argument_group('Schedule') 210 | schedule_group.add_argument( 211 | '--runs', type=int, default=1, metavar='N', help='number of runs') 212 | schedule_group.add_argument( 213 | '--save-interval', 214 | type=int, 215 | default=10, 216 | metavar='N', 217 | help='the interval(number of epochs) to save checkpoint') 218 | schedule_group.add_argument( 219 | '--test-interval', 220 | type=int, 221 | default=None, 222 | metavar='N', 223 | help='the interval(number of epochs) to do test') 224 | schedule_group.add_argument( 225 | '--test-only', action='store_true', help='test-only mode') 226 | 227 | logger = get_logger(__file__) 228 | 229 | args = parser.parse_args() 230 | 231 | args.use_gpu = args.use_gpu and torch.cuda.is_available() 232 | 233 | if args.dump_dir is not None: 234 | io.mkdir(args.dump_dir) 235 | args.log_file = os.path.join(args.dump_dir, 'log.log') 236 | set_output_file(args.log_file) 237 | else: 238 | args.checkpoints_dir = None 239 | args.summary_file = None 240 | 241 | if args.seed is not None: 242 | import jacinle.random as random 243 | random.reset_global_seed(args.seed) 244 | 245 | args.task_is_outdegree = args.task in ['outdegree'] 246 | args.task_is_connectivity = args.task in ['connectivity'] 247 | args.task_is_adjacent = args.task in ['adjacent', 'adjacent-mnist'] 248 | args.task_is_family_tree = args.task in [ 249 | 'has-father', 'has-sister', 'grandparents', 'uncle', 'maternal-great-uncle' 250 | ] 251 | args.task_is_mnist_input = args.task in ['adjacent-mnist'] 252 | args.task_is_1d_output = args.task in [ 253 | 'outdegree', 'adjacent', 'adjacent-mnist', 'has-father', 'has-sister' 254 | ] 255 | 256 | 257 | class LeNet(nn.Module): 258 | 259 | def __init__(self): 260 | super().__init__() 261 | self.conv1 = jacnn.Conv2dLayer( 262 | 1, 10, kernel_size=5, batch_norm=True, activation='relu') 263 | self.conv2 = jacnn.Conv2dLayer( 264 | 10, 265 | 20, 266 | kernel_size=5, 267 | batch_norm=True, 268 | dropout=False, 269 | activation='relu') 270 | self.fc1 = nn.Linear(320, 50) 271 | self.fc2 = nn.Linear(50, 10) 272 | 273 | def forward(self, x): 274 | x = F.max_pool2d(self.conv1(x), 2) 275 | x = F.max_pool2d(self.conv2(x), 2) 276 | x = x.view(-1, 320) 277 | x = F.relu(self.fc1(x)) 278 | x = self.fc2(x) 279 | return x 280 | 281 | 282 | class Model(nn.Module): 283 | """The model for family tree or general graphs path tasks.""" 284 | 285 | def __init__(self): 286 | super().__init__() 287 | 288 | # inputs 289 | input_dim = 4 if args.task_is_family_tree else 1 290 | self.feature_axis = 1 if args.task_is_1d_output else 2 291 | 292 | # features 293 | if args.model == 'nlm': 294 | input_dims = [0 for _ in range(args.nlm_breadth + 1)] 295 | if args.task_is_adjacent: 296 | input_dims[1] = args.gen_graph_colors 297 | if args.task_is_mnist_input: 298 | self.lenet = LeNet() 299 | input_dims[2] = input_dim 300 | 301 | self.features = LogicMachine.from_args( 302 | input_dims, args.nlm_attributes, args, prefix='nlm') 303 | output_dim = self.features.output_dims[self.feature_axis] 304 | 305 | elif args.model == 'memnet': 306 | if args.task_is_adjacent: 307 | input_dim += args.gen_graph_colors 308 | self.feature = MemoryNet.from_args( 309 | input_dim, self.feature_axis, args, prefix='memnet') 310 | output_dim = self.feature.get_output_dim() 311 | 312 | # target 313 | target_dim = args.adjacent_pred_colors if args.task_is_adjacent else 1 314 | self.pred = LogicInference(output_dim, target_dim, []) 315 | 316 | # losses 317 | if args.ohem_size > 0: 318 | from jactorch.nn.losses import BinaryCrossEntropyLossWithProbs as BCELoss 319 | self.loss = BCELoss(average='none') 320 | else: 321 | self.loss = nn.BCELoss() 322 | 323 | def forward(self, feed_dict): 324 | feed_dict = GView(feed_dict) 325 | 326 | # properties 327 | if args.task_is_adjacent: 328 | states = feed_dict.states.float() 329 | else: 330 | states = None 331 | 332 | # relations 333 | relations = feed_dict.relations.float() 334 | batch_size, nr = relations.size()[:2] 335 | 336 | if args.model == 'nlm': 337 | if args.task_is_adjacent and args.task_is_mnist_input: 338 | states_shape = states.size() 339 | states = states.view((-1,) + states_shape[2:]) 340 | states = self.lenet(states) 341 | states = states.view(states_shape[:2] + (-1,)) 342 | states = F.sigmoid(states) 343 | 344 | inp = [None for _ in range(args.nlm_breadth + 1)] 345 | inp[1] = states 346 | inp[2] = relations 347 | 348 | depth = None 349 | if args.nlm_recursion: 350 | depth = 1 351 | while 2**depth + 1 < nr: 352 | depth += 1 353 | depth = depth * 2 + 1 354 | feature = self.features(inp, depth=depth)[self.feature_axis] 355 | elif args.model == 'memnet': 356 | feature = self.feature(relations, states) 357 | if args.task_is_adjacent and args.task_is_mnist_input: 358 | raise NotImplementedError() 359 | 360 | pred = self.pred(feature) 361 | if not args.task_is_adjacent: 362 | pred = pred.squeeze(-1) 363 | if args.task_is_connectivity: 364 | pred = meshgrid_exclude_self(pred) # exclude self-cycle 365 | 366 | if self.training: 367 | monitors = dict() 368 | target = feed_dict.target.float() 369 | 370 | if args.task_is_adjacent: 371 | target = target[:, :, :args.adjacent_pred_colors] 372 | 373 | monitors.update(binary_accuracy(target, pred, return_float=False)) 374 | 375 | loss = self.loss(pred, target) 376 | # ohem loss is unused. 377 | if args.ohem_size > 0: 378 | loss = loss.view(-1).topk(args.ohem_size)[0].mean() 379 | return loss, monitors, dict(pred=pred) 380 | else: 381 | return dict(pred=pred) 382 | 383 | 384 | def make_dataset(n, epoch_size, is_train): 385 | pmin, pmax = args.gen_graph_pmin, args.gen_graph_pmax 386 | if args.task_is_outdegree: 387 | return GraphOutDegreeDataset( 388 | args.outdegree_n, 389 | epoch_size, 390 | n, 391 | pmin=pmin, 392 | pmax=pmax, 393 | directed=args.gen_directed, 394 | gen_method=args.gen_graph_method) 395 | elif args.task_is_connectivity: 396 | nmin, nmax = n, n 397 | if is_train and args.nlm_recursion: 398 | nmin = 2 399 | return GraphConnectivityDataset( 400 | args.connectivity_dist_limit, 401 | epoch_size, 402 | nmin, 403 | pmin, 404 | nmax, 405 | pmax, 406 | directed=args.gen_directed, 407 | gen_method=args.gen_graph_method) 408 | elif args.task_is_adjacent: 409 | return GraphAdjacentDataset( 410 | args.gen_graph_colors, 411 | epoch_size, 412 | n, 413 | pmin=pmin, 414 | pmax=pmax, 415 | directed=args.gen_directed, 416 | gen_method=args.gen_graph_method, 417 | is_train=is_train, 418 | is_mnist_colors=args.task_is_mnist_input) 419 | else: 420 | return FamilyTreeDataset(args.task, epoch_size, n, p_marriage=1.0) 421 | 422 | 423 | class MyTrainer(TrainerBase): 424 | def save_checkpoint(self, name): 425 | if args.checkpoints_dir is not None: 426 | checkpoint_file = os.path.join(args.checkpoints_dir, 427 | 'checkpoint_{}.pth'.format(name)) 428 | super().save_checkpoint(checkpoint_file) 429 | 430 | def _dump_meters(self, meters, mode): 431 | if args.summary_file is not None: 432 | meters_kv = meters._canonize_values('avg') 433 | meters_kv['mode'] = mode 434 | meters_kv['epoch'] = self.current_epoch 435 | with open(args.summary_file, 'a') as f: 436 | f.write(io.dumps_json(meters_kv)) 437 | f.write('\n') 438 | 439 | data_iterator = {} 440 | 441 | def _prepare_dataset(self, epoch_size, mode): 442 | assert mode in ['train', 'test'] 443 | if mode == 'train': 444 | batch_size = args.batch_size 445 | number = args.train_number 446 | else: 447 | batch_size = args.test_batch_size 448 | number = self.test_number 449 | 450 | # The actual number of instances in an epoch is epoch_size * batch_size. 451 | dataset = make_dataset(number, epoch_size * batch_size, mode == 'train') 452 | dataloader = JacDataLoader( 453 | dataset, 454 | shuffle=True, 455 | batch_size=batch_size, 456 | num_workers=min(epoch_size, 4)) 457 | self.data_iterator[mode] = dataloader.__iter__() 458 | 459 | def _get_data(self, index, meters, mode): 460 | feed_dict = self.data_iterator[mode].next() 461 | meters.update(number=feed_dict['n'].data.numpy().mean()) 462 | if args.use_gpu: 463 | feed_dict = as_cuda(feed_dict) 464 | return feed_dict 465 | 466 | def _get_result(self, index, meters, mode): 467 | feed_dict = self._get_data(index, meters, mode) 468 | output_dict = self.model(feed_dict) 469 | 470 | target = feed_dict['target'] 471 | if args.task_is_adjacent: 472 | target = target[:, :, :args.adjacent_pred_colors] 473 | result = binary_accuracy(target, output_dict['pred']) 474 | succ = result['accuracy'] == 1.0 475 | 476 | meters.update(succ=succ) 477 | meters.update(result, n=target.size(0)) 478 | message = '> {} iter={iter}, accuracy={accuracy:.4f}, \ 479 | balance_acc={balanced_accuracy:.4f}'.format( 480 | mode, iter=index, **meters.val) 481 | return message, dict(succ=succ, feed_dict=feed_dict) 482 | 483 | def _get_train_data(self, index, meters): 484 | return self._get_data(index, meters, mode='train') 485 | 486 | def _train_epoch(self, epoch_size): 487 | meters = super()._train_epoch(epoch_size) 488 | 489 | i = self.current_epoch 490 | if args.save_interval is not None and i % args.save_interval == 0: 491 | self.save_checkpoint(str(i)) 492 | if args.test_interval is not None and i % args.test_interval == 0: 493 | self.test() 494 | return meters 495 | 496 | def _early_stop(self, meters): 497 | return meters.avg['loss'] < args.early_stop_loss_thresh 498 | 499 | 500 | def main(run_id): 501 | if args.dump_dir is not None: 502 | if args.runs > 1: 503 | args.current_dump_dir = os.path.join(args.dump_dir, 504 | 'run_{}'.format(run_id)) 505 | io.mkdir(args.current_dump_dir) 506 | else: 507 | args.current_dump_dir = args.dump_dir 508 | 509 | args.summary_file = os.path.join(args.current_dump_dir, 'summary.json') 510 | args.checkpoints_dir = os.path.join(args.current_dump_dir, 'checkpoints') 511 | io.mkdir(args.checkpoints_dir) 512 | 513 | logger.info(format_args(args)) 514 | 515 | model = Model() 516 | if args.use_gpu: 517 | model.cuda() 518 | optimizer = get_optimizer(args.optimizer, model, args.lr) 519 | if args.accum_grad > 1: 520 | optimizer = AccumGrad(optimizer, args.accum_grad) 521 | trainer = MyTrainer.from_args(model, optimizer, args) 522 | 523 | if args.load_checkpoint is not None: 524 | trainer.load_checkpoint(args.load_checkpoint) 525 | 526 | if args.test_only: 527 | return None, trainer.test() 528 | 529 | final_meters = trainer.train() 530 | trainer.save_checkpoint('last') 531 | 532 | return trainer.early_stopped, trainer.test() 533 | 534 | 535 | if __name__ == '__main__': 536 | stats = [] 537 | nr_graduated = 0 538 | 539 | for i in range(args.runs): 540 | graduated, test_meters = main(i) 541 | logger.info('run {}'.format(i + 1)) 542 | 543 | if test_meters is not None: 544 | for j, meters in enumerate(test_meters): 545 | if len(stats) <= j: 546 | stats.append(GroupMeters()) 547 | stats[j].update( 548 | number=meters.avg['number'], test_acc=meters.avg['accuracy']) 549 | 550 | for meters in stats: 551 | logger.info('number {}, test_acc {}'.format(meters.avg['number'], 552 | meters.avg['test_acc'])) 553 | 554 | if not args.test_only: 555 | nr_graduated += int(graduated) 556 | logger.info('graduate_ratio {}'.format(nr_graduated / (i + 1))) 557 | if graduated: 558 | for j, meters in enumerate(test_meters): 559 | stats[j].update(grad_test_acc=meters.avg['accuracy']) 560 | if nr_graduated > 0: 561 | for meters in stats: 562 | logger.info('number {}, grad_test_acc {}'.format( 563 | meters.avg['number'], meters.avg['grad_test_acc'])) 564 | -------------------------------------------------------------------------------- /scripts/graph/learn_policy.py: -------------------------------------------------------------------------------- 1 | #! /usr/bin/env python3 2 | # 3 | # Copyright 2018 Google LLC 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # https://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | """The script for sorting or shortest path experiments.""" 17 | 18 | import collections 19 | import copy 20 | import functools 21 | import json 22 | import os 23 | 24 | import numpy as np 25 | import torch 26 | import torch.nn as nn 27 | import torch.nn.functional as F 28 | 29 | import jacinle.random as random 30 | import jacinle.io as io 31 | 32 | from difflogic.cli import format_args 33 | from difflogic.nn.baselines import MemoryNet 34 | from difflogic.nn.neural_logic import LogicMachine, LogitsInference 35 | from difflogic.nn.neural_logic.modules._utils import meshgrid_exclude_self 36 | from difflogic.nn.rl.reinforce import REINFORCELoss 37 | from difflogic.train import MiningTrainerBase 38 | 39 | from jacinle.cli.argument import JacArgumentParser 40 | from jacinle.logging import get_logger, set_output_file 41 | from jacinle.utils.container import GView 42 | from jacinle.utils.meter import GroupMeters 43 | from jactorch.optim.accum_grad import AccumGrad 44 | from jactorch.optim.quickaccess import get_optimizer 45 | from jactorch.utils.meta import as_cuda, as_numpy, as_tensor 46 | 47 | parser = JacArgumentParser() 48 | 49 | parser.add_argument( 50 | '--model', 51 | default='nlm', 52 | choices=['nlm', 'memnet'], 53 | help='model choices, nlm: Neural Logic Machine, memnet: Memory Networks') 54 | 55 | # NLM parameters, works when model is 'nlm'. 56 | nlm_group = parser.add_argument_group('Neural Logic Machines') 57 | LogicMachine.make_nlm_parser( 58 | nlm_group, { 59 | 'depth': 5, 60 | 'breadth': 3, 61 | 'residual': True, 62 | 'exclude_self': True, 63 | 'logic_hidden_dim': [] 64 | }, 65 | prefix='nlm') 66 | nlm_group.add_argument( 67 | '--nlm-attributes', 68 | type=int, 69 | default=8, 70 | metavar='N', 71 | help='number of output attributes in each group of each layer of the LogicMachine' 72 | ) 73 | 74 | # MemNN parameters, works when model is 'memnet'. 75 | memnet_group = parser.add_argument_group('Memory Networks') 76 | MemoryNet.make_memnet_parser(memnet_group, {}, prefix='memnet') 77 | 78 | parser.add_argument( 79 | '--task', required=True, choices=['sort', 'path'], help='tasks choices') 80 | 81 | data_gen_group = parser.add_argument_group('Data Generation') 82 | data_gen_group.add_argument( 83 | '--gen-method', 84 | default='dnc', 85 | choices=['dnc', 'edge'], 86 | help='method use to generate random graph') 87 | data_gen_group.add_argument( 88 | '--gen-graph-pmin', 89 | type=float, 90 | default=0.3, 91 | metavar='F', 92 | help='control parameter p reflecting the graph sparsity') 93 | data_gen_group.add_argument( 94 | '--gen-graph-pmax', 95 | type=float, 96 | default=0.3, 97 | metavar='F', 98 | help='control parameter p reflecting the graph sparsity') 99 | data_gen_group.add_argument( 100 | '--gen-max-len', 101 | type=int, 102 | default=5, 103 | metavar='N', 104 | help='maximum length of shortest path during training') 105 | data_gen_group.add_argument( 106 | '--gen-test-len', 107 | type=int, 108 | default=4, 109 | metavar='N', 110 | help='length of shortest path during testing') 111 | data_gen_group.add_argument( 112 | '--gen-directed', action='store_true', help='directed graph') 113 | 114 | MiningTrainerBase.make_trainer_parser( 115 | parser, { 116 | 'epochs': 400, 117 | 'epoch_size': 100, 118 | 'test_epoch_size': 1000, 119 | 'test_number_begin': 10, 120 | 'test_number_step': 10, 121 | 'test_number_end': 50, 122 | 'curriculum_start': 3, 123 | 'curriculum_step': 1, 124 | 'curriculum_graduate': 12, 125 | 'curriculum_thresh_relax': 0.005, 126 | 'sample_array_capacity': 3, 127 | 'enable_mining': True, 128 | 'mining_interval': 6, 129 | 'mining_epoch_size': 3000, 130 | 'mining_dataset_size': 300, 131 | 'inherit_neg_data': True, 132 | 'prob_pos_data': 0.5 133 | }) 134 | 135 | train_group = parser.add_argument_group('Train') 136 | train_group.add_argument('--seed', type=int, default=None, metavar='SEED') 137 | train_group.add_argument( 138 | '--use-gpu', action='store_true', help='use GPU or not') 139 | train_group.add_argument( 140 | '--optimizer', 141 | default='AdamW', 142 | choices=['SGD', 'Adam', 'AdamW'], 143 | help='optimizer choices') 144 | train_group.add_argument( 145 | '--lr', 146 | type=float, 147 | default=0.005, 148 | metavar='F', 149 | help='initial learning rate') 150 | train_group.add_argument( 151 | '--lr-decay', 152 | type=float, 153 | default=0.9, 154 | metavar='F', 155 | help='exponential decay of learning rate per lesson') 156 | train_group.add_argument( 157 | '--accum-grad', 158 | type=int, 159 | default=1, 160 | metavar='N', 161 | help='accumulated gradient (default: 1)') 162 | train_group.add_argument( 163 | '--candidate-relax', 164 | type=int, 165 | default=0, 166 | metavar='N', 167 | help='number of thresh relaxation for candidate') 168 | 169 | rl_group = parser.add_argument_group('Reinforcement Learning') 170 | rl_group.add_argument( 171 | '--gamma', 172 | type=float, 173 | default=0.99, 174 | metavar='F', 175 | help='discount factor for accumulated reward function in reinforcement learning' 176 | ) 177 | rl_group.add_argument( 178 | '--penalty', 179 | type=float, 180 | default=-0.01, 181 | metavar='F', 182 | help='a small penalty each step') 183 | rl_group.add_argument( 184 | '--entropy-beta', 185 | type=float, 186 | default=0.1, 187 | metavar='F', 188 | help='entropy loss scaling factor') 189 | rl_group.add_argument( 190 | '--entropy-beta-decay', 191 | type=float, 192 | default=0.8, 193 | metavar='F', 194 | help='entropy beta exponential decay factor') 195 | 196 | io_group = parser.add_argument_group('Input/Output') 197 | io_group.add_argument( 198 | '--dump-dir', default=None, metavar='DIR', help='dump dir') 199 | io_group.add_argument( 200 | '--dump-play', 201 | action='store_true', 202 | help='dump the trajectory of the plays for visualization') 203 | io_group.add_argument( 204 | '--dump-fail-only', action='store_true', help='dump failure cases only') 205 | io_group.add_argument( 206 | '--load-checkpoint', 207 | default=None, 208 | metavar='FILE', 209 | help='load parameters from checkpoint') 210 | 211 | schedule_group = parser.add_argument_group('Schedule') 212 | schedule_group.add_argument( 213 | '--runs', type=int, default=1, metavar='N', help='number of runs') 214 | schedule_group.add_argument( 215 | '--early-drop-epochs', 216 | type=int, 217 | default=40, 218 | metavar='N', 219 | help='epochs could spend for each lesson, early drop') 220 | schedule_group.add_argument( 221 | '--save-interval', 222 | type=int, 223 | default=10, 224 | metavar='N', 225 | help='the interval(number of epochs) to save checkpoint') 226 | schedule_group.add_argument( 227 | '--test-interval', 228 | type=int, 229 | default=None, 230 | metavar='N', 231 | help='the interval(number of epochs) to do test') 232 | schedule_group.add_argument( 233 | '--test-only', action='store_true', help='test-only mode') 234 | schedule_group.add_argument( 235 | '--test-not-graduated', 236 | action='store_true', 237 | help='test not graduated models also') 238 | 239 | args = parser.parse_args() 240 | 241 | args.use_gpu = args.use_gpu and torch.cuda.is_available() 242 | args.dump_play = args.dump_play and (args.dump_dir is not None) 243 | 244 | if args.dump_dir is not None: 245 | io.mkdir(args.dump_dir) 246 | args.log_file = os.path.join(args.dump_dir, 'log.log') 247 | set_output_file(args.log_file) 248 | else: 249 | args.checkpoints_dir = None 250 | args.summary_file = None 251 | 252 | if args.seed is not None: 253 | import jacinle.random as random 254 | random.reset_global_seed(args.seed) 255 | 256 | args.is_path_task = args.task in ['path'] 257 | args.is_sort_task = args.task in ['sort'] 258 | if args.is_path_task: 259 | from difflogic.envs.graph import make as make_env 260 | make_env = functools.partial( 261 | make_env, 262 | pmin=args.gen_graph_pmin, 263 | pmax=args.gen_graph_pmax, 264 | directed=args.gen_directed, 265 | gen_method=args.gen_method) 266 | elif args.is_sort_task: 267 | from difflogic.envs.algorithmic import make as make_env 268 | 269 | logger = get_logger(__file__) 270 | 271 | 272 | class Model(nn.Module): 273 | """The model for sorting or shortest path tasks.""" 274 | 275 | def __init__(self): 276 | super().__init__() 277 | 278 | self.feature_axis = 1 if args.is_path_task else 2 279 | if args.model == 'memnet': 280 | current_dim = 4 if args.is_path_task else 6 281 | self.feature = MemoryNet.from_args( 282 | current_dim, self.feature_axis, args, prefix='memnet') 283 | current_dim = self.feature.get_output_dim() 284 | else: 285 | input_dims = [0 for i in range(args.nlm_breadth + 1)] 286 | if args.is_path_task: 287 | input_dims[1] = 2 288 | input_dims[2] = 2 289 | elif args.is_sort_task: 290 | input_dims[2] = 6 291 | 292 | self.features = LogicMachine.from_args( 293 | input_dims, args.nlm_attributes, args, prefix='nlm') 294 | if args.is_path_task: 295 | current_dim = self.features.output_dims[1] 296 | elif args.task == 'sort': 297 | current_dim = self.features.output_dims[2] 298 | 299 | self.pred = LogitsInference(current_dim, 1, []) 300 | self.loss = REINFORCELoss() 301 | self.pred_loss = nn.BCELoss() 302 | 303 | def forward(self, feed_dict): 304 | feed_dict = GView(feed_dict) 305 | states = None 306 | if args.is_path_task: 307 | states = feed_dict.states.float() 308 | relations = feed_dict.relations.float() 309 | elif args.is_sort_task: 310 | relations = feed_dict.states.float() 311 | 312 | def get_features(states, relations, depth=None): 313 | inp = [None for i in range(args.nlm_breadth + 1)] 314 | inp[1] = states 315 | inp[2] = relations 316 | features = self.features(inp, depth=depth) 317 | return features 318 | 319 | if args.model == 'memnet': 320 | f = self.feature(relations, states) 321 | else: 322 | f = get_features(states, relations)[self.feature_axis] 323 | if self.feature_axis == 2: #sorting task 324 | f = meshgrid_exclude_self(f) 325 | 326 | logits = self.pred(f).squeeze(dim=-1).view(relations.size(0), -1) 327 | # Set minimal value to avoid loss to be nan. 328 | policy = F.softmax(logits, dim=-1).clamp(min=1e-20) 329 | 330 | if self.training: 331 | loss, monitors = self.loss(policy, feed_dict.actions, 332 | feed_dict.discount_rewards, 333 | feed_dict.entropy_beta) 334 | return loss, monitors, dict() 335 | else: 336 | return dict(policy=policy, logits=logits) 337 | 338 | 339 | def make_data(traj, gamma): 340 | Q = 0 341 | discount_rewards = [] 342 | for reward in traj['rewards'][::-1]: 343 | Q = Q * gamma + reward 344 | discount_rewards.append(Q) 345 | discount_rewards.reverse() 346 | 347 | traj['states'] = as_tensor(np.array(traj['states'])) 348 | if args.is_path_task: 349 | traj['relations'] = as_tensor(np.array(traj['relations'])) 350 | traj['actions'] = as_tensor(np.array(traj['actions'])) 351 | traj['discount_rewards'] = as_tensor(np.array(discount_rewards)).float() 352 | return traj 353 | 354 | 355 | def run_episode(env, 356 | model, 357 | number, 358 | play_name='', 359 | dump=False, 360 | eval_only=False, 361 | use_argmax=False, 362 | need_restart=False, 363 | entropy_beta=0.0): 364 | """Run one episode using the model with $number nodes/numbers.""" 365 | is_over = False 366 | traj = collections.defaultdict(list) 367 | score = 0 368 | moves = [] 369 | # If dump_play=True, store the states and actions in a json file 370 | # for visualization. 371 | dump_play = args.dump_play and dump 372 | 373 | if need_restart: 374 | env.restart() 375 | 376 | if args.is_path_task: 377 | optimal = env.unwrapped.dist 378 | relation = env.unwrapped.graph.get_edges() 379 | relation = np.stack([relation, relation.T], axis=-1) 380 | st, ed = env.current_state 381 | nodes_trajectory = [int(st)] 382 | destination = int(ed) 383 | policies = [] 384 | elif args.is_sort_task: 385 | optimal = env.unwrapped.optimal 386 | array = [str(i) for i in env.unwrapped.array] 387 | 388 | while not is_over: 389 | if args.is_path_task: 390 | st, ed = env.current_state 391 | state = np.zeros((relation.shape[0], 2)) 392 | state[st, 0] = 1 393 | state[ed, 1] = 1 394 | feed_dict = dict(states=np.array([state]), relations=np.array([relation])) 395 | elif args.is_sort_task: 396 | state = env.current_state 397 | feed_dict = dict(states=np.array([state])) 398 | feed_dict['entropy_beta'] = as_tensor(entropy_beta).float() 399 | feed_dict = as_tensor(feed_dict) 400 | if args.use_gpu: 401 | feed_dict = as_cuda(feed_dict) 402 | 403 | with torch.set_grad_enabled(not eval_only): 404 | output_dict = model(feed_dict) 405 | 406 | policy = output_dict['policy'] 407 | p = as_numpy(policy.data[0]) 408 | action = p.argmax() if use_argmax else random.choice(len(p), p=p) 409 | reward, is_over = env.action(action) 410 | 411 | # collect moves information 412 | if dump_play: 413 | if args.is_path_task: 414 | moves.append(int(action)) 415 | nodes_trajectory.append(int(env.current_state[0])) 416 | logits = as_numpy(output_dict['logits'].data[0]) 417 | tops = np.argsort(p)[-10:][::-1] 418 | tops = list( 419 | map(lambda x: (int(x), float(p[x]), float(logits[x])), tops)) 420 | policies.append(tops) 421 | if args.is_sort_task: 422 | # Need to ensure that env.utils.MapActionProxy is the outermost class. 423 | mapped_x, mapped_y = env.mapping[action] 424 | moves.append([mapped_x, mapped_y]) 425 | 426 | # For now, assume reward=1 only when succeed, otherwise reward=0. 427 | # Manipulate the reward and get success information according to reward. 428 | if reward == 0 and args.penalty is not None: 429 | reward = args.penalty 430 | succ = 1 if is_over and reward > 0.99 else 0 431 | 432 | score += reward 433 | traj['states'].append(state) 434 | if args.is_path_task: 435 | traj['relations'].append(relation) 436 | traj['rewards'].append(reward) 437 | traj['actions'].append(action) 438 | 439 | # dump json file storing information of playing 440 | if dump_play and not (args.dump_fail_only and succ): 441 | if args.is_path_task: 442 | num = env.unwrapped.nr_nodes 443 | graph = relation[:, :, 0].tolist() 444 | coordinates = env.unwrapped.graph.get_coordinates().tolist() 445 | json_str = json.dumps( 446 | dict( 447 | graph=graph, 448 | coordinates=coordinates, 449 | policies=policies, 450 | destination=destination, 451 | current=nodes_trajectory, 452 | moves=moves)) 453 | if args.is_sort_task: 454 | num = env.unwrapped.nr_numbers 455 | json_str = json.dumps(dict(array=array, moves=moves)) 456 | dump_file = os.path.join(args.current_dump_dir, 457 | '{}_size{}.json'.format(play_name, num)) 458 | with open(dump_file, 'w') as f: 459 | f.write(json_str) 460 | 461 | length = len(traj['rewards']) 462 | return succ, score, traj, length, optimal 463 | 464 | 465 | class MyTrainer(MiningTrainerBase): 466 | def save_checkpoint(self, name): 467 | if args.checkpoints_dir is not None: 468 | checkpoint_file = os.path.join(args.checkpoints_dir, 469 | 'checkpoint_{}.pth'.format(name)) 470 | super().save_checkpoint(checkpoint_file) 471 | 472 | def _dump_meters(self, meters, mode): 473 | if args.summary_file is not None: 474 | meters_kv = meters._canonize_values('avg') 475 | meters_kv['mode'] = mode 476 | meters_kv['epoch'] = self.current_epoch 477 | with open(args.summary_file, 'a') as f: 478 | f.write(io.dumps_json(meters_kv)) 479 | f.write('\n') 480 | 481 | def _prepare_dataset(self, epoch_size, mode): 482 | pass 483 | 484 | def _get_player(self, number, mode): 485 | if args.is_path_task: 486 | test_len = args.gen_test_len 487 | dist_range = (test_len, test_len) if mode == 'test' \ 488 | else (1, args.gen_max_len) 489 | player = make_env(args.task, number, dist_range=dist_range) 490 | else: 491 | player = make_env(args.task, number) 492 | player.restart() 493 | return player 494 | 495 | def _get_result_given_player(self, index, meters, number, player, mode): 496 | assert mode in ['train', 'test', 'mining', 'inherit'] 497 | params = dict( 498 | eval_only=True, 499 | number=number, 500 | play_name='{}_epoch{}_episode{}'.format(mode, self.current_epoch, 501 | index)) 502 | backup = None 503 | if mode == 'train': 504 | params['eval_only'] = False 505 | params['entropy_beta'] = self.entropy_beta 506 | meters.update(lr=self.lr, entropy_beta=self.entropy_beta) 507 | elif mode == 'test': 508 | params['dump'] = True 509 | params['use_argmax'] = True 510 | else: 511 | backup = copy.deepcopy(player) 512 | params['use_argmax'] = self.is_candidate 513 | succ, score, traj, length, optimal = \ 514 | run_episode(player, self.model, **params) 515 | meters.update( 516 | number=number, succ=succ, score=score, length=length, optimal=optimal) 517 | 518 | if mode == 'train': 519 | feed_dict = make_data(traj, args.gamma) 520 | feed_dict['entropy_beta'] = as_tensor(self.entropy_beta).float() 521 | 522 | if args.use_gpu: 523 | feed_dict = as_cuda(feed_dict) 524 | return feed_dict 525 | else: 526 | message = '> {} iter={iter}, number={number}, succ={succ}, \ 527 | score={score:.4f}, length={length}, optimal={optimal}'.format( 528 | mode, iter=index, **meters.val) 529 | return message, dict(succ=succ, number=number, backup=backup) 530 | 531 | def _extract_info(self, extra): 532 | return extra['succ'], extra['number'], extra['backup'] 533 | 534 | def _get_accuracy(self, meters): 535 | return meters.avg['succ'] 536 | 537 | def _get_threshold(self): 538 | candidate_relax = 0 if self.is_candidate else args.candidate_relax 539 | return super()._get_threshold() - \ 540 | self.curriculum_thresh_relax * candidate_relax 541 | 542 | def _upgrade_lesson(self): 543 | super()._upgrade_lesson() 544 | # Adjust lr & entropy_beta w.r.t different lesson progressively. 545 | self.lr *= args.lr_decay 546 | self.entropy_beta *= args.entropy_beta_decay 547 | self.set_learning_rate(self.lr) 548 | 549 | def _train_epoch(self, epoch_size): 550 | meters = super()._train_epoch(epoch_size) 551 | 552 | i = self.current_epoch 553 | if args.save_interval is not None and i % args.save_interval == 0: 554 | self.save_checkpoint(str(i)) 555 | if args.test_interval is not None and i % args.test_interval == 0: 556 | self.test() 557 | 558 | return meters 559 | 560 | def _early_stop(self, meters): 561 | t = args.early_drop_epochs 562 | if t is not None and self.current_epoch > t * (self.nr_upgrades + 1): 563 | return True 564 | return super()._early_stop(meters) 565 | 566 | def train(self): 567 | self.lr = args.lr 568 | self.entropy_beta = args.entropy_beta 569 | return super().train() 570 | 571 | 572 | def main(run_id): 573 | if args.dump_dir is not None: 574 | if args.runs > 1: 575 | args.current_dump_dir = os.path.join(args.dump_dir, 576 | 'run_{}'.format(run_id)) 577 | io.mkdir(args.current_dump_dir) 578 | else: 579 | args.current_dump_dir = args.dump_dir 580 | args.checkpoints_dir = os.path.join(args.current_dump_dir, 'checkpoints') 581 | io.mkdir(args.checkpoints_dir) 582 | args.summary_file = os.path.join(args.current_dump_dir, 'summary.json') 583 | 584 | logger.info(format_args(args)) 585 | 586 | model = Model() 587 | if args.use_gpu: 588 | model.cuda() 589 | optimizer = get_optimizer(args.optimizer, model, args.lr) 590 | if args.accum_grad > 1: 591 | optimizer = AccumGrad(optimizer, args.accum_grad) 592 | 593 | trainer = MyTrainer.from_args(model, optimizer, args) 594 | 595 | if args.load_checkpoint is not None: 596 | trainer.load_checkpoint(args.load_checkpoint) 597 | 598 | if args.test_only: 599 | trainer.current_epoch = 0 600 | return None, trainer.test() 601 | 602 | graduated = trainer.train() 603 | trainer.save_checkpoint('last') 604 | test_meters = trainer.test() if graduated or args.test_not_graduated else None 605 | return graduated, test_meters 606 | 607 | 608 | if __name__ == '__main__': 609 | stats = [] 610 | nr_graduated = 0 611 | 612 | for i in range(args.runs): 613 | graduated, test_meters = main(i) 614 | logger.info('run {}'.format(i + 1)) 615 | 616 | if test_meters is not None: 617 | for j, meters in enumerate(test_meters): 618 | if len(stats) <= j: 619 | stats.append(GroupMeters()) 620 | stats[j].update( 621 | number=meters.avg['number'], test_succ=meters.avg['succ']) 622 | 623 | for meters in stats: 624 | logger.info('number {}, test_succ {}'.format(meters.avg['number'], 625 | meters.avg['test_succ'])) 626 | 627 | if not args.test_only: 628 | nr_graduated += int(graduated) 629 | logger.info('graduate_ratio {}'.format(nr_graduated / (i + 1))) 630 | if graduated: 631 | for j, meters in enumerate(test_meters): 632 | stats[j].update(grad_test_succ=meters.avg['succ']) 633 | if nr_graduated > 0: 634 | for meters in stats: 635 | logger.info('number {}, grad_test_succ {}'.format( 636 | meters.avg['number'], meters.avg['grad_test_succ'])) 637 | -------------------------------------------------------------------------------- /vis/README.md: -------------------------------------------------------------------------------- 1 | # Neural Logic Machine Visualization 2 | 3 | This directory is the website visualizer for sorting, shortest path and blocks world task. 4 | 5 | It includes an example for each task, which generated by the training/testing code. 6 | 7 | ``` 8 | $ cd ROOT_DIR_OF_PEOJECT 9 | $ python2 -m SimpleHTTPServer -p PORT (default 8000) 10 | Open localhost:8000 in browser, and open the corresponding html file 11 | click "run" button. 12 | ``` 13 | 14 | To visualize another play, edit the value of "DATA_URL" in the html file. 15 | 16 | htmls: [sort, path, blocksworld] 17 | -------------------------------------------------------------------------------- /vis/blocksworld.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 106 | 107 | 108 | 109 | 110 | 111 |
112 | 115 |
116 |
World 0 (Source)
117 |
118 | 119 |
120 | 121 |
122 |
123 |
World 1 (Target)
124 |
125 | 126 |
127 |
128 |
129 | 130 |
131 | Step: 132 | ? 133 |
134 |
135 |
136 | Move (?, ?) 137 |
138 |
139 |
140 |
141 |
142 |
143 |
144 | 145 |
146 |
Inputs: Comparison Matrix
147 |
148 |
149 | 150 |
World 0 : World 0
151 |
152 |
153 | 154 |
World 0 : World 1
155 |
156 |
157 | 158 |
World 1 : World 1
159 |
160 |
161 |
162 |
163 | 164 | 166 | 167 | 168 | 169 | 438 | 439 | 440 | 441 | -------------------------------------------------------------------------------- /vis/path.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | Shortest path 5 | 102 | 103 | 104 | 105 | 106 |
107 | 108 | 109 |
110 |

Shortest Path

111 |
    112 |
113 |
114 |
115 | 116 | 117 |
118 | 119 |
120 | 121 | 122 | 123 | 426 | 427 | 428 | 429 | -------------------------------------------------------------------------------- /vis/path.json: -------------------------------------------------------------------------------- 1 | {"coordinates": [[0.12631716862509446, 0.6855076118490626], [0.5544261529837252, 0.30608487473501045], [0.5005915892542612, 0.9580584987550406], [0.01512028143186761, 0.42942676493276843], [0.47788297550968006, 0.9849410778207275], [0.9101169513951646, 0.5094584487376643], [0.839880593710889, 0.8820075046113616], [0.4047856477323437, 0.9986184147793395], [0.20806824905298538, 0.27702184560356324], [0.2305353902874855, 0.3449516438448228], [0.9751901145157092, 0.306683855099085], [0.26724241676244476, 0.47029149231329936], [0.2645407204872413, 0.058618971772641415], [0.9940535448570998, 0.4819384171630231], [0.7654864949943191, 0.7115964727441639], [0.9421460493475816, 0.8465873996796955], [0.8407562899174361, 0.28182365781207885], [0.2536753198806271, 0.7403949433437086], [0.2564644415907945, 0.9685134384147592], [0.6958467262803021, 0.4043946043679537]], "destination": 3, "policies": [[[15, 0.5494989156723022, -15.328779220581055], [6, 0.4505009651184082, -15.527425765991211], [13, 4.3655653314544907e-08, -31.676963806152344], [5, 4.3655653314544907e-08, -31.676963806152344], [3, 1.5620519133274606e-14, -46.52022171020508], [11, 4.69135391765221e-16, -50.02566909790039], [18, 4.655698065737541e-16, -50.03329849243164], [1, 4.654881208339032e-16, -50.03347396850586], [2, 4.538940397310048e-16, -50.05870056152344], [17, 4.387526640543441e-16, -50.09262466430664]], [[2, 1.0, 13.086091995239258], [6, 6.388004850836713e-18, -26.50601577758789], [19, 9.999999682655225e-21, -50.26811599731445], [18, 9.999999682655225e-21, -50.065242767333984], [1, 9.999999682655225e-21, -50.033851623535156], [3, 9.999999682655225e-21, -46.67695617675781], [4, 9.999999682655225e-21, -50.15426254272461], [5, 9.999999682655225e-21, -35.77608108520508], [7, 9.999999682655225e-21, -50.32582473754883], [8, 9.999999682655225e-21, -50.102718353271484]], [[17, 1.0, 28.310367584228516], [19, 9.999999682655225e-21, -50.25307846069336], [8, 9.999999682655225e-21, -50.10295867919922], [1, 9.999999682655225e-21, -50.034000396728516], [2, 9.999999682655225e-21, -52.38613510131836], [3, 9.999999682655225e-21, -46.76729965209961], [4, 9.999999682655225e-21, -30.387378692626953], [5, 9.999999682655225e-21, -50.26861572265625], [6, 9.999999682655225e-21, -37.818519592285156], [7, 9.999999682655225e-21, -39.61301803588867]], [[3, 1.0, 55.797306060791016], [19, 9.999999682655225e-21, -50.253299713134766], [18, 9.999999682655225e-21, -38.05373001098633], [1, 9.999999682655225e-21, -50.06608963012695], [2, 9.999999682655225e-21, -38.03934860229492], [4, 9.999999682655225e-21, -38.8900032043457], [5, 9.999999682655225e-21, -50.25330352783203], [6, 9.999999682655225e-21, -50.268497467041016], [7, 9.999999682655225e-21, -50.3636589050293], [8, 9.999999682655225e-21, -50.12850570678711]]], "current": [14, 15, 2, 17, 3], "moves": [15, 2, 17, 3], "graph": [[0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 1, 0], [0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 0, 0, 0, 1, 0, 0, 1], [0, 0, 0, 0, 1, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 1, 0], [1, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 1, 0, 0, 0, 0, 0, 1, 0, 0], [0, 0, 1, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0], [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 1, 1, 1, 1, 0, 0, 1], [0, 0, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0], [0, 0, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0], [0, 1, 0, 1, 0, 0, 0, 0, 0, 1, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0], [0, 1, 0, 1, 0, 0, 0, 0, 1, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0], [0, 1, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 1, 0, 0, 1], [1, 1, 0, 1, 0, 0, 0, 0, 1, 1, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0], [0, 1, 0, 0, 0, 0, 0, 0, 1, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 1, 1, 1, 0, 0, 1], [0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 0, 0, 0], [0, 0, 1, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0], [0, 1, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 1, 0, 0, 0, 0, 0, 1], [1, 0, 1, 1, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 0], [1, 0, 1, 0, 1, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0], [0, 1, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 1, 0, 0, 1, 0, 0, 0]]} -------------------------------------------------------------------------------- /vis/sort.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 60 | 61 | 62 | 63 | 64 |
65 | 78 |
79 | 80 |
81 | 82 |
83 |
84 | Step: 85 | ? 86 |
87 |
88 |
89 | Swap (?, ?) 90 |
91 |
92 |
93 |
94 |
95 |
96 |
97 | 98 |
99 |
Inputs: Comparison Matrix
100 |
101 |
102 | 103 | 104 | 105 | 106 | 291 | 292 | 293 | 294 | -------------------------------------------------------------------------------- /vis/sort.json: -------------------------------------------------------------------------------- 1 | { 2 | "array": [ 3 | "4", 4 | "14", 5 | "19", 6 | "13", 7 | "6", 8 | "17", 9 | "9", 10 | "12", 11 | "5", 12 | "8", 13 | "1", 14 | "15", 15 | "18", 16 | "2", 17 | "16", 18 | "0", 19 | "7", 20 | "10", 21 | "11", 22 | "3" 23 | ], 24 | "moves": [ 25 | [ 26 | 0, 27 | 19 28 | ], 29 | [ 30 | 0, 31 | 15 32 | ], 33 | [ 34 | 2, 35 | 19 36 | ], 37 | [ 38 | 1, 39 | 10 40 | ], 41 | [ 42 | 2, 43 | 13 44 | ], 45 | [ 46 | 3, 47 | 15 48 | ], 49 | [ 50 | 4, 51 | 13 52 | ], 53 | [ 54 | 5, 55 | 8 56 | ], 57 | [ 58 | 6, 59 | 16 60 | ], 61 | [ 62 | 7, 63 | 13 64 | ], 65 | [ 66 | 6, 67 | 7 68 | ], 69 | [ 70 | 8, 71 | 9 72 | ], 73 | [ 74 | 9, 75 | 16 76 | ], 77 | [ 78 | 10, 79 | 17 80 | ], 81 | [ 82 | 11, 83 | 18 84 | ], 85 | [ 86 | 12, 87 | 13 88 | ], 89 | [ 90 | 13, 91 | 15 92 | ], 93 | [ 94 | 14, 95 | 17 96 | ], 97 | [ 98 | 15, 99 | 18 100 | ], 101 | [ 102 | 16, 103 | 17 104 | ] 105 | ] 106 | } 107 | 108 | --------------------------------------------------------------------------------