├── .github ├── release.yml └── workflows │ ├── ci.yml │ ├── label-precommit-prs.yml │ ├── release.yml │ ├── rtd-link-preview.yml │ └── uml.yml ├── .gitignore ├── .pre-commit-config.yaml ├── .readthedocs.yaml ├── CONTRIBUTING.md ├── LICENSE ├── Makefile ├── README.md ├── causalpy ├── __init__.py ├── custom_exceptions.py ├── data │ ├── AJR2001.csv │ ├── GDP_in_dollars_billions.csv │ ├── __init__.py │ ├── ancova_generated.csv │ ├── banks.csv │ ├── datasets.py │ ├── deaths_and_temps_england_wales.csv │ ├── did.csv │ ├── drinking.csv │ ├── geolift1.csv │ ├── geolift_multi_cell.csv │ ├── gt_social_media_data.csv │ ├── its.csv │ ├── its_simple.csv │ ├── nhefs.csv │ ├── regression_discontinuity.csv │ ├── schoolingReturns.csv │ ├── simulate_data.py │ └── synthetic_control.csv ├── experiments │ ├── __init__.py │ ├── base.py │ ├── diff_in_diff.py │ ├── instrumental_variable.py │ ├── interrupted_time_series.py │ ├── inverse_propensity_weighting.py │ ├── prepostnegd.py │ ├── regression_discontinuity.py │ ├── regression_kink.py │ └── synthetic_control.py ├── plot_utils.py ├── pymc_models.py ├── skl_models.py ├── tests │ ├── __init__.py │ ├── conftest.py │ ├── test_data_loading.py │ ├── test_input_validation.py │ ├── test_integration_pymc_examples.py │ ├── test_integration_skl_examples.py │ ├── test_misc.py │ ├── test_model_experiment_compatability.py │ ├── test_pymc_models.py │ ├── test_synthetic_data.py │ └── test_utils.py ├── utils.py └── version.py ├── codecov.yml ├── docs ├── Makefile ├── make.bat └── source │ ├── .codespell │ ├── codespell-whitelist.txt │ ├── notebook_to_markdown.py │ ├── test_data │ │ └── test_notebook.ipynb │ └── test_notebook_to_markdown.py │ ├── _static │ ├── classes.png │ ├── favicon_logo.png │ ├── flat_logo.png │ ├── flat_logo_darkmode.png │ ├── interrogate_badge.svg │ ├── iv_reg1.png │ ├── iv_reg2.png │ ├── logo.png │ ├── packages.png │ └── pymc-labs-log.png │ ├── _templates │ └── autosummary │ │ ├── base.rst │ │ ├── class.rst │ │ ├── method.rst │ │ └── module.rst │ ├── api │ └── index.md │ ├── conf.py │ ├── index.md │ ├── knowledgebase │ ├── causal_video_resources.md │ ├── causal_written_resources.md │ ├── design_notation.md │ ├── glossary.rst │ ├── index.md │ └── quasi_dags.ipynb │ ├── notebooks │ ├── ancova_pymc.ipynb │ ├── did_pymc.ipynb │ ├── did_pymc_banks.ipynb │ ├── did_skl.ipynb │ ├── geolift.csv │ ├── geolift1.ipynb │ ├── index.md │ ├── inv_prop_pymc.ipynb │ ├── its_covid.ipynb │ ├── its_pymc.ipynb │ ├── its_skl.ipynb │ ├── iv_pymc.ipynb │ ├── iv_weak_instruments.ipynb │ ├── multi_cell_geolift.ipynb │ ├── rd_pymc.ipynb │ ├── rd_pymc_drinking.ipynb │ ├── rd_skl.ipynb │ ├── rd_skl_drinking.ipynb │ ├── rkink_pymc.ipynb │ ├── sc_pymc.ipynb │ ├── sc_pymc_brexit.ipynb │ └── sc_skl.ipynb │ └── references.bib ├── environment.yml └── pyproject.toml /.github/release.yml: -------------------------------------------------------------------------------- 1 | # This file has been mostly taken verbatim from https://github.com/pymc-devs/pymc/blob/main/.github/release.yml 2 | # 3 | # This file contains configuration for the automatic generation of release notes in GitHub. 4 | # It's not perfect, but it makes it a little less laborious to write informative release notes. 5 | # Also see https://docs.github.com/en/repositories/releasing-projects-on-github/automatically-generated-release-notes 6 | changelog: 7 | exclude: 8 | labels: 9 | - no releasenotes 10 | categories: 11 | - title: Major Changes 🛠 12 | labels: 13 | - major 14 | - title: New Features 🎉 15 | labels: 16 | - enhancement 17 | - feature request 18 | - title: Bugfixes 🐛 19 | labels: 20 | - bug 21 | - title: Documentation 📖 22 | labels: 23 | - documentation 24 | - title: Maintenance 🔧 25 | labels: 26 | - "*" 27 | -------------------------------------------------------------------------------- /.github/workflows/ci.yml: -------------------------------------------------------------------------------- 1 | name: ci 2 | 3 | on: 4 | pull_request: 5 | branches: [main] 6 | push: 7 | branches: [main] 8 | 9 | jobs: 10 | test: 11 | runs-on: ubuntu-latest 12 | strategy: 13 | matrix: 14 | python-version: ["3.10", "3.11", "3.12"] 15 | 16 | steps: 17 | - uses: actions/checkout@v4 18 | - name: Set up Python 19 | uses: actions/setup-python@v5 20 | with: 21 | python-version: ${{ matrix.python-version }} 22 | - name: Update pip and setuptools 23 | run: pip install --upgrade pip setuptools 24 | - name: Setup environment 25 | run: pip install -e .[test] 26 | - name: Run doctests 27 | run: pytest --doctest-modules --ignore=causalpy/tests/ causalpy/ 28 | - name: Run extra tests 29 | run: pytest docs/source/.codespell/test_notebook_to_markdown.py 30 | - name: Run tests 31 | run: pytest --cov-report=xml --no-cov-on-fail 32 | - name: Check codespell for notebooks 33 | run: | 34 | python ./docs/source/.codespell/notebook_to_markdown.py --tempdir tmp_markdown 35 | codespell 36 | - name: Upload coverage to Codecov 37 | uses: codecov/codecov-action@v4 38 | with: 39 | token: ${{ secrets.CODECOV_TOKEN }} # use token for more robust uploads 40 | name: ${{ matrix.python-version }} 41 | fail_ci_if_error: false 42 | -------------------------------------------------------------------------------- /.github/workflows/label-precommit-prs.yml: -------------------------------------------------------------------------------- 1 | name: Label Pre-Commit PRs 2 | 3 | on: 4 | pull_request: 5 | types: [opened, synchronize] 6 | 7 | jobs: 8 | label: 9 | if: github.actor == 'pre-commit-ci[bot]' 10 | runs-on: ubuntu-latest 11 | permissions: 12 | pull-requests: write 13 | steps: 14 | - name: Add "no releasenotes" label 15 | uses: actions-ecosystem/action-add-labels@v1 16 | with: 17 | github_token: ${{ secrets.GITHUB_TOKEN }} 18 | labels: no releasenotes 19 | -------------------------------------------------------------------------------- /.github/workflows/release.yml: -------------------------------------------------------------------------------- 1 | name: PyPI release 2 | 3 | on: 4 | workflow_dispatch: 5 | pull_request: 6 | branches: [main] 7 | push: 8 | branches: [main] 9 | release: 10 | types: [published] 11 | 12 | jobs: 13 | build: 14 | name: Build source distribution 15 | runs-on: ubuntu-latest 16 | steps: 17 | - uses: actions/checkout@v4 18 | with: 19 | fetch-depth: 0 20 | - uses: actions/setup-python@v5 21 | with: 22 | python-version: 3.11 23 | - name: Build the sdist and the wheel 24 | run: | 25 | pip install build 26 | python -m build 27 | ls dist # List the contents of the dist directory 28 | - name: Check the sdist installs and imports 29 | run: | 30 | mkdir -p test-sdist 31 | cd test-sdist 32 | python -m venv venv-sdist 33 | venv-sdist/bin/python -m pip install ../dist/causalpy*.tar.gz 34 | echo "Checking import and version number (on release)" 35 | venv-sdist/bin/python -c "import causalpy; assert causalpy.__version__ == '${{ github.ref_name }}' if '${{ github.ref_type }}' == 'tag' else causalpy.__version__; print(causalpy.__version__)" 36 | cd .. 37 | - name: Check the bdist installs and imports 38 | run: | 39 | mkdir -p test-bdist 40 | cd test-bdist 41 | python -m venv venv-bdist 42 | venv-bdist/bin/python -m pip install ../dist/causalpy*.whl 43 | echo "Checking import and version number (on release)" 44 | venv-bdist/bin/python -c "import causalpy; assert causalpy.__version__ == '${{ github.ref_name }}' if '${{ github.ref_type }}' == 'tag' else causalpy.__version__; print(causalpy.__version__)" 45 | cd .. 46 | - uses: actions/upload-artifact@v4 47 | with: 48 | name: artifact 49 | path: dist/* 50 | 51 | test: 52 | name: Upload to Test PyPI 53 | permissions: 54 | id-token: write 55 | needs: [build] 56 | runs-on: ubuntu-latest 57 | if: github.event_name == 'release' && github.event.action == 'published' 58 | steps: 59 | - uses: actions/download-artifact@v4 60 | with: 61 | name: artifact 62 | path: dist 63 | - uses: pypa/gh-action-pypi-publish@release/v1 64 | with: 65 | skip_existing: true 66 | repository_url: https://test.pypi.org/legacy/ 67 | - uses: actions/setup-python@v5 68 | with: 69 | python-version: 3.11 70 | - name: Test pip install from test.pypi 71 | run: | 72 | # Give time to test.pypi to update its index. If we don't wait, 73 | # we might request to install before test.pypi is aware that it actually has the package 74 | sleep 5s 75 | python -m venv venv-test-pypi 76 | venv-test-pypi/bin/python -m pip install --index-url https://test.pypi.org/simple/ --extra-index-url https://pypi.org/simple causalpy 77 | echo "Checking import and version number" 78 | venv-test-pypi/bin/python -c "import causalpy; assert causalpy.__version__ == '${{ github.ref_name }}'" 79 | 80 | publish: 81 | environment: release 82 | permissions: 83 | id-token: write 84 | name: Upload release to PyPI 85 | needs: [build, test] 86 | runs-on: ubuntu-latest 87 | if: github.event_name == 'release' && github.event.action == 'published' 88 | steps: 89 | - uses: actions/download-artifact@v4 90 | with: 91 | name: artifact 92 | path: dist 93 | - uses: pypa/gh-action-pypi-publish@release/v1 94 | -------------------------------------------------------------------------------- /.github/workflows/rtd-link-preview.yml: -------------------------------------------------------------------------------- 1 | name: Read the Docs Pull Request Preview 2 | on: 3 | pull_request_target: 4 | types: 5 | - opened 6 | 7 | permissions: 8 | pull-requests: write 9 | 10 | jobs: 11 | documentation-links: 12 | runs-on: ubuntu-latest 13 | steps: 14 | - uses: readthedocs/actions/preview@v1 15 | with: 16 | project-slug: "causalpy" 17 | -------------------------------------------------------------------------------- /.github/workflows/uml.yml: -------------------------------------------------------------------------------- 1 | name: Update the UML Diagrams 2 | on: 3 | workflow_dispatch: 4 | schedule: 5 | - cron: '0 12 * * 1' 6 | 7 | jobs: 8 | build: 9 | runs-on: ubuntu-latest 10 | permissions: write-all 11 | steps: 12 | 13 | - name: Checkout repository 14 | uses: actions/checkout@v4 15 | with: 16 | ref: main 17 | 18 | - name: Set up Python 19 | uses: actions/setup-python@v5 20 | with: 21 | python-version: "3.10" 22 | 23 | - name: Install dependencies 24 | run: | 25 | python -m pip install --upgrade pip 26 | pip install 'causalpy[docs]' 27 | sudo apt-get update && sudo apt-get install -y graphviz 28 | 29 | - name: Install pylint explicitly 30 | run: python -m pip install pylint 31 | 32 | - name: Verify pylint and pyreverse 33 | run: | 34 | python -m pip show pylint 35 | which pyreverse 36 | pyreverse --version 37 | 38 | - name: Configure Git Identity 39 | run: | 40 | git config user.name 'github-actions[bot]' 41 | git config user.email 'github-actions[bot]@users.noreply.github.com' 42 | 43 | - name: Update the UML Diagrams 44 | run: | 45 | make uml 46 | 47 | - name: Detect UML changes 48 | id: changes 49 | run: | 50 | git add docs/source/_static/*.png 51 | if git diff --staged --exit-code; then 52 | echo "No changes to commit" 53 | echo "changes_exist=false" >> $GITHUB_OUTPUT 54 | else 55 | echo "changes_exist=true" >> $GITHUB_OUTPUT 56 | fi 57 | 58 | - name: Create PR for changes 59 | if: steps.changes.outputs.changes_exist == 'true' 60 | run: | 61 | git checkout -b update-uml-diagrams 62 | git commit -m "Update UML Diagrams" 63 | git push -u origin update-uml-diagrams 64 | gh pr create \ 65 | --base main \ 66 | --title "Update UML Diagrams" \ 67 | --body "This PR updates the UML diagrams 68 | This PR was created automatically by the [UML workflow](https://github.com/pymc-labs/CausalPy/blob/main/.github/workflows/uml.yml). 69 | See the logs [here](https://github.com/pymc-labs/CausalPy/actions/workflows/uml.yml) for more details." \ 70 | --label "no releasenotes" \ 71 | --reviewer drbenvincent 72 | env: 73 | GH_TOKEN: ${{ github.token }} 74 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__ 2 | _build 3 | .ipynb_checkpoints 4 | *_checkpoints 5 | *.DS_Store 6 | *.egg-info 7 | build/ 8 | dist/ 9 | *.vscode 10 | .coverage 11 | *.jupyterlab-workspace 12 | 13 | # Sphinx documentation 14 | docs/build/ 15 | docs/jupyter_execute/ 16 | docs/source/api/generated/ 17 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | ci: 2 | autofix_prs: false 3 | 4 | # See https://pre-commit.com for more information 5 | # See https://pre-commit.com/hooks.html for more hooks 6 | repos: 7 | - repo: https://github.com/lucianopaz/head_of_apache 8 | rev: "0.1.1" 9 | hooks: 10 | - id: head_of_apache 11 | args: 12 | - --author=The PyMC Labs Developers 13 | - --exclude=docs/ 14 | - --exclude=scripts/ 15 | - repo: https://github.com/pre-commit/pre-commit-hooks 16 | rev: v5.0.0 17 | hooks: 18 | - id: debug-statements 19 | - id: trailing-whitespace 20 | exclude_types: [svg] 21 | - id: end-of-file-fixer 22 | exclude_types: [svg] 23 | - id: check-yaml 24 | - id: check-added-large-files 25 | exclude: &exclude_pattern 'iv_weak_instruments.ipynb' 26 | args: ["--maxkb=1500"] 27 | - repo: https://github.com/astral-sh/ruff-pre-commit 28 | rev: v0.11.11 29 | hooks: 30 | # Run the linter 31 | - id: ruff 32 | types_or: [ python, pyi, jupyter ] 33 | args: [ --fix ] 34 | # Run the formatter 35 | - id: ruff-format 36 | types_or: [ python, pyi, jupyter ] 37 | - repo: https://github.com/econchick/interrogate 38 | rev: 1.7.0 39 | hooks: 40 | - id: interrogate 41 | # needed to make excludes in pyproject.toml work 42 | # see here https://github.com/econchick/interrogate/issues/60#issuecomment-735436566 43 | pass_filenames: false 44 | - repo: https://github.com/codespell-project/codespell 45 | rev: v2.4.1 46 | hooks: 47 | - id: codespell 48 | additional_dependencies: 49 | # Support pyproject.toml configuration 50 | - tomli 51 | -------------------------------------------------------------------------------- /.readthedocs.yaml: -------------------------------------------------------------------------------- 1 | # .readthedocs.yaml 2 | # Read the Docs configuration file 3 | # See https://docs.readthedocs.io/en/stable/config-file/v2.html for details 4 | 5 | # Required 6 | version: 2 7 | 8 | # Set the version of Python and other tools you might need 9 | build: 10 | os: ubuntu-lts-latest 11 | tools: 12 | python: "3.11" 13 | # You can also specify other tool versions: 14 | # nodejs: "16" 15 | # rust: "1.55" 16 | # golang: "1.17" 17 | 18 | # Build documentation in the docs/ directory with Sphinx 19 | sphinx: 20 | configuration: docs/source/conf.py 21 | 22 | # If using Sphinx, optionally build your docs in additional formats such as PDF 23 | # formats: 24 | # - pdf 25 | 26 | # Optionally declare the Python requirements required to build your docs 27 | python: 28 | install: 29 | - method: pip 30 | path: . 31 | extra_requirements: 32 | - docs 33 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Guidelines for Contributing 2 | 3 | CausalPy welcomes contributions from interested individuals or groups. These guidelines are provided to give potential contributors information to make their contribution compliant with the conventions of the CausalPy project, and maximize the probability of such contributions are merged as quickly and efficiently as possible. Contributors need not be experts, but should be interested in the project, willing to learn, and share knowledge. 4 | 5 | There are 4 main ways of contributing to the CausalPy project (in ascending order of difficulty or scope): 6 | 7 | 1. Submitting issues related to bugs or desired enhancements. 8 | 2. Contributing or improving the documentation (docs) or examples. 9 | 3. Fixing outstanding issues (bugs) with the existing codebase. They range from low-level software bugs to higher-level design problems. 10 | 4. Adding new or improved functionality to the existing codebase. 11 | 12 | Items 2-4 require setting up a local development environment, see [Local development steps](#Local-development-steps) for more information. 13 | 14 | ## Opening issues 15 | 16 | We appreciate being notified of problems with the existing CausalPy code. We prefer that issues be filed the on [Github Issue Tracker](https://github.com/pymc-labs/CausalPy/issues), rather than on social media or by direct email to the developers. 17 | 18 | Please verify that your issue is not being currently addressed by other issues or pull requests by using the GitHub search tool to look for key words in the project issue tracker. 19 | 20 | ## Contributing code via pull requests 21 | 22 | While issue reporting is valuable, we strongly encourage users who are inclined to do so to submit patches for new or existing issues via pull requests. This is particularly the case for simple fixes, such as typos or tweaks to documentation, which do not require a heavy investment of time and attention. 23 | 24 | Contributors are also encouraged to contribute new code to enhance CausalPy's functionality, via pull requests. 25 | 26 | The preferred workflow for contributing to CausalPy is to fork the GitHub repository, clone it to your local machine, and develop on a feature branch. 27 | 28 | For more instructions see the [Pull request checklist](#pull-request-checklist) 29 | 30 | ## Local development steps 31 | 32 | 1. If you have not already done so, fork the [project repository](https://github.com/pymc-labs/CausalPy) by clicking on the 'Fork' button near the top right of the main repository page. This creates a copy of the code under your GitHub user account. 33 | 34 | 1. Clone your fork of the `CausalPy` repo from your GitHub account to your local disk, and add the base repository as a remote: 35 | 36 | ```bash 37 | git clone git@github.com:/CausalPy.git 38 | cd CausalPy 39 | git remote add upstream git@github.com:pymc-labs/CausalPy.git 40 | ``` 41 | 42 | 1. Create a feature branch (e.g. `my-feature`) to hold your development changes: 43 | 44 | ```bash 45 | git checkout -b my-feature 46 | ``` 47 | 48 | Always use a feature branch. It's good practice to never routinely work on the `main` branch of any repository. 49 | 50 | 1. Create the environment from the `environment.yml` file. 51 | 52 | ```bash 53 | mamba env create -f environment.yml 54 | ``` 55 | 56 | Activate the environment. 57 | 58 | ```bash 59 | conda activate CausalPy 60 | ``` 61 | 62 | Install the package (in editable mode) and its development dependencies. The `--no-deps` flag is used to avoid installing the dependencies of `CausalPy` as they are already installed when installing the development dependencies. This can end up interfering with the conda-only install of pymc. 63 | 64 | ```bash 65 | pip install --no-deps -e . 66 | ``` 67 | 68 | Install development dependencies 69 | 70 | ```bash 71 | pip install 'causalpy[dev]' 72 | pip install 'causalpy[docs]' 73 | pip install 'causalpy[test]' 74 | pip install 'causalpy[lint]' 75 | pip install 'pylint' 76 | ``` 77 | 78 | It may also be necessary to [install](https://pandoc.org/installing.html) `pandoc`. On a mac, run `brew install pandoc`. 79 | 80 | Set [pre-commit hooks](https://pre-commit.com/) 81 | 82 | ```bash 83 | pre-commit install 84 | ``` 85 | 86 | If you are editing or writing new examples in the form of Jupyter notebooks, you may have to run the following command to make Jupyter Lab aware of the `CausalPy` environment. 87 | 88 | ``` 89 | python -m ipykernel install --user --name CausalPy 90 | ``` 91 | 92 | 1. You can then work on your changes locally, in your feature branch. Add changed files using `git add` and then `git commit` files: 93 | 94 | ```bash 95 | git add modified_files 96 | git commit -m "Message summarizing commit changes" 97 | ``` 98 | 99 | to record your changes locally. 100 | After committing, it is a good idea to sync with the base repository in case there have been any changes: 101 | 102 | ```bash 103 | git fetch upstream 104 | git rebase upstream/main 105 | ``` 106 | 107 | Then push the changes to your GitHub account with: 108 | 109 | ```bash 110 | git push -u origin my-feature 111 | ``` 112 | 113 | 1. Before you submit a Pull Request, follow the [Pull request checklist](#pull-request-checklist). 114 | 115 | 1. Finally, to submit a pull request, go to the GitHub web page of your fork of the CausalPy repo. Click the 'Pull request' button to send your changes to the project's maintainers for review. This will send an email to the committers. 116 | 117 | ## Pull request checklist 118 | 119 | We recommend that your contribution complies with the following guidelines before you submit a pull request: 120 | 121 | - If your pull request addresses an issue, please use the pull request title to describe the issue and mention the issue number in the pull request description. This will make sure a link back to the original issue is created. 122 | 123 | - All public methods must have informative docstrings with sample usage when appropriate. 124 | 125 | - Example usage in docstrings is tested via doctest, which can be run via 126 | 127 | ```bash 128 | make doctest 129 | ``` 130 | 131 | - Doctest can also be run directly via pytest, which can be helpful to run only specific tests during development. The following commands run all doctests, only doctests in the pymc_models module, and only the doctests for the `PyMCModel` class in pymc_models: 132 | 133 | ```bash 134 | pytest --doctest-modules causalpy/ 135 | pytest --doctest-modules causalpy/pymc_models.py 136 | pytest --doctest-modules causalpy/pmyc_models.py::causalpy.pymc_models.PyMCModel 137 | ``` 138 | 139 | - To indicate a work in progress please mark the PR as `draft`. Drafts may be useful to (1) indicate you are working on something to avoid duplicated work, (2) request broad review of functionality or API, or (3) seek collaborators. 140 | 141 | - All other tests pass when everything is rebuilt from scratch. Tests can be run with: 142 | 143 | ```bash 144 | make test 145 | ``` 146 | 147 | - When adding additional functionality, either edit an existing example, or create a new example (typically in the form of a Jupyter Notebook). Have a look at other examples for reference. Examples should demonstrate why the new functionality is useful in practice. 148 | 149 | - Documentation and high-coverage tests are necessary for enhancements to be accepted. 150 | 151 | - Documentation follows [NumPy style guide](https://numpydoc.readthedocs.io/en/latest/format.html) 152 | 153 | - If you have changed the documentation, you should [build the docs locally](#Building-the-documentation-locally) and check that the changes look correct. 154 | 155 | - Run any of the pre-existing examples in `CausalPy/docs/source/*` that contain analyses that would be affected by your changes to ensure that nothing breaks. This is a useful opportunity to not only check your work for bugs that might not be revealed by unit test, but also to show how your contribution improves CausalPy for end users. 156 | 157 | - Your code passes linting tests. Run the line below to check linting errors: 158 | 159 | ```bash 160 | make check_lint 161 | ``` 162 | If you want to fix linting errors automatically, run 163 | 164 | ```bash 165 | make lint 166 | ``` 167 | 168 | ## Building the documentation locally 169 | 170 | A local build of the docs is achieved by: 171 | 172 | ```bash 173 | cd docs 174 | make html 175 | ``` 176 | 177 | Sometimes not all changes are recognised. In that case run this (again from within the `docs` folder): 178 | 179 | ```bash 180 | make clean && make html 181 | ``` 182 | 183 | Docs are built in `docs/_build`, but these docs are _not_ committed to the GitHub repository due to `.gitignore`. 184 | 185 | ## Overview of code structure 186 | 187 | Classes 188 | ![](docs/source/_static/classes.png) 189 | 190 | Packages 191 | ![](docs/source/_static/packages.png) 192 | 193 | UML diagrams can be created with the command below. 194 | 195 | ```bash 196 | make uml 197 | ``` 198 | 199 | --- 200 | 201 | This guide takes some inspiration from the [Bambi guide to contributing](https://github.com/bambinos/bambi/blob/main/docs/CONTRIBUTING.md) 202 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | .PHONY: init lint check_lint test 2 | 3 | init: 4 | python -m pip install -e . --no-deps 5 | 6 | lint: 7 | ruff check --fix . 8 | ruff format . 9 | 10 | check_lint: 11 | ruff check . 12 | ruff format --diff --check . 13 | interrogate . 14 | 15 | doctest: 16 | pytest --doctest-modules --ignore=causalpy/tests/ causalpy/ 17 | 18 | test: 19 | pytest 20 | 21 | uml: 22 | pyreverse -o png causalpy --output-directory docs/source/_static --ignore tests 23 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |
2 | 3 |
4 | 5 | ---- 6 | 7 | ![Build Status](https://github.com/pymc-labs/CausalPy/actions/workflows/ci.yml/badge.svg?branch=main) 8 | [![Ruff](https://img.shields.io/endpoint?url=https://raw.githubusercontent.com/astral-sh/ruff/main/assets/badge/v2.json)](https://github.com/astral-sh/ruff) 9 | [![PyPI version](https://badge.fury.io/py/CausalPy.svg)](https://badge.fury.io/py/CausalPy) 10 | ![GitHub Repo stars](https://img.shields.io/github/stars/pymc-labs/causalpy?style=social) 11 | ![Read the Docs](https://img.shields.io/readthedocs/causalpy) 12 | ![PyPI - Downloads](https://img.shields.io/pypi/dm/causalpy) 13 | ![Interrogate](docs/source/_static/interrogate_badge.svg) 14 | [![codecov](https://codecov.io/gh/pymc-labs/CausalPy/branch/main/graph/badge.svg?token=FDKNAY5CZ9)](https://codecov.io/gh/pymc-labs/CausalPy) 15 | 16 | # CausalPy 17 | 18 | A Python package focussing on causal inference in quasi-experimental settings. The package allows for sophisticated Bayesian model fitting methods to be used in addition to traditional OLS. 19 | 20 | ## Installation 21 | 22 | To get the latest release: 23 | ```bash 24 | pip install CausalPy 25 | ``` 26 | 27 | Alternatively, if you want the very latest version of the package you can install from GitHub: 28 | 29 | ```bash 30 | pip install git+https://github.com/pymc-labs/CausalPy.git 31 | ``` 32 | 33 | ## Quickstart 34 | 35 | ```python 36 | import causalpy as cp 37 | import matplotlib.pyplot as plt 38 | 39 | # Import and process data 40 | df = ( 41 | cp.load_data("drinking") 42 | .rename(columns={"agecell": "age"}) 43 | .assign(treated=lambda df_: df_.age > 21) 44 | ) 45 | 46 | # Run the analysis 47 | result = cp.RegressionDiscontinuity( 48 | df, 49 | formula="all ~ 1 + age + treated", 50 | running_variable_name="age", 51 | model=cp.pymc_models.LinearRegression(), 52 | treatment_threshold=21, 53 | ) 54 | 55 | # Visualize outputs 56 | fig, ax = result.plot(); 57 | 58 | # Get a results summary 59 | result.summary() 60 | 61 | plt.show() 62 | ``` 63 | 64 | ## Roadmap 65 | 66 | Plans for the repository can be seen in the [Issues](https://github.com/pymc-labs/CausalPy/issues). 67 | 68 | ## Videos 69 | Click on the thumbnail below to watch a video about CausalPy on YouTube. 70 | [![Youtube video thumbnail image](https://img.youtube.com/vi/gV6wzTk3o1U/maxresdefault.jpg)](https://www.youtube.com/watch?v=gV6wzTk3o1U) 71 | 72 | ## Features 73 | 74 | CausalPy has a broad range of quasi-experimental methods for causal inference: 75 | 76 | | Method | Description | 77 | |-|-| 78 | | Synthetic control | Constructs a synthetic version of the treatment group from a weighted combination of control units. Used for causal inference in comparative case studies when a single unit is treated, and there are multiple control units.| 79 | | Geographical lift | Measures the impact of an intervention in a specific geographic area by comparing it to similar areas without the intervention. Commonly used in marketing to assess regional campaigns. | 80 | | ANCOVA | Analysis of Covariance combines ANOVA and regression to control for the effects of one or more quantitative covariates. Used when comparing group means while controlling for other variables. | 81 | | Differences in Differences | Compares the changes in outcomes over time between a treatment group and a control group. Used in observational studies to estimate causal effects by accounting for time trends. | 82 | | Regression discontinuity | Identifies causal effects by exploiting a cutoff or threshold in an assignment variable. Used when treatment is assigned based on a threshold value of an observed variable, allowing comparison just above and below the cutoff. | 83 | | Regression kink designs | Focuses on changes in the slope (kinks) of the relationship between variables rather than jumps at cutoff points. Used to identify causal effects when treatment intensity changes at a threshold. | 84 | | Interrupted time series | Analyzes the effect of an intervention by comparing time series data before and after the intervention. Used when data is collected over time and an intervention occurs at a known point, allowing assessment of changes in level or trend. | 85 | | Instrumental variable regression | Addresses endogeneity by using an instrument variable that is correlated with the endogenous explanatory variable but uncorrelated with the error term. Used when explanatory variables are correlated with the error term, providing consistent estimates of causal effects. | 86 | | Inverse Propensity Score Weighting | Weights observations by the inverse of the probability of receiving the treatment. Used in causal inference to create a synthetic sample where the treatment assignment is independent of measured covariates, helping to adjust for confounding variables in observational studies. | 87 | 88 | ## License 89 | 90 | [Apache License 2.0](LICENSE) 91 | 92 | --- 93 | 94 | ## Support 95 | 96 | 97 | 98 | This repository is supported by [PyMC Labs](https://www.pymc-labs.com). 99 | 100 | If you are interested in seeing what PyMC Labs can do for you, then please email [ben.vincent@pymc-labs.com](mailto:ben.vincent@pymc-labs.com). We work with companies at a variety of scales and with varying levels of existing modeling capacity. We also run corporate workshop training events and can provide sessions ranging from introduction to Bayes to more advanced topics. 101 | -------------------------------------------------------------------------------- /causalpy/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 - 2025 The PyMC Labs Developers 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | import arviz as az 15 | 16 | import causalpy.pymc_models as pymc_models 17 | import causalpy.skl_models as skl_models 18 | from causalpy.skl_models import create_causalpy_compatible_class 19 | from causalpy.version import __version__ 20 | 21 | from .data import load_data 22 | from .experiments.diff_in_diff import DifferenceInDifferences 23 | from .experiments.instrumental_variable import InstrumentalVariable 24 | from .experiments.interrupted_time_series import InterruptedTimeSeries 25 | from .experiments.inverse_propensity_weighting import InversePropensityWeighting 26 | from .experiments.prepostnegd import PrePostNEGD 27 | from .experiments.regression_discontinuity import RegressionDiscontinuity 28 | from .experiments.regression_kink import RegressionKink 29 | from .experiments.synthetic_control import SyntheticControl 30 | 31 | az.style.use("arviz-darkgrid") 32 | 33 | __all__ = [ 34 | "__version__", 35 | "DifferenceInDifferences", 36 | "create_causalpy_compatible_class", 37 | "InstrumentalVariable", 38 | "InterruptedTimeSeries", 39 | "InversePropensityWeighting", 40 | "load_data", 41 | "PrePostNEGD", 42 | "pymc_models", 43 | "RegressionDiscontinuity", 44 | "RegressionKink", 45 | "skl_models", 46 | "SyntheticControl", 47 | ] 48 | -------------------------------------------------------------------------------- /causalpy/custom_exceptions.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 - 2025 The PyMC Labs Developers 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """ 15 | Custom Exceptions for CausalPy. 16 | """ 17 | 18 | 19 | class BadIndexException(Exception): 20 | """Custom exception used when we have a mismatch in types between the dataframe 21 | index and an event, typically a treatment or intervention.""" 22 | 23 | def __init__(self, message: str): 24 | self.message = message 25 | 26 | 27 | class FormulaException(Exception): 28 | """Exception raised given when there is some error in a user-provided model 29 | formula""" 30 | 31 | def __init__(self, message: str): 32 | self.message = message 33 | 34 | 35 | class DataException(Exception): 36 | """Exception raised given when there is some error in user-provided dataframe""" 37 | 38 | def __init__(self, message: str): 39 | self.message = message 40 | -------------------------------------------------------------------------------- /causalpy/data/GDP_in_dollars_billions.csv: -------------------------------------------------------------------------------- 1 | Time,Australia,Austria,Belgium,Canada,Denmark,Finland,France,Germany,Iceland,Italy,Japan,Luxemburg,Netherlands,New_Zealand,Norway,Portugal,Spain,Sweden,Switzerland,UK,US 2 | 2008-07-01,3.8237,0.835645,0.97149,17.53314,4.67639,0.55785,5.21976,7.07351,5.44478,4.434069,5.192159,0.12152299999999999,1.706577,0.47764,7.82785,0.48438800000000004,2.6986,10.86125,1.5433882,4.79983,157.09562 3 | 2008-10-01,3.80552,0.8169080000000001,0.95042,17.32984,4.56596,0.54588,5.14178,6.96078,5.74929,4.317894,5.065315,0.117046,1.695315,0.47741,7.85237,0.478131,2.86664,10.47702,1.4999767000000002,4.70358,153.66607 4 | 2009-01-01,3.84048,0.8028360000000001,0.94117,16.93824,4.50096,0.51052,5.0545,6.63471,5.18157,4.201581,4.822338,0.11483600000000001,1.6343910000000001,0.47336,7.78753,0.466098,2.58309,10.3222,1.4765317000000002,4.61881,151.87475 5 | 2009-04-01,3.86954,0.796545,0.94162,16.7534,4.41372,0.50829,5.05375,6.6453,5.16171,4.1882090000000005,4.915681,0.116259,1.634432,0.47916,7.71903,0.466515,2.72044,10.32867,1.485509,4.60431,151.61772 6 | 2009-07-01,3.88115,0.799937,0.95352,16.82878,4.42898,0.51299,5.06237,6.68237,5.24132,4.212578,4.912656,0.118747,1.6409820000000002,0.48188,7.724,0.47063099999999997,2.6055,10.32328,1.5025060000000001,4.60722,152.16647 7 | 2009-10-01,3.91028,0.8038230000000001,0.96117,17.02503,4.433,0.50903,5.09832,6.73155,5.22482,4.226248,4.974765,0.119302,1.6508660000000002,0.48805,7.72812,0.470856,2.7842,10.37107,1.5151385000000002,4.62152,153.79155 8 | 2010-01-01,3.92716,0.80051,0.96615,17.23041,4.47128,0.51413,5.11625,6.78621,4.91128,4.237582,5.027625,0.121414,1.647748,0.49349,7.87891,0.47473099999999996,2.57598,10.64833,1.5258639999999999,4.6538,154.56059 9 | 2010-04-01,3.95387,0.811277,0.97567,17.32057,4.50184,0.52833,5.14098,6.93903,5.10614,4.269259,5.086985,0.122733,1.654981,0.49609,7.79516,0.477354,2.73434,10.86674,1.5384847,4.70655,156.05628 10 | 2010-07-01,3.98175,0.81855,0.97974,17.44332,4.57321,0.52616,5.17273,6.99577,5.00671,4.289668,5.178457,0.120794,1.6622929999999998,0.48887,7.60172,0.477754,2.61706,10.99898,1.5475226999999998,4.73954,157.26282 11 | 2010-10-01,4.01663,0.828113,0.98469,17.63825,4.56292,0.53721,5.20808,7.05251,5.19581,4.313207,5.137305,0.121695,1.681046,0.48355,7.81512,0.47682699999999995,2.79971,11.17711,1.5593448,4.74526,158.07995 12 | 2011-01-01,4.0046,0.8352289999999999,0.98943,17.77148,4.57286,0.54021,5.26173,7.18926,4.98014,4.333769,5.083540999999999,0.123155,1.690748,0.49175,7.82243,0.473676,2.58719,11.21731,1.5693595999999999,4.7672,157.69911 13 | 2011-04-01,4.05656,0.837287,0.99148,17.8061,4.6187,0.53927,5.25703,7.19607,5.11759,4.335718,5.039066,0.121965,1.689237,0.49628,7.79009,0.471644,2.72481,11.25308,1.5767164000000002,4.77202,158.76839 14 | 2011-07-01,4.11156,0.842913,0.99432,18.05176,4.56125,0.53993,5.28283,7.25962,5.20997,4.313057,5.16075,0.12346,1.6892429999999998,0.50209,7.91561,0.467962,2.59493,11.40015,1.5707735999999999,4.78726,158.70684 15 | 2011-10-01,4.15576,0.84265,0.99722,18.19392,4.59853,0.54008,5.29302,7.23616,5.28546,4.27216,5.154163,0.123072,1.679051,0.5062,7.91969,0.46104300000000004,2.7307,11.24116,1.5725503,4.79335,160.48702 16 | 2012-01-01,4.19499,0.8493639999999999,0.99936,18.20558,4.59533,0.53772,5.29395,7.2513,5.11087,4.225777,5.22599,0.123716,1.675556,0.50695,8.09152,0.45852699999999996,2.52313,11.26089,1.5888463000000002,4.82583,161.79968 17 | 2012-04-01,4.22634,0.845857,0.99935,18.26496,4.59879,0.53203,5.28606,7.26643,5.18428,4.194572,5.178464,0.12370600000000001,1.67641,0.51008,8.1106,0.45234599999999997,2.62527,11.2776,1.587918,4.82299,162.53726 18 | 2012-07-01,4.25166,0.844361,1.00178,18.28984,4.60311,0.52998,5.29839,7.28685,5.21788,4.172838,5.158104,0.125012,1.669214,0.513,7.99303,0.447237,2.5334,11.26257,1.5990306,4.88205,162.82151 19 | 2012-10-01,4.27205,0.84392,1.00132,18.32766,4.59567,0.52957,5.29273,7.25432,5.29916,4.141455,5.154555,0.127251,1.6575339999999998,0.51894,8.05836,0.44016900000000003,2.62919,11.18393,1.5975844000000001,4.87,163.00035 20 | 2013-01-01,4.29019,0.8424010000000001,0.99857,18.49206,4.62158,0.52503,5.29616,7.22122,5.33395,4.102825,5.226566,0.127256,1.662886,0.52179,8.0705,0.442,2.46846,11.33892,1.6044659,4.89236,164.41485 21 | 2013-04-01,4.30869,0.843032,1.0044,18.59938,4.62501,0.52745,5.33164,7.29839,5.42017,4.103628,5.273670999999999,0.129726,1.6598810000000002,0.52008,8.1288,0.44535199999999997,2.59228,11.32686,1.6188392,4.92451,164.64402 22 | 2013-07-01,4.34261,0.846905,1.00751,18.75096,4.65431,0.52929,5.32621,7.33924,5.4072,4.113084000000001,5.323485,0.12981399999999998,1.669891,0.52474,8.20853,0.44473199999999996,2.51199,11.37816,1.6313214,4.96229,165.94743 23 | 2013-10-01,4.37749,0.85024,1.00971,18.94795,4.66367,0.52834,5.36163,7.36043,5.59834,4.1039829999999995,5.317018,0.128722,1.68049,0.52764,8.19398,0.44960199999999995,2.63075,11.45938,1.6367806,4.98839,167.1276 24 | 2014-01-01,4.41048,0.8483289999999999,1.01405,18.97892,4.67548,0.52426,5.36183,7.43211,5.43661,4.11047,5.3616,0.131386,1.678897,0.53332,8.23666,0.446716,2.48197,11.55366,1.6470401000000001,5.03492,166.54247 25 | 2014-04-01,4.43084,0.853993,1.01674,19.15226,4.67209,0.52511,5.36653,7.43211,5.43975,4.109159,5.264161,0.129623,1.688726,0.53587,8.29753,0.448202,2.60733,11.64891,1.6581819,5.07551,168.68109 26 | 2014-07-01,4.45209,0.85319,1.02376,19.33594,4.7503,0.52714,5.39723,7.46994,5.63207,4.1133809999999995,5.267956999999999,0.13231600000000002,1.692948,0.5425,8.33241,0.448627,2.5478,11.73151,1.6673167999999998,5.10988,170.64616 27 | 2014-10-01,4.46827,0.854363,1.02911,19.46974,4.76734,0.52588,5.39537,7.53046,5.61836,4.102336,5.29257,0.135719,1.7081110000000002,0.54982,8.43427,0.452256,2.68448,11.82275,1.6761612,5.13852,171.41235 28 | 2015-01-01,4.50884,0.856751,1.03446,19.36275,4.79998,0.52179,5.4238,7.49188,5.51916,4.113187,5.375484,0.134703,1.718281,0.55348,8.41262,0.45517199999999997,2.57233,11.98956,1.6733119,5.17165,172.80647 29 | 2015-04-01,4.5137,0.858388,1.04141,19.31005,4.82243,0.52973,5.42994,7.5471,5.86313,4.129486,5.382865,0.13561600000000001,1.723575,0.5585,8.45721,0.45652699999999996,2.72132,12.10879,1.6853151000000002,5.20984,173.80875 30 | 2015-07-01,4.56084,0.863285,1.0434,19.37835,4.83698,0.52968,5.44213,7.58115,5.81588,4.1394,5.387195999999999,0.134972,1.729509,0.56733,8.55631,0.457046,2.66375,12.27554,1.6959884,5.23783,174.3708 31 | 2015-10-01,4.58825,0.8643789999999999,1.04775,19.39286,4.84776,0.53265,5.45297,7.61443,5.9103,4.161275,5.378318,0.13584,1.729793,0.57376,8.48118,0.45923699999999995,2.8185,12.36629,1.7029607999999998,5.27344,174.62579 32 | 2016-01-01,4.62773,0.873659,1.04835,19.49923,4.9036,0.53922,5.48707,7.67836,5.73509,4.172618,5.419015,0.139863,1.745688,0.57866,8.49252,0.46138,2.65844,12.35977,1.7101123999999999,5.29673,175.65465 33 | 2016-04-01,4.658,0.872602,1.05411,19.40335,4.96993,0.54006,5.47082,7.71241,6.13539,4.181198,5.411239,0.14183,1.7497989999999999,0.58517,8.48983,0.46276300000000004,2.81788,12.36546,1.7190254999999999,5.32731,176.18581 34 | 2016-07-01,4.66437,0.877404,1.05596,19.60344,5.0121,0.54616,5.48611,7.73208,6.24267,4.204893,5.42041,0.142868,1.7693329999999998,0.58918,8.46376,0.468196,2.75793,12.40342,1.7283486,5.35103,177.24489 35 | 2016-10-01,4.71319,0.886402,1.06138,19.71351,5.04821,0.54784,5.51862,7.76007,6.45202,4.216942,5.427695,0.14357899999999998,1.78415,0.59054,8.57712,0.472559,2.90415,12.51088,1.7343463000000001,5.39059,178.1256 36 | 2017-01-01,4.72423,0.8904139999999999,1.0683,19.92778,5.08124,0.55343,5.56005,7.85351,6.19735,4.2401,5.47353,0.14136100000000001,1.7935070000000002,0.59832,8.66243,0.47809199999999996,2.76574,12.57115,1.7372078,5.42562,178.96623 37 | 2017-04-01,4.75627,0.895741,1.07088,20.13165,5.13804,0.56048,5.60345,7.92008,6.41612,4.25706,5.495453,0.143345,1.809282,0.60619,8.74589,0.48050800000000005,2.94157,12.73491,1.7424922,5.44248,179.96802 38 | 2017-07-01,4.80272,0.9003669999999999,1.07019,20.21658,5.1192,0.56257,5.65092,7.97833,6.3905,4.27331,5.537455,0.145172,1.822236,0.61264,8.80582,0.48384699999999997,2.87429,12.85438,1.75342,5.46579,181.26226 39 | 2017-10-01,4.82362,0.906802,1.07877,20.3213,5.15784,0.56621,5.68609,8.05096,6.59169,4.29624,5.5409180000000005,0.145756,1.8360329999999998,0.62027,8.75565,0.48784099999999997,3.03707,12.87352,1.7726275,5.48781,182.96685 40 | 2018-01-01,4.86504,0.914896,1.08326,20.49916,5.18425,0.56768,5.69028,8.00425,6.62937,4.294731,5.552345,0.14615,1.844232,0.62654,8.82385,0.491389,2.86562,12.92577,1.7915179,5.5008,184.36262 41 | 2018-04-01,4.90809,0.91826,1.08833,20.6601,5.21031,0.56749,5.7137,8.06099,6.75676,4.295742,5.573858,0.14558100000000002,1.856078,0.63435,8.85039,0.49522,3.04836,13.06163,1.8080886999999999,5.53087,185.90004 42 | 2018-07-01,4.92894,0.9195169999999999,1.09254,20.80268,5.24168,0.56655,5.73572,7.99744,6.68631,4.300846,5.537012,0.147342,1.860633,0.63601,8.90615,0.497823,2.96887,12.95859,1.8039333,5.56581,186.79599 43 | 2018-10-01,4.93524,0.928973,1.10307,20.87359,5.26786,0.56651,5.77073,8.06099,6.77453,4.313607,5.522476,0.148252,1.86831,0.64537,8.90915,0.500855,3.14974,13.12264,1.8081592000000002,5.58448,187.21281 44 | 2019-01-01,4.95551,0.936745,1.10577,20.89251,5.25637,0.57061,5.8107,8.11867,6.90307,4.322133,5.5518469999999995,0.150063,1.8813479999999998,0.64928,8.88936,0.50527,2.99029,13.19624,1.8128046,5.62033,188.33195 45 | 2019-04-01,4.98973,0.9325439999999999,1.11054,21.0995,5.3092,0.57579,5.8489,8.10506,6.87472,4.334204000000001,5.57382,0.15265299999999998,1.889184,0.65371,8.89098,0.508123,3.15507,13.27584,1.8202787,5.62779,189.82528 46 | 2019-07-01,5.0306,0.934451,1.11798,21.16842,5.33348,0.57554,5.8492,8.11413,6.74077,4.334605,5.572678000000001,0.152008,1.89659,0.66126,8.91728,0.5104730000000001,3.05611,13.3004,1.8296510999999998,5.65362,191.12653 47 | 2019-10-01,5.05036,0.932082,1.12513,21.23207,5.31726,0.57405,5.83324,8.13532,6.97884,4.299539,5.408475,0.151945,1.9040839999999999,0.66254,9.05169,0.514682,3.24228,13.34498,1.8385548,5.65109,192.0231 48 | 2020-01-01,5.03496,0.9083310000000001,1.08966,20.77403,5.27728,0.57338,5.50817,8.01957,6.62288,4.045370999999999,5.438035999999999,0.149451,1.87627,0.65584,8.93389,0.491961,2.89305,13.32553,1.8093797,5.50835,189.51992 49 | 2020-04-01,4.69275,0.80489,0.96287,18.48005,4.96031,0.53833,4.76258,7.25923,6.13008,3.533511,5.002934000000001,0.14042,1.728407,0.59528,8.5071,0.416939,2.51187,12.25111,1.6983431,4.43817,172.58205 50 | 2020-07-01,4.85727,0.891973,1.0774,20.14029,5.27494,0.56453,5.63803,7.9129,6.23727,4.10219,5.276015,0.152086,1.83597,0.67859,8.89123,0.478115,2.82149,13.15607,1.8052762,5.2191,185.60774 51 | 2020-10-01,5.01644,0.875193,1.07616,20.58185,5.28059,0.56862,5.58853,7.96207,6.56287,4.037671,5.362419,0.15393600000000002,1.8364179999999999,0.6717,8.95496,0.47943,2.99307,13.14522,1.8059775,5.29647,187.67778 52 | 2021-01-01,5.1059,0.87101,1.09033,20.8068,5.27907,0.56794,5.59176,7.84519,6.42829,4.046204,5.343331999999999,0.157143,1.837799,0.68192,8.95172,0.46540400000000004,2.79732,13.35595,1.8016901,5.2344,190.55655 53 | 2021-04-01,5.14784,0.908738,1.10899,20.642,5.40941,0.57661,5.64981,7.99649,6.74324,4.153119,5.367795,0.158468,1.90723,0.69828,9.00515,0.48568,3.00089,13.45755,1.8372548000000002,5.52521,193.6831 54 | 2021-07-01,5.05413,0.939523,1.1323,20.91196,5.47941,0.5818,5.83899,8.05928,6.67206,4.266435,5.33882,0.159873,1.936301,0.66858,9.36779,0.49891599999999997,2.97498,13.72724,1.871626,5.577,194.78893 55 | 2021-10-01,5.23725,0.932196,1.13713,21.24709,5.63489,0.58634,5.87112,8.05701,6.82061,4.294219,5.391243,0.161441,1.9498929999999999,0.68504,9.37914,0.507602,3.27744,13.88566,1.8745770000000002,5.64812,198.0629 56 | 2022-01-01,5.27676,0.9466310000000001,1.1433,21.40751,5.60826,0.5893,5.85862,8.12132,6.89605,4.298887,5.391991999999999,0.163454,1.9587679999999998,0.6846,9.29927,0.520615,3.10012,13.7735,1.8835115,5.69182,197.27918 57 | -------------------------------------------------------------------------------- /causalpy/data/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 - 2025 The PyMC Labs Developers 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Code for loading datasets.""" 15 | 16 | from .datasets import load_data 17 | 18 | __all__ = ["load_data"] 19 | -------------------------------------------------------------------------------- /causalpy/data/ancova_generated.csv: -------------------------------------------------------------------------------- 1 | group,pre,post 2 | 0,8.489252455317661,7.824477455926374 3 | 1,12.419853815210045,14.796265003859528 4 | 0,11.131001267370635,10.693494775255136 5 | 0,10.503789109304071,10.532152654532391 6 | 0,10.599760912049257,9.731499501477701 7 | 0,9.20585567484764,8.631473651277016 8 | 1,12.981286194970231,14.858409910487993 9 | 1,11.235838962316713,12.67802892339558 10 | 1,11.052268558744334,11.981392508857457 11 | 0,11.032876418087163,10.735915114198521 12 | 1,10.797986215792887,12.602345899059923 13 | 0,9.53432237712092,9.7465648655858 14 | 0,8.777175028501961,9.185299260280473 15 | 0,10.483186257178556,9.779337261557618 16 | 0,8.18693217372839,8.827809032010238 17 | 0,9.888550758723023,10.06543301874182 18 | 1,12.025545572764294,13.433202294587286 19 | 1,10.72641157652606,12.536558873383596 20 | 0,8.06197411122309,7.425364373941061 21 | 0,11.040779722618025,11.723017219921568 22 | 1,11.30603790090078,13.385634372409909 23 | 0,10.447600888850474,11.154060930494943 24 | 1,13.067201979470134,14.28235454219127 25 | 1,12.17750632339526,14.14039333760221 26 | 0,10.807168530049271,11.24970995818306 27 | 1,11.184307019963528,13.14627760120039 28 | 0,9.586159304495181,9.826129608943418 29 | 1,11.755038401678656,13.84496736394065 30 | 1,10.625903042225632,12.689955968858188 31 | 1,12.257685256277034,14.408471522629108 32 | 0,8.896622773357983,8.273050312066276 33 | 1,12.214940002590598,14.454534799661321 34 | 1,11.835561812225908,14.247528917945397 35 | 0,9.433030539280715,8.574508097683603 36 | 0,10.686248326783621,11.040360143424008 37 | 0,9.85838654577244,10.302354190227469 38 | 1,9.608614059642928,12.299587472723434 39 | 1,11.421790910329685,12.658228985028742 40 | 0,10.669638134729697,10.358722385594032 41 | 1,12.192272071263583,14.117735405267519 42 | 1,13.038834337670504,15.514850208000231 43 | 1,12.368314676020525,14.518772266846883 44 | 0,10.617557026965056,11.016801339134387 45 | 0,10.626532269883139,10.892646857034633 46 | 0,10.300614333902987,10.330920058552755 47 | 1,10.102849381841654,12.230359457045822 48 | 1,11.008791245054956,12.733090909633706 49 | 1,12.574573335103288,14.657540816061372 50 | 1,13.098366871280593,16.027890469131716 51 | 0,10.13314185203344,9.772772883269145 52 | 1,12.444696778109526,15.429437827915937 53 | 0,11.183433582379267,11.823413215893098 54 | 0,8.542795699329139,8.295219192316004 55 | 1,11.527270787720067,13.365372906900005 56 | 0,9.879649326756322,10.36380907877101 57 | 0,10.683112840950638,11.148616607826034 58 | 1,11.973614356210454,13.689535712453113 59 | 1,12.21887877910744,14.84175871850897 60 | 1,11.055525478539108,13.066599895661387 61 | 0,8.453381640575538,8.111537419587666 62 | 0,9.870089190698051,10.244954780256606 63 | 1,12.777355697190721,15.264886253657998 64 | 1,12.317011998370008,13.97135908558011 65 | 0,7.492808733899154,6.523860954287915 66 | 1,12.368024481821628,14.53041599596784 67 | 0,9.36416007790942,8.864468005622566 68 | 0,10.619712036376995,11.492278368396999 69 | 0,10.994893683611522,10.713579618741603 70 | 1,10.630885220476767,12.932536617478073 71 | 0,10.18920300527993,10.175662908120165 72 | 0,11.549276043566842,11.632633532005107 73 | 1,10.520615328912166,11.684547910255223 74 | 1,11.645072905040799,13.45623886750313 75 | 0,10.45875937643886,10.625967437146144 76 | 1,10.060866941403077,11.62348553911652 77 | 0,9.124687051984628,9.009220889339772 78 | 1,11.498360775319536,12.746703456629843 79 | 1,11.006238229752102,12.48465079862449 80 | 0,10.722870069639583,10.757471649322138 81 | 1,12.398698791433201,14.771181680989793 82 | 0,9.770691699718377,8.685202159817438 83 | 0,11.985158572294631,12.641466377478256 84 | 0,8.594048648834946,9.116827003382555 85 | 1,10.98177675963099,13.342589713080896 86 | 0,9.648471015491806,8.337936823299579 87 | 1,11.416965761785413,13.716362069518569 88 | 0,11.703769282145567,11.56360776974179 89 | 1,13.654731090563079,16.322820102174276 90 | 1,10.50982748639981,12.363066668386029 91 | 1,12.184260698273881,14.220574568667647 92 | 0,9.117870231724542,8.517782751200494 93 | 0,10.899705622412764,10.131533553924468 94 | 0,10.833022280927885,10.539488123705226 95 | 0,9.692716357998181,9.355244323730048 96 | 0,11.100332278658298,10.768836085743134 97 | 1,12.49699625130767,14.642322214070033 98 | 1,11.726863918384433,14.10163077915994 99 | 1,12.438566140714228,14.771487851461018 100 | 1,10.90730329459493,13.097411299715988 101 | 1,12.278901475190365,13.552951756078636 102 | 1,12.265156438310129,14.67737747755582 103 | 1,13.691286929642066,16.627878023869545 104 | 0,10.1951357226745,9.45160547840346 105 | 1,11.000528354049498,13.790898849686132 106 | 0,8.893284860677767,8.948932364973224 107 | 1,11.017152087022655,13.405675844916233 108 | 1,12.845189458215584,14.695396946894284 109 | 1,11.930183856716571,13.756428807657489 110 | 1,12.192407095379489,14.677900510965449 111 | 1,11.884725802514529,13.973533220264589 112 | 0,9.204366676101822,8.238565832347659 113 | 1,11.924056566110686,13.41118690511502 114 | 0,9.733192805687251,9.919719653574846 115 | 0,8.460665781880559,8.6999387553229 116 | 0,9.020778090418924,10.03288785482543 117 | 1,11.447580675708176,13.432136586882441 118 | 0,8.580959268916429,8.291049437131987 119 | 0,9.971720891948188,10.282006486576645 120 | 1,11.032668688977394,13.40967978250074 121 | 1,12.416818567306535,13.945708586691158 122 | 0,10.182172058583925,10.09331090683779 123 | 1,13.163171546449826,15.424084801147393 124 | 0,10.42095673336904,10.409510045922122 125 | 1,13.975016510347821,15.216115879129374 126 | 1,11.49732187912045,13.846578205619414 127 | 1,13.440602163936193,14.886634015814483 128 | 1,11.930700474328532,13.780182905112268 129 | 1,11.618453058418556,14.126668329679438 130 | 0,9.751172804679456,9.788918717778754 131 | 0,9.902487911531106,9.632698810164237 132 | 0,9.398150334179403,9.438726118848791 133 | 0,10.55877667049362,11.501197656653362 134 | 1,12.225712172022552,13.732858661492681 135 | 1,11.199763117297856,12.993311863721898 136 | 0,8.503962964614457,8.787744515593845 137 | 1,12.382539639759282,13.869004229739318 138 | 1,11.264860317136874,13.471615470374822 139 | 0,10.118365138760117,10.342860170948581 140 | 1,13.015502974296197,14.267548895909114 141 | 1,13.989377837690588,16.190181459644702 142 | 0,11.127158566720949,10.677570054928793 143 | 1,10.646242343119749,12.11896339648179 144 | 0,10.164330881912127,10.89793960809276 145 | 1,12.452471680236114,14.41318894029163 146 | 0,8.929478965826593,9.399685573592025 147 | 0,10.887105276402668,11.21233752513777 148 | 1,12.788793113841145,14.611853869573453 149 | 0,10.054823274810728,10.84174421439569 150 | 0,10.450628491489136,9.801219714604168 151 | 0,10.071076168392345,9.988065245821273 152 | 1,14.65619257260125,16.21037548076697 153 | 1,10.752644418784662,12.757671745957854 154 | 1,12.367512031476597,13.990647903061141 155 | 0,10.478471646252299,11.245810158940271 156 | 1,11.932590274724973,14.08356956327396 157 | 0,11.347258305364148,11.348332988923774 158 | 1,11.503201853400503,14.214057326884314 159 | 1,12.518869965621,14.137961853644292 160 | 0,10.626820133170224,11.188432926449703 161 | 1,11.538511555990787,14.002453123582423 162 | 0,8.306059310792339,7.727924552988854 163 | 1,12.442128762602495,13.810433501760814 164 | 0,9.537521007599878,9.334443627143134 165 | 1,12.967972843155685,14.36171721318257 166 | 0,9.757145635663939,10.003194740646963 167 | 1,11.998471541228344,14.73831178011914 168 | 1,10.292481711678189,12.737319514475427 169 | 1,12.314756886184675,13.116844779993592 170 | 0,10.43849975260881,10.346747225019238 171 | 1,12.919621939942418,14.681619928311472 172 | 1,11.677649306775082,13.487347797970331 173 | 0,9.941965032199327,10.269646730835463 174 | 1,12.647579585136072,15.046848419329748 175 | 0,9.628431468232515,9.735896298922077 176 | 0,10.374851996408715,10.583207722350805 177 | 0,9.876906469705084,9.830578613616709 178 | 0,11.251897875890322,10.639444030967656 179 | 1,12.349030894865676,13.492613886511528 180 | 1,12.267519542091463,13.059608239132452 181 | 1,11.583461058198159,13.230230139069306 182 | 0,10.444945878227408,9.922913907294015 183 | 1,13.374851599228663,15.856004353433418 184 | 1,13.551659152951943,16.30466714176413 185 | 1,10.52815952038451,12.209413797528292 186 | 1,10.927962912150592,13.376284546621207 187 | 0,9.629546929224926,10.112458762424659 188 | 0,10.822862686692964,10.323347958244627 189 | 1,11.77033081064472,14.573590716769981 190 | 0,8.493079379186764,8.209446538006098 191 | 1,13.440558253502417,16.094839169420357 192 | 1,9.70734572727937,11.653537179686879 193 | 1,11.790148012749617,12.855907075040314 194 | 1,11.013447374072008,12.774959683965271 195 | 1,11.946928386480945,13.715643446567078 196 | 1,10.649070940308492,12.60415030897177 197 | 0,8.567264509757244,9.237374661074682 198 | 0,11.4829954066837,11.979219164829882 199 | 0,10.04930108326449,10.7058776385638 200 | 0,11.136769789873284,11.40344282008195 201 | 0,9.599327203689894,9.927868251548668 202 | -------------------------------------------------------------------------------- /causalpy/data/datasets.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 - 2025 The PyMC Labs Developers 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """ 15 | Functions to load example datasets 16 | """ 17 | 18 | import pathlib 19 | 20 | import pandas as pd 21 | 22 | import causalpy as cp 23 | 24 | DATASETS = { 25 | "banks": {"filename": "banks.csv"}, 26 | "brexit": {"filename": "GDP_in_dollars_billions.csv"}, 27 | "covid": {"filename": "deaths_and_temps_england_wales.csv"}, 28 | "did": {"filename": "did.csv"}, 29 | "drinking": {"filename": "drinking.csv"}, 30 | "its": {"filename": "its.csv"}, 31 | "its simple": {"filename": "its_simple.csv"}, 32 | "rd": {"filename": "regression_discontinuity.csv"}, 33 | "sc": {"filename": "synthetic_control.csv"}, 34 | "anova1": {"filename": "ancova_generated.csv"}, 35 | "geolift1": {"filename": "geolift1.csv"}, 36 | "geolift_multi_cell": {"filename": "geolift_multi_cell.csv"}, 37 | "risk": {"filename": "AJR2001.csv"}, 38 | "nhefs": {"filename": "nhefs.csv"}, 39 | "schoolReturns": {"filename": "schoolingReturns.csv"}, 40 | } 41 | 42 | 43 | def _get_data_home() -> pathlib.PosixPath: 44 | """Return the path of the data directory""" 45 | return pathlib.Path(cp.__file__).parents[1] / "causalpy" / "data" 46 | 47 | 48 | def load_data(dataset: str = None) -> pd.DataFrame: 49 | """Loads the requested dataset and returns a pandas DataFrame. 50 | 51 | :param dataset: The desired dataset to load 52 | """ 53 | 54 | if dataset in DATASETS: 55 | data_dir = _get_data_home() 56 | datafile = DATASETS[dataset] 57 | file_path = data_dir / datafile["filename"] 58 | return pd.read_csv(file_path) 59 | else: 60 | raise ValueError(f"Dataset {dataset} not found!") 61 | -------------------------------------------------------------------------------- /causalpy/data/deaths_and_temps_england_wales.csv: -------------------------------------------------------------------------------- 1 | date,temp,deaths,year,month,t,pre 2 | 2006-01-01,3.8,49124,2006,1,0,True 3 | 2006-02-01,3.4,42664,2006,2,1,True 4 | 2006-03-01,3.9,49207,2006,3,2,True 5 | 2006-04-01,7.4,40645,2006,4,3,True 6 | 2006-05-01,10.7,42425,2006,5,4,True 7 | 2006-06-01,14.5,40797,2006,6,5,True 8 | 2006-07-01,17.8,38870,2006,7,6,True 9 | 2006-08-01,14.9,39140,2006,8,7,True 10 | 2006-09-01,15.2,36594,2006,9,8,True 11 | 2006-10-01,11.6,40390,2006,10,9,True 12 | 2006-11-01,7.2,42326,2006,11,10,True 13 | 2006-12-01,5.5,40417,2006,12,11,True 14 | 2007-01-01,5.9,50305,2007,1,12,True 15 | 2007-02-01,5.1,44669,2007,2,13,True 16 | 2007-03-01,6.3,44358,2007,3,14,True 17 | 2007-04-01,10.2,41416,2007,4,15,True 18 | 2007-05-01,10.6,41579,2007,5,16,True 19 | 2007-06-01,13.7,38028,2007,6,17,True 20 | 2007-07-01,14.2,39729,2007,7,18,True 21 | 2007-08-01,14.3,39048,2007,8,19,True 22 | 2007-09-01,12.6,35215,2007,9,20,True 23 | 2007-10-01,10.2,42935,2007,10,21,True 24 | 2007-11-01,6.8,42832,2007,11,22,True 25 | 2007-12-01,4.4,43938,2007,12,23,True 26 | 2008-01-01,5.3,52057,2008,1,24,True 27 | 2008-02-01,4.9,42338,2008,2,25,True 28 | 2008-03-01,5.1,41999,2008,3,26,True 29 | 2008-04-01,7.0,46016,2008,4,27,True 30 | 2008-05-01,12.1,39240,2008,5,28,True 31 | 2008-06-01,12.9,37845,2008,6,29,True 32 | 2008-07-01,15.3,40084,2008,7,30,True 33 | 2008-08-01,15.1,34980,2008,8,31,True 34 | 2008-09-01,12.5,39238,2008,9,32,True 35 | 2008-10-01,8.7,42195,2008,10,33,True 36 | 2008-11-01,6.1,39504,2008,11,34,True 37 | 2008-12-01,3.1,53594,2008,12,35,True 38 | 2009-01-01,2.8,55045,2009,1,36,True 39 | 2009-02-01,3.7,41433,2009,2,37,True 40 | 2009-03-01,6.1,42395,2009,3,38,True 41 | 2009-04-01,8.9,40270,2009,4,39,True 42 | 2009-05-01,10.8,36568,2009,5,40,True 43 | 2009-06-01,13.7,38851,2009,6,41,True 44 | 2009-07-01,15.1,37975,2009,7,42,True 45 | 2009-08-01,15.3,33606,2009,8,43,True 46 | 2009-09-01,13.2,39127,2009,9,44,True 47 | 2009-10-01,10.4,40187,2009,10,45,True 48 | 2009-11-01,7.3,40122,2009,11,46,True 49 | 2009-12-01,2.1,45769,2009,12,47,True 50 | 2010-01-01,0.9,48363,2010,1,48,True 51 | 2010-02-01,1.8,41048,2010,2,49,True 52 | 2010-03-01,5.1,45138,2010,3,50,True 53 | 2010-04-01,7.9,40584,2010,4,51,True 54 | 2010-05-01,9.7,36517,2010,5,52,True 55 | 2010-06-01,14.1,40168,2010,6,53,True 56 | 2010-07-01,15.6,36888,2010,7,54,True 57 | 2010-08-01,14.2,36083,2010,8,55,True 58 | 2010-09-01,12.8,39423,2010,9,56,True 59 | 2010-10-01,9.4,38613,2010,10,57,True 60 | 2010-11-01,4.3,42123,2010,11,58,True 61 | 2010-12-01,-0.9,48294,2010,12,59,True 62 | 2011-01-01,3.1,49992,2011,1,60,True 63 | 2011-02-01,5.3,39350,2011,2,61,True 64 | 2011-03-01,5.8,44209,2011,3,62,True 65 | 2011-04-01,10.7,36943,2011,4,63,True 66 | 2011-05-01,11.0,40100,2011,5,64,True 67 | 2011-06-01,12.6,40000,2011,6,65,True 68 | 2011-07-01,14.2,35646,2011,7,66,True 69 | 2011-08-01,14.1,38383,2011,8,67,True 70 | 2011-09-01,13.8,38358,2011,9,68,True 71 | 2011-10-01,11.2,37200,2011,10,69,True 72 | 2011-11-01,8.7,40624,2011,11,70,True 73 | 2011-12-01,4.8,43562,2011,12,71,True 74 | 2012-01-01,4.6,46897,2012,1,72,True 75 | 2012-02-01,4.2,44537,2012,2,73,True 76 | 2012-03-01,7.7,44142,2012,3,74,True 77 | 2012-04-01,6.3,41685,2012,4,75,True 78 | 2012-05-01,10.5,44008,2012,5,76,True 79 | 2012-06-01,12.3,36680,2012,6,77,True 80 | 2012-07-01,14.1,39293,2012,7,78,True 81 | 2012-08-01,15.3,39035,2012,8,79,True 82 | 2012-09-01,11.9,35216,2012,9,80,True 83 | 2012-10-01,8.2,43169,2012,10,81,True 84 | 2012-11-01,5.7,42124,2012,11,82,True 85 | 2012-12-01,3.8,42545,2012,12,83,True 86 | 2013-01-01,3.3,52898,2013,1,84,True 87 | 2013-02-01,2.7,43778,2013,2,85,True 88 | 2013-03-01,2.2,44915,2013,3,86,True 89 | 2013-04-01,6.3,49735,2013,4,87,True 90 | 2013-05-01,9.5,42273,2013,5,88,True 91 | 2013-06-01,12.8,35866,2013,6,89,True 92 | 2013-07-01,17.0,39806,2013,7,90,True 93 | 2013-08-01,15.6,35691,2013,8,91,True 94 | 2013-09-01,12.7,36775,2013,9,92,True 95 | 2013-10-01,11.2,42322,2013,10,93,True 96 | 2013-11-01,5.5,39941,2013,11,94,True 97 | 2013-12-01,5.7,42790,2013,12,95,True 98 | 2014-01-01,4.7,49026,2014,1,96,True 99 | 2014-02-01,5.2,41199,2014,2,97,True 100 | 2014-03-01,6.7,41217,2014,3,98,True 101 | 2014-04-01,9.2,41487,2014,4,99,True 102 | 2014-05-01,11.2,39422,2014,5,100,True 103 | 2014-06-01,14.2,38505,2014,6,101,True 104 | 2014-07-01,16.3,41244,2014,7,102,True 105 | 2014-08-01,13.9,35959,2014,8,103,True 106 | 2014-09-01,13.9,40979,2014,9,104,True 107 | 2014-10-01,11.0,43159,2014,10,105,True 108 | 2014-11-01,7.6,39457,2014,11,106,True 109 | 2014-12-01,4.4,49770,2014,12,107,True 110 | 2015-01-01,3.7,60891,2015,1,108,True 111 | 2015-02-01,3.5,46721,2015,2,109,True 112 | 2015-03-01,5.5,47895,2015,3,110,True 113 | 2015-04-01,7.9,45178,2015,4,111,True 114 | 2015-05-01,9.5,39343,2015,5,112,True 115 | 2015-06-01,12.7,42082,2015,6,113,True 116 | 2015-07-01,14.4,40512,2015,7,114,True 117 | 2015-08-01,14.7,36199,2015,8,115,True 118 | 2015-09-01,11.9,41573,2015,9,116,True 119 | 2015-10-01,10.0,42232,2015,10,117,True 120 | 2015-11-01,8.1,41520,2015,11,118,True 121 | 2015-12-01,7.9,45509,2015,12,119,True 122 | 2016-01-01,4.5,47457,2016,1,120,True 123 | 2016-02-01,3.8,46021,2016,2,121,True 124 | 2016-03-01,5.3,48665,2016,3,122,True 125 | 2016-04-01,6.5,46856,2016,4,123,True 126 | 2016-05-01,11.3,41384,2016,5,124,True 127 | 2016-06-01,13.9,42012,2016,6,125,True 128 | 2016-07-01,15.3,38983,2016,7,126,True 129 | 2016-08-01,15.5,40786,2016,8,127,True 130 | 2016-09-01,14.6,40367,2016,9,128,True 131 | 2016-10-01,9.8,40448,2016,10,129,True 132 | 2016-11-01,4.9,46514,2016,11,130,True 133 | 2016-12-01,5.9,45555,2016,12,131,True 134 | 2017-01-01,3.8,57368,2017,1,132,True 135 | 2017-02-01,5.2,47766,2017,2,133,True 136 | 2017-03-01,7.3,48664,2017,3,134,True 137 | 2017-04-01,8.0,39101,2017,4,135,True 138 | 2017-05-01,12.1,44279,2017,5,136,True 139 | 2017-06-01,14.4,42175,2017,6,137,True 140 | 2017-07-01,15.1,38425,2017,7,138,True 141 | 2017-08-01,14.5,41074,2017,8,139,True 142 | 2017-09-01,12.5,40095,2017,9,140,True 143 | 2017-10-01,11.2,43597,2017,10,141,True 144 | 2017-11-01,5.7,45580,2017,11,142,True 145 | 2017-12-01,4.1,45129,2017,12,143,True 146 | 2018-01-01,4.0,64154,2018,1,144,True 147 | 2018-02-01,2.3,49177,2018,2,145,True 148 | 2018-03-01,3.8,51229,2018,3,146,True 149 | 2018-04-01,8.4,46469,2018,4,147,True 150 | 2018-05-01,12.0,42784,2018,5,148,True 151 | 2018-06-01,14.8,39767,2018,6,149,True 152 | 2018-07-01,17.2,40723,2018,7,150,True 153 | 2018-08-01,15.2,40192,2018,8,151,True 154 | 2018-09-01,12.4,37137,2018,9,152,True 155 | 2018-10-01,9.5,44440,2018,10,153,True 156 | 2018-11-01,7.3,43978,2018,11,154,True 157 | 2018-12-01,5.8,41539,2018,12,155,True 158 | 2019-01-01,3.4,53910,2019,1,156,True 159 | 2019-02-01,6.0,45795,2019,2,157,True 160 | 2019-03-01,6.8,43944,2019,3,158,True 161 | 2019-04-01,8.4,44121,2019,4,159,True 162 | 2019-05-01,10.0,44389,2019,5,160,True 163 | 2019-06-01,13.2,38603,2019,6,161,True 164 | 2019-07-01,16.4,42308,2019,7,162,True 165 | 2019-08-01,15.8,38843,2019,8,163,True 166 | 2019-09-01,13.1,40011,2019,9,164,True 167 | 2019-10-01,8.9,46238,2019,10,165,True 168 | 2019-11-01,5.3,45219,2019,11,166,True 169 | 2019-12-01,5.1,47460,2019,12,167,True 170 | 2020-01-01,5.6,56704,2020,1,168,False 171 | 2020-02-01,5.1,43650,2020,2,169,False 172 | 2020-03-01,5.6,49723,2020,3,170,False 173 | 2020-04-01,9.1,88141,2020,4,171,False 174 | 2020-05-01,11.3,52363,2020,5,172,False 175 | 2020-06-01,14.0,42614,2020,6,173,False 176 | 2020-07-01,14.3,40778,2020,7,174,False 177 | 2020-08-01,15.9,37184,2020,8,175,False 178 | 2020-09-01,12.8,42494,2020,9,176,False 179 | 2020-10-01,9.4,46282,2020,10,177,False 180 | 2020-11-01,7.7,51317,2020,11,178,False 181 | 2020-12-01,4.3,56672,2020,12,179,False 182 | 2021-01-01,2.2,73315,2021,1,180,False 183 | 2021-02-01,4.1,58767,2021,2,181,False 184 | 2021-03-01,6.4,48624,2021,3,182,False 185 | 2021-04-01,5.7,41513,2021,4,183,False 186 | 2021-05-01,9.1,37864,2021,5,184,False 187 | 2021-06-01,14.2,41223,2021,6,185,False 188 | 2021-07-01,16.6,43264,2021,7,186,False 189 | 2021-08-01,15.0,43151,2021,8,187,False 190 | 2021-09-01,14.7,47520,2021,9,188,False 191 | 2021-10-01,10.9,46511,2021,10,189,False 192 | 2021-11-01,7.0,51602,2021,11,190,False 193 | 2021-12-01,5.3,52859,2021,12,191,False 194 | 2022-01-01,4.7,53158,2022,1,192,False 195 | 2022-02-01,5.6,45869,2022,2,193,False 196 | 2022-03-01,6.7,49489,2022,3,194,False 197 | 2022-04-01,8.1,45919,2022,4,195,False 198 | 2022-05-01,11.8,48611,2022,5,196,False 199 | -------------------------------------------------------------------------------- /causalpy/data/did.csv: -------------------------------------------------------------------------------- 1 | group,t,unit,post_treatment,y 2 | 0,0.0,0,False,0.897122432901507 3 | 0,1.0,0,True,1.9612135788421983 4 | 1,0.0,1,False,1.2335249009813691 5 | 1,1.0,1,True,2.7527941327437286 6 | 0,0.0,2,False,1.149207391077308 7 | 0,1.0,2,True,1.9107194958946412 8 | 1,0.0,3,False,1.2096028435304764 9 | 1,1.0,3,True,2.7870530562317772 10 | 0,0.0,4,False,1.0182211686591378 11 | 0,1.0,4,True,2.1355782741951903 12 | 1,0.0,5,False,1.2566023467285772 13 | 1,1.0,5,True,2.6352164140993417 14 | 0,0.0,6,False,1.1206312917156163 15 | 0,1.0,6,True,2.0293786635661104 16 | 1,0.0,7,False,1.2253914316635341 17 | 1,1.0,7,True,2.836234979171606 18 | 0,0.0,8,False,1.0937901142584816 19 | 0,1.0,8,True,2.0046646527573992 20 | 1,0.0,9,False,1.1311676279399658 21 | 1,1.0,9,True,2.597416938762001 22 | 0,0.0,10,False,1.1338268148431594 23 | 0,1.0,10,True,2.0396150424632604 24 | 1,0.0,11,False,1.2769574784336464 25 | 1,1.0,11,True,2.7237901979669057 26 | 0,0.0,12,False,1.0548219817786735 27 | 0,1.0,12,True,2.0966644540989554 28 | 1,0.0,13,False,1.2941834769826859 29 | 1,1.0,13,True,2.828746461772019 30 | 0,0.0,14,False,1.0011352011986534 31 | 0,1.0,14,True,2.2367233120727237 32 | 1,0.0,15,False,1.2621457689408864 33 | 1,1.0,15,True,2.737756363134591 34 | 0,0.0,16,False,1.0613566957247114 35 | 0,1.0,16,True,2.105012700050028 36 | 1,0.0,17,False,1.228130146156384 37 | 1,1.0,17,True,2.6887857541638813 38 | 0,0.0,18,False,1.2259823349004162 39 | 0,1.0,18,True,2.097530059810398 40 | 1,0.0,19,False,1.263074342393256 41 | 1,1.0,19,True,2.697326984058356 42 | -------------------------------------------------------------------------------- /causalpy/data/drinking.csv: -------------------------------------------------------------------------------- 1 | ,agecell,all,allfitted,internal,internalfitted,external,externalfitted,alcohol,alcoholfitted,homicide,homicidefitted,suicide,suicidefitted,mva,mvafitted,drugs,drugsfitted,externalother,externalotherfitted 2 | 0,19.068493,92.8254,91.70615,16.61759,16.73813,76.20782,74.96801,0.63913804,0.7943445,16.316818,16.284573,11.203714,11.5921,35.829327,34.81778,3.8724246,3.448835,8.534373,8.388236 3 | 1,19.150684,95.10074,91.88372,18.327684,16.920654,76.773056,74.963066,0.6774093,0.8375749,16.859964,16.270697,12.193368,11.593611,35.639256,34.63389,3.2365112,3.4700222,8.655786,8.530174 4 | 2,19.232876,92.144295,92.049065,18.911053,17.098843,73.23324,74.950226,0.8664426,0.8778347,15.219254,16.262882,11.715812,11.595129,34.20565,34.446735,3.2020707,3.4920695,8.513741,8.662681 5 | 3,19.31507,88.42776,92.20214,16.10177,17.27268,72.32598,74.92947,0.86730844,0.9151149,16.742825,16.261148,11.27501,11.596655,32.278957,34.2563,3.2806885,3.5149798,8.258285,8.7857275 6 | 4,19.39726,88.70494,92.34292,17.36352,17.442156,71.341415,74.90076,1.0191631,0.94940656,14.947726,16.26551,10.984314,11.598189,32.650967,34.062588,3.5481975,3.5387552,8.417533,8.899288 7 | 5,19.479452,90.19179,92.471344,17.872105,17.607254,72.31968,74.86409,1.1713219,0.98070073,15.642815,16.27599,12.166634,11.599731,32.721443,33.86558,3.211689,3.5633986,7.9725456,9.003332 8 | 6,19.561644,96.22031,92.58739,16.414942,17.767965,79.80537,74.81942,0.8699163,1.0089884,16.263653,16.292604,12.405763,11.601281,36.385197,33.66527,3.8578897,3.5889127,10.287705,9.097831 9 | 7,19.643835,89.615555,92.69102,15.977087,17.924273,73.638466,74.76675,1.0979514,1.0342605,15.825645,16.31537,10.979514,11.602839,34.187935,33.461647,3.4831562,3.6153,8.670031,9.182756 10 | 8,19.726027,93.3817,92.782196,17.433271,18.076166,75.948425,74.706024,1.174851,1.0565081,16.789,16.344309,11.900103,11.604405,31.910467,33.254696,4.0551305,3.6425629,10.76315,9.2580805 11 | 9,19.80822,90.857956,92.86087,18.2854,18.22363,72.572556,74.637245,0.9484129,1.0757217,16.616194,16.379436,11.570638,11.60598,30.576832,33.044415,3.5660326,3.6707044,9.863494,9.323772 12 | 10,19.890411,95.81015,92.927025,18.6076,18.366652,77.20255,74.56037,1.3291142,1.0918926,17.278484,16.420774,11.468357,11.607562,33.531654,32.830788,4.101267,3.6997268,9.835445,9.379805 13 | 11,19.972603,94.158066,92.980606,18.43628,18.505219,75.72179,74.47539,1.2164142,1.1050112,16.953773,16.468338,11.631961,11.609152,33.603443,32.613808,3.7632816,3.729633,9.123107,9.426147 14 | 12,20.054794,92.4646,93.02158,19.634459,18.639317,72.83014,74.38227,1.0273844,1.1150686,17.199177,16.52215,11.263178,11.61075,31.772816,32.393467,3.729025,3.7604258,8.295178,9.46277 15 | 13,20.136986,98.2456,93.09767,20.88386,18.778563,77.36174,74.319115,0.9527308,1.1226311,17.034826,16.590736,11.928189,11.618317,34.29831,32.186256,3.9252508,3.7940538,9.679745,9.494516 16 | 14,20.219177,94.9604,93.16124,18.473003,18.913473,76.4874,74.24777,1.3740251,1.1271195,16.297464,16.6657,11.87005,11.625911,33.32011,31.975483,4.2365775,3.828613,9.809012,9.516518 17 | 15,20.30137,93.38532,93.212234,18.883484,19.044024,74.50184,74.168205,0.9938675,1.1285197,15.940106,16.747074,12.46157,11.633533,32.759403,31.761126,3.4403107,3.8641078,9.250613,9.528728 18 | 16,20.383562,90.12092,93.25059,18.56782,19.170197,71.5531,74.08039,1.2250932,1.1268175,16.883316,16.834883,11.102407,11.641183,30.589046,31.543165,3.254154,3.9005423,8.9967785,9.531103 19 | 17,20.465754,92.40623,93.276245,19.36313,19.291971,73.04309,73.984276,0.9969136,1.1219985,17.33096,16.92916,10.889364,11.648862,31.28775,31.32159,4.1793685,3.9379208,8.972222,9.523599 20 | 18,20.547945,93.354485,93.289154,19.968874,19.409328,73.38561,73.87982,0.960042,1.1140486,16.743132,17.029932,11.482102,11.6565695,30.298925,31.096382,4.3777914,3.976248,9.83083,9.506168 21 | 19,20.630136,92.22878,93.289246,19.422657,19.522247,72.80612,73.767,0.9615177,1.1029531,15.884273,17.137232,12.345887,11.664306,30.230118,30.867521,3.6537673,4.015528,10.192088,9.478767 22 | 20,20.712328,92.332245,93.276474,20.377037,19.630705,71.95521,73.64577,1.0015179,1.0886977,17.411003,17.25109,10.978177,11.6720705,30.122578,30.634996,3.6593924,4.055765,9.39886,9.441348 23 | 21,20.794521,92.66751,93.25076,19.521132,19.734684,73.14638,73.51608,1.1959587,1.0712676,17.9008,17.37153,11.303739,11.679864,29.74465,30.398787,4.3594623,4.096964,9.220456,9.393866 24 | 22,20.876713,93.3902,93.212074,18.623945,19.834162,74.76625,73.37791,0.96597224,1.0506482,18.160278,17.498592,12.209889,11.687687,30.717916,30.158875,3.9411666,4.139129,9.157417,9.336272 25 | 23,20.958904,94.26991,93.160324,20.200695,19.92912,74.06922,73.23121,1.2383568,1.0268247,17.143503,17.6323,11.880486,11.6955385,30.41714,29.915245,4.4890437,4.182265,9.48117,9.268521 26 | 26,21.041096,105.26835,102.58908,21.937366,21.104536,83.330986,81.48454,2.5193088,1.7885877,18.332817,17.72146,12.402751,13.520511,36.316807,34.590244,4.224687,4.4496655,10.697372,10.321973 27 | 27,21.123287,101.06651,101.98372,21.121853,21.097532,79.944664,80.88619,1.5919044,1.7310421,17.316814,17.64034,14.83189,13.467395,32.575798,34.198776,4.426271,4.464704,9.978522,10.259112 28 | 28,21.205479,96.96652,101.412476,19.25328,21.102318,77.71324,80.31016,1.3613431,1.6776136,16.880655,17.563643,12.213193,13.416207,33.022293,33.815987,4.745253,4.4810143,10.268416,10.200633 29 | 29,21.287672,102.8269,100.87552,21.508316,21.118961,81.31858,79.75656,1.4416807,1.6283239,17.378096,17.49139,14.4557705,13.36696,35.106873,33.441917,4.1691847,4.498604,9.624192,10.146556 30 | 30,21.369864,100.97943,100.37305,22.366144,21.14752,78.61329,79.225525,1.3271359,1.5831952,17.018566,17.423609,13.583627,13.319661,32.358696,33.07662,4.410775,4.5174794,10.578053,10.096907 31 | 31,21.452055,101.236946,99.905235,21.232779,21.18806,80.00417,78.71717,1.7205198,1.5422496,17.478916,17.360321,13.803261,13.2743225,32.45526,32.72013,4.8487377,4.5376472,10.284016,10.051706 32 | 32,21.534246,99.4976,99.47227,21.27055,21.240648,78.22704,78.23162,1.1359962,1.5055093,18.410973,17.30155,12.848509,13.230954,31.533688,32.372505,4.622329,4.5591145,10.341482,10.010981 33 | 33,21.616438,95.20131,99.07433,22.093296,21.305344,73.10801,77.76899,1.0202943,1.472997,16.599403,17.24732,11.811869,13.189567,30.098682,32.03379,4.8660192,4.581888,9.496586,9.974753 34 | 34,21.69863,97.25825,98.71162,21.267872,21.382215,75.99038,77.32941,1.6117979,1.444735,18.044275,17.197657,13.3268175,13.15017,30.388288,31.704031,4.4029603,4.6059756,8.963169,9.943048 35 | 35,21.780823,98.92873,98.384315,19.454931,21.471325,79.47379,76.91299,1.7328279,1.4207466,17.17075,17.152584,14.76842,13.112777,31.742258,31.383278,5.0803366,4.6313834,9.806231,9.915888 36 | 36,21.863014,102.26221,98.092606,20.515566,21.572742,81.74664,76.51987,1.3019494,1.4010543,17.16206,17.112127,13.926913,13.077396,34.481934,31.071575,4.8132677,4.6581187,10.573407,9.893299 37 | 37,21.945206,96.833374,97.83669,22.093819,21.686531,74.739555,76.15016,1.1461909,1.3856815,16.718578,17.07631,12.370957,13.04404,31.540012,30.768976,4.2685733,4.6861887,9.446195,9.875306 38 | 38,22.027397,99.50197,97.61676,21.024887,21.812757,78.47707,75.80401,1.3462263,1.3746514,17.223778,17.045156,12.274417,13.012717,31.834293,30.475529,4.9097667,4.7156005,11.482519,9.861933 39 | 39,22.109589,96.07229,97.43374,23.40324,21.951653,72.669044,75.48208,1.6263269,1.3679978,17.016933,17.01882,11.622288,12.983537,29.55155,30.191505,4.4029827,4.746397,9.440629,9.853277 40 | 40,22.19178,95.96821,97.28708,20.624224,22.103125,75.34398,75.18395,1.8279659,1.3657339,16.570908,16.9972,12.597069,12.956413,30.836988,29.916727,4.609653,4.7785506,9.497475,9.849293 41 | 41,22.273973,98.61067,97.177,23.129915,22.267242,75.48076,74.90976,1.6322316,1.3678838,17.198147,16.980322,12.420884,12.931358,29.857895,29.651249,4.9365053,4.8120685,10.0720625,9.850004 42 | 42,22.356165,97.95247,97.10369,23.092215,22.444073,74.86025,74.659615,1.5554341,1.3744715,17.109776,16.96821,13.879258,12.908381,28.75559,29.395117,5.1050143,4.846958,9.372488,9.855436 43 | 43,22.438356,97.41173,97.06736,24.37291,22.633688,73.03882,74.43367,1.6381792,1.3855213,16.701437,16.960892,13.86459,12.887496,27.64927,29.148384,4.4350705,4.883228,9.469475,9.865617 44 | 44,22.520548,97.509476,97.0682,23.49674,22.836157,74.01273,74.23204,1.5611123,1.4010576,17.532492,16.958393,12.809127,12.868712,28.260136,28.911102,4.523223,4.920884,9.887045,9.88057 45 | 45,22.60274,95.56239,97.10642,22.817877,23.051548,72.744514,74.05487,1.2431532,1.4211051,16.040686,16.960741,12.15082,12.852041,29.234152,28.683325,5.0127144,4.9599347,10.025429,9.900323 46 | 46,22.68493,97.02306,97.18223,21.815123,23.279934,75.20794,73.90229,0.883854,1.4456885,16.793226,16.967962,12.856058,12.837497,29.528757,28.4651,5.4236493,5.0003877,10.204496,9.924901 47 | 47,22.767124,99.37457,97.29582,24.350594,23.521387,75.02398,73.77444,1.2074674,1.4748328,17.226536,16.980082,14.610356,12.8250885,27.44976,28.256483,4.789621,5.042251,10.504967,9.954332 48 | 48,22.849316,94.31531,97.447426,22.94374,23.77598,71.37157,73.67144,1.3709793,1.5085632,16.69369,16.997128,12.338814,12.814829,26.855064,28.057526,5.564563,5.0855317,9.435563,9.988643 49 | 49,22.931507,97.397644,97.63723,24.23832,24.043783,73.159325,73.59345,1.7774768,1.5469048,17.774767,17.019129,11.917173,12.806729,27.389301,27.868282,4.9688554,5.130238,10.099299,10.027859 50 | -------------------------------------------------------------------------------- /causalpy/data/gt_social_media_data.csv: -------------------------------------------------------------------------------- 1 | date,twitter,linkedin,tiktok,instagram 2 | 2022-05-15,55,9,23,59 3 | 2022-05-16,54,18,20,59 4 | 2022-05-17,54,20,23,57 5 | 2022-05-18,54,20,21,55 6 | 2022-05-19,49,23,21,52 7 | 2022-05-20,46,18,22,56 8 | 2022-05-21,51,9,23,58 9 | 2022-05-22,47,9,27,59 10 | 2022-05-23,45,19,21,58 11 | 2022-05-24,49,21,23,53 12 | 2022-05-25,55,21,21,61 13 | 2022-05-26,53,19,22,68 14 | 2022-05-27,52,16,23,52 15 | 2022-05-28,46,8,24,59 16 | 2022-05-29,45,7,22,56 17 | 2022-05-30,45,9,24,61 18 | 2022-05-31,46,19,20,58 19 | 2022-06-01,51,21,22,56 20 | 2022-06-02,47,19,22,54 21 | 2022-06-03,46,17,21,56 22 | 2022-06-04,45,9,23,58 23 | 2022-06-05,47,9,23,60 24 | 2022-06-06,48,20,21,58 25 | 2022-06-07,46,23,21,57 26 | 2022-06-08,48,22,24,56 27 | 2022-06-09,48,20,23,55 28 | 2022-06-10,48,17,22,56 29 | 2022-06-11,47,9,25,54 30 | 2022-06-12,46,9,24,56 31 | 2022-06-13,46,20,22,54 32 | 2022-06-14,46,21,23,58 33 | 2022-06-15,47,21,23,58 34 | 2022-06-16,46,21,24,56 35 | 2022-06-17,51,17,23,57 36 | 2022-06-18,44,9,24,54 37 | 2022-06-19,43,8,26,59 38 | 2022-06-20,45,16,25,53 39 | 2022-06-21,53,23,24,56 40 | 2022-06-22,48,21,25,58 41 | 2022-06-23,48,22,26,55 42 | 2022-06-24,54,18,24,56 43 | 2022-06-25,54,9,25,57 44 | 2022-06-26,48,8,23,62 45 | 2022-06-27,51,19,24,58 46 | 2022-06-28,54,21,25,56 47 | 2022-06-29,49,20,25,61 48 | 2022-06-30,53,19,27,58 49 | 2022-07-01,56,17,26,55 50 | 2022-07-02,47,8,25,56 51 | 2022-07-03,49,8,28,58 52 | 2022-07-04,47,9,25,58 53 | 2022-07-05,52,20,22,61 54 | 2022-07-06,49,20,27,60 55 | 2022-07-07,49,21,26,56 56 | 2022-07-08,58,17,27,53 57 | 2022-07-09,61,10,30,55 58 | 2022-07-10,56,15,31,55 59 | 2022-07-11,59,27,27,55 60 | 2022-07-12,52,25,28,57 61 | 2022-07-13,54,25,26,55 62 | 2022-07-14,61,25,25,60 63 | 2022-07-15,51,22,25,58 64 | 2022-07-16,50,12,24,58 65 | 2022-07-17,53,13,28,57 66 | 2022-07-18,48,24,25,57 67 | 2022-07-19,50,26,25,57 68 | 2022-07-20,47,26,25,53 69 | 2022-07-21,49,24,26,55 70 | 2022-07-22,50,21,26,57 71 | 2022-07-23,47,9,26,55 72 | 2022-07-24,45,8,27,57 73 | 2022-07-25,47,19,26,59 74 | 2022-07-26,47,20,28,58 75 | 2022-07-27,47,20,27,56 76 | 2022-07-28,48,19,26,57 77 | 2022-07-29,47,18,27,58 78 | 2022-07-30,46,9,28,53 79 | 2022-07-31,48,9,27,58 80 | 2022-08-01,48,19,25,56 81 | 2022-08-02,51,21,25,56 82 | 2022-08-03,49,21,27,55 83 | 2022-08-04,47,20,24,54 84 | 2022-08-05,47,17,25,54 85 | 2022-08-06,47,10,26,57 86 | 2022-08-07,46,9,26,58 87 | 2022-08-08,46,19,24,59 88 | 2022-08-09,62,21,25,59 89 | 2022-08-10,52,20,25,56 90 | 2022-08-11,54,21,24,61 91 | 2022-08-12,60,18,24,59 92 | 2022-08-13,55,9,26,56 93 | 2022-08-14,52,9,25,55 94 | 2022-08-15,48,19,24,54 95 | 2022-08-16,49,21,27,55 96 | 2022-08-17,47,20,24,53 97 | 2022-08-18,45,19,24,53 98 | 2022-08-19,47,18,25,53 99 | 2022-08-20,47,9,25,54 100 | 2022-08-21,49,10,27,56 101 | 2022-08-22,44,20,23,55 102 | 2022-08-23,46,22,23,52 103 | 2022-08-24,49,22,23,52 104 | 2022-08-25,47,21,23,56 105 | 2022-08-26,56,17,23,54 106 | 2022-08-27,52,9,25,55 107 | 2022-08-28,51,9,26,57 108 | 2022-08-29,47,19,22,56 109 | 2022-08-30,47,21,22,52 110 | 2022-08-31,46,20,21,52 111 | 2022-09-01,47,18,22,55 112 | 2022-09-02,48,16,21,49 113 | 2022-09-03,50,8,24,52 114 | 2022-09-04,47,8,25,56 115 | 2022-09-05,48,10,27,56 116 | 2022-09-06,45,19,24,54 117 | 2022-09-07,49,20,21,54 118 | 2022-09-08,51,19,23,50 119 | 2022-09-09,49,16,21,52 120 | 2022-09-10,49,9,21,52 121 | 2022-09-11,52,9,22,55 122 | 2022-09-12,49,19,21,50 123 | 2022-09-13,54,21,20,52 124 | 2022-09-14,47,21,20,49 125 | 2022-09-15,45,20,30,51 126 | 2022-09-16,42,16,22,48 127 | 2022-09-17,47,9,23,52 128 | 2022-09-18,50,9,23,51 129 | 2022-09-19,48,20,21,48 130 | 2022-09-20,47,21,22,50 131 | 2022-09-21,47,22,23,51 132 | 2022-09-22,46,21,23,55 133 | 2022-09-23,46,18,21,49 134 | 2022-09-24,47,9,22,50 135 | 2022-09-25,48,9,23,53 136 | 2022-09-26,45,17,23,51 137 | 2022-09-27,46,20,23,47 138 | 2022-09-28,49,19,23,50 139 | 2022-09-29,45,18,30,56 140 | 2022-09-30,49,17,32,56 141 | 2022-10-01,47,9,36,64 142 | 2022-10-02,52,9,39,66 143 | 2022-10-03,53,19,31,60 144 | 2022-10-04,64,21,33,66 145 | 2022-10-05,58,20,34,68 146 | 2022-10-06,49,19,33,62 147 | 2022-10-07,51,17,22,49 148 | 2022-10-08,50,9,24,48 149 | 2022-10-09,52,9,22,52 150 | 2022-10-10,52,16,21,50 151 | 2022-10-11,51,20,22,50 152 | 2022-10-12,47,19,22,49 153 | 2022-10-13,46,20,23,48 154 | 2022-10-14,46,18,20,49 155 | 2022-10-15,49,9,22,48 156 | 2022-10-16,52,9,23,52 157 | 2022-10-17,47,19,22,52 158 | 2022-10-18,48,20,24,49 159 | 2022-10-19,46,19,24,50 160 | 2022-10-20,46,20,22,49 161 | 2022-10-21,49,18,22,47 162 | 2022-10-22,50,9,23,49 163 | 2022-10-23,58,8,22,53 164 | 2022-10-24,53,18,22,50 165 | 2022-10-25,53,21,22,50 166 | 2022-10-26,52,19,22,49 167 | 2022-10-27,56,19,28,49 168 | 2022-10-28,100,17,26,48 169 | 2022-10-29,75,8,25,49 170 | 2022-10-30,66,9,23,56 171 | 2022-10-31,69,17,21,83 172 | 2022-11-01,75,19,21,58 173 | 2022-11-02,64,21,23,51 174 | 2022-11-03,61,19,24,49 175 | 2022-11-04,76,17,28,52 176 | 2022-11-05,69,9,23,54 177 | 2022-11-06,62,8,25,51 178 | 2022-11-07,66,18,23,50 179 | 2022-11-08,60,18,24,47 180 | 2022-11-09,64,18,21,45 181 | 2022-11-10,61,19,22,49 182 | 2022-11-11,69,16,23,47 183 | -------------------------------------------------------------------------------- /causalpy/data/its.csv: -------------------------------------------------------------------------------- 1 | date,month,year,t,y 2 | 2010-01-31,1,2010,0,25.058185688624274 3 | 2010-02-28,2,2010,1,27.18981176621271 4 | 2010-03-31,3,2010,2,26.487551351776656 5 | 2010-04-30,4,2010,3,31.24171632294163 6 | 2010-05-31,5,2010,4,40.75397275191598 7 | 2010-06-30,6,2010,5,48.399764008729115 8 | 2010-07-31,7,2010,6,42.44837366407632 9 | 2010-08-31,8,2010,7,54.68759081682996 10 | 2010-09-30,9,2010,8,39.207389562324934 11 | 2010-10-31,10,2010,9,32.345685348239776 12 | 2010-11-30,11,2010,10,31.094745554351693 13 | 2010-12-31,12,2010,11,26.84734836557912 14 | 2011-01-31,1,2011,12,22.195602683205493 15 | 2011-02-28,2,2011,13,31.16014850868488 16 | 2011-03-31,3,2011,14,28.46148151306288 17 | 2011-04-30,4,2011,15,34.381456501257915 18 | 2011-05-31,5,2011,16,39.03659030307911 19 | 2011-06-30,6,2011,17,48.62390728954829 20 | 2011-07-31,7,2011,18,45.92272444498544 21 | 2011-08-31,8,2011,19,58.073105015427224 22 | 2011-09-30,9,2011,20,41.76572722870496 23 | 2011-10-31,10,2011,21,35.837444489156745 24 | 2011-11-30,11,2011,22,31.659872678017088 25 | 2011-12-31,12,2011,23,26.169559806322034 26 | 2012-01-31,1,2012,24,27.75643347859959 27 | 2012-02-29,2,2012,25,34.56549644786656 28 | 2012-03-31,3,2012,26,27.526994096947618 29 | 2012-04-30,4,2012,27,34.369697996570636 30 | 2012-05-31,5,2012,28,44.828674931484834 31 | 2012-06-30,6,2012,29,53.35299326853688 32 | 2012-07-31,7,2012,30,48.504281900773336 33 | 2012-08-31,8,2012,31,63.844063826881516 34 | 2012-09-30,9,2012,32,46.10634227245801 35 | 2012-10-31,10,2012,33,40.41895847314143 36 | 2012-11-30,11,2012,34,35.798386737928865 37 | 2012-12-31,12,2012,35,29.915864189226642 38 | 2013-01-31,1,2013,36,30.34725688275301 39 | 2013-02-28,2,2013,37,28.373763151628996 40 | 2013-03-31,3,2013,38,32.185334915708914 41 | 2013-04-30,4,2013,39,36.35673439679912 42 | 2013-05-31,5,2013,40,45.781698735900434 43 | 2013-06-30,6,2013,41,56.44272755212367 44 | 2013-07-31,7,2013,42,50.82281465126571 45 | 2013-08-31,8,2013,43,66.51954064889695 46 | 2013-09-30,9,2013,44,45.36525287708088 47 | 2013-10-31,10,2013,45,43.451014005472054 48 | 2013-11-30,11,2013,46,36.50996653278897 49 | 2013-12-31,12,2013,47,32.17151034089354 50 | 2014-01-31,1,2014,48,35.500423116138876 51 | 2014-02-28,2,2014,49,34.997909325365576 52 | 2014-03-31,3,2014,50,36.66645064194652 53 | 2014-04-30,4,2014,51,42.40356368127723 54 | 2014-05-31,5,2014,52,45.65128780529429 55 | 2014-06-30,6,2014,53,58.776900999483594 56 | 2014-07-31,7,2014,54,52.05254007367343 57 | 2014-08-31,8,2014,55,69.96297892616612 58 | 2014-09-30,9,2014,56,52.21316535759528 59 | 2014-10-31,10,2014,57,44.22456939686441 60 | 2014-11-30,11,2014,58,41.30844911751353 61 | 2014-12-31,12,2014,59,35.407976959959385 62 | 2015-01-31,1,2015,60,33.32944085662502 63 | 2015-02-28,2,2015,61,37.04126984884195 64 | 2015-03-31,3,2015,62,36.315881978487944 65 | 2015-04-30,4,2015,63,43.73663524366684 66 | 2015-05-31,5,2015,64,52.78691068355672 67 | 2015-06-30,6,2015,65,63.85424840070961 68 | 2015-07-31,7,2015,66,55.90935291160097 69 | 2015-08-31,8,2015,67,70.82508414070433 70 | 2015-09-30,9,2015,68,54.10977096452966 71 | 2015-10-31,10,2015,69,44.355785763129035 72 | 2015-11-30,11,2015,70,47.68226099515911 73 | 2015-12-31,12,2015,71,39.785148011882015 74 | 2016-01-31,1,2016,72,37.804854502934084 75 | 2016-02-29,2,2016,73,40.102001211453384 76 | 2016-03-31,3,2016,74,35.428159588572356 77 | 2016-04-30,4,2016,75,43.95028894490062 78 | 2016-05-31,5,2016,76,54.2914501178684 79 | 2016-06-30,6,2016,77,63.43338295825927 80 | 2016-07-31,7,2016,78,52.6637789157146 81 | 2016-08-31,8,2016,79,72.76141520596525 82 | 2016-09-30,9,2016,80,58.72450101312248 83 | 2016-10-31,10,2016,81,48.96097007279306 84 | 2016-11-30,11,2016,82,46.606094702703885 85 | 2016-12-31,12,2016,83,41.93075780032188 86 | 2017-01-31,1,2017,84,36.400482252201876 87 | 2017-02-28,2,2017,85,43.862391588167114 88 | 2017-03-31,3,2017,86,39.701698017743425 89 | 2017-04-30,4,2017,87,43.79807233996439 90 | 2017-05-31,5,2017,88,56.312237712677394 91 | 2017-06-30,6,2017,89,66.91087568258902 92 | 2017-07-31,7,2017,90,68.13107537818581 93 | 2017-08-31,8,2017,91,85.49829874099522 94 | 2017-09-30,9,2017,92,71.20272582446988 95 | 2017-10-31,10,2017,93,65.26319039224376 96 | 2017-11-30,11,2017,94,58.81884761916778 97 | 2017-12-31,12,2017,95,53.67759540435465 98 | 2018-01-31,1,2018,96,48.888124256011494 99 | 2018-02-28,2,2018,97,50.249538862849704 100 | 2018-03-31,3,2018,98,46.38968292648549 101 | 2018-04-30,4,2018,99,51.88125531494078 102 | 2018-05-31,5,2018,100,60.85718568651798 103 | 2018-06-30,6,2018,101,68.7454132089781 104 | 2018-07-31,7,2018,102,68.1829292617686 105 | 2018-08-31,8,2018,103,77.01103851283908 106 | 2018-09-30,9,2018,104,61.858547642740625 107 | 2018-10-31,10,2018,105,57.090003986629384 108 | 2018-11-30,11,2018,106,50.66438260739964 109 | 2018-12-31,12,2018,107,43.39577134254689 110 | 2019-01-31,1,2019,108,42.48868900217848 111 | 2019-02-28,2,2019,109,48.81234047130436 112 | 2019-03-31,3,2019,110,45.035024384878476 113 | 2019-04-30,4,2019,111,51.16480607814459 114 | 2019-05-31,5,2019,112,60.29670535086518 115 | 2019-06-30,6,2019,113,67.14254688478812 116 | 2019-07-31,7,2019,114,66.03569748603856 117 | 2019-08-31,8,2019,115,78.64090955888433 118 | 2019-09-30,9,2019,116,63.096985530860366 119 | 2019-10-31,10,2019,117,62.685352885118 120 | 2019-11-30,11,2019,118,49.51492562850536 121 | 2019-12-31,12,2019,119,47.71159389472015 122 | -------------------------------------------------------------------------------- /causalpy/data/its_simple.csv: -------------------------------------------------------------------------------- 1 | date,linear_trend,timeseries,causal effect,intercept 2 | 2010-01-31,0,0.21480503003113205,0,1.0 3 | 2010-02-28,1,-0.3777296459532429,0,1.0 4 | 2010-03-31,2,-0.13867236650683104,0,1.0 5 | 2010-04-30,3,-0.2274185300796163,0,1.0 6 | 2010-05-31,4,-0.20726109513496646,0,1.0 7 | 2010-06-30,5,0.35104891575858055,0,1.0 8 | 2010-07-31,6,-0.1586256525747562,0,1.0 9 | 2010-08-31,7,0.13031695043255306,0,1.0 10 | 2010-09-30,8,-0.07392035205380793,0,1.0 11 | 2010-10-31,9,-0.025987587608871963,0,1.0 12 | 2010-11-30,10,0.24334219255419662,0,1.0 13 | 2010-12-31,11,-0.29755994071885433,0,1.0 14 | 2011-01-31,12,-0.1930799501107326,0,1.0 15 | 2011-02-28,13,-0.22033102146024258,0,1.0 16 | 2011-03-31,14,-0.06169332486577135,0,1.0 17 | 2011-04-30,15,-0.3142827593294249,0,1.0 18 | 2011-05-31,16,-0.1339264209805803,0,1.0 19 | 2011-06-30,17,0.07772459565486825,0,1.0 20 | 2011-07-31,18,0.17834926255672728,0,1.0 21 | 2011-08-31,19,-0.41626700730676286,0,1.0 22 | 2011-09-30,20,0.32627071413723246,0,1.0 23 | 2011-10-31,21,-0.4759069598676782,0,1.0 24 | 2011-11-30,22,-0.5805710897997448,0,1.0 25 | 2011-12-31,23,-0.39176340644424007,0,1.0 26 | 2012-01-31,24,0.2132694719825104,0,1.0 27 | 2012-02-29,25,-0.08911519459888295,0,1.0 28 | 2012-03-31,26,0.007835469535464875,0,1.0 29 | 2012-04-30,27,-0.14048432767497343,0,1.0 30 | 2012-05-31,28,-0.1505117786887522,0,1.0 31 | 2012-06-30,29,0.021772841825501005,0,1.0 32 | 2012-07-31,30,0.1743374088012384,0,1.0 33 | 2012-08-31,31,-0.1623769307263898,0,1.0 34 | 2012-09-30,32,0.1641695366144703,0,1.0 35 | 2012-10-31,33,0.11764791290943934,0,1.0 36 | 2012-11-30,34,-0.34529733183545497,0,1.0 37 | 2012-12-31,35,-0.0708009744584712,0,1.0 38 | 2013-01-31,36,-0.24650654077268203,0,1.0 39 | 2013-02-28,37,0.29239690359549664,0,1.0 40 | 2013-03-31,38,-0.6094705671860218,0,1.0 41 | 2013-04-30,39,-0.28954858107121456,0,1.0 42 | 2013-05-31,40,-0.293204723082857,0,1.0 43 | 2013-06-30,41,0.3491785634260171,0,1.0 44 | 2013-07-31,42,0.3007621533113873,0,1.0 45 | 2013-08-31,43,-0.5597258268129217,0,1.0 46 | 2013-09-30,44,-0.24391360176528137,0,1.0 47 | 2013-10-31,45,0.14571351711202302,0,1.0 48 | 2013-11-30,46,0.2663614714898447,0,1.0 49 | 2013-12-31,47,0.20042949586871858,0,1.0 50 | 2014-01-31,48,-0.11575131756496734,0,1.0 51 | 2014-02-28,49,0.34031431993656147,0,1.0 52 | 2014-03-31,50,-0.38939175166861845,0,1.0 53 | 2014-04-30,51,-0.34561444091194937,0,1.0 54 | 2014-05-31,52,0.5745778911445035,0,1.0 55 | 2014-06-30,53,0.170491033794666,0,1.0 56 | 2014-07-31,54,0.12544285459809848,0,1.0 57 | 2014-08-31,55,-0.08127795945641932,0,1.0 58 | 2014-09-30,56,-0.40604163049893977,0,1.0 59 | 2014-10-31,57,-0.03405328884499359,0,1.0 60 | 2014-11-30,58,-0.3364326770001927,0,1.0 61 | 2014-12-31,59,-0.3117914538931869,0,1.0 62 | 2015-01-31,60,0.18017573881851917,0,1.0 63 | 2015-02-28,61,0.4850210343192964,0,1.0 64 | 2015-03-31,62,0.12362092184691267,0,1.0 65 | 2015-04-30,63,-0.12694227903653954,0,1.0 66 | 2015-05-31,64,-0.11675762886279556,0,1.0 67 | 2015-06-30,65,0.012931641490223127,0,1.0 68 | 2015-07-31,66,-0.11288437580685105,0,1.0 69 | 2015-08-31,67,0.1304899666842564,0,1.0 70 | 2015-09-30,68,-0.1006332398442942,0,1.0 71 | 2015-10-31,69,-0.13512733989576795,0,1.0 72 | 2015-11-30,70,-0.22483633826860164,0,1.0 73 | 2015-12-31,71,-0.134689911744993,0,1.0 74 | 2016-01-31,72,-0.29494785225097314,0,1.0 75 | 2016-02-29,73,0.2893300751133943,0,1.0 76 | 2016-03-31,74,0.47963761656585,0,1.0 77 | 2016-04-30,75,0.19369701646691656,0,1.0 78 | 2016-05-31,76,-0.01996110657665389,0,1.0 79 | 2016-06-30,77,-0.06720575887435706,0,1.0 80 | 2016-07-31,78,0.15277014844655742,0,1.0 81 | 2016-08-31,79,0.13714717645590996,0,1.0 82 | 2016-09-30,80,0.7645215248749366,0,1.0 83 | 2016-10-31,81,0.24657655271463647,0,1.0 84 | 2016-11-30,82,-0.22941212068273753,0,1.0 85 | 2016-12-31,83,-0.35225831704160865,0,1.0 86 | 2017-01-31,84,1.7571874979027642,2,1.0 87 | 2017-02-28,85,2.217472380797493,2,1.0 88 | 2017-03-31,86,1.7448137450933103,2,1.0 89 | 2017-04-30,87,2.450433616203781,2,1.0 90 | 2017-05-31,88,1.9363912255699476,2,1.0 91 | 2017-06-30,89,2.1053217862322455,2,1.0 92 | 2017-07-31,90,2.083713797718643,2,1.0 93 | 2017-08-31,91,1.9008839245448768,2,1.0 94 | 2017-09-30,92,1.9184745029224628,2,1.0 95 | 2017-10-31,93,2.020940611311974,2,1.0 96 | 2017-11-30,94,1.7397067074383412,2,1.0 97 | 2017-12-31,95,2.1832834419495,2,1.0 98 | 2018-01-31,96,2.29964000766867,2,1.0 99 | 2018-02-28,97,2.0457247994863335,2,1.0 100 | 2018-03-31,98,1.8700658137747732,2,1.0 101 | 2018-04-30,99,1.7494828654009884,2,1.0 102 | 2018-05-31,100,2.2912946285698412,2,1.0 103 | 2018-06-30,101,2.162605832978686,2,1.0 104 | 2018-07-31,102,2.174801919528005,2,1.0 105 | 2018-08-31,103,2.4337863547275074,2,1.0 106 | 2018-09-30,104,2.0959033165036347,2,1.0 107 | 2018-10-31,105,1.9843593728855553,2,1.0 108 | 2018-11-30,106,2.1726749673036556,2,1.0 109 | 2018-12-31,107,1.9547951058092656,2,1.0 110 | 2019-01-31,108,2.4320780457477773,2,1.0 111 | 2019-02-28,109,2.1710850036019322,2,1.0 112 | 2019-03-31,110,1.766843673760467,2,1.0 113 | 2019-04-30,111,1.9809952251192608,2,1.0 114 | 2019-05-31,112,1.9464777141518568,2,1.0 115 | 2019-06-30,113,1.8893740752664778,2,1.0 116 | 2019-07-31,114,2.0809560142281747,2,1.0 117 | 2019-08-31,115,1.625699163381451,2,1.0 118 | 2019-09-30,116,2.317369477428977,2,1.0 119 | 2019-10-31,117,2.4464877308159165,2,1.0 120 | 2019-11-30,118,2.1528944663725462,2,1.0 121 | 2019-12-31,119,1.6786919166797156,2,1.0 122 | -------------------------------------------------------------------------------- /causalpy/data/regression_discontinuity.csv: -------------------------------------------------------------------------------- 1 | x,y,treated 2 | -0.9327385461258624,-0.09191917023352716,False 3 | -0.9307784317910752,-0.38266334518771755,False 4 | -0.9291102731331999,-0.18178605409990187,False 5 | -0.9074191600776724,-0.2882449528549306,False 6 | -0.8824690868975393,-0.4208105229302404,False 7 | -0.8820758837996223,-0.5549416248850486,False 8 | -0.8332766335235537,-0.5662224315011265,False 9 | -0.8256935425965268,-0.48382299305702603,False 10 | -0.8024673938694238,-0.48477611961579603,False 11 | -0.790752285418777,-0.707747814996515,False 12 | -0.7785350116122896,-0.7374712999723909,False 13 | -0.742315412891587,-0.6483035757627679,False 14 | -0.6662225073834396,-0.832778095754855,False 15 | -0.6655667842451465,-0.9989970490441504,False 16 | -0.6509022265133668,-1.098201518293105,False 17 | -0.6493078409319306,-0.9288624468046688,False 18 | -0.6421762113107035,-0.9939703584030359,False 19 | -0.6410725703735278,-0.8925680035173033,False 20 | -0.6309038601840204,-1.123162903010154,False 21 | -0.5853108528774433,-0.8584552817020795,False 22 | -0.5662309825044631,-0.8990431536784658,False 23 | -0.5421530926711697,-0.8914521837968054,False 24 | -0.5345255660968882,-0.7576976802440811,False 25 | -0.5241914096330065,-0.8139792051464783,False 26 | -0.4499646079171833,-0.9007481417862807,False 27 | -0.42444468303758254,-1.0549599737478617,False 28 | -0.42111838094392473,-0.9760074789291727,False 29 | -0.38317230534022007,-0.9103842982737919,False 30 | -0.37371137860022796,-1.0051641606789314,False 31 | -0.36047537362698145,-0.9371669617081703,False 32 | -0.3577661698808785,-1.041576844973374,False 33 | -0.3512132527609937,-0.7759056271726158,False 34 | -0.3476524653221995,-0.7852243849800049,False 35 | -0.34452493301453013,-0.7305267966928453,False 36 | -0.3223785782776811,-0.7937435745851932,False 37 | -0.2586041283423526,-0.6956552041429478,False 38 | -0.23736898669085704,-0.749653158519372,False 39 | -0.21686857598176612,-0.4469798613877832,False 40 | -0.1999263588559459,-0.513059219604664,False 41 | -0.19121601896034512,-0.5897278030501697,False 42 | -0.17266217404023743,-0.5746438793443456,False 43 | -0.13030794537177903,-0.37524196626363915,False 44 | -0.11247215959487789,-0.2601230818907452,False 45 | -0.09447524051128697,-0.37994785186983576,False 46 | -0.09123271185333248,-0.4627236621529374,False 47 | -0.08069628976855281,-0.1739999276828995,False 48 | -0.02536718705196228,-0.04907735191681943,False 49 | -0.0077759567114659145,-0.11153035197485056,False 50 | -0.0017950967836366516,-0.12984310619947467,False 51 | 0.013015947428665964,-0.09696035789102539,False 52 | 0.013331548963634532,-0.13485208870644552,False 53 | 0.05728497490602269,0.12095585344305049,False 54 | 0.06680091633212948,0.0980955142625518,False 55 | 0.08157003555575648,0.3102233462804672,False 56 | 0.10862172617948374,0.3614103650465586,False 57 | 0.11130088281530748,0.3211655811431866,False 58 | 0.13668970491321009,0.4919620363535288,False 59 | 0.1818710731517259,0.5746374909539395,False 60 | 0.19715093147837837,0.6747775161300689,False 61 | 0.2718902644475478,0.9175173409558198,False 62 | 0.2727891405874987,0.8623157516949401,False 63 | 0.28427586609395616,0.6549926550108098,False 64 | 0.2927653526355971,0.6500686314741195,False 65 | 0.29455506663295594,0.7954660561997995,False 66 | 0.34319289018706267,0.7632844249615528,False 67 | 0.3525679375558599,0.7744837015602224,False 68 | 0.3646703125551185,0.9697584879284552,False 69 | 0.36863039619535565,0.6267794966520605,False 70 | 0.38665296495344625,0.9304469844634929,False 71 | 0.4202873868743666,0.9864653128031807,False 72 | 0.4405552319322412,1.0194533769866643,False 73 | 0.45092594217108894,1.0282091634456967,False 74 | 0.4521577119994231,0.9389298616096072,False 75 | 0.4529925469899221,0.875817844211295,False 76 | 0.486610571075738,1.072003857873219,False 77 | 0.5607944031274459,1.4534250292863555,True 78 | 0.5897055257756263,1.4390739413971585,True 79 | 0.6062259846468445,1.3025080744755781,True 80 | 0.6289209045910922,1.396959761061237,True 81 | 0.6331009480616656,1.3510080742631283,True 82 | 0.646491360404372,1.4426028469860563,True 83 | 0.6523045310300368,1.3642181273569531,True 84 | 0.6676661670611155,1.251109051453663,True 85 | 0.673009526858835,1.4801657776642538,True 86 | 0.6775438953577584,1.4810395861010783,True 87 | 0.6992263916443597,1.395165667445156,True 88 | 0.713156426736919,1.3553935461931284,True 89 | 0.7331257342597464,1.3438026004649561,True 90 | 0.7705167864098279,1.256225147032783,True 91 | 0.7708605267154753,1.3038119904586107,True 92 | 0.8181547685748638,1.2207382278626857,True 93 | 0.8197557443167558,1.0599304856869969,True 94 | 0.8784745759091654,0.9167289368568633,True 95 | 0.8926476011192372,0.9802808270766279,True 96 | 0.8927404022617769,0.9570485596222553,True 97 | 0.9036374895996484,0.8737634563951697,True 98 | 0.9328946539516394,0.7486536221645006,True 99 | 0.9500391252052713,0.9073506550779494,True 100 | 0.9639827572280819,0.7995388871883963,True 101 | 0.9718757875035902,0.7397903225941795,True 102 | -------------------------------------------------------------------------------- /causalpy/experiments/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 - 2025 The PyMC Labs Developers 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | -------------------------------------------------------------------------------- /causalpy/experiments/base.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 - 2025 The PyMC Labs Developers 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """ 15 | Base class for quasi experimental designs. 16 | """ 17 | 18 | from abc import abstractmethod 19 | 20 | import pandas as pd 21 | from sklearn.base import RegressorMixin 22 | 23 | from causalpy.pymc_models import PyMCModel 24 | from causalpy.skl_models import create_causalpy_compatible_class 25 | 26 | 27 | class BaseExperiment: 28 | """Base class for quasi experimental designs.""" 29 | 30 | supports_bayes: bool 31 | supports_ols: bool 32 | 33 | def __init__(self, model=None): 34 | # Ensure we've made any provided Scikit Learn model (as identified as being type 35 | # RegressorMixin) compatible with CausalPy by appending our custom methods. 36 | if isinstance(model, RegressorMixin): 37 | model = create_causalpy_compatible_class(model) 38 | 39 | if model is not None: 40 | self.model = model 41 | 42 | if isinstance(self.model, PyMCModel) and not self.supports_bayes: 43 | raise ValueError("Bayesian models not supported.") 44 | 45 | if isinstance(self.model, RegressorMixin) and not self.supports_ols: 46 | raise ValueError("OLS models not supported.") 47 | 48 | if self.model is None: 49 | raise ValueError("model not set or passed.") 50 | 51 | @property 52 | def idata(self): 53 | """Return the InferenceData object of the model. Only relevant for PyMC models.""" 54 | return self.model.idata 55 | 56 | def print_coefficients(self, round_to=None): 57 | """Ask the model to print its coefficients.""" 58 | self.model.print_coefficients(self.labels, round_to) 59 | 60 | def plot(self, *args, **kwargs) -> tuple: 61 | """Plot the model. 62 | 63 | Internally, this function dispatches to either `_bayesian_plot` or `_ols_plot` 64 | depending on the model type. 65 | """ 66 | if isinstance(self.model, PyMCModel): 67 | return self._bayesian_plot(*args, **kwargs) 68 | elif isinstance(self.model, RegressorMixin): 69 | return self._ols_plot(*args, **kwargs) 70 | else: 71 | raise ValueError("Unsupported model type") 72 | 73 | @abstractmethod 74 | def _bayesian_plot(self, *args, **kwargs): 75 | """Abstract method for plotting the model.""" 76 | raise NotImplementedError("_bayesian_plot method not yet implemented") 77 | 78 | @abstractmethod 79 | def _ols_plot(self, *args, **kwargs): 80 | """Abstract method for plotting the model.""" 81 | raise NotImplementedError("_ols_plot method not yet implemented") 82 | 83 | def get_plot_data(self, *args, **kwargs) -> pd.DataFrame: 84 | """Recover the data of an experiment along with the prediction and causal impact information. 85 | 86 | Internally, this function dispatches to either :func:`get_plot_data_bayesian` or :func:`get_plot_data_ols` 87 | depending on the model type. 88 | """ 89 | if isinstance(self.model, PyMCModel): 90 | return self.get_plot_data_bayesian(*args, **kwargs) 91 | elif isinstance(self.model, RegressorMixin): 92 | return self.get_plot_data_ols(*args, **kwargs) 93 | else: 94 | raise ValueError("Unsupported model type") 95 | 96 | @abstractmethod 97 | def get_plot_data_bayesian(self, *args, **kwargs): 98 | """Abstract method for recovering plot data.""" 99 | raise NotImplementedError("get_plot_data_bayesian method not yet implemented") 100 | 101 | @abstractmethod 102 | def get_plot_data_ols(self, *args, **kwargs): 103 | """Abstract method for recovering plot data.""" 104 | raise NotImplementedError("get_plot_data_ols method not yet implemented") 105 | -------------------------------------------------------------------------------- /causalpy/experiments/instrumental_variable.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 - 2025 The PyMC Labs Developers 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """ 15 | Instrumental variable regression 16 | """ 17 | 18 | import warnings # noqa: I001 19 | 20 | import numpy as np 21 | import pandas as pd 22 | from patsy import dmatrices 23 | from sklearn.linear_model import LinearRegression as sk_lin_reg 24 | 25 | from causalpy.custom_exceptions import DataException 26 | from .base import BaseExperiment 27 | 28 | 29 | class InstrumentalVariable(BaseExperiment): 30 | """ 31 | A class to analyse instrumental variable style experiments. 32 | 33 | :param instruments_data: A pandas dataframe of instruments 34 | for our treatment variable. Should contain 35 | instruments Z, and treatment t 36 | :param data: A pandas dataframe of covariates for fitting 37 | the focal regression of interest. Should contain covariates X 38 | including treatment t and outcome y 39 | :param instruments_formula: A statistical model formula for 40 | the instrumental stage regression 41 | e.g. t ~ 1 + z1 + z2 + z3 42 | :param formula: A statistical model formula for the \n 43 | focal regression e.g. y ~ 1 + t + x1 + x2 + x3 44 | :param model: A PyMC model 45 | :param priors: An optional dictionary of priors for the 46 | mus and sigmas of both regressions. If priors are not 47 | specified we will substitute MLE estimates for the beta 48 | coefficients. Greater control can be achieved 49 | by specifying the priors directly e.g. priors = { 50 | "mus": [0, 0], 51 | "sigmas": [1, 1], 52 | "eta": 2, 53 | "lkj_sd": 2, 54 | } 55 | 56 | Example 57 | -------- 58 | >>> import pandas as pd 59 | >>> import causalpy as cp 60 | >>> from causalpy.pymc_models import InstrumentalVariableRegression 61 | >>> import numpy as np 62 | >>> N = 100 63 | >>> e1 = np.random.normal(0, 3, N) 64 | >>> e2 = np.random.normal(0, 1, N) 65 | >>> Z = np.random.uniform(0, 1, N) 66 | >>> ## Ensure the endogeneity of the the treatment variable 67 | >>> X = -1 + 4 * Z + e2 + 2 * e1 68 | >>> y = 2 + 3 * X + 3 * e1 69 | >>> test_data = pd.DataFrame({"y": y, "X": X, "Z": Z}) 70 | >>> sample_kwargs = { 71 | ... "tune": 1, 72 | ... "draws": 5, 73 | ... "chains": 1, 74 | ... "cores": 4, 75 | ... "target_accept": 0.95, 76 | ... "progressbar": False, 77 | ... } 78 | >>> instruments_formula = "X ~ 1 + Z" 79 | >>> formula = "y ~ 1 + X" 80 | >>> instruments_data = test_data[["X", "Z"]] 81 | >>> data = test_data[["y", "X"]] 82 | >>> iv = cp.InstrumentalVariable( 83 | ... instruments_data=instruments_data, 84 | ... data=data, 85 | ... instruments_formula=instruments_formula, 86 | ... formula=formula, 87 | ... model=InstrumentalVariableRegression(sample_kwargs=sample_kwargs), 88 | ... ) 89 | """ 90 | 91 | supports_ols = False 92 | supports_bayes = True 93 | 94 | def __init__( 95 | self, 96 | instruments_data: pd.DataFrame, 97 | data: pd.DataFrame, 98 | instruments_formula: str, 99 | formula: str, 100 | model=None, 101 | priors=None, 102 | **kwargs, 103 | ): 104 | super().__init__(model=model) 105 | self.expt_type = "Instrumental Variable Regression" 106 | self.data = data 107 | self.instruments_data = instruments_data 108 | self.formula = formula 109 | self.instruments_formula = instruments_formula 110 | self.model = model 111 | self.input_validation() 112 | 113 | y, X = dmatrices(formula, self.data) 114 | self._y_design_info = y.design_info 115 | self._x_design_info = X.design_info 116 | self.labels = X.design_info.column_names 117 | self.y, self.X = np.asarray(y), np.asarray(X) 118 | self.outcome_variable_name = y.design_info.column_names[0] 119 | 120 | t, Z = dmatrices(instruments_formula, self.instruments_data) 121 | self._t_design_info = t.design_info 122 | self._z_design_info = Z.design_info 123 | self.labels_instruments = Z.design_info.column_names 124 | self.t, self.Z = np.asarray(t), np.asarray(Z) 125 | self.instrument_variable_name = t.design_info.column_names[0] 126 | 127 | self.get_naive_OLS_fit() 128 | self.get_2SLS_fit() 129 | 130 | # fit the model to the data 131 | COORDS = {"instruments": self.labels_instruments, "covariates": self.labels} 132 | self.coords = COORDS 133 | if priors is None: 134 | priors = { 135 | "mus": [self.ols_beta_first_params, self.ols_beta_second_params], 136 | "sigmas": [1, 1], 137 | "eta": 2, 138 | "lkj_sd": 1, 139 | } 140 | self.priors = priors 141 | self.model.fit( 142 | X=self.X, Z=self.Z, y=self.y, t=self.t, coords=COORDS, priors=self.priors 143 | ) 144 | 145 | def input_validation(self): 146 | """Validate the input data and model formula for correctness""" 147 | treatment = self.instruments_formula.split("~")[0] 148 | test = treatment.strip() in self.instruments_data.columns 149 | test = test & (treatment.strip() in self.data.columns) 150 | if not test: 151 | raise DataException( 152 | f""" 153 | The treatment variable: 154 | {treatment} must appear in the instrument_data to be used 155 | as an outcome variable and in the data object to be used as a covariate. 156 | """ 157 | ) 158 | Z = self.data[treatment.strip()] 159 | check_binary = len(np.unique(Z)) > 2 160 | if check_binary: 161 | warnings.warn( 162 | """Warning. The treatment variable is not Binary. 163 | This is not necessarily a problem but it violates 164 | the assumption of a simple IV experiment. 165 | The coefficients should be interpreted appropriately.""" 166 | ) 167 | 168 | def get_2SLS_fit(self): 169 | """ 170 | Two Stage Least Squares Fit 171 | 172 | This function is called by the experiment, results are used for 173 | priors if none are provided. 174 | """ 175 | first_stage_reg = sk_lin_reg().fit(self.Z, self.t) 176 | fitted_Z_values = first_stage_reg.predict(self.Z) 177 | X2 = self.data.copy(deep=True) 178 | X2[self.instrument_variable_name] = fitted_Z_values 179 | _, X2 = dmatrices(self.formula, X2) 180 | second_stage_reg = sk_lin_reg().fit(X=X2, y=self.y) 181 | betas_first = list(first_stage_reg.coef_[0][1:]) 182 | betas_first.insert(0, first_stage_reg.intercept_[0]) 183 | betas_second = list(second_stage_reg.coef_[0][1:]) 184 | betas_second.insert(0, second_stage_reg.intercept_[0]) 185 | self.ols_beta_first_params = betas_first 186 | self.ols_beta_second_params = betas_second 187 | self.first_stage_reg = first_stage_reg 188 | self.second_stage_reg = second_stage_reg 189 | 190 | def get_naive_OLS_fit(self): 191 | """ 192 | Naive Ordinary Least Squares 193 | 194 | This function is called by the experiment. 195 | """ 196 | ols_reg = sk_lin_reg().fit(self.X, self.y) 197 | beta_params = list(ols_reg.coef_[0][1:]) 198 | beta_params.insert(0, ols_reg.intercept_[0]) 199 | self.ols_beta_params = dict(zip(self._x_design_info.column_names, beta_params)) 200 | self.ols_reg = ols_reg 201 | 202 | def plot(self, round_to=None): 203 | """ 204 | Plot the results 205 | 206 | :param round_to: 207 | Number of decimals used to round results. Defaults to 2. Use "None" to return raw numbers. 208 | """ 209 | raise NotImplementedError("Plot method not implemented.") 210 | 211 | def summary(self, round_to=None) -> None: 212 | """Print summary of main results and model coefficients. 213 | 214 | :param round_to: 215 | Number of decimals used to round results. Defaults to 2. Use "None" to return raw numbers 216 | """ 217 | raise NotImplementedError("Summary method not implemented.") 218 | -------------------------------------------------------------------------------- /causalpy/experiments/prepostnegd.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 - 2025 The PyMC Labs Developers 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """ 15 | Pretest/posttest nonequivalent group design 16 | """ 17 | 18 | from typing import List 19 | 20 | import arviz as az 21 | import numpy as np 22 | import pandas as pd 23 | import seaborn as sns 24 | import xarray as xr 25 | from matplotlib import pyplot as plt 26 | from patsy import build_design_matrices, dmatrices 27 | from sklearn.base import RegressorMixin 28 | 29 | from causalpy.custom_exceptions import ( 30 | DataException, 31 | ) 32 | from causalpy.plot_utils import plot_xY 33 | from causalpy.pymc_models import PyMCModel 34 | from causalpy.utils import _is_variable_dummy_coded, round_num 35 | 36 | from .base import BaseExperiment 37 | 38 | LEGEND_FONT_SIZE = 12 39 | 40 | 41 | class PrePostNEGD(BaseExperiment): 42 | """ 43 | A class to analyse data from pretest/posttest designs 44 | 45 | :param data: 46 | A pandas dataframe 47 | :param formula: 48 | A statistical model formula 49 | :param group_variable_name: 50 | Name of the column in data for the group variable, should be either 51 | binary or boolean 52 | :param pretreatment_variable_name: 53 | Name of the column in data for the pretreatment variable 54 | :param model: 55 | A PyMC model 56 | 57 | Example 58 | -------- 59 | >>> import causalpy as cp 60 | >>> df = cp.load_data("anova1") 61 | >>> seed = 42 62 | >>> result = cp.PrePostNEGD( 63 | ... df, 64 | ... formula="post ~ 1 + C(group) + pre", 65 | ... group_variable_name="group", 66 | ... pretreatment_variable_name="pre", 67 | ... model=cp.pymc_models.LinearRegression( 68 | ... sample_kwargs={ 69 | ... "target_accept": 0.95, 70 | ... "random_seed": seed, 71 | ... "progressbar": False, 72 | ... } 73 | ... ), 74 | ... ) 75 | >>> result.summary(round_to=1) # doctest: +NUMBER 76 | ==================Pretest/posttest Nonequivalent Group Design=================== 77 | Formula: post ~ 1 + C(group) + pre 78 | 79 | Results: 80 | Causal impact = 2, $CI_{94%}$[2, 2] 81 | Model coefficients: 82 | Intercept -0.5, 94% HDI [-1, 0.2] 83 | C(group)[T.1] 2, 94% HDI [2, 2] 84 | pre 1, 94% HDI [1, 1] 85 | sigma 0.5, 94% HDI [0.5, 0.6] 86 | """ 87 | 88 | supports_ols = False 89 | supports_bayes = True 90 | 91 | def __init__( 92 | self, 93 | data: pd.DataFrame, 94 | formula: str, 95 | group_variable_name: str, 96 | pretreatment_variable_name: str, 97 | model=None, 98 | **kwargs, 99 | ): 100 | super().__init__(model=model) 101 | self.data = data 102 | self.expt_type = "Pretest/posttest Nonequivalent Group Design" 103 | self.formula = formula 104 | self.group_variable_name = group_variable_name 105 | self.pretreatment_variable_name = pretreatment_variable_name 106 | self.input_validation() 107 | 108 | y, X = dmatrices(formula, self.data) 109 | self._y_design_info = y.design_info 110 | self._x_design_info = X.design_info 111 | self.labels = X.design_info.column_names 112 | self.y, self.X = np.asarray(y), np.asarray(X) 113 | self.outcome_variable_name = y.design_info.column_names[0] 114 | 115 | # turn into xarray.DataArray's 116 | self.X = xr.DataArray( 117 | self.X, 118 | dims=["obs_ind", "coeffs"], 119 | coords={ 120 | "obs_ind": np.arange(self.X.shape[0]), 121 | "coeffs": self.labels, 122 | }, 123 | ) 124 | self.y = xr.DataArray( 125 | self.y[:, 0], 126 | dims=["obs_ind"], 127 | coords={"obs_ind": self.data.index}, 128 | ) 129 | 130 | # fit the model to the observed (pre-intervention) data 131 | if isinstance(self.model, PyMCModel): 132 | COORDS = {"coeffs": self.labels, "obs_ind": np.arange(self.X.shape[0])} 133 | self.model.fit(X=self.X, y=self.y, coords=COORDS) 134 | elif isinstance(self.model, RegressorMixin): 135 | raise NotImplementedError("Not implemented for OLS model") 136 | else: 137 | raise ValueError("Model type not recognized") 138 | 139 | # Calculate the posterior predictive for the treatment and control for an 140 | # interpolated set of pretest values 141 | # get the model predictions of the observed data 142 | self.pred_xi = np.linspace( 143 | np.min(self.data[self.pretreatment_variable_name]), 144 | np.max(self.data[self.pretreatment_variable_name]), 145 | 200, 146 | ) 147 | # untreated 148 | x_pred_untreated = pd.DataFrame( 149 | { 150 | self.pretreatment_variable_name: self.pred_xi, 151 | self.group_variable_name: np.zeros(self.pred_xi.shape), 152 | } 153 | ) 154 | (new_x_untreated,) = build_design_matrices( 155 | [self._x_design_info], x_pred_untreated 156 | ) 157 | self.pred_untreated = self.model.predict(X=np.asarray(new_x_untreated)) 158 | # treated 159 | x_pred_treated = pd.DataFrame( 160 | { 161 | self.pretreatment_variable_name: self.pred_xi, 162 | self.group_variable_name: np.ones(self.pred_xi.shape), 163 | } 164 | ) 165 | (new_x_treated,) = build_design_matrices([self._x_design_info], x_pred_treated) 166 | self.pred_treated = self.model.predict(X=np.asarray(new_x_treated)) 167 | 168 | # Evaluate causal impact as equal to the trestment effect 169 | self.causal_impact = self.model.idata.posterior["beta"].sel( 170 | {"coeffs": self._get_treatment_effect_coeff()} 171 | ) 172 | 173 | def input_validation(self) -> None: 174 | """Validate the input data and model formula for correctness""" 175 | if not _is_variable_dummy_coded(self.data[self.group_variable_name]): 176 | raise DataException( 177 | f""" 178 | There must be 2 levels of the grouping variable 179 | {self.group_variable_name}. I.e. the treated and untreated. 180 | """ 181 | ) 182 | 183 | def _get_treatment_effect_coeff(self) -> str: 184 | """Find the beta regression coefficient corresponding to the 185 | group (i.e. treatment) effect. 186 | For example if self.group_variable_name is 'group' and 187 | the labels are `['Intercept', 'C(group)[T.1]', 'pre']` 188 | then we want `C(group)[T.1]`. 189 | """ 190 | for label in self.labels: 191 | if (self.group_variable_name in label) & (":" not in label): 192 | return label 193 | 194 | raise NameError("Unable to find coefficient name for the treatment effect") 195 | 196 | def _causal_impact_summary_stat(self, round_to) -> str: 197 | """Computes the mean and 94% credible interval bounds for the causal impact.""" 198 | percentiles = self.causal_impact.quantile([0.03, 1 - 0.03]).values 199 | ci = ( 200 | r"$CI_{94%}$" 201 | + f"[{round_num(percentiles[0], round_to)}, {round_num(percentiles[1], round_to)}]" 202 | ) 203 | causal_impact = f"{round_num(self.causal_impact.mean(), round_to)}, " 204 | return f"Causal impact = {causal_impact + ci}" 205 | 206 | def summary(self, round_to=None) -> None: 207 | """Print summary of main results and model coefficients. 208 | 209 | :param round_to: 210 | Number of decimals used to round results. Defaults to 2. Use "None" to return raw numbers 211 | """ 212 | print(f"{self.expt_type:=^80}") 213 | print(f"Formula: {self.formula}") 214 | print("\nResults:") 215 | # TODO: extra experiment specific outputs here 216 | print(self._causal_impact_summary_stat(round_to)) 217 | self.print_coefficients(round_to) 218 | 219 | def _bayesian_plot( 220 | self, round_to=None, **kwargs 221 | ) -> tuple[plt.Figure, List[plt.Axes]]: 222 | """Generate plot for ANOVA-like experiments with non-equivalent group designs.""" 223 | fig, ax = plt.subplots( 224 | 2, 1, figsize=(7, 9), gridspec_kw={"height_ratios": [3, 1]} 225 | ) 226 | 227 | # Plot raw data 228 | sns.scatterplot( 229 | x="pre", 230 | y="post", 231 | hue="group", 232 | alpha=0.5, 233 | data=self.data, 234 | legend=True, 235 | ax=ax[0], 236 | ) 237 | ax[0].set(xlabel="Pretest", ylabel="Posttest") 238 | 239 | # plot posterior predictive of untreated 240 | h_line, h_patch = plot_xY( 241 | self.pred_xi, 242 | self.pred_untreated["posterior_predictive"].mu, 243 | ax=ax[0], 244 | plot_hdi_kwargs={"color": "C0"}, 245 | label="Control group", 246 | ) 247 | handles = [(h_line, h_patch)] 248 | labels = ["Control group"] 249 | 250 | # plot posterior predictive of treated 251 | h_line, h_patch = plot_xY( 252 | self.pred_xi, 253 | self.pred_treated["posterior_predictive"].mu, 254 | ax=ax[0], 255 | plot_hdi_kwargs={"color": "C1"}, 256 | label="Treatment group", 257 | ) 258 | handles.append((h_line, h_patch)) 259 | labels.append("Treatment group") 260 | 261 | ax[0].legend( 262 | handles=(h_tuple for h_tuple in handles), 263 | labels=labels, 264 | fontsize=LEGEND_FONT_SIZE, 265 | ) 266 | 267 | # Plot estimated caual impact / treatment effect 268 | az.plot_posterior(self.causal_impact, ref_val=0, ax=ax[1], round_to=round_to) 269 | ax[1].set(title="Estimated treatment effect") 270 | return fig, ax 271 | -------------------------------------------------------------------------------- /causalpy/experiments/regression_kink.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 - 2025 The PyMC Labs Developers 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """ 16 | Regression kink design 17 | """ 18 | 19 | import warnings # noqa: I001 20 | 21 | from matplotlib import pyplot as plt 22 | import numpy as np 23 | import pandas as pd 24 | import seaborn as sns 25 | from patsy import build_design_matrices, dmatrices 26 | import xarray as xr 27 | from causalpy.plot_utils import plot_xY 28 | 29 | from .base import BaseExperiment 30 | from causalpy.utils import round_num 31 | from causalpy.custom_exceptions import ( 32 | DataException, 33 | FormulaException, 34 | ) 35 | from causalpy.utils import _is_variable_dummy_coded 36 | 37 | 38 | LEGEND_FONT_SIZE = 12 39 | 40 | 41 | class RegressionKink(BaseExperiment): 42 | """Regression Kink experiment class.""" 43 | 44 | supports_ols = False 45 | supports_bayes = True 46 | 47 | def __init__( 48 | self, 49 | data: pd.DataFrame, 50 | formula: str, 51 | kink_point: float, 52 | model=None, 53 | running_variable_name: str = "x", 54 | epsilon: float = 0.001, 55 | bandwidth: float = np.inf, 56 | **kwargs, 57 | ): 58 | super().__init__(model=model) 59 | self.expt_type = "Regression Kink" 60 | self.data = data 61 | self.formula = formula 62 | self.running_variable_name = running_variable_name 63 | self.kink_point = kink_point 64 | self.epsilon = epsilon 65 | self.bandwidth = bandwidth 66 | self.input_validation() 67 | 68 | if self.bandwidth is not np.inf: 69 | fmin = self.kink_point - self.bandwidth 70 | fmax = self.kink_point + self.bandwidth 71 | filtered_data = self.data.query(f"{fmin} <= x <= {fmax}") 72 | if len(filtered_data) <= 10: 73 | warnings.warn( 74 | f"Choice of bandwidth parameter has lead to only {len(filtered_data)} remaining datapoints. Consider increasing the bandwidth parameter.", # noqa: E501 75 | UserWarning, 76 | ) 77 | y, X = dmatrices(formula, filtered_data) 78 | else: 79 | y, X = dmatrices(formula, self.data) 80 | 81 | self._y_design_info = y.design_info 82 | self._x_design_info = X.design_info 83 | self.labels = X.design_info.column_names 84 | self.y, self.X = np.asarray(y), np.asarray(X) 85 | self.outcome_variable_name = y.design_info.column_names[0] 86 | 87 | # turn into xarray.DataArray's 88 | self.X = xr.DataArray( 89 | self.X, 90 | dims=["obs_ind", "coeffs"], 91 | coords={ 92 | "obs_ind": np.arange(self.X.shape[0]), 93 | "coeffs": self.labels, 94 | }, 95 | ) 96 | self.y = xr.DataArray( 97 | self.y[:, 0], 98 | dims=["obs_ind"], 99 | coords={"obs_ind": np.arange(self.y.shape[0])}, 100 | ) 101 | 102 | COORDS = {"coeffs": self.labels, "obs_ind": np.arange(self.X.shape[0])} 103 | self.model.fit(X=self.X, y=self.y, coords=COORDS) 104 | 105 | # score the goodness of fit to all data 106 | self.score = self.model.score(X=self.X, y=self.y) 107 | 108 | # get the model predictions of the observed data 109 | if self.bandwidth is not np.inf: 110 | xi = np.linspace(fmin, fmax, 200) 111 | else: 112 | xi = np.linspace( 113 | np.min(self.data[self.running_variable_name]), 114 | np.max(self.data[self.running_variable_name]), 115 | 200, 116 | ) 117 | self.x_pred = pd.DataFrame( 118 | {self.running_variable_name: xi, "treated": self._is_treated(xi)} 119 | ) 120 | (new_x,) = build_design_matrices([self._x_design_info], self.x_pred) 121 | self.pred = self.model.predict(X=np.asarray(new_x)) 122 | 123 | # evaluate gradient change around kink point 124 | mu_kink_left, mu_kink, mu_kink_right = self._probe_kink_point() 125 | self.gradient_change = self._eval_gradient_change( 126 | mu_kink_left, mu_kink, mu_kink_right, epsilon 127 | ) 128 | 129 | def input_validation(self): 130 | """Validate the input data and model formula for correctness""" 131 | if "treated" not in self.formula: 132 | raise FormulaException( 133 | "A predictor called `treated` should be in the formula" 134 | ) 135 | 136 | if _is_variable_dummy_coded(self.data["treated"]) is False: 137 | raise DataException( 138 | """The treated variable should be dummy coded. Consisting of 0's and 1's only.""" # noqa: E501 139 | ) 140 | 141 | if self.bandwidth <= 0: 142 | raise ValueError("The bandwidth must be greater than zero.") 143 | 144 | if self.epsilon <= 0: 145 | raise ValueError("Epsilon must be greater than zero.") 146 | 147 | @staticmethod 148 | def _eval_gradient_change(mu_kink_left, mu_kink, mu_kink_right, epsilon): 149 | """Evaluate the gradient change at the kink point. 150 | It works by evaluating the model below the kink point, at the kink point, 151 | and above the kink point. 152 | This is a static method for ease of testing. 153 | """ 154 | gradient_left = (mu_kink - mu_kink_left) / epsilon 155 | gradient_right = (mu_kink_right - mu_kink) / epsilon 156 | gradient_change = gradient_right - gradient_left 157 | return gradient_change 158 | 159 | def _probe_kink_point(self): 160 | """Probe the kink point to evaluate the predicted outcome at the kink point and 161 | either side.""" 162 | # Create a dataframe to evaluate predicted outcome at the kink point and either 163 | # side 164 | x_predict = pd.DataFrame( 165 | { 166 | self.running_variable_name: np.array( 167 | [ 168 | self.kink_point - self.epsilon, 169 | self.kink_point, 170 | self.kink_point + self.epsilon, 171 | ] 172 | ), 173 | "treated": np.array([0, 1, 1]), 174 | } 175 | ) 176 | (new_x,) = build_design_matrices([self._x_design_info], x_predict) 177 | predicted = self.model.predict(X=np.asarray(new_x)) 178 | # extract predicted mu values 179 | mu_kink_left = predicted["posterior_predictive"].sel(obs_ind=0)["mu"] 180 | mu_kink = predicted["posterior_predictive"].sel(obs_ind=1)["mu"] 181 | mu_kink_right = predicted["posterior_predictive"].sel(obs_ind=2)["mu"] 182 | return mu_kink_left, mu_kink, mu_kink_right 183 | 184 | def _is_treated(self, x): 185 | """Returns ``True`` if `x` is greater than or equal to the treatment threshold.""" # noqa: E501 186 | return np.greater_equal(x, self.kink_point) 187 | 188 | def summary(self, round_to=None) -> None: 189 | """Print summary of main results and model coefficients. 190 | 191 | :param round_to: 192 | Number of decimals used to round results. Defaults to 2. Use "None" to return raw numbers 193 | """ 194 | print( 195 | f""" 196 | {self.expt_type:=^80} 197 | Formula: {self.formula} 198 | Running variable: {self.running_variable_name} 199 | Kink point on running variable: {self.kink_point} 200 | 201 | Results: 202 | Change in slope at kink point = {round_num(self.gradient_change.mean(), round_to)} 203 | """ 204 | ) 205 | self.print_coefficients(round_to) 206 | 207 | def _bayesian_plot(self, round_to=None, **kwargs) -> tuple[plt.Figure, plt.Axes]: 208 | """Generate plot for regression kink designs.""" 209 | fig, ax = plt.subplots() 210 | # Plot raw data 211 | sns.scatterplot( 212 | self.data, 213 | x=self.running_variable_name, 214 | y=self.outcome_variable_name, 215 | c="k", # hue="treated", 216 | ax=ax, 217 | ) 218 | 219 | # Plot model fit to data 220 | h_line, h_patch = plot_xY( 221 | self.x_pred[self.running_variable_name], 222 | self.pred["posterior_predictive"].mu, 223 | ax=ax, 224 | plot_hdi_kwargs={"color": "C1"}, 225 | ) 226 | handles = [(h_line, h_patch)] 227 | labels = ["Posterior mean"] 228 | 229 | # create strings to compose title 230 | title_info = f"{round_num(self.score.r2, round_to)} (std = {round_num(self.score.r2_std, round_to)})" 231 | r2 = f"Bayesian $R^2$ on all data = {title_info}" 232 | percentiles = self.gradient_change.quantile([0.03, 1 - 0.03]).values 233 | ci = ( 234 | r"$CI_{94\%}$" 235 | + f"[{round_num(percentiles[0], round_to)}, {round_num(percentiles[1], round_to)}]" 236 | ) 237 | grad_change = f""" 238 | Change in gradient = {round_num(self.gradient_change.mean(), round_to)}, 239 | """ 240 | ax.set(title=r2 + "\n" + grad_change + ci) 241 | # Intervention line 242 | ax.axvline( 243 | x=self.kink_point, 244 | ls="-", 245 | lw=3, 246 | color="r", 247 | label="treatment threshold", 248 | ) 249 | ax.legend( 250 | handles=(h_tuple for h_tuple in handles), 251 | labels=labels, 252 | fontsize=LEGEND_FONT_SIZE, 253 | ) 254 | return fig, ax 255 | -------------------------------------------------------------------------------- /causalpy/plot_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 - 2025 The PyMC Labs Developers 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """ 15 | Plotting utility functions. 16 | """ 17 | 18 | from typing import Any, Dict, Optional, Tuple, Union 19 | 20 | import arviz as az 21 | import matplotlib.pyplot as plt 22 | import numpy as np 23 | import pandas as pd 24 | import xarray as xr 25 | from matplotlib.collections import PolyCollection 26 | from matplotlib.lines import Line2D 27 | 28 | 29 | def plot_xY( 30 | x: Union[pd.DatetimeIndex, np.array], 31 | Y: xr.DataArray, 32 | ax: plt.Axes, 33 | plot_hdi_kwargs: Optional[Dict[str, Any]] = None, 34 | hdi_prob: float = 0.94, 35 | label: Union[str, None] = None, 36 | ) -> Tuple[Line2D, PolyCollection]: 37 | """ 38 | Utility function to plot HDI intervals. 39 | 40 | :param x: 41 | Pandas datetime index or numpy array of x-axis values 42 | :param y: 43 | Xarray data array of y-axis data 44 | :param ax: 45 | Matplotlib ax object 46 | :param plot_hdi_kwargs: 47 | Dictionary of keyword arguments passed to ax.plot() 48 | :param hdi_prob: 49 | The size of the HDI, default is 0.94 50 | :param label: 51 | The plot label 52 | """ 53 | 54 | if plot_hdi_kwargs is None: 55 | plot_hdi_kwargs = {} 56 | 57 | (h_line,) = ax.plot( 58 | x, 59 | Y.mean(dim=["chain", "draw"]), 60 | ls="-", 61 | **plot_hdi_kwargs, 62 | label=f"{label}", 63 | ) 64 | ax_hdi = az.plot_hdi( 65 | x, 66 | Y, 67 | hdi_prob=hdi_prob, 68 | fill_kwargs={ 69 | "alpha": 0.25, 70 | "label": " ", 71 | }, 72 | smooth=False, 73 | ax=ax, 74 | **plot_hdi_kwargs, 75 | ) 76 | # Return handle to patch. We get a list of the children of the axis. Filter for just 77 | # the PolyCollection objects. Take the last one. 78 | h_patch = list( 79 | filter(lambda x: isinstance(x, PolyCollection), ax_hdi.get_children()) 80 | )[-1] 81 | return (h_line, h_patch) 82 | 83 | 84 | def get_hdi_to_df( 85 | x: xr.DataArray, 86 | hdi_prob: float = 0.94, 87 | ) -> pd.DataFrame: 88 | """ 89 | Utility function to calculate and recover HDI intervals. 90 | 91 | :param x: 92 | Xarray data array 93 | :param hdi_prob: 94 | The size of the HDI, default is 0.94 95 | """ 96 | hdi = ( 97 | az.hdi(x, hdi_prob=hdi_prob) 98 | .to_dataframe() 99 | .unstack(level="hdi") 100 | .droplevel(0, axis=1) 101 | ) 102 | return hdi 103 | -------------------------------------------------------------------------------- /causalpy/skl_models.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 - 2025 The PyMC Labs Developers 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Custom scikit-learn models for causal inference""" 15 | 16 | from functools import partial 17 | 18 | import numpy as np 19 | from scipy.optimize import fmin_slsqp 20 | from sklearn.base import RegressorMixin 21 | from sklearn.linear_model._base import LinearModel 22 | 23 | from causalpy.utils import round_num 24 | 25 | 26 | class ScikitLearnAdaptor: 27 | """Base class for scikit-learn models that can be used for causal inference.""" 28 | 29 | def calculate_impact(self, y_true, y_pred): 30 | """Calculate the causal impact of the intervention.""" 31 | return y_true - y_pred 32 | 33 | def calculate_cumulative_impact(self, impact): 34 | """Calculate the cumulative impact intervention.""" 35 | return np.cumsum(impact) 36 | 37 | def print_coefficients(self, labels, round_to=None) -> None: 38 | """Print the coefficients of the model with the corresponding labels.""" 39 | print("Model coefficients:") 40 | coef_ = self.get_coeffs() 41 | # Determine the width of the longest label 42 | max_label_length = max(len(name) for name in labels) 43 | # Print each coefficient with formatted alignment 44 | for name, val in zip(labels, coef_): 45 | # Left-align the name 46 | formatted_name = f"{name:<{max_label_length}}" 47 | # Right-align the value with width 10 48 | formatted_val = f"{round_num(val, round_to):>10}" 49 | print(f" {formatted_name}\t{formatted_val}") 50 | 51 | def get_coeffs(self): 52 | """Get the coefficients of the model as a numpy array.""" 53 | return np.squeeze(self.coef_) 54 | 55 | 56 | class WeightedProportion(ScikitLearnAdaptor, LinearModel, RegressorMixin): 57 | """Weighted proportion model for causal inference. Used for synthetic control 58 | methods for example""" 59 | 60 | def loss(self, W, X, y): 61 | """Compute root mean squared loss with data X, weights W, and predictor y""" 62 | return np.sqrt(np.mean((y - np.dot(X, W.T)) ** 2)) 63 | 64 | def fit(self, X, y): 65 | """Fit model on data X with predictor y""" 66 | w_start = [1 / X.shape[1]] * X.shape[1] 67 | coef_ = fmin_slsqp( 68 | partial(self.loss, X=X, y=y), 69 | np.array(w_start), 70 | f_eqcons=lambda w: np.sum(w) - 1, 71 | bounds=[(0.0, 1.0)] * len(w_start), 72 | disp=False, 73 | ) 74 | self.coef_ = np.atleast_2d(coef_) # return as column vector 75 | self.mse = self.loss(W=self.coef_, X=X, y=y) 76 | return self 77 | 78 | def predict(self, X): 79 | """Predict results for data X""" 80 | return np.dot(X, self.coef_.T) 81 | 82 | 83 | def create_causalpy_compatible_class( 84 | estimator: type[RegressorMixin], 85 | ) -> type[RegressorMixin]: 86 | """This function takes a scikit-learn estimator and returns a new class that is 87 | compatible with CausalPy.""" 88 | _add_mixin_methods(estimator, ScikitLearnAdaptor) 89 | return estimator 90 | 91 | 92 | def _add_mixin_methods(model_instance, mixin_class): 93 | """Utility function to bind mixin methods to an existing model instance.""" 94 | for attr_name in dir(mixin_class): 95 | attr = getattr(mixin_class, attr_name) 96 | if callable(attr) and not attr_name.startswith("__"): 97 | # Bind the method to the instance 98 | method = attr.__get__(model_instance, model_instance.__class__) 99 | setattr(model_instance, attr_name, method) 100 | return model_instance 101 | -------------------------------------------------------------------------------- /causalpy/tests/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 - 2025 The PyMC Labs Developers 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | -------------------------------------------------------------------------------- /causalpy/tests/conftest.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 - 2025 The PyMC Labs Developers 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """ 15 | CausalPy Test Configuration 16 | 17 | Functions: 18 | * rng: random number generator with session level scope 19 | """ 20 | 21 | import numpy as np 22 | import pytest 23 | 24 | 25 | @pytest.fixture(scope="session") 26 | def rng() -> np.random.Generator: 27 | """Random number generator that can persist through a pytest session""" 28 | seed: int = sum(map(ord, "causalpy")) 29 | return np.random.default_rng(seed=seed) 30 | -------------------------------------------------------------------------------- /causalpy/tests/test_data_loading.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 - 2025 The PyMC Labs Developers 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """ 15 | Tests that example data can be loaded into data frames. 16 | """ 17 | 18 | import pandas as pd 19 | import pytest 20 | 21 | import causalpy as cp 22 | 23 | tests = [ 24 | "banks", 25 | "brexit", 26 | "covid", 27 | "did", 28 | "drinking", 29 | "its", 30 | "its simple", 31 | "rd", 32 | "sc", 33 | "anova1", 34 | ] 35 | 36 | 37 | @pytest.mark.parametrize("dataset_name", tests) 38 | def test_data_loading(dataset_name): 39 | """ 40 | Checks that test data can be loaded into data frames and that there are no 41 | missing values in any column. 42 | """ 43 | df = cp.load_data(dataset_name) 44 | assert isinstance(df, pd.DataFrame) 45 | # Check that there are no missing values in any column 46 | assert df.isnull().sum().sum() == 0 47 | -------------------------------------------------------------------------------- /causalpy/tests/test_integration_skl_examples.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 - 2025 The PyMC Labs Developers 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | import numpy as np 15 | import pandas as pd 16 | import pytest 17 | from matplotlib import pyplot as plt 18 | from sklearn.gaussian_process import GaussianProcessRegressor 19 | from sklearn.gaussian_process.kernels import ExpSineSquared, WhiteKernel 20 | from sklearn.linear_model import LinearRegression 21 | 22 | import causalpy as cp 23 | 24 | 25 | @pytest.mark.integration 26 | def test_did(): 27 | """ 28 | Test Difference in Differences (DID) Sci-Kit Learn experiment. 29 | 30 | Loads data and checks: 31 | 1. data is a dataframe 32 | 2. skl_experiements.DifferenceInDifferences returns correct type 33 | """ 34 | data = cp.load_data("did") 35 | result = cp.DifferenceInDifferences( 36 | data, 37 | formula="y ~ 1 + group*post_treatment", 38 | time_variable_name="t", 39 | group_variable_name="group", 40 | treated=1, 41 | untreated=0, 42 | model=LinearRegression(), 43 | ) 44 | assert isinstance(data, pd.DataFrame) 45 | assert isinstance(result, cp.DifferenceInDifferences) 46 | result.summary() 47 | fig, ax = result.plot() 48 | assert isinstance(fig, plt.Figure) 49 | assert isinstance(ax, plt.Axes) 50 | with pytest.raises(NotImplementedError): 51 | result.get_plot_data() 52 | 53 | 54 | @pytest.mark.integration 55 | def test_rd_drinking(): 56 | """ 57 | Test Regression Discontinuity Sci-Kit Learn experiment on drinking age data. 58 | 59 | Loads data and checks: 60 | 1. data is a dataframe 61 | 2. skl_experiements.RegressionDiscontinuity returns correct type 62 | """ 63 | 64 | df = ( 65 | cp.load_data("drinking") 66 | .rename(columns={"agecell": "age"}) 67 | .assign(treated=lambda df_: df_.age > 21) 68 | ) 69 | result = cp.RegressionDiscontinuity( 70 | df, 71 | formula="all ~ 1 + age + treated", 72 | running_variable_name="age", 73 | model=LinearRegression(), 74 | treatment_threshold=21, 75 | epsilon=0.001, 76 | ) 77 | assert isinstance(df, pd.DataFrame) 78 | assert isinstance(result, cp.RegressionDiscontinuity) 79 | result.summary() 80 | fig, ax = result.plot() 81 | assert isinstance(fig, plt.Figure) 82 | assert isinstance(ax, plt.Axes) 83 | with pytest.raises(NotImplementedError): 84 | result.get_plot_data() 85 | 86 | 87 | @pytest.mark.integration 88 | def test_its(): 89 | """ 90 | Test Interrupted Time Series Sci-Kit Learn experiment. 91 | 92 | Loads data and checks: 93 | 1. data is a dataframe 94 | 2. skl_experiements.InterruptedTimeSeries returns correct type 95 | 3. the method get_plot_data returns a DataFrame with expected columns 96 | """ 97 | 98 | df = ( 99 | cp.load_data("its") 100 | .assign(date=lambda x: pd.to_datetime(x["date"])) 101 | .set_index("date") 102 | ) 103 | treatment_time = pd.to_datetime("2017-01-01") 104 | result = cp.InterruptedTimeSeries( 105 | df, 106 | treatment_time, 107 | formula="y ~ 1 + t + C(month)", 108 | model=LinearRegression(), 109 | ) 110 | assert isinstance(df, pd.DataFrame) 111 | assert isinstance(result, cp.InterruptedTimeSeries) 112 | result.summary() 113 | fig, ax = result.plot() 114 | assert isinstance(fig, plt.Figure) 115 | # For multi-panel plots, ax should be an array of axes 116 | assert isinstance(ax, np.ndarray) and all( 117 | isinstance(item, plt.Axes) for item in ax 118 | ), "ax must be a numpy.ndarray of plt.Axes" 119 | # Test get_plot_data with default parameters 120 | plot_data = result.get_plot_data() 121 | assert isinstance(plot_data, pd.DataFrame), ( 122 | "The returned object is not a pandas DataFrame" 123 | ) 124 | expected_columns = ["prediction", "impact"] 125 | assert set(expected_columns).issubset(set(plot_data.columns)), ( 126 | f"DataFrame is missing expected columns {expected_columns}" 127 | ) 128 | 129 | 130 | @pytest.mark.integration 131 | def test_sc(): 132 | """ 133 | Test Synthetic Control Sci-Kit Learn experiment. 134 | 135 | Loads data and checks: 136 | 1. data is a dataframe 137 | 2. skl_experiements.SyntheticControl returns correct type 138 | 3. the method get_plot_data returns a DataFrame with expected columns 139 | """ 140 | df = cp.load_data("sc") 141 | treatment_time = 70 142 | result = cp.SyntheticControl( 143 | df, 144 | treatment_time, 145 | control_units=["a", "b", "c", "d", "e", "f", "g"], 146 | treated_units=["actual"], 147 | model=cp.skl_models.WeightedProportion(), 148 | ) 149 | assert isinstance(df, pd.DataFrame) 150 | assert isinstance(result, cp.SyntheticControl) 151 | result.summary() 152 | 153 | fig, ax = result.plot() 154 | assert isinstance(fig, plt.Figure) 155 | # For multi-panel plots, ax should be an array of axes 156 | assert isinstance(ax, np.ndarray) and all( 157 | isinstance(item, plt.Axes) for item in ax 158 | ), "ax must be a numpy.ndarray of plt.Axes" 159 | 160 | fig, ax = result.plot() 161 | assert isinstance(fig, plt.Figure) 162 | # For multi-panel plots, ax should be an array of axes 163 | assert isinstance(ax, np.ndarray) and all( 164 | isinstance(item, plt.Axes) for item in ax 165 | ), "ax must be a numpy.ndarray of plt.Axes" 166 | # Test get_plot_data with default parameters 167 | plot_data = result.get_plot_data() 168 | assert isinstance(plot_data, pd.DataFrame), ( 169 | "The returned object is not a pandas DataFrame" 170 | ) 171 | expected_columns = ["prediction", "impact"] 172 | assert set(expected_columns).issubset(set(plot_data.columns)), ( 173 | f"DataFrame is missing expected columns {expected_columns}" 174 | ) 175 | 176 | 177 | @pytest.mark.integration 178 | def test_rd_linear_main_effects(): 179 | """ 180 | Test Regression Discontinuity Sci-Kit Learn experiment main effects. 181 | 182 | Loads data and checks: 183 | 1. data is a dataframe 184 | 2. skl_experiements.RegressionDiscontinuity returns correct type 185 | """ 186 | data = cp.load_data("rd") 187 | result = cp.RegressionDiscontinuity( 188 | data, 189 | formula="y ~ 1 + x + treated", 190 | model=LinearRegression(), 191 | treatment_threshold=0.5, 192 | epsilon=0.001, 193 | ) 194 | assert isinstance(data, pd.DataFrame) 195 | assert isinstance(result, cp.RegressionDiscontinuity) 196 | result.summary() 197 | fig, ax = result.plot() 198 | assert isinstance(fig, plt.Figure) 199 | assert isinstance(ax, plt.Axes) 200 | 201 | 202 | @pytest.mark.integration 203 | def test_rd_linear_main_effects_bandwidth(): 204 | """ 205 | Test Regression Discontinuity Sci-Kit Learn experiment, main effects with 206 | bandwidth parameter. 207 | 208 | Loads data and checks: 209 | 1. data is a dataframe 210 | 2. skl_experiements.RegressionDiscontinuity returns correct type 211 | """ 212 | data = cp.load_data("rd") 213 | result = cp.RegressionDiscontinuity( 214 | data, 215 | formula="y ~ 1 + x + treated", 216 | model=LinearRegression(), 217 | treatment_threshold=0.5, 218 | epsilon=0.001, 219 | bandwidth=0.3, 220 | ) 221 | assert isinstance(data, pd.DataFrame) 222 | assert isinstance(result, cp.RegressionDiscontinuity) 223 | result.summary() 224 | fig, ax = result.plot() 225 | assert isinstance(fig, plt.Figure) 226 | assert isinstance(ax, plt.Axes) 227 | 228 | 229 | @pytest.mark.integration 230 | def test_rd_linear_with_interaction(): 231 | """ 232 | Test Regression Discontinuity Sci-Kit Learn experiment with interaction. 233 | 234 | Loads data and checks: 235 | 1. data is a dataframe 236 | 2. skl_experiements.RegressionDiscontinuity returns correct type 237 | """ 238 | data = cp.load_data("rd") 239 | result = cp.RegressionDiscontinuity( 240 | data, 241 | formula="y ~ 1 + x + treated + x:treated", 242 | model=LinearRegression(), 243 | treatment_threshold=0.5, 244 | epsilon=0.001, 245 | ) 246 | assert isinstance(data, pd.DataFrame) 247 | assert isinstance(result, cp.RegressionDiscontinuity) 248 | result.summary() 249 | fig, ax = result.plot() 250 | assert isinstance(fig, plt.Figure) 251 | assert isinstance(ax, plt.Axes) 252 | 253 | 254 | @pytest.mark.integration 255 | def test_rd_linear_with_gaussian_process(): 256 | """ 257 | Test Regression Discontinuity Sci-Kit Learn experiment with Gaussian process model. 258 | 259 | Loads data and checks: 260 | 1. data is a dataframe 261 | 2. skl_experiements.RegressionDiscontinuity returns correct type 262 | """ 263 | data = cp.load_data("rd") 264 | kernel = 1.0 * ExpSineSquared(1.0, 5.0) + WhiteKernel(1e-1) 265 | result = cp.RegressionDiscontinuity( 266 | data, 267 | formula="y ~ 1 + x + treated", 268 | model=GaussianProcessRegressor(kernel=kernel), 269 | model_kwargs={"kernel": kernel}, 270 | treatment_threshold=0.5, 271 | epsilon=0.001, 272 | ) 273 | assert isinstance(data, pd.DataFrame) 274 | assert isinstance(result, cp.RegressionDiscontinuity) 275 | fig, ax = result.plot() 276 | assert isinstance(fig, plt.Figure) 277 | assert isinstance(ax, plt.Axes) 278 | -------------------------------------------------------------------------------- /causalpy/tests/test_misc.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 - 2025 The PyMC Labs Developers 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """ 15 | Miscellaneous unit tests 16 | """ 17 | 18 | import causalpy as cp 19 | 20 | 21 | def test_regression_kink_gradient_change(): 22 | """Test function to numerically calculate the change in gradient around the kink 23 | point in regression kink designs""" 24 | # test no change in gradient 25 | assert cp.RegressionKink._eval_gradient_change(-1, 0, 1, 1) == 0.0 26 | assert cp.RegressionKink._eval_gradient_change(1, 0, -1, 1) == 0.0 27 | assert cp.RegressionKink._eval_gradient_change(0, 0, 0, 1) == 0.0 28 | # test positive change in gradient 29 | assert cp.RegressionKink._eval_gradient_change(0, 0, 1, 1) == 1.0 30 | assert cp.RegressionKink._eval_gradient_change(0, 0, 2, 1) == 2.0 31 | assert cp.RegressionKink._eval_gradient_change(-1, -1, 2, 1) == 3.0 32 | assert cp.RegressionKink._eval_gradient_change(-1, 0, 2, 1) == 1.0 33 | # test negative change in gradient 34 | assert cp.RegressionKink._eval_gradient_change(0, 0, -1, 1) == -1.0 35 | assert cp.RegressionKink._eval_gradient_change(0, 0, -2, 1) == -2.0 36 | assert cp.RegressionKink._eval_gradient_change(-1, -1, -2, 1) == -1.0 37 | assert cp.RegressionKink._eval_gradient_change(1, 0, -2, 1) == -1.0 38 | -------------------------------------------------------------------------------- /causalpy/tests/test_model_experiment_compatability.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 - 2025 The PyMC Labs Developers 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """ 15 | Test exceptions are raised when an experiment object is provided a model type (e.g. 16 | `PyMCModel` or `ScikitLearnAdaptor`) that is not supported by the experiment object. 17 | """ 18 | 19 | import numpy as np 20 | import pandas as pd 21 | import pytest 22 | from sklearn.linear_model import LinearRegression 23 | 24 | import causalpy as cp 25 | 26 | # TODO: THE TWO FUNCTIONS BELOW ARE COPIED FROM causalpy/tests/test_regression_kink.py 27 | 28 | 29 | def setup_regression_kink_data(kink): 30 | """Set up data for regression kink design tests""" 31 | # define parameters for data generation 32 | seed = 42 33 | rng = np.random.default_rng(seed) 34 | N = 50 35 | kink = 0.5 36 | beta = [0, -1, 0, 2, 0] 37 | sigma = 0.05 38 | # generate data 39 | x = rng.uniform(-1, 1, N) 40 | y = reg_kink_function(x, beta, kink) + rng.normal(0, sigma, N) 41 | return pd.DataFrame({"x": x, "y": y, "treated": x >= kink}) 42 | 43 | 44 | def reg_kink_function(x, beta, kink): 45 | """Utility function for regression kink design. Returns a piecewise linear function 46 | evaluated at x with a kink at kink and parameters beta""" 47 | return ( 48 | beta[0] 49 | + beta[1] * x 50 | + beta[2] * x**2 51 | + beta[3] * (x - kink) * (x >= kink) 52 | + beta[4] * (x - kink) ** 2 * (x >= kink) 53 | ) 54 | 55 | 56 | # Test that a ValueError is raised when a ScikitLearnAdaptor is provided to a RegressionKink object 57 | def test_olsmodel_and_regressionkink(): 58 | """RegressionKink does not support OLS models, so a ValueError should be raised""" 59 | 60 | with pytest.raises(ValueError): 61 | kink = 0.5 62 | df = setup_regression_kink_data(kink) 63 | _ = cp.RegressionKink( 64 | df, 65 | formula=f"y ~ 1 + x + I((x-{kink})*treated)", 66 | model=LinearRegression(), 67 | kink_point=kink, 68 | ) 69 | 70 | 71 | # Test that a ValueError is raised when a ScikitLearnAdaptor is provided to a InstrumentalVariable object 72 | def test_olsmodel_and_iv(): 73 | """InstrumentalVariable does not support OLS models, so a ValueError should be raised""" 74 | 75 | with pytest.raises(ValueError): 76 | df = cp.load_data("risk") 77 | instruments_formula = "risk ~ 1 + logmort0" 78 | formula = "loggdp ~ 1 + risk" 79 | instruments_data = df[["risk", "logmort0"]] 80 | data = df[["loggdp", "risk"]] 81 | _ = cp.InstrumentalVariable( 82 | instruments_data=instruments_data, 83 | data=data, 84 | instruments_formula=instruments_formula, 85 | formula=formula, 86 | model=LinearRegression(), 87 | ) 88 | 89 | 90 | # Test that a ValueError is raised when a ScikitLearnAdaptor is provided to a PrePostNEGD object 91 | def test_olsmodel_and_prepostnegd(): 92 | """PrePostNEGD does not support OLS models, so a ValueError should be raised""" 93 | 94 | with pytest.raises(ValueError): 95 | df = cp.load_data("anova1") 96 | _ = cp.PrePostNEGD( 97 | df, 98 | formula="post ~ 1 + C(group) + pre", 99 | group_variable_name="group", 100 | pretreatment_variable_name="pre", 101 | model=LinearRegression(), 102 | ) 103 | 104 | 105 | # Test that a ValueError is raised when a ScikitLearnAdaptor is provided to a InversePropensityWeighting object 106 | def test_olsmodel_and_ipw(): 107 | """InversePropensityWeighting does not support OLS models, so a ValueError should be raised""" 108 | 109 | with pytest.raises(ValueError): 110 | df = cp.load_data("nhefs") 111 | _ = cp.InversePropensityWeighting( 112 | df, 113 | formula="trt ~ 1 + age + race", 114 | outcome_variable="outcome", 115 | weighting_scheme="robust", 116 | model=LinearRegression(), 117 | ) 118 | -------------------------------------------------------------------------------- /causalpy/tests/test_pymc_models.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 - 2025 The PyMC Labs Developers 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | import arviz as az 15 | import numpy as np 16 | import pandas as pd 17 | import pymc as pm 18 | import pytest 19 | 20 | import causalpy as cp 21 | from causalpy.pymc_models import PyMCModel 22 | 23 | sample_kwargs = {"tune": 20, "draws": 20, "chains": 2, "cores": 2} 24 | 25 | 26 | class MyToyModel(PyMCModel): 27 | """ 28 | A subclass of PyMCModel with a simple regression model for use in tests. 29 | """ 30 | 31 | def build_model(self, X, y, coords): 32 | """ 33 | Required .build_model() method of a PyMCModel subclass 34 | 35 | This is a basic 1-variable linear regression model for use in tests. 36 | """ 37 | with self: 38 | X_ = pm.Data(name="X", value=X) 39 | y_ = pm.Data(name="y", value=y) 40 | beta = pm.Normal("beta", mu=0, sigma=1, shape=X_.shape[1]) 41 | sigma = pm.HalfNormal("sigma", sigma=1) 42 | mu = pm.Deterministic("mu", pm.math.dot(X_, beta)) 43 | pm.Normal("y_hat", mu=mu, sigma=sigma, observed=y_) 44 | 45 | 46 | class TestPyMCModel: 47 | """ 48 | Related tests that check aspects of PyMCModel objects. 49 | """ 50 | 51 | def test_init(self): 52 | """ 53 | Test initialization. 54 | 55 | Creates PyMCModel() object and checks that idata is None and no sample 56 | kwargs are specified. 57 | """ 58 | mb = PyMCModel() 59 | assert mb.idata is None 60 | assert mb.sample_kwargs == {} 61 | 62 | @pytest.mark.parametrize( 63 | argnames="coords", argvalues=[{"a": 1}, None], ids=["coords-dict", "coord-None"] 64 | ) 65 | @pytest.mark.parametrize( 66 | argnames="y", argvalues=[np.ones(3), None], ids=["y-array", "y-None"] 67 | ) 68 | @pytest.mark.parametrize( 69 | argnames="X", argvalues=[np.ones(2), None], ids=["X-array", "X-None"] 70 | ) 71 | def test_model_builder(self, X, y, coords) -> None: 72 | """ 73 | Tests that a PyMCModel() object without a .build_model() method raises 74 | appropriate exception. 75 | """ 76 | with pytest.raises( 77 | NotImplementedError, match="This method must be implemented by a subclass" 78 | ): 79 | PyMCModel().build_model(X=X, y=y, coords=coords) 80 | 81 | def test_fit_build_not_implemented(self): 82 | """ 83 | Tests that a PyMCModel() object without a .fit() method raises appropriate 84 | exception. 85 | """ 86 | with pytest.raises( 87 | NotImplementedError, match="This method must be implemented by a subclass" 88 | ): 89 | PyMCModel().fit(X=np.ones(2), y=np.ones(3), coords={"a": 1}) 90 | 91 | @pytest.mark.parametrize( 92 | argnames="coords", 93 | argvalues=[None, {"a": 1}], 94 | ids=["None-coords", "dict-coords"], 95 | ) 96 | def test_fit_predict(self, coords, rng) -> None: 97 | """ 98 | Test fit and predict methods on MyToyModel. 99 | 100 | Generates normal data, fits the model, makes predictions, scores the model 101 | then: 102 | 1. checks that model.idata is az.InferenceData type 103 | 2. checks that beta, sigma, mu, and y_hat can be extract from idata 104 | 3. checks score is a pandas series of the correct shape 105 | 4. checks that predictions are az.InferenceData type 106 | """ 107 | X = rng.normal(loc=0, scale=1, size=(20, 2)) 108 | y = rng.normal(loc=0, scale=1, size=(20,)) 109 | model = MyToyModel(sample_kwargs={"chains": 2, "draws": 2}) 110 | model.fit(X, y, coords=coords) 111 | predictions = model.predict(X=X) 112 | score = model.score(X=X, y=y) 113 | assert isinstance(model.idata, az.InferenceData) 114 | assert az.extract(data=model.idata, var_names=["beta"]).shape == (2, 2 * 2) 115 | assert az.extract(data=model.idata, var_names=["sigma"]).shape == (2 * 2,) 116 | assert az.extract(data=model.idata, var_names=["mu"]).shape == (20, 2 * 2) 117 | assert az.extract( 118 | data=model.idata, group="posterior_predictive", var_names=["y_hat"] 119 | ).shape == (20, 2 * 2) 120 | assert isinstance(score, pd.Series) 121 | assert score.shape == (2,) 122 | assert isinstance(predictions, az.InferenceData) 123 | 124 | 125 | def test_idata_property(): 126 | """Test that we can access the idata property of the model""" 127 | df = cp.load_data("did") 128 | result = cp.DifferenceInDifferences( 129 | df, 130 | formula="y ~ 1 + group + t + group:post_treatment", 131 | time_variable_name="t", 132 | group_variable_name="group", 133 | model=cp.pymc_models.LinearRegression(sample_kwargs=sample_kwargs), 134 | ) 135 | assert hasattr(result, "idata") 136 | assert isinstance(result.idata, az.InferenceData) 137 | 138 | 139 | seeds = [1234, 42, 123456789] 140 | 141 | 142 | @pytest.mark.parametrize("seed", seeds) 143 | def test_result_reproducibility(seed): 144 | """Test that we can reproduce the results from the model. We could in theory test 145 | this with all the model and experiment types, but what is being targeted is 146 | the PyMCModel.fit method, so we should be safe testing with just one model. Here 147 | we use the DifferenceInDifferences experiment class.""" 148 | # Load the data 149 | df = cp.load_data("did") 150 | # Set a random seed 151 | sample_kwargs["random_seed"] = seed 152 | # Calculate the result twice 153 | result1 = cp.DifferenceInDifferences( 154 | df, 155 | formula="y ~ 1 + group + t + group:post_treatment", 156 | time_variable_name="t", 157 | group_variable_name="group", 158 | model=cp.pymc_models.LinearRegression(sample_kwargs=sample_kwargs), 159 | ) 160 | result2 = cp.DifferenceInDifferences( 161 | df, 162 | formula="y ~ 1 + group + t + group:post_treatment", 163 | time_variable_name="t", 164 | group_variable_name="group", 165 | model=cp.pymc_models.LinearRegression(sample_kwargs=sample_kwargs), 166 | ) 167 | assert np.all(result1.idata.posterior.mu == result2.idata.posterior.mu) 168 | assert np.all(result1.idata.prior.mu == result2.idata.prior.mu) 169 | assert np.all( 170 | result1.idata.prior_predictive.y_hat == result2.idata.prior_predictive.y_hat 171 | ) 172 | -------------------------------------------------------------------------------- /causalpy/tests/test_synthetic_data.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 - 2025 The PyMC Labs Developers 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """ 15 | Tests for the simulated data functions 16 | """ 17 | 18 | import numpy as np 19 | import pandas as pd 20 | 21 | 22 | def test_generate_multicell_geolift_data(): 23 | """ 24 | Test the generate_multicell_geolift_data function. 25 | """ 26 | from causalpy.data.simulate_data import generate_multicell_geolift_data 27 | 28 | df = generate_multicell_geolift_data() 29 | assert isinstance(df, pd.DataFrame) 30 | assert np.all(df >= 0), "Found negative values in dataset" 31 | 32 | 33 | def test_generate_geolift_data(): 34 | """ 35 | Test the generate_geolift_data function. 36 | """ 37 | from causalpy.data.simulate_data import generate_geolift_data 38 | 39 | df = generate_geolift_data() 40 | assert isinstance(df, pd.DataFrame) 41 | assert np.all(df >= 0), "Found negative values in dataset" 42 | -------------------------------------------------------------------------------- /causalpy/tests/test_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 - 2025 The PyMC Labs Developers 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """ 15 | Tests for utility functions 16 | """ 17 | 18 | import pandas as pd 19 | 20 | from causalpy.utils import _is_variable_dummy_coded, _series_has_2_levels, round_num 21 | 22 | 23 | def test_dummy_coding(): 24 | """Test if the function to check if a variable is dummy coded works correctly""" 25 | assert _is_variable_dummy_coded(pd.Series([False, True, False, True])) is True 26 | assert _is_variable_dummy_coded(pd.Series([False, True, False, "frog"])) is False 27 | assert _is_variable_dummy_coded(pd.Series([0, 0, 1, 0, 1])) is True 28 | assert _is_variable_dummy_coded(pd.Series([0, 0, 1, 0, 2])) is False 29 | assert _is_variable_dummy_coded(pd.Series([0, 0.5, 1, 0, 1])) is False 30 | 31 | 32 | def test_2_level_series(): 33 | """Test if the function to check if a variable has 2 levels works correctly""" 34 | assert _series_has_2_levels(pd.Series(["a", "a", "b"])) is True 35 | assert _series_has_2_levels(pd.Series(["a", "a", "b", "c"])) is False 36 | assert _series_has_2_levels(pd.Series(["coffee", "tea", "coffee"])) is True 37 | assert _series_has_2_levels(pd.Series(["water", "tea", "coffee"])) is False 38 | assert _series_has_2_levels(pd.Series([0, 1, 0, 1])) is True 39 | assert _series_has_2_levels(pd.Series([0, 1, 0, 2])) is False 40 | 41 | 42 | def test_round_num(): 43 | """Test if the function to round numbers works correctly""" 44 | assert round_num(0.12345, None) == "0.12" 45 | assert round_num(0.12345, 0) == "0.1" 46 | assert round_num(0.12345, 1) == "0.1" 47 | assert round_num(0.12345, 2) == "0.12" 48 | assert round_num(0.12345, 3) == "0.123" 49 | assert round_num(0.12345, 4) == "0.1235" 50 | assert round_num(0.12345, 5) == "0.12345" 51 | assert round_num(0.12345, 6) == "0.12345" 52 | assert round_num(123.456, None) == "123" 53 | assert round_num(123.456, 1) == "123" 54 | assert round_num(123.456, 2) == "123" 55 | assert round_num(123.456, 3) == "123" 56 | assert round_num(123.456, 4) == "123.5" 57 | assert round_num(123.456, 5) == "123.46" 58 | assert round_num(123.456, 6) == "123.456" 59 | assert round_num(123.456, 7) == "123.456" 60 | -------------------------------------------------------------------------------- /causalpy/utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 - 2025 The PyMC Labs Developers 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """ 15 | Utility functions 16 | """ 17 | 18 | from typing import Union 19 | 20 | import numpy as np 21 | import pandas as pd 22 | import xarray as xr 23 | 24 | 25 | def _is_variable_dummy_coded(series: pd.Series) -> bool: 26 | """Check if a data in the provided Series is dummy coded. It should be 0 or 1 27 | only.""" 28 | return len(set(series).difference(set([0, 1]))) == 0 29 | 30 | 31 | def _series_has_2_levels(series: pd.Series) -> bool: 32 | """Check that the variable in the provided Series has 2 levels""" 33 | return len(pd.Categorical(series).categories) == 2 34 | 35 | 36 | def round_num(n, round_to): 37 | """ 38 | Return a string representing a number with `round_to` significant figures. 39 | 40 | Parameters 41 | ---------- 42 | n : float 43 | number to round 44 | round_to : int 45 | number of significant figures 46 | """ 47 | sig_figs = _format_sig_figs(n, round_to) 48 | return f"{n:.{sig_figs}g}" 49 | 50 | 51 | def _format_sig_figs(value, default=None): 52 | """Get a default number of significant figures. 53 | 54 | Gives the integer part or `default`, whichever is bigger. 55 | 56 | Examples 57 | -------- 58 | 0.1234 --> 0.12 59 | 1.234 --> 1.2 60 | 12.34 --> 12 61 | 123.4 --> 123 62 | """ 63 | if default is None: 64 | default = 2 65 | if value == 0: 66 | return 1 67 | return max(int(np.log10(np.abs(value))) + 1, default) 68 | 69 | 70 | def convert_to_string(x: Union[float, xr.DataArray], round_to: int = 2) -> str: 71 | """Utility function which takes in numeric inputs and returns a string.""" 72 | if isinstance(x, float): 73 | # In the case of a float, we return the number rounded to 2 decimal places 74 | return f"{x:.2f}" 75 | elif isinstance(x, xr.DataArray): 76 | # In the case of an xarray object, we return the mean and 94% CI 77 | percentiles = x.quantile([0.03, 1 - 0.03]).values 78 | ci = ( 79 | r"$CI_{94\%}$" 80 | + f"[{round_num(percentiles[0], round_to)}, {round_num(percentiles[1], round_to)}]" 81 | ) 82 | return f"{x.mean().values:.2f}" + ci 83 | else: 84 | raise ValueError( 85 | "Type not supported. Please provide a float or an xarray object." 86 | ) 87 | -------------------------------------------------------------------------------- /causalpy/version.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 - 2025 The PyMC Labs Developers 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """CausalPy Version""" 15 | 16 | __version__ = "0.4.2" 17 | -------------------------------------------------------------------------------- /codecov.yml: -------------------------------------------------------------------------------- 1 | coverage: 2 | status: 3 | project: 4 | default: 5 | # basic 6 | target: auto 7 | threshold: 2% 8 | base: auto 9 | paths: 10 | - "causalpy/" 11 | # advanced settings 12 | branches: 13 | - main 14 | if_ci_failed: error #success, failure, error, ignore 15 | informational: false 16 | only_pulls: false 17 | -------------------------------------------------------------------------------- /docs/Makefile: -------------------------------------------------------------------------------- 1 | # Minimal makefile for Sphinx documentation 2 | # 3 | 4 | # You can set these variables from the command line, and also 5 | # from the environment for the first two. 6 | SPHINXOPTS ?= 7 | SPHINXBUILD ?= sphinx-build 8 | SOURCEDIR = source 9 | BUILDDIR = _build 10 | 11 | # Put it first so that "make" without argument is like "make help". 12 | help: 13 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 14 | 15 | .PHONY: help Makefile 16 | 17 | # Catch-all target: route all unknown targets to Sphinx using the new 18 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). 19 | %: Makefile 20 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 21 | -------------------------------------------------------------------------------- /docs/make.bat: -------------------------------------------------------------------------------- 1 | @ECHO OFF 2 | 3 | pushd %~dp0 4 | 5 | REM Command file for Sphinx documentation 6 | 7 | if "%SPHINXBUILD%" == "" ( 8 | set SPHINXBUILD=sphinx-build 9 | ) 10 | set SOURCEDIR=. 11 | set BUILDDIR=_build 12 | 13 | %SPHINXBUILD% >NUL 2>NUL 14 | if errorlevel 9009 ( 15 | echo. 16 | echo.The 'sphinx-build' command was not found. Make sure you have Sphinx 17 | echo.installed, then set the SPHINXBUILD environment variable to point 18 | echo.to the full path of the 'sphinx-build' executable. Alternatively you 19 | echo.may add the Sphinx directory to PATH. 20 | echo. 21 | echo.If you don't have Sphinx installed, grab it from 22 | echo.https://www.sphinx-doc.org/ 23 | exit /b 1 24 | ) 25 | 26 | if "%1" == "" goto help 27 | 28 | %SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% 29 | goto end 30 | 31 | :help 32 | %SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% 33 | 34 | :end 35 | popd 36 | -------------------------------------------------------------------------------- /docs/source/.codespell/codespell-whitelist.txt: -------------------------------------------------------------------------------- 1 | nD 2 | CACE 3 | compliers 4 | complier 5 | -------------------------------------------------------------------------------- /docs/source/.codespell/notebook_to_markdown.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The PyMC Labs Developers 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """ 15 | This is a simple script that converts the jupyter notebooks into markdown 16 | for easier (and cleaner) parsing for the codespell check. Whitelisted words 17 | are maintained within this directory in the `codespeel-whitelist.txt`. For 18 | more information on this pre-commit hook please visit the github homepage 19 | for the project: https://github.com/codespell-project/codespell. 20 | """ 21 | 22 | import argparse 23 | import os 24 | from glob import glob 25 | 26 | import nbformat 27 | from nbconvert import MarkdownExporter 28 | 29 | 30 | def notebook_to_markdown(pattern: str, output_dir: str) -> None: 31 | """ 32 | Utility to convert jupyter notebook to markdown files. 33 | 34 | :param pattern: 35 | str that is a glob appropriate pattern to search 36 | :param output_dir: 37 | str directory to save the markdown files to 38 | """ 39 | for f_name in glob(pattern, recursive=True): 40 | with open(f_name, "r", encoding="utf-8") as f: 41 | nb = nbformat.read(f, as_version=4) 42 | 43 | markdown_exporter = MarkdownExporter() 44 | (body, _) = markdown_exporter.from_notebook_node(nb) 45 | 46 | os.makedirs(output_dir, exist_ok=True) 47 | 48 | output_file = os.path.join( 49 | output_dir, os.path.splitext(os.path.basename(f_name))[0] + ".md" 50 | ) 51 | 52 | with open(output_file, "w", encoding="utf-8") as f: 53 | f.write(body) 54 | 55 | 56 | if __name__ == "__main__": 57 | parser = argparse.ArgumentParser() 58 | parser.add_argument( 59 | "-p", 60 | "--pattern", 61 | help="the glob appropriate pattern to search for jupyter notebooks", 62 | default="docs/**/*.ipynb", 63 | ) 64 | parser.add_argument( 65 | "-t", 66 | "--tempdir", 67 | help="temporary directory to save the converted notebooks", 68 | default="tmp_markdown", 69 | ) 70 | args = parser.parse_args() 71 | notebook_to_markdown(args.pattern, args.tempdir) 72 | -------------------------------------------------------------------------------- /docs/source/.codespell/test_data/test_notebook.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import os" 10 | ] 11 | }, 12 | { 13 | "cell_type": "code", 14 | "execution_count": null, 15 | "metadata": {}, 16 | "outputs": [], 17 | "source": [ 18 | "print(f\"{os.__file}__\")\n", 19 | "\n", 20 | "# Speling mistake." 21 | ] 22 | } 23 | ], 24 | "metadata": { 25 | "language_info": { 26 | "name": "python" 27 | } 28 | }, 29 | "nbformat": 4, 30 | "nbformat_minor": 2 31 | } 32 | -------------------------------------------------------------------------------- /docs/source/.codespell/test_notebook_to_markdown.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The PyMC Labs Developers 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Notebook to markdown tests.""" 15 | 16 | import os 17 | from tempfile import TemporaryDirectory 18 | 19 | import pytest 20 | from notebook_to_markdown import notebook_to_markdown 21 | 22 | 23 | @pytest.fixture 24 | def data_dir() -> str: 25 | """Get current directory.""" 26 | return os.path.join(os.path.dirname(os.path.realpath(__file__)), "test_data") 27 | 28 | 29 | def test_notebook_to_markdown_empty_pattern(data_dir: str) -> None: 30 | """Test basic functionality of notebook_to_markdown with empty pattern.""" 31 | with TemporaryDirectory() as tmp_dir: 32 | pattern = "*.missing" 33 | notebook_to_markdown(f"{data_dir}/{pattern}", tmp_dir) 34 | assert len(os.listdir(tmp_dir)) == 0 35 | 36 | 37 | def test_notebook_to_markdown(data_dir: str) -> None: 38 | """Test basic functionality of notebook_to_markdown with a correct pattern.""" 39 | with TemporaryDirectory() as tmp_dir: 40 | pattern = "*.ipynb" 41 | notebook_to_markdown(f"{data_dir}/{pattern}", tmp_dir) 42 | assert len(os.listdir(tmp_dir)) == 1 43 | assert "test_notebook.md" in os.listdir(tmp_dir) 44 | -------------------------------------------------------------------------------- /docs/source/_static/classes.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pymc-labs/CausalPy/bb7c2bbea029c3563251d6416d020543a00ed2b1/docs/source/_static/classes.png -------------------------------------------------------------------------------- /docs/source/_static/favicon_logo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pymc-labs/CausalPy/bb7c2bbea029c3563251d6416d020543a00ed2b1/docs/source/_static/favicon_logo.png -------------------------------------------------------------------------------- /docs/source/_static/flat_logo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pymc-labs/CausalPy/bb7c2bbea029c3563251d6416d020543a00ed2b1/docs/source/_static/flat_logo.png -------------------------------------------------------------------------------- /docs/source/_static/flat_logo_darkmode.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pymc-labs/CausalPy/bb7c2bbea029c3563251d6416d020543a00ed2b1/docs/source/_static/flat_logo_darkmode.png -------------------------------------------------------------------------------- /docs/source/_static/iv_reg1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pymc-labs/CausalPy/bb7c2bbea029c3563251d6416d020543a00ed2b1/docs/source/_static/iv_reg1.png -------------------------------------------------------------------------------- /docs/source/_static/iv_reg2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pymc-labs/CausalPy/bb7c2bbea029c3563251d6416d020543a00ed2b1/docs/source/_static/iv_reg2.png -------------------------------------------------------------------------------- /docs/source/_static/logo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pymc-labs/CausalPy/bb7c2bbea029c3563251d6416d020543a00ed2b1/docs/source/_static/logo.png -------------------------------------------------------------------------------- /docs/source/_static/packages.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pymc-labs/CausalPy/bb7c2bbea029c3563251d6416d020543a00ed2b1/docs/source/_static/packages.png -------------------------------------------------------------------------------- /docs/source/_static/pymc-labs-log.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pymc-labs/CausalPy/bb7c2bbea029c3563251d6416d020543a00ed2b1/docs/source/_static/pymc-labs-log.png -------------------------------------------------------------------------------- /docs/source/_templates/autosummary/base.rst: -------------------------------------------------------------------------------- 1 | {{ name | escape | underline}} 2 | 3 | .. currentmodule:: {{ module }} 4 | 5 | .. auto{{ objtype }}:: {{ objname }} 6 | -------------------------------------------------------------------------------- /docs/source/_templates/autosummary/class.rst: -------------------------------------------------------------------------------- 1 | {{ name | escape | underline}} 2 | 3 | .. currentmodule:: {{ module }} 4 | 5 | .. autoclass:: {{ objname }} 6 | 7 | {% block methods %} 8 | {% if methods %} 9 | 10 | .. rubric:: Methods 11 | 12 | .. autosummary:: 13 | :toctree: 14 | 15 | {% for item in methods %} 16 | {{ objname }}.{{ item }} 17 | {%- endfor %} 18 | {% endif %} 19 | {% endblock %} 20 | 21 | {% block attributes %} 22 | {% if attributes %} 23 | .. rubric:: Attributes 24 | 25 | .. autosummary:: 26 | {% for item in attributes %} 27 | ~{{ name }}.{{ item }} 28 | {%- endfor %} 29 | {% endif %} 30 | {% endblock %} 31 | -------------------------------------------------------------------------------- /docs/source/_templates/autosummary/method.rst: -------------------------------------------------------------------------------- 1 | {{ objname | escape | underline}} 2 | 3 | .. currentmodule:: {{ module }} 4 | 5 | .. auto{{ objtype }}:: {{ objname }} 6 | -------------------------------------------------------------------------------- /docs/source/_templates/autosummary/module.rst: -------------------------------------------------------------------------------- 1 | {{ name | escape | underline}} 2 | 3 | .. automodule:: {{ fullname }} 4 | 5 | {% block attributes %} 6 | {% if attributes %} 7 | .. rubric:: {{ _('Module Attributes') }} 8 | 9 | .. autosummary:: 10 | {% for item in attributes %} 11 | {{ item }} 12 | {%- endfor %} 13 | {% endif %} 14 | {% endblock %} 15 | 16 | {% block functions %} 17 | {% if functions %} 18 | .. rubric:: {{ _('Functions') }} 19 | 20 | .. autosummary:: 21 | :toctree: 22 | 23 | {% for item in functions %} 24 | {{ item }} 25 | {%- endfor %} 26 | {% endif %} 27 | {% endblock %} 28 | 29 | {% block classes %} 30 | {% if classes %} 31 | .. rubric:: {{ _('Classes') }} 32 | 33 | .. autosummary:: 34 | :toctree: 35 | 36 | {% for item in classes %} 37 | {{ item }} 38 | {%- endfor %} 39 | {% endif %} 40 | {% endblock %} 41 | 42 | {% block exceptions %} 43 | {% if exceptions %} 44 | .. rubric:: {{ _('Exceptions') }} 45 | 46 | .. autosummary:: 47 | {% for item in exceptions %} 48 | {{ item }} 49 | {%- endfor %} 50 | {% endif %} 51 | {% endblock %} 52 | 53 | {% block modules %} 54 | {% if modules %} 55 | .. rubric:: Modules 56 | 57 | .. autosummary:: 58 | :toctree: 59 | :recursive: 60 | {% for item in modules %} 61 | {{ item }} 62 | {%- endfor %} 63 | {% endif %} 64 | {% endblock %} 65 | -------------------------------------------------------------------------------- /docs/source/api/index.md: -------------------------------------------------------------------------------- 1 | # API 2 | 3 | ## Modules 4 | ```{eval-rst} 5 | .. currentmodule:: causalpy 6 | .. autosummary:: 7 | :recursive: 8 | :toctree: generated/ 9 | 10 | data 11 | pymc_models 12 | skl_models 13 | experiments 14 | ``` 15 | -------------------------------------------------------------------------------- /docs/source/conf.py: -------------------------------------------------------------------------------- 1 | # Configuration file for the Sphinx documentation builder. 2 | # 3 | # For the full list of built-in configuration values, see the documentation: 4 | # https://www.sphinx-doc.org/en/master/usage/configuration.html 5 | 6 | 7 | # -- Path setup -------------------------------------------------------------- 8 | # If extensions (or modules to document with autodoc) are in another directory, 9 | # add these directories to sys.path here. If the directory is relative to the 10 | # documentation root, use os.path.abspath to make it absolute, like shown here. 11 | 12 | import os 13 | import sys 14 | 15 | from causalpy.version import __version__ 16 | 17 | sys.path.insert(0, os.path.abspath("../")) 18 | 19 | # autodoc_mock_imports 20 | # This avoids autodoc breaking when it can't find packages imported in the code. 21 | # https://www.sphinx-doc.org/en/master/usage/extensions/autodoc.html#confval-autodoc_mock_imports # noqa: E501 22 | autodoc_mock_imports = [ 23 | "arviz", 24 | "matplotlib", 25 | "numpy", 26 | "pandas", 27 | "patsy", 28 | "pymc", 29 | "scipy", 30 | "seaborn", 31 | "sklearn", 32 | "xarray", 33 | ] 34 | 35 | # -- Project information ----------------------------------------------------- 36 | # https://www.sphinx-doc.org/en/master/usage/configuration.html#project-information 37 | project = "CausalPy" 38 | author = "PyMC Labs" 39 | copyright = f"2024, {author}" 40 | 41 | 42 | release = __version__ 43 | version = release 44 | 45 | # -- General configuration --------------------------------------------------- 46 | # https://www.sphinx-doc.org/en/master/usage/configuration.html#general-configuration 47 | 48 | # Add any Sphinx extension module names here, as strings 49 | extensions = [ 50 | # extensions from sphinx base 51 | "sphinx.ext.autodoc", 52 | "sphinx.ext.autosummary", 53 | "sphinx.ext.viewcode", 54 | "sphinx.ext.mathjax", 55 | "sphinx.ext.intersphinx", 56 | "sphinx.ext.napoleon", 57 | "sphinx_autodoc_typehints", 58 | # extensions provided by other packages 59 | "sphinxcontrib.bibtex", 60 | "matplotlib.sphinxext.plot_directive", # needed to plot in docstrings 61 | "myst_nb", 62 | "notfound.extension", 63 | "sphinx_copybutton", 64 | "sphinx_design", 65 | ] 66 | 67 | nb_execution_mode = "off" 68 | 69 | # configure copy button to avoid copying sphinx or console characters 70 | copybutton_exclude = ".linenos, .gp" 71 | copybutton_prompt_text = r">>> |\.\.\. |\$ |In \[\d*\]: | {2,5}\.\.\.: | {5,8}: " 72 | copybutton_prompt_is_regexp = True 73 | 74 | source_suffix = { 75 | ".rst": "restructuredtext", 76 | ".ipynb": "myst-nb", 77 | ".myst": "myst-nb", 78 | } 79 | templates_path = ["_templates"] 80 | exclude_patterns = ["_build", "Thumbs.db", ".DS_Store"] 81 | master_doc = "index" 82 | 83 | # bibtex config 84 | bibtex_bibfiles = ["references.bib"] 85 | bibtex_default_style = "unsrt" 86 | bibtex_reference_style = "author_year" 87 | 88 | 89 | # numpydoc and autodoc typehints config 90 | numpydoc_show_class_members = False 91 | numpydoc_xref_param_type = True 92 | # fmt: off 93 | numpydoc_xref_ignore = { 94 | "of", "or", "optional", "default", "numeric", "type", "scalar", "1D", "2D", "3D", "nD", "array", 95 | "instance", "M", "N" 96 | } 97 | # fmt: on 98 | numpydoc_xref_aliases = { 99 | "TensorVariable": ":class:`~pytensor.tensor.TensorVariable`", 100 | "RandomVariable": ":class:`~pytensor.tensor.random.RandomVariable`", 101 | "ndarray": ":class:`~numpy.ndarray`", 102 | "InferenceData": ":class:`~arviz.InferenceData`", 103 | "Model": ":class:`~pymc.Model`", 104 | "tensor_like": ":term:`tensor_like`", 105 | "unnamed_distribution": ":term:`unnamed_distribution`", 106 | } 107 | # don't add a return type section, use standard return with type info 108 | typehints_document_rtype = False 109 | 110 | # -- intersphinx config ------------------------------------------------------- 111 | intersphinx_mapping = { 112 | "arviz": ("https://python.arviz.org/en/stable/", None), 113 | "examples": ("https://www.pymc.io/projects/examples/en/latest/", None), 114 | "mpl": ("https://matplotlib.org/stable", None), 115 | "numpy": ("https://numpy.org/doc/stable/", None), 116 | "pandas": ("https://pandas.pydata.org/pandas-docs/stable/", None), 117 | "pymc": ("https://www.pymc.io/projects/docs/en/stable/", None), 118 | "python": ("https://docs.python.org/3", None), 119 | "scikit-learn": ("https://scikit-learn.org/stable/", None), 120 | "scipy": ("https://docs.scipy.org/doc/scipy/", None), 121 | "xarray": ("https://docs.xarray.dev/en/stable/", None), 122 | } 123 | 124 | # MyST options for working with markdown files. 125 | # Info about extensions here https://myst-parser.readthedocs.io/en/latest/syntax/optional.html?highlight=math#admonition-directives # noqa: E501 126 | myst_enable_extensions = [ 127 | "dollarmath", 128 | "amsmath", 129 | "colon_fence", 130 | "linkify", 131 | "html_admonition", 132 | ] 133 | 134 | # -- Options for HTML output ------------------------------------------------- 135 | # https://www.sphinx-doc.org/en/master/usage/configuration.html#options-for-html-output 136 | 137 | html_theme = "labs_sphinx_theme" 138 | html_static_path = ["_static"] 139 | html_favicon = "_static/favicon_logo.png" 140 | # Theme options are theme-specific and customize the look and feel of a theme 141 | # further. For a list of options available for each theme, see the 142 | # documentation. 143 | html_theme_options = { 144 | "logo": { 145 | "image_light": "_static/flat_logo.png", 146 | "image_dark": "_static/flat_logo_darkmode.png", 147 | }, 148 | "analytics": {"google_analytics_id": "G-3MCDG3M7X6"}, 149 | } 150 | html_context = { 151 | "github_user": "pymc-labs", 152 | "github_repo": "CausalPy", 153 | "github_version": "main", 154 | "doc_path": "docs/source/", 155 | "default_mode": "light", 156 | } 157 | 158 | # -- Options for autodoc ---------------------------------------------------- 159 | # https://www.sphinx-doc.org/en/master/usage/extensions/autodoc.html#configuration 160 | 161 | # Automatically extract typehints when specified and place them in 162 | # descriptions of the relevant function/method. 163 | autodoc_typehints = "description" 164 | 165 | # Don't show class signature with the class' name. 166 | autodoc_class_signature = "separated" 167 | 168 | # Add "Edit on Github" link. Replaces "view page source" ---------------------- 169 | html_context = { 170 | "display_github": True, # Integrate GitHub 171 | "github_user": "pymc-labs", # Username 172 | "github_repo": "CausalPy", # Repo name 173 | "github_version": "master", # Version 174 | "conf_py_path": "/docs/source/", # Path in the checkout to the docs root 175 | } 176 | -------------------------------------------------------------------------------- /docs/source/index.md: -------------------------------------------------------------------------------- 1 | :::{image} _static/logo.png 2 | :width: 60 % 3 | :align: center 4 | :alt: CausalPy logo 5 | ::: 6 | 7 | # CausalPy - causal inference for quasi-experiments 8 | 9 | A Python package focussing on causal inference for quasi-experiments. The package allows users to use different model types. Sophisticated Bayesian methods can be used, harnessing the power of [PyMC](https://www.pymc.io/) and [ArviZ](https://python.arviz.org). But users can also use more traditional [Ordinary Least Squares](https://en.wikipedia.org/wiki/Ordinary_least_squares) estimation methods via [scikit-learn](https://scikit-learn.org) models. 10 | 11 | ## Installation 12 | 13 | To get the latest release you can use pip: 14 | 15 | ```bash 16 | pip install CausalPy 17 | ``` 18 | 19 | or conda: 20 | 21 | ```bash 22 | conda install causalpy -c conda-forge 23 | ``` 24 | 25 | Alternatively, if you want the very latest version of the package you can install from GitHub: 26 | 27 | ```bash 28 | pip install git+https://github.com/pymc-labs/CausalPy.git 29 | ``` 30 | 31 | ## Quickstart 32 | 33 | ```python 34 | 35 | import causalpy as cp 36 | import matplotlib.pyplot as plt 37 | 38 | 39 | # Import and process data 40 | df = ( 41 | cp.load_data("drinking") 42 | .rename(columns={"agecell": "age"}) 43 | .assign(treated=lambda df_: df_.age > 21) 44 | ) 45 | 46 | # Run the analysis 47 | result = cp.RegressionDiscontinuity( 48 | df, 49 | formula="all ~ 1 + age + treated", 50 | running_variable_name="age", 51 | model=cp.pymc_models.LinearRegression(), 52 | treatment_threshold=21, 53 | ) 54 | 55 | # Visualize outputs 56 | fig, ax = result.plot(); 57 | 58 | # Get a results summary 59 | result.summary() 60 | 61 | plt.show() 62 | ``` 63 | 64 | ## Videos 65 | 66 | 85 | 86 |
87 | 88 |
89 | 90 | ## Features 91 | CausalPy has a broad range of quasi-experimental methods for causal inference: 92 | 93 | | Method | Description | 94 | |-|-| 95 | | Synthetic control | Constructs a synthetic version of the treatment group from a weighted combination of control units. Used for causal inference in comparative case studies when a single unit is treated, and there are multiple control units.| 96 | | Geographical lift | Measures the impact of an intervention in a specific geographic area by comparing it to similar areas without the intervention. Commonly used in marketing to assess regional campaigns. | 97 | | ANCOVA | Analysis of Covariance combines ANOVA and regression to control for the effects of one or more quantitative covariates. Used when comparing group means while controlling for other variables. | 98 | | Differences in Differences | Compares the changes in outcomes over time between a treatment group and a control group. Used in observational studies to estimate causal effects by accounting for time trends. | 99 | |Regression discontinuity | Identifies causal effects by exploiting a sharp cutoff or threshold in an assignment variable. Used when treatment is assigned based on a threshold value of an observed variable, allowing comparison just above and below the cutoff. | 100 | | Regression kink designs | Focuses on changes in the slope (kinks) of the relationship between variables rather than jumps at cutoff points. Used to identify causal effects when treatment intensity changes at a threshold. | 101 | | Interrupted time series | Analyzes the effect of an intervention by comparing time series data before and after the intervention. Used when data is collected over time and an intervention occurs at a known point, allowing assessment of changes in level or trend. | 102 | | Instrumental variable regression | Addresses endogeneity by using an instrument variable that is correlated with the endogenous explanatory variable but uncorrelated with the error term. Used when explanatory variables are correlated with the error term, providing consistent estimates of causal effects. | 103 | | Inverse Propensity Score Weighting | Weights observations by the inverse of the probability of receiving the treatment. Used in causal inference to create a synthetic sample where the treatment assignment is independent of measured covariates, helping to adjust for confounding variables in observational studies. | 104 | 105 | ## Support 106 | 107 | This repository is supported by [PyMC Labs](https://www.pymc-labs.io). 108 | 109 | For companies that want to use CausalPy in production, [PyMC Labs](https://www.pymc-labs.com) is available for consulting and training. We can help you build and deploy your models in production. We have experience with cutting edge Bayesian and causal modelling techniques which we have applied to a range of business domains. 110 | 111 |

112 | 113 | PyMC Labs Logo 114 | 115 |

116 | 117 | :::{toctree} 118 | :hidden: 119 | 120 | knowledgebase/index 121 | api/index 122 | notebooks/index 123 | ::: 124 | -------------------------------------------------------------------------------- /docs/source/knowledgebase/causal_video_resources.md: -------------------------------------------------------------------------------- 1 | # Causal video resources 2 | 3 | 4 | 23 | 24 | ## What if? Causal reasoning meets Bayesian Inference 25 | 26 |
27 | 28 |
29 | 30 | ## Combining Bayes and Graph-based Causal Inference 31 | 32 |
33 | 34 |
35 | 36 | ## Bayesian Causal Modeling 37 | 38 |
39 | 40 |
41 | 42 | ## Guide To Causal Inference Using PyMC 43 | 44 |
45 | 46 |
47 | -------------------------------------------------------------------------------- /docs/source/knowledgebase/causal_written_resources.md: -------------------------------------------------------------------------------- 1 | # Written resources on causal inference 2 | 3 | Below is a list of written resources (books, blog posts, etc.) that are useful for learning about causal inference. 4 | 5 | ## Quasi-experiment resources 6 | 7 | * Angrist, J. D., & Pischke, J. S. (2009). [Mostly harmless econometrics: An empiricist's companion](https://www.mostlyharmlesseconometrics.com). Princeton university press. 8 | * Angrist, J. D., & Pischke, J. S. (2014). [Mastering'metrics: The path from cause to effect](https://www.masteringmetrics.com). Princeton University Press. 9 | * Cunningham, S. (2021). [Causal inference: The Mixtape](https://mixtape.scunning.com). Yale University Press. 10 | * Huntington-Klein, N. (2021). [The effect: An introduction to research design and causality](https://theeffectbook.net). Chapman and Hall/CRC. 11 | * Reichardt, C. S. (2019). Quasi-experimentation: A guide to design and analysis. Guilford Publications. 12 | 13 | ## Bayesian causal inference resources 14 | * The official [PyMC examples gallery](https://www.pymc.io/projects/examples/en/latest/gallery.html) has a set of examples specifically relating to causal inference. 15 | 16 | ## General causal inference resources 17 | 18 | * [Awesome Causal Inference](https://github.com/matteocourthoud/awesome-causal-inference), a curated list of resources on causal inference, including books, blogs, and tutorials. 19 | -------------------------------------------------------------------------------- /docs/source/knowledgebase/design_notation.md: -------------------------------------------------------------------------------- 1 | # Quasi-experimental design notation 2 | 3 | This page provides a concise summary of the tabular notation used by {cite:t}`shadish_cook_cambell_2002` and {cite:t}`reichardt2019quasi`. This notation provides a compact description of various experimental designs. While it is possible to describe randomised designs using this notation, we focus purely on {term}`quasi-experimental` designs here, with non-random allocation (abbreviated as `NR`). Observations are denoted by $O$. Time proceeds from left to right, so observations made through time are labelled as $O_1$, $O_2$, etc. The treatment is denoted by `X`. Rows represent different groups of units. Remember, a unit is a person, place, or thing that is the subject of the study. 4 | 5 | ## Pretest-posttest designs 6 | 7 | One of the simplest designs is the pretest-posttest design. Here we have one row, denoting a single group of units. There is an `X` which means all are treated. The pretest is denoted by $O_1$ and the posttest by $O_2$. See p99 of {cite:t}`reichardt2019quasi`. 8 | 9 | | | | | 10 | |----|---|----| 11 | $O_1$ | X | $O_2$ | 12 | 13 | Informally, if we think about drawing conclusions about the {term}`causal impact` of the treatment based on the change from $O_1$ to $O_2$, we might say that the treatment caused the change. However, this is a tenuous conclusion because we have no way of knowing what would have happened in the ({term}`counterfactual`) absence of the treatment. 14 | 15 | A variation of this design which may (slightly) improve this situation from the perspective of making causal claims, would be to take multiple pretest measures. This is shown below, see p107 of {cite:t}`reichardt2019quasi`. 16 | 17 | | | | | | 18 | |----|--|---|----| 19 | $O_1$ | $O_2$ | X | $O_3$ | 20 | 21 | This would allow us to estimate how the group was changing over time before the treatment was introduced. This could be used to make a stronger causal claim about the impact of the treatment. We could use {term}`interrupted time series` analysis to help here. 22 | 23 | ## Nonequivalent group designs 24 | 25 | In randomized experiments, with large enough groups, the randomization process should ensure that the treatment and control groups are approximately equivalent in terms of their attributes. This is positive for causal inference as we can be more sure that differences between control and test groups are due to treatment exposure, not because of differences in attributes of the groups. 26 | 27 | However, in quasi-experimental designs, with non-random (`NR`) allocation, we could expect there to be differences between the treatment and control groups' attributes. This poses some challenges in making strong causal claims about the impact of the treatment - we can't be sure that differences between the groups at the posttest are due to the treatment, or due to pre-existing differences between the groups. 28 | 29 | In the simplest {term}`nonequivalent group design`, we have two groups, one treated and one not treated, and just one posttest. See p114 of {cite:t}`reichardt2019quasi`. 30 | 31 | | | | | 32 | |-----|---|----| 33 | | NR: | X | $O_1$ | 34 | | NR: | | $O_1$ | 35 | 36 | The above design would be considered weak - the lack of a pre-test measure makes it hard to know whether differences between the groups at $O_1$ are due to the treatment or to pre-existing differences between the groups. 37 | 38 | This limitation can be addressed by adding a pretest measure. See p115 of {cite:t}`reichardt2019quasi`. 39 | 40 | | | | | | 41 | |-----|----|---|----| 42 | | NR: | $O_1$ | X | $O_2$ | 43 | | NR: | $O_1$ | | $O_2$ | 44 | 45 | Non-equivalent group designs like this, with a pretest and a posttest measure could be analysed in a number of ways: 46 | 1. **{term}`ANCOVA`:** Here, the group would be a categorical predictor (e.g. treated/untreated), the pretest measure would be a covariate (though there could be more than one), and the posttest measure would be the outcome. 47 | 2. **{term}`Difference in differences`:** We can apply linear modeling approaches such as `y ~ group + time + group:time` to estimate the treatment effect. Here, `y` is the outcome measure, `group` is a binary variable indicating treatment or control group, and `time` is a binary variable indicating pretest or posttest. Note that this approach has a strong assumption of [parallel trends](https://en.wikipedia.org/wiki/Difference_in_differences#Assumptions) - that the treatment and control groups would have changed in the same way in the absence of the treatment. 48 | 49 | A limitation of the nonequivalent group designs with single pre and posttest measures is that we don't know how the groups were changing over time before the treatment was introduced. This can be addressed by adding multiple pretest measures and can help in assessing if the parallel trends assumption is reasonable. See p154 of {cite:t}`reichardt2019quasi`. 50 | 51 | | | | | | | 52 | |-----|----|---|-|----| 53 | | NR: | $O_1$ | $O_2$ | X | $O_3$ | 54 | | NR: | $O_1$ | $O_2$ | | $O_3$ | 55 | 56 | Again, this design could be analysed using the difference-in-differences approach. 57 | 58 | ## Interrupted time series designs 59 | 60 | While there is no control group, the {term}`interrupted time series design` is a powerful quasi-experimental design that can be used to estimate the causal impact of a treatment. The design involves multiple pretest and posttest measures. The treatment is introduced at a specific point in time, denoted by `X`. The design can be used to estimate the causal impact of the treatment by comparing the trajectory of the outcome variable before and after the treatment. See p203 of {cite:t}`reichardt2019quasi`. 61 | 62 | | | | | | | | | | | 63 | |-----|----|---|----|---|----|----|----|----| 64 | | $O_1$ | $O_2$ | $O_3$ | $O_4$ | X | $O_5$ | $O_6$ | $O_7$ | $O_8$ | 65 | 66 | You can see that this is an example of a pretest-posttest design with multiple pre and posttest measures. 67 | 68 | ## Comparative interrupted time series designs 69 | 70 | The {term}`comparative interrupted time-series` design incorporates aspects of **interrupted time series** (with only a treatment group), and **nonequivalent group designs** (with a treatment and control group). This design can be used to estimate the causal impact of a treatment by comparing the trajectory of the outcome variable before and after the treatment in the treatment group, and comparing this to the trajectory of the outcome variable in the control group. See p226 of {cite:t}`reichardt2019quasi`. 71 | 72 | | | | | | | | | | | | 73 | |-----|----|---|----|---|----|----|----|----|-| 74 | | NR: | $O_1$ | $O_2$ | $O_3$ | $O_4$ | X | $O_5$ | $O_6$ | $O_7$ | $O_8$ | 75 | | NR: | $O_1$ | $O_2$ | $O_3$ | $O_4$ | | $O_5$ | $O_6$ | $O_7$ | $O_8$ | 76 | 77 | 78 | Because this design is very similar to the nonequivalent group design, simply with multiple pre and posttest measures, it is well-suited to analysis under the difference-in-differences approach. 79 | 80 | However, if we have many untreated units and one treated unit, then this design could be analysed with the {term}`synthetic control` approach. 81 | 82 | ## Regression discontinuity designs 83 | 84 | The design notation for {term}`regression discontinuity designs` are different from the others and take a bit of getting used to. We have two groups, but allocation to the groups are determined by a units' relation to a cutoff point `C` along a {term}`running variable`. Also, $O_1$ now represents the value of the running variable, and $O_2$ represents the outcome variable. See p169 of {cite:t}`reichardt2019quasi`. This will make more sense if you consider the design notation alongside one of the example notebooks. 85 | 86 | | | | | | 87 | |-----|----|---|----| 88 | | C: | $O_1$ | X | $O_2$ | 89 | | C: | $O_1$ | | $O_2$ | 90 | 91 | From an analysis perspective, regression discontinuity designs are very similar to interrupted time series designs. The key difference is that treatment is determined by a cutoff point along a running variable, rather than by time. 92 | 93 | ## Summary 94 | This page has offered a brief overview of the tabular notation used to describe quasi-experimental designs. The notation is a useful tool for summarizing the design of a study, and can be used to help identify the strengths and limitations of a study design. But readers are strongly encouraged to consult the original sources when assessing the relative strengths and limitations of making causal claims under different quasi-experimental designs. 95 | 96 | ## References 97 | :::{bibliography} 98 | :filter: docname in docnames 99 | ::: 100 | -------------------------------------------------------------------------------- /docs/source/knowledgebase/glossary.rst: -------------------------------------------------------------------------------- 1 | Glossary 2 | ======== 3 | 4 | .. glossary:: 5 | :sorted: 6 | 7 | ANCOVA 8 | Analysis of covariance is a simple linear model, typically with one continuous predictor (the covariate) and a catgeorical variable (which may correspond to treatment or control group). In the context of this package, ANCOVA could be useful in pre-post treatment designs, either with or without random assignment. This is similar to the approach of difference in differences, but only applicable with a single pre and post treatment measure. 9 | 10 | Average treatment effect 11 | ATE 12 | The average treatment effect across all units. 13 | 14 | Average treatment effect on the treated 15 | ATT 16 | The average effect of the treatment on the units that received it. Also called Treatment on the treated. 17 | 18 | Change score analysis 19 | A statistical procedure where the outcome variable is the difference between the posttest and protest scores. 20 | 21 | Causal impact 22 | An umbrella term for the estimated effect of a treatment on an outcome. 23 | 24 | Comparative interrupted time-series 25 | CITS 26 | An interrupted time series design with added comparison time series observations. 27 | 28 | Confound 29 | Anything besides the treatment which varies across the treatment and control conditions. 30 | 31 | Counterfactual 32 | A hypothetical outcome that could or will occur under specific hypothetical circumstances. 33 | 34 | Difference in differences 35 | DiD 36 | Analysis where the treatment effect is estimated as a difference between treatment conditions in the differences between pre-treatment to post treatment observations. 37 | 38 | Interrupted time series design 39 | ITS 40 | A quasi-experimental design to estimate a treatment effect where a series of observations are collected before and after a treatment. No control group is present. 41 | 42 | Instrumental Variable regression 43 | IV 44 | A quasi-experimental design to estimate a treatment effect where the is a risk of confounding between the treatment and the outcome due to endogeniety. 45 | 46 | Endogenous Variable 47 | An endogenous variable is a variable in a regression equation such that the variable is correlated with the error term of the equation i.e. correlated with the outcome variable (in the system). This is a problem for OLS regression estimation techniques because endogeniety violates the assumptions of the Gauss Markov theorem. 48 | 49 | Local Average Treatment effect 50 | LATE 51 | Also known as the complier average causal effect (CACE), is the effect of a treatment for subjects who comply with the experimental treatment assigned to their sample group. It is the quantity we're estimating in IV designs. 52 | 53 | Non-equivalent group designs 54 | NEGD 55 | A quasi-experimental design where units are assigned to conditions non-randomly, and not according to a running variable (see Regression discontinuity design). This can be problematic when assigning causal influence of the treatment - differences in outcomes between groups could be due to the treatment or due to differences in the group attributes themselves. 56 | 57 | One-group posttest-only design 58 | A design where a single group is exposed to a treatment and assessed on an outcome measure. There is no pretest measure or comparison group. 59 | 60 | Parallel trends assumption 61 | An assumption made in difference in differences designs that the trends (over time) in the outcome variable would have been the same between the treatment and control groups in the absence of the treatment. 62 | 63 | Panel data 64 | Time series data collected on multiple units where the same units are observed at each time point. 65 | 66 | Pretest-posttest design 67 | A quasi-experimental design where the treatment effect is estimated by comparing an outcome measure before and after treatment. 68 | 69 | Propensity scores 70 | An estimate of the probability of adopting a treatment status. Used in re-weighting schemes to balance observational data. 71 | 72 | Quasi-experiment 73 | An empirical comparison used to estimate the effects of a treatment where units are not assigned to conditions at random. 74 | 75 | Random assignment 76 | Where units are assigned to conditions at random. 77 | 78 | Randomized experiment 79 | An empirical comparison used to estimate the effects of treatments where units are assigned to treatment conditions randomly. 80 | 81 | Regression discontinuity design 82 | RDD 83 | A quasi–experimental comparison to estimate a treatment effect where units are assigned to treatment conditions based on a cut-off score on a quantitative assignment variable (aka running variable). 84 | 85 | Regression kink design 86 | A quasi-experimental research design that estimates treatment effects by analyzing the impact of a treatment or intervention precisely at a defined threshold or "kink" point in a quantitative assignment variable (running variable). Unlike traditional regression discontinuity designs, regression kink design looks for a change in the slope of an outcome variable at the kink, instead of a discontinuity. This is useful when the assignment variable is not discrete, jumping from 0 to 1 at a threshold. Instead, regression kink designs are appropriate when there is a change in the first derivative of the assignment function at the kink point. 87 | 88 | Running variable 89 | In regression discontinuity designs, the running variable is the variable that determines the assignment of units to treatment or control conditions. This is typically a continuous variable. Examples could include a test score, age, income, or spatial location. But the running variable would not be time, which is the case in interrupted time series designs. 90 | 91 | Sharp regression discontinuity design 92 | A Regression discontinuity design where allocation to treatment or control is determined by a sharp threshold / step function. 93 | 94 | Synthetic control 95 | The synthetic control method is a statistical method used to evaluate the effect of an intervention in comparative case studies. It involves the construction of a weighted combination of groups used as controls, to which the treatment group is compared. 96 | 97 | Treatment on the treated effect 98 | TOT 99 | The average effect of the treatment on the units that received it. Also called the average treatment effect on the treated (ATT). 100 | 101 | Treatment effect 102 | The difference in outcomes between what happened after a treatment is implemented and what would have happened (see Counterfactual) if the treatment had not been implemented, assuming everything else had been the same. 103 | 104 | Wilkinson notation 105 | A notation for describing statistical models :footcite:p:`wilkinson1973symbolic`. 106 | 107 | Two Stage Least Squares 108 | 2SLS 109 | An estimation technique for estimating the parameters of an IV regression. It takes its name from the fact that it uses two OLS regressions - a first and second stage. 110 | 111 | 112 | 113 | References 114 | ---------- 115 | .. footbibliography:: 116 | -------------------------------------------------------------------------------- /docs/source/knowledgebase/index.md: -------------------------------------------------------------------------------- 1 | # Knowledge base 2 | 3 | :::{toctree} 4 | :maxdepth: 1 5 | 6 | glossary 7 | design_notation 8 | quasi_dags.ipynb 9 | causal_video_resources 10 | causal_written_resources 11 | ::: 12 | -------------------------------------------------------------------------------- /docs/source/notebooks/index.md: -------------------------------------------------------------------------------- 1 | # How-to 2 | 3 | :::{toctree} 4 | :caption: ANCOVA 5 | :maxdepth: 1 6 | 7 | ancova_pymc.ipynb 8 | ::: 9 | 10 | :::{toctree} 11 | :caption: Synthetic Control 12 | :maxdepth: 1 13 | 14 | sc_skl.ipynb 15 | sc_pymc.ipynb 16 | sc_pymc_brexit.ipynb 17 | ::: 18 | 19 | :::{toctree} 20 | :caption: Geographical lift testing 21 | :maxdepth: 1 22 | 23 | geolift1.ipynb 24 | multi_cell_geolift.ipynb 25 | ::: 26 | 27 | :::{toctree} 28 | :caption: Difference in Differences 29 | :maxdepth: 1 30 | 31 | did_skl.ipynb 32 | did_pymc.ipynb 33 | did_pymc_banks.ipynb 34 | ::: 35 | 36 | :::{toctree} 37 | :caption: Interrupted Time Series 38 | :maxdepth: 1 39 | 40 | its_skl.ipynb 41 | its_pymc.ipynb 42 | its_covid.ipynb 43 | ::: 44 | 45 | :::{toctree} 46 | :caption: Regression Discontinuity 47 | :maxdepth: 1 48 | 49 | rd_skl.ipynb 50 | rd_pymc.ipynb 51 | rd_pymc_drinking.ipynb 52 | ::: 53 | 54 | :::{toctree} 55 | :caption: Regression Kink Design 56 | :maxdepth: 1 57 | 58 | rkink_pymc.ipynb 59 | ::: 60 | 61 | :::{toctree} 62 | :caption: Instrumental Variables Regression 63 | :maxdepth: 1 64 | 65 | iv_pymc.ipynb 66 | iv_weak_instruments.ipynb 67 | ::: 68 | 69 | :::{toctree} 70 | :caption: Inverse Propensity Score Weighting 71 | :maxdepth: 1 72 | 73 | inv_prop_pymc.ipynb 74 | ::: 75 | -------------------------------------------------------------------------------- /docs/source/references.bib: -------------------------------------------------------------------------------- 1 | @online{brexit2022policybrief, 2 | year={2022}, 3 | title={What can we know about the cost of Brexit so far?}, 4 | url={https://www.cer.eu/publications/archive/policy-brief/2022/cost-brexit-so-far}, 5 | author={Springford, John} 6 | } 7 | 8 | @book{reichardt2019quasi, 9 | title={Quasi-experimentation: A guide to design and analysis}, 10 | author={Reichardt, Charles S}, 11 | year={2019}, 12 | publisher={Guilford Publications} 13 | } 14 | 15 | @article{richardson2009monetary, 16 | title={Monetary intervention mitigated banking panics during the great depression: quasi-experimental evidence from a federal reserve district border, 1929--1933}, 17 | author={Richardson, Gary and Troost, William}, 18 | journal={Journal of Political Economy}, 19 | volume={117}, 20 | number={6}, 21 | pages={1031--1073}, 22 | year={2009}, 23 | publisher={The University of Chicago Press} 24 | } 25 | 26 | @book{angrist2014mastering, 27 | title={Mastering 'Metrics: The path from cause to effect}, 28 | author={Angrist, Joshua D and Pischke, J{\"o}rn-Steffen}, 29 | year={2014}, 30 | publisher={Princeton University Press} 31 | } 32 | 33 | @article{carpenter2009effect, 34 | title={The effect of alcohol consumption on mortality: regression discontinuity evidence from the minimum drinking age}, 35 | author={Carpenter, Christopher and Dobkin, Carlos}, 36 | journal={American Economic Journal: Applied Economics}, 37 | volume={1}, 38 | number={1}, 39 | pages={164--182}, 40 | year={2009}, 41 | publisher={American Economic Association} 42 | } 43 | 44 | @article{wilkinson1973symbolic, 45 | title={Symbolic description of factorial models for analysis of variance}, 46 | author={Wilkinson, GN and Rogers, CE}, 47 | journal={Journal of the Royal Statistical Society Series C: Applied Statistics}, 48 | volume={22}, 49 | number={3}, 50 | pages={392--399}, 51 | year={1973}, 52 | publisher={Oxford University Press} 53 | } 54 | 55 | @book{hansenEconometrics, 56 | title={Econometrics}, 57 | author={Hansen, Bruce E}, 58 | year={2022}, 59 | publisher={Princeton} 60 | } 61 | 62 | @book{aronowFoundations, 63 | author={Aronow, P and Miller, B}, 64 | title={Foundations of Agnostic Statistics}, 65 | publisher={Cambridge University Press}, 66 | year={2019} 67 | } 68 | 69 | @article{acemoglu2001colonial, 70 | title={The Colonial Origins of Comparative Development: An Empirical Investigation}, 71 | author={Acemoglu, D and Johnson, S and Robinson, J}, 72 | journal={American Economic Review}, 73 | volume={91}, 74 | number={5}, 75 | pages={1369--1401}, 76 | year={2001} 77 | } 78 | 79 | @incollection{card1995returns, 80 | author={Card, David}, 81 | title={Using Geographical Variation in College Proximity to Estimate the Return to Schooling}, 82 | editor={Christofides, L.N. and Grant, E.K. and Swidinsky, R.}, 83 | booktitle={Aspects of Labour Market Behaviour: Essays in Honour of John Vanderkamp}, 84 | year={1995}, 85 | publisher={University of Toronto Press} 86 | } 87 | 88 | @incollection{forde2024nonparam, 89 | author = {Forde, Nathaniel}, 90 | title = {Bayesian Non-parametric Causal Inference}, 91 | editor = {PyMC Team}, 92 | booktitle = {PyMC examples}, 93 | doi = {10.5281/zenodo.5654871}, 94 | year = {2024} 95 | } 96 | 97 | @book{shadish_cook_cambell_2002, 98 | title={Experimental and quasi-experimental designs for generalized causal inference}, 99 | author={Cook, Thomas D and Campbell, Donald Thomas and Shadish, William}, 100 | volume={1195}, 101 | year={2002}, 102 | publisher={Houghton Mifflin Boston, MA} 103 | } 104 | 105 | @article{steiner2017graphical, 106 | title={Graphical models for quasi-experimental designs}, 107 | author={Steiner, Peter M and Kim, Yongnam and Hall, Courtney E and Su, Dan}, 108 | journal={Sociological methods \& research}, 109 | volume={46}, 110 | number={2}, 111 | pages={155--188}, 112 | year={2017}, 113 | publisher={SAGE Publications Sage CA: Los Angeles, CA} 114 | } 115 | 116 | @book{cunningham2021causal, 117 | title={Causal inference: The mixtape}, 118 | author={Cunningham, Scott}, 119 | year={2021}, 120 | publisher={Yale university press} 121 | } 122 | 123 | @book{huntington2021effect, 124 | title={The effect: An introduction to research design and causality}, 125 | author={Huntington-Klein, Nick}, 126 | year={2021}, 127 | publisher={Chapman and Hall/CRC} 128 | } 129 | -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | name: CausalPy 2 | channels: 3 | - conda-forge 4 | dependencies: 5 | - arviz>=0.14.0 6 | - graphviz 7 | - ipython!=8.7.0 8 | - matplotlib>=3.5.3 9 | - numpy 10 | - pandas 11 | - patsy 12 | - pymc>=5.15.1 13 | - scikit-learn>=1 14 | - scipy 15 | - seaborn>=0.11.2 16 | - statsmodels 17 | - xarray>=v2022.11.0 18 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | # Minimum requirements for the build system to execute. 3 | build-backend = "setuptools.build_meta" 4 | requires = ["setuptools>=61.0"] 5 | 6 | # This is configuration specific to the `setuptools` build backend. 7 | # If you are using a different build backend, you will need to change this. 8 | [tool.setuptools.packages.find] 9 | exclude = ["causalpy.test*", "docs*"] 10 | 11 | [tool.setuptools.package-data] 12 | "causalpy.data" = ["*.csv"] 13 | 14 | [project] 15 | name = "CausalPy" 16 | version = "0.4.2" 17 | description = "Causal inference for quasi-experiments in Python" 18 | readme = "README.md" 19 | license = { file = "LICENSE" } 20 | authors = [{ name = "Ben Vincent", email = "ben.vincent@pymc-labs.com" }] 21 | requires-python = ">=3.10" 22 | 23 | # This field lists other packages that your project depends on to run. 24 | # Any package you put here will be installed by pip when your project is 25 | # installed, so they must be valid existing projects. 26 | # 27 | # For an analysis of this field vs pip's requirements files see: 28 | # https://packaging.python.org/discussions/install-requires-vs-requirements/ 29 | dependencies = [ 30 | "arviz>=0.14.0", 31 | "graphviz", 32 | "ipython!=8.7.0", 33 | "matplotlib>=3.5.3", 34 | "numpy", 35 | "pandas", 36 | "patsy", 37 | "pymc>=5.15.1", 38 | "scikit-learn>=1", 39 | "scipy", 40 | "seaborn>=0.11.2", 41 | "statsmodels", 42 | "xarray>=v2022.11.0", 43 | ] 44 | 45 | # List additional groups of dependencies here (e.g. development dependencies). Users 46 | # will be able to install these using the "extras" syntax, for example: 47 | # 48 | # $ pip install causalpy[dev] 49 | # 50 | # Similar to `dependencies` above, these must be valid existing projects. 51 | [project.optional-dependencies] 52 | dev = ["pathlib", "pre-commit", "twine", "interrogate", "codespell", "nbformat", "nbconvert"] 53 | docs = [ 54 | "ipykernel", 55 | "daft", 56 | "linkify-it-py", 57 | "myst-nb<=1.0.0", 58 | "pathlib", 59 | "pylint", 60 | "sphinx", 61 | "sphinx-autodoc-typehints", 62 | "sphinx_autodoc_defaultargs", 63 | "labs-sphinx-theme", 64 | "sphinx-copybutton", 65 | "sphinx-rtd-theme", 66 | "statsmodels", 67 | "sphinxcontrib-bibtex", 68 | "sphinx-notfound-page", 69 | "ipywidgets", 70 | "sphinx-design", 71 | ] 72 | lint = ["interrogate", "pre-commit", "ruff"] 73 | test = ["pytest", "pytest-cov", "codespell", "nbformat", "nbconvert"] 74 | 75 | [metadata] 76 | description-file = 'README.md' 77 | license_files = 'LICENSE' 78 | 79 | [project.urls] 80 | Homepage = "https://github.com/pymc-labs/CausalPy" 81 | "Bug Reports" = "https://github.com/pymc-labs/CausalPy/issues" 82 | "Source" = "https://github.com/pymc-labs/CausalPy" 83 | 84 | [tool.pytest.ini_options] 85 | addopts = [ 86 | "-vv", 87 | "--strict-markers", 88 | "--strict-config", 89 | "--cov=causalpy", 90 | "--cov-report=term-missing", 91 | "--doctest-modules", 92 | ] 93 | testpaths = "causalpy/tests" 94 | markers = [ 95 | "integration: mark as an integration test.", 96 | "slow: mark test as slow.", 97 | ] 98 | 99 | [tool.interrogate] 100 | ignore-init-method = true 101 | ignore-init-module = true 102 | ignore-magic = false 103 | ignore-semiprivate = false 104 | ignore-private = false 105 | ignore-property-decorators = false 106 | ignore-module = false 107 | ignore-nested-functions = false 108 | ignore-nested-classes = true 109 | ignore-setters = false 110 | fail-under = 85 111 | exclude = ["setup.py", "docs", "build", "dist"] 112 | ignore-regex = ["^get$", "^mock_.*", ".*BaseClass.*"] 113 | # possible values: 0 (minimal output), 1 (-v), 2 (-vv) 114 | verbose = 1 115 | quiet = false 116 | whitelist-regex = [] 117 | color = true 118 | omit-covered-files = false 119 | generate-badge = "docs/source/_static/" 120 | badge-format = "svg" 121 | 122 | [tool.ruff.format] 123 | docstring-code-format = true 124 | 125 | [tool.ruff.lint] 126 | extend-select = [ 127 | "I", # isort 128 | ] 129 | 130 | [tool.codespell] 131 | ignore-words = "./docs/source/.codespell/codespell-whitelist.txt" 132 | skip = "*.ipynb,*.csv,pyproject.toml,docs/source/.codespell/codespell-whitelist.txt" 133 | --------------------------------------------------------------------------------