├── .gitignore ├── LICENSE ├── README.md ├── examples ├── __init__.py └── walkthroughs │ ├── pylfppl_walkthrough_1.ipynb │ └── pylfppl_walkthrough_2.ipynb ├── pyppl ├── __init__.py ├── aux │ ├── __init__.py │ ├── ppl_transform_visitor.py │ └── ppl_visitor_template.py ├── backend │ ├── __init__.py │ ├── ppl_code_generator.py │ ├── ppl_graph_codegen.py │ ├── ppl_graph_factory.py │ └── ppl_graph_generator.py ├── distributions.py ├── fe_clojure │ ├── __init__.py │ ├── ppl_clojure_forms.py │ ├── ppl_clojure_lexer.py │ ├── ppl_clojure_parser.py │ ├── ppl_clojure_repr.py │ └── ppl_foppl_parser.py ├── fe_python │ ├── __init__.py │ └── ppl_python_parser.py ├── graphs.py ├── lexer.py ├── parser.py ├── ppl_ast.py ├── ppl_ast_annotators.py ├── ppl_base_model.py ├── ppl_branch_scopes.py ├── ppl_namespaces.py ├── ppl_symbol_table.py ├── tests │ └── factor_tests.py ├── transforms │ ├── __init__.py │ ├── ppl_functions_inliner.py │ ├── ppl_new_simplifier.py │ ├── ppl_raw_simplifier.py │ ├── ppl_simplifier.py │ ├── ppl_static_assignments.py │ ├── ppl_symbol_simplifier.py │ └── ppl_var_substitutor.py ├── types │ ├── __init__.py │ ├── ppl_type_inference.py │ ├── ppl_type_operations.py │ └── ppl_types.py └── utils │ ├── __init__.py │ └── core.py ├── requirements.txt └── setup.py /.gitignore: -------------------------------------------------------------------------------- 1 | *.idea 2 | __pycache__ 3 | *__pycache__ 4 | *.pyc 5 | *.ipynb_checkpoints 6 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2018 Tobias Kohn 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /examples/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bayesianbrad/PyLFPPL/5bc160cb00a3d7e9aa3910367ad88cac6b138d99/examples/__init__.py -------------------------------------------------------------------------------- /pyppl/__init__.py: -------------------------------------------------------------------------------- 1 | # 2 | # This file is part of PyFOPPL, an implementation of a First Order Probabilistic Programming Language in Python. 3 | # 4 | # License: MIT (see LICENSE.txt) 5 | # 6 | # 07. Feb 2018, Tobias Kohn 7 | # 26. Mar 2018, Tobias Kohn 8 | # 9 | from typing import Optional 10 | from . import distributions, parser 11 | from .backend import ppl_graph_generator 12 | 13 | 14 | 15 | def compile_model(source, *, 16 | language: Optional[str]=None, 17 | imports=None, 18 | base_class: Optional[str]=None, 19 | namespace: Optional[dict]=None): 20 | if type(imports) in (list, set, tuple): 21 | imports = '\n'.join(imports) 22 | if namespace is not None: 23 | ns = distributions.namespace.copy() 24 | ns.update(namespace) 25 | namespace = ns 26 | else: 27 | namespace = distributions.namespace 28 | ast = parser.parse(source, language=language, namespace=namespace) 29 | gg = ppl_graph_generator.GraphGenerator() 30 | gg.visit(ast) 31 | return gg.generate_model(base_class=base_class, imports=imports) 32 | 33 | 34 | def compile_model_from_file(filename: str, *, 35 | language: Optional[str]=None, 36 | imports=None, 37 | base_class: Optional[str]=None, 38 | namespace: Optional[dict]=None): 39 | with open(filename) as f: 40 | lines = ''.join(f.readlines()) 41 | return compile_model(lines, language=language, imports=imports, base_class=base_class, 42 | namespace=namespace) 43 | -------------------------------------------------------------------------------- /pyppl/aux/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bayesianbrad/PyLFPPL/5bc160cb00a3d7e9aa3910367ad88cac6b138d99/pyppl/aux/__init__.py -------------------------------------------------------------------------------- /pyppl/aux/ppl_transform_visitor.py: -------------------------------------------------------------------------------- 1 | # 2 | # This file is part of PyFOPPL, an implementation of a First Order Probabilistic Programming Language in Python. 3 | # 4 | # License: MIT (see LICENSE.txt) 5 | # 6 | # 15. Mar 2018, Tobias Kohn 7 | # 11. May 2018, Tobias Kohn 8 | # 9 | from ..ppl_ast import * 10 | from ast import copy_location as _cl 11 | 12 | class TransformVisitor(ScopedVisitor): 13 | 14 | def do_visit_dict(self, items:dict): 15 | result = {} 16 | for key in items: 17 | n_item = self.visit(items[key]) 18 | if n_item is not items[key]: 19 | result[key] = n_item 20 | if len(result) > 0: 21 | return items.copy().update(result) 22 | else: 23 | return items 24 | 25 | def do_visit_items(self, items:list): 26 | use_original = True 27 | result = [] 28 | for item in items: 29 | n_item = self.visit(item) 30 | use_original = use_original and n_item is item 31 | result.append(n_item) 32 | if use_original: 33 | return items 34 | else: 35 | return result 36 | 37 | 38 | def visit_node(self, node: AstNode): 39 | return node 40 | 41 | def visit_attribute(self, node:AstAttribute): 42 | base = self.visit(node.base) 43 | if base is node.base: 44 | return node 45 | else: 46 | return node.clone(base=base) 47 | 48 | def visit_binary(self, node:AstBinary): 49 | left = self.visit(node.left) 50 | right = self.visit(node.right) 51 | if left is node.left and right is node.right: 52 | return node 53 | else: 54 | return node.clone(left=left, right=right) 55 | 56 | def visit_body(self, node:AstBody): 57 | items = self.do_visit_items(node.items) 58 | if items is node.items: 59 | return node 60 | else: 61 | return _cl(makeBody(items), node) 62 | 63 | def visit_call(self, node: AstCall): 64 | function = self.visit(node.function) 65 | args = self.do_visit_items(node.args) 66 | if function is node.function and args is node.args: 67 | return node 68 | else: 69 | return node.clone(function=function, args=args) 70 | 71 | def visit_compare(self, node: AstCompare): 72 | left = self.visit(node.left) 73 | right = self.visit(node.right) 74 | if left is node.left and right is node.right: 75 | return node 76 | else: 77 | return node.clone(left=left, right=right) 78 | 79 | def visit_def(self, node: AstDef): 80 | value = self.visit(node.value) 81 | if value is node.value: 82 | return node 83 | else: 84 | return node.clone(value=value) 85 | 86 | def visit_dict(self, node: AstDict): 87 | items = self.do_visit_dict(node.items) 88 | if items is node.items: 89 | return node 90 | else: 91 | return node.clone(items=items) 92 | 93 | def visit_for(self, node: AstFor): 94 | source = self.visit(node.source) 95 | body = self.visit(node.body) 96 | if source is node.source and body is node.body: 97 | return node 98 | else: 99 | return node.clone(source=source, body=body) 100 | 101 | def visit_function(self, node: AstFunction): 102 | body = self.visit(node.body) 103 | if body is node.body: 104 | return node 105 | else: 106 | return node.clone(body=body) 107 | 108 | def visit_if(self, node: AstIf): 109 | test = self.visit(node.test) 110 | if_node = self.visit(node.if_node) 111 | else_node = self.visit(node.else_node) 112 | if test is node.test and if_node is node.if_node and else_node is node.else_node: 113 | return node 114 | else: 115 | return node.clone(test=test, if_node=if_node, else_node=else_node) 116 | 117 | def visit_let(self, node: AstLet): 118 | source = self.visit(node.source) 119 | body = self.visit(node.body) 120 | if source is node.source and body is node.body: 121 | return node 122 | else: 123 | return node.clone(source=source, body=body) 124 | 125 | def visit_list_for(self, node: AstListFor): 126 | source = self.visit(node.source) 127 | expr = self.visit(node.expr) 128 | if source is node.source and expr is node.expr: 129 | return node 130 | else: 131 | return node.clone(source=source, expr=expr) 132 | 133 | def visit_observe(self, node: AstObserve): 134 | dist = self.visit(node.dist) 135 | value = self.visit(node.value) 136 | if dist is node.dist and value is node.value: 137 | return node 138 | else: 139 | return node.clone(dist=dist, value=value) 140 | 141 | def visit_return(self, node: AstReturn): 142 | value = self.visit(node.value) 143 | if value is node.value: 144 | return node 145 | else: 146 | return node.clone(value=value) 147 | 148 | def visit_sample(self, node: AstSample): 149 | dist = self.visit(node.dist) 150 | if dist is node.dist: 151 | return node 152 | else: 153 | return node.clone(dist=dist) 154 | 155 | def visit_slice(self, node: AstSlice): 156 | base = self.visit(node.base) 157 | start = self.visit(node.start) 158 | stop = self.visit(node.stop) 159 | if base is node.base and start is node.start and stop is node.stop: 160 | return node 161 | else: 162 | return node.clone(base=base, start=start, stop=stop) 163 | 164 | def visit_subscript(self, node: AstSubscript): 165 | base = self.visit(node.base) 166 | index = self.visit(node.index) 167 | if base is node.base and index is node.index: 168 | return node 169 | else: 170 | return node.clone(base=base, index=index) 171 | 172 | def visit_unary(self, node: AstUnary): 173 | item = self.visit(node.item) 174 | if item is node.item: 175 | return node 176 | else: 177 | return node.clone(item=item) 178 | 179 | def visit_vector(self, node: AstVector): 180 | items = self.do_visit_items(node.items) 181 | if items is node.items: 182 | return node 183 | else: 184 | return node.clone(items=items) 185 | 186 | def visit_while(self, node: AstWhile): 187 | test = self.visit(node.test) 188 | body = self.visit(node.body) 189 | if test is node.test and body is node.body: 190 | return node 191 | else: 192 | return node.clone(test=test, body=body) 193 | -------------------------------------------------------------------------------- /pyppl/aux/ppl_visitor_template.py: -------------------------------------------------------------------------------- 1 | # 2 | # This file is part of PyFOPPL, an implementation of a First Order Probabilistic Programming Language in Python. 3 | # 4 | # License: MIT (see LICENSE.txt) 5 | # 6 | # 01. Mar 2018, Tobias Kohn 7 | # 15. Mar 2018, Tobias Kohn 8 | # 9 | from pyppl.ppl_ast import * 10 | 11 | class MyVisitor(Visitor): 12 | """ 13 | This is a visitor-template. Copy/Paste it into a new file and then change it according to your own needs! 14 | """ 15 | 16 | def visit_attribute(self, node:AstAttribute): 17 | return self.visit_node(node) 18 | 19 | def visit_binary(self, node:AstBinary): 20 | return self.visit_node(node) 21 | 22 | def visit_body(self, node:AstBody): 23 | return self.visit_node(node) 24 | 25 | def visit_break(self, node: AstBreak): 26 | return self.visit_node(node) 27 | 28 | def visit_call(self, node: AstCall): 29 | return self.visit_node(node) 30 | 31 | def visit_compare(self, node: AstCompare): 32 | return self.visit_node(node) 33 | 34 | def visit_def(self, node: AstDef): 35 | return self.visit_node(node) 36 | 37 | def visit_dict(self, node: AstDict): 38 | return self.visit_node(node) 39 | 40 | def visit_for(self, node: AstFor): 41 | return self.visit_node(node) 42 | 43 | def visit_function(self, node: AstFunction): 44 | return self.visit_node(node) 45 | 46 | def visit_if(self, node: AstIf): 47 | return self.visit_node(node) 48 | 49 | def visit_import(self, node: AstImport): 50 | return self.visit_node(node) 51 | 52 | def visit_let(self, node: AstLet): 53 | return self.visit_node(node) 54 | 55 | def visit_list_for(self, node: AstListFor): 56 | return self.visit_node(node) 57 | 58 | def visit_observe(self, node: AstObserve): 59 | return self.visit_node(node) 60 | 61 | def visit_return(self, node: AstReturn): 62 | return self.visit_node(node) 63 | 64 | def visit_sample(self, node: AstSample): 65 | return self.visit_node(node) 66 | 67 | def visit_slice(self, node: AstSlice): 68 | return self.visit_node(node) 69 | 70 | def visit_subscript(self, node: AstSubscript): 71 | return self.visit_node(node) 72 | 73 | def visit_symbol(self, node: AstSymbol): 74 | return self.visit_node(node) 75 | 76 | def visit_unary(self, node: AstUnary): 77 | return self.visit_node(node) 78 | 79 | def visit_value(self, node: AstValue): 80 | return self.visit_node(node) 81 | 82 | def visit_value_vector(self, node: AstValueVector): 83 | return self.visit_node(node) 84 | 85 | def visit_vector(self, node: AstVector): 86 | return self.visit_node(node) 87 | 88 | def visit_while(self, node: AstWhile): 89 | return self.visit_node(node) 90 | -------------------------------------------------------------------------------- /pyppl/backend/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bayesianbrad/PyLFPPL/5bc160cb00a3d7e9aa3910367ad88cac6b138d99/pyppl/backend/__init__.py -------------------------------------------------------------------------------- /pyppl/backend/ppl_code_generator.py: -------------------------------------------------------------------------------- 1 | # 2 | # This file is part of PyFOPPL, an implementation of a First Order Probabilistic Programming Language in Python. 3 | # 4 | # License: MIT (see LICENSE.txt) 5 | # 6 | # 02. Mar 2018, Tobias Kohn 7 | # 22. Mar 2018, Tobias Kohn 8 | # 9 | from ..ppl_ast import * 10 | from ..ppl_ast_annotators import get_info 11 | 12 | 13 | def _is_block(node): 14 | if isinstance(node, AstBody): 15 | return len(node) > 1 16 | elif isinstance(node, AstLet): 17 | return True 18 | elif isinstance(node, AstFor) or isinstance(node, AstWhile): 19 | return True 20 | elif isinstance(node, AstDef): 21 | return _is_block(node.value) 22 | elif isinstance(node, AstIf): 23 | return node.has_else or _is_block(node.if_node) or _is_block(node.else_node) 24 | else: 25 | return False 26 | 27 | 28 | def _push_return(node, f): 29 | """ 30 | Rewrites the AST to make a `return` or _assignment_ the last effective statement to be executed. 31 | 32 | For instance, a LISP-based frontend might give us a code fragment such as (in pseudo-code): 33 | `return (let [x = 12] (x * 3))` 34 | We cannot translate this directly to Python, as it would result in invalid code, but have to rewrite it to: 35 | `(let [x = 12] (return x * 3))` 36 | 37 | This pushing of the `return`-statement (or any assignment) into the expression is done by this function. The `node` 38 | parameter stands for the expression (in the example above the `let`-expression), while `f` is a function that 39 | takes a node and wraps it into a `return`. 40 | 41 | Sample usage: `_push_return(let_node, lambda x: AstReturn(x))` 42 | 43 | :param node: The expression into which we need to push the `return` or _assignment_. 44 | :param f: A function that takes one argument of type `AstNode` and returns another `AstNode`-object, usually 45 | by wrapping its argument into an `AstReturn` or `AstDef`. 46 | :return: The original expression with the `return` applied to the last statement to be executed. 47 | """ 48 | if node is None: 49 | return None 50 | elif isinstance(node, AstBody) and len(node) > 1: 51 | return AstBody(node.items[:-1] + [_push_return(node.items[-1], f)]) 52 | elif isinstance(node, AstLet): 53 | return AstLet(node.target, node.source, _push_return(node.body, f)) 54 | elif isinstance(node, AstFor): 55 | return AstFor(node.target, node.source, _push_return(node.body, f)) 56 | elif isinstance(node, AstDef): 57 | return AstDef(node.name, _push_return(node.value, f)) 58 | elif isinstance(node, AstIf): 59 | return AstIf(node.test, _push_return(node.if_node, f), _push_return(node.else_node, f)) 60 | elif isinstance(node, AstWhile): 61 | return AstBody([node, f(AstValue(None))]) 62 | else: 63 | return f(node) 64 | 65 | 66 | def _normalize_name(name): 67 | if type(name) is tuple: 68 | return ', '.join([_normalize_name(n) for n in name]) 69 | result = '' 70 | if name.endswith('?'): 71 | name = 'is_' + name[:-1] 72 | if name.endswith('!'): 73 | name = 'do_' + name[:-1] 74 | for n in name: 75 | if n in ('+', '-', '?', '!', '_'): 76 | result += '_' 77 | elif n == '*': 78 | result += '_STAR_' 79 | elif '0' <= n <= '9' or 'A' <= n <= 'Z' or 'a' <= n <= 'z': 80 | result += n 81 | elif n == '.': 82 | result += n 83 | return result 84 | 85 | 86 | class CodeGenerator(ScopedVisitor): 87 | 88 | def __init__(self): 89 | super().__init__() 90 | self.functions = [] 91 | self.imports = [] 92 | self._symbol_counter_ = 99 93 | self.short_names = False # used for debugging 94 | self.state_object = None # type:str 95 | 96 | def get_prefix(self): 97 | import datetime 98 | result = ['# {}'.format(datetime.datetime.now()), 99 | '\n'.join(self.imports), 100 | '\n\n'.join(self.functions)] 101 | return '\n'.join(result) 102 | 103 | def generate_symbol(self): 104 | self._symbol_counter_ += 1 105 | return "_{}_".format(self._symbol_counter_) 106 | 107 | def add_function(self, params:list, body:str): 108 | name = "__LAMBDA_FUNCTION__{}__".format(len(self.functions) + 1) 109 | self.functions.append("def {}({}):\n\t{}".format(name, ', '.join(params), body.replace('\n', '\n\t'))) 110 | return name 111 | 112 | def visit_attribute(self, node:AstAttribute): 113 | result = self.visit(node.base) 114 | return "{}.{}".format(result, node.attr) 115 | 116 | def visit_binary(self, node:AstBinary): 117 | left = self.visit(node.left) 118 | right = self.visit(node.right) 119 | return "({} {} {})".format(left, node.op, right) 120 | 121 | def visit_body(self, node:AstBody): 122 | if len(node) == 0: 123 | return "pass" 124 | items = [self.visit(item) for item in node.items] 125 | items = [item for item in items if item != ''] 126 | return '\n'.join(items) 127 | 128 | def visit_break(self, _): 129 | return "break" 130 | 131 | def visit_call(self, node: AstCall): 132 | function = self.visit(node.function) 133 | args = [self.visit(arg) for arg in node.args] 134 | keywords = [''] * node.pos_arg_count + ['{}='.format(key) for key in node.keywords] 135 | args = [a + b for a, b in zip(keywords, args)] 136 | return "{}({})".format(function, ', '.join(args)) 137 | 138 | def visit_compare(self, node: AstCompare): 139 | if node.second_right is None: 140 | left = self.visit(node.left) 141 | right = self.visit(node.right) 142 | return "({} {} {})".format(left, node.op, right) 143 | else: 144 | left = self.visit(node.left) 145 | right = self.visit(node.right) 146 | second_right = self.visit(node.second_right) 147 | return "({} {} {} {} {})".format(left, node.op, right, node.second_op, second_right) 148 | 149 | def visit_def(self, node: AstDef): 150 | name = _normalize_name(node.original_name if self.short_names else node.name) 151 | if self.state_object is not None: 152 | name = "{}['{}']".format(self.state_object, name) 153 | if isinstance(node.value, AstFunction): 154 | function = node.value 155 | params = function.parameters 156 | if function.vararg is not None: 157 | params.append("*" + function.vararg) 158 | body = self.visit(function.body).replace('\n', '\n\t') 159 | return "def {}({}):\n\t{}".format(name, ', '.join(params), body) 160 | 161 | elif isinstance(node.value, AstWhile) or isinstance(node.value, AstObserve): 162 | result = self.visit(node.value) 163 | return "{}\n{} = None".format(result, name) 164 | 165 | elif _is_block(node.value): 166 | result = _push_return(node.value, lambda x: AstDef(node.name, x)) 167 | if not isinstance(result, AstDef): 168 | return self.visit(result) 169 | 170 | return "{} = {}".format(name, self.visit(node.value)) 171 | 172 | def visit_dict(self, node: AstDict): 173 | result = { key: self.visit(node.items[key]) for key in node.items } 174 | result = ["{}: {}".format(key, result[key]) for key in result] 175 | return "{" + ', '.join(result) + "}" 176 | 177 | def visit_for(self, node: AstFor): 178 | name = _normalize_name(node.original_target if self.short_names else node.target) 179 | source = self.visit(node.source) 180 | body = self.visit(node.body).replace('\n', '\n\t') 181 | return "for {} in {}:\n\t{}".format(name, source, body) 182 | 183 | def visit_function(self, node: AstFunction): 184 | params = node.parameters 185 | if node.vararg is not None: 186 | params.append("*" + node.vararg) 187 | body = self.visit(node.body) 188 | if '\n' in body or get_info(node.body).has_return: 189 | return self.add_function(params, body) 190 | else: 191 | return "(lambda {}: {})".format(', '.join(params), body) 192 | 193 | def visit_if(self, node: AstIf): 194 | test = self.visit(node.test) 195 | if_expr = self.visit(node.if_node) 196 | if node.has_else: 197 | else_expr = self.visit(node.else_node) 198 | if node.has_elif: 199 | if not else_expr.startswith("if"): 200 | enode = node.else_node 201 | etest = self.visit(enode.test) 202 | ebody = self.visit(enode.if_node) 203 | if enode.has_else: 204 | else_expr = "if {}:\n\t{}else:\n\t{}:".format(etest, ebody, self.visit(enode.else_body)) 205 | else: 206 | else_expr = "if {}:\n\t{}".format(etest, ebody) 207 | return "if {}:\n\t{}\nel{}".format(test, if_expr.replace('\n', '\n\t'), else_expr) 208 | elif '\n' in if_expr or '\n' in else_expr: 209 | return "if {}:\n\t{}\nelse:\n\t{}".format(test, if_expr.replace('\n', '\n\t'), 210 | else_expr.replace('\n', '\n\t')) 211 | else: 212 | return "{} if {} else {}".format(if_expr, test, else_expr) 213 | else: 214 | if '\n' in if_expr: 215 | return "if {}:\n\t{}".format(test, if_expr.replace('\n', '\n\t')) 216 | else: 217 | return "{} if {} else None".format(if_expr, test) 218 | 219 | def visit_import(self, node: AstImport): 220 | self.imports.append("import {}".format(node.module_name)) 221 | if node.imported_names is None: 222 | result = "import {}{}".format(node.module_name, "as {}".format(node.alias) if node.alias is not None else '') 223 | elif len(node.imported_names) == 1 and node.alias is not None: 224 | result = "from {} import {} as {}".format(node.module_name, node.imported_names[0], node.alias) 225 | else: 226 | result = "from {} import {}".format(node.module_name, ', '.join(node.imported_names)) 227 | return "" 228 | 229 | def visit_let(self, node: AstLet): 230 | name = _normalize_name(node.original_target if self.short_names else node.target) 231 | if isinstance(node.source, AstLet): 232 | result = self.visit(AstDef(node.target, node.source)) 233 | return result + "\n{}".format(self.visit(node.body)) 234 | else: 235 | return "{} = {}\n{}".format(name, self.visit(node.source), self.visit(node.body)) 236 | 237 | def visit_list_for(self, node: AstListFor): 238 | name = _normalize_name(node.original_target if self.short_names else node.target) 239 | expr = self.visit(node.expr) 240 | if _is_block(node.expr): 241 | expr = self.add_function([str(node.target)], expr) 242 | expr += "({})".format(str(node.target)) 243 | source = self.visit(node.source) 244 | test = (' if ' + self.visit(node.test)) if node.test is not None else '' 245 | return "[{} for {} in {}{}]".format(expr, name, source, test) 246 | 247 | def visit_multi_slice(self, node: AstMultiSlice): 248 | base = self.visit(node.base) 249 | slices = [self.visit(index) if index is not None else ':' for index in node.indices] 250 | return "{}[{}]".format(base, ','.join(slices)) 251 | 252 | def visit_observe(self, node: AstObserve): 253 | dist = self.visit(node.dist) 254 | return "observe({}, {})".format(dist, self.visit(node.value)) 255 | 256 | def visit_return(self, node: AstReturn): 257 | if node.value is None: 258 | return "return None" 259 | elif isinstance(node.value, AstWhile) or isinstance(node.value, AstObserve): 260 | result = self.visit(node.value) 261 | return result + "\nreturn None" 262 | elif _is_block(node.value): 263 | result = _push_return(node.value, lambda x: AstReturn(x)) 264 | if isinstance(result, AstReturn): 265 | return "return {}".format(self.visit(result)) 266 | else: 267 | return self.visit(result) 268 | else: 269 | return "return {}".format(self.visit(node.value)) 270 | 271 | def visit_sample(self, node: AstSample): 272 | dist = self.visit(node.dist) 273 | size = self.visit(node.size) 274 | if size is not None: 275 | return "sample({}, sample_size={})".format(dist, size) 276 | else: 277 | return "sample({})".format(dist) 278 | 279 | def visit_slice(self, node: AstSlice): 280 | base = self.visit(node.base) 281 | start = self.visit(node.start) if node.start is not None else '' 282 | stop = self.visit(node.stop) if node.stop is not None else '' 283 | return "{}[{}:{}]".format(base, start, stop) 284 | 285 | def visit_subscript(self, node: AstSubscript): 286 | base = self.visit(node.base) 287 | index = self.visit(node.index) 288 | if isinstance(node.base, AstDict) and node.default is not None: 289 | default = self.visit(node.default) 290 | return "{}.get({}, {})".format(base, index, default) 291 | else: 292 | return "{}[{}]".format(base, index) 293 | 294 | def visit_symbol(self, node: AstSymbol): 295 | if self.short_names: 296 | if self.state_object is not None and not node.predef and not '.' in node.original_name: 297 | return "{}['{}']".format(self.state_object, node.original_name) 298 | else: 299 | return node.original_name 300 | sym = self.resolve(node.name) 301 | if isinstance(sym, AstSymbol): 302 | name = _normalize_name(sym.name) 303 | else: 304 | name = _normalize_name(node.name) 305 | if self.state_object is not None and not node.predef and not '.' in name: 306 | name = "{}['{}']".format(self.state_object, name) 307 | return name 308 | 309 | def visit_unary(self, node: AstUnary): 310 | return "{}{}".format(node.op, self.visit(node.item)) 311 | 312 | def visit_value(self, node: AstValue): 313 | return repr(node.value) 314 | 315 | def visit_value_vector(self, node: AstValueVector): 316 | return repr(node.items) 317 | 318 | def visit_vector(self, node: AstVector): 319 | return "[{}]".format(', '.join([self.visit(item) for item in node.items])) 320 | 321 | def visit_while(self, node: AstWhile): 322 | test = self.visit(node.test) 323 | body = self.visit(node.body).replace('\n', '\n\t') 324 | return "while {}:\n\t{}".format(test, body) 325 | 326 | 327 | def generate_code(ast, *, code_generator=None, name=None, parameters=None, state_object=None): 328 | if code_generator is not None: 329 | if callable(code_generator): 330 | cg = code_generator() 331 | else: 332 | cg = code_generator 333 | else: 334 | cg = CodeGenerator() 335 | if state_object is not None: 336 | cg.state_object = state_object 337 | 338 | result = cg.visit(ast) 339 | if type(result) is list: 340 | result = [cg.get_prefix()] + result 341 | result = '\n\n'.join(result) 342 | else: 343 | result = cg.get_prefix() + '\n' + result 344 | 345 | if name is not None: 346 | assert type(name) is str, "name must be a string" 347 | if parameters is None: 348 | parameters = '' 349 | elif type(parameters) in (list, tuple): 350 | parameters = ', '.join(parameters) 351 | elif type(parameters) is not str: 352 | raise TypeError("'parameters' must be a list of strings, or a string") 353 | result = "def {}({}):\n\t{}".format(name, parameters, result) 354 | 355 | return result 356 | -------------------------------------------------------------------------------- /pyppl/backend/ppl_graph_codegen.py: -------------------------------------------------------------------------------- 1 | # 2 | # This file is part of PyFOPPL, an implementation of a First Order Probabilistic Programming Language in Python. 3 | # 4 | # License: MIT (see LICENSE.txt) 5 | # 6 | # 12. Mar 2018, Tobias Kohn 7 | # 07. May 2018, Tobias Kohn 8 | # 9 | import datetime 10 | import importlib 11 | from ..graphs import * 12 | from ..ppl_ast import * 13 | 14 | 15 | class GraphCodeGenerator(object): 16 | """ 17 | In contrast to the more general code generator `CodeGenerator`, this class creates the code for a graph-based 18 | model. The output of the method `generate_model_code()` is therefore the code of a class `Model` with functions 19 | such as `gen_log_prob()` or `gen_prior_samples()`, including all necessary imports. 20 | 21 | You want to change this class if you need additional (or adapted) methods in your model-class. 22 | 23 | Usage: 24 | ``` 25 | graph = ... # <- actually generated by the graph-factory/generator 26 | graph_code_gen = GraphCodeGenerator(graph.nodes, state_object='state', imports='import distributions as dist') 27 | code = graph_code_gen.generate_model_code() 28 | my_globals = {} 29 | exec(code, my_globals) 30 | Model = my_globals['Model'] 31 | model = Model(graph.vertices, graph.arcs, graph.data, graph.conditionals) 32 | ``` 33 | The state-object specifies the actual name of the dictionary/map that holds the state, i. e. all the variables. 34 | When any state-object is given, the generated code reads, say, `state['x']` instead of purely `x`. 35 | 36 | Hacking: 37 | The `generate_model_code`-method uses three fixed methods to generate the code for `__init__`, `__repr__` as 38 | well as the doc-string: `_generate_doc_string`, `_generate_init_method`, `_generate_repr_method`. After that, 39 | it scans the object instance of `GraphCodeGenerator` for public methods, and assumes that each method returns 40 | the code for the respective method. 41 | 42 | Say, for instance, you wanted your Model-class to have a method `get_all_nodes` with the following code: 43 | ``` 44 | def get_all_nodes(self): 45 | return set.union(self.vertices, self.conditionals) 46 | ``` 47 | You then add the following method to the `GraphCodeGenerator` and be done. 48 | ``` 49 | def get_all_nodes(self): 50 | return "return set.union(self.vertices, self.conditionals)" 51 | ``` 52 | If, on the other hand, you need some additional parameters/arguments for your method, then you should return 53 | a tuple with the first element being the parameters as a string, and the second element being the code as before. 54 | ``` 55 | def get_all_nodes(self): 56 | return "param1, param2", "return set.union(self.vertices, self.conditionals)" 57 | ``` 58 | 59 | Of course, you do not need to actually change this class, but you can derive a new class from it, if you wish. 60 | """ 61 | 62 | def __init__(self, nodes: list, state_object: Optional[str]=None, imports: Optional[str]=None): 63 | self.nodes = nodes 64 | self.state_object = state_object 65 | self.imports = imports 66 | self.bit_vector_name = None 67 | self.logpdf_suffix = None 68 | 69 | def _complete_imports(self, imports: str): 70 | if imports != '': 71 | has_dist = False 72 | uses_numpy = False 73 | uses_torch = False 74 | uses_pyfo = False 75 | for s in imports.split('\n'): 76 | s = s.strip() 77 | if s.endswith(' dist') or s.endswith('.dist'): 78 | has_dist = True 79 | if s.startswith('from'): 80 | s = s[5:] 81 | elif s.startswith('import'): 82 | s = s[7:] 83 | i = 0 84 | while i < len(s) and 'A' <= s[i].upper() <= 'Z': 85 | i += 1 86 | m = s[:i] 87 | uses_numpy = uses_numpy or m == 'numpy' 88 | uses_torch = uses_torch or m == 'torch' 89 | uses_pyfo = uses_pyfo or m == 'pyfo' 90 | if uses_torch or uses_numpy: 91 | self.logpdf_suffix = '' 92 | if not has_dist: 93 | if uses_torch and uses_pyfo: 94 | return 'import pyfo.distributions as dist\n' 95 | else: 96 | return 'import torch.distributions as dist\n' 97 | return '' 98 | 99 | 100 | def generate_model_code(self, *, 101 | class_name: str='Model', 102 | base_class: str='', 103 | imports: str='') -> str: 104 | 105 | if self.imports is not None: 106 | imports = self.imports + "\n" + imports 107 | if base_class is None: 108 | base_class = '' 109 | 110 | if '.' in base_class: 111 | idx = base_class.rindex('.') 112 | base_module = base_class[:idx] 113 | try: 114 | importlib.import_module(base_module) 115 | base_class = base_class[idx+1:] 116 | imports = "from {} import {}\n".format(base_module, base_class) + imports 117 | except: 118 | pass 119 | 120 | # try: 121 | # graph_module = 'pyppl.aux.graph_plots' 122 | # m = importlib.import_module(graph_module) 123 | # names = [n for n in dir(m) if not n.startswith('_')] 124 | # if len(names) > 1: 125 | # names = [n for n in names if n[0].isupper()] 126 | # if len(names) == 1: 127 | # if base_class != '': 128 | # base_class += ', ' 129 | # base_class += '_' + names[0] 130 | # imports = "from {} import {} as _{}\n".format(graph_module, names[0], names[0]) + imports 131 | # except ModuleNotFoundError: 132 | # pass 133 | 134 | imports = self._complete_imports(imports) + imports 135 | 136 | result = ["# {}".format(datetime.datetime.now()), 137 | imports, 138 | "class {}({}):".format(class_name, base_class)] 139 | 140 | doc_str = self._generate_doc_string() 141 | if doc_str is not None and doc_str != '': 142 | result.append('\t"""\n\t{}\n\t"""'.format(doc_str.replace('\n', '\n\t'))) 143 | result.append('') 144 | 145 | init_method = self._generate_init_method() 146 | if init_method is not None: 147 | result.append('\t' + init_method.replace('\n', '\n\t')) 148 | 149 | repr_method = self._generate_repr_method() 150 | if repr_method is not None: 151 | result.append('\t' + repr_method.replace('\n', '\n\t')) 152 | 153 | methods = [x for x in dir(self) if not x.startswith('_') and x != 'generate_model_code'] 154 | for method_name in methods: 155 | method = getattr(self, method_name) 156 | if callable(method): 157 | code = method() 158 | if type(code) is tuple and len(code) == 2: 159 | args, code = code 160 | args = 'self, ' + args 161 | else: 162 | args = 'self' 163 | code = code.replace('\n', '\n\t\t') 164 | result.append("\tdef {}({}):\n\t\t{}\n".format(method_name, args, code)) 165 | 166 | return '\n'.join(result) 167 | 168 | def _generate_doc_string(self): 169 | return '' 170 | 171 | def _generate_init_method(self): 172 | return "def __init__(self, vertices: set, arcs: set, data: set, conditionals: set):\n" \ 173 | "\tsuper().__init__()\n" \ 174 | "\tself.vertices = vertices\n" \ 175 | "\tself.arcs = arcs\n" \ 176 | "\tself.data = data\n" \ 177 | "\tself.conditionals = conditionals\n" 178 | 179 | def _generate_repr_method(self): 180 | s = "def __repr__(self):\n" \ 181 | "\tV = '\\n'.join(sorted([repr(v) for v in self.vertices]))\n" \ 182 | "\tA = ', '.join(['({}, {})'.format(u.name, v.name) for (u, v) in self.arcs]) if len(self.arcs) > 0 else ' -'\n" \ 183 | "\tC = '\\n'.join(sorted([repr(v) for v in self.conditionals])) if len(self.conditionals) > 0 else ' -'\n" \ 184 | "\tD = '\\n'.join([repr(u) for u in self.data]) if len(self.data) > 0 else ' -'\n" \ 185 | "\tgraph = 'Vertices V:\\n{V}\\nArcs A:\\n {A}\\n\\nConditions C:\\n{C}\\n\\nData D:\\n{D}\\n'.format(V=V, A=A, C=C, D=D)\n" \ 186 | "\tgraph = '#Vertices: {}, #Arcs: {}\\n'.format(len(self.vertices), len(self.arcs)) + graph\n" \ 187 | "\treturn graph\n" 188 | return s 189 | 190 | def get_vertices(self): 191 | return "return self.vertices" 192 | 193 | def get_vertices_names(self): 194 | return "return [v.name for v in self.vertices]" 195 | 196 | def get_arcs(self): 197 | return "return self.arcs" 198 | 199 | def get_arcs_names(self): 200 | return "return [(u.name, v.name) for (u, v) in self.arcs]" 201 | 202 | def get_conditions(self): 203 | return "return self.conditionals" 204 | 205 | def gen_cond_vars(self): 206 | return "return [c.name for c in self.conditionals]" 207 | 208 | def gen_if_vars(self): 209 | return "return [v.name for v in self.vertices if v.is_conditional and v.is_sampled and v.is_continuous]" 210 | 211 | def gen_cont_vars(self): 212 | return "return [v.name for v in self.vertices if v.is_continuous and not v.is_conditional and v.is_sampled]" 213 | 214 | def gen_disc_vars(self): 215 | return "return [v.name for v in self.vertices if v.is_discrete and v.is_sampled]" 216 | 217 | def get_vars(self): 218 | return "return [v.name for v in self.vertices if v.is_sampled]" 219 | 220 | def is_torch_imported(self): 221 | return "import sys \nprint('torch' in sys.modules) \nprint(torch.__version__) \nprint(type(torch.tensor)) \nimport inspect \nprint(inspect.getfile(torch))" 222 | 223 | def _gen_code(self, buffer: list, code_for_vertex, *, want_data_node: bool=True, flags=None): 224 | distribution = None 225 | state = self.state_object 226 | if self.bit_vector_name is not None: 227 | if state is not None: 228 | buffer.append("{}['{}'] = 0".format(state, self.bit_vector_name)) 229 | else: 230 | buffer.append("{} = 0".format(self.bit_vector_name)) 231 | for node in self.nodes: 232 | name = node.name 233 | if state is not None: 234 | name = "{}['{}']".format(state, name) 235 | if isinstance(node, Vertex): 236 | if flags is not None: 237 | code = "dst_ = {}".format(node.get_code(**flags)) 238 | else: 239 | code = "dst_ = {}".format(node.get_code()) 240 | if code != distribution: 241 | buffer.append(code) 242 | distribution = code 243 | code = code_for_vertex(name, node) 244 | if type(code) is str: 245 | buffer.append(code) 246 | elif type(code) is list: 247 | buffer += code 248 | 249 | elif isinstance(node, ConditionNode) and self.bit_vector_name is not None: 250 | bit_vector = "{}['{}']".format(state, self.bit_vector_name) if state is not None else self.bit_vector_name 251 | code = "_c = {}\n{} = _c".format(node.get_code(), name) 252 | buffer.append(code) 253 | buffer.append("{} |= {} if _c else 0".format(bit_vector, node.bit_index)) 254 | 255 | elif want_data_node or not isinstance(node, DataNode): 256 | code = "{} = {}".format(name, node.get_code()) 257 | buffer.append(code) 258 | 259 | def gen_log_prob(self): 260 | def code_for_vertex(name: str, node: Vertex): 261 | cond_code = node.get_cond_code(state_object=self.state_object) 262 | if cond_code is not None: 263 | result = cond_code + "\tlog_prob = log_prob + dst_.log_prob({})".format(name) 264 | else: 265 | result = "log_prob = log_prob + dst_.log_prob({})".format(name) 266 | if self.logpdf_suffix is not None: 267 | result = result + self.logpdf_suffix 268 | return result 269 | 270 | logpdf_code = ["log_prob = 0"] 271 | self._gen_code(logpdf_code, code_for_vertex=code_for_vertex, want_data_node=False) 272 | logpdf_code.append("return log_prob") 273 | logpdf_code.insert(0, "try:") 274 | # return 'state', '\n'.join(logpdf_code) 275 | code = ['\n\t'.join(logpdf_code), "\nexcept(ValueError, RuntimeError) as e:\n\tprint('****Warning: Target density is ill-defined****')"] 276 | return 'state', ''.join(code) 277 | 278 | 279 | # def gen_log_prob_transformed(self): 280 | # def code_for_vertex(name: str, node: Vertex): 281 | # cond_code = node.get_cond_code(state_object=self.state_object) 282 | # if cond_code is not None: 283 | # result = cond_code + "log_prob = log_prob + dst_.log_prob({})".format(name) 284 | # else: 285 | # result = "log_prob = log_prob + dst_.log_prob({})".format(name) 286 | # if self.logpdf_suffix is not None: 287 | # result += self.logpdf_suffix 288 | # return result 289 | # # Note to self : To change suffix for torch or numpy look at line 87-88 in compiled imports (above) 290 | # logpdf_code = ["log_prob = 0"] 291 | # self._gen_code(logpdf_code, code_for_vertex=code_for_vertex, want_data_node=False, flags={'transformed': True}) 292 | # logpdf_code.append("return log_prob.sum()") 293 | # return 'state', '\n'.join(logpdf_code) 294 | 295 | def gen_prior_samples(self): 296 | 297 | def code_for_vertex(name: str, node: Vertex): 298 | if node.has_observation: 299 | return "{} = {}".format(name, node.observation) 300 | sample_size = node.sample_size 301 | if sample_size is not None and sample_size > 1: 302 | return "{} = dst_.sample(sample_size={})".format(name, sample_size) 303 | else: 304 | return "{} = dst_.sample()".format(name) 305 | 306 | state = self.state_object 307 | sample_code = [] 308 | if state is not None: 309 | sample_code.append(state + " = {}") 310 | self._gen_code(sample_code, code_for_vertex=code_for_vertex, want_data_node=True) 311 | if state is not None: 312 | sample_code.append("return " + state) 313 | return '\n'.join(sample_code) 314 | 315 | def gen_cond_bit_vector(self): 316 | code = "result = 0\n" \ 317 | "for cond in self.conditionals:\n" \ 318 | "\tresult = cond.update_bit_vector(state, result)\n" \ 319 | "return result" 320 | return 'state', code 321 | 322 | -------------------------------------------------------------------------------- /pyppl/backend/ppl_graph_factory.py: -------------------------------------------------------------------------------- 1 | # 2 | # This file is part of PyFOPPL, an implementation of a First Order Probabilistic Programming Language in Python. 3 | # 4 | # License: MIT (see LICENSE.txt) 5 | # 6 | # 12. Mar 2018, Tobias Kohn 7 | # 11. May 2018, Tobias Kohn 8 | # 9 | from ..ppl_ast import * 10 | from ..graphs import * 11 | from .ppl_code_generator import CodeGenerator 12 | from .ppl_graph_codegen import GraphCodeGenerator 13 | from .. import distributions 14 | import warnings 15 | 16 | 17 | class _ConditionCollector(Visitor): 18 | 19 | __visit_children_first__ = True 20 | 21 | def __init__(self): 22 | super().__init__() 23 | self.cond_nodes = set() 24 | 25 | def visit_symbol(self, node: AstSymbol): 26 | if isinstance(node.node, ConditionNode): 27 | self.cond_nodes.add(node.node) 28 | return self.visit_node(node) 29 | 30 | 31 | class GraphFactory(object): 32 | 33 | def __init__(self, code_generator=None): 34 | if code_generator is None: 35 | code_generator = CodeGenerator() 36 | code_generator.state_object = 'state' 37 | self._counter = 30000 38 | self.nodes = [] 39 | self.code_generator = code_generator 40 | self.cond_nodes_map = {} 41 | self.data_nodes_cache = {} 42 | 43 | def _generate_code_for_node(self, node: AstNode): 44 | return self.code_generator.visit(node) 45 | 46 | def generate_symbol(self, prefix: str): 47 | self._counter += 1 48 | return prefix + str(self._counter) 49 | 50 | def create_node(self, parents: set): 51 | assert type(parents) is set 52 | return None 53 | 54 | def create_condition_node(self, test: AstNode, parents: set): 55 | name = self.generate_symbol('cond_') 56 | code = self._generate_code_for_node(test) 57 | if code in self.cond_nodes_map: 58 | return self.cond_nodes_map[code] 59 | if isinstance(test, AstCompare) and is_zero(test.right) and test.second_right is None: 60 | result = ConditionNode(name, ancestors=parents, condition=code, 61 | function=self._generate_code_for_node(test.left), op=test.op) 62 | elif isinstance(test, AstCall) and test.function_name.startswith('torch.') and is_number(test.right): 63 | result = ConditionNode(name, ancestors=parents, condition=code, 64 | function=self._generate_code_for_node(test.left), op=test.function_name, 65 | compare_value=test.right.value) 66 | else: 67 | result = ConditionNode(name, ancestors=parents, condition=code) 68 | self.nodes.append(result) 69 | self.cond_nodes_map[code] = result 70 | return result 71 | 72 | def create_data_node(self, data: AstNode, parents: Optional[set]=None): 73 | if parents is None: 74 | parents = set() 75 | code = self._generate_code_for_node(data) 76 | if code in self.data_nodes_cache: 77 | return self.data_nodes_cache[code] 78 | name = self.generate_symbol('data_') 79 | result = DataNode(name, ancestors=parents, data=code) 80 | self.nodes.append(result) 81 | self.data_nodes_cache[code] = result 82 | return result 83 | 84 | def create_observe_node(self, dist: AstNode, value: AstNode, parents: set, conditions: set): 85 | arg_names = None 86 | if isinstance(dist, AstCall): 87 | func = dist.function_name 88 | args = [self._generate_code_for_node(arg) for arg in dist.args] 89 | # args = dist.add_keywords_to_args(args) 90 | trans = dist.get_keyword_arg_value("transform") 91 | distr = distributions.get_distribution_for_name(func) 92 | if distr is not None: 93 | if 0 < dist.pos_arg_count <= len(distr.params): 94 | arg_names = distr.params[:dist.pos_arg_count] + dist.keywords 95 | else: 96 | func = None 97 | args = None 98 | trans = None 99 | name = self.generate_symbol('y') 100 | d_code = self._generate_code_for_node(dist) 101 | v_code = self._generate_code_for_node(value) 102 | obs_value = value.value if is_value(value) else None 103 | cc = _ConditionCollector() 104 | cc.visit(dist) 105 | result = Vertex(name, ancestors=parents, distribution_code=d_code, distribution_name=_get_dist_name(dist), 106 | distribution_args=args, distribution_func=func, 107 | distribution_transform=trans, distribution_arg_names=arg_names, 108 | observation=v_code, 109 | observation_value=obs_value, conditions=conditions, 110 | condition_nodes=cc.cond_nodes if len(cc.cond_nodes) > 0 else None) 111 | self.nodes.append(result) 112 | return result 113 | 114 | def create_sample_node(self, dist: AstNode, size: int, parents: set, original_name: Optional[str]=None): 115 | arg_names = None 116 | if isinstance(dist, AstCall): 117 | func = dist.function_name 118 | args = [self._generate_code_for_node(arg) for arg in dist.args] 119 | # args = dist.add_keywords_to_args(args) 120 | trans = dist.get_keyword_arg_value("transform") 121 | distr = distributions.get_distribution_for_name(func) 122 | if distr is not None: 123 | if 0 < dist.pos_arg_count <= len(distr.params): 124 | arg_names = distr.params[:dist.pos_arg_count] + dist.keywords 125 | else: 126 | func = None 127 | args = None 128 | trans = None 129 | name = self.generate_symbol('x') 130 | code = self._generate_code_for_node(dist) 131 | 132 | # stop the use of factor in sample statements 133 | _is_factor = _get_dist_name(dist) 134 | if _is_factor.__contains__('factor'): 135 | import sys 136 | import warnings 137 | warnings.warn('{0} Model is not valid {0}'.format(10*'*'), stacklevel=5) 138 | warnings.warn('{0} factor statements cannot be placed within sample statements {0}'.format(10*'*'), stacklevel=5) 139 | sys.exit(1) 140 | 141 | 142 | result = Vertex(name, ancestors=parents, distribution_code=code, distribution_name=_get_dist_name(dist), 143 | distribution_args=args, distribution_func=func, distribution_transform=trans, 144 | distribution_arg_names=arg_names, 145 | sample_size=size, original_name=original_name) 146 | self.nodes.append(result) 147 | return result 148 | 149 | def generate_code(self, *, class_name: Optional[str] = None, imports: Optional[str]=None, 150 | base_class: Optional[str]=None): 151 | code_gen = GraphCodeGenerator(self.nodes, self.code_generator.state_object, 152 | imports=imports if imports is not None else '') 153 | return code_gen.generate_model_code(class_name=class_name, base_class=base_class) 154 | 155 | 156 | def _get_dist_name(dist: AstNode): 157 | if isinstance(dist, AstCall): 158 | result = dist.function_name 159 | if result.startswith('dist.'): 160 | result = result[5:] 161 | if result == 'factor_cont' or result =='factor_disc': 162 | warnings.warn('{0} compiler cannot guarantee that the function is analytic, as factor is being called {0}\n'.format(10*'*'), stacklevel=2) 163 | return result 164 | elif isinstance(dist, AstSubscript): 165 | if isinstance(dist.base, AstVector): 166 | names = set([_get_dist_name(x) for x in dist.base.items]) 167 | if len(names) == 1: 168 | return tuple(names)[0] 169 | 170 | raise Exception("Not a valid distribution: '{}'".format(repr(dist))) 171 | -------------------------------------------------------------------------------- /pyppl/backend/ppl_graph_generator.py: -------------------------------------------------------------------------------- 1 | # 2 | # This file is part of PyFOPPL, an implementation of a First Order Probabilistic Programming Language in Python. 3 | # 4 | # License: MIT (see LICENSE.txt) 5 | # 6 | # 12. Mar 2018, Tobias Kohn 7 | # 23. Mar 2018, Tobias Kohn 8 | # 9 | from ..ppl_ast import * 10 | from ..graphs import * 11 | from .ppl_graph_factory import GraphFactory 12 | 13 | 14 | class ConditionScope(object): 15 | 16 | def __init__(self, prev, condition): 17 | self.prev = prev 18 | self.condition = condition 19 | self.truth_value = True 20 | 21 | def switch_branch(self): 22 | self.truth_value = not self.truth_value 23 | 24 | def get_condition(self): 25 | return (self.condition, self.truth_value) 26 | 27 | 28 | class ConditionScopeContext(object): 29 | 30 | def __init__(self, visitor): 31 | self.visitor = visitor 32 | 33 | def __enter__(self): 34 | return self.visitor.conditions 35 | 36 | def __exit__(self, exc_type, exc_val, exc_tb): 37 | self.visitor.leave_condition() 38 | 39 | 40 | class GraphGenerator(ScopedVisitor): 41 | 42 | def __init__(self, factory: Optional[GraphFactory]=None): 43 | super().__init__() 44 | if factory is None: 45 | factory = GraphFactory() 46 | self.factory = factory 47 | self.nodes = [] 48 | self.conditions = None # type: ConditionScope 49 | self.imports = set() 50 | 51 | def enter_condition(self, condition): 52 | self.conditions = ConditionScope(self.conditions, condition) 53 | 54 | def leave_condition(self): 55 | self.conditions = self.conditions.prev 56 | 57 | def switch_condition(self): 58 | self.conditions.switch_branch() 59 | 60 | def create_condition(self, condition): 61 | self.enter_condition(condition) 62 | return ConditionScopeContext(self) 63 | 64 | def get_current_conditions(self): 65 | result = [] 66 | c = self.conditions 67 | while c is not None: 68 | result.append(c.get_condition()) 69 | c = c.prev 70 | return set(result) 71 | 72 | def _visit_dict(self, items): 73 | result = {} 74 | parents = set() 75 | for key in items.keys(): 76 | item, parent = self.visit(items[key]) 77 | result[key] = item 78 | parents = set.union(parents, parent) 79 | return result, parents 80 | 81 | def _visit_items(self, items): 82 | result = [] 83 | parents = set() 84 | for _item in (self.visit(item) for item in items): 85 | if _item is not None: 86 | item, parent = _item 87 | result.append(item) 88 | parents = set.union(parents, parent) 89 | else: 90 | result.append(None) 91 | return result, parents 92 | 93 | def visit_node(self, node: AstNode): 94 | raise RuntimeError("cannot compile '{}'".format(node)) 95 | 96 | def visit_attribute(self, node:AstAttribute): 97 | base, parents = self.visit(node.base) 98 | if base is node.base: 99 | return node, parents 100 | else: 101 | return AstAttribute(base, node.attr), parents 102 | 103 | def visit_binary(self, node:AstBinary): 104 | left, l_parents = self.visit(node.left) 105 | right, r_parents = self.visit(node.right) 106 | return AstBinary(left, node.op, right), set.union(l_parents, r_parents) 107 | 108 | def visit_body(self, node:AstBody): 109 | items, parents = self._visit_items(node.items) 110 | return makeBody(items), parents 111 | 112 | def visit_call(self, node: AstCall): 113 | function, f_parents = self.visit(node.function) 114 | args, a_parents = self._visit_items(node.args) 115 | parents = set.union(f_parents, a_parents) 116 | return AstCall(function, args, node.keywords), parents 117 | 118 | def visit_call_torch_function(self, node: AstCall): 119 | name = node.function_name 120 | if name.startswith("torch.") and node.arg_count == 1 and isinstance(node.args[0], AstValueVector): 121 | name = name[6:] 122 | if name in ('tensor', 'Tensor', 'FloatTensor', 'IntTensor', 'DoubleTensor', 'HalfTensor', 123 | 'ByteTensor', 'ShortTensor', 'LongTensor'): 124 | node = self.factory.create_data_node(node) 125 | if node is not None: 126 | self.nodes.append(node) 127 | return AstSymbol(node.name, node=node), set() 128 | 129 | elif name.startswith('torch.') and name[6:] in ('eq', 'ge', 'gt', 'le', 'lt', 'ne') and node.arg_count == 2: 130 | left, l_parents = self.visit(node.left) 131 | right, r_parents = self.visit(node.right) 132 | parents = set.union(l_parents, r_parents) 133 | cond_node = self.factory.create_condition_node(node.clone(args=[left, right]), parents) 134 | if cond_node is not None: 135 | self.nodes.append(cond_node) 136 | name = cond_node.name 137 | return AstSymbol(name, node=cond_node), parents 138 | 139 | return self.visit_call(node) 140 | 141 | def visit_compare(self, node: AstCompare): 142 | left, l_parents = self.visit(node.left) 143 | right, r_parents = self.visit(node.right) 144 | if node.second_right is not None: 145 | second_right, sc_parents = self.visit(node.second_right) 146 | parents = set.union(l_parents, r_parents) 147 | parents = set.union(parents, sc_parents) 148 | return AstCompare(left, node.op, right, node.second_op, second_right), parents 149 | else: 150 | return AstCompare(left, node.op, right), set.union(l_parents, r_parents) 151 | 152 | def visit_def(self, node: AstDef): 153 | self.define(node.name, self.visit(node.value)) 154 | return AstValue(None), set() 155 | 156 | def visit_dict(self, node: AstDict): 157 | items, parents = self._visit_dict(node.items) 158 | return AstDict(items), parents 159 | 160 | def visit_for(self, node: AstFor): 161 | source, s_parents = self.visit(node.source) 162 | body, b_parents = self.visit(node.body) 163 | parents = set.union(s_parents, b_parents) 164 | return AstFor(node.target, source, body), parents 165 | 166 | def visit_if(self, node: AstIf): 167 | test, parents = self.visit(node.test) 168 | cond_node = self.factory.create_condition_node(test, parents) 169 | if cond_node is not None: 170 | self.nodes.append(cond_node) 171 | name = cond_node.name 172 | test = AstSymbol(name, node=cond_node) 173 | 174 | with self.create_condition(cond_node): 175 | a_node, a_parents = self.visit(node.if_node) 176 | parents = set.union(parents, a_parents) 177 | self.switch_condition() 178 | b_node, b_parents = self.visit(node.else_node) 179 | parents = set.union(parents, b_parents) 180 | 181 | return AstIf(test, a_node, b_node), parents 182 | 183 | def visit_import(self, node: AstImport): 184 | self.imports.add(node.module_name) 185 | return AstValue(None), set() 186 | 187 | def visit_let(self, node: AstLet): 188 | self.define(node.target, self.visit(node.source)) 189 | return self.visit(node.body) 190 | 191 | def visit_list_for(self, node: AstListFor): 192 | source, s_parents = self.visit(node.source) 193 | expr, e_parents = self.visit(node.expr) 194 | parents = set.union(s_parents, e_parents) 195 | if node.test is not None: 196 | test, t_parents = self.visit(node.test) 197 | parents = set.union(parents, t_parents) 198 | else: 199 | test = None 200 | return AstListFor(node.target, source, expr, test), parents 201 | 202 | def visit_multi_slice(self, node: AstMultiSlice): 203 | items, parents = self._visit_items(node.indices) 204 | result = node.clone(indices=items) 205 | return result, parents 206 | 207 | def visit_observe(self, node: AstObserve): 208 | dist, d_parents = self.visit(node.dist) 209 | value, v_parents = self.visit(node.value) 210 | parents = set.union(d_parents, v_parents) 211 | node = self.factory.create_observe_node(dist, value, parents, self.get_current_conditions()) 212 | self.nodes.append(node) 213 | return AstSymbol(node.name, node=node), set() 214 | 215 | def visit_sample(self, node: AstSample): 216 | dist, d_parents = self.visit(node.dist) 217 | if node.size is not None: 218 | size, s_parents = self.visit(node.size) 219 | parents = set.union(d_parents, s_parents) 220 | if isinstance(size, AstValue): 221 | size = size.value 222 | else: 223 | raise RuntimeError("sample size must be a constant integer value instead of '{}'".format(size)) 224 | else: 225 | size = 1 226 | parents = d_parents 227 | node = self.factory.create_sample_node(dist, size, parents, original_name=getattr(node, 'original_name', None)) 228 | self.nodes.append(node) 229 | return AstSymbol(node.name, node=node), { node } 230 | 231 | def visit_slice(self, node: AstSlice): 232 | base, parents = self.visit(node.base) 233 | if node.start is not None: 234 | start, a_parents = self.visit(node.start) 235 | parents = set.union(parents, a_parents) 236 | else: 237 | start = None 238 | if node.stop is not None: 239 | stop, a_parents = self.visit(node.stop) 240 | parents = set.union(parents, a_parents) 241 | else: 242 | stop = None 243 | return AstSlice(base, start, stop), parents 244 | 245 | def visit_subscript(self, node: AstSubscript): 246 | base, b_parents = self.visit(node.base) 247 | index, i_parents = self.visit(node.index) 248 | if is_vector(base) and is_integer(index): 249 | return self.visit(base[index.value]) 250 | return makeSubscript(base, index), set.union(b_parents, i_parents) 251 | 252 | def visit_symbol(self, node: AstSymbol): 253 | item = self.resolve(node.name) 254 | if item is not None: 255 | return item 256 | elif node.node is not None: 257 | return node, { node.node } 258 | elif node.predef: 259 | return node, set() 260 | else: 261 | line = " [line {}]".format(node.lineno) if hasattr(node, 'lineno') else '' 262 | raise RuntimeError("symbol not found: '{}'{}".format(node.original_name, line)) 263 | 264 | def visit_unary(self, node: AstUnary): 265 | item, parents = self.visit(node.item) 266 | return AstUnary(node.op, item), parents 267 | 268 | def visit_value(self, node: AstValue): 269 | return node, set() 270 | 271 | def visit_value_vector(self, node: AstValueVector): 272 | if len(node) > 3: 273 | node = self.factory.create_data_node(node) 274 | if node is not None: 275 | self.nodes.append(node) 276 | return AstSymbol(node.name, node=node), set() 277 | return node, set() 278 | 279 | def visit_vector(self, node: AstVector): 280 | items, parents = self._visit_items(node.items) 281 | result = makeVector(items) 282 | return result, parents 283 | 284 | def generate_code(self, imports: Optional[str]=None, *, 285 | base_class: Optional[str]=None, 286 | class_name: Optional[str]=None): 287 | if len(self.imports) > 0: 288 | _imports = '\n'.join(['import {}'.format(item) for item in self.imports]) 289 | if imports is not None: 290 | _imports += '\n' + imports 291 | elif imports is not None: 292 | _imports = imports 293 | else: 294 | _imports = '' 295 | return self.factory.generate_code(class_name=class_name, imports=_imports, 296 | base_class=base_class) 297 | 298 | def generate_model(self, imports: Optional[str]=None, base_class: Optional[str]=None, class_name: str='Model'): 299 | vertices = set() 300 | arcs = set() 301 | data = set() 302 | conditionals = set() 303 | for node in self.nodes: 304 | if isinstance(node, Vertex): 305 | vertices.add(node) 306 | for a in node.ancestors: 307 | arcs.add((a, node)) 308 | elif isinstance(node, DataNode): 309 | data.add(node) 310 | elif isinstance(node, ConditionNode): 311 | conditionals.add(node) 312 | 313 | code = self.generate_code(imports=imports, base_class=base_class, class_name=class_name) 314 | c_globals = {} 315 | exec(code, c_globals) 316 | Model = c_globals[class_name] 317 | result = Model(vertices, arcs, data, conditionals) 318 | result.code = code 319 | return result 320 | -------------------------------------------------------------------------------- /pyppl/distributions.py: -------------------------------------------------------------------------------- 1 | # 2 | # This file is part of PyFOPPL, an implementation of a First Order Probabilistic Programming Language in Python. 3 | # 4 | # License: MIT (see LICENSE.txt) 5 | # 6 | # 17. Jan 2018, Tobias Kohn 7 | # 28. Mar 2018, Tobias Kohn 8 | # 9 | from enum import * 10 | 11 | ######################################################################### 12 | 13 | class DistributionType(Enum): 14 | 15 | CONTINUOUS = "continuous" 16 | DISCRETE = "discrete" 17 | UNDEFINED = "undefined" 18 | 19 | 20 | ######################################################################### 21 | 22 | class Distribution(object): 23 | 24 | def __init__(self, name:str, distributions_type:DistributionType=None, params:list=None, *, 25 | vector_sample:bool=False, 26 | foppl_name:str=None, 27 | python_name:str=None): 28 | assert type(name) is str 29 | self.name = name 30 | self.foppl_name = name.lower() if foppl_name is None else foppl_name 31 | self.python_name = name if python_name is None else python_name 32 | if distributions_type is None: 33 | self.distribution_type = DistributionType.CONTINUOUS 34 | else: 35 | self.distribution_type = distributions_type 36 | if params is None: 37 | self.params = [] 38 | else: 39 | self.params = params 40 | self._vector_sample = vector_sample 41 | 42 | @property 43 | def is_continuous(self): 44 | return self.distribution_type == DistributionType.CONTINUOUS 45 | 46 | @property 47 | def is_discrete(self): 48 | return self.distribution_type == DistributionType.DISCRETE 49 | 50 | @property 51 | def is_undefined(self): 52 | return self.distribution_type == DistributionType.UNDEFINED 53 | 54 | @property 55 | def parameter_count(self): 56 | return len(self.params) 57 | 58 | 59 | 60 | ######################################################################### 61 | distributions = { 62 | Distribution('Bernoulli', DistributionType.DISCRETE, ['probs']), 63 | Distribution('Beta', DistributionType.CONTINUOUS, ['alpha', 'beta']), 64 | Distribution('Binomial', DistributionType.DISCRETE, ['total_count', 'probs']), 65 | Distribution('Categorical', DistributionType.DISCRETE, ['probs']), 66 | Distribution('Cauchy', DistributionType.CONTINUOUS, ['mu', 'gamma']), 67 | Distribution('Dirichlet', DistributionType.CONTINUOUS, ['alpha'], vector_sample=True), 68 | Distribution('Discrete', DistributionType.DISCRETE, None), 69 | Distribution('Exponential', DistributionType.CONTINUOUS, ['rate']), 70 | Distribution('Gamma', DistributionType.CONTINUOUS, ['alpha', 'beta']), 71 | Distribution('HalfCauchy', DistributionType.CONTINUOUS, ['mu', 'gamma'], foppl_name='half_cauchy'), 72 | Distribution('LogGamma', DistributionType.CONTINUOUS, ['alpha', 'beta']), 73 | Distribution('LogNormal', DistributionType.CONTINUOUS, ['mu', 'sigma'], foppl_name='log_normal'), 74 | Distribution('Multinomial', DistributionType.DISCRETE, ['total_count', 'probs', 'n']), 75 | Distribution('MultivariateNormal', 76 | DistributionType.CONTINUOUS, ['mean', 'covariance_matrix'], vector_sample=True, foppl_name='mvn'), 77 | Distribution('Normal', DistributionType.CONTINUOUS, ['loc', 'scale']), 78 | Distribution('Poisson', DistributionType.DISCRETE, ['rate']), 79 | Distribution('Uniform', DistributionType.CONTINUOUS, ['low', 'high']), 80 | Distribution('Exp', DistributionType.CONTINUOUS, ['values'], foppl_name='Exp'), 81 | Distribution('Log', DistributionType.CONTINUOUS, ['values'], foppl_name='Log'), 82 | Distribution('Sin', DistributionType.CONTINUOUS, ['theta'], foppl_name='Sin'), 83 | Distribution('Cos', DistributionType.CONTINUOUS, ['theta'], foppl_name='Cos'), 84 | Distribution('Poly', DistributionType.CONTINUOUS, ['coeff', 'order'], foppl_name='Poly'), 85 | Distribution('factor', DistributionType.UNDEFINED, ['log_p'] ) 86 | 87 | } 88 | 89 | namespace = { 90 | d.foppl_name: 'dist.' + d.python_name for d in distributions 91 | } 92 | 93 | def get_distribution_for_name(name: str) -> Distribution: 94 | if name.startswith("dist."): 95 | return get_distribution_for_name(name[5:]) 96 | for dist in distributions: 97 | if dist.name == name or dist.python_name == name or dist.foppl_name == name: 98 | return dist 99 | return None 100 | -------------------------------------------------------------------------------- /pyppl/fe_clojure/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bayesianbrad/PyLFPPL/5bc160cb00a3d7e9aa3910367ad88cac6b138d99/pyppl/fe_clojure/__init__.py -------------------------------------------------------------------------------- /pyppl/fe_clojure/ppl_clojure_forms.py: -------------------------------------------------------------------------------- 1 | # 2 | # This file is part of PyFOPPL, an implementation of a First Order Probabilistic Programming Language in Python. 3 | # 4 | # License: MIT (see LICENSE.txt) 5 | # 6 | # 20. Feb 2018, Tobias Kohn 7 | # 28. Feb 2018, Tobias Kohn 8 | # 9 | from typing import Optional 10 | import inspect 11 | from ast import copy_location 12 | 13 | class ClojureObject(object): 14 | 15 | _attributes = {'col_offset', 'lineno'} 16 | tag = None 17 | 18 | def visit(self, visitor): 19 | """ 20 | The visitor-object given as argument must provide at least one `visit_XXX`-method to be called by this method. 21 | If the visitor does not provide any specific `visit_XXX`-method to be called, the method will try and call 22 | `visit_node` or `generic_visit`, respectively. 23 | 24 | :param visitor: An object with a `visit_XXX`-method. 25 | :return: The result returned by the `visit_XXX`-method of the visitor. 26 | """ 27 | name = self.__class__.__name__.lower() 28 | method_names = ['visit_' + name + '_form', 'visit_node', 'generic_visit'] 29 | methods = [getattr(visitor, name, None) for name in method_names] 30 | methods = [name for name in methods if name is not None] 31 | if len(methods) == 0 and callable(visitor): 32 | return visitor(self) 33 | elif len(methods) > 0: 34 | result = methods[0](self) 35 | if hasattr(result, '_attributes'): 36 | result = copy_location(result, self) 37 | return result 38 | else: 39 | raise RuntimeError("visitor '{}' has no visit-methods to call".format(type(visitor))) 40 | 41 | 42 | ####################################################################################################################### 43 | 44 | class Form(ClojureObject): 45 | 46 | def __init__(self, items:list, lineno:Optional[int]=None): 47 | self.items = items 48 | if lineno is not None: 49 | self.lineno = lineno 50 | self._special_names = { 51 | '->': 'arrow', 52 | '->>': 'double_arrow', 53 | '.': 'dot' 54 | } 55 | assert type(items) in [list, tuple] 56 | assert all([isinstance(item, ClojureObject) for item in items]) 57 | assert lineno is None or type(lineno) is int 58 | 59 | def __getitem__(self, item): 60 | return self.items[item] 61 | 62 | def __len__(self): 63 | return len(self.items) 64 | 65 | def __repr__(self): 66 | return "({})".format(' '.join([repr(item) for item in self.items])) 67 | 68 | def visit(self, visitor): 69 | name = self.name 70 | if name is not None: 71 | if name in self._special_names: 72 | name = '_sym_' + self._special_names[name] 73 | if name.endswith('?'): 74 | name = 'is_' + name[:-1] 75 | name = name.replace('-', '_').replace('.', '_').replace('/', '_') 76 | name = ''.join([n if n.islower() else "_" + n.lower() for n in name]) 77 | method = getattr(visitor, 'visit_' + name, None) 78 | if method is not None: 79 | arg_count = len(self.items) - 1 80 | has_varargs = inspect.getfullargspec(method).varargs is not None 81 | param_count = len(inspect.getfullargspec(method).args) - 1 82 | has_correct_arg_count = arg_count >= param_count if has_varargs else arg_count == param_count 83 | if not has_correct_arg_count: 84 | s = "at least" if has_varargs else "exactly" 85 | if param_count == 0: 86 | t = "no arguments" 87 | elif param_count == 1: 88 | t = "one argument" 89 | elif param_count == 2: 90 | t = "two arguments" 91 | elif param_count == 3: 92 | t = "three arguments" 93 | else: 94 | t = "{} arguments".format(param_count) 95 | pos = "(line {})".format(self.lineno) if self.lineno is not None else '' 96 | raise TypeError("{}() takes {} {} ({} given) {}".format(self.name, s, t, arg_count, pos)) 97 | 98 | result = method(*self.items[1:]) 99 | if hasattr(result, '_attributes'): 100 | result = copy_location(result, self) 101 | return result 102 | 103 | return super(Form, self).visit(visitor) 104 | 105 | @property 106 | def head(self): 107 | return self.items[0] 108 | 109 | @property 110 | def tail(self): 111 | return Form(self.items[1:]) 112 | 113 | @property 114 | def last(self): 115 | return self.items[-1] 116 | 117 | @property 118 | def name(self): 119 | if len(self.items) > 0 and isinstance(self.items[0], Symbol): 120 | return self.items[0].name 121 | else: 122 | return None 123 | 124 | @property 125 | def is_empty(self): 126 | return len(self.items) == 0 127 | 128 | @property 129 | def non_empty(self): 130 | return len(self.items) > 0 131 | 132 | @property 133 | def length(self): 134 | return len(self.items) 135 | 136 | 137 | class Map(ClojureObject): 138 | 139 | def __init__(self, items:list, lineno:Optional[int]=None): 140 | self.items = items 141 | assert type(items) is list 142 | assert all([isinstance(item, ClojureObject) for item in items]) 143 | assert len(self.items) % 2 == 0 144 | assert lineno is None or type(lineno) is int 145 | 146 | def __repr__(self): 147 | return "{" + ' '.join([repr(item) for item in self.items]) + "}" 148 | 149 | 150 | 151 | class Symbol(ClojureObject): 152 | 153 | def __init__(self, name:str, lineno:Optional[int]=None): 154 | self.name = name 155 | if lineno is not None: 156 | self.lineno = lineno 157 | assert type(name) is str 158 | assert lineno is None or type(lineno) is int 159 | 160 | def __repr__(self): 161 | return self.name 162 | 163 | 164 | class Value(ClojureObject): 165 | 166 | def __init__(self, value, lineno:Optional[int]=None): 167 | self.value = value 168 | if lineno is not None: 169 | self.lineno = lineno 170 | assert value is None or type(value) in [bool, complex, float, int, str] 171 | assert lineno is None or type(lineno) is int 172 | 173 | def __repr__(self): 174 | return repr(self.value) 175 | 176 | 177 | class Vector(ClojureObject): 178 | 179 | def __init__(self, items:list, lineno:Optional[int]=None): 180 | self.items = items 181 | if lineno is not None: 182 | self.lineno = lineno 183 | assert type(items) in [list, tuple] 184 | assert all([isinstance(item, ClojureObject) for item in items]) 185 | assert lineno is None or type(lineno) is int 186 | 187 | def __getitem__(self, item): 188 | return self.items[item] 189 | 190 | def __len__(self): 191 | return len(self.items) 192 | 193 | def __repr__(self): 194 | return "[{}]".format(' '.join([repr(item) for item in self.items])) 195 | 196 | @property 197 | def is_empty(self): 198 | return len(self.items) == 0 199 | 200 | @property 201 | def non_empty(self): 202 | return len(self.items) > 0 203 | 204 | @property 205 | def length(self): 206 | return len(self.items) 207 | 208 | ####################################################################################################################### 209 | 210 | class Visitor(object): 211 | 212 | def visit(self, ast): 213 | if isinstance(ast, ClojureObject): 214 | return ast.visit(self) 215 | elif hasattr(ast, '__iter__'): 216 | return [self.visit(item) for item in ast] 217 | else: 218 | raise TypeError("cannot walk/visit an object of type '{}'".format(type(ast))) 219 | 220 | def visit_node(self, node:ClojureObject): 221 | return node 222 | 223 | 224 | class LeafVisitor(Visitor): 225 | 226 | def visit_symbol(self, node:Symbol): 227 | self.visit_node(node) 228 | 229 | def visit_value(self, node:Value): 230 | self.visit_node(node) 231 | 232 | def visit_form_form(self, node:Form): 233 | for n in node.items: 234 | n.visit(self) 235 | 236 | def visit_map_form(self, node:Map): 237 | for n in node.items: 238 | n.visit(self) 239 | 240 | def visit_symbol_form(self, node:Symbol): 241 | self.visit_symbol(node) 242 | 243 | def visit_value_form(self, node:Value): 244 | self.visit_value(node) 245 | 246 | def visit_vector_form(self, node:Vector): 247 | for n in node.items: 248 | n.visit(self) 249 | 250 | 251 | ####################################################################################################################### 252 | 253 | def is_form(form): 254 | return isinstance(form, Form) 255 | 256 | def is_integer(form): 257 | if isinstance(form, Value): 258 | return type(form.value) is int 259 | else: 260 | return False 261 | 262 | def is_map(form): 263 | return isinstance(form, Map) 264 | 265 | def is_numeric(form): 266 | if isinstance(form, Value): 267 | return type(form.value) in [complex, float, int] 268 | else: 269 | return False 270 | 271 | def is_quoted(form): 272 | return isinstance(form, Form) and len(form) == 2 and is_symbol(form.items[0], 'quote') 273 | 274 | def is_string(form): 275 | if isinstance(form, Value): 276 | return type(form.value) is str 277 | else: 278 | return False 279 | 280 | def is_symbol(form, symbol:str=None): 281 | if isinstance(form, Symbol): 282 | return form.name == symbol if symbol is not None else True 283 | else: 284 | return False 285 | 286 | def is_symbol_vector(form): 287 | if isinstance(form, Vector): 288 | return all([isinstance(item, Symbol) for item in form.items]) 289 | else: 290 | return False 291 | 292 | def is_vector(form): 293 | return isinstance(form, Vector) 294 | -------------------------------------------------------------------------------- /pyppl/fe_clojure/ppl_clojure_lexer.py: -------------------------------------------------------------------------------- 1 | # 2 | # This file is part of PyFOPPL, an implementation of a First Order Probabilistic Programming Language in Python. 3 | # 4 | # License: MIT (see LICENSE.txt) 5 | # 6 | # 20. Feb 2018, Tobias Kohn 7 | # 20. Mar 2018, Tobias Kohn 8 | # 9 | from .. import lexer 10 | from ..fe_clojure import ppl_clojure_forms as clj 11 | from ..lexer import CatCode, TokenType 12 | 13 | 14 | ####################################################################################################################### 15 | 16 | class ClojureLexer(object): 17 | 18 | def __init__(self, text: str): 19 | self.text = text 20 | self.lexer = lexer.Lexer(text) 21 | self.source = lexer.BufferedIterator(self.lexer) 22 | self.lexer.catcodes['\n', ','] = CatCode.WHITESPACE 23 | self.lexer.catcodes['!', '$', '*', '+', '-', '.', '/', '<', '>', '=', '?'] = CatCode.ALPHA 24 | self.lexer.catcodes[';'] = CatCode.LINE_COMMENT 25 | self.lexer.catcodes['#', '\'', '`', '~', '^', '@'] = CatCode.SYMBOL 26 | self.lexer.catcodes['&'] = CatCode.SYMBOL 27 | self.lexer.catcodes['%', ':'] = CatCode.PREFIX 28 | self.lexer.add_symbols('~@', '#\'') 29 | self.lexer.add_string_prefix('#') 30 | self.lexer.add_constant('false', False) 31 | self.lexer.add_constant('nil', None) 32 | self.lexer.add_constant('true', True) 33 | 34 | def __iter__(self): 35 | return self 36 | 37 | def __next__(self): 38 | source = self.source 39 | if source.has_next: 40 | token = source.next() 41 | pos, token_type, value = token 42 | lineno = self.lexer.get_line_from_pos(pos) 43 | 44 | if token_type == TokenType.LEFT_BRACKET: 45 | left = value 46 | result = [] 47 | while source.has_next and source.peek()[1] != TokenType.RIGHT_BRACKET: 48 | result.append(self.__next__()) 49 | 50 | if source.has_next: 51 | token = source.next() 52 | right = token[2] if token is not None else '' 53 | if not token[1] == TokenType.RIGHT_BRACKET: 54 | raise SyntaxError("expected right parentheses or bracket instead of '{}' (line {})".format( 55 | right, self.lexer.get_line_from_pos(token[0]) 56 | )) 57 | if left == '(' and right == ')': 58 | return clj.Form(result, lineno=lineno) 59 | 60 | elif left == '[' and right == ']': 61 | return clj.Vector(result, lineno=lineno) 62 | 63 | elif left == '{' and right == '}': 64 | if len(result) % 2 != 0: 65 | raise SyntaxError("map requires an even number of elements ({} given)".format(len(result))) 66 | return clj.Map(result, lineno=lineno) 67 | 68 | else: 69 | raise SyntaxError("mismatched parentheses: '{}' amd '{}' (line {})".format( 70 | left, right, lineno 71 | )) 72 | 73 | elif token_type == TokenType.NUMBER: 74 | return clj.Value(value, lineno=lineno) 75 | 76 | elif token_type == TokenType.STRING: 77 | return clj.Value(eval(value), lineno=lineno) 78 | 79 | elif token_type == TokenType.VALUE: 80 | return clj.Value(value, lineno=lineno) 81 | 82 | elif token_type == TokenType.SYMBOL: 83 | 84 | if value == '#': 85 | form = self.__next__() 86 | if not isinstance(form, clj.Form): 87 | raise SyntaxError("'#' requires a form to build a function (line {})".format(lineno)) 88 | 89 | params = clj.Vector(_ParameterExtractor().extract_parameters(form)) 90 | return clj.Form(['fn', params, form]) 91 | 92 | elif value == '@': 93 | form = self.__next__() 94 | return clj.Form([clj.Symbol('deref', lineno=lineno), form], lineno=lineno) 95 | 96 | elif value == '\'': 97 | form = self.__next__() 98 | return clj.Form([clj.Symbol('quote', lineno=lineno), form], lineno=lineno) 99 | 100 | elif value == '#\'': 101 | form = self.__next__() 102 | return clj.Form([clj.Symbol('var', lineno=lineno), form], lineno=lineno) 103 | 104 | return clj.Symbol(value, lineno=lineno) 105 | 106 | else: 107 | raise SyntaxError("invalid token: '{}' (line {})".format(token_type, lineno)) 108 | 109 | raise StopIteration 110 | 111 | ####################################################################################################################### 112 | 113 | class _ParameterExtractor(clj.LeafVisitor): 114 | 115 | def __iter__(self): 116 | self.parameters = set() 117 | 118 | def visit_symbol(self, node:clj.Symbol): 119 | n = node.name 120 | if n.startswith('%'): 121 | if len(n) == 1: 122 | self.parameters.add(n) 123 | elif len(n) == 2 and '1' <= n[1] <= '9': 124 | self.parameters.add(n) 125 | else: 126 | raise SyntaxError("invalid parameter: '{}'".format(n)) 127 | 128 | def extract_parameters(self, node): 129 | self.parameters = set() 130 | self.visit(node) 131 | if '%' in self.parameters: 132 | if len(self.parameters) == 1: 133 | return ['%'] 134 | raise TypeError("cannot combine parameters '%' and '%1', '%2', ... in one function") 135 | else: 136 | count = max([ord(n[1])-ord('0') for n in self.parameters]) 137 | result = ['%' + chr(i + ord('1')) for i in range(count)] 138 | return result 139 | -------------------------------------------------------------------------------- /pyppl/fe_clojure/ppl_clojure_repr.py: -------------------------------------------------------------------------------- 1 | # 2 | # This file is part of PyFOPPL, an implementation of a First Order Probabilistic Programming Language in Python. 3 | # 4 | # License: MIT (see LICENSE.txt) 5 | # 6 | # 27. Feb 2018, Tobias Kohn 7 | # 20. Mar 2018, Tobias Kohn 8 | # 9 | from ..ppl_ast import * 10 | 11 | class ClojureRepr(Visitor): 12 | 13 | def __init__(self): 14 | super().__init__() 15 | self.short_names = False # used for debugging 16 | 17 | def visit_indent(self, node, indent=2): 18 | if type(indent) is int: 19 | indent = ' ' * indent 20 | result = self.visit(node) 21 | if result is not None: 22 | result = result.replace('\n', '\n'+indent) 23 | return result 24 | 25 | def visit_attribute(self, node:AstAttribute): 26 | base = self.visit(node.base) 27 | return "(. {} {})".format(base, node.attr) 28 | 29 | def visit_binary(self, node:AstBinary): 30 | left = self.visit(node.left) 31 | right = self.visit(node.right) 32 | return "({} {} {})".format(node.op, left, right) 33 | 34 | def visit_body(self, node:AstBody): 35 | items = [self.visit_indent(item) for item in node.items] 36 | return "(do\n {})".format('\n '.join(items)) 37 | 38 | def visit_break(self, node:AstBreak): 39 | return "(break)" 40 | 41 | def visit_call(self, node:AstCall): 42 | function = self.visit(node.function) 43 | args = [self.visit(item) for item in node.args] 44 | keywords = [''] * node.pos_arg_count + [':{} '.format(key) for key in node.keywords] 45 | args = [a + b for a, b in zip(keywords, args)] 46 | return "({} {})".format(function, ' '.join(args)) 47 | 48 | def visit_compare(self, node:AstCompare): 49 | left = self.visit(node.left) 50 | right = self.visit(node.right) 51 | op = '=' if node.op == '==' else node.op 52 | if node.second_right is not None: 53 | third = self.visit(node.second_right) 54 | if node.op == node.second_op: 55 | return "({} {} {} {})".format(op, left, right, third) 56 | else: 57 | return "(and ({} {} {}) ({} {} {}))".format(op, left, right, node.second_op, right, third) 58 | else: 59 | return "({} {} {})".format(op, left, right) 60 | 61 | def visit_def(self, node:AstDef): 62 | name = node.original_name if self.short_names else node.name 63 | value = self.visit_indent(node.value) 64 | if '\n' in value: 65 | return "(def {}\n {})".format(name, value) 66 | else: 67 | return "(def {} {})".format(name, value) 68 | 69 | def visit_dict(self, node:AstDict): 70 | items = ["{} {}".format(key, node.items[key]) for key in node.items] 71 | return "{" + ', '.join(items) + "}" 72 | 73 | def visit_for(self, node:AstFor): 74 | name = node.original_target if self.short_names else node.target 75 | source = self.visit(node.source) 76 | body = self.visit_indent(node.body) 77 | return "(doseq [{} {}]\n {})".format(name, source, body) 78 | 79 | def visit_function(self, node:AstFunction): 80 | params = node.parameters 81 | if node.vararg is not None: 82 | params.append('& ' + node.vararg) 83 | body = self.visit_indent(node.body) 84 | return "(fn [{}]\n {})".format(' '.join(params), body) 85 | 86 | def visit_if(self, node:AstIf): 87 | test = self.visit(node.test) 88 | body = self.visit_indent(node.if_node) 89 | else_body = self.visit_indent(node.else_node) 90 | if else_body is not None: 91 | return "(if {}\n {}\n {})".format(test, body, else_body) 92 | else: 93 | return "(if {}\n {})".format(test, body) 94 | 95 | def visit_import(self, node:AstImport): 96 | if node.alias is not None: 97 | if node.imported_names is None: 98 | s = ":as {}".format(node.alias) 99 | else: 100 | s = "[{} :as {}]".format(node.imported_names[0], node.alias) 101 | return "(require '{} {})".format(node.module_name, s) 102 | elif node.imported_names is not None: 103 | if len(node.imported_names) == 1 and node.imported_names[0] == '*': 104 | return "(use '{})".format(node.module_name) 105 | else: 106 | return "(require '{} :refer [{}])".format(node.module_name, ' '.join(node.imported_names)) 107 | return "(require '{})".format(node.module_name) 108 | 109 | def visit_let(self, node:AstLet): 110 | name = node.original_target if self.short_names else node.target 111 | source = self.visit(node.source) 112 | body = self.visit_indent(node.body) 113 | return "(let [{} {}]\n {})".format(name, source, body) 114 | 115 | def visit_list_for(self, node:AstListFor): 116 | name = node.original_target if self.short_names else node.target 117 | source = self.visit(node.source) 118 | body = self.visit_indent(node.expr) 119 | return "(for [{} {}]\n {})".format(name, source, body) 120 | 121 | def visit_observe(self, node:AstObserve): 122 | dist = self.visit(node.dist) 123 | value = self.visit(node.value) 124 | return "(observe {} {})".format(dist, value) 125 | 126 | def visit_return(self, node:AstReturn): 127 | value = self.visit(node.value) 128 | return "(return {})".format(value) 129 | 130 | def visit_sample(self, node:AstSample): 131 | dist = self.visit(node.dist) 132 | return "(sample {})".format(dist) 133 | 134 | def visit_slice(self, node:AstSlice): 135 | sequence = self.visit(node.base) 136 | start = self.visit(node.start) 137 | stop = self.visit(node.stop) 138 | if stop is None: 139 | if node.start_as_int == 1: 140 | return "(rest {})".format(sequence) 141 | else: 142 | return "(drop {} {})".format(sequence, start) 143 | elif start is None: 144 | return "(take {} {})".format(sequence, stop) 145 | else: 146 | return "(subvec {} {} {})".format(sequence, start, stop) 147 | 148 | def visit_subscript(self, node:AstSubscript): 149 | sequence = self.visit(node.base) 150 | index = self.visit(node.index) 151 | return "(get {} {})".format(sequence, index) 152 | 153 | def visit_symbol(self, node:AstSymbol): 154 | return node.original_name if self.short_names else node.name 155 | 156 | def visit_unary(self, node:AstUnary): 157 | item = self.visit(node.item) 158 | return "({} {})".format(node.op, item) 159 | 160 | def visit_value(self, node:AstValue): 161 | return repr(node) 162 | 163 | def visit_value_vector(self, node:AstValueVector): 164 | items = [repr(item) for item in node.items] 165 | return "[{}]".format(' '.join(items)) 166 | 167 | def visit_vector(self, node:AstVector): 168 | items = [self.visit(item) for item in node.items] 169 | return "[{}]".format(' '.join(items)) 170 | 171 | def visit_while(self, node:AstWhile): 172 | test = self.visit(node.test) 173 | body = self.visit_indent(node.body) 174 | return "(while {}\n {})".format(test, body) 175 | 176 | 177 | def dump(ast): 178 | """ 179 | Returns a string-representation of the AST which is valid `Clojure`-code. 180 | 181 | :param ast: The AST representing the program. 182 | :return: A string with valid `Clojure`-code. 183 | """ 184 | result = ClojureRepr().visit(ast) 185 | if type(result) is list: 186 | return '\n'.join(result) 187 | else: 188 | return result 189 | -------------------------------------------------------------------------------- /pyppl/fe_clojure/ppl_foppl_parser.py: -------------------------------------------------------------------------------- 1 | # 2 | # This file is part of PyFOPPL, an implementation of a First Order Probabilistic Programming Language in Python. 3 | # 4 | # License: MIT (see LICENSE.txt) 5 | # 6 | # 21. Feb 2018, Tobias Kohn 7 | # 20. Mar 2018, Tobias Kohn 8 | # 9 | from ..fe_clojure import ppl_clojure_forms as clj 10 | from ..ppl_ast import * 11 | from .ppl_clojure_lexer import ClojureLexer 12 | from .ppl_clojure_parser import ClojureParser 13 | 14 | 15 | ####################################################################################################################### 16 | 17 | class FopplParser(ClojureParser): 18 | 19 | def visit_loop(self, count, initial_data, function, *args): 20 | if not clj.is_integer(count): 21 | raise SyntaxError("loop requires an integer value as first argument") 22 | count = count.value 23 | initial_data = initial_data.visit(self) 24 | function = function.visit(self) 25 | args = [arg.visit(self) for arg in args] 26 | result = initial_data 27 | i = 0 28 | while i < count: 29 | result = AstCall(function, [AstValue(i), result] + args) 30 | i += 1 31 | return result 32 | 33 | 34 | ####################################################################################################################### 35 | 36 | def parse(source): 37 | clj_ast = list(ClojureLexer(source)) 38 | ppl_ast = FopplParser().visit(clj_ast) 39 | return ppl_ast 40 | -------------------------------------------------------------------------------- /pyppl/fe_python/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bayesianbrad/PyLFPPL/5bc160cb00a3d7e9aa3910367ad88cac6b138d99/pyppl/fe_python/__init__.py -------------------------------------------------------------------------------- /pyppl/graphs.py: -------------------------------------------------------------------------------- 1 | # 2 | # This file is part of PyFOPPL, an implementation of a First Order Probabilistic Programming Language in Python. 3 | # 4 | # License: MIT (see LICENSE.txt) 5 | # 6 | # 20. Dec 2017, Tobias Kohn 7 | # 07. Jun 2018, Tobias Kohn 8 | # 9 | from typing import Optional 10 | from . import distributions 11 | 12 | 13 | class GraphNode(object): 14 | """ 15 | The base class for all nodes, including the actual graph vertices, but also conditionals, data, and possibly 16 | parameters. 17 | 18 | Each node has a name, which is usually generated automatically. The generation of the name is based on a simple 19 | counter. This generated name (i.e. the counter value inside the name) is used later on to impose a compute order 20 | on the nodes (see the method `get_ordered_list_of_all_nodes` in the `graph`). Hence, you should not change the 21 | naming scheme unless you know exactly what you are doing! 22 | 23 | The set of ancestors provides the edges for the graph and the graphical model, respectively. Note that all 24 | ancestors are always vertices. Conditions, parameters, data, etc. are hold in other fields. This ensures that by 25 | looking at the ancestors of vertices, we get the pure graphical model. 26 | 27 | Finally, the methods `evaluate`, `update` and `update_pdf` are used by the model to sample values and compute 28 | log-pdf, etc. Of course, `evaluate` is just a placeholder here so as to define a minimal interface. Usually, you 29 | will use `update` and `update_pdf` instead of `evaluate`. However, given a `state`-dictionary holding all the 30 | necessary values, it is save to call `evaluate`. 31 | """ 32 | 33 | def __init__(self, name: str, ancestors: Optional[set]=None): 34 | if ancestors is None: 35 | ancestors = set() 36 | self.ancestors = ancestors 37 | self.name = name 38 | self.original_name = name 39 | assert type(self.ancestors) is set 40 | assert type(self.name) is str 41 | assert all([isinstance(item, GraphNode) for item in self.ancestors]) 42 | 43 | @property 44 | def display_name(self): 45 | if hasattr(self, 'original_name'): 46 | name = self.original_name 47 | if name is not None and '.' in name: 48 | name = name.split('.')[-1] 49 | if name is not None and len(name) > 0: 50 | return name.replace('_', '') 51 | return self.name[-3:] 52 | 53 | def create_repr(self, caption: str, **fields): 54 | 55 | def fmt_field(key): 56 | value = fields[key] 57 | if value is None: 58 | return '-' 59 | elif type(value) in (list, set, tuple) and all([isinstance(item, GraphNode) for item in value]): 60 | return ', '.join([item.name for item in value]) 61 | elif type(value) in (list, set, tuple) and \ 62 | all([type(item) is tuple and isinstance(item[0], GraphNode) for item in value]): 63 | return ', '.join(['{}={}'.format(item[0].name, item[1]) for item in value]) 64 | else: 65 | return value 66 | 67 | if len(fields) > 0: 68 | key_len = max(max([len(key) for key in fields]), 9) 69 | fmt = " {:" + str(key_len+2) + "}{}" 70 | result = [fmt.format(key+':', fmt_field(key)) for key in fields if fields[key] is not None] 71 | else: 72 | fmt = " {:11}{}" 73 | result = [] 74 | result.insert(0, fmt.format("Ancestors:", ', '.join([item.name for item in self.ancestors]))) 75 | result.insert(0, fmt.format("Name:", self.name)) 76 | line_no = getattr(self, 'line_number', -1) 77 | if line_no > 0: 78 | result.append(fmt.format("Line:", line_no)) 79 | return "{}\n{}".format(caption, '\n'.join(result)) 80 | 81 | def __repr__(self): 82 | return self.create_repr(self.name) 83 | 84 | def get_code(self): 85 | raise NotImplemented 86 | 87 | 88 | #################################################################################################### 89 | 90 | class ConditionNode(GraphNode): 91 | """ 92 | A `ConditionNode` represents a condition that depends on stochastic variables (vertices). It is not directly 93 | part of the graphical model, but you can think of conditions to be attached to a specific vertex. 94 | 95 | Usually, we try to transform all conditions into the form `f(state) >= 0` (this is not possible for `f(X) == 0`, 96 | through). However, if the condition satisfies this format, the node object has an associated `function`, which 97 | can be evaluated on its own. In other words: you can not only check if a condition is `True` or `False`, but you 98 | can also gain information about the 'distance' to the 'border'. 99 | """ 100 | 101 | __condition_node_counter = 1 102 | 103 | def __init__(self, name: str, *, ancestors: Optional[set]=None, 104 | condition: str, 105 | function: Optional[str]=None, 106 | op: Optional[str]=None, 107 | compare_value: Optional[float]=None): 108 | super().__init__(name, ancestors) 109 | self.condition = condition 110 | self.function = function 111 | self.op = op 112 | self.compare_value = compare_value 113 | self.bit_index = self.__class__.__condition_node_counter 114 | self.__class__.__condition_node_counter *= 2 115 | for a in ancestors: 116 | if isinstance(a, Vertex): 117 | a.add_dependent_condition(self) 118 | 119 | def __repr__(self): 120 | return self.create_repr("Condition", Condition=self.condition, Function=self.function, Op=self.op, 121 | CompareValue=self.compare_value) 122 | 123 | def get_code(self): 124 | return self.condition 125 | 126 | def is_false_from_bit_vector(self, bit_vector): 127 | return (bit_vector & self.bit_index) == 0 128 | 129 | def is_true_from_bit_vector(self, bit_vector): 130 | return (bit_vector & self.bit_index) > 0 131 | 132 | def update_bit_vector(self, state, bit_vector): 133 | if state[self.name] is True: 134 | bit_vector |= self.bit_index 135 | return bit_vector 136 | 137 | 138 | class DataNode(GraphNode): 139 | """ 140 | Data nodes do not carry out any computation, but provide the data. They are used to keep larger data set out 141 | of the code, as large lists are replaced by symbols. 142 | """ 143 | 144 | def __init__(self, name: str, *, ancestors: Optional[set]=None, data: str): 145 | super().__init__(name, ancestors) 146 | self.data_code = data 147 | 148 | def __repr__(self): 149 | return self.create_repr("Data", Data=self.data_code) 150 | 151 | def get_code(self): 152 | return self.data_code 153 | 154 | 155 | class Vertex(GraphNode): 156 | """ 157 | Vertices play the crucial and central role in the graphical model. Each vertex represents either the sampling from 158 | a distribution, or the observation of such a sampled value. 159 | 160 | You can get the entire graphical model by taking the set of vertices and their `ancestors`-fields, containing all 161 | vertices, upon which this vertex depends. However, there is a plethora of additional fields, providing information 162 | about the node and its relationship and status. 163 | 164 | `name`: 165 | The generated name of the vertex. See also: `original_name`. 166 | `original_name`: 167 | In contrast to the `name`-field, this field either contains the name attributed to this value in the original 168 | code, or `None`. 169 | `ancestors`: 170 | The set of all parent vertices. This contains only the ancestors, which are in direct line, and not the parents 171 | of parents. Use the `get_all_ancestors()`-method to retrieve a full list of all ancestors (including parents of 172 | parents of parents of ...). 173 | `dist_ancestors`: 174 | The set of ancestors used for the distribution/sampling, without those used inside the conditions. 175 | `cond_ancestors`: 176 | The set of ancestors, which are linked through conditionals. 177 | `distribution_name`: 178 | The name of the distribution, such as `Normal` or `Gamma`. 179 | `distribution_type`: 180 | Either `"continuous"` or `"discrete"`. You will usually query this field using one of the properties 181 | `is_continuous` or `is_discrete`. 182 | `observation`: 183 | The observation as a string containing Python-code. 184 | `conditions`: 185 | The set of all conditions under which this vertex is evaluated. Each item in the set is actually a tuple of 186 | a `ConditionNode` and a boolean value, to which the condition should evaluate. Note that the conditions are 187 | not owned by a vertex, but might be shared across several vertices. 188 | `dependent_conditions`: 189 | The set of all conditions that depend on this vertex. In other words, all conditions which contain this 190 | vertex in their `get_all_ancestors`-set. 191 | `sample_size`: 192 | The dimension of the samples drawn from this distribution. 193 | """ 194 | 195 | def __init__(self, name: str, *, 196 | ancestors: Optional[set]=None, 197 | condition_nodes: Optional[set]=None, 198 | conditions: Optional[set]=None, 199 | distribution_args: Optional[list]=None, 200 | distribution_arg_names: Optional[list]=None, 201 | distribution_code: str, 202 | distribution_func: Optional[str]=None, 203 | distribution_name: str, 204 | distribution_transform=None, 205 | observation: Optional[str]=None, 206 | observation_value: Optional=None, 207 | original_name: Optional[str]=None, 208 | sample_size: int = 1, 209 | line_number: int = -1): 210 | super().__init__(name, ancestors) 211 | self.condition_nodes = condition_nodes 212 | self.conditions = conditions 213 | self.distribution_args = distribution_args 214 | self.distribution_arg_names = distribution_arg_names 215 | self.distribution_code = distribution_code 216 | self.distribution_func = distribution_func 217 | self.distribution_name = distribution_name 218 | distr = distributions.get_distribution_for_name(distribution_name) 219 | self.distribution_type = distr.distribution_type if distr is not None else None 220 | self.distribution_transform = distribution_transform 221 | self.observation = observation 222 | self.observation_value = observation_value 223 | self.original_name = original_name 224 | self.line_number = line_number 225 | self.sample_size = sample_size 226 | self.dependent_conditions = set() 227 | if conditions is not None: 228 | if self.condition_nodes is None: 229 | self.condition_nodes = set() 230 | for cond, truth_value in conditions: 231 | self.condition_nodes.add(cond) 232 | self.condition_ancestors = set() 233 | if self.condition_nodes is not None: 234 | for cond in self.condition_nodes: 235 | self.condition_ancestors = set.union(self.condition_ancestors, cond.ancestors) 236 | if self.distribution_args is not None and self.distribution_arg_names is not None and \ 237 | len(self.distribution_args) == len(self.distribution_arg_names): 238 | self.distribution_arguments = { n: v for n, v in zip(self.distribution_arg_names, self.distribution_args) } 239 | else: 240 | self.distribution_arguments = None 241 | 242 | def __repr__(self): 243 | args = { 244 | "Conditions": self.conditions, 245 | "Cond-Ancs.": self.condition_ancestors, 246 | "Cond-Nodes": self.condition_nodes, 247 | "Dist-Args": self.distribution_arguments, 248 | "Dist-Code": self.distribution_code, 249 | "Dist-Name": self.distribution_name, 250 | "Dist-Type": self.distribution_type, 251 | "Dist-Transform": self.distribution_transform, 252 | "Sample-Size": self.sample_size, 253 | "Orig. Name": self.original_name, 254 | } 255 | if self.observation is not None: 256 | args["Observation"] = self.observation 257 | title = "Observe" if self.observation is not None else "Sample" 258 | return self.create_repr("Vertex {} [{}]".format(self.name, title), **args) 259 | 260 | def get_code(self, **flags): 261 | if self.distribution_func is not None and self.distribution_args is not None: 262 | args = self.distribution_args[:] 263 | if self.distribution_arg_names is not None: 264 | arg_names = self.distribution_arg_names 265 | if len(arg_names) < len(args): 266 | arg_names = ['{}='.format(n) for n in arg_names] 267 | arg_names = [''] * (len(args)-len(arg_names)) + arg_names 268 | args = ["{}{}".format(a, b) for a, b in zip(arg_names, args) if a not in flags.keys()] 269 | else: 270 | args = ["{}={}".format(a, b) for a, b in zip(arg_names, args) if a not in flags.keys()] 271 | for key in flags: 272 | args.append("{}={}".format(key, flags[key])) 273 | return "{}({})".format(self.distribution_func, ', '.join(args)) 274 | return self.distribution_code 275 | 276 | def get_cond_code(self, state_object: Optional[str]=None): 277 | if self.conditions is not None and len(self.conditions) > 0: 278 | result = [] 279 | for cond, truth_value in self.conditions: 280 | name = cond.name 281 | if state_object is not None: 282 | name = "{}['{}']".format(state_object, name) 283 | if truth_value: 284 | result.append(name) 285 | else: 286 | result.append('not ' + name) 287 | return "if {}:\n\t".format(' and '.join(result)) 288 | else: 289 | return None 290 | 291 | def add_dependent_condition(self, cond: ConditionNode): 292 | self.dependent_conditions.add(cond) 293 | for a in self.ancestors: 294 | a.add_dependent_condition(cond) 295 | 296 | @property 297 | def has_observation(self): 298 | return self.observation is not None 299 | 300 | @property 301 | def get_all_ancestors(self): 302 | result = [] 303 | for a in self.ancestors: 304 | if a not in result: 305 | result.append(a) 306 | result += list(a.get_all_ancestors()) 307 | return set(result) 308 | 309 | @property 310 | def is_conditional(self): 311 | return len(self.dependent_conditions) > 0 312 | 313 | @property 314 | def is_continuous(self): 315 | return self.distribution_type == distributions.DistributionType.CONTINUOUS 316 | 317 | @property 318 | def is_discrete(self): 319 | return self.distribution_type == distributions.DistributionType.DISCRETE 320 | 321 | @property 322 | def is_observed(self): 323 | return self.observation is not None 324 | 325 | @property 326 | def is_sampled(self): 327 | return self.observation is None 328 | 329 | @property 330 | def has_conditions(self): 331 | return self.conditions is not None and len(self.conditions) > 0 332 | -------------------------------------------------------------------------------- /pyppl/parser.py: -------------------------------------------------------------------------------- 1 | # 2 | # This file is part of PyFOPPL, an implementation of a First Order Probabilistic Programming Language in Python. 3 | # 4 | # License: MIT (see LICENSE.txt) 5 | # 6 | # 22. Feb 2018, Tobias Kohn 7 | # 22. Mar 2018, Tobias Kohn 8 | # 9 | from typing import Optional 10 | 11 | from .transforms import (ppl_new_simplifier, ppl_raw_simplifier, ppl_functions_inliner, 12 | ppl_symbol_simplifier, ppl_static_assignments) 13 | from . import ppl_ast 14 | from .fe_clojure import ppl_foppl_parser 15 | from .fe_python import ppl_python_parser 16 | 17 | 18 | def _detect_language(s:str): 19 | for char in s: 20 | if char in ['#']: 21 | return 'py' 22 | 23 | elif char in [';', '(']: 24 | return 'clj' 25 | 26 | elif 'A' <= char <= 'Z' or 'a' <= char <= 'z' or char == '_': 27 | return 'py' 28 | 29 | elif char > ' ': 30 | return 'py' 31 | 32 | return None 33 | 34 | 35 | def parse(source:str, *, simplify:bool=True, language:Optional[str]=None, namespace:Optional[dict]=None): 36 | result = None 37 | if type(source) is str and str != '': 38 | lang = _detect_language(source) if language is None else language.lower() 39 | if lang in ['py', 'python']: 40 | result = ppl_python_parser.parse(source) 41 | 42 | elif lang in ['clj', 'clojure']: 43 | result = ppl_foppl_parser.parse(source) 44 | 45 | elif lang == 'foppl': 46 | result = ppl_foppl_parser.parse(source) 47 | 48 | if type(result) is list: 49 | result = ppl_ast.makeBody(result) 50 | 51 | if result is not None: 52 | if namespace is None: 53 | namespace = {} 54 | raw_sim = ppl_raw_simplifier.RawSimplifier(namespace) 55 | result = raw_sim.visit(result) 56 | if simplify: 57 | result = ppl_functions_inliner.FunctionInliner().visit(result) 58 | result = raw_sim.visit(result) 59 | 60 | if simplify and result is not None: 61 | result = ppl_static_assignments.StaticAssignments().visit(result) 62 | result = ppl_new_simplifier.Simplifier().visit(result) 63 | 64 | result = ppl_symbol_simplifier.SymbolSimplifier().visit(result) 65 | return result 66 | 67 | 68 | def parse_from_file(filename: str, *, simplify:bool=True, language:Optional[str]=None, namespace:Optional[dict]=None): 69 | with open(filename) as f: 70 | source = ''.join(f.readlines()) 71 | return parse(source, simplify=simplify, language=language, namespace=namespace) 72 | -------------------------------------------------------------------------------- /pyppl/ppl_ast_annotators.py: -------------------------------------------------------------------------------- 1 | # 2 | # This file is part of PyFOPPL, an implementation of a First Order Probabilistic Programming Language in Python. 3 | # 4 | # License: MIT (see LICENSE.txt) 5 | # 6 | # 22. Feb 2018, Tobias Kohn 7 | # 20. Mar 2018, Tobias Kohn 8 | # 9 | from typing import Optional 10 | from .ppl_ast import * 11 | 12 | 13 | class NodeInfo(object): 14 | 15 | def __init__(self, *, base=None, 16 | changed_vars:Optional[set]=None, 17 | cond_vars:Optional[set]=None, 18 | free_vars:Optional[set]=None, 19 | has_break:bool=False, 20 | has_cond:bool=False, 21 | has_observe:bool=False, 22 | has_return:bool=False, 23 | has_sample:bool=False, 24 | has_side_effects:bool=False, 25 | return_count:int=0): 26 | 27 | if changed_vars is None: 28 | changed_vars = set() 29 | if cond_vars is None: 30 | cond_vars = set() 31 | if free_vars is None: 32 | free_vars = set() 33 | 34 | if base is None: 35 | bases = [] 36 | elif isinstance(base, NodeInfo): 37 | bases = [base] 38 | elif type(base) in (list, set, tuple) and all([item is None or isinstance(item, NodeInfo) for item in base]): 39 | bases = [item for item in base if item is not None] 40 | else: 41 | raise TypeError("NodeInfo(): wrong type of 'base': '{}'".format(type(base))) 42 | 43 | self.changed_var_count = { k: 1 for k in changed_vars } # type:dict 44 | self.changed_vars = changed_vars # type:set 45 | self.cond_vars = cond_vars # type:set 46 | self.free_vars = free_vars # type:set 47 | self.has_break = has_break # type:bool 48 | self.has_cond = has_cond # type:bool 49 | self.has_observe = has_observe # type:bool 50 | self.has_return = has_return # type:bool 51 | self.has_sample = has_sample # type:bool 52 | self.has_side_effects = has_side_effects # type:bool 53 | self.return_count = return_count # type:int 54 | for item in bases: 55 | self.changed_vars = set.union(self.changed_vars, item.changed_vars) 56 | self.cond_vars = set.union(self.cond_vars, item.cond_vars) 57 | self.free_vars = set.union(self.free_vars, item.free_vars) 58 | self.has_cond = self.has_cond or item.has_cond 59 | self.has_observe = self.has_observe or item.has_observe 60 | self.has_return = self.has_return or item.has_return 61 | self.has_sample = self.has_sample or item.has_sample 62 | self.has_side_effects = self.has_side_effects or item.has_side_effects 63 | self.return_count += item.return_count 64 | for key in item.changed_var_count: 65 | if key not in self.changed_var_count: 66 | self.changed_var_count[key] = 0 67 | self.changed_var_count[key] += item.changed_var_count[key] 68 | 69 | self.has_changed_vars = len(self.changed_vars) > 0 70 | self.has_free_vars = len(self.free_vars) > 0 71 | self.can_embed = not (self.has_observe or self.has_sample or self.has_side_effects or self.has_changed_vars) 72 | self.mutable_vars = set([key for key in self.changed_var_count if self.changed_var_count[key] > 1]) 73 | 74 | assert type(self.changed_vars) is set and all([type(item) is str for item in self.changed_vars]) 75 | assert type(self.free_vars) is set and all([type(item) is str for item in self.free_vars]) 76 | assert type(self.changed_var_count) is dict 77 | assert type(self.cond_vars) is set and all([type(item) is str for item in self.cond_vars]), cond_vars 78 | assert type(self.has_break) is bool 79 | assert type(self.has_cond) is bool 80 | assert type(self.has_observe) is bool 81 | assert type(self.has_return) is bool 82 | assert type(self.has_sample) is bool 83 | assert type(self.has_side_effects) is bool 84 | assert type(self.return_count) is int 85 | 86 | 87 | def clone(self, binding_vars:Optional[set]=None, **kwargs): 88 | result = NodeInfo(base=self) 89 | for key in kwargs: 90 | setattr(result, key, kwargs[key]) 91 | if binding_vars is not None: 92 | result.changed_vars = set.difference(result.changed_vars, binding_vars) 93 | result.cond_vars = set.difference(result.cond_vars, binding_vars) 94 | result.free_vars = set.difference(result.free_vars, binding_vars) 95 | for n in binding_vars: 96 | if n in result.changed_var_count: 97 | del result.changed_var_count[n] 98 | return result 99 | 100 | 101 | def bind_var(self, name): 102 | if type(name) is str: 103 | return self.clone(binding_vars={name}) 104 | 105 | elif type(name) in (list, set, tuple) and all([type(item) is str for item in name]): 106 | return self.clone(binding_vars=set(name)) 107 | 108 | elif name is not None: 109 | raise TypeError("NodeInfo(): cannot bind '{}'".format(name)) 110 | 111 | else: 112 | return self 113 | 114 | 115 | def change_var(self, name): 116 | if type(name) is str: 117 | name = {name} 118 | 119 | elif type(name) in (list, set, tuple) and all([type(item) is str for item in name]): 120 | name = set(name) 121 | 122 | elif name is not None: 123 | raise TypeError("NodeInfo(): cannot add var-name '{}'".format(name)) 124 | 125 | if self.has_cond: 126 | return NodeInfo(base=self, changed_vars=name, has_side_effects=True, cond_vars=name) 127 | else: 128 | return NodeInfo(base=self, changed_vars=name, has_side_effects=True) 129 | 130 | 131 | def union(self, *other): 132 | other = [item for item in other if item is not None] 133 | if len(other) == 0: 134 | return self 135 | elif all([isinstance(item, NodeInfo) for item in other]): 136 | return NodeInfo(base=[self] + other) 137 | else: 138 | raise TypeError("NodeInfo(): cannot build union with '{}'" 139 | .format([item for item in other if not isinstance(item, NodeInfo)])) 140 | 141 | def is_independent(self, other): 142 | assert isinstance(other, NodeInfo) 143 | a = set.intersection(self.free_vars, other.changed_vars) 144 | b = set.intersection(self.changed_vars, other.free_vars) 145 | c = set.intersection(self.changed_vars, other.changed_vars) 146 | return len(a) == len(b) == len(c) == 0 147 | 148 | 149 | class InfoAnnotator(Visitor): 150 | 151 | def visit_node(self, node:AstNode): 152 | return NodeInfo() 153 | 154 | def visit_attribute(self, node: AstAttribute): 155 | return NodeInfo(base=self.visit(node.base), free_vars={node.attr}) 156 | 157 | def visit_binary(self, node: AstBinary): 158 | return NodeInfo(base=(self.visit(node.left), self.visit(node.right))) 159 | 160 | def visit_body(self, node: AstBody): 161 | return NodeInfo(base=[self.visit(item) for item in node.items]) 162 | 163 | def visit_break(self, _): 164 | return NodeInfo(has_break=True) 165 | 166 | def visit_call(self, node: AstCall): 167 | base = [self.visit(node.function)] 168 | args = [self.visit(arg) for arg in node.args] 169 | return NodeInfo(base=base + args) 170 | 171 | def visit_compare(self, node: AstCompare): 172 | return NodeInfo(base=[self.visit(node.left), self.visit(node.right), self.visit(node.second_right)]) 173 | 174 | def visit_def(self, node: AstDef): 175 | result = self.visit(node.value) 176 | return result.change_var(node.name) 177 | 178 | def visit_dict(self, node: AstDict): 179 | return NodeInfo(base=[self.visit(node.items[key]) for key in node.items]) 180 | 181 | def visit_for(self, node: AstFor): 182 | source = self.visit(node.source) 183 | body = self.visit(node.body).bind_var(node.target) 184 | return NodeInfo(base=[body, source]) 185 | 186 | def visit_function(self, node: AstFunction): 187 | body = self.visit(node.body) 188 | return body.bind_var(node.parameters).bind_var(node.vararg) 189 | 190 | def visit_if(self, node: AstIf): 191 | if node.has_else: 192 | base = [self.visit(node.if_node), self.visit(node.else_node)] 193 | else: 194 | base = [self.visit(node.if_node)] 195 | cond_vars = set.union(*[item.changed_vars for item in base]) 196 | return NodeInfo(base=base + [self.visit(node.test)], cond_vars=cond_vars, has_cond=True) 197 | 198 | def visit_import(self, _): 199 | return NodeInfo() 200 | 201 | def visit_let(self, node: AstLet): 202 | result = self.visit(node.body).bind_var(node.target) 203 | result = result.union(self.visit(node.source)) 204 | return result 205 | 206 | def visit_list_for(self, node: AstListFor): 207 | source = self.visit(node.source) 208 | expr = self.visit(node.expr).bind_var(node.target) 209 | return NodeInfo(base=[expr, source]) 210 | 211 | def visit_observe(self, node: AstObserve): 212 | return NodeInfo(base=[self.visit(node.dist), self.visit(node.value)], has_observe=True) 213 | 214 | def visit_return(self, node: AstReturn): 215 | return NodeInfo(base=self.visit(node.value), has_return=True, return_count=1) 216 | 217 | def visit_sample(self, node: AstSample): 218 | return NodeInfo(base=self.visit(node.dist), has_sample=True) 219 | 220 | def visit_slice(self, node: AstSlice): 221 | base = [self.visit(node.base), 222 | self.visit(node.start), 223 | self.visit(node.stop)] 224 | return NodeInfo(base=base) 225 | 226 | def visit_subscript(self, node: AstSubscript): 227 | base = [self.visit(node.base), self.visit(node.index)] 228 | return NodeInfo(base=base) 229 | 230 | def visit_symbol(self, node: AstSymbol): 231 | return NodeInfo(free_vars={node.name}) 232 | 233 | def visit_unary(self, node: AstUnary): 234 | return self.visit(node.item) 235 | 236 | def visit_value(self, _): 237 | return NodeInfo() 238 | 239 | def visit_value_vector(self, _): 240 | return NodeInfo() 241 | 242 | def visit_vector(self, node: AstVector): 243 | return NodeInfo(base=[self.visit(item) for item in node.items]) 244 | 245 | def visit_while(self, node: AstWhile): 246 | base = [self.visit(node.test), self.visit(node.body)] 247 | return NodeInfo(base=base, has_side_effects=True) 248 | 249 | 250 | class VarCountVisitor(Visitor): 251 | 252 | __visit_children_first__ = True 253 | 254 | def __init__(self, name:str): 255 | super().__init__() 256 | self.count = 0 257 | self.name = name 258 | assert type(self.name) is str and self.name != '' 259 | 260 | def visit_node(self, node:AstNode): 261 | return node 262 | 263 | def visit_symbol(self, node:AstSymbol): 264 | if node.name == self.name: 265 | self.count += 1 266 | return node 267 | 268 | 269 | 270 | def get_info(ast:AstNode) -> NodeInfo: 271 | return InfoAnnotator().visit(ast) 272 | 273 | def count_variable_usage(name:str, ast:AstNode): 274 | vcv = VarCountVisitor(name) 275 | vcv.visit(ast) 276 | return vcv.count -------------------------------------------------------------------------------- /pyppl/ppl_base_model.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | ''' 4 | Author: Bradley Gram-Hansen 5 | Time created: 18:51 6 | Date created: 19/03/2018 7 | 8 | License: MIT 9 | ''' 10 | 11 | from abc import ABC, abstractmethod, ABCMeta 12 | 13 | 14 | class base_model(ABC): 15 | 16 | @abstractmethod 17 | def get_vertices(self): 18 | ''' 19 | Generates the vertices of the graphical model. 20 | :return: Set of vertices 21 | ''' 22 | return NotImplementedError 23 | 24 | @abstractmethod 25 | def get_vertices_names(self): 26 | return NotImplementedError 27 | 28 | @abstractmethod 29 | def get_arcs(self): 30 | return NotImplementedError 31 | 32 | @abstractmethod 33 | def get_arcs_names(self): 34 | return NotImplementedError 35 | 36 | @abstractmethod 37 | def get_conditions(self): 38 | return NotImplementedError 39 | 40 | @abstractmethod 41 | def gen_cond_vars(self): 42 | return NotImplementedError 43 | 44 | @abstractmethod 45 | def gen_if_vars(self): 46 | return NotImplementedError 47 | 48 | @abstractmethod 49 | def gen_cont_vars(self): 50 | return NotImplementedError 51 | 52 | @abstractmethod 53 | def gen_disc_vars(self): 54 | return NotImplementedError 55 | 56 | @abstractmethod 57 | def get_vars(self): 58 | return NotImplementedError 59 | 60 | @abstractmethod 61 | def gen_log_prob(self): 62 | return NotImplementedError 63 | 64 | # @abstractmethod 65 | # def gen_log_prob_transformed(self): 66 | # return NotImplementedError 67 | 68 | @abstractmethod 69 | def gen_prior_samples(self): 70 | return NotImplementedError -------------------------------------------------------------------------------- /pyppl/ppl_branch_scopes.py: -------------------------------------------------------------------------------- 1 | # 2 | # This file is part of PyFOPPL, an implementation of a First Order Probabilistic Programming Language in Python. 3 | # 4 | # License: MIT (see LICENSE.txt) 5 | # 6 | # 07. Mar 2018, Tobias Kohn 7 | # 20. Mar 2018, Tobias Kohn 8 | # 9 | from .ppl_ast import * 10 | 11 | def union(items): 12 | if len(items) == 1: 13 | return items[0][1] 14 | 15 | if len(items) == 2: 16 | tests = [i[0] for i in items] 17 | values = [i[1] for i in items] 18 | if values[0] == values[1]: 19 | return values[0] 20 | if tests[0] is None or is_boolean_true(tests[0]) or \ 21 | (isinstance(tests[0], AstUnary) and not isinstance(tests[1], AstUnary)): 22 | tests[0], tests[1] = tests[1], tests[0] 23 | values[0], values[1] = values[1], values[0] 24 | 25 | if is_negation_of(tests[0], tests[1]): 26 | return makeIf(tests[0], values[0], values[1]) 27 | 28 | raise RuntimeError("cannot take the union of '{}'".format(items)) 29 | 30 | 31 | class BranchScope(object): 32 | 33 | def __init__(self, *, parent=None, names=None, condition:Optional[AstNode]=None): 34 | self.condition = condition 35 | self.parent = parent 36 | self.branches = [] 37 | self.current_branch = self # type:BranchScope 38 | if names is not None: 39 | self.values = { key: None for key in names } 40 | else: 41 | self.values = { } 42 | 43 | def new_branching(self, cond:AstNode): 44 | result = BranchScope(parent=self.current_branch, condition=cond) 45 | self.current_branch.branches = [result] 46 | self.current_branch = result 47 | return result 48 | 49 | def switch_branch(self): 50 | self.current_branch = self.current_branch.parent 51 | cond = AstUnary('not', self.current_branch.branches[-1].condition) 52 | result = BranchScope(parent=self.current_branch, condition=cond) 53 | self.current_branch.branches.append(result) 54 | self.current_branch = result 55 | return result 56 | 57 | def end_branching(self): 58 | self.current_branch = self.current_branch.parent 59 | branch = self.current_branch 60 | names = set() 61 | for b in branch.branches: 62 | names = set.union(names, b.names) 63 | values = { key: [] for key in names } 64 | for key in names: 65 | if not all([key in b.values for b in branch.branches]): 66 | values[key].append((None, branch[key])) 67 | for b in branch.branches: 68 | for key in b.values: 69 | values[key].append((b.condition, b.values[key])) 70 | for key in values: 71 | self.values[key] = union(values[key]) 72 | branch.branches = [] 73 | return branch 74 | 75 | def get_value(self, name:str): 76 | assert type(name) is str 77 | if name in self.values: 78 | return self.values[name] 79 | elif isinstance(self.parent, BranchScope): 80 | return self.parent.get_value(name) 81 | elif isinstance(self.parent, BranchScopeVisitor): 82 | return self.parent.branch.get_value(name) 83 | else: 84 | return None 85 | 86 | def set_value(self, name:str, value): 87 | assert type(name) is str 88 | self.values[name] = value 89 | 90 | def __getitem__(self, item): 91 | if type(item) is str: 92 | return self.get_value(item) 93 | else: 94 | raise TypeError("key must be of type 'str', not '{}'".format(type(item))) 95 | 96 | def __setitem__(self, key, value): 97 | if type(key) is str: 98 | return self.set_value(key, value) 99 | else: 100 | raise TypeError("key must be of type 'str', not '{}'".format(type(key))) 101 | 102 | @property 103 | def names(self): 104 | return set(self.values.keys()) 105 | 106 | 107 | class BranchScopeContext(object): 108 | """ 109 | The `BranchScopeContext` is a thin layer used to support scoping in `with`-statements inside methods of 110 | `BranchScopedVisitor`, i.e. `with create_scope(): do something`. 111 | """ 112 | 113 | def __init__(self, visitor): 114 | self.visitor = visitor 115 | 116 | def __enter__(self): 117 | return self.visitor.branch 118 | 119 | def __exit__(self, exc_type, exc_val, exc_tb): 120 | self.visitor.leave_scope() 121 | 122 | 123 | class LockScope(object): 124 | 125 | def __init__(self, prev=None): 126 | self.prev = prev 127 | self.names = set() 128 | self.write_names = set() 129 | assert prev is None or isinstance(prev, LockScope) 130 | 131 | def lock(self, name:str): 132 | if type(name) is str and name != '_' and name != '': 133 | self.names.add(name) 134 | 135 | def lock_write(self, name:str): 136 | if type(name) is str and name != '_' and name != '': 137 | self.write_names.add(name) 138 | 139 | def unlock(self, name:str): 140 | if name in self.names: 141 | self.names.remove(name) 142 | if name in self.write_names: 143 | self.write_names.remove(name) 144 | 145 | def is_locked(self, name:str): 146 | if name in self.names or name in self.write_names: 147 | return True 148 | elif self.prev is not None: 149 | return self.prev.is_locked(name) 150 | else: 151 | return False 152 | 153 | def is_write_locked(self, name:str): 154 | if name in self.write_names: 155 | return True 156 | elif self.prev is not None: 157 | return self.prev.is_locked(name) 158 | else: 159 | return False 160 | 161 | 162 | class NameLockContext(object): 163 | 164 | def __init__(self, visitor): 165 | self.visitor = visitor 166 | 167 | def __enter__(self): 168 | return self.visitor.name_lock 169 | 170 | def __exit__(self, exc_type, exc_val, exc_tb): 171 | self.visitor.leave_name_lock() 172 | 173 | 174 | class BranchScopeVisitor(Visitor): 175 | 176 | def __init__(self, symbols:list): 177 | self.branch = BranchScope(names=symbols) 178 | self.symbols = symbols 179 | self.name_lock = LockScope() 180 | 181 | def enter_scope(self, condition:AstNode): 182 | self.branch.new_branching(condition) 183 | 184 | def leave_scope(self): 185 | self.branch.end_branching() 186 | 187 | def enter_name_lock(self): 188 | self.name_lock = LockScope(self.name_lock) 189 | 190 | def leave_name_lock(self): 191 | self.name_lock = self.name_lock.prev 192 | assert isinstance(self.name_lock, LockScope) 193 | 194 | def create_scope(self, condition:AstNode): 195 | self.enter_scope(condition) 196 | return BranchScopeContext(self) 197 | 198 | def create_lock(self, *names): 199 | self.enter_name_lock() 200 | for n in names: 201 | self.name_lock.lock(n) 202 | return NameLockContext(self) 203 | 204 | def create_write_lock(self): 205 | self.enter_name_lock() 206 | self.lock_all_write() 207 | return NameLockContext(self) 208 | 209 | def switch_branch(self): 210 | self.branch.switch_branch() 211 | 212 | def define(self, name:str, value): 213 | if self.name_lock.is_write_locked(name): 214 | self.name_lock.lock(name) 215 | else: 216 | self.branch[name] = value 217 | 218 | def define_all(self, names:list, values:list, *, vararg:Optional[str]=None): 219 | assert type(names) is list 220 | assert type(values) is list 221 | assert vararg is None or type(vararg) is str 222 | for name, value in zip(names, values): 223 | if isinstance(name, AstSymbol): 224 | name = name.name 225 | if type(name) is str: 226 | self.define(name, value) 227 | if vararg is not None: 228 | self.define(str(vararg), makeVector(values[len(names):]) if len(values) > len(names) else []) 229 | 230 | def resolve(self, name:str): 231 | if not self.name_lock.is_locked(name): 232 | return self.branch[name] 233 | else: 234 | return None 235 | 236 | def lock_all(self): 237 | for symbol in self.symbols: 238 | self.name_lock.lock(symbol.full_name) 239 | 240 | def lock_all_write(self): 241 | for symbol in self.symbols: 242 | self.name_lock.lock_write(symbol.full_name) 243 | 244 | def lock_name(self, name:str): 245 | self.name_lock.lock(name) 246 | 247 | def unlock_name(self, name:str): 248 | self.name_lock.unlock(name) 249 | 250 | def lock_name_write(self, name:str): 251 | self.name_lock.lock_write(name) 252 | 253 | def is_constant(self, name:str): 254 | for sym in self.symbols: 255 | if sym.name == name: 256 | return sym.read_only or sym.modify_count == 0 257 | return False 258 | 259 | def get_usage_count(self, name:str): 260 | for sym in self.symbols: 261 | if sym.name == name: 262 | return sym.usage_count 263 | return None 264 | -------------------------------------------------------------------------------- /pyppl/ppl_namespaces.py: -------------------------------------------------------------------------------- 1 | # 2 | # This file is part of PyFOPPL, an implementation of a First Order Probabilistic Programming Language in Python. 3 | # 4 | # License: MIT (see LICENSE.txt) 5 | # 6 | # 12. Mar 2018, Tobias Kohn 7 | # 12. Mar 2018, Tobias Kohn 8 | # 9 | from importlib import import_module 10 | 11 | def namespace_from_module(module_name: str): 12 | module = import_module(module_name) 13 | if module is not None: 14 | return module.__name__, [name for name in dir(module) if not name.startswith('_')] 15 | else: 16 | return None, [] 17 | -------------------------------------------------------------------------------- /pyppl/ppl_symbol_table.py: -------------------------------------------------------------------------------- 1 | # 2 | # This file is part of PyFOPPL, an implementation of a First Order Probabilistic Programming Language in Python. 3 | # 4 | # License: MIT (see LICENSE.txt) 5 | # 6 | # 07. Mar 2018, Tobias Kohn 7 | # 20. Mar 2018, Tobias Kohn 8 | # 9 | from .types import ppl_types, ppl_type_inference 10 | from .ppl_ast import * 11 | from .ppl_namespaces import namespace_from_module 12 | 13 | class Symbol(object): 14 | 15 | def __init__(self, name:str, read_only:bool=False, missing:bool=False, predef:Optional[str]=None): 16 | global _symbol_counter 17 | self.name = name # type:str 18 | self.usage_count = 0 # type:int 19 | self.modify_count = 0 # type:int 20 | self.read_only = read_only # type:bool 21 | self.value_type = None 22 | if predef is not None: 23 | self.full_name = predef 24 | self.is_predef = True 25 | elif '.' in self.name: 26 | self.full_name = self.name 27 | self.is_predef = True 28 | else: 29 | self.full_name = name 30 | self.is_predef = False 31 | if missing: 32 | self.modify_count = -1 33 | assert type(self.name) is str 34 | assert type(self.read_only) is bool 35 | assert predef is None or type(predef) is str 36 | 37 | def use(self): 38 | self.usage_count += 1 39 | 40 | def modify(self): 41 | self.modify_count += 1 42 | 43 | def get_type(self): 44 | return self.value_type 45 | 46 | def set_type(self, tp): 47 | if self.value_type is not None and tp is not None: 48 | self.value_type = ppl_types.union(self.value_type, tp) 49 | elif tp is not None: 50 | self.value_type = tp 51 | 52 | def __repr__(self): 53 | return "{}[{}/{}{}]".format(self.full_name, self.usage_count, self.modify_count, 'R' if self.read_only else '') 54 | 55 | 56 | class SymbolTableGenerator(ScopedVisitor): 57 | """ 58 | Walks the AST and records all symbols, their definitions and usages. After walking the AST, the field `symbols` 59 | is a list of all symbols used in the program. 60 | 61 | Note that nodes of type `AstSymbol` are modified by walking the tree. In particular, the Symbol-Table-Generator 62 | sets the field `symbol` of `AstSymbol`-nodes and modifies the `name`-field, so that all names in the program are 63 | guaranteed to be unique. 64 | By relying on the fact that all names in the program are unique, we can later on use a flat list of symbol values 65 | without worrying about correct scoping (the scoping is taken care of here). 66 | """ 67 | 68 | def __init__(self, namespace: Optional[dict]=None): 69 | super().__init__() 70 | if namespace is None: 71 | namespace = {} 72 | self.symbols = [] 73 | self.current_lineno = None 74 | self.type_inferencer = ppl_type_inference.TypeInferencer(self) 75 | self.namespace = namespace 76 | 77 | def get_type(self, node:AstNode): 78 | result = self.type_inferencer.visit(node) 79 | return result if result is not None else ppl_types.AnyType 80 | 81 | def get_full_name(self, name:str): 82 | for sym in self.symbols: 83 | if sym.name == name: 84 | return sym.full_name 85 | return name 86 | 87 | def get_item_type(self, node:AstNode): 88 | tp = self.get_type(node) 89 | if isinstance(tp, ppl_types.SequenceType): 90 | return tp.item 91 | else: 92 | return ppl_types.AnyType 93 | 94 | def get_symbols(self): 95 | for symbol in self.symbols: 96 | if symbol.modify_count == 0: 97 | symbol.read_only = True 98 | return self.symbols 99 | 100 | def create_symbol(self, name:str, read_only:bool=False, missing:bool=False): 101 | symbol = Symbol(name, read_only=read_only, missing=missing) 102 | self.symbols.append(symbol) 103 | return symbol 104 | 105 | def g_def(self, name:str, read_only:bool=False, value_type=None): 106 | if name == '_': 107 | return None 108 | symbol = self.global_scope.resolve(name) 109 | if symbol is None: 110 | symbol = self.create_symbol(name, read_only) 111 | self.global_scope.define(name, symbol) 112 | else: 113 | symbol.modify() 114 | if symbol is not None and value_type is not None: 115 | symbol.set_type(value_type) 116 | return symbol 117 | 118 | def l_def(self, name:str, read_only:bool=False, value_type=None): 119 | if name == '_': 120 | return None 121 | symbol = self.resolve(name) 122 | if symbol is None: 123 | symbol = self.create_symbol(name, read_only) 124 | self.define(name, symbol) 125 | else: 126 | symbol.modify() 127 | if symbol is not None and value_type is not None: 128 | symbol.set_type(value_type) 129 | return symbol 130 | 131 | def use_symbol(self, name:str): 132 | if name == '_': 133 | return None 134 | symbol = self.resolve(name) 135 | if symbol is None: 136 | symbol = self.create_symbol(name, missing=True) 137 | self.global_scope.define(name, symbol) 138 | symbol.use() 139 | return symbol 140 | 141 | def import_symbol(self, name:str, full_name:str): 142 | symbol = Symbol(name, read_only=True, predef=full_name) 143 | self.define(name, symbol) 144 | return symbol 145 | 146 | 147 | def visit_node(self, node:AstNode): 148 | node.visit_children(self) 149 | 150 | def visit_def(self, node: AstDef): 151 | self.visit(node.value) 152 | sym = self.resolve(node.name) 153 | if sym is not None and sym.read_only: 154 | raise TypeError("[line {}] cannot modify '{}'".format(self.current_lineno, node.name)) 155 | if node.global_context: 156 | sym = self.g_def(node.name, read_only=False, value_type=self.get_type(node.value)) 157 | else: 158 | sym = self.l_def(node.name, read_only=False, value_type=self.get_type(node.value)) 159 | if sym is not None: 160 | node.name = sym.full_name 161 | 162 | def visit_for(self, node: AstFor): 163 | self.visit(node.source) 164 | with self.create_scope(): 165 | sym = self.l_def(node.target, read_only=True, value_type=self.get_item_type(node.source)) 166 | if sym is not None: 167 | node.target = sym.full_name 168 | self.visit(node.body) 169 | 170 | def visit_function(self, node: AstFunction): 171 | with self.create_scope(): 172 | for i in range(len(node.parameters)): 173 | param = node.parameters[i] 174 | sym = self.l_def(param) 175 | if sym is not None: 176 | node.parameters[i] = sym.full_name 177 | if node.vararg is not None: 178 | sym = self.l_def(node.vararg) 179 | if sym is not None: 180 | node.vararg = sym.full_name 181 | self.visit(node.body) 182 | node.f_locals = set(self.get_full_name(n) for n in self.scope.bindings.keys()) 183 | 184 | def visit_import(self, node: AstImport): 185 | module, names = namespace_from_module(node.module_name) 186 | if node.imported_names is not None: 187 | if node.alias is None: 188 | for name in node.imported_names: 189 | self.import_symbol(name, "{}.{}".format(module, name)) 190 | else: 191 | self.import_symbol(node.alias, "{}.{}".format(module, node.imported_names[0])) 192 | 193 | else: 194 | m = node.module_name if node.alias is None else node.alias 195 | self.import_symbol(m, module) 196 | for name in names: 197 | self.import_symbol("{}.{}".format(m, name), "{}.{}".format(module, name)) 198 | 199 | def visit_let(self, node: AstLet): 200 | self.visit(node.source) 201 | with self.create_scope(): 202 | sym = self.l_def(node.target, read_only=True, value_type=self.get_type(node.source)) 203 | if sym is not None: 204 | node.target = sym.full_name 205 | self.visit(node.body) 206 | 207 | def visit_list_for(self, node: AstListFor): 208 | self.visit(node.source) 209 | with self.create_scope(): 210 | sym = self.l_def(node.target, read_only=True, value_type=self.get_item_type(node.source)) 211 | if sym is not None: 212 | node.target = sym.full_name 213 | self.visit(node.test) 214 | self.visit(node.expr) 215 | 216 | def visit_symbol(self, node: AstSymbol): 217 | if node.original_name in self.namespace: 218 | node.name = self.namespace[node.original_name] 219 | node.original_name = node.name 220 | node.predef = True 221 | if not node.predef: 222 | symbol = self.use_symbol(node.original_name) 223 | node.symbol = symbol 224 | node.name = symbol.full_name 225 | if symbol.is_predef: 226 | node.original_name = node.name 227 | 228 | def visit_while(self, node: AstWhile): 229 | return self.visit_node(node) 230 | 231 | 232 | def generate_symbol_table(ast): 233 | table_generator = SymbolTableGenerator() 234 | table_generator.visit(ast) 235 | result = table_generator.symbols 236 | return result 237 | -------------------------------------------------------------------------------- /pyppl/tests/factor_tests.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | ''' 4 | Author: Bradley Gram-Hansen 5 | Time created: 13:55 6 | Date created: 07/01/2019 7 | 8 | License: MIT 9 | ''' 10 | 11 | from pyppl import compile_model 12 | from pyppl.utils.core import create_network_graph, display_graph 13 | 14 | model_rrhmc_clojure=""" 15 | (let [x (sample (uniform -6 6)) 16 | absx (max x (- x)) 17 | A 0.1 18 | z (- (sqrt (* x (* A x))))] 19 | (if (< (- absx 3) 0) 20 | (observe (factor z) 0) 21 | (observe (factor (- z 1)) 0 )) 22 | x) 23 | """ 24 | 25 | model_rrhmc_python= """ 26 | import torch 27 | x = sample(uniform(-6,6)) 28 | absx = max(x, -x) 29 | A = 0.1 30 | z = -torch.sqrt(x*A*x) 31 | if absx-3 < 0: 32 | observe(factor(z),None) 33 | observe(factor(z-1),None) 34 | """ 35 | compiled_clojure = compile_model(model_rrhmc_clojure, language='clojure') 36 | compiled_python = compile_model(model_rrhmc_python, language='python') 37 | 38 | print(compiled_python.code) 39 | -------------------------------------------------------------------------------- /pyppl/transforms/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bayesianbrad/PyLFPPL/5bc160cb00a3d7e9aa3910367ad88cac6b138d99/pyppl/transforms/__init__.py -------------------------------------------------------------------------------- /pyppl/transforms/ppl_functions_inliner.py: -------------------------------------------------------------------------------- 1 | # 2 | # This file is part of PyFOPPL, an implementation of a First Order Probabilistic Programming Language in Python. 3 | # 4 | # License: MIT (see LICENSE.txt) 5 | # 6 | # 20. Mar 2018, Tobias Kohn 7 | # 21. Mar 2018, Tobias Kohn 8 | # 9 | from ..ppl_ast import * 10 | from ..aux.ppl_transform_visitor import TransformVisitor 11 | 12 | 13 | class FunctionInliner(TransformVisitor): 14 | 15 | def __init__(self): 16 | super().__init__() 17 | self._let_counter = 0 18 | 19 | def visit_call(self, node: AstCall): 20 | if isinstance(node.function, AstSymbol): 21 | function = self.resolve(node.function.name) 22 | elif isinstance(node.function, AstFunction): 23 | function = node.function 24 | else: 25 | function = None 26 | if isinstance(function, AstFunction): 27 | args = [self.visit(arg) for arg in node.args] 28 | tmp = generate_temp_var() 29 | params = function.parameters[:] 30 | if function.vararg is not None: 31 | params.append(function.vararg) 32 | args = function.order_arguments(args, node.keywords) 33 | arguments = [] 34 | for p, a in zip(params, args): 35 | if p != '_' and not isinstance(a, AstSymbol): 36 | arguments.append(AstDef(p + tmp, a)) 37 | elif not isinstance(a, AstSymbol): 38 | arguments.append(a) 39 | with self.create_scope(tmp): 40 | for p, a in zip(params, args): 41 | if p != '_': 42 | if isinstance(a, AstSymbol): 43 | self.define(p, a) 44 | else: 45 | self.define(p, AstSymbol(p + tmp)) 46 | result = self.visit(function.body) 47 | 48 | if isinstance(result, AstReturn): 49 | return makeBody(arguments, result.value) 50 | # result = result.value 51 | # for p, a in zip(reversed(params), reversed(args)): 52 | # if p != '_' and not isinstance(a, AstSymbol): 53 | # result = AstLet(p + tmp, a, result) 54 | # elif not isinstance(a, AstSymbol): 55 | # result = makeBody(a, result) 56 | # return result 57 | 58 | elif isinstance(result, AstBody) and result.last_is_return: 59 | if len(result) > 1: 60 | return makeBody(arguments, result.items[:-1], result.items[-1].value) 61 | else: 62 | return makeBody(arguments, result.items[-1].value) 63 | 64 | return super().visit_call(node) 65 | 66 | def visit_def(self, node: AstDef): 67 | if isinstance(node.value, AstFunction): 68 | self.define(node.name, node.value, globally=node.global_context) 69 | return node 70 | 71 | elif not node.global_context: 72 | tmp = self.scope.name 73 | if tmp is not None and tmp != '': 74 | value = self.visit(node.value) 75 | name = node.name + tmp 76 | self.define(node.name, AstSymbol(name)) 77 | return node.clone(name=name, value=value) 78 | 79 | return super().visit_def(node) 80 | 81 | def visit_let(self, node: AstLet): 82 | self._let_counter += 1 83 | tmp = self.scope.name 84 | if node.target != '_': 85 | if tmp is None: 86 | tmp = '__' 87 | tmp += 'L{}'.format(self._let_counter) 88 | source = self.visit(node.source) 89 | with self.create_scope(tmp): 90 | self.define(node.target, AstSymbol(node.target + tmp)) 91 | body = self.visit(node.body) 92 | return AstLet(node.target + tmp, source, body) 93 | 94 | else: 95 | return super().visit_let(node) 96 | 97 | def visit_symbol(self, node: AstSymbol): 98 | sym = self.resolve(node.name) 99 | if isinstance(sym, AstSymbol): 100 | return sym 101 | else: 102 | return node 103 | -------------------------------------------------------------------------------- /pyppl/transforms/ppl_new_simplifier.py: -------------------------------------------------------------------------------- 1 | # 2 | # This file is part of PyFOPPL, an implementation of a First Order Probabilistic Programming Language in Python. 3 | # 4 | # License: MIT (see LICENSE.txt) 5 | # 6 | # 22. Feb 2018, Tobias Kohn 7 | # 23. Mar 2018, Tobias Kohn 8 | # 9 | from ast import copy_location as _cl 10 | from ..ppl_ast_annotators import * 11 | from ..aux.ppl_transform_visitor import TransformVisitor 12 | from ..types import ppl_types, ppl_type_inference 13 | 14 | 15 | class Simplifier(TransformVisitor): 16 | 17 | def __init__(self): 18 | super().__init__() 19 | self.type_inferencer = ppl_type_inference.TypeInferencer(self) 20 | self.bindings = {} 21 | 22 | def get_type(self, node: AstNode): 23 | result = self.type_inferencer.visit(node) 24 | return result 25 | 26 | def define_name(self, name: str, value): 27 | if name not in ('', '_'): 28 | self.bindings[name] = value 29 | 30 | def resolve_name(self, name: str): 31 | return self.bindings.get(name, None) 32 | 33 | 34 | def visit_binary(self, node:AstBinary): 35 | if is_symbol(node.left) and is_symbol(node.right) and \ 36 | node.op in ('-', '/', '//') and node.left.name == node.right.name: 37 | return AstValue(0 if node.op == '-' else 1) 38 | 39 | left = self.visit(node.left) 40 | right = self.visit(node.right) 41 | op = node.op 42 | if is_number(left) and is_number(right): 43 | return AstValue(node.op_function(left.value, right.value)) 44 | 45 | elif op == '+' and is_string(left) and is_string(right): 46 | return _cl(AstValue(left.value + right.value), node) 47 | 48 | elif op == '+' and isinstance(left, AstValueVector) and isinstance(right, AstValueVector): 49 | return _cl(AstValueVector(left.items + right.items), node) 50 | 51 | elif op == '*' and (is_string(left) and is_integer(right)) or (is_integer(left) and is_string(right)): 52 | return _cl(AstValue(left.value * right.value), node) 53 | 54 | elif op == '*' and isinstance(left, AstValueVector) and is_integer(right): 55 | return _cl(AstValueVector(left.items * right.value), node) 56 | 57 | elif op == '*' and is_integer(left) and isinstance(right, AstValueVector): 58 | return _cl(AstValueVector(left.value * right.items), node) 59 | 60 | elif is_number(left): 61 | value = left.value 62 | if value == 0: 63 | if op in ('+', '|', '^'): 64 | return right 65 | elif op == '-': 66 | return self.visit(_cl(AstUnary('-', right), node)) 67 | elif op in ('*', '/', '//', '%', '&', '<<', '>>', '**'): 68 | return left 69 | 70 | elif value == 1: 71 | if op == '*': 72 | return right 73 | 74 | elif value == -1: 75 | if op == '*': 76 | return self.visit(_cl(AstUnary('-', right), node)) 77 | 78 | if isinstance(right, AstBinary) and is_number(right.left): 79 | r_value = right.left.value 80 | if op == right.op and op in ('+', '-', '*', '&', '|'): 81 | return self.visit(_cl(AstBinary(AstValue(node.op_function(value, r_value)), 82 | '+' if op == '-' else op, 83 | right.right), node)) 84 | 85 | elif op == right.op and op == '/': 86 | return self.visit(_cl(AstBinary(AstValue(value / r_value), '*', right.right), node)) 87 | 88 | elif op in ['+', '-'] and right.op in ['+', '-']: 89 | return self.visit(_cl(AstBinary(AstValue(node.op_function(value, r_value)), '-', right.right), node)) 90 | 91 | elif is_number(right): 92 | value = right.value 93 | if value == 0: 94 | if op in ('+', '-', '|', '^'): 95 | return left 96 | elif op == '**': 97 | return AstValue(1) 98 | elif op == '*': 99 | return right 100 | 101 | elif value == 1: 102 | if op in ('*', '/', '**'): 103 | return left 104 | 105 | elif value == -1: 106 | if op in ('*', '/'): 107 | return self.visit(_cl(AstUnary('-', right), node)) 108 | 109 | if op == '-': 110 | op = '+' 111 | value = -value 112 | right = AstValue(value) 113 | elif op == '/' and value != 0: 114 | op = '*' 115 | value = 1 / value 116 | right = AstValue(value) 117 | 118 | if isinstance(left, AstBinary) and is_number(left.right): 119 | l_value = left.right.value 120 | if op == left.op and op in ('+', '*', '|', '&'): 121 | return self.visit(_cl(AstBinary(left.left, op, AstValue(node.op_function(l_value, value))), node)) 122 | 123 | elif op == left.op and op == '-': 124 | return self.visit(_cl(AstBinary(left.left, '-', AstValue(l_value + value)), node)) 125 | 126 | elif op == left.op and op in ('/', '**'): 127 | return self.visit(_cl(AstBinary(left.left, '/', AstValue(l_value * value)), node)) 128 | 129 | elif op in ['+', '-'] and left.op in ('+', '-'): 130 | return self.visit(_cl(AstBinary(left.left, left.op, AstValue(l_value - value)), node)) 131 | 132 | if op in ('<<', '>>') and type(value) is int: 133 | base = 2 if op == '<<' else 0.5 134 | return _cl(AstBinary(left, '*', AstValue(base ** value)), node) 135 | 136 | elif is_boolean(left) and is_boolean(right): 137 | return _cl(AstValue(node.op_function(left.value, right.value)), node) 138 | 139 | elif is_boolean(left): 140 | if op == 'and': 141 | return right if left.value else AstValue(False) 142 | if op == 'or': 143 | return right if not left.value else AstValue(True) 144 | 145 | elif is_boolean(right): 146 | if op == 'and': 147 | return left if right.value else AstValue(False) 148 | if op == 'or': 149 | return left if not right.value else AstValue(True) 150 | 151 | if op == '-' and isinstance(right, AstUnary) and right.op == '-': 152 | return self.visit(_cl(AstBinary(left, '+', right.item), node)) 153 | 154 | if left is node.left and right is node.right: 155 | return node 156 | else: 157 | return _cl(AstBinary(left, op, right), node) 158 | 159 | def visit_call_clojure_core_conj(self, node: AstCall): 160 | args = [self.visit(arg) for arg in node.args] 161 | if is_vector(args[0]): 162 | result = args[0] 163 | for a in args[1:]: 164 | result = result.conj(a) 165 | return result 166 | else: 167 | return node.clone(args=args) 168 | 169 | def visit_call_len(self, node: AstCall): 170 | if node.arg_count == 1: 171 | arg = self.visit(node.args[0]) 172 | if is_vector(arg): 173 | return AstValue(len(arg)) 174 | arg_type = self.get_type(arg) 175 | if isinstance(arg_type, ppl_types.SequenceType): 176 | if arg_type.size is not None: 177 | return AstValue(arg_type.size) 178 | return self.visit_call(node) 179 | 180 | def visit_call_range(self, node:AstCall): 181 | args = [self.visit(arg) for arg in node.args] 182 | if 1 <= len(args) <= 2 and all([is_integer(arg) for arg in args]): 183 | if len(args) == 1: 184 | result = range(args[0].value) 185 | else: 186 | result = range(args[0].value, args[1].value) 187 | return _cl(AstValueVector(list(result)), node) 188 | 189 | return self.visit_call(node) 190 | 191 | def visit_compare(self, node:AstCompare): 192 | left = self.visit(node.left) 193 | right = self.visit(node.right) 194 | second_right = self.visit(node.second_right) 195 | 196 | if second_right is None: 197 | if is_unary_neg(left) and is_unary_neg(right): 198 | left, right = right.item, left.item 199 | elif is_unary_neg(left) and is_number(right): 200 | left, right = AstValue(-right.value), left.item 201 | elif is_number(left) and is_unary_neg(right) : 202 | right, left = AstValue(-left.value), right.item 203 | 204 | if is_binary_add_sub(left) and is_number(right): 205 | left = self.visit(AstBinary(left, '-', right)) 206 | right = AstValue(0) 207 | elif is_binary_add_sub(right) and is_number(left): 208 | right = self.visit(AstBinary(right, '-', left)) 209 | left = AstValue(0) 210 | 211 | if is_number(left) and is_number(right): 212 | result = node.op_function(left.value, right.value) 213 | if second_right is None: 214 | return _cl(AstValue(result), node) 215 | 216 | elif is_number(second_right): 217 | result = result and node.op_function_2(right.value, second_right.value) 218 | return _cl(AstValue(result), node) 219 | 220 | if node.op in ('in', 'not in') and is_vector(right) and second_right is None: 221 | op = node.op 222 | for item in right: 223 | if left == item: 224 | return AstValue(True if op == 'in' else False) 225 | return AstValue(False if op == 'in' else True) 226 | 227 | return _cl(AstCompare(left, node.op, right, node.second_op, second_right), node) 228 | 229 | def visit_def(self, node: AstDef): 230 | value = self.visit(node.value) 231 | if isinstance(value, AstSample): 232 | return node.clone(value=value) 233 | self.define_name(node.name, value) 234 | return AstBody([]) 235 | 236 | def visit_for(self, node: AstFor): 237 | source = self.visit(node.source) 238 | if is_vector(source): 239 | items = [] 240 | for item in source: 241 | items.append(AstDef(node.target, item)) 242 | items.append(node.body) 243 | return self.visit(makeBody(items)) 244 | else: 245 | src_type = self.get_type(source) 246 | if isinstance(src_type, ppl_types.SequenceType) and src_type.size is not None: 247 | items = [] 248 | for i in range(src_type.size): 249 | items.append(AstDef(node.target, makeSubscript(source, i))) 250 | items.append(node.body) 251 | return self.visit(makeBody(items)) 252 | 253 | raise RuntimeError("cannot unroll the for-loop [line {}]".format(getattr(node, 'lineno', '?'))) 254 | 255 | def visit_if(self, node: AstIf): 256 | test = self.visit(node.test) 257 | if isinstance(test, AstValue): 258 | if test.value is True: 259 | return self.visit(node.if_node) 260 | if test.value is False or test.value is None: 261 | return self.visit(node.else_node) 262 | 263 | if_node = self.visit(node.if_node) 264 | else_node = self.visit(node.else_node) 265 | if is_empty(if_node) and is_empty(else_node): 266 | return test 267 | return node.clone(test=test, if_node=if_node, else_node=else_node) 268 | 269 | def visit_list_for(self, node:AstListFor): 270 | source = self.visit(node.source) 271 | if is_vector(source): 272 | src_len = len(source) 273 | else: 274 | src_type = self.get_type(source) 275 | if isinstance(src_type, ppl_types.SequenceType): 276 | src_len = src_type.size 277 | else: 278 | src_len = None 279 | 280 | if node.test is None: 281 | if node.target == '_' and src_len is not None: 282 | if isinstance(node.expr, AstSample) and node.expr.size is None: 283 | return self.visit(node.expr.clone(size=AstValue(src_len))) 284 | else: 285 | return self.visit(_cl(makeVector([node.expr for _ in range(src_len)]), node)) 286 | 287 | if is_vector(source): 288 | items = [] 289 | for item in source: 290 | items.append(AstDef(node.target, item)) 291 | items.append(node.expr) 292 | return self.visit(makeVector(items)) 293 | 294 | elif src_len is not None: 295 | items = [] 296 | for i in range(src_len): 297 | items.append(AstDef(node.target, makeSubscript(source, i))) 298 | items.append(node.expr) 299 | return self.visit(makeVector(items)) 300 | 301 | raise RuntimeError("cannot unroll the for-loop [line {}]".format(getattr(node, 'lineno', '?'))) 302 | 303 | def visit_subscript(self, node: AstSubscript): 304 | base = self.visit(node.base) 305 | index = self.visit(node.index) 306 | if is_vector(base) and is_integer(index): 307 | return base[index.value] 308 | else: 309 | return node.clone(base=base, index=index) 310 | 311 | def visit_symbol(self, node: AstSymbol): 312 | value = self.resolve_name(node.name) 313 | if value is not None: 314 | return value 315 | else: 316 | return node 317 | 318 | def visit_unary(self, node:AstUnary): 319 | op = node.op 320 | if op == '+': 321 | return self.visit(node.item) 322 | 323 | if op == 'not': 324 | item = node.item._visit_expr(self) 325 | if isinstance(item, AstCompare) and item.second_right is None: 326 | return self.visit(_cl(AstCompare(item.left, item.neg_op, item.right), node)) 327 | 328 | if isinstance(item, AstBinary) and item.op in ('and', 'or'): 329 | return self.visit(_cl(AstBinary(AstUnary('not', item.left), 'and' if item.op == 'or' else 'or', 330 | AstUnary('not', item.right)), node)) 331 | 332 | if is_boolean(item): 333 | return _cl(AstValue(not item.value), node) 334 | 335 | if isinstance(node.item, AstUnary) and op == node.item.op: 336 | return self.visit(node.item.item) 337 | 338 | item = self.visit(node.item) 339 | if is_number(item): 340 | if op == '-': 341 | return _cl(AstValue(-item.value), node) 342 | 343 | if item is node.item: 344 | return node 345 | else: 346 | return node.clone(item=item) 347 | 348 | def visit_vector(self, node:AstVector): 349 | items = [self.visit(item) for item in node.items] 350 | if len(items) > 0 and all([isinstance(item, AstSample) and item.size is None for item in items]) and \ 351 | all([item.dist == items[0].dist for item in items]): 352 | result = _cl(AstSample(items[0].dist, size=AstValue(len(items))), node) 353 | original_name = getattr(node, 'original_name', None) 354 | if original_name is not None: 355 | result.original_name = original_name 356 | return result 357 | return makeVector(items) 358 | -------------------------------------------------------------------------------- /pyppl/transforms/ppl_raw_simplifier.py: -------------------------------------------------------------------------------- 1 | # 2 | # This file is part of PyFOPPL, an implementation of a First Order Probabilistic Programming Language in Python. 3 | # 4 | # License: MIT (see LICENSE.txt) 5 | # 6 | # 09. Mar 2018, Tobias Kohn 7 | # 23. Mar 2018, Tobias Kohn 8 | # 9 | from ast import copy_location as _cl 10 | from ..ppl_ast import * 11 | from ..ppl_ast_annotators import get_info 12 | from ..ppl_namespaces import namespace_from_module 13 | 14 | 15 | class RawSimplifier(ScopedVisitor): 16 | 17 | def __init__(self, symbols:dict): 18 | super().__init__() 19 | self.imports = set() 20 | for key in symbols: 21 | self.define(key, AstSymbol(symbols[key], predef=True)) 22 | 23 | def split_expr(self, node:AstNode): 24 | if node is None: 25 | return None 26 | elif isinstance(node, AstBody): 27 | if len(node) == 0: 28 | return [], AstValue(None) 29 | elif len(node) == 1: 30 | return [], node[0] 31 | else: 32 | return node.items[:-1], node.items[-1] 33 | elif isinstance(node, AstCall) and not node.is_builtin: 34 | tmp = generate_temp_var() 35 | return [AstDef(tmp, node, global_context=False)], AstSymbol(tmp) 36 | else: 37 | return [], node 38 | 39 | def _visit_expr(self, node:AstNode): 40 | return self.split_expr(self.visit(node)) 41 | 42 | 43 | def visit_attribute(self, node:AstAttribute): 44 | base = self.visit(node.base) 45 | if isinstance(base, AstNamespace): 46 | if node.attr in base.bindings: 47 | return base.bindings[node.attr] 48 | else: 49 | return AstSymbol("{}.{}".format(base.name, node.attr)) 50 | elif base is node.base: 51 | return node 52 | else: 53 | return node.clone(base=base) 54 | 55 | def visit_binary(self, node:AstBinary): 56 | l_prefix, left = self._visit_expr(node.left) 57 | r_prefix, right = self._visit_expr(node.right) 58 | prefix = l_prefix + r_prefix 59 | 60 | if left is node.left and right is node.right: 61 | return node 62 | else: 63 | prefix.append(AstBinary(left, node.op, right)) 64 | return _cl(makeBody(prefix), node) 65 | 66 | def visit_body(self, node:AstBody): 67 | items = [self.visit(item) for item in node.items] 68 | 69 | i = len(items)-1 70 | while i >= 0: 71 | item = items[i] 72 | if isinstance(item, AstIf): 73 | if has_return(item.if_node) and not has_return(item.else_node): 74 | items[i] = self.visit(AstIf(item.test, item.if_node, makeBody(item.else_node, items[i+1:]))) 75 | items = items[:i+1] 76 | if has_return(item.else_node) and not has_return(item.if_node): 77 | items[i] = self.visit(AstIf(item.test, makeBody(item.if_node, items[i+1:]).item.else_node)) 78 | items = items[:i+1] 79 | i -= 1 80 | 81 | return _cl(makeBody(items), node) 82 | 83 | def visit_call(self, node: AstCall): 84 | if node.arg_count > 0: 85 | function = self.visit(node.function) 86 | prefix = [] 87 | args = [] 88 | for arg in node.args: 89 | p, a = self._visit_expr(arg) 90 | prefix += p 91 | args.append(a) 92 | return makeBody(prefix, node.clone(function=function, args=args)) 93 | else: 94 | function = self.visit(node.function) 95 | if function is node.function: 96 | return node 97 | else: 98 | return node.clone(function=function) 99 | 100 | def visit_compare(self, node: AstCompare): 101 | l_prefix, left = self._visit_expr(node.left) 102 | r_prefix, right = self._visit_expr(node.right) 103 | prefix = l_prefix + r_prefix 104 | if node.second_right is not None: 105 | s_prefix, sec_right = self._visit_expr(node.second_right) 106 | prefix += s_prefix 107 | else: 108 | sec_right = None 109 | 110 | if left is node.left and right is node.right and sec_right is node.second_right: 111 | return node 112 | else: 113 | prefix.append(AstCompare(left, node.op, right, node.second_op, node.second_right)) 114 | return _cl(makeBody(prefix), node) 115 | 116 | def visit_def(self, node: AstDef): 117 | if getattr(node.value, 'original_name', None) is None: 118 | node.value.original_name = node.name 119 | value = self.visit(node.value) 120 | if value is node.value: 121 | return node 122 | else: 123 | return node.clone(value=value) 124 | 125 | def visit_dict(self, node: AstDict): 126 | if len(node) > 0: 127 | prefix = [] 128 | result = {} 129 | for key in node.items: 130 | p, i = self.visit(node.items[key]) 131 | prefix += p 132 | result[key] = i 133 | return _cl(makeBody(prefix, AstDict(result)), node) 134 | else: 135 | return node 136 | 137 | def visit_for(self, node: AstFor): 138 | prefix, source = self._visit_expr(node.source) 139 | body = self.visit(node.body) 140 | target = node.target if node.target in get_info(body).free_vars else '_' 141 | if target is node.target and source is node.source and body is node.body: 142 | return node 143 | else: 144 | return _cl(makeBody(prefix, AstFor(target, source, body, original_target=node.original_target)), node) 145 | 146 | def visit_function(self, node: AstFunction): 147 | with self.create_scope(): 148 | body = self.visit(node.body) 149 | if body is node.body: 150 | return node 151 | else: 152 | return node.clone(body=body) 153 | 154 | def visit_if(self, node: AstIf): 155 | prefix, test = self._visit_expr(node.test) 156 | if_node = self.visit(node.if_node) 157 | else_node = self.visit(node.else_node) 158 | 159 | if isinstance(if_node, AstReturn): 160 | if isinstance(else_node, AstReturn): 161 | return _cl(makeBody(prefix, AstReturn(AstIf(test, if_node.value, else_node.value))), node) 162 | elif is_non_empty_body(else_node) and else_node.last_is_return: 163 | tmp = generate_temp_var() 164 | if_node = _cl(AstDef(tmp, if_node.value, global_context=False), if_node) 165 | else_node = _cl(makeBody(else_node.items[-1], AstDef(tmp, else_node.items[-1].value, global_context=False)), else_node) 166 | return _cl(makeBody(prefix, AstIf(test, if_node, else_node), AstReturn(AstSymbol(tmp))), node) 167 | 168 | elif is_non_empty_body(if_node) and if_node.last_is_return: 169 | if isinstance(else_node, AstReturn): 170 | tmp = generate_temp_var() 171 | if_node = _cl(makeBody(if_node.items[:-1], AstDef(tmp, if_node.items[-1].value, global_context=False)), if_node) 172 | else_node = _cl(AstDef(tmp, else_node.value, global_context=False), else_node) 173 | return _cl(makeBody(prefix, AstIf(test, if_node, else_node), AstReturn(AstSymbol(tmp))), node) 174 | elif is_non_empty_body(else_node) and else_node.last_is_return: 175 | tmp = generate_temp_var() 176 | if_node = _cl(makeBody(if_node.items[:-1], AstDef(tmp, if_node.items[-1].value, global_context=False)), if_node) 177 | else_node = _cl(makeBody(else_node.items[:-1], AstDef(tmp, else_node.items[-1].value, global_context=False)), else_node) 178 | return _cl(makeBody(prefix, AstIf(test, if_node, else_node), AstReturn(AstSymbol(tmp))), node) 179 | 180 | if test is node.test and if_node is node.if_node and else_node is node.else_node: 181 | return node 182 | else: 183 | return _cl(makeBody(prefix, AstIf(test, if_node, else_node)), node) 184 | 185 | def visit_import(self, node: AstImport): 186 | module_name, names = namespace_from_module(node.module_name) 187 | if node.imported_names is not None: 188 | if node.alias is None: 189 | for name in node.imported_names: 190 | self.define(name, AstSymbol("{}.{}".format(module_name, name), predef=True)) 191 | else: 192 | self.define(node.alias, AstSymbol("{}.{}".format(module_name, node.imported_names[0]), predef=True)) 193 | 194 | else: 195 | bindings = { key: AstSymbol("{}.{}".format(module_name, key), predef=True) for key in names } 196 | ns = AstNamespace(module_name, bindings) 197 | if node.alias is not None: 198 | self.define(node.alias, ns) 199 | else: 200 | self.define(node.module_name, ns) 201 | 202 | if module_name is not None: 203 | self.imports.add(module_name) 204 | else: 205 | self.imports.add(node.module_name) 206 | return node # AstBody([]) # _cl(AstImport(module_name), node) 207 | 208 | def visit_let(self, node: AstLet): 209 | prefix, source = self._visit_expr(node.source) 210 | body = self.visit(node.body) 211 | if source is node.source and body is node.body: 212 | return node 213 | else: 214 | return _cl(makeBody(prefix, AstLet(node.target, source, body, original_target=node.original_target)), node) 215 | 216 | def visit_list_for(self, node: AstListFor): 217 | prefix, source = self._visit_expr(node.source) 218 | expr = self.visit(node.expr) 219 | target = node.target if node.target in get_info(expr).free_vars else '_' 220 | if target is node.target and source is node.source and expr is node.expr: 221 | return node 222 | else: 223 | return makeBody(prefix, node.clone(target=target, source=source, expr=expr)) 224 | 225 | def visit_observe(self, node: AstObserve): 226 | d_prefix, dist = self._visit_expr(node.dist) 227 | v_prefix, value = self._visit_expr(node.value) 228 | # keep it from being over-zealous 229 | if len(d_prefix) == 1 and isinstance(d_prefix[0], AstDef) and isinstance(dist, AstSymbol) and \ 230 | d_prefix[0].name == dist.name and isinstance(d_prefix[0].value, AstCall): 231 | d_prefix, dist = [], d_prefix[0].value 232 | if dist is node.dist and value is node.value: 233 | return node 234 | else: 235 | prefix = d_prefix + v_prefix 236 | return makeBody(prefix, node.clone(dist=dist, value=value)) 237 | 238 | def visit_return(self, node: AstReturn): 239 | prefix, value = self._visit_expr(node.value) 240 | if value is node.value: 241 | return node 242 | else: 243 | return _cl(makeBody(prefix, AstReturn(value)), node) 244 | 245 | def visit_sample(self, node: AstSample): 246 | prefix, dist = self._visit_expr(node.dist) 247 | # keep it from being over zealous 248 | 249 | if len(prefix) == 1 and isinstance(prefix[0], AstDef) and isinstance(dist, AstSymbol) and \ 250 | prefix[0].name == dist.name and isinstance(prefix[0].value, AstCall): 251 | prefix, dist = [], prefix[0].value 252 | if node.size is not None: 253 | s_prefix, size = self._visit_expr(node.size) 254 | prefix += s_prefix 255 | else: 256 | size = None 257 | if dist is node.dist and size is node.size: 258 | return node 259 | else: 260 | return makeBody(prefix, node.clone(dist=dist, size=size)) 261 | 262 | def visit_slice(self, node: AstSlice): 263 | prefix, base = self._visit_expr(node.base) 264 | a_prefix, a = self._visit_expr(node.start) 265 | b_prefix, b = self._visit_expr(node.stop) 266 | prefix += a_prefix 267 | prefix += b_prefix 268 | if base is node.base and a is node.start and b is node.stop: 269 | return node 270 | else: 271 | return _cl(makeBody(prefix, AstSlice(base, a, b)), node) 272 | 273 | def visit_subscript(self, node: AstSubscript): 274 | base_prefix, base = self._visit_expr(node.base) 275 | index_prefix, index = self._visit_expr(node.index) 276 | if base is node.base and index is node.index: 277 | return node 278 | else: 279 | prefix = base_prefix + index_prefix 280 | return _cl(makeBody(prefix, makeSubscript(base, index)), node) 281 | 282 | def visit_symbol(self, node: AstSymbol): 283 | symbol = self.resolve(node.name) 284 | if symbol is not None: 285 | return symbol 286 | return node 287 | 288 | def visit_unary(self, node: AstUnary): 289 | # when applying an unary operator twice, it usually cancels, so we can get rid of it entirely 290 | if isinstance(node.item, AstUnary) and node.op == node.item.op: 291 | if node.op in ('not', '+', '-'): 292 | return self.visit(node.item.item) 293 | prefix, item = self._visit_expr(node.item) 294 | if item is node.item: 295 | return node 296 | else: 297 | prefix.append(AstUnary(node.op, item)) 298 | return _cl(makeBody(prefix), node) 299 | 300 | def visit_vector(self, node: AstVector): 301 | original_name = getattr(node, 'original_name', None) 302 | if original_name is not None: 303 | i = 0 304 | for item in node.items: 305 | if getattr(item, 'original_name', None) is None: 306 | item.original_name = "{}[{}]".format(original_name, i) 307 | i += 1 308 | prefix = [] 309 | items = [] 310 | for item in node.items: 311 | p, i = self._visit_expr(item) 312 | prefix += p 313 | items.append(i) 314 | return _cl(makeBody(prefix, makeVector(items)), node) 315 | 316 | def visit_while(self, node: AstWhile): 317 | return self.visit_node(node) 318 | -------------------------------------------------------------------------------- /pyppl/transforms/ppl_static_assignments.py: -------------------------------------------------------------------------------- 1 | # 2 | # This file is part of PyFOPPL, an implementation of a First Order Probabilistic Programming Language in Python. 3 | # 4 | # License: MIT (see LICENSE.txt) 5 | # 6 | # 20. Mar 2018, Tobias Kohn 7 | # 21. Mar 2018, Tobias Kohn 8 | # 9 | from ..ppl_ast import * 10 | from ..aux.ppl_transform_visitor import TransformVisitor 11 | from ast import copy_location as _cl 12 | 13 | 14 | class Symbol(object): 15 | 16 | def __init__(self, name): 17 | self.name = name 18 | self.counter = 0 19 | 20 | def get_new_instance(self): 21 | self.counter += 1 22 | return self.get_current_instance() 23 | 24 | def get_current_instance(self): 25 | if self.counter == 1: 26 | return self.name 27 | else: 28 | return self.name + str(self.counter) 29 | 30 | 31 | class SymbolScope(object): 32 | 33 | def __init__(self, prev, items=None, is_loop:bool=False): 34 | self.prev = prev 35 | self.bindings = {} 36 | self.items = items 37 | self.is_loop = is_loop 38 | 39 | def get_current_symbol(self, name: str): 40 | if name in self.bindings: 41 | return self.bindings[name] 42 | elif self.prev is not None: 43 | return self.prev.get_current_symbol(name) 44 | else: 45 | return name 46 | 47 | def has_current_symbol(self, name: str): 48 | if name in self.bindings: 49 | return True 50 | elif self.prev is not None: 51 | return self.prev.has_current_symbol(name) 52 | else: 53 | return False 54 | 55 | def set_current_symbol(self, name: str, instance_name: str): 56 | self.bindings[name] = instance_name 57 | 58 | def append(self, item): 59 | if self.items is not None: 60 | self.items.append(item) 61 | return True 62 | else: 63 | return False 64 | 65 | 66 | class StaticAssignments(TransformVisitor): 67 | 68 | def __init__(self): 69 | super().__init__() 70 | self.symbols = {} 71 | self.symbol_scope = SymbolScope(None) 72 | 73 | def new_symbol_instance(self, name: str): 74 | if name not in self.symbols: 75 | self.symbols[name] = Symbol(name) 76 | result = self.symbols[name].get_new_instance() 77 | self.symbol_scope.set_current_symbol(name, result) 78 | return result 79 | 80 | def access_symbol(self, name: str): 81 | result = self.symbol_scope.get_current_symbol(name) 82 | return result 83 | 84 | def has_symbol(self, name: str): 85 | return self.symbol_scope.has_current_symbol(name) 86 | 87 | def begin_scope(self, items=None, is_loop:bool=False): 88 | self.symbol_scope = SymbolScope(self.symbol_scope, items, is_loop) 89 | 90 | def end_scope(self): 91 | scope = self.symbol_scope 92 | self.symbol_scope = scope.prev 93 | return scope.bindings 94 | 95 | def append_to_body(self, item: AstNode): 96 | return self.symbol_scope.append(item) 97 | 98 | def is_loop_scope(self): 99 | return self.symbol_scope.is_loop 100 | 101 | def split_body(self, node: AstNode): 102 | if isinstance(node, AstBody): 103 | if len(node) == 0: 104 | return None, AstValue(None) 105 | elif len(node) == 1: 106 | return None, node[0] 107 | else: 108 | return node.items[:-1], node.items[-1] 109 | else: 110 | return None, node 111 | 112 | def visit_and_split(self, node: AstNode): 113 | return self.split_body(self.visit(node)) 114 | 115 | def visit_in_scope(self, node: AstNode, is_loop:bool=False): 116 | items = [] 117 | self.begin_scope(items, is_loop) 118 | if isinstance(node, AstBody): 119 | for item in node.items: 120 | items.append(self.visit(item)) 121 | else: 122 | items.append(self.visit(node)) 123 | result = _cl(makeBody(items), node) 124 | symbols = self.end_scope() 125 | return symbols, result 126 | 127 | 128 | def visit_attribute(self, node:AstAttribute): 129 | prefix, base = self.visit_and_split(node.base) 130 | if prefix is not None: 131 | return makeBody(prefix, node.clone(base=base)) 132 | if base is node.base: 133 | return node 134 | else: 135 | return node.clone(base=base) 136 | 137 | def visit_binary(self, node:AstBinary): 138 | prefix_l, left = self.visit_and_split(node.left) 139 | prefix_r, right = self.visit_and_split(node.right) 140 | if prefix_l is not None and prefix_r is not None: 141 | prefix = prefix_l + prefix_r 142 | return makeBody(prefix, node.clone(left=left, right=right)) 143 | elif prefix_l is not None: 144 | return makeBody(prefix_l, node.clone(left=left, right=right)) 145 | elif prefix_r is not None: 146 | return makeBody(prefix_r, node.clone(left=left, right=right)) 147 | 148 | if left is node.left and right is node.right: 149 | return node 150 | else: 151 | return node.clone(left=left, right=right) 152 | 153 | def _visit_call(self, node: AstCall): 154 | prefix = [] 155 | args = [] 156 | for item in node.args: 157 | p, a = self.visit_and_split(item) 158 | if p is not None: 159 | prefix += p 160 | args.append(a) 161 | 162 | if len(prefix) > 0: 163 | return makeBody(prefix, node.clone(args=args)) 164 | else: 165 | return node.clone(args=args) 166 | 167 | def visit_call(self, node: AstCall): 168 | tmp = generate_temp_var() 169 | result = AstDef(tmp, self._visit_call(node)) 170 | if self.append_to_body(result): 171 | return AstSymbol(tmp) 172 | else: 173 | return makeBody(result, AstSymbol(tmp)) 174 | 175 | def visit_call_range(self, node: AstCall): 176 | if node.arg_count == 1 and is_integer(node.args[0]): 177 | return makeVector(list(range(node.args[0].value))) 178 | else: 179 | return self.visit_call(node) 180 | 181 | def visit_compare(self, node: AstCompare): 182 | prefix_l, left = self.visit_and_split(node.left) 183 | prefix_r, right = self.visit_and_split(node.right) 184 | if node.second_right is not None: 185 | prefix_s, second_right = self.visit_and_split(node.second_right) 186 | else: 187 | prefix_s, second_right = None, None 188 | 189 | if prefix_l is not None or prefix_r is not None or prefix_s is not None: 190 | prefix = prefix_l if prefix_l is not None else [] 191 | if prefix_r is not None: prefix += prefix_r 192 | if prefix_s is not None: prefix += prefix_s 193 | return makeBody(prefix, node.clone(left=left, right=right, second_right=second_right)) 194 | 195 | if left is node.left and right is node.right and second_right is node.second_right: 196 | return node 197 | else: 198 | return node.clone(left=left, right=right, second_right=second_right) 199 | 200 | def visit_def(self, node: AstDef): 201 | if isinstance(node.value, AstObserve): 202 | # We can never assign an observe to something! 203 | result = [self.visit(node.value), 204 | self.visit(node.clone(value=AstValue(None)))] 205 | return makeBody(result) 206 | 207 | elif isinstance(node.value, AstSample): 208 | # We need to handle this as a special case in order to avoid an infinite loop 209 | value = self._visit_sample(node.value) 210 | name = self.new_symbol_instance(node.name) 211 | return node.clone(name=name, value=value) 212 | 213 | elif isinstance(node.value, AstCall): 214 | result = self._visit_call(node.value) 215 | name = self.new_symbol_instance(node.name) 216 | return node.clone(name=name, value=result) 217 | 218 | prefix, value = self.visit_and_split(node.value) 219 | if prefix is not None: 220 | return makeBody(prefix, self.visit(node.clone(value=value))) 221 | 222 | elif isinstance(value, AstFunction): 223 | return AstBody([]) 224 | 225 | name = self.new_symbol_instance(node.name) 226 | if name is node.name and value is node.value: 227 | return node 228 | else: 229 | return node.clone(name=name, value=value) 230 | 231 | def visit_dict(self, node: AstDict): 232 | prefix = [] 233 | items = {} 234 | for key in node.items: 235 | item = node.items[key] 236 | p, i = self.visit_and_split(item) 237 | if p is not None: 238 | prefix += p 239 | items[key] = i 240 | if len(prefix) > 0: 241 | return makeBody(prefix, AstDict(items)) 242 | else: 243 | return AstDict(items) 244 | 245 | def visit_for(self, node: AstFor): 246 | prefix, source = self.visit_and_split(node.source) 247 | if prefix is not None: 248 | return self.visit(makeBody(prefix, node.clone(source=source))) 249 | 250 | if is_vector(source): 251 | result = [] 252 | for item in source: 253 | result.append(AstLet(node.target, item, node.body)) 254 | return self.visit(makeBody(result)) 255 | 256 | _, body = self.visit_in_scope(node.body, is_loop=True) 257 | if source is node.source and body is node.body: 258 | return node 259 | else: 260 | return node.clone(source=source, body=body) 261 | 262 | def visit_if(self, node: AstIf): 263 | 264 | def phi(key, cond, left, right): 265 | return AstDef(key, AstIf(cond, AstSymbol(left), AstSymbol(right))) 266 | 267 | prefix, test = self.visit_and_split(node.test) 268 | if prefix is not None: 269 | return makeBody(prefix, self.visit(node.clone(test=test))) 270 | 271 | if isinstance(test, AstValue): 272 | if test.value is True: 273 | return self.visit(node.if_node) 274 | elif test.value is False or test.value is None: 275 | return self.visit(node.else_node) 276 | 277 | if_symbols, if_node = self.visit_in_scope(node.if_node) 278 | else_symbols, else_node = self.visit_in_scope(node.else_node) 279 | keys = set.union(set(if_symbols.keys()), set(else_symbols.keys())) 280 | if len(keys) == 0: 281 | if test is node.test and if_node is node.if_node and else_node is node.else_node: 282 | return node 283 | else: 284 | return node.clone(test=test, if_node=if_node, else_node=else_node) 285 | else: 286 | result = [] 287 | if not isinstance(test, AstSymbol): 288 | tmp = generate_temp_var() 289 | result.append(AstDef(tmp, test)) 290 | test = AstSymbol(tmp) 291 | result.append(node.clone(test=test, if_node=if_node, else_node=else_node)) 292 | for key in keys: 293 | if key in if_symbols and key in else_symbols: 294 | result.append(phi(self.new_symbol_instance(key), test, if_symbols[key], else_symbols[key])) 295 | elif not self.has_symbol(key): 296 | pass 297 | elif key in if_symbols: 298 | result.append(phi(self.new_symbol_instance(key), test, if_symbols[key], self.access_symbol(key))) 299 | elif key in else_symbols: 300 | result.append(phi(self.new_symbol_instance(key), test, self.access_symbol(key), else_symbols[key])) 301 | return makeBody(result) 302 | 303 | def visit_let(self, node: AstLet): 304 | if node.target == '_': 305 | result = makeBody(node.source, node.body) 306 | else: 307 | result = makeBody(AstDef(node.target, node.source), node.body) 308 | return self.visit(result) 309 | 310 | def visit_list_for(self, node: AstListFor): 311 | prefix, source = self.visit_and_split(node.source) 312 | if prefix is not None: 313 | return makeBody(prefix, self.visit(node.clone(source=source))) 314 | 315 | if is_vector(source): 316 | result = [] 317 | for item in source: 318 | result.append(AstLet(node.target, item, node.expr)) 319 | return self.visit(makeVector(result)) 320 | 321 | if isinstance(node.expr, AstSample): 322 | expr = self._visit_sample(node.expr) 323 | elif isinstance(node.expr, AstCall): 324 | expr = self._visit_call(node.expr) 325 | else: 326 | expr = self.visit(node.expr) 327 | 328 | if source is node.source and expr is node.expr: 329 | return node 330 | else: 331 | return node.clone(source=source, expr=expr) 332 | 333 | def visit_observe(self, node: AstObserve): 334 | prefix, dist = self.visit_and_split(node.dist) 335 | if prefix is not None: 336 | return makeBody(prefix, self.visit(node.clone(dist=dist))) 337 | prefix, value = self.visit_and_split(node.value) 338 | if prefix is not None: 339 | return makeBody(prefix, node.clone(value=value)) 340 | if dist is node.dist and value is node.value: 341 | return node 342 | else: 343 | return node.clone(dist=dist, value=value) 344 | 345 | def _visit_sample(self, node: AstSample): 346 | prefix, dist = self.visit_and_split(node.dist) 347 | if prefix is not None: 348 | return makeBody(prefix, node.clone(dist=dist)) 349 | if dist is node.dist: 350 | return node 351 | else: 352 | return node.clone(dist=dist) 353 | 354 | def visit_sample(self, node: AstSample): 355 | tmp = generate_temp_var() 356 | assign = AstDef(tmp, self._visit_sample(node)) 357 | if self.append_to_body(assign): 358 | return AstSymbol(tmp) 359 | else: 360 | return makeBody([assign, AstSymbol(tmp)]) 361 | 362 | def visit_symbol(self, node: AstSymbol): 363 | name = self.access_symbol(node.name) 364 | if name != node.name: 365 | return node.clone(name=name) 366 | else: 367 | return node 368 | 369 | def visit_unary(self, node: AstUnary): 370 | prefix, item = self.visit_and_split(node.item) 371 | if prefix is not None: 372 | return makeBody(prefix, node.clone(item=item)) 373 | if item is node.item: 374 | return node 375 | else: 376 | return node.clone(item=item) 377 | 378 | def visit_vector(self, node: AstVector): 379 | prefix = [] 380 | items = [] 381 | for item in node.items: 382 | p, i = self.visit_and_split(item) 383 | if p is not None: 384 | prefix += p 385 | items.append(i) 386 | if len(prefix) > 0: 387 | return makeBody(prefix, makeVector(items)) 388 | else: 389 | return makeVector(items) 390 | 391 | def visit_while(self, node: AstWhile): 392 | prefix, test = self.visit_and_split(node.test) 393 | if prefix is not None: 394 | return makeBody(prefix, self.visit(node.clone(test=test))) 395 | 396 | _, body = self.visit_in_scope(node.body, is_loop=True) 397 | if test is node.test and body is node.body: 398 | return node 399 | else: 400 | return node.clone(test=test, body=body) 401 | -------------------------------------------------------------------------------- /pyppl/transforms/ppl_symbol_simplifier.py: -------------------------------------------------------------------------------- 1 | # 2 | # This file is part of PyFOPPL, an implementation of a First Order Probabilistic Programming Language in Python. 3 | # 4 | # License: MIT (see LICENSE.txt) 5 | # 6 | # 20. Mar 2018, Tobias Kohn 7 | # 21. Mar 2018, Tobias Kohn 8 | # 9 | from ..ppl_ast import * 10 | from ..aux.ppl_transform_visitor import TransformVisitor 11 | 12 | 13 | class SymbolSimplifier(TransformVisitor): 14 | 15 | def __init__(self): 16 | super().__init__() 17 | self.names_map = {} 18 | self.name_count = {} 19 | 20 | def simplify_symbol(self, name: str): 21 | if name in self.names_map: 22 | return self.names_map[name] 23 | elif name.startswith('__'): 24 | if '____' in name: 25 | short = name[:name.index('____')+2] 26 | if short not in self.name_count: 27 | self.name_count[short] = 1 28 | else: 29 | self.name_count[short] += 1 30 | short += "_{}".format(self.name_count[short]) 31 | self.names_map[name] = short 32 | return short 33 | else: 34 | return name 35 | elif '__' in name: 36 | short = name[:name.index('__')] 37 | if short not in self.name_count: 38 | self.name_count[short] = 1 39 | else: 40 | self.name_count[short] += 1 41 | short += "_{}".format(self.name_count[short]) 42 | self.names_map[name] = short 43 | return short 44 | else: 45 | self.names_map[name] = name 46 | if name not in self.name_count: 47 | self.name_count[name] = 1 48 | else: 49 | self.name_count[name] += 1 50 | return name 51 | 52 | def visit_def(self, node: AstDef): 53 | value = self.visit(node.value) 54 | name = self.simplify_symbol(node.name) 55 | if name != node.name or value is not node.value: 56 | return node.clone(name=name, value=value) 57 | else: 58 | return node 59 | 60 | def visit_let(self, node: AstLet): 61 | source = self.visit(node.source) 62 | name = self.simplify_symbol(node.target) 63 | body = self.visit(node.body) 64 | if name == node.target and source is node.source and body is node.body: 65 | return node 66 | else: 67 | return node.clone(target=name, source=source, body=body) 68 | 69 | def visit_symbol(self, node: AstSymbol): 70 | name = self.simplify_symbol(node.name) 71 | if name != node.name: 72 | return node.clone(name=name) 73 | else: 74 | return node 75 | -------------------------------------------------------------------------------- /pyppl/transforms/ppl_var_substitutor.py: -------------------------------------------------------------------------------- 1 | # 2 | # This file is part of PyFOPPL, an implementation of a First Order Probabilistic Programming Language in Python. 3 | # 4 | # License: MIT (see LICENSE.txt) 5 | # 6 | # 15. Mar 2018, Tobias Kohn 7 | # 20. Mar 2018, Tobias Kohn 8 | # 9 | from ast import copy_location as _cl 10 | from ..ppl_ast import * 11 | 12 | class VarSubstitutor(Visitor): 13 | 14 | def __init__(self, bindings:dict): 15 | self.bindings = bindings 16 | assert type(self.bindings) is dict 17 | assert all([isinstance(self.bindings[key], AstNode) for key in self.bindings]) 18 | 19 | def parse_items(self, items:list): 20 | use_original = True 21 | result = [] 22 | for item in items: 23 | n_item = self.visit(item) 24 | use_original = use_original and n_item is item 25 | result.append(n_item) 26 | if use_original: 27 | return items 28 | else: 29 | return result 30 | 31 | def visit_node(self, node: AstNode): 32 | return node 33 | 34 | def visit_attribute(self, node:AstAttribute): 35 | base = self.visit(node.base) 36 | if base is node.base: 37 | return node 38 | else: 39 | return _cl(AstAttribute(base, node.attr), node) 40 | 41 | def visit_binary(self, node:AstBinary): 42 | left = self.visit(node.left) 43 | right = self.visit(node.right) 44 | if left is node.left and right is node.right: 45 | return node 46 | else: 47 | return _cl(AstBinary(left, node.op, right), node) 48 | 49 | def visit_body(self, node:AstBody): 50 | items = self.parse_items(node.items) 51 | if items is node.items: 52 | return node 53 | else: 54 | return _cl(makeBody(items), node) 55 | 56 | def visit_call(self, node: AstCall): 57 | args = self.parse_items(node.args) 58 | if args is node.args: 59 | return node 60 | else: 61 | return node.clone(args=args) 62 | 63 | def visit_compare(self, node: AstCompare): 64 | left = self.visit(node.left) 65 | right = self.visit(node.right) 66 | if left is node.left and right is node.right: 67 | return node 68 | else: 69 | return _cl(AstCompare(left, node.op, right), node) 70 | 71 | def visit_def(self, node: AstDef): 72 | value = self.visit(node.value) 73 | if value is node.value: 74 | return node 75 | else: 76 | return _cl(AstDef(node.name, value), node) 77 | 78 | def visit_dict(self, node: AstDict): 79 | return self.visit_node(node) 80 | 81 | def visit_for(self, node: AstFor): 82 | return self.visit_node(node) 83 | 84 | def visit_function(self, node: AstFunction): 85 | return self.visit_node(node) 86 | 87 | def visit_if(self, node: AstIf): 88 | return self.visit_node(node) 89 | 90 | def visit_let(self, node: AstLet): 91 | return self.visit_node(node) 92 | 93 | def visit_list_for(self, node: AstListFor): 94 | return self.visit_node(node) 95 | 96 | def visit_observe(self, node: AstObserve): 97 | dist = self.visit(node.dist) 98 | value = self.visit(node.value) 99 | if dist is node.dist and value is node.value: 100 | return node 101 | else: 102 | return _cl(AstObserve(dist, value), node) 103 | 104 | def visit_return(self, node: AstReturn): 105 | value = self.visit(node.value) 106 | if value is node.value: 107 | return node 108 | else: 109 | return _cl(AstReturn(value), node) 110 | 111 | def visit_sample(self, node: AstSample): 112 | dist = self.visit(node.dist) 113 | if dist is node.dist: 114 | return node 115 | else: 116 | return _cl(AstSample(dist), node) 117 | 118 | def visit_slice(self, node: AstSlice): 119 | return self.visit_node(node) 120 | 121 | def visit_subscript(self, node: AstSubscript): 122 | return self.visit_node(node) 123 | 124 | def visit_symbol(self, node: AstSymbol): 125 | name = node.name 126 | if name in self.bindings: 127 | return self.visit(self.bindings[name]) 128 | else: 129 | return node 130 | 131 | def visit_unary(self, node: AstUnary): 132 | item = self.visit(node.item) 133 | if item is node.item: 134 | return node 135 | else: 136 | return _cl(AstUnary(node.op, item), node) 137 | 138 | def visit_vector(self, node: AstVector): 139 | items = self.parse_items(node.items) 140 | if items is node.items: 141 | return node 142 | else: 143 | return _cl(makeVector(items), node) 144 | 145 | def visit_while(self, node: AstWhile): 146 | test = self.visit(node.test) 147 | body = self.visit(node.body) 148 | if test is node.test and body is node.body: 149 | return node 150 | else: 151 | return _cl(AstWhile(test, body), node) 152 | -------------------------------------------------------------------------------- /pyppl/types/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bayesianbrad/PyLFPPL/5bc160cb00a3d7e9aa3910367ad88cac6b138d99/pyppl/types/__init__.py -------------------------------------------------------------------------------- /pyppl/types/ppl_type_inference.py: -------------------------------------------------------------------------------- 1 | # 2 | # This file is part of PyFOPPL, an implementation of a First Order Probabilistic Programming Language in Python. 3 | # 4 | # License: MIT (see LICENSE.txt) 5 | # 6 | # 19. Feb 2018, Tobias Kohn 7 | # 22. Mar 2018, Tobias Kohn 8 | # 9 | from ..ppl_ast import * 10 | from .ppl_types import * 11 | 12 | class TypeInferencer(Visitor): 13 | 14 | __visit_children_first__ = True 15 | 16 | def __init__(self, parent): 17 | super().__init__() 18 | self.parent = parent 19 | 20 | def define(self, name:str, value): 21 | if name is None or name == '_': 22 | return 23 | if self.parent is not None and value is not None: 24 | result = self.parent.resolve(name) 25 | if hasattr(result, 'set_type'): 26 | result.set_type(value) 27 | 28 | def resolve(self, name:str): 29 | if self.parent is not None: 30 | result = self.parent.resolve(name) 31 | if isinstance(result, Type): 32 | return result 33 | elif isinstance(result, AstNode): 34 | return self.visit(result) 35 | return None 36 | 37 | def get_value_of(self, node: AstNode): 38 | if isinstance(node, AstValue): 39 | return node.value 40 | elif is_call(node, 'len') and node.arg_count == 1: 41 | result = self.visit(node.args[0]) 42 | if isinstance(result, SequenceType): 43 | return result.size 44 | return None 45 | 46 | 47 | def visit_binary(self, node: AstBinary): 48 | left = self.visit(node.left) 49 | right = self.visit(node.right) 50 | return node.op_function(left, right) 51 | 52 | def visit_body(self, node:AstBody): 53 | if node.is_empty: 54 | return NullType 55 | else: 56 | return node.items[-1].get_type() 57 | 58 | def visit_call(self, node: AstCall): 59 | return AnyType 60 | 61 | def visit_call_len(self, _): 62 | return Integer 63 | 64 | def visit_call_range(self, node: AstCall): 65 | if node.arg_count == 2: 66 | a = self.get_value_of(node.args[0]) 67 | b = self.get_value_of(node.args[1]) 68 | if a is not None and b is not None: 69 | return List[Integer][b-a] 70 | elif node.arg_count == 1: 71 | a = self.get_value_of(node.args[0]) 72 | if a is not None: 73 | return List[Integer][a] 74 | return List[Integer] 75 | 76 | def visit_call_torch_function(self, node: AstCall): 77 | name = node.function_name 78 | args = [self.visit(arg) for arg in node.args] 79 | f_name = name[6:] if name.startswith('torch.') else name 80 | if name.startswith('torch.cuda.'): 81 | f_name = f_name[5:] 82 | if node.arg_count == 1: 83 | if f_name in ('from_numpy',): 84 | return makeTensor(args[0]) 85 | elif f_name in ('ones', 'zeros'): 86 | return Tensor[AnyType, self.get_value_of(args[0])] 87 | elif f_name in ('ones_like', 'zeros_like', 'empty_like'): 88 | return args[0] 89 | elif f_name in ('arange',): 90 | return Tensor[Integer, self.get_value_of(args[0])] 91 | elif f_name in ('tensor', 'Tensor'): 92 | return makeTensor(args[0]) 93 | elif f_name in ('FloatTensor', 'IntTensor', 'DoubleTensor', 'HalfTensor', 94 | 'ByteTensor', 'ShortTensor', 'LongTensor'): 95 | return makeTensor(args[0], f_name) 96 | elif f_name in ('abs', 'acos', 'asin', 'atan', 'ceil', 'cos', 'cosh', 'erf', 'exp', 'expm1', 'floor', 97 | 'frac', 'log', 'log1p', 'neg', 'reciprocal', 'round', 'rsqrt', 'sigmoid', 'sign', 98 | 'sin', 'sinh', 'sqrt', 'tan', 'tanh', 'trunc'): 99 | return args[0] 100 | elif f_name in ('diag',): 101 | if isinstance(args[0], SequenceType): 102 | return Tensor[makeTensor(args[0]), args[0].size] 103 | if node.arg_count > 0: 104 | if f_name in ('eye',): 105 | d1 = self.get_value_of(args[0]) 106 | d2 = self.get_value_of(args[1]) if node.arg_count == 2 else d1 107 | return Tensor[Tensor[Float, d2], d1] 108 | elif f_name in ('arange', 'range'): 109 | pos = node.get_position_of_arg('step', 2) 110 | start = self.get_value_of(args[0]) 111 | stop = self.get_value_of(args[1]) 112 | if start is not None and stop is not None: 113 | count = stop - start 114 | if node.arg_count > pos: 115 | steps = self.get_value_of(args[pos]) 116 | if steps is not None and steps > 0: 117 | return Tensor[Integer, count / steps] 118 | else: 119 | return Tensor[Integer] 120 | else: 121 | return Tensor[Integer, count] 122 | else: 123 | return Tensor 124 | elif f_name in ('linspace', 'logspace'): 125 | pos = node.get_position_of_arg('steps', 2) 126 | if node.arg_count > pos: 127 | return Tensor[self.get_value_of(args[pos])] 128 | else: 129 | return Tensor[100] 130 | elif f_name in ('eq', 'ge', 'gt', 'le', 'lt', 'ne', 131 | 'add', 'atan2', 'clamp', 'div', 'fmod', 'lerp', 'mul', 'pow', 'remainder'): 132 | return args[0] 133 | elif f_name in ('equal', 'isnan'): 134 | return Boolean 135 | return Tensor 136 | 137 | def visit_compare(self, _): 138 | return Boolean 139 | 140 | def visit_def(self, node: AstDef): 141 | result = self.visit(node.value) 142 | self.define(node.name, result) 143 | return result 144 | 145 | def visit_dict(self, node: AstDict): 146 | base = union(*[self.visit(item) for item in node.items.values()]) 147 | return Dict[base][len(node.items)] 148 | 149 | def visit_for(self, node: AstFor): 150 | source = self.visit(node.source) 151 | if isinstance(source, SequenceType): 152 | self.define(node.target, source.item) 153 | return self.visit(node.body) 154 | else: 155 | return AnyType 156 | 157 | def visit_function(self, node: AstFunction): 158 | return Function 159 | 160 | def visit_if(self, node: AstIf): 161 | return union(node.if_node.get_type(), node.else_node.get_type()) 162 | 163 | def visit_let(self, node: AstLet): 164 | self.define(node.target, self.visit(node.source)) 165 | return node.body.get_type() 166 | 167 | def visit_list_for(self, node: AstListFor): 168 | source = self.visit(node.source) 169 | if isinstance(source, SequenceType): 170 | self.define(node.target, source.item) 171 | result = self.visit(node.expr) 172 | return List[result][source.size] 173 | else: 174 | return AnyType 175 | 176 | def visit_multi_slice(self, node: AstMultiSlice): 177 | base = self.visit(node.base) 178 | if base in Tensor: 179 | return Tensor 180 | elif base is Array: 181 | return Array 182 | else: 183 | return AnyType 184 | 185 | def visit_sample(self, node: AstSample): 186 | return Numeric 187 | 188 | def visit_slice(self, node: AstSlice): 189 | base = self.visit(node.base) 190 | if isinstance(base, SequenceType): 191 | return base.slice(node.start_as_int, node.stop_as_int) 192 | else: 193 | return AnyType 194 | 195 | def visit_subscript(self, node: AstSubscript): 196 | base = self.visit(node.base) 197 | if isinstance(base, SequenceType): 198 | return base.item_type 199 | else: 200 | return AnyType 201 | 202 | def visit_symbol(self, node: AstSymbol): 203 | result = self.resolve(node.name) 204 | return result if result is not None else AnyType 205 | 206 | def visit_unary(self, node: AstUnary): 207 | if node.op == 'not': 208 | return Boolean 209 | else: 210 | return self.visit(node.item) 211 | 212 | def visit_value(self, node: AstValue): 213 | return from_python(node.value) 214 | 215 | def visit_value_vector(self, node: AstValueVector): 216 | return from_python(node.items) 217 | 218 | def visit_vector(self, node: AstVector): 219 | base_type = union(*[self.visit(item) for item in node.items]) 220 | return List[base_type][len(node.items)] 221 | -------------------------------------------------------------------------------- /pyppl/types/ppl_type_operations.py: -------------------------------------------------------------------------------- 1 | # 2 | # This file is part of PyFOPPL, an implementation of a First Order Probabilistic Programming Language in Python. 3 | # 4 | # License: MIT (see LICENSE.txt) 5 | # 6 | # 19. Feb 2018, Tobias Kohn 7 | # 19. Feb 2018, Tobias Kohn 8 | # 9 | from .ppl_types import * 10 | 11 | ####################################################################################################################### 12 | 13 | def _binary_(left, right): 14 | return union(left, right) 15 | 16 | def add(left, right): 17 | return _binary_(left, right) 18 | 19 | def sub(left, right): 20 | return _binary_(left, right) 21 | 22 | def mul(left, right): 23 | if left in String and right in Integer: 24 | return left 25 | elif left in Integer and right in String: 26 | return right 27 | return _binary_(left, right) 28 | 29 | def div(left, right): 30 | return _binary_(left, right) 31 | 32 | def idiv(left, right): 33 | return _binary_(left, right) 34 | 35 | def mod(left, right): 36 | return _binary_(left, right) 37 | 38 | 39 | ####################################################################################################################### 40 | 41 | def neg(item): 42 | return item 43 | 44 | def pos(item): 45 | return item 46 | -------------------------------------------------------------------------------- /pyppl/utils/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | ''' 4 | Author: Bradley Gram-Hansen 5 | Time created: 14:24 6 | Date created: 08/06/2018 7 | 8 | License: MIT 9 | ''' 10 | -------------------------------------------------------------------------------- /pyppl/utils/core.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | ''' 4 | Author: Bradley Gram-Hansen 5 | Time created: 14:22 6 | Date created: 08/06/2018 7 | 8 | License: MIT 9 | ''' 10 | try: 11 | import networkx as _nx 12 | except ModuleNotFoundError: 13 | _nx = None 14 | try: 15 | import matplotlib.pyplot as _plt 16 | import matplotlib.patches as mpatches 17 | except ModuleNotFoundError: 18 | _plt = None 19 | 20 | 21 | def create_network_graph(vertices): 22 | """ 23 | Create a `networkx` graph. Used by the method `display_graph()`. 24 | :return: Either a `networkx.DiGraph` instance or `None`. 25 | """ 26 | if _nx: 27 | G = _nx.DiGraph() 28 | for v in vertices: 29 | G.add_node(v.display_name) 30 | for a in v.ancestors: 31 | G.add_edge(a.display_name, v.display_name) 32 | return G 33 | else: 34 | return None 35 | 36 | def display_graph(vertices): 37 | """ 38 | Transform the graph to a `networkx.DiGraph`-structure and display it using `matplotlib` -- if the necessary 39 | libraries are installed. 40 | :return: `True` if the graph was drawn, `False` otherwise. 41 | """ 42 | G =create_network_graph(vertices) 43 | _is_conditioned = None 44 | if _nx and _plt and G: 45 | try: 46 | from networkx.drawing.nx_agraph import graphviz_layout 47 | pos = graphviz_layout(G, prog='dot') 48 | except ModuleNotFoundError: 49 | from networkx.drawing.layout import shell_layout 50 | pos = shell_layout(G) 51 | except ImportError: 52 | from networkx.drawing.layout import shell_layout 53 | pos = shell_layout(G) 54 | _plt.subplot(111) 55 | _plt.axis('off') 56 | _nx.draw_networkx_nodes(G, pos, 57 | node_color='r', 58 | node_size=500, 59 | nodelist=[v.display_name for v in vertices if v.is_sampled], 60 | alpha=0.5) 61 | _nx.draw_networkx_nodes(G, pos, 62 | node_color='b', 63 | node_size=500, 64 | nodelist=[v.display_name for v in vertices if v.is_observed], 65 | alpha=0.5) 66 | 67 | for v in vertices: 68 | _nx.draw_networkx_edges(G, pos, arrows=True,arrowsize=22, 69 | edgelist=[(a.display_name, v.display_name) for a in v.ancestors]) 70 | if v.condition_ancestors is not None and len(v.condition_ancestors) > 0: 71 | _is_conditioned = 1 72 | _nx.draw_networkx_edges(G, pos, arrows=True, arrowsize=22, 73 | style='dashed', 74 | edge_color='g', 75 | alpha=0.5, 76 | edgelist=[(a.display_name, v.display_name) for a in v.condition_ancestors]) 77 | _nx.draw_networkx_labels(G, pos, font_size=8, font_color='k', font_weight='bold') 78 | 79 | # for node, _ in G.nodes(): 80 | red_patch = mpatches.Circle((0,0), radius=2, color='r', label='Sampled Variables') 81 | blue_patch = mpatches.Circle((0,0), radius=2, color='b', label='Observed Variables') 82 | green_patch = mpatches.Circle((0,0), radius=2, color='g', label='Conditioned Variables') if _is_conditioned else 0 83 | if _is_conditioned: 84 | _plt.legend(handles=[red_patch, blue_patch, green_patch]) 85 | else: 86 | _plt.legend(handles=[red_patch, blue_patch]) 87 | _plt.show() 88 | 89 | 90 | return True 91 | else: 92 | return False -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | enum34>=1.1.6 2 | numpy>=1.13.3 3 | scipy>=1.0.0 4 | matplotlib>=2.1.1 5 | networkx>=2.0 6 | graphviz>=0.8.2 7 | seaborn>=0.8 8 | tqdm>=4.19 9 | texttable>-1.2 10 | torch>=0.4.0 11 | torchvision 12 | jupyter -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | ''' 4 | Author: Bradley Gram-Hansen 5 | Time created: 12:26 6 | Date created: 08/06/2018 7 | 8 | License: MIT 9 | ''' 10 | from os.path import realpath, dirname, join 11 | from setuptools import setup, find_packages 12 | import sys 13 | 14 | 15 | DISTNAME = 'pyfo' 16 | DESCRIPTION = "Python for FOPPL" 17 | LONG_DESCRIPTION = open('README.md').read() 18 | MAINTAINER = 'Tobias Kohn ,Bradley Gram-hansen ' 19 | MAINTAINER_EMAIL = 'webmaster@tobiaskohn.ch , bradleygramhansen@gmail.com' 20 | AUTHOR = 'Tobias Kohn, Bradley Gram-Hansen' 21 | AUTHOR_EMAIL = 'webmaster@tobiaskohn.ch , bradley@robots.ox.ac.uk' 22 | URL = "http://github.com/bradleygramhansen/pyLFPPL" 23 | LICENSE = 'LICENSE.txt' 24 | VERSION = "0.1.0" 25 | PACKAGES = ['pylfppl'] 26 | classifiers = ['Development Status :: 1 - Production/UnStable', 27 | 'Programming Language :: Python', 28 | 'Programming Language :: Python :: 3.6', 29 | 'License :: OSI Approved :: MIT License', 30 | 'Intended Audience :: Science/Research', 31 | 'Topic :: Scientific/Engineering', 32 | 'Topic :: Scientific/Engineering :: Mathematics', 33 | 'Operating System :: OS Independent'] 34 | 35 | PROJECT_ROOT = dirname(realpath(__file__)) 36 | REQUIREMENTS_FILE = join(PROJECT_ROOT, 'requirements.txt') 37 | 38 | with open(REQUIREMENTS_FILE) as f: 39 | install_reqs = f.read().splitlines() 40 | 41 | if sys.version_info < (3, 4): 42 | install_reqs.append('enum34') 43 | 44 | # test_reqs = ['pytest', 'pytest-cov'] 45 | # if sys.version_info[0] == 2: # py3 has mock in stdlib 46 | # test_reqs.append('mock') 47 | 48 | 49 | if __name__ == "__main__": 50 | setup(name=DISTNAME, 51 | version=VERSION, 52 | maintainer=MAINTAINER, 53 | maintainer_email=MAINTAINER_EMAIL, 54 | description=DESCRIPTION, 55 | license=LICENSE, 56 | url=URL, 57 | long_description=LONG_DESCRIPTION, 58 | packages=find_packages(), 59 | package_data={'docs': ['*']}, 60 | include_package_data=True, 61 | classifiers=classifiers, 62 | install_requires=install_reqs) 63 | --------------------------------------------------------------------------------