├── .gitignore ├── README.md └── src └── trimlogic ├── __init__.py ├── algorithm.py ├── counting.py ├── foil.py ├── graph.py ├── partialordering.py ├── predicate.py ├── stdlib.py ├── term.py ├── test.py ├── test ├── FamilyTreeTestCase.py ├── ListTestCase.py ├── __init__.py └── helper.py └── util.py /.gitignore: -------------------------------------------------------------------------------- 1 | *.pyc -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # FOIL Python 2 | This is a Python implementation of FOIL, First Order Inductive Learner, described in J.R. Quinlan's paper [Learning Logical Definitions from Relations](http://link.springer.com/article/10.1023%2FA%3A1022699322624). In addition, this includes implementations of unification, resolution, and a number of Prolog's standard predicates. 3 | 4 | My main goal in writing this was merely to experiment with machine learning via inductive logic, and reproducing J.R. Quinlan's results seemed like a good place to start. There isn't a fancy UI or interactive prompt, but you can look at the test cases to see how to use the library. This really isn't intended for reuse, or as an example of clean idiomatic Python code, it's just an academic exercise, so keep that in mind. 5 | 6 | ## What is FOIL? 7 | For an in-depth description, see the paper cited above. Basically, FOIL can learn the rule, as a set of horn clauses, that defines some relation given a set of examples in the relation and a set of examples not in the relation. For example, suppose you had the following facts: 8 | 9 | ```prolog 10 | father(frank, abe). 11 | father(frank, alan). 12 | father(alan, sean). 13 | father(sean, jane). 14 | father(george, bob). 15 | father(george, tim). 16 | father(bob, jan). 17 | father(tim, tom). 18 | father(tom, thomas). 19 | father(ian, ann). 20 | father(thomas, billy). 21 | 22 | mother(rebecca, alan). 23 | mother(rebecca, abe). 24 | mother(joan, sean). 25 | mother(jane, ann). 26 | mother(jannet, tim). 27 | mother(jannet, bob). 28 | mother(tammy, tom). 29 | mother(tipsy, thomas). 30 | mother(debrah, billy). 31 | mother(jill, jan). 32 | mother(jan, jane). 33 | ``` 34 | 35 | Now, suppose you didn't know the rule for the `ancestor` relation, but did have some positive examples of the relation (e.g. Tim is an ancestor of Tom, Jill is an ancestor of Ann, etc.) and some negtaive examples of the relation (e.g. Ann is not an ancestor of Billy, Tom is not an ancestor of George, etc.). From those examples and the above facts, FOIL can generate a rule for the `ancestor` relation such as the following: 36 | 37 | ```prolog 38 | ancestor(X, Y) :- father(X, Y). 39 | ancestor(X, Y) :- mother(X, Y). 40 | ancestor(X, Y) :- father(Z, Y), ancestor(X, Z). 41 | ancestor(X, Y) :- mother(Z, Y), ancestor(X, Z). 42 | ``` 43 | 44 | ## Test Cases 45 | ### Family Tree 46 | This test case demonstrates learning the `grandparent` and `ancestor` relations. Execute the following, from the project root: 47 | 48 | ``` 49 | cd src 50 | python trimlogic/test/FamilyTreeTestCase.py 51 | ``` 52 | 53 | You should see output like the following: 54 | 55 | ``` 56 | Rules for ancestor : 57 | ancestor(PARAM_0, PARAM_1) :- father(PARAM_0, PARAM_1). 58 | ancestor(PARAM_0, PARAM_1) :- mother(PARAM_0, PARAM_1). 59 | ancestor(PARAM_0, PARAM_1) :- father(VAR_13, PARAM_1), ancestor(PARAM_0, VAR_13). 60 | ancestor(PARAM_0, PARAM_1) :- mother(VAR_36, PARAM_1), ancestor(PARAM_0, VAR_36). 61 | .. 62 | Rules for grandfather : 63 | grandfather(PARAM_0, PARAM_1) :- father(VAR_4, PARAM_1), father(PARAM_0, VAR_4). 64 | grandfather(PARAM_0, PARAM_1) :- mother(VAR_21, PARAM_1), father(PARAM_0, VAR_21). 65 | . 66 | ---------------------------------------------------------------------- 67 | Ran 3 tests in 8.578s 68 | 69 | OK 70 | 71 | ``` 72 | 73 | ### List 74 | This test case demonstrates learning the `member` relation for lists. Execute the following from the project root: 75 | 76 | ``` 77 | cd src 78 | python trimlogic/test/FamilyTreeTestCase.py 79 | ``` 80 | 81 | You should see output like the following: 82 | 83 | ``` 84 | Rules for member : 85 | member(PARAM_0, PARAM_1) :- components(PARAM_1, PARAM_0, VAR_4). 86 | member(PARAM_0, PARAM_1) :- components(PARAM_1, VAR_12, VAR_13), member(PARAM_0, VAR_13). 87 | . 88 | ---------------------------------------------------------------------- 89 | Ran 1 test in 51.804s 90 | 91 | OK 92 | ``` 93 | 94 | -------------------------------------------------------------------------------- /src/trimlogic/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/johntrimble/foil-python/8628dc3113da66f8990ac673ed280d29cf82246a/src/trimlogic/__init__.py -------------------------------------------------------------------------------- /src/trimlogic/algorithm.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import trimlogic.predicate 3 | from collections import deque 4 | from trimlogic.stdlib import cut 5 | from trimlogic.term import * 6 | from trimlogic.util import * 7 | logger = logging.getLogger(__name__) 8 | 9 | def log_unify(unify_func): 10 | def _unify(s1, s2, bindings=None): 11 | logger.debug("unify( " + str(s1) + ", " + str(s2) + " ) :: " + str(bindings)) 12 | mgu = unify_func(s1, s2, bindings) 13 | if mgu != None: logger.debug("unified " + str(s1) + " and " + str(s2) + " with " + str(mgu)) 14 | else: logger.debug("failed to unify " + str(s1) + " and " + str(s2)) 15 | return mgu 16 | return _unify 17 | 18 | def compose(f, g, composition=None): 19 | if composition == None: composition = {} 20 | f_composition_not_same_instance = not composition is f 21 | for key in f.keys(): 22 | if isinstance(f[key], Term): 23 | composition[key] = f[key].apply_bindings(g) 24 | elif f_composition_not_same_instance: 25 | composition[key] = f[key] 26 | composition.update(g) 27 | return composition 28 | 29 | def unify(s1, s2, bindings=None): 30 | """ 31 | Term, Term -> Map 32 | Implementation Notes: 33 | This is a basic implementation of the most general unifier algorithm presented 34 | in Artificial Intelligence: A Modern Approach 2nd Edition by Stuart J. Russel 35 | and Peter Norvig. 36 | """ 37 | logger = logging.getLogger('unify') 38 | logger.debug("start unify(%s, %s, %s)" % (s1, s2, bindings)) 39 | if bindings == None: 40 | bindings = {} 41 | if isinstance(s1, list): 42 | s1 = tuple(s1) 43 | if isinstance(s2, list): 44 | s2 = tuple(s2) 45 | return _unify(s1, s2, bindings) 46 | 47 | def _unify_var(var, x, bindings): 48 | if bindings.has_key(var): 49 | return _unify(bindings[var], x, bindings) 50 | try: 51 | if bindings.has_key(x): 52 | return _unify(var, bindings[x], bindings) 53 | except TypeError: 54 | pass 55 | return compose(bindings, {var : x}, bindings) 56 | 57 | def apply_bindings_seq(seq, bindings): 58 | l = [] 59 | for x in seq: 60 | try: 61 | l.append(x.apply_bindings(bindings)) 62 | except: 63 | l.append(x) 64 | return tuple(l) 65 | 66 | def _unify(s1, s2, bindings): 67 | s1_is_tuple, s2_is_tuple = isinstance(s1, tuple), isinstance(s2, tuple) 68 | s1_and_s2_tuple = s1_is_tuple and s2_is_tuple 69 | # base cases 70 | if s1 == s2: 71 | return bindings 72 | # recursive cases 73 | elif isinstance(s1, Var): 74 | return _unify_var(s1, s2, bindings) 75 | elif isinstance(s2, Var): 76 | return _unify_var(s2, s1, bindings) 77 | elif isinstance(s1, Pred) and isinstance(s2, Pred): 78 | P, P_terms = s1.predicate, s1.terms 79 | F, F_terms = s2.predicate, s2.terms 80 | if P == F: 81 | return _unify(P_terms, F_terms, bindings) 82 | elif s1_and_s2_tuple and len(s1) == len(s2): 83 | new_bindings = _unify(s1[0], s2[0], bindings) 84 | if new_bindings: 85 | return _unify(apply_bindings_seq(s1[1:], new_bindings), apply_bindings_seq(s2[1:], new_bindings), new_bindings) 86 | else: 87 | return None 88 | 89 | def fol_bc_ask(goals, substitutions): 90 | """ 91 | Attempts to satisfy a given set of goals, if one or more of the goals contains unbound variables, 92 | this algorithm will find every binding for every variable so that the goals are satisfied. The 93 | solutions are given as a sequence of variable mappings that satisfy the goals. If there is no 94 | mapping that will satisfy the goals then the generator yields no results. 95 | """ 96 | logger.debug("fol_bc_ask( " + str(goals) + " ) :: " + str(substitutions)) 97 | if len(goals) == 0: 98 | yield substitutions 99 | return 100 | goal = goals[0].apply_bindings(substitutions) 101 | logger.debug("goal after substitution: " + str(goal)) 102 | for mgu,new_goals,variables in goal.predicate._resolve(goal.terms): 103 | for child_answers in fol_bc_ask(new_goals + goals[1:], compose(substitutions, mgu)): 104 | if child_answers == None: 105 | logger.debug("received None for answers") 106 | yield None 107 | return 108 | for var in variables: 109 | try: 110 | del child_answers[var] 111 | logger.debug("Removed answer for %s" % var) 112 | except: pass 113 | logger.debug("yielding %s" % child_answers) 114 | yield child_answers 115 | if goal.predicate == cut: 116 | logger.debug("Cut encountered, stopping back tracking") 117 | yield None 118 | -------------------------------------------------------------------------------- /src/trimlogic/counting.py: -------------------------------------------------------------------------------- 1 | def choose(l, pick): 2 | if pick == 0: 3 | yield [] 4 | return 5 | for i in xrange(len(l)): 6 | item = l[i] 7 | for l2 in choose(l[(i+1):], pick-1): 8 | l2.insert(0, item) 9 | yield l2 10 | del l2[0] 11 | 12 | def permute(l, pick=None): 13 | if pick == None: pick = len(l) 14 | if pick == 0: 15 | yield [] 16 | return 17 | for i in xrange(len(l)): 18 | item = l[i] 19 | del l[i] 20 | for l2 in permute(l, pick-1): 21 | l2.insert(0, item) 22 | yield l2 23 | del l2[0] 24 | l.insert(i, item) 25 | -------------------------------------------------------------------------------- /src/trimlogic/foil.py: -------------------------------------------------------------------------------- 1 | import math, logging, sys, itertools, operator, time 2 | from trimlogic.term import UniqueVariableFactory, VariableFactory, Var 3 | from trimlogic.term import Atom, Pred 4 | from trimlogic.algorithm import fol_bc_ask 5 | from trimlogic.counting import choose, permute 6 | from trimlogic.predicate import Rule, MutableRule 7 | from trimlogic.partialordering import find_ordering 8 | from trimlogic.partialordering import create_partial_comparator 9 | logger = logging.getLogger(__name__) 10 | 11 | ############################################################################## 12 | # Constants. 13 | ############################################################################## 14 | NEW_VARIABLE_GAIN_BIAS = 0.001 15 | MINIMUM_LITERAL_GAIN_TO_ADD = 0.80 16 | 17 | ############################################################################## 18 | # Data strutures for storing and managing positive and negative examples. 19 | ############################################################################## 20 | class ExampleTree: 21 | """ 22 | This class represents a single base example and its extensions. This allows 23 | for greater compaction of the example data. The class also provides 24 | functionality to produce the actual extensions of the base example through 25 | the extend method. 26 | """ 27 | class Node: 28 | 29 | def __init__(self, variables, values, parent=None): 30 | self.parent = parent 31 | self.values = values 32 | self.variables = variables 33 | self.children = [] 34 | 35 | def __str__(self): 36 | return ("" 37 | % (self.values,len(self.children))) 38 | 39 | def __repr__(self): return str(self) 40 | 41 | def __init__(self, variables, values): 42 | assert len(variables) == len(values) 43 | self.root = ExampleTree.Node(variables, values) 44 | self.variables = variables 45 | self.size = len(variables) 46 | self.levels = 1 47 | self._count_list = [1] 48 | 49 | def extend(self, goals, variables): 50 | """ 51 | Extends the current example by using the fol_bc_ask algorithm to determine 52 | values for the new variables. 53 | 54 | @param goals: The goals by which the extension of the example is to be 55 | calculated. 56 | @type goals: A list of Term objects. 57 | @param variables: The new variables in the extension. 58 | @type variables: A list of Var objects. 59 | @return: The number of examples created from extending. 60 | """ 61 | logger.debug("extend( " + str(goals) + ", " + str(variables) + " )") 62 | extension_count = 0 63 | examples_count = 0 64 | for node, bindings in self.enumerate_nodes_bindings(): 65 | extended = False 66 | logger.debug("calling " + str(fol_bc_ask) + " with '" + str(goals) 67 | + "' '" + str(bindings) + "'") 68 | for answer in fol_bc_ask(goals, bindings): 69 | if answer != None: 70 | examples_count += 1 71 | extended = True 72 | value = [] 73 | for var in variables: 74 | value.append(answer[var]) 75 | node.children.append(ExampleTree.Node(variables, value, node)) 76 | if extended: 77 | extension_count += 1 78 | self.levels += 1 79 | self._count_list.append(examples_count) 80 | logger.debug("Performed " + str(extension_count) + " extensions.") 81 | return extension_count 82 | 83 | def enumerate_nodes_bindings(self, i=None, node=None, bindings=None): 84 | if i == None: i = self.levels - 1 85 | if node == None: node = self.root 86 | if bindings == None: bindings = {} 87 | logger.debug("enumerate_nodes_bindings( " + str(i) + ", " + str(node) + ", " + str(bindings) + " )") 88 | for var,const in zip(node.variables, node.values): 89 | assert not bindings.has_key(var) 90 | bindings[var] = const 91 | logger.debug("Bindings '" + str(bindings) + "'.") 92 | if i == 0: 93 | yield node, bindings 94 | return 95 | k = i - 1 96 | for child in node.children: 97 | for x in self.enumerate_nodes_bindings(k, child, bindings): 98 | yield x 99 | for var in child.variables: 100 | del bindings[var] 101 | 102 | def enumerate_examples(self, level=None, node=None): 103 | if level == None: level = self.levels - 1 104 | if node == None: node = self.root 105 | logger.debug("enumerate_examples( " + str(level) + ", " 106 | + str(node) + " )") 107 | if level == 0: 108 | yield node.values 109 | return 110 | k = level - 1 111 | for child in node.children: 112 | for l in self.enumerate_examples(k, child): 113 | yield node.values + l 114 | 115 | def __iter__(self): 116 | return self.enumerate_examples() 117 | 118 | def cut_levels(self, i): 119 | self._cut_levels(i, self.root) 120 | 121 | def _cut_levels(self, i, node): 122 | if i == 1: 123 | for child in node.children: 124 | self.unlink_subtree(child) 125 | return 126 | k = i - 1 127 | for child in node.children: 128 | self._cut_levels(k, child) 129 | 130 | def rollback(self): 131 | logger.debug("rollback( )") 132 | self.cut_levels(self.levels-1) 133 | self.levels -= 1 134 | self._count_list.pop() 135 | 136 | def reset(self): 137 | while self.levels > 1: 138 | self.rollback() 139 | 140 | def unlink_subtree(self, node): 141 | s = [node] 142 | while s: 143 | n = s.pop() 144 | n.parent.children.remove(n) 145 | n.parent = None 146 | s.extend(n.children) 147 | 148 | def __len__(self): 149 | return self._count_list[self.levels-1] 150 | 151 | def __str__(self): 152 | s = "" 153 | for example in self.enumerate_examples(): 154 | s += str(example) 155 | s += "\n" 156 | return s 157 | 158 | 159 | class ExampleCollection: 160 | 161 | def __init__(self, predicate, formals, examples): 162 | self.predicate = predicate 163 | self._examples = [] 164 | for example in examples: 165 | self._examples.append(ExampleTree(formals, example)) 166 | self.variables = formals 167 | 168 | def __len__(self): 169 | return sum(map(len, self._examples)) 170 | 171 | def __iter__(self): 172 | return itertools.chain(*self._examples) 173 | 174 | def rollback(self): 175 | for tree in self._examples: 176 | tree.rollback() 177 | 178 | def extend(self, goals, variables): 179 | return sum(map(lambda t: t.extend(goals, variables), self._examples)) 180 | 181 | def reset(self): 182 | map(lambda x: x.reset(), self._examples) 183 | 184 | def prune_covered(self): 185 | logger.debug("Examples to consider for pruning: ") 186 | logger.debug(str(self._examples)) 187 | for ex in self._examples[:]: 188 | prune = False 189 | logger.debug("Calling fol_bc_ask(" + str(self.predicate(*ex.root.values)) + ", {})") 190 | for answer in fol_bc_ask([self.predicate(*ex.root.values)], {}): 191 | logger.debug("fol_bc_ask(" + str(self.predicate(*ex.root.values)) 192 | + ", {}) -> " + str(answer)) 193 | if answer != None and answer != False: 194 | self._examples.remove(ex) 195 | prune = True 196 | logger.debug("Pruned '%s'." % ex) 197 | break 198 | if not prune: logger.debug("Did not prune '%s'." % ex) 199 | 200 | def __repr__(self): 201 | s = "" 202 | for e in self._examples: 203 | s += str(e) 204 | return s 205 | 206 | def __str__(self): 207 | return self.__repr__() 208 | 209 | 210 | class TrainingSet: 211 | 212 | def __init__(self, predicate, formals, positive_examples, negative_examples): 213 | self.positive_examples = self._insureExampleCollection(predicate, 214 | formals, 215 | positive_examples) 216 | self.negative_examples = self._insureExampleCollection(predicate, 217 | formals, 218 | negative_examples) 219 | self._variables = [formals] 220 | self._extensions = 0 221 | 222 | def get_variables(self): 223 | vars = [] 224 | for l in self._variables: 225 | vars.extend(l) 226 | return vars 227 | 228 | def _insureExampleCollection(self, predicate, formals, examples): 229 | if isinstance(examples, ExampleCollection): 230 | return examples 231 | else: 232 | return ExampleCollection(predicate, formals, examples) 233 | 234 | def rollback(self): 235 | assert self._extensions > 0 236 | self._extensions -= 1 237 | self.positive_examples.rollback() 238 | self.negative_examples.rollback() 239 | self._variables.pop() 240 | 241 | def reset(self): 242 | self._extensions = 0 243 | self._variables = self._variables[:1] 244 | self.positive_examples.reset() 245 | self.negative_examples.reset() 246 | 247 | def has_variable(self, variable): 248 | for variables in self._variables: 249 | if variable in variables: 250 | return True 251 | return False 252 | 253 | def extend(self, goals, variables): 254 | self._extensions += 1 255 | self._variables.append(variables) 256 | return (self.positive_examples.extend(goals, variables), 257 | self.negative_examples.extend(goals, variables)) 258 | 259 | def get_information_measure(self): 260 | return -math.log(float(len(self.positive_examples) + 1) 261 | / float(len(self.negative_examples) 262 | + len(self.positive_examples) + 1), 2) 263 | 264 | def get_maximum_possible_gain(self): 265 | return len(self.positive_examples)*self.get_information_measure() 266 | 267 | def __str__(self): 268 | s = "" 269 | for example in itertools.chain(self.pos_examples, self.neg_examples): 270 | s += str(example) 271 | return s 272 | 273 | variables = property(fget=get_variables) 274 | 275 | 276 | ############################################################################## 277 | # Functions for building a clause. 278 | ############################################################################## 279 | def construct_clause_recursive(predicate, rule, training_set, bk, 280 | variable_factory=None, ordering=None, depth=0): 281 | assert isinstance(rule, Rule) 282 | assert isinstance(training_set, TrainingSet) 283 | logger.debug("Starting construct_clause_recursive(...).") 284 | logger.debug("Training set contains %s positive examples and %s negative " 285 | "examples." % (len(training_set.positive_examples), 286 | len(training_set.negative_examples))) 287 | ordering = find_partial_ordering_of_terms(rule) 288 | if variable_factory == None: variable_factory = UniqueVariableFactory() 289 | head, body = rule.terms, rule.body 290 | if len(training_set.negative_examples) > 0: 291 | logger.debug("Clause so far %s :- %s." % (head, body)) 292 | new_literals, determinate_literals = ( 293 | find_gainful_and_determinate_literals(predicate, 294 | rule, 295 | training_set, 296 | bk, 297 | variable_factory, 298 | ordering)) 299 | logger.debug("New literals: " + str(new_literals)) 300 | gain = new_literals[0][0] 301 | gain_ratio = gain / training_set.get_maximum_possible_gain() 302 | logger.debug("Gain ratio %s." % str(gain_ratio)) 303 | depth += 1 304 | if( gain_ratio < MINIMUM_LITERAL_GAIN_TO_ADD 305 | and len(determinate_literals) > 0 ): 306 | for literal, new_variables in determinate_literals: 307 | for var in new_variables: var.depth = depth 308 | body.append(literal) 309 | training_set.extend(body, new_variables) 310 | logger.debug("Determinate literals added.") 311 | if construct_clause_recursive(predicate, 312 | rule, 313 | training_set, 314 | bk, 315 | ordering=ordering, 316 | variable_factory=variable_factory, 317 | depth=depth): 318 | return True 319 | else: 320 | logger.debug("Adding determinates of no use, back tracking.") 321 | for i in xrange(len(determinate_literals)): 322 | training_set.rollback() 323 | body.pop() 324 | if gain < 0.001: 325 | logger.debug("Returning False because gain of %s < 0.001." % gain) 326 | return False 327 | logger.debug("Gainful literals to try: %s" % new_literals) 328 | for gain, literal, new_variables in new_literals: 329 | for var in new_variables: var.depth = depth 330 | body.append(literal) 331 | training_set.extend(body, new_variables) 332 | logger.debug("Trying solution: " + str((gain, literal, new_variables))) 333 | if construct_clause_recursive(predicate, 334 | rule, 335 | training_set, 336 | bk, 337 | ordering=ordering, 338 | variable_factory=variable_factory, 339 | depth=depth): 340 | return True 341 | logger.debug("Trying next solution.") 342 | training_set.rollback() 343 | body.pop() 344 | return False 345 | elif len(training_set.negative_examples) == 0: 346 | rule.body = tuple(rule.body) 347 | if rule.is_recursive(): 348 | pass 349 | predicate.rules.remove(rule) 350 | predicate.rules.append(rule.immutable_instance) 351 | logger.debug("Found a rule %s." % predicate.rules[len(predicate.rules)-1]) 352 | training_set.reset() 353 | training_set.positive_examples.prune_covered() 354 | return True 355 | raise "Fell through!" 356 | 357 | def variablization(predicate, vars, variable_factory): 358 | if len(vars) == 0 or vars[len(vars)-1].depth < 4: 359 | for i in xrange(1, predicate.arity+1): 360 | for old_vars in choose(vars, i): 361 | new_vars = variable_factory.next_variable_sequence( 362 | predicate.arity-i, 363 | predicate.name[0].upper() 364 | + predicate.name[1:] + '_') 365 | for seq in permute(new_vars + old_vars): 366 | yield predicate(*seq), new_vars 367 | else: 368 | for old_vars in choose(vars, predicate.arity): 369 | for seq in permute(old_vars): 370 | yield predicate(*seq), [] 371 | 372 | def insert_literal((gain, literal, new_variables), literals, length): 373 | if len(literals) == 0: 374 | literals.append((gain, literal, new_variables)) 375 | else: 376 | for i in xrange(len(literals)): 377 | if gain > literals[i][0]: 378 | literals.insert(i, (gain, literal, new_variables)) 379 | break 380 | if len(literals) > length: 381 | literals.pop() 382 | 383 | 384 | ############################################################################## 385 | # Functions for determining the soundness of recursive literals. 386 | ############################################################################## 387 | def determine_param_orderings(predicate): 388 | logger.debug("Determining ordering for: " + str(predicate)) 389 | 390 | def establish_relationship(value_pair, index_pair, op, cmp_map): 391 | x,y = value_pair 392 | i,k = index_pair 393 | if not(cmp_map.has_key((i,k)) and cmp_map[(i,k)] == None): 394 | if op(x,y): 395 | if cmp_map.has_key((i,k)): 396 | if op != cmp_map[(i,k)]: 397 | cmp_map[(i,k)] = None 398 | else: 399 | cmp_map[(i,k)] = op 400 | elif cmp_map.has_key((i,k)) and op == cmp_map[(i,k)]: 401 | cmp_map[(i,k)] = None 402 | logger.debug(str(cmp_map)) 403 | # end establish_relationship 404 | 405 | v = UniqueVariableFactory() 406 | type_map = {} 407 | types = predicate.param_types 408 | for i, type in zip(range(0, len(types)), types): 409 | if not type_map.has_key(type): 410 | type_map[type] = [] 411 | type_map[type].append(i) 412 | pairs = [] 413 | for type in type_map.keys(): 414 | for x in choose(type_map[type], 2): 415 | pairs.append(list(x)) 416 | cmp_map = {} 417 | variables = v.next_variable_sequence(predicate.arity) 418 | logger.debug("Calling: fol_bc_ask( %s )" % predicate(*variables)) 419 | for answer in fol_bc_ask([predicate(*variables)], {}): 420 | logger.debug("Answer: " + str(answer)) 421 | for pair in pairs: 422 | i,k = pair 423 | x,y = answer[variables[i]], answer[variables[k]] 424 | logger.debug("Comparing %s and %s." % (x, y)) 425 | try: 426 | for op in [operator.lt, operator.gt, operator.eq]: 427 | establish_relationship((x,y), (i,k), op, cmp_map) 428 | except: 429 | pass 430 | return cmp_map 431 | 432 | def create_unique_variable_sequence(prefix, length): 433 | return map(lambda x: Var(prefix + length), range(0, length)) 434 | 435 | def find_partial_ordering_of_terms(rule): 436 | logger.debug("Finding ordering for rule '%s'." % rule) 437 | constraints = [] 438 | for literal in rule.body: 439 | if isinstance(literal, Pred): 440 | predicate = literal.predicate 441 | if predicate.param_orderings: 442 | for key in predicate.param_orderings.keys(): 443 | x = literal.terms[key[0]] 444 | y = literal.terms[key[1]] 445 | constraints.append((predicate.param_orderings[key], x, y)) 446 | logger.debug("Found constraints: " + str(constraints)) 447 | return create_partial_comparator(constraints) 448 | 449 | def will_halt(predicate, recursive_literal, variables, ordering=None): 450 | will_halt = False 451 | head_terms = variables[:predicate.arity] 452 | recr_terms = recursive_literal.terms 453 | logger.debug("Determining if %s > %s." % (predicate(*head_terms), 454 | recursive_literal)) 455 | for i in xrange(predicate.arity): 456 | if head_terms[i] == recr_terms[i]: 457 | continue 458 | if not recr_terms[i] in ordering: 459 | logger.debug("Term '%s' not in ordering." % recr_terms[i]) 460 | will_halt = False 461 | break 462 | if not head_terms[i] in ordering: 463 | logger.debug("Term '%s' not in ordering." % head_terms[i]) 464 | will_halt = False 465 | break 466 | if ordering.eq(head_terms[i], recr_terms[i]): 467 | continue 468 | if ordering.gt(head_terms[i], recr_terms[i]): 469 | will_halt = True 470 | break 471 | elif ordering.lt(head_terms[i], recr_terms[i]): 472 | will_halt = False 473 | break 474 | if will_halt: 475 | logger.debug("Found that %s > %s." % (predicate(*head_terms), 476 | recursive_literal)) 477 | return will_halt 478 | 479 | def gen_variablization_space(predicate, 480 | path_finding_func, 481 | variables, 482 | variable_factory, 483 | parameters=None, 484 | new_variable_positions=None): 485 | if parameters == None: 486 | parameters = variable_factory.next_variable_sequence(predicate.arity) 487 | new_variable_positions = range(len(parameters)) 488 | for i in xrange(len(parameters)): 489 | new_var = parameters[i] 490 | pos = new_variable_positions[i] 491 | del new_variable_positions[i] 492 | for old_var in variables: 493 | parameters[i] = old_var 494 | for x in gen_variablization_space(predicate, 495 | path_finding_func, 496 | variables, 497 | variable_factory, 498 | parameters, 499 | new_variable_positions): 500 | yield x 501 | new_variable_positions.insert(i, pos) 502 | parameters[i] = new_var 503 | else: 504 | new_variables = [] 505 | for k in new_variable_positions: 506 | new_variables.append(parameters[k]) 507 | literal = predicate(*parameters) 508 | yield (literal, new_variables) 509 | if path_finding_func(literal, new_variables): 510 | for i in xrange(len(new_variable_positions)): 511 | pos = new_variable_positions[i] 512 | del new_variable_positions[i] 513 | new_var = parameters[pos] 514 | for old_var in variables: 515 | parameters[pos] = old_var 516 | for x in gen_variablization_space(predicate, 517 | path_finding_func, 518 | variables, 519 | variable_factory, 520 | parameters, 521 | new_variable_positions): 522 | yield x 523 | new_variable_positions.insert(i, pos) 524 | parameters[pos] = new_var 525 | 526 | def find_gainful_and_determinate_literals(predicate, 527 | rule, 528 | training_set, 529 | bk, 530 | variable_factory, 531 | ordering=None, 532 | clause=None, 533 | determinate_literals=None, 534 | new_literals=None, 535 | grab_size=1): 536 | logger.debug("Finding a new literal.") 537 | logger.debug("Rules so far:") 538 | for rule in predicate.rules: 539 | logger.debug(str(rule)) 540 | head, body = rule.terms, rule.body 541 | if determinate_literals == None: determinate_literals = [] 542 | if new_literals == None: new_literals = [] 543 | variables = training_set.variables 544 | current_depth = 0 545 | best_literal = None 546 | best_gain = -10000 547 | best_new_variables = None 548 | old_info_value = training_set.get_information_measure() 549 | for next_predicate in bk: 550 | continue_search = True 551 | def path_finding_func(literal, new_variables): 552 | return continue_search 553 | for literal, new_variables in gen_variablization_space(next_predicate, 554 | path_finding_func, 555 | variables, 556 | variable_factory): 557 | continue_search = True 558 | if predicate == next_predicate: 559 | if not will_halt(predicate, 560 | literal, 561 | training_set.variables, 562 | ordering): 563 | logger.debug("Adding '%s' may lead to infinit recursion." 564 | % literal) 565 | continue 566 | else: 567 | logger.debug("Adding '%s' will not lead to infinit recursion." 568 | % literal) 569 | body.append(literal) 570 | old_len_pos = len(training_set.positive_examples) 571 | old_len_neg = len(training_set.negative_examples) 572 | s_pos, s_neg = training_set.extend(body, new_variables) 573 | s = s_pos + s_neg 574 | new_info_value = training_set.get_information_measure() 575 | gain = foil_gain(s, old_info_value, new_info_value) 576 | if new_variables > 0: gain += NEW_VARIABLE_GAIN_BIAS 577 | if (len(training_set.positive_examples) > 0 578 | and len(training_set.negative_examples) == 0): 579 | """ 580 | Return 'literal' as the best literal as it excludes all negative 581 | examples but includes at least 1 positive example. We do this 582 | primarily to reduce the complexity of individual rules as well as 583 | prevent excessive branching. 584 | """ 585 | training_set.rollback() 586 | body.pop() 587 | logging.debug("Found literal '%s' which completes subset of " 588 | "relation, choosing as best literal." % literal) 589 | return ([(gain, literal, new_variables)], []) 590 | if (s_pos == old_len_pos == len(training_set.positive_examples) 591 | and s_neg <= old_len_neg 592 | and s_neg == len(training_set.negative_examples) 593 | and len(new_variables) > 0): 594 | determinate_vars = variable_factory.next_variable_sequence( 595 | len(new_variables)) 596 | remap_bindings = {} 597 | for nvar,dvar in zip(new_variables, determinate_vars): 598 | remap_bindings[nvar] = dvar 599 | determinate_literals.append((literal.apply_bindings(remap_bindings), 600 | determinate_vars)) 601 | logger.debug("Considering literal '%s' with gain of %s which yields" 602 | " %s positive and %s negative extensions." % (literal, 603 | gain, 604 | s_pos, 605 | s_neg)) 606 | if len(training_set.positive_examples) > 0: 607 | insert_literal((gain, literal, new_variables), 608 | new_literals, grab_size) 609 | training_set.rollback() 610 | body.pop() 611 | if s * old_info_value < best_gain: 612 | continue_search = False 613 | logger.debug("Best literal found '%s'." % best_literal) 614 | logger.debug("Determinate literals found '%s'." % determinate_literals) 615 | return (new_literals, determinate_literals) 616 | 617 | def foil_gain(s, old_information_value, new_information_value): 618 | return s * ( old_information_value - new_information_value ) 619 | 620 | ############################################################################## 621 | # Functions for cleaning learned perdicate rules. 622 | ############################################################################## 623 | def determine_covered(predicate, tuple): 624 | for x in fol_bc_ask([predicate(*tuple)], {}): 625 | return True 626 | return False 627 | 628 | def determine_tuples_covered(predicate, tuples): 629 | tuples = [x for x in tuples if determine_covered(predicate, x)] 630 | return tuples 631 | 632 | def determine_tuples_covered_same_or_better(predicate, 633 | positive_tuples, 634 | negative_tuples, 635 | positive_subset=None, 636 | negative_subset=None): 637 | if positive_subset == None: positive_subset = positive_tuples 638 | if negative_subset == None: negative_subset = negative_tuples 639 | covered_pos = determine_tuples_covered(predicate, positive_tuples) 640 | covered_neg = determine_tuples_covered(predicate, negative_tuples) 641 | if( reduce(lambda x,y: x and y, 642 | map(lambda x: x in covered_pos, positive_subset), True) 643 | and reduce(lambda x,y: x and y, 644 | map(lambda x: x in negative_subset, covered_neg), 645 | True) ): 646 | return True 647 | return False 648 | 649 | def will_rule_halt(predicate, rule): 650 | body_temp = [] 651 | for term in rule.body: 652 | if hasattr(term, 'predicate') and term.predicate == predicate: 653 | ordering = find_partial_ordering_of_terms( 654 | Rule(predicate, rule.terms, body_temp)) 655 | if not will_halt(predicate, term, rule.terms, ordering): 656 | return False 657 | body_temp.append(term) 658 | else: 659 | body_temp.append(term) 660 | return True 661 | 662 | def predicate_rules_postprocessing_compact(predicate, 663 | positive_tuples, 664 | negative_tuples): 665 | """ 666 | Removes redundant rules and terms from a predicate. This is an approximation 667 | algorithm since the optimal result is undecidable. 668 | 669 | @param predicate: The predicate whose rules are to be compacted. 670 | @type predicate: Predicate 671 | @param positive_tuples: Tuples covered by the predicate. These are used to 672 | insure that compaction does not affect the coverage of the predicate. 673 | @param negative_tuples: Tuples not covered by the predicate. These are used 674 | it insure that compaction does not affect the coverage of the predicate. 675 | """ 676 | logger.debug("Start " + predicate_rules_postprocessing_compact.func_name) 677 | rules = predicate.rules 678 | predicate.rules = [] 679 | for rule in rules: 680 | new_rule = rule 681 | predicate.rules.append(rule) 682 | covered_positive_tuples, covered_negative_tuples = ( 683 | map(lambda x: determine_tuples_covered(predicate, x), 684 | [positive_tuples, negative_tuples])) 685 | predicate.rules.pop() 686 | for i in xrange(1, len(rule.body)): 687 | found_rule = False 688 | for sub_rule_body in choose(rule.body, len(rule.body) - i): 689 | sub_rule = Rule(predicate, rule.terms, sub_rule_body) 690 | if not will_rule_halt(predicate, sub_rule): 691 | continue 692 | predicate.rules.append(sub_rule) 693 | better_rule = determine_tuples_covered_same_or_better(predicate, 694 | positive_tuples, 695 | negative_tuples, 696 | covered_positive_tuples, 697 | covered_negative_tuples) 698 | predicate.rules.pop() 699 | if better_rule: 700 | new_rule = sub_rule 701 | found_rule = True 702 | break 703 | if not found_rule: 704 | break 705 | predicate.rules.append(new_rule) 706 | rule_removed = True 707 | while rule_removed: 708 | rule_removed = False 709 | for i in xrange(len(predicate.rules)): 710 | rule = predicate.rules[i] 711 | del predicate.rules[i] 712 | if determine_tuples_covered_same_or_better(predicate, 713 | positive_tuples, 714 | negative_tuples): 715 | rule_removed = True 716 | break 717 | else: 718 | predicate.rules.insert(i, rule) 719 | logger.debug("End " + predicate_rules_postprocessing_compact.func_name) 720 | 721 | def predicate_rules_postprocessing(predicate, 722 | positive_tuples, 723 | negative_tuples): 724 | s = time.clock() 725 | predicate_rules_postprocessing_compact(predicate, 726 | positive_tuples, 727 | negative_tuples) 728 | f = time.clock() 729 | logger.debug(predicate_rules_postprocessing_compact.func_name 730 | + " completed in %s seconds." % (f-s)) 731 | 732 | ############################################################################## 733 | # Main entry point for the FOIL algorithm. 734 | ############################################################################## 735 | def foil(predicate, positive_tuples, negative_tuples, bk, ordering=None): 736 | s = time.clock() 737 | foil_main(predicate, positive_tuples, negative_tuples, bk, ordering) 738 | f = time.clock() 739 | logger.debug(foil_main.func_name 740 | + " completed in %s seconds." % (f-s)) 741 | predicate_rules_postprocessing(predicate, positive_tuples, negative_tuples) 742 | 743 | def foil_main(predicate, positive_tuples, negative_tuples, bk, ordering=None): 744 | arity = predicate.arity 745 | clauses = set([]) 746 | variable_factory = UniqueVariableFactory() 747 | params = variable_factory.next_variable_sequence(predicate.arity, 748 | prefix="PARAM_") 749 | training_set = TrainingSet(predicate, 750 | params, 751 | positive_tuples, 752 | negative_tuples) 753 | while len(training_set.positive_examples) > 0: 754 | head = tuple(training_set.variables) 755 | for x in head: x.depth = 0 756 | body = [] 757 | rule = MutableRule(predicate, head, body) 758 | predicate.rules.append(rule) 759 | construct_clause_recursive(predicate, 760 | rule, 761 | training_set, 762 | bk, 763 | variable_factory=variable_factory, 764 | ordering=ordering) 765 | -------------------------------------------------------------------------------- /src/trimlogic/graph.py: -------------------------------------------------------------------------------- 1 | from trimlogic.algorithm import fol_bc_ask 2 | from trimlogic.term import VariableFactory 3 | from trimlogic.predicate import * 4 | from trimlogic.stdlib import * 5 | 6 | def consult(query): 7 | print str(query) 8 | for answer in fol_bc_ask(query, {}): print "Answer: " + str(answer) 9 | 10 | v = VariableFactory() 11 | edge_list = ( (0,1), (1,2), (0,3), (3,2), (3,4), (4,5), (4,6), (7,6), (6,8), (7,8) ) 12 | linkedto = RuleBasedPredicate('linked-to') 13 | map(lambda x: linkedto.add_rule( Head = x ), edge_list) 14 | 15 | canreach = RuleBasedPredicate('can-reach') 16 | canreach.add_rule( Head=( v.X, v.Y ), 17 | Body=( linkedto(v.X, v.Y), ) ) 18 | canreach.add_rule( Head=( v.X, v.Y ), 19 | Body=( linkedto(v.X, v.Z), canreach(v.Z, v.Y) ) ) 20 | 21 | consult( [linkedto(1,2)] ) 22 | consult( [linkedto(2,1)] ) 23 | consult( [canreach(0,v.X)] ) 24 | -------------------------------------------------------------------------------- /src/trimlogic/partialordering.py: -------------------------------------------------------------------------------- 1 | import operator 2 | 3 | class Node: 4 | 5 | def __init__(self, values): 6 | self.values = values 7 | self.edge_list = set() 8 | self.inverse_edge_list = set() 9 | 10 | def add_target(self, node): 11 | self.edge_list.add(node) 12 | node.inverse_edge_list.add(self) 13 | 14 | def remove_target(self, node): 15 | self.edge_list.remove(node) 16 | node.inverse_edge_list.remove(self) 17 | 18 | def has_target(self, node): 19 | return self.edge_list.contains(node) 20 | 21 | def clear(self): 22 | self.values = None 23 | self.edge_list = None 24 | self.inverse_edge_list = None 25 | 26 | def __hash__(self): 27 | return sum(map(hash, self.values)) 28 | 29 | def __repr__(self): 30 | s = "(" + str(self.values) + ", {" 31 | if self.edge_list: 32 | for node in self.edge_list: 33 | s += str(node.values) + ", " 34 | s += "})" 35 | return s 36 | 37 | 38 | class DirectedGraph(object): 39 | 40 | def __init__(self): 41 | self.object_node_map = {} 42 | self.nodes = set([]) 43 | 44 | def enumerate_nodes(self): 45 | for node in self.object_node_map.values: 46 | yield node 47 | 48 | def remove_node(self, node): 49 | self.nodes.remove(node) 50 | for value in node.values: 51 | if self.object_node_map[value] == Node: 52 | del self.object_node_map[value] 53 | for targeting_node in node.inverse_edge_list: 54 | targeting_node.edge_list.remove(node) 55 | for targeted_node in node.edge_list: 56 | targeted_node.inverse_edge_list.remove(node) 57 | node.clear() 58 | 59 | def remove(self, a): 60 | self.remove_node(self._get_or_create_node(a)) 61 | 62 | def merge(self, a, b): 63 | node1 = self._get_or_create_node(a) 64 | node2 = self._get_or_create_node(b) 65 | if node1 != node2: 66 | new_node = Node(node1.values + node2.values) 67 | for old_node in [node1, node2]: 68 | # Adjust value to node mapping so that values for old_node now point to new_node. 69 | for value in old_node.values: 70 | self.object_node_map[value] = new_node 71 | # make all nodes targeting old_node now target new_node. 72 | for targeting_node in old_node.inverse_edge_list: 73 | targeting_node.edge_list.add(new_node) 74 | # make all targets of old_node the targets of new_node. 75 | for target in old_node.edge_list(): 76 | new_node.add_target(target) 77 | # clear out the references held by old_node to make the GC happy. 78 | self.remove_node(old_node) 79 | return new_node 80 | 81 | def insert_edge(self, a, b): 82 | node1 = self._get_or_create_node(a) 83 | node2 = self._get_or_create_node(b) 84 | node1.add_target(node2) 85 | return (node1, node2) 86 | 87 | def _get_or_create_node(self, a): 88 | if not self.object_node_map.has_key(a): 89 | node = Node([a]) 90 | self.object_node_map[a] = node 91 | self.nodes.add(node) 92 | return self.object_node_map[a] 93 | 94 | def __del__(self): 95 | """ 96 | Removes cyclic references between nodes and then calls the object.__del__() 97 | method to finish garbage collection. 98 | """ 99 | for node in self.nodes: 100 | node.clear() 101 | 102 | 103 | def _topological_sort(graph): 104 | ordering = [] 105 | while graph.nodes: 106 | removable_nodes = filter(lambda x: len(x.inverse_edge_list) == 0, graph.nodes) 107 | if not(removable_nodes): 108 | raise "not an acyclic graph" 109 | while removable_nodes: 110 | node = removable_nodes.pop() 111 | ordering.extend(node.values) 112 | graph.remove_node(node) 113 | return ordering 114 | 115 | 116 | class PartialOrdering: 117 | 118 | def __init__(self, graph): 119 | self.graph = graph 120 | 121 | def lt(self, a, b): 122 | stack = [] 123 | stack.append(self.graph.object_node_map[a]) 124 | while stack: 125 | node = stack.pop() 126 | if b in node.values: 127 | return True 128 | stack.extend(node.edge_list) 129 | return False 130 | 131 | def gt(self, a, b): 132 | stack = [] 133 | stack.append(self.graph.object_node_map[a]) 134 | while stack: 135 | node = stack.pop() 136 | if b in node.values: 137 | return True 138 | stack.extend(node.inverse_edge_list) 139 | return False 140 | 141 | def eq(self, a, b): 142 | return self.graph.object_node_map[a] == self.graph.object_node_map[b] 143 | 144 | def __contains__(self, a): 145 | return self.graph.object_node_map.has_key(a) 146 | 147 | 148 | def create_partial_comparator(constraints): 149 | graph = _build_graph(constraints) 150 | return PartialOrdering(graph) 151 | 152 | def _build_graph(constraints): 153 | graph = DirectedGraph() 154 | for (op, x, y) in constraints: 155 | if op == operator.gt: 156 | graph.insert_edge(y,x) 157 | elif op == operator.lt: 158 | graph.insert_edge(x,y) 159 | elif op == operator.eq and x != y: 160 | graph.merge(x, y) 161 | return graph 162 | 163 | def find_ordering(constraints): 164 | graph = _build_graph(constraints) 165 | ordering = _topological_sort(graph) 166 | return ordering 167 | 168 | _TEST = False 169 | if _TEST: 170 | lt = operator.lt 171 | print str(find_ordering([(lt, 'x', 'y'), 172 | (lt, 'z', 'y'), 173 | (lt, 'x', 'z'), 174 | ])) 175 | -------------------------------------------------------------------------------- /src/trimlogic/predicate.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from trimlogic.term import * 3 | 4 | logger = logging.getLogger() 5 | 6 | def create_python_boolean_predicate(boolean_function, name): 7 | 8 | class PythonBooleanPredicate(RuleBasedPredicate): 9 | 10 | def __init__(self): 11 | RuleBasedPredicate.__init__(self, name) 12 | 13 | def _resolve(self, terms): 14 | try: 15 | if boolean_function(*terms): 16 | yield ({}, [], set()) 17 | except: 18 | pass 19 | 20 | 21 | return PythonBooleanPredicate() 22 | 23 | 24 | class KnowledgeBase: 25 | 26 | def __init__(self): 27 | self._map = {} 28 | 29 | def __getitem__(self, key): 30 | if self._map.has_key(key): 31 | raise "Predicate with name '" + predicate.name + "' already exists." 32 | return self._map[key] 33 | 34 | def remove(self, predicate): 35 | del self._map[predicate.name] 36 | 37 | def add(self, predicate): 38 | self._map[predicate.name] = predicate 39 | 40 | def add_all(self, l): 41 | map(self.add, l) 42 | 43 | def __iter__(self): 44 | return self._map.itervalues() 45 | 46 | 47 | class Predicate: 48 | 49 | def __init__(self, arity = 1): 50 | self.arity = arity 51 | self.param_types = None 52 | self.param_orderings = None 53 | 54 | def contains(self): 55 | pass 56 | 57 | def _resolve(self, terms): 58 | pass 59 | 60 | def __call__(self, *terms): 61 | return pred(self, *terms) 62 | 63 | 64 | class RuleBasedPredicate(Predicate): 65 | 66 | def __init__(self, name = None, types=None): 67 | Predicate.__init__(self) 68 | self.rules = [] 69 | self.name = name 70 | if types: 71 | self.param_types = tuple(types) 72 | if types == None: 73 | self.arity = None 74 | else: 75 | self.arity = len(types) 76 | 77 | def _resolve(self, terms): 78 | from trimlogic.algorithm import unify 79 | logging.debug(str(self) + "._resolve( " + str(terms) + " ) ") 80 | for rule in self.rules: 81 | rule = rule.instantiate() 82 | logging.debug("considering rule: " + str(rule)) 83 | Head, Body, variables = rule.terms, rule.body, rule.variables 84 | mgu = unify(terms, Head, {}) 85 | if mgu != None: 86 | logging.debug("substitutions: " + str(mgu)) 87 | yield (mgu, list(Body), variables) 88 | 89 | def add_rule(self, Head=None, Body=None): 90 | if self.arity == None: self.arity = len(Head) 91 | if Body == None: self.rules.append(Fact(self, Head)) 92 | else: self.rules.append(Rule(self, Head, Body)) 93 | 94 | def __str__(self): 95 | if self.name != None: return self.name 96 | else: return object.__str__(self) 97 | 98 | def __repr__(self): return str(self) 99 | 100 | 101 | class CutPredicate(RuleBasedPredicate, Pred): 102 | """ 103 | This predicate always succeeds yielding no variable bindings, no new goals, and 104 | no new variables. 105 | """ 106 | def __init__(self): 107 | self.name, self.predicate, self.terms = "!", self, () 108 | self.variables = None 109 | self.arity = 0 110 | def _resolve(self, terms): 111 | yield ({}, [], set([])) 112 | def __str__(self): 113 | return str(self.name) 114 | def __repr__(self): 115 | return str(self) 116 | 117 | class FailPredicate(RuleBasedPredicate, Pred): 118 | def __init__(self): 119 | RuleBasedPredicate.__init__(self, "fail") 120 | self.predicate, self.terms = self, () 121 | self.arity = 0 122 | self.variables = None 123 | 124 | class NegationAsFailure(RuleBasedPredicate): 125 | def __init__(self): 126 | """ 127 | neg(Goal) :- Goal,!,fail. 128 | neg(Goal). 129 | """ 130 | RuleBasedPredicate.__init__(self, 'neg') 131 | Goal = Var("Goal") 132 | self.add_rule( Head=( Goal ), 133 | Rule=( Goal, cut, fail ) ) 134 | 135 | class ListPredicate(RuleBasedPredicate): 136 | def __call__(self, *terms): 137 | return ListPred(self, *terms) 138 | 139 | class IsPredicate(Predicate): 140 | def __init__(self): 141 | Predicate.__init__(self) 142 | self.name = "is" 143 | self.arity = 2 144 | def _resolve(self, terms): 145 | x = terms[0] 146 | f = terms[1] 147 | if isinstance(f, Function): 148 | result = f.function(*f.terms) 149 | yield ({ x : result }, [], []) 150 | elif isinstance(x, int): 151 | pass 152 | 153 | class Rule: 154 | 155 | def __init__(self, predicate, terms, body): 156 | self.predicate = predicate 157 | self.terms = tuple(terms) 158 | self.body = tuple(body) 159 | self._variables = None 160 | 161 | def get_variables(self): 162 | if self._variables == None: 163 | variables = [] 164 | find_variables(self.terms, variables) 165 | find_variables(self.body, variables) 166 | self._variables = set(variables) 167 | return self._variables 168 | 169 | def instantiate(self): 170 | instance_var_map = {} 171 | for var in self.variables: instance_var_map[var] = Var.get_unique(var) 172 | new_terms = [] 173 | for term in self.terms: 174 | if isinstance(term, Term): new_terms.append(term.apply_bindings(instance_var_map)) 175 | else: new_terms.append(term) 176 | new_body = [] 177 | for term in self.body: 178 | if isinstance(term, Term): new_body.append(term.apply_bindings(instance_var_map)) 179 | else: new_body.append(term) 180 | return Rule(self.predicate, new_terms, new_body) 181 | 182 | def is_recursive(self): 183 | for literal in self.body: 184 | try: 185 | if literal.predicate == self.predicate: 186 | return True 187 | except e, AttributeError: 188 | pass # literal doesn't have a predicate, but thats okay. 189 | return False 190 | 191 | def __str__(self): 192 | return str(self.predicate) + str(self.terms) + " :- " + ", ".join(map(str, self.body)) + "." 193 | 194 | variables = property(fget=get_variables) 195 | 196 | 197 | class MutableRule(Rule): 198 | 199 | def __init__(self, predicate, terms, body): 200 | self.predicate = predicate 201 | self.terms = terms 202 | self.body = body 203 | 204 | def get_variables(self): 205 | variables = [] 206 | find_variables(self.terms, variables) 207 | find_variables(self.body, variables) 208 | return set(variables) 209 | 210 | def get_immutable_instance(self): 211 | return Rule(self.predicate, self.terms, self.body) 212 | 213 | variables = property(fget=get_variables) 214 | immutable_instance = property(fget=get_immutable_instance) 215 | 216 | 217 | class Fact(Rule): 218 | def __init__(self, predicate, terms): 219 | Rule.__init__(self, predicate, terms, ()) -------------------------------------------------------------------------------- /src/trimlogic/stdlib.py: -------------------------------------------------------------------------------- 1 | def plist(l, Tail=None): 2 | if Tail == None: 3 | Tail = [] 4 | if l == []: 5 | return Tail 6 | return dot( l[0], plist(l[1:], Tail=Tail) ) 7 | 8 | def __define_list_predicates(): 9 | import operator 10 | from trimlogic.term import VariableFactory, Term 11 | from trimlogic.predicate import RuleBasedPredicate, ListPredicate 12 | global dot, car, cdr, cons, append, reverse, components 13 | 14 | v = VariableFactory() 15 | dot = ListPredicate('.') 16 | 17 | car = RuleBasedPredicate('car') 18 | car.add_rule( Head=( dot(v.H, v.T), v.H ) ) 19 | 20 | cdr = RuleBasedPredicate('cdr') 21 | cdr.add_rule( Head=( dot(v.H, v.T), v.T ) ) 22 | 23 | cons = RuleBasedPredicate('cons') 24 | cons.add_rule( Head=( v.H, dot(v.H1, v.T), dot(v.H, dot(v.H1, v.T)) ) ) 25 | cons.add_rule( Head=( v.H, [], dot(v.H, []) ) ) 26 | 27 | components = RuleBasedPredicate('components', (dot, Term, Term)) 28 | components.add_rule( Head=( v.X, v.H, v.T ), Body=( car(v.X, v.H), cdr(v.X, v.T ) ) ) 29 | components.param_orderings = {(0, 2):operator.gt} 30 | 31 | append = RuleBasedPredicate('append') 32 | append.add_rule( Head=( [], v.L, v.L ) ) 33 | append.add_rule( Head=( dot(v.X, v.A), v.B, dot(v.X, v.C) ), 34 | Body=( append(v.A, v.B, v.C), ) ) 35 | 36 | reverse = RuleBasedPredicate('reverse') 37 | reverse.add_rule( Head=( [], [] ) ) 38 | reverse.add_rule( Head=( dot(v.H, v.T), v.Rev ), 39 | Body=( reverse(v.T, v.Trev), append(v.Trev, plist([v.H]), v.Rev) ) ) 40 | 41 | def __define_arithmetic_predicates(): 42 | from trimlogic.predicate import IsPredicate 43 | global est 44 | est = IsPredicate() 45 | 46 | def __define_type_predicates(): 47 | from trimlogic.predicate import RuleBasedPredicate, create_python_boolean_predicate 48 | global is_atom, is_integer, is_number, is_compound, is_list, is_variable, is_atomic 49 | is_atom = create_python_boolean_predicate(lambda x: not isinstance(x, Term) and not isinstance(x, int) and isinstance(x, str) , 'is_atom') 50 | is_integer = create_python_boolean_predicate(lambda x: isinstance(x, int), 'is_integer') 51 | class IsNumberPredicate(RuleBasedPredicate): 52 | def __init__(self): 53 | RuleBasedPredicate.__init__(self, 'is_number') 54 | def _resolve(self, terms): 55 | for x in is_integer._resolve(terms): 56 | yield x 57 | 58 | def __define_algorithmetic_predicates(): 59 | from trimlogic.term import VariableFactory 60 | from trimlogic.predicate import CutPredicate, FailPredicate, RuleBasedPredicate 61 | global fail, cut, neg, eql 62 | v = VariableFactory() 63 | cut = CutPredicate() 64 | fail = FailPredicate() 65 | neg = RuleBasedPredicate('neg') 66 | neg.add_rule( Head=( v.Goal, ), 67 | Body=( v.Goal, cut, fail ) ) 68 | neg.add_rule( Head=( v.Goal, ) ) 69 | eql = RuleBasedPredicate('eql') 70 | eql.add_rule( Head=( v.X, v.X ) ) 71 | 72 | __define_type_predicates() 73 | __define_algorithmetic_predicates() 74 | __define_arithmetic_predicates() 75 | __define_list_predicates() 76 | -------------------------------------------------------------------------------- /src/trimlogic/term.py: -------------------------------------------------------------------------------- 1 | def find_variables(terms, variables=None): 2 | l = None 3 | if not(isinstance(terms, list) or isinstance(terms, tuple)): 4 | l = [terms] 5 | elif isinstance(terms,tuple): 6 | l = list(terms) 7 | else: 8 | l = terms 9 | if variables == None: 10 | variables = [] 11 | for term in l: 12 | if isinstance(term, Pred): 13 | if term.variables == None: 14 | term.variables = set(find_variables(term.terms)) 15 | variables.extend(term.variables) 16 | elif isinstance(term, Var): 17 | variables.append(term) 18 | elif isinstance(term, list) or isinstance(term, tuple): 19 | l.extend(term) 20 | return variables 21 | 22 | 23 | class Term: 24 | 25 | def apply_bindings(self, bindings): 26 | return self 27 | 28 | 29 | class Pred(Term): 30 | 31 | def __init__(self, predicate, *terms): 32 | self.predicate, self.terms = predicate, terms 33 | self.variables = None 34 | 35 | def __repr__(self): 36 | return str(self) 37 | 38 | def __str__(self): 39 | if len(self.terms) > 0: return str(self.predicate) + "(" + str(reduce(lambda x,y: str(x) + ", " + str(y), self.terms)) + ")" 40 | else: return str(self.predicate) 41 | 42 | def __hash__(self): 43 | return hash(self.predicate.name) + len(self.terms) 44 | 45 | def apply_bindings(self, bindings): 46 | new_terms = [] 47 | for term in self.terms: 48 | if isinstance(term, Term): new_terms.append(term.apply_bindings(bindings)) 49 | else: new_terms.append(term) 50 | return self.predicate(*new_terms) 51 | 52 | 53 | class ListPred(Pred): 54 | 55 | def __generate_str(self, l): 56 | if [] == l: 57 | return "" 58 | if isinstance(l, Var): return "|" + str(l) 59 | if not hasattr(l, 'terms'): return "." + str(l) 60 | H, T = l.terms 61 | s = str(H) 62 | st = self.__generate_str(T) 63 | if st != "": s += ", " + st 64 | return s 65 | 66 | def __str__(self): 67 | return "plist([" + self.__generate_str(self) + "])" 68 | 69 | def __eq__(self, other): 70 | return isinstance(other, ListPred) and self.terms[0] == other.terms[0] and self.terms[1] == other.terms[1] 71 | 72 | 73 | class Function(Term): 74 | 75 | def __init__(self, function, *terms): 76 | self.function = function 77 | self.terms = terms 78 | 79 | def apply_bindings(self, bindings): 80 | new_terms = [] 81 | for term in self.terms: 82 | if isinstance(term, Term): new_terms.append(term.apply_bindings(bindings)) 83 | else: new_terms.append(term) 84 | return Function(self.function, *new_terms) 85 | 86 | 87 | class Atom(Term): 88 | 89 | def __init__(self, name): 90 | self.name = name 91 | self._hash = hash(self.name) 92 | 93 | def __str__(self): 94 | return self.name 95 | 96 | def __repr__(self): 97 | return str(self) 98 | 99 | def __eq__(self, other): 100 | return isinstance(other, Atom) and self.name == other.name 101 | 102 | def __hash__(self): 103 | return self._hash 104 | 105 | def apply_bindings(self, bindings): 106 | if bindings.has_key(self): return bindings[self] 107 | else: return self 108 | 109 | 110 | class Var(Term): 111 | 112 | unique_count = 0 113 | 114 | def __init__(self, name): 115 | self.name = name 116 | self._hash_value = hash(self.name) 117 | self.scope = None 118 | 119 | def __str__(self): 120 | return str(self.name) 121 | 122 | def __repr__(self): 123 | return str(self) 124 | 125 | def __eq__(self, other): 126 | return isinstance(other, Var) and self.name == other.name 127 | 128 | def __hash__(self): 129 | return self._hash_value 130 | 131 | def apply_bindings(self, bindings): 132 | if bindings.has_key(self): return bindings[self] 133 | else: return self 134 | 135 | def get_unique(var): 136 | Var.unique_count += 1 137 | return Var('@_' + str(Var.unique_count) + '_' + var.name) 138 | 139 | get_unique = staticmethod(get_unique) 140 | 141 | 142 | class VariableFactory: 143 | 144 | def __getattr__(self, name): 145 | return Var(name) 146 | 147 | 148 | class AtomFactory: 149 | 150 | def __getattr__(self, name): 151 | return Atom(name) 152 | 153 | 154 | class UniqueVariableFactory: 155 | 156 | instance_count = 0 157 | 158 | def __init__(self): 159 | self.variable_map = {} 160 | 161 | def _next_count(self, prefix): 162 | try: 163 | self.variable_map[prefix] += 1 164 | except: 165 | self.variable_map[prefix] = 0 166 | return str(self.variable_map[prefix]) 167 | 168 | def next_variable(self, prefix = "VAR_"): 169 | count = self._next_count(prefix) 170 | var = Var(prefix + str(count)) 171 | return var 172 | 173 | def next_variable_sequence(self, length, prefix="VAR_"): 174 | return map(lambda x: self.next_variable(prefix), range(length)) 175 | 176 | def reset(self): 177 | self.variable_map = {} 178 | 179 | def pred(predicate, *terms): 180 | return Pred(predicate, *terms) 181 | 182 | def func(function, *terms): 183 | return Function(function, *terms) 184 | -------------------------------------------------------------------------------- /src/trimlogic/test.py: -------------------------------------------------------------------------------- 1 | import logging 2 | logging.basicConfig(level=logging.INFO) 3 | 4 | import unittest 5 | 6 | from trimlogic.term import * 7 | from trimlogic.predicate import * 8 | from trimlogic.algorithm import * 9 | from trimlogic.stdlib import * 10 | 11 | class PrologTestCase(unittest.TestCase): 12 | def assertHaveSameElements(self, expected, given): 13 | self.assertEquals(len(expected), len(given)) 14 | for (e1, e2) in zip(expected, given): 15 | self.assertEquals(e1, e2) 16 | 17 | class AlgPredicatesTestCase(PrologTestCase): 18 | def testCutNeg(self): 19 | v = VariableFactory() 20 | a = RuleBasedPredicate('a') 21 | b = RuleBasedPredicate('b') 22 | c = RuleBasedPredicate('c') 23 | a.add_rule( Head=( v.X, v.Y ), 24 | Body=( b(v.X), cut, c(v.Y) ) ) 25 | b.add_rule( Head=( 1, ) ) 26 | b.add_rule( Head=( 2, ) ) 27 | b.add_rule( Head=( 3, ) ) 28 | c.add_rule( Head=( 1, ) ) 29 | c.add_rule( Head=( 2, ) ) 30 | c.add_rule( Head=( 3, ) ) 31 | self.assertHaveSameElements( ({v.Q : 1, v.R : 1}, 32 | {v.Q : 1, v.R : 2}, 33 | {v.Q : 1, v.R : 3}, 34 | None), 35 | list(fol_bc_ask([a(v.Q, v.R)], {})) ) 36 | self.assertHaveSameElements( (None,), list(fol_bc_ask([neg(b(1))], {})) ) 37 | self.assertHaveSameElements( ({},), list(fol_bc_ask([neg(b(4))], {})) ) 38 | def testEql(self): 39 | v = VariableFactory() 40 | self.assertHaveSameElements( ({},), 41 | list(fol_bc_ask([eql(1,1)], {})) ) 42 | self.assertHaveSameElements( (), 43 | list(fol_bc_ask([eql(1,2)], {})) ) 44 | self.assertHaveSameElements( ({v.X : 1},), 45 | list(fol_bc_ask([eql(1,v.X)], {})) ) 46 | 47 | class ListTestCase(PrologTestCase): 48 | def testBasicPredicates(self): 49 | v = VariableFactory() 50 | self.assertHaveSameElements( ({v.X : 1},), 51 | list(fol_bc_ask([car(plist([1, 2, 3, 4]), v.X)], {})) ) 52 | self.assertHaveSameElements( ({v.X : plist([2, 3, 4])},), 53 | list(fol_bc_ask([cdr(plist([1, 2, 3, 4]), v.X)], {})) ) 54 | self.assertHaveSameElements( ({v.X : plist([1, 2])},), 55 | list(fol_bc_ask([cons(1, plist([2]), v.X)], {})) ) 56 | self.assertHaveSameElements( ({v.X : plist([1, 2, 3, 4])},), 57 | list(fol_bc_ask([append(plist([1, 2]), plist([3, 4]), v.X)], {})) ) 58 | def testReversePredicate(self): 59 | v = VariableFactory() 60 | self.assertHaveSameElements( ({v.X : plist([3, 2, 1])},), 61 | list(fol_bc_ask([reverse(plist([1, 2, 3]), v.X)], {})) ) 62 | 63 | class TypePredicatesTestCase(PrologTestCase): 64 | def testNumberPredicates(self): 65 | v = VariableFactory() 66 | self.assertHaveSameElements( ({},), 67 | list(fol_bc_ask([is_integer(1)], {})) ) 68 | self.assertHaveSameElements( (), 69 | list(fol_bc_ask([is_integer("not an integer")], {})) ) 70 | 71 | def testOOP(): 72 | father = RuleBasedPredicate('father') 73 | father.add_rule(Head=('steven', 'john')) 74 | father.add_rule(Head=('olin', 'gayle')) 75 | 76 | mother = RuleBasedPredicate('mother') 77 | mother.add_rule(Head=('gayle', 'john')) 78 | 79 | parent = RuleBasedPredicate('parent') 80 | X, Y, Z = Var('X'), Var('Y'), Var('Z') 81 | parent.add_rule(Head=(X, Y), Body=(pred(mother, X, Y),)) 82 | parent.add_rule(Head=(X, Y), Body=(pred(father, X, Y),)) 83 | logging.debug("running algorithm") 84 | for answer in fol_bc_ask([parent(Z, 'john')], {}): print "Answer: " + str(answer) 85 | logging.debug("done") 86 | 87 | def testReverseList(): 88 | """ 89 | reverse([],[]). % the empty list is its own reverse. Base for induction. 90 | reverse([H|T], Rev) :- reverse(T, Trev), append(Trev, [H], Rev). 91 | """ 92 | X = Var("X") 93 | for query in ([reverse(plist([1,2]), X)],): 94 | print query 95 | for answer in fol_bc_ask(query, {}): print "Answer: " + str(answer) 96 | 97 | def testOOP01(): 98 | for answer in fol_bc_ask([est(Var('X'), func(int.__add__, 1, 2))], {}): print str(answer) 99 | 100 | if __name__ == '__main__': 101 | unittest.main() 102 | -------------------------------------------------------------------------------- /src/trimlogic/test/FamilyTreeTestCase.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import unittest 3 | from trimlogic.test.helper import FoilTestCase 4 | from trimlogic.predicate import KnowledgeBase 5 | from trimlogic.predicate import RuleBasedPredicate 6 | from trimlogic.predicate import VariableFactory, UniqueVariableFactory, AtomFactory 7 | from trimlogic.foil import TrainingSet, find_gainful_and_determinate_literals, construct_clause_recursive, foil, find_partial_ordering_of_terms, determine_param_orderings 8 | from trimlogic.algorithm import fol_bc_ask 9 | from trimlogic.term import Atom 10 | 11 | class FamilyMemberFactory: 12 | 13 | def __getattr__(self, name): 14 | return FamilyMember(name) 15 | 16 | 17 | class FamilyMember(Atom): 18 | 19 | def __init__(self, name): 20 | Atom.__init__(self, name) 21 | 22 | def set_ordering(self, ordering): 23 | self._ordering = ordering 24 | def compare(x,y): 25 | return cmp(self._ordering.index(x), self._ordering.index(y)) 26 | FamilyMember.__cmp__ = compare 27 | 28 | set_ordering = classmethod(set_ordering) 29 | 30 | 31 | class FamilyTreeTestCase(FoilTestCase): 32 | 33 | def setUp(self): 34 | pass 35 | 36 | def loadFamilyTree1(self): 37 | """ 38 | Frank -------- Rebecca George -------- Jannet 39 | | | | | 40 | Abe Alan ------ Joan Jill ------ Bob Tim ---- Tammy 41 | | | | 42 | Sean -------- Jan Tipsy --- Tom 43 | | | 44 | Jane ---- Ian Thomas --- Debrah 45 | | | 46 | Ann Billy 47 | """ 48 | v, a = self.v, self.a 49 | ordering = [a.frank, a.rebecca, a.george, a.jannet, a.abe, a.alan, a.joan, a.jill, a.bob, a.tim, a.tammy, a.sean, a.jan, a.tipsy, a.tom, a.jane, a.ian, a.thomas, a.debrah, a.ann, a.billy] 50 | FamilyMember.set_ordering(ordering) 51 | self.father.rules = [] 52 | self.father.param_orderings = None 53 | map(self.father.add_rule, [(a.frank, a.abe), 54 | (a.frank, a.alan), 55 | (a.alan, a.sean), 56 | (a.sean, a.jane), 57 | (a.george, a.bob), 58 | (a.george, a.tim), 59 | (a.bob, a.jan), 60 | (a.tim, a.tom), 61 | (a.tom, a.thomas), 62 | (a.ian, a.ann), 63 | (a.thomas, a.billy)]) 64 | self.mother.rules = [] 65 | self.mother.param_orderings = None 66 | map(self.mother.add_rule, [(a.rebecca, a.alan), 67 | (a.rebecca, a.abe), 68 | (a.joan, a.sean), 69 | (a.jane, a.ann), 70 | (a.jannet, a.tim), 71 | (a.jannet, a.bob), 72 | (a.tammy, a.tom), 73 | (a.tipsy, a.thomas), 74 | (a.debrah, a.billy), 75 | (a.jill, a.jan), 76 | (a.jan, a.jane)]) 77 | for predicate in [self.mother, self.father]: 78 | try: 79 | predicate.param_orderings = determine_param_orderings(predicate) 80 | except: 81 | pass 82 | 83 | def setUp(self): 84 | self.v, self.a = VariableFactory(), FamilyMemberFactory() 85 | self.kb = KnowledgeBase() 86 | v, a = self.v, self.a 87 | self.father = RuleBasedPredicate('father', (FamilyMember, FamilyMember)) 88 | self.mother = RuleBasedPredicate('mother', (FamilyMember, FamilyMember)) 89 | self.kb.add_all([self.mother, self.father]) 90 | 91 | def testFindRecursiveRules(self): 92 | from trimlogic.predicate import Rule, MutableRule 93 | from trimlogic.foil import will_halt 94 | v, a = self.v, self.a 95 | self.loadFamilyTree1() 96 | mother = self.mother 97 | predicate = RuleBasedPredicate('ancestor', (FamilyMember, FamilyMember)) 98 | predicate.rules.append( Rule( predicate, (v.X, v.Y), (mother(v.X, v.Y),) ) ) 99 | current_rule = MutableRule( predicate, (v.X, v.Y), (mother(v.Z, v.Y),) ) 100 | predicate.rules.append(current_rule) 101 | recursive_literal = predicate(v.X, v.Z) 102 | ordering = find_partial_ordering_of_terms(current_rule) 103 | self.assertTrue(will_halt(predicate, recursive_literal, [v.X, v.Y], ordering)) 104 | 105 | def testAncestor(self): 106 | v, a = self.v, self.a 107 | self.loadFamilyTree1() 108 | ancestor = RuleBasedPredicate('ancestor', (FamilyMember, FamilyMember)) 109 | self.kb.add(ancestor) 110 | positive_tuples = [[a.jane, a.ann], 111 | [a.tim, a.tom], 112 | [a.tipsy, a.billy], 113 | [a.jannet, a.billy], 114 | [a.joan, a.jane], 115 | [a.rebecca, a.ann], 116 | [a.frank, a.jane], 117 | [a.jan, a.ann], 118 | [a.jill, a.ann]] 119 | negative_tuples = [[a.tim, a.ann], 120 | [a.ann, a.billy], 121 | [a.jane, a.frank], 122 | [a.tom, a.debrah], 123 | [a.tim, a.tammy], 124 | [a.tom, a.george], 125 | [a.jane, a.joan]] 126 | foil(ancestor, positive_tuples, negative_tuples, self.kb, ordering=None) 127 | self.assertFollows(ancestor(a.bob, a.jane)) 128 | self.assertNotFollows(ancestor(a.bob, a.ian)) 129 | self.kb.remove(ancestor) 130 | self.print_rules(ancestor) 131 | 132 | def testGrandparent(self): 133 | import sys 134 | v, a = self.v, self.a 135 | self.loadFamilyTree1() 136 | grandfather = RuleBasedPredicate('grandfather', (FamilyMember, FamilyMember)) 137 | grandfather.arity = 2 138 | self.kb.add(grandfather) 139 | positive_tuples = [[a.frank, a.sean], 140 | [a.tom, a.billy], 141 | [a.george, a.jan], 142 | [a.bob, a.jane], 143 | [a.sean, a.ann], 144 | [a.frank, a.sean]] 145 | negative_tuples = [[a.jannet, a.tim], 146 | [a.jane, a.alan], 147 | [a.rebecca, a.sean], 148 | [a.jan, a.ann]] 149 | foil(grandfather, positive_tuples, negative_tuples, self.kb, ordering=None) 150 | self.assertFollows(grandfather(a.george, a.tom)) 151 | self.assertFollows(grandfather(a.alan, a.jane)) 152 | self.assertNotFollows(grandfather(a.tipsy, a.billy)) 153 | self.assertNotFollows(grandfather(a.bob, a.ian)) 154 | self.kb.remove(grandfather) 155 | self.print_rules(grandfather) 156 | 157 | 158 | if __name__ == "__main__": 159 | unittest.main() 160 | -------------------------------------------------------------------------------- /src/trimlogic/test/ListTestCase.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | import logging 3 | from trimlogic.test.helper import FoilTestCase 4 | from trimlogic.predicate import KnowledgeBase 5 | from trimlogic.predicate import RuleBasedPredicate 6 | from trimlogic.predicate import VariableFactory, UniqueVariableFactory 7 | from trimlogic.predicate import AtomFactory 8 | from trimlogic.foil import TrainingSet, construct_clause_recursive 9 | from trimlogic.foil import find_gainful_and_determinate_literals, foil 10 | from trimlogic.foil import find_partial_ordering_of_terms 11 | from trimlogic.foil import determine_param_orderings 12 | from trimlogic.algorithm import fol_bc_ask 13 | from trimlogic.term import Atom, Term 14 | from trimlogic.stdlib import components, dot, plist 15 | 16 | class ListTestCase(FoilTestCase): 17 | 18 | def setUp(self): 19 | logging.getLogger("foil_construct_clause").setLevel(logging.DEBUG) 20 | logging.getLogger("foil_determine_ordering").setLevel(logging.DEBUG) 21 | logging.getLogger("foil_new_literal").setLevel(logging.DEBUG) 22 | logging.getLogger("predicate_rules_postprocessing_compact").setLevel(logging.DEBUG) 23 | logging.getLogger("foil_profiling").setLevel(logging.DEBUG) 24 | 25 | def testMemberPredicate(self): 26 | print "" 27 | v, a = VariableFactory(), AtomFactory() 28 | kb = KnowledgeBase() 29 | kb.add(components) 30 | member = RuleBasedPredicate("member", (Term, dot)) 31 | kb.add(member) 32 | positive_tuples = ((1, plist([1,2,3,4])), 33 | (4, plist([1,2,3,4])), 34 | (5, plist([5,4,3,2])), 35 | (2, plist([1,2,3,4])), 36 | (3, plist([1,2,3,4]))) 37 | negative_tuples = ((1, plist([2,3,4,5])), 38 | (2, plist([1,3,4,5])), 39 | (3, plist([1,2,4,5])), 40 | (4, plist([1,2,3,5])), 41 | (5, plist([1,2,3,4])), 42 | (plist([1,2]), plist([1,2,3,4]))) 43 | foil(member, positive_tuples, negative_tuples, kb, ordering=None) 44 | self.assertFollows(member(11, plist([2,3,5,7,11,13]))) 45 | self.assertNotFollows(member(12, plist([2,3,5,7,11,13]))) 46 | self.print_rules(member) 47 | 48 | if __name__ == "__main__": 49 | unittest.main() -------------------------------------------------------------------------------- /src/trimlogic/test/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/johntrimble/foil-python/8628dc3113da66f8990ac673ed280d29cf82246a/src/trimlogic/test/__init__.py -------------------------------------------------------------------------------- /src/trimlogic/test/helper.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | from trimlogic.algorithm import fol_bc_ask 3 | 4 | PRINT_RULES = True 5 | 6 | class FoilTestCase(unittest.TestCase): 7 | def print_rules(self, predicate): 8 | if PRINT_RULES: 9 | print "" 10 | print "Rules for", predicate, ":" 11 | for rule in predicate.rules: 12 | print str(rule) 13 | 14 | def assertFollows(self, arg0, arg1=None): 15 | msg, term = None, None 16 | if isinstance(arg0, str): 17 | msg = arg0 18 | term = arg1 19 | else: 20 | term = arg0 21 | for x in fol_bc_ask([term], {}): 22 | return 23 | if msg: 24 | self.fail(msg) 25 | else: 26 | self.fail() 27 | 28 | def assertNotFollows(self, arg0, arg1=None): 29 | msg, term = None, None 30 | if isinstance(arg0, str): 31 | msg = arg0 32 | term = arg1 33 | else: 34 | term = arg0 35 | for x in fol_bc_ask([term], {}): 36 | if msg: 37 | self.fail(msg) 38 | else: 39 | self.fail() 40 | -------------------------------------------------------------------------------- /src/trimlogic/util.py: -------------------------------------------------------------------------------- 1 | def is_tuple_or_list(x): 2 | return isinstance(x, tuple) or isinstance(x, list) --------------------------------------------------------------------------------