├── .gitignore ├── CHANGELOG ├── README.md ├── docs ├── hypothesis_space.md └── simple_learning_system.py ├── loreleai ├── __init__.py ├── language │ ├── __init__.py │ ├── datalog │ │ └── __init__.py │ ├── kanren │ │ ├── __init__.py │ │ └── kanren_utils.py │ ├── lp │ │ └── __init__.py │ └── utils.py ├── learning │ ├── __init__.py │ ├── abstract_learners.py │ ├── eval_functions.py │ ├── hypothesis_space.py │ ├── language_filtering.py │ ├── language_manipulation.py │ ├── learners │ │ ├── __init__.py │ │ ├── aleph.py │ │ └── breadth_first_learner.py │ ├── task.py │ └── utilities.py └── reasoning │ ├── __init__.py │ └── lp │ ├── __init__.py │ ├── datalog │ └── __init__.py │ ├── kanren │ └── __init__.py │ └── prolog │ └── __init__.py ├── requirements.txt ├── setup.py └── tests ├── __init__.py ├── aleph_tests.py ├── common_tasks.py ├── eval_functions_tests.py ├── language_test.py ├── solver_datalog_test.py ├── solver_kanren_test.py ├── solver_prolog_gnu_test.py ├── solver_prolog_swipy_test.py └── solver_prolog_xsb_test.py /.gitignore: -------------------------------------------------------------------------------- 1 | venv/ 2 | \.pytest_cache/ 3 | nvenv/ 4 | z3-z3-4.8.7/ 5 | *.pytest_cache/ 6 | 7 | # Python recommended .gitignore at e931ef7 8 | # https://github.com/github/gitignore/blob/e931ef7f3e7d8f7aa0e784c14bd291ad4448b1ab/Python.gitignore 9 | # Byte-compiled / optimized / DLL files 10 | __pycache__/ 11 | *.py[cod] 12 | *$py.class 13 | 14 | # C extensions 15 | *.so 16 | 17 | # Distribution / packaging 18 | .Python 19 | build/ 20 | develop-eggs/ 21 | dist/ 22 | downloads/ 23 | eggs/ 24 | .eggs/ 25 | lib/ 26 | lib64/ 27 | parts/ 28 | sdist/ 29 | var/ 30 | wheels/ 31 | pip-wheel-metadata/ 32 | share/python-wheels/ 33 | *.egg-info/ 34 | .installed.cfg 35 | *.egg 36 | MANIFEST 37 | 38 | # PyInstaller 39 | # Usually these files are written by a python script from a template 40 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 41 | *.manifest 42 | *.spec 43 | 44 | # Installer logs 45 | pip-log.txt 46 | pip-delete-this-directory.txt 47 | 48 | # Unit test / coverage reports 49 | htmlcov/ 50 | .tox/ 51 | .nox/ 52 | .coverage 53 | .coverage.* 54 | .cache 55 | nosetests.xml 56 | coverage.xml 57 | *.cover 58 | *.py,cover 59 | .hypothesis/ 60 | .pytest_cache/ 61 | cover/ 62 | 63 | # Translations 64 | *.mo 65 | *.pot 66 | 67 | # Django stuff: 68 | *.log 69 | local_settings.py 70 | db.sqlite3 71 | db.sqlite3-journal 72 | 73 | # Flask stuff: 74 | instance/ 75 | .webassets-cache 76 | 77 | # Scrapy stuff: 78 | .scrapy 79 | 80 | # Sphinx documentation 81 | docs/_build/ 82 | 83 | # PyBuilder 84 | .pybuilder/ 85 | target/ 86 | 87 | # Jupyter Notebook 88 | .ipynb_checkpoints 89 | 90 | # IPython 91 | profile_default/ 92 | ipython_config.py 93 | 94 | # pyenv 95 | # For a library or package, you might want to ignore these files since the code is 96 | # intended to run in multiple environments; otherwise, check them in: 97 | # .python-version 98 | 99 | # pipenv 100 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 101 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 102 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 103 | # install all needed dependencies. 104 | #Pipfile.lock 105 | 106 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 107 | __pypackages__/ 108 | 109 | # Celery stuff 110 | celerybeat-schedule 111 | celerybeat.pid 112 | 113 | # SageMath parsed files 114 | *.sage.py 115 | 116 | # Environments 117 | .env 118 | .venv 119 | env/ 120 | venv/ 121 | ENV/ 122 | env.bak/ 123 | venv.bak/ 124 | 125 | # Spyder project settings 126 | .spyderproject 127 | .spyproject 128 | 129 | # Rope project settings 130 | .ropeproject 131 | 132 | # mkdocs documentation 133 | /site 134 | 135 | # mypy 136 | .mypy_cache/ 137 | .dmypy.json 138 | dmypy.json 139 | 140 | # Pyre type checker 141 | .pyre/ 142 | 143 | # pytype static type analyzer 144 | .pytype/ 145 | 146 | # Cython debug symbols 147 | cython_debug/ 148 | 149 | # static files generated from Django application using `collectstatic` 150 | media 151 | static 152 | -------------------------------------------------------------------------------- /CHANGELOG: -------------------------------------------------------------------------------- 1 | # 0.1.4 2 | - added cache support in hypothesis space 3 | - hypothesis space returns recursions with Predicate as head_constructor 4 | - fixed recursion generation 5 | - when accessing recursion in hypothesis space, the partner pointer is always reset to the root (so that get_successors work) 6 | 7 | # 0.1.3 8 | - added pair from pylo 9 | 10 | # 0.1.2 11 | - added bottom clause based refinement 12 | 13 | # 0.1.1 14 | - updated to Pylo 0.3.2 -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # loreleai 2 | Lorelai aims to be Keras for **Lo**gical **re**asoning and **le**arning in AI. 3 | It provides a unified language for expressing logical theories and connects it to various backends (Prolog, Answer Set Programming, Datalog, ...) to reason with the provided theories. 4 | 5 | **THIS IS STILL WORK IN PROGRESS, EXPECT CHANGES!** 6 | 7 | # Installation 8 | 9 | `loreleai` depends on [pylo](https://github.com/sebdumancic/pylo2) to interface with Prolog engines. 10 | Follow the instructions to install `pylo` [here](https://github.com/sebdumancic/pylo2). 11 | 12 | If you will be using a Datalog engine, follow the instructions to install [z3](https://github.com/Z3Prover/z3). 13 | 14 | Then clone this repository and run 15 | ```shell script 16 | pip install . 17 | ``` 18 | 19 | 20 | 21 | # Quick start 22 | `loreleai` allows you to easy specify you knowledge and ask queries about it 23 | 24 | **More details:** for more details on usage of logic and Prolog engines, check the instructions of [pylo](https://github.com/sebdumancic/pylo2) 25 | 26 | ```python 27 | from loreleai.language.lp import c_const, c_var, c_pred 28 | from loreleai.reasoning.lp.datalog import MuZ 29 | 30 | p1 = c_const("p1") # create a constant with the name 'p1' 31 | p2 = c_const("p2") 32 | p3 = c_const("p3") 33 | 34 | parent = c_pred("parent", 2) # create a predicate/relation 'parent' 35 | grandparent = c_pred("grandparent", 2) 36 | 37 | f1 = parent(p1, p2) # create the fact 'parent(p1, p2)' 38 | f2 = parent(p2, p3) 39 | 40 | V1 = c_var("X") # create a variable named 'X' 41 | V2 = c_var("Y") 42 | V3 = c_var("Z") 43 | 44 | # create a clause defining the grandparent relation 45 | cl = (grandparent(V1, V3) <= parent(V1, V2) & parent(V2, V3)) 46 | 47 | solver = MuZ() # instantiate the solver 48 | # Z3 datalog (muZ)I pro 49 | solver.assert_fact(f1) # assert a fact 50 | solver.assert_fact(f2) 51 | solver.assert_rule(cl) # assert a rule 52 | 53 | solver.has_solution(grandparent(p1, p3))# ask whether there is a solution to a query 54 | solver.query(parent(V1, V2)) # ask for all solutions 55 | solver.query(grandparent(p1, V1), max_solutions=1)# ask for a single solution 56 | ``` 57 | 58 | Alternatively, `loreleai` provides shortcuts to defining facts 59 | ```python 60 | from loreleai.language.lp import c_pred 61 | 62 | parent = c_pred("parent", 2) # create a predicate/relation 'parent' 63 | grandparent = c_pred("grandparent", 2) 64 | 65 | f1 = parent("p1", "p2") # 'p1' and 'p2' are automatically parsed into a Constant 66 | f2 = parent("p2", "p3") 67 | 68 | query_literal = grandparent("p1", "X") # 'X' is automatically parsed into a Variable 69 | ``` 70 | 71 | Reasoning engines are located in `loreleai.reasoning.lp` followed by the specific type of logic programming 72 | - `prolog`: supported Prolog engines: 73 | - `SWIProlog` for SWI Prolog 74 | - `GNUProlog` for GNU Prolog 75 | - `XSBProlog` for XSB Prolog 76 | - `datalog`: supported Datalog engines: 77 | - `MuZ` for Z3's datalog engine 78 | - `kanren`: for relational programming 79 | - `MiniKanren` 80 | 81 | 82 | # Supported reasoning engines 83 | 84 | ### Prolog 85 | 86 | Currently supported (via [pylo](https://github.com/sebdumancic/pylo2)): 87 | - [SWI Prolog](https://www.swi-prolog.org/) 88 | - [XSB Prolog](http://xsb.sourceforge.net/) 89 | - [GNU Prolog](http://www.gprolog.org/) 90 | 91 | 92 | ### Relational programming 93 | Prolog without side-effects (cut and so on) 94 | 95 | Currently supported: 96 | - [miniKanren](https://github.com/pythological/kanren); seems to be actively maintained 97 | 98 | 99 | ### Datalog 100 | A subset of Prolog without functors/structures 101 | 102 | Currently supported: 103 | - [muZ (Z3's datalog engine)](http://www.cs.tau.ac.il/~msagiv/courses/asv/z3py/fixedpoints-examples.htm) 104 | 105 | Considering: 106 | - [pyDatalog](https://sites.google.com/site/pydatalog/home) 107 | 108 | ### Deductive databases 109 | 110 | Currently supported: 111 | - none yet 112 | 113 | Considering: 114 | - [Grakn](https://grakn.ai/) 115 | 116 | 117 | ### Answer set programming 118 | 119 | Currently supported: 120 | - none yet 121 | 122 | Considering: 123 | - [aspirin](https://github.com/potassco/asprin) 124 | - [clorm](https://github.com/potassco/clorm) 125 | - [asp-lite](https://github.com/lorenzleutgeb/asp-lite) 126 | - [hexlite](https://github.com/hexhex/hexlite) 127 | - [clyngor](https://github.com/aluriak/clyngor) 128 | 129 | 130 | # Roadmap 131 | 132 | ### First direction: reasoning engines 133 | 134 | - [x] integrate one solver for each of the representative categories 135 | - [ ] add support for external predicates (functionality specified in Python) 136 | - [x] SWI prolog wrapper 137 | - [ ] include probabilistic engines (Problog, PSL, MLNs) 138 | - [ ] add parsers for each dialect 139 | - [ ] different ways of loading data (input language, CSV, ...) 140 | 141 | 142 | 143 | ### Second directions: learning primitives 144 | 145 | - add learning primitives such as search, hypothesis space generation 146 | - wrap state of the art learners (ACE, Metagol, Aleph) 147 | 148 | 149 | # Code structure 150 | 151 | The *language* constructs are in `loreleai/language` folder. 152 | There is a folder for each dialect of first-order logic. 153 | Currently there are _logic programming_ (`loreleai/language/lp`) and _relational programming_ (`loreleai/language/kanren`). 154 | The implementations of all shared concepts are in `loreleai/language/commons.py` and the idea is to use `__init__.py` files to provide the allowed constructs for each dialect. 155 | 156 | 157 | The *reasoning* constructs are in `loreleai/reasoning` folder. 158 | The structure is the same as with language. 159 | Different dialects of logic programming are in the folder `lorelai/reasoning/lp`. 160 | 161 | 162 | The *learning* primitives are supposed to be in the `loreleai/learning` folder. 163 | 164 | 165 | 166 | # Requirements 167 | 168 | - pyswip 169 | - problog 170 | - ortools 171 | - minikanren 172 | - z3-solver 173 | - black 174 | 175 | # Notes for different engines 176 | 177 | ## SWI Prolog 178 | For using SWI prolog, check the install instructions: https://github.com/yuce/pyswip/blob/master/INSTALL.md 179 | 180 | ## Z3 181 | 182 | Z3Py scripts stored in arbitrary directories can be executed if the 'build/python' directory is added to the PYTHONPATH environment variable and the 'build' directory is added to the DYLD_LIBRARY_PATH environment variable. 183 | -------------------------------------------------------------------------------- /docs/hypothesis_space.md: -------------------------------------------------------------------------------- 1 | The central design principle in `loreleai` is to explicitly represent the hypothesis space (the space of all programs) and allow you to manipulate it. 2 | The hypothesis spaces are located in `loreleai.learning.hypothesis_space` and currently implemented ones are: `TopDownHypothesisSpace`. 3 | This document gives a brief introduction how to use them. 4 | 5 | # Top down hypothesis space 6 | 7 | This is a hypothesis space in which the programs are constructed from the simplest (shortest) to more complicated ones. 8 | 9 | ## Constructing hypothesis space 10 | 11 | To create the hypothesis space, you need to provide the following ingredients: 12 | - **primitives:** functions that extend a given clause (otherwise known as refinement operators) 13 | - **head constructor:** instructions how to construct the head of the clauses 14 | - **expansion hooks:** functions used to eliminate useless extensions 15 | 16 | 17 | Below is the complete example and this document explains it part by part. 18 | ```python 19 | from loreleai.learning.hypothesis_space import TopDownHypothesisSpace 20 | from loreleai.learning.language_filtering import connected_clause, has_singleton_vars 21 | from loreleai.learning.language_manipulation import plain_extension 22 | from loreleai.language.lp import c_pred 23 | 24 | grandparent = c_pred("grandparent", 2) 25 | father = c_pred("father", 2) 26 | mother = c_pred("mother", 2) 27 | 28 | hs = TopDownHypothesisSpace(primitives=[lambda x: plain_extension(x, father, connected_clauses=False), 29 | lambda x: plain_extension(x, mother, connected_clauses=False)], 30 | head_constructor=grandparent, 31 | expansion_hooks_keep=[lambda x, y: connected_clause(x, y)], 32 | expansion_hooks_reject=[lambda x, y: has_singleton_vars(x, y)] 33 | ) 34 | ``` 35 | 36 | ### Primitives 37 | 38 | To construct the hypothesis space, you need to provide *primitives* -- functions that take a clause or a body and refine it (add another literal to it). 39 | It accepts it as a list of functions that could be called with a single argument, being `Clause`, `Body` or `Procedure`. 40 | The function should return a list of extensions of the clause and be type consistent: if an object of the type `Body` is provided, then all extensions should also be of the type `Body` and so on. 41 | `loreleai` provides some primitives, but you can also provide your own functions. 42 | 43 | The best way to see how primitive functions should be implemented is to check the implementations `loreleai` provides is `loreleai.learning.language_manipulation`. 44 | Currently, only one such function is implemented: `plain_extension` extends the given body/clause/procedure by adding all possible literals with the provided predicate. 45 | It has the following signature: 46 | ```python 47 | def plain_extension( 48 | clause: typing.Union[Clause, Body, Procedure], 49 | predicate: Predicate, 50 | connected_clauses: bool = True, 51 | negated: bool = False, 52 | ) -> typing.Sequence[typing.Union[Clause, Body, Procedure]] 53 | ``` 54 | where: 55 | - `clause` is the body/clause/procedure to extend 56 | - `predicate` is the predicate to add to the body 57 | - `connected_clauses` is a flag indicating that extensions shoudl result in connected clauses (no disjoint sets of variables) 58 | - `negated` is a flag indicating that the added literals should be negations 59 | 60 | 61 | ### Head constructor 62 | 63 | Head constructor provides instructions how to construct the heads of clauses. 64 | There are two options to do so: 65 | - provide the exact predicate to be put in the head 66 | - provide the `FillerPredicate` (located in `loreleai.learning.utilities`): this means that every unique body will get a new predicate (this is useful for predicate invention) 67 | 68 | `FillerPredicate` takes several arguments: 69 | - `prefix` specifies the name of the invented predicates (additionally suffixed by their index) 70 | - `arity` (optional) of the invented predicates 71 | - `min_arity` (optional): minimal arity of the invented predicates 72 | - `max_arity` (optional): maximal arity of the invented predicates 73 | 74 | `FillerPredicate` can be configured in two ways by: 75 | - specifying the `arity`: this introduces predicates of fixed arity 76 | - specifying `min_arity` and `max_arity`: this introduces predicates with arity between `min_arity` to `max_arity` 77 | 78 | You have to specify one of the options. 79 | 80 | 81 | ### Expansions hooks 82 | 83 | Depending on the primitive functions used, many of the generated clauses will be useless (for example, clauses that are never true in data). 84 | Expansions hooks can be used to eliminate such clauses: every time a hypothesis is expanded, the 'child' hypotheses are checked with expansion hook; if they 'fail' the test, they are eliminated (removed). 85 | 86 | Expansion hooks are provided as a list of functions that can be called with a clause as an input. 87 | The functions need to have the following specification: 88 | ```python 89 | function_name(head: Atom, body: Body) -> bool 90 | ``` 91 | where: 92 | - `head` is the head of a clause 93 | - `body` is the body of a clause 94 | 95 | Expansion hooks come in two flavours: 96 | - `expansion_hooks_keep`: keeps all expanded hypotheses for which the expansion functions return `True` 97 | - `expansion_hooks_reject`: rejects all expanded hypothesis for which the expansion functions return `True` 98 | 99 | `loreleai` implements several of these functions in `loreleai.learning.language_filtering`: 100 | - `has_singleton_variables(head, body)`: return `True` is a clause has a singleton variable (variable that appears only once) 101 | - `max_var(head, body, max_count)`: returns `True` if the number of variables in the clause is less or equal to `max_count` 102 | - `connected_body(head, body)`: returns `True` if the body of the clause is connected (variables cannot be partitioned in disjoint sets) 103 | - `connected_clause`: returns `True` if the entire clause is connected 104 | - `negation_at_the_end(head, body)`: returns `True` if negative literals appear after positive literals 105 | - `max_pred_occurrences(head, body, pred, max_occurrence)`: returns `True` if predicate `pred` appears at most `max_occurence` times in the body of the clause 106 | - `has_duplicated_literals(head, body):` returns `True` if there are duplicated literals in the body 107 | 108 | 109 | ### Other options 110 | - `recursive_procedures`: if set to `True` it will enumerate recursions 111 | 112 | 113 | ## Using the hypothesis space 114 | 115 | The hypothesis space objects offer several methods to manipulate the hypothesis space. 116 | 117 | **Important thing to under** is that the unit component in the TopDownHypothesisSpace is `Body`, i.e., the body of the clause. 118 | If the hypothesis space is viewed as a graph, the nodes are bodies and possible heads are kept as 'attributes' of the corresponding node. 119 | This is important to keep in mind because the operations that retrieve candidates from the hypothesis space (e.g., `expand()`) return all candidates with the same body. 120 | Likewise, other operations that manipulate the hypothesis space can affect both a single clause as well as all clauses with the same body. 121 | 122 | 123 | The `.expand(clause)` method expands/refines the the given clause with all provided primitive functions. 124 | The clause can be either a `Body`, `Clause` or `Procedure` (specifically, `Recursion`). 125 | The method returns all possible expansions of the given clause (not a single one). 126 | 127 | 128 | The `.block(clause)` method blocks the expansion of the give clause, but keeps the clause itself. 129 | The clause can be either a `Body`, `Clause` or `Procedure` (specifically, `Recursion`). 130 | Every clause with the same body gets blocked. 131 | 132 | The `.ignore(clause)` method ignores the clause in the hypothesis space: the clause can be further refined, but will not be returns as a viable candidate. 133 | The clause can be either a `Body`, `Clause` or `Procedure` (specifically, `Recursion`). 134 | If the `Clause` is provided, only that specific clause is ignored. 135 | If the `Body` is provided, every clause with that body is ignored. 136 | 137 | 138 | The `.remove(clause, remove_entire_body)` method removes the clause from the hypothesis space. 139 | The clause can be of type `Clause`, `Body` or `Procedure`. 140 | If `Clause` is provided, the specific clause is remove. 141 | If `Body` is provided, every clause with that body is ignore. 142 | If `remove_entire_body` is set to `True`, one can provide a specific clause but all clauses with the same body will be removed. 143 | 144 | The `.get_successors_of(clause)` method returns all successors of the clause. 145 | The clause can either be `Clause` or `Body`. 146 | The method returns all clauses that are obtained by extending/refining the body of the given clause. 147 | 148 | The `.get_predecessor_of(clause)` method returns all predecessors of the clause. 149 | The clause can either be `Clause` or `Body`. 150 | 151 | 152 | 153 | The hypothesis space contains individual clauses as the basic units. 154 | More complex programs (disjunctions and recursions) can be constructed by combining individual clauses. 155 | Hypothesis space objects in `loreleai` allow you to achieve this through the usage of *pointers*. 156 | A pointer simply holds a position in the hypothesis space. 157 | Multiple pointers can be created and all of them can be moved independently. 158 | Upon construction of the hypothesis space, th `main` pointer is created and assigned to the root node (empty clause). 159 | 160 | The `.register_pointer(name, init_value)` method registers new pointer under the name `name`. 161 | If initial value/position of the pointer `init_value` is not provided, it is set to the root node. 162 | `init_value` can be either `Clause` or `Body`; if `Clause` is provided, it is automatically converted to `Body`. 163 | 164 | The `.get_current_candidate(pointer_name)` method returns all clauses that can be constructed from the body to which the pointer `pointer_name` is currently assigned to. 165 | If `pointer_name` is not specified, it is assumed that the `main` pointer is in question. 166 | The initial position of the ay pointer is at the root of the TopDownHypothesisSpace. 167 | 168 | The `.move_pointer_to(clause, pointer_name)` method moves the pointer `pointer_name` to the body of `clause`. 169 | If `pointer_name` is not specified, it is assumed to be `main` one. 170 | `clause` can be either `Clause` or `Body`; if `Clause` is provided, it is automatically converted to `Body`. 171 | 172 | The `.reset_pointer(pointer_name, init_val)` method resets the `pointer_name` pointer to the root or `init_val` if provided. 173 | `init_value` can be either `Clause` or `Body`; if `Clause` is provided, it is automatically converted to `Body`. 174 | 175 | 176 | Recursions are realised via pointers, if enabled (`recursive_procedures=True`). 177 | Every time an extension/refinement operation results in a recursive clause, a new pointer is created and associated with the hypothesis (every constructed recursive clause is blocked from further expansion). 178 | When `.get_current_candidate` method requests a recursive candidate, the associated pointer traverses the entire hypothesis space in search for valid base cases. 179 | Then it returns all valid recursions. 180 | 181 | 182 | 183 | # Using the hypothesis space 184 | 185 | The file `simple_learning_system.py` illustrates how to build a simple enumerative learner. 186 | 187 | 188 | 189 | 190 | 191 | 192 | 193 | 194 | 195 | 196 | 197 | 198 | 199 | 200 | 201 | 202 | 203 | 204 | 205 | 206 | 207 | 208 | 209 | 210 | 211 | 212 | 213 | 214 | 215 | 216 | 217 | 218 | 219 | 220 | -------------------------------------------------------------------------------- /docs/simple_learning_system.py: -------------------------------------------------------------------------------- 1 | import typing 2 | from abc import ABC, abstractmethod 3 | 4 | from orderedset import OrderedSet 5 | 6 | from loreleai.language.lp import c_pred, Clause, Procedure, Atom 7 | from loreleai.learning.hypothesis_space import TopDownHypothesisSpace 8 | from loreleai.learning.language_filtering import has_singleton_vars, has_duplicated_literal 9 | from loreleai.learning.language_manipulation import plain_extension 10 | from loreleai.learning.task import Task, Knowledge 11 | from loreleai.reasoning.lp.prolog import SWIProlog, Prolog 12 | 13 | """ 14 | This is an abstract learner class that defines a learner with the configurable options. 15 | 16 | It follows a very simple learning principle: it iteratively 17 | - searches for a single clause that covers most positive examples 18 | - adds it to the program 19 | - removes covered examples 20 | 21 | It is implemented as a template learner - you still need to provide the following methods: 22 | - initialisation of the candidate pool (the data structure that keeps all candidates) 23 | - getting a single candidate from the candidate pool 24 | - adding candidate(s) to the pool 25 | - evaluating a single candidate 26 | - stopping the search for a single clause 27 | - processing expansion/refinements of clauses 28 | 29 | The learner does not handle recursions correctly! 30 | """ 31 | class TemplateLearner(ABC): 32 | 33 | def __init__(self, solver_instance: Prolog): 34 | self._solver = solver_instance 35 | self._candidate_pool = [] 36 | 37 | def _assert_knowledge(self, knowledge: Knowledge): 38 | """ 39 | Assert knowledge into Prolog engine 40 | """ 41 | facts = knowledge.get_atoms() 42 | for f_ind in range(len(facts)): 43 | self._solver.assertz(facts[f_ind]) 44 | 45 | clauses = knowledge.get_clauses() 46 | for cl_ind in range(len(clauses)): 47 | self._solver.assertz(clauses[cl_ind]) 48 | 49 | def _execute_program(self, clause: Clause) -> typing.Sequence[Atom]: 50 | """ 51 | Evaluates a clause using the Prolog engine and background knowledge 52 | 53 | Returns a set of atoms that the clause covers 54 | """ 55 | if len(clause.get_body().get_literals()) == 0: 56 | return [] 57 | else: 58 | head_predicate = clause.get_head().get_predicate() 59 | head_variables = clause.get_head_variables() 60 | 61 | sols = self._solver.query(*clause.get_body().get_literals()) 62 | 63 | sols = [head_predicate(*[s[v] for v in head_variables]) for s in sols] 64 | 65 | return sols 66 | 67 | @abstractmethod 68 | def initialise_pool(self): 69 | """ 70 | Creates an empty pool of candidates 71 | """ 72 | raise NotImplementedError() 73 | 74 | @abstractmethod 75 | def get_from_pool(self) -> Clause: 76 | """ 77 | Gets a single clause from the pool 78 | """ 79 | raise NotImplementedError() 80 | 81 | @abstractmethod 82 | def put_into_pool(self, candidates: typing.Union[Clause, Procedure, typing.Sequence]) -> None: 83 | """ 84 | Inserts a clause/a set of clauses into the pool 85 | """ 86 | raise NotImplementedError() 87 | 88 | @abstractmethod 89 | def evaluate(self, examples: Task, clause: Clause) -> typing.Union[int, float]: 90 | """ 91 | Evaluates a clause of a task 92 | 93 | Returns a number (the higher the better) 94 | """ 95 | raise NotImplementedError() 96 | 97 | @abstractmethod 98 | def stop_inner_search(self, eval: typing.Union[int, float], examples: Task, clause: Clause) -> bool: 99 | """ 100 | Returns true if the search for a single clause should be stopped 101 | """ 102 | raise NotImplementedError() 103 | 104 | @abstractmethod 105 | def process_expansions(self, examples: Task, exps: typing.Sequence[Clause], hypothesis_space: TopDownHypothesisSpace) -> typing.Sequence[Clause]: 106 | """ 107 | Processes the expansions of a clause 108 | It can be used to eliminate useless expansions (e.g., the one that have no solution, ...) 109 | 110 | Returns a filtered set of candidates 111 | """ 112 | raise NotImplementedError() 113 | 114 | def _learn_one_clause(self, examples: Task, hypothesis_space: TopDownHypothesisSpace) -> Clause: 115 | """ 116 | Learns a single clause 117 | 118 | Returns a clause 119 | """ 120 | # reset the search space 121 | hypothesis_space.reset_pointer() 122 | 123 | # empty the pool just in case 124 | self.initialise_pool() 125 | 126 | # put initial candidates into the pool 127 | self.put_into_pool(hypothesis_space.get_current_candidate()) 128 | current_cand = None 129 | score = -100 130 | 131 | while current_cand is None or (len(self._candidate_pool) > 0 and not self.stop_inner_search(score, examples, current_cand)): 132 | # get first candidate from the pool 133 | current_cand = self.get_from_pool() 134 | 135 | # expand the candidate 136 | _ = hypothesis_space.expand(current_cand) 137 | # this is important: .expand() method returns candidates only the first time it is called; 138 | # if the same node is expanded the second time, it returns the empty list 139 | # it is safer than to use the .get_successors_of method 140 | exps = hypothesis_space.get_successors_of(current_cand) 141 | exps = self.process_expansions(examples, exps, hypothesis_space) 142 | # add into pull 143 | self.put_into_pool(exps) 144 | 145 | score = self.evaluate(examples, current_cand) 146 | 147 | return current_cand 148 | 149 | def learn(self, examples: Task, knowledge: Knowledge, hypothesis_space: TopDownHypothesisSpace): 150 | """ 151 | General learning loop 152 | """ 153 | 154 | self._assert_knowledge(knowledge) 155 | final_program = [] 156 | examples_to_use = examples 157 | pos, _ = examples_to_use.get_examples() 158 | 159 | while len(final_program) == 0 or len(pos) > 0: 160 | # learn na single clause 161 | cl = self._learn_one_clause(examples_to_use, hypothesis_space) 162 | final_program.append(cl) 163 | 164 | # update covered positive examples 165 | covered = self._execute_program(cl) 166 | 167 | pos, neg = examples_to_use.get_examples() 168 | pos = pos.difference(covered) 169 | 170 | examples_to_use = Task(pos, neg) 171 | 172 | return final_program 173 | 174 | 175 | """ 176 | A simple breadth-first top-down learner: it extends the template learning by searching in a breadth-first fashion 177 | 178 | It implements the abstract functions in the following way: 179 | - initialise_pool: creates an empty OrderedSet 180 | - put_into_pool: adds to the ordered set 181 | - get_from_pool: returns the first elements in the ordered set 182 | - evaluate: returns the number of covered positive examples and 0 if any negative example is covered 183 | - stop inner search: stops if the provided score of a clause is bigger than zero 184 | - process expansions: removes from the hypothesis space all clauses that have no solutions 185 | 186 | The learner does not handle recursions correctly! 187 | """ 188 | 189 | 190 | class SimpleBreadthFirstLearner(TemplateLearner): 191 | 192 | def __init__(self, solver_instance: Prolog, max_body_literals=4): 193 | super().__init__(solver_instance) 194 | self._max_body_literals = max_body_literals 195 | 196 | def initialise_pool(self): 197 | self._candidate_pool = OrderedSet() 198 | 199 | def put_into_pool(self, candidates: typing.Union[Clause, Procedure, typing.Sequence]) -> None: 200 | if isinstance(candidates, Clause): 201 | self._candidate_pool.add(candidates) 202 | else: 203 | self._candidate_pool |= candidates 204 | 205 | def get_from_pool(self) -> Clause: 206 | return self._candidate_pool.pop(0) 207 | 208 | def evaluate(self, examples: Task, clause: Clause) -> typing.Union[int, float]: 209 | covered = self._execute_program(clause) 210 | 211 | pos, neg = examples.get_examples() 212 | 213 | covered_pos = pos.intersection(covered) 214 | covered_neg = neg.intersection(covered) 215 | 216 | if len(covered_neg) > 0: 217 | return 0 218 | else: 219 | return len(covered_pos) 220 | 221 | def stop_inner_search(self, eval: typing.Union[int, float], examples: Task, clause: Clause) -> bool: 222 | if eval > 0: 223 | return True 224 | else: 225 | return False 226 | 227 | def process_expansions(self, examples: Task, exps: typing.Sequence[Clause], hypothesis_space: TopDownHypothesisSpace) -> typing.Sequence[Clause]: 228 | # eliminate every clause with more body literals than allowed 229 | exps = [cl for cl in exps if len(cl) <= self._max_body_literals] 230 | 231 | # check if every clause has solutions 232 | exps = [(cl, self._solver.has_solution(*cl.get_body().get_literals())) for cl in exps] 233 | new_exps = [] 234 | 235 | for ind in range(len(exps)): 236 | if exps[ind][1]: 237 | # keep it if it has solutions 238 | new_exps.append(exps[ind][0]) 239 | else: 240 | # remove from hypothesis space if it does not 241 | hypothesis_space.remove(exps[ind][0]) 242 | 243 | return new_exps 244 | 245 | 246 | if __name__ == '__main__': 247 | # define the predicates 248 | father = c_pred("father", 2) 249 | mother = c_pred("mother", 2) 250 | grandparent = c_pred("grandparent", 2) 251 | 252 | # specify the background knowledge 253 | background = Knowledge(father("a", "b"), mother("a", "b"), mother("b", "c"), 254 | father("e", "f"), father("f", "g"), 255 | mother("h", "i"), mother("i", "j")) 256 | 257 | # positive examples 258 | pos = {grandparent("a", "c"), grandparent("e", "g"), grandparent("h", "j")} 259 | 260 | # negative examples 261 | neg = {grandparent("a", "b"), grandparent("a", "g"), grandparent("i", "j")} 262 | 263 | task = Task(positive_examples=pos, negative_examples=neg) 264 | 265 | # create Prolog instance 266 | prolog = SWIProlog() 267 | 268 | learner = SimpleBreadthFirstLearner(prolog, max_body_literals=3) 269 | 270 | # create the hypothesis space 271 | hs = TopDownHypothesisSpace(primitives=[lambda x: plain_extension(x, father, connected_clauses=True), 272 | lambda x: plain_extension(x, mother, connected_clauses=True)], 273 | head_constructor=grandparent, 274 | expansion_hooks_reject=[lambda x, y: has_singleton_vars(x, y), 275 | lambda x, y: has_duplicated_literal(x, y)]) 276 | 277 | program = learner.learn(task, background, hs) 278 | 279 | print(program) 280 | 281 | 282 | 283 | -------------------------------------------------------------------------------- /loreleai/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sebdumancic/loreleai/3863c683b6acfdaa9c1e23a526110492e610fe74/loreleai/__init__.py -------------------------------------------------------------------------------- /loreleai/language/__init__.py: -------------------------------------------------------------------------------- 1 | from pylo.language.commons import KANREN_LOGPY 2 | from pylo.language.commons import MUZ 3 | 4 | __all__ = [ 5 | 'MUZ', 6 | 'KANREN_LOGPY' 7 | ] -------------------------------------------------------------------------------- /loreleai/language/datalog/__init__.py: -------------------------------------------------------------------------------- 1 | from pylo.language.datalog import ( 2 | Term, 3 | Constant, 4 | Variable, 5 | Structure, 6 | Predicate, 7 | Type, 8 | Not, 9 | Program, 10 | c_pred, 11 | c_const, 12 | c_id_to_const, 13 | c_var, 14 | c_literal, 15 | c_fresh_var, 16 | c_find_domain, 17 | Atom, 18 | Clause, 19 | Literal, 20 | Procedure, 21 | Disjunction, 22 | Recursion, 23 | Body, 24 | c_type 25 | ) 26 | 27 | 28 | __all__ = [ 29 | "Term", 30 | "Constant", 31 | "Variable", 32 | "Structure", 33 | "Predicate", 34 | "Not", 35 | "Type", 36 | "Program", 37 | "c_pred", 38 | "c_const", 39 | "c_id_to_const", 40 | "c_var", 41 | "c_literal", 42 | "c_fresh_var", 43 | "c_find_domain", 44 | "Clause", 45 | "Atom", 46 | "Literal", 47 | "Procedure", 48 | "Disjunction", 49 | "Recursion", 50 | "Body", 51 | "c_type" 52 | ] -------------------------------------------------------------------------------- /loreleai/language/kanren/__init__.py: -------------------------------------------------------------------------------- 1 | from pylo.language.kanren import ( 2 | Term, 3 | Constant, 4 | Variable, 5 | Structure, 6 | Predicate, 7 | Type, 8 | Not, 9 | Program, 10 | c_pred, 11 | c_const, 12 | c_id_to_const, 13 | c_var, 14 | c_literal, 15 | c_fresh_var, 16 | Clause, 17 | Atom, 18 | Literal, 19 | Procedure, 20 | Disjunction, 21 | Recursion, 22 | Body, 23 | c_type 24 | ) 25 | 26 | from .kanren_utils import construct_recursive_rule 27 | 28 | __all__ = [ 29 | "Term", 30 | "Constant", 31 | "Variable", 32 | "Structure", 33 | "Predicate", 34 | "Type", 35 | "Not", 36 | "Program", 37 | "c_pred", 38 | "c_const", 39 | "c_id_to_const", 40 | "c_var", 41 | "c_literal", 42 | "c_fresh_var", 43 | "Clause", 44 | "Atom", 45 | 'construct_recursive_rule', 46 | "Literal", 47 | "Procedure", 48 | "Disjunction", 49 | "Recursion", 50 | "Body", 51 | "c_type" 52 | ] 53 | -------------------------------------------------------------------------------- /loreleai/language/kanren/kanren_utils.py: -------------------------------------------------------------------------------- 1 | from functools import reduce 2 | from typing import Sequence 3 | 4 | import kanren 5 | 6 | from loreleai.language.lp import Clause, Variable 7 | 8 | 9 | def _turn_clause_to_interim_repr(clause: Clause, suffix: str = "_x"): 10 | head_vars = dict( 11 | [(v, ind) for ind, v in enumerate(clause.get_head().get_variables())] 12 | ) 13 | 14 | return [ 15 | tuple([a.get_predicate()] 16 | + [ 17 | head_vars[t] if isinstance(t, Variable) and t in head_vars else t 18 | for t in a.get_terms() 19 | ]) 20 | for a in clause.get_literals() 21 | ] 22 | 23 | 24 | def construct_recursive_rule(rules: Sequence[Clause]): 25 | base_cases = [x for x in rules if not x.is_recursive()] 26 | recursive_rules = [x for x in rules if x.is_recursive()] 27 | head_predicate = rules[0].get_head().get_predicate() 28 | 29 | # turn each clause into its body: [(predicate, arg1, arg2, ...)] 30 | base_cases = [_turn_clause_to_interim_repr(x) for x in base_cases] 31 | 32 | recursive_rules = [_turn_clause_to_interim_repr(x) for x in recursive_rules] 33 | 34 | # variables that need to be created in the generic function 35 | bc_vars = [set(reduce(lambda x, y: x + y, [[v for v in c if isinstance(v, Variable)] for c in x], [])) for x in base_cases] 36 | rc_vars = [set(reduce(lambda x, y: x + y, [[v for v in c if isinstance(v, Variable)] for c in x], [])) for x in recursive_rules] 37 | 38 | def generic_recusive_predicate( 39 | *args, 40 | head=head_predicate, 41 | bases=base_cases, 42 | recursion=recursive_rules, 43 | bvars=bc_vars, 44 | rvars=rc_vars 45 | ): 46 | bvars = [ 47 | dict( 48 | [(c, kanren.var()) for c in v] 49 | + [(x, args[x]) for x in range(head.get_arity())] 50 | ) 51 | for v in bvars 52 | ] 53 | rvars = [ 54 | dict( 55 | [(c, kanren.var()) for c in v] 56 | + [(x, args[x]) for x in range(head.get_arity())] 57 | ) 58 | for v in rvars 59 | ] 60 | 61 | base_cases = [ 62 | [ 63 | a[0].as_kanren()( 64 | *[ 65 | bvars[ind][ar] if ar in bvars[ind] else ar.get_name() 66 | for ar in a[1:] 67 | ] 68 | ) 69 | for a in cl 70 | ] 71 | for ind, cl in enumerate(bases) 72 | ] 73 | 74 | recursion = [ 75 | [ 76 | a[0].as_kanren()( 77 | *[ 78 | rvars[ind][ar] if ar in rvars[ind] else ar.get_name() 79 | for ar in a[1:] 80 | ] 81 | ) 82 | if a[0] != head_predicate 83 | else kanren.core.Zzz( 84 | generic_recusive_predicate, 85 | *[ 86 | rvars[ind][ar] if ar in rvars[ind] else ar.get_name() 87 | for ar in a[1:] 88 | ] 89 | ) 90 | for a in cl 91 | ] 92 | for ind, cl in enumerate(recursion) 93 | ] 94 | 95 | all_for_conde = base_cases + recursion 96 | 97 | return kanren.conde(*all_for_conde) 98 | 99 | return generic_recusive_predicate 100 | -------------------------------------------------------------------------------- /loreleai/language/lp/__init__.py: -------------------------------------------------------------------------------- 1 | # from .lp import ClausalTheory, parse 2 | # from ..commons import ( 3 | # Term, 4 | # Constant, 5 | # Variable, 6 | # Structure, 7 | # Predicate, 8 | # Type, 9 | # Not, 10 | # Type, 11 | # Program, 12 | # c_pred, 13 | # c_const, 14 | # c_id_to_const, 15 | # c_var, 16 | # c_literal, 17 | # c_find_domain, 18 | # c_functor, 19 | # c_symbol, 20 | # Atom, 21 | # Not, 22 | # Clause, 23 | # c_fresh_var, 24 | # Literal, 25 | # Procedure, 26 | # Disjunction, 27 | # Recursion, 28 | # Functor, 29 | # global_context, 30 | # list_func, 31 | # List, 32 | # Context, 33 | # Body 34 | # ) 35 | 36 | from pylo.language.lp import Term, Constant, Variable, Structure, Predicate, Type, Program, c_pred, c_const, \ 37 | c_id_to_const, c_var, c_literal, c_find_domain, c_functor, c_symbol, Clause, Atom, Not, c_fresh_var, Literal, \ 38 | Procedure, Disjunction, Recursion, Functor, list_func, List, Body, c_type, Pair 39 | 40 | from ..utils import triplet 41 | 42 | __all__ = [ 43 | "Term", 44 | "Constant", 45 | "Variable", 46 | "Structure", 47 | "Predicate", 48 | "Type", 49 | "Program", 50 | "c_pred", 51 | "c_const", 52 | "c_id_to_const", 53 | "c_var", 54 | "c_literal", 55 | "c_find_domain", 56 | "c_functor", 57 | "c_symbol", 58 | "Clause", 59 | "Atom", 60 | "Not", 61 | "c_fresh_var", 62 | 'triplet', 63 | 'Literal', 64 | 'Procedure', 65 | "Disjunction", 66 | "Recursion", 67 | "Functor", 68 | "list_func", 69 | "List", 70 | "Body", 71 | "c_type", 72 | "Pair" 73 | ] 74 | -------------------------------------------------------------------------------- /loreleai/language/utils.py: -------------------------------------------------------------------------------- 1 | from typing import Union, Sequence, Dict 2 | 3 | import networkx as nx 4 | from pylo.language.lp import Predicate, Constant, Atom, Variable, c_fresh_var 5 | 6 | 7 | # MUZ = "muz" 8 | # LP = 1 9 | # FOL = 2 10 | # KANREN_LOGPY = "logpy" 11 | 12 | 13 | def triplet(subject: Union[str, Constant, Variable], 14 | relation: Predicate, 15 | object: Union[str, Constant, Variable]) -> Atom: 16 | """ 17 | Allows to specify a literal as a triplet 18 | 19 | Arguments: 20 | subject: name of the subject/head entity; '?' means variable 21 | relation: the relation between the entities 22 | object: name of th object/tail entity; '?' means variable 23 | 24 | Return: 25 | a literal 26 | """ 27 | assert relation.get_arity() == 2 28 | arg_types = relation.get_arg_types() 29 | return relation(c_fresh_var(arg_types[0]) if isinstance(subject, str) and subject == "?" else subject, 30 | c_fresh_var(arg_types[1]) if isinstance(object, str) and object == "?" else object) 31 | 32 | 33 | def nx_to_loreleai(graph: nx.Graph, relation_map: Dict[str, Predicate] = None) -> Sequence[Atom]: 34 | """ 35 | Converts a NetworkX graph into Loreleai representation 36 | 37 | To indicate the type of relations and nodes, the functions looks for a 'type' attribute 38 | 39 | Arguments: 40 | graph: NetworkX graph 41 | relation_map: maps from edge types to predicates 42 | """ 43 | 44 | literals = [] 45 | 46 | if relation_map is None: 47 | relation_map = {} 48 | 49 | for (u, v, t) in graph.edges.data('type', default=None): 50 | literals.append(relation_map[t](u, v)) 51 | 52 | return literals 53 | 54 | 55 | -------------------------------------------------------------------------------- /loreleai/learning/__init__.py: -------------------------------------------------------------------------------- 1 | from .hypothesis_space import HypothesisSpace, TopDownHypothesisSpace 2 | from .language_filtering import has_singleton_vars, has_duplicated_literal, max_var, max_pred_occurrences, \ 3 | connected_clause, connected_body, negation_at_the_end 4 | from .language_manipulation import plain_extension 5 | from .task import Knowledge, Interpretation, Task 6 | from .utilities import FillerPredicate, are_variables_connected 7 | from .abstract_learners import TemplateLearner 8 | -------------------------------------------------------------------------------- /loreleai/learning/abstract_learners.py: -------------------------------------------------------------------------------- 1 | import typing 2 | from abc import ABC, abstractmethod 3 | 4 | from loreleai.reasoning.lp import LPSolver 5 | from loreleai.learning.task import Knowledge, Task 6 | from loreleai.language.lp import Clause,Atom,Procedure 7 | from loreleai.learning.hypothesis_space import TopDownHypothesisSpace, HypothesisSpace 8 | from loreleai.learning.eval_functions import EvalFunction 9 | import datetime 10 | 11 | class Learner(ABC): 12 | """ 13 | Base class for all learners 14 | """ 15 | def __init__(self): 16 | self._learnresult = LearnResult() 17 | 18 | @abstractmethod 19 | def learn(self, examples: Task, knowledge: Knowledge, hypothesis_space: HypothesisSpace): 20 | raise NotImplementedError() 21 | 22 | """ 23 | This is an abstract learner class that defines a learner with the configurable options. 24 | 25 | It follows a very simple learning principle: it iteratively 26 | - searches for a single clause that covers most positive examples 27 | - adds it to the program 28 | - removes covered examples 29 | 30 | It is implemented as a template learner - you still need to provide the following methods: 31 | - initialisation of the candidate pool (the data structure that keeps all candidates) 32 | - getting a single candidate from the candidate pool 33 | - adding candidate(s) to the pool 34 | - evaluating a single candidate 35 | - stopping the search for a single clause 36 | - processing expansion/refinements of clauses 37 | 38 | The learner does not handle recursions correctly! 39 | """ 40 | 41 | class LearnResult: 42 | """ 43 | The LearnResult class holds statistics about the learning process. 44 | It is implemented as a dict and supports indexing [] as usual. 45 | """ 46 | def __init__(self): 47 | self.info = dict() 48 | 49 | def __getitem__(self,key): 50 | return self.info[key] 51 | 52 | def __setitem__(self,key,val): 53 | self.info[key] = val 54 | 55 | def __repr__(self): 56 | max_keylength = max([len(str(k)) for k in self.info.keys()]) 57 | output_str = "Result of learning: \n" 58 | 59 | for key,val in self.info.items(): 60 | output_str += str(key) + ":" + " "*(max_keylength - len(str(key))+2) + str(val) + "\n" 61 | return output_str 62 | 63 | 64 | 65 | class TemplateLearner(Learner): 66 | 67 | def __init__(self, solver_instance: LPSolver, eval_fn: EvalFunction, do_print=False): 68 | self._solver = solver_instance 69 | self._candidate_pool = [] 70 | self._eval_fn = eval_fn 71 | self._print = do_print 72 | 73 | # Statistics about learning process 74 | self._prolog_queries = 0 75 | self._intermediate_coverage = [] # Coverage of examples after every iteration 76 | 77 | super().__init__() 78 | 79 | def _assert_knowledge(self, knowledge: Knowledge): 80 | """ 81 | Assert knowledge into Prolog engine 82 | """ 83 | facts = knowledge.get_atoms() 84 | for f_ind in range(len(facts)): 85 | # self._solver.assert_fact(facts[f_ind]) 86 | self._solver.assertz(facts[f_ind]) 87 | 88 | clauses = knowledge.get_clauses() 89 | for cl_ind in range(len(clauses)): 90 | # self._solver.assert_rule(clauses[cl_ind]) 91 | self._solver.assertz(clauses[cl_ind]) 92 | 93 | def _execute_program(self, clause: Clause, count_as_query=True) -> typing.Sequence[Atom]: 94 | """ 95 | Evaluates a clause using the Prolog engine and background knowledge 96 | 97 | Returns a set of atoms that the clause covers 98 | """ 99 | if len(clause.get_body().get_literals()) == 0: 100 | return [] 101 | else: 102 | head_predicate = clause.get_head().get_predicate() 103 | head_variables = clause.get_head_variables() 104 | 105 | sols = self._solver.query(*clause.get_body().get_literals()) 106 | 107 | self._prolog_queries += 1 if count_as_query else 0 108 | 109 | sols = [head_predicate(*[s[v] for v in head_variables]) for s in sols] 110 | 111 | return sols 112 | 113 | @abstractmethod 114 | def initialise_pool(self): 115 | """ 116 | Creates an empty pool of candidates 117 | """ 118 | raise NotImplementedError() 119 | 120 | @abstractmethod 121 | def get_from_pool(self) -> Clause: 122 | """ 123 | Gets a single clause from the pool 124 | """ 125 | raise NotImplementedError() 126 | 127 | @abstractmethod 128 | def put_into_pool(self, candidates: typing.Union[Clause, Procedure, typing.Sequence]) -> None: 129 | """ 130 | Inserts a clause/a set of clauses into the pool 131 | """ 132 | raise NotImplementedError() 133 | 134 | def evaluate(self, examples: Task, clause: Clause, hypothesis_space: HypothesisSpace) -> typing.Union[int, float]: 135 | """ 136 | Evaluates a clause by calling the Learner's eval_fn. 137 | Returns a number (the higher the better) 138 | """ 139 | # add_to_cache(node,key,val) 140 | # retrieve_from_cache(node,key) -> val or None 141 | # remove_from_cache(node,key) -> None 142 | 143 | # Cache holds sets of examples that were covered before 144 | covered = hypothesis_space.retrieve_from_cache(clause,"covered") 145 | 146 | # We have executed this clause before 147 | if covered is not None: 148 | # Note that _eval.fn.evaluate() will ignore clauses in `covered` 149 | # that are not in the current Task 150 | result = self._eval_fn.evaluate(clause,examples,covered) 151 | # print("No query here.") 152 | return result 153 | else: 154 | covered = self._execute_program(clause) 155 | # if 'None', i.e. trivial hypothesis, all clauses are covered 156 | if covered is None: 157 | pos,neg = examples.get_examples() 158 | covered = pos.union(neg) 159 | 160 | result = self._eval_fn.evaluate(clause,examples,covered) 161 | hypothesis_space.add_to_cache(clause,"covered",covered) 162 | return result 163 | 164 | @abstractmethod 165 | def stop_inner_search(self, eval: typing.Union[int, float], examples: Task, clause: Clause) -> bool: 166 | """ 167 | Returns true if the search for a single clause should be stopped 168 | """ 169 | raise NotImplementedError() 170 | 171 | @abstractmethod 172 | def process_expansions(self, examples: Task, exps: typing.Sequence[Clause], hypothesis_space: TopDownHypothesisSpace) -> typing.Sequence[Clause]: 173 | """ 174 | Processes the expansions of a clause 175 | It can be used to eliminate useless expansions (e.g., the one that have no solution, ...) 176 | 177 | Returns a filtered set of candidates 178 | """ 179 | raise NotImplementedError() 180 | 181 | def _learn_one_clause(self, examples: Task, hypothesis_space: TopDownHypothesisSpace) -> Clause: 182 | """ 183 | Learns a single clause 184 | 185 | Returns a clause 186 | """ 187 | # reset the search space 188 | hypothesis_space.reset_pointer() 189 | 190 | # empty the pool just in case 191 | self.initialise_pool() 192 | 193 | # put initial candidates into the pool 194 | self.put_into_pool(hypothesis_space.get_current_candidate()) 195 | current_cand = None 196 | score = -100 197 | 198 | while current_cand is None or (len(self._candidate_pool) > 0 and not self.stop_inner_search(score, examples, current_cand)): 199 | # get first candidate from the pool 200 | current_cand = self.get_from_pool() 201 | 202 | # expand the candidate 203 | _ = hypothesis_space.expand(current_cand) 204 | # this is important: .expand() method returns candidates only the first time it is called; 205 | # if the same node is expanded the second time, it returns the empty list 206 | # it is safer than to use the .get_successors_of method 207 | exps = hypothesis_space.get_successors_of(current_cand) 208 | exps = self.process_expansions(examples, exps, hypothesis_space) 209 | # add into pool 210 | self.put_into_pool(exps) 211 | 212 | score = self.evaluate(examples, current_cand, hypothesis_space) 213 | 214 | if self._print: 215 | print(f"- New clause: {current_cand}") 216 | print(f"- Candidates has value {round(score,2)} for metric '{self._eval_fn.name()}'") 217 | return current_cand 218 | 219 | 220 | def learn(self, examples: Task, knowledge: Knowledge, hypothesis_space: TopDownHypothesisSpace): 221 | """ 222 | General learning loop 223 | """ 224 | 225 | self._assert_knowledge(knowledge) 226 | final_program = [] 227 | examples_to_use = examples 228 | pos, _ = examples_to_use.get_examples() 229 | i = 0 230 | start = datetime.datetime.now() 231 | 232 | 233 | while len(final_program) == 0 or len(pos) > 0: 234 | # learn na single clause 235 | 236 | if self._print: 237 | print(f"Iteration {i}") 238 | print("- Current program:") 239 | for program_clause in final_program: 240 | print("\t"+str(program_clause)) 241 | 242 | cl = self._learn_one_clause(examples_to_use, hypothesis_space) 243 | final_program.append(cl) 244 | 245 | # update covered positive examples 246 | covered = self._execute_program(cl) 247 | 248 | # Find intermediate quality of program at this point, add to learnresult (don't cound these as Prolog queries) 249 | c = set() 250 | for cl in final_program: 251 | c = c.union(self._execute_program(cl,count_as_query=False)) 252 | pos_covered = len(c.intersection(examples._positive_examples)) 253 | neg_covered = len(c.intersection(examples._negative_examples)) 254 | self.__intermediate_coverage.append((pos_covered,neg_covered)) 255 | 256 | # Remove covered examples and start next iteration 257 | pos, neg = examples_to_use.get_examples() 258 | pos = pos.difference(covered) 259 | 260 | examples_to_use = Task(pos, neg) 261 | i += 1 262 | 263 | total_time = (datetime.datetime.now()-start).total_seconds() 264 | if self._print: 265 | print("Done! Search took {:.5f} seconds.".format(total_time)) 266 | 267 | # Wrap results into learnresult and return 268 | self._learnresult["final_program"] = final_program 269 | self._learnresult["total_time"] = total_time 270 | self._learnresult["num_iterations"] = i 271 | self._learnresult["evalfn_evaluations"] = self._eval_fn._clauses_evaluated 272 | self._learnresult["prolog_queries"] = self._prolog_queries 273 | self._learnresult["intermediate_coverage"] = self._intermediate_coverage 274 | 275 | return self._learnresult 276 | -------------------------------------------------------------------------------- /loreleai/learning/eval_functions.py: -------------------------------------------------------------------------------- 1 | import typing 2 | from abc import ABC, abstractmethod 3 | from typing import Sequence 4 | from loreleai.language.lp import Clause, Atom 5 | from loreleai.learning.task import Task 6 | import math 7 | 8 | 9 | class EvalFunction(ABC): 10 | """ 11 | Abstract base class for an evaluation function 12 | """ 13 | 14 | def __init__(self, name): 15 | self._clauses_evaluated = 0 16 | self._name = name 17 | 18 | @abstractmethod 19 | def evaluate(self, clause: Clause, examples: Task, covered: Sequence[Atom]): 20 | """ 21 | Evaluates the quality of clause on the given examples and 22 | set of covered atoms 23 | """ 24 | raise NotImplementedError() 25 | 26 | def name(self): 27 | return self._name 28 | 29 | 30 | class Accuracy(EvalFunction): 31 | """ 32 | Accuracy is defined as the number of positive examples coverd, 33 | divided by the number of positive and negative examples covered 34 | """ 35 | 36 | def __init__(self, return_upperbound=False): 37 | super().__init__("Accuracy") 38 | self._return_upperbound = return_upperbound 39 | 40 | def evaluate(self, clause: Clause, examples: Task, covered: Sequence[Atom]): 41 | self._clauses_evaluated += 1 42 | 43 | pos, neg = examples.get_examples() 44 | covered_pos = len(pos.intersection(covered)) 45 | covered_neg = len(neg.intersection(covered)) 46 | 47 | if covered_pos + covered_neg == 0: 48 | return 0 if not self._return_upperbound else 0, 0 49 | return ( 50 | covered_pos / (covered_pos + covered_neg) 51 | if not self._return_upperbound 52 | else covered_pos / (covered_pos + covered_neg), 53 | 1, 54 | ) 55 | 56 | 57 | class Compression(EvalFunction): 58 | """ 59 | Compression is similar to coverage but favours shorter clauses 60 | """ 61 | 62 | def __init__(self, return_upperbound=False): 63 | super().__init__("Compression") 64 | self._return_upperbound = return_upperbound 65 | 66 | def evaluate(self, clause: Clause, examples: Task, covered: Sequence[Atom]): 67 | self._clauses_evaluated += 1 68 | 69 | pos, neg = examples.get_examples() 70 | covered_pos = len(pos.intersection(covered)) 71 | covered_neg = len(neg.intersection(covered)) 72 | clause_length = len(clause.get_literals()) 73 | if self._return_upperbound: 74 | return (covered_pos - covered_neg - clause_length + 1), covered_pos 75 | return covered_pos - covered_neg - clause_length + 1 76 | 77 | 78 | class Coverage(EvalFunction): 79 | """ 80 | Coverage is defined as the difference between the number of positive 81 | and negative examples covered. 82 | """ 83 | 84 | def __init__(self, return_upperbound=False): 85 | """ 86 | Initializes the Coverage EvalFunction. When return_upperbound is True, 87 | a tuple (coverage, upper_bound) will be returned upon evaluation, where upper_bound 88 | gives the maximum coverage any clauses extending the original clause can achieve 89 | """ 90 | super().__init__("Coverage") 91 | self._return_upperbound = return_upperbound 92 | 93 | def evaluate(self, clause: Clause, examples: Task, covered: Sequence[Atom]): 94 | self._clauses_evaluated += 1 95 | 96 | pos, neg = examples.get_examples() 97 | covered_pos = len(pos.intersection(covered)) 98 | covered_neg = len(neg.intersection(covered)) 99 | if self._return_upperbound: 100 | return (covered_pos - covered_neg), covered_pos 101 | return covered_pos - covered_neg 102 | 103 | 104 | class Entropy(EvalFunction): 105 | """ 106 | Entropy is a measure of how well the clause divides 107 | negative and positive examples into two distinct categories. 108 | This implementation uses: 109 | 110 | -(p * log10(p) + (1-p) * log10(1-p)), with p = P/(P+N). 111 | P and N are respectively 112 | the number of positive and negative examples that are covered by the clause 113 | """ 114 | 115 | def __init__(self): 116 | super().__init__("Entropy") 117 | 118 | def evaluate(self, clause: Clause, examples: Task, covered: Sequence[Atom]): 119 | self._clauses_evaluated += 1 120 | 121 | pos, neg = examples.get_examples() 122 | covered_pos = len(pos.intersection(covered)) 123 | covered_neg = len(neg.intersection(covered)) 124 | if covered_pos + covered_neg == 0: 125 | return 0 126 | 127 | p = covered_pos / (covered_pos + covered_neg) 128 | 129 | # Perfect split, no entropy 130 | if p == 1 or p == 0: 131 | return 0 132 | return -(p * math.log10(p) + (1 - p) * math.log10(1 - p)) 133 | -------------------------------------------------------------------------------- /loreleai/learning/hypothesis_space.py: -------------------------------------------------------------------------------- 1 | import typing 2 | from abc import ABC, abstractmethod 3 | from functools import reduce 4 | from itertools import combinations_with_replacement, combinations 5 | 6 | import networkx as nx 7 | from networkx.classes.function import create_empty_copy 8 | from orderedset import OrderedSet 9 | 10 | from loreleai.language.lp import ( 11 | Predicate, 12 | Clause, 13 | Procedure, 14 | Atom, 15 | Body, 16 | Recursion, 17 | ) 18 | from loreleai.learning.utilities import FillerPredicate 19 | from loreleai.learning.language_manipulation import plain_extension 20 | 21 | 22 | class HypothesisSpace(ABC): 23 | def __init__( 24 | self, 25 | primitives: typing.Sequence, 26 | head_constructor: typing.Union[Predicate, FillerPredicate], 27 | connected_clauses: bool = True, 28 | recursive_procedures: bool = False, 29 | expansion_hooks_keep: typing.Sequence = (), 30 | expansion_hooks_reject: typing.Sequence = () 31 | ) -> None: 32 | self._primitives: typing.Sequence = primitives 33 | self._head_constructor: typing.Union[ 34 | Predicate, FillerPredicate 35 | ] = head_constructor 36 | # Predicate -> use this predicate in the head 37 | # FillerPredicate -> create new head predicate for each clause/procedure 38 | self._hypothesis_space = None 39 | self._connected_clauses = connected_clauses 40 | self._use_recursions = recursive_procedures 41 | self._recursive_expansion = lambda x: plain_extension(x, self._head_constructor, connected_clauses=True) if (self._use_recursions and isinstance(self._head_constructor, Predicate)) else None 42 | self._pointers: typing.Dict[str, Body] = {"main": None} 43 | self._expansion_hooks_keep = expansion_hooks_keep 44 | self._expansion_hooks_reject = expansion_hooks_reject 45 | 46 | @abstractmethod 47 | def initialise(self) -> None: 48 | """ 49 | Initialise the search space 50 | """ 51 | raise NotImplementedError() 52 | 53 | @abstractmethod 54 | def expand(self, node: typing.Union[Clause, Procedure]) -> None: 55 | """ 56 | Expand the node in the search space 57 | """ 58 | raise NotImplementedError() 59 | 60 | @abstractmethod 61 | def block(self, node: typing.Union[Clause, Procedure]) -> None: 62 | """ 63 | Block expansion of the node 64 | """ 65 | raise NotImplementedError() 66 | 67 | @abstractmethod 68 | def ignore(self, node: typing.Union[Clause, Procedure]) -> None: 69 | """ 70 | Ignores the node, but keeps extending it 71 | """ 72 | raise NotImplementedError() 73 | 74 | @abstractmethod 75 | def remove(self, node: typing.Union[Clause, Procedure], remove_entire_body: bool = False, 76 | not_if_other_parents: bool = True) -> None: 77 | """ 78 | Removes the clause from the hypothesis space 79 | """ 80 | raise NotImplementedError() 81 | 82 | @abstractmethod 83 | def get_current_candidate(self) -> typing.Union[Clause, Procedure]: 84 | """ 85 | Get the next candidate 86 | """ 87 | raise NotImplementedError() 88 | 89 | @abstractmethod 90 | def get_successors_of( 91 | self, node: typing.Union[Clause, Procedure] 92 | ) -> typing.Sequence[typing.Union[Clause, Procedure]]: 93 | """ 94 | Get all successors of the provided node 95 | """ 96 | raise NotImplementedError() 97 | 98 | @abstractmethod 99 | def get_predecessor_of( 100 | self, node: typing.Union[Clause, Procedure] 101 | ) -> typing.Union[Clause, Procedure]: 102 | """ 103 | Returns the predecessor of the node 104 | """ 105 | raise NotImplementedError() 106 | 107 | @abstractmethod 108 | def move_pointer_to( 109 | self, node: typing.Union[Clause, Procedure], pointer_name: str = "main" 110 | ) -> None: 111 | """ 112 | Moves the current candidate pointer to 113 | """ 114 | raise NotImplementedError() 115 | 116 | 117 | class TopDownHypothesisSpace(HypothesisSpace): 118 | def __init__( 119 | self, 120 | primitives: typing.Sequence, 121 | head_constructor: typing.Union[Predicate, FillerPredicate], 122 | connected_clauses: bool = True, 123 | recursive_procedures: bool = False, 124 | repetitions_in_head_variables: int = 2, 125 | expansion_hooks_keep: typing.Sequence = (), 126 | expansion_hooks_reject: typing.Sequence = (), 127 | constants=None, 128 | initial_clause: typing.Union[Clause,Body] = None 129 | ): 130 | super().__init__( 131 | primitives, 132 | head_constructor, 133 | recursive_procedures=recursive_procedures, 134 | connected_clauses=connected_clauses, 135 | expansion_hooks_keep=expansion_hooks_keep, 136 | expansion_hooks_reject=expansion_hooks_reject 137 | ) 138 | self._hypothesis_space = nx.DiGraph() 139 | self._root_node: Body = None 140 | self._repetition_vars_head = repetitions_in_head_variables 141 | self._invented_predicate_count = 0 142 | self._recursive_pointers_count = 0 143 | self._constants = constants 144 | self._recursive_pointer_prefix = "rec" 145 | self.initialise(initial_clause) 146 | 147 | def initialise(self, initial_clause: typing.Union[Clause,Body]) -> None: 148 | """ 149 | Initialises the search space. It is possible to provide an initial 150 | clause to initialize the hypothesis space with (instead of :-). 151 | """ 152 | if isinstance(self._head_constructor, (Predicate, FillerPredicate)): 153 | if isinstance(self._head_constructor, Predicate): 154 | # create possible heads 155 | head_variables = [chr(x) for x in range(ord("A"), ord("Z"))][ 156 | : self._head_constructor.get_arity() 157 | ] 158 | 159 | possible_heads = [ 160 | self._head_constructor(*list(x)) 161 | for x in combinations_with_replacement(head_variables, self._head_constructor.get_arity()) 162 | ] 163 | else: 164 | possible_heads = self._head_constructor.all_possible_atoms() 165 | 166 | # create empty clause or use initial clause 167 | if initial_clause: 168 | clause = initial_clause if isinstance(initial_clause,Body) else initial_clause.get_body() 169 | else: 170 | clause = Body() 171 | if len(clause.get_literals()) > 0 and len(clause.get_variables()) < self._head_constructor.get_arity(): 172 | raise AssertionError("Cannot provide an initial clause with fewer distinct variables than the head predicate!") 173 | 174 | init_head_dict = {"ignored": False, "blocked": False, "visited": False} 175 | self._hypothesis_space.add_node(clause) 176 | self._hypothesis_space.nodes[clause]["heads"] = dict([(x, init_head_dict.copy()) for x in possible_heads]) 177 | self._hypothesis_space.nodes[clause]["visited"] = False 178 | 179 | self._pointers["main"] = clause 180 | self._root_node = clause 181 | else: 182 | raise Exception(f"Unknown head constructor ({self._head_constructor}") 183 | 184 | def _create_possible_heads( 185 | self, body: Body, use_as_head_predicate: Predicate = None 186 | ) -> typing.Sequence[Atom]: 187 | """ 188 | Creates possible heads for a given body 189 | 190 | if the _head_constructor is Predicate, it makes all possible combinations that matches the types in the head 191 | """ 192 | vars = body.get_variables() 193 | 194 | if isinstance(self._head_constructor, Predicate): 195 | arg_types = self._head_constructor.get_arg_types() 196 | 197 | # matches_vars = [] 198 | # for i in range(len(arg_types)): 199 | # matches_vars[i] = [] 200 | # for var_ind in range(len(vars)): 201 | # if arg_types[i] == vars[var_ind].get_type(): 202 | # matches_vars[i].append(vars[var_ind]) 203 | # 204 | # bases = [matches_vars[x] for x in range(self._head_constructor.get_arity())] 205 | # heads = [] 206 | # 207 | # for comb in product(*bases): 208 | # heads.append(Atom(self._head_constructor, list(comb))) 209 | heads = [] 210 | for comb in combinations(vars, self._head_constructor.get_arity()): 211 | if [x.get_type() for x in comb] == arg_types: 212 | heads.append(Atom(self._head_constructor, list(comb))) 213 | return heads 214 | elif isinstance(self._head_constructor, FillerPredicate): 215 | return self._head_constructor.new_from_body( 216 | body, use_as_head_predicate=use_as_head_predicate 217 | ) 218 | else: 219 | raise Exception(f"Unknown head constructor {self._head_constructor}") 220 | 221 | def _check_if_recursive(self, body: Body): 222 | """ 223 | checks if the body forms a recursive clause: 224 | - one of the predicates in the body is equal to the head predicate 225 | - a predicate constructed by FillerPredicate is in the body 226 | """ 227 | if isinstance(self._head_constructor, Predicate): 228 | return True if self._head_constructor in body.get_predicates() else False 229 | else: 230 | return ( 231 | True 232 | if any( 233 | [ 234 | self._head_constructor.is_created_by(x) 235 | for x in body.get_predicates() 236 | ] 237 | ) 238 | else False 239 | ) 240 | 241 | def _insert_node(self, node: typing.Union[Body]) -> bool: 242 | """ 243 | Inserts a clause/procedure into the hypothesis space 244 | 245 | Returns True if successfully inserted (after applying hooks), otherwise returns False 246 | """ 247 | recursive = self._check_if_recursive(node) 248 | 249 | if recursive and isinstance(self._head_constructor, FillerPredicate): 250 | recursive_pred = list( 251 | filter( 252 | lambda x: self._head_constructor.is_created_by(x), 253 | node.get_predicates(), 254 | ) 255 | )[0] 256 | possible_heads = self._create_possible_heads( 257 | node, use_as_head_predicate=recursive_pred 258 | ) 259 | else: 260 | possible_heads = self._create_possible_heads(node) 261 | 262 | # if expansion hooks available, check if the heads pass 263 | if self._expansion_hooks_keep: 264 | possible_heads = [x for x in possible_heads if all([f(x, node) for f in self._expansion_hooks_keep])] 265 | 266 | # if rejection hooks are available, check if the heads fail 267 | if self._expansion_hooks_reject: 268 | possible_heads = [x for x in possible_heads if not any([f(x, node) for f in self._expansion_hooks_reject])] 269 | 270 | if possible_heads: 271 | init_head_dict = {"ignored": False, "blocked": False, "visited": False, "cache": {}} 272 | possible_heads = dict([(x, init_head_dict.copy()) for x in possible_heads]) 273 | 274 | self._hypothesis_space.add_node( 275 | node, last_visited_from=None, heads=possible_heads 276 | ) 277 | 278 | if recursive: 279 | self._recursive_pointers_count += 1 280 | pointer_name = ( 281 | f"{self._recursive_pointer_prefix}{self._recursive_pointers_count}" 282 | ) 283 | self.register_pointer(pointer_name, self._root_node) 284 | self._hypothesis_space.nodes[node]["partner"] = pointer_name 285 | self._hypothesis_space.nodes[node]["blocked"] = True 286 | 287 | return True 288 | else: 289 | # print("No possible heads") 290 | return False 291 | 292 | def _insert_edge(self, parent: Body, child: Body,) -> None: 293 | """ 294 | Inserts a directed edge between two clauses/procedures in the hypothesis space 295 | """ 296 | # if child not in self._hypothesis_space.nodes: 297 | # self._insert_node(child) 298 | self._hypothesis_space.add_edge(parent, child) 299 | 300 | def add_to_cache(self, node: typing.Union[Clause, Procedure], key, val) -> None: 301 | """ 302 | Adds a key-value pair to the cache of a particular clause 303 | (body of the clause is the node, head is on of the attributes) 304 | 305 | Arguments: 306 | node [Clause, Procedure]: clause to associate the key-value pair with 307 | key: key to use 308 | val: value to store 309 | """ 310 | if isinstance(node, Procedure): 311 | raise NotImplementedError("no support for caching with procedures currently") 312 | else: 313 | head, body = node.get_head(), node.get_body() 314 | if head in self._hypothesis_space.nodes[body]["heads"]: 315 | if "cache" not in self._hypothesis_space.nodes[body]["heads"][head]: 316 | self._hypothesis_space.nodes[body]["heads"][head]["cache"] = {} 317 | self._hypothesis_space.nodes[body]["heads"][head]["cache"][key] = val 318 | 319 | def retrieve_from_cache(self, node: typing.Union[Clause, Procedure], key): 320 | """ 321 | Retrieves the value associate with the key from the cache of the clause 322 | 323 | Arguments: 324 | node: clause of interest 325 | key: key in the cache dict 326 | 327 | """ 328 | if isinstance(node, Procedure): 329 | raise NotImplementedError("no support for caching with procedures currently") 330 | else: 331 | head, body = node.get_head(), node.get_body() 332 | if head in self._hypothesis_space.nodes[body]["heads"] \ 333 | and "cache" in self._hypothesis_space.nodes[body]["heads"][head] \ 334 | and key in self._hypothesis_space.nodes[body]["heads"][head]["cache"]: 335 | return self._hypothesis_space.nodes[body]["heads"][head]["cache"][key] 336 | else: 337 | return None 338 | 339 | def remove_from_cache(self, node: typing.Union[Clause,Procedure], key): 340 | """ 341 | Removes the key from the cache associate with the clause 342 | 343 | Arguments: 344 | node: clause of interest 345 | key: key to delete 346 | """ 347 | if isinstance(node, Procedure): 348 | raise NotImplementedError("no support for caching with procedures currently") 349 | else: 350 | head, body = node.get_head(), node.get_body() 351 | if head in self._hypothesis_space.nodes[body]["heads"]: 352 | del self._hypothesis_space.nodes[body]["heads"][head]["cache"][key] 353 | 354 | def register_pointer(self, name: str, init_value: Body = None): 355 | """ 356 | Registers a new pointer. If init_value is None, assigns it to the root note 357 | """ 358 | if name in self._pointers: 359 | raise Exception(f"pointer {name} already exists!") 360 | else: 361 | self._pointers[name] = self._root_node if init_value is None else init_value 362 | 363 | def reset_pointer(self, name: str = "main", init_value: Body = None): 364 | """ 365 | Resets the specified pointer to the root or the specified initial value 366 | """ 367 | self._pointers[name] = self._root_node if init_value is None else init_value 368 | 369 | def _expand_body(self, node: Body) -> typing.Sequence[Body]: 370 | """ 371 | Expands the provided node with provided primitives (extensions become its children in a graph) 372 | 373 | returns the expanded constructs 374 | """ 375 | expansions = OrderedSet() 376 | 377 | # Add the result of applying a primitive 378 | for item in range(len(self._primitives)): 379 | exp = self._primitives[item](node) 380 | expansions = expansions.union(exp) 381 | 382 | # if recursions should be enumerated when FillerPredicate is used to construct the heads 383 | if self._use_recursions: 384 | if isinstance(self._head_constructor, FillerPredicate): 385 | recursive_cases = self._head_constructor.add_to_body(node) 386 | elif isinstance(self._head_constructor, Predicate): 387 | recursive_cases = self._recursive_expansion(node) 388 | else: 389 | raise Exception(f"Unknown head constructor ({type(self._head_constructor)})") 390 | 391 | for r_ind in range(len(recursive_cases)): 392 | expansions = expansions.union([recursive_cases[r_ind]]) 393 | 394 | # add expansions to the hypothesis space 395 | # if self._insert_node returns False, forget the expansion 396 | expansions_to_consider = [] 397 | for exp_ind in range(len(expansions)): 398 | r = self._insert_node(expansions[exp_ind]) 399 | if r: 400 | expansions_to_consider.append(expansions[exp_ind]) 401 | # else: 402 | # print("Rejected: {}".format(expansions[exp_ind])) 403 | 404 | expansions = expansions_to_consider 405 | 406 | # add edges 407 | for ind in range(len(expansions)): 408 | self._insert_edge(node, expansions[ind]) 409 | 410 | return expansions 411 | 412 | def retrieve_clauses_from_body( 413 | self, body: Body 414 | ) -> typing.Sequence[typing.Union[Clause, Procedure]]: 415 | """ 416 | Returns all possible clauses given the body 417 | """ 418 | heads = self._hypothesis_space.nodes[body]["heads"] 419 | 420 | heads = [ 421 | x 422 | for x in heads 423 | if not heads[x]["ignored"] 424 | ] 425 | return [Clause(x, body) for x in heads] 426 | 427 | def expand( 428 | self, node: typing.Union[Clause, Procedure] 429 | ) -> typing.Sequence[typing.Union[Clause, Procedure]]: 430 | """ 431 | Expands the provided node with provided primitives (extensions become its children in a graph) 432 | 433 | returns the expanded constructs 434 | if already expanded, returns an empty list 435 | """ 436 | body = self._extract_body(node) 437 | 438 | if ( 439 | "partner" in self._hypothesis_space.nodes[body] 440 | or self._hypothesis_space.nodes[body].get("blocked", False) 441 | ): 442 | # do not expand recursions or blocked nodes 443 | return [] 444 | 445 | # check if already expanded 446 | expansions = list(self._hypothesis_space.successors(body)) 447 | 448 | if len(expansions) == 0: 449 | expansions = self._expand_body(body) 450 | else: 451 | return [] 452 | 453 | return reduce( 454 | lambda x, y: x + y, 455 | [self.retrieve_clauses_from_body(x) if "partner" not in self._hypothesis_space.nodes[x] else self._get_recursions(x) for x in expansions], 456 | [], 457 | ) 458 | 459 | def block(self, node: typing.Union[Clause, Procedure]) -> None: 460 | """ 461 | Blocks the expansions of the body (but keeps it in the hypothesis space) 462 | """ 463 | # TODO: make it possible to block only specific clause 464 | clause = ( 465 | node 466 | if isinstance(node, Clause) 467 | else [x for x in node.get_clauses() if x.is_recursive()][0] 468 | ) 469 | body = clause.get_body() 470 | 471 | self._hypothesis_space.nodes[body]["blocked"] = True 472 | 473 | def remove( 474 | self, node: typing.Union[Clause, Procedure], 475 | remove_entire_body: bool = False, 476 | not_if_other_parents: bool = True 477 | ) -> None: 478 | """ 479 | Removes the node from the hypothesis space (and all of its descendents) 480 | """ 481 | 482 | clause = ( 483 | node 484 | if isinstance(node, Clause) 485 | else [x for x in node.get_clauses() if x.is_recursive()][0] 486 | ) 487 | 488 | head = clause.get_head() 489 | body = clause.get_body() 490 | 491 | children = self._hypothesis_space.successors(body) 492 | 493 | if not_if_other_parents: 494 | # do not remove children that have other parents 495 | children = [x for x in children if len(self._hypothesis_space.predecessors(x)) <= 1] 496 | 497 | if remove_entire_body: 498 | # remove entire body 499 | self._hypothesis_space.remove_node(body) 500 | 501 | for ch_ind in range(len(children)): 502 | self.remove(Clause(head, body), remove_entire_body=remove_entire_body, not_if_other_parents=not_if_other_parents) 503 | else: 504 | # remove just the head 505 | if head in self._hypothesis_space.nodes[body]["heads"]: 506 | del self._hypothesis_space.nodes[body]["heads"][head] 507 | 508 | if len(self._hypothesis_space.nodes[body]["heads"]) == 0: 509 | # if no heads left, remove the entire node 510 | self._hypothesis_space.remove_node(body) 511 | 512 | for ch_ind in range(len(children)): 513 | self.remove(Clause(head, body), remove_entire_body=True) 514 | else: 515 | # remove the same head from children 516 | if len(children) > 0: 517 | for ch_ind in range(len(children)): 518 | self.remove(Clause(head, children[ch_ind])) 519 | 520 | def ignore(self, node: typing.Union[Clause, Procedure]) -> None: 521 | """ 522 | Sets the node to be ignored. That is, the node will be expanded, but not taken into account as a candidate 523 | """ 524 | # TODO: make it possible to ignore the entire body 525 | clause = ( 526 | node 527 | if isinstance(node, Clause) 528 | else [x for x in node.get_clauses() if x.is_recursive()][0] 529 | ) 530 | 531 | head = clause.get_head() 532 | body = clause.get_body() 533 | 534 | self._hypothesis_space.nodes[body]["heads"][head]["ignored"] = True 535 | 536 | def _get_recursions(self, node: Body) -> typing.Sequence[Recursion]: 537 | """ 538 | Prepares the valid recursions 539 | """ 540 | pointer_name = self._hypothesis_space.nodes[node]["partner"] 541 | init_pointer_value = self._pointers[pointer_name] 542 | last_pointer_value = None 543 | 544 | valid_heads = list(self._hypothesis_space.nodes[node]["heads"].keys()) 545 | recursions = [] 546 | 547 | # for each valid head 548 | for h_ind in range(len(valid_heads)): 549 | c_head: Atom = valid_heads[h_ind] 550 | recursive_clause = Clause(c_head, node) 551 | 552 | frontier = [self._pointers[pointer_name]] 553 | 554 | while len(frontier) > 0: 555 | focus_node = frontier[0] 556 | frontier = frontier[1:] 557 | 558 | # find matching heads 559 | focus_node_heads: typing.Sequence[Atom] = list( 560 | self._hypothesis_space.nodes[focus_node]["heads"].keys() 561 | ) 562 | focus_node_heads = [ 563 | x 564 | for x in focus_node_heads 565 | if x.get_predicate().get_arg_types() 566 | == c_head.get_predicate().get_arg_types() 567 | ] 568 | 569 | # prepare recursion 570 | for bcl_ind in range(len(focus_node_heads)): 571 | if isinstance(self._head_constructor, Predicate): 572 | recursions.append( 573 | Recursion( 574 | [ 575 | Clause(focus_node_heads[bcl_ind], focus_node), 576 | recursive_clause, 577 | ] 578 | ) 579 | ) 580 | else: 581 | # if the filler predicate is used to construct heads, make sure the same head predicate is used 582 | head_args = focus_node_heads[bcl_ind].get_arguments() 583 | recursions.append( 584 | Recursion( 585 | [ 586 | Clause( 587 | Atom(c_head.get_predicate(), head_args), 588 | focus_node, 589 | ), 590 | recursive_clause, 591 | ] 592 | ) 593 | ) 594 | 595 | # extend the frontier - exclude recursive nodes 596 | to_add = [ 597 | x 598 | for x in self._hypothesis_space.successors(focus_node) 599 | if "partner" not in self._hypothesis_space.nodes[x] 600 | ] 601 | frontier += to_add 602 | last_pointer_value = focus_node 603 | 604 | # reset the pointer value for next valid head 605 | self.reset_pointer(pointer_name, init_pointer_value) 606 | 607 | # set the pointer to the last explored clause 608 | # set the point to the root, by default 609 | self.reset_pointer(pointer_name) #, last_pointer_value) 610 | 611 | return recursions 612 | 613 | def get_current_candidate( 614 | self, name: str = "main" 615 | ) -> typing.Sequence[typing.Union[Clause, Procedure]]: 616 | """ 617 | Get the current program candidate (the current pointer) 618 | """ 619 | if "partner" in self._hypothesis_space.nodes[self._pointers[name]]: 620 | # recursion 621 | return self._get_recursions(self._pointers[name]) 622 | else: 623 | return self.retrieve_clauses_from_body(self._pointers[name]) 624 | 625 | def _extract_body(self, clause: typing.Union[Clause, Procedure]) -> Body: 626 | if isinstance(clause, Clause): 627 | return clause.get_body() 628 | elif isinstance(clause, Recursion): 629 | rec = clause.get_recursive_case() 630 | if len(rec) == 1: 631 | return rec[0].get_body() 632 | else: 633 | raise Exception( 634 | f"got more than one recursive case when extracting the body {clause}" 635 | ) 636 | else: 637 | raise Exception( 638 | f"Don't know how to get a single body from {type(clause)} {clause}" 639 | ) 640 | 641 | def move_pointer_to( 642 | self, node: typing.Union[Clause, Recursion, Body], pointer_name: str = "main" 643 | ) -> None: 644 | """ 645 | Moves the pointer to the pre-defined node 646 | """ 647 | if isinstance(node, Body): 648 | body = node 649 | else: 650 | body = self._extract_body(node) 651 | 652 | self._hypothesis_space.nodes[body]["last_visited_from"] = self._pointers[ 653 | pointer_name 654 | ] 655 | self._pointers[pointer_name] = body 656 | 657 | def get_predecessor_of( 658 | self, node: typing.Union[Clause, Recursion, Body] 659 | ) -> typing.Union[Clause, Recursion, Body, typing.Sequence[Clause]]: 660 | """ 661 | Returns the predecessor of the node = the last position of the pointer before reaching the node 662 | :param node: 663 | :return: 664 | """ 665 | # TODO: make it possible to get all predecessors, not just the last visited from 666 | if isinstance(node, Body): 667 | return self._hypothesis_space.nodes[node]["last_visited_from"] 668 | else: 669 | if isinstance(node, Clause): 670 | head = node.get_head() 671 | body = node.get_body() 672 | else: 673 | rec = node.get_recursive_case() 674 | if len(rec) > 1: 675 | raise Exception( 676 | "do not support recursions with more than 1 recursive case" 677 | ) 678 | else: 679 | head = rec[0].get_head() 680 | body = rec[0].get_body() 681 | 682 | predecessor = self._hypothesis_space.nodes[body]["last_visited_from"] 683 | if head in self._hypothesis_space.nodes[predecessor]["heads"]: 684 | return Clause(head, predecessor) 685 | else: 686 | return self.retrieve_clauses_from_body(predecessor) 687 | 688 | def get_successors_of( 689 | self, node: typing.Union[Clause, Recursion, Body] 690 | ) -> typing.Sequence[typing.Union[Clause, Body, Procedure]]: 691 | """ 692 | Returns all successors of the node 693 | """ 694 | if isinstance(node, Body): 695 | return reduce(lambda x, y: x + y, [self.retrieve_clauses_from_body(x) if not self._check_if_recursive(x) else self._get_recursions(x) for x in self._hypothesis_space.successors(node)], []) 696 | else: 697 | body = self._extract_body(node) 698 | return reduce(lambda x, y: x + y, [self.retrieve_clauses_from_body(x) if not self._check_if_recursive(x) else self._get_recursions(x) for x in self._hypothesis_space.successors(body)], []) 699 | 700 | def remove_all_edges(self): 701 | self._hypothesis_space = create_empty_copy(self._hypothesis_space,with_data=True) -------------------------------------------------------------------------------- /loreleai/learning/language_filtering.py: -------------------------------------------------------------------------------- 1 | from functools import reduce 2 | 3 | from loreleai.language.lp import Body, Atom, Not, Predicate 4 | from .utilities import are_variables_connected 5 | 6 | """ 7 | It contains the functions used to prune the search space 8 | """ 9 | 10 | 11 | def has_singleton_vars(head: Atom, body: Body) -> bool: 12 | """ 13 | Returns True is the clause has a singleton variable (appears only once) 14 | """ 15 | if len(body) == 0: 16 | return False 17 | 18 | vars = {} 19 | head_vars = head.get_variables() 20 | for ind in range(len(head_vars)): 21 | if head_vars[ind] not in vars: 22 | vars[head_vars[ind]] = head_vars.count(head_vars[ind]) 23 | 24 | bvars = body.get_variables() 25 | body_vars_flat = reduce(lambda x, y: x + y, [x.get_variables() for x in body.get_literals()], []) 26 | for ind in range(len(bvars)): 27 | if bvars[ind] in vars: 28 | vars[bvars[ind]] += body_vars_flat.count(bvars[ind]) 29 | else: 30 | vars[bvars[ind]] = body_vars_flat.count(bvars[ind]) 31 | 32 | return True if any([k for k, v in vars.items() if v == 1]) else False 33 | 34 | 35 | def max_var(head: Atom, body: Body, max_count: int) -> bool: 36 | """ 37 | Return True if there are no more than max_count variables in the clause 38 | """ 39 | vars = body.get_variables() 40 | for v in head.get_variables(): 41 | if v not in vars: 42 | vars += [v] 43 | return True if len(vars) <= max_count else False 44 | 45 | 46 | def connected_body(head: Atom, body: Body) -> bool: 47 | """ 48 | Returns True if variables in the body cannot be partitioned in two non-overlapping sets 49 | """ 50 | if len(body) == 0: 51 | return True 52 | return are_variables_connected([x.get_atom() if isinstance(x, Not) else x for x in body.get_literals()]) 53 | 54 | 55 | def connected_clause(head: Atom, body: Body) -> bool: 56 | """ 57 | Returns True is the variables in the clause cannot be partitioned in two non-overlapping sets 58 | """ 59 | if len(body) == 0: 60 | return True 61 | return are_variables_connected([x.get_atom() if isinstance(x, Not) else x for x in body.get_literals() + [head]]) 62 | 63 | 64 | def negation_at_the_end(head: Atom, body: Body) -> bool: 65 | """ 66 | Returns True is negations appear after all positive literals 67 | """ 68 | pos_location = -1 69 | neg_location = -1 70 | lits = body.get_literals() 71 | 72 | for ind in range(len(lits)): 73 | if isinstance(lits[ind], Atom): 74 | pos_location = ind 75 | elif neg_location < 0: 76 | neg_location = ind 77 | 78 | return False if (-1 < neg_location < pos_location) else True 79 | 80 | 81 | def max_pred_occurrences(head: Atom, body: Body, pred: Predicate, max_occurrence: int) -> bool: 82 | """ 83 | Returns True if the predicate pred does not appear more than max_occurrence times in the clause 84 | """ 85 | preds = [x for x in body.get_literals() if x.get_predicate() == pred] 86 | 87 | return len(preds) <= max_occurrence 88 | 89 | 90 | def has_duplicated_literal(head: Atom, body: Body) -> bool: 91 | """ 92 | Returns True if there are duplicated literals in the body 93 | """ 94 | return len(body) != len(set(body.get_literals())) 95 | 96 | 97 | 98 | 99 | -------------------------------------------------------------------------------- /loreleai/learning/language_manipulation.py: -------------------------------------------------------------------------------- 1 | import typing 2 | from itertools import product, combinations_with_replacement 3 | 4 | from loreleai.language.lp import ( 5 | Clause, 6 | Procedure, 7 | Predicate, 8 | Variable, 9 | Type, 10 | c_var, 11 | Disjunction, 12 | Recursion, 13 | Not, 14 | Body, 15 | Constant, 16 | Atom 17 | ) 18 | 19 | INPUT_ARG = 1 20 | OUTPUT_ARG = 2 21 | CONSTANT_ARG = 3 22 | 23 | 24 | def new_variable( 25 | existing_variables: typing.Set[Variable], type: Type = None 26 | ) -> Variable: 27 | existing_variable_names = {x.get_name() for x in existing_variables} 28 | if len(existing_variables) < 27: 29 | potential_names = [ 30 | chr(x) 31 | for x in range(ord("A"), ord("Z") + 1) 32 | if chr(x) not in existing_variable_names 33 | ] 34 | else: 35 | potential_names = [ 36 | f"{chr(x)}{chr(y)}" 37 | for x in range(ord("A"), ord("Z") + 1) 38 | for y in range(ord("A"), ord("Z") + 1) 39 | if f"{chr(x)}{chr(y)}" not in existing_variable_names 40 | ] 41 | 42 | return c_var(potential_names[0], type) 43 | 44 | 45 | 46 | """ 47 | Plain extensions adds literals to a clause/body with every possible combination of variables 48 | (no language bias, all possible combinations) 49 | """ 50 | 51 | def _plain_extend_clause( 52 | clause: typing.Union[Clause, Body], 53 | predicate: Predicate, 54 | connected_clause: bool = True 55 | ) -> typing.Sequence[typing.Union[Clause, Body]]: 56 | """ 57 | Extends the clause with the predicate in every possible way (no bias) 58 | 59 | Arguments: 60 | clause: a clause to be extended 61 | predicate: a predicate to add to the clause 62 | """ 63 | if isinstance(clause, Body) and len(clause) == 0: 64 | head_variables = [chr(x) for x in range(ord("A"), ord("Z"))][: predicate.get_arity()] 65 | possible_heads = [ 66 | Body(predicate(*list(x))) 67 | for x in combinations_with_replacement(head_variables, predicate.get_arity()) 68 | ] 69 | 70 | return possible_heads 71 | 72 | clause_variables: typing.Sequence[Variable] = clause.get_variables() 73 | used_variables = {x for x in clause_variables} 74 | pred_argument_types: typing.Sequence[Type] = predicate.get_arg_types() 75 | 76 | argument_matches = {} 77 | new_variables = set() 78 | 79 | # create new variable for each argument of a predicate 80 | for arg_ind in range(len(pred_argument_types)): 81 | new_var = new_variable(used_variables, pred_argument_types[arg_ind]) 82 | argument_matches[arg_ind] = [new_var] 83 | used_variables.add(new_var) 84 | new_variables.add(new_var) 85 | 86 | # check for potential match with other variables 87 | for clv_ind in range(len(clause_variables)): 88 | for arg_ind in range(len(pred_argument_types)): 89 | if clause_variables[clv_ind].get_type() == pred_argument_types[arg_ind]: 90 | argument_matches[arg_ind].append(clause_variables[clv_ind]) 91 | 92 | # do cross product of matches 93 | base_sets = [argument_matches[x] for x in range(len(pred_argument_types))] 94 | candidates: typing.List[typing.Union[Clause, Body]] = [] 95 | 96 | for arg_combo in product(*base_sets): 97 | new_clause = None 98 | if connected_clause and not all( 99 | [True if x in new_variables else False for x in arg_combo] 100 | ): 101 | # check that the new literal is not disconnected from the rest of the clause 102 | new_clause = clause + predicate(*list(arg_combo)) 103 | elif not connected_clause: 104 | new_clause = clause + predicate(*list(arg_combo)) 105 | 106 | if new_clause is not None: 107 | candidates.append(new_clause) 108 | 109 | return candidates 110 | 111 | 112 | def _plain_extend_negation_clause( 113 | clause: typing.Union[Clause, Body], predicate: Predicate 114 | ) -> typing.Sequence[typing.Union[Clause, Body]]: 115 | """ 116 | Extends a clause with the negation of a predicate (no new variables allowed) 117 | """ 118 | if isinstance(clause, Body): 119 | suitable_vars = clause.get_variables() 120 | else: 121 | suitable_vars = clause.get_body_variables() 122 | pred_argument_types: typing.Sequence[Type] = predicate.get_arg_types() 123 | argument_matches = {} 124 | 125 | # check for potential match with other variables 126 | for clv_ind in range(len(suitable_vars)): 127 | for arg_ind in range(len(pred_argument_types)): 128 | if suitable_vars[clv_ind].get_type() == pred_argument_types[arg_ind]: 129 | if arg_ind not in argument_matches: 130 | argument_matches[arg_ind] = [] 131 | argument_matches[arg_ind].append(suitable_vars[clv_ind]) 132 | 133 | base_sets = [argument_matches[x] for x in range(len(pred_argument_types))] 134 | candidates: typing.List[typing.Union[Clause, Body]] = [] 135 | 136 | for arg_combo in product(*base_sets): 137 | new_clause = clause + Not(predicate(*list(arg_combo))) 138 | candidates.append(new_clause) 139 | 140 | return candidates 141 | 142 | def _instantiate_var_clause(clause: Clause, constant: Constant): 143 | """ 144 | Returns all clauses generated by substituting a `Variable` in `clause` with `constant`. 145 | """ 146 | if isinstance(clause, Body): 147 | suitable_vars = clause.get_variables() 148 | else: 149 | suitable_vars = clause.get_body_variables() 150 | 151 | candidates = [] 152 | for var in suitable_vars: 153 | candidates.append(clause.substitute({var:constant})) 154 | # if len(candidates) > 0: 155 | # print("Type of each is {}", type(candidates[0])) 156 | return list(set(candidates)) 157 | 158 | 159 | def variable_instantiation( 160 | clause: typing.Union[Clause,Body,Procedure], 161 | constant: Constant) -> typing.Sequence[typing.Union[Clause,Body,Procedure]]: 162 | """ 163 | Extends a clause by instantiation, replacing all occurrences 164 | of a variable with a constant 165 | """ 166 | if isinstance(clause, (Clause, Body)): 167 | return _instantiate_var_clause(clause, constant) 168 | else: 169 | clauses = clause.get_clauses() 170 | 171 | # extend each clause individually 172 | extensions = [] 173 | for cl_ind in range(len(clauses)): 174 | clause_extensions = (_instantiate_var_clause(clauses[cl_ind], constant)) 175 | for ext_cl_ind in range(len(clause_extensions)): 176 | cls = [ 177 | clauses[x] if x != cl_ind else clause_extensions[ext_cl_ind] 178 | for x in range(len(clauses)) 179 | ] 180 | 181 | if isinstance(clause, Disjunction): 182 | extensions.append(Disjunction(cls)) 183 | else: 184 | extensions.append(Recursion(cls)) 185 | print("Extensions for {} are {}".format(clause,extensions)) 186 | return extensions 187 | 188 | 189 | def plain_extension( 190 | clause: typing.Union[Clause, Body, Procedure], 191 | predicate: Predicate, 192 | connected_clauses: bool = True, 193 | negated: bool = False, 194 | ) -> typing.Sequence[typing.Union[Clause, Body, Procedure]]: 195 | """ 196 | Extends a clause or a procedure without any bias. Only checks for variable type match. 197 | Adds the predicate to the clause/procedure 198 | """ 199 | if isinstance(clause, (Clause, Body)): 200 | if negated: 201 | return _plain_extend_negation_clause(clause, predicate) 202 | else: 203 | return _plain_extend_clause( 204 | clause, predicate, connected_clause=connected_clauses 205 | ) 206 | else: 207 | clauses = clause.get_clauses() 208 | 209 | # extend each clause individually 210 | extensions = [] 211 | for cl_ind in range(len(clauses)): 212 | clause_extensions = ( 213 | _plain_extend_clause( 214 | clauses[cl_ind], predicate, connected_clause=connected_clauses 215 | ) 216 | if not negated 217 | else _plain_extend_negation_clause(clauses[cl_ind], predicate) 218 | ) 219 | for ext_cl_ind in range(len(clause_extensions)): 220 | cls = [ 221 | clauses[x] if x != cl_ind else clause_extensions[ext_cl_ind] 222 | for x in range(len(clauses)) 223 | ] 224 | 225 | if isinstance(clause, Disjunction): 226 | extensions.append(Disjunction(cls)) 227 | else: 228 | extensions.append(Recursion(cls)) 229 | 230 | return extensions 231 | 232 | def aleph_extension( 233 | clause: typing.Union[Clause, Body, Procedure], 234 | predicate: Predicate, 235 | allowed_positions_dict, 236 | constants: typing.Sequence[Constant], 237 | allowed_reflexivity = [] 238 | ) -> typing.Sequence[typing.Union[Clause, Body, Procedure]]: 239 | """ 240 | Used to generate new clauses in the Aleph learner. 241 | This extension is a mixture of plain clause extension 242 | and smart instantiation of variables. 243 | 244 | First, all plain extensions are generated. Then, the least Herbrand 245 | model is constructed given the Task's background information. This model 246 | contains all grounded facts that can be derived from the background information. 247 | From this model we can derive position constraints, that constrain where 248 | constants are allowed to appear in certain predicates, if at all. 249 | 250 | Next, variable instantiation is applied to all extended clauses. If a resulting 251 | clauses violates the position constraints, it is immediately pruned. 252 | """ 253 | 254 | # Plain extension 255 | extensions = plain_extension(clause,predicate) 256 | subst_extensions = set() 257 | 258 | # Variable instantiation 259 | for ext in extensions: 260 | for const in constants: 261 | subst_extensions = subst_extensions.union(set(variable_instantiation(ext,const))) 262 | # print(subst_extensions) 263 | 264 | # Prune clauses that violate position constraints 265 | pruned_clauses1 = set() 266 | for cl in subst_extensions: 267 | if _valid_positions(cl,allowed_positions_dict,allowed_reflexivity=allowed_reflexivity): 268 | pruned_clauses1.add(cl) 269 | 270 | # Prune clauses that have a fully grounded atom 271 | pruned_clauses2 = set() 272 | for cl in pruned_clauses1: 273 | if _not_grounded(cl): 274 | pruned_clauses2.add(cl) 275 | 276 | # print("Pruning {} -> {} -> {}".format(len(extensions),len(subst_extensions),len(pruned_clauses1),len(pruned_clauses2))) 277 | return list(set(extensions).union(pruned_clauses2)) 278 | 279 | def _valid_positions(cl: Clause,allowed_positions_dict,allowed_reflexivity=[]): 280 | """ 281 | Returns True iff the clause `cl` respects the allowed 282 | positions for constants as given by `allowed_positions_dict`, 283 | and is not reflective (e.g. next(X,X)) when disallowed 284 | """ 285 | for atom in cl.get_literals(): 286 | pred = atom.get_predicate() 287 | args = atom.get_arguments() 288 | for i in range(len(args)): 289 | arg = args[i] 290 | # Constants must appear at the right places 291 | if isinstance(arg,Constant): 292 | if not i in allowed_positions_dict[arg][pred]: 293 | return False 294 | 295 | # If all arguments are equal, this must be explicitly allowed 296 | if len(args) > 0 and all(args[i] == args[0] for i in range(len(args))) and pred not in allowed_reflexivity: 297 | return False 298 | return True 299 | 300 | def _not_grounded(cl): 301 | for atom in cl.get_literals(): 302 | if len(atom.get_variables()) == 0: 303 | return False 304 | return True 305 | 306 | 307 | 308 | 309 | 310 | 311 | class BottomClauseExpansion: 312 | 313 | def __init__(self, clause: Clause, only_connected_clauses: bool = False): 314 | self._clause: Clause = clause 315 | self._variable_literal_dependency: typing.Dict[Variable, typing.List[typing.Union[Atom, Not]]] = {} 316 | self._literal_order: typing.Dict[typing.Union[Atom, Not], int] = {} 317 | self._only_connected_clauses = only_connected_clauses 318 | 319 | self._to_dependency_structure() 320 | 321 | def _to_dependency_structure(self): 322 | body_lits = self._clause.get_literals() 323 | 324 | for lit_ind in range(len(body_lits)): 325 | # store order of the literal 326 | self._literal_order[body_lits[lit_ind]] = len(self._literal_order) 327 | 328 | # compute variable dependency 329 | vars = body_lits[lit_ind].get_variables() 330 | for v_ind in range(len(vars)): 331 | cv = vars[v_ind] 332 | 333 | if cv not in self._variable_literal_dependency: 334 | self._variable_literal_dependency[cv] = [] 335 | 336 | self._variable_literal_dependency[cv].append(body_lits[lit_ind]) 337 | 338 | def _expand_clause(self, clause: typing.Union[Clause, Body]) -> typing.Sequence[typing.Union[Clause, Body]]: 339 | existing_vars = clause.get_variables() 340 | used_literals = {x for x in clause.get_literals()} 341 | last_literal_id = self._literal_order.get(clause.get_literals()[-1], -1) if len(clause) > 0 else -1 342 | 343 | expansions = [] 344 | 345 | if self._only_connected_clauses: 346 | for v_ind in range(len(existing_vars)): 347 | v = existing_vars[v_ind] 348 | lits_to_add = [ 349 | x 350 | for x in self._variable_literal_dependency.get(v, []) 351 | if x not in used_literals and self._literal_order[x] > last_literal_id 352 | ] 353 | 354 | for l_ind in range(len(lits_to_add)): 355 | expansions.append(clause + lits_to_add[l_ind]) 356 | else: 357 | expansions = [ 358 | clause + l 359 | for l in self._clause.get_literals() 360 | if l not in used_literals and self._literal_order[l] > last_literal_id] 361 | 362 | return expansions 363 | 364 | def expand(self, clause: typing.Union[Clause, Body, Procedure]) -> typing.Sequence[typing.Union[Body, Clause, Procedure]]: 365 | """ 366 | Expands the clause/byd/procedure by adding literals from the bottom clause 367 | :param clause: 368 | :param variable_lit_dependency: 369 | :return: 370 | """ 371 | if isinstance(clause, (Body, Clause)): 372 | return self._expand_clause(clause) 373 | else: 374 | clauses = clause.get_clauses() 375 | 376 | # extend each clause individually 377 | extensions = [] 378 | for cl_ind in range(len(clauses)): 379 | clause_extensions = self._expand_clause(clauses[cl_ind]) 380 | 381 | for ext_cl_ind in range(len(clause_extensions)): 382 | cls = [ 383 | clauses[x] if x != cl_ind else clause_extensions[ext_cl_ind] 384 | for x in range(len(clauses)) 385 | ] 386 | 387 | if isinstance(clause, Disjunction): 388 | extensions.append(Disjunction(cls)) 389 | else: 390 | extensions.append(Recursion(cls)) 391 | 392 | return extensions 393 | -------------------------------------------------------------------------------- /loreleai/learning/learners/__init__.py: -------------------------------------------------------------------------------- 1 | from .breadth_first_learner import SimpleBreadthFirstLearner 2 | from .aleph import Aleph 3 | 4 | -------------------------------------------------------------------------------- /loreleai/learning/learners/aleph.py: -------------------------------------------------------------------------------- 1 | import typing 2 | from abc import ABC, abstractmethod 3 | from typing import Sequence, Dict, Tuple 4 | from orderedset import OrderedSet 5 | import datetime 6 | 7 | from loreleai.learning.abstract_learners import Learner, TemplateLearner, LearnResult 8 | from loreleai.learning import Task, Knowledge, HypothesisSpace, TopDownHypothesisSpace 9 | from loreleai.language.lp import ( 10 | Clause, 11 | Constant, 12 | c_type, 13 | Variable, 14 | Not, 15 | Atom, 16 | Procedure, 17 | Body 18 | ) 19 | from itertools import product, combinations_with_replacement 20 | from collections import Counter 21 | from loreleai.reasoning.lp import LPSolver 22 | from loreleai.learning.language_manipulation import plain_extension,aleph_extension 23 | from loreleai.learning.language_filtering import ( 24 | has_singleton_vars, 25 | has_duplicated_literal, 26 | ) 27 | from loreleai.learning.eval_functions import EvalFunction, Coverage 28 | from loreleai.learning.language_manipulation import variable_instantiation 29 | from loreleai.learning.utilities import ( 30 | compute_bottom_clause, 31 | find_allowed_positions, 32 | find_allowed_reflexivity, 33 | find_frequent_constants) 34 | 35 | 36 | class Aleph(TemplateLearner): 37 | """ 38 | Implements the Aleph learner in loreleai. See https://www.cs.ox.ac.uk/activities/programinduction/Aleph/aleph.html#SEC45. 39 | Aleph efficiently searches the hypothesis space by bounding the search from above (X :- true) and below (using the bottom clause), 40 | and by using mode declarations for predicates. It iteratively adds new clauses that maximize the evalfn. Searching for a new clause 41 | is done using a branch-and-bound algorithm, where clauses that are guaranteed to not lead to improvements are immediately pruned. 42 | 43 | Aleph currently only supports eval functions that can define an upper bound on the quality of a clause, such as Coverage 44 | and Compression. 45 | """ 46 | 47 | def __init__( 48 | self, 49 | solver: LPSolver, 50 | eval_fn: EvalFunction, 51 | max_body_literals=5, 52 | do_print=False, 53 | ): 54 | super().__init__(solver, eval_fn, do_print) 55 | self._max_body_literals = max_body_literals 56 | 57 | def learn( 58 | self, examples: Task, knowledge: Knowledge, hypothesis_space: HypothesisSpace, 59 | initial_clause: typing.Union[Body,Clause] = None, minimum_freq: int = 0 60 | ): 61 | """ 62 | To find a hypothesis, Aleph uses the following set covering approach: 63 | 1. Select a positive example to be generalised. If none exists, stop; otherwise proceed to the next step. 64 | 2. Construct the most specific clause (the bottom clause) (Muggleton, 1995) that entails the selected example 65 | and that is consistent with the mode declarations. 66 | 3. Search for a clause more general than the bottom clause and that has the best score. 67 | 4. Add the clause to the current hypothesis and remove all the examples made redundant by it. 68 | Return to step 1. 69 | (Description from Cropper and Dumancic ) 70 | """ 71 | 72 | # Variables for learning statics 73 | start_time = datetime.datetime.now() 74 | i = 0 75 | stop = False 76 | self._learnresult = LearnResult() # Reset in case the learner is reused 77 | self._prolog_queries = 0 78 | self._intermediate_coverage = [] 79 | self._eval_fn._clauses_evaluated = 0 80 | 81 | # Assert all BK into engines 82 | self._solver.retract_all() 83 | self._assert_knowledge(knowledge) 84 | 85 | # Start with all examples 86 | examples_to_use = examples 87 | pos, _ = examples_to_use.get_examples() 88 | 89 | # List of clauses we're learning 90 | prog = [] 91 | 92 | # parameters for aleph_extension() 93 | allowed_positions = find_allowed_positions(knowledge) 94 | allowed_reflexivity = find_allowed_reflexivity(knowledge) 95 | if minimum_freq > 0: 96 | allowed_constants = find_frequent_constants(knowledge,minimum_freq) 97 | else: 98 | allowed_constants = None 99 | 100 | # Create HypothesisSpace: primitives will be different in each iteration 101 | # (based on the chosen positive example) 102 | hs = TopDownHypothesisSpace( 103 | primitives=[], 104 | head_constructor=list(pos)[0].get_predicate(), 105 | expansion_hooks_reject=[ 106 | lambda x, y: has_duplicated_literal(x, y), 107 | ], 108 | initial_clause=initial_clause 109 | ) 110 | 111 | while len(pos) > 0 and not stop: 112 | i += 1 113 | 114 | # Pick example from pos 115 | pos_ex = Clause(list(pos)[0], []) 116 | bk = knowledge.as_clauses() 117 | bottom = compute_bottom_clause(bk, pos_ex) 118 | if self._print: 119 | print("Next iteration: generalizing example {}".format(str(pos_ex))) 120 | # print("Bottom clause: " + str(bottom)) 121 | 122 | # Predicates can only be picked from the body of the bottom clause 123 | body_predicates = list( 124 | set(map( 125 | lambda l: l.get_predicate(), 126 | bottom.get_body().get_literals())) 127 | ) 128 | 129 | # Constants can only be picked from the literals in the bottom clause, 130 | # and from constants that are frequent enough in bk (if applicable) 131 | if allowed_constants is None: 132 | allowed = lambda l: isinstance(l,Constant) or isinstance(l,int) 133 | else: 134 | allowed = lambda l: (isinstance(l,Constant) and l in allowed_constants) or isinstance(l,int) 135 | 136 | constants = list(set(list(filter( 137 | allowed, 138 | bottom.get_body().get_arguments(),)))) 139 | if self._print: 140 | print("Constants in bottom clause: {}".format(constants)) 141 | print("Predicates in bottom clause: {}".format(body_predicates)) 142 | 143 | # IMPORTANT: use VALUES of pred and constants, not the variables 144 | # Has something to do with closures 145 | extensions = [ 146 | lambda x,a=pred,b=allowed_positions,c=constants,d=allowed_reflexivity: aleph_extension(x,a,b,c,d) for pred in body_predicates 147 | ] 148 | 149 | # Update hypothesis space for this iteration 150 | hs._primitives = extensions 151 | hs.remove_all_edges() 152 | 153 | # Learn 1 clause and add to program 154 | cl = self._learn_one_clause(examples_to_use, hs) 155 | prog.append(cl) 156 | if self._print: 157 | print("- New clause: " + str(cl)) 158 | 159 | # update covered positive examples 160 | covered = self._execute_program(cl) 161 | if self._print: 162 | print( 163 | "Clause covers {} pos examples: {}".format( 164 | len(pos.intersection(covered)), pos.intersection(covered) 165 | ) 166 | ) 167 | 168 | # Find intermediate quality of program at this point, add to learnresult (don't cound these as Prolog queries) 169 | c = set() 170 | for cl in prog: 171 | c = c.union(self._execute_program(cl,count_as_query=False)) 172 | pos_covered = len(c.intersection(examples._positive_examples)) 173 | neg_covered = len(c.intersection(examples._negative_examples)) 174 | self._intermediate_coverage.append((pos_covered,neg_covered)) 175 | 176 | # Remove covered examples and start next iteration 177 | pos, neg = examples_to_use.get_examples() 178 | pos = pos.difference(covered) 179 | examples_to_use = Task(pos, neg) 180 | 181 | if self._print: 182 | print("Finished iteration {}".format(i)) 183 | # print("Current program: {}".format(str(prog))) 184 | 185 | # Wrap results into learnresult and return 186 | self._learnresult['learner'] = "Aleph" 187 | self._learnresult["total_time"] = (datetime.datetime.now() - start_time).total_seconds() 188 | self._learnresult["final_program"] = prog 189 | self._learnresult["num_iterations"] = i 190 | self._learnresult["evalfn_evaluations"] = self._eval_fn._clauses_evaluated 191 | self._learnresult["prolog_queries"] = self._prolog_queries 192 | self._learnresult["intermediate_coverage"] = self._intermediate_coverage 193 | 194 | return self._learnresult 195 | 196 | def initialise_pool(self): 197 | self._candidate_pool = OrderedSet() 198 | 199 | def put_into_pool( 200 | self, candidates: Tuple[typing.Union[Clause, Procedure, typing.Sequence], float] 201 | ) -> None: 202 | if isinstance(candidates, Tuple): 203 | self._candidate_pool.add(candidates) 204 | else: 205 | self._candidate_pool |= candidates 206 | 207 | def prune_pool(self, minValue): 208 | """ 209 | Removes all clauss with upper bound on value < minValue form pool 210 | """ 211 | self._candidate_pool = OrderedSet( 212 | [t for t in self._candidate_pool if not t[2] < minValue] 213 | ) 214 | 215 | def get_from_pool(self) -> Clause: 216 | return self._candidate_pool.pop(0) 217 | 218 | def stop_inner_search( 219 | self, eval: typing.Union[int, float], examples: Task, clause: Clause 220 | ) -> bool: 221 | raise NotImplementedError() 222 | 223 | def process_expansions( 224 | self, 225 | examples: Task, 226 | exps: typing.Sequence[Clause], 227 | hypothesis_space: TopDownHypothesisSpace, 228 | ) -> typing.Sequence[Clause]: 229 | # eliminate every clause with more body literals than allowed 230 | exps = [cl for cl in exps if len(cl) <= self._max_body_literals] 231 | 232 | # check if every clause has solutions 233 | exps = [ 234 | (cl, self._solver.has_solution(*cl.get_body().get_literals())) 235 | for cl in exps 236 | ] 237 | new_exps = [] 238 | 239 | for ind in range(len(exps)): 240 | if exps[ind][1]: 241 | # keep it if it has solutions 242 | new_exps.append(exps[ind][0]) 243 | # print(f"Not removed: {exps[ind][0]}") 244 | else: 245 | # remove from hypothesis space if it does not 246 | hypothesis_space.remove(exps[ind][0]) 247 | # print(f"Removed: {exps[ind][0]}") 248 | 249 | return new_exps 250 | 251 | def _execute_program(self, clause: Clause, count_as_query: bool = True) -> typing.Sequence[Atom]: 252 | """ 253 | Evaluates a clause using the Prolog engine and background knowledge 254 | 255 | Returns a set of atoms that the clause covers 256 | """ 257 | if len(clause.get_body().get_literals()) == 0: 258 | # Covers all possible examples because trivial hypothesis 259 | return None 260 | else: 261 | head_predicate = clause.get_head().get_predicate() 262 | head_args = clause.get_head_arguments() 263 | # print("{}({})".format(head_predicate, *head_args)) 264 | 265 | sols = self._solver.query(*clause.get_body().get_literals()) 266 | self._prolog_queries += 1 if count_as_query else 0 267 | 268 | # Build a solution by substituting Variables with their found value 269 | # and copying constants without change 270 | sols = [head_predicate(*[s[v] if isinstance(v,Variable) else v for v in head_args]) for s in sols] 271 | 272 | return sols 273 | 274 | def _learn_one_clause( 275 | self, examples: Task, hypothesis_space: TopDownHypothesisSpace 276 | ) -> Clause: 277 | """ 278 | Learns a single clause to add to the theory. 279 | Algorithm from https://www.cs.ox.ac.uk/activities/programinduction/Aleph/aleph.html#SEC45 280 | """ 281 | # reset the search space 282 | hypothesis_space.reset_pointer() 283 | 284 | # empty the pool just in case 285 | self.initialise_pool() 286 | 287 | # Add first clauses into pool (active) 288 | initial_clauses = hypothesis_space.get_current_candidate() 289 | self.put_into_pool( 290 | [ 291 | (cl, self.evaluate(examples, cl,hypothesis_space)[0], self.evaluate(examples, cl,hypothesis_space)[1]) 292 | for cl in initial_clauses 293 | ] 294 | ) 295 | # print(self._candidate_pool) 296 | currentbest = None 297 | currentbestvalue = -99999 298 | 299 | i = 0 300 | 301 | while len(self._candidate_pool) > 0: 302 | # Optimise: pick smart according to evalFn (e.g. shorter clause when using compression) 303 | k = self.get_from_pool() 304 | if self._print: 305 | print("Expanding clause {}".format(k[0])) 306 | # Generate children of k 307 | new_clauses = hypothesis_space.expand(k[0]) 308 | 309 | # Remove clauses that are too long... 310 | new_clauses = self.process_expansions( 311 | examples, new_clauses, hypothesis_space 312 | ) 313 | # Compute costs for these children 314 | value = {cl: self.evaluate(examples, cl, hypothesis_space)[0] for cl in new_clauses} 315 | upperbound_value = { 316 | cl: self.evaluate(examples, cl, hypothesis_space)[1] for cl in new_clauses 317 | } 318 | 319 | for c in new_clauses: 320 | # If upper bound too low, don't bother expanding 321 | if upperbound_value[c] <= currentbestvalue and not c == currentbest: 322 | hypothesis_space.remove(c) 323 | else: 324 | if value[c] > currentbestvalue: 325 | currentbestvalue = value[c] 326 | currentbest = c 327 | len_before = len(self._candidate_pool) 328 | self.prune_pool(value[c]) 329 | len_after = len(self._candidate_pool) 330 | 331 | if self._print: 332 | print("Found new best: {}: {} {}".format(c,self._eval_fn.name(),value[c])) 333 | print("Pruning to upperbound {} >= {}: {} of {} clauses removed".format(self._eval_fn.name(),value[c],(len_before-len_after),len_before)) 334 | 335 | self.put_into_pool((c, value[c], upperbound_value[c])) 336 | if self._print: 337 | print("Put {} into pool, contains {} clauses".format(str(c),len(self._candidate_pool))) 338 | 339 | i += 1 340 | 341 | if self._print: 342 | print("New clause: {} with score {}".format(currentbest,currentbestvalue)) 343 | return currentbest 344 | -------------------------------------------------------------------------------- /loreleai/learning/learners/breadth_first_learner.py: -------------------------------------------------------------------------------- 1 | import typing 2 | from abc import ABC, abstractmethod 3 | 4 | from loreleai.learning.abstract_learners import TemplateLearner 5 | from loreleai.reasoning.lp import LPSolver 6 | from loreleai.language.lp import Clause,Atom,Procedure 7 | from loreleai.learning.task import Task, Knowledge 8 | from loreleai.learning.hypothesis_space import TopDownHypothesisSpace 9 | from loreleai.learning.eval_functions import EvalFunction 10 | 11 | from orderedset import OrderedSet 12 | 13 | 14 | 15 | """ 16 | A simple breadth-first top-down learner: it extends the template learning by searching in a breadth-first fashion 17 | 18 | It implements the abstract functions in the following way: 19 | - initialise_pool: creates an empty OrderedSet 20 | - put_into_pool: adds to the ordered set 21 | - get_from_pool: returns the first elements in the ordered set 22 | - evaluate: returns the number of covered positive examples and 0 if any negative example is covered 23 | - stop inner search: stops if the provided score of a clause is bigger than zero 24 | - process expansions: removes from the hypothesis space all clauses that have no solutions 25 | 26 | The learner does not handle recursions correctly! 27 | """ 28 | class SimpleBreadthFirstLearner(TemplateLearner): 29 | 30 | def __init__(self, solver_instance: LPSolver, eval_fn: EvalFunction, max_body_literals=4,do_print=False): 31 | super().__init__(solver_instance,eval_fn,do_print=do_print) 32 | self._max_body_literals = max_body_literals 33 | 34 | def initialise_pool(self): 35 | self._candidate_pool = OrderedSet() 36 | 37 | def put_into_pool(self, candidates: typing.Union[Clause, Procedure, typing.Sequence]) -> None: 38 | if isinstance(candidates, Clause): 39 | self._candidate_pool.add(candidates) 40 | else: 41 | self._candidate_pool |= candidates 42 | 43 | def get_from_pool(self) -> Clause: 44 | return self._candidate_pool.pop(0) 45 | 46 | 47 | def stop_inner_search(self, eval: typing.Union[int, float], examples: Task, clause: Clause) -> bool: 48 | if eval > 0: 49 | return True 50 | else: 51 | return False 52 | 53 | def process_expansions(self, examples: Task, exps: typing.Sequence[Clause], hypothesis_space: TopDownHypothesisSpace) -> typing.Sequence[Clause]: 54 | # eliminate every clause with more body literals than allowed 55 | exps = [cl for cl in exps if len(cl) <= self._max_body_literals] 56 | 57 | # check if every clause has solutions 58 | exps = [(cl, self._solver.has_solution(*cl.get_body().get_literals())) for cl in exps] 59 | new_exps = [] 60 | 61 | for ind in range(len(exps)): 62 | if exps[ind][1]: 63 | # keep it if it has solutions 64 | new_exps.append(exps[ind][0]) 65 | else: 66 | # remove from hypothesis space if it does not 67 | hypothesis_space.remove(exps[ind][0]) 68 | 69 | return new_exps -------------------------------------------------------------------------------- /loreleai/learning/task.py: -------------------------------------------------------------------------------- 1 | from functools import reduce 2 | from typing import Union, List, Sequence, Tuple, Set 3 | 4 | from loreleai.language.lp import Atom, Clause, Procedure, Program 5 | 6 | 7 | class Knowledge: 8 | 9 | def __init__(self, *pieces: Union[Atom, Clause, Procedure, Program]) -> None: 10 | self._knowledge_pieces: List = reduce( 11 | lambda x, y: x + [y] if isinstance(y, (Atom, Clause, Procedure)) else x + y.get_clauses(), 12 | pieces, 13 | []) 14 | 15 | def _add(self, other: Union[Atom, Clause, Procedure, Program]) -> None: 16 | if isinstance(other, (Atom, Clause, Procedure)): 17 | self._knowledge_pieces.append(other) 18 | else: 19 | self._knowledge_pieces += other.get_clauses() 20 | 21 | def add(self, *piece: Union[Atom, Clause, Procedure, Program]) -> None: 22 | map(lambda x: self._add(x), piece) 23 | 24 | def get_all(self): 25 | return self._knowledge_pieces 26 | 27 | def get_clauses(self): 28 | return [x for x in self._knowledge_pieces if isinstance(x, (Clause, Procedure))] 29 | 30 | def get_atoms(self): 31 | return [x for x in self._knowledge_pieces if isinstance(x, Atom)] 32 | 33 | def as_clauses(self): 34 | l = [] 35 | for x in self.get_all(): 36 | if isinstance(x,Clause): 37 | l.append(x) 38 | elif isinstance(x,Atom): 39 | l.append(Clause(x,[])) 40 | elif isinstance(x,Procedure): 41 | for cl in x.get_clauses(): 42 | l.append(cl) 43 | elif isinstance(x,Program): 44 | for cl in x.get_clauses(): 45 | l.append(cl) 46 | else: 47 | raise AssertionError("Knowledge can only contain clauses, atoms, procedures or programs!") 48 | return l 49 | 50 | 51 | class Interpretation: 52 | 53 | def __init__(self, *literals: Atom) -> None: 54 | self._literals: Sequence[Atom] = literals 55 | 56 | def get_literals(self) -> Sequence[Atom]: 57 | return self._literals 58 | 59 | 60 | class Task: 61 | 62 | def __init__(self, positive_examples: Set[Atom] = None, negative_examples: Set[Atom] = None, interpretations: Sequence[Interpretation] = None): 63 | self._examples: Sequence[Interpretation] = interpretations 64 | self._positive_examples: Set[Atom] = positive_examples 65 | self._negative_examples: Set[Atom] = negative_examples 66 | 67 | def get_examples(self) -> Union[Sequence[Interpretation], Tuple[Set[Atom], Set[Atom]]]: 68 | if self._examples is not None: 69 | return self._examples 70 | else: 71 | return self._positive_examples, self._negative_examples 72 | -------------------------------------------------------------------------------- /loreleai/learning/utilities.py: -------------------------------------------------------------------------------- 1 | from itertools import product, combinations 2 | from typing import Sequence 3 | 4 | import networkx as nx 5 | 6 | from loreleai.language.lp import Type, c_pred, Clause, Body, Atom, Predicate, c_var, Variable, Constant 7 | from loreleai.learning.task import Knowledge 8 | from typing import Sequence, Dict, Tuple 9 | 10 | 11 | class FillerPredicate: 12 | def __init__( 13 | self, 14 | prefix_name: str, 15 | arity: int = None, 16 | max_arity: int = None, 17 | min_arity: int = None, 18 | ): 19 | assert bool(max_arity) == bool(min_arity) and min_arity <= max_arity 20 | assert bool(arity) != bool(max_arity) 21 | self._prefix_name = prefix_name 22 | self._arity = arity 23 | self._max_arity = max_arity 24 | self._min_arity = min_arity 25 | self._instance_counter = 0 26 | 27 | def new( 28 | self, arity: int = None, argument_types: Sequence[Type] = None 29 | ) -> Predicate: 30 | """ 31 | Creates a new predicate from the template 32 | 33 | if the FillerPredicate is constructed with arity set, it returns a new predicate with that arity 34 | if the 35 | """ 36 | assert ( 37 | arity is not None or self._arity is not None or argument_types is not None 38 | ) 39 | self._instance_counter += 1 40 | 41 | if argument_types is not None: 42 | assert ( 43 | len(argument_types) == self._arity 44 | or len(argument_types) >= self._min_arity 45 | or len(argument_types) <= self._max_arity 46 | ) 47 | return c_pred( 48 | f"{self._prefix_name}_{self._instance_counter}", 49 | self._arity, 50 | argument_types, 51 | ) 52 | elif arity is not None: 53 | return c_pred(f"{self._prefix_name}_{self._instance_counter}", arity) 54 | else: 55 | return c_pred(f"{self._prefix_name}_{self._instance_counter}", self._arity) 56 | 57 | def _from_body_fixed_arity( 58 | self, 59 | body: Body, 60 | arity: int = None, 61 | arg_types: Sequence[Type] = None, 62 | use_as_head_predicate: Predicate = None, 63 | ) -> Sequence[Atom]: 64 | """ 65 | Creates a head atom given the body of the clause 66 | :param body: 67 | :param arity: (optional) desired arity if specified with min/max when constructing the FillerPredicate 68 | :param arg_types: (optional) argument types to use 69 | :return: 70 | """ 71 | assert bool(arity) != bool(arg_types) 72 | vars = body.get_variables() 73 | 74 | if use_as_head_predicate and arg_types is None: 75 | arg_types = use_as_head_predicate.get_arg_types() 76 | 77 | if arg_types is None: 78 | base = [vars] * arity 79 | else: 80 | matches = {} 81 | 82 | for t_ind in range(len(arg_types)): 83 | matches[t_ind] = [] 84 | for v_ind in range(len(vars)): 85 | if vars[v_ind].get_type() == arg_types[t_ind]: 86 | matches[t_ind].append(vars[v_ind]) 87 | 88 | base = [matches[x] for x in range(arity)] 89 | 90 | heads = [] 91 | for comb in product(*base): 92 | self._instance_counter += 1 93 | if use_as_head_predicate is not None: 94 | pred = use_as_head_predicate 95 | elif arg_types is None: 96 | pred = c_pred(f"{self._prefix_name}_{self._instance_counter}", arity) 97 | else: 98 | pred = c_pred( 99 | f"{self._prefix_name}_{self._instance_counter}", 100 | len(arg_types), 101 | arg_types, 102 | ) 103 | 104 | heads.append(Atom(pred, list(comb))) 105 | 106 | return heads 107 | 108 | def new_from_body( 109 | self, 110 | body: Body, 111 | arity: int = None, 112 | argument_types: Sequence[Type] = None, 113 | use_as_head_predicate: Predicate = None, 114 | ) -> Sequence[Atom]: 115 | """ 116 | Constructs all possible head atoms given a body of a clause 117 | 118 | If use_as_head_predicate is provided, it uses that. 119 | Then, it check is the FillerPredicate is instantiated with a fixed arity 120 | Then, it check if the arity argument is provided 121 | """ 122 | if use_as_head_predicate: 123 | return self._from_body_fixed_arity(body, arity=use_as_head_predicate.get_arity(), arg_types=use_as_head_predicate.get_arg_types(), use_as_head_predicate=use_as_head_predicate) 124 | elif self._arity is not None: 125 | # a specific arity is provided 126 | return self._from_body_fixed_arity(body, self._arity, argument_types) 127 | else: 128 | # min -max arity is provided 129 | if arity is not None: 130 | # specific arity is requests 131 | return self._from_body_fixed_arity(body, arity) 132 | else: 133 | heads = [] 134 | for i in range(self._min_arity, self._max_arity + 1): 135 | heads += self._from_body_fixed_arity(body, i) 136 | 137 | return heads 138 | 139 | def all_possible_atoms(self) -> Sequence[Atom]: 140 | """ 141 | Creates all possible argument configurations for the atom 142 | """ 143 | head_variables = [c_var(chr(x)) for x in range(ord("A"), ord("Z"))][ 144 | : self._arity 145 | ] 146 | if self._arity is not None: 147 | return [ 148 | Atom(self.new(), list(x)) 149 | for x in product(head_variables, repeat=self._arity) 150 | ] 151 | else: 152 | combos = [] 153 | for i in range(self._min_arity, self._max_arity): 154 | combos += [ 155 | Atom(self.new(arity=i), list(x)) 156 | for x in product(head_variables, repeat=i) 157 | ] 158 | return combos 159 | 160 | def is_created_by(self, predicate: Predicate) -> bool: 161 | """ 162 | checks if a given predicate is created by the FillerPredicate 163 | 164 | it does so by checking if the name of the predicate is in [name]_[number] format and the name is 165 | equal to self._prefix_name and number is <= self._instance_counter 166 | """ 167 | sp = predicate.get_name().split("_") 168 | if len(sp) != 2: 169 | return False 170 | else: 171 | if sp[0] == self._prefix_name and int(sp[1]) <= self._instance_counter: 172 | return True 173 | else: 174 | return False 175 | 176 | def _add_to_body_fixed_arity(self, body: Body, arity: int) -> Sequence[Body]: 177 | new_pred_stash = {} # arg_types tuple -> pred 178 | 179 | vars = body.get_variables() 180 | 181 | bodies = [] 182 | 183 | args = list(combinations(vars, arity)) 184 | for ind in range(len(args)): 185 | arg_types = (x.get_type() for x in args[ind]) 186 | 187 | if arg_types in new_pred_stash: 188 | pred = new_pred_stash[arg_types] 189 | else: 190 | self._instance_counter += 1 191 | pred = c_pred(f"{self._prefix_name}{self._instance_counter}", arity, arg_types) 192 | new_pred_stash[arg_types] = pred 193 | 194 | bodies.append(body + pred(*args[ind])) 195 | 196 | return bodies 197 | 198 | def add_to_body(self, body: Body) -> Sequence[Body]: 199 | """ 200 | Adds the filler predicate to the body 201 | 202 | It is meant to be used to create a recursive clause 203 | """ 204 | 205 | if self._arity: 206 | return self._add_to_body_fixed_arity(body, self._arity) 207 | else: 208 | bodies = [] 209 | 210 | for ar in range(self._min_arity, self._max_arity + 1): 211 | bodies += self._add_to_body_fixed_arity(body, ar) 212 | 213 | return bodies 214 | 215 | 216 | def are_variables_connected(atoms: Sequence[Atom]): 217 | """ 218 | Checks whether the Variables in the clause are connected 219 | 220 | Args: 221 | atoms (Sequence[Atom]): atoms whose variables have to be checked 222 | 223 | """ 224 | g = nx.Graph() 225 | 226 | for atm in atoms: 227 | vrs = atm.get_variables() 228 | if len(vrs) == 1: 229 | g.add_node(vrs[0]) 230 | else: 231 | for cmb in combinations(vrs, 2): 232 | g.add_edge(cmb[0], cmb[1]) 233 | 234 | res = nx.is_connected(g) 235 | del g 236 | 237 | return res 238 | 239 | 240 | def compute_bottom_clause(theory: Sequence[Clause], c: Clause) -> Clause: 241 | """ 242 | Computes the bottom clause given a theory and a clause. 243 | Algorithm from (De Raedt,2008) 244 | """ 245 | # 1. Find a skolemization substitution θ for c (w.r.t. B and c) 246 | _, theta = skolemize(c) 247 | 248 | # 2. Compute the least Herbrand model M of theory ¬body(c)θ 249 | body_facts = [ 250 | Clause(l.substitute(theta), []) for l in c.get_body().get_literals() 251 | ] 252 | m = herbrand_model(theory + body_facts) 253 | 254 | # 3. Deskolemize the clause head(cθ) <= M and return the result. 255 | theta_inv = {value: key for key, value in theta.items()} 256 | return Clause(c.get_head(), [l.get_head().substitute(theta_inv) for l in m]) 257 | 258 | def skolemize(clause: Clause) -> Clause: 259 | # Find all variables in clause 260 | vars = clause.get_variables() 261 | 262 | # Map from X,Y,Z,... -> sk0,sk1,sk2,... 263 | subst = {vars[i]: Constant(f"sk{i}", c_type("thing")) for i in range(len(vars))} 264 | 265 | # Apply this substitution to create new clause without quantifiers 266 | return clause.substitute(subst), subst 267 | 268 | 269 | def herbrand_model(clauses: Sequence[Clause]) -> Sequence[Clause]: 270 | """ 271 | Computes a minimal Herbrand model of a theory 'clauses'. 272 | Algorithm from Logical and Relational learning (De Raedt, 2008) 273 | """ 274 | i = 1 275 | m = {0: []} 276 | # Find a fact in the theory (i.e. no body literals) 277 | facts = list(filter(lambda c: len(c.get_body().get_literals()) == 0, clauses)) 278 | if len(facts) == 0: 279 | raise AssertionError( 280 | "Theory does not contain ground facts, which necessary to compute a minimal Herbrand model!" 281 | ) 282 | # print("Finished iteration 0") 283 | 284 | # If all clauses are just facts, there is nothing to be done. 285 | if len(facts) == len(clauses): 286 | return clauses 287 | 288 | #BUG: doesn't work properly after pylo update... 289 | 290 | m[1] = list(facts) 291 | while Counter(m[i]) != Counter(m[i - 1]): 292 | model_constants = _flatten( 293 | [fact.get_head().get_arguments() for fact in m[i]] 294 | ) 295 | 296 | m[i + 1] = [] 297 | rules = list( 298 | filter(lambda c: len(c.get_body().get_literals()) > 0, clauses) 299 | ) 300 | 301 | for rule in rules: 302 | # if there is a substition theta such that 303 | # all literals in rule._body are true in the previous model 304 | body = rule.get_body() 305 | body_vars = body.get_variables() 306 | # Build all substitutions body_vars -> model_constants 307 | substitutions = _all_maps(body_vars, model_constants) 308 | 309 | for theta in substitutions: 310 | # add_rule is True unless there is some literal that never 311 | # occurs in m[i] 312 | add_fact = True 313 | for body_lit in body.get_literals(): 314 | candidate = body_lit.substitute(theta) 315 | facts = list(map(lambda x: x.get_head(), m[i])) 316 | # print("Does {} occur in {}?".format(candidate,facts)) 317 | if candidate in facts: 318 | pass 319 | # print("Yes") 320 | else: 321 | add_fact = False 322 | 323 | new_fact = Clause(rule.get_head().substitute(theta), []) 324 | 325 | if add_fact and not new_fact in m[i + 1] and not new_fact in m[i]: 326 | m[i + 1].append(new_fact) 327 | # print("Added fact {} to m[{}]".format(str(new_fact),i+1)) 328 | # print(m[i+1]) 329 | 330 | # print(f"Finished iteration {i}") 331 | m[i + 1] = list(set(m[i + 1] + m[i])) 332 | # print("New model: "+str(m[i+1])) 333 | i += 1 334 | return m[i] 335 | 336 | def find_allowed_positions(knowledge: Knowledge): 337 | """ 338 | Returns a dict x such that x[constant][predicate] contains 339 | all positions such i such that `predicate` can have `constant` as 340 | argument at position i in the background knowledge. 341 | This is used to restrict the number of clauses generated through variable 342 | instantiation. 343 | If an atom is not implied by the background theory (i.e. is not in 344 | the Herbrand Model), there is no point in considering it, because 345 | it will never be true. 346 | """ 347 | facts = herbrand_model(list(knowledge.as_clauses())) 348 | predicates = set() 349 | 350 | # Build dict that will restrict where constants are allowed to appear 351 | # e.g. allowed_positions[homer][father] = [0,1] 352 | allowed_positions = dict() 353 | for atom in facts: 354 | args = atom.get_head().get_arguments() 355 | pred = atom.get_head().get_predicate() 356 | predicates = predicates.union({pred}) 357 | for i in range(len(args)): 358 | arg = args[i] 359 | if isinstance(arg,Constant): 360 | # New constant, initialize allowed_positions[constant] 361 | if not arg in allowed_positions.keys(): 362 | allowed_positions[arg] = dict() 363 | # Known constant, but not for this predicate 364 | if not pred in allowed_positions[arg]: 365 | allowed_positions[arg][pred] = [i] 366 | # Known constant, and predicate already seen for this constant 367 | else: 368 | if i not in allowed_positions[arg][pred]: 369 | allowed_positions[arg][pred] = allowed_positions[arg][pred]+[i] 370 | 371 | # Complete dict with empty lists for constant/predicate combinations 372 | # that were not observed in the background data 373 | for const in allowed_positions.keys(): 374 | for pred in list(predicates): 375 | if pred not in allowed_positions[const].keys(): 376 | allowed_positions[const][pred] = [] 377 | 378 | return allowed_positions 379 | 380 | def find_allowed_reflexivity(knowledge: Knowledge): 381 | """ 382 | Returns the set of predicates in `knowledge` that allow all of its 383 | arguments to be equal. That is, if `knowledge` contains a fact pred(x,x,x), 384 | pred will be in the return value. 385 | """ 386 | facts = herbrand_model(list(knowledge.as_clauses())) 387 | allow_reflexivity = set() 388 | for atom in facts: 389 | args = atom.get_head().get_arguments() 390 | pred = atom.get_head().get_predicate() 391 | if len(args) > 0: 392 | # If all arguments are equal 393 | if all(args[i] == args[0] for i in range(len(args))): 394 | allow_reflexivity.add(pred) 395 | 396 | return allow_reflexivity 397 | 398 | def find_frequent_constants(knowledge: Knowledge,min_frequency=0): 399 | """ 400 | Returns a list of all constants that occur at least `min_frequency` times in 401 | `knowledge` 402 | """ 403 | facts = herbrand_model(list(knowledge.as_clauses())) 404 | d = {} 405 | 406 | # Count occurrences of constants 407 | for atom in facts: 408 | args = atom.get_head().get_arguments() 409 | for arg in args: 410 | if isinstance(arg, Constant): 411 | if arg not in d.keys(): 412 | d[arg] = 0 413 | else: 414 | d[arg] = d[arg] + 1 415 | 416 | return [const for const in d.keys() if d[const] >= min_frequency] 417 | 418 | def _flatten(l) -> Sequence: 419 | """ 420 | [[1],[2],[3]] -> [1,2,3] 421 | """ 422 | return [item for sublist in l for item in sublist] 423 | 424 | 425 | def _all_maps(l1, l2) -> Sequence[Dict[Variable, Constant]]: 426 | """ 427 | Return all maps between l1 and l2 428 | such that all elements of l1 have an entry in the map 429 | """ 430 | sols = [] 431 | for c in combinations_with_replacement(l2, len(l1)): 432 | sols.append({l1[i]: c[i] for i in range(len(l1))}) 433 | return sols -------------------------------------------------------------------------------- /loreleai/reasoning/__init__.py: -------------------------------------------------------------------------------- 1 | from pylo.engines.lpsolver import LPSolver 2 | 3 | __all__ = ['LPSolver'] 4 | -------------------------------------------------------------------------------- /loreleai/reasoning/lp/__init__.py: -------------------------------------------------------------------------------- 1 | # from loreleai.reasoning.lp.kanren.relationalsolver import RelationalSolver 2 | # from loreleai.reasoning.lp.prolog.Prolog import Prolog 3 | # from .datalog.datalogsolver import DatalogSolver 4 | # from .lpsolver import LPSolver 5 | from pylo.engines.datalog.datalogsolver import DatalogSolver 6 | from pylo.engines.kanren.relationalsolver import RelationalSolver 7 | from pylo.engines.lpsolver import LPSolver 8 | from pylo.engines.prolog.prologsolver import Prolog 9 | 10 | __all__ = [ 11 | 'DatalogSolver', 12 | 'LPSolver', 13 | 'RelationalSolver', 14 | # 'Prolog' 15 | ] -------------------------------------------------------------------------------- /loreleai/reasoning/lp/datalog/__init__.py: -------------------------------------------------------------------------------- 1 | # from loreleai.reasoning.lp.datalog.muz import MuZ 2 | # from .datalogsolver import DatalogSolver 3 | 4 | from pylo.engines.datalog import MuZ 5 | from pylo.engines.datalog.datalogsolver import DatalogSolver 6 | 7 | __all__ = [ 8 | 'MuZ', 9 | 'DatalogSolver' 10 | ] -------------------------------------------------------------------------------- /loreleai/reasoning/lp/kanren/__init__.py: -------------------------------------------------------------------------------- 1 | # from .minikanren import MiniKanren 2 | # #from .relationalsolver import RelationalSolver 3 | 4 | from pylo.engines.kanren import MiniKanren 5 | from pylo.engines.kanren.relationalsolver import RelationalSolver 6 | 7 | __all__ = [ 8 | 'MiniKanren', 9 | 'RelationalSolver' 10 | ] -------------------------------------------------------------------------------- /loreleai/reasoning/lp/prolog/__init__.py: -------------------------------------------------------------------------------- 1 | # from .GNUProlog import GNUProlog 2 | # from .SWIProlog import SWIProlog 3 | # from .XSBProlog import XSBProlog 4 | # #from .Prolog import Prolog 5 | 6 | engines = [] 7 | 8 | try: 9 | from pylo.engines.prolog import GNUProlog 10 | engines += ['GNUProlog'] 11 | except Exception: 12 | pass 13 | 14 | try: 15 | from pylo.engines.prolog import SWIProlog 16 | engines += ['SWIProlog'] 17 | except Exception: 18 | pass 19 | 20 | try: 21 | from pylo.engines.prolog import XSBProlog 22 | engines += ['XSBProlog'] 23 | except Exception: 24 | pass 25 | 26 | from pylo.engines.prolog import Prolog 27 | 28 | engines += ['Prolog'] 29 | __all__ = engines 30 | 31 | # __all__ = [ 32 | # "SWIProlog", 33 | # "XSBProlog", 34 | # "GNUProlog", 35 | # "Prolog" 36 | # ] -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | pytest 2 | networkx 3 | black 4 | matplotlib 5 | pygraphviz 6 | z3-solver 7 | miniKanren 8 | orderedset -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, find_packages 2 | 3 | with open("README.md", "r") as fh: 4 | long_description = fh.read() 5 | 6 | setup(name='loreleai', 7 | version='0.1.4', 8 | description='A library for program induction/synthesis and StarAI', 9 | long_description=long_description, 10 | long_description_content_type='text/markdown', 11 | url='https://github.com/sebdumancic/loreleai', 12 | author='Sebastijan Dumancic', 13 | author_email='sebastijan.dumancic@gmail.com', 14 | license='MIT', 15 | packages=find_packages(), 16 | # package_dir={'': 'loreleai'}, 17 | install_requires=[ 18 | 'pytest', 19 | 'networkx', 20 | 'matplotlib', 21 | 'z3-solver', 22 | 'miniKanren', 23 | 'orderedset' 24 | ], 25 | python_requires=">=3.6" 26 | ) 27 | -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sebdumancic/loreleai/3863c683b6acfdaa9c1e23a526110492e610fe74/tests/__init__.py -------------------------------------------------------------------------------- /tests/aleph_tests.py: -------------------------------------------------------------------------------- 1 | import datetime 2 | 3 | from loreleai.learning.learners import Aleph 4 | from loreleai.learning.task import Task,Knowledge 5 | from loreleai.language.lp import c_type,Clause, Term, Variable, Constant, c_pred,c_const,c_var, Body 6 | from loreleai.learning import HypothesisSpace 7 | from loreleai.reasoning.lp.prolog import SWIProlog 8 | from loreleai.reasoning.lp.datalog import MuZ 9 | from loreleai.learning.eval_functions import Coverage, Compression, Accuracy 10 | from loreleai.learning.language_manipulation import plain_extension, variable_instantiation 11 | from loreleai.learning.hypothesis_space import TopDownHypothesisSpace 12 | from loreleai.learning.language_filtering import has_duplicated_literal, has_singleton_vars 13 | 14 | 15 | def learn_with_constants(): 16 | """ 17 | Consider a row of blocks [ block1 block2 block3 block4 block5 block6 ] 18 | The order of this row is expressed using follows(X,Y) 19 | The color of a block is expressed using color(X,Color) 20 | 21 | Goal: learn a function f that says: a block is positive when it is followed by a red block 22 | pos(X) :- next(X,Y), color(Y,red) 23 | """ 24 | block = c_type("block") 25 | col = c_type("col") 26 | 27 | block1 = c_const("block1",domain=block) # blue -> positive 28 | block2 = c_const("block2",domain=block) # red 29 | block3 = c_const("block3",domain=block) # green -> positive 30 | block4 = c_const("block4",domain=block) # red -> positive 31 | block5 = c_const("block5",domain=block) # red 32 | block6 = c_const("block6",domain=block) # green 33 | block7 = c_const("block7",domain=block) # blue 34 | block8 = c_const("block8",domain=block) # blue 35 | 36 | red = c_const("red",domain="col") 37 | green = c_const("green",domain="col") 38 | blue = c_const("blue",domain="col") 39 | 40 | follows = c_pred("follows",2,domains=[block,block]) 41 | color = c_pred("color",2,domains=[block,col]) 42 | 43 | # Predicate to learn: 44 | f = c_pred("f",1,domains=[block]) 45 | 46 | bk = Knowledge( 47 | follows(block1,block2), follows(block2,block3), follows(block3,block4), 48 | follows(block4,block5), follows(block5,block6), follows(block6,block7), 49 | follows(block7,block8), color(block1,blue), color(block2, red), 50 | color(block3,green), color(block4,red), color(block5,red), 51 | color(block6,green), color(block7,blue), color(block8,blue) 52 | ) 53 | 54 | pos = {f(x) for x in [block1,block3,block4]} 55 | neg = {f(x) for x in [block2,block5,block6,block7,block8]} 56 | 57 | task = Task(positive_examples=pos, negative_examples=neg) 58 | solver = SWIProlog() 59 | 60 | # EvalFn must return an upper bound on quality to prune search space. 61 | eval_fn1 = Coverage(return_upperbound=True) 62 | eval_fn2 = Compression(return_upperbound=True) 63 | eval_fn3 = Accuracy(return_upperbound=True) 64 | 65 | learners = [Aleph(solver,eval_fn,max_body_literals=4,do_print=False) 66 | for eval_fn in [eval_fn1,eval_fn3]] 67 | 68 | for learner in learners: 69 | res = learner.learn(task,bk,None,minimum_freq=1) 70 | print(res) 71 | 72 | 73 | def learn_text(): 74 | """ 75 | We describe piece of text spanning multiple lines: 76 | "node A red node B green node C blue " 77 | using the next\2, linestart\2, lineend\2, tokenlength\2 predicates 78 | """ 79 | token = c_type("token") 80 | num = c_type("num") 81 | 82 | next = c_pred("next",2,("token","token")) 83 | linestart = c_pred("linestart",2,("token","token")) 84 | lineend = c_pred("lineend",2,("token","token")) 85 | tokenlength = c_pred("tokenlength",2,("token","num")) 86 | 87 | n1 = c_const("n1",num) 88 | n3 = c_const("n3",num) 89 | n4 = c_const("n4",num) 90 | n5 = c_const("n5",num) 91 | node1 = c_const("node1",token) 92 | node2 = c_const("node2",token) 93 | node3 = c_const("node3",token) 94 | red = c_const("red",token) 95 | green = c_const("green",token) 96 | blue = c_const("blue",token) 97 | a_c = c_const("a_c",token) 98 | b_c = c_const("b_c",token) 99 | c_c = c_const("c_c",token) 100 | start = c_const("c_START",token) 101 | end = c_const("c_END",token) 102 | 103 | bk = Knowledge( 104 | next(start,node1),next(node1,a_c),next(a_c,red), 105 | next(red,node2),next(node2,green),next(green,b_c), 106 | next(b_c,node3),next(node3,c_c),next(c_c,blue), 107 | next(blue,end),tokenlength(node1,n4),tokenlength(node2,n4), 108 | tokenlength(node3,n4),tokenlength(a_c,n1),tokenlength(b_c,n1), 109 | tokenlength(c_c,n1),tokenlength(red,n3),tokenlength(green,n5), 110 | tokenlength(blue,n4),linestart(node1,node1),linestart(a_c,node1), 111 | linestart(red,node1),linestart(node2,node2),linestart(b_c,node2), 112 | linestart(green,node2),linestart(node3,node3),linestart(c_c,node3), 113 | linestart(blue,node3),lineend(node1,a_c),lineend(a_c,red), 114 | lineend(node2,red),lineend(b_c,green),lineend(node3,blue), 115 | lineend(c_c,blue),lineend(red,red),lineend(green,green), 116 | lineend(blue,blue)) 117 | 118 | solver = SWIProlog() 119 | eval_fn1 = Coverage(return_upperbound=True) 120 | learner = Aleph(solver,eval_fn1,max_body_literals=3,do_print=False) 121 | 122 | # 1. Consider the hypothesis: f1(word) :- word is the second word on a line 123 | if True: 124 | f1 = c_pred("f1",1,[token]) 125 | neg = {f1(x) for x in [node1,node2,node3,blue,green,red]} 126 | pos = {f1(x) for x in [a_c,b_c,c_c]} 127 | task = Task(positive_examples=pos, negative_examples=neg) 128 | 129 | res = learner.learn(task,bk,None) 130 | print(res) 131 | 132 | # 2. Consider the hypothesis: f2(word) :- word is the first word on a line 133 | if True: 134 | f2 = c_pred("f2",1,[token]) 135 | neg = {f1(x) for x in [a_c,b_c,c_c,blue,green,red]} 136 | pos = {f1(x) for x in [node1,node2,node3]} 137 | task2 = Task(positive_examples=pos, negative_examples=neg) 138 | 139 | res = learner.learn(task2,bk,None) 140 | print(res) 141 | 142 | # 3. Assume we have learned the predicate node(X) before (A, B and C and nodes). 143 | # We want to learn f3(Node,X) :- X is the next token after Node 144 | if True: 145 | node = c_pred("node",1,[token]) 146 | color = c_pred("color",1,[token]) 147 | nodecolor = c_pred("nodecolor",2,[token,token]) 148 | a = c_var("A",token) 149 | b = c_var("B",token) 150 | bk_old = bk.get_all() 151 | bk = Knowledge(*bk_old, node(a_c),node(b_c),node(c_c), 152 | node(a_c), node(b_c),node(c_c), 153 | color(red),color(green),color(blue)) 154 | pos = {nodecolor(a_c,red),nodecolor(b_c,green),nodecolor(c_c,blue)} 155 | neg = set() 156 | neg = {nodecolor(node1,red),nodecolor(node2,red),nodecolor(node3,red), 157 | nodecolor(node1,blue),nodecolor(node2,blue),nodecolor(node2,blue), 158 | nodecolor(node1,green),nodecolor(node2,green),nodecolor(node3,green), 159 | nodecolor(a_c,green),nodecolor(a_c,blue),nodecolor(b_c,blue), 160 | nodecolor(b_c,red),nodecolor(c_c,red),nodecolor(c_c,green) 161 | } 162 | task3 = Task(positive_examples=pos, negative_examples=neg) 163 | 164 | # prog = learner.learn(task3,bk,None,initial_clause=Body(node(a),color(b))) 165 | result = learner.learn(task3,bk,None,initial_clause=Body(node(a),color(b)),minimum_freq=3) 166 | print(result) 167 | 168 | 169 | 170 | def learn_simpsons(): 171 | # define the predicates 172 | father = c_pred("father", 2) 173 | mother = c_pred("mother", 2) 174 | grandparent = c_pred("grandparent", 2) 175 | 176 | # specify the background knowledge 177 | background = Knowledge( 178 | father("homer", "bart"), father("homer", "lisa"), father("homer", "maggie"), 179 | mother("marge", "bart"), mother("marge", "lisa"),mother("marge","maggie"), 180 | mother("mona","homer"),father("abe","homer"), 181 | mother("jacqueline","marge"),father("clancy","marge") 182 | ) 183 | 184 | # positive examples 185 | pos = { 186 | grandparent("abe", "bart"), 187 | grandparent("abe", "lisa"), 188 | grandparent("abe", "maggie"), 189 | grandparent("mona", "bart"), 190 | grandparent("abe", "lisa"), 191 | grandparent("abe", "maggie"), 192 | grandparent("jacqueline", "bart"), 193 | grandparent("jacqueline", "lisa"), 194 | grandparent("jacqueline", "maggie"), 195 | grandparent("clancy", "bart"), 196 | grandparent("clancy", "lisa"), 197 | grandparent("clancy", "maggie"), 198 | } 199 | 200 | # negative examples 201 | neg = { 202 | grandparent("abe", "marge"), grandparent("abe", "homer"), grandparent("abe", "clancy"),grandparent("abe","jacqueline"), 203 | grandparent("homer","marge"), grandparent("homer","jacqueline"),grandparent("jacqueline","marge"), 204 | grandparent("clancy","homer"),grandparent("clancy","abe") 205 | } 206 | 207 | task = Task(positive_examples=pos, negative_examples=neg) 208 | solver = SWIProlog() 209 | 210 | # EvalFn must return an upper bound on quality to prune search space. 211 | eval_fn = Coverage(return_upperbound=True) 212 | eval_fn2 = Compression(return_upperbound=True) 213 | eval_fn3 = Compression(return_upperbound=True) 214 | 215 | learner = Aleph(solver,eval_fn,max_body_literals=4,do_print=False) 216 | learner2 = Aleph(solver,eval_fn2,max_body_literals=4,do_print=False) 217 | learner3 = Aleph(solver,eval_fn3,max_body_literals=4,do_print=False) 218 | 219 | result = learner.learn(task,background,None) 220 | print(result) 221 | 222 | # pr = learner2.learn(task,background,None) 223 | # print("Final program: {}".format(str(pr))) 224 | 225 | # pr = learner3.learn(task,background,None) 226 | # print("Final program: {}".format(str(pr))) 227 | 228 | 229 | if __name__ == "__main__": 230 | learn_simpsons() 231 | learn_with_constants() 232 | learn_text() 233 | 234 | 235 | 236 | 237 | 238 | 239 | 240 | 241 | 242 | 243 | -------------------------------------------------------------------------------- /tests/common_tasks.py: -------------------------------------------------------------------------------- 1 | from loreleai.language.commons import c_pred, c_const,c_type,c_var, Atom 2 | from loreleai.reasoning.lp.prolog import SWIProlog 3 | 4 | """ 5 | The description of these tasks is due to 6 | (Quinlan, 1990): Learning Logical definitions from Relations 7 | """ 8 | 9 | 10 | def family_relationships(): 11 | """ 12 | Family tree: https://i imgur com/43607jf png (Quinlan,1990) 13 | """ 14 | 15 | wife = c_pred("wife",2) 16 | husband = c_pred("husband",2) 17 | mother = c_pred("mother",2) 18 | father = c_pred("father",2) 19 | daughter = c_pred("daughter",2) 20 | son = c_pred("son",2) 21 | sister = c_pred("sister",2) 22 | brother = c_pred("brother",2) 23 | aunt = c_pred("aunt",2) 24 | uncle = c_pred("uncle",2) 25 | niece = c_pred("niece",2) 26 | nephew = c_pred("nephew",2) 27 | 28 | christopher = c_const("christopher") 29 | penelope = c_const("penelope") 30 | margaret = c_const("margaret") 31 | arthur = c_const("arthur") 32 | victoria = c_const("victoria") 33 | james = c_const("james") 34 | colin = c_const("colin") 35 | charlotte = c_const("charlotte") 36 | andrew = c_const("andrew") 37 | christine = c_const("christine") 38 | jennifer = c_const("jennifer") 39 | charles = c_const("charles") 40 | roberto = c_const("roberto") 41 | maria = c_const("maria") 42 | gina = c_const("gina") 43 | emilio = c_const("emilio") 44 | lucia = c_const("lucia") 45 | marco = c_const("marco") 46 | alfonso = c_const("alfonso") 47 | sophia = c_const("sophia") 48 | pierro = c_const("pierro") 49 | francesca = c_const("francesca") 50 | angela = c_const("angela") 51 | tomaso = c_const("tomaso") 52 | 53 | bk = [ 54 | wife(christopher,penelope), 55 | wife(andrew,christine), 56 | wife(charles,jennifer), 57 | wife(james,victoria), 58 | wife(arthur,margaret), 59 | wife(roberto,maria), 60 | wife(pierro,francesca), 61 | wife(emilio,gina), 62 | wife(marco,lucia), 63 | wife(tomaso,angela), 64 | 65 | husband(penelope,christopher), 66 | husband(christine,andrew), 67 | husband(jennifer,charles), 68 | husband(victoria,james), 69 | husband(margaret,arthur), 70 | husband(maria,roberto), 71 | husband(francesca,pierro), 72 | husband(gina,emilio), 73 | husband(lucia,marco), 74 | husband(angela,tomaso), 75 | 76 | mother(penelope,arthur), 77 | mother(penelope,victoria), 78 | mother(victoria,colin), 79 | mother(victoria,charlotte), 80 | mother(christine,jennifer), 81 | mother(christine,james), 82 | mother(maria,emilio), 83 | mother(maria,lucia), 84 | mother(lucia,alfonso), 85 | mother(lucia,sophia), 86 | mother(francesca,angela), 87 | mother(francesca,marco), 88 | 89 | father(christopher,arthur), 90 | father(christopher,victoria), 91 | father(james,colin), 92 | father(james,charlotte), 93 | father(andrew,james), 94 | father(andrew,jennifer), 95 | father(roberto,emilio), 96 | father(roberto,lucia), 97 | father(marco,alfonso), 98 | father(marco,sophia), 99 | father(pierro,marco), 100 | father(pierro,angela), 101 | 102 | daughter(victoria,christopher), 103 | daughter(victoria,penelope), 104 | daughter(charlotte,james), 105 | daughter(charlotte,victoria), 106 | daughter(jennifer,christine), 107 | daughter(jennifer,andrew), 108 | daughter(lucia,maria), 109 | daughter(lucia,roberto), 110 | daughter(sophia,marco), 111 | daughter(sophia,lucia), 112 | daughter(angela,francesca), 113 | daughter(angela,pierro), 114 | 115 | son(arthur,christopher), 116 | son(arthur,penelope), 117 | son(james,andrew), 118 | son(james,christine), 119 | son(colin,victoria), 120 | son(colin,charlotte), 121 | son(emilio,roberto), 122 | son(emilio,maria), 123 | son(marco,pierro), 124 | son(marco,francesca), 125 | son(alfonso,lucia), 126 | son(alfonso,marco), 127 | 128 | sister(charlotte,colin), 129 | sister(victoria,arthur), 130 | sister(jennifer,james), 131 | sister(lucia,emilio), 132 | sister(sophia,alfonso), 133 | sister(angela,marco), 134 | 135 | brother(colin,charlotte), 136 | brother(arthur,victoria), 137 | brother(james,jennifer), 138 | brother(emilio,lucia), 139 | brother(alfonso,sophia), 140 | brother(marco,angela), 141 | 142 | aunt(margaret,colin), 143 | aunt(margaret,charlotte), 144 | aunt(jennifer,colin), 145 | aunt(jennifer,charlotte), 146 | aunt(gina,alfonso), 147 | aunt(gina,sophia), 148 | aunt(angela,sophia), 149 | aunt(angela,alfonso), 150 | 151 | uncle(arthur,colin), 152 | uncle(arthur,charlotte), 153 | uncle(charles,colin), 154 | uncle(charles,charlotte), 155 | uncle(emilio,alfonso), 156 | uncle(emilio,sophia), 157 | uncle(tomaso,alfonso), 158 | uncle(tomaso,sophia), 159 | 160 | niece(charlotte,arthur), 161 | niece(charlotte,charles), 162 | niece(sophia,emilio), 163 | niece(sophia,tomaso), 164 | niece(charlotte,margaret), 165 | niece(charlotte,jennifer), 166 | niece(sophia,angela), 167 | niece(sophia,gina), 168 | 169 | nephew(colin,margaret), 170 | nephew(colin,jennifer), 171 | nephew(alfonso,gina), 172 | nephew(alfonso,angela), 173 | nephew(colin,arthur), 174 | nephew(colin,charles), 175 | nephew(alfonso,emilio), 176 | nephew(alfonso,tomaso), 177 | ] 178 | 179 | return bk 180 | 181 | 182 | 183 | 184 | def michalski_trains(): 185 | """ 186 | Description of the classic eastbound/westbound trains problem (Michalski 1980) 187 | """ 188 | car = c_type("car") 189 | train = c_type("train") 190 | shape = c_type("shape") 191 | num = c_type("num") 192 | 193 | car_shape = c_pred("car_shape",2,domains=(car,shape)) 194 | short = c_pred("short",1,domains=(car)) 195 | closed = c_pred("closed",1,domains=(car)) 196 | long = c_pred("long",1,domains=(car)) 197 | open_car = c_pred("open_car",1,domains=(car)) 198 | load = c_pred("load",3,domains=(car,shape,num)) 199 | wheels = c_pred("wheels",2,domains=(car,shape)) 200 | has_car = c_pred("has_car",2,domains=(train,car)) 201 | double = c_pred("double",1,domains=(car)) 202 | jagged = c_pred("jagged",1,domains=(car)) 203 | 204 | car_11 = c_const("car_11",domain=car) 205 | car_12 = c_const("car_12",domain=car) 206 | car_13 = c_const("car_13",domain=car) 207 | car_14 = c_const("car_14",domain=car) 208 | car_21 = c_const("car_21",domain=car) 209 | car_22 = c_const("car_22",domain=car) 210 | car_23 = c_const("car_23",domain=car) 211 | car_31 = c_const("car_31",domain=car) 212 | car_32 = c_const("car_32",domain=car) 213 | car_33 = c_const("car_33",domain=car) 214 | car_41 = c_const("car_41",domain=car) 215 | car_42 = c_const("car_42",domain=car) 216 | car_43 = c_const("car_43",domain=car) 217 | car_44 = c_const("car_44",domain=car) 218 | car_51 = c_const("car_51",domain=car) 219 | car_52 = c_const("car_52",domain=car) 220 | car_53 = c_const("car_53",domain=car) 221 | car_61 = c_const("car_61",domain=car) 222 | car_62 = c_const("car_62",domain=car) 223 | car_71 = c_const("car_71",domain=car) 224 | car_72 = c_const("car_72",domain=car) 225 | car_73 = c_const("car_73",domain=car) 226 | car_81 = c_const("car_81",domain=car) 227 | car_82 = c_const("car_82",domain=car) 228 | car_91 = c_const("car_91",domain=car) 229 | car_92 = c_const("car_92",domain=car) 230 | car_93 = c_const("car_93",domain=car) 231 | car_94 = c_const("car_94",domain=car) 232 | car_101 = c_const("car_101",domain=car) 233 | car_102 = c_const("car_102",domain=car) 234 | 235 | east1 = c_const("east1",domain=train) 236 | east2 = c_const("east2",domain=train) 237 | east3 = c_const("east3",domain=train) 238 | east4 = c_const("east4",domain=train) 239 | east5 = c_const("east5",domain=train) 240 | west6 = c_const("west6",domain=train) 241 | west7 = c_const("west7",domain=train) 242 | west8 = c_const("west8",domain=train) 243 | west9 = c_const("west9",domain=train) 244 | west10 = c_const("west10",domain=train) 245 | 246 | elipse = c_const("elipse",domain=shape) 247 | hexagon = c_const("hexagon",domain=shape) 248 | rectangle = c_const("rectangle",domain=shape) 249 | u_shaped = c_const("u_shaped",domain=shape) 250 | triangle = c_const("triangle",domain=shape) 251 | circle = c_const("circle",domain=shape) 252 | nil = c_const("nil",domain=shape) 253 | 254 | n0 = c_const("n0",domain=num) 255 | n1 = c_const("n1",domain=num) 256 | n2 = c_const("n2",domain=num) 257 | n3 = c_const("n3",domain=num) 258 | 259 | #eastbound train 1 260 | bk = [ 261 | short(car_12), 262 | closed(car_12), 263 | long(car_11), 264 | long(car_13), 265 | short(car_14), 266 | open_car(car_11), 267 | open_car(car_13), 268 | open_car(car_14), 269 | car_shape(car_11,rectangle), 270 | car_shape(car_12,rectangle), 271 | car_shape(car_13,rectangle), 272 | car_shape(car_14,rectangle), 273 | load(car_11,rectangle,n3), 274 | load(car_12,triangle,n1), 275 | load(car_13,hexagon,n1), 276 | load(car_14,circle,n1), 277 | wheels(car_11,n2), 278 | wheels(car_12,n2), 279 | wheels(car_13,n3), 280 | wheels(car_14,n2), 281 | has_car(east1,car_11), 282 | has_car(east1,car_12), 283 | has_car(east1,car_13), 284 | has_car(east1,car_14), 285 | 286 | #eastbound train 287 | has_car(east2,car_21), 288 | has_car(east2,car_22), 289 | has_car(east2,car_23), 290 | short(car_21), 291 | short(car_22), 292 | short(car_23), 293 | car_shape(car_21,u_shaped), 294 | car_shape(car_22,u_shaped), 295 | car_shape(car_23,rectangle), 296 | open_car(car_21), 297 | open_car(car_22), 298 | closed(car_23), 299 | load(car_21,triangle,n1), 300 | load(car_22,rectangle,n1), 301 | load(car_23,circle,n2), 302 | wheels(car_21,n2), 303 | wheels(car_22,n2), 304 | wheels(car_23,n2), 305 | 306 | #eastbound train 307 | has_car(east3,car_31), 308 | has_car(east3,car_32), 309 | has_car(east3,car_33), 310 | short(car_31), 311 | short(car_32), 312 | long(car_33), 313 | car_shape(car_31,rectangle), 314 | car_shape(car_32,hexagon), 315 | car_shape(car_33,rectangle), 316 | open_car(car_31), 317 | closed(car_32), 318 | closed(car_33), 319 | load(car_31,circle,n1), 320 | load(car_32,triangle,n1), 321 | load(car_33,triangle,n1), 322 | wheels(car_31,n2), 323 | wheels(car_32,n2), 324 | wheels(car_33,n3), 325 | 326 | #eastbound train 327 | has_car(east4,car_41), 328 | has_car(east4,car_42), 329 | has_car(east4,car_43), 330 | has_car(east4,car_44), 331 | short(car_41), 332 | short(car_42), 333 | short(car_43), 334 | short(car_44), 335 | car_shape(car_41,u_shaped), 336 | car_shape(car_42,rectangle), 337 | car_shape(car_43,elipse), 338 | car_shape(car_44,rectangle), 339 | double(car_42), 340 | open_car(car_41), 341 | open_car(car_42), 342 | closed(car_43), 343 | open_car(car_44), 344 | load(car_41,triangle,n1), 345 | load(car_42,triangle,n1), 346 | load(car_43,rectangle,n1), 347 | load(car_44,rectangle,n1), 348 | wheels(car_41,n2), 349 | wheels(car_42,n2), 350 | wheels(car_43,n2), 351 | wheels(car_44,n2), 352 | 353 | #eastbound train 354 | has_car(east5,car_51), 355 | has_car(east5,car_52), 356 | has_car(east5,car_53), 357 | short(car_51), 358 | short(car_52), 359 | short(car_53), 360 | car_shape(car_51,rectangle), 361 | car_shape(car_52,rectangle), 362 | car_shape(car_53,rectangle), 363 | double(car_51), 364 | open_car(car_51), 365 | closed(car_52), 366 | closed(car_53), 367 | load(car_51,triangle,n1), 368 | load(car_52,rectangle,n1), 369 | load(car_53,circle,n1), 370 | wheels(car_51,n2), 371 | wheels(car_52,n3), 372 | wheels(car_53,n2), 373 | 374 | #westbound train 375 | has_car(west6,car_61), 376 | has_car(west6,car_62), 377 | long(car_61), 378 | short(car_62), 379 | car_shape(car_61,rectangle), 380 | car_shape(car_62,rectangle), 381 | closed(car_61), 382 | open_car(car_62), 383 | load(car_61,circle,n3), 384 | load(car_62,triangle,n1), 385 | wheels(car_61,n2), 386 | wheels(car_62,n2), 387 | 388 | #westbound train 389 | has_car(west7,car_71), 390 | has_car(west7,car_72), 391 | has_car(west7,car_73), 392 | short(car_71), 393 | short(car_72), 394 | long(car_73), 395 | car_shape(car_71,rectangle), 396 | car_shape(car_72,u_shaped), 397 | car_shape(car_73,rectangle), 398 | double(car_71), 399 | open_car(car_71), 400 | open_car(car_72), 401 | jagged(car_73), 402 | load(car_71,circle,n1), 403 | load(car_72,triangle,n1), 404 | load(car_73,nil,n0), 405 | wheels(car_71,n2), 406 | wheels(car_72,n2), 407 | wheels(car_73,n2), 408 | 409 | #westbound train 410 | has_car(west8,car_81), 411 | has_car(west8,car_82), 412 | long(car_81), 413 | short(car_82), 414 | car_shape(car_81,rectangle), 415 | car_shape(car_82,u_shaped), 416 | closed(car_81), 417 | open_car(car_82), 418 | load(car_81,rectangle,n1), 419 | load(car_82,circle,n1), 420 | wheels(car_81,n3), 421 | wheels(car_82,n2), 422 | 423 | #westbound train 424 | has_car(west9,car_91), 425 | has_car(west9,car_92), 426 | has_car(west9,car_93), 427 | has_car(west9,car_94), 428 | short(car_91), 429 | long(car_92), 430 | short(car_93), 431 | short(car_94), 432 | car_shape(car_91,u_shaped), 433 | car_shape(car_92,rectangle), 434 | car_shape(car_93,rectangle), 435 | car_shape(car_94,u_shaped), 436 | open_car(car_91), 437 | jagged(car_92), 438 | open_car(car_93), 439 | open_car(car_94), 440 | load(car_91,circle,n1), 441 | load(car_92,rectangle,n1), 442 | load(car_93,rectangle,n1), 443 | load(car_93,circle,n1), 444 | wheels(car_91,n2), 445 | wheels(car_92,n2), 446 | wheels(car_93,n2), 447 | wheels(car_94,n2), 448 | 449 | # westbound train 1 450 | has_car(west10,car_101), 451 | has_car(west10,car_102), 452 | short(car_101), 453 | long(car_102), 454 | car_shape(car_101,u_shaped), 455 | car_shape(car_102,rectangle), 456 | open_car(car_101), 457 | open_car(car_102), 458 | load(car_101,rectangle,n1), 459 | load(car_102,rectangle,n2), 460 | wheels(car_101,n2), 461 | wheels(car_102,n2), 462 | ] 463 | return bk 464 | 465 | if __name__ == "__main__": 466 | 467 | bk = michalski_trains() 468 | pl = SWIProlog() 469 | 470 | for fact in bk: 471 | pl.assert_fact(fact) 472 | 473 | x = c_var("X") 474 | y = c_var("Y") 475 | 476 | t = c_const("west10") 477 | has_car = c_pred("has_car",2) 478 | 479 | sols = pl.query(has_car(t,y)) 480 | 481 | for sol in sols: 482 | print(f"has_car({t},{sol[y]})") 483 | 484 | 485 | 486 | 487 | 488 | 489 | 490 | 491 | 492 | -------------------------------------------------------------------------------- /tests/eval_functions_tests.py: -------------------------------------------------------------------------------- 1 | import typing 2 | from abc import ABC, abstractmethod 3 | import math 4 | 5 | from orderedset import OrderedSet 6 | 7 | from loreleai.language.lp import c_pred, Clause, Procedure, Atom, Variable 8 | from loreleai.learning.hypothesis_space import TopDownHypothesisSpace 9 | from loreleai.learning.language_filtering import has_singleton_vars, has_duplicated_literal 10 | from loreleai.learning.language_manipulation import plain_extension 11 | from loreleai.learning.task import Task, Knowledge 12 | from loreleai.reasoning.lp.prolog import SWIProlog, Prolog 13 | from loreleai.learning.learners import SimpleBreadthFirstLearner 14 | from loreleai.learning.eval_functions import Accuracy, Compression, Coverage, Entropy 15 | 16 | # define the predicates 17 | father = c_pred("father", 2) 18 | mother = c_pred("mother", 2) 19 | grandparent = c_pred("grandparent", 2) 20 | # specify the background knowledge 21 | background = Knowledge(father("a", "b"), mother("a", "b"), mother("b", "c"), 22 | father("e", "f"), father("f", "g"), 23 | mother("h", "i"), mother("i", "j")) 24 | 25 | # positive examples 26 | pos = {grandparent("a", "c"), grandparent("e", "g"), grandparent("h", "j")} 27 | 28 | # negative examples 29 | neg = {grandparent("a", "b"), grandparent("a", "g"), grandparent("i", "j")} 30 | task = Task(positive_examples=pos, negative_examples=neg) 31 | 32 | 33 | # 1. A clause of length 3 covers all positive examples and no negative examples 34 | covered = pos 35 | x = Variable("X") 36 | y = Variable("Y") 37 | # contents of clause don't matter, only length 38 | cl = grandparent(x,y) <= father(x,y) & father(x,y) & father(x,y) 39 | 40 | # accuracy = P/(N+P) 41 | acc = Accuracy() 42 | # coverage = P-N 43 | cov = Coverage() 44 | # compression = P-N-L+1 45 | comp = Compression() 46 | # entropy = p log p + (1-p) log (1-p) where p = P/(P + N) 47 | entr = Entropy() 48 | 49 | assert(acc.evaluate(cl,task,covered) == 1) 50 | assert(cov.evaluate(cl,task,covered) == 3) 51 | assert(comp.evaluate(cl,task,covered) == 3-3+1) 52 | assert(entr.evaluate(cl,task,covered) == 0) 53 | 54 | # 2. A clause of length 2 covers 3 positive examples, 2 negative examples 55 | covered = list(pos) + [grandparent("a", "b"),grandparent("a", "g")] 56 | # contents of clause don't matter, only length 57 | cl = grandparent(x,y) <= father(x,y) & father(x,y) 58 | 59 | assert(acc.evaluate(cl,task,covered) == 3/5) 60 | assert(cov.evaluate(cl,task,covered) == 3-2) 61 | assert(comp.evaluate(cl,task,covered) == 3-2-2+1) 62 | assert(entr.evaluate(cl,task,covered) == 0.29228525323862886) 63 | 64 | # 3. A clause of length 4 covers 0 positive examples, 3 negative examples 65 | covered = neg 66 | # contents of clause don't matter, only length 67 | cl = grandparent(x,y) <= father(x,y) & father(x,y) & father(x,y) & father(x,y) 68 | 69 | assert(acc.evaluate(cl,task,covered) == 0/3) 70 | assert(cov.evaluate(cl,task,covered) == 0-3) 71 | assert(comp.evaluate(cl,task,covered) == -3-4+1) 72 | assert(entr.evaluate(cl,task,covered) == 0) 73 | 74 | 75 | -------------------------------------------------------------------------------- /tests/language_test.py: -------------------------------------------------------------------------------- 1 | from loreleai.language.lp import c_var, c_pred, c_const, Predicate, Constant, Variable, Clause, Atom, Disjunction 2 | from loreleai.learning.hypothesis_space import TopDownHypothesisSpace 3 | from loreleai.learning.language_filtering import has_singleton_vars, connected_clause 4 | from loreleai.learning.language_manipulation import plain_extension, BottomClauseExpansion 5 | 6 | 7 | class LanguageTest: 8 | 9 | def manual_constructs(self): 10 | p1 = c_const("p1") 11 | p2 = c_const("p2") 12 | p3 = c_const("p3") 13 | 14 | parent = c_pred("parent", 2) 15 | grandparent = c_pred("grandparent", 2) 16 | 17 | f1 = parent(p1, p2) 18 | f2 = parent(p2, p3) 19 | 20 | v1 = c_var("X") 21 | v2 = c_var("Y") 22 | v3 = c_var("Z") 23 | 24 | cl = grandparent(v1, v3) <= parent(v1, v2) & parent(v2, v3) 25 | 26 | assert isinstance(p1, Constant) 27 | assert isinstance(p2, Constant) 28 | assert isinstance(p3, Constant) 29 | 30 | assert isinstance(parent, Predicate) 31 | assert isinstance(grandparent, Predicate) 32 | 33 | assert isinstance(v1, Variable) 34 | assert isinstance(v2, Variable) 35 | assert isinstance(v2, Variable) 36 | 37 | assert isinstance(cl, Clause) 38 | 39 | assert isinstance(f1, Atom) 40 | assert isinstance(f2, Atom) 41 | 42 | def shorthand_constructs(self): 43 | parent = c_pred("parent", 2) 44 | grandparent = c_pred("grandparent", 2) 45 | 46 | f1 = parent("p1", "p2") 47 | f2 = parent("p2", "p3") 48 | f3 = grandparent("p1", "X") 49 | 50 | assert isinstance(parent, Predicate) 51 | assert isinstance(grandparent, Predicate) 52 | 53 | assert isinstance(f1, Atom) 54 | assert isinstance(f2, Atom) 55 | assert isinstance(f3, Atom) 56 | 57 | assert isinstance(f1.arguments[0], Constant) 58 | assert isinstance(f1.arguments[1], Constant) 59 | assert isinstance(f3.arguments[1], Variable) 60 | 61 | 62 | class LanguageManipulationTest: 63 | 64 | def plain_clause_extensions(self): 65 | parent = c_pred("parent", 2) 66 | grandparent = c_pred("grandparent", 2) 67 | 68 | cl1 = grandparent("X", "Y") <= parent("X", "Z") 69 | 70 | extensions = plain_extension(cl1, parent, connected_clauses=False) 71 | 72 | assert len(extensions) == 16 73 | 74 | def plain_clause_extensions_connected(self): 75 | parent = c_pred("parent", 2) 76 | grandparent = c_pred("grandparent", 2) 77 | 78 | cl1 = grandparent("X", "Y") <= parent("X", "Z") 79 | 80 | extensions = plain_extension(cl1, parent, connected_clauses=True) 81 | 82 | assert len(extensions) == 15 83 | 84 | def plain_procedure_extension(self): 85 | parent = c_pred("parent", 2) 86 | ancestor = c_pred("ancestor", 2) 87 | 88 | cl1 = ancestor("X", "Y") <= parent("X", "Y") 89 | cl2 = ancestor("X", "Y") <= parent("X", "Z") & parent("Z", "Y") 90 | 91 | proc = Disjunction([cl1, cl2]) 92 | 93 | extensions = plain_extension(proc, parent, connected_clauses=False) 94 | 95 | assert len(extensions) == 25 96 | 97 | 98 | class HypothesisSpace(): 99 | 100 | def top_down_plain(self): 101 | grandparent = c_pred("grandparent", 2) 102 | father = c_pred("father", 2) 103 | mother = c_pred("mother", 2) 104 | 105 | hs = TopDownHypothesisSpace(primitives=[lambda x: plain_extension(x, father), 106 | lambda x: plain_extension(x, mother)], 107 | head_constructor=grandparent) 108 | 109 | current_cand = hs.get_current_candidate() 110 | assert len(current_cand) == 3 111 | 112 | expansions = hs.expand(current_cand[0]) 113 | assert len(expansions) == 6 114 | 115 | expansions_2 = hs.expand(expansions[0]) 116 | assert len(expansions_2) == 10 117 | 118 | expansions3 = hs.expand(expansions[1]) 119 | assert len(expansions3) == 32 120 | 121 | hs.block(expansions[2]) 122 | expansions4 = hs.expand(expansions[2]) 123 | assert len(expansions4) == 0 124 | 125 | hs.remove(expansions[3]) 126 | expansions5 = hs.get_successors_of(current_cand[0]) 127 | assert len(expansions5) == 5 128 | 129 | hs.move_pointer_to(expansions[1]) 130 | current_cand = hs.get_current_candidate() 131 | assert current_cand[0] == expansions[1] 132 | 133 | hs.ignore(expansions[4]) 134 | hs.move_pointer_to(expansions[4]) 135 | expansions6 = hs.get_current_candidate() 136 | assert len(expansions6) == 0 137 | 138 | def top_down_limited(self): 139 | grandparent = c_pred("grandparent", 2) 140 | father = c_pred("father", 2) 141 | mother = c_pred("mother", 2) 142 | 143 | hs = TopDownHypothesisSpace(primitives=[lambda x: plain_extension(x, father, connected_clauses=False), 144 | lambda x: plain_extension(x, mother, connected_clauses=False)], 145 | head_constructor=grandparent, 146 | expansion_hooks_keep=[lambda x, y: connected_clause(x, y)], 147 | expansion_hooks_reject=[lambda x, y: has_singleton_vars(x, y)]) 148 | 149 | current_cand = hs.get_current_candidate() 150 | assert len(current_cand) == 3 151 | 152 | expansions = hs.expand(current_cand[1]) 153 | assert len(expansions) == 6 154 | 155 | expansion2 = hs.expand(expansions[1]) 156 | assert len(expansion2) == 16 157 | 158 | def bottom_up(self): 159 | a = c_pred("a", 2) 160 | b = c_pred("b", 2) 161 | c = c_pred("c", 1) 162 | h = c_pred("h", 2) 163 | 164 | cl = h("X", "Y") <= a("X", "Z") & b("Z", "Y") & c("X") 165 | 166 | bc = BottomClauseExpansion(cl) 167 | 168 | hs = TopDownHypothesisSpace(primitives=[lambda x: bc.expand(x)], 169 | head_constructor=h) 170 | 171 | cls = hs.get_current_candidate() 172 | assert len(cls) == 3 173 | 174 | exps = hs.expand(cls[1]) 175 | assert len(exps) == 2 176 | 177 | exps2 = hs.expand(exps[0]) 178 | assert len(exps2) == 4 179 | 180 | def recursions(self): 181 | parent = c_pred("parent", 2) 182 | ancestor = c_pred("ancestor", 2) 183 | 184 | hs = TopDownHypothesisSpace(primitives=[lambda x: plain_extension(x, parent, connected_clauses=True)], 185 | head_constructor=ancestor, 186 | recursive_procedures=True) 187 | 188 | cls = hs.get_current_candidate() 189 | cls2 = hs.expand(cls[1]) 190 | print(cls2) 191 | 192 | 193 | 194 | 195 | 196 | def test_language(): 197 | test = LanguageTest() 198 | test.manual_constructs() 199 | 200 | test_bias = LanguageManipulationTest() 201 | test_bias.plain_clause_extensions() 202 | test_bias.plain_clause_extensions_connected() 203 | test_bias.plain_procedure_extension() 204 | # 205 | test_hypothesis_space = HypothesisSpace() 206 | test_hypothesis_space.top_down_plain() 207 | test_hypothesis_space.top_down_limited() 208 | test_hypothesis_space.bottom_up() 209 | test_hypothesis_space.recursions() 210 | 211 | test_language() -------------------------------------------------------------------------------- /tests/solver_datalog_test.py: -------------------------------------------------------------------------------- 1 | from loreleai.language.lp import c_const, c_var, c_pred 2 | from loreleai.reasoning.lp.datalog import MuZ 3 | 4 | 5 | class DatalogTests: 6 | 7 | def simple_grandparent(self): 8 | p1 = c_const("p1") 9 | p2 = c_const("p2") 10 | p3 = c_const("p3") 11 | 12 | parent = c_pred("parent", 2) 13 | grandparent = c_pred("grandparent", 2) 14 | 15 | f1 = parent(p1, p2) 16 | f2 = parent(p2, p3) 17 | 18 | v1 = c_var("X") 19 | v2 = c_var("Y") 20 | v3 = c_var("Z") 21 | 22 | cl = (grandparent(v1, v3) <= parent(v1, v2) & parent(v2, v3)) 23 | 24 | solver = MuZ() 25 | 26 | solver.assert_fact(f1) 27 | solver.assert_fact(f2) 28 | solver.assert_rule(cl) 29 | 30 | # assert solver.has_solution(parent(v1, v2)) 31 | # assert not solver.has_solution(parent(v1, v1)) 32 | # assert len(solver.all_solutions(parent(v1, v2))) == 2 33 | # assert len(solver.all_solutions(parent(p1, v1))) == 1 34 | # assert solver.has_solution(parent(p1, p2)) 35 | # assert not solver.has_solution(parent(p2, p1)) 36 | # assert len(solver.one_solution(parent(p1, v1))) == 1 37 | # 38 | # assert solver.has_solution(grandparent(v1, v2)) 39 | # assert solver.has_solution(grandparent(p1, v1)) 40 | # assert len(solver.one_solution(grandparent(p1, v1))) == 1 41 | # assert solver.has_solution(grandparent(p1, p3)) 42 | # assert not solver.has_solution(grandparent(p2, v1)) 43 | # assert len(solver.one_solution(grandparent(p1, v1))) == 1 44 | # ans = solver.one_solution(grandparent(p1, v1)) 45 | # assert ans[v1] == p3 46 | # ans = solver.one_solution(grandparent(v1, v2)) 47 | # assert ans[v1] == p1 and ans[v2] == p3 48 | # 49 | # assert solver.has_solution(cl) 50 | # ans = solver.one_solution(cl) 51 | # assert ans[v1] == p1 and ans[v3] == p3 52 | # assert len(solver.all_solutions(cl)) == 1 53 | 54 | assert solver.has_solution(parent(v1, v2)) 55 | assert not solver.has_solution(parent(v1, v1)) 56 | assert len(solver.query(parent(v1, v2))) == 2 57 | assert len(solver.query(parent(p1, v1))) == 1 58 | assert solver.has_solution(parent(p1, p2)) 59 | assert not solver.has_solution(parent(p2, p1)) 60 | assert len(solver.query(parent(p1, v1), max_solutions=1)) == 1 61 | 62 | assert solver.has_solution(grandparent(v1, v2)) 63 | assert solver.has_solution(grandparent(p1, v1)) 64 | assert len(solver.query(grandparent(p1, v1), max_solutions=1)) == 1 65 | assert solver.has_solution(grandparent(p1, p3)) 66 | assert not solver.has_solution(grandparent(p2, v1)) 67 | assert len(solver.query(grandparent(p1, v1), max_solutions=1)) == 1 68 | ans = solver.query(grandparent(p1, v1), max_solutions=1)[0] 69 | assert ans[v1] == p3 70 | ans = solver.query(grandparent(v1, v2), max_solutions=1)[0] 71 | assert ans[v1] == p1 and ans[v2] == p3 72 | 73 | assert solver.has_solution(*cl.get_literals()) 74 | ans = solver.query(*cl.get_literals(), max_solutions=1)[0] 75 | assert ans[v1] == p1 and ans[v3] == p3 76 | assert len(solver.query(*cl.get_literals())) == 1 77 | 78 | def graph_connectivity(self): 79 | v1 = c_const("v1") 80 | v2 = c_const("v2") 81 | v3 = c_const("v3") 82 | v4 = c_const("v4") 83 | 84 | edge = c_pred("edge", 2) 85 | path = c_pred("path", 2) 86 | 87 | f1 = edge(v1, v2) 88 | f2 = edge(v1, v3) 89 | f3 = edge(v2, v4) 90 | 91 | X = c_var("X") 92 | Y = c_var("Y") 93 | Z = c_var("Z") 94 | 95 | cl1 = path(X, Y) <= edge(X, Y) 96 | cl2 = path(X, Y) <= path(X, Z) & edge(Z, Y) 97 | 98 | solver = MuZ() 99 | 100 | solver.assert_fact(f1) 101 | solver.assert_fact(f2) 102 | solver.assert_fact(f3) 103 | 104 | solver.assert_rule(cl1) 105 | solver.assert_rule(cl2) 106 | 107 | # assert solver.has_solution(path(v1, v2)) 108 | # assert solver.has_solution(path(v1, v4)) 109 | # assert not solver.has_solution(path(v3, v4)) 110 | # 111 | # assert len(solver.one_solution(path(v1, X))) == 1 112 | # assert len(solver.one_solution(path(X, v4))) == 1 113 | # assert len(solver.one_solution(path(X, Y))) == 2 114 | # 115 | # assert len(solver.all_solutions(path(v1, X))) == 3 116 | # assert len(solver.all_solutions(path(X, Y))) == 4 117 | 118 | assert solver.has_solution(path(v1, v2)) 119 | assert solver.has_solution(path(v1, v4)) 120 | assert not solver.has_solution(path(v3, v4)) 121 | 122 | assert len(solver.query(path(v1, X), max_solutions=1)[0]) == 1 123 | assert len(solver.query(path(X, v4), max_solutions=1)[0]) == 1 124 | assert len(solver.query(path(X, Y), max_solutions=1)[0]) == 2 125 | 126 | assert len(solver.query(path(v1, X))) == 3 127 | assert len(solver.query(path(X, Y))) == 4 128 | 129 | 130 | def test_datalog(): 131 | dtest = DatalogTests() 132 | 133 | dtest.simple_grandparent() 134 | dtest.graph_connectivity() 135 | 136 | print("all tests done!") 137 | 138 | test_datalog() 139 | -------------------------------------------------------------------------------- /tests/solver_kanren_test.py: -------------------------------------------------------------------------------- 1 | from loreleai.language.kanren import c_var, c_pred, c_const 2 | from loreleai.reasoning.lp.kanren import MiniKanren 3 | 4 | 5 | class KanrenTest: 6 | 7 | def simple_grandparent(self): 8 | p1 = c_const("p1") 9 | p2 = c_const("p2") 10 | p3 = c_const("p3") 11 | 12 | parent = c_pred("parent", 2) 13 | grandparent = c_pred("grandparent", 2) 14 | 15 | f1 = parent(p1, p2) 16 | f2 = parent(p2, p3) 17 | 18 | v1 = c_var("X") 19 | v2 = c_var("Y") 20 | v3 = c_var("Z") 21 | 22 | cl = grandparent(v1, v3) <= parent(v1, v2) & parent(v2, v3) 23 | 24 | solver = MiniKanren() 25 | 26 | solver.assert_fact(f1) 27 | solver.assert_fact(f2) 28 | solver.assert_rule(cl) 29 | 30 | # assert solver.has_solution(parent(v1, v2)) 31 | # assert not solver.has_solution(parent(v1, v1)) 32 | # assert len(solver.all_solutions(parent(v1, v2))) == 2 33 | # assert len(solver.all_solutions(parent(p1, v1))) == 1 34 | # assert solver.has_solution(parent(p1, p2)) 35 | # assert not solver.has_solution(parent(p2, p1)) 36 | # assert len(solver.one_solution(parent(p1, v1))) == 1 37 | # 38 | # assert solver.has_solution(grandparent(v1, v2)) 39 | # assert solver.has_solution(grandparent(p1, v1)) 40 | # assert len(solver.one_solution(grandparent(p1, v1))) == 1 41 | # assert solver.has_solution(grandparent(p1, p3)) 42 | # assert not solver.has_solution(grandparent(p2, v1)) 43 | # assert len(solver.one_solution(grandparent(p1, v1))) == 1 44 | # ans = solver.one_solution(grandparent(p1, v1)) 45 | # assert ans[v1] == p3 46 | # ans = solver.one_solution(grandparent(v1, v2)) 47 | # assert ans[v1] == p1 and ans[v2] == p3 48 | # 49 | # assert solver.has_solution(cl) 50 | # ans = solver.one_solution(cl) 51 | # assert ans[v1] == p1 and ans[v3] == p3 52 | # assert len(solver.all_solutions(cl)) == 1 53 | 54 | assert solver.has_solution(parent(v1, v2)) 55 | assert not solver.has_solution(parent(v1, v1)) 56 | assert len(solver.query(parent(v1, v2))) == 2 57 | assert len(solver.query(parent(p1, v1))) == 1 58 | assert solver.has_solution(parent(p1, p2)) 59 | assert not solver.has_solution(parent(p2, p1)) 60 | assert len(solver.query(parent(p1, v1), max_solutions=1)) == 1 61 | 62 | assert solver.has_solution(grandparent(v1, v2)) 63 | assert solver.has_solution(grandparent(p1, v1)) 64 | assert len(solver.query(grandparent(p1, v1), max_solutions=1)) == 1 65 | assert solver.has_solution(grandparent(p1, p3)) 66 | assert not solver.has_solution(grandparent(p2, v1)) 67 | assert len(solver.query(grandparent(p1, v1), max_solutions=1)) == 1 68 | ans = solver.query(grandparent(p1, v1), max_solutions=1)[0] 69 | assert ans[v1] == p3 70 | ans = solver.query(grandparent(v1, v2), max_solutions=1)[0] 71 | assert ans[v1] == p1 and ans[v2] == p3 72 | 73 | assert solver.has_solution(*cl.get_literals()) 74 | ans = solver.query(*cl.get_literals(), max_solutions=1)[0] 75 | assert ans[v1] == p1 and ans[v3] == p3 76 | assert len(solver.query(*cl.get_literals())) == 1 77 | 78 | def graph_connectivity(self): 79 | v1 = c_const("v1") 80 | v2 = c_const("v2") 81 | v3 = c_const("v3") 82 | v4 = c_const("v4") 83 | 84 | edge = c_pred("edge", 2) 85 | path = c_pred("path", 2) 86 | 87 | f1 = edge(v1, v2) 88 | f2 = edge(v1, v3) 89 | f3 = edge(v2, v4) 90 | 91 | X = c_var("X") 92 | Y = c_var("Y") 93 | Z = c_var("Z") 94 | 95 | cl1 = path(X, Y) <= edge(X, Y) 96 | cl2 = path(X, Y) <= edge(X, Z) & path(Z, Y) 97 | 98 | solver = MiniKanren() 99 | 100 | solver.assert_fact(f1) 101 | solver.assert_fact(f2) 102 | solver.assert_fact(f3) 103 | 104 | solver.assert_rule([cl1, cl2]) 105 | 106 | # assert solver.has_solution(path(v1, v2)) 107 | # assert solver.has_solution(path(v1, v4)) 108 | # assert not solver.has_solution(path(v3, v4)) 109 | # 110 | # assert len(solver.one_solution(path(v1, X))) == 1 111 | # assert len(solver.one_solution(path(X, v4))) == 1 112 | # assert len(solver.one_solution(path(X, Y))) == 2 113 | # 114 | # assert len(solver.all_solutions(path(v1, X))) == 3 115 | # assert len(solver.all_solutions(path(X, Y))) == 4 116 | 117 | assert solver.has_solution(path(v1, v2)) 118 | assert solver.has_solution(path(v1, v4)) 119 | assert not solver.has_solution(path(v3, v4)) 120 | 121 | assert len(solver.query(path(v1, X), max_solutions=1)[0]) == 1 122 | assert len(solver.query(path(X, v4), max_solutions=1)[0]) == 1 123 | assert len(solver.query(path(X, Y), max_solutions=1)[0]) == 2 124 | 125 | assert len(solver.query(path(v1, X))) == 3 126 | assert len(solver.query(path(X, Y))) == 4 127 | 128 | 129 | def test_kanren(): 130 | test = KanrenTest() 131 | 132 | test.simple_grandparent() 133 | test.graph_connectivity() 134 | 135 | print("all tests done!") 136 | 137 | test_kanren() -------------------------------------------------------------------------------- /tests/solver_prolog_gnu_test.py: -------------------------------------------------------------------------------- 1 | 2 | from loreleai.language.lp import c_functor, c_pred, c_var, List 3 | from loreleai.reasoning.lp.prolog import GNUProlog 4 | 5 | pl = GNUProlog() 6 | 7 | p = c_pred("p", 2) 8 | f = c_functor("t", 3) 9 | f1 = p("a", "b") 10 | 11 | pl.assertz(f1) 12 | 13 | X = c_var("X") 14 | Y = c_var("Y") 15 | 16 | query = p(X, Y) 17 | 18 | r = pl.has_solution(query) 19 | print("has solution", r) 20 | 21 | rv = pl.query(query) 22 | print("all solutions", rv) 23 | 24 | f2 = p("a", "c") 25 | pl.assertz(f2) 26 | 27 | rv = pl.query(query) 28 | print("all solutions after adding f2", rv) 29 | 30 | func1 = f(1, 2, 3) 31 | f3 = p(func1, "b") 32 | pl.assertz(f3) 33 | 34 | rv = pl.query(query) 35 | print("all solutions after adding structure", rv) 36 | 37 | l = List([1, 2, 3, 4, 5]) 38 | 39 | member = c_pred("member", 2) 40 | 41 | query2 = member(X, l) 42 | 43 | rv = pl.query(query2) 44 | print("all solutions to list membership ", rv) 45 | 46 | r = c_pred("r", 2) 47 | f4 = r("a", l) 48 | f5 = r("a", "b") 49 | 50 | pl.asserta(f4) 51 | pl.asserta(f5) 52 | 53 | query3 = r(X, Y) 54 | 55 | rv = pl.query(query3) 56 | print("all solutions after adding list ", rv) 57 | 58 | -------------------------------------------------------------------------------- /tests/solver_prolog_swipy_test.py: -------------------------------------------------------------------------------- 1 | from loreleai.language.lp import c_pred, c_functor, c_var, List 2 | from loreleai.reasoning.lp.prolog import SWIProlog 3 | 4 | pl = SWIProlog() 5 | 6 | p = c_pred("p", 2) 7 | f = c_functor("t", 3) 8 | f1 = p("a", "b") 9 | 10 | pl.assertz(f1) 11 | 12 | X = c_var("X") 13 | Y = c_var("Y") 14 | 15 | query = p(X, Y) 16 | 17 | r = pl.has_solution(query) 18 | print("has solution", r) 19 | 20 | rv = pl.query(query) 21 | print("all solutions", rv) 22 | 23 | f2 = p("a", "c") 24 | pl.assertz(f2) 25 | 26 | rv = pl.query(query) 27 | print("all solutions after adding f2", rv) 28 | 29 | func1 = f(1, 2, 3) 30 | f3 = p(func1, "b") 31 | pl.assertz(f3) 32 | 33 | rv = pl.query(query) 34 | print("all solutions after adding structure", rv) 35 | 36 | l = List([1, 2, 3, 4, 5]) 37 | 38 | member = c_pred("member", 2) 39 | 40 | query2 = member(X, l) 41 | 42 | rv = pl.query(query2) 43 | print("all solutions to list membership ", rv) 44 | 45 | r = c_pred("r", 2) 46 | f4 = r("a", l) 47 | f5 = r("a", "b") 48 | 49 | pl.asserta(f4) 50 | pl.asserta(f5) 51 | 52 | query3 = r(X, Y) 53 | 54 | rv = pl.query(query3) 55 | print("all solutions after adding list ", rv) -------------------------------------------------------------------------------- /tests/solver_prolog_xsb_test.py: -------------------------------------------------------------------------------- 1 | from loreleai.language.lp import c_pred, c_functor, c_var, List 2 | from loreleai.reasoning.lp.prolog import XSBProlog 3 | 4 | pl = XSBProlog("/Users/seb/Documents/programs/XSB") 5 | 6 | p = c_pred("p", 2) 7 | f = c_functor("t", 3) 8 | f1 = p("a", "b") 9 | 10 | pl.assertz(f1) 11 | 12 | X = c_var("X") 13 | Y = c_var("Y") 14 | 15 | query = p(X, Y) 16 | 17 | r = pl.has_solution(query) 18 | print("has solution", r) 19 | 20 | rv = pl.query(query) 21 | print("all solutions", rv) 22 | 23 | f2 = p("a", "c") 24 | pl.assertz(f2) 25 | 26 | rv = pl.query(query) 27 | print("all solutions after adding f2", rv) 28 | 29 | func1 = f(1, 2, 3) 30 | f3 = p(func1, "b") 31 | pl.assertz(f3) 32 | 33 | rv = pl.query(query) 34 | print("all solutions after adding structure", rv) 35 | 36 | l = List([1, 2, 3, 4, 5]) 37 | 38 | member = c_pred("member", 2) 39 | pl.use_module("lists", predicates=[member]) 40 | 41 | query2 = member(X, l) 42 | 43 | rv = pl.query(query2) 44 | print("all solutions to list membership ", rv) 45 | 46 | r = c_pred("r", 2) 47 | f4 = r("a", l) 48 | f5 = r("a", "b") 49 | 50 | pl.asserta(f4) 51 | pl.asserta(f5) 52 | 53 | query3 = r(X, Y) 54 | 55 | rv = pl.query(query3) 56 | print("all solutions after adding list ", rv) 57 | 58 | q = c_pred("q", 2) 59 | cl = (q("X", "Y") <= r("X", "Y") & r("X", "Z")) 60 | 61 | pl.assertz(cl) 62 | query4 = q("X", "Y") 63 | rv = pl.query(query4) 64 | print("all solutions to q: ", rv) --------------------------------------------------------------------------------