├── .github └── workflows │ └── ci.yml ├── .gitignore ├── CHANGES.md ├── CLAUDE.md ├── LICENSE ├── Makefile ├── README.md ├── __init__.py ├── poetry.lock ├── pyproject.toml ├── squigglepy ├── __init__.py ├── bayes.py ├── correlation.py ├── distributions.py ├── numbers.py ├── rng.py ├── samplers.py ├── utils.py └── version.py └── tests ├── __init__.py ├── integration.py ├── strategies.py ├── test_bayes.py ├── test_correlation.py ├── test_distributions.py ├── test_numbers.py ├── test_rng.py ├── test_samplers.py └── test_utils.py /.github/workflows/ci.yml: -------------------------------------------------------------------------------- 1 | # This workflow will install Python dependencies, run tests and lint with a variety of Python versions 2 | # For more information see: https://docs.github.com/en/actions/automating-builds-and-tests/building-and-testing-python 3 | 4 | name: CI 5 | 6 | on: 7 | push: 8 | branches: 9 | - main 10 | pull_request: 11 | branches: 12 | - main 13 | 14 | jobs: 15 | test: 16 | name: Test (pytest) 17 | runs-on: ubuntu-latest 18 | strategy: 19 | fail-fast: false 20 | matrix: 21 | python-version: ["3.9", "3.10", "3.11"] 22 | steps: 23 | - uses: actions/checkout@v3 24 | name: Checkout repository 25 | 26 | - name: Install poetry 27 | run: pipx install poetry 28 | 29 | - name: Set up Python ${{ matrix.python-version }} 30 | uses: actions/setup-python@v4 31 | with: 32 | python-version: ${{ matrix.python-version }} 33 | cache: 'poetry' 34 | 35 | - name: Install dependencies 36 | run: | 37 | poetry install 38 | 39 | - name: Test with pytest 40 | run: | 41 | poetry run pytest 42 | 43 | lint: 44 | name: "Lint (ruff)" 45 | runs-on: ubuntu-latest 46 | steps: 47 | - uses: actions/checkout@v3 48 | - uses: chartboost/ruff-action@v1 49 | 50 | format: 51 | name: "Format (black)" 52 | runs-on: ubuntu-latest 53 | steps: 54 | - uses: actions/checkout@v3 55 | - uses: psf/black@stable -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | build/* 2 | dist/* 3 | *.egg-info 4 | __pycache__ 5 | .ruff_cache 6 | .pytest-runtimes 7 | .hypothesis -------------------------------------------------------------------------------- /CHANGES.md: -------------------------------------------------------------------------------- 1 | ## v0.30 - development version 2 | 3 | * Uses prettier `tqdm` output that is now aware of Jupyter notebooks. 4 | 5 | ## v0.29 - latest release 6 | 7 | * Fixes a bug where `max_gain` and `modeled_gain` were incorrect in kelly output. 8 | * Fixes a bug where `error` was not correctly passed from `half_kelly`, `third_kelly`, and `quarter_kelly`. 9 | * Added `invlognorm` as a new distribution. 10 | * Added `bucket_percentages` to more easily get the percentage of values within a bucket. 11 | * Added `third_kelly` as an alias for `kelly` with deference = 0.66. (TODO: Fix tests) 12 | * Allows Bernoulli distributions to be defined with p=0 or p=1 13 | * Added a `Makefile` to help simplify testing and linting workflows 14 | 15 | 16 | ## v0.28 17 | 18 | * **[Breaking change]** `sq.pareto` previously sampled from a Lomax distribution due to a complication with numpy. Now it properly samples from a Pareto distribution. 19 | * **[Breaking change]** lclip / rclip have been removed from triangular distribution because that doesn't make sense. 20 | * **[Breaking change]** You now can nest mixture and discrete distributions within mixture distributions. 21 | * **[Breaking change]** `sq.kelly` now raises an error if you put in a price below the market price. You can pass `error=False` to disable this and return to the old behavior. 22 | * Added `pert` distribution. 23 | * Added `sharpe_ratio` to utilities. 24 | * `get_percentiles`, `get_log_percentiles`, `get_mean_and_ci`, and `get_median_and_ci` now can all take an optional `weights` parameter to do a weighted version. 25 | 26 | 27 | ## v0.27 28 | 29 | * **[Breaking change]** This package now only supports Python 3.9 and higher. 30 | * **[Breaking change]** `get_percentiles` and `get_log_percentiles` now always return a dictionary, even if there's only one element. 31 | * **[Breaking change]** `.type` is now removed from distribution objects. 32 | * You can now create correlated variables using `sq.correlate`. 33 | * Added `geometric` distribution. 34 | * Distribution objects now have the version of squigglepy they were created with, which can be accessed via `obj._version`. This should be helpful for debugging and noticing stale objects, especially when squigglepy distributions are stored in caches. 35 | * Distributions can now be hashed with `hash`. 36 | * Fixed a bug where `tdist` would not return multiple samples if defined with `t` alone. 37 | * Package load time is now ~2x faster. 38 | * Mixture sampling is now ~2x faster. 39 | * Pandas and matplotlib as removed as required dependencies, but their related features are lazily enabled when the modules are available. These packages are still available for install as extras, installable with `pip install squigglepy[plots]` (for plotting-related functionality, matplotlib for now), `pip install squigglepy[ecosystem]` (for pandas, and in the future other related packages), or `pip install squigglepy[all]` (for all extras). 40 | * Multicore distribution now does extra checks to avoid crashing from race conditions. 41 | * Using black now for formatting. 42 | * Switched from `flake8` to `ruff`. 43 | 44 | 45 | ## v0.26 46 | 47 | * **[Breaking change]** `lognorm` can now be defined either referencing the mean and sd of the underlying normal distribution via `norm_mean` / `norm_sd` or via the mean and sd of the lognormal distribution itself via `lognorm_mean` / `lognorm_sd`. To further disambiguate, `mean` and `sd` are no longer variables that can be passed to `lognorm`. 48 | 49 | 50 | ## v0.25 51 | 52 | * Added `plot` as a method to more easily plot distributions. 53 | * Added `dist_log` and `dist_exp` operators on distributions. 54 | * Added `growth_rate_to_doubling_time` and `doubling_time_to_growth_rate` convenience functions. These take numbers, numpy arrays or distributions. 55 | * Mixture distributions now print with weights in addition to distributions. 56 | * Changes `get_log_percentiles` to report in scientific notation. 57 | * `bayes` now supports separate arguments for `memcache_load` and `memcache_save` to better customize how memcache behavior works. `memcache` remains a parameter that sets both `memcache_load` and `memcache_save` to True. 58 | 59 | 60 | ## v0.24 61 | 62 | * Distributions can now be negated with `-` (e.g., `-lognorm(0.1, 1)`). 63 | * Numpy ints and floats can now be used for determining the number of samples. 64 | * Fixed some typos in the documentation. 65 | 66 | 67 | ## v0.23 68 | 69 | * Added `pareto` distribution. 70 | * Added `get_median_and_ci` to return the median and a given confidence interval for data. 71 | * `discrete` and `mixture` distributions now give more detail when printed. 72 | * Fixed some typos in the documentation. 73 | 74 | 75 | ## v0.22 76 | 77 | * Added `extremize` to extremize predictions. 78 | * Added `normalize` to normalize a list of numbers to sum to 1. 79 | * Added `get_mean_and_ci` to return the mean and a given confidence interval for data. 80 | * Added `is_dist` to determine if an object is a Squigglepy distribution. 81 | * Added `is_sampleable` to determine if an object can be sampled using `sample`. 82 | * Support for working within Pandas is now explicitly added. `pandas` has been added as a requirement. 83 | * `discrete` sampling now will compress a large array if possible for more efficient sampling. 84 | * `clip`, `lclip`, and `rclip` can now be used without needing distributions. 85 | * Some functions (e.g, `geomean`) previously only supported lists, dictionaries, and numpy arrays. They have been expanded to support all iterables. 86 | * `dist_max` and `dist_min` now support pipes (`>>`) 87 | * `get_percentiles` now coerces output to integer if `digits` is less than or equal to 0, instead of just exactly 0. 88 | 89 | 90 | ## v0.21 91 | 92 | * Mixture sampling is now 4-23x faster. 93 | * You can now get the version of squigglepy via `sq.__version__`. 94 | * Fixes a bug where the tqdm was displayed with the incorrect count when collecting cores during a multicore `sample`. 95 | 96 | 97 | ## v0.20 98 | 99 | * Fixes how package dependencies are handled in `setup.py` an specifies Python >= 3.7 must be used. This should fix install errors. 100 | 101 | 102 | ## v0.19 103 | 104 | #### Bugfixes 105 | 106 | * Fixes a bug where `lclip` and/or `rclip` on `mixture` distribution were not working correctly. 107 | * Fixes a bug where `dist_fn` did not work with `np.vectorize` functions. 108 | * Fixes a bug where in-memory caching was invoked for `bayesnet` when not desired. 109 | 110 | #### Caching and Multicore 111 | 112 | * **[Breaking change]** `bayesnet` caching is now based on binary files instead of pickle files (uses `msgspec` as the underlying library). 113 | * **[Breaking change]** `sample` caching is now based on numpy files instead of pickle files. 114 | * A cache can now be loaded via `sample(load_cache=cachefile)` or `bayesnet(load_cache=cachefile)`, without needing to pass the distribution / function. 115 | * `bayesnet` and `sample` now take an argument `cores` (default 1). If greater than 1, will run the calculations on multiple cores using the pathos package. 116 | 117 | #### Other 118 | 119 | * Functions that take `weights` now can instead take a parameter `relative_weights` where waits are automatically normalized to sum to 1 (instead of erroring, which is still the behavior if using `weights`). 120 | * Verbose output for `bayesnet` and `sample` is now clearer (and slightly more verbose). 121 | 122 | 123 | ## v0.18 124 | 125 | * **[Breaking change]** The default `t` for t-distributions has changed from 1 to 20. 126 | * `sample` results can now be cached in-memory using `memcache=True`. They can also be cached to a file -- use `dump_cache_file` to write the file and `load_cache_file` to load from the file. 127 | * _(Non-visible backend change)_ Weights that are set to 0 are now dropped entirely. 128 | 129 | 130 | ## v0.17 131 | 132 | * When `verbose=True` is used in `sample`, the progress bar now pops up in more relevant places and is much more likely to get triggered when relevant. 133 | * `discrete_sample` and `mixture_sample` now can take `verbose` parameter. 134 | 135 | 136 | ## v0.16 137 | 138 | * `zero_inflated` can create an arbitrary zero-inflated distribution. 139 | * Individual sampling functions (`normal_sample`, `lognormal_sample`, etc.) can now take an argument `samples` to generate multiple samples. 140 | * A large speedup has been achieved to sampling from the same distribution multiple times. 141 | * `requirements.txt` has been updated. 142 | 143 | 144 | ## v0.15 145 | 146 | * **[Breaking change]** `bayesnet` function now refers to parameter `memcache` where previously this parameter was called `cache`. 147 | * **[Breaking change]** If `get_percentiles` or `get_log_percentiles` is called with just one elemement for `percentiles`, it will return that value instead of a dict. 148 | * Fixed a bug where `get_percentiles` would not round correctly. 149 | * `bayesnet` results can now be cached to a file. Use `dump_cache_file` to write the file and `load_cache_file` to load from the file. 150 | * `discrete` now works with numpy arrays in addition to lists. 151 | * Added `one_in` as a shorthand to convert percentages into "1 in X" notation. 152 | * Distributions can now be compared with `==` and `!=`. 153 | 154 | 155 | ## v0.14 156 | 157 | * Nested sampling now works as intended. 158 | * You can now use `>>` for pipes for distributions. For example, `sq.norm(1, 2) >> dist_ceil` 159 | * Distributions can now be compared with `>`, `<`, `>=`, and `<=`. 160 | * `dist_max` can be used to get the maximum value between two distributions. This family of functions are not evaluated until the distribution is sampled and they work with pipes. 161 | * `dist_min` can be used to get the minimum value between two distributions. 162 | * `dist_round` can be used to round the final output of a distribution. This makes the distribution discrete. 163 | * `dist_ceil` can be used to ceiling round the final output of a distribution. This makes the distribution discrete. 164 | * `dist_floor` can be used to floor round the final output of a distribution. This makes the distribution discrete. 165 | * `lclip` can be used to clip a distribution to a lower bound. This is the same functionality that is available within the distribution and the `sample` method. 166 | * `rclip` can be used to clip a distribution to an upper bound. This is the same functionality that is available within the distribution and the `sample` method. 167 | * `clip` can be used to clip a distribution to both an upper bound and a lower bound. This is the same functionality that is available within the distribution and the `sample` method. 168 | * `sample` can now be used directly on numbers. This makes `const` functionally obsolete, but `const` is maintained for backwards compatibility and in case it is useful. 169 | * `sample(None)` now returns `None` instead of an error. 170 | 171 | 172 | ## v0.13 173 | 174 | * Sample shorthand notation can go in either order. That is, `100 @ sq.norm(1, 2)` now works and is the same as `sq.norm(1, 2) @ 100`, which is the same as `sq.sample(sq.norm(1, 2), n=100)`. 175 | 176 | 177 | ## v0.12 178 | 179 | * Distributions now implement math directly. That is, you can do things like `sq.norm(2, 3) + sq.norm(4, 5)`, whereas previously this would not work. Thanks to Dawn Drescher for helping me implement this. 180 | * `~sq.norm(1, 2)` is now a shorthand for `sq.sample(sq.norm(1, 2))`. Thanks to Dawn Drescher for helping me implement this shorthand. 181 | * `sq.norm(1, 2) @ 100` is now a shorthand for `sq.sample(sq.norm(1, 2), n=100)` 182 | 183 | 184 | ## v0.11 185 | 186 | #### Distributions 187 | 188 | * **[Breaking change]** `tdist` and `log_tdist` have been modified to better approximate the desired credible intervals. 189 | * `tdist` now can be defined by just `t`, producing a classic t-distribution. 190 | * `tdist` now has a default value for `t`: 1. 191 | * Added `chisquare` distribution. 192 | * `lognormal` now returns an error if it is defined with a zero or negative value. 193 | 194 | #### Other 195 | 196 | * All functions now have docstrings. 197 | * Added `kelly` to calculate Kelly criterion for bet sizing with probabilities. 198 | * Added `full_kelly`, `half_kelly`, `quarter_kelly` as helpful aliases. 199 | 200 | 201 | ## v0.10 202 | 203 | * **[Breaking change]** `credibility` is now defined using a number out of 100 (e.g., `credibility=80` to define an 80% CI) rather than a decimal out of 1 (e.g., `credibility=0.8` to define an 80% CI). 204 | * Distribution objects now print their parameters. 205 | 206 | 207 | ## v0.9 208 | 209 | * `goemean` and `geomean_odds` now can take the nested-list-based and dictionary-based formats for passing weights. 210 | 211 | 212 | ## v0.8 213 | 214 | #### Bayesian library updates 215 | 216 | * **[Breaking change]** `bayes.update` now updates normal distributions from the distribution rather than from samples. 217 | * **[Breaking change]** `bayes.update` no longer takes a `type` parameter but can now infer the type from the passed distribution. 218 | * **[Breaking change]** Corrected a bug in how `bayes.update` implemented `evidence_weight` when updating normal distributions. 219 | 220 | #### Non-visible backend changes 221 | 222 | * Distributions are now implemented as classes (rather than lists). 223 | 224 | 225 | ## v0.7 226 | 227 | #### Bugfixes 228 | 229 | * Fixes an issue with sampling from the `bernoulli` distribution. 230 | * Fixes a bug with the implementation of `lclip` and `rclip`. 231 | 232 | #### New distributions 233 | 234 | * Adds `discrete` to calculate a discrete distribution. Example: `discrete({'A': 0.3, 'B': 0.3, 'C': 0.4})` will return A 30% of the time, B 30% of the time, and C 40% of the time. 235 | * Adds `poisson(lam)` to calculate a poisson distribution. 236 | * Adds `gamma(size, scale)` to calculate a gamma distribution. 237 | 238 | #### Bayesian library updates 239 | 240 | * Adds `bayes.bayesnet` to do bayesian inferece (see README). 241 | * `bayes.update` now can take an `evidence_weight` parameter. Typically this would be equal to the number of samples. 242 | * **[Breaking change]** `bayes.bayes` has been renamed `bayes.simple_bayes`. 243 | 244 | #### Other 245 | 246 | * **[Breaking change]** `credibility`, which defines the size of the interval (e.g., `credibility=0.8` for an 80% CI), is now a property of the distribution rather than the sampler. That is, you should now call `sample(norm(1, 3, credibility=0.8))` whereas previously it was `sample(norm(1, 3), credibility=0.8)`. This will allow mixing of distributions with different credible ranges. 247 | * **[Breaking change]** Numbers have been changed from functions to global variables. Use `thousand` or `K` instead of `thousand()` (old/deprecated). 248 | * `sample` now has a nice progress reporter if `verbose=True`. 249 | * The `exponential` distribution now implements `lclip` and `rclip`. 250 | * The `mixture` distribution can infer equal weights if no weights are given. 251 | * The `mixture` distribution can infer the last weight if the last weight is not given. 252 | * `geomean` and `geomean_odds` can infer the last weight if the last weight is not given. 253 | * You can use `flip_coin` and `roll_die(sides)` to flip a coin or roll a die. 254 | * `event_happens` and `event` are aliases for `event_occurs`. 255 | * `get_percentiles` will now cast output to `int` if `digits=0`. 256 | * `get_log_percentiles` now has a default value for `percentiles`. 257 | * You can now set the seed for the RNG using `sq.set_seed`. 258 | 259 | #### Non-visible backend changes 260 | 261 | * Now has tests via pytest. 262 | * The random numbers now come from a numpy generator as opposed to the previous deprecated `np.random` methods. 263 | * The `sample` module (containing the `sample` function) has been renamed `samplers`. 264 | 265 | 266 | ## v0.6 267 | 268 | #### New distributions 269 | 270 | * Add `triangular(left, mode, right)` to calculate a triangular distribution. 271 | * Add `binomial(n, p)` to calculate a binomial distribution. 272 | * Add `beta(a, b)` to calculate a beta distribution. 273 | * Add `bernoulli(p)` to calculate a bernoulli distribution. 274 | * Add `exponential(scale)` to calculate an exponential distribution. 275 | 276 | #### New Bayesian library 277 | 278 | * Add `bayes.update` to get a posterior distribution from a prior distribution and an evidence distribution. 279 | * Add `bayes.average` to average distributions (via a mixture). 280 | 281 | #### New utility functions 282 | 283 | * Add `laplace` to calculate Laplace's Law of Succession. If `s` and `n` are passed, it will calculate `(s+1)/(n+2)`. If `s`, `time_passed`, and `time_remaining` are passed, it will use the [time invariant version](https://www.lesswrong.com/posts/wE7SK8w8AixqknArs/a-time-invariant-version-of-laplace-s-rule). Use `time_fixed=True` for fixed time periods and `time_fixed=False` (default) otherwise. 284 | * Add `geomean` to calculate the geometric mean. 285 | * Add `p_to_odds` to convert probability to odds. Also `odds_to_p` to convert odds to probability. 286 | * Add `geomean_odds` to calculate the geometric mean of odds, converted to and from probabilities. 287 | 288 | #### Other 289 | 290 | * If a distribution is defined with `sd` but not `mean`, `mean` will be inferred to be 0. 291 | * `sample` can now take `lclip` and `rclip` directly, in addition to defining `lclip` and `rclip` on the distribution itself. If both are defined, the most restrictive of the two bounds will be used. 292 | 293 | 294 | ## v0.5 295 | 296 | * Fix critical bug to `tdist` and `log_tdist` introduced in v0.3. 297 | 298 | 299 | ## v0.4 300 | 301 | * Fix critical bug introduced in v0.3. 302 | 303 | 304 | ## v0.3 305 | 306 | * Be able to define distributions using `mean` and `sd` instead of defining the interval. 307 | 308 | 309 | ## v0.2 310 | 311 | * **[Breaking change]** Change `distributed_log` to `mixture` (to follow Squiggle) and allow it to implement any sub-distribution. 312 | * **[Breaking change]** Changed library to single import. 313 | * **[Breaking change]** Remove `weighted_log` as a distribution. 314 | 315 | 316 | ## v0.1 317 | 318 | * Initial library 319 | -------------------------------------------------------------------------------- /CLAUDE.md: -------------------------------------------------------------------------------- 1 | # CLAUDE.md 2 | 3 | This file provides guidance to Claude Code (claude.ai/code) when working with code in this repository. 4 | 5 | ## Build/Test/Lint Commands 6 | - Install: `poetry install --with dev` 7 | - Run all tests: `make test` or `pytest && pip3 install . && python3 tests/integration.py` 8 | - Run single test: `pytest tests/test_file.py::test_function_name -v` 9 | - Format code: `make format` or `black . && ruff check . --fix` 10 | - Lint code: `make lint` or `ruff check .` 11 | 12 | ## Style Guidelines 13 | - Line length: 99 characters (configured for both Black and Ruff) 14 | - Imports: stdlib first, third-party next, local imports last 15 | - Naming: CamelCase for classes, snake_case for functions/vars, UPPER_CASE for constants 16 | - Documentation: NumPy-style docstrings with examples, parameters, returns 17 | - Type hints: Use throughout codebase 18 | - Error handling: Validate inputs, use ValueError with descriptive messages 19 | - Use operator overloading (`__add__`, `__mul__`, etc.) and custom operators (`@` for sampling) 20 | - Tests: Descriptive names, unit tests match module structure, use hypothesis for property testing -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 Peter Wildeford 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | # Variables 2 | POETRY = poetry 3 | PYTHON = $(POETRY) run python 4 | 5 | # Install dependencies 6 | install: 7 | $(POETRY) install 8 | 9 | install-dev: 10 | $(POETRY) install --with dev 11 | 12 | # Format code 13 | format: 14 | $(POETRY) run black . 15 | $(POETRY) run ruff check . --fix 16 | 17 | # Run linting 18 | lint: 19 | $(POETRY) run ruff check . 20 | 21 | # Run tests 22 | test: 23 | pytest && pip3 install . && python3 tests/integration.py 24 | 25 | # Help 26 | help: 27 | @echo "Available commands:" 28 | @echo " make install Install production dependencies" 29 | @echo " make install-dev Install all dependencies including dev tools" 30 | @echo " make format Format code using Black and Ruff" 31 | @echo " make lint Run Ruff for linting" 32 | @echo " make test Run tests" 33 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Squigglepy: Implementation of Squiggle in Python 2 | 3 | [Squiggle](https://www.squiggle-language.com/) is a "simple programming language for intuitive probabilistic estimation". It serves as its own standalone programming language with its own syntax, but it is implemented in JavaScript. I like the features of Squiggle and intend to use it frequently, but I also sometimes want to use similar functionalities in Python, especially alongside other Python statistical programming packages like Numpy, Pandas, and Matplotlib. The **squigglepy** package here implements many Squiggle-like functionalities in Python. 4 | 5 | ## Installation 6 | 7 | ```shell 8 | pip install squigglepy 9 | ``` 10 | 11 | For plotting support, you can also use the `plots` extra: 12 | 13 | ```shell 14 | pip install squigglepy[plots] 15 | ``` 16 | 17 | ## Usage 18 | 19 | ### Piano Tuners Example 20 | 21 | Here's the Squigglepy implementation of [the example from Squiggle Docs](https://www.squiggle-language.com/docs/Overview): 22 | 23 | ```Python 24 | import squigglepy as sq 25 | import numpy as np 26 | import matplotlib.pyplot as plt 27 | from squigglepy.numbers import K, M 28 | from pprint import pprint 29 | 30 | pop_of_ny_2022 = sq.to(8.1*M, 8.4*M) # This means that you're 90% confident the value is between 8.1 and 8.4 Million. 31 | pct_of_pop_w_pianos = sq.to(0.2, 1) * 0.01 # We assume there are almost no people with multiple pianos 32 | pianos_per_piano_tuner = sq.to(2*K, 50*K) 33 | piano_tuners_per_piano = 1 / pianos_per_piano_tuner 34 | total_tuners_in_2022 = pop_of_ny_2022 * pct_of_pop_w_pianos * piano_tuners_per_piano 35 | samples = total_tuners_in_2022 @ 1000 # Note: `@ 1000` is shorthand to get 1000 samples 36 | 37 | # Get mean and SD 38 | print('Mean: {}, SD: {}'.format(round(np.mean(samples), 2), 39 | round(np.std(samples), 2))) 40 | 41 | # Get percentiles 42 | pprint(sq.get_percentiles(samples, digits=0)) 43 | 44 | # Histogram 45 | plt.hist(samples, bins=200) 46 | plt.show() 47 | 48 | # Shorter histogram 49 | total_tuners_in_2022.plot() 50 | ``` 51 | 52 | And the version from the Squiggle doc that incorporates time: 53 | 54 | ```Python 55 | import squigglepy as sq 56 | from squigglepy.numbers import K, M 57 | 58 | pop_of_ny_2022 = sq.to(8.1*M, 8.4*M) 59 | pct_of_pop_w_pianos = sq.to(0.2, 1) * 0.01 60 | pianos_per_piano_tuner = sq.to(2*K, 50*K) 61 | piano_tuners_per_piano = 1 / pianos_per_piano_tuner 62 | 63 | def pop_at_time(t): # t = Time in years after 2022 64 | avg_yearly_pct_change = sq.to(-0.01, 0.05) # We're expecting NYC to continuously grow with an mean of roughly between -1% and +4% per year 65 | return pop_of_ny_2022 * ((avg_yearly_pct_change + 1) ** t) 66 | 67 | def total_tuners_at_time(t): 68 | return pop_at_time(t) * pct_of_pop_w_pianos * piano_tuners_per_piano 69 | 70 | # Get total piano tuners at 2030 71 | sq.get_percentiles(total_tuners_at_time(2030-2022) @ 1000) 72 | ``` 73 | 74 | **WARNING:** Be careful about dividing by `K`, `M`, etc. `1/2*K` = 500 in Python. Use `1/(2*K)` instead to get the expected outcome. 75 | 76 | **WARNING:** Be careful about using `K` to get sample counts. Use `sq.norm(2, 3) @ (2*K)`... `sq.norm(2, 3) @ 2*K` will return only two samples, multiplied by 1000. 77 | 78 | ### Distributions 79 | 80 | ```Python 81 | import squigglepy as sq 82 | 83 | # Normal distribution 84 | sq.norm(1, 3) # 90% interval from 1 to 3 85 | 86 | # Distribution can be sampled with mean and sd too 87 | sq.norm(mean=0, sd=1) 88 | 89 | # Shorthand to get one sample 90 | ~sq.norm(1, 3) 91 | 92 | # Shorthand to get more than one sample 93 | sq.norm(1, 3) @ 100 94 | 95 | # Longhand version to get more than one sample 96 | sq.sample(sq.norm(1, 3), n=100) 97 | 98 | # Nice progress reporter 99 | sq.sample(sq.norm(1, 3), n=1000, verbose=True) 100 | 101 | # Other distributions exist 102 | sq.lognorm(1, 10) 103 | sq.invlognorm(1, 10) 104 | sq.tdist(1, 10, t=5) 105 | sq.triangular(1, 2, 3) 106 | sq.pert(1, 2, 3, lam=2) 107 | sq.binomial(p=0.5, n=5) 108 | sq.beta(a=1, b=2) 109 | sq.bernoulli(p=0.5) 110 | sq.poisson(10) 111 | sq.chisquare(2) 112 | sq.gamma(3, 2) 113 | sq.pareto(1) 114 | sq.exponential(scale=1) 115 | sq.geometric(p=0.5) 116 | 117 | # Discrete sampling 118 | sq.discrete({'A': 0.1, 'B': 0.9}) 119 | 120 | # Can return integers 121 | sq.discrete({0: 0.1, 1: 0.3, 2: 0.3, 3: 0.15, 4: 0.15}) 122 | 123 | # Alternate format (also can be used to return more complex objects) 124 | sq.discrete([[0.1, 0], 125 | [0.3, 1], 126 | [0.3, 2], 127 | [0.15, 3], 128 | [0.15, 4]]) 129 | 130 | sq.discrete([0, 1, 2]) # No weights assumes equal weights 131 | 132 | # You can mix distributions together 133 | sq.mixture([sq.norm(1, 3), 134 | sq.norm(4, 10), 135 | sq.lognorm(1, 10)], # Distributions to mix 136 | [0.3, 0.3, 0.4]) # These are the weights on each distribution 137 | 138 | # This is equivalent to the above, just a different way of doing the notation 139 | sq.mixture([[0.3, sq.norm(1,3)], 140 | [0.3, sq.norm(4,10)], 141 | [0.4, sq.lognorm(1,10)]]) 142 | 143 | # Make a zero-inflated distribution 144 | # 60% chance of returning 0, 40% chance of sampling from `norm(1, 2)`. 145 | sq.zero_inflated(0.6, sq.norm(1, 2)) 146 | ``` 147 | 148 | ### Additional Features 149 | 150 | ```Python 151 | import squigglepy as sq 152 | 153 | # You can add and subtract distributions 154 | (sq.norm(1,3) + sq.norm(4,5)) @ 100 155 | (sq.norm(1,3) - sq.norm(4,5)) @ 100 156 | (sq.norm(1,3) * sq.norm(4,5)) @ 100 157 | (sq.norm(1,3) / sq.norm(4,5)) @ 100 158 | 159 | # You can also do math with numbers 160 | ~((sq.norm(sd=5) + 2) * 2) 161 | ~(-sq.lognorm(0.1, 1) * sq.pareto(1) / 10) 162 | 163 | # You can change the CI from 90% (default) to 80% 164 | sq.norm(1, 3, credibility=80) 165 | 166 | # You can clip 167 | sq.norm(0, 3, lclip=0, rclip=5) # Sample norm with a 90% CI from 0-3, but anything lower than 0 gets clipped to 0 and anything higher than 5 gets clipped to 5. 168 | 169 | # You can also clip with a function, and use pipes 170 | sq.norm(0, 3) >> sq.clip(0, 5) 171 | 172 | # You can correlate continuous distributions 173 | a, b = sq.uniform(-1, 1), sq.to(0, 3) 174 | a, b = sq.correlate((a, b), 0.5) # Correlate a and b with a correlation of 0.5 175 | # You can even pass your own correlation matrix! 176 | a, b = sq.correlate((a, b), [[1, 0.5], [0.5, 1]]) 177 | ``` 178 | 179 | #### Example: Rolling a Die 180 | 181 | An example of how to use distributions to build tools: 182 | 183 | ```Python 184 | import squigglepy as sq 185 | 186 | def roll_die(sides, n=1): 187 | return sq.discrete(list(range(1, sides + 1))) @ n if sides > 0 else None 188 | 189 | roll_die(sides=6, n=10) 190 | # [2, 6, 5, 2, 6, 2, 3, 1, 5, 2] 191 | ``` 192 | 193 | This is already included standard in the utils of this package. Use `sq.roll_die`. 194 | 195 | ### Bayesian inference 196 | 197 | 1% of women at age forty who participate in routine screening have breast cancer. 198 | 80% of women with breast cancer will get positive mammographies. 199 | 9.6% of women without breast cancer will also get positive mammographies. 200 | 201 | A woman in this age group had a positive mammography in a routine screening. 202 | What is the probability that she actually has breast cancer? 203 | 204 | We can approximate the answer with a Bayesian network (uses rejection sampling): 205 | 206 | ```Python 207 | import squigglepy as sq 208 | from squigglepy import bayes 209 | from squigglepy.numbers import M 210 | 211 | def mammography(has_cancer): 212 | return sq.event(0.8 if has_cancer else 0.096) 213 | 214 | def define_event(): 215 | cancer = ~sq.bernoulli(0.01) 216 | return({'mammography': mammography(cancer), 217 | 'cancer': cancer}) 218 | 219 | bayes.bayesnet(define_event, 220 | find=lambda e: e['cancer'], 221 | conditional_on=lambda e: e['mammography'], 222 | n=1*M) 223 | # 0.07723995880535531 224 | ``` 225 | 226 | Or if we have the information immediately on hand, we can directly calculate it. Though this doesn't work for very complex stuff. 227 | 228 | ```Python 229 | from squigglepy import bayes 230 | bayes.simple_bayes(prior=0.01, likelihood_h=0.8, likelihood_not_h=0.096) 231 | # 0.07763975155279504 232 | ``` 233 | 234 | You can also make distributions and update them: 235 | 236 | ```Python 237 | import matplotlib.pyplot as plt 238 | import squigglepy as sq 239 | from squigglepy import bayes 240 | from squigglepy.numbers import K 241 | import numpy as np 242 | 243 | print('Prior') 244 | prior = sq.norm(1,5) 245 | prior_samples = prior @ (10*K) 246 | plt.hist(prior_samples, bins = 200) 247 | plt.show() 248 | print(sq.get_percentiles(prior_samples)) 249 | print('Prior Mean: {} SD: {}'.format(np.mean(prior_samples), np.std(prior_samples))) 250 | print('-') 251 | 252 | print('Evidence') 253 | evidence = sq.norm(2,3) 254 | evidence_samples = evidence @ (10*K) 255 | plt.hist(evidence_samples, bins = 200) 256 | plt.show() 257 | print(sq.get_percentiles(evidence_samples)) 258 | print('Evidence Mean: {} SD: {}'.format(np.mean(evidence_samples), np.std(evidence_samples))) 259 | print('-') 260 | 261 | print('Posterior') 262 | posterior = bayes.update(prior, evidence) 263 | posterior_samples = posterior @ (10*K) 264 | plt.hist(posterior_samples, bins = 200) 265 | plt.show() 266 | print(sq.get_percentiles(posterior_samples)) 267 | print('Posterior Mean: {} SD: {}'.format(np.mean(posterior_samples), np.std(posterior_samples))) 268 | 269 | print('Average') 270 | average = bayes.average(prior, evidence) 271 | average_samples = average @ (10*K) 272 | plt.hist(average_samples, bins = 200) 273 | plt.show() 274 | print(sq.get_percentiles(average_samples)) 275 | print('Average Mean: {} SD: {}'.format(np.mean(average_samples), np.std(average_samples))) 276 | ``` 277 | 278 | #### Example: Alarm net 279 | 280 | This is the alarm network from [Bayesian Artificial Intelligence - Section 2.5.1](https://bayesian-intelligence.com/publications/bai/book/BAI_Chapter2.pdf): 281 | 282 | > Assume your house has an alarm system against burglary. 283 | > 284 | > You live in the seismically active area and the alarm system can get occasionally set off by an earthquake. 285 | > 286 | > You have two neighbors, Mary and John, who do not know each other. 287 | > If they hear the alarm they call you, but this is not guaranteed. 288 | > 289 | > The chance of a burglary on a particular day is 0.1%. 290 | > The chance of an earthquake on a particular day is 0.2%. 291 | > 292 | > The alarm will go off 95% of the time with both a burglary and an earthquake, 94% of the time with just a burglary, 29% of the time with just an earthquake, and 0.1% of the time with nothing (total false alarm). 293 | > 294 | > John will call you 90% of the time when the alarm goes off. But on 5% of the days, John will just call to say "hi". 295 | > Mary will call you 70% of the time when the alarm goes off. But on 1% of the days, Mary will just call to say "hi". 296 | 297 | ```Python 298 | import squigglepy as sq 299 | from squigglepy import bayes 300 | from squigglepy.numbers import M 301 | 302 | def p_alarm_goes_off(burglary, earthquake): 303 | if burglary and earthquake: 304 | return 0.95 305 | elif burglary and not earthquake: 306 | return 0.94 307 | elif not burglary and earthquake: 308 | return 0.29 309 | elif not burglary and not earthquake: 310 | return 0.001 311 | 312 | def p_john_calls(alarm_goes_off): 313 | return 0.9 if alarm_goes_off else 0.05 314 | 315 | def p_mary_calls(alarm_goes_off): 316 | return 0.7 if alarm_goes_off else 0.01 317 | 318 | def define_event(): 319 | burglary_happens = sq.event(p=0.001) 320 | earthquake_happens = sq.event(p=0.002) 321 | alarm_goes_off = sq.event(p_alarm_goes_off(burglary_happens, earthquake_happens)) 322 | john_calls = sq.event(p_john_calls(alarm_goes_off)) 323 | mary_calls = sq.event(p_mary_calls(alarm_goes_off)) 324 | return {'burglary': burglary_happens, 325 | 'earthquake': earthquake_happens, 326 | 'alarm_goes_off': alarm_goes_off, 327 | 'john_calls': john_calls, 328 | 'mary_calls': mary_calls} 329 | 330 | # What are the chances that both John and Mary call if an earthquake happens? 331 | bayes.bayesnet(define_event, 332 | n=1*M, 333 | find=lambda e: (e['mary_calls'] and e['john_calls']), 334 | conditional_on=lambda e: e['earthquake']) 335 | # Result will be ~0.19, though it varies because it is based on a random sample. 336 | # This also may take a minute to run. 337 | 338 | # If both John and Mary call, what is the chance there's been a burglary? 339 | bayes.bayesnet(define_event, 340 | n=1*M, 341 | find=lambda e: e['burglary'], 342 | conditional_on=lambda e: (e['mary_calls'] and e['john_calls'])) 343 | # Result will be ~0.27, though it varies because it is based on a random sample. 344 | # This will run quickly because there is a built-in cache. 345 | # Use `cache=False` to not build a cache and `reload_cache=True` to recalculate the cache. 346 | ``` 347 | 348 | Note that the amount of Bayesian analysis that squigglepy can do is pretty limited. For more complex bayesian analysis, consider [sorobn](https://github.com/MaxHalford/sorobn), [pomegranate](https://github.com/jmschrei/pomegranate), [bnlearn](https://github.com/erdogant/bnlearn), or [pyMC](https://github.com/pymc-devs/pymc). 349 | 350 | #### Example: A Demonstration of the Monty Hall Problem 351 | 352 | ```Python 353 | import squigglepy as sq 354 | from squigglepy import bayes 355 | from squigglepy.numbers import K, M, B, T 356 | 357 | 358 | def monte_hall(door_picked, switch=False): 359 | doors = ['A', 'B', 'C'] 360 | car_is_behind_door = ~sq.discrete(doors) 361 | reveal_door = ~sq.discrete([d for d in doors if d != door_picked and d != car_is_behind_door]) 362 | 363 | if switch: 364 | old_door_picked = door_picked 365 | door_picked = [d for d in doors if d != old_door_picked and d != reveal_door][0] 366 | 367 | won_car = (car_is_behind_door == door_picked) 368 | return won_car 369 | 370 | 371 | def define_event(): 372 | door = ~sq.discrete(['A', 'B', 'C']) 373 | switch = sq.event(0.5) 374 | return {'won': monte_hall(door_picked=door, switch=switch), 375 | 'switched': switch} 376 | 377 | RUNS = 10*K 378 | r = bayes.bayesnet(define_event, 379 | find=lambda e: e['won'], 380 | conditional_on=lambda e: e['switched'], 381 | verbose=True, 382 | n=RUNS) 383 | print('Win {}% of the time when switching'.format(int(r * 100))) 384 | 385 | r = bayes.bayesnet(define_event, 386 | find=lambda e: e['won'], 387 | conditional_on=lambda e: not e['switched'], 388 | verbose=True, 389 | n=RUNS) 390 | print('Win {}% of the time when not switching'.format(int(r * 100))) 391 | 392 | # Win 66% of the time when switching 393 | # Win 34% of the time when not switching 394 | ``` 395 | 396 | #### Example: More complex coin/dice interactions 397 | 398 | > Imagine that I flip a coin. If heads, I take a random die out of my blue bag. If tails, I take a random die out of my red bag. 399 | > The blue bag contains only 6-sided dice. The red bag contains a 4-sided die, a 6-sided die, a 10-sided die, and a 20-sided die. 400 | > I then roll the random die I took. What is the chance that I roll a 6? 401 | 402 | ```Python 403 | import squigglepy as sq 404 | from squigglepy.numbers import K, M, B, T 405 | from squigglepy import bayes 406 | 407 | def define_event(): 408 | if sq.flip_coin() == 'heads': # Blue bag 409 | return sq.roll_die(6) 410 | else: # Red bag 411 | return sq.discrete([4, 6, 10, 20]) >> sq.roll_die 412 | 413 | 414 | bayes.bayesnet(define_event, 415 | find=lambda e: e == 6, 416 | verbose=True, 417 | n=100*K) 418 | # This run for me returned 0.12306 which is pretty close to the correct answer of 0.12292 419 | ``` 420 | 421 | ### Kelly betting 422 | 423 | You can use probability generated, combine with a bankroll to determine bet sizing using [Kelly criterion](https://en.wikipedia.org/wiki/Kelly_criterion). 424 | 425 | For example, if you want to Kelly bet and you've... 426 | 427 | - determined that your price (your probability of the event in question happening / the market in question resolving in your favor) is $0.70 (70%) 428 | - see that the market is pricing at $0.65 429 | - you have a bankroll of $1000 that you are willing to bet 430 | 431 | You should bet as follows: 432 | 433 | ```Python 434 | import squigglepy as sq 435 | kelly_data = sq.kelly(my_price=0.70, market_price=0.65, bankroll=1000) 436 | kelly_data['kelly'] # What fraction of my bankroll should I bet on this? 437 | # 0.143 438 | kelly_data['target'] # How much money should be invested in this? 439 | # 142.86 440 | kelly_data['expected_roi'] # What is the expected ROI of this bet? 441 | # 0.077 442 | ``` 443 | 444 | ### More examples 445 | 446 | You can see more examples of squigglepy in action [here](https://github.com/peterhurford/public-botecs). 447 | 448 | ## Run tests 449 | 450 | Main test engine is `pytest && pip3 install . && python3 tests/integration.py`, which can also be run via `make test` 451 | 452 | Formatting is `black . && ruff check .`, which can also be run via `make format` (black with fixes on) and `make lint` (ruff check). 453 | 454 | Use `black .` for formatting. 455 | 456 | 457 | ## Disclaimers 458 | 459 | This package is unofficial and supported by myself and Rethink Priorities. It is not affiliated with or associated with the Quantified Uncertainty Research Institute, which maintains the Squiggle language (in JavaScript). 460 | 461 | This package is also new and not yet in a stable production version, so you may encounter bugs and other errors. Please report those so they can be fixed. It's also possible that future versions of the package may introduce breaking changes. 462 | 463 | This package is available under an MIT License. 464 | 465 | ## Acknowledgements 466 | 467 | - The primary author of this package is Peter Wildeford. Agustín Covarrubias and Bernardo Baron contributed several key features and developments. 468 | - Thanks to Ozzie Gooen and the Quantified Uncertainty Research Institute for creating and maintaining the original Squiggle language. 469 | - Thanks to Dawn Drescher for helping me implement math between distributions. 470 | - Thanks to Dawn Drescher for coming up with the idea to use `~` as a shorthand for `sample`, as well as helping me implement it. 471 | -------------------------------------------------------------------------------- /__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rethinkpriorities/squigglepy/24f631d246bd82619d68b0c4f792c9ebd05fc34c/__init__.py -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [tool.poetry] 2 | name = "squigglepy" 3 | version = "0.30-dev0" 4 | description = "Squiggle programming language for intuitive probabilistic estimation features in Python" 5 | authors = ["Peter Wildeford "] 6 | license = "MIT" 7 | readme = "README.md" 8 | include = ["CHANGES.md"] 9 | classifiers = [ 10 | "Development Status :: 3 - Alpha", 11 | "Programming Language :: Python :: 3", 12 | "License :: OSI Approved :: MIT License", 13 | "Operating System :: OS Independent", 14 | ] 15 | repository = "https://github.com/rethinkpriorities/squigglepy" 16 | 17 | [tool.poetry.dependencies] 18 | python = ">=3.9,<3.12" 19 | setuptools = "^69.0.0" 20 | numpy = "^1.24.3" 21 | scipy = "^1.10.1" 22 | tqdm = "^4.65.0" 23 | pathos = "^0.3.0" 24 | msgspec = "^0.15.1" 25 | matplotlib = { version = "^3.7.1", optional = true } 26 | pandas = { version = "^2.0.2", optional = true } 27 | 28 | 29 | [tool.poetry.group.dev.dependencies] 30 | ruff = "^0.0.272" 31 | pytest = "^7.3.2" 32 | pytest-mock = "^3.10.0" 33 | black = "^24.10.0" 34 | seaborn = "^0.12.2" 35 | hypothesis = "^6.78.3" 36 | hypofuzz = "^23.6.1" 37 | 38 | [tool.poetry.extras] 39 | plots = ["matplotlib"] 40 | ecosystem = ["pandas"] 41 | all = ["plots", "ecosystem"] 42 | 43 | [build-system] 44 | requires = ["poetry-core"] 45 | build-backend = "poetry.core.masonry.api" 46 | 47 | [tool.ruff] 48 | line-length = 99 49 | 50 | [tool.black] 51 | line-length = 99 52 | -------------------------------------------------------------------------------- /squigglepy/__init__.py: -------------------------------------------------------------------------------- 1 | from .distributions import * # noqa ignore=F405 2 | from .numbers import * # noqa ignore=F405 3 | from .samplers import * # noqa ignore=F405 4 | from .utils import * # noqa ignore=F405 5 | from .rng import * # noqa ignore=F405 6 | from .correlation import * # noqa ignore=F405 7 | from .version import __version__ # noqa ignore=F405 8 | -------------------------------------------------------------------------------- /squigglepy/bayes.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | import math 4 | import msgspec 5 | 6 | import numpy as np 7 | import pathos.multiprocessing as mp 8 | 9 | from datetime import datetime 10 | 11 | from .distributions import BetaDistribution, NormalDistribution, norm, beta, mixture 12 | from .utils import _core_cuts, _init_tqdm, _tick_tqdm, _flush_tqdm 13 | 14 | 15 | _squigglepy_internal_bayesnet_caches = {} 16 | 17 | 18 | def simple_bayes(likelihood_h, likelihood_not_h, prior): 19 | """ 20 | Calculate Bayes rule. 21 | 22 | p(h|e) = (p(e|h)*p(h)) / (p(e|h)*p(h) + p(e|~h)*(1-p(h))) 23 | p(h|e) is called posterior 24 | p(e|h) is called likelihood 25 | p(h) is called prior 26 | 27 | Parameters 28 | ---------- 29 | likelihood_h : float 30 | The likelihood (given that the hypothesis is true), aka p(e|h) 31 | likelihood_not_h : float 32 | The likelihood given the hypothesis is not true, aka p(e|~h) 33 | prior : float 34 | The prior probability, aka p(h) 35 | 36 | Returns 37 | ------- 38 | float 39 | The result of Bayes rule, aka p(h|e) 40 | 41 | Examples 42 | -------- 43 | # Cancer example: prior of having cancer is 1%, the likelihood of a positive 44 | # mammography given cancer is 80% (true positive rate), and the likelihood of 45 | # a positive mammography given no cancer is 9.6% (false positive rate). 46 | # Given this, what is the probability of cancer given a positive mammography? 47 | >>> simple_bayes(prior=0.01, likelihood_h=0.8, likelihood_not_h=0.096) 48 | 0.07763975155279504 49 | """ 50 | return (likelihood_h * prior) / (likelihood_h * prior + likelihood_not_h * (1 - prior)) 51 | 52 | 53 | def bayesnet( 54 | event_fn=None, 55 | n=1, 56 | find=None, 57 | conditional_on=None, 58 | reduce_fn=None, 59 | raw=False, 60 | memcache=True, 61 | memcache_load=True, 62 | memcache_save=True, 63 | reload_cache=False, 64 | dump_cache_file=None, 65 | load_cache_file=None, 66 | cache_file_primary=False, 67 | verbose=False, 68 | cores=1, 69 | ): 70 | """ 71 | Calculate a Bayesian network. 72 | 73 | Allows you to find conditional probabilities of custom events based on 74 | rejection sampling. 75 | 76 | Parameters 77 | ---------- 78 | event_fn : function 79 | A function that defines the bayesian network 80 | n : int 81 | The number of samples to generate 82 | find : a function or None 83 | What do we want to know the probability of? 84 | conditional_on : a function or None 85 | When finding the probability, what do we want to condition on? 86 | reduce_fn : a function or None 87 | When taking all the results of the simulations, how do we aggregate them 88 | into a final answer? Defaults to ``np.mean``. 89 | raw : bool 90 | If True, just return the results of each simulation without aggregating. 91 | memcache : bool 92 | If True, cache the results in-memory for future calculations. Each cache 93 | will be matched based on the ``event_fn``. Default ``True``. 94 | memcache_load : bool 95 | If True, load cache from the in-memory. This will be true if ``memcache`` 96 | is True. Cache will be matched based on the ``event_fn``. Default ``True``. 97 | memcache_save : bool 98 | If True, save results to an in-memory cache. This will be true if ``memcache`` 99 | is True. Cache will be matched based on the ``event_fn``. Default ``True``. 100 | reload_cache : bool 101 | If True, any existing cache will be ignored and recalculated. Default ``False``. 102 | dump_cache_file : str or None 103 | If present, will write out the cache to a binary file with this path with 104 | ``.sqlcache`` appended to the file name. 105 | load_cache_file : str or None 106 | If present, will first attempt to load and use a cache from a file with this 107 | path with ``.sqlcache`` appended to the file name. 108 | cache_file_primary : bool 109 | If both an in-memory cache and file cache are present, the file 110 | cache will be used for the cache if this is True, and the in-memory cache 111 | will be used otherwise. Defaults to False. 112 | verbose : bool 113 | If True, will print out statements on computational progress. 114 | cores : int 115 | If 1, runs on a single core / process. If greater than 1, will run on a multiprocessing 116 | pool with that many cores / processes. 117 | 118 | Returns 119 | ------- 120 | various 121 | The result of ``reduce_fn`` on ``n`` simulations of ``event_fn``. 122 | 123 | Examples 124 | -------- 125 | # Cancer example: prior of having cancer is 1%, the likelihood of a positive 126 | # mammography given cancer is 80% (true positive rate), and the likelihood of 127 | # a positive mammography given no cancer is 9.6% (false positive rate). 128 | # Given this, what is the probability of cancer given a positive mammography? 129 | >> def mammography(has_cancer): 130 | >> p = 0.8 if has_cancer else 0.096 131 | >> return bool(sq.sample(sq.bernoulli(p))) 132 | >> 133 | >> def define_event(): 134 | >> cancer = sq.sample(sq.bernoulli(0.01)) 135 | >> return({'mammography': mammography(cancer), 136 | >> 'cancer': cancer}) 137 | >> 138 | >> bayes.bayesnet(define_event, 139 | >> find=lambda e: e['cancer'], 140 | >> conditional_on=lambda e: e['mammography'], 141 | >> n=1*M) 142 | 0.07723995880535531 143 | """ 144 | events = None 145 | if memcache is True: 146 | memcache_load = True 147 | memcache_save = True 148 | elif memcache is False: 149 | memcache_load = False 150 | memcache_save = False 151 | has_in_mem_cache = event_fn in _squigglepy_internal_bayesnet_caches 152 | cache_path = load_cache_file + ".sqcache" if load_cache_file else None 153 | has_file_cache = os.path.exists(cache_path) if load_cache_file else False 154 | 155 | if load_cache_file or dump_cache_file or cores > 1: 156 | encoder = msgspec.msgpack.Encoder() 157 | decoder = msgspec.msgpack.Decoder() 158 | 159 | if load_cache_file and not has_file_cache and verbose: 160 | print("Warning: cache file `{}.sqcache` not found.".format(load_cache_file)) 161 | 162 | if not reload_cache: 163 | if load_cache_file and has_file_cache and (not has_in_mem_cache or cache_file_primary): 164 | if verbose: 165 | print("Loading from cache file (`{}`)...".format(cache_path)) 166 | with open(cache_path, "rb") as f: 167 | events = decoder.decode(f.read()) 168 | 169 | elif memcache_load and has_in_mem_cache: 170 | if verbose: 171 | print("Loading from in-memory cache...") 172 | events = _squigglepy_internal_bayesnet_caches.get(event_fn) 173 | 174 | if events: 175 | if events["metadata"]["n"] < n: 176 | raise ValueError( 177 | ("insufficient samples - {} results cached but " + "requested {}").format( 178 | events["metadata"]["n"], n 179 | ) 180 | ) 181 | 182 | events = events["events"] 183 | if verbose: 184 | print("...Loaded") 185 | 186 | elif verbose: 187 | print("Reloading cache...") 188 | 189 | if events is None: 190 | if event_fn is None: 191 | return None 192 | 193 | def run_event_fn(pbar=None, total_cores=1): 194 | _tick_tqdm(pbar, total_cores) 195 | return event_fn() 196 | 197 | if cores == 1: 198 | if verbose: 199 | print("Generating Bayes net...") 200 | r_ = range(n) 201 | pbar = _init_tqdm(verbose=verbose, total=n) 202 | events = [run_event_fn(pbar=pbar, total_cores=1) for _ in r_] 203 | _flush_tqdm(pbar) 204 | else: 205 | if verbose: 206 | print("Generating Bayes net with {} cores...".format(cores)) 207 | with mp.ProcessingPool(cores) as pool: 208 | cuts = _core_cuts(n, cores) 209 | 210 | def multicore_event_fn(core, total_cores=1, verbose=False): 211 | r_ = range(cuts[core]) 212 | pbar = _init_tqdm(verbose=verbose, total=n) 213 | batch = [run_event_fn(pbar=pbar, total_cores=total_cores) for _ in r_] 214 | _flush_tqdm(pbar) 215 | 216 | if verbose: 217 | print("Shuffling data...") 218 | 219 | while not os.path.exists("test-core-{}.sqcache".format(core)): 220 | with open("test-core-{}.sqcache".format(core), "wb") as outfile: 221 | encoder = msgspec.msgpack.Encoder() 222 | outfile.write(encoder.encode(batch)) 223 | if verbose: 224 | print("Writing data...") 225 | time.sleep(1) 226 | 227 | pool_results = pool.amap(multicore_event_fn, list(range(cores - 1))) 228 | multicore_event_fn(cores - 1, total_cores=cores, verbose=verbose) 229 | if verbose: 230 | print("Waiting for other cores...") 231 | while not pool_results.ready(): 232 | if verbose: 233 | print(".", end="", flush=True) 234 | time.sleep(1) 235 | 236 | if cores > 1: 237 | if verbose: 238 | print("Collecting data...") 239 | events = [] 240 | pbar = _init_tqdm(verbose=verbose, total=cores) 241 | for c in range(cores): 242 | _tick_tqdm(pbar, 1) 243 | with open("test-core-{}.sqcache".format(c), "rb") as infile: 244 | events += decoder.decode(infile.read()) 245 | os.remove("test-core-{}.sqcache".format(c)) 246 | _flush_tqdm(pbar) 247 | if verbose: 248 | print("...Collected!") 249 | 250 | metadata = {"n": n, "last_generated": datetime.now()} 251 | cache_data = {"events": events, "metadata": metadata} 252 | if memcache_save and (not has_in_mem_cache or reload_cache): 253 | if verbose: 254 | print("Caching in-memory...") 255 | _squigglepy_internal_bayesnet_caches[event_fn] = cache_data 256 | if verbose: 257 | print("...Cached!") 258 | 259 | if dump_cache_file: 260 | cache_path = dump_cache_file + ".sqcache" 261 | if verbose: 262 | print("Writing cache to file `{}`...".format(cache_path)) 263 | with open(cache_path, "wb") as f: 264 | f.write(encoder.encode(cache_data)) 265 | if verbose: 266 | print("...Cached!") 267 | 268 | if conditional_on is not None: 269 | if verbose: 270 | print("Filtering conditional...") 271 | events = [e for e in events if conditional_on(e)] 272 | 273 | if len(events) < 1: 274 | raise ValueError("insufficient samples for condition") 275 | 276 | if conditional_on and verbose: 277 | print("...Filtered!") 278 | 279 | if find is None: 280 | if verbose: 281 | print("...Reducing") 282 | events = events if reduce_fn is None else reduce_fn(events) 283 | if verbose: 284 | print("...Reduced!") 285 | else: 286 | if verbose: 287 | print("...Finding") 288 | events = [find(e) for e in events] 289 | if verbose: 290 | print("...Found!") 291 | if not raw: 292 | if verbose: 293 | print("...Reducing") 294 | reduce_fn = np.mean if reduce_fn is None else reduce_fn 295 | events = reduce_fn(events) 296 | if verbose: 297 | print("...Reduced!") 298 | if verbose: 299 | print("...All done!") 300 | return events 301 | 302 | 303 | def update(prior, evidence, evidence_weight=1): 304 | """ 305 | Update a distribution. 306 | 307 | Starting with a prior distribution, use Bayesian inference to perform an update, 308 | producing a posterior distribution from the evidence distribution. 309 | 310 | Parameters 311 | ---------- 312 | prior : Distribution 313 | The prior distribution. Currently must either be normal or beta type. Other 314 | types are not yet supported. 315 | evidence : Distribution 316 | The distribution used to update the prior. Currently must either be normal 317 | or beta type. Other types are not yet supported. 318 | evidence_weight : float 319 | How much weight to put on the evidence distribution? Currently this only matters 320 | for normal distributions, where this should be equivalent to the sample weight. 321 | 322 | Returns 323 | ------- 324 | Distribution 325 | The posterior distribution 326 | 327 | Examples 328 | -------- 329 | >> prior = sq.norm(1,5) 330 | >> evidence = sq.norm(2,3) 331 | >> bayes.update(prior, evidence) 332 | norm(mean=2.53, sd=0.29) 333 | """ 334 | if isinstance(prior, NormalDistribution) and isinstance(evidence, NormalDistribution): 335 | prior_mean = prior.mean 336 | prior_var = prior.sd**2 337 | evidence_mean = evidence.mean 338 | evidence_var = evidence.sd**2 339 | return norm( 340 | mean=( 341 | (evidence_var * prior_mean + evidence_weight * (prior_var * evidence_mean)) 342 | / (evidence_weight * prior_var + evidence_var) 343 | ), 344 | sd=math.sqrt( 345 | (evidence_var * prior_var) / (evidence_weight * prior_var + evidence_var) 346 | ), 347 | ) 348 | elif isinstance(prior, BetaDistribution) and isinstance(evidence, BetaDistribution): 349 | prior_a = prior.a 350 | prior_b = prior.b 351 | evidence_a = evidence.a 352 | evidence_b = evidence.b 353 | return beta(prior_a + evidence_a, prior_b + evidence_b) 354 | elif not isinstance(prior, type(evidence)): 355 | print(type(prior), type(evidence)) 356 | raise ValueError("can only update distributions of the same type.") 357 | else: 358 | raise ValueError("type `{}` not supported.".format(prior.__class__.__name__)) 359 | 360 | 361 | def average(prior, evidence, weights=[0.5, 0.5], relative_weights=None): 362 | """ 363 | Average two distributions. 364 | 365 | Parameters 366 | ---------- 367 | prior : Distribution 368 | The prior distribution. 369 | evidence : Distribution 370 | The distribution used to average with the prior. 371 | weights : list or np.array or float 372 | How much weight to put on ``prior`` versus ``evidence`` when averaging? If 373 | only one weight is passed, the other weight will be inferred to make the 374 | total weights sum to 1. Defaults to 50-50 weights. 375 | relative_weights : list or None 376 | Relative weights, which if given will be weights that are normalized 377 | to sum to 1. 378 | 379 | Returns 380 | ------- 381 | Distribution 382 | A mixture distribution that accords weights to ``prior`` and ``evidence``. 383 | 384 | Examples 385 | -------- 386 | >> prior = sq.norm(1,5) 387 | >> evidence = sq.norm(2,3) 388 | >> bayes.average(prior, evidence) 389 | mixture 390 | """ 391 | return mixture(dists=[prior, evidence], weights=weights, relative_weights=relative_weights) 392 | -------------------------------------------------------------------------------- /squigglepy/correlation.py: -------------------------------------------------------------------------------- 1 | """ 2 | This module implements the Iman-Conover method for inducing correlations between distributions. 3 | 4 | Some of the code has been adapted from Abraham Lee's mcerp package (https://github.com/tisimst/mcerp/). 5 | """ 6 | 7 | # Parts of `induce_correlation` are licensed as follows: 8 | 9 | # BSD 3-Clause License 10 | 11 | # Copyright (c) 2018, Abraham Lee 12 | # All rights reserved. 13 | 14 | # Redistribution and use in source and binary forms, with or without 15 | # modification, are permitted provided that the following conditions are met: 16 | 17 | # * Redistributions of source code must retain the above copyright notice, this 18 | # list of conditions and the following disclaimer. 19 | 20 | # * Redistributions in binary form must reproduce the above copyright notice, 21 | # this list of conditions and the following disclaimer in the documentation 22 | # and/or other materials provided with the distribution. 23 | 24 | # * Neither the name of the copyright holder nor the names of its 25 | # contributors may be used to endorse or promote products derived from 26 | # this software without specific prior written permission. 27 | 28 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 29 | # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 30 | # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 31 | # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 32 | # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 33 | # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 34 | # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 35 | # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 36 | # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 37 | # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 38 | 39 | from __future__ import annotations 40 | 41 | from dataclasses import dataclass 42 | import numpy as np 43 | from scipy.linalg import cholesky 44 | from scipy.stats import rankdata, spearmanr 45 | from scipy.stats.distributions import norm as _scipy_norm 46 | from numpy.typing import NDArray 47 | from copy import deepcopy 48 | 49 | from typing import TYPE_CHECKING, Union 50 | 51 | if TYPE_CHECKING: 52 | from .distributions import OperableDistribution 53 | 54 | 55 | def correlate( 56 | variables: tuple[OperableDistribution, ...], 57 | correlation: Union[NDArray[np.float64], list[list[float]], np.float64, float], 58 | tolerance: Union[float, np.float64, None] = 0.05, 59 | _min_unique_samples: int = 100, 60 | ): 61 | """ 62 | Correlate a set of variables according to a rank correlation matrix. 63 | 64 | This employs the Iman-Conover method to induce the correlation while 65 | preserving the original marginal distributions. 66 | 67 | This method works on a best-effort basis, and may fail to induce the desired 68 | correlation depending on the distributions provided. An exception will be raised 69 | if that's the case. 70 | 71 | Parameters 72 | ---------- 73 | variables : tuple of distributions 74 | The variables to correlate as a tuple of distributions. 75 | 76 | The distributions must be able to produce enough unique samples for the method 77 | to be able to induce the desired correlation by shuffling the samples. 78 | 79 | Discrete distributions are notably hard to correlate this way, 80 | as it's common for them to result in very few unique samples. 81 | 82 | correlation : 2d-array or float 83 | An n-by-n array that defines the desired Spearman rank correlation coefficients. 84 | This matrix must be symmetric and positive semi-definite; and must not be confused with 85 | a covariance matrix. 86 | 87 | Correlation parameters can only be between -1 and 1, exclusive 88 | (including extremely close approximations). 89 | 90 | If a float is provided, all variables will be correlated with the same coefficient. 91 | 92 | tolerance : float, optional 93 | If provided, overrides the absolute tolerance used to check if the resulting 94 | correlation matrix matches the desired correlation matrix. Defaults to 0.05. 95 | 96 | Checking can also be disabled by passing None. 97 | 98 | Returns 99 | ------- 100 | correlated_variables : tuple of distributions 101 | The correlated variables as a tuple of distributions in the same order as 102 | the input variables. 103 | 104 | Examples 105 | -------- 106 | Suppose we want to correlate two variables with a correlation coefficient of 0.65: 107 | >>> solar_radiation, temperature = sq.gamma(300, 100), sq.to(22, 28) 108 | >>> solar_radiation, temperature = sq.correlate((solar_radiation, temperature), 0.7) 109 | >>> print(np.corrcoef(solar_radiation @ 1000, temperature @ 1000)[0, 1]) 110 | 0.6975960649767123 111 | 112 | Or you could pass a correlation matrix: 113 | >>> funding_gap, cost_per_delivery, effect_size = ( 114 | sq.to(20_000, 80_000), sq.to(30, 80), sq.beta(2, 5) 115 | ) 116 | >>> funding_gap, cost_per_delivery, effect_size = sq.correlate( 117 | (funding_gap, cost_per_delivery, effect_size), 118 | [[1, 0.6, -0.5], [0.6, 1, -0.2], [-0.5, -0.2, 1]] 119 | ) 120 | >>> print(np.corrcoef(funding_gap @ 1000, cost_per_delivery @ 1000, effect_size @ 1000)) 121 | array([[ 1. , 0.580520 , -0.480149], 122 | [ 0.580962, 1. , -0.187831], 123 | [-0.480149, -0.187831 , 1. ]]) 124 | 125 | """ 126 | if not isinstance(variables, tuple): 127 | variables = tuple(variables) 128 | 129 | if len(variables) < 2: 130 | raise ValueError("You must provide at least two variables to correlate.") 131 | 132 | assert all(v.correlation_group is None for v in variables) 133 | 134 | # Convert a float to a correlation matrix 135 | if ( 136 | isinstance(correlation, float) 137 | or isinstance(correlation, np.floating) 138 | or isinstance(correlation, int) 139 | ): 140 | correlation_parameter = np.float64(correlation) 141 | 142 | assert ( 143 | -1 < correlation_parameter < 1 144 | ), "Correlation parameter must be between -1 and 1, exclusive." 145 | # Generate a correlation matrix with 146 | # pairwise correlations equal to the correlation parameter 147 | correlation_matrix: NDArray[np.float64] = np.full( 148 | (len(variables), len(variables)), correlation_parameter 149 | ) 150 | # Set the diagonal to 1 151 | np.fill_diagonal(correlation_matrix, 1) 152 | else: 153 | # Coerce the correlation matrix into a numpy array 154 | correlation_matrix: NDArray[np.float64] = np.array(correlation, dtype=np.float64) 155 | 156 | tolerance = float(tolerance) if tolerance is not None else None 157 | 158 | # Deepcopy the variables to avoid modifying the originals 159 | variables = deepcopy(variables) 160 | 161 | # Create the correlation group 162 | CorrelationGroup(variables, correlation_matrix, tolerance, _min_unique_samples) 163 | 164 | return variables 165 | 166 | 167 | @dataclass 168 | class CorrelationGroup: 169 | """ 170 | An object that holds metadata for a group of correlated distributions. 171 | This object is not intended to be used directly by the user, but 172 | rather during sampling to induce correlations between distributions. 173 | """ 174 | 175 | correlated_dists: tuple[OperableDistribution] 176 | correlation_matrix: NDArray[np.float64] 177 | correlation_tolerance: Union[float, None] = 0.05 178 | min_unique_samples: int = 100 179 | 180 | def __post_init__(self): 181 | # Check that the correlation matrix is square of the expected size 182 | assert ( 183 | self.correlation_matrix.shape[0] 184 | == self.correlation_matrix.shape[1] 185 | == len(self.correlated_dists) 186 | ), "Correlation matrix must be square, and of the length of the number of dists. provided." 187 | 188 | # Check that the diagonal of the correlation matrix is all ones 189 | assert np.all(np.diag(self.correlation_matrix) == 1), "Diagonal must be all ones." 190 | 191 | # Check that values are between -1 and 1 192 | assert ( 193 | -1 <= np.min(self.correlation_matrix) and np.max(self.correlation_matrix) <= 1 194 | ), "Correlation matrix values must be between -1 and 1." 195 | 196 | # Check that the correlation matrix is positive semi-definite 197 | assert np.all( 198 | np.linalg.eigvals(self.correlation_matrix) >= 0 199 | ), "Matrix must be positive semi-definite." 200 | 201 | # Check that the correlation matrix is symmetric 202 | assert np.all( 203 | self.correlation_matrix == self.correlation_matrix.T 204 | ), "Matrix must be symmetric." 205 | 206 | # Link the correlation group to each distribution 207 | for dist in self.correlated_dists: 208 | dist.correlation_group = self 209 | 210 | def induce_correlation(self, data: NDArray[np.float64]) -> NDArray[np.float64]: 211 | """ 212 | Induce a set of correlations on a column-wise dataset 213 | 214 | Parameters 215 | ---------- 216 | data : 2d-array 217 | An m-by-n array where m is the number of samples and n is the 218 | number of independent variables, each column of the array corresponding 219 | to each variable 220 | corrmat : 2d-array 221 | An n-by-n array that defines the desired correlation coefficients 222 | (between -1 and 1). Note: the matrix must be symmetric and 223 | positive-definite in order to induce. 224 | 225 | Returns 226 | ------- 227 | new_data : 2d-array 228 | An m-by-n array that has the desired correlations. 229 | 230 | """ 231 | # Check that each column doesn't have too little unique values 232 | for column in data.T: 233 | if not self.has_sufficient_sample_diversity(column): 234 | raise ValueError( 235 | "The data has too many repeated values to induce a correlation. " 236 | "This might be because of too few samples, or too many repeated samples." 237 | ) 238 | 239 | # If the correlation matrix is the identity matrix, just return the data 240 | if np.all(self.correlation_matrix == np.eye(self.correlation_matrix.shape[0])): 241 | return data 242 | 243 | # Create a rank-matrix 244 | data_rank = np.vstack([rankdata(datai, method="min") for datai in data.T]).T 245 | 246 | # Generate van der Waerden scores 247 | data_rank_score = data_rank / (data_rank.shape[0] + 1.0) 248 | data_rank_score = _scipy_norm(0, 1).ppf(data_rank_score) 249 | 250 | # Calculate the lower triangular matrix of the Cholesky decomposition 251 | # of the desired correlation matrix 252 | p = cholesky(self.correlation_matrix, lower=True) 253 | 254 | # Calculate the current correlations 255 | t = np.corrcoef(data_rank_score, rowvar=False) 256 | 257 | # Calculate the lower triangular matrix of the Cholesky decomposition 258 | # of the current correlation matrix 259 | q = cholesky(t, lower=True) 260 | 261 | # Calculate the re-correlation matrix 262 | s = np.dot(p, np.linalg.inv(q)) 263 | 264 | # Calculate the re-sampled matrix 265 | new_data = np.dot(data_rank_score, s.T) 266 | 267 | # Create the new rank matrix 268 | new_data_rank = np.vstack([rankdata(datai, method="min") for datai in new_data.T]).T 269 | 270 | # Sort the original data according to the new rank matrix 271 | self._sort_data_according_to_rank(data, data_rank, new_data_rank) 272 | 273 | # # Check correlation 274 | if self.correlation_tolerance: 275 | self._check_empirical_correlation(data) 276 | 277 | return data 278 | 279 | def _sort_data_according_to_rank( 280 | self, 281 | data: NDArray[np.float64], 282 | data_rank: NDArray[np.float64], 283 | new_data_rank: NDArray[np.float64], 284 | ): 285 | """Sorts the original data according to new_data_rank, in place.""" 286 | assert ( 287 | data.shape == data_rank.shape == new_data_rank.shape 288 | ), "All input arrays must have the same shape" 289 | for i in range(data.shape[1]): 290 | _, order = np.unique( 291 | np.hstack((data_rank[:, i], new_data_rank[:, i])), return_inverse=True 292 | ) 293 | old_order = order[: new_data_rank.shape[0]] 294 | new_order = order[-new_data_rank.shape[0] :] 295 | tmp = data[np.argsort(old_order), i][new_order] 296 | data[:, i] = tmp[:] 297 | 298 | def _check_empirical_correlation(self, samples: NDArray[np.float64]): 299 | """ 300 | Ensures that the empirical correlation matrix is 301 | the same as the desired correlation matrix. 302 | """ 303 | assert self.correlation_tolerance is not None 304 | 305 | # Compute the empirical correlation matrix 306 | empirical_correlation = spearmanr(samples).statistic 307 | if len(self.correlated_dists) == 2: 308 | # empirical_correlation is a scalar 309 | properly_correlated = np.isclose( 310 | empirical_correlation, 311 | self.correlation_matrix[0, 1], 312 | atol=self.correlation_tolerance, 313 | rtol=0, 314 | ) 315 | else: 316 | # empirical_correlation is a matrix 317 | properly_correlated = np.allclose( 318 | empirical_correlation, 319 | self.correlation_matrix, 320 | atol=self.correlation_tolerance, 321 | rtol=0, 322 | ) 323 | if not properly_correlated: 324 | raise RuntimeError( 325 | "Failed to induce the desired correlation between samples. " 326 | "This might be because of too little diversity in the samples. " 327 | "You can relax the tolerance by passing `tolerance` to correlate()." 328 | ) 329 | 330 | def has_sufficient_sample_diversity( 331 | self, 332 | samples: NDArray[np.float64], 333 | relative_threshold: float = 0.7, 334 | absolute_threshold=None, 335 | ) -> bool: 336 | """ 337 | Check if there is there are sufficient unique samples to work with in the data. 338 | """ 339 | 340 | if absolute_threshold is None: 341 | absolute_threshold = self.min_unique_samples 342 | 343 | unique_samples = len(np.unique(samples, axis=0)) 344 | n_samples = len(samples) 345 | 346 | diversity = unique_samples / n_samples 347 | 348 | return (diversity >= relative_threshold) and (unique_samples >= absolute_threshold) 349 | -------------------------------------------------------------------------------- /squigglepy/numbers.py: -------------------------------------------------------------------------------- 1 | thousand = 10**3 2 | K = thousand 3 | 4 | million = 10**6 5 | M = million 6 | 7 | billion = 10**9 8 | B = billion 9 | 10 | trillion = 10**12 11 | T = trillion 12 | 13 | quadrillion = 10**15 14 | 15 | quintillion = 10**18 16 | 17 | sextillion = 10**21 18 | 19 | septillion = 10**24 20 | 21 | octillion = 10**27 22 | 23 | nonillion = 10**30 24 | 25 | decillion = 10**33 26 | -------------------------------------------------------------------------------- /squigglepy/rng.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | _squigglepy_internal_rng = np.random.default_rng() 4 | 5 | 6 | def set_seed(seed): 7 | """ 8 | Set the seed of the random number generator used by Squigglepy. 9 | 10 | The RNG is a ``np.random.default_rng`` under the hood. 11 | 12 | Parameters 13 | ---------- 14 | seed : float 15 | The seed to use for the RNG. 16 | 17 | Returns 18 | ------- 19 | np.random.default_rng 20 | The RNG used internally. 21 | 22 | Examples 23 | -------- 24 | >>> set_seed(42) 25 | Generator(PCG64) at 0x127EDE9E0 26 | """ 27 | global _squigglepy_internal_rng 28 | _squigglepy_internal_rng = np.random.default_rng(seed) 29 | return _squigglepy_internal_rng 30 | -------------------------------------------------------------------------------- /squigglepy/version.py: -------------------------------------------------------------------------------- 1 | __version__ = "0.30-dev0" 2 | -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rethinkpriorities/squigglepy/24f631d246bd82619d68b0c4f792c9ebd05fc34c/tests/__init__.py -------------------------------------------------------------------------------- /tests/integration.py: -------------------------------------------------------------------------------- 1 | import time 2 | import numpy as np 3 | 4 | from tqdm import tqdm 5 | from squigglepy.distributions import LogTDistribution 6 | 7 | 8 | RUNS = 10_000 9 | 10 | 11 | def _within(actual, expected, tolerance_ratio=None, abs_tolerance=None): 12 | if expected == 0 or actual == 0: 13 | ratio = None 14 | elif actual < expected: 15 | ratio = expected / actual 16 | else: 17 | ratio = actual / expected 18 | 19 | abs_diff = np.abs(actual - expected) 20 | 21 | if abs_tolerance is not None and abs_diff < abs_tolerance: 22 | return True 23 | elif tolerance_ratio is not None and ratio < tolerance_ratio: 24 | return True 25 | else: 26 | return False 27 | 28 | 29 | def _mark_time(start, expected_sec, label, tolerance_ratio=1.05, tolerance_ms_threshold=5): 30 | end = time.time() 31 | delta_sec = end - start 32 | use_delta = delta_sec 33 | expected = expected_sec 34 | delta_label = "sec" 35 | if delta_sec < 1: 36 | delta_ms = delta_sec * 1000 37 | expected = expected_sec * 1000 38 | use_delta = delta_ms 39 | delta_label = "ms" 40 | use_delta = round(use_delta, 2) 41 | print( 42 | "...{} in {}{} (expected ~{}{})".format( 43 | label, use_delta, delta_label, expected, delta_label 44 | ) 45 | ) 46 | if delta_label == "ms": 47 | deviation = not _within(use_delta, expected, tolerance_ratio, tolerance_ms_threshold) 48 | else: 49 | deviation = not _within(use_delta, expected, tolerance_ratio) 50 | if deviation: 51 | print("!!! WARNING: Unexpected timing deviation") 52 | return {"timing(sec)": delta_sec, "deviation": deviation} 53 | 54 | 55 | def pct_of_pop_w_pianos(): 56 | percentage = sq.to(0.2, 1) 57 | return sq.sample(percentage) * 0.01 58 | 59 | 60 | def piano_tuners_per_piano(): 61 | pianos_per_piano_tuner = sq.to(2 * K, 50 * K) 62 | return 1 / sq.sample(pianos_per_piano_tuner) 63 | 64 | 65 | def total_tuners_in_2022(): 66 | return sq.sample(pop_of_ny_2022) * pct_of_pop_w_pianos() * piano_tuners_per_piano() 67 | 68 | 69 | def pop_at_time(t): 70 | avg_yearly_pct_change = sq.to(-0.01, 0.05) 71 | return sq.sample(pop_of_ny_2022) * ((sq.sample(avg_yearly_pct_change) + 1) ** t) 72 | 73 | 74 | def total_tuners_at_time(t): 75 | return pop_at_time(t) * pct_of_pop_w_pianos() * piano_tuners_per_piano() 76 | 77 | 78 | def pop_at_time2(t): 79 | return pop_of_ny_2022 * ((sq.to(-0.01, 0.05) + 1) ** t) 80 | 81 | 82 | def total_tuners_at_time2(t): 83 | piano_tuners_per_piano = 1 / sq.to(2 * K, 50 * K) 84 | pct_of_pop_w_pianos = sq.to(0.2, 1) * 0.01 85 | return pop_at_time2(t) * pct_of_pop_w_pianos * piano_tuners_per_piano 86 | 87 | 88 | def roll_die(sides, n=1): 89 | return sq.sample(sq.discrete(list(range(1, sides + 1))), n=n) if sides > 0 else None 90 | 91 | 92 | def roll_die2(sides, n=1): 93 | return sq.discrete(list(range(1, sides + 1))) @ n if sides > 0 else None 94 | 95 | 96 | def mammography(has_cancer): 97 | return sq.event(0.8 if has_cancer else 0.096) 98 | 99 | 100 | def mammography_event(): 101 | cancer = ~sq.bernoulli(0.01) 102 | return {"mammography": mammography(cancer), "cancer": cancer} 103 | 104 | 105 | def p_alarm_goes_off(burglary, earthquake): 106 | if burglary and earthquake: 107 | return 0.95 108 | elif burglary and not earthquake: 109 | return 0.94 110 | elif not burglary and earthquake: 111 | return 0.29 112 | elif not burglary and not earthquake: 113 | return 0.001 114 | 115 | 116 | def p_john_calls(alarm_goes_off): 117 | return 0.9 if alarm_goes_off else 0.05 118 | 119 | 120 | def p_mary_calls(alarm_goes_off): 121 | return 0.7 if alarm_goes_off else 0.01 122 | 123 | 124 | def alarm_net(): 125 | burglary_happens = sq.event(p=0.001) 126 | earthquake_happens = sq.event(p=0.002) 127 | alarm_goes_off = sq.event(p_alarm_goes_off(burglary_happens, earthquake_happens)) 128 | john_calls = sq.event(p_john_calls(alarm_goes_off)) 129 | mary_calls = sq.event(p_mary_calls(alarm_goes_off)) 130 | return { 131 | "burglary": burglary_happens, 132 | "earthquake": earthquake_happens, 133 | "alarm_goes_off": alarm_goes_off, 134 | "john_calls": john_calls, 135 | "mary_calls": mary_calls, 136 | } 137 | 138 | 139 | def monte_hall(door_picked, switch=False): 140 | doors = ["A", "B", "C"] 141 | car_is_behind_door = ~sq.discrete(doors) 142 | reveal_door = [d for d in doors if d != door_picked and d != car_is_behind_door] 143 | reveal_door = ~sq.discrete(reveal_door) 144 | 145 | if switch: 146 | old_door_picked = door_picked 147 | door_picked = [d for d in doors if d != old_door_picked and d != reveal_door][0] 148 | 149 | won_car = car_is_behind_door == door_picked 150 | return won_car 151 | 152 | 153 | def monte_hall_event(): 154 | door = ~sq.discrete(["A", "B", "C"]) 155 | switch = sq.event(0.5) 156 | return {"won": monte_hall(door_picked=door, switch=switch), "switched": switch} 157 | 158 | 159 | def coins_and_dice(): 160 | flip = sq.flip_coin() 161 | if flip == "heads": 162 | dice_sides = 6 163 | else: 164 | dice_sides = ~sq.discrete([4, 6, 10, 20]) 165 | return sq.roll_die(dice_sides) 166 | 167 | 168 | def model(): 169 | prior = sq.exponential(12) 170 | guess = sq.norm(10, 14) 171 | days = bayes.average(prior, guess, weights=[0.3, 0.7]) 172 | 173 | def move_days(days): 174 | if days < 4 and sq.event(0.9): 175 | days = 4 176 | if days < 7 and sq.event(0.9): 177 | diff_days = 7 - days 178 | days = days + sq.norm(diff_days / 1.5, diff_days * 1.5) 179 | return days 180 | 181 | return sq.dist_fn(days, fn=move_days) >> sq.dist_round >> sq.lclip(3) 182 | 183 | 184 | if __name__ == "__main__": 185 | print("Test 0 (LOAD SQ)") 186 | start0 = time.time() 187 | import squigglepy as sq 188 | from squigglepy.numbers import K, M 189 | from squigglepy import bayes 190 | 191 | _mark_time(start0, 0.033, "Test 0 complete") 192 | 193 | print("Test 1 (PIANO TUNERS, NO TIME, LONG FORMAT)...") 194 | sq.set_seed(42) 195 | start1 = time.time() 196 | pop_of_ny_2022 = sq.to(8.1 * M, 8.4 * M) 197 | out = sq.get_percentiles(sq.sample(total_tuners_in_2022, n=100), digits=1) 198 | expected = { 199 | 1: 0.6, 200 | 5: 0.9, 201 | 10: 1.1, 202 | 20: 2.0, 203 | 30: 2.6, 204 | 40: 3.1, 205 | 50: 3.9, 206 | 60: 4.6, 207 | 70: 6.1, 208 | 80: 8.1, 209 | 90: 11.8, 210 | 95: 19.6, 211 | 99: 36.8, 212 | } 213 | if out != expected: 214 | print("ERROR 1") 215 | import pdb 216 | 217 | pdb.set_trace() 218 | _mark_time(start1, 0.033, "Test 1 complete") 219 | 220 | print("Test 2 (PIANO TUNERS, NO TIME, LONG FORMAT)...") 221 | sq.set_seed(42) 222 | start1b = time.time() 223 | 224 | pop_of_ny_2022 = sq.to(8.1 * M, 8.4 * M) 225 | pct_of_pop_w_pianos_ = sq.to(0.2, 1) * 0.01 226 | pianos_per_piano_tuner_ = sq.to(2 * K, 50 * K) 227 | piano_tuners_per_piano_ = 1 / pianos_per_piano_tuner_ 228 | total_tuners_in_2022 = pop_of_ny_2022 * pct_of_pop_w_pianos_ * piano_tuners_per_piano_ 229 | samples = total_tuners_in_2022 @ 1000 # Note: `@ 1000` is shorthand to get 1000 samples 230 | out = sq.get_percentiles(samples, digits=1) 231 | expected = { 232 | 1: 0.3, 233 | 5: 0.5, 234 | 10: 0.8, 235 | 20: 1.3, 236 | 30: 1.9, 237 | 40: 2.6, 238 | 50: 3.5, 239 | 60: 4.5, 240 | 70: 6.2, 241 | 80: 8.6, 242 | 90: 14.0, 243 | 95: 22.1, 244 | 99: 48.1, 245 | } 246 | if out != expected: 247 | print("ERROR 1B") 248 | import pdb 249 | 250 | pdb.set_trace() 251 | _mark_time(start1b, 0.001, "Test 2 complete") 252 | 253 | print("Test 3 (PIANO TUNERS, TIME COMPONENT, LONG FORMAT)...") 254 | sq.set_seed(42) 255 | start2 = time.time() 256 | out = sq.get_percentiles(sq.sample(lambda: total_tuners_at_time(2030 - 2022), n=100), digits=1) 257 | expected = { 258 | 1: 0.7, 259 | 5: 1.0, 260 | 10: 1.3, 261 | 20: 2.1, 262 | 30: 2.7, 263 | 40: 3.4, 264 | 50: 4.3, 265 | 60: 6.0, 266 | 70: 7.4, 267 | 80: 9.4, 268 | 90: 14.1, 269 | 95: 19.6, 270 | 99: 24.4, 271 | } 272 | 273 | if out != expected: 274 | print("ERROR 2") 275 | import pdb 276 | 277 | pdb.set_trace() 278 | _mark_time(start2, 0.046, "Test 3 complete") 279 | 280 | print("Test 4 (PIANO TUNERS, TIME COMPONENT, SHORT FORMAT)...") 281 | sq.set_seed(42) 282 | start3 = time.time() 283 | out = sq.get_percentiles(total_tuners_at_time2(2030 - 2022) @ 100, digits=1) 284 | expected = { 285 | 1: 0.5, 286 | 5: 0.6, 287 | 10: 1.1, 288 | 20: 1.5, 289 | 30: 1.8, 290 | 40: 2.4, 291 | 50: 3.1, 292 | 60: 4.4, 293 | 70: 7.3, 294 | 80: 9.8, 295 | 90: 16.6, 296 | 95: 28.4, 297 | 99: 85.4, 298 | } 299 | 300 | if out != expected: 301 | print("ERROR 3") 302 | import pdb 303 | 304 | pdb.set_trace() 305 | _mark_time(start3, 0.001, "Test 4 complete") 306 | 307 | print("Test 5 (VARIOUS DISTRIBUTIONS, LONG FORMAT)...") 308 | sq.set_seed(42) 309 | start4 = time.time() 310 | sq.sample(sq.norm(1, 3)) # 90% interval from 1 to 3 311 | sq.sample(sq.norm(mean=0, sd=1)) 312 | sq.sample(sq.norm(-1.67, 1.67)) # This is equivalent to mean=0, sd=1 313 | sq.sample(sq.norm(1, 3), n=100) 314 | sq.sample(sq.lognorm(1, 10)) 315 | sq.sample(sq.invlognorm(1, 10)) 316 | sq.sample(sq.tdist(1, 10, t=5)) 317 | sq.sample(sq.triangular(1, 2, 3)) 318 | sq.sample(sq.pert(1, 2, 3, lam=2)) 319 | sq.sample(sq.binomial(p=0.5, n=5)) 320 | sq.sample(sq.beta(a=1, b=2)) 321 | sq.sample(sq.bernoulli(p=0.5)) 322 | sq.sample(sq.poisson(10)) 323 | sq.sample(sq.chisquare(2)) 324 | sq.sample(sq.gamma(3, 2)) 325 | sq.sample(sq.pareto(1)) 326 | sq.sample(sq.exponential(scale=1)) 327 | sq.sample(sq.geometric(p=0.5)) 328 | sq.sample(sq.discrete({"A": 0.1, "B": 0.9})) 329 | sq.sample(sq.discrete({0: 0.1, 1: 0.3, 2: 0.3, 3: 0.15, 4: 0.15})) 330 | sq.sample(sq.discrete([[0.1, 0], [0.3, 1], [0.3, 2], [0.15, 3], [0.15, 4]])) 331 | sq.sample(sq.discrete([0, 1, 2])) 332 | sq.sample(sq.mixture([sq.norm(1, 3), sq.norm(4, 10), sq.lognorm(1, 10)], [0.3, 0.3, 0.4])) 333 | sq.sample(sq.mixture([[0.3, sq.norm(1, 3)], [0.3, sq.norm(4, 10)], [0.4, sq.lognorm(1, 10)]])) 334 | sq.sample(lambda: sq.sample(sq.norm(1, 3)) + sq.sample(sq.norm(4, 5)), n=100) 335 | sq.sample(lambda: sq.sample(sq.norm(1, 3)) - sq.sample(sq.norm(4, 5)), n=100) 336 | sq.sample(lambda: sq.sample(sq.norm(1, 3)) * sq.sample(sq.norm(4, 5)), n=100) 337 | sq.sample(lambda: sq.sample(sq.norm(1, 3)) / sq.sample(sq.norm(4, 5)), n=100) 338 | sq.sample(sq.norm(1, 3, credibility=80)) 339 | sq.sample(sq.norm(0, 3, lclip=0, rclip=5)) 340 | sq.sample(sq.const(4)) 341 | sq.sample(sq.zero_inflated(0.6, sq.norm(1, 2))) 342 | roll_die(sides=6, n=10) 343 | _mark_time(start4, 0.110, "Test 5 complete") 344 | 345 | print("Test 6 (VARIOUS DISTRIBUTIONS, SHORT FORMAT)...") 346 | sq.set_seed(42) 347 | start5 = time.time() 348 | ~sq.norm(1, 3) 349 | ~sq.norm(mean=0, sd=1) 350 | ~sq.norm(-1.67, 1.67) 351 | sq.norm(1, 3) @ 100 352 | ~sq.lognorm(1, 10) 353 | ~sq.invlognorm(1, 10) 354 | ~sq.tdist(1, 10, t=5) 355 | ~sq.triangular(1, 2, 3) 356 | ~sq.pert(1, 2, 3, lam=2) 357 | ~sq.binomial(p=0.5, n=5) 358 | ~sq.beta(a=1, b=2) 359 | ~sq.bernoulli(p=0.5) 360 | ~sq.poisson(10) 361 | ~sq.chisquare(2) 362 | ~sq.gamma(3, 2) 363 | ~sq.pareto(1) 364 | ~sq.exponential(scale=1) 365 | ~sq.geometric(p=0.5) 366 | ~sq.discrete({"A": 0.1, "B": 0.9}) 367 | ~sq.discrete({0: 0.1, 1: 0.3, 2: 0.3, 3: 0.15, 4: 0.15}) 368 | ~sq.discrete([[0.1, 0], [0.3, 1], [0.3, 2], [0.15, 3], [0.15, 4]]) 369 | ~sq.discrete([0, 1, 2]) 370 | ~sq.mixture([sq.norm(1, 3), sq.norm(4, 10), sq.lognorm(1, 10)], [0.3, 0.3, 0.4]) 371 | ~sq.mixture([[0.3, sq.norm(1, 3)], [0.3, sq.norm(4, 10)], [0.4, sq.invlognorm(1, 10)]]) 372 | ~sq.norm(1, 3) + ~sq.norm(4, 5) 373 | ~sq.norm(1, 3) - ~sq.norm(4, 5) 374 | ~sq.norm(1, 3) / ~sq.norm(4, 5) 375 | ~sq.norm(1, 3) * ~sq.norm(4, 5) 376 | ~(sq.norm(1, 3) + ~sq.norm(4, 5)) 377 | ~(sq.norm(1, 3) - ~sq.norm(4, 5)) 378 | ~(sq.norm(1, 3) / ~sq.norm(4, 5)) 379 | ~(sq.norm(1, 3) * ~sq.norm(4, 5)) 380 | (sq.norm(1, 3) + ~sq.norm(4, 5)) @ 100 381 | (sq.norm(1, 3) - ~sq.norm(4, 5)) @ 100 382 | (sq.norm(1, 3) / ~sq.norm(4, 5)) @ 100 383 | (sq.norm(1, 3) * ~sq.norm(4, 5)) @ 100 384 | ~(-sq.lognorm(0.1, 1) * sq.pareto(1) / 10) 385 | ~sq.norm(1, 3, credibility=80) 386 | ~sq.norm(0, 3, lclip=0, rclip=5) 387 | ~sq.const(4) 388 | ~sq.zero_inflated(0.6, sq.norm(1, 2)) 389 | roll_die2(sides=6, n=10) 390 | _mark_time(start5, 0.005, "Test 6 complete") 391 | 392 | print("Test 7 (MAMMOGRAPHY BAYES)...") 393 | sq.set_seed(42) 394 | start7 = time.time() 395 | out = bayes.bayesnet( 396 | mammography_event, 397 | find=lambda e: e["cancer"], 398 | conditional_on=lambda e: e["mammography"], 399 | memcache=False, 400 | n=RUNS, 401 | ) 402 | expected = 0.09 403 | if round(out, 2) != expected: 404 | print("ERROR 7") 405 | import pdb 406 | 407 | pdb.set_trace() 408 | test_7_mark = _mark_time(start7, 0.187, "Test 7 complete") 409 | 410 | print("Test 8 (SIMPLE BAYES)...") 411 | sq.set_seed(42) 412 | start8 = time.time() 413 | out = bayes.simple_bayes(prior=0.01, likelihood_h=0.8, likelihood_not_h=0.096) 414 | expected = None 415 | if round(out, 2) != 0.08: 416 | print("ERROR 8") 417 | import pdb 418 | 419 | pdb.set_trace() 420 | _mark_time(start8, 0.00001, "Test 8 complete") 421 | 422 | print("Test 9 (BAYESIAN UPDATE)...") 423 | sq.set_seed(42) 424 | start9 = time.time() 425 | prior = sq.norm(1, 5) 426 | evidence = sq.norm(2, 3) 427 | posterior = bayes.update(prior, evidence) 428 | if round(posterior.mean, 2) != 2.53 and round(posterior.sd, 2) != 0.3: 429 | print("ERROR 9") 430 | import pdb 431 | 432 | pdb.set_trace() 433 | _mark_time(start9, 0.0004, "Test 9 complete") 434 | 435 | print("Test 10 (BAYESIAN AVERAGE)...") 436 | sq.set_seed(42) 437 | start10 = time.time() 438 | average = bayes.average(prior, evidence) 439 | average_samples = sq.sample(average, n=K) 440 | out = (np.mean(average_samples), np.std(average_samples)) 441 | if round(out[0], 2) != 2.76 and round(out[1], 2) != 0.10: 442 | print("ERROR 10") 443 | import pdb 444 | 445 | pdb.set_trace() 446 | _mark_time(start10, 0.002, "Test 10 complete") 447 | 448 | print("Test 11 (ALARM NET)...") 449 | sq.set_seed(42) 450 | start11 = time.time() 451 | out = bayes.bayesnet( 452 | alarm_net, 453 | n=RUNS * 3, 454 | find=lambda e: (e["mary_calls"] and e["john_calls"]), 455 | conditional_on=lambda e: e["earthquake"], 456 | ) 457 | if round(out, 2) != 0.19: 458 | print("ERROR 11") 459 | import pdb 460 | 461 | pdb.set_trace() 462 | _mark_time(start11, 0.68, "Test 11 complete") 463 | 464 | print("Test 12 (ALARM NET II)...") 465 | sq.set_seed(42) 466 | start12 = time.time() 467 | out = bayes.bayesnet( 468 | alarm_net, 469 | n=RUNS * 3, 470 | find=lambda e: e["burglary"], 471 | conditional_on=lambda e: (e["mary_calls"] and e["john_calls"]), 472 | ) 473 | if round(out, 2) != 0.35: 474 | print("ERROR 12") 475 | import pdb 476 | 477 | pdb.set_trace() 478 | _mark_time(start12, 0.0025, "Test 12 complete") 479 | 480 | print("Test 13 (MONTE HALL)...") 481 | sq.set_seed(42) 482 | start13 = time.time() 483 | out = bayes.bayesnet( 484 | monte_hall_event, 485 | find=lambda e: e["won"], 486 | conditional_on=lambda e: e["switched"], 487 | n=RUNS, 488 | ) 489 | if round(out, 2) != 0.67: 490 | print("ERROR 13") 491 | import pdb 492 | 493 | pdb.set_trace() 494 | _mark_time(start13, 1.26, "Test 13 complete") 495 | 496 | print("Test 14 (MONTE HALL II)...") 497 | sq.set_seed(42) 498 | start14 = time.time() 499 | out = bayes.bayesnet( 500 | monte_hall_event, 501 | find=lambda e: e["won"], 502 | conditional_on=lambda e: not e["switched"], 503 | n=RUNS, 504 | ) 505 | if round(out, 2) != 0.34: 506 | print("ERROR 14") 507 | import pdb 508 | 509 | pdb.set_trace() 510 | _mark_time(start14, 0.003, "Test 14 complete") 511 | 512 | print("Test 15 (COINS AND DICE)...") 513 | sq.set_seed(42) 514 | start15 = time.time() 515 | out = bayes.bayesnet(coins_and_dice, find=lambda e: e == 6, n=RUNS) 516 | if round(out, 2) != 0.12: 517 | print("ERROR 15") 518 | import pdb 519 | 520 | pdb.set_trace() 521 | _mark_time(start15, 1.24, "Test 15 complete") 522 | 523 | print("Test 16 (PIPES)...") 524 | sq.set_seed(42) 525 | start16 = time.time() 526 | samples = sq.sample(model, n=1000) 527 | if not all(isinstance(s, np.int64) for s in samples): 528 | print("ERROR 16") 529 | import pdb 530 | 531 | pdb.set_trace() 532 | _mark_time(start16, 0.247, "Test 16 complete") 533 | 534 | print("Test 17 (T TEST)...") 535 | sq.set_seed(42) 536 | start17 = time.time() 537 | # TODO: Accuracy with t<20 538 | ts = [20, 40, 50] 539 | vals = [[1, 10], [0, 3], [-4, 4], [5, 10], [100, 200]] 540 | credibilities = [80, 90] 541 | tqdm_ = tqdm(total=len(ts) * len(vals) * len(credibilities) * 2) 542 | for t in ts: 543 | for val in vals: 544 | for credibility in credibilities: 545 | for dist in [sq.tdist, sq.log_tdist]: 546 | if not dist == sq.log_tdist and val[0] < 1: 547 | dist = dist(val[0], val[1], t, credibility=credibility) 548 | pctiles = sq.get_percentiles( 549 | dist @ (20 * K), 550 | percentiles=[ 551 | (100 - credibility) / 2, 552 | 100 - ((100 - credibility) / 2), 553 | ], 554 | ) 555 | tol = 140 / t if isinstance(dist, LogTDistribution) else 1.35 556 | if not _within( 557 | pctiles[(100 - credibility) / 2], val[0], tol, tol 558 | ) or not _within( 559 | pctiles[100 - ((100 - credibility) / 2)], val[1], tol, tol 560 | ): 561 | print("ERROR 17 on {}".format(str(dist))) 562 | print(pctiles) 563 | import pdb 564 | 565 | pdb.set_trace() 566 | tqdm_.update(1) 567 | tqdm_.close() 568 | _mark_time(start17, 0.082, "Test 17 complete") 569 | 570 | print("Test 18 (SPEED TEST, 10M SAMPLES)...") 571 | start18 = time.time() 572 | samps = (sq.norm(1, 3) + sq.norm(4, 5)) @ (10 * M) 573 | if len(samps) != (10 * M): 574 | print("ERROR ON 18") 575 | import pdb 576 | 577 | pdb.set_trace() 578 | _mark_time(start18, 0.327, "Test 18 complete") 579 | 580 | print("Test 19 (LCLIP FIDELITY, 1M SAMPLES)...") 581 | start19 = time.time() 582 | dist = sq.mixture([[0.1, 0], [0.8, sq.norm(0, 3)], [0.1, sq.norm(7, 11)]], lclip=0) 583 | samps = dist @ (1 * M) 584 | if any(samps < 0): 585 | print("ERROR ON 19") 586 | import pdb 587 | 588 | pdb.set_trace() 589 | _mark_time(start19, 1.5, "Test 19 complete") 590 | 591 | print("Test 20 (RCLIP FIDELITY, 1M SAMPLES)...") 592 | start20 = time.time() 593 | dist = sq.mixture([[0.1, 0], [0.1, sq.norm(0, 3)], [0.8, sq.norm(7, 11)]], rclip=3) 594 | samps = dist @ (1 * M) 595 | if any(samps > 3): 596 | print("ERROR ON 20") 597 | import pdb 598 | 599 | pdb.set_trace() 600 | test_20_mark = _mark_time(start20, 1.5, "Test 20 complete") 601 | 602 | print("Test 21 (MULTICORE SAMPLE, 10M SAMPLES)...") 603 | start21 = time.time() 604 | dist = sq.mixture([[0.1, 0], [0.1, sq.norm(0, 3)], [0.8, sq.norm(7, 11)]], rclip=3) 605 | samps = sq.sample(dist, cores=7, n=10 * M, verbose=True) 606 | if len(samps) != (10 * M) or any(samps > 3): 607 | print("ERROR ON 21") 608 | import pdb 609 | 610 | pdb.set_trace() 611 | test_21_mark = _mark_time(start21, 3.6, "Test 21 complete") 612 | print("1 core 10M RUNS expected {}sec".format(round(test_20_mark["timing(sec)"] * 10, 1))) 613 | print("7 core 10M RUNS ideal {}sec".format(round(test_20_mark["timing(sec)"] * 10 / 7, 1))) 614 | print("7 core 10M RUNS actual {}sec".format(round(test_21_mark["timing(sec)"], 1))) 615 | 616 | print("Test 22 (MAMMOGRAPHY BAYES MULTICORE)...") 617 | sq.set_seed(42) 618 | start22 = time.time() 619 | out = bayes.bayesnet( 620 | mammography_event, 621 | find=lambda e: e["cancer"], 622 | conditional_on=lambda e: e["mammography"], 623 | n=10 * M, 624 | verbose=True, 625 | memcache=False, 626 | cores=7, 627 | ) 628 | expected = 0.08 629 | if round(out, 2) != expected: 630 | print("ERROR ON 22") 631 | import pdb 632 | 633 | pdb.set_trace() 634 | test_22_mark = _mark_time(start22, 84.87, "Test 22 complete") 635 | print("1 core 10M RUNS expected {}sec".format(round(test_7_mark["timing(sec)"] * K, 1))) 636 | print("7 core 10M RUNS ideal {}sec".format(round(test_7_mark["timing(sec)"] * K / 7, 1))) 637 | print("7 core 10M RUNS actual {}sec".format(round(test_22_mark["timing(sec)"], 1))) 638 | 639 | print("Test 23 (DISCRETE COMPRESSION)...") 640 | start23 = time.time() 641 | large_array = sq.mixture([[0.1, 0.1], [0.2, 0.2], [0.3, 0.3], [0.4, 0.4]]) @ (10 * M) 642 | dist = sq.discrete(large_array) 643 | samps = sq.sample(dist, n=1 * M, verbose=True) 644 | test_23_mark = _mark_time(start23, 20.53, "Test 23 complete") 645 | 646 | print("Test 24 (DISCRETE COMPRESSION, MULTICORE)...") 647 | start24 = time.time() 648 | large_array = sq.mixture([[0.1, 0.1], [0.2, 0.2], [0.3, 0.3], [0.4, 0.4]]) @ (10 * M) 649 | dist = sq.discrete(large_array) 650 | samps = sq.sample(dist, n=10 * M, verbose=True) 651 | test_24_mark = _mark_time(start24, 31, "Test 22 complete") 652 | print("1 core 10M RUNS expected {}sec".format(round(test_23_mark["timing(sec)"] * 10, 1))) 653 | print("7 core 10M RUNS ideal {}sec".format(round(test_23_mark["timing(sec)"] * 10 / 7, 1))) 654 | print("7 core 10M RUNS actual {}sec".format(round(test_24_mark["timing(sec)"], 1))) 655 | 656 | print("Test 25 (VERSION)...") 657 | print("Squigglepy version is {}".format(sq.__version__)) 658 | 659 | # END 660 | _mark_time(start0, 150, "Integration tests complete") 661 | print("DONE! INTEGRATION TEST SUCCESS!") 662 | -------------------------------------------------------------------------------- /tests/strategies.py: -------------------------------------------------------------------------------- 1 | from typing import Callable 2 | 3 | import hypothesis.strategies as st 4 | import numpy as np 5 | from hypothesis import assume, note 6 | from hypothesis.extra.numpy import arrays 7 | 8 | from .. import squigglepy as sq 9 | 10 | CONTINUOUS_DISTRIBUTIONS = [ 11 | sq.uniform, 12 | sq.norm, 13 | sq.lognorm, 14 | # sq.to, # Disabled to help isolate errors to either normal or lognormal 15 | sq.beta, 16 | sq.tdist, 17 | # sq.log_tdist, # TODO: Re-enable when overflows are fixed 18 | sq.triangular, 19 | sq.chisquare, 20 | sq.exponential, 21 | sq.gamma, 22 | sq.pareto, 23 | sq.pert, 24 | ] 25 | 26 | DISCRETE_DISTRIBUTIONS = [ 27 | sq.binomial, 28 | sq.bernoulli, 29 | sq.discrete, 30 | sq.poisson, 31 | ] 32 | 33 | ALL_DISTRIBUTIONS = CONTINUOUS_DISTRIBUTIONS + DISCRETE_DISTRIBUTIONS 34 | 35 | 36 | @st.composite 37 | def distributions_with_correlation(draw, min_size=2, max_size=20, continuous_only=False): 38 | dists = tuple( 39 | draw( 40 | st.lists( 41 | random_distributions(continuous_only), 42 | min_size=min_size, 43 | max_size=max_size, 44 | ) 45 | ) 46 | ) 47 | corr = draw(correlation_matrices(min_size=len(dists), max_size=len(dists))) 48 | note(f"Distributions: {dists}") 49 | note(f"Correlation matrix: {corr}") 50 | return dists, corr 51 | 52 | 53 | @st.composite 54 | def correlation_matrices(draw, min_size=2, max_size=20): 55 | # Generate a random list of correlations 56 | n_variables = draw(st.integers(min_size, max_size)) 57 | correlation_matrix = draw( 58 | arrays(np.float64, (n_variables, n_variables), elements=st.floats(-0.99, 0.99)) 59 | ) 60 | # Reflect the matrix 61 | correlation_matrix = np.tril(correlation_matrix) + np.tril(correlation_matrix, -1).T 62 | # Fill the diagonal with 1s 63 | np.fill_diagonal(correlation_matrix, 1) 64 | 65 | # Reject if not positive semi-definite 66 | assume(np.all(np.linalg.eigvals(correlation_matrix) >= 0)) 67 | 68 | return correlation_matrix 69 | 70 | 71 | @st.composite 72 | def random_distributions( 73 | draw, continuous_only: bool = False, discrete_only: bool = False 74 | ) -> sq.OperableDistribution: 75 | assert not (continuous_only and discrete_only), "Cannot be both continuous and discrete" 76 | 77 | if continuous_only: 78 | dist = instantiate_with_parameters(draw, draw(st.sampled_from(CONTINUOUS_DISTRIBUTIONS))) 79 | assert isinstance(dist, sq.ContinuousDistribution), f"{dist} is not continuous" 80 | elif discrete_only: 81 | dist = instantiate_with_parameters(draw, draw(st.sampled_from(DISCRETE_DISTRIBUTIONS))) 82 | assert isinstance(dist, sq.DiscreteDistribution), f"{dist} is not discrete" 83 | else: 84 | dist = instantiate_with_parameters(draw, draw(st.sampled_from(ALL_DISTRIBUTIONS))) 85 | assert isinstance(dist, sq.OperableDistribution), f"{dist} is not an operable distribution" 86 | 87 | return dist 88 | 89 | 90 | def instantiate_with_parameters(draw, dist_fn: Callable) -> sq.OperableDistribution: 91 | if dist_fn == sq.uniform: 92 | a = draw( 93 | st.floats(-1e30, 1e30, allow_nan=False, allow_infinity=False, allow_subnormal=False) 94 | ) 95 | b = draw( 96 | st.floats( 97 | min_value=a + 2, 98 | max_value=a + 1e30, 99 | allow_nan=False, 100 | allow_subnormal=False, 101 | exclude_min=True, 102 | ) 103 | ) 104 | return dist_fn(a, b) 105 | elif dist_fn in (sq.norm, sq.tdist): 106 | # Distributions that receive confidence intervals 107 | a = draw( 108 | st.floats(-1e30, 1e30, allow_infinity=False, allow_nan=False, allow_subnormal=False) 109 | ) 110 | b = draw( 111 | st.floats( 112 | min_value=a + 0.01, 113 | max_value=a + 1e30, 114 | allow_infinity=False, 115 | allow_nan=False, 116 | exclude_min=True, 117 | allow_subnormal=False, 118 | ) 119 | ) 120 | return dist_fn(a, b) 121 | elif dist_fn in (sq.lognorm, sq.log_tdist): 122 | # Distributions that receive confidence intervals starting from 0 123 | a = draw( 124 | st.floats( 125 | 0.005, 126 | 1e20, 127 | allow_infinity=False, 128 | allow_nan=False, 129 | allow_subnormal=False, 130 | exclude_min=True, 131 | ) 132 | ) 133 | b = draw( 134 | st.floats( 135 | min_value=a + 0.05, 136 | max_value=a + 1e20, 137 | allow_infinity=False, 138 | allow_nan=False, 139 | exclude_min=True, 140 | allow_subnormal=False, 141 | ) 142 | ) 143 | return dist_fn(a, b) 144 | elif dist_fn == sq.binomial: 145 | n = draw(st.integers(1, 500)) 146 | p = draw(st.floats(0.01, 0.999, exclude_min=True, exclude_max=True)) 147 | return dist_fn(n, p) 148 | elif dist_fn == sq.bernoulli: 149 | p = draw(st.floats(0.01, 0.999, exclude_min=True, exclude_max=True, allow_subnormal=False)) 150 | return dist_fn(p) 151 | elif dist_fn == sq.discrete: 152 | items = draw( 153 | st.dictionaries( 154 | st.floats(allow_infinity=False, allow_nan=False), 155 | st.floats(0, 1, exclude_min=True), 156 | min_size=1, 157 | ) 158 | ) 159 | # Normalize the probabilities 160 | normalized_items = dict() 161 | value_sum = sum(items.values()) 162 | for k, v in items.items(): 163 | normalized_items[k] = v / value_sum 164 | 165 | return dist_fn(normalized_items) 166 | elif dist_fn == sq.exponential: 167 | a = draw( 168 | st.floats( 169 | min_value=0, 170 | max_value=1e20, # Prevents overflow 171 | exclude_min=True, 172 | exclude_max=True, 173 | allow_infinity=False, 174 | allow_nan=False, 175 | allow_subnormal=False, # Prevents overflow (again) 176 | ) 177 | ) 178 | return dist_fn(a) 179 | elif dist_fn == sq.beta: 180 | a = draw( 181 | st.floats( 182 | min_value=0.01, 183 | max_value=100, 184 | exclude_min=True, 185 | allow_nan=False, 186 | allow_infinity=False, 187 | allow_subnormal=False, 188 | ) 189 | ) 190 | b = draw( 191 | st.floats( 192 | min_value=0.01, 193 | max_value=100, 194 | exclude_min=True, 195 | allow_nan=False, 196 | allow_infinity=False, 197 | allow_subnormal=False, 198 | ) 199 | ) 200 | return dist_fn(a, b) 201 | 202 | elif dist_fn == sq.triangular: 203 | a = draw( 204 | st.floats(-1e30, 1e30, allow_infinity=False, allow_nan=False, allow_subnormal=False) 205 | ) 206 | b = draw( 207 | st.floats( 208 | min_value=a + 0.05, 209 | max_value=a + 1e30, 210 | allow_infinity=False, 211 | allow_nan=False, 212 | allow_subnormal=False, 213 | exclude_min=True, 214 | ) 215 | ) 216 | c = draw( 217 | st.floats( 218 | min_value=a, 219 | max_value=b, 220 | allow_infinity=False, 221 | allow_nan=False, 222 | allow_subnormal=False, 223 | ) 224 | ) 225 | return dist_fn(a, c, b) 226 | 227 | elif dist_fn == sq.pert: 228 | low = draw( 229 | st.floats(-1e10, 1e10, allow_infinity=False, allow_nan=False, allow_subnormal=False) 230 | ) 231 | mode_offset = draw( 232 | st.floats( 233 | min_value=0.01, 234 | max_value=1e10, 235 | allow_infinity=False, 236 | allow_nan=False, 237 | ) 238 | ) 239 | high_offset = draw( 240 | st.floats( 241 | min_value=0.01, 242 | max_value=1e10, 243 | allow_infinity=False, 244 | allow_nan=False, 245 | ) 246 | ) 247 | shape = draw( 248 | st.floats( 249 | min_value=0.01, 250 | max_value=1e30, 251 | allow_infinity=False, 252 | allow_nan=False, 253 | ) 254 | ) 255 | return dist_fn(low, low + mode_offset, low + mode_offset + high_offset, shape) 256 | 257 | elif dist_fn == sq.poisson: 258 | lambda_ = draw(st.integers(1, 1000)) 259 | return dist_fn(lambda_) 260 | 261 | elif dist_fn == sq.chisquare: 262 | df = draw(st.integers(1, 1000)) 263 | return dist_fn(df) 264 | 265 | elif dist_fn == sq.gamma: 266 | shape = draw( 267 | st.floats( 268 | min_value=0.01, 269 | max_value=100, 270 | exclude_min=True, 271 | allow_nan=False, 272 | allow_infinity=False, 273 | allow_subnormal=False, 274 | ) 275 | ) 276 | scale = draw( 277 | st.floats( 278 | min_value=0.01, 279 | max_value=100, 280 | exclude_min=True, 281 | allow_nan=False, 282 | allow_infinity=False, 283 | allow_subnormal=False, 284 | ) 285 | ) 286 | return dist_fn(shape, scale) 287 | 288 | elif dist_fn == sq.pareto: 289 | b = draw( 290 | st.floats( 291 | min_value=1.01, 292 | max_value=10.0, 293 | exclude_min=True, 294 | allow_nan=False, 295 | allow_infinity=False, 296 | allow_subnormal=False, 297 | ) 298 | ) 299 | return dist_fn(b) 300 | else: 301 | raise NotImplementedError(f"Unknown distribution {dist_fn}") 302 | -------------------------------------------------------------------------------- /tests/test_bayes.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pytest 3 | 4 | from ..squigglepy.bayes import simple_bayes, bayesnet, update, average 5 | from ..squigglepy.samplers import sample 6 | from ..squigglepy.distributions import discrete, norm, beta, gamma 7 | from ..squigglepy.rng import set_seed 8 | from ..squigglepy.distributions import BetaDistribution, MixtureDistribution, NormalDistribution 9 | 10 | 11 | def test_simple_bayes(): 12 | out = simple_bayes(prior=0.01, likelihood_h=0.8, likelihood_not_h=0.096) 13 | assert round(out, 2) == 0.08 14 | 15 | 16 | def test_bayesnet(): 17 | set_seed(42) 18 | out = bayesnet( 19 | lambda: {"a": 1, "b": 2}, 20 | find=lambda e: e["a"], 21 | conditional_on=lambda e: e["b"], 22 | n=100, 23 | ) 24 | assert out == 1 25 | 26 | 27 | def test_bayesnet_noop(): 28 | out = bayesnet() 29 | assert out is None 30 | 31 | 32 | def test_bayesnet_conditional(): 33 | def define_event(): 34 | a = sample(discrete([1, 2])) 35 | b = 1 if a == 1 else 2 36 | return {"a": a, "b": b} 37 | 38 | set_seed(42) 39 | out = bayesnet(define_event, find=lambda e: e["a"] == 1, n=100) 40 | assert round(out, 1) == 0.5 41 | 42 | out = bayesnet( 43 | define_event, 44 | find=lambda e: e["a"] == 1, 45 | conditional_on=lambda e: e["b"] == 1, 46 | n=100, 47 | ) 48 | assert round(out, 1) == 1 49 | 50 | out = bayesnet( 51 | define_event, 52 | find=lambda e: e["a"] == 2, 53 | conditional_on=lambda e: e["b"] == 1, 54 | n=100, 55 | ) 56 | assert round(out, 1) == 0 57 | 58 | out = bayesnet( 59 | define_event, 60 | find=lambda e: e["a"] == 1, 61 | conditional_on=lambda e: e["b"] == 2, 62 | n=100, 63 | ) 64 | assert round(out, 1) == 0 65 | 66 | 67 | def test_bayesnet_reduce_fn(): 68 | out = bayesnet(lambda: {"a": 1, "b": 2}, find=lambda e: e["a"], reduce_fn=sum, n=100) 69 | assert out == 100 70 | 71 | 72 | def test_bayesnet_raw(): 73 | out = bayesnet(lambda: {"a": 1, "b": 2}, find=lambda e: e["a"], raw=True, n=100) 74 | assert out == [1] * 100 75 | 76 | 77 | def test_bayesnet_cache(): 78 | from ..squigglepy.bayes import _squigglepy_internal_bayesnet_caches 79 | 80 | n_caches = len(_squigglepy_internal_bayesnet_caches) 81 | 82 | def define_event(): 83 | return {"a": 1, "b": 2} 84 | 85 | bayesnet(define_event, find=lambda e: e["a"], n=100) 86 | from ..squigglepy.bayes import _squigglepy_internal_bayesnet_caches 87 | 88 | n_caches2 = len(_squigglepy_internal_bayesnet_caches) 89 | assert n_caches < n_caches2 90 | 91 | bayesnet(define_event, find=lambda e: e["a"], n=100) 92 | from ..squigglepy.bayes import _squigglepy_internal_bayesnet_caches 93 | 94 | n_caches3 = len(_squigglepy_internal_bayesnet_caches) 95 | assert n_caches2 == n_caches3 96 | 97 | bayesnet(define_event, find=lambda e: e["b"], n=100) 98 | from ..squigglepy.bayes import _squigglepy_internal_bayesnet_caches 99 | 100 | n_caches4 = len(_squigglepy_internal_bayesnet_caches) 101 | assert n_caches2 == n_caches4 102 | assert _squigglepy_internal_bayesnet_caches.get(define_event)["metadata"]["n"] == 100 103 | 104 | 105 | def test_bayesnet_cache_multiple(): 106 | from ..squigglepy.bayes import _squigglepy_internal_bayesnet_caches 107 | 108 | n_caches = len(_squigglepy_internal_bayesnet_caches) 109 | 110 | def define_event(): 111 | return {"a": 1, "b": 2} 112 | 113 | bayesnet(define_event, find=lambda e: e["a"], n=100) 114 | from ..squigglepy.bayes import _squigglepy_internal_bayesnet_caches 115 | 116 | n_caches2 = len(_squigglepy_internal_bayesnet_caches) 117 | assert n_caches < n_caches2 118 | 119 | bayesnet(define_event, find=lambda e: e["a"], n=100) 120 | from ..squigglepy.bayes import _squigglepy_internal_bayesnet_caches 121 | 122 | n_caches3 = len(_squigglepy_internal_bayesnet_caches) 123 | assert n_caches2 == n_caches3 124 | 125 | def define_event2(): 126 | return {"a": 4, "b": 6} 127 | 128 | bayesnet(define_event2, find=lambda e: e["b"], n=1000) 129 | from ..squigglepy.bayes import _squigglepy_internal_bayesnet_caches 130 | 131 | n_caches4 = len(_squigglepy_internal_bayesnet_caches) 132 | assert n_caches2 < n_caches4 133 | assert _squigglepy_internal_bayesnet_caches.get(define_event)["metadata"]["n"] == 100 134 | assert _squigglepy_internal_bayesnet_caches.get(define_event2)["metadata"]["n"] == 1000 135 | 136 | bayesnet(define_event2, find=lambda e: e["a"], n=100) 137 | from ..squigglepy.bayes import _squigglepy_internal_bayesnet_caches 138 | 139 | n_caches5 = len(_squigglepy_internal_bayesnet_caches) 140 | assert n_caches4 == n_caches5 141 | 142 | bayesnet(define_event, find=lambda e: e["a"], n=100) 143 | from ..squigglepy.bayes import _squigglepy_internal_bayesnet_caches 144 | 145 | n_caches6 = len(_squigglepy_internal_bayesnet_caches) 146 | assert n_caches4 == n_caches6 147 | 148 | 149 | def test_bayesnet_reload_cache(): 150 | from ..squigglepy.bayes import _squigglepy_internal_bayesnet_caches 151 | 152 | n_caches = len(_squigglepy_internal_bayesnet_caches) 153 | 154 | def define_event(): 155 | return {"a": 1, "b": 2} 156 | 157 | bayesnet(define_event, find=lambda e: e["a"], n=100) 158 | from ..squigglepy.bayes import _squigglepy_internal_bayesnet_caches 159 | 160 | n_caches2 = len(_squigglepy_internal_bayesnet_caches) 161 | assert n_caches < n_caches2 162 | 163 | bayesnet(define_event, find=lambda e: e["a"], n=100) 164 | from ..squigglepy.bayes import _squigglepy_internal_bayesnet_caches 165 | 166 | n_caches3 = len(_squigglepy_internal_bayesnet_caches) 167 | assert n_caches2 == n_caches3 168 | 169 | bayesnet(define_event, find=lambda e: e["b"], n=100, reload_cache=True) 170 | from ..squigglepy.bayes import _squigglepy_internal_bayesnet_caches 171 | 172 | n_caches4 = len(_squigglepy_internal_bayesnet_caches) 173 | assert n_caches3 == n_caches4 174 | assert _squigglepy_internal_bayesnet_caches.get(define_event)["metadata"]["n"] == 100 175 | 176 | 177 | def test_bayesnet_dont_use_cache(): 178 | from ..squigglepy.bayes import _squigglepy_internal_bayesnet_caches 179 | 180 | n_caches = len(_squigglepy_internal_bayesnet_caches) 181 | 182 | def define_event(): 183 | return {"a": 1, "b": 2} 184 | 185 | bayesnet(define_event, find=lambda e: e["a"], memcache=False, n=100) 186 | from ..squigglepy.bayes import _squigglepy_internal_bayesnet_caches 187 | 188 | n_caches2 = len(_squigglepy_internal_bayesnet_caches) 189 | assert n_caches == n_caches2 190 | 191 | 192 | def test_bayesnet_cache_n_error(): 193 | def define_event(): 194 | return {"a": 1, "b": 2} 195 | 196 | bayesnet(define_event, find=lambda e: e["a"], n=100) 197 | with pytest.raises(ValueError) as excinfo: 198 | bayesnet(define_event, find=lambda e: e["a"], n=1000) 199 | assert "100 results cached but requested 1000" in str(excinfo.value) 200 | 201 | 202 | def test_bayesnet_insufficent_samples_error(): 203 | with pytest.raises(ValueError) as excinfo: 204 | bayesnet( 205 | lambda: {"a": 1, "b": 2}, 206 | find=lambda e: e["a"], 207 | conditional_on=lambda e: e["b"] == 3, 208 | n=100, 209 | ) 210 | assert "insufficient samples" in str(excinfo.value) 211 | 212 | 213 | @pytest.fixture 214 | def cachefile(): 215 | cachefile = "testcache" 216 | yield cachefile 217 | os.remove(cachefile + ".sqcache") 218 | 219 | 220 | def test_bayesnet_cachefile(cachefile): 221 | assert not os.path.exists(cachefile + ".sqcache") 222 | 223 | def define_event(): 224 | return {"a": 1, "b": 2} 225 | 226 | bayesnet(define_event, find=lambda e: e["a"], dump_cache_file=cachefile, n=100) 227 | 228 | out = bayesnet(define_event, find=lambda e: e["a"], raw=True, n=100) 229 | assert os.path.exists(cachefile + ".sqcache") 230 | assert set(out) == set([1]) 231 | 232 | out = bayesnet( 233 | define_event, 234 | find=lambda e: e["a"], 235 | conditional_on=lambda e: e["b"] == 2, 236 | raw=True, 237 | n=100, 238 | ) 239 | assert os.path.exists(cachefile + ".sqcache") 240 | assert set(out) == set([1]) 241 | 242 | def define_event(): 243 | return {"a": 2, "b": 3} 244 | 245 | out = bayesnet( 246 | define_event, 247 | find=lambda e: e["a"], 248 | load_cache_file=cachefile, 249 | memcache=False, 250 | raw=True, 251 | n=100, 252 | ) 253 | assert set(out) == set([1]) 254 | 255 | out = bayesnet(define_event, find=lambda e: e["a"], memcache=False, raw=True, n=100) 256 | assert set(out) == set([2]) 257 | 258 | 259 | def test_bayesnet_cachefile_primary(cachefile): 260 | assert not os.path.exists(cachefile + ".sqcache") 261 | 262 | from ..squigglepy.bayes import _squigglepy_internal_bayesnet_caches 263 | 264 | n_caches = len(_squigglepy_internal_bayesnet_caches) 265 | 266 | def define_event(): 267 | return {"a": 1, "b": 2} 268 | 269 | bayesnet(define_event, find=lambda e: e["a"], n=100) 270 | 271 | from ..squigglepy.bayes import _squigglepy_internal_bayesnet_caches 272 | 273 | n_caches2 = len(_squigglepy_internal_bayesnet_caches) 274 | assert n_caches2 == n_caches + 1 275 | assert not os.path.exists(cachefile + ".sqcache") 276 | 277 | def define_event2(): 278 | return {"a": 2, "b": 3} 279 | 280 | bayesnet( 281 | define_event2, 282 | find=lambda e: e["a"], 283 | dump_cache_file=cachefile, 284 | memcache=False, 285 | n=100, 286 | ) 287 | 288 | from ..squigglepy.bayes import _squigglepy_internal_bayesnet_caches 289 | 290 | n_caches3 = len(_squigglepy_internal_bayesnet_caches) 291 | assert n_caches3 == n_caches2 292 | assert os.path.exists(cachefile + ".sqcache") 293 | 294 | out = bayesnet(define_event, find=lambda e: e["a"], raw=True, n=100) 295 | assert set(out) == set([1]) 296 | 297 | out = bayesnet( 298 | define_event, 299 | load_cache_file=cachefile, 300 | cache_file_primary=False, 301 | find=lambda e: e["a"], 302 | raw=True, 303 | n=100, 304 | ) 305 | assert set(out) == set([1]) 306 | 307 | out = bayesnet( 308 | define_event, 309 | load_cache_file=cachefile, 310 | cache_file_primary=True, 311 | find=lambda e: e["a"], 312 | raw=True, 313 | n=100, 314 | ) 315 | assert set(out) == set([2]) 316 | 317 | from ..squigglepy.bayes import _squigglepy_internal_bayesnet_caches 318 | 319 | n_caches4 = len(_squigglepy_internal_bayesnet_caches) 320 | assert n_caches4 == n_caches2 321 | assert os.path.exists(cachefile + ".sqcache") 322 | 323 | 324 | def test_bayesnet_cachefile_will_also_memcache(cachefile): 325 | assert not os.path.exists(cachefile + ".sqcache") 326 | from ..squigglepy.bayes import _squigglepy_internal_bayesnet_caches 327 | 328 | n_caches = len(_squigglepy_internal_bayesnet_caches) 329 | 330 | def define_event(): 331 | return {"a": 1, "b": 2} 332 | 333 | out = bayesnet( 334 | define_event, 335 | find=lambda e: e["a"], 336 | dump_cache_file=cachefile, 337 | memcache=False, 338 | raw=True, 339 | n=100, 340 | ) 341 | 342 | assert os.path.exists(cachefile + ".sqcache") 343 | assert set(out) == set([1]) 344 | from ..squigglepy.bayes import _squigglepy_internal_bayesnet_caches 345 | 346 | n_caches2 = len(_squigglepy_internal_bayesnet_caches) 347 | assert n_caches2 == n_caches 348 | 349 | out = bayesnet( 350 | define_event, 351 | find=lambda e: e["a"], 352 | dump_cache_file=cachefile, 353 | memcache=True, 354 | raw=True, 355 | n=100, 356 | ) 357 | 358 | assert os.path.exists(cachefile + ".sqcache") 359 | assert set(out) == set([1]) 360 | from ..squigglepy.bayes import _squigglepy_internal_bayesnet_caches 361 | 362 | n_caches3 = len(_squigglepy_internal_bayesnet_caches) 363 | assert n_caches3 == n_caches + 1 364 | 365 | 366 | def test_bayesnet_cachefile_insufficent_samples_error(cachefile): 367 | assert not os.path.exists(cachefile + ".sqcache") 368 | 369 | def define_event(): 370 | return {"a": 1, "b": 2} 371 | 372 | bayesnet(define_event, find=lambda e: e["a"], dump_cache_file=cachefile, n=100) 373 | assert os.path.exists(cachefile + ".sqcache") 374 | 375 | with pytest.raises(ValueError) as excinfo: 376 | bayesnet(define_event, load_cache_file=cachefile, find=lambda e: e["a"], n=1000) 377 | assert "insufficient samples" in str(excinfo.value) 378 | 379 | 380 | def test_bayesnet_multicore(): 381 | def define_event(): 382 | return {"a": 1, "b": 2} 383 | 384 | out = bayesnet(define_event, find=lambda e: e["a"], cores=2, n=100) 385 | assert out == 1 386 | assert not os.path.exists("test-core-0.sqcache") 387 | 388 | 389 | def test_update_normal(): 390 | out = update(norm(1, 10), norm(5, 15)) 391 | assert isinstance(out, NormalDistribution) 392 | assert round(out.mean, 2) == 7.51 393 | assert round(out.sd, 2) == 2.03 394 | 395 | 396 | def test_update_normal_evidence_weight(): 397 | out = update(norm(1, 10), norm(5, 15), evidence_weight=3) 398 | assert isinstance(out, NormalDistribution) 399 | assert round(out.mean, 2) == 8.69 400 | assert round(out.sd, 2) == 1.48 401 | 402 | 403 | def test_update_beta(): 404 | out = update(beta(1, 1), beta(2, 2)) 405 | assert isinstance(out, BetaDistribution) 406 | assert out.a == 3 407 | assert out.b == 3 408 | 409 | 410 | def test_update_not_implemented(): 411 | with pytest.raises(ValueError) as excinfo: 412 | update(gamma(1), gamma(2)) 413 | assert "not supported" in str(excinfo.value) 414 | 415 | 416 | def test_update_not_matching(): 417 | with pytest.raises(ValueError) as excinfo: 418 | update(norm(1, 2), beta(1, 2)) 419 | assert "can only update distributions of the same type" in str(excinfo.value) 420 | 421 | 422 | def test_average(): 423 | out = average(norm(1, 2), norm(3, 4)) 424 | assert isinstance(out, MixtureDistribution) 425 | assert isinstance(out.dists[0], NormalDistribution) 426 | assert out.dists[0].x == 1 427 | assert out.dists[0].y == 2 428 | assert isinstance(out.dists[1], NormalDistribution) 429 | assert out.dists[1].x == 3 430 | assert out.dists[1].y == 4 431 | assert out.weights == [0.5, 0.5] 432 | -------------------------------------------------------------------------------- /tests/test_correlation.py: -------------------------------------------------------------------------------- 1 | import scipy.stats as stats 2 | 3 | from .. import squigglepy as sq 4 | from .strategies import distributions_with_correlation 5 | from hypothesis import given, assume, note, example 6 | import hypothesis.strategies as st 7 | import numpy as np 8 | import warnings 9 | 10 | 11 | def check_correlation_from_matrix(dists, corr, atol=0.08): 12 | samples = np.column_stack([dist @ 3_000 for dist in dists]) 13 | estimated_corr = stats.spearmanr(samples).statistic 14 | note(f"Estimated correlation: {estimated_corr}") 15 | if len(dists) == 2: 16 | note(f"Desired correlation: {corr[0, 1]}") 17 | assert np.all(np.isclose(estimated_corr, corr[0, 1], atol=atol)) 18 | else: 19 | note(f"Desired correlation: {corr}") 20 | assert np.all(np.isclose(estimated_corr, corr, atol=atol)) 21 | return samples 22 | 23 | 24 | def check_correlation_from_parameter(dists_or_samples, corr, atol=0.08): 25 | assert len(dists_or_samples) == 2 26 | if isinstance(dists_or_samples[0], sq.OperableDistribution): 27 | # Sample 28 | samples = np.column_stack([dists_or_samples[0] @ 5_000, dists_or_samples[1] @ 5_000]) 29 | else: 30 | assert isinstance(dists_or_samples[0], np.ndarray) 31 | samples = np.column_stack(dists_or_samples) 32 | 33 | note(f"Desired correlation: {corr}") 34 | estimated_corr = stats.spearmanr(samples).statistic 35 | note(f"Estimated correlation: {estimated_corr}") 36 | assert np.all(np.isclose(estimated_corr, corr, atol=atol)) 37 | 38 | 39 | @given(st.floats(-0.999, 0.999)) 40 | @example(corr=0.5).via("discovered failure") 41 | def test_basic_correlates(corr): 42 | """ 43 | Test a basic example of correlation between two distributions. 44 | This ensures that the resulting distributions are correlated as expected. 45 | """ 46 | with warnings.catch_warnings(): 47 | warnings.simplefilter("error") 48 | 49 | a_params = (-1, 1) 50 | b_params = (0, 1) 51 | 52 | a, b = sq.UniformDistribution(*a_params), sq.NormalDistribution( 53 | mean=b_params[0], sd=b_params[1] 54 | ) 55 | a, b = sq.correlate((a, b), corr, tolerance=None) 56 | 57 | # Sample 58 | a_samples = a @ 3_000 59 | b_samples = b @ 3_000 60 | check_correlation_from_parameter((a, b), corr) 61 | 62 | # Check marginal distributions 63 | # a (uniform) 64 | assert np.isclose( 65 | np.mean(a_samples), np.mean(a_params), atol=0.08 66 | ), f"Mean: {np.mean(a_samples)} != {np.mean(a_params)}" 67 | expected_sd = np.sqrt((a_params[1] - a_params[0]) ** 2 / 12) 68 | assert np.isclose( 69 | np.std(a_samples), expected_sd, atol=0.08 70 | ), f"SD: {np.std(a_samples)} != {expected_sd}" 71 | # b (normal) 72 | assert np.isclose(np.mean(b_samples), b_params[0], atol=0.08), np.mean(b_samples) 73 | assert np.isclose(np.std(b_samples), b_params[1], atol=0.08), np.std(b_samples) 74 | 75 | 76 | @given(distributions_with_correlation()) 77 | def test_arbitrary_correlates(dist_corrs): 78 | """ 79 | Test multi-variable correlation with a series of arbitrary random variables, 80 | and an arbitrarily generated correlation matrix. 81 | 82 | Ensures the resulting corr. matrix is as expected, and that the marginal 83 | distributionsremain intact. 84 | """ 85 | 86 | uncorrelated_dists, corrs = dist_corrs 87 | correlated_dists = sq.correlate( 88 | uncorrelated_dists, corrs, tolerance=None, _min_unique_samples=1_000 89 | ) 90 | try: 91 | # The tolerance is quite high here, given that we're only 92 | # interested in very signifcant errors. The user would be warned 93 | # if the correlation was too far off anyway (with less tolerance). 94 | check_correlation_from_matrix(correlated_dists, corrs, atol=0.1) 95 | except ValueError as e: 96 | assume("repeated values" not in str(e)) 97 | raise e 98 | 99 | # Check that marginal distributions are preserved 100 | group = correlated_dists[0].correlation_group 101 | assert group is not None 102 | uncorr_samples = np.column_stack( 103 | [sq.sample(dist, 3_000, _correlate_if_needed=False) for dist in group.correlated_dists] 104 | ) 105 | corr_samples = group.induce_correlation(uncorr_samples) 106 | 107 | assert np.isclose( 108 | np.mean(uncorr_samples), np.mean(corr_samples), rtol=0.01 109 | ), "Means are not equal, violating integrity of marginal distributions" 110 | assert np.isclose( 111 | np.std(uncorr_samples), np.std(corr_samples), rtol=0.01 112 | ), "SDs are not equal, violating integrity of marginal distributions" 113 | assert np.isclose(np.median(uncorr_samples), np.median(corr_samples), rtol=0.01) 114 | assert np.isclose(np.max(uncorr_samples), np.max(corr_samples), rtol=0.01) 115 | assert np.isclose(np.min(uncorr_samples), np.min(corr_samples), rtol=0.01) 116 | 117 | 118 | def test_correlated_resampling(): 119 | """ 120 | Tests that correlated distributions can be resampled 121 | without stale samples being used (self._correlated_samples) 122 | """ 123 | uncorrelated_dists = sq.to(2, 30), sq.uniform(-3, 6), sq.beta(50, 100) 124 | correlated_dists = sq.correlate(uncorrelated_dists, 0.8, tolerance=None) 125 | 126 | first_samples = np.column_stack([d @ 1_000 for d in correlated_dists]) 127 | second_samples = np.column_stack([d @ 1_000 for d in correlated_dists]) 128 | 129 | assert not np.allclose( 130 | first_samples, second_samples 131 | ), "Resampling correlated distributions produces the same samples" 132 | -------------------------------------------------------------------------------- /tests/test_numbers.py: -------------------------------------------------------------------------------- 1 | from ..squigglepy.numbers import K, M, B, T 2 | 3 | 4 | def test_thousand(): 5 | assert K == 1000 6 | 7 | 8 | def test_million(): 9 | assert M == 10**6 10 | 11 | 12 | def test_billion(): 13 | assert B == 10**9 14 | 15 | 16 | def test_trillion(): 17 | assert T == 10**12 18 | -------------------------------------------------------------------------------- /tests/test_rng.py: -------------------------------------------------------------------------------- 1 | from ..squigglepy.rng import set_seed 2 | from ..squigglepy.samplers import sample 3 | from ..squigglepy.distributions import norm 4 | 5 | 6 | def test_seed(): 7 | set_seed(42) 8 | test = sample(norm(1, 10000)) 9 | set_seed(42) 10 | expected = sample(norm(1, 10000)) 11 | assert test == expected 12 | -------------------------------------------------------------------------------- /tests/test_samplers.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pytest 3 | import numpy as np 4 | from unittest.mock import patch, Mock 5 | 6 | from ..squigglepy.distributions import ( 7 | const, 8 | uniform, 9 | norm, 10 | lognorm, 11 | binomial, 12 | beta, 13 | bernoulli, 14 | discrete, 15 | tdist, 16 | log_tdist, 17 | triangular, 18 | pert, 19 | chisquare, 20 | poisson, 21 | exponential, 22 | gamma, 23 | pareto, 24 | mixture, 25 | zero_inflated, 26 | inf0, 27 | geometric, 28 | dist_min, 29 | dist_max, 30 | dist_round, 31 | dist_ceil, 32 | dist_floor, 33 | lclip, 34 | rclip, 35 | clip, 36 | dist_fn, 37 | ) 38 | from ..squigglepy import samplers 39 | from ..squigglepy.utils import _is_numpy 40 | from ..squigglepy.samplers import ( 41 | normal_sample, 42 | lognormal_sample, 43 | mixture_sample, 44 | discrete_sample, 45 | log_t_sample, 46 | t_sample, 47 | sample, 48 | ) 49 | from ..squigglepy.distributions import NormalDistribution 50 | 51 | 52 | class FakeRNG: 53 | def normal(self, mu, sigma, n): 54 | return round(mu, 2), round(sigma, 2) 55 | 56 | def lognormal(self, mu, sigma, n): 57 | return round(mu, 2), round(sigma, 2) 58 | 59 | def uniform(self, low, high, n): 60 | return low, high 61 | 62 | def binomial(self, n, p, nsamp): 63 | return n, p 64 | 65 | def beta(self, a, b, n): 66 | return a, b 67 | 68 | def bernoulli(self, p, n): 69 | return p 70 | 71 | def gamma(self, shape, scale, n): 72 | return shape, scale 73 | 74 | def pareto(self, shape, n): 75 | return shape 76 | 77 | def poisson(self, lam, n): 78 | return lam 79 | 80 | def exponential(self, scale, n): 81 | return scale 82 | 83 | def triangular(self, left, mode, right, n): 84 | return left, mode, right 85 | 86 | def standard_t(self, t, n): 87 | return t 88 | 89 | def chisquare(self, df, n): 90 | return df 91 | 92 | def geometric(self, p, n): 93 | return p 94 | 95 | 96 | def test_noop(): 97 | assert sample() is None 98 | 99 | 100 | @patch.object(samplers, "_get_rng", Mock(return_value=FakeRNG())) 101 | def test_norm(): 102 | assert normal_sample(1, 2) == (1, 2) 103 | 104 | 105 | @patch.object(samplers, "_get_rng", Mock(return_value=FakeRNG())) 106 | def test_sample_norm(): 107 | assert sample(norm(1, 2)) == (1.5, 0.3) 108 | 109 | 110 | @patch.object(samplers, "_get_rng", Mock(return_value=FakeRNG())) 111 | def test_sample_norm_shorthand(): 112 | assert ~norm(1, 2) == (1.5, 0.3) 113 | 114 | 115 | @patch.object(samplers, "_get_rng", Mock(return_value=FakeRNG())) 116 | def test_sample_norm_with_credibility(): 117 | assert sample(norm(1, 2, credibility=70)) == (1.5, 0.48) 118 | 119 | 120 | @patch.object(samplers, "_get_rng", Mock(return_value=FakeRNG())) 121 | def test_sample_norm_with_just_sd_infers_zero_mean(): 122 | assert sample(norm(sd=2)) == (0, 2) 123 | 124 | 125 | @patch.object(samplers, "normal_sample", Mock(return_value=-100)) 126 | def test_sample_norm_passes_lclip(): 127 | assert sample(norm(1, 2)) == -100 128 | assert sample(norm(1, 2, lclip=1)) == 1 129 | 130 | 131 | @patch.object(samplers, "normal_sample", Mock(return_value=100)) 132 | def test_sample_norm_passes_rclip(): 133 | assert sample(norm(1, 2)) == 100 134 | assert sample(norm(1, 2, rclip=3)) == 3 135 | 136 | 137 | @patch.object(samplers, "normal_sample", Mock(return_value=100)) 138 | def test_sample_norm_passes_lclip_rclip(): 139 | assert sample(norm(1, 2)) == 100 140 | assert sample(norm(1, 2, lclip=1, rclip=3)) == 3 141 | assert ~norm(1, 2, lclip=1, rclip=3) == 3 142 | 143 | 144 | @patch.object(samplers, "normal_sample", Mock(return_value=100)) 145 | def test_sample_norm_competing_clip(): 146 | assert sample(norm(1, 2)) == 100 147 | assert sample(norm(1, 2, rclip=3)) == 3 148 | assert sample(norm(1, 2, rclip=3), rclip=2) == 2 149 | assert sample(norm(1, 2, rclip=2), rclip=3) == 2 150 | 151 | 152 | @patch.object(samplers, "normal_sample", Mock(return_value=100)) 153 | def test_sample_norm_competing_clip_multiple_values(): 154 | assert all(sample(norm(1, 2), n=3) == np.array([100, 100, 100])) 155 | assert all(sample(norm(1, 2, rclip=3), n=3) == np.array([3, 3, 3])) 156 | assert all(sample(norm(1, 2, rclip=3), rclip=2, n=3) == np.array([2, 2, 2])) 157 | assert all(sample(norm(1, 2, rclip=2), rclip=3, n=3) == np.array([2, 2, 2])) 158 | 159 | 160 | @patch.object(samplers, "_get_rng", Mock(return_value=FakeRNG())) 161 | def test_lognorm(): 162 | assert lognormal_sample(1, 2) == (1, 2) 163 | 164 | 165 | @patch.object(samplers, "_get_rng", Mock(return_value=FakeRNG())) 166 | def test_sample_lognorm(): 167 | assert sample(lognorm(1, 2)) == (0.35, 0.21) 168 | 169 | 170 | @patch.object(samplers, "_get_rng", Mock(return_value=FakeRNG())) 171 | def test_sample_lognorm_with_credibility(): 172 | assert sample(lognorm(1, 2, credibility=70)) == (0.35, 0.33) 173 | 174 | 175 | @patch.object(samplers, "_get_rng", Mock(return_value=FakeRNG())) 176 | def test_sample_shorthand_lognorm(): 177 | assert ~lognorm(1, 2) == (0.35, 0.21) 178 | 179 | 180 | @patch.object(samplers, "_get_rng", Mock(return_value=FakeRNG())) 181 | def test_sample_shorthand_lognorm_with_credibility(): 182 | assert ~lognorm(1, 2, credibility=70) == (0.35, 0.33) 183 | 184 | 185 | @patch.object(samplers, "_get_rng", Mock(return_value=FakeRNG())) 186 | def test_sample_lognorm_with_just_normsd_infers_zero_mean(): 187 | assert sample(lognorm(norm_sd=2)) == (0, 2) 188 | 189 | 190 | @patch.object(samplers, "_get_rng", Mock(return_value=FakeRNG())) 191 | def test_sample_lognorm_with_just_lognormsd_infers_unit_mean(): 192 | assert sample(lognorm(lognorm_sd=2)) == (-0.8, 1.27) 193 | 194 | 195 | @patch.object(samplers, "lognormal_sample", Mock(return_value=100)) 196 | def test_sample_lognorm_passes_lclip_rclip(): 197 | assert sample(lognorm(1, 2)) == 100 198 | assert sample(lognorm(1, 2, lclip=1, rclip=3)) == 3 199 | 200 | 201 | @patch.object(samplers, "_get_rng", Mock(return_value=FakeRNG())) 202 | def test_sample_uniform(): 203 | assert sample(uniform(1, 2)) == (1, 2) 204 | 205 | 206 | @patch.object(samplers, "_get_rng", Mock(return_value=FakeRNG())) 207 | def test_sample_binomial(): 208 | assert sample(binomial(10, 0.1)) == (10, 0.1) 209 | 210 | 211 | @patch.object(samplers, "_get_rng", Mock(return_value=FakeRNG())) 212 | def test_sample_beta(): 213 | assert sample(beta(10, 1)) == (10, 1) 214 | 215 | 216 | @patch.object(samplers, "_get_rng", Mock(return_value=FakeRNG())) 217 | @patch.object(samplers, "uniform_sample", Mock(return_value=0)) 218 | def test_sample_bernoulli(): 219 | assert sample(bernoulli(0.1)) == 1 220 | 221 | 222 | @patch.object(samplers, "_get_rng", Mock(return_value=FakeRNG())) 223 | @patch.object(samplers, "normal_sample", Mock(return_value=1)) 224 | def test_tdist(): 225 | assert round(t_sample(1, 2, 3), 2) == 1 226 | 227 | 228 | @patch.object(samplers, "_get_rng", Mock(return_value=FakeRNG())) 229 | @patch.object(samplers, "normal_sample", Mock(return_value=1)) 230 | def test_tdist_t(): 231 | assert round(t_sample(), 2) == 20 232 | 233 | 234 | @patch.object(samplers, "_get_rng", Mock(return_value=FakeRNG())) 235 | @patch.object(samplers, "normal_sample", Mock(return_value=1)) 236 | def test_tdist_with_credibility(): 237 | assert round(t_sample(1, 2, 3, credibility=70), 2) == 1 238 | 239 | 240 | def test_tdist_low_gt_high(): 241 | with pytest.raises(ValueError) as execinfo: 242 | t_sample(10, 5, 3) 243 | assert "`high value` cannot be lower than `low value`" in str(execinfo.value) 244 | 245 | 246 | @patch.object(samplers, "_get_rng", Mock(return_value=FakeRNG())) 247 | @patch.object(samplers, "normal_sample", Mock(return_value=1)) 248 | def test_sample_tdist(): 249 | assert round(sample(tdist(1, 2, 3)), 2) == 1 250 | 251 | 252 | @patch.object(samplers, "t_sample", Mock(return_value=100)) 253 | def test_sample_tdist_passes_lclip_rclip(): 254 | assert sample(tdist(1, 2, 3)) == 100 255 | assert sample(tdist(1, 2, 3, lclip=1, rclip=3)) == 3 256 | 257 | 258 | @patch.object(samplers, "_get_rng", Mock(return_value=FakeRNG())) 259 | @patch.object(samplers, "normal_sample", Mock(return_value=1)) 260 | def test_log_tdist(): 261 | assert round(log_t_sample(1, 2, 3), 2) == 2.72 262 | 263 | 264 | @patch.object(samplers, "_get_rng", Mock(return_value=FakeRNG())) 265 | @patch.object(samplers, "normal_sample", Mock(return_value=1)) 266 | def test_log_tdist_with_credibility(): 267 | assert round(log_t_sample(1, 2, 3, credibility=70), 2) == 2.72 268 | 269 | 270 | def test_log_tdist_low_gt_high(): 271 | with pytest.raises(ValueError) as execinfo: 272 | log_t_sample(10, 5, 3) 273 | assert "`high value` cannot be lower than `low value`" in str(execinfo.value) 274 | 275 | 276 | @patch.object(samplers, "_get_rng", Mock(return_value=FakeRNG())) 277 | @patch.object(samplers, "normal_sample", Mock(return_value=1)) 278 | def teslog_t_sample_log_tdist(): 279 | assert round(sample(log_tdist(1, 2, 3)), 2) == 1 / 3 280 | 281 | 282 | @patch.object(samplers, "log_t_sample", Mock(return_value=100)) 283 | def teslog_t_sample_log_tdist_passes_lclip_rclip(): 284 | assert sample(log_tdist(1, 2, 3)) == 100 285 | assert sample(log_tdist(1, 2, 3, lclip=1, rclip=3)) == 3 286 | 287 | 288 | @patch.object(samplers, "_get_rng", Mock(return_value=FakeRNG())) 289 | def test_sample_triangular(): 290 | assert sample(triangular(10, 20, 30)) == (10, 20, 30) 291 | 292 | 293 | @patch.object(samplers, "pert_sample", Mock(return_value=100)) 294 | def test_sample_pert(): 295 | assert sample(pert(10, 20, 30, 40)) == 100 296 | 297 | 298 | @patch.object(samplers, "pert_sample", Mock(return_value=100)) 299 | def test_sample_pert_passes_lclip_rclip(): 300 | assert sample(pert(1, 2, 3, 4)) == 100 301 | assert sample(pert(1, 2, 3, 4, lclip=1, rclip=3)) == 3 302 | 303 | 304 | @patch.object(samplers, "_get_rng", Mock(return_value=FakeRNG())) 305 | def test_sample_exponential(): 306 | assert sample(exponential(10)) == 10 307 | 308 | 309 | @patch.object(samplers, "exponential_sample", Mock(return_value=100)) 310 | def test_sample_exponential_passes_lclip_rclip(): 311 | assert sample(exponential(1)) == 100 312 | assert sample(exponential(1, lclip=1, rclip=3)) == 3 313 | 314 | 315 | @patch.object(samplers, "_get_rng", Mock(return_value=FakeRNG())) 316 | def test_sample_chisquare(): 317 | assert sample(chisquare(9)) == 9 318 | 319 | 320 | @patch.object(samplers, "_get_rng", Mock(return_value=FakeRNG())) 321 | def test_sample_poisson(): 322 | assert sample(poisson(10)) == 10 323 | 324 | 325 | @patch.object(samplers, "poisson_sample", Mock(return_value=100)) 326 | def test_sample_poisson_passes_lclip_rclip(): 327 | assert sample(poisson(1)) == 100 328 | assert sample(poisson(1, lclip=1, rclip=3)) == 3 329 | 330 | 331 | def test_sample_const(): 332 | assert sample(const(11)) == 11 333 | 334 | 335 | def test_sample_const_shorthand(): 336 | assert ~const(11) == 11 337 | 338 | 339 | def test_nested_const_does_not_resolve(): 340 | assert isinstance((~const(norm(1, 2))), NormalDistribution) 341 | assert (~const(norm(1, 2))).x == 1 342 | assert (~const(norm(1, 2))).y == 2 343 | 344 | 345 | @patch.object(samplers, "_get_rng", Mock(return_value=FakeRNG())) 346 | def test_nested_const_double_resolve(): 347 | assert ~~const(norm(1, 2)) == (1.5, 0.3) 348 | 349 | 350 | @patch.object(samplers, "_get_rng", Mock(return_value=FakeRNG())) 351 | def test_sample_gamma_default(): 352 | assert sample(gamma(10)) == (10, 1) 353 | 354 | 355 | @patch.object(samplers, "_get_rng", Mock(return_value=FakeRNG())) 356 | def test_sample_gamma(): 357 | assert sample(gamma(10, 2)) == (10, 2) 358 | 359 | 360 | @patch.object(samplers, "gamma_sample", Mock(return_value=100)) 361 | def test_sample_gamma_passes_lclip_rclip(): 362 | assert sample(gamma(1, 2)) == 100 363 | assert sample(gamma(1, 2, lclip=1, rclip=3)) == 3 364 | 365 | 366 | @patch.object(samplers, "_get_rng", Mock(return_value=FakeRNG())) 367 | def test_sample_pareto_default(): 368 | assert sample(pareto(10)) == 11 369 | 370 | 371 | @patch.object(samplers, "_get_rng", Mock(return_value=FakeRNG())) 372 | @patch.object(samplers, "uniform_sample", Mock(return_value=0)) 373 | def test_discrete(): 374 | assert discrete_sample([0, 1, 2]) == 0 375 | 376 | 377 | @patch.object(samplers, "_get_rng", Mock(return_value=FakeRNG())) 378 | @patch.object(samplers, "uniform_sample", Mock(return_value=0)) 379 | def test_discrete_alt_format(): 380 | assert discrete_sample([[0.9, "a"], [0.1, "b"]]) == "a" 381 | 382 | 383 | @patch.object(samplers, "_get_rng", Mock(return_value=FakeRNG())) 384 | @patch.object(samplers, "uniform_sample", Mock(return_value=0)) 385 | def test_discrete_alt2_format(): 386 | assert discrete_sample({"a": 0.9, "b": 0.1}) == "a" 387 | 388 | 389 | @patch.object(samplers, "_get_rng", Mock(return_value=FakeRNG())) 390 | @patch.object(samplers, "uniform_sample", Mock(return_value=0)) 391 | def test_sample_discrete(): 392 | assert sample(discrete([0, 1, 2])) == 0 393 | 394 | 395 | @patch.object(samplers, "_get_rng", Mock(return_value=FakeRNG())) 396 | @patch.object(samplers, "uniform_sample", Mock(return_value=0)) 397 | def test_sample_discrete_alt_format(): 398 | assert sample(discrete([[0.9, "a"], [0.1, "b"]])) == "a" 399 | 400 | 401 | @patch.object(samplers, "_get_rng", Mock(return_value=FakeRNG())) 402 | @patch.object(samplers, "uniform_sample", Mock(return_value=0)) 403 | def test_sample_discrete_alt2_format(): 404 | assert sample(discrete({"a": 0.9, "b": 0.1})) == "a" 405 | 406 | 407 | @patch.object(samplers, "_get_rng", Mock(return_value=FakeRNG())) 408 | @patch.object(samplers, "uniform_sample", Mock(return_value=0)) 409 | def test_sample_discrete_shorthand(): 410 | assert ~discrete([0, 1, 2]) == 0 411 | assert ~discrete([[0.9, "a"], [0.1, "b"]]) == "a" 412 | assert ~discrete({"a": 0.9, "b": 0.1}) == "a" 413 | 414 | 415 | @patch.object(samplers, "_get_rng", Mock(return_value=FakeRNG())) 416 | @patch.object(samplers, "uniform_sample", Mock(return_value=0)) 417 | def test_sample_discrete_cannot_mixture(): 418 | obj = ~discrete([norm(1, 2), norm(3, 4)]) 419 | # Instead of sampling `norm(1, 2)`, discrete just returns it unsampled. 420 | assert isinstance(obj, NormalDistribution) 421 | assert obj.x == 1 422 | assert obj.y == 2 423 | 424 | 425 | @patch.object(samplers, "_get_rng", Mock(return_value=FakeRNG())) 426 | @patch.object(samplers, "uniform_sample", Mock(return_value=0)) 427 | def test_sample_discrete_indirect_mixture(): 428 | # You would have to double resolve this to get a value. 429 | assert ~~discrete([norm(1, 2), norm(3, 4)]) == (1.5, 0.3) 430 | 431 | 432 | @patch.object(samplers, "_get_rng", Mock(return_value=FakeRNG())) 433 | @patch.object(samplers, "uniform_sample", Mock(return_value=0)) 434 | def test_mixture_sample(): 435 | assert mixture_sample([norm(1, 2), norm(3, 4)], [0.2, 0.8]) == (1.5, 0.3) 436 | 437 | 438 | @patch.object(samplers, "_get_rng", Mock(return_value=FakeRNG())) 439 | @patch.object(samplers, "uniform_sample", Mock(return_value=0)) 440 | def test_mixture_sample_relative_weights(): 441 | assert mixture_sample([norm(1, 2), norm(3, 4)], relative_weights=[1, 1]) == (1.5, 0.3) 442 | 443 | 444 | @patch.object(samplers, "_get_rng", Mock(return_value=FakeRNG())) 445 | @patch.object(samplers, "uniform_sample", Mock(return_value=0)) 446 | def test_mixture_sample_alt_format(): 447 | assert mixture_sample([[0.2, norm(1, 2)], [0.8, norm(3, 4)]]) == (1.5, 0.3) 448 | 449 | 450 | @patch.object(samplers, "normal_sample", Mock(return_value=100)) 451 | @patch.object(samplers, "uniform_sample", Mock(return_value=0)) 452 | def test_mixture_sample_rclip_lclip(): 453 | assert mixture_sample([norm(1, 2), norm(3, 4)], [0.2, 0.8]) == 100 454 | assert mixture_sample([norm(1, 2, rclip=3), norm(3, 4)], [0.2, 0.8]) == 3 455 | 456 | 457 | @patch.object(samplers, "_get_rng", Mock(return_value=FakeRNG())) 458 | @patch.object(samplers, "uniform_sample", Mock(return_value=0)) 459 | def test_mixture_sample_no_weights(): 460 | assert mixture_sample([norm(1, 2), norm(3, 4)]) == (1.5, 0.3) 461 | 462 | 463 | @patch.object(samplers, "_get_rng", Mock(return_value=FakeRNG())) 464 | @patch.object(samplers, "uniform_sample", Mock(return_value=0)) 465 | def test_mixture_sample_different_distributions(): 466 | assert mixture_sample([lognorm(1, 2), norm(3, 4)]) == (0.35, 0.21) 467 | 468 | 469 | @patch.object(samplers, "_get_rng", Mock(return_value=FakeRNG())) 470 | @patch.object(samplers, "uniform_sample", Mock(return_value=0)) 471 | def test_mixture_sample_with_numbers(): 472 | assert mixture_sample([2, norm(3, 4)]) == 2 473 | 474 | 475 | @patch.object(samplers, "_get_rng", Mock(return_value=FakeRNG())) 476 | @patch.object(samplers, "uniform_sample", Mock(return_value=0)) 477 | def test_sample_mixture(): 478 | assert sample(mixture([norm(1, 2), norm(3, 4)], [0.2, 0.8])) == (1.5, 0.3) 479 | 480 | 481 | @patch.object(samplers, "_get_rng", Mock(return_value=FakeRNG())) 482 | @patch.object(samplers, "uniform_sample", Mock(return_value=0)) 483 | def test_sample_mixture_alt_format(): 484 | assert sample(mixture([[0.2, norm(1, 2)], [0.8, norm(3, 4)]])) == (1.5, 0.3) 485 | 486 | 487 | @patch.object(samplers, "normal_sample", Mock(return_value=1)) 488 | @patch.object(samplers, "uniform_sample", Mock(return_value=0)) 489 | def test_sample_mixture_lclip(): 490 | assert sample(mixture([norm(1, 2), norm(3, 4)], [0.2, 0.8])) == 1 491 | assert sample(mixture([norm(1, 2, lclip=3), norm(3, 4)], [0.2, 0.8])) == 3 492 | assert sample(mixture([norm(1, 2), norm(3, 4)], [0.2, 0.8], lclip=3)) == 3 493 | assert sample(mixture([norm(1, 2), norm(3, 4)], [0.2, 0.8]), lclip=3) == 3 494 | 495 | 496 | @patch.object(samplers, "normal_sample", Mock(return_value=1)) 497 | @patch.object(samplers, "uniform_sample", Mock(return_value=0)) 498 | def test_sample_mixture_lclip_multiple_values(): 499 | assert all(sample(mixture([norm(1, 2), norm(3, 4)], [0.2, 0.8]), n=3) == np.array([1, 1, 1])) 500 | assert all( 501 | sample(mixture([norm(1, 2, lclip=3), norm(3, 4)], [0.2, 0.8])) == np.array([3, 3, 3]) 502 | ) 503 | assert all( 504 | sample(mixture([norm(1, 2), norm(3, 4)], [0.2, 0.8], lclip=3)) == np.array([3, 3, 3]) 505 | ) 506 | assert all( 507 | sample(mixture([norm(1, 2), norm(3, 4)], [0.2, 0.8]), lclip=3) == np.array([3, 3, 3]) 508 | ) 509 | 510 | 511 | @patch.object(samplers, "normal_sample", Mock(return_value=1)) 512 | @patch.object(samplers, "uniform_sample", Mock(return_value=0)) 513 | def test_sample_mixture_lclip_alt_format(): 514 | assert sample(mixture([[0.2, norm(1, 2, lclip=3)], [0.8, norm(3, 4)]])) == 3 515 | 516 | 517 | @patch.object(samplers, "normal_sample", Mock(return_value=1)) 518 | @patch.object(samplers, "uniform_sample", Mock(return_value=0)) 519 | def test_mixture_sample_lclip_alt_format_multiple_values(): 520 | assert all( 521 | mixture_sample([[0.2, norm(1, 2, lclip=3)], [0.8, norm(3, 4)]], samples=3) 522 | == np.array([3, 3, 3]) 523 | ) 524 | 525 | 526 | @patch.object(samplers, "normal_sample", Mock(return_value=1)) 527 | @patch.object(samplers, "uniform_sample", Mock(return_value=0)) 528 | def test_sample_mixture_lclip_alt_format_multiple_values(): 529 | assert all( 530 | sample(mixture([[0.2, norm(1, 2, lclip=3)], [0.8, norm(3, 4)]]), n=3) 531 | == np.array([3, 3, 3]) 532 | ) 533 | 534 | 535 | @patch.object(samplers, "normal_sample", Mock(return_value=1)) 536 | @patch.object(samplers, "uniform_sample", Mock(return_value=0)) 537 | def test_sample_mixture_lclip_alt_format2(): 538 | assert ~mixture([[0.2, norm(1, 2, lclip=3)], [0.8, norm(3, 4)]]) == 3 539 | 540 | 541 | @patch.object(samplers, "normal_sample", Mock(return_value=1)) 542 | @patch.object(samplers, "uniform_sample", Mock(return_value=0)) 543 | def test_sample_mixture_lclip_alt_format2_multiple_values(): 544 | assert all(mixture([[0.2, norm(1, 2, lclip=3)], [0.8, norm(3, 4)]]) @ 3 == np.array([3, 3, 3])) 545 | 546 | 547 | @patch.object(samplers, "normal_sample", Mock(return_value=100)) 548 | @patch.object(samplers, "uniform_sample", Mock(return_value=0)) 549 | def test_sample_mixture_rclip(): 550 | assert sample(mixture([norm(1, 2), norm(3, 4)], [0.2, 0.8])) == 100 551 | assert sample(mixture([norm(1, 2, rclip=3), norm(3, 4)], [0.2, 0.8])) == 3 552 | assert sample(mixture([norm(1, 2), norm(3, 4)], [0.2, 0.8], rclip=3)) == 3 553 | assert sample(mixture([norm(1, 2), norm(3, 4)], [0.2, 0.8]), rclip=3) == 3 554 | 555 | 556 | @patch.object(samplers, "normal_sample", Mock(return_value=100)) 557 | @patch.object(samplers, "uniform_sample", Mock(return_value=0)) 558 | def test_sample_mixture_competing_clip(): 559 | assert sample(mixture([norm(1, 2, rclip=3), norm(3, 4)], [0.2, 0.8])) == 3 560 | assert sample(mixture([norm(1, 2, rclip=2), norm(3, 4)], [0.2, 0.8], rclip=3)) == 2 561 | assert sample(mixture([norm(1, 2, rclip=3), norm(3, 4)], [0.2, 0.8]), rclip=2) == 2 562 | 563 | 564 | @patch.object(samplers, "_get_rng", Mock(return_value=FakeRNG())) 565 | @patch.object(samplers, "uniform_sample", Mock(return_value=0)) 566 | def test_sample_mixture_no_weights(): 567 | assert sample(mixture([norm(1, 2), norm(3, 4)])) == (1.5, 0.3) 568 | 569 | 570 | @patch.object(samplers, "_get_rng", Mock(return_value=FakeRNG())) 571 | @patch.object(samplers, "uniform_sample", Mock(return_value=0)) 572 | def test_sample_mixture_different_distributions(): 573 | assert sample(mixture([lognorm(1, 2), norm(3, 4)])) == (0.35, 0.21) 574 | 575 | 576 | @patch.object(samplers, "_get_rng", Mock(return_value=FakeRNG())) 577 | @patch.object(samplers, "uniform_sample", Mock(return_value=0)) 578 | def test_sample_mixture_with_numbers(): 579 | assert sample(mixture([2, norm(3, 4)])) == 2 580 | 581 | 582 | @patch.object(samplers, "_get_rng", Mock(return_value=FakeRNG())) 583 | @patch.object(samplers, "uniform_sample", Mock(return_value=0)) 584 | def test_sample_mixture_can_be_discrete(): 585 | assert ~mixture([0, 1, 2]) == 0 586 | assert ~mixture([[0.9, "a"], [0.1, "b"]]) == "a" 587 | assert ~mixture({"a": 0.9, "b": 0.1}) == "a" 588 | assert ~mixture([norm(1, 2), norm(3, 4)]) == (1.5, 0.3) 589 | 590 | 591 | @patch.object(samplers, "_get_rng", Mock(return_value=FakeRNG())) 592 | @patch.object(samplers, "uniform_sample", Mock(return_value=0)) 593 | def test_sample_mixture_contains_discrete(): 594 | assert sample(mixture([lognorm(1, 2), discrete([3, 4])])) == (0.35, 0.21) 595 | 596 | 597 | @patch.object(samplers, "_get_rng", Mock(return_value=FakeRNG())) 598 | @patch.object(samplers, "uniform_sample", Mock(return_value=0)) 599 | def test_sample_mixture_contains_mixture(): 600 | assert sample(mixture([lognorm(1, 2), mixture([1, discrete([3, 4])])])) == ( 601 | 0.35, 602 | 0.21, 603 | ) 604 | 605 | 606 | @patch.object(samplers, "_get_rng", Mock(return_value=FakeRNG())) 607 | @patch.object(samplers, "uniform_sample", Mock(return_value=0)) 608 | def test_sample_zero_inflated(): 609 | assert ~zero_inflated(0.6, norm(1, 2)) == 0 610 | 611 | 612 | @patch.object(samplers, "_get_rng", Mock(return_value=FakeRNG())) 613 | @patch.object(samplers, "uniform_sample", Mock(return_value=0)) 614 | def test_sample_inf0(): 615 | assert ~inf0(0.6, norm(1, 2)) == 0 616 | 617 | 618 | @patch.object(samplers, "_get_rng", Mock(return_value=FakeRNG())) 619 | def test_sample_geometric(): 620 | assert sample(geometric(0.1)) == 0.1 621 | 622 | 623 | def test_sample_n_gt_1_norm(): 624 | out = sample(norm(1, 2), n=5) 625 | assert _is_numpy(out) 626 | assert len(out) == 5 627 | 628 | 629 | def test_sample_n_gt_1_lognorm(): 630 | out = sample(lognorm(1, 2), n=5) 631 | assert _is_numpy(out) 632 | assert len(out) == 5 633 | 634 | 635 | def test_sample_n_gt_1_binomial(): 636 | out = sample(binomial(5, 0.1), n=5) 637 | assert _is_numpy(out) 638 | assert len(out) == 5 639 | 640 | 641 | def test_sample_n_gt_1_beta(): 642 | out = sample(beta(5, 10), n=5) 643 | assert _is_numpy(out) 644 | assert len(out) == 5 645 | 646 | 647 | def test_sample_n_gt_1_bernoulli(): 648 | out = sample(bernoulli(0.1), n=5) 649 | assert _is_numpy(out) 650 | assert len(out) == 5 651 | 652 | 653 | def test_sample_n_gt_1_poisson(): 654 | out = sample(poisson(0.1), n=5) 655 | assert _is_numpy(out) 656 | assert len(out) == 5 657 | 658 | 659 | def test_sample_n_gt_1_chisquare(): 660 | out = sample(chisquare(10), n=5) 661 | assert _is_numpy(out) 662 | assert len(out) == 5 663 | 664 | 665 | def test_sample_n_gt_1_gamma(): 666 | out = sample(gamma(10, 10), n=5) 667 | assert _is_numpy(out) 668 | assert len(out) == 5 669 | 670 | 671 | def test_sample_n_gt_1_triangular(): 672 | out = sample(triangular(1, 2, 3), n=5) 673 | assert _is_numpy(out) 674 | assert len(out) == 5 675 | 676 | 677 | def test_sample_n_gt_1_tdist(): 678 | out = sample(tdist(1, 2, 3), n=5) 679 | assert _is_numpy(out) 680 | assert len(out) == 5 681 | 682 | 683 | def test_sample_n_gt_1_tdist_t(): 684 | out = sample(tdist(), n=5) 685 | assert _is_numpy(out) 686 | assert len(out) == 5 687 | 688 | 689 | def test_sample_n_gt_1_log_tdist(): 690 | out = sample(log_tdist(1, 2, 3), n=5) 691 | assert _is_numpy(out) 692 | assert len(out) == 5 693 | 694 | 695 | def test_sample_n_gt_1_const(): 696 | out = sample(const(1), n=5) 697 | assert _is_numpy(out) 698 | assert len(out) == 5 699 | 700 | 701 | def test_sample_n_gt_1_uniform(): 702 | out = sample(uniform(0, 1), n=5) 703 | assert _is_numpy(out) 704 | assert len(out) == 5 705 | 706 | 707 | def test_sample_n_gt_1_discrete(): 708 | out = sample(discrete([1, 2, 3]), n=5) 709 | assert _is_numpy(out) 710 | assert len(out) == 5 711 | 712 | 713 | def test_sample_n_gt_1_mixture(): 714 | out = sample(mixture([norm(1, 2), norm(3, 4)]), n=5) 715 | assert _is_numpy(out) 716 | assert len(out) == 5 717 | 718 | 719 | def test_sample_n_gt_1_geometric(): 720 | out = sample(geometric(0.1), n=5) 721 | assert _is_numpy(out) 722 | assert len(out) == 5 723 | 724 | 725 | def test_sample_n_gt_1_raw_float(): 726 | out = sample(0.1, n=5) 727 | assert _is_numpy(out) 728 | assert len(out) == 5 729 | 730 | 731 | def test_sample_n_gt_1_raw_int(): 732 | out = sample(1, n=5) 733 | assert _is_numpy(out) 734 | assert len(out) == 5 735 | 736 | 737 | def test_sample_n_gt_1_raw_str(): 738 | out = sample("a", n=5) 739 | assert _is_numpy(out) 740 | assert len(out) == 5 741 | 742 | 743 | def test_sample_n_gt_1_complex(): 744 | out = sample(uniform(0, 1) + 5 >> dist_ceil, n=5) 745 | assert _is_numpy(out) 746 | assert len(out) == 5 747 | 748 | 749 | def test_sample_n_gt_1_callable(): 750 | def _fn(): 751 | return norm(1, 2) + norm(3, 4) 752 | 753 | out = sample(_fn, n=5) 754 | assert _is_numpy(out) 755 | assert len(out) == 5 756 | 757 | 758 | def test_sample_shorthand_n_gt_1(): 759 | out = norm(1, 2) @ 5 760 | assert _is_numpy(out) 761 | assert len(out) == 5 762 | 763 | 764 | def test_sample_shorthand_n_is_var(): 765 | n = 2 + 3 766 | out = norm(1, 2) @ n 767 | assert _is_numpy(out) 768 | assert len(out) == n 769 | 770 | 771 | def test_sample_shorthand_n_is_float(): 772 | out = norm(1, 2) @ 7.0 773 | assert _is_numpy(out) 774 | assert len(out) == 7.0 775 | 776 | 777 | def test_sample_shorthand_n_is_numpy_int(): 778 | out = norm(1, 2) @ np.int64(4) 779 | assert _is_numpy(out) 780 | assert len(out) == 4 781 | out = norm(1, 2) @ np.int32(4) 782 | assert _is_numpy(out) 783 | assert len(out) == 4 784 | 785 | 786 | def test_sample_shorthand_n_gt_1_alt(): 787 | out = 5 @ norm(1, 2) 788 | assert _is_numpy(out) 789 | assert len(out) == 5 790 | 791 | 792 | def test_sample_n_is_0_is_error(): 793 | with pytest.raises(ValueError) as execinfo: 794 | sample(norm(1, 5), n=0) 795 | assert "n must be >= 1" in str(execinfo.value) 796 | 797 | 798 | def test_sample_n_is_0_is_error_shorthand(): 799 | with pytest.raises(ValueError) as execinfo: 800 | norm(1, 5) @ 0 801 | assert "n must be >= 1" in str(execinfo.value) 802 | 803 | 804 | def test_sample_n_is_0_is_error_shorthand_alt(): 805 | with pytest.raises(ValueError) as execinfo: 806 | 0 @ norm(1, 5) 807 | assert "n must be >= 1" in str(execinfo.value) 808 | 809 | 810 | def test_sample_int(): 811 | assert sample(4) == 4 812 | 813 | 814 | def test_sample_float(): 815 | assert sample(3.14) == 3.14 816 | 817 | 818 | def test_sample_str(): 819 | assert sample("a") == "a" 820 | 821 | 822 | def test_sample_none(): 823 | assert sample(None) is None 824 | 825 | 826 | def test_sample_callable(): 827 | def sample_fn(): 828 | return 1 829 | 830 | assert sample(sample_fn) == 1 831 | 832 | 833 | @patch.object(samplers, "normal_sample", Mock(return_value=1)) 834 | @patch.object(samplers, "lognormal_sample", Mock(return_value=4)) 835 | def test_sample_more_complex_callable(): 836 | def sample_fn(): 837 | return max(~norm(1, 4), ~lognorm(1, 10)) 838 | 839 | assert sample(sample_fn) == 4 840 | 841 | 842 | @patch.object(samplers, "normal_sample", Mock(return_value=1)) 843 | @patch.object(samplers, "lognormal_sample", Mock(return_value=4)) 844 | def test_sample_callable_resolves_fully(): 845 | def sample_fn(): 846 | return norm(1, 4) + lognorm(1, 10) 847 | 848 | assert sample(sample_fn) == 5 849 | 850 | 851 | @patch.object(samplers, "normal_sample", Mock(return_value=1)) 852 | @patch.object(samplers, "lognormal_sample", Mock(return_value=4)) 853 | def test_sample_callable_resolves_fully2(): 854 | def really_inner_sample_fn(): 855 | return 1 856 | 857 | def inner_sample_fn(): 858 | return norm(1, 4) + lognorm(1, 10) + really_inner_sample_fn() 859 | 860 | def outer_sample_fn(): 861 | return inner_sample_fn 862 | 863 | assert sample(outer_sample_fn) == 6 864 | 865 | 866 | def test_sample_invalid_input(): 867 | with pytest.raises(ValueError) as execinfo: 868 | sample([1, 5]) 869 | assert "not a sampleable type" in str(execinfo.value) 870 | 871 | 872 | @patch.object(samplers, "normal_sample", Mock(return_value=100)) 873 | def test_sample_math(): 874 | assert ~(norm(0, 1) + norm(1, 2)) == 200 875 | 876 | 877 | @patch.object(samplers, "normal_sample", Mock(return_value=10)) 878 | @patch.object(samplers, "lognormal_sample", Mock(return_value=100)) 879 | def test_sample_complex_math(): 880 | obj = (2 ** norm(0, 1)) - (8 * 6) + 2 + (lognorm(10, 100) / 11) + 8 881 | expected = (2**10) - (8 * 6) + 2 + (100 / 11) + 8 882 | assert ~obj == expected 883 | 884 | 885 | @patch.object(samplers, "normal_sample", Mock(return_value=100)) 886 | def test_sample_equality(): 887 | assert ~(norm(0, 1) == norm(1, 2)) 888 | 889 | 890 | @patch.object(samplers, "normal_sample", Mock(return_value=10)) 891 | def test_pipe(): 892 | assert ~(norm(0, 1) >> rclip(2)) == 2 893 | assert ~(norm(0, 1) >> lclip(2)) == 10 894 | 895 | 896 | @patch.object(samplers, "normal_sample", Mock(return_value=1.6)) 897 | def test_two_pipes(): 898 | assert ~(norm(0, 1) >> rclip(10) >> dist_round) == 2 899 | 900 | 901 | @patch.object(samplers, "normal_sample", Mock(return_value=10)) 902 | def test_dist_fn(): 903 | def mirror(x): 904 | return 1 - x if x > 0.5 else x 905 | 906 | assert ~dist_fn(norm(0, 1), mirror) == -9 907 | 908 | 909 | @patch.object(samplers, "normal_sample", Mock(return_value=10)) 910 | def test_dist_fn2(): 911 | def mirror(x, y): 912 | return 1 - x if x > y else x 913 | 914 | assert ~dist_fn(norm(0, 10), norm(1, 2), mirror) == 10 915 | 916 | 917 | @patch.object(samplers, "normal_sample", Mock(return_value=10)) 918 | def test_dist_fn_list(): 919 | def mirror(x): 920 | return 1 - x if x > 0.5 else x 921 | 922 | def mirror2(x): 923 | return 1 + x if x > 0.5 else x 924 | 925 | assert ~dist_fn(norm(0, 1), [mirror, mirror2]) == -9 926 | 927 | 928 | @patch.object(samplers, "normal_sample", Mock(return_value=10)) 929 | @patch.object(samplers, "lognormal_sample", Mock(return_value=20)) 930 | def test_max(): 931 | assert ~dist_max(norm(0, 1), lognorm(0.1, 1)) == 20 932 | 933 | 934 | @patch.object(samplers, "normal_sample", Mock(return_value=10)) 935 | @patch.object(samplers, "lognormal_sample", Mock(return_value=20)) 936 | def test_min(): 937 | assert ~dist_min(norm(0, 1), lognorm(0.1, 1)) == 10 938 | 939 | 940 | @patch.object(samplers, "normal_sample", Mock(return_value=3.1415)) 941 | def test_round(): 942 | assert ~dist_round(norm(0, 1)) == 3 943 | 944 | 945 | @patch.object(samplers, "normal_sample", Mock(return_value=3.1415)) 946 | def test_round_two_digits(): 947 | assert ~dist_round(norm(0, 1), digits=2) == 3.14 948 | 949 | 950 | @patch.object(samplers, "normal_sample", Mock(return_value=3.1415)) 951 | def test_ceil(): 952 | assert ~dist_ceil(norm(0, 1)) == 4 953 | 954 | 955 | @patch.object(samplers, "normal_sample", Mock(return_value=3.1415)) 956 | def test_floor(): 957 | assert ~dist_floor(norm(0, 1)) == 3 958 | 959 | 960 | @patch.object(samplers, "normal_sample", Mock(return_value=10)) 961 | def test_lclip(): 962 | assert ~lclip(norm(0, 1), 0.5) == 10 963 | 964 | 965 | @patch.object(samplers, "normal_sample", Mock(return_value=10)) 966 | def test_rclip(): 967 | assert ~rclip(norm(0, 1), 0.5) == 0.5 968 | 969 | 970 | @patch.object(samplers, "normal_sample", Mock(return_value=10)) 971 | def test_clip(): 972 | assert ~clip(norm(0, 1), 0.5, 0.9) == 0.9 973 | 974 | 975 | def test_sample_cache(): 976 | from ..squigglepy.samplers import _squigglepy_internal_sample_caches 977 | 978 | n_caches = len(_squigglepy_internal_sample_caches) 979 | 980 | sample(norm(1, 2), memcache=True) 981 | from ..squigglepy.samplers import _squigglepy_internal_sample_caches 982 | 983 | n_caches2 = len(_squigglepy_internal_sample_caches) 984 | assert n_caches < n_caches2 985 | 986 | sample(norm(1, 2), memcache=True) 987 | from ..squigglepy.samplers import _squigglepy_internal_sample_caches 988 | 989 | n_caches3 = len(_squigglepy_internal_sample_caches) 990 | assert n_caches2 == n_caches3 991 | 992 | sample(norm(1, 2)) 993 | from ..squigglepy.samplers import _squigglepy_internal_sample_caches 994 | 995 | n_caches4 = len(_squigglepy_internal_sample_caches) 996 | assert n_caches2 == n_caches4 997 | 998 | sample(norm(3, 4)) 999 | from ..squigglepy.samplers import _squigglepy_internal_sample_caches 1000 | 1001 | n_caches5 = len(_squigglepy_internal_sample_caches) 1002 | assert n_caches2 == n_caches5 1003 | 1004 | sample(norm(3, 4), memcache=True) 1005 | from ..squigglepy.samplers import _squigglepy_internal_sample_caches 1006 | 1007 | n_caches6 = len(_squigglepy_internal_sample_caches) 1008 | assert n_caches6 > n_caches5 1009 | 1010 | 1011 | def test_sample_reload_cache(): 1012 | from ..squigglepy.samplers import _squigglepy_internal_sample_caches 1013 | 1014 | n_caches = len(_squigglepy_internal_sample_caches) 1015 | 1016 | out1 = sample(norm(5, 6), memcache=True) 1017 | from ..squigglepy.samplers import _squigglepy_internal_sample_caches 1018 | 1019 | n_caches2 = len(_squigglepy_internal_sample_caches) 1020 | assert n_caches < n_caches2 1021 | 1022 | out2 = sample(norm(5, 6), memcache=True) 1023 | from ..squigglepy.samplers import _squigglepy_internal_sample_caches 1024 | 1025 | n_caches3 = len(_squigglepy_internal_sample_caches) 1026 | assert n_caches2 == n_caches3 1027 | assert out1 == out2 1028 | 1029 | out3 = sample(norm(5, 6), memcache=True, reload_cache=True) 1030 | from ..squigglepy.samplers import _squigglepy_internal_sample_caches 1031 | 1032 | n_caches4 = len(_squigglepy_internal_sample_caches) 1033 | assert n_caches3 == n_caches4 1034 | assert out3 != out2 1035 | 1036 | 1037 | @pytest.fixture 1038 | def cachefile(): 1039 | cachefile = "testcache" 1040 | yield cachefile 1041 | try: 1042 | os.remove(cachefile + ".sqcache.npy") 1043 | except FileNotFoundError: 1044 | pass 1045 | 1046 | 1047 | def test_sample_cachefile(cachefile): 1048 | assert not os.path.exists(cachefile + ".sqcache.npy") 1049 | sample(norm(1, 2), dump_cache_file=cachefile) 1050 | assert os.path.exists(cachefile + ".sqcache.npy") 1051 | 1052 | 1053 | def test_sample_cachefile_primary(cachefile): 1054 | assert not os.path.exists(cachefile + ".sqcache.npy") 1055 | 1056 | from ..squigglepy.samplers import _squigglepy_internal_sample_caches 1057 | 1058 | n_caches = len(_squigglepy_internal_sample_caches) 1059 | 1060 | sample(norm(10, 20), dump_cache_file=cachefile, memcache=True) 1061 | 1062 | from ..squigglepy.samplers import _squigglepy_internal_sample_caches 1063 | 1064 | n_caches2 = len(_squigglepy_internal_sample_caches) 1065 | assert n_caches2 == n_caches + 1 1066 | assert os.path.exists(cachefile + ".sqcache.npy") 1067 | 1068 | o1 = sample(norm(10, 20), load_cache_file=cachefile, memcache=True, cache_file_primary=True) 1069 | o2 = sample(norm(10, 20), load_cache_file=cachefile, memcache=True, cache_file_primary=False) 1070 | assert o1 == o2 1071 | 1072 | 1073 | def test_sample_load_noop_cachefile(cachefile): 1074 | assert not os.path.exists(cachefile + ".sqcache.npy") 1075 | 1076 | from ..squigglepy.samplers import _squigglepy_internal_sample_caches 1077 | 1078 | n_caches = len(_squigglepy_internal_sample_caches) 1079 | 1080 | o1 = sample(norm(100, 200), dump_cache_file=cachefile, memcache=True) 1081 | 1082 | from ..squigglepy.samplers import _squigglepy_internal_sample_caches 1083 | 1084 | n_caches2 = len(_squigglepy_internal_sample_caches) 1085 | assert n_caches2 == n_caches + 1 1086 | assert os.path.exists(cachefile + ".sqcache.npy") 1087 | 1088 | o2 = sample(load_cache_file=cachefile) 1089 | assert o1 == o2 1090 | 1091 | 1092 | def test_sample_load_noop_nonexisting_cachefile(cachefile): 1093 | assert not os.path.exists(cachefile + ".sqcache.npy") 1094 | assert sample(load_cache_file=cachefile) is None 1095 | 1096 | 1097 | def test_sample_multicore(): 1098 | sample(norm(100, 200), n=100, cores=2) 1099 | assert not os.path.exists("test-core-0.sqcache") 1100 | -------------------------------------------------------------------------------- /tests/test_utils.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import numpy as np 3 | 4 | from datetime import datetime, timedelta 5 | from ..squigglepy.utils import ( 6 | _core_cuts, 7 | _process_weights_values, 8 | _process_discrete_weights_values, 9 | event_occurs, 10 | event_happens, 11 | event, 12 | _weighted_percentile, 13 | get_percentiles, 14 | get_log_percentiles, 15 | get_mean_and_ci, 16 | get_median_and_ci, 17 | geomean, 18 | p_to_odds, 19 | odds_to_p, 20 | geomean_odds, 21 | laplace, 22 | growth_rate_to_doubling_time, 23 | doubling_time_to_growth_rate, 24 | roll_die, 25 | flip_coin, 26 | kelly, 27 | full_kelly, 28 | half_kelly, 29 | third_kelly, 30 | quarter_kelly, 31 | one_in, 32 | extremize, 33 | sharpe_ratio, 34 | normalize, 35 | bucket_percentages, 36 | ) 37 | from ..squigglepy.rng import set_seed 38 | from ..squigglepy.distributions import bernoulli, beta, norm, dist_round, const 39 | 40 | 41 | def test_process_weights_values_simple_case(): 42 | test = _process_weights_values(weights=[0.1, 0.9], values=[2, 3]) 43 | expected = ([0.1, 0.9], [2, 3]) 44 | assert test == expected 45 | 46 | 47 | def test_process_weights_values_numpy_arrays(): 48 | test = _process_weights_values(weights=np.array([0.1, 0.9]), values=np.array([2, 3])) 49 | expected = ([0.1, 0.9], [2, 3]) 50 | assert test == expected 51 | 52 | 53 | def test_process_weights_values_length_one(): 54 | test = _process_weights_values(weights=[1], values=[2]) 55 | expected = ([1], [2]) 56 | assert test == expected 57 | 58 | 59 | def test_process_weights_values_alt_format(): 60 | test = _process_weights_values(values=[[0.1, 2], [0.2, 3], [0.3, 4], [0.4, 5]]) 61 | expected = ([0.1, 0.2, 0.3, 0.4], [2, 3, 4, 5]) 62 | assert test == expected 63 | 64 | 65 | def test_process_weights_values_alt2_format(): 66 | test = _process_weights_values(values={2: 0.1, 3: 0.2, 4: 0.3, 5: 0.4}) 67 | expected = ([0.1, 0.2, 0.3, 0.4], [2, 3, 4, 5]) 68 | assert test == expected 69 | 70 | 71 | def test_process_weights_values_dict_error(): 72 | with pytest.raises(ValueError) as execinfo: 73 | _process_weights_values( 74 | weights=[0.1, 0.2, 0.3, 0.4], values={2: 0.1, 3: 0.2, 4: 0.3, 5: 0.4} 75 | ) 76 | assert "cannot pass dict and weights separately" in str(execinfo.value) 77 | 78 | 79 | def test_process_weights_values_weight_inference(): 80 | test = _process_weights_values(weights=[0.9], values=[2, 3]) 81 | expected = ([0.9, 0.1], [2, 3]) 82 | test[0][1] = round(test[0][1], 1) # fix floating point errors 83 | assert test == expected 84 | 85 | 86 | def test_process_weights_values_weight_inference_not_list(): 87 | test = _process_weights_values(weights=0.9, values=[2, 3]) 88 | expected = ([0.9, 0.1], [2, 3]) 89 | test[0][1] = round(test[0][1], 1) # fix floating point errors 90 | assert test == expected 91 | 92 | 93 | def test_process_weights_values_weight_inference_no_weights(): 94 | test = _process_weights_values(values=[2, 3]) 95 | expected = ([0.5, 0.5], [2, 3]) 96 | assert test == expected 97 | 98 | 99 | def test_process_weights_values_weight_inference_relative_weights(): 100 | test = _process_weights_values(values=[2, 3], relative_weights=[1, 3]) 101 | expected = ([0.25, 0.75], [2, 3]) 102 | assert test == expected 103 | 104 | 105 | def test_process_weights_values_weight_inference_no_weights_len4(): 106 | test = _process_weights_values(values=[2, 3, 4, 5]) 107 | expected = ([0.25, 0.25, 0.25, 0.25], [2, 3, 4, 5]) 108 | assert test == expected 109 | 110 | 111 | def test_process_weights_values_weights_must_be_list_error(): 112 | with pytest.raises(ValueError) as excinfo: 113 | _process_weights_values(weights="error", values=[2, 3]) 114 | assert "passed weights must be an iterable" in str(excinfo.value) 115 | 116 | 117 | def test_process_weights_values_values_must_be_list_error(): 118 | with pytest.raises(ValueError) as excinfo: 119 | _process_weights_values(weights=[0.1, 0.9], values="error") 120 | assert "passed values must be an iterable" in str(excinfo.value) 121 | 122 | 123 | def test_process_weights_values_weights_must_sum_to_1_error(): 124 | with pytest.raises(ValueError) as excinfo: 125 | _process_weights_values(weights=[0.2, 0.9], values=[2, 3]) 126 | assert "weights don't sum to 1 - they sum to 1.1" in str(excinfo.value) 127 | 128 | 129 | def test_process_weights_values_length_mismatch_error(): 130 | with pytest.raises(ValueError) as excinfo: 131 | _process_weights_values(weights=[0.1, 0.9], values=[2, 3, 4]) 132 | assert "weights and values not same length" in str(excinfo.value) 133 | 134 | 135 | def test_process_weights_values_negative_weights(): 136 | with pytest.raises(ValueError) as excinfo: 137 | _process_weights_values(weights=[-0.1, 0.2, 0.9], values=[2, 3, 4]) 138 | assert "weight cannot be negative" in str(excinfo.value) 139 | 140 | 141 | def test_process_weights_values_remove_zero_weights(): 142 | test = _process_weights_values(weights=[0, 0.3, 0, 0.7, 0], values=[1, 2, 3, 4, 5]) 143 | expected = ([0.3, 0.7], [2, 4]) 144 | assert test == expected 145 | 146 | 147 | def test_process_weights_values_handle_none(): 148 | test = _process_weights_values(weights=None, values=[1, None, 3, 4, 5]) 149 | expected = ([0.2, 0.2, 0.2, 0.2, 0.2], [1, None, 3, 4, 5]) 150 | assert test == expected 151 | 152 | 153 | def test_process_weights_values_can_drop_none(): 154 | test = _process_weights_values(weights=None, values=[1, None, 3, 4, 5], drop_na=True) 155 | expected = ([0.25, 0.25, 0.25, 0.25], [1, 3, 4, 5]) 156 | assert test == expected 157 | 158 | 159 | def test_process_weights_values_attempt_drop_none_with_weights_error(): 160 | with pytest.raises(ValueError) as execinfo: 161 | _process_weights_values( 162 | weights=[0.2, 0.2, 0.2, 0.2, 0.2], values=[1, None, 3, 4, 5], drop_na=True 163 | ) 164 | assert "cannot drop NA and process weights" in str(execinfo.value) 165 | with pytest.raises(ValueError) as execinfo: 166 | _process_weights_values( 167 | relative_weights=[1, 1, 1, 1, 1], values=[1, None, 3, 4, 5], drop_na=True 168 | ) 169 | assert "cannot drop NA and process weights" in str(execinfo.value) 170 | 171 | 172 | def test_process_discrete_weights_values_simple_case(): 173 | test = _process_discrete_weights_values([[0.1, 2], [0.9, 3]]) 174 | expected = ([0.1, 0.9], [2, 3]) 175 | assert test == expected 176 | test = _process_discrete_weights_values({2: 0.1, 3: 0.9}) 177 | expected = ([0.1, 0.9], [2, 3]) 178 | assert test == expected 179 | 180 | 181 | def test_process_discrete_weights_values_compress(): 182 | items = [round((x % 10) / 10, 1) for x in range(1000)] 183 | test = _process_discrete_weights_values(items) 184 | expected = ( 185 | [0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1], 186 | [0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9], 187 | ) 188 | assert test == expected 189 | 190 | 191 | def test_normalize(): 192 | assert normalize([0.1, 0.4]) == [0.2, 0.8] 193 | 194 | 195 | def test_event_occurs(): 196 | set_seed(42) 197 | assert event_occurs(0.9) 198 | set_seed(42) 199 | assert not event_occurs(0.1) 200 | 201 | 202 | def test_event_occurs_can_handle_distributions(): 203 | set_seed(42) 204 | assert event_occurs(bernoulli(0.9)) 205 | 206 | 207 | def test_event_occurs_can_handle_distributions2(): 208 | set_seed(42) 209 | assert event_occurs(beta(10, 1)) 210 | 211 | 212 | def test_event_occurs_can_handle_distributions_callable(): 213 | def get_p(): 214 | return 0.9 215 | 216 | assert event_occurs(get_p) 217 | 218 | 219 | def test_event_happens(): 220 | set_seed(42) 221 | assert event_happens(0.9) 222 | set_seed(42) 223 | assert not event_happens(0.1) 224 | 225 | 226 | def test_event(): 227 | set_seed(42) 228 | assert event(0.9) 229 | set_seed(42) 230 | assert not event(0.1) 231 | 232 | 233 | def test_one_in(): 234 | assert one_in(0.1) == "1 in 10" 235 | assert one_in(0.02) == "1 in 50" 236 | assert one_in(0.00002) == "1 in 50,000" 237 | 238 | 239 | def test_one_in_w_rounding(): 240 | assert one_in(0.1415) == "1 in 7" 241 | assert one_in(0.1415, digits=1) == "1 in 7.1" 242 | assert one_in(0.1415, digits=2) == "1 in 7.07" 243 | assert one_in(0.1415, digits=3) == "1 in 7.067" 244 | 245 | 246 | def test_one_in_not_verbose(): 247 | assert one_in(0.1415, digits=3, verbose=False) == 7.067 248 | 249 | 250 | def test_get_percentiles(): 251 | test = get_percentiles(range(1, 901)) 252 | expected = { 253 | 1: 9.99, 254 | 5: 45.95, 255 | 10: 90.9, 256 | 20: 180.8, 257 | 30: 270.7, 258 | 40: 360.6, 259 | 50: 450.5, 260 | 60: 540.4, 261 | 70: 630.3, 262 | 80: 720.2, 263 | 90: 810.1, 264 | 95: 855.05, 265 | 99: 891.01, 266 | } 267 | assert test == expected 268 | 269 | 270 | def test_get_percentiles_change_percentiles(): 271 | test = get_percentiles(range(1, 901), percentiles=[20, 80]) 272 | expected = {20: 180.8, 80: 720.2} 273 | assert test == expected 274 | 275 | test = get_percentiles(range(1, 901), percentiles=[25, 75]) 276 | expected = {25: 225.75, 75: 675.25} 277 | assert test == expected 278 | 279 | 280 | def test_get_percentiles_reverse(): 281 | test = get_percentiles(range(1, 901), percentiles=[20, 80], reverse=True) 282 | expected = {20: 720.2, 80: 180.8} 283 | assert test == expected 284 | 285 | 286 | def test_get_percentiles_digits(): 287 | test = get_percentiles(range(1, 901), percentiles=[25, 75], digits=1) 288 | expected = {25: 225.8, 75: 675.2} 289 | assert test == expected 290 | 291 | 292 | def test_get_percentiles_length_one(): 293 | test = get_percentiles(range(1, 901), percentiles=[25], digits=1) 294 | assert test == 225.8 295 | test = get_percentiles(range(1, 901), percentiles=25, digits=1) 296 | assert test == 225.8 297 | 298 | 299 | def test_get_percentiles_zero_digits(): 300 | test = get_percentiles(range(1, 901), percentiles=[25, 75], digits=0) 301 | expected = {25: 226, 75: 675} 302 | assert test == expected 303 | assert isinstance(expected[25], int) 304 | assert isinstance(expected[75], int) 305 | 306 | 307 | def test_get_percentiles_negative_one_digits(): 308 | test = get_percentiles(range(1, 901), percentiles=[25, 75], digits=-1) 309 | expected = {25: 230, 75: 680} 310 | assert test == expected 311 | assert isinstance(expected[25], int) 312 | assert isinstance(expected[75], int) 313 | 314 | 315 | def test_weighted_percentile(): 316 | test = _weighted_percentile( 317 | range(1, 901), weights=[1 for _ in range(900)], percentiles=[25, 75] 318 | ) 319 | assert test[0] == 225.5 320 | assert test[1] == 675.5 321 | test = _weighted_percentile(range(1, 901), weights=range(900), percentiles=[25, 75]) 322 | assert round(test[0], 1) == 450.7 323 | assert round(test[1], 1) == 780.0 324 | 325 | 326 | def test_weighted_percentile_requires_weights_len_equal_data_len(): 327 | with pytest.raises(ValueError) as execinfo: 328 | _weighted_percentile([1, 2, 3], [1, 2], [1, 2, 3]) 329 | assert "must be of the same length" in str(execinfo.value) 330 | 331 | 332 | def test_get_percentiles_weighted(): 333 | test = get_percentiles(range(1, 901), weights=range(900), digits=-1) 334 | expected = { 335 | 1: 90, 336 | 5: 200, 337 | 10: 290, 338 | 20: 400, 339 | 30: 490, 340 | 40: 570, 341 | 50: 640, 342 | 60: 700, 343 | 70: 750, 344 | 80: 810, 345 | 90: 850, 346 | 95: 880, 347 | 99: 900, 348 | } 349 | assert test == expected 350 | 351 | 352 | def test_get_percentiles_with_weights_requires_weights_len_equal_data_len(): 353 | with pytest.raises(ValueError) as execinfo: 354 | get_percentiles(range(1, 901), weights=range(800), digits=-1) 355 | assert "must be of the same length" in str(execinfo.value) 356 | 357 | 358 | def test_get_log_percentiles(): 359 | test = get_log_percentiles([10**x for x in range(1, 10)]) 360 | expected = { 361 | 1: "1.7e+01", 362 | 5: "4.6e+01", 363 | 10: "8.2e+01", 364 | 20: "6.4e+02", 365 | 30: "4.6e+03", 366 | 40: "2.8e+04", 367 | 50: "1.0e+05", 368 | 60: "8.2e+05", 369 | 70: "6.4e+06", 370 | 80: "4.6e+07", 371 | 90: "2.8e+08", 372 | 95: "6.4e+08", 373 | 99: "9.3e+08", 374 | } 375 | assert test == expected 376 | 377 | 378 | def test_get_log_percentiles_change_percentiles(): 379 | test = get_log_percentiles([10**x for x in range(1, 10)], percentiles=[20, 80]) 380 | expected = {20: "6.4e+02", 80: "4.6e+07"} 381 | assert test == expected 382 | 383 | 384 | def test_get_log_percentiles_reverse(): 385 | test = get_log_percentiles([10**x for x in range(1, 10)], percentiles=[20, 80], reverse=True) 386 | expected = {20: "4.6e+07", 80: "6.4e+02"} 387 | assert test == expected 388 | 389 | 390 | def test_get_log_percentiles_no_display(): 391 | test = get_log_percentiles([10**x for x in range(1, 10)], percentiles=[20, 80], display=False) 392 | expected = {20: 2.8, 80: 7.7} 393 | assert test == expected 394 | 395 | 396 | def test_get_log_percentiles_zero_digits(): 397 | test = get_log_percentiles( 398 | [10**x for x in range(1, 10)], percentiles=[20, 80], display=False, digits=0 399 | ) 400 | expected = {20: 3, 80: 8} 401 | assert test == expected 402 | 403 | 404 | def test_get_log_percentiles_length_one(): 405 | test = get_log_percentiles( 406 | [10**x for x in range(1, 10)], percentiles=[20], display=False, digits=0 407 | ) 408 | assert test == 3 409 | test = get_log_percentiles( 410 | [10**x for x in range(1, 10)], percentiles=20, display=False, digits=0 411 | ) 412 | assert test == 3 413 | 414 | 415 | def test_get_log_percentiles_weighted(): 416 | test = get_log_percentiles([10**x for x in range(1, 10)], weights=range(9)) 417 | expected = { 418 | 1: "7.5e+01", 419 | 5: "8.8e+02", 420 | 10: "6.8e+03", 421 | 20: "7.9e+04", 422 | 30: "6.6e+05", 423 | 40: "4.1e+06", 424 | 50: "1.0e+07", 425 | 60: "6.0e+07", 426 | 70: "1.8e+08", 427 | 80: "6.2e+08", 428 | 90: "1.0e+09", 429 | 95: "1.0e+09", 430 | 99: "1.0e+09", 431 | } 432 | assert test == expected 433 | 434 | 435 | def test_get_mean_and_ci(): 436 | test1 = get_mean_and_ci(range(1, 901), digits=1) 437 | assert test1 == {"mean": 450.5, "ci_low": 46.0, "ci_high": 855.0} 438 | test2 = get_mean_and_ci([1, 2, 6], digits=1) 439 | assert test2 == {"mean": 3, "ci_low": 1.1, "ci_high": 5.6} 440 | 441 | 442 | def test_get_mean_and_80_pct_ci(): 443 | test = get_mean_and_ci(range(1, 901), digits=1, credibility=80) 444 | assert test == {"mean": 450.5, "ci_low": 90.9, "ci_high": 810.1} 445 | 446 | 447 | def test_get_mean_and_ci_weighted(): 448 | test1 = get_mean_and_ci(range(1, 901), weights=range(900), digits=1) 449 | assert test1 == {"mean": 450.5, "ci_low": 202.1, "ci_high": 877.7} 450 | 451 | 452 | def test_get_median_and_ci(): 453 | test1 = get_median_and_ci(range(1, 901), digits=1) 454 | assert test1 == {"median": 450.5, "ci_low": 46.0, "ci_high": 855.0} 455 | test2 = get_median_and_ci([1, 2, 6], digits=1) 456 | assert test2 == {"median": 2, "ci_low": 1.1, "ci_high": 5.6} 457 | 458 | 459 | def test_get_median_and_80_pct_ci(): 460 | test = get_median_and_ci(range(1, 901), digits=1, credibility=80) 461 | assert test == {"median": 450.5, "ci_low": 90.9, "ci_high": 810.1} 462 | 463 | 464 | def test_get_median_and_ci_weighted(): 465 | test1 = get_median_and_ci(range(1, 901), weights=range(900), digits=1) 466 | assert test1 == {"median": 637.0, "ci_low": 202.1, "ci_high": 877.7} 467 | 468 | 469 | def test_geomean(): 470 | assert round(geomean([0.1, 0.2, 0.3, 0.4, 0.5]), 2) == 0.26 471 | 472 | 473 | def test_geomean_numpy(): 474 | assert round(geomean(np.array([0.1, 0.2, 0.3, 0.4, 0.5])), 2) == 0.26 475 | 476 | 477 | def test_weighted_geomean(): 478 | assert round(geomean([0.1, 0.2, 0.3, 0.4, 0.5], weights=[0.5, 0.1, 0.1, 0.1, 0.2]), 2) == 0.19 479 | 480 | 481 | def test_relative_weighted_geomean(): 482 | assert round(geomean([0.1, 0.2, 0.3, 0.4, 0.5], relative_weights=[5, 1, 1, 1, 2]), 2) == 0.19 483 | 484 | 485 | def test_geomean_with_none_value(): 486 | assert round(geomean([0.1, 0.2, None, 0.3, 0.4, None, 0.5]), 2) == 0.26 487 | 488 | 489 | def test_weighted_geomean_alt_format(): 490 | assert round(geomean([[0.5, 0.1], [0.1, 0.2], [0.1, 0.3], [0.1, 0.4], [0.2, 0.5]]), 2) == 0.19 491 | 492 | 493 | def test_weighted_geomean_alt2_format(): 494 | assert round(geomean({0.1: 0.5, 0.2: 0.1, 0.3: 0.1, 0.4: 0.1, 0.5: 0.2}), 2) == 0.19 495 | 496 | 497 | def test_weighted_geomean_errors_with_none_value(): 498 | with pytest.raises(ValueError) as execinfo: 499 | geomean({0.1: 0.5, 0.2: 0.1, 0.3: None, 0.4: 0.1, 0.5: 0.2}) 500 | assert "cannot handle NA-like values in weights" in str(execinfo.value) 501 | with pytest.raises(ValueError) as execinfo: 502 | geomean([[0.5, 0.1], [0.1, None], [0.1, 0.3], [0.1, 0.4], [0.2, 0.5]]) 503 | assert "cannot drop NA and process weights" in str(execinfo.value) 504 | 505 | 506 | def test_p_to_odds(): 507 | assert round(p_to_odds(0.1), 2) == 0.11 508 | 509 | 510 | def test_odds_to_p(): 511 | assert round(odds_to_p(1 / 9), 2) == 0.1 512 | 513 | 514 | def test_p_to_odds_handles_none(): 515 | assert p_to_odds(None) is None 516 | 517 | 518 | def test_odds_to_p_handles_none(): 519 | assert odds_to_p(None) is None 520 | 521 | 522 | def test_p_to_odds_handles_multiple(): 523 | assert all(np.round(p_to_odds([0.1, 0.2, 0.3]), 2) == np.array([0.11, 0.25, 0.43])) 524 | 525 | 526 | def test_odds_to_p_handles_multiple(): 527 | assert all(np.round(odds_to_p([0.1, 0.2, 0.3]), 2) == np.array([0.09, 0.17, 0.23])) 528 | 529 | 530 | def test_geomean_odds(): 531 | assert round(geomean_odds([0.1, 0.2, 0.3, 0.4, 0.5]), 2) == 0.28 532 | 533 | 534 | def test_geomean_odds_numpy(): 535 | assert round(geomean_odds(np.array([0.1, 0.2, 0.3, 0.4, 0.5])), 2) == 0.28 536 | 537 | 538 | def test_weighted_geomean_odds(): 539 | assert ( 540 | round( 541 | geomean_odds([0.1, 0.2, 0.3, 0.4, 0.5], weights=[0.5, 0.1, 0.1, 0.1, 0.2]), 542 | 2, 543 | ) 544 | == 0.2 545 | ) 546 | 547 | 548 | def test_weighted_geomean_odds_alt_format(): 549 | assert ( 550 | round( 551 | geomean_odds([[0.5, 0.1], [0.1, 0.2], [0.1, 0.3], [0.1, 0.4], [0.2, 0.5]]), 552 | 2, 553 | ) 554 | == 0.2 555 | ) 556 | 557 | 558 | def test_weighted_geomean_odds_alt2_format(): 559 | assert round(geomean_odds({0.1: 0.5, 0.2: 0.1, 0.3: 0.1, 0.4: 0.1, 0.5: 0.2}), 2) == 0.2 560 | 561 | 562 | def test_laplace_simple(): 563 | test = laplace(0, 1) 564 | expected = 1 / 3 565 | assert test == expected 566 | 567 | 568 | def test_laplace_s_is_1(): 569 | test = laplace(1, 1) 570 | expected = 2 / 3 571 | assert test == expected 572 | 573 | 574 | def test_laplace_s_gt_n(): 575 | with pytest.raises(ValueError) as excinfo: 576 | laplace(3, 2) 577 | assert "`s` cannot be greater than `n`" in str(excinfo.value) 578 | 579 | 580 | def test_time_invariant_laplace_zero_s(): 581 | assert laplace(s=0, time_passed=2, time_remaining=2) == 0.5 582 | 583 | 584 | def test_time_invariant_laplace_one_s_time_fixed(): 585 | assert laplace(s=1, time_passed=2, time_remaining=2, time_fixed=True) == 0.75 586 | 587 | 588 | def test_time_invariant_laplace_one_s_time_variable(): 589 | assert laplace(s=1, time_passed=2, time_remaining=2) == 0.5 590 | 591 | 592 | def test_time_invariant_laplace_infer_time_remaining(): 593 | assert round(laplace(s=1, time_passed=2), 2) == 0.33 594 | 595 | 596 | def test_time_invariant_laplace_two_s(): 597 | assert laplace(s=2, time_passed=2, time_remaining=2) == 0.75 598 | 599 | 600 | def test_laplace_only_s(): 601 | with pytest.raises(ValueError) as excinfo: 602 | laplace(3) 603 | assert "Must define `time_passed` or `n`" in str(excinfo.value) 604 | 605 | 606 | def test_laplace_no_time_passed(): 607 | with pytest.raises(ValueError) as excinfo: 608 | laplace(3, time_remaining=1) 609 | assert "Must define `time_passed`" in str(excinfo.value) 610 | with pytest.raises(ValueError) as excinfo: 611 | laplace(3, n=10, time_remaining=1) 612 | assert "Must define `time_passed`" in str(excinfo.value) 613 | 614 | 615 | def test_growth_rate_to_doubling_time_float(): 616 | assert round(growth_rate_to_doubling_time(0.01), 2) == 69.66 617 | assert round(growth_rate_to_doubling_time(0.5), 2) == 1.71 618 | assert round(growth_rate_to_doubling_time(1.0), 2) == 1.0 619 | 620 | 621 | def test_growth_rate_to_doubling_time_nparray(): 622 | result = growth_rate_to_doubling_time(np.array([0.01, 0.5, 1.0])) 623 | assert round(result[0], 2) == 69.66 624 | assert round(result[1], 2) == 1.71 625 | assert round(result[2], 2) == 1 626 | 627 | 628 | def test_growth_rate_to_doubling_time_dist(): 629 | assert round(growth_rate_to_doubling_time(const(0.01)) @ 1, 2) == 69.66 630 | 631 | 632 | def test_doubling_time_to_growth_rate_float(): 633 | assert round(doubling_time_to_growth_rate(12), 2) == 0.06 634 | assert round(doubling_time_to_growth_rate(5.5), 2) == 0.13 635 | assert round(doubling_time_to_growth_rate(1), 2) == 1 636 | 637 | 638 | def test_doubling_time_to_growth_rate_nparray(): 639 | result = doubling_time_to_growth_rate(np.array([12, 5.5, 1])) 640 | assert round(result[0], 2) == 0.06 641 | assert round(result[1], 2) == 0.13 642 | assert round(result[2], 2) == 1 643 | 644 | 645 | def test_doubling_time_to_growth_rate_dist(): 646 | assert round(doubling_time_to_growth_rate(const(12)) @ 1, 2) == 0.06 647 | 648 | 649 | def test_roll_die(): 650 | set_seed(42) 651 | assert roll_die(6) == 5 652 | 653 | 654 | def test_roll_die_different_sides(): 655 | set_seed(42) 656 | assert roll_die(4) == 4 657 | 658 | 659 | def test_roll_die_with_distribution(): 660 | set_seed(42) 661 | assert (norm(2, 6) >> dist_round >> roll_die) == 2 662 | 663 | 664 | def test_roll_one_sided_die(): 665 | with pytest.raises(ValueError) as excinfo: 666 | roll_die(1) 667 | assert "cannot roll less than a 2-sided die" in str(excinfo.value) 668 | 669 | 670 | def test_roll_nonint_die(): 671 | with pytest.raises(ValueError) as excinfo: 672 | roll_die(2.5) 673 | assert "can only roll an integer number of sides" in str(excinfo.value) 674 | 675 | 676 | def test_roll_nonint_n(): 677 | with pytest.raises(ValueError) as excinfo: 678 | roll_die(6, 2.5) 679 | assert "can only roll an integer number of times" in str(excinfo.value) 680 | 681 | 682 | def test_roll_five_die(): 683 | set_seed(42) 684 | assert list(roll_die(4, 4)) == [4, 2, 4, 3] 685 | 686 | 687 | def test_flip_coin(): 688 | set_seed(42) 689 | assert flip_coin() == "heads" 690 | 691 | 692 | def test_flip_five_coins(): 693 | set_seed(42) 694 | assert flip_coin(5) == ["heads", "tails", "heads", "heads", "tails"] 695 | 696 | 697 | def test_kelly_market_price_error(): 698 | for val in [0, 1, 2, -1]: 699 | with pytest.raises(ValueError) as execinfo: 700 | kelly(my_price=0.99, market_price=val) 701 | assert "market_price must be >0 and <1" in str(execinfo.value) 702 | 703 | 704 | def test_kelly_my_price_error(): 705 | for val in [0, 1, 2, -1]: 706 | with pytest.raises(ValueError) as execinfo: 707 | kelly(my_price=val, market_price=0.99) 708 | assert "my_price must be >0 and <1" in str(execinfo.value) 709 | 710 | 711 | def test_kelly_deference_error(): 712 | for val in [-1, 2]: 713 | with pytest.raises(ValueError) as execinfo: 714 | kelly(my_price=0.01, market_price=0.99, deference=val) 715 | assert "deference must be >=0 and <=1" in str(execinfo.value) 716 | 717 | 718 | def test_kelly_user_below_market_price_error(): 719 | with pytest.raises(ValueError) as execinfo: 720 | kelly(my_price=0.1, market_price=0.2) 721 | assert "below the market price" in str(execinfo.value) 722 | assert "bypass this issue" in str(execinfo.value) 723 | 724 | 725 | def test_kelly_user_below_market_price_error_can_be_overriden(): 726 | assert isinstance(kelly(my_price=0.1, market_price=0.2, error=False), dict) 727 | 728 | 729 | def test_kelly_defaults(): 730 | obj = kelly(my_price=0.99, market_price=0.01) 731 | assert obj["my_price"] == 0.99 732 | assert obj["market_price"] == 0.01 733 | assert obj["deference"] == 0 734 | assert obj["adj_price"] == 0.99 735 | assert obj["delta_price"] == 0.98 736 | assert obj["adj_delta_price"] == 0.98 737 | assert obj["kelly"] == 0.99 738 | assert obj["target"] == 0.99 739 | assert obj["current"] == 0 740 | assert obj["delta"] == 0.99 741 | assert obj["max_gain"] == 98 742 | assert obj["modeled_gain"] == 97.01 743 | assert obj["expected_roi"] == 98 744 | assert obj["expected_arr"] is None 745 | assert obj["resolve_date"] is None 746 | 747 | 748 | def test_full_kelly(): 749 | obj = full_kelly(my_price=0.99, market_price=0.01) 750 | assert obj["my_price"] == 0.99 751 | assert obj["market_price"] == 0.01 752 | assert obj["deference"] == 0 753 | assert obj["adj_price"] == 0.99 754 | assert obj["delta_price"] == 0.98 755 | assert obj["adj_delta_price"] == 0.98 756 | assert obj["kelly"] == 0.99 757 | assert obj["target"] == 0.99 758 | assert obj["current"] == 0 759 | assert obj["delta"] == 0.99 760 | assert obj["max_gain"] == 98 761 | assert obj["modeled_gain"] == 97.01 762 | assert obj["expected_roi"] == 98 763 | assert obj["expected_arr"] is None 764 | assert obj["resolve_date"] is None 765 | 766 | 767 | def test_full_kelly_passes_error_parameter(): 768 | obj = full_kelly(my_price=0.1, market_price=0.2, error=False) 769 | assert isinstance(obj, dict) 770 | 771 | with pytest.raises(ValueError) as excinfo: 772 | full_kelly(my_price=0.1, market_price=0.2, error=True) 773 | assert "below the market price" in str(excinfo.value) 774 | 775 | 776 | def test_half_kelly(): 777 | obj = half_kelly(my_price=0.99, market_price=0.01) 778 | assert obj["my_price"] == 0.99 779 | assert obj["market_price"] == 0.01 780 | assert obj["deference"] == 0.5 781 | assert obj["adj_price"] == 0.5 782 | assert obj["delta_price"] == 0.98 783 | assert obj["adj_delta_price"] == 0.49 784 | assert obj["kelly"] == 0.495 785 | assert obj["target"] == 0.49 786 | assert obj["current"] == 0 787 | assert obj["delta"] == 0.49 788 | assert obj["max_gain"] == 49 789 | assert obj["modeled_gain"] == 24.25 790 | assert obj["expected_roi"] == 49 791 | assert obj["expected_arr"] is None 792 | assert obj["resolve_date"] is None 793 | 794 | 795 | def test_half_kelly_passes_error_parameter(): 796 | obj = half_kelly(my_price=0.1, market_price=0.2, error=False) 797 | assert isinstance(obj, dict) 798 | 799 | with pytest.raises(ValueError) as excinfo: 800 | half_kelly(my_price=0.1, market_price=0.2, error=True) 801 | assert "below the market price" in str(excinfo.value) 802 | 803 | 804 | def test_third_kelly(): 805 | obj = third_kelly(my_price=0.99, market_price=0.01) 806 | assert obj["my_price"] == 0.99 807 | assert obj["market_price"] == 0.01 808 | assert obj["deference"] == 0.667 809 | assert obj["adj_price"] == 0.34 810 | assert obj["delta_price"] == 0.98 811 | assert obj["adj_delta_price"] == 0.33 812 | assert obj["kelly"] == 0.33 813 | assert obj["target"] == 0.33 814 | assert obj["current"] == 0 815 | assert obj["delta"] == 0.33 816 | assert obj["max_gain"] == 32.67 817 | assert obj["modeled_gain"] == 10.78 818 | assert obj["expected_roi"] == 32.673 819 | assert obj["expected_arr"] is None 820 | assert obj["resolve_date"] is None 821 | 822 | 823 | def test_third_kelly_passes_error_parameter(): 824 | obj = third_kelly(my_price=0.1, market_price=0.2, error=False) 825 | assert isinstance(obj, dict) 826 | 827 | with pytest.raises(ValueError) as excinfo: 828 | third_kelly(my_price=0.1, market_price=0.2, error=True) 829 | assert "below the market price" in str(excinfo.value) 830 | 831 | 832 | def test_quarter_kelly(): 833 | obj = quarter_kelly(my_price=0.99, market_price=0.01) 834 | assert obj["my_price"] == 0.99 835 | assert obj["market_price"] == 0.01 836 | assert obj["deference"] == 0.75 837 | assert obj["adj_price"] == 0.26 838 | assert obj["delta_price"] == 0.98 839 | assert obj["adj_delta_price"] == 0.24 840 | assert obj["kelly"] == 0.247 841 | assert obj["target"] == 0.25 842 | assert obj["current"] == 0 843 | assert obj["delta"] == 0.25 844 | assert obj["max_gain"] == 24.5 845 | assert obj["modeled_gain"] == 6.06 846 | assert obj["expected_roi"] == 24.5 847 | assert obj["expected_arr"] is None 848 | assert obj["resolve_date"] is None 849 | 850 | 851 | def test_quarter_kelly_passes_error_parameter(): 852 | obj = quarter_kelly(my_price=0.1, market_price=0.2, error=False) 853 | assert isinstance(obj, dict) 854 | 855 | with pytest.raises(ValueError) as excinfo: 856 | quarter_kelly(my_price=0.1, market_price=0.2, error=True) 857 | assert "below the market price" in str(excinfo.value) 858 | 859 | 860 | def test_kelly_with_bankroll(): 861 | obj = kelly(my_price=0.99, market_price=0.01, bankroll=1000) 862 | assert obj["my_price"] == 0.99 863 | assert obj["market_price"] == 0.01 864 | assert obj["deference"] == 0 865 | assert obj["adj_price"] == 0.99 866 | assert obj["delta_price"] == 0.98 867 | assert obj["adj_delta_price"] == 0.98 868 | assert obj["kelly"] == 0.99 869 | assert obj["target"] == 989.9 870 | assert obj["current"] == 0 871 | assert obj["delta"] == 989.9 872 | assert obj["max_gain"] == 98000 873 | assert obj["modeled_gain"] == 97010.1 874 | assert obj["expected_roi"] == 98 875 | assert obj["expected_arr"] is None 876 | assert obj["resolve_date"] is None 877 | 878 | 879 | def test_kelly_with_current(): 880 | obj = kelly(my_price=0.99, market_price=0.01, bankroll=1000, current=100) 881 | assert obj["my_price"] == 0.99 882 | assert obj["market_price"] == 0.01 883 | assert obj["deference"] == 0 884 | assert obj["adj_price"] == 0.99 885 | assert obj["delta_price"] == 0.98 886 | assert obj["adj_delta_price"] == 0.98 887 | assert obj["kelly"] == 0.99 888 | assert obj["target"] == 989.9 889 | assert obj["current"] == 100 890 | assert obj["delta"] == 889.9 891 | assert obj["max_gain"] == 98000 892 | assert obj["modeled_gain"] == 97010.1 893 | assert obj["expected_roi"] == 98 894 | assert obj["expected_arr"] is None 895 | assert obj["resolve_date"] is None 896 | 897 | 898 | def test_kelly_with_resolve_date(): 899 | one_year_from_today = datetime.now() + timedelta(days=365) 900 | one_year_from_today_str = one_year_from_today.strftime("%Y-%m-%d") 901 | obj = kelly(my_price=0.99, market_price=0.01, resolve_date=one_year_from_today_str) 902 | assert obj["my_price"] == 0.99 903 | assert obj["market_price"] == 0.01 904 | assert obj["deference"] == 0 905 | assert obj["adj_price"] == 0.99 906 | assert obj["delta_price"] == 0.98 907 | assert obj["adj_delta_price"] == 0.98 908 | assert obj["kelly"] == 0.99 909 | assert obj["target"] == 0.99 910 | assert obj["current"] == 0 911 | assert obj["delta"] == 0.99 912 | assert obj["max_gain"] == 98 913 | assert obj["modeled_gain"] == 97.01 914 | assert obj["expected_roi"] == 98 915 | assert obj["expected_arr"] == 99.258 916 | assert obj["resolve_date"] == datetime( 917 | one_year_from_today.year, 918 | one_year_from_today.month, 919 | one_year_from_today.day, 920 | 0, 921 | 0, 922 | ) 923 | 924 | 925 | def test_kelly_with_resolve_date2(): 926 | two_years_from_today = datetime.now() + timedelta(days=365 * 2) 927 | two_years_from_today_str = two_years_from_today.strftime("%Y-%m-%d") 928 | obj = kelly(my_price=0.99, market_price=0.01, resolve_date=two_years_from_today_str) 929 | assert obj["my_price"] == 0.99 930 | assert obj["market_price"] == 0.01 931 | assert obj["deference"] == 0 932 | assert obj["adj_price"] == 0.99 933 | assert obj["delta_price"] == 0.98 934 | assert obj["adj_delta_price"] == 0.98 935 | assert obj["kelly"] == 0.99 936 | assert obj["target"] == 0.99 937 | assert obj["current"] == 0 938 | assert obj["delta"] == 0.99 939 | assert obj["max_gain"] == 98 940 | assert obj["modeled_gain"] == 97.01 941 | assert obj["expected_roi"] == 98 942 | assert obj["expected_arr"] == 8.981 943 | assert obj["resolve_date"] == datetime( 944 | two_years_from_today.year, 945 | two_years_from_today.month, 946 | two_years_from_today.day, 947 | 0, 948 | 0, 949 | ) 950 | 951 | 952 | def test_kelly_with_resolve_date0pt5(): 953 | half_year_from_today = datetime.now() + timedelta(days=int(round(365 * 0.5))) 954 | half_year_from_today_str = half_year_from_today.strftime("%Y-%m-%d") 955 | obj = kelly(my_price=0.99, market_price=0.01, resolve_date=half_year_from_today_str) 956 | assert obj["my_price"] == 0.99 957 | assert obj["market_price"] == 0.01 958 | assert obj["deference"] == 0 959 | assert obj["adj_price"] == 0.99 960 | assert obj["delta_price"] == 0.98 961 | assert obj["adj_delta_price"] == 0.98 962 | assert obj["kelly"] == 0.99 963 | assert obj["target"] == 0.99 964 | assert obj["current"] == 0 965 | assert obj["delta"] == 0.99 966 | assert obj["max_gain"] == 98 967 | assert obj["modeled_gain"] == 97.01 968 | assert obj["expected_roi"] == 98 969 | assert obj["expected_arr"] == 10575.628 970 | assert obj["resolve_date"] == datetime( 971 | half_year_from_today.year, 972 | half_year_from_today.month, 973 | half_year_from_today.day, 974 | 0, 975 | 0, 976 | ) 977 | 978 | 979 | def test_kelly_worked_example(): 980 | half_year_from_today = datetime.now() + timedelta(days=int(round(365 * 0.5))) 981 | half_year_from_today_str = half_year_from_today.strftime("%Y-%m-%d") 982 | obj = kelly( 983 | my_price=0.6, 984 | market_price=0.17, 985 | deference=0.66, 986 | bankroll=46288, 987 | resolve_date=half_year_from_today_str, 988 | current=7300, 989 | ) 990 | assert obj["my_price"] == 0.6 991 | assert obj["market_price"] == 0.17 992 | assert obj["deference"] == 0.66 993 | assert obj["adj_price"] == 0.32 994 | assert obj["delta_price"] == 0.43 995 | assert obj["adj_delta_price"] == 0.15 996 | assert obj["kelly"] == 0.176 997 | assert obj["target"] == 8153.38 998 | assert obj["current"] == 7300 999 | assert obj["delta"] == 853.38 1000 | assert obj["max_gain"] == 39807.68 1001 | assert obj["modeled_gain"] == 7011.91 1002 | assert obj["expected_roi"] == 0.86 1003 | assert obj["expected_arr"] == 2.495 1004 | assert obj["resolve_date"] == datetime( 1005 | half_year_from_today.year, 1006 | half_year_from_today.month, 1007 | half_year_from_today.day, 1008 | 0, 1009 | 0, 1010 | ) 1011 | 1012 | 1013 | def test_extremize(): 1014 | assert round(extremize(p=0.7, e=1), 3) == 0.7 1015 | assert round(extremize(p=0.7, e=1.73), 3) == 0.875 1016 | assert round(extremize(p=0.2, e=1.73), 3) == 0.062 1017 | 1018 | 1019 | def test_extremize_out_of_bounds(): 1020 | for p in [-1, 0, 1, 2]: 1021 | with pytest.raises(ValueError) as execinfo: 1022 | extremize(p=p, e=1.73) 1023 | assert "must be greater than 0 and less than 1" in str(execinfo.value) 1024 | 1025 | 1026 | def test_core_cuts(): 1027 | assert _core_cuts(10, 2) == [5, 5] 1028 | assert _core_cuts(10, 3) == [3, 3, 4] 1029 | 1030 | 1031 | def test_sharpe_ratio(): 1032 | assert round(sharpe_ratio([0.04, -0.03, 0.05, 0.02, 0.03]), 4) == 0.7898 1033 | 1034 | 1035 | def test_bucket_percentages(): 1036 | data = np.array([1, 3, 5, 7, 9, 11]) 1037 | result = bucket_percentages(data, bins=[-np.inf, 2, 4, 6, 8, 10, np.inf]) 1038 | assert len(result) == 6 1039 | for val in result.values(): 1040 | assert abs(val - 16.67) < 0.01 1041 | 1042 | 1043 | def test_bucket_percentages_custom_bins(): 1044 | data = np.array([1, 3, 5, 7, 9, 11]) 1045 | result = bucket_percentages(data, bins=[0, 5, 10, 15]) 1046 | expected = { 1047 | "[0, 5)": 33.33, 1048 | "[5, 10)": 50.0, 1049 | "[10, 15)": 16.67, 1050 | } 1051 | for key in result: 1052 | assert abs(result[key] - expected[key]) < 0.01 1053 | 1054 | 1055 | def test_bucket_percentages_custom_ranges_and_labels(): 1056 | data = np.array([1, 3, 5, 7, 9, 11]) 1057 | custom_bins = [(-np.inf, 5), (5, 10), (10, np.inf)] 1058 | labels = ["Low", "Medium", "High"] 1059 | result = bucket_percentages(data, custom_bins=custom_bins, labels=labels) 1060 | expected = { 1061 | "Low": 33.33, 1062 | "Medium": 50.0, 1063 | "High": 16.67, 1064 | } 1065 | for key in result: 1066 | assert abs(result[key] - expected[key]) < 0.01 1067 | 1068 | 1069 | def test_bucket_percentages_counts(): 1070 | data = np.array([1, 3, 5, 7, 9, 11]) 1071 | result = bucket_percentages(data, bins=[0, 5, 10, 15], normalize=False, as_percentage=False) 1072 | expected = { 1073 | "[0, 5)": 2, 1074 | "[5, 10)": 3, 1075 | "[10, 15)": 1, 1076 | } 1077 | assert result == expected 1078 | 1079 | 1080 | def test_bucket_percentages_label_mismatch(): 1081 | data = np.array([1, 3, 5, 7, 9, 11]) 1082 | custom_bins = [(-np.inf, 5), (5, 10), (10, np.inf)] 1083 | labels = ["Low", "Medium"] # Missing one label 1084 | 1085 | with pytest.raises(ValueError) as excinfo: 1086 | bucket_percentages(data, custom_bins=custom_bins, labels=labels) 1087 | assert "Number of labels" in str(excinfo.value) 1088 | --------------------------------------------------------------------------------