├── tests ├── __init__.py └── test_smolppl.py ├── smolppl ├── __init__.py └── smolppl.py ├── requirements.txt ├── README.md ├── LICENSE ├── .github └── workflows │ └── python-app.yml └── .gitignore /tests/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /smolppl/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | scipy==1.7.3 2 | numpy==1.22.0 -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # smolppl 2 | A Probabilistic Programming Language in 70 lines of Python. Code for the blog post https://mrandri19.github.io/2022/01/12/a-PPL-in-70-lines-of-python.html 3 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 Andrea Cognolato 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /.github/workflows/python-app.yml: -------------------------------------------------------------------------------- 1 | # This workflow will install Python dependencies, run tests and lint with a single version of Python 2 | # For more information see: https://help.github.com/actions/language-and-framework-guides/using-python-with-github-actions 3 | 4 | name: Python application 5 | 6 | on: 7 | push: 8 | branches: [ main ] 9 | pull_request: 10 | branches: [ main ] 11 | 12 | jobs: 13 | build: 14 | 15 | runs-on: ubuntu-latest 16 | 17 | steps: 18 | - uses: actions/checkout@v2 19 | - name: Set up Python 3.10 20 | uses: actions/setup-python@v2 21 | with: 22 | python-version: "3.10" 23 | - name: Install dependencies 24 | run: | 25 | python -m pip install --upgrade pip 26 | pip install flake8 pytest 27 | if [ -f requirements.txt ]; then pip install -r requirements.txt; fi 28 | - name: Lint with flake8 29 | run: | 30 | # stop the build if there are Python syntax errors or undefined names 31 | flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics 32 | # exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide 33 | flake8 . --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics 34 | - name: Test with pytest 35 | run: | 36 | pytest 37 | -------------------------------------------------------------------------------- /tests/test_smolppl.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | from scipy.stats import norm 3 | from smolppl.smolppl import ( 4 | Normal, 5 | LatentVariable, 6 | ObservedVariable, 7 | evaluate_log_density, 8 | ) 9 | from numpy.testing import assert_almost_equal 10 | 11 | 12 | class TestSmolppl(unittest.TestCase): 13 | def test_log_density_correct_1(self): 14 | # z <- x 15 | z = LatentVariable("z", Normal, [0.0, 5.0]) 16 | x = ObservedVariable("x", Normal, [z, 1.0], observed=5.0) 17 | 18 | assert_almost_equal( 19 | evaluate_log_density(x, {"z": 1.5}), 20 | norm.logpdf(1.5, 0, 5) + norm.logpdf(5, 1.5, 1.0), 21 | ) 22 | 23 | def test_log_density_correct_2(self): 24 | # z <- 25 | # +- x 26 | # w <- 27 | z = LatentVariable("z", Normal, [0.0, 5.0]) 28 | w = LatentVariable("w", Normal, [0.0, 4.0]) 29 | x = ObservedVariable("x", Normal, [z, w], observed=5.0) 30 | 31 | assert_almost_equal( 32 | evaluate_log_density(x, {"z": 1.5, "w": 0.5}), 33 | norm.logpdf(1.5, 0, 5) 34 | + norm.logpdf(0.5, 0.0, 4.0) 35 | + norm.logpdf(5, 1.5, 0.5), 36 | ) 37 | 38 | def test_log_density_correct_3(self): 39 | # z <- w <- x 40 | # ^---------+ 41 | z = LatentVariable("z", Normal, [0.0, 5.0]) 42 | w = LatentVariable("w", Normal, [z, 5.0]) 43 | x = ObservedVariable("x", Normal, [z, w], observed=5.0) 44 | 45 | assert_almost_equal( 46 | evaluate_log_density(x, {"z": 1.5, "w": 0.5}), 47 | norm.logpdf(1.5, 0, 5) 48 | + norm.logpdf(0.5, 1.5, 5) 49 | + norm.logpdf(5, 1.5, 0.5), 50 | ) 51 | 52 | def test_log_density_correct_4(self): 53 | mu = LatentVariable("mu", Normal, [0.0, 5.0]) 54 | y_bar = ObservedVariable("y_bar", Normal, [mu, 1.0], observed=5.0) 55 | 56 | assert_almost_equal( 57 | evaluate_log_density(y_bar, {"mu": 4.0}), 58 | norm.logpdf(4.0, 0.0, 5.0) + norm.logpdf(5.0, 4.0, 1.0), 59 | ) 60 | -------------------------------------------------------------------------------- /smolppl/smolppl.py: -------------------------------------------------------------------------------- 1 | from scipy.stats import norm 2 | 3 | 4 | class Distribution: 5 | @staticmethod 6 | def log_density(point, params): 7 | raise NotImplementedError("Must be implemented by a subclass") 8 | 9 | 10 | class Normal(Distribution): 11 | @staticmethod 12 | def log_density(point, params): 13 | return float(norm.logpdf(point, params[0], params[1])) 14 | 15 | 16 | class LatentVariable: 17 | def __init__(self, name, dist_class, dist_args): 18 | self.name = name 19 | self.dist_class = dist_class 20 | self.dist_args = dist_args 21 | 22 | 23 | class ObservedVariable: 24 | def __init__(self, name, dist_class, dist_args, observed): 25 | self.name = name 26 | self.dist_class = dist_class 27 | self.dist_args = dist_args 28 | self.observed = observed 29 | 30 | 31 | def evaluate_log_density(variable, latent_values): 32 | visited = set() 33 | variables = [] 34 | 35 | def collect_variables(variable): 36 | if isinstance(variable, float): 37 | return 38 | 39 | visited.add(variable) 40 | variables.append(variable) 41 | 42 | for arg in variable.dist_args: 43 | if arg not in visited: 44 | collect_variables(arg) 45 | 46 | collect_variables(variable) 47 | 48 | log_density = 0.0 49 | for variable in variables: 50 | dist_params = [] 51 | for dist_arg in variable.dist_args: 52 | if isinstance(dist_arg, float): 53 | dist_params.append(dist_arg) 54 | if isinstance(dist_arg, LatentVariable): 55 | dist_params.append(latent_values[dist_arg.name]) 56 | 57 | if isinstance(variable, LatentVariable): 58 | log_density += variable.dist_class.log_density( 59 | latent_values[variable.name], dist_params 60 | ) 61 | if isinstance(variable, ObservedVariable): 62 | log_density += variable.dist_class.log_density( 63 | variable.observed, dist_params 64 | ) 65 | 66 | return log_density 67 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | --------------------------------------------------------------------------------