├── .gitattributes ├── .gitignore ├── .vscode └── settings.json ├── CHANGELOG.md ├── FAQ.md ├── LICENSE.txt ├── QBN.md ├── README.md ├── REDIS.md ├── USAGE.md ├── python └── plotmarginals.py └── rust ├── Cargo.lock ├── Cargo.toml ├── check.sh ├── explorer.sh ├── makebackup.sh ├── plot.sh ├── src ├── baseline │ ├── mod.rs │ └── model.rs ├── bin │ ├── explorer_server.rs │ ├── list_entities.rs │ ├── plot.rs │ └── train.rs ├── common │ ├── graph.rs │ ├── interface.rs │ ├── logging.rs │ ├── mod.rs │ ├── model.rs │ ├── proposition_db.rs │ ├── redis.rs │ ├── resources.rs │ ├── setup.rs │ ├── test.rs │ └── train.rs ├── explorer │ ├── assets │ │ ├── animation.css │ │ ├── app.html │ │ ├── slides.html │ │ └── style.css │ ├── diagram_utils.rs │ ├── mod.rs │ ├── render_utils.rs │ └── routes │ │ ├── animation_route.rs │ │ ├── experiment_route.rs │ │ ├── factors_route.rs │ │ ├── index_route.rs │ │ ├── marginals_route.rs │ │ ├── mod.rs │ │ ├── network_route.rs │ │ └── weights_route.rs ├── inference │ ├── graph.rs │ ├── inference.rs │ ├── lambda.rs │ ├── mod.rs │ ├── pi.rs │ ├── rounds.rs │ └── table.rs ├── lib.rs ├── model │ ├── choose.rs │ ├── config.rs │ ├── creators.rs │ ├── exponential.rs │ ├── mod.rs │ ├── objects.rs │ ├── ops.rs │ └── weights.rs └── scenarios │ ├── dating_simple.rs │ ├── dating_triangle.rs │ ├── factory.rs │ ├── helpers.rs │ ├── long_and.rs │ ├── long_chain.rs │ ├── mid_chain.rs │ ├── mod.rs │ ├── one_var.rs │ └── two_var.rs ├── static └── images │ └── domains │ ├── Man.png │ └── Woman.png ├── train.sh └── unittests.sh /.gitattributes: -------------------------------------------------------------------------------- 1 | *.png filter=lfs diff=lfs merge=lfs -text 2 | *.jpg filter=lfs diff=lfs merge=lfs -text 3 | *.pdf filter=lfs diff=lfs merge=lfs -text -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | node_modules 2 | package-lock.json 3 | .DS_Store 4 | 5 | texput.log 6 | 7 | *.aux 8 | *.bbl 9 | *.blg 10 | *.log 11 | *.out 12 | 13 | backups 14 | 15 | output.png 16 | loss_graph.png 17 | 18 | output.md 19 | 20 | temp 21 | 22 | target 23 | 24 | output 25 | -------------------------------------------------------------------------------- /.vscode/settings.json: -------------------------------------------------------------------------------- 1 | { 2 | "rust-analyzer.linkedProjects": [ 3 | "./rust/Cargo.toml" 4 | ], 5 | "rust-analyzer.showUnlinkedFileNotification": false 6 | } -------------------------------------------------------------------------------- /CHANGELOG.md: -------------------------------------------------------------------------------- 1 | # CHANGELOG 2 | ## animation0111 3 | 4 | Make it so that "/animation///" will: 5 | 1) load up a scenario 6 | 2) remember the inferences for each 7 | 3) graph the marginals 8 | 9 | ## implications0108 10 | Render the "implications". 11 | 12 | Make there "libraries" so that each part can be made from smaller parts (compositionality). 13 | 14 | 15 | ## itdomain0107 16 | Figure out how to: 17 | * print out all of the information for each "entity" in a domain 18 | 19 | This involves: 20 | * make sure we register a domain 21 | * figure out how to keep a list of all entities per domain 22 | * actually make a binary to iterate through all the entities 23 | * run each of the tests like this 24 | 25 | ## namespaceit0107 26 | This PR is about seting up all of the objects in the ontology or universe to live in: 27 | * consistent namespaces -- so that all of the namespaces can run together 28 | * iterating over objects -- define and make iterable all the things we surface to the user 29 | 30 | ### Namespaces 31 | Want each experiment to exist in its on persistent "namespace". 32 | 33 | Maybe this can work by prepending the "experiment" name in front. 34 | 35 | Then, we have a list of "experiment names".. that gets you all of these other lists. 36 | 37 | ### Iteration 38 | We have to be able to iterate over: 39 | 40 | * DOMAIN Names 41 | * VERBs 42 | * entities in domains 43 | * specific propositions "that we know" 44 | 45 | 46 | ## cleanup0107 47 | * just clean up the code 48 | * run through all the scenarios and see if they all still work 49 | -------------------------------------------------------------------------------- /FAQ.md: -------------------------------------------------------------------------------- 1 | # Frequently Asked Questions (FAQ) 2 | 3 | ## Did you Literally Build AGI? 4 | The QBN as I am presented it is trained on **artificial data**. 5 | 6 | It will be AGI when the QBN is trained on **real web-scale data**. 7 | 8 | Right now, the QBN only "thinks about" very simple worlds that I encoded by hand. 9 | But, if we assume that the LLM has "world knowledge", then the only problem to get full AGI is to transfer the knowledge from the LLM to the QBN. 10 | 11 | That, I claim would be full AGI. Right now, I repeat, the QBN is trained on "toy universes" that I made up programmatically. 12 | 13 | ## Is it Trivial to Transfer Knowledge from LLM to QBN? 14 | No. This is not trivial. It will require that the LLM model be re-written to generate a **tree-structured** analysis of a sentence, mapping the **surface form** of the sentence to its **logical form**. 15 | 16 | This **logical form** is **latent**--meaning we can't observe it, and neither can actual people (this is why misunderstandings arise). 17 | 18 | So, the following new abilities need to be developed before "full AGI" exists: 19 | 1. parse to logical forms, which are: 20 | a. latent (not observed) 21 | b. structured (recursively tree-structured) 22 | 2. concretize the continuous knowledge of the LLM into the discrete knowledge of the QBN 23 | 24 | ## Does the QBN Help us Understand the LLM? 25 | Yes, I believe so. The QBN uses "semantic roles", which might explain why the "key-value" nature of the attention mechanism can learn world knowledge: 26 | that is, the **key-value** knowledge of the LLM is actually learning the **semantic role** knowledge of linguistics. 27 | -------------------------------------------------------------------------------- /LICENSE.txt: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | Copyright 2024 Greg Coppola 179 | 180 | Licensed under the Apache License, Version 2.0 (the "License"); 181 | you may not use this file except in compliance with the License. 182 | You may obtain a copy of the License at 183 | 184 | http://www.apache.org/licenses/LICENSE-2.0 185 | 186 | Unless required by applicable law or agreed to in writing, software 187 | distributed under the License is distributed on an "AS IS" BASIS, 188 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 189 | See the License for the specific language governing permissions and 190 | limitations under the License. 191 | -------------------------------------------------------------------------------- /QBN.md: -------------------------------------------------------------------------------- 1 | # QBN Model Overview 2 | 3 | # Background Reading 4 | This software package introduces the **Quantified Bayesian Network** (**QBN**). 5 | The QBN generalizes: 6 | 1. Traditional (generative) Bayesian Networks 7 | - **Bayesian Networks** are graphical models representing probabilistic relationships among variables. They use directed acyclic graphs to encode joint probability distributions, allowing for efficient reasoning and inference in complex systems. Bayesian Networks are widely used in various fields like machine learning, data analysis, and artificial intelligence for tasks like prediction, anomaly detection, and decision making. 8 | - Learn more: 9 | - [Bayesian Networks and their Applications](https://www.sciencedirect.com/topics/computer-science/bayesian-network) 10 | 2. First-Order Logic 11 | - **First-Order Logic** (FOL), also known as predicate logic or first-order predicate calculus, is a collection of formal systems used in mathematics, philosophy, linguistics, and computer science. It provides a framework for expressing statements with quantifiers and variables, allowing for the formulation of hypotheses about objects and their relationships. FOL is fundamental in formal systems, theorem proving, and is foundational in artificial intelligence for knowledge representation and reasoning. 12 | - Learn more: 13 | - [First-Order Logic: Basics](https://plato.stanford.edu/entries/logic-classical/) 14 | - [Understanding First-Order Logic](https://www.britannica.com/topic/formal-logic/Higher-order-and-modal-logic) 15 | 16 | # How Does the QBN Avoid Hallucinations? 17 | 18 | The QBN addresses the issue of hallucinations – which in the context of machine learning models refers to generating misleading or incorrect information – through a multifaceted approach. This is achieved by integrating aspects of first-order logic and principles of traditional Bayesian Networks. Here's a breakdown of how it accomplishes this: 19 | 20 | ## Using Logic 21 | - **Integration with First-Order Logic:** By generalizing first-order logic, the QBN can handle more complex and nuanced relationships between entities and their attributes. This logic-based approach enables the QBN to more accurately infer relationships and dependencies, reducing the likelihood of generating nonsensical or factually incorrect statements. 22 | - **Structured Reasoning:** The logical structure inherent in QBNs facilitates a more disciplined reasoning process. Unlike models that rely solely on statistical patterns, the QBN’s logic-based framework helps in maintaining consistency and coherence in its outputs, aligning closer with established logical norms and reducing errors that arise from purely data-driven inferences. 23 | 24 | ## Understanding the Argument 25 | - **Explanatory Capabilities:** QBNs are designed to not only make predictions or inferences but also to provide explanations for their outputs. This is crucial in understanding the 'why' behind a decision or inference, lending greater transparency and reliability to the model. 26 | - **Handling Uncertainty with Bayesian Principles:** Bayesian Networks excel in dealing with uncertainty. By incorporating these principles, QBNs can weigh evidence and consider various hypotheses, leading to more robust and well-supported conclusions. 27 | 28 | ## Awareness of Its Limitations 29 | - **Acknowledging Unknowns:** One of the key strengths of the QBN is its built-in mechanism to recognize the limits of its knowledge. This acknowledgment of uncertainty and unknown factors prevents overconfidence in its outputs, a common cause of hallucinations in other models. 30 | - **Continuous Learning and Adaptation:** The QBN framework allows for continuous updating and learning from new data, ensuring that the model remains relevant and its knowledge base evolves over time, further reducing the risk of outdated or incorrect information leading to hallucinations. 31 | 32 | ## Combining Logic and Causality 33 | - **Causal Reasoning:** The QBN extends beyond mere correlational data analysis by incorporating causal reasoning, drawing from ideas in classical Bayesian Networks. This enables the QBN to construct more realistic generative models of the world, leading to outputs that are not only statistically sound but also logically and causally coherent. 34 | - **Complex Generalization of First-Order Logic and Bayesian Principles:** The sophisticated interplay between the logical structure of FOL and the probabilistic reasoning of Bayesian Networks allows the QBN to navigate complex scenarios with a balanced approach, harnessing the strengths of both logical rigor and probabilistic flexibility. 35 | 36 | In summary, the QBN's ability to avoid hallucinations stems from its sophisticated integration of logical reasoning, causal inference, and an intrinsic understanding of its own limitations. This combination leads to more reliable, transparent, and accurate outputs, especially in complex and uncertain environments. -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # BAYES STAR ([bayesstar.com](https://bayesstar.com)) 2 | 3 | This software package provides a reference implementation of a: 4 | * Logical Bayesian Network 5 | 6 | 7 | ## Impact 8 | This **BAYES STAR** software package implements this [arXiv paper](https://arxiv.org/abs/2402.06557), and provides a *unified model* of *probabilistic* and *logical* reasoning. 9 | 10 | ## Usage 11 | For instructions on how to use the software see [USAGE.md](USAGE.md). 12 | 13 | ## Documentation 14 | 15 | ### arXiv Paper 16 | There is an [arXiv paper](https://arxiv.org/abs/2402.06557) now and I will add more links to the stuff "published to Bitcoin". 17 | As noted below, there is also a lot of multi-media content on my Twitter at [@coppola_ai](https://twitter.com/coppola_ai). 18 | 19 | ## References 20 | Here are some references that inspired this paper [BibTeX File](paper/bibtex.bib). 21 | 22 | ### Time-Stamping of Ideas on Bitcoin Chain 23 | For time-stamping, I have put all of my work on the Bitcoin Blockchain ([address1](https://ordinals.hiro.so/inscriptions?a=bc1pjlpr5nzl6cmljtyz0a3gng98y3r5hs8z68gw55vg4ccjptvj9msq5gqrc5), [address2](https://ordinals.hiro.so/inscriptions?a=bc1pvd4selnseakwz5eljgj4d99mka25mk8pp3k7v7hc6uxw8txy6lgsf7lmtg)). 24 | This provides a crypto-graphically secure time-stamping and immutable recording of each idea, whose accuracy is guarnateed by the value of the Bitcoin chain (almost $1 trillion). 25 | 26 | Find me online at: 27 | * twitter: [@coppola_ai](https://twitter.com/coppola_ai) 28 | 29 | ## License 30 | 31 | This project is licensed under the ISC License - see the [LICENSE.txt](LICENSE.txt) file for details. 32 | -------------------------------------------------------------------------------- /REDIS.md: -------------------------------------------------------------------------------- 1 | # Installing Redis on Your System 2 | 3 | **Do not run this software if you already have a Redis database on 'localhost' because it will get cleared.** 4 | 5 | Redis is an advanced key-value store, known for its flexibility, performance, and wide language support. This guide will walk you through the installation process for Redis on various operating systems. 6 | 7 | Remember you can always ask your favorite "chat bot" if you get stuck. 8 | 9 | ## Table of Contents 10 | 1. [Prerequisites](#prerequisites) 11 | 2. [Installing on Linux](#installing-on-linux) 12 | 3. [Installing on Windows](#installing-on-windows) 13 | 4. [Installing on macOS](#installing-on-macos) 14 | 5. [Verifying the Installation](#verifying-the-installation) 15 | 6. [Next Steps](#next-steps) 16 | 17 | ## Prerequisites 18 | - Basic knowledge of command line operations. 19 | - Administrative or root access on your system. 20 | 21 | ## Installing on Linux 22 | ### Debian/Ubuntu 23 | 1. Update your package list: 24 | ``` 25 | sudo apt-get update 26 | ``` 27 | 2. Install Redis: 28 | ``` 29 | sudo apt-get install redis-server 30 | ``` 31 | 3. Start Redis: 32 | ``` 33 | sudo service redis-server start 34 | ``` 35 | 36 | ### CentOS/RedHat 37 | 1. Add the EPEL repository: 38 | ``` 39 | sudo yum install epel-release 40 | ``` 41 | 2. Install Redis: 42 | ``` 43 | sudo yum install redis 44 | ``` 45 | 3. Start Redis: 46 | ``` 47 | sudo systemctl start redis 48 | 49 | ## Installing on Windows 50 | Redis does not natively support Windows. However, you can use the Windows Subsystem for Linux (WSL) or a Windows-compatible version of Redis. 51 | 1. [Enable WSL](https://docs.microsoft.com/en-us/windows/wsl/install) on Windows 10/11. 52 | 2. Follow the Linux installation steps within WSL. 53 | 54 | ## Installing on macOS 55 | 1. Install Homebrew if it's not already installed: 56 | ``` 57 | /bin/bash -c "$(curl -fsSL https://raw.githubusercontent.com/Homebrew/install/master/install.sh)" 58 | ``` 59 | 2. Install Redis: 60 | ``` 61 | brew install redis 62 | ``` 63 | 3. Start Redis: 64 | ``` 65 | brew services start redis 66 | ``` 67 | 68 | ## Verifying the Installation 69 | After installation, you can verify that Redis is running correctly: 70 | ``` 71 | redis-cli ping 72 | ``` 73 | If Redis is running, it will return: 74 | ``` 75 | PONG 76 | ``` 77 | 78 | ## Next Steps 79 | Now that Redis is installed, you can start using it in your projects. Check out the official [Redis documentation](https://redis.io/documentation) for more information on how to use Redis. 80 | 81 | --- 82 | 83 | Feel free to adjust the content to fit the specific needs and context of your GitHub repository. 84 | -------------------------------------------------------------------------------- /USAGE.md: -------------------------------------------------------------------------------- 1 | # Bayes-Star Usage Instructions 2 | 3 | This document provides instructions on how to set up and run the Bayes-Star project. 4 | 5 | # Warning if You Already Have Redis 6 | This software is currently set to clear the Redis database on localhost when it starts. 7 | I will work on a better UI. 8 | In the meantime: 9 | * **Do not run this software if you already have a Redis database on 'localhost' because it will get cleared.** 10 | 11 | 12 | # Reminder to "Use a Chat Bot" 13 | **At any time if you get stuck, just ask your favorite "chat bot"**. 14 | 15 | Keeping docs up to date has always been impossible, and I'm only testing on my own context, but you can just get your favorite chat bot to explain things to you if you know how to ask. I used [ChatGPT](https://chat.openai.com/) in the creation of this project. 16 | 17 | # System Overview 18 | The dependencies are: 19 | * **Rust** 20 | * This runs the **Rust** code that does inference and training [Rust](https://www.rust-lang.org/). 21 | * **REDIS** 22 | * This is an in-memory data store (similar to **MEMCACHE**) where the data and theories are stored as **strings**. 23 | * You can use any store and any serialization method. 24 | * **python3** 25 | * This is **optional** because I wrote my "eval" in python3. 26 | * But, it is trivial and you can start over in any framework. I'm not that current on the latest data analysis tools. 27 | 28 | # Redis for the Data Store 29 | **NOTE**: Training will *wipe out* your **REDIS** store on *localhost*, so **STOP** right now if you have **REDIS** on *localhost*. 30 | 31 | See [REDIS.md](REDIS.md). 32 | 33 | # Rust for the Model 34 | 35 | The main program is written in [Rust](https://www.rust-lang.org/). 36 | 37 | There is right now some analysis code written in [python3](https://www.python.org/). But, if you want to use a different language for analysis, you don't have to use python. 38 | 39 | ### Installing Rust on Your System 40 | 41 | See [Rust](https://www.rust-lang.org/) or ask your favorite chat bot. 42 | 43 | ## Run 44 | ### Training 45 | 46 | **NOTE**: Training will *wipe out* your **REDIS** store on *localhost*, so **STOP** right now if you have **REDIS** on *localhost*. 47 | 48 | From the `rust` directory: 49 | 50 | ``` 51 | ./train.sh dating_simple 52 | ``` 53 | 54 | ### Plotting Convergence 55 | Plot convergence for an observation of a variable using the string-valued test scenario key defined in `rust/src/bin/plot.rs`. 56 | 57 | From the `rust` directory: 58 | 59 | ``` 60 | ./plot.sh $OUTPUT_DIRECTORY dating_simple they_date 10 61 | ``` -------------------------------------------------------------------------------- /python/plotmarginals.py: -------------------------------------------------------------------------------- 1 | import json 2 | import matplotlib.pyplot as plt 3 | import sys 4 | import os 5 | 6 | def read_and_process_file(file_path, max_lines=None): 7 | data = {} 8 | with open(file_path, 'r') as file: 9 | for i, line in enumerate(file): 10 | if max_lines is not None and i >= max_lines: 11 | break # Stop reading if max_lines is reached 12 | print(f"time point {i}") 13 | json_line = json.loads(line) 14 | for entry in json_line['entries']: 15 | condition, probability = entry 16 | if not "exist" in condition: 17 | print(f"\"{condition}\" {probability}") 18 | if condition not in data: 19 | data[condition] = [] 20 | data[condition].append(probability) 21 | return data 22 | 23 | def plot_data(data, out_file_path): 24 | plt.figure(figsize=(10, 6)) 25 | for condition, probabilities in data.items(): 26 | plt.plot(probabilities, label=condition) 27 | plt.xlabel('Timepoint') 28 | plt.ylabel('Probability') 29 | plt.title('Probability of Conditions Over Time') 30 | plt.legend() 31 | plt.savefig(out_file_path) # Save the plot to a file 32 | 33 | def main(): 34 | if len(sys.argv) < 2: 35 | print("Usage: python script.py [max_lines]") 36 | sys.exit(1) 37 | input_path = sys.argv[1] 38 | max_lines = int(sys.argv[2]) 39 | out_path = f"{input_path}_plot_{max_lines}.png" 40 | data = read_and_process_file(input_path, max_lines) 41 | plot_data(data, out_path) 42 | print(f"Plot saved to {out_path}") 43 | 44 | if __name__ == "__main__": 45 | main() -------------------------------------------------------------------------------- /rust/Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | name = "bayes-star" 3 | version = "0.1.0" 4 | edition = "2021" 5 | 6 | [dependencies] 7 | redis = "0.23.0" 8 | serde = { version = "1.0", features = ["derive"] } 9 | serde_json = "1.0" 10 | rand = "0.8.4" 11 | log = "0.4" 12 | env_logger = "0.10.1" 13 | clap = "3.0" 14 | once_cell = "1.7" 15 | colored = "1.7.0" 16 | rocket = "0.4.11" 17 | rocket_contrib = { version = "0.4.11", features = ["json"] } 18 | walkdir = "2.3" -------------------------------------------------------------------------------- /rust/check.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Run cargo check 4 | cargo check 5 | 6 | # Check the exit status of the last command (cargo check) 7 | if [ $? -eq 0 ]; then 8 | # Check if the first argument is "test" 9 | if [ "$1" == "test" ]; then 10 | # If cargo check was successful and first argument is "test", run cargo check for tests 11 | cargo check --tests 12 | else 13 | echo "Skipping test checks, as the first argument is not 'test'." 14 | fi 15 | else 16 | # If cargo check failed, exit the script 17 | echo "cargo check failed, aborting script." 18 | exit 1 19 | fi 20 | 21 | -------------------------------------------------------------------------------- /rust/explorer.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | export ROCKET_ENV=development 3 | 4 | RUST_BACKTRACE=1 cargo run --bin explorer_server -- --scenario_name dating_simple -------------------------------------------------------------------------------- /rust/makebackup.sh: -------------------------------------------------------------------------------- 1 | cd .. 2 | git archive --format=zip -o "backups/bayes-star_$(git rev-parse --abbrev-ref HEAD)_$(date "+%Y-%b-%d-%H-%M").zip" HEAD 3 | 4 | 5 | -------------------------------------------------------------------------------- /rust/plot.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | ROOT_DATA=$1 4 | SCENARIO_NAME=$2 5 | TEST_SCENARIO=$3 6 | NUM_ITERATIONS_TO_PLOT=$4 7 | 8 | if [ -z "$ROOT_DATA" ] || [ -z "$SCENARIO_NAME" ] || [ -z "$TEST_SCENARIO" ] || [ -z "$NUM_ITERATIONS_TO_PLOT" ]; then 9 | echo "Error: ROOT_DATA directory, scenario name, and test scenario must be provided." 10 | echo "usage: ./plot.sh " 11 | exit 1 12 | fi 13 | 14 | # Compute MARGINAL_OUTPUT_FILE based on SCENARIO_NAME and TEST_SCENARIO 15 | MARGINAL_OUTPUT_FILE="${ROOT_DATA}/${SCENARIO_NAME}_${TEST_SCENARIO}" 16 | 17 | # Proceed with the original command using the variables and add the --marginal_output_file argument 18 | RUST_BACKTRACE=1 RUST_LOG=info cargo run --bin plot -- --print_training_loss --entities_per_domain=1024 --test_example=0 --scenario_name=$SCENARIO_NAME --test_scenario=$TEST_SCENARIO --marginal_output_file=$MARGINAL_OUTPUT_FILE 19 | 20 | # TODO: replace with javascript 21 | # python3 ../python/plotmarginals.py $MARGINAL_OUTPUT_FILE $NUM_ITERATIONS_TO_PLOT -------------------------------------------------------------------------------- /rust/src/baseline/mod.rs: -------------------------------------------------------------------------------- 1 | mod model; 2 | -------------------------------------------------------------------------------- /rust/src/baseline/model.rs: -------------------------------------------------------------------------------- 1 | use std::{collections::HashMap, error::Error}; 2 | use crate::model::objects::PredicateGroup; 3 | 4 | pub struct MonolithicBayes { 5 | underlying:HashMap, 6 | } 7 | 8 | 9 | impl MonolithicBayes { 10 | pub fn new() -> Result> { 11 | Ok(MonolithicBayes{ underlying: HashMap::new() }) 12 | } 13 | } -------------------------------------------------------------------------------- /rust/src/bin/explorer_server.rs: -------------------------------------------------------------------------------- 1 | #![feature(decl_macro)] 2 | #[macro_use] 3 | extern crate rocket; 4 | 5 | use bayes_star::{ 6 | common::{ 7 | resources::ResourceContext, 8 | setup::{parse_configuration_options, CommandLineOptions}, 9 | }, 10 | explorer::routes::{animation_route::internal_animation, experiment_route::internal_experiment, factors_route::internal_factors, index_route::internal_index, marginals_route::internal_marginals, network_route::internal_network, weights_route::internal_weights}, 11 | }; 12 | use rocket::response::content::Html; 13 | use rocket::State; 14 | use rocket_contrib::serve::StaticFiles; 15 | 16 | pub struct WebContext { 17 | namespace: ResourceContext, 18 | } 19 | 20 | impl WebContext { 21 | pub fn new(config: CommandLineOptions) -> Self { 22 | let namespace = ResourceContext::new(&config).expect("Failed to create factory resources"); 23 | WebContext { namespace } 24 | } 25 | } 26 | 27 | #[get("/")] 28 | fn home(_context: State) -> Html { 29 | internal_index() 30 | } 31 | 32 | #[get("/experiment/")] 33 | fn experiment(experiment_name: String, context: State) -> Html { 34 | internal_experiment(&experiment_name, &context.namespace) 35 | } 36 | 37 | #[get("/network/")] 38 | fn network(experiment_name: String, context: State) -> Html { 39 | internal_network(&experiment_name, &context.namespace) 40 | } 41 | 42 | #[get("/weights/")] 43 | fn weights(experiment_name: String, context: State) -> Html { 44 | internal_weights(&experiment_name, &context.namespace) 45 | } 46 | 47 | #[get("/marginals//")] 48 | fn marginals(experiment_name: String, test_scenario: String, context: State) -> Html { 49 | internal_marginals(&experiment_name, &test_scenario, &context.namespace) 50 | } 51 | 52 | #[get("/factors/")] 53 | fn factors(experiment_name: String, context: State) -> Html { 54 | internal_factors(&experiment_name, &context.namespace) 55 | } 56 | 57 | #[get("/animation//")] 58 | fn animation(experiment_name: String, test_scenario: String, context: State) -> Html { 59 | internal_animation(&experiment_name, &test_scenario, &context.namespace) 60 | } 61 | 62 | fn main() { 63 | let config = parse_configuration_options(); 64 | rocket::ignite() 65 | .manage(WebContext::new(config)) 66 | .mount("/", routes![home, experiment, network, weights, marginals, factors, animation]) 67 | .mount("/static", StaticFiles::from("static")) 68 | .launch(); 69 | } 70 | -------------------------------------------------------------------------------- /rust/src/bin/list_entities.rs: -------------------------------------------------------------------------------- 1 | use bayes_star::common::{graph::InferenceGraph, resources::ResourceContext, setup::parse_configuration_options}; 2 | 3 | fn main() { 4 | let config: bayes_star::common::setup::CommandLineOptions = parse_configuration_options(); 5 | let resources = ResourceContext::new(&config).unwrap(); 6 | let mut connection = resources.connection.lock().unwrap(); 7 | let graph = InferenceGraph::new_shared(config.scenario_name.clone()).unwrap(); 8 | // 9 | // Domains. 10 | let all_domains = graph.get_all_domains(&mut connection).unwrap(); 11 | println!("all_domains {:?}", &all_domains); 12 | for domain in &all_domains { 13 | let elements = graph.get_entities_in_domain(&mut connection, domain).unwrap(); 14 | println!("elements: {:?}", &elements); 15 | } 16 | // 17 | // Relations. 18 | let all_relations = graph.get_all_relations(&mut connection).unwrap(); 19 | println!("all_relations {:?}", &all_relations); 20 | for relation in &all_relations { 21 | println!("relation {:?}", relation); 22 | } 23 | // 24 | // Implications. 25 | let all_implications = graph.get_all_implications(&mut connection).unwrap(); 26 | println!("all_implications {:?}", &all_implications); 27 | for implication in &all_implications { 28 | println!("implication {:?}", implication); 29 | } 30 | 31 | println!("main finishes"); 32 | } 33 | 34 | -------------------------------------------------------------------------------- /rust/src/bin/plot.rs: -------------------------------------------------------------------------------- 1 | use bayes_star::common::resources::ResourceContext; 2 | use bayes_star::common::setup::parse_configuration_options; 3 | use bayes_star::inference::rounds::run_inference_rounds; 4 | 5 | extern crate log; 6 | 7 | fn main() { 8 | let config = parse_configuration_options(); 9 | let resources = ResourceContext::new(&config).expect("Couldn't create resources."); 10 | let test_scenario = config.test_scenario.expect("no test_scenario in config"); 11 | let mut connection = resources.connection.lock().unwrap(); 12 | let marginal_tables = run_inference_rounds(&mut connection, &config.scenario_name, &test_scenario) 13 | .expect("Testing failed."); 14 | for marginal_table in &marginal_tables { 15 | println!("table {:?}", marginal_table); 16 | } 17 | println!("main finishes"); 18 | } 19 | -------------------------------------------------------------------------------- /rust/src/bin/train.rs: -------------------------------------------------------------------------------- 1 | use std::borrow::Borrow; 2 | 3 | use bayes_star::common::setup::parse_configuration_options; 4 | use bayes_star::common::{resources::ResourceContext, train::setup_and_train}; 5 | use bayes_star::scenarios::factory::ScenarioMakerFactory; 6 | 7 | #[macro_use] 8 | extern crate log; 9 | 10 | fn main() { 11 | let config = parse_configuration_options(); 12 | let resources = ResourceContext::new(&config).expect("Couldn't create resources."); 13 | let scenario_maker = ScenarioMakerFactory::new_shared(&config.scenario_name).unwrap(); 14 | setup_and_train(&resources, scenario_maker.borrow(), &config.scenario_name).expect("Error in training."); 15 | trace!("program done"); 16 | } 17 | -------------------------------------------------------------------------------- /rust/src/common/graph.rs: -------------------------------------------------------------------------------- 1 | use super::{ 2 | interface::{PredictStatistics, TrainStatistics}, 3 | redis::{set_value, RedisManager}, 4 | resources::ResourceContext, 5 | }; 6 | use crate::{ 7 | common::{ 8 | interface::BeliefTable, 9 | redis::{get_value, is_member, set_add, set_members}, 10 | }, 11 | model::{ 12 | self, 13 | choose::{ 14 | extract_existence_factor_for_predicate, extract_existence_factor_for_proposition, 15 | }, 16 | exponential::ExponentialModel, 17 | objects::{ 18 | Domain, Entity, ImplicationFactor, Predicate, PredicateGroup, Proposition, 19 | PropositionGroup, Relation, 20 | }, 21 | }, 22 | print_blue, 23 | }; 24 | use redis::{Commands, Connection}; 25 | use serde::{Deserialize, Serialize}; 26 | use std::{ 27 | cell::RefCell, 28 | error::Error, 29 | rc::Rc, 30 | sync::{Arc, Mutex}, 31 | }; 32 | pub struct InferenceGraph { 33 | pub namespace: String, 34 | } 35 | 36 | impl InferenceGraph { 37 | pub fn new_mutable(namespace: String) -> Result, Box> { 38 | Ok(Box::new(InferenceGraph { namespace })) 39 | } 40 | 41 | pub fn new_shared(namespace: String) -> Result, Box> { 42 | Ok(Arc::new(InferenceGraph { namespace })) 43 | } 44 | 45 | pub fn new_literal( 46 | redis_connection: Arc>, 47 | namespace: String, 48 | ) -> Result> { 49 | Ok(InferenceGraph { namespace }) 50 | } 51 | 52 | pub fn register_experiment( 53 | &mut self, 54 | connection: &mut Connection, 55 | experiment_name: &str, 56 | ) -> Result<(), Box> { 57 | set_add( 58 | connection, 59 | &self.namespace, 60 | &Self::experiment_set_name(), 61 | experiment_name, 62 | )?; 63 | Ok(()) 64 | } 65 | 66 | pub fn get_all_experiments( 67 | &self, 68 | connection: &mut Connection, 69 | ) -> Result, Box> { 70 | let set_members: Vec = 71 | set_members(connection, &self.namespace, &Self::experiment_set_name())?; 72 | Ok(set_members) 73 | } 74 | 75 | pub fn register_relation( 76 | &mut self, 77 | connection: &mut Connection, 78 | relation: &Relation, 79 | ) -> Result<(), Box> { 80 | let record = serialize_record(relation)?; 81 | set_add( 82 | connection, 83 | &self.namespace, 84 | &Self::relation_set_name(), 85 | &record, 86 | )?; 87 | Ok(()) 88 | } 89 | 90 | pub fn check_relation( 91 | &mut self, 92 | connection: &mut Connection, 93 | relation: &Relation, 94 | ) -> Result<(), Box> { 95 | // TODO: impelment this 96 | Ok(()) 97 | } 98 | 99 | pub fn get_all_relations( 100 | &self, 101 | connection: &mut Connection, 102 | ) -> Result, Box> { 103 | let set_members: Vec = 104 | set_members(connection, &self.namespace, &Self::relation_set_name())?; 105 | set_members 106 | .into_iter() 107 | .map(|record| serde_json::from_str(&record).map_err(|e| Box::new(e) as Box)) 108 | .collect() 109 | } 110 | 111 | pub fn register_domain( 112 | &mut self, 113 | connection: &mut Connection, 114 | domain: &String, 115 | ) -> Result<(), Box> { 116 | set_add(connection, &self.namespace, "domains", domain)?; 117 | Ok(()) 118 | } 119 | 120 | pub fn check_domain( 121 | &self, 122 | connection: &mut Connection, 123 | domain: &String, 124 | ) -> Result<(), Box> { 125 | let result = is_member(connection, &self.namespace, "domains", domain)?; 126 | assert!(result); 127 | Ok(()) 128 | } 129 | 130 | pub fn get_all_domains( 131 | &self, 132 | connection: &mut Connection, 133 | ) -> Result, Box> { 134 | let result = set_members(connection, &self.namespace, "domains")?; 135 | Ok(result) 136 | } 137 | 138 | pub fn register_target( 139 | &mut self, 140 | connection: &mut Connection, 141 | target: &Proposition, 142 | ) -> Result<(), Box> { 143 | let record = serialize_record(target)?; 144 | set_value( 145 | connection, 146 | &self.namespace, 147 | &Self::target_key_name(), 148 | &record, 149 | )?; 150 | Ok(()) 151 | } 152 | 153 | pub fn get_target(&self, connection: &mut Connection) -> Result> { 154 | let record = get_value(connection, &self.namespace, &Self::target_key_name())?.unwrap(); 155 | serde_json::from_str(&record).map_err(|e| Box::new(e) as Box) 156 | } 157 | 158 | pub fn store_entity( 159 | &mut self, 160 | connection: &mut Connection, 161 | entity: &Entity, 162 | ) -> Result<(), Box> { 163 | trace!( 164 | "Storing entity in domain '{}': {}", 165 | entity.domain, 166 | entity.name 167 | ); 168 | self.check_domain(connection, &entity.domain)?; 169 | // NOTE: this is a "set" named after the "domain", with each "entity.name" inside of it. 170 | set_add( 171 | connection, 172 | &self.namespace, 173 | &entity.domain.to_string(), 174 | &entity.name, 175 | )?; 176 | Ok(()) 177 | } 178 | 179 | pub fn get_entities_in_domain( 180 | &self, 181 | connection: &mut Connection, 182 | domain: &String, 183 | ) -> Result, Box> { 184 | let domain_string = domain.to_string(); 185 | let names: Vec = set_members(connection, &self.namespace, &domain_string)?; 186 | Ok(names 187 | .into_iter() 188 | .map(|name| Entity { 189 | domain: domain.clone(), 190 | name, 191 | }) 192 | .collect()) 193 | } 194 | 195 | fn predicate_backward_set_name(predicate: &Predicate) -> String { 196 | format!("predicate_backward:{}", predicate.hash_string()) 197 | } 198 | 199 | fn implication_seq_name() -> String { 200 | "implications".to_string() 201 | } 202 | 203 | fn relation_set_name() -> String { 204 | "relations".to_string() 205 | } 206 | 207 | fn experiment_set_name() -> String { 208 | "experiments".to_string() 209 | } 210 | 211 | fn target_key_name() -> String { 212 | "target".to_string() 213 | } 214 | 215 | fn store_implication( 216 | &mut self, 217 | connection: &mut Connection, 218 | implication: &ImplicationFactor, 219 | ) -> Result<(), Box> { 220 | let record = serialize_record(implication)?; 221 | set_add( 222 | connection, 223 | &self.namespace, 224 | &Self::implication_seq_name(), 225 | &record, 226 | )?; 227 | Ok(()) 228 | } 229 | 230 | // TODO: I feel like this should not be public. 231 | pub fn ensure_existence_backlinks_for_proposition( 232 | &mut self, 233 | connection: &mut Connection, 234 | proposition: &Proposition, 235 | ) -> Result<(), Box> { 236 | let implication = extract_existence_factor_for_proposition(proposition)?; 237 | self.store_predicate_implication(connection, &implication)?; 238 | Ok(()) 239 | } 240 | 241 | fn store_predicate_backward_link( 242 | &mut self, 243 | connection: &mut Connection, 244 | inference: &ImplicationFactor, 245 | ) -> Result<(), Box> { 246 | let conclusion = &inference.conclusion; 247 | let record = serialize_record(inference)?; 248 | set_add( 249 | connection, 250 | &self.namespace, 251 | &Self::predicate_backward_set_name(conclusion), 252 | &record, 253 | )?; 254 | Ok(()) 255 | } 256 | 257 | pub fn store_predicate_implication( 258 | &mut self, 259 | connection: &mut Connection, 260 | implication: &ImplicationFactor, 261 | ) -> Result<(), Box> { 262 | self.store_implication(connection, implication)?; 263 | self.store_predicate_backward_link(connection, implication)?; 264 | Ok(()) 265 | } 266 | 267 | pub fn store_predicate_implications( 268 | &mut self, 269 | connection: &mut Connection, 270 | implications: &Vec, 271 | ) -> Result<(), Box> { 272 | for implication in implications { 273 | self.store_predicate_implication(connection, implication)?; 274 | } 275 | Ok(()) 276 | } 277 | 278 | pub fn get_all_implications( 279 | &self, 280 | connection: &mut Connection, 281 | ) -> Result, Box> { 282 | let set_members: Vec = 283 | set_members(connection, &self.namespace, &Self::implication_seq_name())?; 284 | 285 | set_members 286 | .into_iter() 287 | .map(|record| { 288 | trace!("Deserializing record: {}", record); // Log each record before deserialization 289 | serde_json::from_str::(&record).map_err(|e| { 290 | trace!("Failed to deserialize record: {}, Error: {}", record, e); // Log if deserialization fails 291 | Box::new(e) as Box 292 | }) 293 | }) 294 | .collect() 295 | } 296 | 297 | pub fn predicate_backward_links( 298 | &self, 299 | connection: &mut Connection, 300 | conclusion: &Predicate, 301 | ) -> Result, Box> { 302 | let set_members: Vec = set_members( 303 | connection, 304 | &self.namespace, 305 | &Self::predicate_backward_set_name(conclusion), 306 | )?; 307 | set_members 308 | .into_iter() 309 | .map(|record| serde_json::from_str(&record).map_err(|e| Box::new(e) as Box)) 310 | .collect() 311 | } 312 | } 313 | 314 | pub fn serialize_record(obj: &T) -> Result> 315 | where 316 | T: Serialize, 317 | { 318 | serde_json::to_string(obj).map_err(|e| Box::new(e) as Box) 319 | } 320 | 321 | fn deserialize_record<'a, T>(record: &'a str) -> Result> 322 | where 323 | T: Deserialize<'a>, 324 | { 325 | serde_json::from_str(record).map_err(|e| Box::new(e) as Box) 326 | } 327 | -------------------------------------------------------------------------------- /rust/src/common/interface.rs: -------------------------------------------------------------------------------- 1 | use std::error::Error; 2 | 3 | use redis::Connection; 4 | 5 | use crate::model::objects::{PredicateGroup, ImplicationFactor, Predicate, Proposition}; 6 | 7 | use super::{graph::InferenceGraph, model::InferenceModel, train::TrainingPlan, redis::RedisManager, resources::ResourceContext}; 8 | 9 | pub struct TrainStatistics { 10 | pub loss: f64, 11 | } 12 | 13 | pub struct PredictStatistics { 14 | pub probability: f64, 15 | } 16 | 17 | pub trait BeliefTable { 18 | fn get_proposition_probability( 19 | &self, 20 | context: &mut Connection, 21 | proposition: &Proposition, 22 | ) -> Result, Box>; 23 | 24 | fn store_proposition_probability( 25 | &self, 26 | context: &mut Connection, 27 | proposition: &Proposition, 28 | probability: f64, 29 | ) -> Result<(), Box>; 30 | 31 | fn store_proposition_boolean( 32 | &self, 33 | context: &mut Connection, 34 | proposition: &Proposition, 35 | observation: bool, 36 | ) -> Result<(), Box> { 37 | if observation { 38 | self.store_proposition_probability(context, proposition, 1.0)?; 39 | } else { 40 | self.store_proposition_probability(context, proposition, 0.0)?; 41 | } 42 | Ok(()) 43 | } 44 | } 45 | 46 | pub trait ScenarioMaker { 47 | fn setup_scenario( 48 | &self, 49 | redis: &ResourceContext, 50 | ) -> Result<(), Box>; 51 | } 52 | -------------------------------------------------------------------------------- /rust/src/common/logging.rs: -------------------------------------------------------------------------------- 1 | 2 | #[macro_export] 3 | macro_rules! print_red { 4 | ($($arg:tt)*) => { 5 | use colored::*; 6 | println!("{}", format!($($arg)*).red()); 7 | }; 8 | } 9 | #[macro_export] 10 | macro_rules! print_green { 11 | ($($arg:tt)*) => { 12 | use colored::*; 13 | println!("{}", format!($($arg)*).green()); 14 | }; 15 | } 16 | #[macro_export] 17 | macro_rules! print_yellow { 18 | ($($arg:tt)*) => { 19 | use colored::*; 20 | println!("{}", format!($($arg)*).yellow()); 21 | }; 22 | } 23 | #[macro_export] 24 | macro_rules! print_blue { 25 | ($($arg:tt)*) => { 26 | use colored::*; 27 | println!("{}", format!($($arg)*).blue()); 28 | }; 29 | } -------------------------------------------------------------------------------- /rust/src/common/mod.rs: -------------------------------------------------------------------------------- 1 | pub mod redis; 2 | pub mod interface; 3 | pub mod model; 4 | pub mod graph; 5 | pub mod proposition_db; 6 | pub mod train; 7 | pub mod resources; 8 | pub mod setup; 9 | pub mod test; 10 | pub mod logging; -------------------------------------------------------------------------------- /rust/src/common/model.rs: -------------------------------------------------------------------------------- 1 | use crate::{ 2 | common::interface::BeliefTable, 3 | inference::graph::PropositionFactor, 4 | model::{ 5 | self, 6 | exponential::ExponentialModel, 7 | objects::{Domain, Entity, ImplicationFactor, Predicate, PredicateGroup, Proposition}, 8 | }, 9 | }; 10 | use redis::{Commands, Connection}; 11 | use std::{cell::RefCell, collections::HashMap, error::Error, rc::Rc, sync::Arc}; 12 | 13 | use super::{ 14 | graph::InferenceGraph, 15 | interface::{PredictStatistics, TrainStatistics}, 16 | proposition_db::RedisBeliefTable, 17 | redis::RedisManager, 18 | resources::ResourceContext, 19 | }; 20 | 21 | pub struct InferenceModel { 22 | pub graph: Arc, 23 | pub model: Arc, 24 | } 25 | 26 | impl InferenceModel { 27 | pub fn new_shared(namespace: String) -> Result, Box> { 28 | let graph = InferenceGraph::new_shared(namespace.clone())?; 29 | let model = ExponentialModel::new_shared(namespace.clone())?; 30 | Ok(Arc::new(InferenceModel { graph, model })) 31 | } 32 | } 33 | 34 | #[derive(Debug)] 35 | pub struct FactorContext { 36 | pub factor: Vec, 37 | pub probabilities: Vec, 38 | } 39 | 40 | pub trait FactorModel { 41 | fn initialize_connection( 42 | &mut self, 43 | connection: &mut Connection, 44 | implication: &ImplicationFactor, 45 | ) -> Result<(), Box>; 46 | 47 | fn train( 48 | &mut self, 49 | connection: &mut Connection, 50 | factor: &FactorContext, 51 | probability: f64, 52 | ) -> Result>; 53 | 54 | fn predict( 55 | &self, 56 | connection: &mut Connection, 57 | factor: &FactorContext, 58 | ) -> Result>; 59 | } 60 | -------------------------------------------------------------------------------- /rust/src/common/proposition_db.rs: -------------------------------------------------------------------------------- 1 | use crate::{ 2 | common::{interface::BeliefTable, redis::map_insert}, 3 | inference::table::PropositionNode, 4 | model::{ 5 | self, 6 | exponential::ExponentialModel, 7 | objects::{ 8 | Domain, Entity, Predicate, ImplicationFactor, PredicateGroup, Proposition, 9 | existence_predicate_name, 10 | }, 11 | }, 12 | }; 13 | use redis::{Commands, Connection}; 14 | use std::{cell::RefCell, collections::HashMap, error::Error, io::Empty, rc::Rc, sync::{Arc, Mutex}}; 15 | 16 | use super::{ 17 | graph::InferenceGraph, 18 | interface::{PredictStatistics, TrainStatistics}, 19 | redis::{map_get, RedisManager}, resources::ResourceContext, 20 | }; 21 | 22 | pub struct RedisBeliefTable { 23 | namespace: String, 24 | } 25 | 26 | impl RedisBeliefTable { 27 | pub fn new_mutable(namespace: String) -> Result, Box> { 28 | Ok(Box::new(RedisBeliefTable { namespace })) 29 | } 30 | pub fn new_shared(namespace: String) -> Result, Box> { 31 | Ok(Rc::new(RedisBeliefTable { namespace })) 32 | } 33 | pub const PROBABILITIES_KEY: &'static str = "probabilities"; 34 | } 35 | 36 | impl BeliefTable for RedisBeliefTable { 37 | // Return Some if the probability exists in the table, or else None. 38 | fn get_proposition_probability( 39 | &self, 40 | connection: &mut Connection, 41 | proposition: &Proposition, 42 | ) -> Result, Box> { 43 | if proposition.predicate.relation.relation_name == existence_predicate_name() { 44 | return Ok(Some(1f64)); 45 | } 46 | let hash_string = proposition.predicate.hash_string(); 47 | let probability_record = map_get( 48 | connection, 49 | &self.namespace, 50 | Self::PROBABILITIES_KEY, 51 | &hash_string, 52 | )? 53 | .expect("should be there"); 54 | let probability = probability_record 55 | .parse::() 56 | .map_err(|e| Box::new(e) as Box)?; 57 | Ok(Some(probability)) 58 | } 59 | 60 | fn store_proposition_probability( 61 | &self, 62 | connection: &mut Connection, 63 | proposition: &Proposition, 64 | probability: f64, 65 | ) -> Result<(), Box> { 66 | trace!("GraphicalModel::store_proposition_probability - Start. Input proposition: {:?}, probability: {}", proposition, probability); 67 | let hash_string = proposition.predicate.hash_string(); 68 | map_insert( 69 | connection, 70 | &self.namespace, 71 | Self::PROBABILITIES_KEY, 72 | &hash_string, 73 | &probability.to_string(), 74 | )?; 75 | Ok(()) 76 | } 77 | } 78 | 79 | pub struct EmptyBeliefTable; 80 | 81 | impl EmptyBeliefTable { 82 | pub fn new_shared(_namespace: &str) -> Result, Box> { 83 | Ok(Arc::new(EmptyBeliefTable {})) 84 | } 85 | } 86 | 87 | impl BeliefTable for EmptyBeliefTable { 88 | // Return Some if the probability exists in the table, or else None. 89 | fn get_proposition_probability( 90 | &self, 91 | connection: &mut Connection, 92 | proposition: &Proposition, 93 | ) -> Result, Box> { 94 | if proposition.predicate.relation.relation_name == existence_predicate_name() { 95 | return Ok(Some(1f64)); 96 | } 97 | Ok(None) 98 | } 99 | 100 | fn store_proposition_probability( 101 | &self, 102 | connection: &mut Connection, 103 | proposition: &Proposition, 104 | probability: f64, 105 | ) -> Result<(), Box> { 106 | panic!("Shouldn't call this.") 107 | } 108 | } 109 | 110 | pub struct HashMapBeliefTable { 111 | evidence: RefCell>, 112 | } 113 | 114 | impl HashMapBeliefTable { 115 | pub fn new() -> Arc { 116 | Arc::new(HashMapBeliefTable { 117 | evidence: RefCell::new(HashMap::new()), 118 | }) 119 | } 120 | 121 | pub fn clear(&self, node: &PropositionNode) -> () { 122 | self.evidence.borrow_mut().remove(node); 123 | } 124 | } 125 | 126 | impl BeliefTable for HashMapBeliefTable { 127 | fn get_proposition_probability( 128 | &self, 129 | connection: &mut Connection, 130 | proposition: &Proposition, 131 | ) -> Result, Box> { 132 | if proposition.predicate.relation.relation_name == existence_predicate_name() { 133 | return Ok(Some(1f64)); 134 | } 135 | let node = PropositionNode::from_single(proposition); 136 | let map = self.evidence.borrow(); 137 | let result = map.get(&node); 138 | Ok(result.copied()) 139 | } 140 | 141 | fn store_proposition_probability( 142 | &self, 143 | connection: &mut Connection, 144 | proposition: &Proposition, 145 | probability: f64, 146 | ) -> Result<(), Box> { 147 | let node = PropositionNode::from_single(proposition); 148 | // Use `borrow_mut` to get a mutable reference to the HashMap 149 | self.evidence.borrow_mut().insert(node, probability); 150 | Ok(()) 151 | } 152 | } 153 | -------------------------------------------------------------------------------- /rust/src/common/redis.rs: -------------------------------------------------------------------------------- 1 | use redis::Commands; 2 | use redis::Connection; 3 | use std::cell::RefCell; 4 | use std::error::Error; 5 | use std::sync::Arc; 6 | use std::sync::Mutex; 7 | 8 | pub struct RedisManager { 9 | client: redis::Client, 10 | } 11 | 12 | impl RedisManager { 13 | pub fn new() -> Result> { 14 | let client = 15 | redis::Client::open("redis://127.0.0.1/").expect("Could not connect to Redis."); // Replace with your Redis server URL 16 | let redis_client = RedisManager { client }; 17 | Ok(redis_client) 18 | } 19 | 20 | pub fn get_connection(&self) -> Result, Box> { 21 | let connection = self 22 | .client 23 | .get_connection() 24 | .expect("Couldn't get connection."); 25 | let refcell = RefCell::new(connection); 26 | Ok(refcell) 27 | } 28 | 29 | pub fn get_mutex_guarded_connection(&self) -> Result, Box> { 30 | let connection = self 31 | .client 32 | .get_connection() 33 | .expect("Couldn't get connection."); 34 | let refcell = Mutex::new(connection); 35 | Ok(refcell) 36 | } 37 | 38 | pub fn get_arc_mutex_guarded_connection(&self) -> Result>, Box> { 39 | let connection = self 40 | .client 41 | .get_connection() 42 | .expect("Couldn't get connection."); 43 | let refcell = Arc::new(Mutex::new(connection)); 44 | Ok(refcell) 45 | } 46 | } 47 | 48 | fn namespace_qualified_key(namespace: &str, key: &str) -> String { 49 | format!("bayes-star:{namespace}:{key}") 50 | } 51 | 52 | pub fn set_value( 53 | conn: &mut Connection, 54 | namespace: &str, 55 | key: &str, 56 | value: &str, 57 | ) -> Result<(), Box> { 58 | let nskey = &namespace_qualified_key(namespace, key); 59 | conn.set(nskey, value)?; 60 | Ok(()) 61 | } 62 | 63 | pub fn get_value( 64 | conn: &mut Connection, 65 | namespace: &str, 66 | key: &str, 67 | ) -> Result, Box> { 68 | let nskey = &namespace_qualified_key(namespace, key); 69 | let value: Option = conn.get(nskey)?; 70 | trace!("nskey: {nskey}, value: {:?}", &value); 71 | Ok(value) 72 | } 73 | 74 | pub fn map_insert( 75 | conn: &mut Connection, 76 | namespace: &str, 77 | key: &str, 78 | field: &str, 79 | value: &str, 80 | ) -> Result<(), Box> { 81 | let nskey = &namespace_qualified_key(namespace, key); 82 | conn.hset(nskey, field, value)?; 83 | Ok(()) 84 | } 85 | 86 | pub fn map_get( 87 | conn: &mut Connection, 88 | namespace: &str, 89 | key: &str, 90 | field: &str, 91 | ) -> Result, Box> { 92 | let nskey = &namespace_qualified_key(namespace, key); 93 | let value: Option = conn.hget(nskey, field)?; 94 | Ok(value) 95 | } 96 | 97 | pub fn set_add(conn: &mut Connection, namespace: &str, key: &str, member: &str) -> Result> { 98 | let nskey = &namespace_qualified_key(namespace, key); 99 | let added: bool = conn.sadd(nskey, member)?; 100 | Ok(added) 101 | } 102 | 103 | pub fn set_members(conn: &mut Connection, namespace: &str, key: &str) -> Result, Box> { 104 | let nskey = &namespace_qualified_key(namespace, key); 105 | let members: Vec = conn.smembers(nskey)?; 106 | Ok(members) 107 | } 108 | 109 | pub fn is_member(conn: &mut Connection, namespace: &str, key: &str, member: &str) -> Result> { 110 | let nskey = &namespace_qualified_key(namespace, key); 111 | let is_member: bool = conn.sismember(nskey, member)?; 112 | Ok(is_member) 113 | } 114 | 115 | pub fn seq_push(conn: &mut Connection, namespace: &str, key: &str, value: &str) -> Result> { 116 | let nskey = &namespace_qualified_key(namespace, key); 117 | let length: i64 = conn.rpush(nskey, value)?; 118 | Ok(length) 119 | } 120 | 121 | // pub fn seq_pop(conn: &mut Connection, key: &str) -> Result, Box> { 122 | // let value: Option = conn.lpop(key, None)?; 123 | // Ok(value) 124 | // } 125 | 126 | pub fn seq_get_all(conn: &mut Connection, namespace: &str, key: &str) -> Result, Box> { 127 | let nskey = &namespace_qualified_key(namespace, key); 128 | let elements: Vec = conn.lrange(nskey, 0, -1)?; 129 | Ok(elements) 130 | } 131 | -------------------------------------------------------------------------------- /rust/src/common/resources.rs: -------------------------------------------------------------------------------- 1 | use std::{error::Error, sync::{Arc, Mutex}}; 2 | use super::{redis::RedisManager, setup::CommandLineOptions}; 3 | 4 | pub struct ResourceContext { 5 | pub connection: Arc>, 6 | } 7 | 8 | impl ResourceContext { 9 | pub fn new(options: &CommandLineOptions) -> Result> { 10 | let manager = RedisManager::new()?; 11 | let connection = manager.get_arc_mutex_guarded_connection()?; 12 | Ok(ResourceContext { 13 | connection, 14 | }) 15 | } 16 | } -------------------------------------------------------------------------------- /rust/src/common/setup.rs: -------------------------------------------------------------------------------- 1 | use crate::common::resources::ResourceContext; 2 | use clap::{App, Arg}; 3 | use env_logger::{Builder, Env}; 4 | use serde::Deserialize; 5 | use std::{io::Write, path::Path}; 6 | 7 | /// These options define the inputs from the user. 8 | /// Nothing is owned by basic data types so this class can be easily freely around. 9 | #[derive(Deserialize, Clone, Debug)] 10 | pub struct CommandLineOptions { 11 | pub scenario_name: String, 12 | pub test_scenario: Option, 13 | pub entities_per_domain: i32, 14 | pub print_training_loss: bool, 15 | pub test_example: Option, 16 | pub marginal_output_file: Option, 17 | } 18 | 19 | fn check_file_does_not_exist(file_name: &str) { 20 | if Path::new(file_name).exists() { 21 | panic!("File '{}' already exists!", file_name); 22 | } 23 | } 24 | 25 | pub fn parse_configuration_options() -> CommandLineOptions { 26 | Builder::from_env(Env::default().default_filter_or("info")) 27 | .format(|buf, record| { 28 | let file = record.file().unwrap_or("unknown"); 29 | let line = record.line().unwrap_or(0); 30 | writeln!( 31 | buf, 32 | "{} [{}:{}] {}", 33 | record.level(), 34 | file, 35 | line, 36 | record.args() 37 | ) 38 | }) 39 | .init(); 40 | let matches = App::new("BAYES STAR") 41 | .version("1.0") 42 | .author("Greg Coppola") 43 | .about("Efficient combination of First-Order Logic and Bayesian Networks.") 44 | .arg( 45 | Arg::with_name("entities_per_domain") 46 | .long("entities_per_domain") 47 | .value_name("NUMBER") 48 | .help("Sets the number of entities per domain") 49 | .takes_value(true) 50 | .default_value("1024"), 51 | ) 52 | .arg( 53 | Arg::with_name("print_training_loss") 54 | .long("print_training_loss") 55 | .help("Enables printing of training loss") 56 | .takes_value(false), // No value is expected, presence of flag sets it to true 57 | ) 58 | .arg( 59 | Arg::with_name("test_example") 60 | .long("test_example") 61 | .value_name("NUMBER") 62 | .help("Sets the test example number (optional)") 63 | .takes_value(true), // This argument is optional and takes a value 64 | ) 65 | .arg( 66 | Arg::with_name("scenario_name") 67 | .long("scenario_name") 68 | .value_name("STRING") 69 | .help("Sets the scenario name") 70 | .takes_value(true) 71 | .required(true), // Mark this argument as required 72 | ) 73 | .arg( 74 | Arg::with_name("test_scenario") 75 | .long("test_scenario") 76 | .value_name("STRING") 77 | .help("Test Scenario name") 78 | .takes_value(true) 79 | .required(false), // Mark this argument as required 80 | ) 81 | .arg( 82 | Arg::with_name("marginal_output_file") 83 | .long("marginal_output_file") 84 | .value_name("FILE") 85 | .help("Sets the file name for marginal output (optional)") 86 | .takes_value(true), // This argument is optional and takes a string value 87 | ) 88 | .get_matches(); 89 | let entities_per_domain: i32 = matches 90 | .value_of("entities_per_domain") 91 | .unwrap() // safe because we have a default value 92 | .parse() 93 | .expect("entities_per_domain needs to be an integer"); 94 | let print_training_loss = matches.is_present("print_training_loss"); 95 | let test_example: Option = matches.value_of("test_example").map(|v| { 96 | v.parse() 97 | .expect("test_example needs to be a positive integer or omitted") 98 | }); 99 | let marginal_output_file = matches.value_of("marginal_output_file").map(String::from); 100 | let scenario_name: String = matches 101 | .value_of("scenario_name") 102 | .expect("scenario_name is required") // As it's required, unwrap directly 103 | .to_string(); 104 | let test_scenario = matches.value_of("test_scenario").map(String::from); 105 | 106 | CommandLineOptions { 107 | scenario_name, 108 | test_scenario, 109 | entities_per_domain, 110 | print_training_loss, 111 | test_example, 112 | marginal_output_file, 113 | } 114 | } 115 | -------------------------------------------------------------------------------- /rust/src/common/test.rs: -------------------------------------------------------------------------------- 1 | use std::{collections::HashMap, error::Error, io, rc::Rc, sync::Arc}; 2 | 3 | use colored::Colorize; 4 | use redis::Connection; 5 | 6 | use crate::{ 7 | common::{ 8 | graph::InferenceGraph, 9 | model::InferenceModel, 10 | proposition_db::{EmptyBeliefTable, HashMapBeliefTable, RedisBeliefTable}, 11 | train::TrainingPlan, 12 | }, 13 | inference::{ 14 | graph::PropositionGraph, 15 | inference::Inferencer, 16 | table::{self, PropositionNode}, 17 | }, 18 | model::{exponential::ExponentialModel, objects::Proposition}, 19 | print_blue, print_green, print_red, print_yellow, 20 | }; 21 | 22 | use super::{interface::BeliefTable, resources::ResourceContext, setup::CommandLineOptions}; 23 | 24 | pub struct ReplState { 25 | pub inferencer: Box, 26 | pub fact_memory: Arc, 27 | /// Relative set by the `print_ordering` last time it serialized an ordering. 28 | pub question_index: HashMap, 29 | pub proposition_index: HashMap, 30 | } 31 | 32 | impl ReplState { 33 | pub fn new(mut inferencer: Box) -> ReplState { 34 | let fact_memory = HashMapBeliefTable::new(); 35 | inferencer.fact_memory = fact_memory.clone(); 36 | let proposition_index = make_proposition_map(&inferencer.proposition_graph); 37 | ReplState { 38 | inferencer, 39 | fact_memory, 40 | question_index: HashMap::new(), 41 | proposition_index, 42 | } 43 | } 44 | 45 | pub fn set_pairs_by_name( 46 | &mut self, 47 | connection: &mut Connection, 48 | pairs: &Vec<(&str, f64)>, 49 | ) -> Option { 50 | assert!(pairs.len() <= 1); 51 | for pair in pairs { 52 | let key = pair.0.to_string(); 53 | trace!("key {key}"); 54 | let node = self.proposition_index.get(&key).unwrap(); 55 | let prop = node.extract_single(); 56 | trace!("setting {} to {}", &key, pair.1); 57 | self.fact_memory 58 | .store_proposition_probability(connection, &prop, pair.1) 59 | .unwrap(); 60 | self.inferencer 61 | .do_fan_out_from_node(connection, &node) 62 | .unwrap(); 63 | return Some(node.clone()); 64 | } 65 | None 66 | } 67 | } 68 | 69 | fn make_proposition_map(graph: &PropositionGraph) -> HashMap { 70 | let bfs = graph.get_bfs_order(); 71 | let mut result = HashMap::new(); 72 | for (index, node) in bfs.iter().enumerate() { 73 | let name = node.debug_string(); 74 | trace!("name_key: {}", &name); 75 | result.insert(name, node.clone()); 76 | } 77 | result 78 | } 79 | -------------------------------------------------------------------------------- /rust/src/common/train.rs: -------------------------------------------------------------------------------- 1 | use crate::{ 2 | common::{ 3 | interface::BeliefTable, 4 | redis::{seq_get_all, seq_push}, 5 | }, 6 | model::{ 7 | self, 8 | exponential::ExponentialModel, 9 | objects::{ 10 | Domain, Entity, ImplicationFactor, Predicate, PredicateGroup, Proposition, 11 | PropositionGroup, 12 | }, 13 | }, 14 | print_yellow, 15 | }; 16 | use redis::{Commands, Connection}; 17 | use serde::Deserialize; 18 | use std::{ 19 | cell::RefCell, 20 | error::Error, 21 | sync::{Arc, Mutex}, 22 | }; 23 | 24 | use super::graph::InferenceGraph; 25 | use super::interface::ScenarioMaker; 26 | use super::model::FactorModel; 27 | use super::resources::ResourceContext; 28 | use super::{ 29 | interface::{PredictStatistics, TrainStatistics}, 30 | model::FactorContext, 31 | redis::RedisManager, 32 | }; 33 | use crate::common::model::InferenceModel; 34 | use crate::common::proposition_db::RedisBeliefTable; 35 | use crate::model::choose::extract_backimplications_from_proposition; 36 | use std::borrow::BorrowMut; 37 | 38 | pub struct TrainingPlan { 39 | namespace: String, 40 | } 41 | 42 | impl TrainingPlan { 43 | pub fn new(namespace: String) -> Result> { 44 | Ok(TrainingPlan { namespace }) 45 | } 46 | 47 | pub fn add_proposition_to_queue( 48 | &mut self, 49 | connection: &mut Connection, 50 | queue_name: &String, 51 | proposition: &Proposition, 52 | ) -> Result<(), Box> { 53 | trace!( 54 | "GraphicalModel::add_to_training_queue - Start. Input proposition: {:?}", 55 | proposition 56 | ); 57 | let serialized_proposition = match serde_json::to_string(proposition) { 58 | Ok(record) => record, 59 | Err(e) => { 60 | trace!( 61 | "GraphicalModel::add_to_training_queue - Error serializing proposition: {}", 62 | e 63 | ); 64 | return Err(Box::new(e)); 65 | } 66 | }; 67 | trace!( 68 | "GraphicalModel::add_to_training_queue - Serialized proposition: {}", 69 | &serialized_proposition 70 | ); 71 | seq_push( 72 | connection, 73 | &self.namespace, 74 | &queue_name, 75 | &serialized_proposition, 76 | )?; 77 | trace!("GraphicalModel::add_to_training_queue - Proposition added to training queue successfully"); 78 | Ok(()) 79 | } 80 | 81 | pub fn maybe_add_to_training( 82 | &mut self, 83 | connection: &mut Connection, 84 | is_training: bool, 85 | proposition: &Proposition, 86 | ) -> Result<(), Box> { 87 | if is_training { 88 | self.add_proposition_to_queue(connection, &"training_queue".to_string(), &proposition) 89 | } else { 90 | Ok(()) 91 | } 92 | } 93 | 94 | pub fn maybe_add_to_test( 95 | &mut self, 96 | connection: &mut Connection, 97 | is_test: bool, 98 | proposition: &Proposition, 99 | ) -> Result<(), Box> { 100 | if is_test { 101 | self.add_proposition_to_queue(connection, &"test_queue".to_string(), &proposition) 102 | } else { 103 | Ok(()) 104 | } 105 | } 106 | 107 | fn get_propositions_from_queue( 108 | &self, 109 | connection: &mut Connection, 110 | seq_name: &String, 111 | ) -> Result, Box> { 112 | trace!( 113 | "GraphicalModel::get_propositions_from_queue - Start. Queue name: {}", 114 | seq_name 115 | ); 116 | let records = seq_get_all(connection, &self.namespace, &seq_name)?; 117 | let mut result = vec![]; 118 | for record in &records { 119 | let proposition = deserialize_record(record)?; 120 | result.push(proposition); 121 | } 122 | trace!("GraphicalModel::get_propositions_from_queue - Retrieved and deserialized propositions successfully"); 123 | Ok(result) 124 | } 125 | 126 | pub fn get_training_questions( 127 | &self, 128 | connection: &mut Connection, 129 | ) -> Result, Box> { 130 | let training_queue_name = String::from("training_queue"); 131 | self.get_propositions_from_queue(connection, &training_queue_name) 132 | } 133 | 134 | pub fn get_test_questions( 135 | &self, 136 | connection: &mut Connection, 137 | ) -> Result, Box> { 138 | let test_queue_name = String::from("test_queue"); 139 | self.get_propositions_from_queue(connection, &test_queue_name) 140 | } 141 | } 142 | 143 | fn deserialize_record<'a, T>(record: &'a str) -> Result> 144 | where 145 | T: Deserialize<'a>, 146 | { 147 | serde_json::from_str(record).map_err(|e| Box::new(e) as Box) 148 | } 149 | 150 | // Probabilities are either 0 or 1, so assume independent, i.e., just boolean combine them as AND. 151 | fn extract_group_probability_for_training( 152 | connection: &mut Connection, 153 | proposition_db: &Box, 154 | premise: &PropositionGroup, 155 | ) -> Result> { 156 | let mut product = 1f64; 157 | for term in &premise.terms { 158 | let part = proposition_db 159 | .get_proposition_probability(connection, term)? 160 | .unwrap(); 161 | product *= part; 162 | } 163 | Ok(product) 164 | } 165 | 166 | fn extract_factor_for_proposition_for_training( 167 | connection: &mut Connection, 168 | proposition_db: &Box, 169 | graph: &InferenceGraph, 170 | conclusion: Proposition, 171 | ) -> Result> { 172 | let factors = extract_backimplications_from_proposition(connection, graph, &conclusion)?; 173 | let mut probabilities = vec![]; 174 | for factor in &factors { 175 | let probability = 176 | extract_group_probability_for_training(connection, proposition_db, &factor.premise)?; 177 | probabilities.push(probability); 178 | } 179 | let result = FactorContext { 180 | factor: factors, 181 | probabilities, 182 | }; 183 | Ok(result) 184 | } 185 | 186 | pub fn do_training(resources: &ResourceContext, namespace: String) -> Result<(), Box> { 187 | let mut connection = resources.connection.lock().unwrap(); 188 | let graph = InferenceGraph::new_mutable(namespace.clone())?; 189 | let proposition_db = RedisBeliefTable::new_mutable(namespace.clone())?; 190 | let plan = TrainingPlan::new(namespace.clone())?; 191 | let mut factor_model = ExponentialModel::new_mutable(namespace.clone())?; 192 | trace!("do_training - Getting all implications"); 193 | let implications = graph.get_all_implications(&mut connection)?; 194 | for implication in implications { 195 | print_yellow!("do_training - Processing implication: {:?}", implication); 196 | factor_model.initialize_connection(&mut connection, &implication)?; 197 | } 198 | trace!("do_training - Getting all propositions"); 199 | let training_questions = plan.get_training_questions(&mut connection)?; 200 | trace!( 201 | "do_training - Processing propositions: {}", 202 | training_questions.len() 203 | ); 204 | let mut examples_processed = 0; 205 | for proposition in &training_questions { 206 | trace!("do_training - Processing proposition: {:?}", proposition); 207 | let factor = extract_factor_for_proposition_for_training( 208 | &mut connection, 209 | &proposition_db, 210 | &graph, 211 | proposition.clone(), 212 | )?; 213 | trace!("do_training - Backimplications: {:?}", &factor); 214 | let probabiity_opt = 215 | proposition_db.get_proposition_probability(&mut connection, proposition)?; 216 | let probability = probabiity_opt.expect("Probability should exist."); 217 | let _stats = factor_model.train(&mut connection, &factor, probability)?; 218 | examples_processed += 1; 219 | } 220 | trace!( 221 | "do_training - Training complete: examples processed {}", 222 | examples_processed 223 | ); 224 | Ok(()) 225 | } 226 | 227 | pub fn setup_and_train( 228 | resources: &ResourceContext, 229 | scenario_maker: &dyn ScenarioMaker, 230 | namespace: &str, 231 | ) -> Result<(), Box> { 232 | let model_spec = "dummy_model_spec".to_string(); 233 | let result = scenario_maker.setup_scenario(resources); 234 | trace!("scenario result: {:?}", result); 235 | let train_result = do_training(resources, namespace.to_string()); 236 | trace!("train result: {:?}", train_result); 237 | Ok(()) 238 | } 239 | -------------------------------------------------------------------------------- /rust/src/explorer/assets/animation.css: -------------------------------------------------------------------------------- 1 | .card-container { 2 | position: relative; 3 | width: 100%; 4 | text-align: center; 5 | overflow: hidden; 6 | } 7 | 8 | .animation-card { 9 | display: none; 10 | /* Hide all cards initially */ 11 | width: 100%; 12 | background-color: #fff; 13 | border: 1px solid #ddd; 14 | display: flex; 15 | justify-content: center; 16 | align-items: center; 17 | font-size: 24px; 18 | color: #333; 19 | transition: transform 0.3s ease; 20 | } 21 | 22 | .arrow { 23 | position: absolute; 24 | top: 50%; 25 | transform: translateY(-50%); 26 | font-size: 24px; 27 | background-color: rgba(0, 0, 0, 0.5); 28 | color: white; 29 | border: none; 30 | padding: 10px; 31 | cursor: pointer; 32 | } 33 | 34 | #left-arrow { 35 | left: 10px; 36 | } 37 | 38 | #right-arrow { 39 | right: 10px; 40 | } 41 | 42 | #index-display { 43 | margin: 10px; 44 | font-size: 128px; 45 | background-color: black; 46 | color: white; 47 | border: 12px red solid; 48 | } -------------------------------------------------------------------------------- /rust/src/explorer/assets/app.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | Bayes Star 7 | 10 | 11 | 12 |
13 |

BAYES STARexplorer

14 |
15 |
16 | {body_html} 17 |
18 | 19 | 20 | -------------------------------------------------------------------------------- /rust/src/explorer/assets/slides.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | Animation 8 | 11 | 12 | 13 | 14 |
Card 1 of 10
15 |
16 | {body_html} 17 | 18 | 19 | 20 |
21 | 22 | 55 | 56 | 57 | 58 | -------------------------------------------------------------------------------- /rust/src/explorer/assets/style.css: -------------------------------------------------------------------------------- 1 | body, html { 2 | margin: 0; 3 | padding: 0; 4 | height: 100%; 5 | } 6 | 7 | header { 8 | background-color: black; 9 | width: 100%; 10 | padding: 20px 0; 11 | text-align: center; 12 | } 13 | 14 | #product-name { 15 | color: white; 16 | font-size: 48px; /* You can adjust the size as needed */ 17 | margin: 0; 18 | } 19 | 20 | #sub-title { 21 | color: yellow; 22 | font-size: 24px; /* You can adjust the size as needed */ 23 | } 24 | 25 | main { 26 | background-color: white; 27 | padding: 20px; 28 | height: calc(100% - 60px); /* Adjust the height based on the header padding */ 29 | } 30 | 31 | 32 | .section_header { 33 | background-color: red; 34 | width: 100%; 35 | padding: 20px 10px; 36 | text-align: left; 37 | color: white; 38 | font-size: 24px; 39 | } 40 | 41 | .experiment_name { 42 | font-size: 54px; 43 | } 44 | 45 | .row_element { 46 | font-size: 24px; 47 | } 48 | 49 | .domain_icon { 50 | width: 200px; 51 | } 52 | 53 | .domain_label { 54 | font-size: 48px; 55 | } 56 | 57 | .relation_name { 58 | font-size: 54px; 59 | background-color: blue; 60 | color: white; 61 | } 62 | 63 | .role_name { 64 | font-size: 48px; 65 | background-color: yellow; 66 | color: black; 67 | } 68 | 69 | .and_separator { 70 | background-color: black; 71 | color: white; 72 | font-size: 80px; 73 | } 74 | 75 | .implication_box { 76 | border: 10px solid black; 77 | display: flex; 78 | flex-direction: row; 79 | } 80 | 81 | .implication_divider { 82 | background-color: red; 83 | font-size: 80px; 84 | color: white; 85 | } 86 | 87 | 88 | .network_cell { 89 | background-color: green; 90 | text-align: center; 91 | } 92 | 93 | .network_column { 94 | background-color: orange; 95 | } 96 | 97 | .network_row { 98 | background-color: plum; 99 | display: flex; 100 | justify-content: center; 101 | align-items: center; 102 | flex-direction: row; 103 | border: 5px solid red; 104 | } 105 | 106 | .proof_box { 107 | display: flex; 108 | justify-content: center; 109 | align-items: center; 110 | flex-direction: column; 111 | } 112 | 113 | .weight_box { 114 | background-color: yellow; 115 | } 116 | 117 | .weight_box_row { 118 | display: flex; 119 | justify-content: center; 120 | align-items: center; 121 | flex-direction: row; 122 | border: 5px solid paleturquoise; 123 | } 124 | 125 | .weight_box_cell { 126 | margin: 20px; 127 | padding: 20px; 128 | } 129 | 130 | .positive_weight { 131 | background-color: darkgreen; 132 | color: white; 133 | } 134 | 135 | .negative_weight { 136 | background-color: darkred; 137 | color: white; 138 | } 139 | 140 | .neutral_weight { 141 | background-color: black; 142 | color: white; 143 | } 144 | 145 | .factor_box { 146 | background-color: teal; 147 | color: white; 148 | font-weight: bold; 149 | border: 10px orange dotted; 150 | } 151 | 152 | .factor_parent_box { 153 | display: flex; 154 | justify-content: center; 155 | align-items: center; 156 | flex-direction: row; 157 | border: 5px solid paleturquoise; 158 | } 159 | 160 | .marginal { 161 | background-color: orangered; 162 | color: white; 163 | font-size: 32px; 164 | } -------------------------------------------------------------------------------- /rust/src/explorer/diagram_utils.rs: -------------------------------------------------------------------------------- 1 | use crate::{ 2 | inference::{graph::PropositionFactor, inference::MarginalTable}, 3 | model::objects::{ 4 | Argument, ImplicationFactor, Predicate, PredicateGroup, Proposition, PropositionGroup, 5 | Relation, 6 | }, 7 | }; 8 | 9 | fn diagram_domain(domain: &str) -> String { 10 | format!( 11 | r#" 12 | 13 | {domain} 14 | 15 | 16 | "# 17 | ) 18 | } 19 | 20 | fn diagram_argument(arg: &Argument) -> String { 21 | match arg { 22 | Argument::Constant(const_arg) => { 23 | format!( 24 | "
Constant Argument:
Domain: {}
Entity ID: {}
", 25 | const_arg.domain, const_arg.entity_id 26 | ) 27 | } 28 | Argument::Variable(var_arg) => diagram_domain(&var_arg.domain), 29 | } 30 | } 31 | 32 | fn diagram_relation(relation: &Relation) -> String { 33 | let mut argument_part = "".to_string(); 34 | for argument in &relation.types { 35 | argument_part += &format!( 36 | "{domain}", 37 | domain = &argument.domain 38 | ); 39 | } 40 | format!( 41 | r#" 42 | 43 | 44 | {relation_name} 45 | 46 | {argument_part} 47 | 48 | "#, 49 | relation_name = &relation.relation_name 50 | ) 51 | } 52 | 53 | pub fn diagram_proposition( 54 | proposition: &Proposition, 55 | marginal_table: Option<&MarginalTable>, 56 | ) -> String { 57 | let score_part = match marginal_table { 58 | Some(table) => { 59 | let marginal = table.get_marginal(proposition).unwrap(); 60 | let color = if marginal < 0.5 { 61 | format!( 62 | "rgb({}, {}, 0)", 63 | (255.0 * (1.0 - marginal * 2.0)) as u8, 64 | (255.0 * marginal * 2.0) as u8 65 | ) 66 | } else { 67 | format!( 68 | "rgb(0, {}, {})", 69 | (255.0 * (marginal - 0.5) * 2.0) as u8, 70 | (255.0 * (1.0 - (marginal - 0.5) * 2.0)) as u8 71 | ) 72 | }; 73 | format!( 74 | "{}", 75 | color, marginal 76 | ) 77 | } 78 | None => "".to_string(), 79 | }; 80 | format!( 81 | r#" 82 | 83 | 84 | {predicate_part} 85 | 86 | {score_part} 87 | 88 | "#, 89 | predicate_part = &diagram_predicate(&proposition.predicate), 90 | ) 91 | } 92 | 93 | pub fn diagram_predicate(predicate: &Predicate) -> String { 94 | let mut argument_buffer = "".to_string(); 95 | for argument in &predicate.roles { 96 | let argument_part = diagram_argument(&argument.argument); 97 | argument_buffer += &format!( 98 | "{role_name}{argument_part}", 99 | role_name = &argument.role_name 100 | ); 101 | } 102 | format!( 103 | r#" 104 | 105 | 106 | {relation_name} 107 | 108 | {argument_buffer} 109 | 110 | "#, 111 | relation_name = &predicate.relation.relation_name 112 | ) 113 | } 114 | 115 | fn diagram_predicate_group(group: &PredicateGroup) -> String { 116 | let mut parts = vec![]; 117 | for predicate in &group.terms { 118 | parts.push(diagram_predicate(predicate)); 119 | } 120 | let separator = ""; // Customize as needed 121 | let joined_parts = parts.join(separator); 122 | format!("
{}
", joined_parts) 123 | } 124 | 125 | pub fn diagram_implication(relation: &ImplicationFactor) -> String { 126 | format!( 127 | r#" 128 |
129 |
130 | {predicate_group_part} 131 |
132 |
133 | ==> 134 |
135 |
136 | {conclusion_part} 137 |
138 |
139 | "#, 140 | predicate_group_part = diagram_predicate_group(&relation.premise), 141 | conclusion_part = diagram_predicate(&relation.conclusion), 142 | ) 143 | } 144 | 145 | pub fn diagram_proposition_factor( 146 | relation: &PropositionFactor, 147 | marginal_table: Option<&MarginalTable>, 148 | ) -> String { 149 | format!( 150 | r#" 151 |
152 |
153 | {predicate_group_part} 154 |
155 |
156 | ==> 157 |
158 |
159 | {conclusion_part} 160 |
161 |
162 | "#, 163 | predicate_group_part = diagram_proposition_group(&relation.premise), 164 | conclusion_part = diagram_proposition(&relation.conclusion, marginal_table), 165 | ) 166 | } 167 | 168 | pub fn diagram_proposition_group(group: &PropositionGroup) -> String { 169 | let mut parts = vec![]; 170 | for predicate in &group.terms { 171 | parts.push(diagram_predicate(&predicate.predicate)); 172 | } 173 | let separator = ""; // Customize as needed 174 | let joined_parts = parts.join(separator); 175 | format!("
{}
", joined_parts) 176 | } 177 | 178 | // pub fn diagram_proposition_group(proposition_group: &PropositionGroup) -> String { 179 | // let parts: Vec = proposition_group 180 | // .terms 181 | // .iter() 182 | // .map(|f| "".to_string()) 183 | // .collect(); 184 | // format!(r#" 185 | //
186 | // {proposition_group_part} 187 | //
188 | // "#, 189 | // proposition_group_part = parts.join(""), 190 | // ) 191 | // } 192 | -------------------------------------------------------------------------------- /rust/src/explorer/mod.rs: -------------------------------------------------------------------------------- 1 | pub mod diagram_utils; 2 | pub mod render_utils; 3 | pub mod routes; -------------------------------------------------------------------------------- /rust/src/explorer/render_utils.rs: -------------------------------------------------------------------------------- 1 | use std::collections::HashMap; 2 | use std::error::Error; 3 | use std::fs; 4 | use std::fs::File; 5 | use std::io::{self, Read}; 6 | use std::path::{Path, PathBuf}; 7 | use walkdir::WalkDir; 8 | 9 | fn collect_files_with_extension(dir: &Path, extension: &str) -> Vec { 10 | WalkDir::new(dir) 11 | .into_iter() 12 | .filter_map(|e| e.ok()) 13 | .filter_map(|entry| { 14 | let path = entry.path().to_path_buf(); 15 | if path.is_file() && path.extension().and_then(|ext| ext.to_str()) == Some(extension) { 16 | Some(path) 17 | } else { 18 | None 19 | } 20 | }) 21 | .collect() 22 | } 23 | 24 | fn concatenate_file_contents(files: Vec) -> Result { 25 | let mut contents = String::new(); 26 | for file in files { 27 | trace!("reading file: {:?}", &file); 28 | let file_contents = fs::read_to_string(file)?; 29 | contents.push_str(&file_contents); 30 | contents.push_str("\n\n"); 31 | } 32 | Ok(contents) 33 | } 34 | 35 | pub fn read_all_css(dir_path: &Path) -> String { 36 | collate_files_generic(dir_path, "css").unwrap() 37 | } 38 | 39 | pub fn read_all_js(dir_path: &Path) -> String { 40 | collate_files_generic(dir_path, "js").unwrap() 41 | } 42 | 43 | fn collate_files_generic(dir_path: &Path, extension: &str) -> Result { 44 | let files = collect_files_with_extension(dir_path, extension); 45 | let contents = concatenate_file_contents(files)?; 46 | Ok(contents) 47 | } 48 | 49 | pub fn read_file_contents>(path: P) -> io::Result { 50 | let mut file = File::open(path)?; 51 | let mut contents = String::new(); 52 | file.read_to_string(&mut contents)?; 53 | Ok(contents) 54 | } 55 | 56 | pub fn do_replaces(base: &String, subs: &HashMap) -> String { 57 | let mut buffer = base.clone(); 58 | for (key, value) in subs { 59 | buffer = buffer.replace(key, value); 60 | } 61 | buffer 62 | } 63 | 64 | pub fn render_component(body_path: &str, subs: &HashMap) -> String { 65 | trace!("body_path {body_path}"); 66 | let raw_body = read_file_contents(body_path).unwrap(); 67 | let new_body = do_replaces(&raw_body, subs); 68 | new_body 69 | } 70 | 71 | pub fn render_against_custom_body(body_html: &str, body_path: &str) -> Result> { 72 | let raw_body = read_file_contents(body_path).unwrap(); 73 | let mut subs = HashMap::new(); 74 | subs.insert("{body_html}".to_string(), body_html.to_string()); 75 | let html_root = Path::new("."); 76 | subs.insert("/* css here */".to_string(), read_all_css(html_root)); 77 | let new_body = do_replaces(&raw_body, &subs); 78 | Ok(new_body) 79 | } 80 | 81 | pub fn render_app_body(body_html: &str) -> Result> { 82 | let body_path = "src/explorer/assets/app.html"; 83 | render_against_custom_body(body_html, body_path) 84 | } 85 | -------------------------------------------------------------------------------- /rust/src/explorer/routes/animation_route.rs: -------------------------------------------------------------------------------- 1 | use std::error::Error; 2 | 3 | use redis::Connection; 4 | use rocket::response::content::Html; 5 | 6 | use crate::{ 7 | common::{ 8 | graph::InferenceGraph, model::InferenceModel, proposition_db::EmptyBeliefTable, 9 | resources::ResourceContext, 10 | }, 11 | explorer::{ 12 | diagram_utils::{diagram_predicate, diagram_proposition, diagram_proposition_factor}, 13 | render_utils::{render_against_custom_body, render_app_body}, 14 | }, 15 | inference::{ 16 | graph::PropositionGraph, 17 | inference::{Inferencer, MarginalTable}, 18 | rounds::run_inference_rounds, 19 | table::PropositionNode, 20 | }, 21 | model::{ 22 | choose::extract_backimplications_from_proposition, 23 | objects::{Proposition, PropositionGroup}, 24 | }, 25 | }; 26 | 27 | fn backwards_print_group_with_marginal_table( 28 | connection: &mut Connection, 29 | inferencer: &Inferencer, 30 | target: &PropositionGroup, 31 | table: &MarginalTable, 32 | ) -> Result> { 33 | let proposition_node = PropositionNode::from_group(&target); 34 | let backlinks = inferencer 35 | .proposition_graph 36 | .get_all_backward(&proposition_node); 37 | let mut buffer = "".to_string(); 38 | for backlink in &backlinks { 39 | let single = backlink.extract_single(); 40 | let part = 41 | backwards_print_single_with_marginal_table(connection, inferencer, &single, table)?; 42 | buffer += ∂ 43 | } 44 | Ok(buffer) 45 | } 46 | 47 | fn backwards_print_single_with_marginal_table( 48 | connection: &mut Connection, 49 | inferencer: &Inferencer, 50 | target: &Proposition, 51 | table: &MarginalTable, 52 | ) -> Result> { 53 | let proposition_node = PropositionNode::from_single(&target); 54 | let backlinks = inferencer 55 | .proposition_graph 56 | .get_all_backward(&proposition_node); 57 | let mut buffer = "".to_string(); 58 | buffer += &format!(r#"
"#,); 59 | buffer += &format!(r#"
"#,); 60 | for backlink in &backlinks { 61 | let group = backlink.extract_group(); 62 | let part = 63 | backwards_print_group_with_marginal_table(connection, inferencer, &group, table)?; 64 | buffer += ∂ 65 | } 66 | buffer += &format!(r#"
"#,); // network_row 67 | let backimplications = 68 | extract_backimplications_from_proposition(connection, &inferencer.model.graph, target) 69 | .unwrap(); 70 | buffer += &format!(r#"
"#,); 71 | for backimplication in &backimplications { 72 | buffer += &format!( 73 | r#" 74 | 75 | {implication_part} 76 | 77 | "#, 78 | implication_part = diagram_proposition_factor(backimplication, Some(table)) 79 | ); 80 | } 81 | buffer += &format!(r#"
"#,); // network_row 82 | buffer += &format!( 83 | r#" 84 |
85 | {target_part} 86 |
87 | "#, 88 | target_part = diagram_proposition(target, Some(&table)) 89 | ); 90 | buffer += &format!(r#"
"#,); // "proof_box" 91 | Ok(buffer) 92 | } 93 | 94 | fn safe_network_animations( 95 | connection: &mut Connection, 96 | namespace: &str, 97 | marginal_tables: &Vec, 98 | ) -> Result> { 99 | let graph = InferenceGraph::new_shared(namespace.to_string())?; 100 | let target = graph.get_target(connection)?; 101 | let proposition_graph = PropositionGraph::new_shared(connection, &graph, target)?; 102 | proposition_graph.visualize(); 103 | let model = InferenceModel::new_shared(namespace.to_string()).unwrap(); 104 | let fact_memory = EmptyBeliefTable::new_shared(namespace)?; 105 | let inferencer = 106 | Inferencer::new_mutable(model.clone(), proposition_graph.clone(), fact_memory)?; 107 | let mut result = "".to_string(); 108 | for table in marginal_tables { 109 | result += &format!(r#"
"#,); 110 | result += &backwards_print_single_with_marginal_table( 111 | connection, 112 | &inferencer, 113 | &inferencer.proposition_graph.target, 114 | table, 115 | )?; 116 | result += &format!(r#"
"#,); // "animation-card" 117 | } 118 | Ok(result) 119 | } 120 | 121 | pub fn internal_animation( 122 | experiment_name: &str, 123 | test_scenario: &str, 124 | resource_context: &ResourceContext, 125 | ) -> Html { 126 | let mut connection = resource_context.connection.lock().unwrap(); 127 | let marginal_tables = run_inference_rounds(&mut connection, experiment_name, test_scenario) 128 | .expect("Testing failed."); 129 | let body_html = 130 | safe_network_animations(&mut connection, experiment_name, &marginal_tables).unwrap(); 131 | // let result = render_app_body(&body_html); 132 | let body_path = "src/explorer/assets/slides.html"; 133 | let result = render_against_custom_body(&body_html, &body_path); 134 | Html(result.unwrap()) 135 | } 136 | -------------------------------------------------------------------------------- /rust/src/explorer/routes/experiment_route.rs: -------------------------------------------------------------------------------- 1 | use redis::Connection; 2 | use rocket::response::content::Html; 3 | 4 | use crate::{ 5 | common::{graph::InferenceGraph, redis::seq_push, resources::ResourceContext}, 6 | explorer::{diagram_utils::diagram_implication, render_utils::render_app_body}, 7 | }; 8 | 9 | fn render_domain_part(connection: &mut Connection, graph: &InferenceGraph) -> String { 10 | let mut buffer = format!( 11 | r#" 12 |
13 | Domains 14 |
15 | "# 16 | ); 17 | let all_domains = graph.get_all_domains(connection).unwrap(); 18 | println!("all_domains {:?}", &all_domains); 19 | for domain in &all_domains { 20 | let elements = graph.get_entities_in_domain(connection, domain).unwrap(); 21 | println!("elements: {:?}", &elements); 22 | buffer += &format!( 23 | r#" 24 |
25 | {domain} 26 | 27 |
28 | "#, 29 | ) 30 | } 31 | buffer 32 | } 33 | 34 | fn render_relation_part(connection: &mut Connection, graph: &InferenceGraph) -> String { 35 | let mut buffer = format!( 36 | r#" 37 |
38 | Relations 39 |
40 | "# 41 | ); 42 | let all_relations = graph.get_all_relations(connection).unwrap(); 43 | println!("all_relations {:?}", &all_relations); 44 | for relation in &all_relations { 45 | println!("relation {:?}", relation); 46 | buffer += &format!(r#"
"#); 47 | buffer += &format!( 48 | r#" {relation_name}"#, 49 | relation_name = &relation.relation_name 50 | ); 51 | for argument_type in &relation.types { 52 | buffer += &format!( 53 | r#" 54 | {domain_name} 55 | 56 | "#, 57 | domain_name = argument_type.domain 58 | ); 59 | } 60 | buffer += &format!(r#"
"#) 61 | } 62 | buffer 63 | } 64 | 65 | fn render_implication_part(connection: &mut Connection, graph: &InferenceGraph) -> String { 66 | let mut buffer = format!( 67 | r#" 68 |
69 | Implication Factors 70 |
71 | "# 72 | ); 73 | let all_relations = graph.get_all_implications(connection).unwrap(); 74 | println!("all_relations {:?}", &all_relations); 75 | for relation in &all_relations { 76 | buffer += &diagram_implication(relation); 77 | } 78 | buffer 79 | } 80 | 81 | fn render_experiment_parts(connection: &mut Connection, graph: &InferenceGraph) -> String { 82 | format!( 83 | r#" 84 | {domain_part} 85 | {relation_part} 86 | {implication_part} 87 | "#, 88 | domain_part = render_domain_part(connection, graph), 89 | relation_part = render_relation_part(connection, graph), 90 | implication_part = render_implication_part(connection, graph), 91 | ) 92 | } 93 | 94 | fn render_experiment_name(experiment_name: &str) -> String { 95 | format!( 96 | r#" 97 |
98 | Experiment 99 |
100 |
101 | {experiment_name} 102 |
103 | "# 104 | ) 105 | } 106 | 107 | pub fn internal_experiment(experiment_name: &str, resources: &ResourceContext) -> Html { 108 | let mut connection = resources.connection.lock().unwrap(); 109 | let graph = InferenceGraph::new_mutable(experiment_name.to_string()).unwrap(); 110 | // let graph = InferenceGraph::new_mutable(redis_connection, namespace) 111 | let body_html = format!( 112 | r#" 113 | {name_part} 114 | {main_part} 115 | "#, 116 | name_part = render_experiment_name(experiment_name), 117 | main_part = render_experiment_parts(&mut connection, &graph), 118 | ); 119 | let result = render_app_body(&body_html); 120 | Html(result.unwrap()) 121 | } 122 | -------------------------------------------------------------------------------- /rust/src/explorer/routes/factors_route.rs: -------------------------------------------------------------------------------- 1 | use std::error::Error; 2 | 3 | use redis::Connection; 4 | use rocket::response::content::Html; 5 | 6 | use crate::{ 7 | common::{ 8 | graph::InferenceGraph, model::InferenceModel, proposition_db::EmptyBeliefTable, 9 | resources::ResourceContext, 10 | }, 11 | explorer::{ 12 | diagram_utils::{ 13 | diagram_implication, diagram_predicate, diagram_proposition, diagram_proposition_group, 14 | }, 15 | render_utils::render_app_body, 16 | }, 17 | inference::{ 18 | graph::PropositionGraph, 19 | inference::{compute_each_combination, compute_factor_probability_table, Inferencer}, 20 | table::{FactorProbabilityTable, PropositionNode, VariableAssignment}, 21 | }, 22 | model::objects::Proposition, 23 | }; 24 | 25 | pub fn diagram_variable_assignment(assignment: &VariableAssignment) -> String { 26 | let mut html = 27 | String::from(""); 28 | let sorted_keys: Vec<_> = assignment.assignment_map.iter().collect(); 29 | for (key, value) in sorted_keys { 30 | let row = format!("", key, value); 31 | html.push_str(&row); 32 | } 33 | html.push_str("
PropositionNodeValue
{:?}{}
"); 34 | html 35 | } 36 | 37 | pub fn diagram_factor_table(table: &FactorProbabilityTable) -> String { 38 | let mut html = 39 | String::from(""); 40 | for (pair, probability) in &table.pairs { 41 | let assignment_html = diagram_variable_assignment(pair); 42 | let row = format!( 43 | "", 44 | assignment_html, probability 45 | ); 46 | html.push_str(&row); 47 | } 48 | html.push_str("
VariableAssignmentProbability
{}{}
"); 49 | html 50 | } 51 | 52 | fn graph_full_factor(inferencer: &Inferencer, target: &Proposition) -> String { 53 | let node = &PropositionNode::from_single(target); 54 | let mut buffer = "".to_string(); 55 | buffer += &format!("
"); 56 | buffer += &diagram_proposition(target, None); 57 | let parent_nodes = inferencer.proposition_graph.get_all_backward(node); 58 | buffer += &format!("
"); 59 | for parent_node in &parent_nodes { 60 | let proposition = parent_node.extract_group(); 61 | buffer += &diagram_proposition_group(&proposition); 62 | } 63 | buffer += &format!("
"); 64 | buffer += &format!("
"); 65 | buffer 66 | } 67 | 68 | fn compute_factor_probability_table_and_graph( 69 | connection: &mut Connection, 70 | inferencer: &Inferencer, 71 | node: &PropositionNode, 72 | ) -> Result> { 73 | let table = compute_factor_probability_table(connection, inferencer, node)?; 74 | let html = diagram_factor_table(&table); 75 | Ok(html) 76 | } 77 | 78 | fn iterate_through_factors( 79 | scenario_name: &str, 80 | resource_context: &ResourceContext, 81 | ) -> Result> { 82 | let model = InferenceModel::new_shared(scenario_name.to_string()).unwrap(); 83 | let fact_memory = EmptyBeliefTable::new_shared(scenario_name)?; 84 | let mut connection = resource_context.connection.lock().unwrap(); 85 | let target = model.graph.get_target(&mut connection)?; 86 | let proposition_graph = PropositionGraph::new_shared(&mut connection, &model.graph, target)?; 87 | let inferencer = 88 | Inferencer::new_mutable(model.clone(), proposition_graph.clone(), fact_memory)?; 89 | let mut buffer = "".to_string(); 90 | for single_node in &inferencer.bfs_order { 91 | if single_node.is_single() { 92 | let proposition = single_node.extract_single(); 93 | buffer += &graph_full_factor(&inferencer, &proposition); 94 | buffer += &compute_factor_probability_table_and_graph( 95 | &mut connection, 96 | &inferencer, 97 | single_node, 98 | )? 99 | } 100 | } 101 | Ok(buffer) 102 | } 103 | 104 | pub fn internal_factors(experiment_name: &str, resource_context: &ResourceContext) -> Html { 105 | let graph = InferenceGraph::new_mutable(experiment_name.to_string()).unwrap(); 106 | let body_html = iterate_through_factors(experiment_name, resource_context).unwrap(); 107 | let result = render_app_body(&body_html); 108 | Html(result.unwrap()) 109 | } 110 | -------------------------------------------------------------------------------- /rust/src/explorer/routes/index_route.rs: -------------------------------------------------------------------------------- 1 | use rocket::response::content::Html; 2 | 3 | use crate::explorer::render_utils::render_app_body; 4 | 5 | 6 | pub fn internal_index() -> Html { 7 | let result = render_app_body(""); 8 | Html(result.unwrap()) 9 | } -------------------------------------------------------------------------------- /rust/src/explorer/routes/marginals_route.rs: -------------------------------------------------------------------------------- 1 | use rocket::response::content::Html; 2 | 3 | use crate::{common::resources::ResourceContext, explorer::render_utils::render_app_body, inference::rounds::run_inference_rounds}; 4 | 5 | 6 | pub fn internal_marginals(experiment_name: &str, test_scenario: &str, resource_context: &ResourceContext) -> Html { 7 | let mut connection = resource_context.connection.lock().unwrap(); 8 | let marginal_tables = run_inference_rounds(&mut connection, experiment_name, test_scenario) 9 | .expect("Testing failed."); 10 | 11 | let mut body_html = "".to_string(); 12 | body_html += &format!("
"); 13 | for marginal_table in &marginal_tables { 14 | let html_part = marginal_table.render_marginal_table(); 15 | body_html += &html_part; 16 | } 17 | body_html += &format!("
"); 18 | let result = render_app_body(&body_html); 19 | Html(result.unwrap()) 20 | } 21 | -------------------------------------------------------------------------------- /rust/src/explorer/routes/mod.rs: -------------------------------------------------------------------------------- 1 | pub mod animation_route; 2 | pub mod experiment_route; 3 | pub mod factors_route; 4 | pub mod index_route; 5 | pub mod marginals_route; 6 | pub mod network_route; 7 | pub mod weights_route; -------------------------------------------------------------------------------- /rust/src/explorer/routes/network_route.rs: -------------------------------------------------------------------------------- 1 | use std::{error::Error, rc::Rc}; 2 | 3 | use redis::Connection; 4 | use rocket::response::content::Html; 5 | 6 | use crate::{ 7 | common::{ 8 | graph::InferenceGraph, model::InferenceModel, proposition_db::EmptyBeliefTable, 9 | resources::ResourceContext, setup::CommandLineOptions, train::TrainingPlan, 10 | }, 11 | explorer::{ 12 | diagram_utils::{diagram_implication, diagram_predicate, diagram_proposition_factor}, 13 | render_utils::render_app_body, 14 | }, 15 | inference::{graph::PropositionGraph, inference::Inferencer, table::PropositionNode}, 16 | model::{ 17 | choose::extract_backimplications_from_proposition, 18 | objects::{Proposition, PropositionGroup}, 19 | }, 20 | }; 21 | 22 | fn backwards_print_group( 23 | connection: &mut Connection, 24 | inferencer: &Inferencer, 25 | target: &PropositionGroup, 26 | ) -> Result> { 27 | let proposition_node = PropositionNode::from_group(&target); 28 | let backlinks = inferencer 29 | .proposition_graph 30 | .get_all_backward(&proposition_node); 31 | let mut buffer = "".to_string(); 32 | for backlink in &backlinks { 33 | let single = backlink.extract_single(); 34 | let part = backwards_print_single(connection, inferencer, &single)?; 35 | buffer += ∂ 36 | } 37 | Ok(buffer) 38 | } 39 | 40 | fn backwards_print_single( 41 | connection: &mut Connection, 42 | inferencer: &Inferencer, 43 | target: &Proposition, 44 | ) -> Result> { 45 | let proposition_node = PropositionNode::from_single(&target); 46 | let backlinks = inferencer 47 | .proposition_graph 48 | .get_all_backward(&proposition_node); 49 | let mut buffer = "".to_string(); 50 | buffer += &format!( r#"
"#,); 51 | buffer += &format!( r#"
"#,); 52 | for backlink in &backlinks { 53 | let group = backlink.extract_group(); 54 | let part = backwards_print_group(connection, inferencer, &group)?; 55 | buffer += ∂ 56 | } 57 | buffer += &format!( r#"
"#,); 58 | let backimplications = 59 | extract_backimplications_from_proposition(connection, &inferencer.model.graph, target) 60 | .unwrap(); 61 | buffer += &format!( r#"
"#,); 62 | for backimplication in &backimplications { 63 | buffer += &format!( 64 | r#" 65 | 66 | {implication_part} 67 | 68 | "#, 69 | implication_part = diagram_proposition_factor(backimplication, None) 70 | ); 71 | } 72 | buffer += &format!( r#"
"#,); 73 | buffer += &format!( 74 | r#" 75 |
76 | {target_part} 77 |
78 | "#, 79 | target_part = diagram_predicate(&target.predicate) 80 | ); 81 | buffer += &format!( r#"
"#,); // "proof_box" 82 | Ok(buffer) 83 | } 84 | 85 | fn render_network(bundle: &ResourceContext, namespace: &str) -> Result> { 86 | let graph = InferenceGraph::new_shared(namespace.to_string())?; 87 | let mut connection = bundle.connection.lock().unwrap(); 88 | let target = graph.get_target(&mut connection)?; 89 | let proposition_graph = PropositionGraph::new_shared(&mut connection, &graph, target)?; 90 | proposition_graph.visualize(); 91 | let model = InferenceModel::new_shared(namespace.to_string()).unwrap(); 92 | let fact_memory = EmptyBeliefTable::new_shared(namespace)?; 93 | let inferencer = 94 | Inferencer::new_mutable(model.clone(), proposition_graph.clone(), fact_memory)?; 95 | let result = backwards_print_single( 96 | &mut connection, 97 | &inferencer, 98 | &inferencer.proposition_graph.target, 99 | )?; 100 | Ok(result) 101 | } 102 | 103 | pub fn internal_network(experiment_name: &str, namespace: &ResourceContext) -> Html { 104 | let network = render_network(namespace, experiment_name).unwrap(); 105 | let body_html = format!( 106 | r#" 107 | {network} 108 | "#, 109 | ); 110 | let result = render_app_body(&body_html); 111 | Html(result.unwrap()) 112 | } 113 | -------------------------------------------------------------------------------- /rust/src/explorer/routes/weights_route.rs: -------------------------------------------------------------------------------- 1 | use redis::Connection; 2 | use rocket::response::content::Html; 3 | 4 | use crate::{ 5 | common::{graph::InferenceGraph, resources::ResourceContext}, 6 | explorer::{diagram_utils::diagram_implication, render_utils::render_app_body}, 7 | model::{ 8 | objects::ImplicationFactor, 9 | weights::{negative_feature, positive_feature, ExponentialWeights, CLASS_LABELS}, 10 | }, 11 | }; 12 | 13 | fn render_one_weight_box( 14 | connection: &mut Connection, 15 | graph: &InferenceGraph, 16 | factor: &ImplicationFactor, 17 | ) -> String { 18 | let weights = ExponentialWeights::new(graph.namespace.clone()).unwrap(); 19 | let feature = factor.unique_key(); 20 | let mut buffer = "".to_string(); 21 | buffer += &format!("
"); 22 | buffer += &format!( 23 | r#" 24 |
25 |
26 |
27 |
28 | false 29 |
30 |
31 | true 32 |
33 |
34 | "# 35 | ); 36 | for class_label in CLASS_LABELS { 37 | let posf = positive_feature(&feature, class_label); 38 | let negf = negative_feature(&feature, class_label); 39 | let posf_count = weights.read_single_weight(connection, &posf).unwrap(); 40 | let negf_count = weights.read_single_weight(connection, &negf).unwrap(); 41 | let posf_css = if posf_count > 0.1f64 { 42 | "positive_weight".to_string() 43 | } else if posf_count < -0.1f64 { 44 | "negative_weight".to_string() 45 | } else { 46 | "neutral_weight".to_string() 47 | }; 48 | let negf_css = if negf_count > 0.1f64 { 49 | "positive_weight".to_string() 50 | } else if negf_count < -0.1f64 { 51 | "negative_weight".to_string() 52 | } else { 53 | "neutral_weight".to_string() 54 | }; 55 | buffer += &format!( 56 | r#" 57 |
58 |
59 | {class_label} 60 |
61 |
62 | {negf_count} 63 |
64 |
65 | {posf_count} 66 |
67 |
68 | "# 69 | ); 70 | } 71 | buffer += &format!("
"); 72 | buffer 73 | } 74 | 75 | fn render_weights_part(connection: &mut Connection, graph: &InferenceGraph) -> String { 76 | let mut buffer = format!( 77 | r#" 78 |
79 | Implication Factors 80 |
81 | "# 82 | ); 83 | let all_relations = graph.get_all_implications(connection).unwrap(); 84 | println!("all_relations {:?}", &all_relations); 85 | for relation in &all_relations { 86 | buffer += &diagram_implication(relation); 87 | buffer += &render_one_weight_box(connection, graph, relation); 88 | } 89 | buffer 90 | } 91 | 92 | pub fn internal_weights(experiment_name: &str, resources: &ResourceContext) -> Html { 93 | let mut connection = resources.connection.lock().unwrap(); 94 | let graph = InferenceGraph::new_mutable(experiment_name.to_string()).unwrap(); 95 | let body_html = render_weights_part(&mut connection, &graph); 96 | let result = render_app_body(&body_html); 97 | Html(result.unwrap()) 98 | } 99 | -------------------------------------------------------------------------------- /rust/src/inference/graph.rs: -------------------------------------------------------------------------------- 1 | use std::{ 2 | collections::{HashMap, HashSet, VecDeque}, 3 | error::Error, 4 | rc::Rc, sync::Arc, 5 | }; 6 | 7 | use env_logger::init; 8 | use redis::Connection; 9 | use serde::{Deserialize, Serialize}; 10 | 11 | use crate::{ 12 | common::{graph::InferenceGraph, redis::RedisManager}, 13 | model::{ 14 | choose::{compute_search_predicates, extract_backimplications_from_proposition}, 15 | objects::{GroupRoleMap, ImplicationFactor, Proposition, PropositionGroup}, 16 | }, print_yellow, 17 | }; 18 | 19 | use super::table::{GenericNodeType, PropositionNode}; 20 | 21 | #[derive(Serialize, Deserialize, Debug, Clone)] 22 | pub struct PropositionFactor { 23 | pub premise: PropositionGroup, 24 | pub conclusion: Proposition, 25 | pub inference: ImplicationFactor, 26 | } 27 | 28 | impl PropositionFactor { 29 | pub fn debug_string(&self) -> String { 30 | format!( 31 | "{} -> {}", 32 | self.premise.hash_string(), 33 | self.conclusion.hash_string() 34 | ) 35 | } 36 | } 37 | 38 | /// This class does NOT store a link to any database. 39 | /// It is EXPENSIVE to copy, though.. should just be moved. 40 | pub struct PropositionGraph { 41 | pub single_forward: HashMap>, 42 | pub single_backward: HashMap>, 43 | pub group_forward: HashMap>, 44 | pub inference_used: HashMap<(PropositionGroup, Proposition), ImplicationFactor>, 45 | pub roots: HashSet, 46 | pub all_nodes: HashSet, 47 | pub target: Proposition, 48 | } 49 | 50 | fn initialize_visit_single( 51 | connection: &mut Connection, 52 | predicate_graph: &InferenceGraph, 53 | graph: &mut PropositionGraph, 54 | single: &Proposition, 55 | ) -> Result<(), Box> { 56 | trace!( 57 | "\x1b[32mInitializing visit for proposition: {:?}\x1b[0m", 58 | single.hash_string() 59 | ); 60 | graph 61 | .all_nodes 62 | .insert(PropositionNode::from_single(single)); 63 | let inference_factors = 64 | extract_backimplications_from_proposition(connection, predicate_graph, single)?; 65 | trace!( 66 | "\x1b[33mInference factors count: {}\x1b[0m", 67 | inference_factors.len() 68 | ); 69 | 70 | if inference_factors.is_empty() { 71 | trace!("\x1b[34mNo inference factors. Adding to roots.\x1b[0m"); 72 | graph.roots.insert(single.clone()); 73 | } else { 74 | for inference_factor in &inference_factors { 75 | trace!( 76 | "\x1b[36mProcessing inference factor: {:?}\x1b[0m", 77 | inference_factor.debug_string() 78 | ); 79 | let inference_used_key = (inference_factor.premise.clone(), inference_factor.conclusion.clone()); 80 | graph.inference_used.insert(inference_used_key, inference_factor.inference.clone()); 81 | 82 | trace!( 83 | "\x1b[36mUpdating single_backward for conclusion: {:?}\x1b[0m", 84 | inference_factor.conclusion.hash_string() 85 | ); 86 | graph 87 | .single_backward 88 | .entry(inference_factor.conclusion.clone()) 89 | .or_insert_with(HashSet::new) 90 | .insert(inference_factor.premise.clone()); 91 | 92 | trace!( 93 | "\x1b[36mUpdating group_forward for premise: {:?}\x1b[0m", 94 | inference_factor.premise.hash_string() 95 | ); 96 | graph 97 | .group_forward 98 | .entry(inference_factor.premise.clone()) 99 | .or_insert_with(HashSet::new) 100 | .insert(inference_factor.conclusion.clone()); 101 | 102 | graph 103 | .all_nodes 104 | .insert(PropositionNode::from_group(&inference_factor.premise)); 105 | 106 | for term in &inference_factor.premise.terms { 107 | trace!("\x1b[35mProcessing term: {:?}\x1b[0m", term.hash_string()); 108 | graph 109 | .single_forward 110 | .entry(term.clone()) 111 | .or_insert_with(HashSet::new) 112 | .insert(inference_factor.premise.clone()); 113 | trace!( 114 | "\x1b[35mRecursively initializing visit for term: {:?}\x1b[0m", 115 | term.hash_string() 116 | ); 117 | initialize_visit_single(connection, predicate_graph, graph, term)?; 118 | } 119 | } 120 | } 121 | trace!( 122 | "\x1b[32mFinished initializing visit for proposition: {:?}\x1b[0m", 123 | single.hash_string() 124 | ); 125 | Ok(()) 126 | } 127 | 128 | impl PropositionGraph { 129 | pub fn new_shared( 130 | connection: &mut Connection, 131 | predicate_graph: &InferenceGraph, 132 | target: Proposition, 133 | ) -> Result, Box> { 134 | let mut graph = PropositionGraph { 135 | single_forward: HashMap::new(), 136 | single_backward: HashMap::new(), 137 | group_forward: HashMap::new(), 138 | inference_used: HashMap::new(), 139 | roots: HashSet::new(), 140 | all_nodes: HashSet::new(), 141 | target: target.clone(), 142 | }; 143 | initialize_visit_single(connection, predicate_graph, &mut graph, &target)?; 144 | Ok(Arc::new(graph)) 145 | } 146 | 147 | pub fn get_inference_used(&self, premise:&PropositionGroup, conclusion: &Proposition) -> ImplicationFactor { 148 | let key = (premise.clone(), conclusion.clone()); 149 | self.inference_used 150 | .get(&key).unwrap().clone() 151 | } 152 | 153 | pub fn get_single_forward(&self, key: &Proposition) -> HashSet { 154 | self.single_forward 155 | .get(key) 156 | .cloned() 157 | .unwrap_or_else(HashSet::new) 158 | } 159 | 160 | pub fn get_single_backward(&self, key: &Proposition) -> HashSet { 161 | self.single_backward 162 | .get(key) 163 | .cloned() 164 | .unwrap_or_else(HashSet::new) 165 | } 166 | 167 | pub fn get_group_forward(&self, key: &PropositionGroup) -> HashSet { 168 | self.group_forward.get(key).unwrap().clone() 169 | } 170 | 171 | pub fn get_group_backward(&self, key: &PropositionGroup) -> Vec { 172 | key.terms.clone() 173 | } 174 | 175 | pub fn get_all_backward(&self, node: &PropositionNode) -> Vec { 176 | trace!("get_all_backward called for node: {:?}", node.debug_string()); 177 | let mut r = vec![]; 178 | match &node.node { 179 | GenericNodeType::Single(proposition) => { 180 | trace!("Processing as Single: {:?}", proposition.debug_string()); 181 | let initial = self.get_single_backward(proposition); 182 | trace!("Initial singles: {}", initial.len()); 183 | for group in &initial { 184 | trace!("Adding group from initial singles: {:?}", group.debug_string()); 185 | r.push(PropositionNode::from_group(group)); 186 | } 187 | } 188 | GenericNodeType::Group(group) => { 189 | trace!("Processing as Group: {:?}", group.debug_string()); 190 | let initial = self.get_group_backward(group); 191 | trace!("Initial groups: {}", initial.len()); 192 | for single in &initial { 193 | trace!("Adding single from initial groups: {:?}", single.debug_string()); 194 | r.push(PropositionNode::from_single(single)); 195 | } 196 | } 197 | } 198 | trace!("Resulting vector: {:?}", r); 199 | r 200 | } 201 | 202 | pub fn get_all_forward(&self, node: &PropositionNode) -> Vec { 203 | trace!("get_all_backward called for node: {:?}", node.debug_string()); 204 | let mut r = vec![]; 205 | match &node.node { 206 | GenericNodeType::Single(proposition) => { 207 | trace!("Processing as Single: {:?}", proposition.debug_string()); 208 | let initial = self.get_single_forward(proposition); 209 | trace!("Initial singles: {}", initial.len()); 210 | for group in &initial { 211 | trace!("Adding group from initial singles: {:?}", group.debug_string()); 212 | r.push(PropositionNode::from_group(group)); 213 | } 214 | } 215 | GenericNodeType::Group(group) => { 216 | trace!("Processing as Group: {:?}", group.debug_string()); 217 | let initial = self.get_group_forward(group); 218 | trace!("Initial groups: {}", initial.len()); 219 | for single in &initial { 220 | trace!("Adding single from initial groups: {:?}", single.debug_string()); 221 | r.push(PropositionNode::from_single(single)); 222 | } 223 | } 224 | } 225 | trace!("Resulting vector: {:?}", r); 226 | r 227 | } 228 | 229 | pub fn get_roots(&self) -> HashSet { 230 | self.roots.clone() 231 | } 232 | 233 | pub fn get_bfs_order(&self) -> Vec { 234 | create_bfs_order(&self) 235 | } 236 | } 237 | 238 | impl PropositionGraph { 239 | pub fn visualize(&self) { 240 | trace!("Single Forward:"); 241 | for (key, value) in self.single_forward.iter() { 242 | trace!(" {:?}: {:?}", key, value); 243 | } 244 | 245 | trace!("Single Backward:"); 246 | for (key, value) in self.single_backward.iter() { 247 | trace!(" {:?}: {:?}", key, value); 248 | } 249 | 250 | trace!("Group Forward:"); 251 | for (key, value) in self.group_forward.iter() { 252 | trace!(" {:?}: {:?}", key, value); 253 | } 254 | 255 | trace!("Inference Used:"); 256 | for (key, value) in self.inference_used.iter() { 257 | trace!(" ({:?}, {:?}): {:?}", key.0, key.1, value); 258 | } 259 | 260 | trace!("Roots: {:?}", self.roots); 261 | trace!("All Nodes: {:?}", self.all_nodes); 262 | } 263 | } 264 | 265 | fn reverse_prune_duplicates(raw_order: &Vec<(i32, PropositionNode)>) -> Vec { 266 | let mut seen = HashSet::new(); 267 | let mut result = vec![]; 268 | for (depth, node) in raw_order.iter().rev() { 269 | if !seen.contains(node) { 270 | result.push(node.clone()); 271 | } 272 | seen.insert(node); 273 | } 274 | result.reverse(); 275 | result 276 | } 277 | 278 | fn create_bfs_order(proposition_graph: &PropositionGraph) -> Vec { 279 | let mut queue = VecDeque::new(); 280 | let mut buffer = vec![]; 281 | for root in &proposition_graph.roots { 282 | queue.push_back((0, PropositionNode::from_single(&root))); 283 | } 284 | while let Some((depth, node)) = queue.pop_front() { 285 | buffer.push((depth, node.clone())); 286 | let forward = proposition_graph.get_all_forward(&node); 287 | for child in &forward { 288 | queue.push_back((depth + 1, child.clone())); 289 | } 290 | } 291 | let result = reverse_prune_duplicates(&buffer); 292 | result 293 | } 294 | -------------------------------------------------------------------------------- /rust/src/inference/lambda.rs: -------------------------------------------------------------------------------- 1 | use redis::Connection; 2 | 3 | use super::{ 4 | inference::{compute_each_combination, groups_from_backlinks, Inferencer}, 5 | table::{GenericNodeType, PropositionNode}, 6 | }; 7 | use crate::{model::weights::CLASS_LABELS, print_blue, print_green, print_red, print_yellow}; 8 | use std::error::Error; 9 | 10 | impl Inferencer { 11 | pub fn initialize_lambda(&mut self) -> Result<(), Box> { 12 | trace!("initialize_lambda: proposition"); 13 | for node in &self.proposition_graph.all_nodes { 14 | trace!("initializing: {}", node.debug_string()); 15 | for outcome in CLASS_LABELS { 16 | self.data.set_lambda_value(node, outcome, 1f64); 17 | } 18 | for parent in &self.proposition_graph.get_all_backward(node) { 19 | trace!( 20 | "initializing lambda link from {} to {}", 21 | node.debug_string(), 22 | parent.debug_string() 23 | ); 24 | for outcome in CLASS_LABELS { 25 | self.data.set_lambda_message(node, parent, outcome, 1f64); 26 | } 27 | } 28 | } 29 | Ok(()) 30 | } 31 | 32 | pub fn do_lambda_traversal( 33 | &mut self, 34 | connection: &mut Connection, 35 | ) -> Result<(), Box> { 36 | let mut bfs_order = self.bfs_order.clone(); 37 | bfs_order.reverse(); 38 | trace!("send_lambda_messages bfs_order: {:?}", &bfs_order); 39 | for node in &bfs_order { 40 | trace!("send pi bfs selects {:?}", node); 41 | self.lambda_visit_node(connection, node)?; 42 | } 43 | Ok(()) 44 | } 45 | 46 | pub fn lambda_visit_node( 47 | &mut self, 48 | connection: &mut Connection, 49 | from_node: &PropositionNode, 50 | ) -> Result<(), Box> { 51 | self.lambda_send_messages(connection, from_node)?; 52 | let is_observed = self.is_observed(connection, from_node)?; 53 | trace!( 54 | "lambda_visit_node {:?} is_observed {}", 55 | from_node, 56 | is_observed 57 | ); 58 | if is_observed { 59 | self.lambda_set_from_evidence(connection, from_node)?; 60 | } else { 61 | self.lambda_compute_value(connection, &from_node)?; 62 | } 63 | Ok(()) 64 | } 65 | 66 | pub fn lambda_set_from_evidence( 67 | &mut self, 68 | connection: &mut Connection, 69 | node: &PropositionNode, 70 | ) -> Result<(), Box> { 71 | let as_single = node.extract_single(); 72 | let probability = self 73 | .fact_memory 74 | .get_proposition_probability(connection, &as_single)? 75 | .unwrap(); 76 | trace!("set from evidence {:?} {}", node, probability); 77 | self.data.set_lambda_value(node, 1, probability); 78 | self.data.set_lambda_value(node, 0, 1f64 - probability); 79 | Ok(()) 80 | } 81 | 82 | pub fn lambda_compute_value( 83 | &mut self, 84 | connection: &mut Connection, 85 | node: &PropositionNode, 86 | ) -> Result<(), Box> { 87 | let is_observed = self.is_observed(connection, node)?; 88 | assert!(!is_observed); 89 | let children = self.proposition_graph.get_all_forward(node); 90 | for class_label in &CLASS_LABELS { 91 | let mut product = 1f64; 92 | for (_child_index, child_node) in children.iter().enumerate() { 93 | let child_lambda = self 94 | .data 95 | .get_lambda_message(&child_node, node, *class_label) 96 | .unwrap(); 97 | product *= child_lambda; 98 | } 99 | self.data.set_lambda_value(&node, *class_label, product); 100 | } 101 | Ok(()) 102 | } 103 | 104 | pub fn lambda_send_messages( 105 | &mut self, 106 | connection: &mut Connection, 107 | node: &PropositionNode, 108 | ) -> Result<(), Box> { 109 | let parent_nodes = self.proposition_graph.get_all_backward(node); 110 | trace!( 111 | "lambda_send_generic for node {:?} with parents {:?}", 112 | node, 113 | &parent_nodes 114 | ); 115 | let all_combinations = compute_each_combination(&parent_nodes); 116 | let lambda_true = self.data.get_lambda_value(node, 1).unwrap(); 117 | let lambda_false = self.data.get_lambda_value(node, 0).unwrap(); 118 | for (to_index, to_parent) in parent_nodes.iter().enumerate() { 119 | trace!("to_index {} to_parent {:?}", to_index, to_parent); 120 | let mut sum_true = 0f64; 121 | let mut sum_false = 0f64; 122 | for combination in &all_combinations { 123 | let mut pi_product = 1f64; 124 | for (other_index, other_parent) in parent_nodes.iter().enumerate() { 125 | if other_index != to_index { 126 | let class_bool = combination.get(other_parent).unwrap(); 127 | let class_label = if *class_bool { 1 } else { 0 }; 128 | let this_pi = self 129 | .data 130 | .get_pi_message(&other_parent, node, class_label) 131 | .unwrap(); 132 | trace!( 133 | "using pi message parent {:?}, node {:?}, label {}: pi={}", 134 | &other_parent, 135 | node, 136 | class_label, 137 | this_pi 138 | ); 139 | pi_product *= this_pi; 140 | } 141 | } 142 | let probability_true = 143 | self.score_factor_assignment(connection, &parent_nodes, combination, node)?; 144 | let probability_false = 1f64 - probability_true; 145 | trace!( 146 | "probability {} for {:?} on assignment {:?}", 147 | probability_true, 148 | node, 149 | combination 150 | ); 151 | let parent_assignment = combination.get(to_parent).unwrap(); 152 | let true_factor = probability_true * pi_product * lambda_true; 153 | let false_factor = probability_false * pi_product * lambda_false; 154 | if *parent_assignment { 155 | sum_true += true_factor + false_factor; 156 | } else { 157 | sum_false += true_factor + false_factor; 158 | } 159 | } 160 | trace!( 161 | "final 1 lambda message {} from {:?} to {:?}", 162 | sum_true, 163 | node, 164 | to_parent 165 | ); 166 | trace!( 167 | "final 0 lambda message {} from {:?} to {:?}", 168 | sum_false, 169 | node, 170 | to_parent 171 | ); 172 | self.data.set_lambda_message(node, to_parent, 1, sum_true); 173 | self.data.set_lambda_message(node, to_parent, 0, sum_false); 174 | } 175 | Ok(()) 176 | } 177 | } 178 | -------------------------------------------------------------------------------- /rust/src/inference/mod.rs: -------------------------------------------------------------------------------- 1 | pub mod table; 2 | pub mod inference; 3 | pub mod graph; 4 | pub mod pi; 5 | pub mod lambda; 6 | pub mod rounds; -------------------------------------------------------------------------------- /rust/src/inference/pi.rs: -------------------------------------------------------------------------------- 1 | use redis::Connection; 2 | 3 | use super::{ 4 | inference::{compute_each_combination, groups_from_backlinks, Inferencer}, 5 | table::{GenericNodeType, PropositionNode}, 6 | }; 7 | use crate::{ 8 | model::{objects::existence_predicate_name, weights::CLASS_LABELS}, 9 | print_blue, print_green, print_red, 10 | }; 11 | use std::error::Error; 12 | 13 | impl Inferencer { 14 | pub fn do_pi_traversal(&mut self, connection: &mut Connection) -> Result<(), Box> { 15 | let bfs_order = self.bfs_order.clone(); 16 | for node in &bfs_order { 17 | self.pi_visit_node(connection, node)?; 18 | } 19 | Ok(()) 20 | } 21 | 22 | pub fn pi_visit_node( 23 | &mut self, 24 | connection: &mut Connection, 25 | from_node: &PropositionNode, 26 | ) -> Result<(), Box> { 27 | if !self.is_root(from_node) { 28 | let is_observed = self.is_observed(connection, from_node)?; 29 | if is_observed { 30 | self.pi_set_from_evidence(connection, from_node)?; 31 | } else { 32 | self.pi_compute_value(connection, &from_node)?; 33 | } 34 | } else { 35 | self.pi_compute_root(from_node)?; 36 | } 37 | self.pi_send_messages(from_node)?; 38 | Ok(()) 39 | } 40 | 41 | fn pi_compute_root(&mut self, node: &PropositionNode) -> Result<(), Box> { 42 | let root = node.extract_single(); 43 | self.data 44 | .set_pi_value(&PropositionNode::from_single(&root), 1, 1.0f64); 45 | self.data 46 | .set_pi_value(&PropositionNode::from_single(&root), 0, 0.0f64); 47 | Ok(()) 48 | } 49 | 50 | pub fn pi_set_from_evidence( 51 | &mut self, 52 | connection: &mut Connection, 53 | node: &PropositionNode, 54 | ) -> Result<(), Box> { 55 | let as_single = node.extract_single(); 56 | let probability = self 57 | .fact_memory 58 | .get_proposition_probability(connection, &as_single)? 59 | .unwrap(); 60 | self.data.set_pi_value(node, 1, probability); 61 | self.data.set_pi_value(node, 0, 1f64 - probability); 62 | Ok(()) 63 | } 64 | 65 | pub fn pi_compute_value( 66 | &mut self, 67 | connection: &mut Connection, 68 | node: &PropositionNode, 69 | ) -> Result<(), Box> { 70 | let is_observed = self.is_observed(connection, node)?; 71 | assert!(!is_observed); 72 | let parent_nodes = self.proposition_graph.get_all_backward(node); 73 | let all_combinations = compute_each_combination(&parent_nodes); 74 | let mut sum_true = 0f64; 75 | let mut sum_false = 0f64; 76 | for combination in &all_combinations { 77 | let mut product = 1f64; 78 | for (index, parent_node) in parent_nodes.iter().enumerate() { 79 | let boolean_outcome = combination.get(parent_node).unwrap(); 80 | let usize_outcome = if *boolean_outcome { 1 } else { 0 }; 81 | let pi_x_z = self 82 | .data 83 | .get_pi_message(parent_node, node, usize_outcome) 84 | .unwrap(); 85 | trace!( 86 | "getting pi message parent_node {:?}, node {:?}, usize_outcome {}, pi_x_z {}", 87 | &parent_node, 88 | &node, 89 | usize_outcome, 90 | pi_x_z, 91 | ); 92 | product *= pi_x_z; 93 | } 94 | let true_marginal = self.score_factor_assignment(connection, &parent_nodes, combination, node)?; 95 | let false_marginal = 1f64 - true_marginal; 96 | sum_true += true_marginal * product; 97 | sum_false += false_marginal * product; 98 | } 99 | self.data.set_pi_value(node, 1, sum_true); 100 | self.data.set_pi_value(node, 0, sum_false); 101 | Ok(()) 102 | } 103 | 104 | pub fn pi_send_messages(&mut self, node: &PropositionNode) -> Result<(), Box> { 105 | let forward_groups = self.proposition_graph.get_all_forward(node); 106 | for (this_index, to_node) in forward_groups.iter().enumerate() { 107 | for class_label in &CLASS_LABELS { 108 | let mut lambda_part = 1f64; 109 | for (other_index, other_child) in forward_groups.iter().enumerate() { 110 | if other_index != this_index { 111 | let this_lambda = self 112 | .data 113 | .get_lambda_message(&other_child, node, *class_label) 114 | .unwrap(); 115 | lambda_part *= this_lambda; 116 | } 117 | } 118 | let pi_part = self.data.get_pi_value(&node, *class_label).unwrap(); 119 | let message = pi_part * lambda_part; 120 | self.data 121 | .set_pi_message(&node, &to_node, *class_label, message); 122 | } 123 | } 124 | Ok(()) 125 | } 126 | } 127 | -------------------------------------------------------------------------------- /rust/src/inference/rounds.rs: -------------------------------------------------------------------------------- 1 | use std::error::Error; 2 | 3 | use redis::Connection; 4 | 5 | use crate::common::{model::InferenceModel, proposition_db::EmptyBeliefTable, resources::ResourceContext, test::ReplState}; 6 | 7 | use super::{graph::PropositionGraph, inference::{Inferencer, MarginalTable}, table::PropositionNode}; 8 | 9 | fn setup_test_scenario( 10 | connection: &mut Connection, 11 | scenario_name: &str, 12 | test_scenario: &str, 13 | repl_state: &mut ReplState, 14 | ) -> Result, Box> { 15 | let pairs = match (scenario_name, test_scenario) { 16 | ("dating_simple", "prior") => vec![], 17 | ("dating_simple", "jack_lonely") => vec![("lonely[sub=test_Man0]", 1f64)], 18 | ("dating_simple", "they_date") => vec![("date[obj=test_Woman0,sub=test_Man0]", 1f64)], 19 | ("dating_simple", "jack_likes") => vec![("like[obj=test_Woman0,sub=test_Man0]", 1f64)], 20 | ("dating_simple", "jill_likes") => vec![("like[obj=test_Man0,sub=test_Woman0]", 1f64)], 21 | ("dating_triangle", "prior") => vec![("charming[sub=test_Man0]", 1f64)], 22 | ("dating_triangle", "charming") => vec![("charming[sub=test_Man0]", 1f64)], 23 | ("dating_triangle", "baller") => vec![("baller[sub=test_Man0]", 1f64)], 24 | ("long_chain", "prior") => vec![], 25 | ("long_chain", "set_0_1") => vec![("alpha0[sub=test_Man0]", 1f64)], 26 | ("long_chain", "set_n_1") => vec![("alpha10[sub=test_Man0]", 1f64)], 27 | ("mid_chain", "set_0_1") => vec![("alpha0[sub=test_Man0]", 1f64)], 28 | ("mid_chain", "set_n_1") => vec![("alpha4[sub=test_Man0]", 1f64)], 29 | _ => panic!("Case name not recognized"), 30 | }; 31 | let r = repl_state.set_pairs_by_name(connection, &pairs); 32 | Ok(r) 33 | } 34 | 35 | pub fn run_inference_rounds( 36 | connection: &mut Connection, 37 | scenario_name: &str, 38 | test_scenario: &str, 39 | ) -> Result, Box> { 40 | let model = InferenceModel::new_shared(scenario_name.to_string()).unwrap(); 41 | let fact_memory = EmptyBeliefTable::new_shared(scenario_name)?; 42 | let target = model.graph.get_target(connection)?; 43 | let proposition_graph = PropositionGraph::new_shared(connection, &model.graph, target)?; 44 | proposition_graph.visualize(); 45 | let mut inferencer = 46 | Inferencer::new_mutable(model.clone(), proposition_graph.clone(), fact_memory)?; 47 | inferencer.initialize_chart(connection)?; 48 | let mut repl = ReplState::new(inferencer); 49 | let mut buffer = vec![]; 50 | buffer.push(repl.inferencer.log_table_to_file()?); 51 | let evidence_node = setup_test_scenario(connection, scenario_name, test_scenario, &mut repl)?; 52 | if evidence_node.is_some() { 53 | for _i in 0..50 { 54 | repl.inferencer 55 | .do_fan_out_from_node(connection, &evidence_node.clone().unwrap())?; 56 | buffer.push(repl.inferencer.log_table_to_file()?); 57 | } 58 | } else { 59 | for _i in 0..50 { 60 | repl.inferencer 61 | .do_full_forward_and_backward(connection)?; 62 | buffer.push(repl.inferencer.log_table_to_file()?); 63 | } 64 | } 65 | Ok(buffer) 66 | } 67 | -------------------------------------------------------------------------------- /rust/src/inference/table.rs: -------------------------------------------------------------------------------- 1 | use crate::{ 2 | common::{graph::serialize_record, interface::BeliefTable}, 3 | model::{ 4 | objects::{Predicate, PredicateGroup, Proposition, PropositionGroup}, 5 | weights::CLASS_LABELS, 6 | }, 7 | print_green, print_yellow, 8 | }; 9 | use redis::Connection; 10 | use serde::{Deserialize, Serialize}; 11 | use std::{collections::HashMap, error::Error, rc::Rc}; 12 | 13 | use colored::*; 14 | use std::collections::hash_map::DefaultHasher; 15 | use std::fmt; 16 | use std::hash::{Hash, Hasher}; 17 | 18 | #[derive(Debug, PartialEq, Eq, Hash, Clone)] 19 | pub enum GenericNodeType { 20 | Single(Proposition), 21 | Group(PropositionGroup), 22 | } 23 | 24 | #[derive(PartialEq, Eq, Clone)] 25 | pub struct PropositionNode { 26 | pub node: GenericNodeType, 27 | pub underlying_hash: u64, 28 | } 29 | 30 | fn hash_proposition(proposition: &Proposition) -> u64 { 31 | let mut hasher = DefaultHasher::new(); 32 | proposition.hash(&mut hasher); 33 | hasher.finish() // This returns the hash as u64 34 | } 35 | 36 | fn hash_group(group: &PropositionGroup) -> u64 { 37 | let mut hasher = DefaultHasher::new(); 38 | group.hash(&mut hasher); 39 | hasher.finish() // This returns the hash as u64 40 | } 41 | 42 | impl Hash for PropositionNode { 43 | fn hash(&self, state: &mut H) { 44 | self.underlying_hash.hash(state); 45 | } 46 | } 47 | 48 | impl PropositionNode { 49 | pub fn from_single(proposition: &Proposition) -> PropositionNode { 50 | let underlying_hash = hash_proposition(proposition); 51 | PropositionNode { 52 | node: GenericNodeType::Single(proposition.clone()), 53 | underlying_hash, 54 | } 55 | } 56 | 57 | pub fn from_group(group: &PropositionGroup) -> PropositionNode { 58 | let underlying_hash = hash_group(group); 59 | trace!("got hash {} {:?}", underlying_hash, group); 60 | PropositionNode { 61 | node: GenericNodeType::Group(group.clone()), 62 | underlying_hash, 63 | } 64 | } 65 | 66 | pub fn debug_string(&self) -> String { 67 | let string_part = match &self.node { 68 | GenericNodeType::Single(proposition) => proposition.debug_string(), 69 | GenericNodeType::Group(group) => group.debug_string(), 70 | }; 71 | format!("{}", string_part) 72 | } 73 | 74 | pub fn is_single(&self) -> bool { 75 | matches!(self.node, GenericNodeType::Single(_)) 76 | } 77 | 78 | pub fn is_group(&self) -> bool { 79 | matches!(self.node, GenericNodeType::Group(_)) 80 | } 81 | 82 | pub fn extract_single(&self) -> Proposition { 83 | match &self.node { 84 | GenericNodeType::Single(proposition) => proposition.clone(), 85 | _ => panic!("This is not a single."), 86 | } 87 | } 88 | 89 | pub fn extract_group(&self) -> PropositionGroup { 90 | match &self.node { 91 | GenericNodeType::Group(group) => group.clone(), 92 | _ => panic!("This is not a group."), 93 | } 94 | } 95 | } 96 | impl fmt::Debug for PropositionNode { 97 | fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { 98 | write!(f, "{}", self.debug_string()) 99 | } 100 | } 101 | 102 | #[derive(Debug, Clone)] 103 | 104 | pub struct HashMapBeliefTable { 105 | pi_values: HashMap<(PropositionNode, usize), f64>, 106 | lambda_values: HashMap<(PropositionNode, usize), f64>, 107 | pi_messages: HashMap<(PropositionNode, PropositionNode, usize), f64>, 108 | lambda_messages: HashMap<(PropositionNode, PropositionNode, usize), f64>, 109 | bfs_order: Vec, 110 | } 111 | 112 | fn print_sorted_map( 113 | map: &HashMap<(PropositionNode, usize), f64>, 114 | bfs_order: &Vec, 115 | ) { 116 | for proposition in bfs_order { 117 | let key = (proposition.clone(), 1); 118 | let prob_true = map.get(&key).unwrap(); 119 | let prob_false = 1.0 - prob_true; 120 | let formatted_prob_true = format!("{:.8}", prob_true); 121 | let formatted_prob_false = format!("{:.8}", prob_false); 122 | println!( 123 | "{:<12} {:<12} {}", 124 | formatted_prob_true.green(), 125 | formatted_prob_false.red(), 126 | proposition.debug_string() 127 | ); 128 | } 129 | } 130 | 131 | fn print_sorted_messages( 132 | map: &HashMap<(PropositionNode, PropositionNode, usize), f64>, 133 | bfs_order: &Vec, 134 | ) { 135 | for from in bfs_order { 136 | for to in bfs_order { 137 | let key = (from.clone(), to.clone(), 1); 138 | if let Some(&prob_true) = map.get(&key) { 139 | let prob_false = 1.0 - prob_true; 140 | let formatted_prob_true = format!("{:.8}", prob_true); 141 | let formatted_prob_false = format!("{:.8}", prob_false); 142 | println!( 143 | "{:<12} {:<12} {:<20} {}", 144 | formatted_prob_true.green(), 145 | formatted_prob_false.red(), 146 | from.debug_string(), 147 | to.debug_string() 148 | ); 149 | } 150 | } 151 | } 152 | } 153 | 154 | impl HashMapBeliefTable { 155 | pub fn print_table(&self, table_name: &String) { 156 | match table_name.as_str() { 157 | "pv" => { 158 | println!("PI VALUES"); 159 | print_sorted_map(&self.pi_values, &self.bfs_order); 160 | } 161 | "lv" => { 162 | println!("LAMBDA VALUES"); 163 | print_sorted_map(&self.lambda_values, &self.bfs_order); 164 | } 165 | "pm" => { 166 | println!("PI MESSAGES"); 167 | print_sorted_messages(&self.pi_messages, &self.bfs_order); 168 | } 169 | "lm" => { 170 | println!("LAMBDA MESSAGES"); 171 | print_sorted_messages(&self.lambda_messages, &self.bfs_order); 172 | } 173 | _ => println!("Table not recognized."), 174 | }; 175 | } 176 | } 177 | 178 | impl HashMapBeliefTable { 179 | // Constructor to create a new instance 180 | pub fn new(bfs_order: Vec) -> Self { 181 | HashMapBeliefTable { 182 | pi_values: HashMap::new(), 183 | lambda_values: HashMap::new(), 184 | pi_messages: HashMap::new(), 185 | lambda_messages: HashMap::new(), 186 | bfs_order, 187 | } 188 | } 189 | 190 | // Getter for pi values 191 | pub fn get_pi_value(&self, node: &PropositionNode, outcome: usize) -> Option { 192 | let key = (node.clone(), outcome); 193 | self.pi_values.get(&key).cloned() 194 | } 195 | 196 | // Setter for pi values 197 | pub fn set_pi_value(&mut self, node: &PropositionNode, outcome: usize, value: f64) { 198 | let key = (node.clone(), outcome); 199 | self.pi_values.insert(key, value); 200 | } 201 | 202 | // Getter for lambda values 203 | pub fn get_lambda_value(&self, node: &PropositionNode, outcome: usize) -> Option { 204 | let key = (node.clone(), outcome); 205 | self.lambda_values.get(&key).cloned() 206 | } 207 | 208 | // Setter for lambda values 209 | pub fn set_lambda_value(&mut self, node: &PropositionNode, outcome: usize, value: f64) { 210 | let key = (node.clone(), outcome); 211 | self.lambda_values.insert(key, value); 212 | } 213 | 214 | // Getter for pi messages 215 | pub fn get_pi_message( 216 | &self, 217 | from: &PropositionNode, 218 | to: &PropositionNode, 219 | outcome: usize, 220 | ) -> Option { 221 | let key = (from.clone(), to.clone(), outcome); 222 | self.pi_messages.get(&key).cloned() 223 | } 224 | 225 | // Setter for pi messages 226 | pub fn set_pi_message( 227 | &mut self, 228 | from: &PropositionNode, 229 | to: &PropositionNode, 230 | outcome: usize, 231 | value: f64, 232 | ) { 233 | let key = (from.clone(), to.clone(), outcome); 234 | self.pi_messages.insert(key, value); 235 | } 236 | 237 | // Getter for lambda messages 238 | pub fn get_lambda_message( 239 | &self, 240 | from: &PropositionNode, 241 | to: &PropositionNode, 242 | outcome: usize, 243 | ) -> Option { 244 | let key = (from.clone(), to.clone(), outcome); 245 | self.lambda_messages.get(&key).cloned() 246 | } 247 | 248 | // Setter for lambda messages 249 | pub fn set_lambda_message( 250 | &mut self, 251 | from: &PropositionNode, 252 | to: &PropositionNode, 253 | outcome: usize, 254 | value: f64, 255 | ) { 256 | let key = (from.clone(), to.clone(), outcome); 257 | self.lambda_messages.insert(key, value); 258 | } 259 | } 260 | 261 | pub struct VariableAssignment { 262 | pub assignment_map: HashMap, 263 | } 264 | 265 | impl VariableAssignment { 266 | pub fn new(assignment_map: HashMap) -> VariableAssignment { 267 | VariableAssignment { assignment_map } 268 | } 269 | } 270 | 271 | pub struct FactorProbabilityTable { 272 | pub pairs: Vec<(VariableAssignment, f64)>, 273 | } 274 | 275 | impl FactorProbabilityTable { 276 | pub fn new(pairs: Vec<(VariableAssignment, f64)>) -> FactorProbabilityTable { 277 | FactorProbabilityTable { pairs } 278 | } 279 | } 280 | -------------------------------------------------------------------------------- /rust/src/lib.rs: -------------------------------------------------------------------------------- 1 | #![allow(unused_imports)] 2 | #![allow(unused_variables)] 3 | #![allow(dead_code)] 4 | 5 | pub mod model; 6 | pub mod explorer; 7 | pub mod scenarios; 8 | pub mod inference; 9 | pub mod common; 10 | pub mod baseline; 11 | 12 | #[macro_use] 13 | extern crate log; 14 | -------------------------------------------------------------------------------- /rust/src/model/choose.rs: -------------------------------------------------------------------------------- 1 | use redis::Connection; 2 | 3 | use super::objects::{ImplicationFactor, Proposition}; 4 | use super::ops::{convert_to_proposition, convert_to_quantified, extract_premise_role_map}; 5 | use crate::common::graph::InferenceGraph; 6 | use crate::common::model::{FactorContext, InferenceModel}; 7 | use crate::inference::graph::PropositionFactor; 8 | use crate::model::objects::{GroupRoleMap, PropositionGroup, RoleMap, existence_predicate_name}; 9 | use crate::{ 10 | common::interface::BeliefTable, 11 | model::objects::{Predicate, PredicateGroup}, 12 | }; 13 | use crate::{print_green, print_red}; 14 | use std::collections::{HashMap, HashSet}; 15 | use std::{borrow::Borrow, error::Error}; 16 | 17 | fn combine(input_array: &[usize], k: usize) -> Vec> { 18 | let mut result = vec![]; 19 | let mut temp_vec = vec![]; 20 | fn run( 21 | input_array: &[usize], 22 | k: usize, 23 | start: usize, 24 | temp_vec: &mut Vec, 25 | result: &mut Vec>, 26 | ) { 27 | if temp_vec.len() == k { 28 | result.push(temp_vec.clone()); 29 | return; 30 | } 31 | for i in start..input_array.len() { 32 | temp_vec.push(input_array[i]); 33 | run(input_array, k, i + 1, temp_vec, result); 34 | temp_vec.pop(); 35 | } 36 | } 37 | run(input_array, k, 0, &mut temp_vec, &mut result); 38 | result 39 | } 40 | 41 | fn compute_choose_configurations(n: usize, k: usize) -> Vec> { 42 | let input_array: Vec = (0..n).collect(); 43 | combine(&input_array, k) 44 | } 45 | 46 | fn extract_roles_from_indices(roles: &[String], indices: &[usize]) -> Vec { 47 | let index_set: std::collections::HashSet = indices.iter().cloned().collect(); 48 | roles 49 | .iter() 50 | .enumerate() 51 | .filter_map(|(i, role)| { 52 | if index_set.contains(&i) { 53 | Some(role.clone()) 54 | } else { 55 | None 56 | } 57 | }) 58 | .collect() 59 | } 60 | 61 | pub fn compute_search_predicates( 62 | proposition: &Proposition, 63 | ) -> Result, Box> { 64 | let num_roles = proposition.predicate.roles().len(); 65 | let configurations1 = compute_choose_configurations(num_roles, 1); 66 | let configurations2 = compute_choose_configurations(num_roles, 2); 67 | let roles = proposition.predicate.role_names(); 68 | let mut result = Vec::new(); 69 | for configuration in configurations1.into_iter().chain(configurations2) { 70 | let quantified_roles = extract_roles_from_indices(&roles, &configuration); 71 | let quantified = convert_to_quantified(proposition, &quantified_roles); 72 | result.push(quantified); 73 | } 74 | Ok(result) 75 | } 76 | 77 | pub fn extract_backimplications_from_proposition( 78 | connection: &mut Connection, 79 | graph: &InferenceGraph, 80 | conclusion: &Proposition, 81 | ) -> Result, Box> { 82 | trace!( 83 | "Computing backimplications for proposition {:?}", 84 | conclusion 85 | ); 86 | let search_keys = compute_search_predicates(conclusion)?; 87 | trace!("Computed search_keys {:?}", &search_keys); 88 | let mut backimplications = Vec::new(); 89 | for predicate in &search_keys { 90 | trace!("Processing search_key {:?}", &predicate.hash_string()); 91 | let implications = graph.predicate_backward_links(connection, &predicate)?; 92 | trace!("Found implications {:?}", &implications); 93 | for implication in &implications { 94 | let mut terms = Vec::new(); 95 | for (index, proposition) in implication.premise.terms.iter().enumerate() { 96 | trace!("Processing term {}: {:?}", index, proposition); 97 | let extracted_mapping = 98 | extract_premise_role_map(&conclusion, &implication.role_maps.role_maps[index]); 99 | trace!( 100 | "Extracted mapping for term {}: {:?}", 101 | index, 102 | &extracted_mapping 103 | ); 104 | let extracted_proposition = 105 | convert_to_proposition(&proposition, &extracted_mapping)?; 106 | trace!( 107 | "Converted to proposition for term {}: {:?}", 108 | index, 109 | extracted_proposition 110 | ); 111 | terms.push(extracted_proposition); 112 | } 113 | backimplications.push(PropositionFactor { 114 | premise: PropositionGroup { terms }, 115 | conclusion: conclusion.clone(), 116 | inference: implication.clone(), 117 | }); 118 | } 119 | } 120 | trace!("Returning backimplications {:?}", &backimplications); 121 | debug!( 122 | "Completed computing backimplications, total count: {}", 123 | backimplications.len() 124 | ); 125 | Ok(backimplications) 126 | } 127 | 128 | pub fn extract_existence_factor_for_predicate( 129 | conclusion: &Predicate, 130 | ) -> Result> { 131 | let mut new_roles = vec![]; 132 | let mut mapping = HashMap::new(); 133 | for old_role in &conclusion.roles() { 134 | new_roles.push(old_role.convert_to_quantified()); 135 | mapping.insert(old_role.role_name.clone(), old_role.role_name.clone()); 136 | } 137 | let premise = Predicate::new_from_just_name(existence_predicate_name(), new_roles); 138 | let role_map = RoleMap::new(mapping); 139 | let premise_group = PredicateGroup::new(vec![premise]); 140 | let mapping_group = GroupRoleMap::new(vec![role_map]); 141 | let factor = ImplicationFactor { 142 | premise: premise_group, 143 | role_maps: mapping_group, 144 | conclusion: conclusion.clone(), 145 | }; 146 | trace!("extracted existence predicate {:?}", &factor); 147 | Ok(factor) 148 | } 149 | 150 | pub fn extract_existence_factor_for_proposition( 151 | basis: &Proposition, 152 | ) -> Result> { 153 | let mut new_roles = vec![]; 154 | let mut mapping = HashMap::new(); 155 | for old_role in &basis.predicate.roles() { 156 | new_roles.push(old_role.convert_to_quantified()); 157 | mapping.insert(old_role.role_name.clone(), old_role.role_name.clone()); 158 | } 159 | let premise = Predicate::new_from_just_name(existence_predicate_name(), new_roles.clone()); 160 | let role_map = RoleMap::new(mapping); 161 | let premise_group = PredicateGroup::new(vec![premise]); 162 | let mapping_group = GroupRoleMap::new(vec![role_map]); 163 | let conclusion = Predicate::new_from_relation(basis.predicate.relation.clone(), new_roles.clone()); 164 | let factor = ImplicationFactor { 165 | premise: premise_group, 166 | role_maps: mapping_group, 167 | conclusion, 168 | }; 169 | trace!("extracted existence predicate {:?}", &factor); 170 | Ok(factor) 171 | } 172 | -------------------------------------------------------------------------------- /rust/src/model/config.rs: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gregorycoppola/bayes-star/123e619d0126a0031db09746252e5a752d3aa9f1/rust/src/model/config.rs -------------------------------------------------------------------------------- /rust/src/model/creators.rs: -------------------------------------------------------------------------------- 1 | use crate::model::objects::*; 2 | 3 | // Import the necessary structs and enums 4 | use crate::model::objects::{ 5 | ConstantArgument, LabeledArgument, Predicate, ImplicationFactor, VariableArgument, 6 | }; 7 | 8 | pub fn conjunction(terms: Vec) -> PredicateGroup { 9 | PredicateGroup { terms } 10 | } 11 | 12 | pub fn implication( 13 | premise: PredicateGroup, 14 | conclusion: Predicate, 15 | role_maps: Vec, 16 | ) -> ImplicationFactor { 17 | let role_maps = GroupRoleMap { role_maps }; 18 | ImplicationFactor { 19 | premise, 20 | conclusion, 21 | role_maps, 22 | } 23 | } 24 | 25 | pub fn variable_argument(domain: String) -> VariableArgument { 26 | VariableArgument { 27 | domain 28 | } 29 | } 30 | 31 | pub fn relation(relation_name: String, roles: Vec) -> Relation { 32 | Relation::new(relation_name, roles) 33 | } 34 | 35 | pub fn proposition(relation: Relation, roles: Vec) -> Proposition { 36 | Proposition::from(Predicate::new_from_relation(relation, roles)) 37 | } 38 | 39 | pub fn predicate(relation: Relation, roles: Vec) -> Predicate { 40 | Predicate::new_from_relation(relation, roles) 41 | } 42 | 43 | // Function to create a FilledRole 44 | pub fn role(role_name: String, argument: Argument) -> LabeledArgument { 45 | // Assuming logger.noop is a logging function, you can implement similar functionality in Rust if needed. 46 | // For this example, it's omitted. 47 | LabeledArgument { 48 | role_name, 49 | argument, 50 | } 51 | } 52 | 53 | // Function to create a VariableArgument 54 | pub fn variable(domain: String) -> Argument { 55 | Argument::Variable(VariableArgument { domain }) 56 | } 57 | 58 | // Function to create a ConstantArgument 59 | pub fn constant(domain: String, entity_id: String) -> Argument { 60 | Argument::Constant(ConstantArgument { domain, entity_id }) 61 | } 62 | 63 | // Helper functions for specific roles 64 | pub fn sub(argument: Argument) -> LabeledArgument { 65 | role("sub".to_string(), argument) 66 | } 67 | 68 | pub fn obj(argument: Argument) -> LabeledArgument { 69 | role("obj".to_string(), argument) 70 | } 71 | -------------------------------------------------------------------------------- /rust/src/model/exponential.rs: -------------------------------------------------------------------------------- 1 | use super::choose::extract_backimplications_from_proposition; 2 | use super::objects::ImplicationFactor; 3 | use super::weights::{negative_feature, positive_feature, ExponentialWeights}; 4 | use crate::common::interface::{BeliefTable, PredictStatistics, TrainStatistics}; 5 | use crate::common::model::InferenceModel; 6 | use crate::common::model::{FactorContext, FactorModel}; 7 | use crate::common::redis::RedisManager; 8 | use crate::common::resources::ResourceContext; 9 | use crate::common::setup::CommandLineOptions; 10 | use crate::model::objects::Predicate; 11 | use crate::model::weights::CLASS_LABELS; 12 | use crate::{print_blue, print_yellow}; 13 | use redis::Connection; 14 | use std::cell::RefCell; 15 | use std::collections::HashMap; 16 | use std::error::Error; 17 | use std::rc::Rc; 18 | use std::sync::Arc; 19 | pub struct ExponentialModel { 20 | print_training_loss: bool, 21 | weights: ExponentialWeights, 22 | } 23 | 24 | impl ExponentialModel { 25 | pub fn new_mutable(namespace: String) -> Result, Box> { 26 | let weights = ExponentialWeights::new(namespace.clone())?; 27 | Ok(Box::new(ExponentialModel { 28 | print_training_loss: false, 29 | weights, 30 | })) 31 | } 32 | pub fn new_shared(namespace: String) -> Result, Box> { 33 | let weights = ExponentialWeights::new(namespace.clone())?; 34 | Ok(Arc::new(ExponentialModel { 35 | print_training_loss: false, 36 | weights, 37 | })) 38 | } 39 | } 40 | 41 | fn dot_product(dict1: &HashMap, dict2: &HashMap) -> f64 { 42 | let mut result = 0.0; 43 | for (key, &v1) in dict1 { 44 | if let Some(&v2) = dict2.get(key) { 45 | let product = v1 * v2; 46 | trace!( 47 | "dot_product: key {}, v1 {}, v2 {}, product {}", 48 | key, 49 | v1, 50 | v2, 51 | product 52 | ); 53 | result += product; 54 | } 55 | // In case of null (None), we skip the key as per the original JavaScript logic. 56 | } 57 | result 58 | } 59 | 60 | pub fn compute_potential(weights: &HashMap, features: &HashMap) -> f64 { 61 | let dot = dot_product(weights, features); 62 | dot.exp() 63 | } 64 | 65 | pub fn features_from_factor( 66 | factor: &FactorContext, 67 | ) -> Result>, Box> { 68 | let mut vec_result = vec![]; 69 | for class_label in CLASS_LABELS { 70 | let mut result = HashMap::new(); 71 | for (i, premise) in factor.factor.iter().enumerate() { 72 | debug!("Processing backimplication {}", i); 73 | let feature = premise.inference.unique_key(); 74 | debug!("Generated unique key for feature: {}", feature); 75 | let probability = factor.probabilities[i]; 76 | debug!( 77 | "Conjunction probability for backimplication {}: {}", 78 | i, probability 79 | ); 80 | let posf = positive_feature(&feature, class_label); 81 | let negf = negative_feature(&feature, class_label); 82 | result.insert(posf.clone(), probability); 83 | result.insert(negf.clone(), 1.0 - probability); 84 | debug!( 85 | "Inserted features for backimplication {}: positive - {}, negative - {}", 86 | i, posf, negf 87 | ); 88 | } 89 | vec_result.push(result); 90 | } 91 | trace!("features_from_backimplications completed successfully"); 92 | Ok(vec_result) 93 | } 94 | 95 | pub fn compute_expected_features( 96 | probability: f64, 97 | features: &HashMap, 98 | ) -> HashMap { 99 | let mut result = HashMap::new(); 100 | for (key, &value) in features { 101 | result.insert(key.clone(), value * probability); 102 | } 103 | result 104 | } 105 | 106 | const LEARNING_RATE: f64 = 0.05; 107 | 108 | pub fn do_sgd_update( 109 | weights: &HashMap, 110 | gold_features: &HashMap, 111 | expected_features: &HashMap, 112 | print_training_loss: bool, 113 | ) -> HashMap { 114 | let mut new_weights = HashMap::new(); 115 | for (feature, &wv) in weights { 116 | let gv = gold_features.get(feature).unwrap_or(&0.0); 117 | let ev = expected_features.get(feature).unwrap_or(&0.0); 118 | let new_weight = wv + LEARNING_RATE * (gv - ev); 119 | let loss = (gv - ev).abs(); 120 | if print_training_loss { 121 | trace!( 122 | "feature: {}, gv: {}, ev: {}, loss: {}, old_weight: {}, new_weight: {}", 123 | feature, 124 | gv, 125 | ev, 126 | loss, 127 | wv, 128 | new_weight 129 | ); 130 | } 131 | new_weights.insert(feature.clone(), new_weight); 132 | } 133 | new_weights 134 | } 135 | 136 | impl FactorModel for ExponentialModel { 137 | fn initialize_connection( 138 | &mut self, 139 | connection: &mut Connection, 140 | implication: &ImplicationFactor, 141 | ) -> Result<(), Box> { 142 | self.weights.initialize_weights(connection, implication)?; 143 | Ok(()) 144 | } 145 | 146 | fn train( 147 | &mut self, 148 | connection: &mut Connection, 149 | factor: &FactorContext, 150 | gold_probability: f64, 151 | ) -> Result> { 152 | trace!("train_on_example - Getting features from backimplications"); 153 | let features = match features_from_factor(factor) { 154 | Ok(f) => f, 155 | Err(e) => { 156 | trace!( 157 | "train_on_example - Error in features_from_backimplications: {:?}", 158 | e 159 | ); 160 | return Err(e); 161 | } 162 | }; 163 | let mut weight_vectors = vec![]; 164 | let mut potentials = vec![]; 165 | for class_label in CLASS_LABELS { 166 | for (feature, weight) in &features[class_label] { 167 | trace!("feature {:?} {}", feature, weight); 168 | } 169 | trace!( 170 | "train_on_example - Reading weights for class {}", 171 | class_label 172 | ); 173 | let weight_vector = match self.weights.read_weight_vector( 174 | connection, 175 | &features[class_label].keys().cloned().collect::>(), 176 | ) { 177 | Ok(w) => w, 178 | Err(e) => { 179 | trace!("train_on_example - Error in read_weights: {:?}", e); 180 | return Err(e); 181 | } 182 | }; 183 | trace!("train_on_example - Computing probability"); 184 | let potential = compute_potential(&weight_vector, &features[class_label]); 185 | trace!("train_on_example - Computed probability: {}", potential); 186 | potentials.push(potential); 187 | weight_vectors.push(weight_vector); 188 | } 189 | let normalization = potentials[0] + potentials[1]; 190 | for class_label in CLASS_LABELS { 191 | let probability = potentials[class_label] / normalization; 192 | trace!("train_on_example - Computing expected features"); 193 | let this_true_prob = if class_label == 0 { 194 | 1f64 - gold_probability 195 | } else { 196 | gold_probability 197 | }; 198 | let gold = compute_expected_features(this_true_prob, &features[class_label]); 199 | let expected = compute_expected_features(probability, &features[class_label]); 200 | trace!("train_on_example - Performing SGD update"); 201 | let new_weight = do_sgd_update( 202 | &weight_vectors[class_label], 203 | &gold, 204 | &expected, 205 | self.print_training_loss, 206 | ); 207 | trace!("train_on_example - Saving new weights"); 208 | self.weights.save_weight_vector(connection, &new_weight)?; 209 | } 210 | trace!("train_on_example - End"); 211 | Ok(TrainStatistics { loss: 1f64 }) 212 | } 213 | fn predict( 214 | &self, 215 | connection: &mut Connection, 216 | factor: &FactorContext, 217 | ) -> Result> { 218 | let features = match features_from_factor(factor) { 219 | Ok(f) => f, 220 | Err(e) => { 221 | trace!( 222 | "inference_probability - Error in features_from_backimplications: {:?}", 223 | e 224 | ); 225 | return Err(e); 226 | } 227 | }; 228 | let mut potentials = vec![]; 229 | for class_label in CLASS_LABELS { 230 | let this_features = &features[class_label]; 231 | for (feature, weight) in this_features.iter() { 232 | trace!("feature {:?} {}", &feature, weight); 233 | } 234 | trace!("inference_probability - Reading weights"); 235 | let weight_vector = match self.weights.read_weight_vector( 236 | connection, 237 | &this_features.keys().cloned().collect::>(), 238 | ) { 239 | Ok(w) => w, 240 | Err(e) => { 241 | trace!("inference_probability - Error in read_weights: {:?}", e); 242 | return Err(e); 243 | } 244 | }; 245 | for (feature, weight) in weight_vector.iter() { 246 | trace!("weight {:?} {}", &feature, weight); 247 | } 248 | let potential = compute_potential(&weight_vector, &this_features); 249 | trace!("potential for {} {} {:?}", class_label, potential, &factor); 250 | potentials.push(potential); 251 | } 252 | let normalization = potentials[0] + potentials[1]; 253 | let probability = potentials[1] / normalization; 254 | trace!( 255 | "dot_product: normalization {}, marginal {}", 256 | normalization, 257 | probability 258 | ); 259 | Ok(PredictStatistics { probability }) 260 | } 261 | } 262 | -------------------------------------------------------------------------------- /rust/src/model/mod.rs: -------------------------------------------------------------------------------- 1 | pub mod objects; 2 | pub mod creators; 3 | pub mod choose; 4 | pub mod ops; 5 | pub mod weights; 6 | pub mod exponential; 7 | pub mod config; -------------------------------------------------------------------------------- /rust/src/model/ops.rs: -------------------------------------------------------------------------------- 1 | use crate::model::objects::{LabeledArgument, Predicate, RoleMap}; 2 | use std::{collections::HashMap, error::Error}; 3 | 4 | use super::objects::{Argument, Proposition}; 5 | 6 | pub fn convert_to_quantified(proposition: &Proposition, roles: &[String]) -> Predicate { 7 | let role_set: std::collections::HashSet = roles.iter().cloned().collect(); 8 | let result: Vec = proposition 9 | .predicate 10 | .roles() 11 | .iter() 12 | .map(|crole| { 13 | if role_set.contains(&crole.role_name) { 14 | crole.convert_to_quantified() 15 | } else { 16 | crole.clone() 17 | } 18 | }) 19 | .collect(); 20 | 21 | Predicate::new_from_relation(proposition.predicate.relation.clone(), result) 22 | } 23 | 24 | pub fn convert_to_proposition( 25 | predicate: &Predicate, 26 | role_map: &HashMap, 27 | ) -> Result> { 28 | debug!( 29 | "Converting to proposition: {:?}, role_map {:?}", 30 | predicate, &role_map 31 | ); 32 | let mut result_roles = Vec::new(); 33 | for role in &predicate.roles() { 34 | debug!("Processing role: {:?}", role); 35 | if role.argument.is_variable() { 36 | debug!("Role is a variable, attempting substitution."); 37 | match role_map.get(&role.role_name) { 38 | Some(substitute) => { 39 | debug!( 40 | "Substitution found for role: {}, substitute: {:?}", 41 | role.role_name, substitute 42 | ); 43 | let new_role = role.do_substitution(substitute.clone()); // Assuming this method exists in FilledRole 44 | debug!("New role after substitution: {:?}", new_role); 45 | 46 | assert!( 47 | new_role.argument.is_constant(), 48 | "After substitution, arg must be a constant in new_role: {:?}", 49 | new_role 50 | ); 51 | result_roles.push(new_role); 52 | } 53 | None => { 54 | error!("Substitution not found for role: {}", role.role_name); 55 | return Err( 56 | format!("Substitution not found for role: {}", role.role_name).into(), 57 | ); 58 | } 59 | } 60 | } else { 61 | debug!("Role is not a variable, pushing as is."); 62 | result_roles.push(role.clone()); 63 | } 64 | } 65 | debug!("Conversion to proposition completed successfully."); 66 | let function = predicate.relation.clone(); 67 | Ok(Proposition { 68 | predicate: Predicate::new_from_relation(function, result_roles), 69 | }) 70 | } 71 | 72 | pub fn extract_premise_role_map( 73 | proposition: &Proposition, 74 | role_map: &RoleMap, 75 | ) -> HashMap { 76 | debug!( 77 | "Extracting premise role map for proposition: {:?}", 78 | proposition 79 | ); 80 | let mut result = HashMap::new(); 81 | for crole in &proposition.predicate.roles() { 82 | assert!( 83 | crole.argument.is_constant(), 84 | "crole must be a constant {:?}", 85 | &crole 86 | ); 87 | let role_name = &crole.role_name; 88 | trace!("Processing role: {:?}", crole); 89 | if let Some(premise_role_name) = role_map.get(role_name) { 90 | trace!("Mapping found: {} -> {}", role_name, premise_role_name); 91 | result.insert(premise_role_name.clone(), crole.argument.clone()); 92 | } else { 93 | trace!("No mapping found for role: {}", role_name); 94 | } 95 | } 96 | debug!("Extraction complete, result: {:?}", result); 97 | result 98 | } 99 | -------------------------------------------------------------------------------- /rust/src/model/weights.rs: -------------------------------------------------------------------------------- 1 | use crate::{ 2 | common::{ 3 | redis::{map_get, map_insert}, 4 | resources::ResourceContext, 5 | }, 6 | model::objects::ImplicationFactor, 7 | }; 8 | use rand::Rng; 9 | use redis::{Commands, Connection}; 10 | use std::{cell::RefCell, error::Error}; 11 | use std::{ 12 | collections::HashMap, 13 | sync::{Arc, Mutex}, 14 | }; 15 | 16 | pub const CLASS_LABELS: [usize; 2] = [0, 1]; 17 | 18 | fn random_weight() -> f64 { 19 | let mut rng = rand::thread_rng(); 20 | (rng.gen::() - rng.gen::()) / 5.0 21 | } 22 | 23 | fn sign_char(value: usize) -> String { 24 | if value == 0 { 25 | '-'.to_string() 26 | } else { 27 | "+".to_string() 28 | } 29 | } 30 | 31 | pub fn positive_feature(feature: &str, class_label: usize) -> String { 32 | format!("+>{} {}", sign_char(class_label), feature) 33 | } 34 | 35 | pub fn negative_feature(feature: &str, class_label: usize) -> String { 36 | format!("->{} {}", sign_char(class_label), feature) 37 | } 38 | 39 | pub struct ExponentialWeights { 40 | namespace: String, 41 | } 42 | 43 | impl ExponentialWeights { 44 | pub fn new(namespace: String) -> Result> { 45 | Ok(ExponentialWeights { namespace }) 46 | } 47 | } 48 | 49 | impl ExponentialWeights { 50 | pub const WEIGHTS_KEY: &'static str = "weights"; 51 | 52 | pub fn initialize_weights( 53 | &mut self, 54 | connection: &mut Connection, 55 | implication: &ImplicationFactor, 56 | ) -> Result<(), Box> { 57 | trace!("initialize_weights - Start: {:?}", implication); 58 | let feature = implication.unique_key(); 59 | trace!("initialize_weights - Unique key: {}", feature); 60 | for class_label in CLASS_LABELS { 61 | let posf = positive_feature(&feature, class_label); 62 | let negf = negative_feature(&feature, class_label); 63 | trace!( 64 | "initialize_weights - Positive feature: {}, Negative feature: {}", 65 | posf, 66 | negf 67 | ); 68 | let weight1 = random_weight(); 69 | let weight2 = random_weight(); 70 | trace!( 71 | "initialize_weights - Generated weights: {}, {}", 72 | weight1, 73 | weight2 74 | ); 75 | map_insert( 76 | connection, 77 | &self.namespace, 78 | Self::WEIGHTS_KEY, 79 | &posf, 80 | &weight1.to_string(), 81 | )?; 82 | map_insert( 83 | connection, 84 | &self.namespace, 85 | Self::WEIGHTS_KEY, 86 | &negf, 87 | &weight2.to_string(), 88 | )?; 89 | } 90 | trace!("initialize_weights - End"); 91 | Ok(()) 92 | } 93 | 94 | pub fn read_single_weight( 95 | &self, 96 | connection: &mut Connection, 97 | feature: &str, 98 | ) -> Result> { 99 | trace!("read_weights - Start"); 100 | trace!("read_weights - Reading weight for feature: {}", feature); 101 | let weight_record = map_get(connection, &self.namespace, Self::WEIGHTS_KEY, &feature)? 102 | .unwrap_or("0.0".to_string()); 103 | // .expect("should be there"); 104 | let weight = weight_record.parse::().map_err(|e| { 105 | trace!("read_weights - Error parsing weight: {:?}", e); 106 | Box::new(e) as Box 107 | })?; 108 | trace!("read_weights - End"); 109 | Ok(weight) 110 | } 111 | 112 | pub fn read_weight_vector( 113 | &self, 114 | connection: &mut Connection, 115 | features: &[String], 116 | ) -> Result, Box> { 117 | trace!("read_weights - Start"); 118 | let mut weights = HashMap::new(); 119 | for feature in features { 120 | trace!("read_weights - Reading weight for feature: {}", feature); 121 | let weight_record = map_get(connection, &self.namespace, Self::WEIGHTS_KEY, &feature)? 122 | .expect("should be there"); 123 | let weight = weight_record.parse::().map_err(|e| { 124 | trace!("read_weights - Error parsing weight: {:?}", e); 125 | Box::new(e) as Box 126 | })?; 127 | weights.insert(feature.clone(), weight); 128 | } 129 | trace!("read_weights - End"); 130 | Ok(weights) 131 | } 132 | 133 | pub fn save_weight_vector( 134 | &mut self, 135 | connection: &mut Connection, 136 | weights: &HashMap, 137 | ) -> Result<(), Box> { 138 | trace!("save_weights - Start"); 139 | for (feature, &value) in weights { 140 | trace!( 141 | "save_weights - Saving weight for feature {}: {}", 142 | feature, 143 | value 144 | ); 145 | map_insert( 146 | connection, 147 | &self.namespace, 148 | Self::WEIGHTS_KEY, 149 | &feature, 150 | &value.to_string(), 151 | )?; 152 | } 153 | trace!("save_weights - End"); 154 | Ok(()) 155 | } 156 | } 157 | -------------------------------------------------------------------------------- /rust/src/scenarios/dating_triangle.rs: -------------------------------------------------------------------------------- 1 | use crate::common::proposition_db::RedisBeliefTable; 2 | use crate::common::graph::InferenceGraph; 3 | use crate::common::interface::BeliefTable; 4 | use crate::common::model::InferenceModel; 5 | use crate::common::redis::RedisManager; 6 | use crate::common::resources::{self, NamespaceBundle}; 7 | use crate::common::train::TrainingPlan; 8 | use crate::model::creators::predicate; 9 | use crate::scenarios::helpers::weighted_cointoss; 10 | use crate::{ 11 | common::interface::ScenarioMaker, 12 | model::{ 13 | creators::{conjunction, constant, implication, obj, proposition, sub, variable}, 14 | objects::{Domain, Entity, RoleMap}, 15 | }, 16 | }; 17 | use std::{collections::HashMap, error::Error}; 18 | 19 | pub struct EligibilityTriangle {} 20 | 21 | impl ScenarioMaker for EligibilityTriangle { 22 | fn setup_scenario( 23 | &self, 24 | resources: &NamespaceBundle, 25 | ) -> Result<(), Box> { 26 | let mut graph = InferenceGraph::new_mutable(resources)?; 27 | let proposition_db = RedisBeliefTable::new_mutable(&resources)?; 28 | let mut plan = TrainingPlan::new(&resources)?; 29 | let config = &resources.config; 30 | let total_members_each_class = config.entities_per_domain; 31 | let jack_domain = Domain::MAN.to_string(); 32 | for i in 0..total_members_each_class { 33 | let is_test = i == 0; 34 | let is_training = !is_test; 35 | let prefix = if is_test { "test" } else { "train" }; 36 | let name = format!("{}_{:?}{}", &prefix, Domain::MAN, i); 37 | let domain = Domain::MAN.to_string(); 38 | let jack_entity = Entity { 39 | domain, 40 | name: name.clone(), 41 | }; 42 | graph.store_entity(&jack_entity)?; 43 | let jack = constant(jack_entity.domain, jack_entity.name.clone()); 44 | let p_jack_charming = weighted_cointoss(0.3f64); 45 | let jack_charming = proposition("charming".to_string(), vec![sub(jack.clone())]); 46 | proposition_db.store_proposition_boolean(&jack_charming, p_jack_charming)?; 47 | plan.maybe_add_to_training(is_training, &jack_charming)?; 48 | graph.ensure_existence_backlinks_for_proposition(&jack_charming)?; 49 | let p_jack_rich: bool = if p_jack_charming { 50 | weighted_cointoss(0.7f64) 51 | } else { 52 | weighted_cointoss(0.2f64) 53 | }; 54 | let jack_rich = proposition("rich".to_string(), vec![sub(jack.clone())]); 55 | proposition_db.store_proposition_boolean(&jack_rich, p_jack_rich)?; 56 | plan.maybe_add_to_training(is_training, &jack_rich)?; 57 | let p_jack_baller = p_jack_charming && p_jack_rich; 58 | let jack_baller = proposition("baller".to_string(), vec![sub(jack.clone())]); 59 | proposition_db.store_proposition_boolean(&jack_baller, p_jack_baller)?; 60 | plan.maybe_add_to_training(is_training, &jack_baller)?; 61 | plan.maybe_add_to_test(is_test, &jack_baller)?; 62 | } 63 | 64 | let xjack = variable(Domain::MAN.to_string()); 65 | let implications = vec![ 66 | implication( 67 | conjunction(vec![predicate("charming".to_string(), vec![ 68 | sub(xjack.clone()), 69 | ])]), 70 | predicate("rich".to_string(), 71 | vec![ 72 | sub(xjack.clone()), 73 | ]), 74 | vec![RoleMap::new(HashMap::from([( 75 | "sub".to_string(), 76 | "sub".to_string(), 77 | )]))], 78 | ), 79 | implication( 80 | conjunction(vec![ 81 | predicate("rich".to_string(), 82 | vec![ 83 | sub(xjack.clone()), 84 | ]), 85 | predicate("charming".to_string(), vec![ 86 | sub(xjack.clone()), 87 | ]), 88 | ]), 89 | predicate("baller".to_string(), 90 | vec![ 91 | sub(xjack.clone()), 92 | ]), 93 | vec![ 94 | RoleMap::new(HashMap::from([ 95 | ("sub".to_string(), "sub".to_string()), 96 | ])), 97 | RoleMap::new(HashMap::from([ 98 | ("sub".to_string(), "sub".to_string()), 99 | ])), 100 | ], 101 | ), 102 | ]; 103 | graph.store_predicate_implications(&implications)?; 104 | Ok(()) 105 | } 106 | } 107 | -------------------------------------------------------------------------------- /rust/src/scenarios/factory.rs: -------------------------------------------------------------------------------- 1 | use std::{error::Error, rc::Rc}; 2 | 3 | use crate::common::{interface::ScenarioMaker, resources::ResourceContext}; 4 | 5 | use super::{dating_simple::SimpleDating, one_var::OneVariable}; 6 | 7 | pub struct ScenarioMakerFactory; 8 | 9 | impl ScenarioMakerFactory { 10 | pub fn new_shared(namespace: &str) -> Result, Box> { 11 | match namespace { 12 | "dating_simple" => Ok(Rc::new(SimpleDating {})), 13 | // "dating_triangle" => Ok(Rc::new(EligibilityTriangle {})), 14 | "one_var" => Ok(Rc::new(OneVariable {})), 15 | // "long_chain" => Ok(Rc::new(long_chain::Scenario {})), 16 | // "mid_chain" => Ok(Rc::new(mid_chain::Scenario {})), 17 | // "long_and" => Ok(Rc::new(long_and::Scenario {})), 18 | // "two_var" => Ok(Rc::new(TwoVariable {})), 19 | _ => Err("Unknown ScenarioMaker type".into()), 20 | } 21 | } 22 | } 23 | -------------------------------------------------------------------------------- /rust/src/scenarios/helpers.rs: -------------------------------------------------------------------------------- 1 | 2 | use rand::Rng; 3 | pub fn weighted_cointoss(threshold: f64) -> bool { 4 | let mut rng = rand::thread_rng(); // Get a random number generator 5 | if rng.gen::() < threshold { 6 | true 7 | } else { 8 | false 9 | } 10 | } 11 | -------------------------------------------------------------------------------- /rust/src/scenarios/long_and.rs: -------------------------------------------------------------------------------- 1 | use crate::common::graph::InferenceGraph; 2 | use crate::common::interface::BeliefTable; 3 | use crate::common::model::InferenceModel; 4 | use crate::common::proposition_db::RedisBeliefTable; 5 | use crate::common::redis::RedisManager; 6 | use crate::common::resources::{self, NamespaceBundle}; 7 | use crate::common::train::TrainingPlan; 8 | use crate::model::choose::extract_existence_factor_for_proposition; 9 | use crate::model::creators::predicate; 10 | use crate::{ 11 | common::interface::ScenarioMaker, 12 | model::{ 13 | creators::{conjunction, constant, implication, obj, proposition, sub, variable}, 14 | objects::{Domain, Entity, RoleMap}, 15 | }, 16 | }; 17 | use crate::{print_red, print_yellow}; 18 | use rand::Rng; // Import Rng trait 19 | use std::{collections::HashMap, error::Error}; 20 | 21 | use super::helpers::weighted_cointoss; 22 | 23 | pub struct Scenario {} 24 | 25 | const LINK_HEIGHT: u32 = 10; 26 | 27 | impl ScenarioMaker for Scenario { 28 | fn setup_scenario(&self, resources: &NamespaceBundle) -> Result<(), Box> { 29 | let mut graph = InferenceGraph::new_mutable(resources)?; 30 | let proposition_db = RedisBeliefTable::new_mutable(&resources)?; 31 | let mut plan = TrainingPlan::new(&resources)?; 32 | let config = &resources.config; 33 | let total_members_each_class = config.entities_per_domain; 34 | let domain = Domain::MAN.to_string(); 35 | for i in 0..total_members_each_class { 36 | let is_test = i == 0; 37 | let is_training = !is_test; 38 | let prefix = if is_test { "test" } else { "train" }; 39 | let name = format!("{}_{:?}{}", &prefix, domain, i); 40 | let jack_entity = Entity { 41 | domain: domain.clone(), 42 | name: name.clone(), 43 | }; 44 | graph.store_entity(&jack_entity)?; 45 | 46 | let p_jack_alpha = weighted_cointoss(0.3f64); 47 | let p_jack_beta = weighted_cointoss(0.3f64); 48 | let p_jack_gamma = p_jack_alpha && p_jack_beta; 49 | let jack = constant(jack_entity.domain, jack_entity.name.clone()); 50 | for level in 0..LINK_HEIGHT { 51 | let function = format!("alpha{}", level); 52 | let jack_alpha = proposition(function, vec![sub(jack.clone())]); 53 | if level == 0 { 54 | graph.ensure_existence_backlinks_for_proposition(&jack_alpha)?; 55 | } 56 | proposition_db.store_proposition_boolean(&jack_alpha, p_jack_alpha)?; 57 | plan.maybe_add_to_training(is_training, &jack_alpha)?; 58 | } 59 | for level in 0..LINK_HEIGHT { 60 | let function = format!("beta{}", level); 61 | let jack_beta = proposition(function, vec![sub(jack.clone())]); 62 | if level == 0 { 63 | graph.ensure_existence_backlinks_for_proposition(&jack_beta)?; 64 | } 65 | proposition_db.store_proposition_boolean(&jack_beta, p_jack_beta)?; 66 | plan.maybe_add_to_training(is_training, &jack_beta)?; 67 | } 68 | { 69 | let function = format!("gamma"); 70 | let jack_gamma = proposition(function, vec![sub(jack.clone())]); 71 | proposition_db.store_proposition_boolean(&jack_gamma, p_jack_gamma)?; 72 | plan.maybe_add_to_training(is_training, &jack_gamma)?; 73 | plan.maybe_add_to_test(is_test, &jack_gamma)?; 74 | } 75 | } 76 | let xjack = variable(Domain::MAN.to_string()); 77 | let mut implications = vec![]; 78 | let channel_names = ["alpha", "beta"]; 79 | for channel_name in channel_names { 80 | for level in 0..(LINK_HEIGHT - 1) { 81 | let fn1 = format!("{}{}", channel_name, level); 82 | let fn2 = format!("{}{}", channel_name, level + 1); 83 | implications.push(implication( 84 | conjunction(vec![predicate(fn1, vec![sub(xjack.clone())])]), 85 | predicate(fn2, vec![sub(xjack.clone())]), 86 | vec![RoleMap::new(HashMap::from([( 87 | "sub".to_string(), 88 | "sub".to_string(), 89 | )]))], 90 | )); 91 | } 92 | } 93 | implications.push(implication( 94 | conjunction(vec![ 95 | predicate(format!("{}{}", "alpha", LINK_HEIGHT - 1), vec![sub(xjack.clone())]), 96 | predicate(format!("{}{}", "beta", LINK_HEIGHT - 1), vec![sub(xjack.clone())]), 97 | ]), 98 | predicate(format!("gamma"), vec![sub(xjack.clone())]), 99 | vec![ 100 | RoleMap::new(HashMap::from([( 101 | "sub".to_string(), 102 | "sub".to_string(), 103 | )])), 104 | RoleMap::new(HashMap::from([( 105 | "sub".to_string(), 106 | "sub".to_string(), 107 | )])) 108 | ], 109 | )); 110 | graph.store_predicate_implications(&implications)?; 111 | Ok(()) 112 | } 113 | } 114 | -------------------------------------------------------------------------------- /rust/src/scenarios/long_chain.rs: -------------------------------------------------------------------------------- 1 | use crate::common::graph::InferenceGraph; 2 | use crate::common::interface::BeliefTable; 3 | use crate::common::model::InferenceModel; 4 | use crate::common::proposition_db::RedisBeliefTable; 5 | use crate::common::redis::RedisManager; 6 | use crate::common::resources::{self, NamespaceBundle}; 7 | use crate::common::train::TrainingPlan; 8 | use crate::model::choose::extract_existence_factor_for_proposition; 9 | use crate::model::creators::predicate; 10 | use crate::{ 11 | common::interface::ScenarioMaker, 12 | model::{ 13 | creators::{conjunction, constant, implication, obj, proposition, sub, variable}, 14 | objects::{Domain, Entity, RoleMap}, 15 | }, 16 | }; 17 | use crate::{print_red, print_yellow}; 18 | use rand::Rng; // Import Rng trait 19 | use std::{collections::HashMap, error::Error}; 20 | 21 | use super::helpers::weighted_cointoss; 22 | 23 | pub struct Scenario {} 24 | 25 | const LINK_HEIGHT: u32 = 11; 26 | 27 | impl ScenarioMaker for Scenario { 28 | fn setup_scenario(&self, resources: &NamespaceBundle) -> Result<(), Box> { 29 | let mut graph = InferenceGraph::new_mutable(resources)?; 30 | let proposition_db = RedisBeliefTable::new_mutable(&resources)?; 31 | let mut plan = TrainingPlan::new(&resources)?; 32 | let config = &resources.config; 33 | let total_members_each_class = config.entities_per_domain; 34 | let domain = Domain::MAN.to_string(); 35 | for i in 0..total_members_each_class { 36 | let is_test = i == 0; 37 | let is_training = !is_test; 38 | let prefix = if is_test { "test" } else { "train" }; 39 | let name = format!("{}_{:?}{}", &prefix, domain, i); 40 | let jack_entity = Entity { 41 | domain: domain.clone(), 42 | name: name.clone(), 43 | }; 44 | graph.store_entity(&jack_entity)?; 45 | 46 | let p_jack_alpha = weighted_cointoss(0.5f64); 47 | for level in 0..LINK_HEIGHT { 48 | let jack = constant(jack_entity.domain.clone(), jack_entity.name.clone()); 49 | let function = format!("alpha{}", level); 50 | let jack_alpha = proposition(function, vec![sub(jack)]); 51 | if level == 0 { 52 | graph.ensure_existence_backlinks_for_proposition(&jack_alpha)?; 53 | } 54 | proposition_db.store_proposition_boolean(&jack_alpha, p_jack_alpha)?; 55 | plan.maybe_add_to_training(is_training, &jack_alpha)?; 56 | 57 | if level == LINK_HEIGHT - 1 { 58 | plan.maybe_add_to_test(is_test, &jack_alpha)?; 59 | } 60 | } 61 | } 62 | let xjack = variable(Domain::MAN.to_string()); 63 | let mut implications = vec![]; 64 | for level in 0..(LINK_HEIGHT-1) { 65 | let fn1 = format!("alpha{}", level); 66 | let fn2 = format!("alpha{}", level + 1); 67 | implications.push(implication( 68 | conjunction(vec![predicate( 69 | fn1, 70 | vec![sub(xjack.clone())], 71 | )]), 72 | predicate(fn2, vec![sub(xjack.clone())]), 73 | vec![RoleMap::new(HashMap::from([( 74 | "sub".to_string(), 75 | "sub".to_string(), 76 | )]))], 77 | )); 78 | } 79 | graph.store_predicate_implications(&implications)?; 80 | Ok(()) 81 | } 82 | } 83 | -------------------------------------------------------------------------------- /rust/src/scenarios/mid_chain.rs: -------------------------------------------------------------------------------- 1 | use crate::common::graph::InferenceGraph; 2 | use crate::common::interface::BeliefTable; 3 | use crate::common::model::InferenceModel; 4 | use crate::common::proposition_db::RedisBeliefTable; 5 | use crate::common::redis::RedisManager; 6 | use crate::common::resources::{self, NamespaceBundle}; 7 | use crate::common::train::TrainingPlan; 8 | use crate::model::choose::extract_existence_factor_for_proposition; 9 | use crate::model::creators::predicate; 10 | use crate::{ 11 | common::interface::ScenarioMaker, 12 | model::{ 13 | creators::{conjunction, constant, implication, obj, proposition, sub, variable}, 14 | objects::{Domain, Entity, RoleMap}, 15 | }, 16 | }; 17 | use crate::{print_red, print_yellow}; 18 | use rand::Rng; // Import Rng trait 19 | use std::{collections::HashMap, error::Error}; 20 | 21 | use super::helpers::weighted_cointoss; 22 | 23 | pub struct Scenario {} 24 | 25 | const LINK_HEIGHT: u32 = 5; 26 | 27 | impl ScenarioMaker for Scenario { 28 | fn setup_scenario(&self, resources: &NamespaceBundle) -> Result<(), Box> { 29 | let mut graph = InferenceGraph::new_mutable(resources)?; 30 | let proposition_db = RedisBeliefTable::new_mutable(&resources)?; 31 | let mut plan = TrainingPlan::new(&resources)?; 32 | let config = &resources.config; 33 | let total_members_each_class = config.entities_per_domain; 34 | let domain = Domain::MAN.to_string(); 35 | for i in 0..total_members_each_class { 36 | let is_test = i == 0; 37 | let is_training = !is_test; 38 | let prefix = if is_test { "test" } else { "train" }; 39 | let name = format!("{}_{:?}{}", &prefix, domain, i); 40 | let jack_entity = Entity { 41 | domain: domain.clone(), 42 | name: name.clone(), 43 | }; 44 | graph.store_entity(&jack_entity)?; 45 | 46 | let p_jack_alpha = weighted_cointoss(0.3f64); 47 | for level in 0..LINK_HEIGHT { 48 | let jack = constant(jack_entity.domain.clone(), jack_entity.name.clone()); 49 | let function = format!("alpha{}", level); 50 | let jack_alpha = proposition(function, vec![sub(jack)]); 51 | if level == 0 { 52 | graph.ensure_existence_backlinks_for_proposition(&jack_alpha)?; 53 | } 54 | proposition_db.store_proposition_boolean(&jack_alpha, p_jack_alpha)?; 55 | plan.maybe_add_to_training(is_training, &jack_alpha)?; 56 | 57 | if level == LINK_HEIGHT - 1 { 58 | plan.maybe_add_to_test(is_test, &jack_alpha)?; 59 | } 60 | } 61 | } 62 | let xjack = variable(Domain::MAN.to_string()); 63 | let mut implications = vec![]; 64 | for level in 0..(LINK_HEIGHT-1) { 65 | let fn1 = format!("alpha{}", level); 66 | let fn2 = format!("alpha{}", level + 1); 67 | implications.push(implication( 68 | conjunction(vec![predicate( 69 | fn1, 70 | vec![sub(xjack.clone())], 71 | )]), 72 | predicate(fn2, vec![sub(xjack.clone())]), 73 | vec![RoleMap::new(HashMap::from([( 74 | "sub".to_string(), 75 | "sub".to_string(), 76 | )]))], 77 | )); 78 | } 79 | graph.store_predicate_implications(&implications)?; 80 | Ok(()) 81 | } 82 | } 83 | -------------------------------------------------------------------------------- /rust/src/scenarios/mod.rs: -------------------------------------------------------------------------------- 1 | pub mod factory; 2 | // scenarios 3 | pub mod dating_simple; 4 | // pub mod dating_triangle; 5 | pub mod one_var; 6 | // pub mod two_var; 7 | // pub mod helpers; 8 | // pub mod long_chain; 9 | // pub mod long_and; 10 | // pub mod mid_chain; -------------------------------------------------------------------------------- /rust/src/scenarios/one_var.rs: -------------------------------------------------------------------------------- 1 | use crate::common::graph::InferenceGraph; 2 | use crate::common::interface::BeliefTable; 3 | use crate::common::model::InferenceModel; 4 | use crate::common::proposition_db::RedisBeliefTable; 5 | use crate::common::redis::RedisManager; 6 | use crate::common::resources::{self, ResourceContext}; 7 | use crate::common::train::TrainingPlan; 8 | use crate::model::choose::extract_existence_factor_for_proposition; 9 | use crate::model::creators::{predicate, relation, variable_argument}; 10 | use crate::{ 11 | common::interface::ScenarioMaker, 12 | model::{ 13 | creators::{conjunction, constant, implication, obj, proposition, sub, variable}, 14 | objects::{Domain, Entity, RoleMap}, 15 | }, 16 | }; 17 | use crate::{print_red, print_yellow}; 18 | use rand::Rng; // Import Rng trait 19 | use std::{collections::HashMap, error::Error}; 20 | fn cointoss() -> f64 { 21 | let mut rng = rand::thread_rng(); // Get a random number generator 22 | if rng.gen::() < 0.5 { 23 | 1.0 24 | } else { 25 | 0.0 26 | } 27 | } 28 | 29 | fn weighted_cointoss(threshold: f64) -> f64 { 30 | let mut rng = rand::thread_rng(); // Get a random number generator 31 | if rng.gen::() < threshold { 32 | 1.0 33 | } else { 34 | 0.0 35 | } 36 | } 37 | 38 | pub struct OneVariable {} 39 | 40 | impl ScenarioMaker for OneVariable { 41 | fn setup_scenario(&self, resources: &ResourceContext) -> Result<(), Box> { 42 | // let mut graph = InferenceGraph::new_mutable(resources.connection.clone(), resources.namespace.clone())?; 43 | // let proposition_db = RedisBeliefTable::new_mutable(&resources)?; 44 | // let mut plan = TrainingPlan::new(&resources)?; 45 | // let total_members_each_class = 1024; 46 | // let jack_domain = Domain::MAN.to_string(); 47 | // graph.register_domain(&jack_domain)?; 48 | // let jack_relation = relation( 49 | // "exciting".to_string(), 50 | // vec![variable_argument(jack_domain.clone())], 51 | // ); 52 | // graph.register_relation(&jack_relation)?; 53 | // for i in 0..total_members_each_class { 54 | // let is_test = i % 10 == 9; 55 | // let is_training = !is_test; 56 | // let domain = Domain::MAN.to_string(); 57 | // let prefix = if is_test { "test" } else { "train" }; 58 | // let name = format!("{}_{:?}{}", &prefix, domain, i); 59 | // let jack_entity = Entity { 60 | // domain: domain.clone(), 61 | // name: name.clone(), 62 | // }; 63 | // graph.store_entity(&jack_entity)?; 64 | // let p_jack_exciting = weighted_cointoss(0.3f64); 65 | // { 66 | // let jack = constant(jack_entity.domain, jack_entity.name.clone()); 67 | // let jack_exciting = proposition(jack_relation.clone(), vec![sub(jack)]); 68 | // graph.ensure_existence_backlinks_for_proposition(&jack_exciting)?; 69 | // proposition_db.store_proposition_probability(&jack_exciting, p_jack_exciting)?; 70 | // plan.maybe_add_to_training(is_training, &jack_exciting)?; 71 | // plan.maybe_add_to_test(is_test, &jack_exciting)?; 72 | // } 73 | // } 74 | // Ok(()) 75 | panic!() 76 | } 77 | } 78 | -------------------------------------------------------------------------------- /rust/src/scenarios/two_var.rs: -------------------------------------------------------------------------------- 1 | use crate::common::proposition_db::RedisBeliefTable; 2 | use crate::common::graph::InferenceGraph; 3 | use crate::common::interface::BeliefTable; 4 | use crate::common::model::InferenceModel; 5 | use crate::common::redis::RedisManager; 6 | use crate::common::resources::{self, NamespaceBundle}; 7 | use crate::common::train::TrainingPlan; 8 | use crate::model::choose::extract_existence_factor_for_proposition; 9 | use crate::model::creators::predicate; 10 | use crate::{print_red, print_yellow}; 11 | use crate::{ 12 | common::interface::ScenarioMaker, 13 | model::{ 14 | creators::{conjunction, constant, implication, obj, proposition, sub, variable}, 15 | objects::{Domain, Entity, RoleMap}, 16 | }, 17 | }; 18 | use rand::Rng; // Import Rng trait 19 | use std::{collections::HashMap, error::Error}; 20 | 21 | use super::helpers::weighted_cointoss; 22 | 23 | pub struct TwoVariable {} 24 | 25 | impl ScenarioMaker for TwoVariable { 26 | fn setup_scenario( 27 | &self, 28 | resources: &NamespaceBundle, 29 | ) -> Result<(), Box> { 30 | let mut graph = InferenceGraph::new_mutable(resources)?; 31 | let proposition_db = RedisBeliefTable::new_mutable(&resources)?; 32 | let mut plan = TrainingPlan::new(&resources)?; 33 | let config = &resources.config; 34 | let total_members_each_class = config.entities_per_domain; 35 | let jack_domain = Domain::MAN.to_string(); 36 | let jacks: Vec = graph.get_entities_in_domain(&jack_domain)?; 37 | let mut propositions = vec![]; 38 | for i in 0..total_members_each_class { 39 | let is_test = i == 0; 40 | let is_training = !is_test; 41 | let mut domain_entity_map: HashMap = HashMap::new(); 42 | for domain in [Domain::MAN.to_string()].iter() { 43 | let prefix = if is_test { "test" } else { "train" }; 44 | let name = format!("{}_{:?}{}", &prefix, domain, i); 45 | let entity = Entity { 46 | domain: domain.clone(), 47 | name: name.clone(), 48 | }; 49 | graph.store_entity(&entity)?; 50 | domain_entity_map.insert(domain.to_string(), entity); 51 | } 52 | let jack_entity = &domain_entity_map[&Domain::MAN.to_string()]; 53 | let p_jack_exciting = weighted_cointoss(0.3f64); 54 | { 55 | let jack = constant(jack_entity.domain.clone(), jack_entity.name.clone()); 56 | let jack_exciting = proposition("exciting".to_string(), vec![sub(jack)]); 57 | graph.ensure_existence_backlinks_for_proposition(&jack_exciting)?; 58 | proposition_db.store_proposition_boolean(&jack_exciting, p_jack_exciting)?; 59 | plan.maybe_add_to_training(is_training, &jack_exciting)?; 60 | propositions.push(jack_exciting.clone()); 61 | } 62 | { 63 | let jack = constant(jack_entity.domain.clone(), jack_entity.name.clone()); 64 | let jack_rich = proposition("rich".to_string(), vec![sub(jack)]); 65 | graph.ensure_existence_backlinks_for_proposition(&jack_rich)?; 66 | proposition_db.store_proposition_boolean(&jack_rich, p_jack_exciting)?; 67 | plan.maybe_add_to_training(is_training, &jack_rich)?; 68 | propositions.push(jack_rich.clone()); 69 | plan.maybe_add_to_test(is_test, &jack_rich)?; 70 | } 71 | } 72 | let xjack = variable(Domain::MAN.to_string()); 73 | let implications = vec![ 74 | implication( 75 | conjunction(vec![predicate("exciting".to_string(), vec![ 76 | sub(xjack.clone()), 77 | ])]), 78 | predicate("rich".to_string(), 79 | vec![ 80 | sub(xjack.clone()), 81 | ]), 82 | vec![RoleMap::new(HashMap::from([( 83 | "sub".to_string(), 84 | "sub".to_string(), 85 | )]))], 86 | ), 87 | ]; 88 | for implication in implications.iter() { 89 | trace!("Storing implication: {:?}", implication); 90 | graph.store_predicate_implication(implication)?; 91 | } 92 | Ok(()) 93 | } 94 | } 95 | -------------------------------------------------------------------------------- /rust/static/images/domains/Man.png: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:3fbabdedb8b48d9d24c50ca3ce5e715136e2ac996034956fd5a8e61881824c57 3 | size 28626 4 | -------------------------------------------------------------------------------- /rust/static/images/domains/Woman.png: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:39130b0713dbc7e22e8e10a607e521edeb5471699421c40c6381e15b99ad5a06 3 | size 78148 4 | -------------------------------------------------------------------------------- /rust/train.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Assign variable names for better readability 4 | SCENARIO_NAME=$1 5 | 6 | # Check if the scenario name is provided 7 | if [ -z "$SCENARIO_NAME" ] ; then 8 | echo "usage: ./train.sh " 9 | exit 1 10 | fi 11 | 12 | # User has typed the confirmation message or BAYES_STAR_CAN_CLEAR_REDIS is set to 1, proceed with the command. 13 | RUST_BACKTRACE=1 RUST_LOG=info cargo run --bin train -- --print_training_loss --entities_per_domain=4096 --scenario_name=$SCENARIO_NAME 14 | -------------------------------------------------------------------------------- /rust/unittests.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Check if an argument (test name) is provided 4 | if [ "$#" -eq 1 ]; then 5 | # Run only the specified test 6 | cargo test $1 -- --test-threads=1 --nocapture 7 | else 8 | # Run all tests 9 | cargo test -- --test-threads=1 --nocapture 10 | fi 11 | 12 | --------------------------------------------------------------------------------