├── .coveragerc ├── .gitignore ├── .python-version ├── .readthedocs.yml ├── .travis.yml ├── CHANGELOG.md ├── CODEOWNERS ├── CONTRIBUTING.md ├── LICENSE.md ├── Makefile ├── README.md ├── docs ├── Makefile ├── conf.py ├── contrib_folder.rst ├── core_contrib.rst ├── distributions.rst ├── foretold.rst ├── gettingstarted.rst ├── index.rst ├── inference.rst ├── jupyter_colab.rst ├── metaculus.rst ├── nbs_contrib.rst ├── notebook_style.rst ├── predictit.rst └── tips_google_sheets.rst ├── ergo ├── __init__.py ├── conditions │ ├── __init__.py │ ├── condition.py │ ├── crossentropy.py │ ├── interval.py │ ├── maxentropy.py │ ├── mean.py │ ├── mode.py │ ├── point_density.py │ ├── smoothness.py │ ├── variance.py │ └── wasserstein.py ├── contrib │ ├── __init__.py │ ├── el_paso │ │ ├── __init__.py │ │ ├── brachbach.py │ │ ├── krismoore.py │ │ ├── onlyasith.py │ │ ├── shaman.py │ │ └── texas_data.py │ ├── metac_qs_data │ │ └── metac_qs_data.csv │ ├── predictit │ │ ├── __init__.py │ │ └── fuzzy_search.py │ └── utils │ │ ├── __init__.py │ │ ├── core.py │ │ └── utils.py ├── distributions │ ├── __init__.py │ ├── base.py │ ├── constants.py │ ├── distribution.py │ ├── logistic.py │ ├── logistic_mixture.py │ ├── optimizable.py │ ├── point_density.py │ └── truncate.py ├── platforms │ ├── __init__.py │ ├── foretold.py │ ├── metaculus │ │ ├── __init__.py │ │ ├── metaculus.py │ │ └── question │ │ │ ├── __init__.py │ │ │ ├── binary.py │ │ │ ├── constants.py │ │ │ ├── continuous.py │ │ │ ├── linear.py │ │ │ ├── lineardate.py │ │ │ ├── log.py │ │ │ ├── question.py │ │ │ └── types.py │ └── predictit.py ├── ppl.py ├── scale.py ├── static.py ├── theme.py └── utils.py ├── mypy.ini ├── notebooks ├── assorted-predictions.ipynb ├── community-distributions.ipynb ├── covid-19-active.ipynb ├── covid-19-average-lockdown.ipynb ├── covid-19-inference.ipynb ├── covid-19-lockdowns.ipynb ├── covid-19-metaculus.ipynb ├── covid-19-tests-august-2020.ipynb ├── el-paso-workflow.ipynb ├── el-paso.ipynb ├── foretold-submission.ipynb ├── generative-models.ipynb ├── metac_qs_data.ipynb ├── prediction-dashboard.ipynb ├── predictit_clean_sweep.ipynb ├── quickstart.ipynb ├── rejection_sampling.ipynb ├── scrubbed │ ├── assorted-predictions.ipynb │ ├── community-distributions.ipynb │ ├── covid-19-active.ipynb │ ├── covid-19-average-lockdown.ipynb │ ├── covid-19-inference.ipynb │ ├── covid-19-lockdowns.ipynb │ ├── covid-19-metaculus.ipynb │ ├── covid-19-tests-august-2020.ipynb │ ├── el-paso-workflow.ipynb │ ├── el-paso.ipynb │ ├── foretold-submission.ipynb │ ├── generative-models.ipynb │ ├── metac_qs_data.ipynb │ ├── prediction-dashboard.ipynb │ ├── quickstart.ipynb │ ├── rejection_sampling.ipynb │ └── test-mixture-fitting.ipynb └── test-mixture-fitting.ipynb ├── poetry.lock ├── pull_request_template.md ├── pyproject.toml ├── pytest.ini ├── scripts ├── scrub_notebooks.py └── scrub_src.py ├── setup.cfg ├── setup.py └── tests ├── __init__.py ├── conftest.py ├── test_conditions.py ├── test_foretold.py ├── test_jax.py ├── test_logistic.py ├── test_mem.py ├── test_metaculus.py ├── test_point_density.py ├── test_ppl.py ├── test_predictit.py ├── test_rejection.py ├── test_scales.py └── utils.py /.coveragerc: -------------------------------------------------------------------------------- 1 | [run] 2 | omit = ergo/contrib/* -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | .vscode/ 2 | .mypy_cache/ 3 | .venv/ 4 | ergo.egg-info/ 5 | *__pycache__* 6 | dist/ 7 | conda/ 8 | docs/build 9 | .env 10 | .coverage 11 | .coverage.* 12 | .ipynb_checkpoints 13 | **.ipynb_checkpoints 14 | **.ob-jupyter* 15 | **.prof 16 | .idea -------------------------------------------------------------------------------- /.python-version: -------------------------------------------------------------------------------- 1 | 3.6.9 2 | -------------------------------------------------------------------------------- /.readthedocs.yml: -------------------------------------------------------------------------------- 1 | version: 2 2 | 3 | sphinx: 4 | configuration: docs/conf.py 5 | 6 | build: 7 | image: latest 8 | 9 | formats: all 10 | 11 | python: 12 | version: 3.6 13 | install: 14 | - method: pip 15 | path: . 16 | extra_requirements: 17 | - docs 18 | -------------------------------------------------------------------------------- /.travis.yml: -------------------------------------------------------------------------------- 1 | sudo: false 2 | language: python 3 | jobs: 4 | include: 5 | - python: 3.6.9 6 | script: make test_skip_metaculus 7 | - python: 3.7.9 8 | script: make test 9 | - python: 3.8.5 10 | script: make test_skip_metaculus 11 | before_install: 12 | - pip install poetry 13 | install: 14 | - poetry install 15 | script: 16 | - make lint 17 | - make docs 18 | after_success: 19 | - poetry run codecov 20 | -------------------------------------------------------------------------------- /CHANGELOG.md: -------------------------------------------------------------------------------- 1 | # Changelog 2 | 3 | All notable changes to this project will be documented in this file. 4 | 5 | The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/) 6 | 7 | ## [Unreleased] 8 | 9 | ## [0.8.5] - 2020-07-24 10 | 11 | ### Changed 12 | 13 | - PointDensity.cdf() now does not interpolate and instead returns the cdf to the nearest grid point. 14 | - The default grid density for PoinDensity distributions is now 400 (from 200) 15 | 16 | 17 | ## [0.8.4] - 2020-07-15 18 | 19 | ### Changed 20 | 21 | - All arguments passed to Scales are now `float`s and all Scale fields are assumed to be `float`s 22 | 23 | 24 | 25 | ## [0.8.3] - 2020-07-12 26 | 27 | ### Added 28 | 29 | - `PointDensity()` distribution. We now primarily operate on point densities located in the center of what used to be the bins in our histogram. 30 | - "denorm_xs_only" option on PointDensity egress methods. This returns normalized probability densities but on a denormalized x-axis, as Metaculus displays. 31 | - support for LogScale and Metaculus/Elicit LogScale questions 32 | 33 | ### Changed 34 | 35 | - We operate on 200 point densities at the center of bins evenly spaced from 0 to 1 on a normalized scale. If we are passed in anything besides this in from_pairs we interpolate to get 200 points placed in this manner. 36 | - We always operate on normalized points and normalized densities internally, using the `normed_xs` and `normed_densities` fields respectively. We denormalize the density before returning in PDF. 37 | - PointDensity.cdf(x) returns the cdf up to x (not the bin before x, as previously) 38 | 39 | ### Removed 40 | - the `Histrogram()` distribution 41 | 42 | -------------------------------------------------------------------------------- /CODEOWNERS: -------------------------------------------------------------------------------- 1 | * @stuhlmueller 2 | /notebooks/ @brachbach 3 | /tests/ @stuhlmueller 4 | /tests/test_scales.py @djthorne 5 | /tests/test_metaculus.py @brachbach 6 | /tests/test_logistic.py @djthorne 7 | /scripts/ @djthorne 8 | /ergo/scale.py @djthorne 9 | /ergo/theme.py @djthorne 10 | /ergo/platforms/ @stuhlmueller 11 | /ergo/platforms/metaculus/ @brachbach 12 | /ergo/distributions/ @stuhlmueller 13 | /ergo/contrib/ @brachbach 14 | /ergo/conditions/ @stuhlmueller 15 | /ergo/conditions/mean.py @uvafan 16 | /ergo/conditions/mode.py @uvafan 17 | /ergo/conditions/variance.py @uvafan 18 | /docs/ @brachbach 19 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | Ergo is an open source project and we love contributions! 2 | 3 | There are [many open issues](https://github.com/oughtinc/ergo/projects/1), including plenty that are [good for newcomers](https://github.com/oughtinc/ergo/issues?q=is%3Aissue+is%3Aopen+label%3A%22good+first+issue%22). 4 | 5 | Before you start implementation, please make a new issue or comment on an existing one to let us know what you're planning to do (so we can avoid duplicated work). You can also ping us at ergo@ought.org. 6 | 7 | ## I want to contribute to notebooks using Ergo 8 | 9 | You can help the Ergo project by: 10 | 11 | 1. Making your own innovative Colab/Jupyter Notebook using Ergo 12 | 2. Improving one of our [existing notebooks](/notebooks) 13 | 14 | If you'd like to improve one of our existing notebooks or submit your own notebook for inclusion in this repo, please follow [these instructions](https://ergo.ought.org/nbs_contrib.html). 15 | 16 | Even if you'd just like to make your own notebook, we'd love to see it! 17 | 18 | You can share it with us at ergo@ought.org. 19 | 20 | ## I want to contribute to Ergo core 21 | 22 | Read about how to get set up and make PRs [here](https://ergo.ought.org/en/latest/core_contrib.html). 23 | -------------------------------------------------------------------------------- /LICENSE.md: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 Ought 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | all: format lint xtest 2 | 3 | lint: FORCE ## Run isort, flake8, mypy and black (in check mode) 4 | poetry run isort -rc --check-only . 5 | poetry run flake8 6 | poetry run mypy . 7 | poetry run black . --check 8 | 9 | test: FORCE ## Run pytest 10 | poetry run python -m pytest --cov=ergo --ff --verbose -s --doctest-modules . 11 | 12 | test_skip_metaculus: FORCE ## Run pytest, but skip the Metaculus tests to avoid overburdening the Metaculus API 13 | poetry run python -m pytest --cov=ergo --ff --verbose -s --doctest-modules --ignore-glob='*test_metaculus.py' . 14 | 15 | xtest: FORCE ## Run pytest in parallel mode using xdist 16 | poetry run python -m pytest --cov=ergo --ff --verbose -s --doctest-modules -n auto . 17 | 18 | format: FORCE ## Run isort and black (rewriting files) 19 | poetry run isort -rc . 20 | poetry run black . 21 | 22 | docs: FORCE ## Build docs 23 | poetry run $(MAKE) -C docs html 24 | 25 | serve: FORCE ## Run Jupyter notebook server 26 | poetry run python -m jupyter lab 27 | 28 | scrub: FORCE ## Create scrubbed notebooks in notebooks/scrubbed from notebooks 29 | poetry run python scripts/scrub_notebooks.py notebooks notebooks/scrubbed 30 | 31 | scrub_src_only: FORCE ## Scrub notebooks in notebooks/scrubbed (without updating from notebooks) 32 | poetry run python scripts/scrub_src.py notebooks notebooks/scrubbed 33 | 34 | run_nb: FORCE ## scrub and run passed notebook 35 | poetry run python scripts/run_nb.py notebooks notebooks/src $(XFILE) 36 | 37 | .PHONY: help 38 | 39 | .DEFAULT_GOAL := help 40 | 41 | help: 42 | @grep -E '^[a-zA-Z_-]+:.*?## .*$$' $(MAKEFILE_LIST) | sort | awk 'BEGIN {FS = ":.*?## "}; {printf "\033[36m%-30s\033[0m %s\n", $$1, $$2}' 43 | 44 | FORCE: 45 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | [![Build Status](https://travis-ci.org/oughtinc/ergo.svg?branch=master)](https://travis-ci.org/oughtinc/ergo) [![Documentation Status](https://readthedocs.org/projects/ergo/badge/?version=latest)](https://ergo.ought.org/) [![Codecov Status](https://codecov.io/gh/oughtinc/ergo/branch/master/graph/badge.svg)](https://codecov.io/gh/oughtinc/ergo) 2 | 3 | # Ergo 4 | 5 | A Python library for integrating model-based and judgmental forecasting 6 | 7 | [Quickstart](#get-started-using-ergo) | [Docs](https://ergo.ought.org) | [Examples](#notebooks-using-ergo) 8 | 9 | **_Note: We're no longer actively developing Ergo._** 10 | 11 | ## Example 12 | 13 | We'll relate three questions on the [Metaculus](https://www.metaculus.com) crowd prediction platform using a generative model: 14 | 15 | ```py 16 | # Log into Metaculus 17 | metaculus = ergo.Metaculus(username="ought", password="") 18 | 19 | # Load three questions 20 | q_infections = metaculus.get_question(3529, name="Covid-19 infections in 2020") 21 | q_deaths = metaculus.get_question(3530, name="Covid-19 deaths in 2020") 22 | q_ratio = metaculus.get_question(3755, name="Covid-19 ratio of fatalities to infections") 23 | 24 | # Relate the three questions using a generative model 25 | def deaths_from_infections(): 26 | infections = q_infections.sample_community() 27 | ratio = q_ratio.sample_community() 28 | deaths = infections * ratio 29 | ergo.tag(deaths, "Covid-19 deaths in 2020") 30 | return deaths 31 | 32 | # Compute model predictions for the `deaths` question 33 | samples = ergo.run(deaths_from_infections, num_samples=5000) 34 | 35 | # Submit model predictions to Metaculus 36 | q_deaths.submit_from_samples(samples) 37 | ``` 38 | 39 | You can run the model [here](https://colab.research.google.com/github/oughtinc/ergo/blob/master/notebooks/community-distributions.ipynb). 40 | 41 | ## Get started using Ergo 42 | 43 | 1. Open [this Colab](https://colab.research.google.com/github/oughtinc/ergo/blob/master/notebooks/quickstart.ipynb) 44 | 2. Select "Runtime > Run all" in the menu 45 | 3. Edit the code to load other questions, improve the model, etc., and rerun 46 | 47 | ## Notebooks using Ergo 48 | 49 | This notebook is closest to a tutorial right now: 50 | 51 | - [El Paso workflow](notebooks/el-paso-workflow.ipynb) 52 | - This notebook shows multi-level decomposition, Metaculus community distributions, ensembling, and beta-binomial and log-normal distributions using part of the [El Paso Covid-19 model](notebooks/el-paso.ipynb). 53 | 54 | The notebooks below have been created at different points in time and use Ergo in inconsistent ways. Most are rough scratchpads of work-in-progress and haven't been cleaned up for public consumption: 55 | 56 | 1. [Relating Metaculus community distributions: Infections, Deaths, and IFR](notebooks/community-distributions.ipynb) 57 | 58 | - A notebook for the model shown above that uses a model to update Metaculus community distributions towards consistency 59 | 60 | 2. [Model-based predictions of Covid-19 spread](notebooks/covid-19-metaculus.ipynb) 61 | 62 | - End-to-end example: 63 | 1. Load multiple questions from Metaculus 64 | 2. Compute model predictions based on assumptions and external data 65 | 3. Submit predictions to Metaculus 66 | 67 | 3. [Model-based predictions of Covid-19 spread using inference from observed cases](notebooks/covid-19-inference.ipynb) 68 | 69 | - A version of the previous notebook that infers growth rates before and after lockdown decisions 70 | 71 | 4. [Metaculus questions data](notebooks/metac_qs_data.ipynb) 72 | 73 | - Get rich metadata on open Metaculus questions 74 | 75 | 5. [Prediction dashboard](notebooks/prediction-dashboard.ipynb) 76 | 77 | - Show Metaculus prediction results as a dataframe 78 | - Filter Metaculus questions by date and status. 79 | 80 | 6. [El Paso questions](notebooks/el-paso.ipynb) 81 | - Illustrates how to load all questions for a Metaculus category (in this case for the [El Paso series](https://pandemic.metaculus.com/questions/4161/el-paso-series-supporting-covid-19-response-planning-in-a-mid-sized-city/)) 82 | 83 | Outdated Ergo notebooks: 84 | 85 | 1. [Generative models in Ergo](notebooks/generative-models.ipynb) 86 | 87 | 2. [Predicting how long lockdowns will last in multiple locations](notebooks/covid-19-lockdowns.ipynb) 88 | 89 | 3. [Estimating the number of active Covid-19 infections in each country using multiple sources](notebooks/covid-19-active.ipynb) 90 | 91 | 4. [How long will the average American spend under lockdown?](notebooks/covid-19-average-lockdown.ipynb) 92 | 93 | 5. [Assorted COVID predictions](notebooks/assorted-predictions.ipynb) 94 | 95 | ## Local installation 96 | 97 | To install Ergo and its dependencies, we recommend PyEnv and Poetry: 98 | 99 | 1. Install [PyEnv](https://github.com/pyenv/pyenv-installer) for managing Python versions 100 | 2. Install the [Poetry](https://python-poetry.org/docs/) package manager 101 | 102 | Then: 103 | 104 | ``` 105 | mkdir my-ergo-project && cd my-ergo-project 106 | pyenv install 3.6.9 && pyenv local 3.6.9 107 | poetry init -n 108 | # Edit pyproject.toml to set python = "~3.6.9" 109 | poetry add git+https://github.com/oughtinc/ergo.git 110 | poetry install 111 | ``` 112 | 113 | Now Ergo is available in your project: 114 | 115 | ``` 116 | poetry run python 117 | >>> import ergo 118 | >>> ergo.flip(.5) 119 | DeviceArray(True, dtype=bool) 120 | ``` 121 | 122 | ## Contribute 123 | 124 | Ergo is an open source project and we love contributions! 125 | 126 | See our [instructions for contributors](CONTRIBUTING.md) for more. 127 | 128 | ## Philosophy 129 | 130 | The theory behind Ergo: 131 | 132 | 1. Many of the pieces necessary for good forecasting work are out there: 133 | - Prediction platforms 134 | - Probabilistic programming languages 135 | - Superforecasters + qualitative human judgments 136 | - Data science tools like numpy and pandas 137 | - Deep neural nets as expressive function approximators 138 | 2. But they haven't been connected yet in a productive workflow: 139 | - It's difficult to get data in and out of prediction platforms 140 | - Submitting questions to these platforms takes a long time 141 | - The questions on prediction platforms aren't connected to decisions, or even to other questions on the same platform 142 | - Human judgments don't scale 143 | - Models often can't take into account all relevant considerations 144 | - Workflows aren't made explicit so they can't be automated 145 | 3. This limits their potential: 146 | - Few people build models 147 | - Few people submit questions to prediction platforms, or predict on these platforms 148 | - Improvements to forecasting accrue slowly 149 | - Most decisions are not informed by systematic forecasts 150 | 4. Better infrastructure for forecasting can connect the pieces and help realize the potential of scalable high-quality forecasting 151 | 152 | ## Functionality 153 | 154 | Ergo is still at an early stage. Pre-alpha, or whatever the earliest possible stage is. Functionality and API are in flux. 155 | 156 | Here's what Ergo provides right now: 157 | 158 | - Express generative models in a probabilistic programming language 159 | - Ergo provides lightweight wrappers around [Pyro](https://pyro.ai) functions to make the models more readable 160 | - Specify distributions using 90% confidence intervals, e.g. `ergo.lognormal_from_interval(10, 100)` 161 | - For Bayesian inference, Ergo provides a wrapper around Pyro's variational inference algorithm 162 | - Get model results as Pandas dataframes 163 | - Interact with the [Metaculus](https://www.metaculus.com/) and [Foretold](https://www.foretold.io/) prediction platforms 164 | - Load question data given question ids 165 | - Use community distributions as variables in generative models 166 | - Submit model predictions to these platforms 167 | - For Metaculus, we automatically fit a mixture of logistic distributions for continuous-valued questions 168 | - Plot community distributions 169 | 170 | [WIP](https://github.com/oughtinc/ergo/projects/1): 171 | 172 | - Documentation 173 | - Clearer modeling API 174 | 175 | Planned: 176 | 177 | - Interfaces for all prediction platforms 178 | - Search questions on prediction platforms 179 | - Use distributions from any platform 180 | - Programmatically submit questions to platforms 181 | - Track community distribution changes 182 | - Common model components 183 | - Index/ensemble models that summarize fuzzy large questions like "What's going to happen with the economy next year?" 184 | - Model components for integrating qualitative adjustments into quantitative models 185 | - Simple probability decomposition models 186 | - E.g. see [The Model Thinker](https://www.amazon.com/Model-Thinker-What-Need-Know/dp/0465094627) (Scott Page) 187 | - Better tools for integrating models and platforms 188 | - Compute model-based predictions by constraining model variables to be close to the community distributions 189 | - Push/pull to and from repositories for generative models 190 | - Think [Forest](http://forestdb.org) + Github 191 | 192 | If there's something you want Ergo to do, [let us know](https://github.com/oughtinc/ergo/issues)! 193 | -------------------------------------------------------------------------------- /docs/Makefile: -------------------------------------------------------------------------------- 1 | # Minimal makefile for Sphinx documentation 2 | # 3 | 4 | # You can set these variables from the command line. 5 | SPHINXOPTS = 6 | SPHINXBUILD = sphinx-build 7 | SPHINXPROJ = ergo 8 | SOURCEDIR = "." 9 | BUILDDIR = build 10 | 11 | # Put it first so that "make" without argument is like "make help". 12 | help: 13 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 14 | 15 | .PHONY: help Makefile 16 | 17 | # Catch-all target: route all unknown targets to Sphinx using the new 18 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). 19 | %: Makefile 20 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 21 | -------------------------------------------------------------------------------- /docs/conf.py: -------------------------------------------------------------------------------- 1 | # Copyright Contributors to the Pyro project. 2 | # SPDX-License-Identifier: Apache-2.0 3 | 4 | import os 5 | import sys 6 | 7 | import sphinx_rtd_theme 8 | 9 | 10 | # import pkg_resources 11 | 12 | # -*- coding: utf-8 -*- 13 | # 14 | # Configuration file for the Sphinx documentation builder. 15 | # 16 | # This file does only contain a selection of the most common options. For a 17 | # full list see the documentation: 18 | # http://www.sphinx-doc.org/en/master/config 19 | 20 | # -- Path setup -------------------------------------------------------------- 21 | 22 | # If extensions (or modules to document with autodoc) are in another directory, 23 | # add these directories to sys.path here. If the directory is relative to the 24 | # documentation root, use os.path.abspath to make it absolute, like shown here. 25 | # 26 | sys.path.insert(0, os.path.abspath("../..")) 27 | 28 | 29 | os.environ["SPHINX_BUILD"] = "1" 30 | 31 | 32 | # -- Project information ----------------------------------------------------- 33 | 34 | project = u"Ergo" 35 | copyright = u"2020, Ought, Inc" 36 | author = u"Ergo contributors" 37 | 38 | version = "" 39 | 40 | if "READTHEDOCS" not in os.environ: 41 | # if developing locally, use ergo.__version__ as version 42 | from ergo import __version__ # noqaE402 43 | 44 | version = __version__ 45 | 46 | # release version 47 | release = version 48 | 49 | 50 | # -- General configuration --------------------------------------------------- 51 | 52 | # If your documentation needs a minimal Sphinx version, state it here. 53 | # 54 | # needs_sphinx = '1.0' 55 | 56 | # Add any Sphinx extension module names here, as strings. They can be 57 | # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom 58 | # ones. 59 | extensions = [ 60 | "sphinx.ext.autodoc", 61 | "sphinx.ext.doctest", 62 | "sphinx.ext.intersphinx", 63 | "sphinx.ext.mathjax", 64 | "sphinx.ext.viewcode", 65 | "sphinx_autodoc_typehints", 66 | ] 67 | 68 | # Disable documentation inheritance so as to avoid inheriting 69 | # docstrings in a different format, e.g. when the parent class 70 | # is a PyTorch class. 71 | 72 | autodoc_inherit_docstrings = False 73 | 74 | # autodoc_default_options = { 75 | # 'member-order': 'bysource', 76 | # 'show-inheritance': True, 77 | # 'special-members': True, 78 | # 'undoc-members': True, 79 | # 'exclude-members': '__dict__,__module__,__weakref__', 80 | # } 81 | 82 | # Add any paths that contain templates here, relative to this directory. 83 | templates_path = ["_templates"] 84 | 85 | # The suffix(es) of source filenames. 86 | # You can specify multiple suffix as a list of string: 87 | # 88 | # source_suffix = ['.rst', '.md'] 89 | source_suffix = ".rst" 90 | 91 | # The master toctree document. 92 | master_doc = "index" 93 | 94 | # The language for content autogenerated by Sphinx. Refer to documentation 95 | # for a list of supported languages. 96 | # 97 | # This is also used if you do content translation via gettext catalogs. 98 | # Usually you set "language" from the command line for these cases. 99 | language = None 100 | 101 | # List of patterns, relative to source directory, that match files and 102 | # directories to ignore when looking for source files. 103 | # This pattern also affects html_static_path and html_extra_path . 104 | exclude_patterns = [] 105 | 106 | # The name of the Pygments (syntax highlighting) style to use. 107 | pygments_style = "sphinx" 108 | 109 | 110 | # do not prepend module name to functions 111 | add_module_names = False 112 | 113 | # -- Options for HTML output ------------------------------------------------- 114 | 115 | # The theme to use for HTML and HTML Help pages. See the documentation for 116 | # a list of builtin themes. 117 | # 118 | html_theme = "sphinx_rtd_theme" 119 | html_theme_path = [sphinx_rtd_theme.get_html_theme_path()] 120 | 121 | # Theme options are theme-specific and customize the look and feel of a theme 122 | # further. For a list of options available for each theme, see the 123 | # documentation. 124 | # 125 | # html_theme_options = {} 126 | 127 | # Add any paths that contain custom static files (such as style sheets) here, 128 | # relative to this directory. They are copied after the builtin static files, 129 | # so a file named "default.css" will overwrite the builtin "default.css". 130 | html_static_path = [] 131 | 132 | # Custom sidebar templates, must be a dictionary that maps document names 133 | # to template names. 134 | # 135 | # The default sidebars (for documents that don't match any pattern) are 136 | # defined by theme itself. Builtin themes are using these templates by 137 | # default: ``['localtoc.html', 'relations.html', 'sourcelink.html', 138 | # 'searchbox.html']``. 139 | # 140 | # html_sidebars = {} 141 | 142 | 143 | # -- Options for HTMLHelp output --------------------------------------------- 144 | 145 | # Output file base name for HTML help builder. 146 | htmlhelp_basename = "ergodoc" 147 | 148 | 149 | # -- Options for LaTeX output ------------------------------------------------ 150 | 151 | latex_elements = { 152 | # The paper size ('letterpaper' or 'a4paper'). 153 | # 154 | # 'papersize': 'letterpaper', 155 | # The font size ('10pt', '11pt' or '12pt'). 156 | # 157 | # 'pointsize': '10pt', 158 | # Additional stuff for the LaTeX preamble. 159 | # 160 | # 'preamble': '', 161 | # Latex figure (float) alignment 162 | # 163 | # 'figure_align': 'htbp', 164 | } 165 | 166 | # Grouping the document tree into LaTeX files. List of tuples 167 | # (source start file, target name, title, 168 | # author, documentclass [howto, manual, or own class]). 169 | latex_documents = [ 170 | (master_doc, "Ergo.tex", u"Ergo Documentation", u"Ergo contributors", "manual"), 171 | ] 172 | 173 | # -- Options for manual page output ------------------------------------------ 174 | 175 | # One entry per manual page. List of tuples 176 | # (source start file, name, description, authors, manual section). 177 | man_pages = [(master_doc, "Ergo", u"Ergo Documentation", [author], 1)] 178 | 179 | # -- Options for Texinfo output ---------------------------------------------- 180 | 181 | # Grouping the document tree into Texinfo files. List of tuples 182 | # (source start file, target name, title, author, 183 | # dir menu entry, description, category) 184 | texinfo_documents = [ 185 | ( 186 | master_doc, 187 | "Ergo", 188 | u"Ergo Documentation", 189 | author, 190 | "Ergo", 191 | "Integrating judgmental and model-based forecasting", 192 | "Miscellaneous", 193 | ), 194 | ] 195 | 196 | 197 | # -- Extension configuration ------------------------------------------------- 198 | 199 | # -- Options for intersphinx extension --------------------------------------- 200 | 201 | intersphinx_mapping = { 202 | "python": ("https://docs.python.org/3/", None), 203 | "numpy": ("http://docs.scipy.org/doc/numpy/", None), 204 | "jax": ("https://jax.readthedocs.io/en/latest/", None), 205 | "scipy": ("https://docs.scipy.org/doc/scipy/reference", None), 206 | } 207 | -------------------------------------------------------------------------------- /docs/contrib_folder.rst: -------------------------------------------------------------------------------- 1 | Notebook contrib folder 2 | ======================= 3 | 4 | Adding new packages 5 | ------------------- 6 | 7 | For modules providing functionality specific to the questions 8 | addressed in a notebook, create a new package in contrib 9 | ``/ergo/contrib/{your_package}`` and include an ``__init__.py`` 10 | file. You can then access it in your notebook with: 11 | 12 | .. code-block:: python 13 | 14 | from ergo.contrib.{your_package} import {module_you_want} 15 | 16 | For modules providing more general functionality of use across 17 | notebooks (and perhaps a candidate for inclusion in core ergo), you 18 | can use ``/ergo/contrib/utils``. You can either add a new module or 19 | extend an existing one. You can then access it with: 20 | 21 | .. code-block:: python 22 | 23 | from ergo.contrib.utils import {module_you_want} 24 | 25 | Adding dependencies 26 | ------------------- 27 | 28 | 1. Usual poetry way with --optional flag 29 | 30 | .. code-block:: bash 31 | 32 | poetry add {pendulum} --optional 33 | 34 | 2. You can then (manually in the ``pyproject.toml``) add it to the 35 | 'notebook' group 36 | 37 | (Look for "extras" in ``pyproject.toml``) 38 | 39 | .. code-block:: toml 40 | 41 | [tool.poetry.extras] 42 | notebooks = [ 43 | "pendulum", 44 | "scikit-learn", 45 | "{your_dependency}" 46 | ] 47 | 48 | 49 | (To my knowledge) there is no way currently to do this second step 50 | with the CLI. 51 | 52 | This allows people to then install the additional 53 | notebook dependencies with: 
 54 | .. code-block:: bash 55 | 56 | poetry install -E notebooks 57 | -------------------------------------------------------------------------------- /docs/core_contrib.rst: -------------------------------------------------------------------------------- 1 | Contribute to Ergo core 2 | ======================= 3 | 4 | To get started: 5 | 6 | 1. ``git clone https://github.com/oughtinc/ergo.git`` 7 | 2. ``poetry install`` 8 | 3. ``poetry shell`` 9 | 10 | ``poetry`` 11 | ---------- 12 | Ergo uses poetry to manage its dependencies and environments. 13 | 14 | Follow these directions_ to install poetry if you don't already have it. 15 | 16 | Troubleshooting: If you get ``Could not find a version that satisfies the requirement jaxlib ...`` after using poetry to install, this is probably because your virtual environment has old version of pip due to how poetry choses pip versions_. 17 | 18 | Try: 19 | 20 | 1. ``poetry run pip install -U pip`` 21 | 2. ``poetry install`` again 22 | 23 | .. _directions: https://python-poetry.org/docs/#installation 24 | .. _versions: https://github.com/python-poetry/poetry/issues/732 25 | 26 | Before submitting a PR 27 | ---------------------- 28 | 29 | 1. Run ``poetry install`` to make sure you have the latest dependencies 30 | 2. Format code using ``make format`` (black, isort) 31 | 3. Run linting using ``make lint`` (flake8, mypy, black check) 32 | 4. Run tests using ``make test`` 33 | 34 | * To run the tests in ``test_metaculus.py``, you'll need our secret `.env` file_. 35 | If you don't have it, you can ask us for it, or rely on Travis CI to run those tests for you. 36 | 37 | 5. Generate docs using ``make docs``, load 38 | ``docs/build/html/index.html`` and review the generated docs 39 | 6. Or run all of the above using ``make all`` 40 | 41 | .. _file: https://docs.google.com/document/d/1_r_DrCumtO3oKaG2BryyzanexWPiwgtrcx9fxiNBgD4/edit 42 | 43 | Conventions 44 | ----------- 45 | 46 | Import ``numpy`` as follows: 47 | 48 | 49 | .. code-block:: python 50 | 51 | import jax.numpy as np 52 | import numpy as onp 53 | 54 | -------------------------------------------------------------------------------- /docs/distributions.rst: -------------------------------------------------------------------------------- 1 | Distributions 2 | ============= 3 | 4 | normal 5 | ------ 6 | .. autofunction:: ergo.distributions.base.normal 7 | 8 | normal_from_interval 9 | -------------------- 10 | .. autofunction:: ergo.distributions.base.normal_from_interval 11 | 12 | lognormal 13 | --------- 14 | .. autofunction:: ergo.distributions.base.lognormal 15 | 16 | lognormal_from_interval 17 | ----------------------- 18 | .. autofunction:: ergo.distributions.base.lognormal_from_interval 19 | 20 | uniform 21 | ------- 22 | .. autofunction:: ergo.distributions.base.uniform 23 | 24 | beta 25 | ---- 26 | .. autofunction:: ergo.distributions.base.beta 27 | 28 | beta_from_hits 29 | -------------- 30 | .. autofunction:: ergo.distributions.base.beta_from_hits 31 | 32 | categorical 33 | ----------- 34 | .. autofunction:: ergo.distributions.base.categorical 35 | 36 | halfnormal 37 | ------------------------ 38 | .. autofunction:: ergo.distributions.base.halfnormal 39 | 40 | halfnormal_from_interval 41 | ------------------------ 42 | .. autofunction:: ergo.distributions.base.halfnormal_from_interval 43 | 44 | random_choice 45 | ------------- 46 | .. autofunction:: ergo.distributions.base.random_choice 47 | 48 | random_integer 49 | -------------- 50 | .. autofunction:: ergo.distributions.base.random_integer 51 | 52 | flip 53 | ---- 54 | .. autofunction:: ergo.distributions.base.flip 55 | -------------------------------------------------------------------------------- /docs/foretold.rst: -------------------------------------------------------------------------------- 1 | Foretold 2 | ======== 3 | 4 | Foretold 5 | --------- 6 | .. autoclass:: ergo.platforms.foretold.Foretold 7 | :members: 8 | 9 | ForetoldQuestion 10 | ----------------- 11 | .. autoclass:: ergo.platforms.foretold.ForetoldQuestion 12 | :members: 13 | -------------------------------------------------------------------------------- /docs/gettingstarted.rst: -------------------------------------------------------------------------------- 1 | Getting Started 2 | =============== 3 | 4 | 1. To get started with a template to work from, load this Colab notebook_. 5 | 2. For more information about ergo, see the README_. 6 | 3. See the sections below to learn more about using ergo. 7 | 4. To learn about contributing, read our CONTRIBUTING.md_. 8 | 9 | .. _README: https://github.com/oughtinc/ergo/blob/master/README.md 10 | .. _notebook: https://colab.research.google.com/github/oughtinc/ergo/blob/master/notebooks/quickstart.ipynb 11 | .. _CONTRIBUTING.md: https://github.com/oughtinc/ergo/blob/master/CONTRIBUTING.md -------------------------------------------------------------------------------- /docs/index.rst: -------------------------------------------------------------------------------- 1 | :github_url: https://github.com/oughtinc/ergo 2 | 3 | 4 | Ergo documentation 5 | ================== 6 | 7 | `Ergo `_ is a Python library for integrating model-based and judgmental forecasting. 8 | 9 | .. toctree:: 10 | :glob: 11 | :maxdepth: 2 12 | :caption: Usage 13 | 14 | gettingstarted 15 | 16 | .. toctree:: 17 | :glob: 18 | :maxdepth: 2 19 | :caption: Prediction Platforms 20 | 21 | metaculus 22 | foretold 23 | predictit 24 | 25 | .. toctree:: 26 | :glob: 27 | :maxdepth: 2 28 | :caption: Models 29 | 30 | inference 31 | distributions 32 | 33 | .. toctree:: 34 | :glob: 35 | :maxdepth: 2 36 | :caption: Contribute to ergo core 37 | 38 | core_contrib 39 | 40 | .. toctree:: 41 | :glob: 42 | :maxdepth: 2 43 | :caption: Contribute to ergo notebooks 44 | 45 | nbs_contrib 46 | jupyter_colab 47 | notebook_style 48 | contrib_folder 49 | 50 | .. toctree:: 51 | :glob: 52 | :maxdepth: 2 53 | :caption: Tips 54 | 55 | tips_google_sheets 56 | -------------------------------------------------------------------------------- /docs/inference.rst: -------------------------------------------------------------------------------- 1 | Inference 2 | ========= 3 | 4 | tag 5 | --- 6 | .. autofunction:: ergo.ppl.tag 7 | 8 | run 9 | --- 10 | .. autofunction:: ergo.ppl.run 11 | 12 | 13 | -------------------------------------------------------------------------------- /docs/jupyter_colab.rst: -------------------------------------------------------------------------------- 1 | Run a notebook in Colab or JupyterLab 2 | ===================================== 3 | 4 | Colab 5 | ----- 6 | 1. Go to https://colab.research.google.com/: 7 | 8 | 1. click "GitHub" on the "new notebook" dialog, then enter the notebook URL. Or: 9 | 2. go to "Upload" and upload the notebooks ``ipynb`` file. Or: 10 | 11 | 2. Install and use the Open in Colab Chrome extension_ 12 | 13 | JupyterLab 14 | ---------- 15 | 1. ``git clone https://github.com/oughtinc/ergo.git`` 16 | 2. ``poetry install`` 17 | 3. ``poetry shell`` 18 | 4. ``jupyter lab`` 19 | 20 | .. _extension: https://chrome.google.com/webstore/detail/open-in-colab/iogfkhleblhcpcekbiedikdehleodpjo?hl=en -------------------------------------------------------------------------------- /docs/metaculus.rst: -------------------------------------------------------------------------------- 1 | Metaculus 2 | ========= 3 | .. automodule:: ergo.platforms.metaculus 4 | 5 | Metaculus 6 | --------- 7 | .. autoclass:: ergo.platforms.metaculus.Metaculus 8 | :members: get_question, get_questions 9 | 10 | 11 | MetaculusQuestion 12 | ----------------- 13 | .. autoclass:: ergo.platforms.metaculus.question.MetaculusQuestion 14 | :members: 15 | 16 | ContinuousQuestion 17 | ------------------ 18 | .. autoclass:: ergo.platforms.metaculus.question.ContinuousQuestion 19 | :members: 20 | 21 | LinearQuestion 22 | -------------- 23 | .. autoclass:: ergo.platforms.metaculus.question.LinearQuestion 24 | :members: 25 | 26 | LogQuestion 27 | ----------- 28 | .. autoclass:: ergo.platforms.metaculus.question.LogQuestion 29 | :members: 30 | 31 | LinearDateQuestion 32 | ------------------ 33 | .. autoclass:: ergo.platforms.metaculus.question.LinearDateQuestion 34 | :members: 35 | 36 | BinaryQuestion 37 | -------------- 38 | .. autoclass:: ergo.platforms.metaculus.question.BinaryQuestion 39 | :members: 40 | -------------------------------------------------------------------------------- /docs/nbs_contrib.rst: -------------------------------------------------------------------------------- 1 | Contribute to Ergo notebooks 2 | ============================ 3 | 4 | How to change a notebook and make a PR 5 | -------------------------------------- 6 | 1. Open the `notebook`_ in JupyterLab or Colab (:doc:`jupyter_colab`) 7 | 2. Make your changes 8 | 3. Follow our :doc:`notebook_style` 9 | 4. Run the notebook in Colab. Save the .ipynb file (with output) in ``ergo/notebooks`` 10 | 5. Run `make scrub`. This will produce a scrubbed version of the notebook in `ergo/notebooks/scrubbed`/. 11 | 12 | 1. You can `git diff` the scrubbed version against the previous scrubbed version 13 | to more easily see what you changed 14 | 15 | 2. You may want to use nbdime_ for better diffing 16 | 17 | 6. You can now make a PR with your changes. If you make a PR in the original ergo repo 18 | (not a fork), you can then use the auto-comment from ReviewNB to more thoroughly vet your changes 19 | 20 | .. _notebook: https://github.com/oughtinc/ergo/tree/master/notebooks 21 | .. _nbdime: https://nbdime.readthedocs.io/en/latest/ 22 | -------------------------------------------------------------------------------- /docs/notebook_style.rst: -------------------------------------------------------------------------------- 1 | Notebook Style 2 | ============== 3 | 4 | How to clean up a notebook for us to feature: 5 | 6 | 1. Make sure that the notebook meets a high standard in general: 7 | 8 | 1. high-quality code 9 | 2. illuminating data analysis 10 | 3. clear communication of what you’re doing and your findings 11 | 4. as short as possible, but no shorter 12 | 5. this `random style guide`_ I found in a few minutes of Googling 13 | seems good, but it’s not our official style guide or anything 14 | 15 | 16 | 2. Do the following specific things to clean up: 17 | 18 | 1. as much as possible, avoid showing extraneous output from cells 19 | 20 | 1. you can use the ``%%capture`` magic to suppress all output 21 | from a cell (helpful if a function in the cell prints 22 | something) 23 | 2. you can add a ``;`` at the end of the last line in a cell to 24 | suppress printing the return value of the line 25 | 3. think about what cells the reader really needs to see 26 | vs. which ones just have to be there for setup or 27 | whatnot. Collapse the latter. 28 | 29 | 3. use the latest version of ``ergo`` 30 | 4. make sure that any secrets like passwords are removed from the 31 | notebook 32 | 5. Pull out any code not central to the main point of the model 33 | into a module in ``ergo/contrib/``. See :doc:`contrib_folder` for 34 | details. 35 | 36 | The featured notebooks in our README should be exemplars of the 37 | above, so refer to those to see what this looks like in practice. 38 | 39 | .. _random style guide: https://github.com/spacetelescope/style-guides/blob/master/guides/jupyter-notebooks.md 40 | .. _El Paso COVID predictions notebook: https://github.com/oughtinc/ergo/blob/master/notebooks/el-paso.ipynb 41 | -------------------------------------------------------------------------------- /docs/predictit.rst: -------------------------------------------------------------------------------- 1 | PredictIt 2 | ========= 3 | .. automodule:: ergo.platforms.predictit 4 | 5 | PredictIt 6 | --------- 7 | .. autoclass:: ergo.platforms.predictit.PredictIt 8 | :members: 9 | 10 | PredictItMarket 11 | ----------------- 12 | .. autoclass:: ergo.platforms.predictit.PredictItMarket 13 | :members: 14 | 15 | PredictItQuestion 16 | ----------------- 17 | .. autoclass:: ergo.platforms.predictit.PredictItQuestion 18 | :members: 19 | -------------------------------------------------------------------------------- /docs/tips_google_sheets.rst: -------------------------------------------------------------------------------- 1 | Loading data from Google Sheets 2 | =============================== 3 | Three methods for loading data from google sheets into a Colab Notebook 4 | 5 | 6 | Method 1 (Public CSV) 7 | --------------------- 8 | If you're willing to make your spreadsheet public, you can publish it as a CSV file on Google Sheets. Go to File > Publish to the Web, and select the CSV format. Then you can copy the published url, and load it in python using pandas. 9 | 10 | .. code-block:: python 11 | 12 | import pandas as pd 13 | df = pd.read_csv(url) 14 | 15 | Method 2 (OAuth) 16 | ---------------- 17 | This method requires the user of the colab to authorize it every time the colab runs, but can work with non-public sheets 18 | 19 | .. code-block:: python 20 | 21 | # Authentication 22 | import google 23 | google.colab.auth.authenticate_user() 24 | google_sheets_credentials = GoogleCredentials.get_application_default() 25 | gc = gspread.authorize(google_sheets_credentials) 26 | 27 | # Load spreadsheet 28 | wb = gc.open_by_url(url) 29 | sheet = wb.worksheet(sheet) 30 | values = sheet.get_all_values() 31 | 32 | Method 3 (Service Account) 33 | -------------------------- 34 | This method requires your to follow the instructions at https://gspread.readthedocs.io/en/latest/oauth2.html to create a google service account. You then need to share the google sheet with the service account email address. 35 | 36 | .. code-block:: 37 | 38 | # Need a newer version of gspread than included by default in Colab 39 | !pip install --upgrade gspread 40 | 41 | service_account_info = {} #JSON for google service account 42 | import gspread 43 | from google.oauth2.service_account import Credentials 44 | 45 | scope = ['https://spreadsheets.google.com/feeds', 46 | 'https://www.googleapis.com/auth/drive'] 47 | 48 | credentials = Credentials.from_service_account_info(service_account_info, scopes=scope) 49 | 50 | gc = gspread.authorize(credentials) 51 | 52 | # Load spreadsheet 53 | wb = gc.open_by_url(url) 54 | sheet = wb.worksheet(sheet) 55 | values = sheet.get_all_values() 56 | -------------------------------------------------------------------------------- /ergo/__init__.py: -------------------------------------------------------------------------------- 1 | __version__ = "0.8.4" 2 | 3 | import ergo.conditions 4 | import ergo.distributions 5 | import ergo.platforms 6 | import ergo.ppl 7 | import ergo.scale 8 | import ergo.static 9 | import ergo.theme 10 | import ergo.utils 11 | 12 | from .distributions import ( 13 | BetaFromHits, 14 | Logistic, 15 | LogisticMixture, 16 | LogNormalFromInterval, 17 | NormalFromInterval, 18 | PointDensity, 19 | Truncate, 20 | bernoulli, 21 | beta, 22 | beta_from_hits, 23 | categorical, 24 | flip, 25 | halfnormal_from_interval, 26 | lognormal, 27 | lognormal_from_interval, 28 | normal, 29 | normal_from_interval, 30 | random_choice, 31 | random_integer, 32 | uniform, 33 | ) 34 | from .platforms import ( 35 | Foretold, 36 | ForetoldQuestion, 37 | Metaculus, 38 | MetaculusQuestion, 39 | PredictIt, 40 | PredictItMarket, 41 | PredictItQuestion, 42 | ) 43 | from .ppl import condition, mem, run, sample, tag 44 | from .utils import to_float 45 | -------------------------------------------------------------------------------- /ergo/conditions/__init__.py: -------------------------------------------------------------------------------- 1 | from .condition import Condition 2 | from .crossentropy import CrossEntropyCondition, PartialCrossEntropyCondition 3 | from .interval import IntervalCondition 4 | from .maxentropy import MaxEntropyCondition 5 | from .mean import MeanCondition 6 | from .mode import ModeCondition 7 | from .point_density import PointDensityCondition 8 | from .smoothness import SmoothnessCondition 9 | from .variance import VarianceCondition 10 | from .wasserstein import WassersteinCondition 11 | -------------------------------------------------------------------------------- /ergo/conditions/condition.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | from typing import Any, Dict, Sequence, Tuple 3 | 4 | import jax.numpy as np 5 | import numpy as onp 6 | 7 | from ergo.scale import Scale 8 | import ergo.static as static 9 | 10 | 11 | def static_value(v): 12 | if isinstance(v, np.DeviceArray) or isinstance(v, onp.ndarray): 13 | return tuple(v) 14 | elif isinstance(v, tuple): 15 | return tuple(static_value(element) for element in v) 16 | else: 17 | return v 18 | 19 | 20 | class Condition(ABC): 21 | weight: float = 1.0 22 | 23 | def __init__(self, weight=1.0): 24 | self.weight = weight 25 | 26 | def __hash__(self): 27 | return hash(self.__key()) 28 | 29 | def __eq__(self, other): 30 | if isinstance(other, Condition): 31 | return self.__key() == other.__key() 32 | return NotImplemented 33 | 34 | def __key(self): 35 | cls, params = self.destructure() 36 | return (cls, tuple(static_value(param) for param in params)) 37 | 38 | def _describe_fit(self, dist) -> Dict[str, Any]: 39 | # convert to float for easy serialization 40 | return {"loss": self.loss(dist)} 41 | 42 | def normalize(self, scale: Scale): 43 | """ 44 | Assume that the condition's true range is scale. 45 | Return the normalized condition. 46 | 47 | :param scale: the true-scale 48 | :return: the condition normalized to [0,1] 49 | """ 50 | return self 51 | 52 | def denormalize(self, scale: Scale): 53 | """ 54 | Assume that the condition has been normalized to be over [0,1]. 55 | Return the condition on the true scale. 56 | 57 | :param scale: the true-scale 58 | :return: the condition on the true scale of [low, high] 59 | """ 60 | return self 61 | 62 | def describe_fit(self, dist) -> Dict[str, float]: 63 | """ 64 | Describe how well the distribution meets the condition 65 | 66 | :param dist: A probability distribution 67 | :return: A description of various aspects of how well 68 | the distribution meets the condition 69 | """ 70 | 71 | result = static.describe_fit(*dist.destructure(), *self.destructure()) 72 | return {k: float(v) for (k, v) in result.items()} 73 | 74 | @abstractmethod 75 | def loss(self, dist): 76 | """ 77 | Loss function for this condition when fitting a distribution. 78 | 79 | Should have max loss = 1 without considering weight 80 | Should multiply loss * weight 81 | 82 | :param dist: A probability distribution 83 | """ 84 | 85 | def shape_key(self): 86 | return (self.__class__.__name__,) 87 | 88 | @abstractmethod 89 | def destructure(self) -> Tuple["Condition", Sequence[Any]]: 90 | ... 91 | 92 | @classmethod 93 | def structure(cls, params) -> "Condition": 94 | class_params, numeric_params = params 95 | return cls(*numeric_params) 96 | -------------------------------------------------------------------------------- /ergo/conditions/crossentropy.py: -------------------------------------------------------------------------------- 1 | from jax import vmap 2 | import jax.numpy as np 3 | 4 | from ergo.distributions import point_density 5 | from ergo.scale import Scale 6 | 7 | from . import condition 8 | 9 | # TODO: Implement normalize/denormalize for CrossEntropyCondition 10 | 11 | 12 | class CrossEntropyCondition(condition.Condition): 13 | p_dist: point_density.PointDensity 14 | weight: float = 1.0 15 | 16 | def __init__(self, p_dist, weight=1.0): 17 | self.p_dist = p_dist 18 | super().__init__(weight) 19 | 20 | def loss(self, q_dist) -> float: 21 | return self.weight * self.p_dist.cross_entropy(q_dist) 22 | 23 | def destructure(self): 24 | dist_classes, dist_numeric = self.p_dist.destructure() 25 | cond_numeric = (self.weight,) 26 | return ((CrossEntropyCondition, dist_classes), (cond_numeric, dist_numeric)) 27 | 28 | @classmethod 29 | def structure(cls, params): 30 | class_params, numeric_params = params 31 | cond_class, dist_classes = class_params 32 | cond_numeric, dist_numeric = numeric_params 33 | dist_params = (dist_classes, dist_numeric) 34 | dist = dist_classes[0].structure(dist_params) 35 | return cls(dist, cond_numeric[0]) 36 | 37 | def __str__(self): 38 | return "Minimize the cross-entropy of the two distributions" 39 | 40 | def __repr__(self): 41 | return f"CrossEntropyCondition(p_dist={self.p_dist.normed_xs.size}, weight={self.weight})" 42 | 43 | 44 | class PartialCrossEntropyCondition(condition.Condition): 45 | """ 46 | Unlike CrossEntropyCondition, it's fine for (xs, ps) to 47 | only describe part of a distribution 48 | """ 49 | 50 | xs: np.DeviceArray 51 | ps: np.DeviceArray 52 | weight: float = 1.0 53 | 54 | def __init__(self, xs, ps, weight): 55 | self.xs = xs 56 | self.ps = ps 57 | super().__init__(weight) 58 | 59 | def loss(self, q_dist) -> float: 60 | q_logps = vmap(q_dist.logpdf)(self.xs) 61 | cross_entropy = -np.dot(self.ps, q_logps) 62 | return self.weight * cross_entropy 63 | 64 | def destructure(self): 65 | return (PartialCrossEntropyCondition, (self.xs, self.ps, self.weight)) 66 | 67 | @classmethod 68 | def structure(cls, params): 69 | return cls(params[0], params[1], params[2]) 70 | 71 | def __str__(self): 72 | return "Minimize the cross-entropy of the two distributions (p may be partial)" 73 | 74 | def normalize(self, scale: Scale): 75 | # TODO: Vectorization should be part of what Scale does 76 | # scale.normalize_point / scale.denormalize_point is pretty much 77 | # vectorized anyway 78 | normed_xs = vmap(scale.normalize_point)(self.xs) 79 | return PartialCrossEntropyCondition(normed_xs, self.ps, self.weight) 80 | 81 | def denormalize(self, scale: Scale): 82 | # TODO: Vectorization should be part of what Scale does 83 | denormed_xs = vmap(scale.denormalize_point)(self.xs) 84 | return PartialCrossEntropyCondition(denormed_xs, self.ps, self.weight) 85 | -------------------------------------------------------------------------------- /ergo/conditions/interval.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | 3 | from ergo.scale import Scale 4 | 5 | from . import condition 6 | 7 | 8 | class IntervalCondition(condition.Condition): 9 | """ 10 | The specified interval should include as close to the specified 11 | probability mass as possible 12 | 13 | :raises ValueError: max must be strictly greater than min 14 | """ 15 | 16 | p: float 17 | min: Optional[float] 18 | max: Optional[float] 19 | weight: float 20 | 21 | def __init__(self, p, min=None, max=None, weight=1.0): 22 | self.p = p 23 | self.min = min 24 | self.max = max 25 | super().__init__(weight) 26 | 27 | def actual_p(self, dist) -> float: 28 | cdf_at_min = dist.cdf(self.min) if self.min is not None else 0 29 | cdf_at_max = dist.cdf(self.max) if self.max is not None else 1 30 | return cdf_at_max - cdf_at_min 31 | 32 | def loss(self, dist): 33 | actual_p = self.actual_p(dist) 34 | return self.weight * (actual_p - self.p) ** 2 35 | 36 | def _describe_fit(self, dist): 37 | description = super()._describe_fit(dist) 38 | description["p_in_interval"] = self.actual_p(dist) 39 | return description 40 | 41 | def normalize(self, scale: Scale): 42 | normalized_min = ( 43 | scale.normalize_point(self.min) if self.min is not None else None 44 | ) 45 | normalized_max = ( 46 | scale.normalize_point(self.max) if self.max is not None else None 47 | ) 48 | return self.__class__(self.p, normalized_min, normalized_max, self.weight) 49 | 50 | def denormalize(self, scale: Scale): 51 | denormalized_min = ( 52 | scale.denormalize_point(self.min) if self.min is not None else None 53 | ) 54 | denormalized_max = ( 55 | scale.denormalize_point(self.max) if self.max is not None else None 56 | ) 57 | return self.__class__(self.p, denormalized_min, denormalized_max, self.weight) 58 | 59 | def destructure(self): 60 | return ((IntervalCondition,), (self.p, self.min, self.max, self.weight)) 61 | 62 | def shape_key(self): 63 | return (self.__class__.__name__, self.min is None, self.max is None) 64 | 65 | def __str__(self): 66 | return f"There is a {self.p:.0%} chance that the value is in [{self.min}, {self.max}]" 67 | 68 | def __repr__(self): 69 | return f"IntervalCondition(p={self.p}, min={self.min}, max={self.max}, weight={self.weight})" 70 | -------------------------------------------------------------------------------- /ergo/conditions/maxentropy.py: -------------------------------------------------------------------------------- 1 | from . import condition 2 | 3 | 4 | class MaxEntropyCondition(condition.Condition): 5 | def loss(self, dist) -> float: 6 | return -self.weight * dist.entropy() 7 | 8 | def destructure(self): 9 | return ((MaxEntropyCondition,), (self.weight,)) 10 | 11 | def __str__(self): 12 | return "Maximize the entropy of the distribution" 13 | 14 | def __repr__(self): 15 | return f"MaxEntropyCondition(weight={self.weight})" 16 | -------------------------------------------------------------------------------- /ergo/conditions/mean.py: -------------------------------------------------------------------------------- 1 | from ergo.scale import Scale 2 | 3 | from . import condition 4 | 5 | 6 | class MeanCondition(condition.Condition): 7 | """ 8 | The distribution should have as close to the specified mean as possible. 9 | """ 10 | 11 | mean: float 12 | weight: float = 1.0 13 | 14 | def __init__(self, mean, weight=1.0): 15 | self.mean = mean 16 | super().__init__(weight) 17 | 18 | def actual_mean(self, dist) -> float: 19 | return dist.mean() 20 | 21 | def loss(self, dist) -> float: 22 | return self.weight * (self.actual_mean(dist) - self.mean) ** 2 23 | 24 | def _describe_fit(self, dist): 25 | description = super()._describe_fit(dist) 26 | description["mean"] = self.actual_mean(dist) 27 | return description 28 | 29 | def normalize(self, scale: Scale): 30 | normalized_mean = scale.normalize_point(self.mean) 31 | return self.__class__(normalized_mean, self.weight) 32 | 33 | def denormalize(self, scale: Scale): 34 | denormalized_mean = scale.denormalize_point(self.mean) 35 | return self.__class__(denormalized_mean, self.weight) 36 | 37 | def destructure(self): 38 | return ((MeanCondition,), (self.mean, self.weight)) 39 | 40 | def __str__(self): 41 | return f"The mean is {self.mean}." 42 | 43 | def __repr__(self): 44 | return f"MeanCondition(mean={self.mean}, weight={self.weight})" 45 | -------------------------------------------------------------------------------- /ergo/conditions/mode.py: -------------------------------------------------------------------------------- 1 | import jax.numpy as np 2 | 3 | from ergo.scale import Scale 4 | 5 | from . import condition 6 | 7 | 8 | class ModeCondition(condition.Condition): 9 | """ 10 | The specified outcome should be as close to being the most likely as possible. 11 | """ 12 | 13 | outcome: float 14 | weight: float = 1.0 15 | 16 | def __init__(self, outcome, weight=1.0): 17 | self.outcome = outcome 18 | super().__init__(weight) 19 | 20 | def loss(self, dist) -> float: 21 | p_outcome = dist.pdf(self.outcome) 22 | p_highest = np.max( 23 | dist.scale.denormalize_densities( 24 | dist.scale.denormalize_points(dist.normed_xs), dist.normed_densities 25 | ) 26 | ) 27 | return self.weight * (p_highest - p_outcome) ** 2 28 | 29 | def _describe_fit(self, dist): 30 | description = super()._describe_fit(dist) 31 | description["p_outcome"] = dist.pdf(self.outcome) 32 | description["p_highest"] = np.max( 33 | dist.scale.denormalize_densities( 34 | dist.scale.denormalize_points(dist.normed_xs), dist.normed_densities 35 | ) 36 | ) 37 | return description 38 | 39 | def normalize(self, scale: Scale): 40 | normalized_outcome = scale.normalize_point(self.outcome) 41 | return self.__class__(normalized_outcome, self.weight) 42 | 43 | def denormalize(self, scale: Scale): 44 | denormalized_outcome = scale.denormalize_point(self.outcome) 45 | return self.__class__(denormalized_outcome, self.weight) 46 | 47 | def destructure(self): 48 | return ((ModeCondition,), (self.outcome, self.weight)) 49 | 50 | def __str__(self): 51 | return f"The most likely outcome is {self.outcome}." 52 | 53 | def __repr__(self): 54 | return f"ModeCondition(outcome={self.outcome}, weight={self.weight})" 55 | -------------------------------------------------------------------------------- /ergo/conditions/point_density.py: -------------------------------------------------------------------------------- 1 | from jax import vmap 2 | import jax.numpy as np 3 | 4 | from ergo.scale import Scale 5 | 6 | from . import condition 7 | 8 | 9 | class PointDensityCondition(condition.Condition): 10 | """ 11 | The distribution should fit the specified histogram as closely as 12 | possible 13 | """ 14 | 15 | xs: np.DeviceArray 16 | densities: np.DeviceArray 17 | weight: float = 1.0 18 | 19 | def __init__(self, xs, densities, weight=1.0): 20 | self.xs = xs 21 | self.densities = densities 22 | super().__init__(weight) 23 | 24 | def loss(self, dist): 25 | entry_loss_fn = lambda x, density: (density - dist.pdf(x)) ** 2 # noqa: E731 26 | total_loss = np.sum(vmap(entry_loss_fn)(self.xs, self.densities)) 27 | return self.weight * total_loss / self.xs.size 28 | 29 | def normalize(self, scale: Scale): 30 | normed_xs = scale.normalize_points(self.xs) 31 | normed_densities = scale.normalize_densities(normed_xs, self.densities) 32 | return self.__class__(normed_xs, normed_densities, self.weight) 33 | 34 | def denormalize(self, scale: Scale): 35 | denormed_xs = scale.denormalize_points(self.xs) 36 | denormed_densities = scale.denormalize_densities(denormed_xs, self.densities) 37 | return self.__class__(denormed_xs, denormed_densities, self.weight) 38 | 39 | def destructure(self): 40 | return ((PointDensityCondition,), (self.xs, self.densities, self.weight)) 41 | 42 | def __key(self): 43 | return ( 44 | PointDensityCondition, 45 | (tuple(self.xs), tuple(self.densities), self.weight), 46 | ) 47 | 48 | def _describe_fit(self, dist): 49 | description = super()._describe_fit(dist) 50 | 51 | def entry_distance_fn(x, density): 52 | return abs(1.0 - density / dist.pdf(x)) 53 | 54 | distances = vmap(entry_distance_fn)(self.xs, self.densities) 55 | description["max_distance"] = np.max(distances) 56 | description["90th_distance"] = np.percentile(distances, 90) 57 | description["mean_distance"] = np.mean(distances) 58 | return description 59 | 60 | def __str__(self): 61 | return "The probability density function looks similar to the provided density function." 62 | -------------------------------------------------------------------------------- /ergo/conditions/smoothness.py: -------------------------------------------------------------------------------- 1 | import jax.numpy as np 2 | 3 | from ergo.utils import shift 4 | 5 | from . import condition 6 | 7 | 8 | class SmoothnessCondition(condition.Condition): 9 | def loss(self, dist) -> float: 10 | window_size = 5 11 | squared_distance = 0.0 12 | for i in range(1, window_size + 1): 13 | squared_distance += (1 / i ** 2) * np.sum( 14 | np.square( 15 | dist.normed_log_densities 16 | - shift(dist.normed_log_densities, i, dist.normed_log_densities[0]) 17 | ) 18 | ) 19 | return self.weight * squared_distance / dist.normed_log_densities.size 20 | 21 | def destructure(self): 22 | return ((SmoothnessCondition,), (self.weight,)) 23 | 24 | def __str__(self): 25 | return "Minimize rough edges in the distribution" 26 | 27 | def __repr__(self): 28 | return f"SmoothnessCondition(weight={self.weight})" 29 | -------------------------------------------------------------------------------- /ergo/conditions/variance.py: -------------------------------------------------------------------------------- 1 | import jax.numpy as np 2 | 3 | from ergo.scale import Scale 4 | 5 | from . import condition 6 | 7 | 8 | class VarianceCondition(condition.Condition): 9 | """ 10 | The distribution should have as close to the specified variance as possible. 11 | """ 12 | 13 | variance: float 14 | weight: float = 1.0 15 | 16 | def __init__(self, variance, weight=1.0): 17 | self.variance = variance 18 | super().__init__(weight) 19 | 20 | def loss(self, dist) -> float: 21 | return self.weight * (np.log(dist.variance()) - np.log(self.variance)) ** 2 22 | 23 | def _describe_fit(self, dist): 24 | description = super()._describe_fit(dist) 25 | description["variance"] = dist.variance() 26 | return description 27 | 28 | def normalize(self, scale: Scale): 29 | normalized_variance = scale.normalize_variance(self.variance) 30 | return self.__class__(normalized_variance, self.weight) 31 | 32 | def denormalize(self, scale: Scale): 33 | denormalized_variance = scale.denormalize_point(self.variance) 34 | return self.__class__(denormalized_variance, self.weight) 35 | 36 | def destructure(self): 37 | return ((VarianceCondition,), (self.variance, self.weight)) 38 | 39 | def __str__(self): 40 | return f"The variance is {self.variance}." 41 | 42 | def __repr__(self): 43 | return f"VarianceCondition(mean={self.variance}, weight={self.weight})" 44 | -------------------------------------------------------------------------------- /ergo/conditions/wasserstein.py: -------------------------------------------------------------------------------- 1 | from ergo.distributions import point_density 2 | import ergo.static as static 3 | 4 | from . import condition 5 | 6 | 7 | class WassersteinCondition(condition.Condition): 8 | p_dist: point_density.PointDensity 9 | weight: float = 1.0 10 | 11 | def __init__(self, p_dist, weight=1.0): 12 | self.p_dist = p_dist 13 | super().__init__(weight) 14 | 15 | def loss(self, q_dist) -> float: 16 | return self.weight * static.wasserstein_distance( 17 | self.p_dist.normed_densities, q_dist.normed_densities 18 | ) 19 | 20 | def destructure(self): 21 | dist_classes, dist_numeric = self.p_dist.destructure() 22 | cond_numeric = (self.weight,) 23 | return ((WassersteinCondition, dist_classes), (cond_numeric, dist_numeric)) 24 | 25 | @classmethod 26 | def structure(cls, params): 27 | class_params, numeric_params = params 28 | cond_class, dist_classes = class_params 29 | cond_numeric, dist_numeric = numeric_params 30 | dist_params = (dist_classes, dist_numeric) 31 | dist = dist_classes[0].structure(dist_params) 32 | return cls(dist, cond_numeric[0]) 33 | 34 | def __str__(self): 35 | return "Minimize the Wasserstein distance between the two distributions" 36 | -------------------------------------------------------------------------------- /ergo/contrib/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/oughtinc/ergo/c46f4ebd73bc7115771f39dd99cbd32fa7a3bde3/ergo/contrib/__init__.py -------------------------------------------------------------------------------- /ergo/contrib/el_paso/__init__.py: -------------------------------------------------------------------------------- 1 | __all__ = ["brachbach", "krismoore", "onlyasith", "texas_data", "shaman"] 2 | -------------------------------------------------------------------------------- /ergo/contrib/el_paso/brachbach.py: -------------------------------------------------------------------------------- 1 | from datetime import date, timedelta 2 | from typing import Callable 3 | 4 | import pandas as pd 5 | import sklearn 6 | 7 | 8 | def get_hospital_stay_days(): 9 | # from https://penn-chime.phl.io/ 10 | hospital_stay_days_point_estimate = 7 11 | 12 | # get_hospital_confirmed_from_daily_infected_model causes random choices 13 | # (because hospital_stay_days is random) 14 | # but we’re calling it outside of the model functions so it’s only called once. 15 | # this will give weird/wrong results. 16 | # moving hospital_confirmed_from_daily_infected_model = 17 | # get_hospital_confirmed_from_daily_infected_model(daily_infections) 18 | # inside hospital_confirmed_for_date will fix this 19 | # but will make the model pretty slow since each call to the model reruns regression 20 | # hospital_stay_days_fuzzed = round( 21 | # float( 22 | # ergo.normal_from_interval( 23 | # hospital_stay_days_point_estimate * 0.5, 24 | # hospital_stay_days_point_estimate * 1.5, 25 | # ) 26 | # ) 27 | # ) 28 | 29 | # return max(1, hospital_stay_days_fuzzed) 30 | 31 | return hospital_stay_days_point_estimate 32 | 33 | 34 | def get_daily_hospital_confirmed( 35 | hospital_data: pd.DataFrame, daily_infections_fn: Callable[[date], int] 36 | ): 37 | """ 38 | Use a linear regression to predict 39 | the number of patients with COVID currently in the hospital 40 | from the total number of new confirmed cases over the past several days 41 | 42 | :param data: dataframe with index of dates, 43 | columns of "In hospital confirmed" 44 | :return: A function to predict 45 | the number of confirmed cases of COVID in the hospital on a date, 46 | given the total number of confirmed cases for each date 47 | """ 48 | 49 | hospital_stay_days = get_hospital_stay_days() 50 | 51 | has_hospital_confirmed = hospital_data[hospital_data["In hospital confirmed"].notna()] # type: ignore 52 | 53 | data_dates = has_hospital_confirmed.index 54 | 55 | hospital_confirmed = has_hospital_confirmed["In hospital confirmed"] 56 | 57 | def get_recent_cases_data(date): 58 | """ 59 | How many new confirmed cases were there over the past hospital_stay_days days? 60 | """ 61 | return sum( 62 | [ 63 | daily_infections_fn(date - timedelta(n)) 64 | for n in range(0, hospital_stay_days) 65 | ] 66 | ) 67 | 68 | recent_cases = [[get_recent_cases_data(date)] for date in data_dates] 69 | 70 | reg = sklearn.linear_model.LinearRegression(fit_intercept=False).fit( 71 | recent_cases, hospital_confirmed 72 | ) 73 | # TODO: consider adding uncertainty to the fit here 74 | 75 | # now that we've related current hospitalized cases and recent confirmed cases, 76 | # return a function that allows us to predict hospitalized cases given estimates 77 | # of future confirmed cases 78 | def get_hospital_confirmed_from_daily_cases(date: date): 79 | recent_cases = get_recent_cases_data(date) 80 | return round(reg.predict([[recent_cases]])[0]) 81 | 82 | return get_hospital_confirmed_from_daily_cases 83 | -------------------------------------------------------------------------------- /ergo/contrib/el_paso/krismoore.py: -------------------------------------------------------------------------------- 1 | import matplotlib.pyplot as plt 2 | import pandas as pd 3 | import seaborn 4 | 5 | 6 | def get_krismoore_data(): 7 | """ 8 | Get data from Metaculus user @oKrisMoore's compilation of El Paso COVID data 9 | See this sheet for more information: 10 | https://docs.google.com/spreadsheets/d/1eGF9xYmDmvAkr-dCmd-N4efHzPyYEfVl0YmL9zBvH9Q/edit#gid=1694267458 11 | """ 12 | compiled_data = pd.read_csv( 13 | "https://docs.google.com/spreadsheets/d/e/2PACX-1vQEZk_8wZMF5MEm_f66wpev4nkWP7edQ8l6SwcbUd68zFZw6EVizh-jplw2_9gZBGyhNaJk5R_CG25k/pub?gid=0&single=true&output=csv", 14 | index_col="date", 15 | parse_dates=True, 16 | ) 17 | compiled_data = compiled_data.rename( 18 | columns={"in_hospital": "In hospital confirmed"} 19 | ) 20 | 21 | return compiled_data 22 | 23 | 24 | def graph_compiled_data(compiled_data): 25 | compiled_data_to_graph = compiled_data[ 26 | ["new_cases", "In hospital confirmed", "in_icu", "on_ventilator"] 27 | ].dropna() 28 | 29 | compiled_data_to_graph["date"] = compiled_data_to_graph.index 30 | 31 | melted_compiled = pd.melt( 32 | compiled_data_to_graph, id_vars=["date"], value_name="patients" 33 | ) 34 | 35 | ax = seaborn.lineplot(x="date", y="patients", hue="variable", data=melted_compiled) 36 | handles, labels = ax.get_legend_handles_labels() 37 | ax.legend(handles=handles[1:], labels=labels[1:]) 38 | 39 | plt.xticks(rotation=90) 40 | -------------------------------------------------------------------------------- /ergo/contrib/el_paso/onlyasith.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | 3 | 4 | def get_onlyasith_results(): 5 | """ 6 | Get results from Metaculus user @onlyasith's model of cases in El Paso 7 | See this sheet for more information: 8 | https://docs.google.com/spreadsheets/d/1L6pzFAEJ6MfnUwt-ea6tetKyvdi0YubnK_70SGm436c/edit#gid=1807978187 9 | """ 10 | projected_cases = pd.read_csv( 11 | "https://docs.google.com/spreadsheets/d/e/2PACX-1vSurcOWEsa7DBCRfONFA2Gxf802Rj1FebYSyVzvACysenRcD79Fs0ykXWJakIhGcW48_ymgw35TKga-/pub?gid=1213113172&single=true&output=csv", 12 | index_col="Date", 13 | parse_dates=True, 14 | ) 15 | 16 | projected_cases = projected_cases.dropna() 17 | projected_cases["Cases so far"] = projected_cases["Cases so far"].apply( 18 | lambda str: int(str) if str != " " else None 19 | ) 20 | projected_cases["New cases"] = projected_cases["Cases so far"].diff() 21 | 22 | return projected_cases 23 | -------------------------------------------------------------------------------- /ergo/contrib/el_paso/shaman.py: -------------------------------------------------------------------------------- 1 | from datetime import date 2 | from typing import Tuple 3 | 4 | import numpy as np 5 | import pandas as pd 6 | 7 | import ergo 8 | from ergo.contrib.utils import daterange 9 | 10 | 11 | def extract_projections_for_param( 12 | county: str, param: str, column_prefix: str, raw_us_projections 13 | ): 14 | raw_county_projections = raw_us_projections[raw_us_projections.county == county] 15 | metadata = raw_county_projections[["county", "fips", "Date"]] 16 | metadata["var"] = param 17 | percentiles = ["2.5", "25", "50", "75", "97.5"] 18 | projection_column_names = [ 19 | f"{column_prefix}_{percentile}" for percentile in percentiles 20 | ] 21 | projections = raw_county_projections[projection_column_names] 22 | projections.columns = percentiles 23 | return pd.concat([metadata, projections], axis=1) 24 | 25 | 26 | def load_cu_projections(county: str): 27 | """ 28 | The COVID model from the Shaman lab at Columbia projects 29 | "daily new confirmed case, daily new infection (both reported and unreported), 30 | cumulative demand of hospital beds, ICU and ventilators 31 | as well as daily mortality (2.5, 25, 50, 75 and 97.5 percentiles)" 32 | (https://github.com/shaman-lab/COVID-19Projection) 33 | 34 | Load this data 35 | """ 36 | # pd has this option but mypy doesn't know about it 37 | pd.options.mode.chained_assignment = None # type: ignore 38 | scenarios = ["nointerv", "60contact", "70contact", "80contact"] 39 | cu_model_data = {} 40 | for scenario in scenarios: 41 | # pd.read_csv has a parse_dates option but mypy doesn't know about it 42 | raw_cases_df = pd.read_csv( # type: ignore 43 | f"https://raw.githubusercontent.com/shaman-lab/COVID-19Projection/master/Projection_April26/Projection_{scenario}.csv", 44 | parse_dates=["Date"], 45 | ) 46 | cases_projections_df = extract_projections_for_param( 47 | county, "cases", "report", raw_cases_df 48 | ) 49 | 50 | raw_covid_effects_df = pd.read_csv( # type: ignore 51 | f"https://raw.githubusercontent.com/shaman-lab/COVID-19Projection/master/Projection_April26/bed_{scenario}.csv", 52 | parse_dates=["Date"], 53 | ) 54 | 55 | hospital_projections_df = extract_projections_for_param( 56 | county, "hosp", "hosp_need", raw_covid_effects_df 57 | ) 58 | icu_projections_df = extract_projections_for_param( 59 | county, "ICU", "ICU_need", raw_covid_effects_df 60 | ) 61 | vent_projections_df = extract_projections_for_param( 62 | county, "vent", "vent_need", raw_covid_effects_df 63 | ) 64 | deaths_projections_df = extract_projections_for_param( 65 | county, "deaths", "death", raw_covid_effects_df 66 | ) 67 | 68 | all_projections_df = pd.concat( 69 | [ 70 | cases_projections_df, 71 | hospital_projections_df, 72 | icu_projections_df, 73 | vent_projections_df, 74 | deaths_projections_df, 75 | ] 76 | ) 77 | all_projections_df["Date"] = all_projections_df["Date"].apply( 78 | lambda x: x.date() 79 | ) 80 | cu_model_data[scenario] = all_projections_df 81 | return cu_model_data 82 | 83 | 84 | # The below were added for the workflow/tutorial nb 85 | # they're not yet used in the main El Paso notebook 86 | @ergo.mem 87 | def cu_model_scenario(scenarios: Tuple[str]): 88 | """Which of the model scenarios are we in?""" 89 | return ergo.random_choice(scenarios) 90 | 91 | 92 | @ergo.mem 93 | def cu_model_quantile(): 94 | """Where in the distribution of model outputs are we for this model run? 95 | Want to be consistent across time, so we sample it once per model run""" 96 | return ergo.uniform() 97 | 98 | 99 | def cu_projection(param: str, date: date, cu_projections) -> int: 100 | """ 101 | Get the Columbia model's prediction 102 | of the param for the date 103 | """ 104 | scenario = cu_model_scenario(tuple([s for s in cu_projections.keys()])) 105 | quantile = cu_model_quantile() 106 | 107 | # Extract quantiles of the model distribution 108 | xs = np.array([0.025, 0.25, 0.5, 0.75, 0.975]) 109 | scenario_df = cu_projections[scenario] 110 | param_df = scenario_df[scenario_df["var"] == param] 111 | date_df = param_df[param_df["Date"] == date] 112 | if date_df.empty: 113 | raise KeyError(f"No Columbia project for param: {param}, date: {date}") 114 | 115 | ys = np.array(date_df[["2.5", "25", "50", "75", "97.5"]].iloc[0]) 116 | 117 | # Linearly interpolate 118 | # mypy doesn't know that there's an np.interp 119 | return int(round(np.interp(quantile, xs, ys))) # type: ignore 120 | 121 | 122 | def cu_projections_for_dates( 123 | param: str, start_date: date, end_date: date, cu_projections 124 | ): 125 | """ 126 | Get Columbia model projections over a range of dates 127 | """ 128 | date_range = daterange(start_date, end_date) 129 | return [cu_projection(param, date, cu_projections) for date in date_range] 130 | -------------------------------------------------------------------------------- /ergo/contrib/el_paso/texas_data.py: -------------------------------------------------------------------------------- 1 | from datetime import date 2 | import re 3 | 4 | import pandas as pd 5 | 6 | 7 | def get_el_paso_data(): 8 | """ 9 | Get El Paso COVID data from the Texas government's data at 10 | https://dshs.texas.gov/coronavirus/TexasCOVID19DailyCountyCaseCountData.xlsx 11 | """ 12 | texas_cases = pd.read_excel( 13 | "https://dshs.texas.gov/coronavirus/TexasCOVID19DailyCountyCaseCountData.xlsx" 14 | ) 15 | column_names = texas_cases.iloc[1] 16 | column_names[0] = "County Name" 17 | texas_cases.columns = column_names 18 | 19 | el_paso_cases = ( 20 | texas_cases.loc[texas_cases["County Name"] == "El Paso"] 21 | .drop(columns=["County Name", "Population"]) 22 | .transpose() 23 | ) 24 | 25 | el_paso_cases.columns = ["Cases so far"] 26 | 27 | def get_date(column_name): 28 | date_str = re.search("[0-9]{1,2}-[0-9]{1,2}", column_name).group(0) 29 | month_str, day_str = date_str.split("-") 30 | return date(2020, int(month_str), int(day_str)) 31 | 32 | el_paso_cases.index = [get_date(id) for id in el_paso_cases.index] 33 | 34 | el_paso_cases["New cases"] = el_paso_cases["Cases so far"].diff() 35 | 36 | return el_paso_cases 37 | -------------------------------------------------------------------------------- /ergo/contrib/predictit/__init__.py: -------------------------------------------------------------------------------- 1 | from .fuzzy_search import search_market, search_question 2 | -------------------------------------------------------------------------------- /ergo/contrib/predictit/fuzzy_search.py: -------------------------------------------------------------------------------- 1 | import operator 2 | import re 3 | from typing import List, Tuple 4 | 5 | from ergo import PredictIt, PredictItMarket, PredictItQuestion 6 | from fuzzywuzzy import fuzz 7 | 8 | 9 | def _get_name_matches(name: str, guess_words: List[str]) -> int: 10 | """ 11 | Return the number of common words in a str and list of words 12 | :param name: 13 | :param guess_words: 14 | :return: number of matches 15 | """ 16 | matches = sum(word in name for word in guess_words) 17 | return matches 18 | 19 | 20 | def _get_name_score(names: List[str], guess: str) -> int: 21 | """ 22 | Return similarity score of a guess to a name. Higher is better. 23 | :param names: 24 | :param guess: 25 | :return: score 26 | """ 27 | names = [re.sub(r"[^\w\s]", "", name).lower() for name in names] 28 | guess_words = guess.split() 29 | matches = max(_get_name_matches(name, guess_words) for name in names) 30 | diff = max(fuzz.token_sort_ratio(guess, name) for name in names) 31 | return matches * 100 + diff 32 | 33 | 34 | def _check_market(market: PredictItMarket, guess: str) -> Tuple[int, int]: 35 | """ 36 | Return the id and similarity score of a market to a guess. 37 | :param market: 38 | :param guess: 39 | :return: id and similarity score 40 | """ 41 | return market.id, _get_name_score([market.shortName, market.name], guess) 42 | 43 | 44 | def _check_question(question: PredictItQuestion, guess: str) -> Tuple[int, int]: 45 | """ 46 | Return the id and similarity score of a question to a guess. 47 | :param question: 48 | :param guess: 49 | :return: id and similarity score 50 | """ 51 | return question.id, _get_name_score([question.name], guess) 52 | 53 | 54 | def _get_best_market_id(pi: PredictIt, guess: str) -> int: 55 | """ 56 | Return the id of the market with the highest similarity score. 57 | :param pi: 58 | :param guess: 59 | :return: market id 60 | """ 61 | return max( 62 | (_check_market(market, guess) for market in pi.markets), 63 | key=operator.itemgetter(1), 64 | )[0] 65 | 66 | 67 | def _get_best_question_id(market: PredictItMarket, guess: str) -> int: 68 | """ 69 | Return the id of the question with the highest similarity score. 70 | :param market: 71 | :param guess: 72 | :return: question id 73 | """ 74 | return max( 75 | (_check_question(question, guess) for question in market.questions), 76 | key=operator.itemgetter(1), 77 | )[0] 78 | 79 | 80 | def search_market(pi: PredictIt, guess: str) -> PredictItMarket: 81 | """ 82 | Return a PredictIt market with the given name, 83 | using fuzzy matching if an exact match is not found. 84 | :param pi: 85 | :param guess: 86 | :return: market 87 | """ 88 | return pi.get_market(_get_best_market_id(pi, guess)) 89 | 90 | 91 | def search_question(market: PredictItMarket, guess: str) -> PredictItQuestion: 92 | """ 93 | Return the specified question given by the name of the question, 94 | using fuzzy matching in the case where the name isn't exact. 95 | :param market: 96 | :param guess: 97 | :return: question 98 | """ 99 | return market.get_question(_get_best_question_id(market, guess)) 100 | -------------------------------------------------------------------------------- /ergo/contrib/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .core import ( 2 | plot_question, 3 | question, 4 | rejection_sample, 5 | sample_from_ensemble, 6 | samplers, 7 | summarize_question_samples, 8 | ) 9 | from .utils import daterange 10 | -------------------------------------------------------------------------------- /ergo/contrib/utils/core.py: -------------------------------------------------------------------------------- 1 | from datetime import date, timedelta 2 | import functools 3 | 4 | import numpy as np 5 | 6 | import ergo 7 | 8 | # TODO consider turning this into a Class or factoring into Ergo proper 9 | 10 | # Rejection sampling 11 | 12 | 13 | def rejection_sample(fn, condition): 14 | """ 15 | Sample from fn until we get a value that satisfies 16 | condition, then return it. 17 | """ 18 | while True: 19 | value = fn() 20 | if condition(value): 21 | return value 22 | 23 | 24 | # Associate models with questions 25 | 26 | # We'll add a sampler here for each question we predict on. 27 | # Each sampler is a function that returns a single sample 28 | # from our model predicting on that question. 29 | samplers = {} 30 | 31 | 32 | # TODO probably curry this with the notbooks metaculus instance so we don't need to pass 33 | # it in on every question 34 | def question( 35 | metaculus, 36 | question_id, 37 | community_weight=0, 38 | community_fn=None, 39 | start_date=date.today(), 40 | ): 41 | q = metaculus.get_question(question_id) 42 | 43 | def decorator(func): 44 | tag = func.__name__ 45 | 46 | @functools.wraps(func) 47 | @ergo.mem 48 | def sampler(): 49 | if ergo.flip(community_weight): 50 | if community_fn: 51 | value = community_fn() 52 | else: 53 | value = q.sample_community() 54 | else: 55 | value = func() 56 | if isinstance(value, date): 57 | # FIXME: Ergo needs to handle dates 58 | ergo.tag(int((value - start_date).days), tag) 59 | else: 60 | ergo.tag(value, tag) 61 | return value 62 | 63 | sampler.question = q 64 | samplers[q.id] = sampler 65 | return sampler 66 | 67 | return decorator 68 | 69 | 70 | def summarize_question_samples(samples): 71 | sampler_tags = [sampler.__name__ for sampler in samplers.values()] 72 | tags_to_show = [tag for tag in sampler_tags if tag in samples.columns] 73 | samples_to_show = samples[tags_to_show] 74 | summary = samples_to_show.describe().transpose().round(2) 75 | display(summary) # noqa: F821 #TODO see if we need this display command 76 | 77 | 78 | def plot_question(sampler, num_samples=200, bw=None, start_date=date.today()): 79 | samples = ergo.run(sampler, num_samples=num_samples) 80 | 81 | summarize_question_samples(samples) 82 | 83 | q = sampler.question 84 | 85 | q_samples = samples[sampler.__name__] 86 | 87 | if ( 88 | q.id == 4128 89 | ): # Date question: Need to convert back to date from days (https://github.com/oughtinc/ergo/issues/144) 90 | q_samples = np.array([start_date + timedelta(s) for s in q_samples]) 91 | 92 | if bw is not None: 93 | q.show_prediction( 94 | samples=q_samples, show_community=True, percent_kept=0.9, bw=bw 95 | ) 96 | else: 97 | q.show_prediction(samples=q_samples, show_community=True, percent_kept=0.9) 98 | 99 | 100 | def sample_from_ensemble(models, params, weights=None, fallback=False, default=None): 101 | """Sample models in proportion to weights and execute with 102 | model_params. If fallback is true then call different model from 103 | ensemble if the selected model throws an error. If Default is not 104 | None then return default if all models fail 105 | 106 | """ 107 | if len(models) > 1: 108 | model = ergo.random_choice(models, weights) 109 | else: 110 | model = models[0] 111 | try: 112 | result = model(**params) 113 | if np.isnan(result): 114 | raise KeyError 115 | return result 116 | except (KeyError, IndexError): 117 | if fallback and len(models) > 1: 118 | models_copy = models.copy() 119 | weights_copy = weights.copy() 120 | i = models.index(model) 121 | del models_copy[i] 122 | del weights_copy[i] 123 | return sample_from_ensemble( 124 | models_copy, params, weights_copy, fallback, default 125 | ) 126 | return default 127 | -------------------------------------------------------------------------------- /ergo/contrib/utils/utils.py: -------------------------------------------------------------------------------- 1 | from datetime import timedelta 2 | 3 | 4 | def daterange(start_date, end_date): 5 | for n in range(int((end_date - start_date).days)): 6 | yield start_date + timedelta(n) 7 | -------------------------------------------------------------------------------- /ergo/distributions/__init__.py: -------------------------------------------------------------------------------- 1 | from .base import ( 2 | BetaFromHits, 3 | Categorical, 4 | LogNormalFromInterval, 5 | NormalFromInterval, 6 | bernoulli, 7 | beta, 8 | beta_from_hits, 9 | categorical, 10 | flip, 11 | halfnormal, 12 | halfnormal_from_interval, 13 | lognormal, 14 | lognormal_from_interval, 15 | normal, 16 | normal_from_interval, 17 | random_choice, 18 | random_integer, 19 | uniform, 20 | ) 21 | from .constants import bin_sizes, grid, point_density_default_num_points, target_xs 22 | from .distribution import Distribution 23 | from .logistic import Logistic 24 | from .logistic_mixture import LogisticMixture 25 | from .point_density import PointDensity 26 | from .truncate import Truncate 27 | -------------------------------------------------------------------------------- /ergo/distributions/base.py: -------------------------------------------------------------------------------- 1 | """ 2 | This module provides samplers for probability distributions. 3 | """ 4 | 5 | import math 6 | 7 | import jax.numpy as np 8 | import numpyro.distributions as dist 9 | 10 | from ergo.ppl import sample 11 | 12 | 13 | def bernoulli(p=0.5, **kwargs): 14 | return sample(dist.Bernoulli(probs=float(p)), **kwargs) 15 | 16 | 17 | def normal(mean=0, stdev=1, **kwargs): 18 | return sample(dist.Normal(mean, stdev), **kwargs) 19 | 20 | 21 | def lognormal(loc=0, scale=1, **kwargs): 22 | return sample(dist.LogNormal(loc, scale), **kwargs) 23 | 24 | 25 | def halfnormal(stdev=1, **kwargs): 26 | return sample(dist.HalfNormal(stdev), **kwargs) 27 | 28 | 29 | def uniform(low=0, high=1, **kwargs): 30 | return sample(dist.Uniform(low, high), **kwargs) 31 | 32 | 33 | def beta(a=1, b=1, **kwargs): 34 | return sample(dist.Beta(a, b), **kwargs) 35 | 36 | 37 | def categorical(ps, **kwargs): 38 | return sample(Categorical(ps), **kwargs) 39 | 40 | 41 | # Provide alternative parameterizations for primitive distributions 42 | 43 | 44 | def Categorical(scores): 45 | probs = scores / sum(scores) 46 | return dist.Categorical(probs=probs) 47 | 48 | 49 | def NormalFromInterval(low, high): 50 | """This assumes a centered 90% confidence interval, i.e. the left endpoint 51 | marks 0.05% on the CDF, the right 0.95%.""" 52 | mean = (high + low) / 2 53 | stdev = (high - mean) / 1.645 54 | return dist.Normal(mean, stdev) 55 | 56 | 57 | def HalfNormalFromInterval(high): 58 | """This assumes a 90% confidence interval starting at 0, 59 | i.e. right endpoint marks 90% on the CDF""" 60 | stdev = high / 1.645 61 | return dist.HalfNormal(stdev) 62 | 63 | 64 | def LogNormalFromInterval(low, high): 65 | """This assumes a centered 90% confidence interval, i.e. the left endpoint 66 | marks 0.05% on the CDF, the right 0.95%.""" 67 | loghigh = math.log(high) 68 | loglow = math.log(low) 69 | mean = (loghigh + loglow) / 2 70 | stdev = (loghigh - loglow) / (2 * 1.645) 71 | return dist.LogNormal(mean, stdev) 72 | 73 | 74 | def BetaFromHits(hits, total): 75 | return dist.Beta(1 + hits, 1 + (total - hits)) 76 | 77 | 78 | # Alternative names and parameterizations for primitive distribution samplers 79 | 80 | 81 | def normal_from_interval(low, high, **kwargs): 82 | return sample(NormalFromInterval(low, high), **kwargs) 83 | 84 | 85 | def lognormal_from_interval(low, high, **kwargs): 86 | return sample(LogNormalFromInterval(low, high), **kwargs) 87 | 88 | 89 | def halfnormal_from_interval(high, **kwargs): 90 | return sample(HalfNormalFromInterval(high), **kwargs) 91 | 92 | 93 | def beta_from_hits(hits, total, **kwargs): 94 | return sample(BetaFromHits(hits, total), **kwargs) 95 | 96 | 97 | def random_choice(options, ps=None): 98 | if ps is None: 99 | ps = np.full(len(options), 1 / len(options)) 100 | else: 101 | ps = np.array(ps) 102 | 103 | idx = sample(dist.Categorical(ps)) 104 | return options[idx] 105 | 106 | 107 | def random_integer(min: int, max: int, **kwargs) -> int: 108 | return int(math.floor(uniform(min, max, **kwargs).item())) 109 | 110 | 111 | flip = bernoulli 112 | -------------------------------------------------------------------------------- /ergo/distributions/constants.py: -------------------------------------------------------------------------------- 1 | import jax.numpy as np 2 | 3 | # The default number of points PointDensity uses to represent a distribution 4 | point_density_default_num_points = 200 5 | bin_sizes = np.full( 6 | point_density_default_num_points, 1 / point_density_default_num_points 7 | ) 8 | grid = np.linspace(0, 1, point_density_default_num_points + 1) 9 | target_xs = (grid[1:] + grid[:-1]) / 2 10 | -------------------------------------------------------------------------------- /ergo/distributions/distribution.py: -------------------------------------------------------------------------------- 1 | """ 2 | Base Distribution Class 3 | 4 | Specifies interface for specific Distribution Classes 5 | """ 6 | 7 | from abc import ABC, abstractmethod 8 | 9 | from ergo.scale import Scale 10 | 11 | 12 | class Distribution(ABC): 13 | @abstractmethod 14 | def pdf(self, x): 15 | ... 16 | 17 | @abstractmethod 18 | def cdf(self, x): 19 | ... 20 | 21 | @abstractmethod 22 | def ppf(self, q): 23 | ... 24 | 25 | @abstractmethod 26 | def sample(self): 27 | ... 28 | 29 | @abstractmethod 30 | def normalize(self): 31 | ... 32 | 33 | @abstractmethod 34 | def denormalize(self, scale: Scale): 35 | ... 36 | 37 | def percentiles(self, percentiles=None): 38 | from ergo.conditions import IntervalCondition 39 | 40 | if percentiles is None: 41 | percentiles = [0.01, 0.1, 0.25, 0.5, 0.75, 0.9, 0.99] 42 | values = [self.ppf(q) for q in percentiles] 43 | return [ 44 | IntervalCondition(percentile, max=float(value)) 45 | for (percentile, value) in zip(percentiles, values) 46 | ] 47 | -------------------------------------------------------------------------------- /ergo/distributions/logistic.py: -------------------------------------------------------------------------------- 1 | """ 2 | Logistic distribution 3 | """ 4 | from dataclasses import dataclass 5 | from typing import Any, Optional 6 | 7 | from jax import scipy 8 | import jax.numpy as np 9 | import scipy as oscipy 10 | 11 | from ergo.scale import Scale 12 | 13 | from .distribution import Distribution 14 | 15 | 16 | @dataclass 17 | class Logistic(Distribution): 18 | loc: float # normalized 19 | s: float # normalized 20 | scale: Scale 21 | metadata: Any = None 22 | 23 | def __init__( 24 | self, 25 | loc: float, 26 | s: float, 27 | scale: Optional[Scale] = None, 28 | metadata=None, 29 | normalized=False, 30 | ): 31 | # TODO (#303): Raise ValueError on scale < 0 32 | if normalized: 33 | self.loc = loc 34 | self.s = np.max([s, 0.0000001]) 35 | self.metadata = metadata 36 | if scale is not None: 37 | self.scale = scale 38 | else: 39 | self.scale = Scale(0, 1) 40 | self.true_s = self.s * self.scale.width 41 | self.true_loc = self.scale.denormalize_point(loc) 42 | 43 | elif scale is None: 44 | raise ValueError("Either a Scale or normalized parameters are required") 45 | else: 46 | self.loc = scale.normalize_point(loc) 47 | self.s = np.max([s, 0.0000001]) / scale.width 48 | self.scale = scale 49 | self.metadata = metadata 50 | self.true_s = s # convenience field only used in repr currently 51 | self.true_loc = loc # convenience field only used in repr currently 52 | 53 | def __repr__(self): 54 | return ( 55 | f"Logistic(scale={self.scale}, true_loc={self.true_loc}, " 56 | f"true_s={self.true_s}, normed_loc={self.loc}, normed_s={self.s}," 57 | f" metadata={self.metadata})" 58 | ) 59 | 60 | # Distribution 61 | 62 | def pdf(self, x): 63 | y = (self.scale.normalize_point(x) - self.loc) / self.s 64 | p = np.exp(scipy.stats.logistic.logpdf(y) - np.log(self.s)) 65 | return self.scale.denormalize_density(x, p) 66 | 67 | def logpdf(self, x): 68 | return np.log(self.pdf(x)) 69 | 70 | def cdf(self, x): 71 | y = (self.scale.normalize_point(x) - self.loc) / self.s 72 | return scipy.stats.logistic.cdf(y) 73 | 74 | def ppf(self, q): 75 | return self.scale.denormalize_point( 76 | oscipy.stats.logistic(loc=self.loc, scale=self.s).ppf(q) 77 | ) 78 | 79 | def sample(self): 80 | return self.scale.denormalize_point( 81 | oscipy.stats.logistic.rvs(loc=self.loc, scale=self.s) 82 | ) 83 | 84 | # Scaled 85 | 86 | def normalize(self): 87 | """ 88 | Return the normalized condition. 89 | 90 | :param scale: the true scale 91 | :return: the condition normalized to [0,1] 92 | """ 93 | return self.__class__( 94 | self.loc, self.s, Scale(0, 1), self.metadata, normalized=True 95 | ) 96 | 97 | def denormalize(self, scale: Scale): 98 | """ 99 | Assume that the distribution has been normalized to be over [0,1]. 100 | Return the distribution on the true scale 101 | 102 | :param scale: the true scale 103 | """ 104 | 105 | return self.__class__(self.loc, self.s, scale, self.metadata, normalized=True) 106 | 107 | # Structured 108 | 109 | @classmethod 110 | def structure(self, params): 111 | class_params, numeric_params = params 112 | self_class, scale_classes = class_params 113 | self_numeric, scale_numeric = numeric_params 114 | scale = scale_classes[0].structure((scale_classes, scale_numeric)) 115 | return self_class( 116 | loc=self_numeric[0], s=self_numeric[1], scale=scale, normalized=True 117 | ) 118 | 119 | def destructure(self): 120 | scale_classes, scale_numeric = self.scale.destructure() 121 | class_params = (self.__class__, scale_classes) 122 | self_numeric = (self.loc, self.s) 123 | numeric_params = (self_numeric, scale_numeric) 124 | return (class_params, numeric_params) 125 | -------------------------------------------------------------------------------- /ergo/distributions/logistic_mixture.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from typing import Sequence 3 | 4 | from jax import nn, scipy 5 | import jax.numpy as np 6 | import numpy as onp 7 | import scipy as oscipy 8 | 9 | from ergo.scale import Scale 10 | 11 | from .base import categorical 12 | from .distribution import Distribution 13 | from .logistic import Logistic 14 | from .optimizable import Optimizable 15 | from .truncate import Truncate 16 | 17 | 18 | @dataclass 19 | class LogisticMixture(Distribution, Optimizable): 20 | """ 21 | Mixture of logistic distributions as used by Metaculus 22 | 23 | Metaculus mixture weights apply to the truncated renormalized 24 | logistics, not to the non-truncated prior. 25 | """ 26 | 27 | components: Sequence[Logistic] 28 | probs: Sequence[float] 29 | 30 | # Distribution 31 | 32 | def pdf(self, x): 33 | return np.sum([c.pdf(x) * p for (c, p) in zip(self.components, self.probs)]) 34 | 35 | def logpdf(self, x): 36 | scores = [] 37 | for (c, p) in zip(self.components, self.probs): 38 | scores.append(c.logpdf(x) + np.log(p)) 39 | return scipy.special.logsumexp(np.array(scores)) 40 | 41 | def cdf(self, x): 42 | return np.sum([c.cdf(x) * p for (c, p) in zip(self.components, self.probs)]) 43 | 44 | def ppf(self, q): 45 | """ 46 | Percent point function (inverse of cdf) at q. 47 | 48 | Returns the smallest x where the mixture_cdf(x) is greater 49 | than the requested q provided: 50 | 51 | argmin{x} where mixture_cdf(x) > q 52 | 53 | The quantile of a mixture distribution can always be found 54 | within the range of its components quantiles: 55 | https://cran.r-project.org/web/packages/mistr/vignettes/mistr-introduction.pdf 56 | """ 57 | if len(self.components) == 1: 58 | return self.components[0].ppf(q) 59 | ppfs = [c.ppf(q) for c in self.components] 60 | cmin = np.min(ppfs) 61 | cmax = np.max(ppfs) 62 | 63 | return oscipy.optimize.bisect( 64 | lambda x: self.cdf(x) - q, 65 | cmin - abs(cmin / 100), 66 | cmax + abs(cmax / 100), 67 | maxiter=1000, 68 | ) 69 | 70 | def sample(self): 71 | i = categorical(np.array(self.probs)) 72 | component_dist = self.components[i] 73 | return component_dist.sample() 74 | 75 | # Scaled 76 | 77 | @property 78 | def scale(self): 79 | # We require that all scales are the same 80 | return self.components[0].scale 81 | 82 | def normalize(self): 83 | normed_components = [c.normalize() for c in self.components] 84 | return self.__class__(components=normed_components, probs=self.probs,) 85 | 86 | def denormalize(self, scale: Scale): 87 | denormed_components = [c.denormalize(scale) for c in self.components] 88 | return self.__class__(components=denormed_components, probs=self.probs,) 89 | 90 | # Structured 91 | 92 | @classmethod 93 | def structure(cls, params): 94 | (class_params, numeric_params) = params 95 | (mixture_class, component_classes) = class_params 96 | (mixture_params, component_params) = numeric_params 97 | (probs,) = mixture_params 98 | components = [ 99 | c_classes[0].structure((c_classes, c_params)) 100 | for (c_classes, c_params) in zip(component_classes, component_params) 101 | ] 102 | mixture_class = class_params[0] 103 | mixture = mixture_class(components=components, probs=probs,) 104 | return mixture 105 | 106 | def destructure(self): 107 | component_classes, component_numeric = zip( 108 | *[c.destructure() for c in self.components] 109 | ) 110 | self_numeric = (self.probs,) 111 | class_params = (self.__class__, component_classes) 112 | numeric_params = (self_numeric, component_numeric) 113 | return (class_params, numeric_params) 114 | 115 | # Optimizable 116 | 117 | @classmethod 118 | def from_params( 119 | cls, fixed_params, opt_params, scale=None, traceable=True 120 | ): # FIXME: traceable; why sometimes no Scale? 121 | if not scale: 122 | scale = Scale(0.0, 1.0) 123 | floor = fixed_params.get("floor", -np.inf) 124 | ceiling = fixed_params.get("ceiling", np.inf) 125 | # Allow logistic center to exceed the range by 20% 126 | loc_min = np.maximum(scale.low, floor) - 0.2 * scale.width 127 | loc_max = np.minimum(scale.high, ceiling) + 0.2 * scale.width 128 | loc_range = loc_max - loc_min 129 | structured_params = opt_params.reshape((-1, 3)) 130 | locs = loc_min + scipy.special.expit(structured_params[:, 0]) * loc_range 131 | # Allow logistic scales between 0.01 and 0.5 132 | # Don't allow tiny scales outside of the visible range 133 | s_min = 0.01 + 0.1 * np.where( 134 | (locs < scale.low), 135 | scale.low - locs, 136 | np.where(locs > scale.high, locs - scale.high, 0.0), 137 | ) 138 | s_max = 0.5 139 | s_range = s_max - s_min 140 | ss = s_min + scipy.special.expit(structured_params[:, 1]) * s_range 141 | # Allow probs > 0.01 142 | probs = list( 143 | 0.01 144 | + nn.softmax(structured_params[:, 2]) 145 | * (1 - 0.01 * structured_params[:, 2].size) 146 | ) 147 | # Bundle up components 148 | component_logistics = [ 149 | Logistic(l, s, scale, normalized=True) for (l, s) in zip(locs, ss) 150 | ] 151 | components = [ 152 | Truncate(base_dist=cl, floor=floor, ceiling=ceiling) 153 | for cl in component_logistics 154 | ] 155 | mixture = cls(components=components, probs=probs) 156 | return mixture 157 | 158 | @staticmethod 159 | def initialize_optimizable_params(fixed_params): 160 | num_components = fixed_params["num_components"] 161 | loc_multiplier = 3 162 | s_multiplier = 1.5 163 | locs = (onp.random.rand(num_components) - 0.5) * loc_multiplier 164 | scales = (onp.random.rand(num_components) - 0.5) * s_multiplier 165 | weights = onp.random.rand(num_components) 166 | components = onp.stack([locs, scales, weights]).transpose() 167 | return components.reshape(-1) 168 | 169 | @classmethod 170 | def normalize_fixed_params(self, fixed_params, scale): 171 | normed_fixed_params = dict(fixed_params) 172 | normed_fixed_params["floor"] = scale.normalize_point( 173 | fixed_params.get("floor", -np.inf) 174 | ) 175 | normed_fixed_params["ceiling"] = scale.normalize_point( 176 | fixed_params.get("ceiling", np.inf) 177 | ) 178 | return normed_fixed_params 179 | 180 | @classmethod 181 | def from_conditions(cls, *args, init_tries=100, opt_tries=2, **kwargs): 182 | # Increase default initialization and optimization tries 183 | return super(LogisticMixture, cls).from_conditions( 184 | *args, init_tries=init_tries, opt_tries=opt_tries, **kwargs 185 | ) 186 | 187 | @classmethod 188 | def from_samples(cls, *args, init_tries=100, opt_tries=2, **kwargs): 189 | # Increase default initialization and optimization tries 190 | return super(LogisticMixture, cls).from_samples( 191 | *args, init_tries=init_tries, opt_tries=opt_tries, **kwargs 192 | ) 193 | -------------------------------------------------------------------------------- /ergo/distributions/optimizable.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | from typing import Sequence, Type, TypeVar 3 | 4 | import jax.numpy as np 5 | import numpy as onp 6 | 7 | from ergo.conditions import Condition 8 | from ergo.scale import Scale 9 | import ergo.static as static 10 | from ergo.utils import minimize 11 | 12 | T = TypeVar("T", bound="Optimizable") 13 | 14 | 15 | class Optimizable(ABC): 16 | @classmethod 17 | @abstractmethod 18 | def from_params(cls, fixed_params, opt_params, traceable=True): 19 | ... 20 | 21 | @staticmethod 22 | @abstractmethod 23 | def initialize_optimizable_params(fixed_params): 24 | ... 25 | 26 | @abstractmethod 27 | def normalize(self): 28 | ... 29 | 30 | @abstractmethod 31 | def denormalize(self, scale: Scale): 32 | ... 33 | 34 | @classmethod 35 | def from_samples( 36 | cls: Type[T], 37 | data, 38 | fixed_params=None, 39 | scale=None, 40 | verbose=False, 41 | init_tries=1, 42 | opt_tries=1, 43 | ) -> T: 44 | if fixed_params is None: 45 | fixed_params = {} 46 | data = np.array(data) 47 | if scale is None: 48 | data_range = max(data) - min(data) 49 | scale = Scale( 50 | low=min(data) - 0.25 * data_range, high=max(data) + 0.25 * data_range, 51 | ) 52 | 53 | fixed_params = cls.normalize_fixed_params(fixed_params, scale) 54 | normalized_data = np.array(scale.normalize_points(data)) 55 | 56 | def loss(opt_params): 57 | return static.dist_logloss(cls, fixed_params, opt_params, normalized_data) 58 | 59 | def jac(opt_params): 60 | return static.dist_grad_logloss( 61 | cls, fixed_params, opt_params, normalized_data 62 | ) 63 | 64 | normalized_dist = cls.from_loss( 65 | loss, 66 | jac, 67 | fixed_params=fixed_params, 68 | verbose=verbose, 69 | init_tries=init_tries, 70 | opt_tries=opt_tries, 71 | ) 72 | 73 | return normalized_dist.denormalize(scale) 74 | 75 | @classmethod 76 | def from_conditions( 77 | cls: Type[T], 78 | conditions: Sequence[Condition], 79 | fixed_params=None, 80 | scale=None, 81 | verbose=False, 82 | init_tries=1, 83 | opt_tries=1, 84 | jit_all=False, 85 | ) -> T: 86 | 87 | if fixed_params is None: 88 | fixed_params = {} 89 | 90 | if scale is None: 91 | scale = Scale(0.0, 1.0) 92 | 93 | fixed_params = cls.normalize_fixed_params(fixed_params, scale) 94 | normalized_conditions = [condition.normalize(scale) for condition in conditions] 95 | cond_data = [condition.destructure() for condition in normalized_conditions] 96 | if cond_data: 97 | cond_classes, cond_params = zip(*cond_data) 98 | else: 99 | cond_classes, cond_params = [], [] 100 | 101 | if jit_all: 102 | jitted_loss = static.jitted_condition_loss 103 | jitted_jac = static.jitted_condition_loss_grad 104 | else: 105 | jitted_loss = static.condition_loss 106 | jitted_jac = static.condition_loss_grad 107 | 108 | def loss(opt_params): 109 | return jitted_loss(cls, fixed_params, opt_params, cond_classes, cond_params) 110 | 111 | def jac(opt_params): 112 | return jitted_jac(cls, fixed_params, opt_params, cond_classes, cond_params) 113 | 114 | normalized_dist = cls.from_loss( 115 | fixed_params=fixed_params, 116 | loss=loss, 117 | jac=jac, 118 | verbose=verbose, 119 | init_tries=init_tries, 120 | opt_tries=opt_tries, 121 | ) 122 | 123 | return normalized_dist.denormalize(scale) 124 | 125 | @classmethod 126 | def from_loss( 127 | cls: Type[T], 128 | loss, 129 | jac, 130 | fixed_params=None, 131 | verbose=False, 132 | init_tries=1, 133 | opt_tries=1, 134 | ) -> T: 135 | 136 | # fixed_params are assumed to be normalized 137 | 138 | if fixed_params is None: 139 | fixed_params = {} 140 | 141 | onp.random.seed(0) 142 | 143 | init = lambda: cls.initialize_optimizable_params(fixed_params) # noqa: E731 144 | 145 | fit_results = minimize( 146 | loss, 147 | init=init, 148 | jac=jac, 149 | init_tries=init_tries, 150 | opt_tries=opt_tries, 151 | verbose=verbose, 152 | ) 153 | if not fit_results.success and verbose: 154 | print(fit_results) 155 | optimized_params = fit_results.x 156 | 157 | return cls.from_params(fixed_params, optimized_params) 158 | 159 | @classmethod 160 | def normalize_fixed_params(self, fixed_params, scale): 161 | # They are not normalized unless a child class implements this method 162 | return fixed_params 163 | -------------------------------------------------------------------------------- /ergo/distributions/truncate.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | 3 | import jax.numpy as np 4 | 5 | from ergo.scale import Scale 6 | 7 | from .distribution import Distribution 8 | 9 | 10 | @dataclass 11 | class Truncate(Distribution): 12 | base_dist: Distribution 13 | floor: float = -np.inf # true scale 14 | ceiling: float = np.inf # true scale 15 | 16 | def __post_init__(self): 17 | # https://github.com/tensorflow/probability/blob/master/discussion/where-nan.pdf 18 | self.p_below = np.where( 19 | self.floor == -np.inf, 20 | 0, 21 | self.base_dist.cdf(np.where(self.floor == -np.inf, 0, self.floor)), 22 | ) 23 | self.p_above = np.where( 24 | self.ceiling == np.inf, 25 | 0, 26 | 1.0 - self.base_dist.cdf(np.where(self.ceiling == np.inf, 1, self.ceiling)), 27 | ) 28 | self.p_inside = 1.0 - (self.p_below + self.p_above) 29 | 30 | # Distribution 31 | 32 | def pdf(self, x): 33 | p_x = np.exp(self.base_dist.logpdf(x) - np.log(self.p_inside)) 34 | return np.where(x < self.floor, 0.0, np.where(x > self.ceiling, 0.0, p_x),) 35 | 36 | def logpdf(self, x): 37 | logp_x = self.base_dist.logpdf(x) - np.log(self.p_inside) 38 | return np.where( 39 | x < self.floor, -np.inf, np.where(x > self.ceiling, -np.inf, logp_x) 40 | ) 41 | 42 | def cdf(self, x): 43 | c_x = (self.base_dist.cdf(x) - self.p_below) / self.p_inside 44 | return np.where(x < self.floor, 0.0, np.where(x > self.ceiling, 1.0, c_x)) 45 | 46 | def ppf(self, q): 47 | """ 48 | Percent point function (inverse of cdf) at q. 49 | """ 50 | return self.base_dist.ppf(self.p_below + q * self.p_inside) 51 | 52 | def sample(self): 53 | success = False 54 | while not success: 55 | s = self.base_dist.sample() 56 | if s > self.floor and s < self.ceiling: 57 | success = True 58 | return s 59 | 60 | # Scaled 61 | 62 | @property 63 | def scale(self): 64 | return self.base_dist.scale 65 | 66 | def normalize(self): 67 | normed_base_dist = self.base_dist.normalize() 68 | normed_floor = self.scale.normalize_point(self.floor) 69 | normed_ceiling = self.scale.normalize_point(self.ceiling) 70 | return self.__class__( 71 | base_dist=normed_base_dist, floor=normed_floor, ceiling=normed_ceiling, 72 | ) 73 | 74 | def denormalize(self, scale: Scale): 75 | denormed_base_dist = self.base_dist.denormalize(scale) 76 | denormed_floor = scale.denormalize_point(self.floor) 77 | denormed_ceiling = scale.denormalize_point(self.ceiling) 78 | return self.__class__( 79 | base_dist=denormed_base_dist, 80 | floor=denormed_floor, 81 | ceiling=denormed_ceiling, 82 | ) 83 | 84 | # Structured 85 | 86 | @classmethod 87 | def structure(self, params): 88 | class_params, numeric_params = params 89 | (self_class, base_classes) = class_params 90 | (self_numeric, base_numeric) = numeric_params 91 | base_dist = base_classes[0].structure((base_classes, base_numeric)) 92 | return self_class( 93 | base_dist=base_dist, floor=self_numeric[0], ceiling=self_numeric[1], 94 | ) 95 | 96 | def destructure(self): 97 | self_class, self_numeric = ( 98 | self.__class__, 99 | (self.floor, self.ceiling), 100 | ) 101 | base_classes, base_numeric = self.base_dist.destructure() 102 | class_params = (self_class, base_classes) 103 | numeric_params = (self_numeric, base_numeric) 104 | return (class_params, numeric_params) 105 | -------------------------------------------------------------------------------- /ergo/platforms/__init__.py: -------------------------------------------------------------------------------- 1 | from .foretold import Foretold, ForetoldQuestion 2 | from .metaculus import Metaculus, MetaculusQuestion 3 | from .predictit import PredictIt, PredictItMarket, PredictItQuestion 4 | -------------------------------------------------------------------------------- /ergo/platforms/metaculus/__init__.py: -------------------------------------------------------------------------------- 1 | from .metaculus import Metaculus 2 | from .question import MetaculusQuestion 3 | -------------------------------------------------------------------------------- /ergo/platforms/metaculus/question/__init__.py: -------------------------------------------------------------------------------- 1 | from .binary import BinaryQuestion 2 | from .continuous import ContinuousQuestion 3 | from .linear import LinearQuestion 4 | from .lineardate import LinearDateQuestion 5 | from .log import LogQuestion 6 | from .question import MetaculusQuestion 7 | -------------------------------------------------------------------------------- /ergo/platforms/metaculus/question/binary.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from datetime import datetime 3 | from typing import Any 4 | 5 | import requests 6 | 7 | from ergo.distributions.base import flip 8 | 9 | from .question import MetaculusQuestion 10 | 11 | 12 | @dataclass 13 | class ScoredPrediction: 14 | """ 15 | A prediction scored according to how it resolved or 16 | according to the current community prediction 17 | """ 18 | 19 | time: float 20 | prediction: Any 21 | resolution: float 22 | score: float 23 | question_name: str 24 | 25 | 26 | class BinaryQuestion(MetaculusQuestion): 27 | """ 28 | A binary Metaculus question -- how likely is this event to happen, from 0 to 1? 29 | """ 30 | 31 | def score_prediction(self, prediction, resolution: float) -> ScoredPrediction: 32 | """ 33 | Score a prediction relative to a resolution using a Brier Score. 34 | 35 | :param prediction: how likely is the event to happen, from 0 to 1? 36 | :param resolution: how likely is the event to happen, from 0 to 1? 37 | (0 if it didn't, 1 if it did) 38 | :return: ScoredPrediction with Brier score, see 39 | https://en.wikipedia.org/wiki/Brier_score#Definition 40 | 0 is best, 1 is worst, 0.25 is chance 41 | """ 42 | predicted = prediction["x"] 43 | score = (resolution - predicted) ** 2 44 | return ScoredPrediction( 45 | prediction["t"], prediction, resolution, score, self.__str__() 46 | ) 47 | 48 | def change_since(self, since: datetime): 49 | """ 50 | Calculate change in community prediction between the argument and most recent 51 | prediction 52 | 53 | :param since: datetime 54 | :return: change in community prediction since datetime 55 | """ 56 | try: 57 | old = self.get_community_prediction(before=since) 58 | new = self.get_community_prediction() 59 | except LookupError: 60 | # Happens if no prediction predates since or no prediction yet 61 | return 0 62 | 63 | return new - old 64 | 65 | def score_my_predictions(self): 66 | """ 67 | Score all of my predictions according to the question resolution 68 | (or according to the current community prediction if the resolution 69 | isn't available) 70 | 71 | :return: List of ScoredPredictions with Brier scores 72 | """ 73 | resolution = self.resolution 74 | if resolution is None: 75 | last_community_prediction = self.prediction_timeseries[-1] 76 | resolution = last_community_prediction["distribution"]["avg"] 77 | predictions = self.my_predictions["predictions"] 78 | return [ 79 | self.score_prediction(prediction, resolution) for prediction in predictions 80 | ] 81 | 82 | def submit(self, p: float) -> requests.Response: 83 | """ 84 | Submit a prediction to my Metaculus account 85 | 86 | :param p: how likely is the event to happen, from 0 to 1? 87 | """ 88 | return self.metaculus.predict(self.id, {"prediction": p, "void": False},) 89 | 90 | def sample_community(self) -> bool: 91 | """ 92 | Sample from the Metaculus community distribution (Bernoulli). 93 | """ 94 | community_prediction = self.get_community_prediction() 95 | return flip(community_prediction) 96 | -------------------------------------------------------------------------------- /ergo/platforms/metaculus/question/constants.py: -------------------------------------------------------------------------------- 1 | # Max loc of 3 set based on API response to prediction on 2 | # https://pandemic.metaculus.com/questions/3920/what-will-the-cbo-estimate-to-be-the-cost-of-the-emergency-telework-act-s3561/ 3 | max_loc = 3 4 | 5 | # Min loc set based on API response to prediction on 6 | # https://www.metaculus.com/questions/3992/ 7 | min_loc = -2 8 | 9 | # Max scale of 10 set based on API response to prediction on 10 | # https://pandemic.metaculus.com/questions/3920/what-will-the-cbo-estimate-to-be-the-cost-of-the-emergency-telework-act-s3561/ 11 | min_scale = 0.01 12 | max_scale = 10 13 | 14 | # We're not really sure what the deal with the low and high is. 15 | # Presumably they're supposed to be the points at which Metaculus "cuts off" 16 | # your distribution and ignores porbability mass assigned below/above. 17 | # But we're not actually trying to use them to "cut off" our distribution 18 | # in a smart way; we're just trying to include as much of our distribution 19 | # as we can without the API getting unhappy 20 | # (we believe that if you set the low higher than the value below 21 | # [or if you set the high lower], then the API will reject the prediction, 22 | # though we haven't tested that extensively) 23 | min_open_low = 0.01 24 | max_open_low = 0.98 25 | 26 | # Min high of (low + 0.01) set based on API response for 27 | # https://www.metaculus.com/api2/questions/3961/predict/ -- 28 | # {'prediction': ['high minus low must be at least 0.01']}" 29 | min_open_high = 0.01 30 | max_open_high = 0.99 31 | -------------------------------------------------------------------------------- /ergo/platforms/metaculus/question/linear.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from typing import Any, Dict 3 | 4 | from ergo.distributions import Logistic, LogisticMixture 5 | from ergo.scale import Scale 6 | 7 | from .continuous import ContinuousQuestion 8 | 9 | 10 | @dataclass 11 | class LinearQuestion(ContinuousQuestion): 12 | """ 13 | A continuous Metaculus question that's on a linear (as opposed to a log) scale" 14 | """ 15 | 16 | scale: Scale 17 | 18 | def __init__( 19 | self, id: int, metaculus: Any, data: Dict, name=None, 20 | ): 21 | super().__init__(id, metaculus, data, name) 22 | self.scale = Scale( 23 | float(self.question_range["min"]), float(self.question_range["max"]) 24 | ) 25 | 26 | # TODO: also return low and high on the true scale, 27 | # and use those somehow in logistic.py 28 | def get_true_scale_logistic(self, normalized_dist: Logistic) -> Logistic: 29 | """ 30 | Convert a normalized logistic distribution to a logistic on 31 | the true scale of the question. 32 | 33 | :param normalized_dist: normalized logistic distribution 34 | :return: logistic distribution on the true scale of the question 35 | """ 36 | 37 | return normalized_dist.denormalize(self.scale) 38 | 39 | def get_true_scale_mixture( 40 | self, normalized_dist: LogisticMixture 41 | ) -> LogisticMixture: 42 | """ 43 | Convert a normalized logistic mixture distribution to a 44 | logistic on the true scale of the question. 45 | 46 | :param normalized_dist: normalized logistic mixture dist 47 | :return: same distribution rescaled to the true scale of the question 48 | """ 49 | 50 | return normalized_dist.denormalize(self.scale) 51 | -------------------------------------------------------------------------------- /ergo/platforms/metaculus/question/lineardate.py: -------------------------------------------------------------------------------- 1 | import datetime 2 | from typing import Any, Dict 3 | 4 | import pandas as pd 5 | from plotnine import ( 6 | aes, 7 | element_text, 8 | facet_wrap, 9 | geom_histogram, 10 | ggplot, 11 | guides, 12 | scale_fill_brewer, 13 | scale_x_datetime, 14 | theme, 15 | ) 16 | 17 | from ergo.scale import TimeScale 18 | from ergo.theme import ergo_theme 19 | 20 | from .linear import ContinuousQuestion 21 | 22 | 23 | class LinearDateQuestion(ContinuousQuestion): 24 | scale: TimeScale 25 | 26 | def __init__( 27 | self, id: int, metaculus: Any, data: Dict, name=None, 28 | ): 29 | super().__init__(id, metaculus, data, name) 30 | self.scale = TimeScale( 31 | self.date_to_timestamp(self.possibilities["scale"]["min"]), 32 | self.date_to_timestamp(self.possibilities["scale"]["max"]), 33 | ) 34 | 35 | def _scale_x(self, xmin: float = None, xmax: float = None): 36 | return scale_x_datetime(limits=(xmin, xmax)) 37 | 38 | def date_to_timestamp(self, date: str): 39 | """ 40 | Turn a date string in %Y-%m-%d format into a timestamp. Metaculus 41 | uses this format for dates when specifying the range of a date question. 42 | We're assuming Metaculus is interpreting these date strings as UTC. 43 | 44 | :return: A Unix timestamp 45 | """ 46 | dt = datetime.datetime.strptime(date, "%Y-%m-%d") 47 | # To obtain UTC timestamp from datetime, used method described here: 48 | # https://docs.python.org/3/library/datetime.html#datetime.datetime.timestamp 49 | return dt.replace(tzinfo=datetime.timezone.utc).timestamp() 50 | 51 | # TODO enforce return type date/datetime 52 | def sample_community(self): 53 | """ 54 | Sample an approximation of the entire current community prediction, 55 | on the true scale of the question. 56 | 57 | :return: One sample on the true scale 58 | """ 59 | normalized_sample = self.sample_normalized_community() 60 | return self.denormalize_samples(normalized_sample) 61 | 62 | def comparison_plot( # type: ignore 63 | self, df: pd.DataFrame, xmin=None, xmax=None, bins: int = 50, **kwargs 64 | ): 65 | 66 | return ( 67 | ggplot(df, aes(df.columns[1], fill=df.columns[0])) 68 | + scale_fill_brewer(type="qual", palette="Pastel1") 69 | + geom_histogram(position="identity", alpha=0.9, bins=bins) 70 | + self._scale_x(xmin, xmax) 71 | + facet_wrap(df.columns[0], ncol=1) 72 | + guides(fill=False) 73 | + ergo_theme 74 | + theme(axis_text_x=element_text(rotation=45, hjust=1)) 75 | ) 76 | 77 | def density_plot( # type: ignore 78 | self, 79 | df: pd.DataFrame, 80 | xmin=None, 81 | xmax=None, 82 | fill: str = "#fbb4ae", 83 | bins: int = 50, 84 | **kwargs, 85 | ): 86 | 87 | return ( 88 | ggplot(df, aes(df.columns[0])) 89 | + geom_histogram(fill=fill, bins=bins) 90 | + self._scale_x(xmin, xmax) 91 | + ergo_theme 92 | + theme(axis_text_x=element_text(rotation=45, hjust=1)) 93 | ) 94 | -------------------------------------------------------------------------------- /ergo/platforms/metaculus/question/log.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from typing import Any, Dict 3 | 4 | from plotnine import scale_x_log10 5 | 6 | from ergo.scale import LogScale 7 | 8 | from .continuous import ContinuousQuestion 9 | 10 | 11 | @dataclass 12 | class LogQuestion(ContinuousQuestion): 13 | 14 | scale: LogScale 15 | 16 | def __init__( 17 | self, id: int, metaculus: Any, data: Dict, name=None, 18 | ): 19 | super().__init__(id, metaculus, data, name) 20 | self.scale = LogScale( 21 | float(self.question_range["min"]), 22 | float(self.question_range["max"]), 23 | float(self.possibilities["scale"]["deriv_ratio"]), 24 | ) 25 | 26 | def _scale_x(self, xmin: float = None, xmax: float = None): 27 | return scale_x_log10(limits=(xmin, xmax)) 28 | -------------------------------------------------------------------------------- /ergo/platforms/metaculus/question/question.py: -------------------------------------------------------------------------------- 1 | import bisect 2 | from datetime import datetime 3 | from typing import Any, Dict, List, Optional 4 | 5 | import jax.numpy as np 6 | import pandas as pd 7 | 8 | import ergo.distributions as dist 9 | 10 | from .types import ArrayLikeType 11 | 12 | 13 | class MetaculusQuestion: 14 | """ 15 | A forecasting question on Metaculus 16 | 17 | :param id: Question id 18 | :param metaculus: Metaculus API instance 19 | :param data: Question JSON retrieved from Metaculus API 20 | :param name: Name to assign to question (used in models) 21 | 22 | :ivar activity: 23 | :ivar anon_prediction_count: 24 | :ivar author: 25 | :ivar author_name: 26 | :ivar can_use_powers: 27 | :ivar close_time: when the question closes 28 | :ivar comment_count: 29 | :ivar created_time: when the question was created 30 | :ivar id: question id 31 | :ivar is_continuous: is the question continuous or binary? 32 | :ivar last_activity_time: 33 | :ivar page_url: url for the question page on Metaculus 34 | :ivar possibilities: 35 | :ivar prediction_histogram: histogram of the current community prediction 36 | :ivar prediction_timeseries: predictions on this question over time 37 | :ivar publish_time: when the question was published 38 | :ivar resolution: 39 | :ivar resolve_time: when the question will resolve 40 | :ivar status: 41 | :ivar title: 42 | :ivar type: 43 | :ivar url: 44 | :ivar votes: 45 | """ 46 | 47 | id: int 48 | data: Dict 49 | metaculus: Any 50 | name: Optional[str] 51 | 52 | def __init__( 53 | self, id: int, metaculus: Any, data: Dict, name=None, 54 | ): 55 | """ 56 | :param id: question id on Metaculus 57 | :param metaculus: Metaculus class instance, specifies which user to use for 58 | e.g. submitting predictions 59 | :param data: information about the question, 60 | e.g. as returned by the Metaculus API 61 | :param name: name for the question to be 62 | e.g. used in graph titles, defaults to None 63 | """ 64 | self.id = id 65 | self.data = data 66 | self.metaculus = metaculus 67 | self.name = name 68 | 69 | @property 70 | def question_url(self): 71 | return f"https://{self.metaculus.api_domain}.metaculus.com/questions/{self.id}" 72 | 73 | def __repr__(self): 74 | if self.name: 75 | return f'' 76 | elif self.data: 77 | return f'' 78 | else: 79 | return "" 80 | 81 | def __str__(self): 82 | return repr(self) 83 | 84 | def __getattr__(self, name): 85 | """ 86 | If an attribute isn't directly on the class, check whether it's in the 87 | raw question data. If it's a time, format it appropriately. 88 | 89 | :param name: attr name 90 | :return: attr value 91 | """ 92 | if name in self.data: 93 | if name.endswith("_time"): 94 | # could use dateutil.parser to deal with timezones better, 95 | # but opted for lightweight since datetime.fromisoformat 96 | # will fix this in python 3.7 97 | try: 98 | # attempt to parse with microseconds 99 | return datetime.strptime(self.data[name], "%Y-%m-%dT%H:%M:%S.%fZ") 100 | except ValueError: 101 | try: 102 | # attempt to parse without microseconds 103 | return datetime.strptime(self.data[name], "%Y-%m-%dT%H:%M:%SZ") 104 | except ValueError: 105 | print( 106 | f"The column {name} could not be converted into a datetime" 107 | ) 108 | return self.data[name] 109 | 110 | return self.data[name] 111 | else: 112 | raise AttributeError( 113 | f"Attribute {name} is neither directly on this class nor in the raw question data" 114 | ) 115 | 116 | def set_data(self, key: str, value: Any): 117 | """ 118 | Set key on data dict 119 | 120 | :param key: 121 | :param value: 122 | """ 123 | self.data[key] = value 124 | 125 | def refresh_question(self): 126 | """ 127 | Refetch the question data from Metaculus, 128 | used when the question data might have changed 129 | """ 130 | r = self.metaculus.s.get(f"{self.metaculus.api_url}/questions/{self.id}") 131 | self.data = r.json() 132 | 133 | def sample_community(self): 134 | """ 135 | Get one sample from the distribution of the Metaculus community's 136 | prediction on this question 137 | (sample is denormalized/on the the true scale of the question) 138 | """ 139 | raise NotImplementedError("This should be implemented by a subclass") 140 | 141 | def community_dist(self) -> dist.Distribution: 142 | raise NotImplementedError("This should be implemented by a subclass") 143 | 144 | @staticmethod 145 | def to_dataframe( 146 | questions: List["MetaculusQuestion"], 147 | columns: List[str] = ["id", "title", "resolve_time"], 148 | ) -> pd.DataFrame: 149 | """ 150 | Summarize a list of questions in a dataframe 151 | 152 | :param questions: questions to summarize 153 | :param columns: list of column names as strings 154 | :return: pandas dataframe summarizing the questions 155 | """ 156 | 157 | data = [ 158 | [question.name if key == "name" else question.data[key] for key in columns] 159 | for question in questions 160 | ] 161 | 162 | return pd.DataFrame(data, columns=columns) 163 | 164 | def get_community_prediction(self, before: datetime = None): 165 | if len(self.prediction_timeseries) == 0: 166 | raise LookupError # No community prediction exists yet 167 | 168 | if before is None: 169 | return self.prediction_timeseries[-1]["community_prediction"] 170 | 171 | i = bisect.bisect_left( 172 | [prediction["t"] for prediction in self.prediction_timeseries], 173 | before.timestamp(), 174 | ) 175 | 176 | if i == len(self.prediction_timeseries): # No prediction predates 177 | raise LookupError 178 | 179 | return self.prediction_timeseries[i]["community_prediction"] 180 | 181 | @staticmethod 182 | def get_central_quantiles( 183 | df: ArrayLikeType, percent_kept: float = 0.95, side_cut_from: str = "both", 184 | ): 185 | """ 186 | Get the values that bound the central (percent_kept) of the sample distribution, 187 | i.e., cutting the tails from these values will give you the central. 188 | If passed a dataframe with multiple variables, the bounds that encompass 189 | all variables will be returned. 190 | 191 | :param df: pandas dataframe of one or more column of samples 192 | :param percent_kept: percentage of sample distrubtion to keep 193 | :param side_cut_from: which side to cut tails from, 194 | either 'both','lower', or 'upper' 195 | :return: lower and upper values of the central (percent_kept) of 196 | the sample distribution. 197 | """ 198 | 199 | if side_cut_from not in ("both", "lower", "upper"): 200 | raise ValueError("side keyword must be either 'both','lower', or 'upper'") 201 | 202 | percent_cut = 1 - percent_kept 203 | if side_cut_from == "lower": 204 | _lb = percent_cut 205 | _ub = 1.0 206 | elif side_cut_from == "upper": 207 | _lb = 0.0 208 | _ub = 1 - percent_cut 209 | else: 210 | _lb = percent_cut / 2 211 | _ub = 1 - percent_cut / 2 212 | 213 | if isinstance(df, (pd.Series, np.ndarray)): 214 | _lq, _uq = df.quantile([_lb, _ub]) # type: ignore 215 | return (_lq, _uq) 216 | 217 | _lqs = [] 218 | _uqs = [] 219 | for col in df: 220 | _lq, _uq = df[col].quantile([_lb, _ub]) 221 | _lqs.append(_lq) 222 | _uqs.append(_uq) 223 | return (min(_lqs), max(_uqs)) 224 | -------------------------------------------------------------------------------- /ergo/platforms/metaculus/question/types.py: -------------------------------------------------------------------------------- 1 | from typing import Union 2 | 3 | import jax.numpy as np 4 | import numpy as onp 5 | import pandas as pd 6 | 7 | ArrayLikes = [pd.DataFrame, pd.Series, np.ndarray, np.DeviceArray, onp.ndarray] 8 | 9 | ArrayLikeType = Union[pd.DataFrame, pd.Series, np.ndarray, np.DeviceArray, onp.ndarray] 10 | -------------------------------------------------------------------------------- /ergo/platforms/predictit.py: -------------------------------------------------------------------------------- 1 | """ 2 | This module lets you get question and prediction information from PredictIt 3 | via the API (https://predictit.freshdesk.com/support/solutions/articles/12000001878) 4 | """ 5 | from typing import Dict, Generator, List 6 | 7 | from dateutil.parser import parse 8 | import pandas as pd 9 | import requests 10 | 11 | from ergo.distributions.base import flip 12 | 13 | 14 | class PredictItQuestion: 15 | """ 16 | A single binary question in a PredictIt market. 17 | 18 | :param market: PredictIt market instance 19 | :param data: Contract JSON retrieved from PredictIt API 20 | 21 | :ivar PredictItMarket market: PredictIt market instance 22 | :ivar int id: id of the contract 23 | :ivar datetime.datetime dateEnd: end-date of a market, usually None 24 | :ivar str image: url of the image resource for the contract 25 | :ivar str name: name of the contract 26 | :ivar str shortName: shortened name of the contract 27 | :ivar str status: status of the contract. Closed markets aren't included in the API, so always "Open" 28 | :ivar float lastTradePrice: last price the contract was traded at 29 | :ivar float bestBuyYesCost: cost to buy a single Yes share 30 | :ivar float bestBuyNoCost: cost to buy a single No share 31 | :ivar float bestSellYesCost: cost to sell a single Yes share 32 | :ivar float bestSellNoCost: cost to sell a single No share 33 | :ivar float lastClosePrice: price the contract closed at the previous day 34 | :ivar int displayOrder: position of the contract in PredictIt. Defaults to 0 if sorted by lastTradePrice 35 | """ 36 | 37 | def __init__(self, market: "PredictItMarket", data: Dict): 38 | self.market = market 39 | self._data = data 40 | 41 | def __repr__(self): 42 | return f'' 43 | 44 | def __getattr__(self, name: str): 45 | """ 46 | If an attribute isn't directly on the class, check whether it's in the 47 | raw contract data. If it's a time, format it appropriately. 48 | 49 | :param name: 50 | :return: attribute value 51 | """ 52 | if name not in self._data: 53 | raise AttributeError( 54 | f"Attribute {name} is neither directly on this class nor in the raw question data" 55 | ) 56 | if name != "dateEnd": 57 | return self._data[name] 58 | dateEnd = self._data["dateEnd"] 59 | if dateEnd == "N/A": 60 | return None 61 | try: 62 | return parse(dateEnd) 63 | except ValueError: 64 | print(f"The column {name} could not be converted into a datetime") 65 | return dateEnd 66 | 67 | @staticmethod 68 | def to_dataframe( 69 | questions: List["PredictItQuestion"], columns=None, 70 | ) -> pd.DataFrame: 71 | """ 72 | Summarize a list of questions in a dataframe 73 | 74 | :param questions: questions to summarize 75 | :param columns: list of column names as strings 76 | :return: pandas dataframe summarizing the questions 77 | """ 78 | if columns is None: 79 | columns = ["id", "name", "dateEnd"] 80 | data = [[question._data[key] for key in columns] for question in questions] 81 | 82 | return pd.DataFrame(data, columns=columns) 83 | 84 | def refresh(self): 85 | """ 86 | Refetch the market data from PredictIt and reload the question. 87 | """ 88 | self.market.refresh() 89 | self._data = self.market.get_question(self.id)._data 90 | 91 | def sample_community(self) -> bool: 92 | """ 93 | Sample from the PredictIt community distribution (Bernoulli). 94 | 95 | :return: true/false 96 | """ 97 | return flip(self.get_community_prediction()) 98 | 99 | 100 | class PredictItMarket: 101 | """ 102 | A PredictIt market. 103 | 104 | :param predictit: PredictIt API instance 105 | :param data: Market JSON retrieved from PredictIt API 106 | 107 | :ivar PredictIt predictit: PredictIT API instance 108 | :ivar str api_url: url of the PredictIt API for the given question 109 | :ivar int id: id of the market 110 | :ivar str name: name of the market 111 | :ivar str shortName: shortened name of the market 112 | :ivar str image: url of the image resource of the market 113 | :ivar str url: url of the market in PredictIt 114 | :ivar str status: status of the market. Closed markets aren't included in the API, so always "Open" 115 | :ivar datetime.datetime timeStamp: last time the market was updated. 116 | The API updates every minute, but timestamp can be earlier if it hasn't been traded in 117 | """ 118 | 119 | def __init__(self, predictit: "PredictIt", data: Dict): 120 | self.predictit = predictit 121 | self._data = data 122 | self.api_url = f"{self.predictit.api_url}/markets/{self.id}/" 123 | 124 | def _get(self, url: str) -> requests.Response: 125 | """ 126 | Send a get request to the PredictIt API. 127 | 128 | :param url: 129 | :return: response 130 | """ 131 | r = self.predictit.s.get(url) 132 | if r.status_code == 429: 133 | raise requests.RequestException("Hit API rate limit") 134 | return r 135 | 136 | def __repr__(self): 137 | return f'' 138 | 139 | def __getattr__(self, name: str): 140 | """ 141 | If an attribute isn't directly on the class, check whether it's in the 142 | raw contract data. If it's a time, format it appropriately. 143 | 144 | :param name: 145 | :return: attribute value 146 | """ 147 | if name not in self._data: 148 | raise AttributeError( 149 | f"Attribute {name} is neither directly on this class nor in the raw question data" 150 | ) 151 | if name != "timeStamp": 152 | return self._data[name] 153 | date_end = self._data["timeStamp"] 154 | if date_end == "N/A": 155 | return None 156 | try: 157 | return parse(date_end) 158 | except ValueError: 159 | print(f"The column {name} could not be converted into a datetime") 160 | return date_end 161 | 162 | @property 163 | def questions(self) -> Generator[PredictItQuestion, None, None]: 164 | """ 165 | Generate all of the questions in the market. 166 | 167 | :return: generator of questions in market 168 | """ 169 | for data in self._data["contracts"]: 170 | yield PredictItQuestion(self, data) 171 | 172 | def refresh(self): 173 | """ 174 | Refetch the market data from PredictIt, 175 | used when the question data might have changed. 176 | """ 177 | r = self._get(self.api_url) 178 | self._data = r.json() 179 | 180 | def get_question(self, id: int) -> PredictItQuestion: 181 | """ 182 | Return the specified question given by the id number. 183 | 184 | :param id: question id 185 | :return: question 186 | """ 187 | for question in self.questions: 188 | if question.id == id: 189 | return question 190 | raise ValueError("Unable to find a question with that id.") 191 | 192 | 193 | class PredictIt: 194 | """ 195 | The main class for interacting with PredictIt. 196 | """ 197 | 198 | def __init__(self): 199 | self.api_url = "https://www.predictit.org/api/marketdata" 200 | self.s = requests.Session() 201 | self._data = self._get(f"{self.api_url}/all/").json() 202 | 203 | def _get(self, url: str) -> requests.Response: 204 | """ 205 | Send a get request to the PredictIt API. 206 | 207 | :param url: 208 | :return: response 209 | """ 210 | r = self.s.get(url) 211 | if r.status_code == 429: 212 | raise requests.RequestException("Hit API rate limit") 213 | return r 214 | 215 | def refresh_markets(self): 216 | """ 217 | Refetch all of the markets from the PredictIt API. 218 | """ 219 | self._data = self._get(f"{self.api_url}/all/").json() 220 | 221 | @property 222 | def markets(self) -> Generator[PredictItMarket, None, None]: 223 | """ 224 | Generate all of the markets currently in PredictIt. 225 | 226 | :return: iterator of PredictIt markets 227 | """ 228 | for data in self._data["markets"]: 229 | yield PredictItMarket(self, data) 230 | 231 | def get_market(self, id: int) -> PredictItMarket: 232 | """ 233 | Return the PredictIt market with the given id. 234 | A market's id can be found in the url of the market. 235 | 236 | :param id: market id 237 | :return: market 238 | """ 239 | for data in self._data["markets"]: 240 | if data["id"] == id: 241 | return PredictItMarket(self, data) 242 | raise ValueError("Unable to find a market with that ID.") 243 | -------------------------------------------------------------------------------- /ergo/ppl.py: -------------------------------------------------------------------------------- 1 | """ 2 | This module provides a few lightweight wrappers around probabilistic 3 | programming primitives from Numpyro. 4 | """ 5 | 6 | import functools 7 | from typing import Dict, List 8 | 9 | import jax 10 | import jax.numpy as np 11 | import numpyro 12 | import numpyro.distributions as dist 13 | from numpyro.primitives import Messenger 14 | import pandas as pd 15 | from tqdm.autonotebook import tqdm 16 | 17 | # Random numbers 18 | 19 | _RNG_KEY = jax.random.PRNGKey(0) 20 | 21 | 22 | def onetime_rng_key(): 23 | global _RNG_KEY 24 | current_key, _RNG_KEY = jax.random.split(_RNG_KEY, 2) 25 | return current_key 26 | 27 | 28 | # Automatic naming of sampling sites 29 | 30 | 31 | class autoname(Messenger): 32 | """ 33 | If multiple sampling sites have the same name, automatically append a number and 34 | increment it by 1 for each repeated occurence. 35 | """ 36 | 37 | def __enter__(self): 38 | self._names = set() 39 | super(autoname, self).__enter__() 40 | 41 | def _increment_name(self, name, label): 42 | while (name, label) in self._names: 43 | try: 44 | base, count_str = name.rsplit("__", maxsplit=1) 45 | count = int(count_str) + 1 46 | except ValueError: 47 | base, count = name, 1 48 | name = f"{base}__{count}" 49 | return name 50 | 51 | def process_message(self, msg): 52 | if msg["type"] == "sample": 53 | msg["name"] = self._increment_name(msg["name"], "sample") 54 | 55 | def postprocess_message(self, msg): 56 | if msg["type"] == "sample": 57 | self._names.add((msg["name"], "sample")) 58 | 59 | 60 | # Sampling from probability distributions 61 | 62 | 63 | def sample(dist: dist.Distribution, name: str = None, **kwargs): 64 | """ 65 | Sample from a primitive distribution 66 | 67 | :param dist: A Pyro distribution 68 | :param name: Name to assign to this sampling site in the execution trace 69 | :return: A sample from the distribution 70 | """ 71 | # If a value isn't explicitly named, generate an automatic name, 72 | # relying on autoname handler for uniqueness. 73 | if not name: 74 | name = "_v" 75 | # The rng key provided below is only used when no Numpyro seed handler 76 | # is provided. This happens when we sample from distributions outside 77 | # an inference context. 78 | return numpyro.sample(name, dist, rng_key=onetime_rng_key(), **kwargs) 79 | 80 | 81 | # Conditioning 82 | 83 | 84 | def condition(cond: bool, name: str = None): 85 | if not name: 86 | name = "_c" 87 | return numpyro.factor(name, 0 if cond else np.NINF) 88 | 89 | 90 | # Record deterministic values in trace 91 | 92 | 93 | def tag(value, name: str): 94 | return numpyro.deterministic(name, value) 95 | 96 | 97 | # Automatically record model return value in trace 98 | 99 | 100 | def tag_output(model): 101 | def wrapped(): 102 | value = model() 103 | if value is not None: 104 | tag(value, "output") 105 | return value 106 | 107 | return wrapped 108 | 109 | 110 | # Memoization 111 | 112 | memoized_functions = [] # FIXME: global state 113 | 114 | 115 | def mem(func): 116 | func = functools.lru_cache(None)(func) 117 | memoized_functions.append(func) 118 | return func 119 | 120 | 121 | def clear_mem(): 122 | for func in memoized_functions: 123 | func.cache_clear() 124 | 125 | 126 | def handle_mem(model): 127 | def wrapped(*args, **kwargs): 128 | clear_mem() 129 | return model(*args, **kwargs) 130 | 131 | return wrapped 132 | 133 | 134 | # Main inference function 135 | 136 | 137 | def is_singleton_array(value): 138 | return isinstance(value, np.DeviceArray) and value.size in ((1,), 1) 139 | 140 | 141 | def is_factor(entry): 142 | return ( 143 | entry.get("is_observed") 144 | and entry.get("fn") 145 | and isinstance(entry["fn"], numpyro.distributions.Unit) 146 | ) 147 | 148 | 149 | def factor_score(entry): 150 | return entry["fn"].log_factor 151 | 152 | 153 | def run(model, num_samples=5000, ignore_untagged=True, rng_seed=0) -> pd.DataFrame: 154 | """ 155 | Run model forward, record samples for variables. Return dataframe 156 | with one row for each execution. 157 | """ 158 | model = numpyro.handlers.trace(handle_mem(tag_output(autoname(model)))) 159 | with numpyro.handlers.seed(rng_seed=rng_seed): 160 | samples: List[Dict[str, float]] = [] 161 | progress_bar = tqdm(total=num_samples) 162 | progress_bar.update(0) 163 | i = 0 164 | while i < num_samples: 165 | sample: Dict[str, float] = {} 166 | trace = model.get_trace() 167 | reject = False 168 | for name in trace.keys(): 169 | entry = trace[name] 170 | if entry["type"] in ("sample", "deterministic"): 171 | if is_factor(entry): 172 | score = factor_score(entry) 173 | if score == np.NINF: 174 | reject = True 175 | break 176 | elif score == 0: 177 | pass 178 | else: 179 | raise NotImplementedError( 180 | f"Weighted factors - got score {score}" 181 | ) 182 | else: 183 | if ignore_untagged and name.startswith("_"): 184 | continue 185 | value = entry["value"] 186 | if is_singleton_array(value): 187 | value = value.item() # FIXME 188 | sample[name] = value 189 | if reject: 190 | continue 191 | samples.append(sample) 192 | i += 1 193 | progress_bar.update(1) 194 | progress_bar.close() 195 | 196 | return pd.DataFrame(samples) # type: ignore 197 | -------------------------------------------------------------------------------- /ergo/scale.py: -------------------------------------------------------------------------------- 1 | from dataclasses import asdict, dataclass, field 2 | from datetime import timedelta 3 | import time 4 | from typing import TypeVar 5 | 6 | import jax.numpy as np 7 | 8 | 9 | @dataclass 10 | class Scale: 11 | low: float 12 | high: float 13 | width: float = field(init=False) 14 | 15 | def __post_init__(self): 16 | self.width = self.high - self.low 17 | 18 | def __hash__(self): 19 | return hash(self.__key()) 20 | 21 | def __eq__(self, other): 22 | if isinstance(other, Scale): 23 | return self.__key() == other.__key() 24 | return NotImplemented 25 | 26 | def __key(self): 27 | cls, params = self.destructure() 28 | return (cls, params) 29 | 30 | def normalize_point(self, point): 31 | return (point - self.low) / self.width 32 | 33 | def denormalize_point(self, point): 34 | return (point * self.width) + self.low 35 | 36 | def denormalize_points(self, points): 37 | return self.denormalize_point(np.array(points)) 38 | 39 | def normalize_points(self, points): 40 | return self.normalize_point(np.array(points)) 41 | 42 | def normalize_variance(self, variance): 43 | if variance is None: 44 | raise Exception("Point was None This shouldn't happen") 45 | return variance / (self.width ** 2) 46 | 47 | def denormalize_variance(self, variance): 48 | if variance is None: 49 | raise Exception("Point was None This shouldn't happen") 50 | return variance * (self.width ** 2) 51 | 52 | def normalize_density(self, _, density): 53 | return density * self.width 54 | 55 | def denormalize_density(self, _, density): 56 | return density / self.width 57 | 58 | def normalize_densities(self, _, densities): 59 | return densities * self.width 60 | 61 | def denormalize_densities(self, _, densities): 62 | return densities / self.width 63 | 64 | def destructure(self): 65 | return ((Scale,), (self.low, self.high)) 66 | 67 | @classmethod 68 | def structure(cls, params): 69 | classes, numeric = params 70 | return classes[0](*numeric) 71 | 72 | def export(self): 73 | export_dict = asdict(self) 74 | export_dict["class"] = type(self).__name__ 75 | return export_dict 76 | 77 | 78 | ScaleClass = TypeVar("ScaleClass", bound=Scale) 79 | 80 | 81 | @dataclass 82 | class LogScale(Scale): 83 | log_base: float 84 | 85 | def __hash__(self): 86 | return super().__hash__() 87 | 88 | def density_denorm_term(self, true_x): 89 | """ 90 | This is the term required to scale the density from the normalized scale to the 91 | true log scale. It accounts for the stretching to the axis from the exponention 92 | transformation. It is the derivative of the normalize_point transformation. 93 | 94 | :param true_x: the point on the true scale where the true density should be calculated 95 | :return: the term required to scale the normalized density to the true density 96 | 97 | """ 98 | return (self.log_base - 1) / ( 99 | np.log(self.log_base) 100 | * (self.log_base * (true_x - self.low) + self.high - true_x) 101 | ) 102 | 103 | def density_norm_term(self, normed_x): 104 | """ 105 | This is the term required to scale the density from the true log scale to the 106 | normalized scale. It accounts for the shrinking of the axis from the log 107 | transformation. It is the derivative of the denormalize_point transformation. 108 | 109 | :param normed_x: the point on the normed scale where the normed density should be calculated 110 | :return: the term required to scale the true density to the normed density 111 | 112 | """ 113 | return (self.log_base ** normed_x * np.log(self.log_base) * (self.width)) / ( 114 | self.log_base - 1 115 | ) 116 | 117 | def normalize_density(self, normed_x, density): 118 | return density * self.density_norm_term(normed_x) 119 | 120 | def denormalize_density(self, true_x, density): 121 | return density * self.density_denorm_term(true_x) 122 | 123 | def normalize_densities(self, normed_xs, densities): 124 | return densities * self.density_norm_term(normed_xs) 125 | 126 | def denormalize_densities(self, true_xs, densities): 127 | return densities * self.density_denorm_term(true_xs) 128 | 129 | def normalize_point(self, point): 130 | """ 131 | Get a prediction sample value on the normalized scale from a true-scale value 132 | 133 | :param true_value: a sample value on the true scale 134 | :return: a sample value on the normalized scale 135 | """ 136 | if point is None: 137 | raise Exception("Point was None This shouldn't happen") 138 | 139 | shifted = point - self.low 140 | numerator = shifted * (self.log_base - 1) 141 | scaled = numerator / self.width 142 | timber = 1 + scaled 143 | floored_timber = np.maximum(timber, 1e-9) 144 | 145 | return np.log(floored_timber) / np.log(self.log_base) 146 | 147 | def denormalize_point(self, point): 148 | """ 149 | Get a value on the true scale from a normalized-scale value 150 | 151 | :param normalized_value: [description] 152 | :type normalized_value: [type] 153 | :return: [description] 154 | :rtype: [type] 155 | """ 156 | if point is None: 157 | raise Exception("Point was None This shouldn't happen") 158 | 159 | deriv_term = (self.log_base ** point - 1) / (self.log_base - 1) 160 | scaled = self.width * deriv_term 161 | return self.low + scaled 162 | 163 | def destructure(self): 164 | return ((LogScale,), (self.low, self.high, self.log_base)) 165 | 166 | @classmethod 167 | def structure(cls, params): 168 | classes, numeric = params 169 | low, high, log_base = numeric 170 | return cls(low, high, log_base) 171 | 172 | 173 | @dataclass 174 | class TimeScale(Scale): 175 | def __repr__(self): 176 | return ( 177 | f"TimeScale(low={self.timestamp_to_str(self.low)}, " 178 | f"high={self.timestamp_to_str(self.high)}, " 179 | f"width={timedelta(seconds=self.width)})" 180 | ) 181 | 182 | def __hash__(self): 183 | return super().__hash__() 184 | 185 | def destructure(self): 186 | return ( 187 | (TimeScale,), 188 | (self.low, self.high,), 189 | ) 190 | 191 | def timestamp_to_str(self, timestamp: float) -> str: 192 | return time.strftime("%Y-%m-%d", time.localtime(timestamp)) 193 | 194 | 195 | def scale_factory(scale_dict): 196 | scale_class = scale_dict["class"] 197 | low = float(scale_dict["low"]) 198 | high = float(scale_dict["high"]) 199 | 200 | if scale_class == "Scale": 201 | return Scale(low, high) 202 | if scale_class == "LogScale": 203 | return LogScale(low, high, float(scale_dict["log_base"])) 204 | if scale_class == "TimeScale": 205 | return TimeScale(low, high) 206 | raise NotImplementedError( 207 | f"reconstructing scales of class {scale_class} is not implemented." 208 | ) 209 | -------------------------------------------------------------------------------- /ergo/static.py: -------------------------------------------------------------------------------- 1 | from functools import partial 2 | 3 | from jax import grad, jit, vmap 4 | import jax.numpy as np 5 | import jax.scipy as scipy 6 | 7 | # Multi-condition loss, jitting entire function (used for logistic 8 | # mixture dist + histogram loss) 9 | 10 | 11 | @partial(jit, static_argnums=(0, 3)) 12 | def jitted_condition_loss( 13 | dist_class, dist_fixed_params, dist_opt_params, cond_classes, cond_params 14 | ): 15 | print( 16 | f"Tracing {dist_class.__name__} ({dist_fixed_params}) loss for {[c[0].__name__ for c in cond_classes]} ({str(cond_params)[:60]})" 17 | ) 18 | dist = dist_class.from_params(dist_fixed_params, dist_opt_params, traceable=True) 19 | total_loss = 0.0 20 | for (cond_class, cond_param) in zip(cond_classes, cond_params): 21 | condition = cond_class[0].structure((cond_class, cond_param)) 22 | total_loss += condition.loss(dist) 23 | return total_loss * 100 24 | 25 | 26 | jitted_condition_loss_grad = jit( 27 | grad(jitted_condition_loss, argnums=2), static_argnums=(0, 3) 28 | ) 29 | 30 | 31 | # Multi-condition loss, jitting only individual condition losses (used 32 | # for histogram dist + arbitrary losses) 33 | 34 | 35 | def condition_loss( 36 | dist_class, dist_fixed_params, dist_opt_params, cond_classes, cond_params 37 | ): 38 | total_loss = 0.0 39 | for (cond_class, cond_param) in zip(cond_classes, cond_params): 40 | total_loss += single_condition_loss( 41 | dist_class, dist_fixed_params, dist_opt_params, cond_class, cond_param 42 | ) 43 | return total_loss 44 | 45 | 46 | def condition_loss_grad( 47 | dist_class, dist_fixed_params, dist_opt_params, cond_classes, cond_params 48 | ): 49 | total_grad = 0.0 50 | for (cond_class, cond_param) in zip(cond_classes, cond_params): 51 | total_grad += single_condition_loss_grad( 52 | dist_class, dist_fixed_params, dist_opt_params, cond_class, cond_param 53 | ) 54 | return total_grad 55 | 56 | 57 | @partial(jit, static_argnums=(0, 3)) 58 | def single_condition_loss( 59 | dist_class, dist_fixed_params, dist_opt_params, cond_class, cond_param 60 | ): 61 | 62 | dist = dist_class.from_params(dist_fixed_params, dist_opt_params, traceable=True) 63 | condition = cond_class[0].structure((cond_class, cond_param)) 64 | loss = condition.loss(dist) * 100 65 | print( 66 | f"Tracing {cond_class[0].__name__} loss for {dist_class.__name__} distribution:\n" 67 | f"- Fixed: {dist_fixed_params}\n" 68 | f"- Optim: {dist_opt_params}\n" 69 | f"- Cond: {cond_param}\n" 70 | f"- Loss: {loss}\n\n" 71 | ) 72 | return loss 73 | 74 | 75 | single_condition_loss_grad = jit( 76 | grad(single_condition_loss, argnums=2), static_argnums=(0, 3) 77 | ) 78 | 79 | 80 | # Description of distribution/condition fit 81 | 82 | 83 | @partial(jit, static_argnums=(0, 2)) 84 | def describe_fit(dist_classes, dist_params, cond_class, cond_params): 85 | dist_class = dist_classes[0] 86 | dist = dist_class.structure((dist_classes, dist_params)) 87 | condition = cond_class[0].structure((cond_class, cond_params)) 88 | return condition._describe_fit(dist) 89 | 90 | 91 | # General negative log likelihood 92 | 93 | 94 | @partial(jit, static_argnums=0) 95 | def dist_logloss(dist_class, fixed_params, opt_params, data): 96 | dist = dist_class.from_params(fixed_params, opt_params, traceable=True) 97 | if data.size == 1: 98 | return -dist.logpdf(data) 99 | scores = vmap(dist.logpdf)(data) 100 | return -np.sum(scores) 101 | 102 | 103 | dist_grad_logloss = jit(grad(dist_logloss, argnums=2), static_argnums=0) 104 | 105 | 106 | # Logistic mixture 107 | 108 | 109 | @jit 110 | def logistic_logpdf(x, loc, scale): 111 | # x, loc, scale are assumed to be normalized 112 | # https://docs.scipy.org/doc/scipy/reference/generated/scipy.stats.logistic.html 113 | y = (x - loc) / scale 114 | return scipy.stats.logistic.logpdf(y) - np.log(scale) 115 | 116 | 117 | @jit 118 | def logistic_mixture_logpdf(params, data): 119 | # params are assumed to be normalized 120 | if data.size == 1: 121 | return logistic_mixture_logpdf1(params, data) 122 | scores = vmap(partial(logistic_mixture_logpdf1, params))(data) 123 | return np.sum(scores) 124 | 125 | 126 | @jit 127 | def logistic_mixture_logpdf1(params, datum): 128 | # params are assumed to be normalized 129 | structured_params = params.reshape((-1, 3)) 130 | component_scores = [] 131 | for (loc, scale, component_prob) in structured_params: 132 | component_scores.append( 133 | logistic_logpdf(datum, loc, scale) + np.log(component_prob) 134 | ) 135 | return scipy.special.logsumexp(np.array(component_scores)) 136 | 137 | 138 | logistic_mixture_grad_logpdf = jit(grad(logistic_mixture_logpdf, argnums=0)) 139 | 140 | 141 | # Wasserstein distance 142 | 143 | 144 | @jit 145 | def wasserstein_distance(xs, ys): 146 | diffs = np.cumsum(xs - ys) 147 | abs_diffs = np.abs(diffs) 148 | return np.sum(abs_diffs) 149 | -------------------------------------------------------------------------------- /ergo/theme.py: -------------------------------------------------------------------------------- 1 | from plotnine import themes 2 | 3 | ergo_theme = themes.theme_bw() 4 | -------------------------------------------------------------------------------- /ergo/utils.py: -------------------------------------------------------------------------------- 1 | import functools 2 | import weakref 3 | 4 | import jax.numpy as np 5 | import scipy as oscipy 6 | 7 | 8 | def to_float(value): 9 | """Convert value to float""" 10 | return np.asscalar(value) 11 | 12 | 13 | def memoized_method(*lru_args, **lru_kwargs): 14 | def decorator(func): 15 | @functools.wraps(func) 16 | def wrapped_func(self, *args, **kwargs): 17 | # We're storing the wrapped method inside the instance. If we had 18 | # a strong reference to self the instance would never die. 19 | self_weak = weakref.ref(self) 20 | 21 | @functools.wraps(func) 22 | @functools.lru_cache(*lru_args, **lru_kwargs) 23 | def cached_method(*args, **kwargs): 24 | return func(self_weak(), *args, **kwargs) 25 | 26 | setattr(self, func.__name__, cached_method) 27 | return cached_method(*args, **kwargs) 28 | 29 | return wrapped_func 30 | 31 | return decorator 32 | 33 | 34 | def minimize_random(fun, init, tries=100): 35 | best_x = None 36 | best_loss = float("+inf") 37 | while tries > 0: 38 | x = init() 39 | loss = fun(x) 40 | if best_x is None or loss < best_loss: 41 | best_x = x 42 | best_loss = loss 43 | tries -= 1 44 | return best_x 45 | 46 | 47 | def minimize(fun, *args, init=None, init_tries=1, opt_tries=1, verbose=False, **kwargs): 48 | """ 49 | Wrapper around scipy.optimize.minimize that supports retries 50 | """ 51 | if "x0" in kwargs: 52 | raise ValueError("Provide initialization function (init), not x0") 53 | 54 | best_results = None 55 | best_loss = float("+inf") 56 | while opt_tries > 0: 57 | init_params = minimize_random(fun, init, tries=init_tries) 58 | results = oscipy.optimize.minimize(fun, *args, x0=init_params, **kwargs) 59 | opt_tries -= 1 60 | if best_results is None or results.fun < best_loss: 61 | best_results = results 62 | best_loss = results.fun 63 | if opt_tries == 0: 64 | break 65 | return best_results 66 | 67 | 68 | def shift(xs, k, fill_value): 69 | return np.concatenate((np.full(k, fill_value), xs[:-k])) 70 | 71 | 72 | # Taken form https://github.com/google/jax/pull/3042/files (has been merged but not released) 73 | def trapz(y, x=None, dx=1.0): 74 | if x is not None: 75 | dx = np.diff(x) 76 | return 0.5 * (dx * (y[..., 1:] + y[..., :-1])).sum(-1) 77 | 78 | 79 | def safe_log(x, constant=1e-37): 80 | return np.log(x + constant) 81 | -------------------------------------------------------------------------------- /mypy.ini: -------------------------------------------------------------------------------- 1 | [mypy] 2 | python_version = 3.6 3 | 4 | [mypy-seaborn] 5 | ignore_missing_imports = True 6 | 7 | [mypy-numpy] 8 | ignore_missing_imports = True 9 | 10 | [mypy-numpyro] 11 | ignore_missing_imports = True 12 | 13 | [mypy-numpyro.distributions] 14 | ignore_missing_imports = True 15 | 16 | [mypy-numpyro.primitives] 17 | ignore_missing_imports = True 18 | 19 | [mypy-tqdm] 20 | ignore_missing_imports = True 21 | 22 | [mypy-tqdm.autonotebook] 23 | ignore_missing_imports = True 24 | 25 | [mypy-country_converter] 26 | ignore_missing_imports = True 27 | 28 | [mypy-scipy] 29 | ignore_missing_imports = True 30 | 31 | [mypy-jax] 32 | ignore_missing_imports = True 33 | 34 | [mypy-jax.numpy] 35 | ignore_missing_imports = True 36 | 37 | [mypy-jax.scipy] 38 | ignore_missing_imports = True 39 | 40 | [mypy-jax.experimental.optimizers] 41 | ignore_missing_imports = True 42 | 43 | [mypy-jax.interpreters.xla] 44 | ignore_missing_imports = True 45 | 46 | [mypy-IPython] 47 | ignore_missing_imports = True 48 | 49 | [mypy-pytest] 50 | ignore_missing_imports = True 51 | 52 | [mypy-scipy.stats] 53 | ignore_missing_imports = True 54 | 55 | [mypy-scipy.integrate] 56 | ignore_missing_imports = True 57 | 58 | [mypy-scipy.interpolate] 59 | ignore_missing_imports = True 60 | 61 | [mypy-plotnine] 62 | ignore_missing_imports = True 63 | 64 | [mypy-sklearn] 65 | ignore_missing_imports = True 66 | 67 | 68 | [mypy-fuzzywuzzy] 69 | ignore_missing_imports = True 70 | 71 | [mypy-backports.cached_property] 72 | ignore_missing_imports = True -------------------------------------------------------------------------------- /notebooks/rejection_sampling.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 3, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import ergo\n", 10 | "import numpyro\n", 11 | "import jax.numpy as np" 12 | ] 13 | }, 14 | { 15 | "cell_type": "code", 16 | "execution_count": 4, 17 | "metadata": {}, 18 | "outputs": [ 19 | { 20 | "data": { 21 | "application/vnd.jupyter.widget-view+json": { 22 | "model_id": "bfa08e633e27469eb8a046db22f89c3d", 23 | "version_major": 2, 24 | "version_minor": 0 25 | }, 26 | "text/plain": [ 27 | "HBox(children=(FloatProgress(value=0.0), HTML(value='')))" 28 | ] 29 | }, 30 | "metadata": {}, 31 | "output_type": "display_data" 32 | }, 33 | { 34 | "name": "stdout", 35 | "output_type": "stream", 36 | "text": [ 37 | "\n" 38 | ] 39 | }, 40 | { 41 | "data": { 42 | "text/html": [ 43 | "
\n", 44 | "\n", 57 | "\n", 58 | " \n", 59 | " \n", 60 | " \n", 61 | " \n", 62 | " \n", 63 | " \n", 64 | " \n", 65 | " \n", 66 | " \n", 67 | " \n", 68 | " \n", 69 | " \n", 70 | " \n", 71 | " \n", 72 | " \n", 73 | " \n", 74 | " \n", 75 | " \n", 76 | " \n", 77 | " \n", 78 | " \n", 79 | " \n", 80 | " \n", 81 | " \n", 82 | " \n", 83 | " \n", 84 | " \n", 85 | " \n", 86 | "
xycount
0FalseTrue29
1TrueFalse35
2TrueTrue36
\n", 87 | "
" 88 | ], 89 | "text/plain": [ 90 | " x y count\n", 91 | "0 False True 29\n", 92 | "1 True False 35\n", 93 | "2 True True 36" 94 | ] 95 | }, 96 | "execution_count": 4, 97 | "metadata": {}, 98 | "output_type": "execute_result" 99 | } 100 | ], 101 | "source": [ 102 | "def model():\n", 103 | " x = ergo.flip(0.5, name=\"x\")\n", 104 | " y = ergo.flip(0.5, name=\"y\")\n", 105 | " ergo.condition(x or y)\n", 106 | "\n", 107 | "trace = ergo.run(model, num_samples=100)\n", 108 | "\n", 109 | "trace.groupby([\"x\", \"y\"]).size().reset_index(name=\"count\")" 110 | ] 111 | } 112 | ], 113 | "metadata": { 114 | "kernelspec": { 115 | "display_name": "Python 3", 116 | "language": "python", 117 | "name": "python3" 118 | }, 119 | "language_info": { 120 | "codemirror_mode": { 121 | "name": "ipython", 122 | "version": 3 123 | }, 124 | "file_extension": ".py", 125 | "mimetype": "text/x-python", 126 | "name": "python", 127 | "nbconvert_exporter": "python", 128 | "pygments_lexer": "ipython3", 129 | "version": "3.6.9" 130 | } 131 | }, 132 | "nbformat": 4, 133 | "nbformat_minor": 4 134 | } 135 | -------------------------------------------------------------------------------- /notebooks/scrubbed/community-distributions.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "## Show" 8 | ] 9 | }, 10 | { 11 | "cell_type": "markdown", 12 | "metadata": {}, 13 | "source": [ 14 | "- Making COVID infections, deaths, and infections/death ratio predictions on Metaculus more consistent with each other\n", 15 | "- More broadly: a workflow that connects judgmental and model-based forecasting" 16 | ] 17 | }, 18 | { 19 | "cell_type": "markdown", 20 | "metadata": {}, 21 | "source": [ 22 | "## Setup" 23 | ] 24 | }, 25 | { 26 | "cell_type": "code", 27 | "execution_count": null, 28 | "metadata": {}, 29 | "outputs": [], 30 | "source": [ 31 | "%%capture\n", 32 | "%pip install --quiet poetry # Fixes https://github.com/python-poetry/poetry/issues/532\n", 33 | "%pip install --quiet git+https://github.com/oughtinc/ergo.git@7f88222f5f7a2e552eb1750d43b6c7924d2f0361" 34 | ] 35 | }, 36 | { 37 | "cell_type": "code", 38 | "execution_count": null, 39 | "metadata": {}, 40 | "outputs": [], 41 | "source": [ 42 | "import warnings\n", 43 | "warnings.filterwarnings(action=\"ignore\", category=FutureWarning)\n", 44 | "warnings.filterwarnings(module=\"plotnine\", action=\"ignore\")\n", 45 | "warnings.filterwarnings(module=\"jax\", action=\"ignore\")" 46 | ] 47 | }, 48 | { 49 | "cell_type": "code", 50 | "execution_count": null, 51 | "metadata": {}, 52 | "outputs": [], 53 | "source": [ 54 | "import ergo\n", 55 | "import seaborn\n", 56 | "import numpy as np" 57 | ] 58 | }, 59 | { 60 | "cell_type": "code", 61 | "execution_count": null, 62 | "metadata": {}, 63 | "outputs": [], 64 | "source": [ 65 | "metaculus = ergo.Metaculus(username=\"oughtpublic\", password=\"123456\")" 66 | ] 67 | }, 68 | { 69 | "cell_type": "markdown", 70 | "metadata": {}, 71 | "source": [ 72 | "## Look at questions on Metaculus\n", 73 | "* [Total COVID infections before 2021](https://www.metaculus.com/questions/3529/how-many-infections-of-covid-19-will-be-estimated-to-have-occurred-before-2021-50k-1b-range/)\n", 74 | "* [Total COVID deaths before 2021](https://www.metaculus.com/questions/3530/how-many-people-will-die-as-a-result-of-the-2019-novel-coronavirus-covid-19-before-2021/)\n", 75 | "* [Chance of dying of COVID if you get it](https://www.metaculus.com/questions/3755/what-will-be-the-ratio-of-fatalities-to-total-estimated-infections-for-covid-19-by-the-end-of-2020/)\n", 76 | "\n" 77 | ] 78 | }, 79 | { 80 | "cell_type": "markdown", 81 | "metadata": {}, 82 | "source": [ 83 | "### Load questions" 84 | ] 85 | }, 86 | { 87 | "cell_type": "code", 88 | "execution_count": null, 89 | "metadata": {}, 90 | "outputs": [], 91 | "source": [ 92 | "q_infections = metaculus.get_question(3529, name=\"infections\")\n", 93 | "q_deaths = metaculus.get_question(3530, name=\"deaths\")\n", 94 | "q_ratio = metaculus.get_question(3755, name=\"ratio\")\n", 95 | "\n", 96 | "questions = [q_infections, q_deaths, q_ratio]" 97 | ] 98 | }, 99 | { 100 | "cell_type": "markdown", 101 | "metadata": {}, 102 | "source": [ 103 | "### Show community estimate for each question" 104 | ] 105 | }, 106 | { 107 | "cell_type": "code", 108 | "execution_count": null, 109 | "metadata": {}, 110 | "outputs": [], 111 | "source": [ 112 | "[question.show_community_prediction() for question in questions];" 113 | ] 114 | }, 115 | { 116 | "cell_type": "markdown", 117 | "metadata": {}, 118 | "source": [ 119 | "## Model deaths based on the infection and ratio community estimates" 120 | ] 121 | }, 122 | { 123 | "cell_type": "markdown", 124 | "metadata": {}, 125 | "source": [ 126 | "Deaths = infections * deaths/infection" 127 | ] 128 | }, 129 | { 130 | "cell_type": "code", 131 | "execution_count": null, 132 | "metadata": {}, 133 | "outputs": [], 134 | "source": [ 135 | "%%capture\n", 136 | "def deaths_from_infections():\n", 137 | " infections = q_infections.sample_community()\n", 138 | " ratio = q_ratio.sample_community()\n", 139 | " deaths = infections * ratio\n", 140 | " ergo.tag(deaths, \"deaths\")\n", 141 | " return deaths\n", 142 | "\n", 143 | "samples = ergo.run(deaths_from_infections, num_samples=5000)" 144 | ] 145 | }, 146 | { 147 | "cell_type": "markdown", 148 | "metadata": {}, 149 | "source": [ 150 | "### How does our model prediction compare to the community prediction?" 151 | ] 152 | }, 153 | { 154 | "cell_type": "code", 155 | "execution_count": null, 156 | "metadata": {}, 157 | "outputs": [], 158 | "source": [ 159 | "q_deaths.show_prediction(samples[\"deaths\"], show_community=True)" 160 | ] 161 | }, 162 | { 163 | "cell_type": "markdown", 164 | "metadata": {}, 165 | "source": [ 166 | "### Combine our model with the community prediction" 167 | ] 168 | }, 169 | { 170 | "cell_type": "markdown", 171 | "metadata": {}, 172 | "source": [ 173 | "Mostly defer to the community's predictions on the deaths question, but update a bit towards the model that's based on infections * deaths/infection." 174 | ] 175 | }, 176 | { 177 | "cell_type": "code", 178 | "execution_count": null, 179 | "metadata": {}, 180 | "outputs": [], 181 | "source": [ 182 | "%%capture\n", 183 | "def deaths_adjusted():\n", 184 | " if ergo.flip(.66):\n", 185 | " deaths = q_deaths.sample_community()\n", 186 | " else:\n", 187 | " deaths = deaths_from_infections()\n", 188 | " ergo.tag(deaths, \"adjusted_deaths\")\n", 189 | "\n", 190 | "samples = ergo.run(deaths_adjusted, num_samples=5000)\n", 191 | "\n", 192 | "adjusted_samples = samples[\"adjusted_deaths\"]" 193 | ] 194 | }, 195 | { 196 | "cell_type": "code", 197 | "execution_count": null, 198 | "metadata": {}, 199 | "outputs": [], 200 | "source": [ 201 | "q_deaths.show_prediction(adjusted_samples, show_community=True)" 202 | ] 203 | }, 204 | { 205 | "cell_type": "markdown", 206 | "metadata": {}, 207 | "source": [ 208 | "## Submit new distribution on deaths to Metaculus" 209 | ] 210 | }, 211 | { 212 | "cell_type": "code", 213 | "execution_count": null, 214 | "metadata": {}, 215 | "outputs": [], 216 | "source": [ 217 | "%%capture\n", 218 | "q_deaths.submit_from_samples(adjusted_samples)" 219 | ] 220 | }, 221 | { 222 | "cell_type": "markdown", 223 | "metadata": {}, 224 | "source": [ 225 | "**Exercise:** Apply the same idea to estimating infections from deaths" 226 | ] 227 | } 228 | ], 229 | "metadata": { 230 | "jupytext": { 231 | "cell_metadata_filter": "-all", 232 | "main_language": "python", 233 | "notebook_metadata_filter": "-all" 234 | } 235 | }, 236 | "nbformat": 4, 237 | "nbformat_minor": 4 238 | } 239 | -------------------------------------------------------------------------------- /notebooks/scrubbed/foretold-submission.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import numpy as np\n", 10 | "import scipy.stats\n", 11 | "import seaborn\n", 12 | "import ergo" 13 | ] 14 | }, 15 | { 16 | "cell_type": "markdown", 17 | "metadata": {}, 18 | "source": [ 19 | "## Testing the internal method used to build a CDF from a bag of samples:\n", 20 | "\n", 21 | "Generate some samples, and convert to Foretold CDF:" 22 | ] 23 | }, 24 | { 25 | "cell_type": "code", 26 | "execution_count": null, 27 | "metadata": {}, 28 | "outputs": [], 29 | "source": [ 30 | "loc = 50\n", 31 | "scale = 5\n", 32 | "samples = np.random.normal(loc, scale, size=2000)\n", 33 | "cdf = ergo.foretold.ForetoldCdf.from_samples(samples, length=10)\n", 34 | "cdf" 35 | ] 36 | }, 37 | { 38 | "cell_type": "markdown", 39 | "metadata": {}, 40 | "source": [ 41 | "Compare the CDF derived from samples with the true CDF:" 42 | ] 43 | }, 44 | { 45 | "cell_type": "code", 46 | "execution_count": null, 47 | "metadata": {}, 48 | "outputs": [], 49 | "source": [ 50 | "xs = np.linspace(loc - scale * 4, loc + scale * 4, 100)\n", 51 | "ys = scipy.stats.norm.cdf(xs, loc=loc, scale=scale)\n", 52 | "seaborn.lineplot(xs, ys);\n", 53 | "seaborn.lineplot(cdf.xs, cdf.ys);" 54 | ] 55 | }, 56 | { 57 | "cell_type": "markdown", 58 | "metadata": {}, 59 | "source": [ 60 | "## Testing submission:\n", 61 | "\n", 62 | "Submit samples as a prediction (measurement) of a [question](https://www.foretold.io/c/f45577e4-f1b0-4bba-8cf6-63944e63d70c/m/cf86da3f-c257-4787-b526-3ef3cb670cb4) outcome (measureable):" 63 | ] 64 | }, 65 | { 66 | "cell_type": "code", 67 | "execution_count": null, 68 | "metadata": {}, 69 | "outputs": [], 70 | "source": [ 71 | "token = \"YOUR-TOKEN\"\n", 72 | "foretold = ergo.Foretold(token)\n", 73 | "question = foretold.get_question(\"cf86da3f-c257-4787-b526-3ef3cb670cb4\")\n", 74 | "response = question.submit_from_samples(samples, length=20)\n", 75 | "response" 76 | ] 77 | } 78 | ], 79 | "metadata": { 80 | "jupytext": { 81 | "cell_metadata_filter": "-all", 82 | "main_language": "python", 83 | "notebook_metadata_filter": "-all" 84 | } 85 | }, 86 | "nbformat": 4, 87 | "nbformat_minor": 4 88 | } 89 | -------------------------------------------------------------------------------- /notebooks/scrubbed/generative-models.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "\"Open" 8 | ] 9 | }, 10 | { 11 | "cell_type": "markdown", 12 | "metadata": {}, 13 | "source": [ 14 | "# Setup" 15 | ] 16 | }, 17 | { 18 | "cell_type": "code", 19 | "execution_count": null, 20 | "metadata": {}, 21 | "outputs": [], 22 | "source": [ 23 | "!pip install --quiet poetry # Fixes https://github.com/python-poetry/poetry/issues/532\n", 24 | "!pip install --quiet git+https://github.com/oughtinc/ergo.git" 25 | ] 26 | }, 27 | { 28 | "cell_type": "code", 29 | "execution_count": null, 30 | "metadata": {}, 31 | "outputs": [], 32 | "source": [ 33 | "%load_ext google.colab.data_table" 34 | ] 35 | }, 36 | { 37 | "cell_type": "code", 38 | "execution_count": null, 39 | "metadata": {}, 40 | "outputs": [], 41 | "source": [ 42 | "import ergo" 43 | ] 44 | }, 45 | { 46 | "cell_type": "markdown", 47 | "metadata": {}, 48 | "source": [ 49 | "# Model" 50 | ] 51 | }, 52 | { 53 | "cell_type": "code", 54 | "execution_count": null, 55 | "metadata": {}, 56 | "outputs": [], 57 | "source": [ 58 | "def model():\n", 59 | " x = ergo.lognormal_from_interval(1, 10, name=\"x\")\n", 60 | " y = ergo.beta_from_hits(2, 10, name=\"y\")\n", 61 | " z = x * y \n", 62 | " ergo.tag(z, \"z\")\n", 63 | "\n", 64 | "samples = ergo.run(model, num_samples=10000)" 65 | ] 66 | }, 67 | { 68 | "cell_type": "markdown", 69 | "metadata": {}, 70 | "source": [ 71 | "# Analysis" 72 | ] 73 | }, 74 | { 75 | "cell_type": "markdown", 76 | "metadata": {}, 77 | "source": [ 78 | "Histogram:" 79 | ] 80 | }, 81 | { 82 | "cell_type": "code", 83 | "execution_count": null, 84 | "metadata": {}, 85 | "outputs": [], 86 | "source": [ 87 | "samples[\"x\"].hist()" 88 | ] 89 | }, 90 | { 91 | "cell_type": "code", 92 | "execution_count": null, 93 | "metadata": {}, 94 | "outputs": [], 95 | "source": [ 96 | "samples[\"z\"].hist()" 97 | ] 98 | }, 99 | { 100 | "cell_type": "markdown", 101 | "metadata": {}, 102 | "source": [ 103 | "Summary stats:" 104 | ] 105 | }, 106 | { 107 | "cell_type": "code", 108 | "execution_count": null, 109 | "metadata": {}, 110 | "outputs": [], 111 | "source": [ 112 | "samples.describe()" 113 | ] 114 | } 115 | ], 116 | "metadata": { 117 | "jupytext": { 118 | "cell_metadata_filter": "-all", 119 | "main_language": "python", 120 | "notebook_metadata_filter": "-all" 121 | } 122 | }, 123 | "nbformat": 4, 124 | "nbformat_minor": 4 125 | } 126 | -------------------------------------------------------------------------------- /notebooks/scrubbed/prediction-dashboard.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# Setup" 8 | ] 9 | }, 10 | { 11 | "cell_type": "code", 12 | "execution_count": null, 13 | "metadata": {}, 14 | "outputs": [], 15 | "source": [ 16 | "%%capture\n", 17 | "%pip install --quiet poetry # Fixes https://github.com/python-poetry/poetry/issues/532\n", 18 | "%pip install --quiet git+https://github.com/oughtinc/ergo.git@7f88222f5f7a2e552eb1750d43b6c7924d2f0361" 19 | ] 20 | }, 21 | { 22 | "cell_type": "code", 23 | "execution_count": null, 24 | "metadata": {}, 25 | "outputs": [], 26 | "source": [ 27 | "import ergo\n", 28 | "import datetime\n", 29 | "import pandas as pd;" 30 | ] 31 | }, 32 | { 33 | "cell_type": "code", 34 | "execution_count": null, 35 | "metadata": {}, 36 | "outputs": [], 37 | "source": [ 38 | "%load_ext google.colab.data_table" 39 | ] 40 | }, 41 | { 42 | "cell_type": "code", 43 | "execution_count": null, 44 | "metadata": {}, 45 | "outputs": [], 46 | "source": [ 47 | "metaculus = ergo.Metaculus(username=\"oughtpublic\", password=\"123456\", api_domain=\"pandemic\")" 48 | ] 49 | }, 50 | { 51 | "cell_type": "markdown", 52 | "metadata": {}, 53 | "source": [ 54 | "# Dashboard" 55 | ] 56 | }, 57 | { 58 | "cell_type": "markdown", 59 | "metadata": {}, 60 | "source": [ 61 | "## Resolved questions" 62 | ] 63 | }, 64 | { 65 | "cell_type": "code", 66 | "execution_count": null, 67 | "metadata": {}, 68 | "outputs": [], 69 | "source": [ 70 | "resolved_predictions = metaculus.make_questions_df(metaculus.get_questions_json(question_status=\"resolved\", player_status=\"predicted\", pages=9999))" 71 | ] 72 | }, 73 | { 74 | "cell_type": "code", 75 | "execution_count": null, 76 | "metadata": {}, 77 | "outputs": [], 78 | "source": [ 79 | "resolved_predictions[[\"title\", \"resolve_time\", \"i_created\", \"url\", \"id\"]]" 80 | ] 81 | }, 82 | { 83 | "cell_type": "markdown", 84 | "metadata": {}, 85 | "source": [ 86 | "## Open questions" 87 | ] 88 | }, 89 | { 90 | "cell_type": "code", 91 | "execution_count": null, 92 | "metadata": {}, 93 | "outputs": [], 94 | "source": [ 95 | "open = metaculus.make_questions_df(metaculus.get_questions_json(question_status=\"open\", player_status=\"any\", pages=9999))" 96 | ] 97 | }, 98 | { 99 | "cell_type": "markdown", 100 | "metadata": {}, 101 | "source": [ 102 | "### Published in or after Mar 1 2020, closing in 2020" 103 | ] 104 | }, 105 | { 106 | "cell_type": "code", 107 | "execution_count": null, 108 | "metadata": {}, 109 | "outputs": [], 110 | "source": [ 111 | "open_mar_or_after = open[open[\"publish_time\"] >= datetime.datetime(2020,3,1)]\n", 112 | "and_closes_in_2020 = open_mar_or_after[open_mar_or_after[\"close_time\"] <= datetime.datetime(2020,12,31)]" 113 | ] 114 | }, 115 | { 116 | "cell_type": "code", 117 | "execution_count": null, 118 | "metadata": {}, 119 | "outputs": [], 120 | "source": [ 121 | "and_closes_in_2020" 122 | ] 123 | } 124 | ], 125 | "metadata": { 126 | "jupytext": { 127 | "cell_metadata_filter": "-all", 128 | "main_language": "python", 129 | "notebook_metadata_filter": "-all" 130 | } 131 | }, 132 | "nbformat": 4, 133 | "nbformat_minor": 4 134 | } 135 | -------------------------------------------------------------------------------- /notebooks/scrubbed/quickstart.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "\"Open" 8 | ] 9 | }, 10 | { 11 | "cell_type": "markdown", 12 | "metadata": {}, 13 | "source": [ 14 | "# Setup" 15 | ] 16 | }, 17 | { 18 | "cell_type": "code", 19 | "execution_count": null, 20 | "metadata": {}, 21 | "outputs": [], 22 | "source": [ 23 | "!pip install --progress-bar off --quiet poetry\n", 24 | "!pip install --progress-bar off --quiet git+https://github.com/oughtinc/ergo.git" 25 | ] 26 | }, 27 | { 28 | "cell_type": "markdown", 29 | "metadata": {}, 30 | "source": [ 31 | "# Load questions" 32 | ] 33 | }, 34 | { 35 | "cell_type": "code", 36 | "execution_count": null, 37 | "metadata": {}, 38 | "outputs": [], 39 | "source": [ 40 | "import ergo\n", 41 | "import seaborn" 42 | ] 43 | }, 44 | { 45 | "cell_type": "markdown", 46 | "metadata": {}, 47 | "source": [ 48 | "Log into Metaculus:" 49 | ] 50 | }, 51 | { 52 | "cell_type": "code", 53 | "execution_count": null, 54 | "metadata": {}, 55 | "outputs": [], 56 | "source": [ 57 | "metaculus = ergo.Metaculus()" 58 | ] 59 | }, 60 | { 61 | "cell_type": "markdown", 62 | "metadata": {}, 63 | "source": [ 64 | "Load some questions:" 65 | ] 66 | }, 67 | { 68 | "cell_type": "code", 69 | "execution_count": null, 70 | "metadata": {}, 71 | "outputs": [], 72 | "source": [ 73 | "q_infections = metaculus.get_question(3529, name=\"Covid-19 infections in 2020\")\n", 74 | "q_ratio = metaculus.get_question(3755, name=\"Covid-19 ratio of fatalities to infections\")" 75 | ] 76 | }, 77 | { 78 | "cell_type": "markdown", 79 | "metadata": {}, 80 | "source": [ 81 | "Build a model:" 82 | ] 83 | }, 84 | { 85 | "cell_type": "code", 86 | "execution_count": null, 87 | "metadata": {}, 88 | "outputs": [], 89 | "source": [ 90 | "def model():\n", 91 | " infections = q_infections.sample_community()\n", 92 | " ratio = q_ratio.sample_community()\n", 93 | " deaths = infections * ratio\n", 94 | " ergo.tag(deaths, \"Covid-19 deaths in 2020\")\n", 95 | "\n", 96 | "samples = ergo.run(model, num_samples=5000)" 97 | ] 98 | }, 99 | { 100 | "cell_type": "markdown", 101 | "metadata": {}, 102 | "source": [ 103 | "Show samples:" 104 | ] 105 | }, 106 | { 107 | "cell_type": "code", 108 | "execution_count": null, 109 | "metadata": {}, 110 | "outputs": [], 111 | "source": [ 112 | "seaborn.distplot(samples[\"Covid-19 deaths in 2020\"])" 113 | ] 114 | } 115 | ], 116 | "metadata": { 117 | "jupytext": { 118 | "cell_metadata_filter": "-all", 119 | "main_language": "python", 120 | "notebook_metadata_filter": "-all" 121 | } 122 | }, 123 | "nbformat": 4, 124 | "nbformat_minor": 4 125 | } 126 | -------------------------------------------------------------------------------- /notebooks/scrubbed/rejection_sampling.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import ergo\n", 10 | "import numpyro\n", 11 | "import jax.numpy as np" 12 | ] 13 | }, 14 | { 15 | "cell_type": "code", 16 | "execution_count": null, 17 | "metadata": {}, 18 | "outputs": [], 19 | "source": [ 20 | "def model():\n", 21 | " x = ergo.flip(0.5, name=\"x\")\n", 22 | " y = ergo.flip(0.5, name=\"y\")\n", 23 | " ergo.condition(x or y)\n", 24 | "\n", 25 | "trace = ergo.run(model, num_samples=100)\n", 26 | "\n", 27 | "trace.groupby([\"x\", \"y\"]).size().reset_index(name=\"count\")" 28 | ] 29 | } 30 | ], 31 | "metadata": { 32 | "jupytext": { 33 | "cell_metadata_filter": "-all", 34 | "main_language": "python", 35 | "notebook_metadata_filter": "-all" 36 | } 37 | }, 38 | "nbformat": 4, 39 | "nbformat_minor": 4 40 | } 41 | -------------------------------------------------------------------------------- /notebooks/scrubbed/test-mixture-fitting.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import ergo\n", 10 | "import seaborn\n", 11 | "\n", 12 | "from ergo import Logistic, LogisticMixture\n", 13 | "from ergo.distributions.conditions import IntervalCondition, PercentileCondition\n", 14 | "from tqdm.autonotebook import tqdm\n", 15 | "from matplotlib import pyplot" 16 | ] 17 | }, 18 | { 19 | "cell_type": "code", 20 | "execution_count": null, 21 | "metadata": {}, 22 | "outputs": [], 23 | "source": [ 24 | "def normalize(xs):\n", 25 | " z = sum(xs)\n", 26 | " return [x/z for x in xs]\n", 27 | "\n", 28 | "def sample_component():\n", 29 | " return Logistic(loc=ergo.uniform(-1, 2), scale=abs(ergo.lognormal_from_interval(0.2, 3)))\n", 30 | "\n", 31 | "def sample_condition(dist):\n", 32 | " case = ergo.random_choice([\"low_open\", \"bounded\", \"high_open\"])\n", 33 | " if case == \"low_open\":\n", 34 | " xmin = float(\"-inf\")\n", 35 | " xmax = ergo.uniform(-3, 3)\n", 36 | " elif case == \"bounded\":\n", 37 | " xmin = ergo.uniform(-3, 0) \n", 38 | " xmax = xmin + ergo.uniform(0, 3) \n", 39 | " elif case == \"high_open\":\n", 40 | " xmin = ergo.uniform(-3, 3)\n", 41 | " xmax = float(\"+inf\")\n", 42 | " p = actual_p(dist, xmin, xmax)\n", 43 | " return IntervalCondition(p, xmin, xmax)\n", 44 | "\n", 45 | "def sample_conditions(dist):\n", 46 | " num_conditions = ergo.random_choice([1, 2, 3, 5, 7])\n", 47 | " conditions = [sample_condition(dist) for _ in range(num_conditions)]\n", 48 | " return conditions\n", 49 | "\n", 50 | "def sample_mixture():\n", 51 | " num_components = ergo.random_choice([1, 2, 3])\n", 52 | " components = [sample_component() for _ in range(num_components)]\n", 53 | " probs = normalize([ergo.uniform(0, 1) for _ in range(num_components)])\n", 54 | " return LogisticMixture(components, probs)\n", 55 | " \n", 56 | "def actual_p(dist, xmin, xmax):\n", 57 | " cdf_at_min = dist.cdf(xmin) if not np.isneginf(xmin) else 0\n", 58 | " cdf_at_max = dist.cdf(xmax) if not np.isposinf(xmax) else 1\n", 59 | " return cdf_at_max - cdf_at_min\n", 60 | "\n", 61 | "def plot(dist, ax=None):\n", 62 | " xs = np.linspace(-4, 4, 100)\n", 63 | " ys = [float(mixture.pdf1(x)) for x in xs]\n", 64 | " # pyplot.figure()\n", 65 | " return seaborn.lineplot(xs, ys)\n", 66 | " \n", 67 | "def model():\n", 68 | " # 1. Sample a distribution with 1-3 peaks\n", 69 | " true_dist = sample_mixture()\n", 70 | "\n", 71 | " # 2. Sample 1-7 conditions\n", 72 | " conditions = sample_conditions(true_dist)\n", 73 | " \n", 74 | " # 3. Fit a mixture to those conditions\n", 75 | " fit_dist = LogisticMixture.from_conditions(conditions, num_components=3)\n", 76 | " \n", 77 | " # 4. Check that the conditions are satisfied\n", 78 | " for condition in conditions:\n", 79 | " fit = condition.describe_fit(fit_dist)\n", 80 | " if fit[\"loss\"] > 0.000002:\n", 81 | " print(true_dist)\n", 82 | " print(fit_dist)\n", 83 | " print(fit) \n", 84 | " for condition in conditions:\n", 85 | " print(conditions)\n", 86 | " ax = plot(true_dist)\n", 87 | " plot(fit_dist, ax=ax)\n", 88 | " raise Exception(\"Failed to fit\")\n", 89 | "\n", 90 | "for i in tqdm(range(1000)):\n", 91 | " model()" 92 | ] 93 | } 94 | ], 95 | "metadata": { 96 | "jupytext": { 97 | "cell_metadata_filter": "-all", 98 | "main_language": "python", 99 | "notebook_metadata_filter": "-all" 100 | } 101 | }, 102 | "nbformat": 4, 103 | "nbformat_minor": 4 104 | } 105 | -------------------------------------------------------------------------------- /notebooks/test-mixture-fitting.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 58, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import ergo\n", 10 | "import seaborn\n", 11 | "\n", 12 | "from ergo import Logistic, LogisticMixture\n", 13 | "from ergo.distributions.conditions import IntervalCondition, PercentileCondition\n", 14 | "from tqdm.autonotebook import tqdm\n", 15 | "from matplotlib import pyplot" 16 | ] 17 | }, 18 | { 19 | "cell_type": "code", 20 | "execution_count": null, 21 | "metadata": {}, 22 | "outputs": [ 23 | { 24 | "data": { 25 | "application/vnd.jupyter.widget-view+json": { 26 | "model_id": "b676581438534590928b35e2d81e1cdc", 27 | "version_major": 2, 28 | "version_minor": 0 29 | }, 30 | "text/plain": [ 31 | "HBox(children=(FloatProgress(value=0.0, max=1000.0), HTML(value='')))" 32 | ] 33 | }, 34 | "metadata": {}, 35 | "output_type": "display_data" 36 | } 37 | ], 38 | "source": [ 39 | "def normalize(xs):\n", 40 | " z = sum(xs)\n", 41 | " return [x/z for x in xs]\n", 42 | "\n", 43 | "def sample_component():\n", 44 | " return Logistic(loc=ergo.uniform(-1, 2), scale=abs(ergo.lognormal_from_interval(0.2, 3)))\n", 45 | "\n", 46 | "def sample_condition(dist):\n", 47 | " case = ergo.random_choice([\"low_open\", \"bounded\", \"high_open\"])\n", 48 | " if case == \"low_open\":\n", 49 | " xmin = float(\"-inf\")\n", 50 | " xmax = ergo.uniform(-3, 3)\n", 51 | " elif case == \"bounded\":\n", 52 | " xmin = ergo.uniform(-3, 0) \n", 53 | " xmax = xmin + ergo.uniform(0, 3) \n", 54 | " elif case == \"high_open\":\n", 55 | " xmin = ergo.uniform(-3, 3)\n", 56 | " xmax = float(\"+inf\")\n", 57 | " p = actual_p(dist, xmin, xmax)\n", 58 | " return IntervalCondition(p, xmin, xmax)\n", 59 | "\n", 60 | "def sample_conditions(dist):\n", 61 | " num_conditions = ergo.random_choice([1, 2, 3, 5, 7])\n", 62 | " conditions = [sample_condition(dist) for _ in range(num_conditions)]\n", 63 | " return conditions\n", 64 | "\n", 65 | "def sample_mixture():\n", 66 | " num_components = ergo.random_choice([1, 2, 3])\n", 67 | " components = [sample_component() for _ in range(num_components)]\n", 68 | " probs = normalize([ergo.uniform(0, 1) for _ in range(num_components)])\n", 69 | " return LogisticMixture(components, probs)\n", 70 | " \n", 71 | "def actual_p(dist, xmin, xmax):\n", 72 | " cdf_at_min = dist.cdf(xmin) if not np.isneginf(xmin) else 0\n", 73 | " cdf_at_max = dist.cdf(xmax) if not np.isposinf(xmax) else 1\n", 74 | " return cdf_at_max - cdf_at_min\n", 75 | "\n", 76 | "def plot(dist, ax=None):\n", 77 | " xs = np.linspace(-4, 4, 100)\n", 78 | " ys = [float(mixture.pdf1(x)) for x in xs]\n", 79 | " # pyplot.figure()\n", 80 | " return seaborn.lineplot(xs, ys)\n", 81 | " \n", 82 | "def model():\n", 83 | " # 1. Sample a distribution with 1-3 peaks\n", 84 | " true_dist = sample_mixture()\n", 85 | "\n", 86 | " # 2. Sample 1-7 conditions\n", 87 | " conditions = sample_conditions(true_dist)\n", 88 | " \n", 89 | " # 3. Fit a mixture to those conditions\n", 90 | " fit_dist = LogisticMixture.from_conditions(conditions, num_components=3)\n", 91 | " \n", 92 | " # 4. Check that the conditions are satisfied\n", 93 | " for condition in conditions:\n", 94 | " fit = condition.describe_fit(fit_dist)\n", 95 | " if fit[\"loss\"] > 0.000002:\n", 96 | " print(true_dist)\n", 97 | " print(fit_dist)\n", 98 | " print(fit) \n", 99 | " for condition in conditions:\n", 100 | " print(conditions)\n", 101 | " ax = plot(true_dist)\n", 102 | " plot(fit_dist, ax=ax)\n", 103 | " raise Exception(\"Failed to fit\")\n", 104 | "\n", 105 | "for i in tqdm(range(1000)):\n", 106 | " model()" 107 | ] 108 | } 109 | ], 110 | "metadata": { 111 | "kernelspec": { 112 | "display_name": "Python 3", 113 | "language": "python", 114 | "name": "python3" 115 | }, 116 | "language_info": { 117 | "codemirror_mode": { 118 | "name": "ipython", 119 | "version": 3 120 | }, 121 | "file_extension": ".py", 122 | "mimetype": "text/x-python", 123 | "name": "python", 124 | "nbconvert_exporter": "python", 125 | "pygments_lexer": "ipython3", 126 | "version": "3.6.9" 127 | } 128 | }, 129 | "nbformat": 4, 130 | "nbformat_minor": 4 131 | } 132 | -------------------------------------------------------------------------------- /pull_request_template.md: -------------------------------------------------------------------------------- 1 | # Quick checks (to delete) 2 | 3 | 1. Create an [issue](https://linear.app/ought) for each non-trivial PR 4 | 2. Name your branch using "copy Git branch name" in the issue 5 | 3. Don't include changes unrelated to the issue 6 | 4. Include tests or explain why not 7 | 8 | 9 | # PR (fill this out) 10 | 11 | QA: Link to [spreadsheet](https://docs.google.com/spreadsheets/d/1ilbckTFL0EocBTZsRav-yx_ImhgNkn7zgrcXSYfTCLU/edit#gid=2005290025) or describe process 12 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [tool.poetry] 2 | name = "ergo" 3 | version = "0.8.5" 4 | description = "" 5 | authors = ["Ought "] 6 | 7 | [tool.poetry.dependencies] 8 | python = "^3.6.8" 9 | requests = "2.21.0" 10 | country_converter = "0.6.7" 11 | tqdm = "4.45.0" 12 | scipy = "1.4.1" 13 | matplotlib = "3.2.1" 14 | seaborn = "0.10.0" 15 | typing_extensions = "3.7.4.2" 16 | jax = "0.1.75" 17 | jaxlib = "0.1.52" 18 | pandas = "1.0.3" 19 | sphinx-autodoc-typehints = "^1.10.3" 20 | plotnine = "^0.6.0" 21 | numpyro = {git = "https://github.com/pyro-ppl/numpyro.git", rev = "cda5fac353d162d34a76cc6b05033e24bd203fa0"} 22 | 23 | # Dataclasses only needed on 3.6, was merged into standard library for 3.7+ 24 | dataclasses = {version="0.7", python="~3.6"} 25 | 26 | # A list of all of the optional dependencies, some of which are included in the 27 | # below `extras`. They can be opted into by apps. 28 | pendulum = { version = "2.1.0", optional = true } 29 | scikit-learn = { version = "0.22.2", optional = true } 30 | sklearn = {version = "^0.0", optional = true} 31 | fuzzywuzzy = {version = "0.18.0", optional = true} 32 | python-Levenshtein = {version = "0.12.0", optional = true} 33 | "backports.cached-property" = "^1.0.0" 34 | 35 | [tool.poetry.dev-dependencies] 36 | pytest = "^5.2" 37 | mypy = "^0.770" 38 | data-science-types = "^0.2.6" 39 | autopep8 = "^1.5" 40 | IPython = "^7.13.0" 41 | jupyter = "^1.0.0" 42 | codecov = "^2.0.22" 43 | flake8 = "^3.7.9" 44 | black = {version = "^19.10b0", allow-prereleases = true} 45 | pytest-cov = "^2.8.1" 46 | isort = "^4.3.21" 47 | sphinx = "^3.0.1" 48 | sphinx_rtd_theme = "^0.4.3" 49 | pytest-dotenv = "^0.4.0" 50 | pytest-xdist = "^1.31.0" 51 | jupyterlab = "^2.1.1" 52 | nbdime = "^2.0.0" 53 | jupytext = "^1.4.2" 54 | ipywidgets = "^7.5.1" 55 | pytest-sugar = "^0.9.3" 56 | 57 | [tool.poetry.extras] 58 | notebooks = [ 59 | "pendulum", 60 | "scikit-learn", 61 | "sklearn", 62 | "fuzzywuzzy", 63 | "python-Levenshtein" 64 | ] 65 | 66 | [build-system] 67 | requires = ["poetry>=0.12", "setuptools>=47.1.1,<50"] 68 | build-backend = "poetry.masonry.api" 69 | -------------------------------------------------------------------------------- /pytest.ini: -------------------------------------------------------------------------------- 1 | [pytest] 2 | norecursedirs=.* ergo/contrib build -------------------------------------------------------------------------------- /scripts/scrub_notebooks.py: -------------------------------------------------------------------------------- 1 | """ 2 | Invoke with `make scrub`. Will Create scrubbed notebooks in notebooks/src/ from 3 | notebooks/. 4 | """ 5 | 6 | import argparse 7 | import json 8 | import os 9 | from pathlib import Path 10 | import subprocess 11 | 12 | strip_metadata = { 13 | "jupytext": {"notebook_metadata_filter": "-all", "cell_metadata_filter": "-all"} 14 | } 15 | 16 | strip_metadata_string = json.dumps(strip_metadata) 17 | 18 | 19 | def scrub(notebooks_path, scrubbed_path): 20 | for notebook_file in notebooks_path.glob("*.ipynb"): 21 | scrubbed_file = Path(scrubbed_path) / notebook_file.name 22 | subprocess.run( 23 | f"jupytext --output '{scrubbed_file}.md' --to md '{notebook_file}' --update-metadata '{strip_metadata_string}' ", 24 | shell=True, 25 | check=True, 26 | stdout=subprocess.PIPE, 27 | ) 28 | res = subprocess.run( 29 | f"jupytext --output '{scrubbed_file}' --to notebook '{scrubbed_file}.md'", 30 | shell=True, 31 | check=True, 32 | universal_newlines=True, 33 | stdout=subprocess.PIPE, 34 | ) 35 | print(res.stdout.split("\n")[1]) 36 | subprocess.run( 37 | f"rm '{scrubbed_file}.md'", shell=True, check=True, stdout=subprocess.PIPE 38 | ) 39 | 40 | 41 | if __name__ == "__main__": 42 | parser = argparse.ArgumentParser() 43 | parser.add_argument("notebooks_path", type=Path) 44 | parser.add_argument("scrubbed_path", type=Path) 45 | p = parser.parse_args() 46 | assert os.path.exists(p.notebooks_path) 47 | assert os.path.exists(p.scrubbed_path) 48 | scrub(p.notebooks_path, p.scrubbed_path) 49 | -------------------------------------------------------------------------------- /scripts/scrub_src.py: -------------------------------------------------------------------------------- 1 | """ 2 | Invoke with `make scrub_src_only`. Will remove metadata and results from notebooks in 3 | `notebooks/src/`. 4 | """ 5 | 6 | import argparse 7 | import json 8 | from pathlib import Path 9 | import subprocess 10 | 11 | strip_metadata = { 12 | "jupytext": {"notebook_metadata_filter": "-all", "cell_metadata_filter": "-all"} 13 | } 14 | 15 | strip_metadata_string = json.dumps(strip_metadata) 16 | 17 | 18 | def scrub(notebooks_path, scrubbed_path): 19 | for file_to_scrub in scrubbed_path.glob("*.ipynb"): 20 | scrubbed_file_stem = scrubbed_path / file_to_scrub.stem 21 | subprocess.run( 22 | f"jupytext --output '{scrubbed_file_stem}.md' --to md '{file_to_scrub}' --update-metadata '{strip_metadata_string}'", 23 | shell=True, 24 | check=True, 25 | stdout=subprocess.PIPE, 26 | ) 27 | res = subprocess.run( 28 | f"jupytext --output '{file_to_scrub}' --to notebook '{scrubbed_file_stem}.md'", 29 | shell=True, 30 | check=True, 31 | universal_newlines=True, 32 | stdout=subprocess.PIPE, 33 | ) 34 | print(res.stdout.split("\n")[1]) 35 | subprocess.run( 36 | f"rm '{scrubbed_file_stem}.md'", 37 | shell=True, 38 | check=True, 39 | stdout=subprocess.PIPE, 40 | ) 41 | 42 | 43 | if __name__ == "__main__": 44 | parser = argparse.ArgumentParser() 45 | parser.add_argument("notebooks_path", type=Path) 46 | parser.add_argument("scrubbed_path", type=Path) 47 | p = parser.parse_args() 48 | assert p.notebooks_path.is_dir() 49 | assert p.scrubbed_path.is_dir() 50 | scrub(p.notebooks_path, p.scrubbed_path) 51 | -------------------------------------------------------------------------------- /setup.cfg: -------------------------------------------------------------------------------- 1 | [flake8] 2 | exclude = .venv 3 | ignore = E203, E266, E501, W503 4 | max-line-length = 88 5 | per-file-ignores = __init__.py:F401 6 | mypy_config = mypy.ini 7 | 8 | [isort] 9 | line_length = 88 10 | not_skip = __init__.py 11 | # skip ergo/contrib/el_paso/brachbach.py because it's unimportant and for some reason 12 | # I can't get isort running locally to agree with isort running on Travis 13 | # about whether to insert a newline 14 | skip_glob = .ipynb_checkpoints, **/.venv/**, ergo/contrib/el_paso/brachbach.py 15 | known_first_party = ergo 16 | sections = FUTURE, STDLIB, THIRDPARTY, FIRSTPARTY, LOCALFOLDER 17 | force_sort_within_sections = true 18 | multi_line_output = 3 19 | skip = docs 20 | include_trailing_comma = true 21 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | from distutils.core import setup 4 | 5 | setup( 6 | name="ergo", 7 | version="0.8.3", 8 | description="A Python library for integrating model-based and judgmental forecasting", 9 | author="Ought", 10 | author_email="ergo@ought.org", 11 | url="https://ought.org", 12 | packages=["ergo"], 13 | ) 14 | -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/oughtinc/ergo/c46f4ebd73bc7115771f39dd99cbd32fa7a3bde3/tests/__init__.py -------------------------------------------------------------------------------- /tests/conftest.py: -------------------------------------------------------------------------------- 1 | import math 2 | import os 3 | from types import SimpleNamespace 4 | from typing import cast 5 | 6 | from dotenv import load_dotenv 7 | import jax.numpy as np 8 | import pytest 9 | 10 | import ergo 11 | from ergo.distributions import Logistic, LogisticMixture, Truncate 12 | from ergo.scale import LogScale, Scale, TimeScale 13 | 14 | METACULUS_USERNAME = "oughttest" 15 | METACULUS_PASSWORD = "6vCo39Mz^rrb" 16 | METACULUS_USER_ID = "112420" 17 | 18 | 19 | def three_sd_scale(loc, s): 20 | sd = s * math.pi / math.sqrt(3) 21 | return Scale(loc - 3 * sd, loc + 3 * sd) 22 | 23 | 24 | def easyLogistic(loc, scale): 25 | return Logistic(loc, scale, three_sd_scale(loc, scale)) 26 | 27 | 28 | @pytest.fixture(scope="module") 29 | def normalized_logistic_mixture(): 30 | return LogisticMixture( 31 | components=[ 32 | Logistic(loc=0.15, s=0.037034005, scale=Scale(0, 1)), 33 | Logistic(loc=0.85, s=0.032395907, scale=Scale(0, 1)), 34 | ], 35 | probs=[0.6, 0.4], 36 | ) 37 | 38 | 39 | @pytest.fixture(scope="module") 40 | def logistic_mixture(): 41 | xscale = Scale(0, 150000) 42 | return LogisticMixture( 43 | components=[ 44 | Logistic(loc=10000, s=1000, scale=xscale), 45 | Logistic(loc=100000, s=10000, scale=xscale), 46 | ], 47 | probs=[0.8, 0.2], 48 | ) 49 | 50 | 51 | @pytest.fixture(scope="module") 52 | def smooth_logistic_mixture(): 53 | xscale = Scale(1, 1000000.0) 54 | return LogisticMixture( 55 | components=[ 56 | Logistic(loc=400000, s=100000, scale=xscale), 57 | Logistic(loc=700000, s=50000, scale=xscale), 58 | ], 59 | probs=[0.8, 0.2], 60 | ) 61 | 62 | 63 | @pytest.fixture(scope="module") 64 | def logistic_mixture10(): 65 | xscale = Scale(-20, 40) 66 | return LogisticMixture( 67 | components=[ 68 | Logistic(loc=15, s=2.3658268, scale=xscale), 69 | Logistic(loc=5, s=2.3658268, scale=xscale), 70 | ], 71 | probs=[0.5, 0.5], 72 | ) 73 | 74 | 75 | @pytest.fixture(scope="module") 76 | def logistic_mixture_p_uneven(): 77 | xscale = Scale(-10, 20) 78 | return LogisticMixture( 79 | components=[ 80 | Logistic(loc=10, s=3, scale=xscale), 81 | Logistic(loc=5, s=5, scale=xscale), 82 | ], 83 | probs=[1.8629593e-29, 1.0], 84 | ) 85 | 86 | 87 | @pytest.fixture(scope="module") 88 | def truncated_logistic_mixture(): 89 | xscale = Scale(5000, 120000) 90 | return LogisticMixture( 91 | components=[ 92 | Truncate( 93 | Logistic(loc=10000, s=1000, scale=xscale), floor=5000, ceiling=500000 94 | ), 95 | Truncate( 96 | Logistic(loc=100000, s=10000, scale=xscale), floor=5000, ceiling=500000 97 | ), 98 | ], 99 | probs=[0.8, 0.2], 100 | ) 101 | 102 | 103 | @pytest.fixture(scope="module") 104 | def logistic_mixture_p_overlapping(): 105 | xscale = three_sd_scale(4000000.035555004, 200000.02) 106 | return LogisticMixture( 107 | components=[ 108 | Logistic(4000000.035555004, 200000.02, xscale), 109 | Logistic(4000000.0329152746, 200000.0, xscale), 110 | ], 111 | probs=[0.5, 0.5], 112 | ) 113 | 114 | 115 | @pytest.fixture(scope="module") 116 | def logistic_mixture_norm_test(): 117 | xscale = Scale(-50, 50) 118 | return LogisticMixture( 119 | components=[Logistic(-40, 1, xscale), Logistic(50, 10, xscale)], 120 | probs=[0.5, 0.5], 121 | ) 122 | 123 | 124 | @pytest.fixture(scope="module") 125 | def logistic_mixture15(): 126 | xscale = Scale(-10, 40) 127 | return LogisticMixture( 128 | components=[ 129 | Logistic(loc=10, s=3.658268, scale=xscale), 130 | Logistic(loc=20, s=3.658268, scale=xscale), 131 | ], 132 | probs=[0.5, 0.5], 133 | ) 134 | 135 | 136 | @pytest.fixture(scope="module") 137 | def logistic_mixture_samples(logistic_mixture, n=1000): 138 | return np.array([logistic_mixture.sample() for _ in range(0, n)]) 139 | 140 | 141 | @pytest.fixture(scope="module") 142 | def log_question_data(): 143 | return { 144 | "id": 0, 145 | "possibilities": { 146 | "type": "continuous", 147 | "scale": {"deriv_ratio": 10, "min": 1, "max": 10}, 148 | }, 149 | "title": "question_title", 150 | } 151 | 152 | 153 | @pytest.fixture(scope="module") 154 | def metaculus(): 155 | load_dotenv() 156 | uname = METACULUS_USERNAME 157 | pwd = METACULUS_PASSWORD 158 | user_id_str = METACULUS_USER_ID 159 | if None in [uname, pwd, user_id_str]: 160 | raise ValueError( 161 | ".env is missing METACULUS_USERNAME, METACULUS_PASSWORD, or METACULUS_USER_ID" 162 | ) 163 | user_id = int(user_id_str) 164 | metaculus = ergo.Metaculus() 165 | metaculus.login_via_username_and_password(username=uname, password=pwd) 166 | assert metaculus.user_id == user_id 167 | return metaculus 168 | 169 | 170 | @pytest.fixture(scope="module") 171 | def metaculus_via_api_keys(): 172 | load_dotenv() 173 | user_api_key = cast(str, os.getenv("METACULUS_USER_WWW_API_KEY")) 174 | org_api_key = cast(str, os.getenv("METACULUS_ORG_API_KEY")) 175 | if None in [user_api_key, org_api_key]: 176 | raise ValueError( 177 | ".env is missing METACULUS_ORG_API_KEY or METACULUS_USER_WWW_API_KEY" 178 | ) 179 | metaculus = ergo.Metaculus() 180 | metaculus.login_via_api_keys( 181 | user_api_key=user_api_key, org_api_key=org_api_key, 182 | ) 183 | return metaculus 184 | 185 | 186 | @pytest.fixture(scope="module") 187 | def unauthenticated_metaculus(): 188 | return ergo.Metaculus() 189 | 190 | 191 | @pytest.fixture(scope="module") 192 | def metaculus_questions(metaculus, log_question_data): 193 | questions = SimpleNamespace() 194 | questions.continuous_linear_closed_question = metaculus.get_question(3963) 195 | questions.continuous_linear_open_question = metaculus.get_question(3962) 196 | questions.continuous_linear_date_open_question = metaculus.get_question(4212) 197 | questions.continuous_log_open_question = metaculus.get_question(3961) 198 | questions.closed_question = metaculus.get_question(3965) 199 | questions.binary_question = metaculus.get_question(3966) 200 | questions.log_question = metaculus.make_question_from_data(log_question_data) 201 | return questions 202 | 203 | 204 | @pytest.fixture(scope="module") 205 | def date_samples(metaculus_questions, normalized_logistic_mixture): 206 | return metaculus_questions.continuous_linear_date_open_question.denormalize_samples( 207 | np.array([normalized_logistic_mixture.sample() for _ in range(0, 1000)]) 208 | ) 209 | 210 | 211 | @pytest.fixture(scope="module") 212 | def point_densities(): 213 | return make_point_densities() 214 | 215 | 216 | @pytest.fixture(scope="module") 217 | def predictit(): 218 | return ergo.PredictIt() 219 | 220 | 221 | @pytest.fixture(scope="module") 222 | def predictit_markets(): 223 | return list(ergo.PredictIt().markets)[0:3] 224 | 225 | 226 | def make_point_densities(): 227 | xs = np.array( 228 | [ 229 | -0.22231131421566422, 230 | 0.2333153619512007, 231 | 0.6889420381180656, 232 | 1.1445687142849306, 233 | 1.6001953904517954, 234 | 2.0558220666186604, 235 | 2.5114487427855257, 236 | 2.9670754189523905, 237 | ] 238 | ) 239 | densities = np.array( 240 | [ 241 | 0.05020944540593859, 242 | 0.3902426887736647, 243 | 0.5887675161478794, 244 | 0.19516571803813396, 245 | 0.33712516238248535, 246 | 0.4151935926066581, 247 | 0.16147625748938946, 248 | 0.03650993407810862, 249 | ] 250 | ) 251 | return {"xs": xs, "densities": densities} 252 | 253 | 254 | scales_to_test = [ 255 | Scale(0, 1), 256 | Scale(0, 10000), 257 | Scale(-1, 1), 258 | LogScale(0.01, 100, 10), 259 | LogScale(0.01, 1028, 2), 260 | TimeScale(631152000, 946684800), 261 | TimeScale(2000, 2051222400), 262 | ] 263 | -------------------------------------------------------------------------------- /tests/test_foretold.py: -------------------------------------------------------------------------------- 1 | from http import HTTPStatus 2 | 3 | import jax.numpy as np 4 | import numpy as onp 5 | import pandas as pd 6 | import pytest 7 | import scipy.stats 8 | 9 | from ergo.platforms.foretold import Foretold, ForetoldCdf, _measurement_query 10 | from tests.utils import random_seed 11 | 12 | 13 | class TestForetold: 14 | @random_seed 15 | def test_foretold_sampling(self): 16 | foretold = Foretold() 17 | # https://www.foretold.io/c/f45577e4-f1b0-4bba-8cf6-63944e63d70c/m/cf86da3f-c257-4787-b526-3ef3cb670cb4 18 | 19 | # Distribution is mm(10 to 20, 200 to 210), a mixture model with 20 | # most mass split between 10 - 20 and 200 - 210. 21 | dist = foretold.get_question("cf86da3f-c257-4787-b526-3ef3cb670cb4") 22 | assert dist.quantile(0.25) < 100 23 | assert dist.quantile(0.75) > 100 24 | 25 | num_samples = 5000 26 | samples = np.array([dist.sample_community() for _ in range(num_samples)]) 27 | # Probability mass is split evenly between both modes of the distribution, 28 | # so approximately half of the samples should be lower than 100 29 | assert float(onp.count_nonzero(samples > 100)) == pytest.approx( 30 | num_samples / 2, 0.1 31 | ) 32 | 33 | def test_foretold_multiple_questions(self): 34 | foretold = Foretold() 35 | # https://www.foretold.io/c/f45577e4-f1b0-4bba-8cf6-63944e63d70c/m/cf86da3f-c257-4787-b526-3ef3cb670cb4 36 | 37 | ids = [ 38 | "cf86da3f-c257-4787-b526-3ef3cb670cb4", 39 | "77936da2-a581-48c7-add1-8a4ebc647c8c", 40 | "9b0b01fb-f439-4bbe-8722-f57034ffc96e", 41 | ] 42 | has_community_prediction_list = [True, True, False] 43 | questions = foretold.get_questions(ids) 44 | for id, question, has_community_prediction in zip( 45 | ids, questions, has_community_prediction_list 46 | ): 47 | assert question is not None 48 | assert question.id == id 49 | assert question.community_prediction_available == has_community_prediction 50 | 51 | def test_foretold_multiple_questions_error(self): 52 | foretold = Foretold() 53 | with pytest.raises(NotImplementedError): 54 | ids = ["cf86da3f-c257-4787-b526-3ef3cb670cb4"] * 1000 55 | foretold.get_questions(ids) 56 | 57 | def test_cdf_from_samples_numpy(self): 58 | samples = onp.random.normal(loc=0, scale=1, size=1000) 59 | cdf = ForetoldCdf.from_samples(samples, length=100) 60 | xs = np.array(cdf.xs) 61 | ys = np.array(cdf.ys) 62 | true_ys = scipy.stats.norm.cdf(xs, loc=0, scale=1) 63 | assert len(cdf.xs) == 100 64 | assert len(cdf.ys) == 100 65 | assert type(cdf.xs[0]) == float 66 | assert type(cdf.ys[0]) == float 67 | # Check that `xs` is sorted as expected by Foretold. 68 | assert np.all(np.diff(xs) >= 0) 69 | assert np.all(0 <= ys) and np.all(ys <= 1) 70 | assert np.all(np.abs(true_ys - ys) < 0.1) 71 | 72 | def test_cdf_from_samples_pandas(self): 73 | df = pd.DataFrame({"samples": onp.random.normal(loc=0, scale=1, size=100)}) 74 | cdf = ForetoldCdf.from_samples(df["samples"], length=50) 75 | assert len(cdf.xs) == 50 76 | assert len(cdf.ys) == 50 77 | assert type(cdf.xs[0]) == float 78 | assert type(cdf.ys[0]) == float 79 | 80 | def test_measurement_query(self): 81 | cdf = ForetoldCdf([0.0, 1.0, 2.0], [1.0, 2.0, 3.0]) 82 | query = _measurement_query("cf86da3f-c257-4787-b526-3ef3cb670cb4", cdf) 83 | assert type(query) == str 84 | 85 | @pytest.mark.skip(reason="API token required") 86 | def test_create_measurement(self): 87 | foretold = Foretold(token="") 88 | question = foretold.get_question("cf86da3f-c257-4787-b526-3ef3cb670cb4") 89 | samples = onp.random.normal(loc=150, scale=5, size=1000) 90 | r = question.submit_from_samples(samples, length=20) 91 | assert r.status_code == HTTPStatus.OK 92 | -------------------------------------------------------------------------------- /tests/test_jax.py: -------------------------------------------------------------------------------- 1 | from jax import grad, jit 2 | import pytest 3 | 4 | 5 | def f(x, y): 6 | return x * 2.0 + y * 3.0 7 | 8 | 9 | def test_jax(): 10 | gf = jit(grad(f, (0, 1))) 11 | grads = gf(1.0, 1.0) 12 | assert float(grads[0]) == pytest.approx(2.0) 13 | assert float(grads[1]) == pytest.approx(3.0) 14 | 15 | 16 | test_jax() 17 | -------------------------------------------------------------------------------- /tests/test_mem.py: -------------------------------------------------------------------------------- 1 | import ergo 2 | 3 | 4 | def test_nomem(): 5 | """ 6 | Without mem, different calls to foo() should differ sometimes 7 | """ 8 | 9 | def foo(): 10 | return ergo.lognormal_from_interval(1, 10) 11 | 12 | def model(): 13 | x = foo() 14 | y = foo() 15 | return x == y 16 | 17 | samples = ergo.run(model, num_samples=1000) 18 | assert sum(samples["output"]) != 1000 19 | 20 | 21 | def test_mem(): 22 | """ 23 | With mem, different calls to foo() should always have the same value 24 | """ 25 | 26 | @ergo.mem 27 | def foo(): 28 | return ergo.lognormal_from_interval(1, 10) 29 | 30 | def model(): 31 | x = foo() 32 | y = foo() 33 | return x == y 34 | 35 | samples = ergo.run(model, num_samples=1000) 36 | assert sum(samples["output"]) == 1000 37 | 38 | 39 | def test_mem_2(): 40 | """ 41 | Check that mem is cleared at the start of each run 42 | """ 43 | 44 | @ergo.mem 45 | def model(): 46 | return ergo.lognormal_from_interval(1, 10) 47 | 48 | samples = ergo.run(model, num_samples=100) 49 | assert samples["output"].unique().size == 100 50 | -------------------------------------------------------------------------------- /tests/test_ppl.py: -------------------------------------------------------------------------------- 1 | import ergo 2 | 3 | 4 | class TestPPL: 5 | def test_sampling(self): 6 | def model(): 7 | x = ergo.lognormal_from_interval(1, 10, name="x") 8 | y = ergo.beta_from_hits(1, 9, name="y") 9 | z = x * y 10 | ergo.tag(z, "z") 11 | 12 | samples = ergo.run(model, num_samples=2000) 13 | stats = samples.describe() 14 | assert 3.5 < stats["x"]["mean"] < 4.5 15 | assert 0.1 < stats["y"]["mean"] < 0.3 16 | assert 0.6 < stats["z"]["mean"] < 1.0 17 | 18 | def test_tag_output(self): 19 | def model(): 20 | return ergo.normal(7, 1, name="x") 21 | 22 | samples = ergo.run(model, num_samples=2000) 23 | stats = samples.describe() 24 | assert 6 < stats["output"]["mean"] < 8 25 | -------------------------------------------------------------------------------- /tests/test_predictit.py: -------------------------------------------------------------------------------- 1 | import datetime 2 | 3 | import ergo 4 | 5 | 6 | def test_market_attributes(predictit_markets): 7 | """ 8 | Ensure that the PredictIt API hasn't changed. 9 | This test goes through the various attributes of a market and makes sure they were created. 10 | """ 11 | attrs = { 12 | "predictit": ergo.PredictIt, 13 | "api_url": str, 14 | "id": int, 15 | "name": str, 16 | "shortName": str, 17 | "image": str, 18 | "url": str, 19 | "status": str, 20 | "timeStamp": datetime.datetime, 21 | } 22 | for market in predictit_markets: 23 | for attr in attrs.items(): 24 | assert type(getattr(market, attr[0])) is attr[1] 25 | 26 | 27 | def test_question_attributes(predictit_markets): 28 | """ 29 | Ensure that the PredictIt API hasn't changed. 30 | This test goes through the various attributes of a question and makes sure they were created. 31 | """ 32 | attrs = { 33 | "market": ergo.PredictItMarket, 34 | "id": int, 35 | "name": str, 36 | "shortName": str, 37 | "image": str, 38 | "status": str, 39 | "displayOrder": int, 40 | } 41 | for market in predictit_markets: 42 | for question in market.questions: 43 | for attr in attrs.items(): 44 | assert type(getattr(question, attr[0])) is attr[1] 45 | assert ( 46 | type(getattr(question, "dateEnd")) is datetime.datetime 47 | or getattr(question, "dateEnd") is None 48 | ) 49 | -------------------------------------------------------------------------------- /tests/test_rejection.py: -------------------------------------------------------------------------------- 1 | import ergo 2 | 3 | 4 | def test_rejection(): 5 | def model(): 6 | x = ergo.flip() 7 | y = ergo.flip() 8 | ergo.condition(x or y) 9 | return x == y 10 | 11 | samples = ergo.run(model, num_samples=1000) 12 | assert 266 < sum(samples["output"]) < 466 13 | -------------------------------------------------------------------------------- /tests/test_scales.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pytest 3 | import scipy.stats 4 | 5 | from ergo.scale import LogScale, Scale, TimeScale, scale_factory 6 | from tests.conftest import scales_to_test 7 | 8 | 9 | def test_serialization(): 10 | assert hash(Scale(0, 100)) == hash(Scale(0, 100)) 11 | assert hash(Scale(0, 100)) != hash(Scale(100, 200)) 12 | 13 | assert hash(LogScale(0, 100, 10)) == hash(LogScale(0, 100, 10)) 14 | assert hash(LogScale(0, 100, 10)) != hash(LogScale(0, 100, 100)) 15 | 16 | assert hash(TimeScale(946684800, 1592914415)) == hash( 17 | TimeScale(946684800, 1592914415) 18 | ) 19 | assert hash(TimeScale(631152000, 1592914415)) != hash( 20 | TimeScale(946684800, 1592914415) 21 | ) 22 | 23 | assert ( 24 | hash(LogScale(0, 100, 1)) 25 | != hash(Scale(0, 100)) 26 | != hash(TimeScale(631152000, 946684800)) 27 | ) 28 | 29 | 30 | def test_export_import(): 31 | log_scale = LogScale(low=-1, high=1, log_base=2) 32 | log_scale_export = log_scale.export() 33 | assert log_scale_export["width"] == 2 34 | assert log_scale_export["class"] == "LogScale" 35 | 36 | assert (scale_factory(log_scale.export())) == log_scale 37 | 38 | linear_scale = Scale(low=1, high=10000) 39 | assert (scale_factory(linear_scale.export())) == linear_scale 40 | 41 | linear_date_scale = TimeScale(low=631152000, high=946684800) 42 | assert (scale_factory(linear_date_scale.export())) == linear_date_scale 43 | 44 | 45 | @pytest.mark.parametrize("scale", scales_to_test) 46 | def test_density_norm_denorm_roundtrip(scale: Scale): 47 | rv = scipy.stats.logistic(loc=0.5, scale=0.15) 48 | normed_xs = np.linspace(0.01, 1, 201) 49 | normed_densities_truth_set = rv.pdf(normed_xs) 50 | xs = scale.denormalize_points(normed_xs) 51 | 52 | denormed_densities = scale.denormalize_densities(xs, normed_densities_truth_set) 53 | normed_densities = scale.normalize_densities(normed_xs, denormed_densities) 54 | 55 | assert np.allclose(normed_densities_truth_set, normed_densities) # type: ignore 56 | -------------------------------------------------------------------------------- /tests/utils.py: -------------------------------------------------------------------------------- 1 | import numpyro 2 | 3 | 4 | def random_seed(fn): 5 | def wrapped(*args, **kwargs): 6 | with numpyro.handlers.seed(rng_seed=0): 7 | return fn(*args, **kwargs) 8 | 9 | return wrapped 10 | --------------------------------------------------------------------------------