├── .github └── workflows │ └── python-package.yml ├── .gitignore ├── LICENSE.txt ├── README.md ├── check.sh ├── cxx ├── .gitignore ├── Makefile ├── assets │ ├── .gitignore │ ├── animals.binary.obs │ ├── animals.binary.schema │ ├── animals.unary.obs │ ├── animals.unary.schema │ ├── nations.binary.obs │ ├── nations.binary.schema │ ├── nations.unary.obs │ ├── nations.unary.schema │ ├── two_relations.obs │ └── two_relations.schema ├── cxxopts.hpp ├── globals.hh ├── hirm.cc ├── hirm.hh ├── tests │ ├── test_hirm_animals.cc │ ├── test_irm_two_relations.cc │ ├── test_misc.cc │ └── test_util_math.cc ├── util_hash.hh ├── util_io.cc ├── util_io.hh ├── util_math.cc └── util_math.hh ├── examples ├── .gitignore ├── animals_binary_irm.py ├── animals_unary_hirm.py ├── animals_unary_irm.py ├── assets │ └── .gitignore ├── datasets │ ├── 50animalbindat.csv │ ├── 50animalbindat.mat │ ├── README_DATASETS │ ├── alyawarradata.mat │ ├── animals.binary.obs │ ├── animals.binary.schema │ ├── animals.unary.obs │ ├── animals.unary.schema │ ├── convert_animals.py │ ├── convert_nations.py │ ├── dnations.mat │ ├── irmdata.tar.gz │ ├── nations.binary.obs │ ├── nations.binary.schema │ ├── nations.unary.obs │ ├── nations.unary.schema │ └── uml.mat ├── nations_binary_irm.py ├── nations_unary_hirm.py ├── three_relations.py ├── three_relations_plot.py ├── two_clusters_binary_irm.py ├── two_clusters_unary_irm.py ├── two_relations.py ├── two_relations_anti.py └── two_relations_hirm.py ├── pythenv.sh ├── setup.py ├── src ├── __init__.py ├── hirm.py ├── util_io.py ├── util_math.py └── util_plot.py └── tests ├── .gitignore ├── __init__.py ├── test_basic.py └── util_test.py /.github/workflows/python-package.yml: -------------------------------------------------------------------------------- 1 | # This workflow will install Python dependencies, run tests and lint with a variety of Python versions 2 | # For more information see: https://help.github.com/actions/language-and-framework-guides/using-python-with-github-actions 3 | 4 | name: Python package 5 | 6 | on: 7 | push: 8 | branches: [ master ] 9 | pull_request: 10 | branches: [ master ] 11 | 12 | jobs: 13 | build: 14 | 15 | runs-on: ubuntu-latest 16 | strategy: 17 | matrix: 18 | python-version: [3.8] 19 | 20 | steps: 21 | - uses: actions/checkout@v2 22 | - name: Set up Python ${{ matrix.python-version }} 23 | uses: actions/setup-python@v2 24 | with: 25 | python-version: ${{ matrix.python-version }} 26 | - name: Test Python and C++ 27 | run: | 28 | python -m pip install --upgrade pip 29 | python -m pip install . 30 | ./check.sh 31 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | *.egg-info 2 | .cache/ 3 | .coverage 4 | __pycache__/ 5 | build/ 6 | dist/ 7 | htmlcov/ 8 | -------------------------------------------------------------------------------- /LICENSE.txt: -------------------------------------------------------------------------------- 1 | 2 | Apache License 3 | Version 2.0, January 2004 4 | http://www.apache.org/licenses/ 5 | 6 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 7 | 8 | 1. Definitions. 9 | 10 | "License" shall mean the terms and conditions for use, reproduction, 11 | and distribution as defined by Sections 1 through 9 of this document. 12 | 13 | "Licensor" shall mean the copyright owner or entity authorized by 14 | the copyright owner that is granting the License. 15 | 16 | "Legal Entity" shall mean the union of the acting entity and all 17 | other entities that control, are controlled by, or are under common 18 | control with that entity. For the purposes of this definition, 19 | "control" means (i) the power, direct or indirect, to cause the 20 | direction or management of such entity, whether by contract or 21 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 22 | outstanding shares, or (iii) beneficial ownership of such entity. 23 | 24 | "You" (or "Your") shall mean an individual or Legal Entity 25 | exercising permissions granted by this License. 26 | 27 | "Source" form shall mean the preferred form for making modifications, 28 | including but not limited to software source code, documentation 29 | source, and configuration files. 30 | 31 | "Object" form shall mean any form resulting from mechanical 32 | transformation or translation of a Source form, including but 33 | not limited to compiled object code, generated documentation, 34 | and conversions to other media types. 35 | 36 | "Work" shall mean the work of authorship, whether in Source or 37 | Object form, made available under the License, as indicated by a 38 | copyright notice that is included in or attached to the work 39 | (an example is provided in the Appendix below). 40 | 41 | "Derivative Works" shall mean any work, whether in Source or Object 42 | form, that is based on (or derived from) the Work and for which the 43 | editorial revisions, annotations, elaborations, or other modifications 44 | represent, as a whole, an original work of authorship. For the purposes 45 | of this License, Derivative Works shall not include works that remain 46 | separable from, or merely link (or bind by name) to the interfaces of, 47 | the Work and Derivative Works thereof. 48 | 49 | "Contribution" shall mean any work of authorship, including 50 | the original version of the Work and any modifications or additions 51 | to that Work or Derivative Works thereof, that is intentionally 52 | submitted to Licensor for inclusion in the Work by the copyright owner 53 | or by an individual or Legal Entity authorized to submit on behalf of 54 | the copyright owner. For the purposes of this definition, "submitted" 55 | means any form of electronic, verbal, or written communication sent 56 | to the Licensor or its representatives, including but not limited to 57 | communication on electronic mailing lists, source code control systems, 58 | and issue tracking systems that are managed by, or on behalf of, the 59 | Licensor for the purpose of discussing and improving the Work, but 60 | excluding communication that is conspicuously marked or otherwise 61 | designated in writing by the copyright owner as "Not a Contribution." 62 | 63 | "Contributor" shall mean Licensor and any individual or Legal Entity 64 | on behalf of whom a Contribution has been received by Licensor and 65 | subsequently incorporated within the Work. 66 | 67 | 2. Grant of Copyright License. Subject to the terms and conditions of 68 | this License, each Contributor hereby grants to You a perpetual, 69 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 70 | copyright license to reproduce, prepare Derivative Works of, 71 | publicly display, publicly perform, sublicense, and distribute the 72 | Work and such Derivative Works in Source or Object form. 73 | 74 | 3. Grant of Patent License. Subject to the terms and conditions of 75 | this License, each Contributor hereby grants to You a perpetual, 76 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 77 | (except as stated in this section) patent license to make, have made, 78 | use, offer to sell, sell, import, and otherwise transfer the Work, 79 | where such license applies only to those patent claims licensable 80 | by such Contributor that are necessarily infringed by their 81 | Contribution(s) alone or by combination of their Contribution(s) 82 | with the Work to which such Contribution(s) was submitted. If You 83 | institute patent litigation against any entity (including a 84 | cross-claim or counterclaim in a lawsuit) alleging that the Work 85 | or a Contribution incorporated within the Work constitutes direct 86 | or contributory patent infringement, then any patent licenses 87 | granted to You under this License for that Work shall terminate 88 | as of the date such litigation is filed. 89 | 90 | 4. Redistribution. You may reproduce and distribute copies of the 91 | Work or Derivative Works thereof in any medium, with or without 92 | modifications, and in Source or Object form, provided that You 93 | meet the following conditions: 94 | 95 | (a) You must give any other recipients of the Work or 96 | Derivative Works a copy of this License; and 97 | 98 | (b) You must cause any modified files to carry prominent notices 99 | stating that You changed the files; and 100 | 101 | (c) You must retain, in the Source form of any Derivative Works 102 | that You distribute, all copyright, patent, trademark, and 103 | attribution notices from the Source form of the Work, 104 | excluding those notices that do not pertain to any part of 105 | the Derivative Works; and 106 | 107 | (d) If the Work includes a "NOTICE" text file as part of its 108 | distribution, then any Derivative Works that You distribute must 109 | include a readable copy of the attribution notices contained 110 | within such NOTICE file, excluding those notices that do not 111 | pertain to any part of the Derivative Works, in at least one 112 | of the following places: within a NOTICE text file distributed 113 | as part of the Derivative Works; within the Source form or 114 | documentation, if provided along with the Derivative Works; or, 115 | within a display generated by the Derivative Works, if and 116 | wherever such third-party notices normally appear. The contents 117 | of the NOTICE file are for informational purposes only and 118 | do not modify the License. You may add Your own attribution 119 | notices within Derivative Works that You distribute, alongside 120 | or as an addendum to the NOTICE text from the Work, provided 121 | that such additional attribution notices cannot be construed 122 | as modifying the License. 123 | 124 | You may add Your own copyright statement to Your modifications and 125 | may provide additional or different license terms and conditions 126 | for use, reproduction, or distribution of Your modifications, or 127 | for any such Derivative Works as a whole, provided Your use, 128 | reproduction, and distribution of the Work otherwise complies with 129 | the conditions stated in this License. 130 | 131 | 5. Submission of Contributions. Unless You explicitly state otherwise, 132 | any Contribution intentionally submitted for inclusion in the Work 133 | by You to the Licensor shall be under the terms and conditions of 134 | this License, without any additional terms or conditions. 135 | Notwithstanding the above, nothing herein shall supersede or modify 136 | the terms of any separate license agreement you may have executed 137 | with Licensor regarding such Contributions. 138 | 139 | 6. Trademarks. This License does not grant permission to use the trade 140 | names, trademarks, service marks, or product names of the Licensor, 141 | except as required for reasonable and customary use in describing the 142 | origin of the Work and reproducing the content of the NOTICE file. 143 | 144 | 7. Disclaimer of Warranty. Unless required by applicable law or 145 | agreed to in writing, Licensor provides the Work (and each 146 | Contributor provides its Contributions) on an "AS IS" BASIS, 147 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 148 | implied, including, without limitation, any warranties or conditions 149 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 150 | PARTICULAR PURPOSE. You are solely responsible for determining the 151 | appropriateness of using or redistributing the Work and assume any 152 | risks associated with Your exercise of permissions under this License. 153 | 154 | 8. Limitation of Liability. In no event and under no legal theory, 155 | whether in tort (including negligence), contract, or otherwise, 156 | unless required by applicable law (such as deliberate and grossly 157 | negligent acts) or agreed to in writing, shall any Contributor be 158 | liable to You for damages, including any direct, indirect, special, 159 | incidental, or consequential damages of any character arising as a 160 | result of this License or out of the use or inability to use the 161 | Work (including but not limited to damages for loss of goodwill, 162 | work stoppage, computer failure or malfunction, or any and all 163 | other commercial damages or losses), even if such Contributor 164 | has been advised of the possibility of such damages. 165 | 166 | 9. Accepting Warranty or Additional Liability. While redistributing 167 | the Work or Derivative Works thereof, You may choose to offer, 168 | and charge a fee for, acceptance of support, warranty, indemnity, 169 | or other liability obligations and/or rights consistent with this 170 | License. However, in accepting such obligations, You may act only 171 | on Your own behalf and on Your sole responsibility, not on behalf 172 | of any other Contributor, and only if You agree to indemnify, 173 | defend, and hold each Contributor harmless for any liability 174 | incurred by, or claims asserted against, such Contributor by reason 175 | of your accepting any such warranty or additional liability. 176 | 177 | END OF TERMS AND CONDITIONS 178 | 179 | APPENDIX: How to apply the Apache License to your work. 180 | 181 | To apply the Apache License to your work, attach the following 182 | boilerplate notice, with the fields enclosed by brackets "[]" 183 | replaced with your own identifying information. (Don't include 184 | the brackets!) The text should be enclosed in the appropriate 185 | comment syntax for the file format. We also recommend that a 186 | file or class name and description of purpose be included on the 187 | same "printed page" as the copyright notice for easier 188 | identification within third-party archives. 189 | 190 | Copyright [yyyy] [name of copyright owner] 191 | 192 | Licensed under the Apache License, Version 2.0 (the "License"); 193 | you may not use this file except in compliance with the License. 194 | You may obtain a copy of the License at 195 | 196 | http://www.apache.org/licenses/LICENSE-2.0 197 | 198 | Unless required by applicable law or agreed to in writing, software 199 | distributed under the License is distributed on an "AS IS" BASIS, 200 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 201 | See the License for the specific language governing permissions and 202 | limitations under the License. 203 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Hierarchical Infinite Relational Model 2 | 3 | [![Actions Status](https://github.com/probcomp/hierarchical-irm/workflows/Python%20package/badge.svg)](https://github.com/probcomp/hierarchical-irm/actions) 4 | [![pypi](https://img.shields.io/pypi/v/hirm.svg)](https://pypi.org/project/hirm/) 5 | 6 | This repository contains implementations of the Hierarchical Infinite 7 | Relational Model (HIRM), a Bayesian method for automatic structure discovery in 8 | relational data. The method is described in: 9 | 10 | Hierarchical Infinite Relational Model. Saad, Feras A. and Mansinghka, Vikash K. 11 | In: Proc. 37th UAI, 2021. 12 | 13 | 14 | ## Installation (Python) 15 | 16 | This software is tested on Ubuntu 18.04+ and requires a Python 3.6+ 17 | environment. The library can be installed from the PyPI repository using 18 | 19 | $ python -m pip install hirm 20 | 21 | The library will be available as a module named `hirm`. 22 | 23 | The test suite can be invoked via 24 | 25 | $ python -m pytest --pyargs hirm 26 | 27 | ## Running Python examples 28 | 29 | The [examples/](examples) are run using the Python (slower) backend, and in 30 | several cases for fewer MCMC iterations (e.g., <=20) than are needed for 31 | chains to converge. To invoke all the examples, first clone this repository 32 | then run 33 | 34 | $ ./check.sh examples 35 | 36 | The outputs and plots are written to [examples/assets](examples/assets). 37 | To run a specific example 38 | 39 | $ cd examples 40 | $ python two_relations.py 41 | 42 | ## Installation (C++) 43 | 44 | First obtain a GNU C++ compiler, version 7.5.0 or higher. 45 | The binary can be installed by first cloning this repository and then writing 46 | 47 | $ cd cxx 48 | $ make hirm.out 49 | 50 | The test suite can be invoked via 51 | 52 | $ make tests 53 | 54 | A command-line interface to the HIRM is provided under `cxx/hirm.out`. 55 | 56 | For an example of using the C++ library, refer to 57 | [`cxx/tests/test_hirm_animals.cc`](cxx/tests/test_hirm_animals.cc). 58 | 59 | ## Usage: Command Line Interface 60 | 61 | First build the C++ code as described above and then run the binary in 62 | `cxx/hirm.out`. It shows the following synopsis 63 | 64 | ``` 65 | $ ./hirm.out --help 66 | Run a hierarchical infinite relational model. 67 | Usage: 68 | hirm.out [OPTION...] 69 | 70 | --help show help message 71 | --mode arg options are {irm, hirm} (default: hirm) 72 | --seed arg random seed (default: 10) 73 | --iters arg number of inference iterations (default: 10) 74 | --verbose report results to terminal 75 | --timeout arg number of seconds of inference (default: 0) 76 | --load arg path to .[h]irm file with initial clusters (default: "") 77 | ``` 78 | 79 | We will explain the usage by way of the following example 80 | 81 | $ cd cxx 82 | $ ./hirm.out assets/animals.unary 83 | setting seed to 10 84 | loading schema from assets/animals.unary.schema 85 | loading observations from assets/animals.unary.obs 86 | selected model is HIRM 87 | incorporating observations 88 | inferring 10 iters; timeout 0 89 | saving to assets/animals.unary.10.hirm 90 | 91 | In this example we have specified `` to be `assets/animals.unary`. 92 | It is required for there to be two input files on disk: 93 | 1. Schema file: of the form `.schema`. 94 | 2. Observation file: of the form `.obs`. 95 | 96 | The output file is `assets/animals.unary.10.hirm`. 97 | 98 | We next describe the input and output files. 99 | 100 | #### Schema file 101 | 102 | The schema file `assets/animals.unary.schema` specifies the signature of 103 | the relations in the system: 104 | 105 | ``` 106 | $ cat assets/animals.unary.schema 107 | bernoulli black animal 108 | bernoulli white animal 109 | bernoulli blue animal 110 | bernoulli brown animal 111 | bernoulli gray animal 112 | bernoulli orange animal 113 | bernoulli red animal 114 | bernoulli yellow animal 115 | bernoulli patches animal 116 | bernoulli spots animal 117 | ... 118 | ``` 119 | 120 | Each line specifies the signature of a relation in the system: 121 | 122 | - The first entry is the observation type 123 | (only `bernoulli` is supported at the moment). 124 | - The second entry is the name of the relation (e.g., `black`); all the 125 | relations names must be unique. 126 | - The third entry is the domain of the relation (in this example, the only 127 | domain is `animal`). 128 | 129 | Thus, for this schema, we have a list of unary relations that each specify 130 | whether an `animal` has a given attribute. 131 | 132 | Note that, in general a given relational system can be encoded in multiple 133 | ways. See `assets/animals.binary.schema` for an encoding of this system using 134 | a single higher-order relation with signature: `bernoulli has feature animal`. 135 | 136 | #### Observation file 137 | 138 | The observation file `assets/animals.unary.obs` specifies realizations of the relations 139 | 140 | ``` 141 | $ cat assets/animals.unary.obs 142 | 0 black antelope 143 | 1 black grizzlybear 144 | 1 black killerwhale 145 | 0 black beaver 146 | 1 black dalmatian 147 | 0 black persiancat 148 | 1 black horse 149 | 1 black germanshepherd 150 | 0 black bluewhale 151 | 1 black siamesecat 152 | ... 153 | ``` 154 | 155 | Each line specifies a single observation: 156 | 157 | - The first entry is 0 or 1 158 | - The second entry is the relation name (there must be a corresponding 159 | relation with the same name in the schema file) 160 | - The third entry and afterwards are the names of domain entities; e.g, 161 | `antelope`, `grizzlybear`, etc., are entities in the `animals` domain. 162 | The number of domain entities must correspond to the arity of the 163 | relation from the schema file. Since all the relations in this example 164 | are unary, there is only one entity after each relation name. 165 | 166 | Thus, for this observation file, we have observations `black(antelope) = 0`, 167 | `black(grizzlybear) = 1`, and so on. 168 | 169 | #### Output file 170 | 171 | The output file `assets/animals.unary.10.hirm` specifies the learned 172 | clusterings of relations and domain entities. The output file is comprised 173 | of multiple sections, each delimited by a single blank line. 174 | 175 | ``` 176 | $ cat assets/animals.unary.10.hirm 177 | 0 oldworld black insects skimmer chewteeth agility bulbous fast lean orange inactive slow stripes tail red active 178 | 1 quadrapedal paws strainteeth pads meatteeth hooves longneck ocean coastal hunter hairless smart group nocturnal meat buckteeth plankton plains timid horns hibernate forager ground grazer furry fields brown solitary stalker toughskin water arctic blue smelly claws swims vegetation fish flippers walks 179 | 5 mountains jungle forest bipedal cave desert fierce nestspot tree tusks yellow hands scavenger flys 180 | 6 muscle longleg domestic tunnels newworld bush big gray spots strong weak patches white hops small 181 | 182 | irm=0 183 | animal 0 giraffe seal horse bat rabbit chimpanzee killerwhale dalmatian mole chihuahua zebra deer lion mouse raccoon dolphin collie bobcat tiger siamesecat germanshepherd otter weasel spidermonkey beaver leopard antelope gorilla fox hamster squirrel wolf rat 184 | animal 1 skunk persiancat giantpanda polarbear moose pig buffalo elephant cow sheep grizzlybear ox humpbackwhale walrus rhinoceros bluewhale hippopotamus 185 | 186 | irm=1 187 | animal 0 mouse rabbit zebra moose antelope horse buffalo deer ox cow gorilla pig rhinoceros chimpanzee giraffe sheep spidermonkey elephant 188 | animal 1 collie germanshepherd siamesecat giantpanda chihuahua lion raccoon squirrel grizzlybear dalmatian rat persiancat weasel leopard skunk bobcat mole tiger hamster fox wolf 189 | animal 3 otter walrus humpbackwhale killerwhale bluewhale dolphin seal 190 | animal 4 polarbear bat 191 | animal 5 hippopotamus beaver 192 | 193 | irm=5 194 | animal 0 antelope germanshepherd elephant hippopotamus tiger rhinoceros zebra giraffe killerwhale sheep humpbackwhale mole hamster persiancat horse siamesecat chihuahua cow dolphin walrus collie polarbear mouse pig deer moose skunk bluewhale buffalo dalmatian rat beaver ox fox seal rabbit wolf weasel otter 195 | animal 1 squirrel raccoon giantpanda gorilla lion bat spidermonkey chimpanzee grizzlybear bobcat leopard 196 | 197 | irm=6 198 | animal 0 horse killerwhale spidermonkey deer giraffe germanshepherd rhinoceros leopard moose fox wolf buffalo dolphin bluewhale grizzlybear chimpanzee walrus lion bobcat zebra beaver elephant ox antelope gorilla hippopotamus humpbackwhale polarbear tiger 199 | animal 1 collie squirrel raccoon chihuahua sheep hamster rabbit rat mouse skunk persiancat weasel mole bat otter siamesecat 200 | animal 2 dalmatian giantpanda cow pig 201 | animal 3 seal 202 | ``` 203 | 204 | The first section in the file specifies the clustering of the relations. 205 | Each line specifies a relation cluster, for example: 206 | 207 | ``` 208 | 0 oldworld black insects skimmer chewteeth agility bulbous fast lean orange inactive slow stripes tail red active 209 | ``` 210 | 211 | Here, the first entry is a unique integer code for the cluster index and the 212 | remaining entries are names of relations that belong to this cluster. 213 | We see that there are four relation clusters with indexes `[0, 1, 5, 6]`. 214 | 215 | All the remaining sections in the file start with `irm=x`, where `x` is an 216 | integer code from the first section, for example: 217 | 218 | ``` 219 | irm=6 220 | animal 0 horse killerwhale spidermonkey deer giraffe germanshepherd rhinoceros leopard moose fox wolf buffalo dolphin bluewhale grizzlybear chimpanzee walrus lion bobcat zebra beaver elephant ox antelope gorilla hippopotamus humpbackwhale polarbear tiger 221 | animal 1 collie squirrel raccoon chihuahua sheep hamster rabbit rat mouse skunk persiancat weasel mole bat otter siamesecat 222 | animal 2 dalmatian giantpanda cow pig 223 | animal 3 seal 224 | ``` 225 | 226 | Each subsequent line in the `irm=6` section specifies a cluster for a given 227 | domain, for example 228 | 229 | ``` 230 | animal 2 dalmatian giantpanda cow pig 231 | ``` 232 | 233 | Here, the first entry is the name of the domain, the second entry is a 234 | unique integer for the cluster index, and the remaining entries are names 235 | of entities within the domain that belong to this cluster. Recall that the 236 | schema file `assets/animals.unary.schema` has only one domain, so all the 237 | lines in the `irm` section start with `animal`. 238 | 239 | ## Citation 240 | 241 | To cite this work, please use the following BibTeX. 242 | 243 | ```bibtex 244 | @inproceedings{saad2021hirm, 245 | title = {Hierarchical Infinite Relational Model}, 246 | author = {Saad, Feras A. and Mansinghka, Vikash K.}, 247 | booktitle = {UAI 2021: Proceedings of the 37th Conference on Uncertainty in Artificial Intelligence}, 248 | fseries = {Proceedings of Machine Learning Research}, 249 | year = 2021, 250 | location = {Online}, 251 | publisher = {AUAI Press}, 252 | address = {Arlington, VA, USA}, 253 | } 254 | ``` 255 | 256 | ## License 257 | 258 | Copyright (c) 2021 MIT Probabilistic Computing Project 259 | 260 | Licensed under the Apache License, Version 2.0 (the "License"); 261 | you may not use this file except in compliance with the License. 262 | You may obtain a copy of the License at 263 | 264 | http://www.apache.org/licenses/LICENSE-2.0 265 | 266 | Unless required by applicable law or agreed to in writing, software 267 | distributed under the License is distributed on an "AS IS" BASIS, 268 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 269 | See the License for the specific language governing permissions and 270 | limitations under the License. 271 | -------------------------------------------------------------------------------- /check.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | 3 | # Copyright 2021 MIT Probabilistic Computing Project 4 | # Apache License, Version 2.0, refer to LICENSE.txt 5 | 6 | set -Ceux 7 | 8 | : ${PYTHON:=python} 9 | 10 | root=$(cd -- "$(dirname -- "$0")" && pwd) 11 | 12 | ( 13 | set -Ceu 14 | cd -- "${root}" 15 | rm -rf build 16 | "$PYTHON" setup.py build 17 | if [ $# -eq 0 ]; then 18 | # (Default) Run tests/ 19 | ./pythenv.sh "$PYTHON" -m pytest --pyargs hirm 20 | cd cxx && make tests 21 | elif [ ${1} = 'coverage' ]; then 22 | # Generate coverage report. 23 | ./pythenv.sh coverage run --source=build/ -m pytest --pyargs hirm 24 | coverage html 25 | coverage report 26 | elif [ ${1} = 'examples' ]; then 27 | # Run the .py files under examples/ 28 | cd examples 29 | for x in *.py; do 30 | MPLBACKEND=agg python "${x}" || continue 31 | done 32 | elif [ ${1} = 'release' ]; then 33 | # Make a release to pypi 34 | rm -rf dist 35 | "$PYTHON" setup.py sdist bdist_wheel 36 | twine upload --repository pypi dist/* 37 | elif [ ${1} = 'tag' ]; then 38 | # Make a tagged release, e.g., ./check.sh 2.0.0 39 | status="$(git diff --stat && git diff --staged)" 40 | [ -z "${status}" ] || (echo 'fatal: tag dirty' && exit 1) 41 | tag="${2}" 42 | sed -i "s/__version__ = .*/__version__ = '${tag}'/g" -- src/__init__.py 43 | git add -- src/__init__.py 44 | git commit -m "Pin version ${tag}." 45 | git tag -a -m v"${tag}" v"${tag}" 46 | else 47 | # If args are specified delegate control to user. 48 | ./pythenv.sh "$PYTHON" -m pytest "$@" 49 | fi 50 | ) 51 | -------------------------------------------------------------------------------- /cxx/.gitignore: -------------------------------------------------------------------------------- 1 | *.hirm 2 | *.irm 3 | *.o 4 | *.out 5 | *.prof 6 | -------------------------------------------------------------------------------- /cxx/Makefile: -------------------------------------------------------------------------------- 1 | # Copyright 2021 MIT Probabilistic Computing Project 2 | # Apache License, Version 2.0, refer to LICENSE.txt 3 | 4 | GCC_VERSION := $(shell g++ -dumpversion) 5 | MIN_VERSION := 7 6 | 7 | ifeq ($(shell test $(GCC_VERSION) -lt $(MIN_VERSION); echo $$?),0) 8 | $(error ERROR: g++ version $(VERSION) or higher is requried, found $(GCC_VERSION)) 9 | endif 10 | 11 | CXX=g++ 12 | 13 | ifdef HIRMDEBUG 14 | CXXFLAGS= -O0 -pg -g -std=c++17 15 | else 16 | CXXFLAGS= -O3 -std=c++17 17 | endif 18 | 19 | INCDIR=$(shell pwd) 20 | INCFLAGS=-I$(INCDIR) 21 | 22 | # https://www.gnu.org/software/make/manual/html_node/Catalogue-of-Rules.html 23 | LHEADER = $(wildcard *.hh) 24 | LSOURCE = $(filter-out hirm.cc,$(wildcard *.cc)) 25 | LOBJECT = $(LSOURCE:.cc=.o) 26 | 27 | TEST_DIR=tests 28 | LTEST = $(wildcard $(TEST_DIR)/*.cc) 29 | 30 | # Prevent LOBJECT from being removed. 31 | # https://stackoverflow.com/a/29114706/ 32 | .SECONDARY: $(LOBJECT) 33 | 34 | %.o : %.cc $(LHEADER) 35 | $(CXX) -c $(CXXFLAGS) -o $@ $< 36 | 37 | %.out : %.cc $(LOBJECT) 38 | $(CXX) $(CXXFLAGS) $(INCFLAGS) -o $@ $(LOBJECT) $< 39 | 40 | .PHONY: tests 41 | tests: $(LTEST:.cc=.out) hirm.out 42 | ./tests/test_misc.out 43 | ./tests/test_hirm_animals.out 44 | ./tests/test_irm_two_relations.out 45 | ./hirm.out --mode=irm --iters=5 assets/animals.binary 46 | ./hirm.out --seed=1 --iters=5 assets/animals.unary 47 | ./hirm.out --iters=5 --load=assets/animals.unary.1.hirm assets/animals.unary 48 | 49 | # Make sure -pg is not in CXXFLAGS 50 | # since valgrind finds leaks from gmon 51 | .PHONY: leak-check clean 52 | leak-check: $(TEST_DIR)/test_misc.out 53 | ifdef HIRMDEBUG 54 | $(error ERROR: Cannot run Valgrind with HIRMDEBUG) 55 | endif 56 | valgrind \ 57 | --leak-check=full \ 58 | --show-leak-kinds=all \ 59 | --track-origins=yes \ 60 | --verbose \ 61 | ./$< 62 | clean: 63 | rm -rf **.o *.out *.gch 64 | rm -rf tests/*.out 65 | rm -rf assets/*.irm 66 | rm -rf assets/*.hirm 67 | -------------------------------------------------------------------------------- /cxx/assets/.gitignore: -------------------------------------------------------------------------------- 1 | *.irm 2 | *.hirm 3 | -------------------------------------------------------------------------------- /cxx/assets/animals.binary.schema: -------------------------------------------------------------------------------- 1 | bernoulli has feature animal 2 | -------------------------------------------------------------------------------- /cxx/assets/animals.unary.schema: -------------------------------------------------------------------------------- 1 | bernoulli black animal 2 | bernoulli white animal 3 | bernoulli blue animal 4 | bernoulli brown animal 5 | bernoulli gray animal 6 | bernoulli orange animal 7 | bernoulli red animal 8 | bernoulli yellow animal 9 | bernoulli patches animal 10 | bernoulli spots animal 11 | bernoulli stripes animal 12 | bernoulli furry animal 13 | bernoulli hairless animal 14 | bernoulli toughskin animal 15 | bernoulli big animal 16 | bernoulli small animal 17 | bernoulli bulbous animal 18 | bernoulli lean animal 19 | bernoulli flippers animal 20 | bernoulli hands animal 21 | bernoulli hooves animal 22 | bernoulli pads animal 23 | bernoulli paws animal 24 | bernoulli longleg animal 25 | bernoulli longneck animal 26 | bernoulli tail animal 27 | bernoulli chewteeth animal 28 | bernoulli meatteeth animal 29 | bernoulli buckteeth animal 30 | bernoulli strainteeth animal 31 | bernoulli horns animal 32 | bernoulli claws animal 33 | bernoulli tusks animal 34 | bernoulli smelly animal 35 | bernoulli flys animal 36 | bernoulli hops animal 37 | bernoulli swims animal 38 | bernoulli tunnels animal 39 | bernoulli walks animal 40 | bernoulli fast animal 41 | bernoulli slow animal 42 | bernoulli strong animal 43 | bernoulli weak animal 44 | bernoulli muscle animal 45 | bernoulli bipedal animal 46 | bernoulli quadrapedal animal 47 | bernoulli active animal 48 | bernoulli inactive animal 49 | bernoulli nocturnal animal 50 | bernoulli hibernate animal 51 | bernoulli agility animal 52 | bernoulli fish animal 53 | bernoulli meat animal 54 | bernoulli plankton animal 55 | bernoulli vegetation animal 56 | bernoulli insects animal 57 | bernoulli forager animal 58 | bernoulli grazer animal 59 | bernoulli hunter animal 60 | bernoulli scavenger animal 61 | bernoulli skimmer animal 62 | bernoulli stalker animal 63 | bernoulli newworld animal 64 | bernoulli oldworld animal 65 | bernoulli arctic animal 66 | bernoulli coastal animal 67 | bernoulli desert animal 68 | bernoulli bush animal 69 | bernoulli plains animal 70 | bernoulli forest animal 71 | bernoulli fields animal 72 | bernoulli jungle animal 73 | bernoulli mountains animal 74 | bernoulli ocean animal 75 | bernoulli ground animal 76 | bernoulli water animal 77 | bernoulli tree animal 78 | bernoulli cave animal 79 | bernoulli fierce animal 80 | bernoulli timid animal 81 | bernoulli smart animal 82 | bernoulli group animal 83 | bernoulli solitary animal 84 | bernoulli nestspot animal 85 | bernoulli domestic animal 86 | -------------------------------------------------------------------------------- /cxx/assets/nations.binary.schema: -------------------------------------------------------------------------------- 1 | bernoulli has feature country 2 | bernoulli applies predicate country country 3 | -------------------------------------------------------------------------------- /cxx/assets/nations.unary.schema: -------------------------------------------------------------------------------- 1 | bernoulli telephone country 2 | bernoulli agriculturalpop country 3 | bernoulli energyconsume country 4 | bernoulli illiterates country 5 | bernoulli GNP country 6 | bernoulli popxenergabs country 7 | bernoulli incomeabs country 8 | bernoulli popabs country 9 | bernoulli unassessment country 10 | bernoulli defenseexpabs country 11 | bernoulli englishtitles country 12 | bernoulli blocmembership0 country 13 | bernoulli usaidreceived country 14 | bernoulli freedomofopposition0 country 15 | bernoulli IFCandIBRD country 16 | bernoulli threats country 17 | bernoulli accusations country 18 | bernoulli killedforeignviolence country 19 | bernoulli militaryaction country 20 | bernoulli protests country 21 | bernoulli killeddomesticviolence country 22 | bernoulli riots country 23 | bernoulli purges country 24 | bernoulli demonstrations country 25 | bernoulli catholics country 26 | bernoulli airdistance country 27 | bernoulli medicinengo country 28 | bernoulli diplomatexpelled country 29 | bernoulli divorces country 30 | bernoulli popn/land country 31 | bernoulli arable country 32 | bernoulli area country 33 | bernoulli roadlength country 34 | bernoulli railroadlength country 35 | bernoulli religions country 36 | bernoulli immigrants/migrants country 37 | bernoulli rainfall country 38 | bernoulli largestrelgn country 39 | bernoulli runningwater country 40 | bernoulli foreigncollegestud country 41 | bernoulli neutralblock country 42 | bernoulli age country 43 | bernoulli religioustitles country 44 | bernoulli emigrants country 45 | bernoulli seabornegoods country 46 | bernoulli lawngos country 47 | bernoulli unemployed country 48 | bernoulli export country 49 | bernoulli languages country 50 | bernoulli largestlang country 51 | bernoulli ethnicgrps country 52 | bernoulli economicaidtaken country 53 | bernoulli techassistancetaken country 54 | bernoulli goveducationspend country 55 | bernoulli femaleworkers country 56 | bernoulli exports country 57 | bernoulli foreignmail country 58 | bernoulli imports country 59 | bernoulli caloriesconsumed country 60 | bernoulli protein country 61 | bernoulli russiantitles country 62 | bernoulli militarypersonnel country 63 | bernoulli investments country 64 | bernoulli politicalparties country 65 | bernoulli artsculturengo country 66 | bernoulli communistparty country 67 | bernoulli govspending country 68 | bernoulli monarchy country 69 | bernoulli primaryschool country 70 | bernoulli govchangelegal0 country 71 | bernoulli legitgov0 country 72 | bernoulli largestethnic country 73 | bernoulli assassinations country 74 | bernoulli majgovcrisis country 75 | bernoulli unpaymentdelinq country 76 | bernoulli balancepayments country 77 | bernoulli balanceinvestments country 78 | bernoulli systemstyle0 country 79 | bernoulli constitutional0 country 80 | bernoulli electoralsystem0 country 81 | bernoulli noncommunist country 82 | bernoulli politicalleadership0 country 83 | bernoulli horizontalpower0 country 84 | bernoulli military0 country 85 | bernoulli bureaucracy0 country 86 | bernoulli censorship0 country 87 | bernoulli geographyx country 88 | bernoulli geographyy country 89 | bernoulli geographyz country 90 | bernoulli blocmembership1 country 91 | bernoulli blocmembership2 country 92 | bernoulli freedomofopposition1 country 93 | bernoulli freedomofopposition2 country 94 | bernoulli govchangelegal1 country 95 | bernoulli govchangelegal2 country 96 | bernoulli legitgov1 country 97 | bernoulli systemstyle1 country 98 | bernoulli systemstyle2 country 99 | bernoulli constitutional1 country 100 | bernoulli constitutional2 country 101 | bernoulli electoralsystem1 country 102 | bernoulli electoralsystem2 country 103 | bernoulli politicalleadership1 country 104 | bernoulli politicalleadership2 country 105 | bernoulli horizontalpower2 country 106 | bernoulli military1 country 107 | bernoulli military2 country 108 | bernoulli bureaucracy1 country 109 | bernoulli bureaucracy2 country 110 | bernoulli censorship1 country 111 | bernoulli censorship2 country 112 | bernoulli economicaid country country 113 | bernoulli releconomicaid country country 114 | bernoulli treaties country country 115 | bernoulli reltreaties country country 116 | bernoulli officialvisits country country 117 | bernoulli conferences country country 118 | bernoulli exportbooks country country 119 | bernoulli relexportbooks country country 120 | bernoulli booktranslations country country 121 | bernoulli relbooktranslations country country 122 | bernoulli warning country country 123 | bernoulli violentactions country country 124 | bernoulli militaryactions country country 125 | bernoulli duration country country 126 | bernoulli negativebehavior country country 127 | bernoulli severdiplomatic country country 128 | bernoulli expeldiplomats country country 129 | bernoulli boycottembargo country country 130 | bernoulli aidenemy country country 131 | bernoulli negativecomm country country 132 | bernoulli accusation country country 133 | bernoulli protests_rel country country 134 | bernoulli unoffialacts country country 135 | bernoulli attackembassy country country 136 | bernoulli nonviolentbehavior country country 137 | bernoulli weightedunvote country country 138 | bernoulli unweightedunvote country country 139 | bernoulli tourism country country 140 | bernoulli reltourism country country 141 | bernoulli tourism3 country country 142 | bernoulli emigrants_rel country country 143 | bernoulli relemigrants country country 144 | bernoulli emigrants3 country country 145 | bernoulli students country country 146 | bernoulli relstudents country country 147 | bernoulli exports_rel country country 148 | bernoulli relexports country country 149 | bernoulli exports3 country country 150 | bernoulli intergovorgs country country 151 | bernoulli relintergovorgs country country 152 | bernoulli ngo country country 153 | bernoulli relngo country country 154 | bernoulli intergovorgs3 country country 155 | bernoulli ngoorgs3 country country 156 | bernoulli embassy country country 157 | bernoulli reldiplomacy country country 158 | bernoulli timesincewar country country 159 | bernoulli timesinceally country country 160 | bernoulli lostterritory country country 161 | bernoulli dependent country country 162 | bernoulli independence country country 163 | bernoulli commonbloc0 country country 164 | bernoulli blockpositionindex country country 165 | bernoulli militaryalliance country country 166 | bernoulli commonbloc1 country country 167 | bernoulli commonbloc2 country country 168 | -------------------------------------------------------------------------------- /cxx/assets/two_relations.schema: -------------------------------------------------------------------------------- 1 | bernoulli R1 D1 D2 2 | bernoulli R2 D1 D2 3 | -------------------------------------------------------------------------------- /cxx/globals.hh: -------------------------------------------------------------------------------- 1 | // Copyright 2021 MIT Probabilistic Computing Project 2 | // Apache License, Version 2.0, refer to LICENSE.txt 3 | 4 | #pragma once 5 | 6 | #include 7 | #include 8 | #include 9 | #include 10 | #include 11 | #include 12 | #include 13 | #include 14 | #include 15 | #include 16 | 17 | using std::map; 18 | using std::string; 19 | using std::tuple; 20 | using std::vector; 21 | 22 | #define uset std::unordered_set 23 | #define umap std::unordered_map 24 | 25 | // https://stackoverflow.com/q/2241327/ 26 | typedef std::mt19937 PRNG; 27 | 28 | typedef map> T_schema; 29 | 30 | extern const double INF; 31 | -------------------------------------------------------------------------------- /cxx/hirm.cc: -------------------------------------------------------------------------------- 1 | // Copyright 2021 MIT Probabilistic Computing Project 2 | // Apache License, Version 2.0, refer to LICENSE.txt 3 | 4 | #include 5 | #include 6 | #include 7 | 8 | #include "cxxopts.hpp" 9 | #include "globals.hh" 10 | #include "hirm.hh" 11 | #include "util_io.hh" 12 | 13 | #define GET_ELAPSED(t) double(clock() - t) / CLOCKS_PER_SEC 14 | 15 | #define CHECK_TIMEOUT(\ 16 | timeout, \ 17 | t_begin) \ 18 | if (timeout) { \ 19 | auto elapsed = GET_ELAPSED(t_begin); \ 20 | if (timeout < elapsed) { \ 21 | printf("timeout after %1.2fs \n", elapsed); \ 22 | break; \ 23 | } \ 24 | } 25 | 26 | #define REPORT_SCORE(\ 27 | var_verbose, \ 28 | var_t, \ 29 | var_t_total, \ 30 | var_model) \ 31 | if (var_verbose) { \ 32 | auto t_delta = GET_ELAPSED(var_t); \ 33 | var_t_total += t_delta; \ 34 | double x = var_model->logp_score(); \ 35 | printf("%f %f\n", var_t_total, x); \ 36 | fflush(stdout); \ 37 | } 38 | 39 | void inference_irm(IRM * irm, int iters, int timeout, bool verbose) { 40 | clock_t t_begin = clock(); 41 | double t_total = 0; 42 | for (int i = 0; i < iters; i++) { 43 | CHECK_TIMEOUT(timeout, t_begin); 44 | // TRANSITION ASSIGNMENTS. 45 | for (const auto &[d, domain] : irm->domains) { 46 | for (auto item : domain->items) { 47 | clock_t t = clock(); 48 | irm->transition_cluster_assignment_item(d, item); 49 | REPORT_SCORE(verbose, t, t_total, irm); 50 | } 51 | } 52 | // TRANSITION ALPHA. 53 | for (auto const &[d, domain] : irm->domains) { 54 | clock_t t = clock(); 55 | domain->crp.transition_alpha(); 56 | REPORT_SCORE(verbose, t, t_total, irm); 57 | } 58 | } 59 | } 60 | 61 | void inference_hirm(HIRM * hirm, int iters, int timeout, bool verbose) { 62 | clock_t t_begin = clock(); 63 | double t_total = 0; 64 | for (int i = 0; i < iters; i++) { 65 | CHECK_TIMEOUT(timeout, t_begin); 66 | // TRANSITION RELATIONS. 67 | for (const auto &[r, rc] : hirm->relation_to_code) { 68 | clock_t t = clock(); 69 | hirm->transition_cluster_assignment_relation(r); 70 | REPORT_SCORE(verbose, t, t_total, hirm); 71 | } 72 | // TRANSITION IRMs. 73 | for (const auto &[t, irm] : hirm->irms) { 74 | // TRANSITION ASSIGNMENTS. 75 | for (const auto &[d, domain] : irm->domains) { 76 | for (auto item : domain->items) { 77 | clock_t t = clock(); 78 | irm->transition_cluster_assignment_item(d, item); 79 | REPORT_SCORE(verbose, t, t_total, irm); 80 | } 81 | } 82 | // TRANSITION ALPHA. 83 | for (auto const &[d, domain] : irm->domains) { 84 | clock_t t = clock(); 85 | domain->crp.transition_alpha(); 86 | REPORT_SCORE(verbose, t, t_total, irm); 87 | } 88 | } 89 | } 90 | } 91 | 92 | int main(int argc, char **argv) { 93 | 94 | cxxopts::Options options("hirm", "Run a hierarchical infinite relational model."); 95 | options.add_options() 96 | ("help", "show help message") 97 | ("mode", "options are {irm, hirm}", cxxopts::value()->default_value("hirm")) 98 | ("seed", "random seed", cxxopts::value()->default_value("10")) 99 | ("iters", "number of inference iterations", cxxopts::value()->default_value("10")) 100 | ("verbose", "report results to terminal", cxxopts::value()->default_value("false")) 101 | ("timeout", "number of seconds of inference", cxxopts::value()->default_value("0")) 102 | ("load", "path to .[h]irm file with initial clusters", cxxopts::value()->default_value("")) 103 | ("path", "base name of the .schema file", cxxopts::value()) 104 | ("rest", "rest", cxxopts::value>()->default_value({})); 105 | options.parse_positional({"path", "rest"}); 106 | options.positional_help(""); 107 | 108 | auto result = options.parse(argc, argv); 109 | if (result.count("help")) { 110 | std::cout << options.help() << std::endl; 111 | return 0; 112 | } 113 | if (result.count("path") == 0) { 114 | std::cout << options.help() << std::endl; 115 | return 1; 116 | } 117 | 118 | string path_base = result["path"].as(); 119 | int seed = result["seed"].as(); 120 | int iters = result["iters"].as(); 121 | int timeout = result["timeout"].as(); 122 | bool verbose = result["verbose"].as(); 123 | string path_clusters = result["load"].as(); 124 | string mode = result["mode"].as(); 125 | 126 | if (mode != "hirm" && mode != "irm") { 127 | std::cout << options.help() << std::endl; 128 | std::cout << "unknown mode " << mode << std::endl; 129 | return 1; 130 | } 131 | 132 | string path_obs = path_base + ".obs"; 133 | string path_schema = path_base + ".schema"; 134 | string path_save = path_base + "." + std::to_string(seed); 135 | 136 | printf("setting seed to %d\n", seed); 137 | PRNG prng (seed); 138 | 139 | std::cout << "loading schema from " << path_schema << std::endl; 140 | auto schema = load_schema(path_schema); 141 | 142 | std::cout << "loading observations from " << path_obs << std::endl; 143 | auto observations = load_observations(path_obs); 144 | auto encoding = encode_observations(schema, observations); 145 | 146 | if (mode == "irm") { 147 | std::cout << "selected model is IRM" << std::endl; 148 | IRM * irm; 149 | // Load 150 | if (path_clusters.empty()) { 151 | irm = new IRM(schema, &prng); 152 | std::cout << "incorporating observations" << std::endl; 153 | incorporate_observations(*irm, encoding, observations); 154 | } else { 155 | irm = new IRM({}, &prng); 156 | std::cout << "loading clusters from " << path_clusters << std::endl; 157 | from_txt(irm, path_schema, path_obs, path_clusters); 158 | } 159 | // Infer 160 | std::cout << "inferring " << iters << " iters; timeout " << timeout << std::endl; 161 | inference_irm(irm, iters, timeout, verbose); 162 | // Save 163 | path_save += ".irm"; 164 | std::cout << "saving to " << path_save << std::endl; 165 | to_txt(path_save, *irm, encoding); 166 | // Free 167 | free(irm); 168 | return 0; 169 | } 170 | 171 | if (mode == "hirm") { 172 | std::cout << "selected model is HIRM" << std::endl; 173 | HIRM * hirm; 174 | // Load 175 | if (path_clusters.empty()) { 176 | hirm = new HIRM(schema, &prng); 177 | std::cout << "incorporating observations" << std::endl; 178 | incorporate_observations(*hirm, encoding, observations); 179 | } else { 180 | hirm = new HIRM({}, &prng); 181 | std::cout << "loading clusters from " << path_clusters << std::endl; 182 | from_txt(hirm, path_schema, path_obs, path_clusters); 183 | } 184 | // Infer 185 | std::cout << "inferring " << iters << " iters; timeout " << timeout << std::endl; 186 | inference_hirm(hirm, iters, timeout, verbose); 187 | // Save 188 | path_save += ".hirm"; 189 | std::cout << "saving to " << path_save << std::endl; 190 | to_txt(path_save, *hirm, encoding); 191 | // Free 192 | free(hirm); 193 | return 0; 194 | } 195 | } 196 | -------------------------------------------------------------------------------- /cxx/tests/test_hirm_animals.cc: -------------------------------------------------------------------------------- 1 | // Copyright 2021 MIT Probabilistic Computing Project 2 | // Apache License, Version 2.0, refer to LICENSE.txt 3 | 4 | #include 5 | #include 6 | 7 | #include "hirm.hh" 8 | #include "util_hash.hh" 9 | #include "util_io.hh" 10 | #include "util_math.hh" 11 | 12 | int main(int argc, char **argv) { 13 | 14 | srand(1); 15 | PRNG prng (1); 16 | 17 | string path_base = "assets/animals.unary"; 18 | auto path_schema = path_base + ".schema"; 19 | auto path_obs = path_base + ".obs"; 20 | 21 | printf("== HIRM == \n"); 22 | auto schema_unary = load_schema(path_schema); 23 | auto observations_unary = load_observations(path_obs); 24 | auto encoding_unary = encode_observations(schema_unary, observations_unary); 25 | 26 | HIRM hirm (schema_unary, &prng); 27 | incorporate_observations(hirm, encoding_unary, observations_unary); 28 | int n_obs_unary = 0; 29 | for (const auto &[z, irm] : hirm.irms) { 30 | for (const auto &[r, relation] : irm->relations) { 31 | n_obs_unary += relation->data.size(); 32 | } 33 | } 34 | assert(n_obs_unary == observations_unary.size()); 35 | 36 | hirm.transition_cluster_assignments_all(); 37 | hirm.transition_cluster_assignments_all(); 38 | hirm.set_cluster_assignment_gibbs("solitary", 120); 39 | hirm.set_cluster_assignment_gibbs("water", 741); 40 | for (int i = 0; i < 20; i++) { 41 | hirm.transition_cluster_assignments_all(); 42 | for (const auto &[t, irm] : hirm.irms) { 43 | irm->transition_cluster_assignments_all(); 44 | for (const auto &[d, domain] : irm->domains) { 45 | domain->crp.transition_alpha(); 46 | } 47 | } 48 | hirm.crp.transition_alpha(); 49 | printf("%d %f [", i, hirm.logp_score()); 50 | for (const auto &[t, customers] : hirm.crp.tables) { 51 | printf("%ld ", customers.size()); 52 | } 53 | printf("]\n"); 54 | } 55 | 56 | // TODO: Removing the relation causes solitary to have no observations, 57 | // which causes the serialization test. Instead, we need a 58 | // to_txt_dataset([relation | irm | hirm]) which writes the latest 59 | // dataset to disk and is used upon reloading the data. 60 | // hirm.remove_relation("solitary"); 61 | // hirm.transition_cluster_assignments_all(); 62 | // hirm.add_relation("solitary", {"animal"}); 63 | // hirm.transition_cluster_assignments_all(); 64 | 65 | string path_clusters = path_base + ".hirm"; 66 | to_txt(path_clusters, hirm, encoding_unary); 67 | 68 | auto &enc = std::get<0>(encoding_unary); 69 | 70 | // Marginally normalized. 71 | int persiancat = enc["animal"]["persiancat"]; 72 | auto p0_black_persiancat = hirm.logp({{"black", {persiancat}, 0.}}); 73 | auto p1_black_persiancat = hirm.logp({{"black", {persiancat}, 1.}}); 74 | assert(abs(logsumexp({p0_black_persiancat, p1_black_persiancat})) < 1e-10); 75 | 76 | // Marginally normalized. 77 | int sheep = enc["animal"]["sheep"]; 78 | auto p0_solitary_sheep = hirm.logp({{"solitary", {sheep}, 0.}}); 79 | auto p1_solitary_sheep = hirm.logp({{"solitary", {sheep}, 1.}}); 80 | assert(abs(logsumexp({p0_solitary_sheep, p1_solitary_sheep})) < 1e-10); 81 | 82 | // Jointly normalized. 83 | auto p00_black_persiancat_solitary_sheep = hirm.logp( 84 | {{"black", {persiancat}, 0.}, {"solitary", {sheep}, 0.}}); 85 | auto p01_black_persiancat_solitary_sheep = hirm.logp( 86 | {{"black", {persiancat}, 0.}, {"solitary", {sheep}, 1.}}); 87 | auto p10_black_persiancat_solitary_sheep = hirm.logp( 88 | {{"black", {persiancat}, 1.}, {"solitary", {sheep}, 0.}}); 89 | auto p11_black_persiancat_solitary_sheep = hirm.logp( 90 | {{"black", {persiancat}, 1.}, {"solitary", {sheep}, 1.}}); 91 | auto Z = logsumexp({ 92 | p00_black_persiancat_solitary_sheep, 93 | p01_black_persiancat_solitary_sheep, 94 | p10_black_persiancat_solitary_sheep, 95 | p11_black_persiancat_solitary_sheep, 96 | }); 97 | assert(abs(Z) < 1e-10); 98 | 99 | // Independence 100 | assert(abs(p00_black_persiancat_solitary_sheep - (p0_black_persiancat + p0_solitary_sheep)) < 1e-8); 101 | assert(abs(p01_black_persiancat_solitary_sheep - (p0_black_persiancat + p1_solitary_sheep)) < 1e-8); 102 | assert(abs(p10_black_persiancat_solitary_sheep - (p1_black_persiancat + p0_solitary_sheep)) < 1e-8); 103 | assert(abs(p11_black_persiancat_solitary_sheep - (p1_black_persiancat + p1_solitary_sheep)) < 1e-8); 104 | 105 | // Load the clusters. 106 | HIRM hirx ({}, &prng); 107 | from_txt(&hirx, path_schema, path_obs, path_clusters); 108 | 109 | assert(hirm.irms.size() == hirx.irms.size()); 110 | // Check IRMs agree. 111 | for (const auto &[table, irm] : hirm.irms) { 112 | auto irx = hirx.irms.at(table); 113 | // Check log scores agree. 114 | for (const auto &[d, dm] : irm->domains) { 115 | auto dx = irx->domains.at(d); 116 | dx->crp.alpha = dm->crp.alpha; 117 | } 118 | assert(abs(irx->logp_score() - irm->logp_score()) < 1e-8); 119 | // Check domains agree. 120 | for (const auto &[d, dm] : irm->domains) { 121 | auto dx = irx->domains.at(d); 122 | assert(dm->items == dx->items); 123 | assert(dm->crp.assignments == dx->crp.assignments); 124 | assert(dm->crp.tables == dx->crp.tables); 125 | assert(dm->crp.N == dx->crp.N); 126 | assert(dm->crp.alpha == dx->crp.alpha); 127 | } 128 | // Check relations agree. 129 | for (const auto &[r, rm] : irm->relations) { 130 | auto rx = irx->relations.at(r); 131 | assert(rm->data == rx->data); 132 | assert(rm->data_r == rx->data_r); 133 | assert(rm->clusters.size() == rx->clusters.size()); 134 | for (const auto &[z, clusterm] : rm->clusters) { 135 | auto clusterx = rx->clusters.at(z); 136 | assert(clusterm->N == clusterx->N); 137 | } 138 | } 139 | } 140 | hirx.crp.alpha = hirm.crp.alpha; 141 | assert(abs(hirx.logp_score() - hirm.logp_score()) < 1e-8); 142 | } 143 | -------------------------------------------------------------------------------- /cxx/tests/test_irm_two_relations.cc: -------------------------------------------------------------------------------- 1 | // Copyright 2021 MIT Probabilistic Computing Project 2 | // Apache License, Version 2.0, refer to LICENSE.txt 3 | 4 | #include 5 | #include 6 | #include 7 | #include 8 | #include 9 | #include 10 | #include 11 | #include 12 | #include 13 | 14 | #include "globals.hh" 15 | #include "hirm.hh" 16 | #include "util_io.hh" 17 | #include "util_math.hh" 18 | 19 | int main(int argc, char **argv) { 20 | string path_base = "assets/two_relations"; 21 | int seed = 1; 22 | int iters = 2; 23 | 24 | PRNG prng (seed); 25 | 26 | string path_schema = path_base + ".schema"; 27 | std::cout << "loading schema from " << path_schema << std::endl; 28 | auto schema = load_schema(path_schema); 29 | for (auto const &[relation, domains] : schema) { 30 | printf("relation: %s, ", relation.c_str()); 31 | printf("domains: "); 32 | for (auto const &domain : domains) { 33 | printf("%s ", domain.c_str()); 34 | } 35 | printf("\n"); 36 | } 37 | 38 | string path_obs = path_base + ".obs"; 39 | std::cout << "loading observations from " << path_obs << std::endl; 40 | auto observations = load_observations(path_obs); 41 | T_encoding encoding = encode_observations(schema, observations); 42 | 43 | IRM irm (schema, &prng); 44 | incorporate_observations(irm, encoding, observations); 45 | printf("running for %d iterations\n", iters); 46 | for (int i = 0; i < iters; i++) { 47 | irm.transition_cluster_assignments_all(); 48 | for (auto const &[d, domain] : irm.domains) { 49 | domain->crp.transition_alpha(); 50 | } 51 | double x = irm.logp_score(); 52 | printf("iter %d, score %f\n", i, x); 53 | } 54 | 55 | string path_clusters = path_base + ".irm"; 56 | std::cout << "writing clusters to " << path_clusters << std::endl; 57 | to_txt(path_clusters, irm, encoding); 58 | 59 | map> expected_p0 { 60 | {0, { {0, 1}, {10, 1}, {100, .5} } }, 61 | {10, { {0, 0}, {10, 0}, {100, .5} } }, 62 | {100, { {0, .66}, {10, .66}, {100, .5} } }, 63 | }; 64 | 65 | vector> indexes {{0, 10, 100}, {0, 10, 100}}; 66 | for (const auto &l : product(indexes)) { 67 | assert(l.size() == 2); 68 | auto x1 = l.at(0); 69 | auto x2 = l.at(1); 70 | auto p0 = irm.relations.at("R1")->logp({x1, x2}, 0); 71 | auto p0_irm = irm.logp({{"R1", {x1, x2}, 0}}); 72 | assert(abs(p0 - p0_irm) < 1e-10); 73 | auto p1 = irm.relations.at("R1")->logp({x1, x2}, 1); 74 | auto Z = logsumexp({p0, p1}); 75 | assert(abs(Z) < 1e-10); 76 | assert(abs(exp(p0) - expected_p0[x1].at(x2)) < .1); 77 | } 78 | 79 | for (const auto &l : vector> {{0, 10, 100}, {110, 10, 100}}) { 80 | auto x1 = l.at(0); 81 | auto x2 = l.at(1); 82 | auto x3 = l.at(2); 83 | auto p00 = irm.logp({{"R1", {x1, x2}, 0}, {"R1", {x1, x3}, 0}}); 84 | auto p01 = irm.logp({{"R1", {x1, x2}, 0}, {"R1", {x1, x3}, 1}}); 85 | auto p10 = irm.logp({{"R1", {x1, x2}, 1}, {"R1", {x1, x3}, 0}}); 86 | auto p11 = irm.logp({{"R1", {x1, x2}, 1}, {"R1", {x1, x3}, 1}}); 87 | auto Z = logsumexp({p00, p01, p10, p11}); 88 | assert(abs(Z) < 1e-10); 89 | } 90 | 91 | IRM irx ({}, &prng); 92 | from_txt(&irx, path_schema, path_obs, path_clusters); 93 | // Check log scores agree. 94 | for (const auto &d : {"D1", "D2"}) { 95 | auto dm = irm.domains.at(d); 96 | auto dx = irx.domains.at(d); 97 | dx->crp.alpha = dm->crp.alpha; 98 | } 99 | assert(abs(irx.logp_score() - irm.logp_score()) < 1e-8); 100 | // Check domains agree. 101 | for (const auto &d : {"D1", "D2"}) { 102 | auto dm = irm.domains.at(d); 103 | auto dx = irx.domains.at(d); 104 | assert(dm->items == dx->items); 105 | assert(dm->crp.assignments == dx->crp.assignments); 106 | assert(dm->crp.tables == dx->crp.tables); 107 | assert(dm->crp.N == dx->crp.N); 108 | assert(dm->crp.alpha == dx->crp.alpha); 109 | } 110 | // Check relations agree. 111 | for (const auto &r : {"R1", "R2"}) { 112 | auto rm = irm.relations.at(r); 113 | auto rx = irx.relations.at(r); 114 | assert(rm->data == rx->data); 115 | assert(rm->data_r == rx->data_r); 116 | assert(rm->clusters.size() == rx->clusters.size()); 117 | for (const auto &[z, clusterm] : rm->clusters) { 118 | auto clusterx = rx->clusters.at(z); 119 | assert(clusterm->N == clusterx->N); 120 | } 121 | } 122 | } 123 | -------------------------------------------------------------------------------- /cxx/tests/test_misc.cc: -------------------------------------------------------------------------------- 1 | // Copyright 2021 MIT Probabilistic Computing Project 2 | // Apache License, Version 2.0, refer to LICENSE.txt 3 | 4 | #include 5 | #include 6 | #include 7 | #include 8 | #include 9 | #include 10 | #include 11 | #include 12 | #include 13 | #include 14 | #include 15 | 16 | #include "hirm.hh" 17 | #include "util_hash.hh" 18 | #include "util_io.hh" 19 | #include "util_math.hh" 20 | 21 | int main(int argc, char **argv) { 22 | 23 | srand(1); 24 | PRNG prng (1); 25 | 26 | BetaBernoulli bb (&prng); 27 | bb.incorporate(1); 28 | bb.incorporate(1); 29 | printf("%f\n", exp(bb.logp(1))); 30 | for (int i = 0; i < 100; i++) { 31 | printf("%1.f ", bb.sample()); 32 | } 33 | printf("\n"); 34 | 35 | CRP crp (&prng); 36 | crp.alpha = 1.5; 37 | printf("starting crp\n"); 38 | T_item foo = 1, food = 2, sultan = 3, ali = 4; 39 | printf("%f\n", crp.logp_score()); 40 | crp.incorporate(foo, 1); 41 | printf("%f\n", crp.logp_score()); 42 | crp.incorporate(food, 1); 43 | printf("%f\n", crp.logp_score()); 44 | crp.incorporate(sultan, 12); 45 | printf("%f\n", crp.logp_score()); 46 | crp.incorporate(ali, 0); 47 | std::cout << "tables count 10 " << crp.tables.count(10) << std::endl; 48 | for (auto const &i : crp.tables[0]) { 49 | std::cout << i << " "; 50 | } 51 | for (auto const &i : crp.tables[1]) { 52 | std::cout << i << " "; 53 | } 54 | printf("\n"); 55 | std::cout << "assignments ali? " << crp.assignments.count(ali) << std::endl; 56 | crp.unincorporate(ali); 57 | std::cout << "assignments ali? " << crp.assignments.count(ali) << std::endl; 58 | printf("%f %d\n", crp.logp_score(), crp.assignments[-1]); 59 | 60 | printf("=== tables_weights\n"); 61 | auto tables_weights = crp.tables_weights(); 62 | for (auto &tw : tables_weights) { 63 | printf("table %d weight %f\n", tw.first, tw.second); 64 | } 65 | 66 | printf("=== tables_weights_gibbs\n"); 67 | auto tables_weights_gibbs = crp.tables_weights_gibbs(1); 68 | for (auto &tw : tables_weights_gibbs) { 69 | printf("table %d weight %f\n", tw.first, tw.second); 70 | } 71 | printf("==== tables_weights_gibbs_singleton\n"); 72 | auto tables_weights_gibbs_singleton = crp.tables_weights_gibbs(12); 73 | for (auto &tw : tables_weights_gibbs_singleton) { 74 | printf("table %d weight %f\n", tw.first, tw.second); 75 | } 76 | printf("==== log probability\n"); 77 | printf("%f\n", crp.logp(0)); 78 | 79 | printf("=== DOMAIN === \n"); 80 | Domain d ("foo", &prng); 81 | string relation1 = "ali"; 82 | string relation2 = "mubarak"; 83 | T_item salman = 1; 84 | T_item mansour = 2; 85 | d.incorporate(salman); 86 | for (auto &item : d.items) { 87 | printf("item %d: ", item); 88 | } 89 | d.set_cluster_assignment_gibbs(salman, 12); 90 | d.incorporate(salman); 91 | d.incorporate(mansour, 5); 92 | for (auto &item : d.items) { 93 | printf("item %d: ", item); 94 | } 95 | // d.unincorporate(salman); 96 | for (auto &item : d.items) { 97 | printf("item %d: ", item); 98 | } 99 | // d.unincorporate(relation2, salman); 100 | // assert (d.items.size() == 0); 101 | // d.items[01].insert("foo"); 102 | 103 | umap> m; 104 | m[1].insert(10); 105 | m[1] = uset(); 106 | for (auto &ir: m) { 107 | printf("%d\n", ir.first); 108 | for (auto &x : ir.second) { 109 | printf("%d\n", x); 110 | } 111 | } 112 | 113 | printf("== RELATION == \n"); 114 | Domain D1 ("D1", &prng); 115 | Domain D2 ("D2", &prng); 116 | Domain D3 ("D3", &prng); 117 | D1.incorporate(0); 118 | D2.incorporate(1); 119 | D3.incorporate(3); 120 | Relation R1 ("R1", {&D1, &D2, &D3}, &prng); 121 | printf("arity %ld\n", R1.domains.size()); 122 | R1.incorporate({0, 1, 3}, 1); 123 | R1.incorporate({1, 1, 3}, 1); 124 | R1.incorporate({3, 1, 3}, 1); 125 | R1.incorporate({4, 1, 3}, 1); 126 | R1.incorporate({5, 1, 3}, 1); 127 | R1.incorporate({0, 1, 4}, 0); 128 | R1.incorporate({0, 1, 6}, 1); 129 | auto z1 = R1.get_cluster_assignment({0, 1, 3}); 130 | for (int x : z1) { 131 | printf("%d,", x); 132 | } 133 | auto z2 = R1.get_cluster_assignment_gibbs({0, 1, 3}, D2, 1, 191); 134 | printf("\n"); 135 | for (int x : z2) { 136 | printf("%d,", x); 137 | } 138 | printf("\n"); 139 | 140 | double lpg = R1.logp_gibbs_approx(D1, 0, 1); 141 | printf("logp gibbs %f\n", lpg); 142 | lpg = R1.logp_gibbs_approx(D1, 0, 0); 143 | printf("logp gibbs %f\n", lpg); 144 | lpg = R1.logp_gibbs_approx(D1, 0, 10); 145 | printf("logp gibbs %f\n", lpg); 146 | 147 | printf("calling set_cluster_assignment_gibbs\n"); 148 | R1.set_cluster_assignment_gibbs(D1, 0, 1); 149 | printf("new cluster %d\n", D1.get_cluster_assignment(0)); 150 | D1.set_cluster_assignment_gibbs(0, 1); 151 | 152 | printf("%lu\n", R1.data.size()); 153 | // R1.unincorporate({0, 1, 3}); 154 | printf("%lu\n", R1.data.size()); 155 | 156 | printf("== HASHING UTIL == \n"); 157 | std::unordered_map, int, VectorIntHash> map_int; 158 | map_int[{1, 2}] = 7; 159 | printf("%d\n", map_int.at({1,2})); 160 | std::unordered_map, int, VectorStringHash> map_str; 161 | map_str[{"1", "2", "3"}] = 7; 162 | printf("%d\n", map_str.at({"1","2", "3"})); 163 | 164 | 165 | printf("===== IRM ====\n"); 166 | map> schema1 { 167 | {"R1", {"D1", "D1"}}, 168 | {"R2", {"D1", "D2"}}, 169 | {"R3", {"D3", "D1"}}, 170 | }; 171 | IRM irm(schema1, &prng); 172 | 173 | for (auto const &kv : irm.domains) { 174 | printf("%s %s; ", kv.first.c_str(), kv.second->name.c_str()); 175 | for (auto const r : irm.domain_to_relations.at(kv.first)) { 176 | printf("%s ", r.c_str()); 177 | } 178 | printf("\n"); 179 | } 180 | for (auto const &kv : irm.relations) { 181 | printf("%s ", kv.first.c_str()); 182 | for (auto const d : kv.second->domains) { 183 | printf("%s ", d->name.c_str()); 184 | } 185 | printf("\n"); 186 | } 187 | 188 | printf("==== READING IO ===== \n"); 189 | auto schema = load_schema("assets/animals.binary.schema"); 190 | for (auto const &i : schema) { 191 | printf("relation: %s\n", i.first.c_str()); 192 | printf("domains: "); 193 | for (auto const &j : i.second) { 194 | printf("%s ", j.c_str()); 195 | } 196 | printf("\n"); 197 | } 198 | 199 | IRM irm3 (schema, &prng); 200 | auto observations = load_observations("assets/animals.binary.obs"); 201 | auto encoding = encode_observations(schema, observations); 202 | auto item_to_code = std::get<0>(encoding); 203 | for (auto const &i : observations) { 204 | auto relation = std::get<0>(i); 205 | auto value = std::get<2>(i); 206 | auto item = std::get<1>(i); 207 | printf("incorporating %s ", relation.c_str()); 208 | printf("%1.f ", value); 209 | int counter = 0; 210 | T_items items_code; 211 | for (auto const &item : std::get<1>(i)) { 212 | auto domain = schema.at(relation)[counter]; 213 | counter += 1; 214 | auto code = item_to_code.at(domain).at(item); 215 | printf("%s(%d) ", item.c_str(), code); 216 | items_code.push_back(code); 217 | } 218 | printf("\n"); 219 | irm3.incorporate(relation, items_code, value); 220 | } 221 | 222 | for (int i = 0; i < 4; i++) { 223 | irm3.transition_cluster_assignments({"animal", "feature"}); 224 | irm3.transition_cluster_assignments_all(); 225 | for (auto const &[d, domain]: irm3.domains) { 226 | domain->crp.transition_alpha(); 227 | } 228 | double x = irm3.logp_score(); 229 | printf("iter %d, score %f\n", i, x); 230 | } 231 | 232 | string path_clusters = "assets/animals.binary.irm"; 233 | to_txt(path_clusters, irm3, encoding); 234 | 235 | auto rel = irm3.relations.at("has"); 236 | auto &enc = std::get<0>(encoding); 237 | auto lp0 = rel->logp({enc["animal"]["tail"], enc["animal"]["bat"]}, 0); 238 | auto lp1 = rel->logp({enc["animal"]["tail"], enc["animal"]["bat"]}, 1); 239 | auto lp_01 = logsumexp({lp0, lp1}); 240 | assert(abs(lp_01) < 1e-5); 241 | printf("log prob of has(tail, bat)=0 is %1.2f\n", lp0); 242 | printf("log prob of has(tail, bat)=1 is %1.2f\n", lp1); 243 | printf("logsumexp is %1.2f\n", lp_01); 244 | 245 | IRM irm4 ({}, &prng); 246 | from_txt(&irm4, 247 | "assets/animals.binary.schema", 248 | "assets/animals.binary.obs", 249 | path_clusters); 250 | irm4.domains.at("animal")->crp.alpha = irm3.domains.at("animal")->crp.alpha; 251 | irm4.domains.at("feature")->crp.alpha = irm3.domains.at("feature")->crp.alpha; 252 | assert(abs(irm3.logp_score() - irm4.logp_score()) < 1e-8); 253 | for (const auto &d : {"animal", "feature"}) { 254 | auto d3 = irm3.domains.at(d); 255 | auto d4 = irm4.domains.at(d); 256 | assert(d3->items == d4->items); 257 | assert(d3->crp.assignments == d4->crp.assignments); 258 | assert(d3->crp.tables == d4->crp.tables); 259 | assert(d3->crp.N == d4->crp.N); 260 | assert(d3->crp.alpha == d4->crp.alpha); 261 | } 262 | for (const auto &r : {"has"}) { 263 | auto r3 = irm3.relations.at(r); 264 | auto r4 = irm4.relations.at(r); 265 | assert(r3->data == r4->data); 266 | assert(r3->data_r == r4->data_r); 267 | assert(r3->clusters.size() == r4->clusters.size()); 268 | for (const auto &[z, cluster3] : r3->clusters) { 269 | auto cluster4 = r4->clusters.at(z); 270 | assert(cluster3->N == cluster4->N); 271 | } 272 | } 273 | } 274 | -------------------------------------------------------------------------------- /cxx/tests/test_util_math.cc: -------------------------------------------------------------------------------- 1 | // Copyright 2021 MIT Probabilistic Computing Project 2 | // Apache License, Version 2.0, refer to LICENSE.txt 3 | 4 | #include 5 | #include 6 | #include 7 | 8 | #include "util_math.hh" 9 | 10 | int main(int argc, char **argv) { 11 | vector> x {{1}, {2, 3}, {1, 10, 11}}; 12 | 13 | auto cartesian = product(x); 14 | assert(cartesian.size() == 6); 15 | assert((cartesian.at(0) == vector{1, 2, 1})); 16 | assert((cartesian.at(1) == vector{1, 2, 10})); 17 | assert((cartesian.at(2) == vector{1, 2, 11})); 18 | assert((cartesian.at(3) == vector{1, 3, 1})); 19 | assert((cartesian.at(4) == vector{1, 3, 10})); 20 | assert((cartesian.at(5) == vector{1, 3, 11})); 21 | 22 | x.push_back({}); 23 | cartesian = product(x); 24 | assert(cartesian.size() == 0); 25 | } 26 | -------------------------------------------------------------------------------- /cxx/util_hash.hh: -------------------------------------------------------------------------------- 1 | // Copyright 2021 MIT Probabilistic Computing Project 2 | // Apache License, Version 2.0, refer to LICENSE.txt 3 | 4 | // Hash functions for std:: 5 | // https://stackoverflow.com/a/27216842 6 | 7 | #pragma once 8 | 9 | #include 10 | #include 11 | 12 | using std::vector; 13 | using std::string; 14 | 15 | struct VectorIntHash { 16 | int operator()(const vector &V) const { 17 | int hash = V.size(); 18 | for(auto &i : V) { 19 | hash ^= i + 0x9e3779b9 + (hash << 6) + (hash >> 2); 20 | } 21 | return hash; 22 | } 23 | }; 24 | 25 | struct VectorStringHash { 26 | int operator()(const vector &V) const { 27 | int hash = V.size(); 28 | for(auto &s : V) { 29 | hash ^= std::hash{}(s) + 0x9e3779b9 + (hash << 6) + (hash >> 2); 30 | } 31 | return hash; 32 | } 33 | }; 34 | -------------------------------------------------------------------------------- /cxx/util_io.cc: -------------------------------------------------------------------------------- 1 | // Copyright 2021 MIT Probabilistic Computing Project 2 | // Apache License, Version 2.0, refer to LICENSE.txt 3 | 4 | #include 5 | #include 6 | #include 7 | #include 8 | 9 | #include "util_io.hh" 10 | 11 | T_schema load_schema(const string &path) { 12 | std::ifstream fp (path, std::ifstream::in); 13 | assert(fp.good()); 14 | 15 | map> schema; 16 | string line; 17 | while (std::getline(fp, line)) { 18 | std::istringstream stream (line); 19 | 20 | string dist; 21 | string relname; 22 | vector domains; 23 | 24 | stream >> dist; 25 | stream >> relname; 26 | for (string w; stream >> w; ) { 27 | domains.push_back(w); 28 | } 29 | assert(domains.size() > 0); 30 | schema[relname] = domains; 31 | } 32 | fp.close(); 33 | return schema; 34 | } 35 | 36 | T_observations load_observations(const string &path) { 37 | std::ifstream fp (path, std::ifstream::in); 38 | assert(fp.good()); 39 | 40 | vector, double>> observations; 41 | string line; 42 | while (std::getline(fp, line)) { 43 | std::istringstream stream (line); 44 | 45 | double value; 46 | string relname; 47 | vector items; 48 | 49 | stream >> value; 50 | stream >> relname; 51 | for (string w; stream >> w; ) { 52 | items.push_back(w); 53 | } 54 | assert(items.size() > 0); 55 | auto entry= std::make_tuple(relname, items, value); 56 | observations.push_back(entry); 57 | } 58 | fp.close(); 59 | return observations; 60 | } 61 | 62 | // Assumes that T_item is integer. 63 | T_encoding encode_observations(const T_schema &schema, 64 | const T_observations &observations) { 65 | // Counter and encoding maps. 66 | map domain_item_counter; 67 | T_encoding_f item_to_code; 68 | T_encoding_r code_to_item; 69 | // Create a counter of items for each domain. 70 | for (const auto &[r, domains]: schema) { 71 | for (const auto &domain : domains) { 72 | domain_item_counter[domain] = 0; 73 | item_to_code[domain] = map(); 74 | code_to_item[domain] = map(); 75 | } 76 | } 77 | // Create the codes for each item. 78 | for (const auto &i : observations) { 79 | auto relation = std::get<0>(i); 80 | auto items = std::get<1>(i); 81 | int counter = 0; 82 | for (const auto &item : items) { 83 | // Obtain domain that item belongs to. 84 | auto domain = schema.at(relation).at(counter); 85 | counter += 1; 86 | // Compute its code, if necessary. 87 | if (item_to_code.at(domain).count(item) == 0) { 88 | int code = domain_item_counter[domain]; 89 | item_to_code[domain][item] = code; 90 | code_to_item[domain][code] = item; 91 | domain_item_counter[domain]++; 92 | } 93 | } 94 | } 95 | return std::make_pair(item_to_code, code_to_item); 96 | } 97 | 98 | void incorporate_observations(IRM &irm, const T_encoding &encoding, 99 | const T_observations &observations) { 100 | auto item_to_code = std::get<0>(encoding); 101 | for (const auto &[relation, items, value] : observations) { 102 | int counter = 0; 103 | T_items items_e; 104 | for (const auto &item : items) { 105 | auto domain = irm.schema.at(relation)[counter]; 106 | counter += 1; 107 | int code = item_to_code.at(domain).at(item); 108 | items_e.push_back(code); 109 | } 110 | irm.incorporate(relation, items_e, value); 111 | } 112 | } 113 | 114 | void incorporate_observations(HIRM &hirm, const T_encoding &encoding, 115 | const T_observations &observations) { 116 | int j = 0; 117 | auto item_to_code = std::get<0>(encoding); 118 | for (const auto &[relation, items, value] : observations) { 119 | int counter = 0; 120 | T_items items_e; 121 | for (const auto &item : items) { 122 | auto domain = hirm.schema.at(relation)[counter]; 123 | counter += 1; 124 | int code = item_to_code.at(domain).at(item); 125 | items_e.push_back(code); 126 | } 127 | hirm.incorporate(relation, items_e, value); 128 | } 129 | } 130 | 131 | void to_txt(std::ostream &fp, const IRM &irm, const T_encoding &encoding) { 132 | auto code_to_item = std::get<1>(encoding); 133 | for (const auto &[d, domain]: irm.domains) { 134 | auto i0 = domain->crp.tables.begin(); 135 | auto i1 = domain->crp.tables.end(); 136 | map> tables (i0, i1); 137 | for (const auto &[table, items] : tables) { 138 | fp << domain->name << " "; 139 | fp << table << " "; 140 | int i = 1; 141 | for (const auto &item : items) { 142 | fp << code_to_item.at(domain->name).at(item); 143 | if (i++ < items.size()) { 144 | fp << " "; 145 | } 146 | } 147 | fp << "\n"; 148 | } 149 | } 150 | } 151 | 152 | void to_txt(std::ostream &fp, const HIRM &hirm, const T_encoding &encoding){ 153 | // Write the relation clusters. 154 | auto i0 = hirm.crp.tables.begin(); 155 | auto i1 = hirm.crp.tables.end(); 156 | map> tables (i0, i1); 157 | for (const auto &[table, rcs] : tables) { 158 | fp << table << " "; 159 | int i = 1; 160 | for (const auto rc : rcs) { 161 | fp << hirm.code_to_relation.at(rc); 162 | if (i ++ < rcs.size()) { 163 | fp << " "; 164 | } 165 | } 166 | fp << "\n"; 167 | } 168 | fp << "\n"; 169 | // Write the IRMs. 170 | int j = 0; 171 | for (const auto &[table, rcs] : tables) { 172 | const auto &irm = hirm.irms.at(table); 173 | fp << "irm=" << table << "\n"; 174 | to_txt(fp, *irm, encoding); 175 | if (j < tables.size() - 1) { 176 | fp << "\n"; 177 | j += 1; 178 | } 179 | } 180 | } 181 | 182 | void to_txt(const string &path, const IRM &irm, const T_encoding &encoding) { 183 | std::ofstream fp (path); 184 | assert(fp.good()); 185 | to_txt(fp, irm, encoding); 186 | fp.close(); 187 | } 188 | 189 | void to_txt(const string &path, const HIRM &hirm, const T_encoding &encoding) { 190 | std::ofstream fp (path); 191 | assert(fp.good()); 192 | to_txt(fp, hirm, encoding); 193 | fp.close(); 194 | } 195 | 196 | map>> 197 | load_clusters_irm(const string &path) { 198 | std::ifstream fp (path, std::ifstream::in); 199 | assert(fp.good()); 200 | 201 | map>> clusters; 202 | string line; 203 | while (std::getline(fp, line)) { 204 | std::istringstream stream (line); 205 | 206 | string domain; 207 | int table; 208 | vector items; 209 | 210 | stream >> domain; 211 | stream >> table; 212 | for (string w; stream >> w; ) { 213 | items.push_back(w); 214 | } 215 | assert(items.size() > 0); 216 | assert(clusters[domain].count(table) == 0); 217 | clusters[domain][table] = items; 218 | } 219 | fp.close(); 220 | return clusters; 221 | } 222 | 223 | 224 | int isnumeric(const std::string & s) { 225 | for (char c : s) { if (!isdigit(c)) { return false; } } 226 | return !s.empty() && true; 227 | } 228 | 229 | 230 | tuple< 231 | map>, // x[table] = {relation list} 232 | map>>> // x[table][domain][table] = {item list} 233 | > 234 | load_clusters_hirm(const string &path) { 235 | std::ifstream fp (path, std::ifstream::in); 236 | assert(fp.good()); 237 | 238 | map> relations; 239 | map>>> irms; 240 | 241 | string line; 242 | int irmc = 0; 243 | 244 | while (std::getline(fp, line)) { 245 | std::istringstream stream (line); 246 | 247 | string first; 248 | stream >> first; 249 | 250 | // Parse a relation cluster. 251 | if (isnumeric(first)) { 252 | int table = std::stoi(first); 253 | vector items; 254 | for (string item; stream >> item; ) { 255 | items.push_back(item); 256 | } 257 | assert(items.size() > 0); 258 | assert(relations.count(table) == 0); 259 | relations[table] = items; 260 | continue; 261 | } 262 | 263 | // Skip a new line. 264 | if (first.size() == 0) { 265 | irmc = -1; 266 | continue; 267 | } 268 | 269 | // Parse an irm= line. 270 | if (first.rfind("irm=", 0) == 0) { 271 | assert(irmc = -1); 272 | assert(first.size() > 4); 273 | auto x = first.substr(4); 274 | irmc = std::stoi(x); 275 | assert(irms.count(irmc) == 0); 276 | irms[irmc] = {}; 277 | continue; 278 | } 279 | 280 | // Parse a domain cluster. 281 | assert(irmc > -1); 282 | assert(irms.count(irmc) == 1); 283 | string second; 284 | stream >> second; 285 | assert(second.size() > 0); 286 | assert(isnumeric(second)); 287 | auto &domain = first; 288 | auto table = std::stoi(second); 289 | vector items; 290 | for (string item; stream >> item; ) { 291 | items.push_back(item); 292 | } 293 | assert(items.size() > 0); 294 | if (irms.at(irmc).count(domain) == 0) { 295 | irms.at(irmc)[domain] = {}; 296 | } 297 | assert(irms.at(irmc).at(domain).count(table) == 0); 298 | irms.at(irmc).at(domain)[table] = items; 299 | } 300 | 301 | assert(relations.size() == irms.size()); 302 | for (const auto &[t, rs] : relations) { 303 | assert(irms.count(t) == 1); 304 | } 305 | fp.close(); 306 | return std::make_pair(relations, irms); 307 | } 308 | 309 | void from_txt(IRM * const irm, 310 | const string &path_schema, 311 | const string &path_obs, 312 | const string &path_clusters) { 313 | // Load the data. 314 | auto schema = load_schema(path_schema); 315 | auto observations = load_observations(path_obs); 316 | auto encoding = encode_observations(schema, observations); 317 | auto clusters = load_clusters_irm(path_clusters); 318 | // Add the relations. 319 | assert(irm->schema.size() == 0); 320 | assert(irm->domains.size() == 0); 321 | assert(irm->relations.size() == 0); 322 | assert(irm->domain_to_relations.size() == 0); 323 | for (const auto &[r, ds] : schema) { 324 | irm->add_relation(r, ds); 325 | } 326 | // Add the domain entities with fixed clustering. 327 | T_encoding_f item_to_code = std::get<0>(encoding); 328 | for (const auto &[domain, tables] : clusters) { 329 | assert(irm->domains.at(domain)->items.size() == 0); 330 | for (const auto &[table, items] : tables) { 331 | assert(0 <= table); 332 | for (const auto &item : items) { 333 | auto code = item_to_code.at(domain).at(item); 334 | irm->domains.at(domain)->incorporate(code, table); 335 | } 336 | } 337 | } 338 | // Add the observations. 339 | incorporate_observations(*irm, encoding, observations); 340 | } 341 | 342 | void from_txt(HIRM * const hirm, 343 | const string &path_schema, 344 | const string &path_obs, 345 | const string &path_clusters) { 346 | auto schema = load_schema(path_schema); 347 | auto observations = load_observations(path_obs); 348 | auto encoding = encode_observations(schema, observations); 349 | auto [relations, irms] = load_clusters_hirm(path_clusters); 350 | // Add the relations. 351 | assert(hirm->schema.size() == 0); 352 | assert(hirm->irms.size() == 0); 353 | assert(hirm->relation_to_code.size() == 0); 354 | assert(hirm->code_to_relation.size() == 0); 355 | for (const auto &[r, ds] : schema) { 356 | hirm->add_relation(r, ds); 357 | assert(hirm->irms.size() == hirm->crp.tables.size()); 358 | hirm->set_cluster_assignment_gibbs(r, -1); 359 | } 360 | // Add each IRM. 361 | for (const auto &[table, rs] : relations) { 362 | assert(hirm->irms.size() == hirm->crp.tables.size()); 363 | // Add relations to the IRM. 364 | for (const auto &r : rs) { 365 | assert(hirm->irms.size() == hirm->crp.tables.size()); 366 | auto table_current = hirm->relation_to_table(r); 367 | if (table_current != table) { 368 | assert(hirm->irms.size() == hirm->crp.tables.size()); 369 | hirm->set_cluster_assignment_gibbs(r, table); 370 | } 371 | } 372 | // Add the domain entities with fixed clustering to this IRM. 373 | // TODO: Duplicated code with from_txt(IRM) 374 | auto irm = hirm->irms.at(table); 375 | auto clusters = irms.at(table); 376 | assert(irm->relations.size() == rs.size()); 377 | T_encoding_f item_to_code = std::get<0>(encoding); 378 | for (const auto &[domain, tables] : clusters) { 379 | assert(irm->domains.at(domain)->items.size() == 0); 380 | for (const auto &[t, items] : tables) { 381 | assert(0 <= t); 382 | for (const auto &item : items) { 383 | auto code = item_to_code.at(domain).at(item); 384 | irm->domains.at(domain)->incorporate(code, t); 385 | } 386 | } 387 | } 388 | } 389 | assert(hirm->irms.count(-1) == 0); 390 | // Add the observations. 391 | incorporate_observations(*hirm, encoding, observations); 392 | } 393 | -------------------------------------------------------------------------------- /cxx/util_io.hh: -------------------------------------------------------------------------------- 1 | // Copyright 2020 2 | // See LICENSE.txt 3 | 4 | #pragma once 5 | 6 | #include "globals.hh" 7 | #include "hirm.hh" 8 | 9 | typedef map> T_encoding_f; 10 | typedef map> T_encoding_r; 11 | typedef tuple T_encoding; 12 | 13 | typedef tuple, double> T_observation; 14 | typedef vector T_observations; 15 | 16 | typedef umap T_assignment; 17 | typedef umap T_assignments; 18 | 19 | // disk IO 20 | T_schema load_schema(const string &path); 21 | T_observations load_observations(const string &path); 22 | T_encoding encode_observations(const T_schema &schema, const T_observations &observations); 23 | 24 | void incorporate_observations(IRM &irm, const T_encoding &encoding, 25 | const T_observations &observations); 26 | void incorporate_observations(HIRM &hirm, const T_encoding &encoding, 27 | const T_observations &observations); 28 | 29 | void to_txt(const string &path, const IRM &irm, const T_encoding &encoding); 30 | void to_txt(const string &path, const HIRM &irm, const T_encoding &encoding); 31 | void to_txt(std::ostream &fp, const IRM &irm, const T_encoding &encoding); 32 | void to_txt(std::ostream &fp, const HIRM &irm, const T_encoding &encoding); 33 | 34 | map>> load_clusters_irm(const string &path); 35 | tuple< 36 | map>, // x[table] = {relation list} 37 | map>>> // x[table][domain][table] = {item list} 38 | > 39 | load_clusters_hirm(const string &path); 40 | 41 | void from_txt(IRM * const irm, const string &path_schema, const string &path_obs, const string &path_clusters); 42 | void from_txt(HIRM * const irm, const string &path_schema, const string &path_obs, const string &path_clusters); 43 | -------------------------------------------------------------------------------- /cxx/util_math.cc: -------------------------------------------------------------------------------- 1 | // Copyright 2021 MIT Probabilistic Computing Project 2 | // Apache License, Version 2.0, refer to LICENSE.txt 3 | 4 | #include 5 | 6 | #include "util_math.hh" 7 | 8 | using std::vector; 9 | 10 | const double INF = std::numeric_limits::infinity(); 11 | 12 | // http://matlab.izmiran.ru/help/techdoc/ref/betaln.html 13 | double lbeta(int z, int w) { 14 | return lgamma(z) + lgamma(w) - lgamma(z + w); 15 | } 16 | 17 | vector linspace(double start, double stop, int num, bool endpoint) { 18 | double step = (stop - start) / (num - endpoint); 19 | vector v; 20 | for (int i = 0; i < num; i++) { 21 | v.push_back(start + step * i); 22 | } 23 | return v; 24 | } 25 | 26 | vector log_linspace(double start, double stop, int num, bool endpoint) { 27 | auto v = linspace(log(start), log(stop), num, endpoint); 28 | for (int i = 0; i < v.size(); i++) { 29 | v[i] = exp(v[i]); 30 | } 31 | return v; 32 | } 33 | 34 | vector log_normalize(const std::vector &weights){ 35 | double Z = logsumexp(weights); 36 | vector result(weights.size()); 37 | for (int i = 0; i < weights.size(); i++) { 38 | result[i] = weights[i] - Z; 39 | } 40 | return result; 41 | } 42 | 43 | double logsumexp(const vector &weights) { 44 | double m = *std::max_element(weights.begin(), weights.end()); 45 | double s = 0; 46 | for (auto w : weights) { 47 | s += exp(w - m); 48 | } 49 | return log(s) + m; 50 | } 51 | 52 | int choice(const std::vector &weights, PRNG *prng) { 53 | std::discrete_distribution dist(weights.begin(), weights.end()); 54 | int idx = dist(*prng); 55 | return idx; 56 | } 57 | 58 | int log_choice(const std::vector &weights, PRNG *prng) { 59 | vector log_weights_norm = log_normalize(weights); 60 | vector weights_norm; 61 | for (double w : log_weights_norm) { 62 | weights_norm.push_back(exp(w)); 63 | } 64 | return choice(weights_norm, prng); 65 | } 66 | 67 | vector> product(const vector> &lists) { 68 | // https://rosettacode.org/wiki/Cartesian_product_of_two_or_more_lists#C.2B.2B 69 | vector> result; 70 | for (const auto &l : lists) { 71 | if (l.size() == 0) { 72 | return result; 73 | } 74 | } 75 | for (const auto &e : lists[0]) { 76 | result.push_back({e}); 77 | } 78 | for (size_t i = 1; i < lists.size(); ++i) { 79 | vector> temp; 80 | for (auto &e : result) { 81 | for (auto f : lists[i]) { 82 | auto e_tmp = e; 83 | e_tmp.push_back(f); 84 | temp.push_back(e_tmp); 85 | } 86 | } 87 | result = temp; 88 | } 89 | return result; 90 | } 91 | -------------------------------------------------------------------------------- /cxx/util_math.hh: -------------------------------------------------------------------------------- 1 | // Copyright 2021 MIT Probabilistic Computing Project 2 | // Apache License, Version 2.0, refer to LICENSE.txt 3 | 4 | #pragma once 5 | 6 | #include "globals.hh" 7 | 8 | double lbeta(int z, int w); 9 | 10 | vector linspace(double start, double stop, int num, bool endpoint); 11 | vector log_linspace(double start, double stop, int num, bool endpoint); 12 | vector log_normalize(const vector &weights); 13 | double logsumexp(const vector &weights); 14 | 15 | int choice(const vector &weights, PRNG *prng); 16 | int log_choice(const vector &weights, PRNG *prng); 17 | 18 | vector> product(const vector> &lists); 19 | -------------------------------------------------------------------------------- /examples/.gitignore: -------------------------------------------------------------------------------- 1 | *.hirm 2 | *.irm 3 | -------------------------------------------------------------------------------- /examples/animals_binary_irm.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 MIT Probabilistic Computing Project 2 | # Apache License, Version 2.0, refer to LICENSE.txt 3 | 4 | import os 5 | import random 6 | 7 | from pprint import pprint 8 | 9 | import matplotlib.pyplot as plt 10 | 11 | from hirm import IRM 12 | from hirm.util_io import load_schema 13 | from hirm.util_io import load_observations 14 | from hirm.util_io import to_txt_irm 15 | from hirm.util_plot import plot_binary_relation 16 | 17 | dirname = os.path.dirname(os.path.abspath(__file__)) 18 | path_schema = os.path.join(dirname, 'datasets', 'animals.binary.schema') 19 | path_obs = os.path.join(dirname, 'datasets', 'animals.binary.obs') 20 | schema = load_schema(path_schema) 21 | data = load_observations(path_obs) 22 | 23 | prng = random.Random(12) 24 | irm = IRM(schema, prng=prng) 25 | for relation, items, value in data: 26 | irm.incorporate(relation, items, value) 27 | 28 | for i in range(20): 29 | irm.transition_cluster_assignments() 30 | print(i, irm.logp_score()) 31 | pprint(irm.domains['animal'].crp.tables) 32 | pprint(irm.domains['feature'].crp.tables) 33 | 34 | fig, ax = plot_binary_relation(irm.relations['has'], transpose=True) 35 | plt.show() 36 | fig.set_tight_layout(True) 37 | path_figure = os.path.join('assets', 'animals.binary.irm.png') 38 | fig.savefig(path_figure) 39 | 40 | path_clusters = os.path.join('assets', 'animals.binary.irm') 41 | to_txt_irm(path_clusters, irm) 42 | print(path_clusters) 43 | -------------------------------------------------------------------------------- /examples/animals_unary_hirm.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 MIT Probabilistic Computing Project 2 | # Apache License, Version 2.0, refer to LICENSE.txt 3 | 4 | import os 5 | import random 6 | 7 | import matplotlib.pyplot as plt 8 | 9 | from hirm import HIRM 10 | from hirm.util_io import load_schema 11 | from hirm.util_io import load_observations 12 | from hirm.util_io import to_txt_hirm 13 | 14 | from hirm.util_plot import plot_hirm_crosscat 15 | 16 | dirname = os.path.dirname(os.path.abspath(__file__)) 17 | path_schema = os.path.join(dirname, 'datasets', 'animals.unary.schema') 18 | path_obs = os.path.join(dirname, 'datasets', 'animals.unary.obs') 19 | schema = load_schema(path_schema) 20 | data = load_observations(path_obs) 21 | features = [r for r in schema if len(schema[r]) == 1] 22 | 23 | prng = random.Random(12) 24 | hirm = HIRM(schema, prng=prng) 25 | for relation, items, value in data: 26 | print(relation, items, value) 27 | hirm.incorporate(relation, items, value) 28 | 29 | fig, ax = plot_hirm_crosscat(hirm, features) 30 | fig.set_tight_layout(True) 31 | 32 | print(hirm.logp_score()) 33 | for i in range(10): 34 | hirm.transition_cluster_assignments() 35 | for irm in hirm.irms.values(): 36 | irm.transition_cluster_assignments() 37 | print(i, hirm.logp_score(), [len(c) for c in hirm.crp.tables.values()]) 38 | 39 | fig, ax = plot_hirm_crosscat(hirm, features) 40 | plt.show() 41 | fig.set_tight_layout(True) 42 | 43 | path_fig = os.path.join('assets', 'animals.unary.hirm.png') 44 | fig.savefig(path_fig) 45 | print(path_fig) 46 | 47 | path_clusters = os.path.join('assets', 'animals.unary.hirm') 48 | to_txt_hirm(path_clusters, hirm) 49 | print(path_clusters) 50 | -------------------------------------------------------------------------------- /examples/animals_unary_irm.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 MIT Probabilistic Computing Project 2 | # Apache License, Version 2.0, refer to LICENSE.txt 3 | 4 | import os 5 | import random 6 | 7 | from pprint import pprint 8 | 9 | import matplotlib.pyplot as plt 10 | 11 | from hirm import IRM 12 | from hirm.util_io import load_schema 13 | from hirm.util_io import load_observations 14 | from hirm.util_io import to_txt_irm 15 | from hirm.util_plot import plot_unary_relations 16 | 17 | dirname = os.path.dirname(os.path.abspath(__file__)) 18 | path_schema = os.path.join(dirname, 'datasets', 'animals.unary.schema') 19 | path_obs = os.path.join(dirname, 'datasets', 'animals.unary.obs') 20 | schema = load_schema(path_schema) 21 | data = load_observations(path_obs) 22 | 23 | prng = random.Random(13412) 24 | irm = IRM(schema, prng=prng) 25 | for relation, items, value in data: 26 | irm.incorporate(relation, items, value) 27 | 28 | for i in range(20): 29 | irm.transition_cluster_assignments() 30 | print(i, irm.logp_score()) 31 | pprint(irm.domains['animal'].crp.tables) 32 | 33 | fig, ax = plot_unary_relations(list(irm.relations.values())) 34 | plt.show() 35 | fig.set_tight_layout(True) 36 | 37 | path_fig = os.path.join('assets', 'animals.unary.irm.png') 38 | fig.savefig(path_fig) 39 | print(path_fig) 40 | 41 | path_clusters = os.path.join('assets', 'animals.unary.irm') 42 | to_txt_irm(path_clusters, irm) 43 | print(path_clusters) 44 | -------------------------------------------------------------------------------- /examples/assets/.gitignore: -------------------------------------------------------------------------------- 1 | * 2 | -------------------------------------------------------------------------------- /examples/datasets/50animalbindat.csv: -------------------------------------------------------------------------------- 1 | ,black,white,blue,brown,gray,orange,red,yellow,patches,spots,stripes,furry,hairless,toughskin,big,small,bulbous,lean,flippers,hands,hooves,pads,paws,longleg,longneck,tail,chewteeth,meatteeth,buckteeth,strainteeth,horns,claws,tusks,smelly,flys,hops,swims,tunnels,walks,fast,slow,strong,weak,muscle,bipedal,quadrapedal,active,inactive,nocturnal,hibernate,agility,fish,meat,plankton,vegetation,insects,forager,grazer,hunter,scavenger,skimmer,stalker,newworld,oldworld,arctic,coastal,desert,bush,plains,forest,fields,jungle,mountains,ocean,ground,water,tree,cave,fierce,timid,smart,group,solitary,nestspot,domestic 2 | antelope,0,0,0,0,0,0,0,0,0,0,0,1,0,1,1,0,0,1,0,0,1,0,0,1,0,1,1,0,0,0,1,0,0,0,0,0,0,0,1,1,0,1,0,1,0,1,1,0,0,0,1,0,0,0,1,0,1,1,0,0,0,0,1,1,0,0,0,0,1,0,1,0,1,0,1,0,0,0,0,1,0,1,0,0,0 3 | grizzly bear,1,0,0,1,0,0,0,0,0,0,0,1,0,1,1,0,1,0,0,0,0,0,1,0,0,0,1,1,0,0,0,1,0,0,0,0,0,0,1,1,1,1,0,1,1,1,1,1,1,1,0,1,1,0,0,0,1,0,1,0,0,1,1,0,0,0,0,0,0,1,0,0,1,0,1,0,0,1,1,0,1,0,1,0,0 4 | killer whale,1,1,0,0,0,0,0,0,1,1,0,0,1,1,1,0,1,1,1,0,0,0,0,0,0,1,0,1,0,1,0,0,0,0,0,0,1,0,0,1,0,1,0,0,0,0,1,0,0,0,1,1,1,1,0,0,0,0,1,0,0,0,0,0,1,1,0,0,0,0,0,0,0,1,0,1,0,0,1,0,1,1,0,0,0 5 | beaver,0,0,0,1,0,0,0,0,0,0,0,1,0,1,0,1,1,0,0,0,0,1,1,0,0,1,1,0,1,1,0,1,0,0,0,0,1,0,0,1,0,1,0,1,0,1,1,0,1,1,1,1,0,0,0,0,0,0,0,0,0,0,1,0,0,1,0,0,0,0,0,0,0,0,1,1,0,0,0,1,1,1,1,1,0 6 | dalmatian,1,1,0,0,0,0,0,0,1,1,0,1,1,0,1,0,0,1,0,0,0,0,1,1,0,1,1,1,0,0,0,0,0,0,0,0,0,0,1,1,0,1,0,1,0,1,1,0,0,0,1,0,1,0,0,0,0,0,0,0,0,0,1,1,0,0,0,0,0,0,0,0,0,0,1,0,0,0,0,1,1,1,1,0,1 7 | persian cat,0,1,1,0,1,0,0,0,0,0,0,1,0,0,0,1,1,0,0,0,0,1,1,0,0,1,1,1,0,0,0,1,0,0,0,0,0,0,1,1,1,0,1,0,0,1,0,1,0,0,1,1,1,0,0,0,0,0,0,0,0,0,1,1,0,0,0,0,0,0,0,0,0,0,1,0,0,0,0,1,1,0,1,0,1 8 | horse,1,1,0,1,1,0,0,0,1,0,0,1,0,1,1,0,0,1,0,0,1,0,0,1,1,1,1,0,1,0,0,0,0,1,0,0,0,0,1,1,0,1,0,1,0,1,1,0,0,0,1,0,0,0,1,0,0,1,0,0,0,0,1,1,0,0,0,0,1,0,1,0,0,0,1,0,0,0,0,1,1,1,0,0,1 9 | german shepherd,1,0,0,1,1,0,0,0,1,0,0,1,0,0,1,0,0,1,0,0,0,1,1,1,0,1,1,1,0,0,0,1,0,1,0,0,0,0,1,1,0,1,0,1,0,1,1,0,0,0,1,0,1,0,0,0,0,0,1,0,0,1,1,1,0,0,0,0,1,0,0,0,0,0,1,0,0,0,1,0,1,0,1,0,1 10 | blue whale,0,0,1,0,1,0,0,0,0,1,0,0,1,1,1,0,1,0,1,0,0,0,0,0,0,1,0,0,0,1,0,0,0,0,0,0,1,0,0,1,1,1,0,1,0,0,0,1,0,0,0,1,0,1,0,0,0,0,0,0,1,0,1,1,1,0,0,0,0,0,0,0,0,1,0,1,0,0,0,1,1,1,1,0,0 11 | siamese cat,1,1,0,1,1,0,0,0,1,0,0,1,0,0,0,1,0,1,0,0,0,1,1,1,0,1,1,1,0,0,0,1,0,0,0,0,0,0,1,1,0,0,1,1,0,1,1,1,1,0,1,1,1,0,0,0,0,0,1,0,0,1,1,1,0,0,0,0,0,0,0,0,0,0,1,0,0,0,1,1,1,0,1,0,1 12 | skunk,1,1,0,0,0,0,0,0,0,0,1,1,0,0,0,1,1,0,0,0,0,1,1,0,0,1,0,0,0,0,0,0,0,1,0,0,0,0,1,1,1,0,1,0,0,1,1,0,1,1,0,0,0,0,1,0,1,0,0,0,0,0,1,1,0,0,0,0,0,1,1,0,0,0,1,0,0,0,0,1,0,0,1,0,0 13 | mole,1,0,0,1,1,0,0,0,0,0,0,1,0,0,0,1,1,1,0,0,0,0,1,0,0,0,1,0,1,0,0,1,0,0,0,0,0,1,1,1,1,0,1,0,0,1,1,0,1,1,1,0,0,0,1,1,1,0,0,0,0,0,1,1,0,0,0,0,1,1,1,0,0,0,1,0,0,0,0,1,0,0,1,1,0 14 | tiger,1,1,0,0,0,1,0,0,0,0,1,1,0,0,1,0,0,1,0,0,0,1,1,0,0,1,1,1,1,0,0,1,0,1,0,0,0,0,1,1,0,1,0,1,0,1,1,0,1,0,1,0,1,0,0,0,0,0,1,0,0,1,0,1,0,0,0,1,1,1,0,1,0,0,1,0,0,0,1,0,1,1,1,1,0 15 | hippopotamus,0,0,0,0,1,0,0,0,0,0,0,0,1,1,1,0,1,0,0,0,0,0,0,0,0,0,1,0,0,1,0,0,0,0,0,0,1,0,1,0,1,1,0,1,0,1,0,1,0,0,0,1,0,0,1,0,0,0,0,0,0,0,0,1,0,0,0,0,0,0,0,1,0,0,1,1,0,0,0,1,0,0,1,0,0 16 | leopard,1,0,0,1,0,0,0,1,1,1,0,1,0,0,1,0,0,1,0,0,0,0,1,1,0,1,0,1,0,0,0,1,0,0,0,0,0,0,1,1,0,1,0,1,0,1,1,0,1,0,1,1,1,0,0,0,1,0,1,0,0,1,1,1,0,0,0,1,1,0,0,1,1,0,1,0,1,0,1,0,1,0,1,1,0 17 | moose,0,0,0,1,0,0,0,0,0,0,0,1,0,1,1,0,1,0,0,0,1,0,0,1,1,1,1,0,0,0,1,0,0,1,0,0,0,0,1,1,1,1,0,1,0,1,0,1,0,0,0,0,0,0,1,0,1,1,0,0,0,0,1,1,1,0,0,0,1,1,1,0,1,0,1,0,0,0,0,1,0,1,1,0,0 18 | spider monkey,1,0,0,1,1,0,0,0,0,0,0,1,0,0,0,1,0,1,0,1,0,0,1,1,0,1,1,0,0,0,0,0,0,0,0,0,0,0,1,1,0,1,0,1,1,1,1,0,0,0,1,0,0,0,1,0,1,1,0,0,0,0,1,1,0,0,0,0,0,1,0,1,0,0,0,0,1,0,0,1,1,1,0,1,0 19 | humpback whale,1,0,1,0,1,0,0,0,0,0,0,0,1,1,1,0,1,0,1,0,0,0,0,0,0,1,0,0,0,1,0,0,0,0,0,0,1,0,0,1,1,1,0,1,0,0,0,1,0,0,0,1,0,1,0,0,0,0,0,0,1,0,1,1,1,1,0,0,0,0,0,0,0,1,0,1,0,0,0,1,1,1,0,0,0 20 | elephant,0,0,0,0,1,0,0,0,0,0,0,0,1,1,1,0,1,0,0,0,0,0,0,1,0,1,1,0,0,0,0,0,1,1,0,0,0,0,1,0,1,1,0,1,0,1,0,1,0,0,0,0,0,0,1,0,0,1,0,0,0,0,0,1,0,0,0,1,0,0,0,1,0,0,1,0,0,0,0,1,1,1,0,0,0 21 | gorilla,1,0,0,1,0,0,0,0,0,0,0,1,0,1,1,0,1,0,0,1,0,0,0,1,0,0,1,1,0,0,0,0,0,1,0,0,0,0,1,1,0,1,0,1,1,1,1,0,0,0,1,0,1,0,1,0,1,0,0,0,0,0,0,1,0,0,0,0,0,1,0,1,0,0,1,0,1,0,1,0,1,1,0,1,0 22 | ox,1,1,0,1,1,0,0,0,0,0,0,1,1,1,1,0,1,0,0,0,1,0,0,0,0,1,1,0,0,0,1,0,0,1,0,0,0,0,1,0,1,1,0,1,0,1,0,1,0,0,0,0,0,0,1,0,0,1,0,0,0,0,1,1,0,0,0,1,1,0,1,0,0,0,1,0,0,0,0,1,0,0,1,0,1 23 | fox,0,0,0,1,0,1,1,0,0,0,0,1,0,0,0,1,0,1,0,0,0,1,1,0,0,1,1,1,0,0,0,1,0,0,0,0,0,0,1,1,0,1,0,1,0,1,1,0,1,1,1,1,1,0,0,0,1,0,1,0,0,1,1,1,0,0,0,0,1,1,1,0,0,0,1,0,0,0,1,0,1,0,1,1,0 24 | sheep,1,1,0,0,0,0,0,0,0,0,0,1,0,0,0,0,1,0,0,0,1,0,0,0,0,0,1,0,0,0,0,0,0,1,0,0,0,0,1,0,1,0,1,0,0,1,0,1,0,0,0,0,0,0,1,0,0,1,0,0,0,0,1,1,0,0,0,0,1,0,1,0,1,0,1,0,0,0,0,1,0,1,0,0,1 25 | seal,1,1,0,1,1,0,0,0,0,1,0,0,1,1,1,1,0,0,1,0,0,0,0,0,0,1,0,0,0,0,0,0,0,0,0,0,1,0,0,1,1,1,0,0,0,0,1,1,0,0,1,1,0,0,0,0,0,0,1,0,0,0,1,1,1,1,0,0,0,0,0,0,0,1,0,1,0,0,0,1,1,1,0,0,1 26 | chimpanzee,1,0,0,1,0,0,0,0,0,0,0,1,0,1,1,1,0,1,0,1,0,0,0,1,0,1,1,1,0,0,0,0,0,1,0,0,0,0,1,1,0,1,0,1,1,1,1,0,0,0,1,0,0,0,1,1,1,0,0,0,0,0,1,1,0,0,0,1,0,1,0,1,1,0,1,0,1,0,1,1,1,1,0,1,1 27 | hamster,1,1,0,1,1,0,0,0,1,0,0,1,0,0,0,1,1,0,0,0,0,0,1,0,0,1,1,0,1,0,0,1,0,1,0,1,0,1,1,1,0,0,1,0,0,1,1,1,1,1,1,0,0,0,1,0,1,1,0,0,0,0,1,0,0,0,0,0,0,0,1,0,0,0,1,0,0,0,0,1,0,0,1,1,1 28 | squirrel,0,0,0,1,1,0,0,0,0,0,0,1,0,0,0,1,0,0,0,0,0,1,1,0,0,1,1,0,1,0,0,1,0,0,0,1,0,0,1,1,0,0,0,0,1,1,1,0,0,1,1,0,0,0,1,0,1,0,0,0,0,0,1,1,0,0,0,0,0,1,0,0,0,0,1,0,1,0,0,1,0,0,1,1,0 29 | rhinoceros,0,0,0,0,1,0,0,0,0,0,0,0,1,1,1,0,1,0,0,0,1,0,0,0,0,0,0,0,0,0,1,0,1,1,0,0,0,0,1,0,1,1,0,0,0,1,0,1,0,0,0,0,0,0,1,0,0,1,0,0,0,0,0,1,0,0,0,1,0,0,0,1,0,0,1,0,0,0,1,0,0,1,1,0,0 30 | rabbit,1,1,0,1,1,0,0,0,1,0,0,1,0,0,0,1,1,0,0,0,0,0,1,0,0,1,1,0,1,0,0,0,0,0,0,1,0,0,0,1,0,0,1,0,0,1,1,1,0,0,1,0,0,0,1,0,1,1,0,0,0,0,1,1,0,0,0,1,1,1,1,0,0,0,1,0,0,0,0,1,0,1,0,1,1 31 | bat,1,0,0,1,1,0,0,0,0,0,0,1,1,1,0,1,0,1,0,0,0,0,0,0,0,0,1,1,0,0,0,1,0,1,1,0,0,0,0,1,0,0,1,1,1,0,1,1,1,1,1,0,1,0,1,1,1,0,1,1,0,0,1,1,0,0,0,0,0,1,0,1,1,0,0,0,1,1,1,0,1,1,0,1,0 32 | giraffe,0,0,0,1,0,1,0,1,1,1,0,0,0,0,1,0,0,1,0,0,1,0,0,1,1,1,1,0,0,0,1,0,0,1,0,0,0,0,1,1,1,1,0,1,0,1,1,1,0,0,0,0,0,0,1,0,0,1,0,0,0,0,0,1,0,0,0,1,1,0,1,0,0,0,1,0,0,0,0,1,0,1,0,0,0 33 | wolf,1,1,0,1,1,0,0,0,0,0,0,1,0,0,1,0,0,1,0,0,0,1,1,0,0,1,1,1,0,0,0,1,0,0,0,0,0,0,1,1,0,1,0,1,0,1,1,0,1,0,1,0,1,0,0,0,1,0,1,1,0,1,1,1,1,0,0,0,1,1,0,0,1,0,1,0,0,1,1,0,1,1,1,0,0 34 | chihuahua,1,0,0,1,1,0,0,0,0,0,0,1,0,0,0,1,0,1,0,0,0,1,1,0,0,1,0,1,0,0,0,1,0,1,0,0,0,0,1,1,0,0,1,0,0,1,1,0,0,0,0,0,1,0,0,0,0,0,0,0,0,0,1,1,0,0,0,0,0,0,0,0,0,0,1,0,0,0,1,1,1,0,1,0,1 35 | rat,1,1,0,1,1,0,0,0,0,0,0,1,0,0,0,1,1,1,0,0,0,0,1,0,0,1,0,1,1,0,0,1,0,1,0,0,0,1,1,1,0,0,0,0,0,1,1,0,1,1,1,0,1,0,0,1,1,0,1,1,0,0,1,1,0,0,0,0,1,1,1,0,0,0,1,0,0,0,1,0,1,0,1,1,0 36 | weasel,1,0,0,1,1,0,0,0,0,0,0,1,0,0,0,1,0,1,0,0,0,0,1,0,0,1,1,1,0,0,0,1,0,1,0,0,0,1,1,1,0,0,0,1,0,1,1,0,1,0,1,0,1,0,0,0,1,0,1,0,0,0,1,1,0,0,0,0,0,1,1,0,0,0,1,0,0,0,1,0,1,0,1,0,0 37 | otter,1,0,0,1,0,0,0,0,0,0,0,1,0,0,0,1,1,1,1,0,0,0,1,0,0,1,1,1,1,0,0,1,0,0,0,0,1,0,0,1,0,0,0,0,0,1,1,0,0,0,1,1,0,0,0,0,0,0,1,0,0,0,1,1,1,1,0,0,0,1,0,0,0,1,0,1,0,0,0,1,1,0,1,1,0 38 | buffalo,1,0,0,1,0,0,0,0,0,0,0,1,0,1,1,0,1,0,0,0,1,0,0,0,0,0,1,0,0,0,1,0,0,1,0,0,0,0,1,1,1,1,0,1,0,1,0,1,0,0,0,0,0,0,1,0,1,1,0,0,0,0,1,0,0,0,0,0,1,0,0,0,0,0,1,0,0,0,1,0,0,1,0,0,0 39 | zebra,1,1,0,0,0,0,0,0,0,0,1,1,0,1,1,0,0,1,0,0,1,0,0,1,1,1,1,0,0,0,0,0,0,0,0,0,0,0,1,1,0,1,0,1,0,1,1,0,0,0,1,0,0,0,1,0,1,1,0,0,0,0,0,1,0,0,0,1,1,0,1,0,0,0,1,0,0,0,0,1,1,1,0,0,0 40 | giant panda,1,1,0,0,0,0,0,0,1,1,0,1,0,0,1,0,1,0,0,0,0,1,1,0,0,0,1,0,1,0,0,1,0,0,0,0,0,0,1,0,1,1,0,0,1,1,0,1,0,0,0,1,0,0,1,0,1,1,0,0,0,0,1,1,0,0,0,1,0,1,0,1,0,0,1,0,1,0,0,1,1,1,1,1,1 41 | deer,0,0,0,1,0,0,0,0,1,1,0,1,0,0,1,0,0,1,0,0,1,0,0,1,1,1,1,0,0,0,1,0,0,0,0,0,0,0,1,1,0,1,0,1,0,1,1,0,0,0,1,0,0,0,1,0,1,1,0,0,0,0,1,1,0,0,0,0,1,1,1,0,1,0,1,0,0,0,0,1,1,1,0,1,0 42 | bobcat,0,0,0,1,0,1,0,1,0,1,0,1,0,0,0,1,0,1,0,0,0,1,1,0,0,1,0,1,0,0,0,1,0,0,0,0,0,0,1,1,0,1,0,1,0,1,1,0,0,1,1,0,1,0,0,0,1,0,1,0,0,1,1,1,0,0,0,0,1,1,0,0,1,0,1,0,1,1,1,0,1,0,1,0,0 43 | pig,1,1,0,1,1,0,0,0,1,1,0,0,1,1,1,0,1,0,0,0,1,0,0,0,0,1,1,0,0,0,0,0,0,1,0,0,0,0,1,0,1,1,0,0,0,1,0,1,0,0,0,0,0,0,1,0,1,0,0,0,0,0,1,1,0,0,0,0,0,0,1,0,0,0,1,0,0,0,1,1,1,1,0,0,1 44 | lion,0,0,0,1,0,0,0,1,0,0,0,1,0,0,1,0,1,1,0,0,0,1,1,0,0,1,0,1,0,0,0,1,0,1,0,0,0,0,1,1,0,1,0,1,0,1,1,1,0,0,1,0,1,0,0,0,1,0,1,0,0,1,0,1,0,0,1,1,0,1,0,1,1,0,1,0,0,0,1,0,1,1,0,1,0 45 | mouse,0,1,0,1,1,0,0,0,0,0,0,1,0,0,0,1,0,0,0,0,0,0,1,0,0,1,1,0,1,0,0,0,0,1,0,0,0,1,1,1,0,0,1,0,0,1,1,0,1,1,1,0,0,0,1,0,0,1,0,1,0,0,1,1,0,0,0,0,1,1,1,0,0,0,1,0,0,0,0,1,0,1,0,1,1 46 | polar bear,0,1,0,0,0,0,0,0,0,0,0,1,0,1,1,0,1,0,0,0,0,1,1,0,0,0,1,1,0,0,0,1,0,1,0,0,1,0,1,1,1,1,0,0,1,1,1,1,0,1,1,1,1,0,0,0,1,0,1,1,0,1,1,1,1,1,0,0,0,0,0,0,0,1,1,1,0,0,1,0,0,0,1,0,0 47 | collie,0,1,0,1,0,0,0,0,1,0,0,1,0,0,1,1,0,1,0,0,0,0,1,0,0,1,1,1,0,0,0,0,0,0,0,0,0,0,1,1,0,0,0,0,0,1,1,0,0,0,1,0,1,0,0,0,0,0,1,0,0,0,1,1,0,0,0,0,0,0,0,0,0,0,1,0,0,0,0,1,1,0,1,0,1 48 | walrus,0,0,0,1,1,0,0,0,0,0,0,0,1,1,1,0,1,0,1,0,0,0,0,0,0,0,1,1,1,1,0,0,1,1,0,0,1,0,0,1,1,1,0,0,0,0,0,1,0,0,0,1,0,0,0,0,0,0,0,0,0,0,1,0,1,1,0,0,0,0,0,0,0,1,0,1,0,0,0,1,1,1,0,1,0 49 | raccoon,1,1,0,0,1,0,0,0,1,1,1,1,0,0,0,1,0,0,0,0,0,1,1,0,0,1,1,1,0,0,0,1,0,0,0,0,0,0,1,1,0,0,0,0,0,1,1,0,1,1,1,1,1,0,1,0,1,0,0,1,0,0,1,1,0,0,0,0,0,1,1,0,1,0,1,0,1,0,1,0,1,0,1,1,0 50 | cow,1,1,0,1,0,0,0,0,1,1,0,1,0,1,1,0,1,0,0,0,1,0,0,0,0,1,1,0,0,0,1,0,0,1,0,0,0,0,1,0,1,1,0,0,0,1,1,1,0,0,0,0,0,0,1,0,0,1,0,0,0,0,1,1,0,0,0,0,1,0,1,0,0,0,1,0,0,0,0,1,0,1,0,0,1 51 | dolphin,0,1,1,0,1,0,0,0,0,0,0,0,1,1,1,0,0,1,1,0,0,0,0,0,0,1,1,0,0,0,0,0,0,0,0,0,1,0,0,1,0,1,0,1,0,0,1,0,0,0,1,1,0,0,0,0,0,0,0,0,0,0,1,1,0,1,0,0,0,0,0,0,0,1,0,1,0,0,0,1,1,1,0,0,1 52 | -------------------------------------------------------------------------------- /examples/datasets/50animalbindat.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/probsys/hierarchical-irm/6e987386ec4b0a36b69824f265cda20fdaa55222/examples/datasets/50animalbindat.mat -------------------------------------------------------------------------------- /examples/datasets/README_DATASETS: -------------------------------------------------------------------------------- 1 | Data analyzed by 2 | 3 | Kemp, C., Tenenbaum, J. B., Griffiths, T. L., Yamada, T. & Ueda, N. (2006). 4 | Learning systems of concepts with an infinite relational model. AAAI 2006 5 | 6 | http://charleskemp.com/papers/KempTGYU06.pdf 7 | 8 | 1) 50animalbindat: feature ratings primarily collected by Osherson et al. 9 | 10 | 2) UMLS data: data from a biomedical ontology prepared by McCray et al. 11 | 12 | 3) alyawarradata: kinship terms collected by Denham. 13 | atts: kinship sections for each person 14 | datass: kinship sections (one of n coding) 15 | features: column 4 is gender, column 5 is age, column 9 is kinship section. 16 | 17 | The other features are described by Denham at 18 | 19 | https://www.kinsources.net/kidarep/dataset-49-alyawarra-1971-au01.xhtml 20 | 21 | More information about the data is at: 22 | 23 | http://www.culturalsciences.info/AlyaWeb/index.htm 24 | http://www.culturalsciences.info/GCBS/index.htm 25 | http://onlinelibrary.wiley.com/doi/10.1111/amet.1979.6.issue-1/issuetoc 26 | 27 | 28 | 4) dnations.mat: International relations data from the Dimensionality of 29 | Nations project (Rummel). We've thresholded each continuous 30 | variable at its mean and used one-of-n coding for the 31 | categorical variables. 32 | 33 | -------------------------------------------------------------------------------- /examples/datasets/alyawarradata.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/probsys/hierarchical-irm/6e987386ec4b0a36b69824f265cda20fdaa55222/examples/datasets/alyawarradata.mat -------------------------------------------------------------------------------- /examples/datasets/animals.binary.schema: -------------------------------------------------------------------------------- 1 | bernoulli has feature animal 2 | -------------------------------------------------------------------------------- /examples/datasets/animals.unary.schema: -------------------------------------------------------------------------------- 1 | bernoulli black animal 2 | bernoulli white animal 3 | bernoulli blue animal 4 | bernoulli brown animal 5 | bernoulli gray animal 6 | bernoulli orange animal 7 | bernoulli red animal 8 | bernoulli yellow animal 9 | bernoulli patches animal 10 | bernoulli spots animal 11 | bernoulli stripes animal 12 | bernoulli furry animal 13 | bernoulli hairless animal 14 | bernoulli toughskin animal 15 | bernoulli big animal 16 | bernoulli small animal 17 | bernoulli bulbous animal 18 | bernoulli lean animal 19 | bernoulli flippers animal 20 | bernoulli hands animal 21 | bernoulli hooves animal 22 | bernoulli pads animal 23 | bernoulli paws animal 24 | bernoulli longleg animal 25 | bernoulli longneck animal 26 | bernoulli tail animal 27 | bernoulli chewteeth animal 28 | bernoulli meatteeth animal 29 | bernoulli buckteeth animal 30 | bernoulli strainteeth animal 31 | bernoulli horns animal 32 | bernoulli claws animal 33 | bernoulli tusks animal 34 | bernoulli smelly animal 35 | bernoulli flys animal 36 | bernoulli hops animal 37 | bernoulli swims animal 38 | bernoulli tunnels animal 39 | bernoulli walks animal 40 | bernoulli fast animal 41 | bernoulli slow animal 42 | bernoulli strong animal 43 | bernoulli weak animal 44 | bernoulli muscle animal 45 | bernoulli bipedal animal 46 | bernoulli quadrapedal animal 47 | bernoulli active animal 48 | bernoulli inactive animal 49 | bernoulli nocturnal animal 50 | bernoulli hibernate animal 51 | bernoulli agility animal 52 | bernoulli fish animal 53 | bernoulli meat animal 54 | bernoulli plankton animal 55 | bernoulli vegetation animal 56 | bernoulli insects animal 57 | bernoulli forager animal 58 | bernoulli grazer animal 59 | bernoulli hunter animal 60 | bernoulli scavenger animal 61 | bernoulli skimmer animal 62 | bernoulli stalker animal 63 | bernoulli newworld animal 64 | bernoulli oldworld animal 65 | bernoulli arctic animal 66 | bernoulli coastal animal 67 | bernoulli desert animal 68 | bernoulli bush animal 69 | bernoulli plains animal 70 | bernoulli forest animal 71 | bernoulli fields animal 72 | bernoulli jungle animal 73 | bernoulli mountains animal 74 | bernoulli ocean animal 75 | bernoulli ground animal 76 | bernoulli water animal 77 | bernoulli tree animal 78 | bernoulli cave animal 79 | bernoulli fierce animal 80 | bernoulli timid animal 81 | bernoulli smart animal 82 | bernoulli group animal 83 | bernoulli solitary animal 84 | bernoulli nestspot animal 85 | bernoulli domestic animal 86 | -------------------------------------------------------------------------------- /examples/datasets/convert_animals.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 MIT Probabilistic Computing Project 2 | # Apache License, Version 2.0, refer to LICENSE.txt 3 | 4 | from scipy.io import loadmat 5 | 6 | # Animals as a single binary relation" 7 | # has: Animals x Features -> {0,1} 8 | x = loadmat('50animalbindat.mat') 9 | features = [y[0][0] for y in x['features'].T] 10 | animals = [y[0][0] for y in x['names'].T] 11 | data = x['data'] 12 | with open('animals.binary.schema', 'w') as f: 13 | f.write('bernoulli has feature animal\n') 14 | with open('animals.binary.obs', 'w') as f: 15 | for i, animal in enumerate(animals): 16 | for j, feature in enumerate(features): 17 | value = int(data[i,j]) 18 | a = animal.replace(' ', '') 19 | f.write('%d has %s %s\n' % (value, feature, a)) 20 | 21 | with open('animals.unary.schema', 'w') as f: 22 | for feature in features: 23 | f.write('bernoulli %s animal\n' % (feature,)) 24 | with open('animals.unary.obs', 'w') as f: 25 | for j, feature in enumerate(features): 26 | for i, animal in enumerate(animals): 27 | value = data[i,j] 28 | a = animal.replace(' ', '') 29 | f.write('%d %s %s\n' % (value, feature, a)) 30 | -------------------------------------------------------------------------------- /examples/datasets/convert_nations.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 MIT Probabilistic Computing Project 2 | # Apache License, Version 2.0, refer to LICENSE.txt 3 | 4 | import json 5 | import math 6 | 7 | from scipy.io import loadmat 8 | 9 | # Nations as two Relations. 10 | x = loadmat('dnations.mat') 11 | attnames = [y[0][0] for y in x['attnames'].T] 12 | relnnames = [y[0][0] for y in x['relnnames'].T] 13 | countrynames = [y[0][0] for y in x['countrynames'].T] 14 | for i, r in enumerate(relnnames): 15 | if r in attnames: 16 | relnnames[i] += '_rel' 17 | 18 | with open('nations.binary.schema', 'w') as f: 19 | f.write('bernoulli has feature country\n') 20 | f.write('bernoulli applies predicate country country\n') 21 | 22 | with open('nations.binary.obs', 'w') as f: 23 | for i, country in enumerate(countrynames): 24 | for j, feature in enumerate(attnames): 25 | value = x['A'][i,j] 26 | if not math.isnan(value): 27 | f.write('%d has %s %s\n' % (value, feature, country)) 28 | for k, predicate in enumerate(relnnames): 29 | for i, country0 in enumerate(countrynames): 30 | for j, country1 in enumerate(countrynames): 31 | value = x['R'][i,j,k] 32 | if not math.isnan(value): 33 | f.write('%d applies %s %s %s\n' % 34 | (value, predicate, country0, country1)) 35 | 36 | # Nations as multiple Relations. 37 | with open('nations.unary.schema', 'w') as f: 38 | for feature in attnames: 39 | f.write('bernoulli %s country\n' % (feature,)) 40 | for predicate in relnnames: 41 | f.write('bernoulli %s country country\n' % (predicate,)) 42 | 43 | with open('nations.unary.obs', 'w') as f: 44 | for j, feature in enumerate(attnames): 45 | for i, country in enumerate(countrynames): 46 | value = x['A'][i,j] 47 | if not math.isnan(value): 48 | f.write('%d %s %s\n' % (value, feature, country)) 49 | for k, predicate in enumerate(relnnames): 50 | for i, country0 in enumerate(countrynames): 51 | for j, country1 in enumerate(countrynames): 52 | value = x['R'][i,j,k] 53 | if not math.isnan(value): 54 | f.write('%d %s %s %s\n' % 55 | (value, predicate, country0, country1)) 56 | -------------------------------------------------------------------------------- /examples/datasets/dnations.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/probsys/hierarchical-irm/6e987386ec4b0a36b69824f265cda20fdaa55222/examples/datasets/dnations.mat -------------------------------------------------------------------------------- /examples/datasets/irmdata.tar.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/probsys/hierarchical-irm/6e987386ec4b0a36b69824f265cda20fdaa55222/examples/datasets/irmdata.tar.gz -------------------------------------------------------------------------------- /examples/datasets/nations.binary.schema: -------------------------------------------------------------------------------- 1 | bernoulli has feature country 2 | bernoulli applies predicate country country 3 | -------------------------------------------------------------------------------- /examples/datasets/nations.unary.schema: -------------------------------------------------------------------------------- 1 | bernoulli telephone country 2 | bernoulli agriculturalpop country 3 | bernoulli energyconsume country 4 | bernoulli illiterates country 5 | bernoulli GNP country 6 | bernoulli popxenergabs country 7 | bernoulli incomeabs country 8 | bernoulli popabs country 9 | bernoulli unassessment country 10 | bernoulli defenseexpabs country 11 | bernoulli englishtitles country 12 | bernoulli blocmembership0 country 13 | bernoulli usaidreceived country 14 | bernoulli freedomofopposition0 country 15 | bernoulli IFCandIBRD country 16 | bernoulli threats country 17 | bernoulli accusations country 18 | bernoulli killedforeignviolence country 19 | bernoulli militaryaction country 20 | bernoulli protests country 21 | bernoulli killeddomesticviolence country 22 | bernoulli riots country 23 | bernoulli purges country 24 | bernoulli demonstrations country 25 | bernoulli catholics country 26 | bernoulli airdistance country 27 | bernoulli medicinengo country 28 | bernoulli diplomatexpelled country 29 | bernoulli divorces country 30 | bernoulli popn/land country 31 | bernoulli arable country 32 | bernoulli area country 33 | bernoulli roadlength country 34 | bernoulli railroadlength country 35 | bernoulli religions country 36 | bernoulli immigrants/migrants country 37 | bernoulli rainfall country 38 | bernoulli largestrelgn country 39 | bernoulli runningwater country 40 | bernoulli foreigncollegestud country 41 | bernoulli neutralblock country 42 | bernoulli age country 43 | bernoulli religioustitles country 44 | bernoulli emigrants country 45 | bernoulli seabornegoods country 46 | bernoulli lawngos country 47 | bernoulli unemployed country 48 | bernoulli export country 49 | bernoulli languages country 50 | bernoulli largestlang country 51 | bernoulli ethnicgrps country 52 | bernoulli economicaidtaken country 53 | bernoulli techassistancetaken country 54 | bernoulli goveducationspend country 55 | bernoulli femaleworkers country 56 | bernoulli exports country 57 | bernoulli foreignmail country 58 | bernoulli imports country 59 | bernoulli caloriesconsumed country 60 | bernoulli protein country 61 | bernoulli russiantitles country 62 | bernoulli militarypersonnel country 63 | bernoulli investments country 64 | bernoulli politicalparties country 65 | bernoulli artsculturengo country 66 | bernoulli communistparty country 67 | bernoulli govspending country 68 | bernoulli monarchy country 69 | bernoulli primaryschool country 70 | bernoulli govchangelegal0 country 71 | bernoulli legitgov0 country 72 | bernoulli largestethnic country 73 | bernoulli assassinations country 74 | bernoulli majgovcrisis country 75 | bernoulli unpaymentdelinq country 76 | bernoulli balancepayments country 77 | bernoulli balanceinvestments country 78 | bernoulli systemstyle0 country 79 | bernoulli constitutional0 country 80 | bernoulli electoralsystem0 country 81 | bernoulli noncommunist country 82 | bernoulli politicalleadership0 country 83 | bernoulli horizontalpower0 country 84 | bernoulli military0 country 85 | bernoulli bureaucracy0 country 86 | bernoulli censorship0 country 87 | bernoulli geographyx country 88 | bernoulli geographyy country 89 | bernoulli geographyz country 90 | bernoulli blocmembership1 country 91 | bernoulli blocmembership2 country 92 | bernoulli freedomofopposition1 country 93 | bernoulli freedomofopposition2 country 94 | bernoulli govchangelegal1 country 95 | bernoulli govchangelegal2 country 96 | bernoulli legitgov1 country 97 | bernoulli systemstyle1 country 98 | bernoulli systemstyle2 country 99 | bernoulli constitutional1 country 100 | bernoulli constitutional2 country 101 | bernoulli electoralsystem1 country 102 | bernoulli electoralsystem2 country 103 | bernoulli politicalleadership1 country 104 | bernoulli politicalleadership2 country 105 | bernoulli horizontalpower2 country 106 | bernoulli military1 country 107 | bernoulli military2 country 108 | bernoulli bureaucracy1 country 109 | bernoulli bureaucracy2 country 110 | bernoulli censorship1 country 111 | bernoulli censorship2 country 112 | bernoulli economicaid country country 113 | bernoulli releconomicaid country country 114 | bernoulli treaties country country 115 | bernoulli reltreaties country country 116 | bernoulli officialvisits country country 117 | bernoulli conferences country country 118 | bernoulli exportbooks country country 119 | bernoulli relexportbooks country country 120 | bernoulli booktranslations country country 121 | bernoulli relbooktranslations country country 122 | bernoulli warning country country 123 | bernoulli violentactions country country 124 | bernoulli militaryactions country country 125 | bernoulli duration country country 126 | bernoulli negativebehavior country country 127 | bernoulli severdiplomatic country country 128 | bernoulli expeldiplomats country country 129 | bernoulli boycottembargo country country 130 | bernoulli aidenemy country country 131 | bernoulli negativecomm country country 132 | bernoulli accusation country country 133 | bernoulli protests_rel country country 134 | bernoulli unoffialacts country country 135 | bernoulli attackembassy country country 136 | bernoulli nonviolentbehavior country country 137 | bernoulli weightedunvote country country 138 | bernoulli unweightedunvote country country 139 | bernoulli tourism country country 140 | bernoulli reltourism country country 141 | bernoulli tourism3 country country 142 | bernoulli emigrants_rel country country 143 | bernoulli relemigrants country country 144 | bernoulli emigrants3 country country 145 | bernoulli students country country 146 | bernoulli relstudents country country 147 | bernoulli exports_rel country country 148 | bernoulli relexports country country 149 | bernoulli exports3 country country 150 | bernoulli intergovorgs country country 151 | bernoulli relintergovorgs country country 152 | bernoulli ngo country country 153 | bernoulli relngo country country 154 | bernoulli intergovorgs3 country country 155 | bernoulli ngoorgs3 country country 156 | bernoulli embassy country country 157 | bernoulli reldiplomacy country country 158 | bernoulli timesincewar country country 159 | bernoulli timesinceally country country 160 | bernoulli lostterritory country country 161 | bernoulli dependent country country 162 | bernoulli independence country country 163 | bernoulli commonbloc0 country country 164 | bernoulli blockpositionindex country country 165 | bernoulli militaryalliance country country 166 | bernoulli commonbloc1 country country 167 | bernoulli commonbloc2 country country 168 | -------------------------------------------------------------------------------- /examples/datasets/uml.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/probsys/hierarchical-irm/6e987386ec4b0a36b69824f265cda20fdaa55222/examples/datasets/uml.mat -------------------------------------------------------------------------------- /examples/nations_binary_irm.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 MIT Probabilistic Computing Project 2 | # Apache License, Version 2.0, refer to LICENSE.txt 3 | 4 | import json 5 | import os 6 | import random 7 | 8 | import matplotlib.pyplot as plt 9 | 10 | from pprint import pprint 11 | 12 | from hirm import IRM 13 | from hirm.util_io import load_observations 14 | from hirm.util_io import load_schema 15 | from hirm.util_io import to_dict_IRM 16 | from hirm.util_io import to_txt_irm 17 | from hirm.util_plot import plot_binary_relation 18 | from hirm.util_plot import plot_ternary_relation 19 | 20 | dirname = os.path.dirname(os.path.abspath(__file__)) 21 | path_schema = os.path.join(dirname, 'datasets', 'nations.binary.schema') 22 | path_obs = os.path.join(dirname, 'datasets', 'nations.binary.obs') 23 | schema = load_schema(path_schema) 24 | data = load_observations(path_obs) 25 | 26 | prng = random.Random(12) 27 | irm = IRM(schema, prng=prng) 28 | for relation, items, value in data: 29 | irm.incorporate(relation, items, value) 30 | 31 | for i in range(10): 32 | irm.transition_cluster_assignments() 33 | print(i, irm.logp_score()) 34 | 35 | pprint(irm.domains['country'].crp.tables) 36 | pprint(irm.domains['feature'].crp.tables) 37 | pprint(irm.domains['predicate'].crp.tables) 38 | 39 | fig, ax = plot_binary_relation(irm.relations['has'], transpose=True) 40 | fig.set_tight_layout(True) 41 | fig.set_size_inches((20, 10)) 42 | path_features = os.path.join('assets', 'nations.binary.irm.features.png') 43 | fig.savefig(path_features) 44 | print(path_features) 45 | for predicate in irm.domains['predicate'].items: 46 | fig, ax = plot_ternary_relation(irm.relations['applies'], predicate) 47 | fname = os.path.join('assets', 'nations.binary.irm.%s.png' % (predicate,)) 48 | fig.set_tight_layout(True) 49 | fig.savefig(fname) 50 | print(fname) 51 | plt.close(fig) 52 | 53 | d = to_dict_IRM(irm) 54 | path_json = os.path.join('assets', 'nations.binary.irm.json') 55 | with open(path_json, 'w') as f: 56 | json.dump(d, f, indent=4) 57 | print(path_json) 58 | 59 | path_clusters = os.path.join('assets', 'nations.binary.irm') 60 | to_txt_irm(path_clusters, irm) 61 | print(path_clusters) 62 | -------------------------------------------------------------------------------- /examples/nations_unary_hirm.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 MIT Probabilistic Computing Project 2 | # Apache License, Version 2.0, refer to LICENSE.txt 3 | 4 | import json 5 | import os 6 | import random 7 | 8 | import matplotlib.pyplot as plt 9 | 10 | from pprint import pprint 11 | 12 | from hirm import HIRM 13 | from hirm.util_io import load_observations 14 | from hirm.util_io import load_schema 15 | from hirm.util_io import to_dict_HIRM 16 | from hirm.util_io import to_txt_hirm 17 | from hirm.util_plot import plot_binary_relation 18 | from hirm.util_plot import plot_hirm_crosscat 19 | 20 | dirname = os.path.dirname(os.path.abspath(__file__)) 21 | path_schema = os.path.join(dirname, 'datasets', 'nations.unary.schema') 22 | path_obs = os.path.join(dirname, 'datasets', 'nations.unary.obs') 23 | schema = load_schema(path_schema) 24 | data = load_observations(path_obs) 25 | features = [r for r in schema if len(schema[r]) == 1] 26 | predicates = [r for r in schema if len(schema[r]) == 2] 27 | 28 | prng = random.Random(12) 29 | hirm = HIRM(schema, prng=prng) 30 | for relation, items, value in data: 31 | hirm.incorporate(relation, items, value) 32 | 33 | print(hirm.logp_score()) 34 | for i in range(10): 35 | hirm.transition_cluster_assignments() 36 | for irm in hirm.irms.values(): 37 | irm.transition_cluster_assignments() 38 | print(i, hirm.logp_score(), [len(c) for c in hirm.crp.tables.values()]) 39 | 40 | pprint(hirm.crp.tables) 41 | 42 | fig, ax = plot_hirm_crosscat(hirm, features) 43 | fig.set_size_inches((30, 10)) 44 | fig.set_tight_layout(True) 45 | path_features = os.path.join('assets', 'nations.unary.hirm.features.png') 46 | fig.savefig(path_features) 47 | print(path_features) 48 | for r in predicates: 49 | irm = hirm.relation_to_irm(r) 50 | fig, ax = plot_binary_relation(irm.relations[r]) 51 | fname = os.path.join('assets', 'nations.unary.hirm.%s.png' % (r,)) 52 | fig.set_tight_layout(True) 53 | fig.savefig(fname) 54 | print(fname) 55 | plt.close(fig) 56 | 57 | d = to_dict_HIRM(hirm) 58 | path_json = os.path.join('assets', 'nations.unary.hirm.json') 59 | with open(path_json, 'w') as f: 60 | json.dump(d, f, indent=4) 61 | print(path_json) 62 | 63 | path_clusters = os.path.join('assets', 'nations.unary.hirm') 64 | to_txt_hirm(path_clusters, hirm) 65 | print(path_clusters) 66 | -------------------------------------------------------------------------------- /examples/three_relations.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 MIT Probabilistic Computing Project 2 | # Apache License, Version 2.0, refer to LICENSE.txt 3 | 4 | import os 5 | import random 6 | 7 | import matplotlib.pyplot as plt 8 | import numpy as np 9 | 10 | from hirm import HIRM 11 | from hirm.util_io import to_txt_hirm 12 | from three_relations_plot import xlabels 13 | from three_relations_plot import ylabels 14 | from three_relations_plot import plot_hirm_clusters 15 | 16 | ITERS = int(os.environ.get('ITERS', '100')) 17 | NOISY = os.environ.get('NOISY', True) 18 | SHUFFLE = os.environ.get('SHUFFLE', True) 19 | noisy_str = 'noisy' if NOISY else 'clean' 20 | prng = random.Random(1) 21 | def flip(p): 22 | return prng.random() < p if NOISY else p > 0.5 23 | 24 | P_LO = .1 25 | P_HI = .9 26 | 27 | # ===== Synthetic data generation. 28 | schema = { 29 | 'R1': ('D1', 'D1'), 30 | 'R2': ('D1', 'D2'), 31 | 'R3': ('D1', 'D3'), 32 | } 33 | n_items = {'D1': 300, 'D2': 300, 'D3': 300} 34 | 35 | # R1 36 | items_D1_R1 = [ 37 | list(range(0, 200)), 38 | list(range(200, 300)), 39 | ] 40 | data_r1_d10_d10 = [((i, j), flip(P_LO)) for i in items_D1_R1[0] for j in items_D1_R1[0]] 41 | data_r1_d10_d11 = [((i, j), flip(P_HI)) for i in items_D1_R1[0] for j in items_D1_R1[1]] 42 | data_r1_d11_d10 = [((i, j), flip(P_HI)) for i in items_D1_R1[1] for j in items_D1_R1[0]] 43 | data_r1_d11_d11 = [((i, j), flip(P_LO)) for i in items_D1_R1[1] for j in items_D1_R1[1]] 44 | data_r1 = data_r1_d10_d10 + data_r1_d10_d11 + data_r1_d11_d10 + data_r1_d11_d11 45 | 46 | # R2 47 | items_D1_R2 = [ 48 | list(range(0, 300))[::2], 49 | list(range(0, 300))[1::2] 50 | ] 51 | items_D2_R2 = [ 52 | list(range(0, 150))[::2], 53 | list(range(0, 150))[1::2], 54 | list(range(150, 300))[::2], 55 | list(range(150, 300))[1::2] 56 | ] 57 | data_r2_d10_d20 = [((i, j), flip(P_LO)) for i in items_D1_R2[0] for j in items_D2_R2[0]] 58 | data_r2_d10_d21 = [((i, j), flip(P_HI)) for i in items_D1_R2[0] for j in items_D2_R2[1]] 59 | data_r2_d10_d22 = [((i, j), flip(P_LO)) for i in items_D1_R2[0] for j in items_D2_R2[2]] 60 | data_r2_d10_d23 = [((i, j), flip(P_HI)) for i in items_D1_R2[0] for j in items_D2_R2[3]] 61 | data_r2_d11_d20 = [((i, j), flip(P_HI)) for i in items_D1_R2[1] for j in items_D2_R2[0]] 62 | data_r2_d11_d21 = [((i, j), flip(P_LO)) for i in items_D1_R2[1] for j in items_D2_R2[1]] 63 | data_r2_d11_d22 = [((i, j), flip(P_HI)) for i in items_D1_R2[1] for j in items_D2_R2[3]] 64 | data_r2_d11_d23 = [((i, j), flip(P_LO)) for i in items_D1_R2[1] for j in items_D2_R2[2]] 65 | data_r2 \ 66 | = data_r2_d10_d20 + data_r2_d10_d21 + data_r2_d10_d22 + data_r2_d10_d23 \ 67 | + data_r2_d11_d20 + data_r2_d11_d21 + data_r2_d11_d22 + data_r2_d11_d23 68 | 69 | # R3 70 | items_D1_R3 = items_D1_R1 71 | items_D3_R3 = [ 72 | list(range(0, 100)), 73 | list(range(100, 200)), 74 | list(range(200, 300)), 75 | ] 76 | 77 | data_r3_d10_d30 = [((i, j), flip(P_HI)) for i in items_D1_R3[0] for j in items_D3_R3[0]] 78 | data_r3_d10_d31 = [((i, j), flip(P_LO)) for i in items_D1_R3[0] for j in items_D3_R3[1]] 79 | data_r3_d10_d32 = [((i, j), flip(P_HI)) for i in items_D1_R3[0] for j in items_D3_R3[2]] 80 | data_r3_d11_d30 = [((i, j), flip(P_LO)) for i in items_D1_R3[1] for j in items_D3_R3[0]] 81 | data_r3_d11_d31 = [((i, j), flip(P_HI)) for i in items_D1_R3[1] for j in items_D3_R3[1]] 82 | data_r3_d11_d32 = [((i, j), flip(P_HI)) for i in items_D1_R3[1] for j in items_D3_R3[2]] 83 | data_r3 \ 84 | = data_r3_d10_d30 + data_r3_d10_d31 + data_r3_d10_d32 \ 85 | + data_r3_d11_d30 + data_r3_d11_d31 + data_r3_d11_d32 86 | 87 | # Write schema to disk. 88 | path_schema = os.path.join('assets', 'three_relations.schema') 89 | with open(path_schema, 'w') as f: 90 | f.write('bernoulli R1 D1 D1\n') 91 | f.write('bernoulli R2 D1 D2\n') 92 | f.write('bernoulli R3 D1 D3\n') 93 | print(path_schema) 94 | # Write observations to disk. 95 | path_obs = os.path.join('assets', 'three_relations.obs') 96 | with open(path_obs, 'w') as f: 97 | for ((i, j), value) in data_r1: 98 | f.write('%d R1 %d %d\n' % (value, i, j)) 99 | for ((i, j), value) in data_r2: 100 | f.write('%d R2 %d %d\n' % (value, i, j)) 101 | for ((i, j), value) in data_r3: 102 | f.write('%d R3 %d %d\n' % (value, i, j)) 103 | print(path_obs) 104 | 105 | # Plot the synthetic data. 106 | fig, axes = plt.subplots(nrows=2, ncols=2) 107 | axes[0,0].set_axis_off() 108 | for relation, data, ax in [ 109 | ('R1', data_r1, axes[1,1]), 110 | ('R2', data_r2, axes[1,0]), 111 | ('R3', data_r3, axes[0,1]), 112 | ]: 113 | nr = n_items[schema[relation][0]] 114 | nc = n_items[schema[relation][1]] 115 | X = np.zeros((nr, nc)) 116 | for (i, j), v in data: 117 | X[i,j] = v 118 | if relation == 'R3': 119 | X = X.T 120 | if SHUFFLE: 121 | if relation == 'R1': 122 | n = n_items['D1'] 123 | pi = prng.sample(list(range(n)), k=n) 124 | X = np.asarray([ 125 | [X[pi[r], pi[c]] for c in range(n)] 126 | for r in range(n) 127 | ]) 128 | elif relation == 'R2': 129 | nr = n_items['D1'] 130 | nc = n_items['D2'] 131 | pir = prng.sample(list(range(nr)), k=nr) 132 | pic = prng.sample(list(range(nc)), k=nc) 133 | X = np.asarray([ 134 | [X[pir[r], pic[c]] for c in range(nc)] 135 | for r in range(nr) 136 | ]) 137 | if relation == 'R3': 138 | nr = n_items['D1'] 139 | nc = n_items['D3'] 140 | pir = prng.sample(list(range(nr)), k=nr) 141 | pic = prng.sample(list(range(nc)), k=nc) 142 | X = np.asarray([ 143 | [X[pir[r], pic[c]] for c in range(nc)] 144 | for r in range(nr) 145 | ]) 146 | ax.imshow(X, cmap='Greys') 147 | ax.xaxis.tick_top() 148 | ax.xaxis.set_label_position('top') 149 | ax.set_xlabel(xlabels[relation]) 150 | ax.set_ylabel(ylabels[relation], rotation=0, labelpad=10) 151 | ax.set_xticks([]) 152 | ax.set_yticks([]) 153 | ax.text(.05, .95, '$%s_%s$' % (relation[0], relation[1]), 154 | ha='left', va='top', 155 | transform=ax.transAxes, 156 | bbox={'facecolor': 'red', 'alpha': 1, 'edgecolor':'k'}) 157 | figname = os.path.join('assets', 'three_relations.%s.data.png' % (noisy_str,)) 158 | fig.set_size_inches((3.5, 3.5)) 159 | fig.subplots_adjust(wspace=.1, hspace=.1) 160 | fig.savefig(figname) 161 | print(figname) 162 | 163 | # ===== Make an HIRM for three relations and learn partition. 164 | def learn_hirm(seed, steps): 165 | hirm = HIRM(schema, prng=random.Random(seed)) 166 | hirm.seed = seed 167 | for relation, data in [ 168 | ('R1', data_r1), 169 | ('R2', data_r2), 170 | ('R3', data_r3), 171 | ]: 172 | for (i, j), v in data: 173 | hirm.incorporate(relation, (i, j), v) 174 | print(hirm.logp_score()) 175 | for i in range(steps): 176 | hirm.transition_cluster_assignments() 177 | for irm in hirm.irms.values(): 178 | irm.transition_cluster_assignments() 179 | print(i, [len(c) for c in hirm.crp.tables.values()], hirm.logp_score()) 180 | return hirm 181 | 182 | if __name__ == '__main__': 183 | seed = int(os.environ.get('SEED', '0')) 184 | iters = int(os.environ.get('ITERS', '20')) 185 | print('running with seed %d for %d iters' % (seed, iters)) 186 | hirm = learn_hirm(seed, iters) 187 | path_clusters = os.path.join('assets', 'three_relations.%s.%d.hirm' % (noisy_str, seed,)) 188 | to_txt_hirm(path_clusters, hirm) 189 | print(path_clusters) 190 | figname = '%s.png' % (path_clusters) 191 | plot_hirm_clusters(hirm, figname) 192 | -------------------------------------------------------------------------------- /examples/three_relations_plot.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 MIT Probabilistic Computing Project 2 | # Apache License, Version 2.0, refer to LICENSE.txt 3 | 4 | import os 5 | import sys 6 | 7 | import matplotlib.pyplot as plt 8 | 9 | from hirm.util_io import from_txt_hirm 10 | from hirm.util_plot import plot_binary_relation 11 | 12 | xlabels = {'R1': '', 'R2': '$D_2$', 'R3': '$D_1$'} 13 | ylabels = {'R1': '', 'R2': '$D_1$', 'R3': '$D_3$'} 14 | 15 | def plot_hirm_clusters(hirm, figname): 16 | fig, axes = plt.subplots(nrows=2, ncols=2) 17 | axes[0,0].set_axis_off() 18 | 19 | bbox = {'facecolor': 'red', 'alpha': 1, 'edgecolor':'k'} 20 | 21 | irm_R1 = hirm.relation_to_irm('R1') 22 | plot_binary_relation(irm_R1.relations['R1'], ax=axes[1,1]) 23 | score1 = irm_R1.relations['R1'].logp_score() 24 | axes[1,1].xaxis.set_label_position('top') 25 | axes[1,1].set_xlabel(xlabels['R1']) 26 | axes[1,1].set_ylabel(ylabels['R1'], rotation=0, labelpad=10) 27 | axes[1,1].set_xticks([]) 28 | axes[1,1].set_yticks([]) 29 | axes[1,1].text(.05, .95, '$R_1$', ha='left', va='top', 30 | transform=axes[1,1].transAxes, bbox=bbox) 31 | 32 | irm_R2 = hirm.relation_to_irm('R2') 33 | plot_binary_relation(irm_R2.relations['R2'], ax=axes[1,0]) 34 | score2 = irm_R2.relations['R2'].logp_score() 35 | axes[1,0].xaxis.set_label_position('top') 36 | axes[1,0].set_xlabel(xlabels['R2']) 37 | axes[1,0].set_ylabel(ylabels['R2'], rotation=0, labelpad=10) 38 | axes[1,0].set_xticks([]) 39 | axes[1,0].set_yticks([]) 40 | axes[1,0].text(.05, .95, '$R_2$', ha='left', va='top', 41 | transform=axes[1,0].transAxes, bbox=bbox) 42 | 43 | irm_R3 = hirm.relation_to_irm('R3') 44 | plot_binary_relation(irm_R3.relations['R3'], ax=axes[0,1], transpose=1) 45 | score3 = irm_R3.relations['R3'].logp_score() 46 | axes[0,1].xaxis.set_label_position('top') 47 | axes[0,1].set_xlabel(xlabels['R3']) 48 | axes[0,1].set_ylabel(ylabels['R3'], rotation=0, labelpad=10) 49 | axes[0,1].set_xticks([]) 50 | axes[0,1].set_yticks([]) 51 | axes[0,1].text(.05, .95, '$R_3$', ha='left', va='top', 52 | transform=axes[0,1].transAxes, bbox=bbox) 53 | 54 | print(score1, score2, score3) 55 | 56 | fig.set_size_inches((3.5, 3.5)) 57 | fig.subplots_adjust(wspace=.1, hspace=.1) 58 | fig.savefig(figname) 59 | print(figname) 60 | 61 | if __name__ == '__main__': 62 | path_clusters = sys.argv[1] 63 | path_schema = os.path.join('assets', 'three_relations.schema') 64 | path_obs = os.path.join('assets', 'three_relations.obs') 65 | hirm = from_txt_hirm(path_schema, path_obs, path_clusters) 66 | basename = os.path.basename(path_clusters) 67 | figname = os.path.join('assets', '%s.png' % (basename,)) 68 | plot_hirm_clusters(hirm, figname) 69 | -------------------------------------------------------------------------------- /examples/two_clusters_binary_irm.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 MIT Probabilistic Computing Project 2 | # Apache License, Version 2.0, refer to LICENSE.txt 3 | 4 | import os 5 | import random 6 | 7 | from pprint import pprint 8 | 9 | import matplotlib.pyplot as plt 10 | import numpy as np 11 | 12 | from hirm import IRM 13 | from hirm.tests.util_test import make_two_clusters 14 | from hirm.util_io import to_txt_irm 15 | from hirm.util_plot import plot_binary_relation 16 | 17 | schema, items_D1, items_D2, data = make_two_clusters() 18 | 19 | # Write schema to disk. 20 | path_schema = os.path.join('assets', 'two_clusters.binary.schema') 21 | with open(path_schema, 'w') as f: 22 | f.write('bernoulli R1 D1 D2\n') 23 | print(path_schema) 24 | # Write observations to disk. 25 | path_obs = os.path.join('assets', 'two_clusters.binary.obs') 26 | with open(path_obs, 'w') as f: 27 | for ((i, j), value) in data: 28 | f.write('%d R1 %d %d\n' % (value, i, j)) 29 | print(path_obs) 30 | 31 | # Plot the synthetic data. 32 | X = np.zeros((30, 40)) 33 | for (i, j), v in data: 34 | X[i,j] = v 35 | fig, ax = plt.subplots() 36 | ax.imshow(X, cmap='Greys') 37 | ax.xaxis.tick_top() 38 | ax.set_xticks(np.arange(X.shape[1])) 39 | ax.set_yticks(np.arange(X.shape[0])) 40 | ax.set_title('Raw Data') 41 | 42 | # Make an IRM. 43 | irm = IRM(schema, prng=random.Random(1)) 44 | for (i, j), v in data: 45 | irm.incorporate('R1', (i, j), v) 46 | 47 | # Plot the prior. 48 | fig, ax = plot_binary_relation(irm.relations['R1']) 49 | ax.set_title('Prior Sample') 50 | 51 | # Run inference. 52 | for i in range(20): 53 | irm.transition_cluster_assignments() 54 | pprint(irm.domains['D1'].crp.tables) 55 | pprint(irm.domains['D2'].crp.tables) 56 | 57 | # Write the results. 58 | path_clusters = os.path.join('assets', 'two_clusters.binary.irm') 59 | to_txt_irm(path_clusters, irm) 60 | print(path_clusters) 61 | 62 | # Plot the posterior. 63 | fig, ax = plot_binary_relation(irm.relations['R1']) 64 | ax.set_title('Posterior Sample') 65 | plt.show() 66 | 67 | path_figure = os.path.join('assets', 'two_clusters.binary.irm.png') 68 | fig.set_tight_layout(True) 69 | fig.savefig(path_figure) 70 | print(path_figure) 71 | -------------------------------------------------------------------------------- /examples/two_clusters_unary_irm.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 MIT Probabilistic Computing Project 2 | # Apache License, Version 2.0, refer to LICENSE.txt 3 | 4 | import os 5 | import random 6 | 7 | from pprint import pprint 8 | 9 | import matplotlib.pyplot as plt 10 | import numpy as np 11 | 12 | from hirm import IRM 13 | from hirm.tests.util_test import make_two_clusters 14 | from hirm.util_io import to_txt_irm 15 | from hirm.util_plot import plot_unary_relations 16 | 17 | schema, items_D1, items_D2, data = make_two_clusters() 18 | 19 | # Update schema to unary encoding. 20 | schema = {'Feature-%02d' % (j,): ('D1',) for j in range(40)} 21 | 22 | # Write schema to disk. 23 | path_schema = os.path.join('assets', 'two_clusters.unary.schema') 24 | with open(path_schema, 'w') as f: 25 | for j in range(40): 26 | f.write('bernoulli Feature-%02d D1\n' % (j,)) 27 | print(path_schema) 28 | # Write observations to disk. 29 | path_obs = os.path.join('assets', 'two_clusters.unary.obs') 30 | with open(path_obs, 'w') as f: 31 | for ((i, j), value) in data: 32 | f.write('%d Feature-%02d %d\n' % (value, j, i)) 33 | print(path_obs) 34 | 35 | # Plot the synthetic data. 36 | X = np.zeros((30, 40)) 37 | for (i, j), v in data: 38 | X[i,j] = v 39 | fig, ax = plt.subplots() 40 | ax.imshow(X, cmap='Greys') 41 | ax.xaxis.tick_top() 42 | ax.set_xticks(np.arange(X.shape[1])) 43 | ax.set_yticks(np.arange(X.shape[0])) 44 | ax.set_title('Raw Data') 45 | 46 | # Make an IRM. 47 | irm = IRM(schema, prng=random.Random(1)) 48 | for (i, j), v in data: 49 | irm.incorporate('Feature-%02d' % (j,), (i,), v) 50 | 51 | # Plot the prior. 52 | fig, ax = plot_unary_relations(list(irm.relations.values())) 53 | ax.set_title('Prior Sample') 54 | 55 | # Run inference. 56 | for i in range(20): 57 | irm.transition_cluster_assignments() 58 | pprint(irm.domains['D1'].crp.tables) 59 | 60 | # Write the results. 61 | path_clusters = os.path.join('assets', 'two_clusters.unary.irm') 62 | to_txt_irm(path_clusters, irm) 63 | print(path_clusters) 64 | 65 | # Plot the posterior. 66 | fig, ax = plot_unary_relations(list(irm.relations.values())) 67 | ax.set_title('Posterior Sample') 68 | plt.show() 69 | 70 | path_figure = os.path.join('assets', 'two_clusters.unary.irm.png') 71 | fig.set_tight_layout(True) 72 | fig.savefig(path_figure) 73 | print(path_figure) 74 | -------------------------------------------------------------------------------- /examples/two_relations.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 MIT Probabilistic Computing Project 2 | # Apache License, Version 2.0, refer to LICENSE.txt 3 | 4 | import os 5 | import random 6 | 7 | from pprint import pprint 8 | 9 | import matplotlib.pyplot as plt 10 | import numpy as np 11 | 12 | from hirm import IRM 13 | from hirm.util_plot import plot_binary_relation 14 | 15 | prng = random.Random(1) 16 | 17 | # ===== Synthetic data generation. 18 | items_D1_R1 = [ 19 | list(range(0, 10)) + list(range(20,30)), 20 | list(range(10, 20)), 21 | ] 22 | items_D2_R1 = [ 23 | list(range(0, 20)), 24 | list(range(20, 40)), 25 | ] 26 | data_r1_d10_d20 = [((i, j), 0) for i in items_D1_R1[0] for j in items_D2_R1[0]] 27 | data_r1_d10_d21 = [((i, j), 1) for i in items_D1_R1[0] for j in items_D2_R1[1]] 28 | data_r1_d11_d20 = [((i, j), 1) for i in items_D1_R1[1] for j in items_D2_R1[0]] 29 | data_r1_d11_d21 = [((i, j), 0) for i in items_D1_R1[1] for j in items_D2_R1[1]] 30 | data_r1 = data_r1_d10_d20 + data_r1_d10_d21 + data_r1_d11_d20 + data_r1_d11_d21 31 | 32 | items_D1_R2 = [ 33 | list(range(0, 30))[::2], 34 | list(range(0, 30))[1::2] 35 | ] 36 | items_D2_R2 = [ 37 | list(range(0, 20))[::2], 38 | list(range(0, 20))[1::2], 39 | list(range(20, 40))[::2], 40 | list(range(20, 40))[1::2] 41 | ] 42 | data_r2_d10_d20 = [((i, j), 0) for i in items_D1_R2[0] for j in items_D2_R2[0]] 43 | data_r2_d10_d21 = [((i, j), 1) for i in items_D1_R2[0] for j in items_D2_R2[1]] 44 | data_r2_d10_d22 = [((i, j), 0) for i in items_D1_R2[0] for j in items_D2_R2[2]] 45 | data_r2_d10_d23 = [((i, j), 1) for i in items_D1_R2[0] for j in items_D2_R2[3]] 46 | data_r2_d11_d20 = [((i, j), 1) for i in items_D1_R2[1] for j in items_D2_R2[0]] 47 | data_r2_d11_d21 = [((i, j), 0) for i in items_D1_R2[1] for j in items_D2_R2[1]] 48 | data_r2_d11_d22 = [((i, j), 1) for i in items_D1_R2[1] for j in items_D2_R2[3]] 49 | data_r2_d11_d23 = [((i, j), 0) for i in items_D1_R2[1] for j in items_D2_R2[2]] 50 | data_r2 \ 51 | = data_r2_d10_d20 + data_r2_d10_d21 + data_r2_d10_d22 + data_r2_d10_d23 \ 52 | + data_r2_d11_d20 + data_r2_d11_d21 + data_r2_d11_d22 + data_r2_d11_d23 53 | 54 | xlabels = {'R1': 'D2', 'R2': 'D2'} 55 | ylabels = {'R1': 'D1', 'R2': ''} 56 | 57 | # Write schema to disk. 58 | path_schema = os.path.join('assets', 'two_relations.schema') 59 | with open(path_schema, 'w') as f: 60 | f.write('bernoulli R1 D1 D2\n') 61 | f.write('bernoulli R2 D1 D2\n') 62 | print(path_schema) 63 | # Write observations to disk. 64 | path_obs = os.path.join('assets', 'two_relations.obs') 65 | with open(path_obs, 'w') as f: 66 | for ((i, j), value) in data_r1: 67 | f.write('%d R1 %d %d\n' % (value, i, j)) 68 | for ((i, j), value) in data_r2: 69 | f.write('%d R2 %d %d\n' % (value, i, j)) 70 | print(path_obs) 71 | 72 | # Plot the synthetic data. 73 | fig, axes = plt.subplots(ncols=2) 74 | for relation, data, ax in [('R1', data_r1, axes[0]), ('R2', data_r2, axes[1])]: 75 | X = np.zeros((30, 40)) 76 | for (i, j), v in data: 77 | X[i,j] = v 78 | nr = 30 79 | nc = 40 80 | pir = prng.sample(list(range(nr)), k=nr) 81 | pic = prng.sample(list(range(nc)), k=nc) 82 | X = np.asarray([ 83 | [X[pir[r], pic[c]] for c in range(nc)] 84 | for r in range(nr) 85 | ]) 86 | ax.imshow(X, cmap='Greys') 87 | ax.xaxis.tick_top() 88 | ax.xaxis.set_label_position('top') 89 | ax.set_xlabel(xlabels[relation]) 90 | ax.set_ylabel(ylabels[relation]) 91 | ax.set_xticks([]) 92 | ax.set_yticks([]) 93 | ax.text(.05, .95, relation, 94 | ha='left', va='top', 95 | transform=ax.transAxes, 96 | bbox={'facecolor': 'red', 'alpha': 1, 'edgecolor':'k'}) 97 | figname = os.path.join('assets', 'two_relations.data.png') 98 | fig.set_size_inches((4,2)) 99 | fig.set_tight_layout(True) 100 | fig.savefig(figname) 101 | print(figname) 102 | 103 | # ===== Make an IRM for both relations (using seed that underfits). 104 | schema = {'R1': ('D1', 'D2'), 'R2': ('D1', 'D2')} 105 | irm = IRM(schema, prng=random.Random(1)) 106 | for relation, data in [ 107 | ('R1', data_r1), 108 | ('R2', data_r2) 109 | ]: 110 | for (i, j), v in data: 111 | irm.incorporate(relation, (i, j), v) 112 | 113 | # Run inference. 114 | for i in range(100): 115 | irm.transition_cluster_assignments() 116 | pprint(irm.domains['D1'].crp.tables) 117 | pprint(irm.domains['D2'].crp.tables) 118 | pprint(irm.logp_score()) 119 | 120 | # Plot the posterior. 121 | fig, axes = plt.subplots(ncols=2) 122 | for relation, data, ax in [('R1', data_r1, axes[0]), ('R2', data_r2, axes[1])]: 123 | plot_binary_relation(irm.relations[relation], ax=ax) 124 | score = irm.relations[relation].logp_score() 125 | ax.xaxis.set_label_position('top') 126 | ax.set_xlabel(xlabels[relation]) 127 | ax.set_ylabel(ylabels[relation]) 128 | ax.set_xticks([]) 129 | ax.set_yticks([]) 130 | ax.text(.5, -.1, 'log score = %1.2f' % (score,), 131 | ha='center', va='top', color='k', 132 | transform=ax.transAxes) 133 | figname = os.path.join('assets', 'two_relations.underfit.png') 134 | fig.set_size_inches((4,2)) 135 | fig.set_tight_layout(True) 136 | fig.savefig(figname) 137 | print(figname) 138 | 139 | 140 | # ===== Make an IRM for both relations (using seed that overfits). 141 | schema = {'R1': ('D1', 'D2'), 'R2': ('D1', 'D2')} 142 | irm = IRM(schema, prng=random.Random(10)) 143 | for relation, data in [ 144 | ('R1', data_r1), 145 | ('R2', data_r2) 146 | ]: 147 | for (i, j), v in data: 148 | irm.incorporate(relation, (i, j), v) 149 | 150 | # Run inference. 151 | for i in range(200): 152 | irm.transition_cluster_assignments() 153 | pprint(irm.domains['D1'].crp.tables) 154 | pprint(irm.domains['D2'].crp.tables) 155 | pprint(irm.logp_score()) 156 | 157 | # Plot the posterior. 158 | fig, axes = plt.subplots(ncols=2) 159 | for relation, data, ax in [('R1', data_r1, axes[0]), ('R2', data_r2, axes[1])]: 160 | plot_binary_relation(irm.relations[relation], ax=ax) 161 | score = irm.relations[relation].logp_score() 162 | ax.xaxis.set_label_position('top') 163 | ax.set_xlabel(xlabels[relation]) 164 | ax.set_ylabel(ylabels[relation]) 165 | ax.set_xticks([]) 166 | ax.set_yticks([]) 167 | ax.text(.5, -.1, 'log score = %1.2f' % (score,), 168 | ha='center', va='top', color='k', 169 | transform=ax.transAxes) 170 | figname = os.path.join('assets', 'two_relations.overfit.png') 171 | fig.set_size_inches((4,2)) 172 | fig.set_tight_layout(True) 173 | fig.savefig(figname) 174 | print(figname) 175 | 176 | # ===== Make IRM for each relation separately. 177 | irm1 = IRM({'R1': ('D1', 'D2')}, prng=random.Random(1)) 178 | irm2 = IRM({'R2': ('D1', 'D2')}, prng=random.Random(10)) 179 | for (i, j), v in data_r1: 180 | irm1.incorporate('R1', (i, j), v) 181 | for (i, j), v in data_r2: 182 | irm2.incorporate('R2', (i, j), v) 183 | 184 | # Run inference. 185 | for i in range(100): 186 | irm1.transition_cluster_assignments() 187 | irm2.transition_cluster_assignments() 188 | pprint(irm1.domains['D1'].crp.tables) 189 | pprint(irm2.domains['D2'].crp.tables) 190 | pprint(irm1.logp_score()) 191 | pprint(irm2.logp_score()) 192 | 193 | # Plot the posterior. 194 | fig, axes = plt.subplots(ncols=2) 195 | plot_binary_relation(irm1.relations['R1'], ax=axes[0]) 196 | score1 = irm1.relations['R1'].logp_score() 197 | axes[0].xaxis.set_label_position('top') 198 | axes[0].set_xlabel(xlabels['R1']) 199 | axes[0].set_ylabel(ylabels['R1']) 200 | axes[0].set_xticks([]) 201 | axes[0].set_yticks([]) 202 | axes[0].text(.5, -.1, 'log score = %1.2f' % (score1,), 203 | ha='center', va='top', color='k', transform=axes[0].transAxes) 204 | 205 | plot_binary_relation(irm2.relations['R2'], ax=axes[1]) 206 | score2 = irm2.relations['R2'].logp_score() 207 | axes[1].xaxis.set_label_position('top') 208 | axes[1].set_xlabel(xlabels['R2']) 209 | axes[1].set_ylabel(ylabels['R2']) 210 | axes[1].set_xticks([]) 211 | axes[1].set_yticks([]) 212 | axes[1].text(.5, -.1, 'log score = %1.2f' % (score2,), 213 | ha='center', va='top', color='k', transform=axes[1].transAxes) 214 | 215 | figname = os.path.join('assets', 'two_relations.separate.png') 216 | fig.set_size_inches((4,2)) 217 | fig.set_tight_layout(True) 218 | fig.savefig(figname) 219 | print(figname) 220 | plt.show() 221 | 222 | # from hirm.util_math import log_normalize 223 | # p1, p2 = np.exp(log_normalize([score, score1 + score2])) 224 | # print((p1, p2)) 225 | 226 | -------------------------------------------------------------------------------- /examples/two_relations_anti.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 MIT Probabilistic Computing Project 2 | # Apache License, Version 2.0, refer to LICENSE.txt 3 | 4 | import os 5 | import random 6 | 7 | import matplotlib.pyplot as plt 8 | import numpy as np 9 | 10 | NOISY=1 11 | prng = random.Random(1) 12 | def flip(p): 13 | return prng.random() < p if NOISY else p > 0.5 14 | 15 | P_LO = .1 16 | P_HI = .65 17 | 18 | # ===== Synthetic data generation. 19 | items_D1 = [ 20 | list(range(0, 300)), 21 | list(range(300, 400)), 22 | ] 23 | data_r1 \ 24 | = [((i, j), flip(P_LO)) for i in items_D1[0] for j in items_D1[0]] \ 25 | + [((i, j), flip(P_HI)) for i in items_D1[0] for j in items_D1[1]] \ 26 | + [((i, j), flip(P_HI)) for i in items_D1[1] for j in items_D1[0]] \ 27 | + [((i, j), flip(P_LO)) for i in items_D1[1] for j in items_D1[1]] \ 28 | 29 | data_r2 \ 30 | = [((i, j), flip(P_HI)) for i in items_D1[0] for j in items_D1[0]] \ 31 | + [((i, j), flip(P_LO)) for i in items_D1[0] for j in items_D1[1]] \ 32 | + [((i, j), flip(P_LO)) for i in items_D1[1] for j in items_D1[0]] \ 33 | + [((i, j), flip(P_HI)) for i in items_D1[1] for j in items_D1[1]] \ 34 | 35 | xlabels = {'R1': '$D_1$', 'R2': '$D_1$'} 36 | ylabels = {'R1': '$D_1$', 'R2': ''} 37 | 38 | # Plot the synthetic data. 39 | fig, axes = plt.subplots(ncols=2) 40 | for relation, data, ax in [('R1', data_r1, axes[0]), ('R2', data_r2, axes[1])]: 41 | n = max(max(z) for z in items_D1) 42 | X = np.zeros((n+1, n+1)) 43 | for (i, j), v in data: 44 | X[i,j] = v 45 | ax.imshow(X, cmap='Greys') 46 | ax.xaxis.tick_top() 47 | ax.xaxis.set_label_position('top') 48 | ax.set_xlabel(xlabels[relation]) 49 | ax.set_ylabel(ylabels[relation], rotation=0, labelpad=10) 50 | ax.set_xticks([]) 51 | ax.set_yticks([]) 52 | ax.text(.05, .95, '$%s_%s$' % (relation[0], relation[1]), 53 | ha='left', va='top', 54 | transform=ax.transAxes, 55 | bbox={'facecolor': 'red', 'alpha': 1, 'edgecolor':'k'}) 56 | 57 | figname = os.path.join('assets', 'two_relations_anti.data.png') 58 | fig.set_size_inches((3,1.5)) 59 | fig.set_tight_layout(True) 60 | fig.savefig(figname) 61 | print(figname) 62 | 63 | # TODO: Compare output for clustering R1 and R2 using: 64 | # - IRM, with a higher-order encoding R': D1 x D1 X R -> {0, 1} 65 | # - HIRM, with a direct encoding of R1 and R2. 66 | -------------------------------------------------------------------------------- /examples/two_relations_hirm.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 MIT Probabilistic Computing Project 2 | # Apache License, Version 2.0, refer to LICENSE.txt 3 | 4 | import os 5 | import random 6 | 7 | import matplotlib.pyplot as plt 8 | import numpy as np 9 | 10 | from hirm import HIRM 11 | from hirm.util_plot import plot_binary_relation 12 | 13 | prng = random.Random(1) 14 | NOISY = os.environ.get('NOISY', None) 15 | noisy_str = 'noisy' if NOISY else 'clean' 16 | def flip(p): 17 | return prng.random() < p if NOISY else p > 0.5 18 | 19 | items_D1_R1 = [ 20 | list(range(0, 10)) + list(range(20,30)), 21 | list(range(10, 20)), 22 | ] 23 | items_D2_R1 = [ 24 | list(range(0, 20)), 25 | list(range(20, 40)), 26 | ] 27 | data_r1_d10_d20 = [((i, j), flip(.2)) for i in items_D1_R1[0] for j in items_D2_R1[0]] 28 | data_r1_d10_d21 = [((i, j), flip(.7)) for i in items_D1_R1[0] for j in items_D2_R1[1]] 29 | data_r1_d11_d20 = [((i, j), flip(.8)) for i in items_D1_R1[1] for j in items_D2_R1[0]] 30 | data_r1_d11_d21 = [((i, j), flip(.15)) for i in items_D1_R1[1] for j in items_D2_R1[1]] 31 | data_r1 = data_r1_d10_d20 + data_r1_d10_d21 + data_r1_d11_d20 + data_r1_d11_d21 32 | 33 | items_D1_R2 = [ 34 | list(range(0, 30))[::2], 35 | list(range(0, 30))[1::2] 36 | ] 37 | items_D2_R2 = [ 38 | list(range(0, 20))[::2], 39 | list(range(0, 20))[1::2], 40 | list(range(20, 40))[::2], 41 | list(range(20, 40))[1::2] 42 | ] 43 | data_r2_d10_d20 = [((i, j), flip(.1)) for i in items_D1_R2[0] for j in items_D2_R2[0]] 44 | data_r2_d10_d21 = [((i, j), flip(.2)) for i in items_D1_R2[0] for j in items_D2_R2[1]] 45 | data_r2_d10_d22 = [((i, j), flip(.15)) for i in items_D1_R2[0] for j in items_D2_R2[2]] 46 | data_r2_d10_d23 = [((i, j), flip(.8)) for i in items_D1_R2[0] for j in items_D2_R2[3]] 47 | data_r2_d11_d20 = [((i, j), flip(.8)) for i in items_D1_R2[1] for j in items_D2_R2[0]] 48 | data_r2_d11_d21 = [((i, j), flip(.3)) for i in items_D1_R2[1] for j in items_D2_R2[1]] 49 | data_r2_d11_d22 = [((i, j), flip(.9)) for i in items_D1_R2[1] for j in items_D2_R2[3]] 50 | data_r2_d11_d23 = [((i, j), flip(.1)) for i in items_D1_R2[1] for j in items_D2_R2[2]] 51 | data_r2 \ 52 | = data_r2_d10_d20 + data_r2_d10_d21 + data_r2_d10_d22 + data_r2_d10_d23 \ 53 | + data_r2_d11_d20 + data_r2_d11_d21 + data_r2_d11_d22 + data_r2_d11_d23 54 | 55 | # Plot the synthetic data. 56 | fig, axes = plt.subplots(ncols=2) 57 | for relation, data, ax in [('R1', data_r1, axes[0]), ('R2', data_r2, axes[1])]: 58 | X = np.zeros((30, 40)) 59 | for (i, j), v in data: 60 | X[i,j] = v 61 | ax.imshow(X, cmap='Greys') 62 | ax.xaxis.tick_top() 63 | ax.set_xticks(np.arange(X.shape[1])) 64 | ax.set_yticks(np.arange(X.shape[0])) 65 | ax.set_title('Raw Data %s' % (relation,)) 66 | figname = os.path.join('assets', 'two_relations_hirm.%s.data.png' % (noisy_str)) 67 | fig.set_size_inches((4,2)) 68 | fig.set_tight_layout(True) 69 | fig.savefig(figname) 70 | print(figname) 71 | 72 | # ===== Make an HIRM for both relations. 73 | # Using NOISY=0; seed=108; iters=100 learns cross product. 74 | schema = {'R1': ('D1', 'D2'), 'R2': ('D1', 'D2')} 75 | hirm = HIRM(schema, prng=random.Random(108)) 76 | for relation, data in [ 77 | ('R1', data_r1), 78 | ('R2', data_r2) 79 | ]: 80 | for (i, j), v in data: 81 | hirm.incorporate(relation, (i, j), v) 82 | print(hirm.crp.assignments) 83 | 84 | # Run inference. 85 | iters = 100 86 | hirm.set_cluster_assignment_gibbs('R1', 100) 87 | for i in range(iters): 88 | hirm.transition_cluster_assignments() 89 | hirm.transition_crp_alpha() 90 | for irm in hirm.irms.values(): 91 | irm.transition_cluster_assignments() 92 | irm.transition_crp_alphas() 93 | print(hirm.crp.assignments) 94 | print(hirm.logp_score()) 95 | 96 | # Plot the posterior. 97 | fig, axes = plt.subplots(ncols=2) 98 | for relation, data, ax in [('R1', data_r1, axes[0]), ('R2', data_r2, axes[1])]: 99 | irm = hirm.relation_to_irm(relation) 100 | plot_binary_relation(irm.relations[relation], ax=ax) 101 | score = irm.relations[relation].logp_score() 102 | ax.set_title('Posterior Sample %s, score %1.2f' % (relation, score,)) 103 | 104 | figname = os.path.join('assets', 'two_relations_hirm.%s.png' % (noisy_str)) 105 | fig.set_size_inches((4,2)) 106 | fig.set_tight_layout(True) 107 | fig.savefig(figname) 108 | print(figname) 109 | 110 | plt.show() 111 | -------------------------------------------------------------------------------- /pythenv.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | 3 | # Copyright 2021 MIT Probabilistic Computing Project 4 | # Apache License, Version 2.0, refer to LICENSE.txt 5 | 6 | set -Ceu 7 | 8 | : ${PYTHON:=python} 9 | root=`cd -- "$(dirname -- "$0")" && pwd` 10 | platform=$("${PYTHON}" -c 'import distutils.util as u; print(u.get_platform())') 11 | version=$("${PYTHON}" -c 'import sys; print(sys.version[0:3])') 12 | 13 | # The lib directory varies depending on 14 | # 15 | # (a) whether there are extension modules (here, no); and 16 | # (b) whether some Debian maintainer decided to patch the local Python 17 | # to behave as though there were. 18 | # 19 | # But there's no obvious way to just ask distutils what the name will 20 | # be. There's no harm in naming a pathname that doesn't exist, other 21 | # than a handful of microseconds of runtime, so we'll add both. 22 | libdir="${root}/build/lib" 23 | plat_libdir="${libdir}.${platform}-${version}" 24 | export PYTHONPATH="${libdir}:${plat_libdir}${PYTHONPATH:+:${PYTHONPATH}}" 25 | 26 | bindir="${root}/build/scripts-${version}" 27 | export PATH="${bindir}${PATH:+:${PATH}}" 28 | 29 | exec "$@" 30 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 MIT Probabilistic Computing Project 2 | # See LICENSE.txt 3 | 4 | import os 5 | import re 6 | from setuptools import setup 7 | 8 | # Specify the requirements. 9 | requirements = { 10 | 'src' : [ 11 | 'scipy==1.6.*', 12 | ], 13 | 'tests' : [ 14 | 'pytest==5.2.*' 15 | ], 16 | 'examples' : [ 17 | 'matplotlib==3.4.*', 18 | 'numpy==1.20.*', 19 | ] 20 | } 21 | requirements['all'] = [r for v in requirements.values() for r in v] 22 | 23 | # Determine the version (hardcoded). 24 | dirname = os.path.dirname(os.path.realpath(__file__)) 25 | vre = re.compile('__version__ = \'(.*?)\'') 26 | m = open(os.path.join(dirname, 'src', '__init__.py')).read() 27 | __version__ = vre.findall(m)[0] 28 | 29 | setup( 30 | name='hirm', 31 | version=__version__, 32 | description='Hierarchical Infinite Relational Model', 33 | long_description=open('README.md').read(), 34 | long_description_content_type='text/markdown', 35 | license='Apache-2.0', 36 | classifiers=[ 37 | 'Development Status :: 2 - Pre-Alpha', 38 | 'Intended Audience :: Science/Research', 39 | 'License :: OSI Approved :: Apache Software License', 40 | ], 41 | packages=[ 42 | 'hirm', 43 | 'hirm.tests', 44 | ], 45 | package_dir={ 46 | 'hirm' : 'src', 47 | 'hirm.tests' : 'tests', 48 | }, 49 | install_requires=requirements['all'], 50 | extras_require=requirements, 51 | python_requires='>=3.6', 52 | ) 53 | -------------------------------------------------------------------------------- /src/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 MIT Probabilistic Computing Project 2 | # Apache License, Version 2.0, refer to LICENSE.txt 3 | 4 | __version__ = '0.1.3' 5 | 6 | from .hirm import IRM 7 | from .hirm import HIRM 8 | -------------------------------------------------------------------------------- /src/hirm.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 MIT Probabilistic Computing Project 2 | # Apache License, Version 2.0, refer to LICENSE.txt 3 | 4 | import itertools 5 | import math 6 | import random 7 | 8 | from scipy.special import betaln 9 | from scipy.special import gammaln 10 | 11 | from .util_math import log_choices 12 | from .util_math import log_linspace 13 | from .util_math import logsumexp 14 | 15 | INF = float('inf') 16 | 17 | class BetaBernoulli: 18 | def __init__(self, alpha=1, beta=1, prng=None): 19 | self.alpha = alpha # Hyperparameter 20 | self.beta = beta # Hyperparameter 21 | self.N = 0 # Number of incorporated observations 22 | self.s = 0 # Sum of incorporated observations 23 | self.prng = prng or random 24 | def incorporate(self, x): 25 | assert x in [0, 1] 26 | self.N += 1 27 | self.s += x 28 | def unincorporate(self, x): 29 | assert x in [0, 1] 30 | self.N -= 1 31 | self.s -= x 32 | assert 0 <= self.N 33 | assert 0 <= self.s 34 | def logp(self, x): 35 | log_denom = math.log(self.N + self.alpha + self.beta) 36 | if x == 1: 37 | return math.log(self.s + self.alpha) - log_denom 38 | if x == 0: 39 | return math.log(self.N - self.s + self.beta) - log_denom 40 | return -INF 41 | # assert False, 'Bad value %s' % (repr(x),) 42 | def logp_score(self): 43 | n = betaln(self.s + self.alpha, self.N - self.s + self.beta) 44 | d = betaln(self.alpha, self.beta) 45 | return n - d 46 | def sample(self): 47 | p = math.exp(self.logp(1)) 48 | return self.prng.choices([0, 1], [p, 1-p])[0] 49 | def __repr__(self): 50 | return 'BetaBernoulli(alpha=%f, beta=%f, N=%d, s=%d)' \ 51 | % (self.alpha, self.beta, self.N, self.s) 52 | def __str__(self): 53 | return repr(self) 54 | 55 | class CRP: 56 | def __init__(self, alpha=1, prng=None): 57 | self.alpha = alpha # Concentration parameter. 58 | self.N = 0 # Number of customers 59 | self.tables = {} # Map from table to set of items 60 | self.assignments = {} # Map from item to assigned table 61 | self.prng = prng or random 62 | def incorporate(self, item, table): 63 | assert item not in self.assignments 64 | if table not in self.tables: 65 | self.tables[table] = set() 66 | self.tables[table].add(item) 67 | self.assignments[item] = table 68 | self.N += 1 69 | def unincorporate(self, item): 70 | table = self.assignments[item] 71 | self.tables[table].remove(item) 72 | if not self.tables[table]: 73 | del self.tables[table] 74 | del self.assignments[item] 75 | self.N -= 1 76 | def sample(self): 77 | crp_dist = self.tables_weights() 78 | tables = list(crp_dist.keys()) 79 | weights = crp_dist.values() 80 | return self.prng.choices(tables, weights=weights)[0] 81 | def logp(self, table): 82 | dist = self.tables_weights() 83 | if table not in dist: 84 | return -float('inf') 85 | numer = dist[table] 86 | denom = self.N + self.alpha 87 | return math.log(numer) - math.log(denom) 88 | def logp_score(self): 89 | # http://gershmanlab.webfactional.com/pubs/GershmanBlei12.pdf#page=4 (eq 8) 90 | counts = [len(self.tables[t]) for t in self.tables] 91 | return len(self.tables) * math.log(self.alpha) \ 92 | + sum(gammaln(counts)) \ 93 | + gammaln(self.alpha) \ 94 | - gammaln(self.N + self.alpha) 95 | def tables_weights(self): 96 | if self.N == 0: 97 | return {0: 1} 98 | crp_dist = {t : len(self.tables[t]) for t in self.tables} 99 | crp_dist[max(self.tables) + 1] = self.alpha 100 | return crp_dist 101 | def tables_weights_gibbs(self, table): 102 | assert 0 < self.N 103 | crp_dist = self.tables_weights() 104 | crp_dist[table] -= 1 105 | if crp_dist[table] == 0: 106 | crp_dist[table] = self.alpha 107 | del crp_dist[max(crp_dist)] 108 | return crp_dist 109 | def transition_alpha(self): 110 | grid = log_linspace(1/self.N, self.N+1, 30) 111 | log_weights = [] 112 | for g in grid: 113 | self.alpha = g 114 | lp_g = self.logp_score() 115 | log_weights.append(lp_g) 116 | self.alpha = log_choices(grid, log_weights, prng=self.prng)[0] 117 | 118 | def __repr__(self): 119 | return 'CRP(alpha=%r, N=%r, tables=%r, assignments=%r)' \ 120 | % (self.alpha, self.N, self.tables, self.assignments) 121 | def __str__(self): 122 | return repr(self) 123 | 124 | class Domain: 125 | def __init__(self, name, prng=None): 126 | self.name = name # Human-readable string name 127 | self.items = set() # Set of items 128 | self.crp = CRP(prng=prng) # Clustering model for items 129 | self.prng = self.crp.prng 130 | def incorporate(self, item, table=None): 131 | if item in self.items: 132 | assert table is None 133 | if item not in self.items: 134 | self.items.add(item) 135 | t = self.crp.sample() if table is None else table 136 | self.crp.incorporate(item, t) 137 | def unincorporate(self, item): 138 | raise NotImplementedError() 139 | # assert item in self.items 140 | # self.items[item].remove(relation) 141 | # if not self.items[item]: 142 | # self.crp.unincorporate(item) 143 | # del self.items[item] 144 | def get_cluster_assignment(self, item): 145 | assert item in self.items 146 | return self.crp.assignments[item] 147 | def set_cluster_assignment_gibbs(self, item, table): 148 | assert item in self.items 149 | assert self.crp.assignments[item] != table 150 | self.crp.unincorporate(item) 151 | self.crp.incorporate(item, table) 152 | def tables_weights(self): 153 | return self.crp.tables_weights() 154 | def tables_weights_gibbs(self, item): 155 | assert item in self.items 156 | table = self.get_cluster_assignment(item) 157 | return self.crp.tables_weights_gibbs(table) 158 | 159 | def __repr__(self): 160 | return 'Domain(name=%r)' % (self.name,) 161 | def __str__(self): 162 | return repr(self) 163 | 164 | class Relation: 165 | def __init__(self, name, domains, prng=None): 166 | self.name = name # Name of relation 167 | self.domains = tuple(domains) # Domains it is defined over 168 | self.aux = BetaBernoulli # TODO: Generalize 169 | self.clusters = {} # Map from cluster id to BetaBernoulli 170 | self.data = {} # Map from items to observed value 171 | self.data_r = {domain.name : {} for domain in self.domains} 172 | # Map from domain name to reverse map 173 | # from item to set of incorporated 174 | # items that include that item 175 | self.prng = prng or random 176 | def incorporate(self, items, value): 177 | assert items not in self.data 178 | self.data[items] = value 179 | assert len(items) == len(self.domains) 180 | for domain, item in zip(self.domains, items): 181 | domain.incorporate(item) 182 | if item not in self.data_r[domain.name]: 183 | self.data_r[domain.name][item] = set() 184 | self.data_r[domain.name][item].add(items) 185 | cluster = self.get_cluster_assignment(items) 186 | if cluster not in self.clusters: 187 | self.clusters[cluster] = BetaBernoulli() 188 | self.clusters[cluster].incorporate(value) 189 | def unincorporate(self, items): 190 | raise NotImplementedError() 191 | # x = self.data[items] 192 | # z = self.get_cluster_assignment(items) 193 | # self.clusters[z].unincorporate(x) 194 | # if self.clusters[z].N == 0: 195 | # del self.clusters[z] 196 | # for domain, item in zip(self.domains, items): 197 | # if item in self.data_r[domain.name]: 198 | # self.data_r[domain.name][item].discard(items) 199 | # if len(self.data_r[domain.name][item]) == 0: 200 | # del self.data_r[domain.name][item] 201 | # domain.unincorporate(self.name, item) 202 | # del self.data[items] 203 | def get_cluster_assignment(self, items): 204 | return tuple((domain.get_cluster_assignment(item)) 205 | for domain, item in zip(self.domains, items)) 206 | def get_cluster_assignment_gibbs(self, items, domain, item, table): 207 | z = [] 208 | assert len(items) == len(self.domains) 209 | hits = 0 210 | for domain_i, item_i in zip(self.domains, items): 211 | if (domain_i.name == domain.name) and (item_i == item): 212 | t = table 213 | hits += 1 214 | else: 215 | t = domain_i.get_cluster_assignment(item_i) 216 | z.append(t) 217 | assert hits 218 | return tuple(z) 219 | 220 | # Implementation of approximate Gibbs data probabilities (faster). 221 | def logp_gibbs_approx_current(self, domain, item): 222 | """Return approximate proposal probability for current table.""" 223 | logp = 0 224 | for items in self.data_r[domain.name][item]: 225 | x = self.data[items] 226 | z = self.get_cluster_assignment(items) 227 | self.clusters[z].unincorporate(x) 228 | lp = self.clusters[z].logp(x) 229 | self.clusters[z].incorporate(x) 230 | logp += lp 231 | return logp 232 | def logp_gibbs_approx_variant(self, domain, item, table): 233 | """Return approximate proposal probability for non-current table.""" 234 | logp = 0 235 | for items in self.data_r[domain.name][item]: 236 | x = self.data[items] 237 | z = self.get_cluster_assignment_gibbs(items, domain, item, table) 238 | cluster = self.clusters.get(z, self.aux()) 239 | lp = cluster.logp(x) 240 | logp += lp 241 | return logp 242 | def logp_gibbs_approx(self, domain, item, table): 243 | """Return approximate proposal probability of domain.item at table.""" 244 | table_current = domain.get_cluster_assignment(item) 245 | if table_current == table: 246 | logp = self.logp_gibbs_approx_current(domain, item) 247 | else: 248 | logp = self.logp_gibbs_approx_variant(domain, item, table) 249 | return logp 250 | 251 | # Implementation of exact Gibbs data probabilities. 252 | def get_cluster_to_items_list(self, domain, item): 253 | """Return mapping from cluster to all items in that cluster 254 | that have domain.item in at least one dimension.""" 255 | cluster_to_items_list = {} 256 | for items in self.data_r[domain.name][item]: 257 | cluster = self.get_cluster_assignment(items) 258 | if cluster not in cluster_to_items_list: 259 | cluster_to_items_list[cluster] = [] 260 | cluster_to_items_list[cluster].append(items) 261 | return cluster_to_items_list 262 | def logp_gibbs_exact_current(self, items_list): 263 | """Return exact proposal proposal probability for current table.""" 264 | z = self.get_cluster_assignment(items_list[0]) 265 | cluster = self.clusters[z] 266 | logp0 = cluster.logp_score() 267 | for items in items_list: 268 | x = self.data[items] 269 | # assert z == self.get_cluster_assignment(items) 270 | cluster.unincorporate(x) 271 | logp1 = cluster.logp_score() 272 | for items in items_list: 273 | x = self.data[items] 274 | cluster.incorporate(x) 275 | assert cluster.logp_score() == logp0 276 | return logp0 - logp1 277 | def logp_gibbs_exact_variant(self, domain, item, table, items_list): 278 | """Return exact proposal proposal probability for non-current table.""" 279 | z = self.get_cluster_assignment_gibbs(items_list[0], domain, item, table) 280 | cluster = self.clusters.get(z, self.aux()) 281 | logp0 = cluster.logp_score() 282 | for items in items_list: 283 | # assert z == self.get_cluster_assignment_gibbs(items, domain, item, table) 284 | x = self.data[items] 285 | cluster.incorporate(x) 286 | logp1 = cluster.logp_score() 287 | for items in items_list: 288 | # TODO: Skip this loop in case of cluster aux. 289 | x = self.data[items] 290 | cluster.unincorporate(x) 291 | assert cluster.logp_score() == logp0 292 | return logp1 - logp0 293 | def logp_gibbs_exact(self, domain, item, tables): 294 | """Return exact proposal probability of domain.item at tables.""" 295 | # assert tables crp_dist = domain.tables_weights_gibbs(item) 296 | cluster_to_items_list = self.get_cluster_to_items_list(domain, item) 297 | table_current = domain.get_cluster_assignment(item) 298 | logps = [] 299 | for table in tables: 300 | lp = 0 301 | for items_list in cluster_to_items_list.values(): 302 | if table == table_current: 303 | lp_cluster = self.logp_gibbs_exact_current(items_list) 304 | else: 305 | lp_cluster = self.logp_gibbs_exact_variant( 306 | domain, item, table, items_list) 307 | lp += lp_cluster 308 | logps.append(lp) 309 | return logps 310 | 311 | def logp(self, items, value): 312 | assert len(self.domains) == len(items) 313 | # TODO: Replace with call logp_observations. 314 | # XXX Formally, the following assertion is needed for this 315 | # algorithm to be correct: we should only one fresh item per 316 | # domain. Otherwise, the CRP table probabilities are coupled in 317 | # the predictive. However, we will assume that we have a 318 | # "truncated" version of the DPMM with only one auxiliary cluster, 319 | # where each fresh item has the same probability of belonging to the 320 | # cluster independently of previous fresh items from that domain. 321 | # domain_to_item = {} 322 | # for domain, item in zip(self.domains, items): 323 | # assert domain.name not in domain_to_item or domain_to_item[domain.name] == item 324 | tabl_list = [] 325 | wght_list = [] 326 | indx_list = [] 327 | for domain, item in zip(self.domains, items): 328 | if item in domain.items: 329 | t_list = [domain.get_cluster_assignment(item)] 330 | w_list = [0] 331 | i_list = [0] 332 | else: 333 | tables_weights = domain.tables_weights() 334 | Z = math.log(1 + domain.crp.N) 335 | t_list = tuple(tables_weights.keys()) 336 | w_list = tuple(math.log(x) - Z for x in tables_weights.values()) 337 | i_list = tuple(range(len(tables_weights))) 338 | tabl_list.append(t_list) 339 | wght_list.append(w_list) 340 | indx_list.append(i_list) 341 | logps = [] 342 | for indexes in itertools.product(*indx_list): 343 | z = tuple(tabl_list[i][j] for i, j in enumerate(indexes)) 344 | w = tuple(wght_list[i][j] for i, j in enumerate(indexes)) 345 | cluster = self.clusters.get(z, self.aux()) 346 | logp_data = cluster.logp(value) 347 | logp_clst = sum(w) 348 | logps.append(logp_clst + logp_data) 349 | return logsumexp(logps) 350 | def logp_score(self): 351 | return sum(cluster.logp_score() for cluster in self.clusters.values()) 352 | def set_cluster_assignment_gibbs(self, domain, item, table): 353 | # More efficient than calling incorporate/unincorporate. 354 | table_current = domain.get_cluster_assignment(item) 355 | assert table != table_current 356 | for items in self.data_r[domain.name][item]: 357 | x = self.data[items] 358 | # Remove data point from current cluster. 359 | z_prev = self.get_cluster_assignment(items) 360 | cluster_prev = self.clusters[z_prev] 361 | cluster_prev.unincorporate(x) 362 | if cluster_prev.N == 0: 363 | del self.clusters[z_prev] 364 | # Add data point to new cluster. 365 | z_new = self.get_cluster_assignment_gibbs(items, domain, item, table) 366 | assert z_new not in self.clusters or self.clusters[z_new].N > 0 367 | if z_new not in self.clusters: 368 | self.clusters[z_new] = self.aux() 369 | self.clusters[z_new].incorporate(x) 370 | def has_observation(self, domain, item): 371 | return item in self.data_r[domain.name] 372 | 373 | def __repr__(self): 374 | return 'Relation(name=%s, domains=%r)' % (self.name, self.domains,) 375 | def __str__(self): 376 | return repr(self) 377 | 378 | class IRM: 379 | def __init__(self, schema, prng=None): 380 | self.schema = {} 381 | self.domains = {} 382 | self.relations = {} 383 | self.domain_to_relations = {} 384 | self.prng = prng or random 385 | for (relation, domains) in schema.items(): 386 | self.add_relation(relation, domains) 387 | def incorporate(self, r, items, value): 388 | self.relations[r].incorporate(items, value) 389 | def unincorporate(self, r, items): 390 | raise NotImplementedError() 391 | # self.relations[r].unincorporate(items) 392 | def transition_cluster_assignments(self, domains=None): 393 | if domains is None: 394 | domains = list(self.domains) 395 | self.prng.shuffle(domains) 396 | for d in domains: 397 | items = list(self.domains[d].items) 398 | self.prng.shuffle(items) 399 | for item in self.domains[d].items: 400 | self.transition_cluster_assignment_item(d, item) 401 | def transition_cluster_assignment_item(self, d, item): 402 | domain = self.domains[d] 403 | relations = [self.relations[r] for r in self.domain_to_relations[d]] 404 | crp_dist = domain.tables_weights_gibbs(item) 405 | # Compute probability of each table. 406 | tables = crp_dist.keys() 407 | logps = [math.log(crp_dist[t]) for t in tables] 408 | for relation in relations: 409 | if relation.has_observation(domain, item): 410 | lp_relation = relation.logp_gibbs_exact(domain, item, tables) 411 | assert len(lp_relation) == len(tables) 412 | logps = [x + y for x, y in zip(logps, lp_relation)] 413 | # Sample new table. 414 | choice = log_choices(list(crp_dist), logps, prng=self.prng)[0] 415 | if choice != domain.get_cluster_assignment(item): 416 | # Update the relations. 417 | for relation in relations: 418 | if relation.has_observation(domain, item): 419 | relation.set_cluster_assignment_gibbs(domain, item, choice) 420 | # Update the domain. 421 | domain.set_cluster_assignment_gibbs(item, choice) 422 | def transition_crp_alphas(self, domains=None): 423 | if domains is None: 424 | domains = list(self.domains) 425 | self.prng.shuffle(domains) 426 | for d in domains: 427 | self.domains[d].crp.transition_alpha() 428 | def logp(self, observations): 429 | obs = [(self.relations[r], i, v) for (r, i, v) in observations] 430 | return logp_observations(obs) 431 | def logp_score(self): 432 | logp_score_crp = [self.domains[d].crp.logp_score() for d in self.domains] 433 | logp_score_relation = [self.relations[r].logp_score() for r in self.relations] 434 | return sum(logp_score_crp) + sum(logp_score_relation) 435 | def add_relation(self, r, domains): 436 | assert r not in self.schema 437 | assert r not in self.relations 438 | for d in domains: 439 | if d not in self.domains: 440 | self.domains[d] = Domain(d, prng=self.prng) 441 | self.domain_to_relations[d] = set() 442 | self.domain_to_relations[d].add(r) 443 | self.relations[r] = Relation(r, [self.domains[d] for d in domains], 444 | prng=self.prng) 445 | self.schema[r] = domains 446 | def remove_relation(self, r): 447 | domains = {d.name for d in self.relations[r].domains} 448 | for d in domains: 449 | self.domain_to_relations[d].discard(r) 450 | # TODO: Remove r from self.domains[d].items 451 | if len(self.domain_to_relations[d]) == 0: 452 | del self.domain_to_relations[d] 453 | del self.domains[d] 454 | del self.relations[r] 455 | del self.schema[r] 456 | 457 | class HIRM: 458 | def __init__(self, schema, prng=None): 459 | self.crp = CRP(prng=prng) 460 | self.schema = {} 461 | self.irms = {} 462 | self.prng = prng or random 463 | for relation, domains in schema.items(): 464 | self.add_relation(relation, domains) 465 | def incorporate(self, r, items, value): 466 | irm = self.relation_to_irm(r) 467 | irm.incorporate(r, items, value) 468 | def unincorporate(self, r, items): 469 | irm = self.relation_to_irm(r) 470 | irm.unincorporate(r, items) 471 | def relation_to_table(self, r): 472 | return self.crp.assignments[r] 473 | def relation_to_irm(self, r): 474 | table = self.crp.assignments[r] 475 | return self.irms[table] 476 | def relation(self, r): 477 | irm = self.relation_to_irm(r) 478 | return irm.relations[r] 479 | def transition_cluster_assignments(self): 480 | for r in list(self.crp.assignments): 481 | self.transition_cluster_assignment_relation(r) 482 | def transition_cluster_assignment_relation(self, r): 483 | table_current = self.crp.assignments[r] 484 | relation = self.irms[table_current].relations[r] 485 | signature = (r, [d.name for d in relation.domains]) 486 | crp_dist = self.crp.tables_weights_gibbs(table_current) 487 | (table_aux, irm_aux) = (None, None) 488 | logps = [] 489 | # Compute probabilities of each table. 490 | for table in crp_dist: 491 | irm = self.irms.get(table, None) 492 | if irm is None: 493 | irm = IRM({}, prng=self.prng) 494 | assert (table_aux, irm_aux) == (None, None) 495 | (table_aux, irm_aux) = (table, irm) 496 | if table != table_current: 497 | irm.add_relation(signature[0], signature[1]) 498 | for items, value in relation.data.items(): 499 | irm.incorporate(r, items, value) 500 | lp_table = irm.relations[r].logp_score() 501 | logps.append(lp_table) 502 | # Sample new table. 503 | log_weights = [math.log(crp_dist[t]) + l for t, l in zip(crp_dist, logps)] 504 | choice = log_choices(list(crp_dist), log_weights, prng=self.prng)[0] 505 | # Remove relation from all other tables. 506 | for table in self.crp.tables: 507 | if table != choice: 508 | self.irms[table].remove_relation(r) 509 | if len(self.irms[table].relations) == 0: 510 | assert len(self.crp.tables[table]) == 1 511 | assert table == table_current 512 | del self.irms[table] 513 | # Add auxiliary table if necessary. 514 | if choice == table_aux: 515 | self.irms[choice] = irm_aux 516 | # Update the CRP. 517 | self.crp.unincorporate(r) 518 | self.crp.incorporate(r, choice) 519 | assert set(self.irms) == set(self.crp.tables) 520 | def set_cluster_assignment_gibbs(self, r, table): 521 | table_current = self.crp.assignments[r] 522 | assert table != table_current 523 | relation = self.irms[table_current].relations[r] 524 | # Remove from current IRM. 525 | self.irms[table_current].remove_relation(r) 526 | if len(self.irms[table_current].relations) == 0: 527 | del self.irms[table_current] 528 | # Add to target IRM. 529 | irm = self.irms.get(table, None) 530 | if irm is None: 531 | irm = IRM({}, prng=self.prng) 532 | self.irms[table] = irm 533 | irm.add_relation(r, [d.name for d in relation.domains]) 534 | for items, value in relation.data.items(): 535 | irm.incorporate(r, items, value) 536 | # Update CRP. 537 | self.crp.unincorporate(r) 538 | self.crp.incorporate(r, table) 539 | assert set(self.irms) == set(self.crp.tables) 540 | def transition_crp_alpha(self): 541 | self.crp.transition_alpha() 542 | def add_relation(self, r, domains): 543 | assert r not in self.schema 544 | self.schema[r] = domains 545 | table = self.crp.sample() 546 | self.crp.incorporate(r, table) 547 | if table in self.irms: 548 | self.irms[table].add_relation(r, domains) 549 | else: 550 | irm = IRM({r : domains}, prng=self.prng) 551 | self.irms[table] = irm 552 | def remove_relation(self, r): 553 | del self.schema[r] 554 | table = self.crp.assignments[r] 555 | self.crp.unincorporate(r) 556 | self.irms[table].remove_relation(r) 557 | if len(self.irms[table].relations) == 0: 558 | del self.irms[table] 559 | def logp(self, observations): 560 | obs_dict = {} 561 | for (relation, items, value) in observations: 562 | table = self.crp.assignments[relation] 563 | if table not in obs_dict: 564 | obs_dict[table] = [] 565 | obs_dict[table].append((relation, items, value)) 566 | logps = (self.irms[t].logp(obs_dict[t]) for t in obs_dict) 567 | return sum(logps) 568 | def logp_score(self): 569 | logp_score_crp = self.crp.logp_score() 570 | logp_score_irms = [irm.logp_score() for irm in self.irms.values()] 571 | return logp_score_crp + sum(logp_score_irms) 572 | 573 | def logp_observations(observations): 574 | """Observations is a list of (relation, items, value) tuples.""" 575 | # Compute all cluster combinations. 576 | item_universe = set() 577 | index_universe = [] 578 | weight_universe = [] 579 | cluster_universe = {} 580 | seen = set() 581 | for relation, items, value in observations: 582 | assert (relation.name, items) not in seen 583 | seen.add((relation.name, items)) 584 | assert len(items) == len(relation.domains) 585 | for domain, item in zip(relation.domains, items): 586 | if (domain.name, item) in item_universe: 587 | assert (domain.name, item) in cluster_universe 588 | continue 589 | if item in domain.items: 590 | t_list = (domain.get_cluster_assignment(item),) 591 | w_list = (0,) 592 | i_list = (0,) 593 | else: 594 | tables_weights = domain.tables_weights() 595 | t_list = tuple(tables_weights.keys()) 596 | Z = math.log(1 + domain.crp.N) 597 | w_list = tuple(math.log(x) - Z for x in tables_weights.values()) 598 | i_list = tuple(range(len(tables_weights))) 599 | item_universe.add((domain.name, item)) 600 | index_universe.append(i_list) 601 | weight_universe.append(w_list) 602 | loc = len(index_universe) - 1 # location of (domain.name, item) 603 | # within the index universe 604 | cluster_universe[(domain.name, item)] = (loc, t_list) 605 | assert len(item_universe) == len(index_universe) 606 | assert len(item_universe) == len(weight_universe) 607 | assert len(item_universe) == len(cluster_universe) 608 | # Compute data probabilities given each cluster combinations. 609 | # TODO: This implementation can be made more efficient by factoring 610 | # out relations that do not have any overlapping items. 611 | logps = [] 612 | for indexes in itertools.product(*index_universe): 613 | logp_indexes = 0 614 | # Compute weight of cluster assignments. 615 | weight = [weight_universe[i][j] for i, j in enumerate(indexes)] 616 | logp_indexes += sum(weight) 617 | # Compute weight of data given cluster assignments. 618 | for relation, items, value in observations: 619 | z = [] 620 | for domain, item in zip(relation.domains, items): 621 | loc, t_list = cluster_universe[(domain.name, item)] 622 | t = t_list[indexes[loc]] 623 | z.append(t) 624 | cluster = relation.clusters.get(tuple(z), relation.aux()) 625 | logp_indexes += cluster.logp(value) 626 | # Add to global list of logps. 627 | logps.append(logp_indexes) 628 | return logsumexp(logps) 629 | -------------------------------------------------------------------------------- /src/util_io.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 MIT Probabilistic Computing Project 2 | # Apache License, Version 2.0, refer to LICENSE.txt 3 | 4 | import ast 5 | 6 | from . import hirm 7 | 8 | def intify(x): 9 | if x.isnumeric(): 10 | assert int(x) == float(x) 11 | return int(x) 12 | return x 13 | 14 | def load_schema(path): 15 | """Load a schema from path.""" 16 | signatures = {} 17 | with open(path, 'r') as f: 18 | for line in f: 19 | parts = line.strip().split(' ') 20 | assert 3 <= len(parts) 21 | dist = parts[0] 22 | assert dist == 'bernoulli' 23 | feature = parts[1] 24 | domains = tuple(parts[2:]) 25 | signatures[feature] = domains 26 | return signatures 27 | 28 | def load_observations(path): 29 | """Load a dataset from path.""" 30 | data = [] 31 | with open(path, 'r') as f: 32 | for line in f: 33 | parts = line.strip().split(' ') 34 | assert 3 <= len(parts) 35 | x = float(parts[0]) 36 | relation = parts[1] 37 | items = tuple(intify(x) for x in parts[2:]) 38 | data.append((relation, items, x)) 39 | return data 40 | 41 | def load_clusters_irm(path): 42 | """Load clusters from path.""" 43 | clusters = {} 44 | with open(path, 'r') as f: 45 | for line in f: 46 | parts = line.strip().split(' ') 47 | assert 3 <= len(parts) 48 | domain = parts[0] 49 | table = int(parts[1]) 50 | items = tuple(intify(x) for x in parts[2:]) 51 | if domain not in clusters: 52 | clusters[domain] = {} 53 | clusters[domain][table] = items 54 | return clusters 55 | 56 | def load_clusters_hirm(path): 57 | """Load clusters from path.""" 58 | irms = {} 59 | relations = {} 60 | current_irm = 0 61 | with open(path, 'r') as f: 62 | for line in f: 63 | parts = line.strip().split(' ') 64 | if parts[0].isnumeric(): 65 | assert 2 <= len(parts) 66 | table = int(parts[0]) 67 | items = tuple(parts[1:]) 68 | assert table not in relations 69 | relations[table] = items 70 | continue 71 | if len(parts) == 1 and not parts[0]: 72 | current_irm = None 73 | continue 74 | if len(parts) == 1 and parts[0].startswith('irm='): 75 | assert current_irm is None 76 | current_irm = int(parts[0].split('=')[1]) 77 | assert current_irm not in irms 78 | irms[current_irm] = {} 79 | continue 80 | if 2 <= len(parts): 81 | assert current_irm is not None 82 | assert current_irm in irms 83 | domain = parts[0] 84 | table = int(parts[1]) 85 | items = tuple(intify(x) for x in parts[2:]) 86 | if domain not in irms[current_irm]: 87 | irms[current_irm][domain] = {} 88 | assert table not in irms[current_irm][domain] 89 | irms[current_irm][domain][table] = items 90 | continue 91 | assert False, 'Failed to process line' 92 | assert set(relations) == set(irms) 93 | return relations, irms 94 | 95 | # Serialization to/from JSON compatible dictionaries 96 | # NB: Caveats of json.dumps 97 | # - dict keys must be string (no tuples) 98 | # - dict keys that are integers automatically stringified. 99 | # - tuples automatically converted to listified. 100 | # - sets are not JSON serializable. 101 | 102 | def to_dict_BetaBernoulli(x): 103 | return {'alpha': x.alpha, 'beta': x.beta, 'N': x.N, 's': x.s} 104 | def from_dict_BetaBernoulli(d, prng=None): 105 | x = hirm.BetaBernoulli(alpha=d['alpha'], beta=d['beta'], prng=prng) 106 | x.N = d['N'] 107 | x.s = d['s'] 108 | return x 109 | 110 | def to_dict_CRP(x): 111 | return { 112 | 'alpha': x.alpha, 113 | 'N': x.N, 114 | 'tables': {repr(t): list(v) for t,v in x.tables.items()}, 115 | 'assignments': {repr(t): v for t,v in x.assignments.items()} 116 | } 117 | def from_dict_CRP(d, prng=None): 118 | x = hirm.CRP(d['alpha'], prng=prng) 119 | x.N = d['N'] 120 | x.tables = {ast.literal_eval(t): set(v) for t,v in d['tables'].items()} 121 | x.assignments = {ast.literal_eval(t): v for t,v in d['assignments'].items()} 122 | return x 123 | 124 | def to_dict_Domain(x): 125 | return { 126 | 'name': x.name, 127 | 'items': list(x.items), 128 | 'crp': to_dict_CRP(x.crp) 129 | } 130 | def from_dict_Domain(d, prng=None): 131 | x = hirm.Domain(d['name'], prng=prng) 132 | x.items = set(d['items']) 133 | x.crp = from_dict_CRP(d['crp']) 134 | return x 135 | 136 | def to_dict_Relation(x): 137 | return { 138 | 'name' : x.name, 139 | 'domains' : [d.name for d in x.domains], # Serialize names only. 140 | 'clusters' : {repr(c): to_dict_BetaBernoulli(v) for c,v in x.clusters.items()}, 141 | 'data' : {repr(c): v for c,v in x.data.items()}, 142 | 'data_r' : { 143 | repr(k) : {repr(k1): list(v1) for k1, v1 in v.items()} 144 | for k, v in x.data_r.items() 145 | } 146 | } 147 | def from_dict_Relation(d, prng=None): 148 | x = hirm.Relation(d['name'], [], prng=prng) 149 | x.domains = d['domains'] 150 | x.clusters = { 151 | ast.literal_eval(c): from_dict_BetaBernoulli(v, prng=prng) 152 | for c,v in d['clusters'].items() 153 | } 154 | x.data = {ast.literal_eval(c): v for c,v in d['data'].items()} 155 | x.data_r = { 156 | ast.literal_eval(k): { 157 | ast.literal_eval(k1): set(tuple(y) for y in v1) 158 | for k1,v1 in v.items() 159 | } 160 | for k, v in d['data_r'].items() 161 | } 162 | return x 163 | 164 | def to_dict_IRM(x): 165 | return { 166 | 'schema': x.schema, 167 | 'domains': {k: to_dict_Domain(v) for k,v in x.domains.items()}, 168 | 'relations': {k: to_dict_Relation(v) for k,v in x.relations.items()}, 169 | 'domain_to_relations': {k: list(v) for k,v in x.domain_to_relations.items()} 170 | } 171 | def from_dict_IRM(d, prng=None): 172 | x = hirm.IRM({}, prng=prng) 173 | x.schema = d['schema'] 174 | x.domains = {k: from_dict_Domain(v, prng=prng) for k,v in d['domains'].items()} 175 | x.relations = {k: from_dict_Relation(v, prng=prng) for k,v in d['relations'].items()} 176 | x.domain_to_relations = {k: set(v) for k,v in d['domain_to_relations'].items()} 177 | # Resolve Domain names into Domain objects. 178 | for relation in x.relations.values(): 179 | relation.domains = tuple([x.domains[d] for d in relation.domains]) 180 | return x 181 | 182 | def to_dict_HIRM(x): 183 | return { 184 | 'schema': x.schema, 185 | 'crp': to_dict_CRP(x.crp), 186 | 'irms' : {k: to_dict_IRM(v) for k, v in x.irms.items()} 187 | } 188 | def from_dict_HIRM(d, prng=None): 189 | x = hirm.HIRM({}, prng=prng) 190 | x.schema = d['schema'] 191 | x.crp = from_dict_CRP(d['crp'], prng=prng) 192 | x.irms = {int(k): from_dict_IRM(v, prng=prng) for k,v in d['irms'].items()} 193 | return x 194 | 195 | def to_txt_irm(path, irm): 196 | with open(path, 'w') as f: 197 | for domain in irm.domains.values(): 198 | tables = sorted(domain.crp.tables) 199 | for table in tables: 200 | customers = domain.crp.tables[table] 201 | customers_str = ' '.join(str(c) for c in customers) 202 | f.write('%s %d %s' % (domain.name, table, customers_str)) 203 | f.write('\n') 204 | 205 | def to_txt_hirm(path, hirm): 206 | with open(path, 'w') as f: 207 | tables = sorted(hirm.crp.tables) 208 | for table in tables: 209 | customers = hirm.crp.tables[table] 210 | customers_str = ' '.join(str(c) for c in customers) 211 | f.write('%d %s' % (table, customers_str)) 212 | f.write('\n') 213 | f.write('\n') 214 | j = 0 215 | for table in tables: 216 | f.write('irm=%d\n' % (table,)) 217 | irm = hirm.irms[table] 218 | for domain in irm.domains.values(): 219 | for table, customers in domain.crp.tables.items(): 220 | customers_str = ' '.join(str(c) for c in customers) 221 | f.write('%s %d %s' % (domain.name, table, customers_str)) 222 | f.write('\n') 223 | if j != len(hirm.irms) - 1: 224 | f.write('\n') 225 | j += 1 226 | 227 | def from_txt_irm(path_schema, path_obs, path_clusters): 228 | schema = load_schema(path_schema) 229 | observations = load_observations(path_obs) 230 | clusters = load_clusters_irm(path_clusters) 231 | irm = hirm.IRM(schema) 232 | for domain, tables in clusters.items(): 233 | for table, items in tables.items(): 234 | for item in items: 235 | irm.domains[domain].incorporate(item, table=table) 236 | for (relation, items, x) in observations: 237 | irm.incorporate(relation, items, x) 238 | return irm 239 | 240 | def from_txt_hirm(path_schema, path_obs, path_clusters): 241 | schema = load_schema(path_schema) 242 | observations = load_observations(path_obs) 243 | relations, irms = load_clusters_hirm(path_clusters) 244 | hirmm = hirm.HIRM(schema) 245 | for table in relations: 246 | for relation in relations[table]: 247 | if hirmm.crp.assignments[relation] != table: 248 | hirmm.set_cluster_assignment_gibbs(relation, table) 249 | irm = hirmm.irms[table] 250 | for domain, tables in irms[table].items(): 251 | for t, items in tables.items(): 252 | for item in items: 253 | irm.domains[domain].incorporate(item, table=t) 254 | for (relation, items, x) in observations: 255 | hirmm.incorporate(relation, items, x) 256 | return hirmm 257 | -------------------------------------------------------------------------------- /src/util_math.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 MIT Probabilistic Computing Project 2 | # Apache License, Version 2.0, refer to LICENSE.txt 3 | 4 | import math 5 | import random 6 | 7 | inf = float('inf') 8 | 9 | def linspace(start, stop, num=50, endpoint=True): 10 | """linspace from a to b with n entries.""" 11 | step = (stop - start) / (num - endpoint) 12 | return [start + step*i for i in range(num)] 13 | 14 | def log_linspace(a, b, n): 15 | """linspace from a to b with n entries over log scale.""" 16 | points = linspace(math.log(a), math.log(b), num=n) 17 | return [math.exp(x) for x in points] 18 | 19 | def log_normalize(log_weights): 20 | """Return log of the sum of exponentials of input list divided by sum.""" 21 | Z = logsumexp(log_weights) 22 | return [x - Z for x in log_weights] 23 | 24 | def log_choices(population, log_weights, k=1, prng=None): 25 | """Draw from a population given a list of log probabilities.""" 26 | log_weights_normalized = log_normalize(log_weights) 27 | weights = [math.exp(w) for w in log_weights_normalized] 28 | return (prng or random).choices(population, weights, k=k) 29 | 30 | def logsumexp(array): 31 | """Return log of the sum of exponentials of input elements.""" 32 | if len(array) == 0: 33 | return float('-inf') 34 | 35 | # m = +inf means addends are all +inf, hence so are sum and log. 36 | # m = -inf means addends are all zero, hence so is sum, and log is 37 | # -inf. But if +inf and -inf are among the inputs, or if input is 38 | # NaN, let the usual computation yield a NaN. 39 | m = max(array) 40 | if math.isinf(m) \ 41 | and min(array) != -m \ 42 | and all(not math.isnan(a) for a in array): 43 | return m 44 | 45 | # Since m = max{a_0, a_1, ...}, it follows that a <= m for all a, 46 | # so a - m <= 0; hence exp(a - m) is guaranteed not to overflow. 47 | return m + math.log(sum(math.exp(a - m) for a in array)) 48 | 49 | def logmeanexp(array): 50 | """Return log of the mean of exponentials of input elements.""" 51 | if len(array) == 0: 52 | return -inf 53 | 54 | # Treat -inf values as log 0 -- they contribute zero to the sum in 55 | # logsumexp, but one to the count. 56 | # 57 | # If we pass -inf values through to logsumexp, and there are also 58 | # +inf values, then we get NaN -- but if we had averaged exp(-inf) 59 | # = 0 and exp(+inf) = +inf, we would sensibly get +inf, whose log 60 | # is still +inf, not NaN. So strip -inf values first. 61 | # 62 | # Can't say `a > -inf' because that excludes NaNs, but we want to 63 | # include them so they propagate. 64 | noninfs = [a for a in array if not a == -inf] 65 | 66 | # probs = map(exp, array) 67 | # log(mean(probs)) 68 | # = log(sum(probs) / len(probs)) 69 | # = log(sum(probs)) - log(len(probs)) 70 | # = log(sum(map(exp, array))) - log(len(array)) 71 | # = logsumexp(array) - log(len(array)) 72 | return logsumexp(noninfs) - math.log(len(array)) 73 | -------------------------------------------------------------------------------- /src/util_plot.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 MIT Probabilistic Computing Project 2 | # Apache License, Version 2.0, refer to LICENSE.txt 3 | 4 | import copy 5 | 6 | import matplotlib.pyplot as plt 7 | import numpy as np 8 | 9 | ### Generic plotting functions 10 | ### ========================== 11 | 12 | def get_fig_ax(ax=None): 13 | """Create new fig, ax for unspecified axes.""" 14 | if ax is None: 15 | fig, ax = plt.subplots() 16 | return ax.get_figure(), ax 17 | 18 | def sort_crp_tables(tables): 19 | """Sort cluster assignments by number""" 20 | keys = sorted(tables, 21 | key=lambda t: (len(tables[t]), min(tables[t])), 22 | reverse=True) 23 | items = [item for table in keys for item in tables[table]] 24 | dividers = [len(tables[table]) for table in keys] 25 | return (items, np.cumsum(dividers)) 26 | 27 | def sort_data_binary_relation(data, tables_rows, tables_cols): 28 | """Sort rows and columns of a binary relation by cluster size.""" 29 | (items_rows, dividers_rows) = sort_crp_tables(tables_rows) 30 | (items_cols, dividers_cols) = sort_crp_tables(tables_cols) 31 | X = np.asarray([ 32 | [data.get((i,j), np.nan) for j in items_cols] 33 | for i in items_rows 34 | ]) 35 | return (X, (items_rows, items_cols), (dividers_rows, dividers_cols)) 36 | 37 | def sort_data_ternary_relation(data, predicate, tables_rows, tables_cols): 38 | """Sort rows and columns of a ternary relation by cluster size.""" 39 | (items_rows, dividers_rows) = sort_crp_tables(tables_rows) 40 | (items_cols, dividers_cols) = sort_crp_tables(tables_cols) 41 | X = np.asarray([ 42 | [data.get((predicate, i,j), np.nan) for j in items_cols] 43 | for i in items_rows 44 | ]) 45 | return (X, (items_rows, items_cols), (dividers_rows, dividers_cols)) 46 | 47 | def plot_data_matrix_sorted(X, items, dividers, transpose=None, ax=None): 48 | """Plot clustered 2D matrix.""" 49 | # Adapted from https://matplotlib.org/3.1.0/gallery/images_contours_and_fields/image_annotated_heatmap.html 50 | fig, ax = get_fig_ax(ax) 51 | X = X.T if transpose else X 52 | items_rows, items_cols = items[::-1] if transpose else items 53 | dividers_rows, dividers_cols = dividers[::-1] if transpose else dividers 54 | imshow(X, items_rows, items_cols, dividers_rows, dividers_cols, ax) 55 | return fig, ax 56 | 57 | def imshow(X, items_rows, items_cols, dividers_rows, dividers_cols, ax): 58 | """Main function for rendering an IRM matrix.""" 59 | cmap = copy.copy(plt.get_cmap('Greys')) 60 | cmap.set_bad(color='gray') 61 | # Use aspect='auto' 62 | # https://stackoverflow.com/q/44654421/1405543 63 | ax.imshow(X, cmap=cmap) 64 | # Set ticks. 65 | ax.xaxis.tick_top() 66 | ax.set_xticks(np.arange(X.shape[1])) 67 | ax.set_yticks(np.arange(X.shape[0])) 68 | ax.set_yticklabels(items_rows) 69 | ax.set_xticklabels(items_cols, rotation=90) 70 | if len(dividers_rows) > 1: 71 | for i in dividers_rows[:-1]: 72 | ax.axhline(i-.5, color='r', linewidth=2) 73 | if len(dividers_cols) > 1: 74 | for i in dividers_cols[:-1]: 75 | ax.axvline(i-.5, color='r', linewidth=2) 76 | # Make a thin grid on the minor axis. 77 | # ax.set_xticks(np.arange(X.shape[1]+1)-.5, minor=True) 78 | # ax.set_yticks(np.arange(X.shape[0]+1)-.5, minor=True) 79 | # ax.grid(which='minor', color='green', linestyle='-', linewidth=1) 80 | # ax.tick_params(which='minor', top=False, left=False) 81 | 82 | ### Relation plotting functions 83 | ### =========================== 84 | 85 | def sort_binary_relation(relation): 86 | """Sort rows and columns of a binary relation by cluster size.""" 87 | assert len(relation.domains) == 2 88 | tables_rows = relation.domains[0].crp.tables 89 | tables_cols = relation.domains[1].crp.tables 90 | data = relation.data 91 | return sort_data_binary_relation(data, tables_rows, tables_cols) 92 | 93 | def sort_ternary_relation(relation, predicate): 94 | """Sort rows and columns of a ternary relation by cluster size.""" 95 | assert len(relation.domains) == 3 96 | tables_rows = relation.domains[1].crp.tables 97 | tables_cols = relation.domains[2].crp.tables 98 | data = relation.data 99 | return sort_data_ternary_relation(data, predicate, tables_rows, tables_cols) 100 | 101 | def sort_unary_relations(relations): 102 | domain = relations[0].domains[0] 103 | assert all(len(relation.domains) == 1 for relation in relations) 104 | assert all(relation.domains[0] is domain for relation in relations) 105 | items_rows, dividers_rows = sort_crp_tables(domain.crp.tables) 106 | items_cols = [relation.name for relation in relations] 107 | X = np.asarray([ 108 | [relation.data.get((i,), np.nan) for relation in relations] 109 | for i in items_rows 110 | ]) 111 | return X, (items_rows, dividers_rows), (items_cols, []) 112 | 113 | def plot_binary_relation(relation, transpose=None, ax=None): 114 | """Plot matrix for a ternary relation, curried at first value = predicate.""" 115 | X, items, dividers = sort_binary_relation(relation) 116 | return plot_data_matrix_sorted(X, items, dividers, transpose=transpose, ax=ax) 117 | 118 | def plot_ternary_relation(relation, predicate, transpose=None, ax=None): 119 | """Plot matrix for a ternary relation, curried at first value = predicate.""" 120 | X, items, dividers = sort_ternary_relation(relation, predicate) 121 | return plot_data_matrix_sorted(X, items, dividers, transpose=transpose, ax=ax) 122 | 123 | def plot_unary_relations(relations, ax=None): 124 | """Plot partition of unary 'relations' learned by IRM, ala DPMM.""" 125 | fig, ax = get_fig_ax(ax) 126 | X, (items_rows, dividers_rows), (items_cols, dividers_cols) \ 127 | = sort_unary_relations(relations) 128 | imshow(X, items_rows, items_cols, dividers_rows, dividers_cols, ax) 129 | return fig, ax 130 | 131 | def plot_hirm_crosscat(hirm, relations): 132 | """Plot partition of unary 'relations' learned by HIRM, ala CrossCat.""" 133 | domain = hirm.relation(relations[0]).domains[0] 134 | for r in relations: 135 | assert len(hirm.relation(r).domains) == 1 136 | assert hirm.relation(r).domains[0].name == domain.name 137 | tables = set([hirm.crp.assignments[relation] for relation in relations]) 138 | fig, axes = plt.subplots(ncols=len(tables)) 139 | axes = np.atleast_1d(axes) 140 | for table, ax in zip(tables, axes): 141 | relations_table = [ 142 | hirm.relation(r) for r in relations 143 | if hirm.crp.assignments[r] == table 144 | ] 145 | X, (items_rows, dividers_rows), (items_cols, dividers_cols) \ 146 | = sort_unary_relations(relations_table) 147 | imshow(X, items_rows, items_cols, dividers_rows, dividers_cols, ax) 148 | for ax in axes: 149 | ax.set_aspect('auto') 150 | return fig, axes 151 | -------------------------------------------------------------------------------- /tests/.gitignore: -------------------------------------------------------------------------------- 1 | disabled_* 2 | -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 MIT Probabilistic Computing Project 2 | # Apache License, Version 2.0, refer to LICENSE.txt 3 | -------------------------------------------------------------------------------- /tests/test_basic.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 MIT Probabilistic Computing Project 2 | # Apache License, Version 2.0, refer to LICENSE.txt 3 | 4 | import itertools 5 | import json 6 | import math 7 | import os 8 | import random 9 | import tempfile 10 | 11 | import numpy as np 12 | 13 | from hirm import IRM 14 | from hirm.util_io import from_dict_IRM 15 | from hirm.util_io import from_txt_irm 16 | from hirm.util_io import to_dict_IRM 17 | from hirm.util_io import to_txt_irm 18 | from hirm.util_math import logsumexp 19 | from hirm.tests.util_test import make_two_clusters 20 | 21 | def test_irm_add_relation(): 22 | schema = { 23 | 'Flippers' : ['Animal'], # Feature 24 | 'Strain Teeth' : ['Animal'], # Feature 25 | 'Swims' : ['Animal'], # Feature 26 | 'Arctic' : ['Animal'], # Feature 27 | 'Hunts' : ['Animal', 'Animal'], # Likes relation 28 | } 29 | random.seed(1) 30 | model = IRM({}) 31 | for relation, domains in schema.items(): 32 | model.add_relation(relation, domains) 33 | model.incorporate('Arctic', ('Bear',), 1) 34 | model.incorporate('Hunts', ('Bear', 'Bear'), 0) 35 | model.incorporate('Hunts', ('Bear', 'Fish'), 1) 36 | model.transition_cluster_assignments() 37 | # model.unincorporate('Hunts', ('Bear', 'Bear')) 38 | # model.unincorporate('Hunts', ('Bear', 'Fish')) 39 | for relation in schema: 40 | model.remove_relation(relation) 41 | assert not model.relations 42 | assert not model.domains 43 | assert not model.domain_to_relations 44 | 45 | def test_irm_two_clusters(): 46 | schema, items_D1, items_D2, data = make_two_clusters() 47 | irm = IRM(schema, prng=random.Random(1)) 48 | for (i, j), v in data: 49 | irm.incorporate('R1', (i, j), v) 50 | # Run inference. 51 | for i in range(20): 52 | irm.transition_cluster_assignments() 53 | assert len(irm.domains['D1'].crp.tables) == 2 54 | assert set(items_D1[0]) in irm.domains['D1'].crp.tables.values() 55 | assert set(items_D1[1]) in irm.domains['D1'].crp.tables.values() 56 | assert set(items_D2[0]) in irm.domains['D2'].crp.tables.values() 57 | assert set(items_D2[1]) in irm.domains['D2'].crp.tables.values() 58 | # Check probabilities. 59 | expected_p0 = { 60 | (0, 0) : 1., 61 | (0, 10) : 1., 62 | (0, 100) : .5, 63 | (10, 0) : 0., 64 | (10, 10) : 0., 65 | (10, 100) : .5, 66 | (100, 0) : .66, 67 | (100, 10) : .66, 68 | (100, 100) : .5 69 | } 70 | for x1, x2 in itertools.product([0, 10, 100], [0, 10, 100]): 71 | p0 = irm.relations['R1'].logp((x1, x2), 0) 72 | p0_irm = irm.logp((('R1', (x1, x2), 0),)) 73 | assert np.allclose(p0, p0_irm) 74 | p1 = irm.relations['R1'].logp((x1, x2), 1) 75 | assert np.allclose(logsumexp([p0, p1]), 0) 76 | assert abs(math.exp(p0) - expected_p0[(x1, x2)]) < 0.1 77 | for (x1, x2, x3) in [(0, 10, 100), (110, 10, 100)]: 78 | p00 = irm.logp([ 79 | ('R1', (x1, x2), 0), 80 | ('R1', (x1, x3), 0) 81 | ]) 82 | p01 = irm.logp([ 83 | ('R1', (x1, x2), 0), 84 | ('R1', (x1, x3), 1) 85 | ]) 86 | p10 = irm.logp([ 87 | ('R1', (x1, x2), 1), 88 | ('R1', (x1, x3), 0) 89 | ]) 90 | p11 = irm.logp([ 91 | ('R1', (x1, x2), 1), 92 | ('R1', (x1, x3), 1) 93 | ]) 94 | assert np.allclose(logsumexp([p00, p01, p10, p11]), 0) 95 | 96 | def check_irms_agree(irm, x): 97 | schema, items_D1, items_D2, data = make_two_clusters() 98 | for d in ['D1', 'D2']: 99 | assert x.domains[d].crp.assignments == irm.domains[d].crp.assignments 100 | assert x.domains[d].crp.tables == irm.domains[d].crp.tables 101 | assert x.domains[d].items == irm.domains[d].items 102 | assert x.relations['R1'].data == irm.relations['R1'].data 103 | assert x.relations['R1'].data_r == irm.relations['R1'].data_r 104 | # Run inference. 105 | for i in range(20): 106 | x.transition_cluster_assignments() 107 | assert len(x.domains['D1'].crp.tables) == 2 108 | assert set(items_D1[0]) in x.domains['D1'].crp.tables.values() 109 | assert set(items_D1[1]) in x.domains['D1'].crp.tables.values() 110 | assert set(items_D2[0]) in x.domains['D2'].crp.tables.values() 111 | assert set(items_D2[1]) in x.domains['D2'].crp.tables.values() 112 | 113 | def test_irm_two_clusters_serialize_json_dict(): 114 | schema, items_D1, items_D2, data = make_two_clusters() 115 | prng = random.Random(1) 116 | irm = IRM(schema, prng=prng) 117 | for (i, j), v in data: 118 | irm.incorporate('R1', (i, j), v) 119 | # Serialize the prior IRM to dict and JSON. 120 | d1 = to_dict_IRM(irm) 121 | d2 = json.loads(json.dumps(d1)) 122 | irm1 = from_dict_IRM(d1, prng=prng) 123 | irm2 = from_dict_IRM(d2, prng=prng) 124 | for x in [irm1, irm2]: 125 | check_irms_agree(irm, x) 126 | 127 | def test_irm_two_clusters_serliaze_txt(): 128 | schema, items_D1, items_D2, data = make_two_clusters() 129 | prng = random.Random(1) 130 | irm = IRM(schema, prng=prng) 131 | for (i, j), v in data: 132 | irm.incorporate('R1', (i, j), v) 133 | with tempfile.NamedTemporaryFile(mode='w', delete=False) as f: 134 | path_schema = f.name 135 | f.write('bernoulli R1 D1 D2\n') 136 | with tempfile.NamedTemporaryFile(mode='w', delete=False) as f: 137 | path_obs = f.name 138 | for (i, j), v in data: 139 | f.write('%d R1 %d %d\n' % (v, i, j)) 140 | with tempfile.NamedTemporaryFile(delete=False) as f: 141 | path_irm = f.name 142 | to_txt_irm(path_irm, irm) 143 | irm1 = from_txt_irm(path_schema, path_obs, path_irm) 144 | check_irms_agree(irm, irm1) 145 | os.remove(path_schema) 146 | os.remove(path_obs) 147 | os.remove(path_irm) 148 | -------------------------------------------------------------------------------- /tests/util_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 MIT Probabilistic Computing Project 2 | # Apache License, Version 2.0, refer to LICENSE.txt 3 | 4 | def make_two_clusters(): 5 | schema = { 6 | 'R1': ('D1', 'D2') 7 | } 8 | items_D1 = [ 9 | list(range(0, 10)) + list(range(20,30)), 10 | list(range(10, 20)), 11 | ] 12 | items_D2 = [ 13 | list(range(0, 20)), 14 | list(range(20, 40)), 15 | ] 16 | data_d10_d20 = [((i, j), 0) for i in items_D1[0] for j in items_D2[0]] 17 | data_d10_d21 = [((i, j), 1) for i in items_D1[0] for j in items_D2[1]] 18 | data_d11_d20 = [((i, j), 1) for i in items_D1[1] for j in items_D2[0]] 19 | data_d11_d21 = [((i, j), 0) for i in items_D1[1] for j in items_D2[1]] 20 | data = data_d10_d20 + data_d10_d21 + data_d11_d20 + data_d11_d21 21 | return schema, items_D1, items_D2, data 22 | --------------------------------------------------------------------------------