├── .coveragerc ├── .deepsource.toml ├── .devcontainer ├── Dockerfile └── devcontainer.json ├── .flake8 ├── .github └── workflows │ ├── codeql-analysis.yml │ └── main.yml ├── .gitignore ├── .pre-commit-config.yaml ├── .prettierignore ├── .prettierrc.json ├── AUTHORS.md ├── CHANGELOG.md ├── CONTRIBUTING.md ├── GettingStarted.ipynb ├── LICENSE.md ├── README.md ├── environment.yml ├── package-lock.json ├── package.json ├── pyproject.toml ├── renovate.json ├── requirements.in ├── requirements.txt ├── setup.cfg ├── setup.py ├── src └── causalimpact │ ├── __init__.py │ ├── analysis.py │ ├── inferences.py │ ├── misc.py │ └── model.py ├── tests ├── README.md ├── __init__.py ├── conftest.py ├── fixtures │ └── analysis │ │ └── summary_report_output.txt ├── test_analysis.py ├── test_inferences.py ├── test_misc.py └── test_model.py └── tox.ini /.coveragerc: -------------------------------------------------------------------------------- 1 | # .coveragerc to control coverage.py 2 | [run] 3 | branch = True 4 | source = causalimpact 5 | # omit = bad_file.py 6 | 7 | [paths] 8 | source = 9 | src/ 10 | */site-packages/ 11 | 12 | [report] 13 | # Regexes for lines to exclude from consideration 14 | exclude_lines = 15 | # Have to re-enable the standard pragma 16 | pragma: no cover 17 | 18 | # Don't complain about missing debug-only code: 19 | def __repr__ 20 | if self\.debug 21 | 22 | # Don't complain if tests don't hit defensive assertion code: 23 | raise AssertionError 24 | raise NotImplementedError 25 | 26 | # Don't complain if non-runnable code isn't run: 27 | if 0: 28 | if __name__ == .__main__.: 29 | -------------------------------------------------------------------------------- /.deepsource.toml: -------------------------------------------------------------------------------- 1 | version = 1 2 | 3 | test_patterns = [ 4 | "tests/**" 5 | ] 6 | 7 | [[analyzers]] 8 | name = "python" 9 | enabled = true 10 | 11 | [analyzers.meta] 12 | runtime_version = "3.x.x" 13 | 14 | [[transformers]] 15 | name = "black" 16 | enabled = true 17 | -------------------------------------------------------------------------------- /.devcontainer/Dockerfile: -------------------------------------------------------------------------------- 1 | # See here for image contents: https://github.com/microsoft/vscode-dev-containers/tree/v0.231.6/containers/python-3-miniconda/.devcontainer/base.Dockerfile 2 | 3 | FROM mcr.microsoft.com/vscode/devcontainers/miniconda:0-3 4 | # [Choice] Python version: 3, 3.9, 3.8, 3.7, 3.6 5 | ARG VARIANT="3.9" 6 | 7 | # [Option] Install Node.js 8 | ARG INSTALL_NODE="true" 9 | ARG NODE_VERSION="lts/*" 10 | RUN if [ "${INSTALL_NODE}" = "true" ]; then su vscode -c "umask 0002 && . /usr/local/share/nvm/nvm.sh && nvm install ${NODE_VERSION} 2>&1"; fi 11 | 12 | # [Optional] If your pip requirements rarely change, uncomment this section to add them to the image. 13 | COPY requirements.txt /tmp/pip-tmp/ 14 | RUN conda install mamba -n base -c conda-forge 15 | # install package dependencies 16 | RUN conda install --file /tmp/pip-tmp/requirements.txt -c conda-forge && rm -rf /tmp/py-tmp 17 | #RUN conda install --file /tmp/pip-tmp/requirements.txt -c conda-forge && rm -rf /tmp/py-tmp 18 | # install dev dependencies 19 | RUN mamba install pytest pytest-cov tox mock commitizen pre-commit pip-tools jupyter -c conda-forge 20 | 21 | # [Optional] Uncomment this section to install additional OS packages. 22 | # RUN apt-get update && export DEBIAN_FRONTEND=noninteractive \ 23 | # && apt-get -y install --no-install-recommends 24 | -------------------------------------------------------------------------------- /.devcontainer/devcontainer.json: -------------------------------------------------------------------------------- 1 | // For format details, see https://aka.ms/devcontainer.json. For config options, see the README at: 2 | // https://github.com/microsoft/vscode-dev-containers/tree/v0.187.0/containers/python-3 3 | { 4 | "name": "Python 3", 5 | "build": { 6 | "dockerfile": "Dockerfile", 7 | "context": "..", 8 | "args": { 9 | // Update 'VARIANT' to pick a Python version: 3, 3.6, 3.7, 3.8, 3.9 10 | "VARIANT": "3.9", 11 | // Options 12 | "INSTALL_NODE": "true", 13 | "NODE_VERSION": "lts/*" 14 | } 15 | }, 16 | 17 | // Set *default* container specific settings.json values on container create. 18 | "settings": { 19 | "python.pythonPath": "/usr/local/bin/python", 20 | "python.languageServer": "Pylance", 21 | "python.linting.enabled": true, 22 | "python.linting.pylintEnabled": true, 23 | "python.formatting.autopep8Path": "/usr/local/py-utils/bin/autopep8", 24 | "python.formatting.blackPath": "/usr/local/py-utils/bin/black", 25 | "python.formatting.yapfPath": "/usr/local/py-utils/bin/yapf", 26 | "python.linting.banditPath": "/usr/local/py-utils/bin/bandit", 27 | "python.linting.flake8Path": "/usr/local/py-utils/bin/flake8", 28 | "python.linting.mypyPath": "/usr/local/py-utils/bin/mypy", 29 | "python.linting.pycodestylePath": "/usr/local/py-utils/bin/pycodestyle", 30 | "python.linting.pydocstylePath": "/usr/local/py-utils/bin/pydocstyle", 31 | "python.linting.pylintPath": "/usr/local/py-utils/bin/pylint" 32 | }, 33 | 34 | // Add the IDs of extensions you want installed when the container is created. 35 | "extensions": ["ms-python.python", "ms-python.vscode-pylance"], 36 | 37 | // Use 'forwardPorts' to make a list of ports inside the container available locally. 38 | "forwardPorts": [8888], 39 | 40 | // Use 'postCreateCommand' to run commands after the container is created. 41 | "postCreateCommand": "pre-commit install", 42 | 43 | // Comment out connect as root instead. More info: https://aka.ms/vscode-remote/containers/non-root. 44 | "remoteUser": "vscode" 45 | } 46 | -------------------------------------------------------------------------------- /.flake8: -------------------------------------------------------------------------------- 1 | [flake8] 2 | max-line-length = 88 3 | extend-ignore = E203 4 | -------------------------------------------------------------------------------- /.github/workflows/codeql-analysis.yml: -------------------------------------------------------------------------------- 1 | # For most projects, this workflow file will not need changing; you simply need 2 | # to commit it to your repository. 3 | # 4 | # You may wish to alter this file to override the set of languages analyzed, 5 | # or to provide custom queries or build logic. 6 | # 7 | # ******** NOTE ******** 8 | # We have attempted to detect the languages in your repository. Please check 9 | # the `language` matrix defined below to confirm you have the correct set of 10 | # supported CodeQL languages. 11 | # 12 | name: "CodeQL" 13 | 14 | on: 15 | push: 16 | branches: [master] 17 | pull_request: 18 | # The branches below must be a subset of the branches above 19 | branches: [master] 20 | schedule: 21 | - cron: "44 3 * * 1" 22 | 23 | jobs: 24 | analyze: 25 | name: Analyze 26 | runs-on: ubuntu-latest 27 | permissions: 28 | actions: read 29 | contents: read 30 | security-events: write 31 | 32 | strategy: 33 | fail-fast: false 34 | matrix: 35 | language: ["python"] 36 | # CodeQL supports [ 'cpp', 'csharp', 'go', 'java', 'javascript', 'python', 'ruby' ] 37 | # Learn more about CodeQL language support at https://aka.ms/codeql-docs/language-support 38 | 39 | steps: 40 | - name: Checkout repository 41 | uses: actions/checkout@v3 42 | 43 | # Initializes the CodeQL tools for scanning. 44 | - name: Initialize CodeQL 45 | uses: github/codeql-action/init@v2 46 | with: 47 | languages: ${{ matrix.language }} 48 | # If you wish to specify custom queries, you can do so here or in a config file. 49 | # By default, queries listed here will override any specified in a config file. 50 | # Prefix the list here with "+" to use these queries and those in the config file. 51 | # queries: ./path/to/local/query, your-org/your-repo/queries@main 52 | 53 | # Autobuild attempts to build any compiled languages (C/C++, C#, or Java). 54 | # If this step fails, then you should remove it and run the build manually (see below) 55 | - name: Autobuild 56 | uses: github/codeql-action/autobuild@v2 57 | 58 | # ℹ️ Command-line programs to run using the OS shell. 59 | # 📚 See https://docs.github.com/en/actions/using-workflows/workflow-syntax-for-github-actions#jobsjob_idstepsrun 60 | 61 | # ✏️ If the Autobuild fails above, remove it and uncomment the following three lines 62 | # and modify them (or add more) to build your code if your project 63 | # uses a compiled language 64 | 65 | #- run: | 66 | # make bootstrap 67 | # make release 68 | 69 | - name: Perform CodeQL Analysis 70 | uses: github/codeql-action/analyze@v2 71 | -------------------------------------------------------------------------------- /.github/workflows/main.yml: -------------------------------------------------------------------------------- 1 | # This workflow will install Python dependencies, run tests and lint with a variety of Python versions 2 | # For more information see: https://help.github.com/actions/language-and-framework-guides/using-python-with-github-actions 3 | 4 | name: main 5 | 6 | on: 7 | push: 8 | branches: [master] 9 | pull_request: 10 | branches: [master] 11 | 12 | jobs: 13 | build: 14 | runs-on: ubuntu-latest 15 | strategy: 16 | fail-fast: false 17 | matrix: 18 | python-version: ["3.8", "3.9", "3.10", "3.11"] 19 | 20 | steps: 21 | - uses: actions/checkout@v3 22 | - name: Set up Python ${{ matrix.python-version }} 23 | uses: actions/setup-python@v3 24 | with: 25 | python-version: ${{ matrix.python-version }} 26 | - name: Install dependencies 27 | run: | 28 | python -m pip install --upgrade pip 29 | python -m pip install tox 30 | if [ -f requirements.txt ]; then pip install -r requirements.txt; fi 31 | - name: Lint with flake8 32 | run: | 33 | tox -e lint 34 | - name: Test with tox 35 | run: | 36 | tox -- --cov-report=xml 37 | - uses: codecov/codecov-action@v3 38 | with: 39 | token: ${{ secrets.CODE_COV_TOKEN }} 40 | - name: build 41 | run: | 42 | tox -e build 43 | - name: Publish distribution 📦 to PyPI 44 | if: github.event_name == 'push' && startsWith(github.ref, 'refs/tags') 45 | uses: pypa/gh-action-pypi-publish@release/v1 46 | with: 47 | password: ${{ secrets.PYPI_API_TOKEN }} 48 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | env/ 12 | build/ 13 | develop-eggs/ 14 | dist/ 15 | downloads/ 16 | eggs/ 17 | .eggs/ 18 | lib/ 19 | lib64/ 20 | parts/ 21 | sdist/ 22 | var/ 23 | *.egg-info/ 24 | .installed.cfg 25 | *.egg 26 | 27 | # PyInstaller 28 | # Usually these files are written by a python script from a template 29 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 30 | *.manifest 31 | *.spec 32 | 33 | # Installer logs 34 | pip-log.txt 35 | pip-delete-this-directory.txt 36 | 37 | # Unit test / coverage reports 38 | htmlcov/ 39 | .tox/ 40 | .coverage 41 | .coverage.* 42 | .cache 43 | nosetests.xml 44 | coverage.xml 45 | *,cover 46 | .hypothesis/ 47 | .pytest_cache 48 | 49 | # Translations 50 | *.mo 51 | *.pot 52 | 53 | # Django stuff: 54 | *.log 55 | local_settings.py 56 | 57 | # Flask instance folder 58 | instance/ 59 | 60 | # Scrapy stuff: 61 | .scrapy 62 | 63 | # Sphinx documentation 64 | docs/_build/ 65 | 66 | # PyBuilder 67 | target/ 68 | 69 | # IPython Notebook 70 | .ipynb_checkpoints 71 | 72 | # pyenv 73 | .python-version 74 | 75 | # celery beat schedule file 76 | celerybeat-schedule 77 | 78 | # dotenv 79 | .env 80 | .condarc 81 | .conda 82 | .python_history 83 | .bash_history 84 | 85 | # virtualenv 86 | venv/ 87 | ENV/ 88 | 89 | # Spyder project settings 90 | .spyderproject 91 | 92 | #swap files 93 | *swp 94 | node_modules 95 | temp 96 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | repos: 2 | - repo: https://github.com/pre-commit/pre-commit-hooks 3 | rev: v4.2.0 4 | hooks: 5 | - id: check-yaml 6 | - id: end-of-file-fixer 7 | - id: trailing-whitespace 8 | - repo: https://github.com/psf/black 9 | rev: 22.3.0 10 | hooks: 11 | - id: black 12 | - repo: https://github.com/commitizen-tools/commitizen 13 | rev: v2.24.0 14 | hooks: 15 | - id: commitizen 16 | stages: [commit-msg] 17 | - repo: https://github.com/pycqa/flake8 18 | rev: "4.0.1" 19 | hooks: 20 | - id: flake8 21 | - repo: https://github.com/pre-commit/mirrors-prettier 22 | rev: "" # Use the sha / tag you want to point at 23 | hooks: 24 | - id: prettier 25 | -------------------------------------------------------------------------------- /.prettierignore: -------------------------------------------------------------------------------- 1 | *.html 2 | *.css 3 | # Byte-compiled / optimized / DLL files 4 | __pycache__/ 5 | *.py[cod] 6 | *$py.class 7 | 8 | # C extensions 9 | *.so 10 | 11 | # Distribution / packaging 12 | .Python 13 | env/ 14 | build/ 15 | develop-eggs/ 16 | dist/ 17 | downloads/ 18 | eggs/ 19 | .eggs/ 20 | lib/ 21 | lib64/ 22 | parts/ 23 | sdist/ 24 | var/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | **/.tox/ 42 | .coverage 43 | .coverage.* 44 | .cache 45 | nosetests.xml 46 | coverage.xml 47 | *,cover 48 | .hypothesis/ 49 | .pytest_cache 50 | 51 | # Translations 52 | *.mo 53 | *.pot 54 | 55 | # Django stuff: 56 | *.log 57 | local_settings.py 58 | 59 | # Flask instance folder 60 | instance/ 61 | 62 | # Scrapy stuff: 63 | .scrapy 64 | 65 | # Sphinx documentation 66 | docs/_build/ 67 | 68 | # PyBuilder 69 | target/ 70 | 71 | # IPython Notebook 72 | .ipynb_checkpoints 73 | 74 | # pyenv 75 | .python-version 76 | 77 | # celery beat schedule file 78 | celerybeat-schedule 79 | 80 | # dotenv 81 | .env 82 | .condarc 83 | .conda 84 | .python_history 85 | .bash_history 86 | 87 | # virtualenv 88 | venv/ 89 | ENV/ 90 | 91 | # Spyder project settings 92 | .spyderproject 93 | 94 | #swap files 95 | *swp 96 | -------------------------------------------------------------------------------- /.prettierrc.json: -------------------------------------------------------------------------------- 1 | {} 2 | -------------------------------------------------------------------------------- /AUTHORS.md: -------------------------------------------------------------------------------- 1 | # Contributors 2 | 3 | - Jamal Senouci 4 | - Alex Roy 5 | -------------------------------------------------------------------------------- /CHANGELOG.md: -------------------------------------------------------------------------------- 1 | ## v0.2.5 (2023-01-07) 2 | 3 | ### Fix 4 | 5 | - update pypi in actions on new tag 6 | - fixes #43 7 | - update to pymc5 8 | 9 | ## v0.2.4 (2022-06-06) 10 | 11 | ### Fix 12 | 13 | - updated dependency in setup.cfg 14 | 15 | ## v0.2.3 (2022-06-05) 16 | 17 | ### Fix 18 | 19 | - updated arviz version to avoid UnsatisfiableRequirements 20 | 21 | ## v0.2.2 (2022-05-08) 22 | 23 | ## v0.2.1 (2022-05-06) 24 | 25 | ### Fix 26 | 27 | - fixed error in report text + formatting 28 | 29 | ## v0.2.0 (2022-05-06) 30 | 31 | ### Feat 32 | 33 | - #10 added pymc bayesian estimation 34 | 35 | ### Fix 36 | 37 | - fixes error in check for ucm_model arg 38 | - check ucm_arg 39 | - help flake understand unused import 40 | - removed unused assignment 41 | 42 | ## v0.1.6 (2022-04-30) 43 | 44 | ### Fix 45 | 46 | - fixes #24 using pandas dtype checks 47 | - #18 move to use pandas .any() method for frames/series 48 | - #21 fix for pvalue scaling 49 | 50 | ## v0.1.5 (2020-03-12) 51 | 52 | ## v0.1.3 (2017-12-23) 53 | 54 | ## v0.1.2 (2017-12-20) 55 | 56 | ## v0.1.1 (2017-11-29) 57 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contributing 2 | 3 | Welcome to `causalimpact` contributor\'s guide. 4 | 5 | This document focuses on getting any potential contributor familiarized 6 | with the development processes, but [other kinds of 7 | contributions](https://opensource.guide/how-to-contribute) are also 8 | appreciated. 9 | 10 | If you are new to using [git](https://git-scm.com) or have never 11 | collaborated in a project previously, please have a look at 12 | [contribution-guide.org](https://www.contribution-guide.org/). Other 13 | resources are also listed in the excellent [guide created by 14 | FreeCodeCamp](https://github.com/FreeCodeCamp/how-to-contribute-to-open-source)[^1]. 15 | 16 | Please notice, all users and contributors are expected to be **open, 17 | considerate, reasonable, and respectful**. When in doubt, [Python 18 | Software Foundation\'s Code of 19 | Conduct](https://www.python.org/psf/conduct/) is a good reference in 20 | terms of behavior guidelines. 21 | 22 | ## Issue Reports 23 | 24 | If you experience bugs or general issues with `causalimpact`, please 25 | have a look on the [issue 26 | tracker](https://github.com/jamalsenouci/causalimpact/issues). If you 27 | don\'t see anything useful there, please feel free to fire an issue 28 | report. 29 | 30 | > Please don\'t forget to include the closed issues in your search. 31 | > Sometimes a solution was already reported, and the problem is considered 32 | > **solved**. 33 | 34 | New issue reports should include information about your programming 35 | environment (e.g., operating system, Python version) and steps to 36 | reproduce the problem. Please try also to simplify the reproduction 37 | steps to a very minimal example that still illustrates the problem you 38 | are facing. By removing other factors, you help us to identify the root 39 | cause of the issue. 40 | 41 | ## Documentation Improvements 42 | 43 | You can help improve `causalimpact` docs by making them more readable 44 | and coherent, or by adding missing information and correcting mistakes. 45 | 46 | `causalimpact` documentation uses as its main 47 | documentation compiler. This means that the docs are kept in the same 48 | repository as the project code, and that any documentation update is 49 | done in the same way as a code contribution. 50 | 51 | Documentation is written in the markdown language conforming to the 52 | [CommonMark](https://commonmark.org/) spec 53 | 54 | ### Using the Github web editor 55 | 56 | Please notice that the [GitHub web 57 | interface](https://docs.github.com/en/repositories/working-with-files/managing-files/editing-files) 58 | provides a quick way of propose changes in `causalimpact`\'s files. 59 | While this mechanism can be tricky for normal code contributions, it 60 | works perfectly fine for contributing to the docs, and can be quite 61 | handy. 62 | 63 | If you are interested in trying this method out, please navigate to the 64 | `docs` folder in the source 65 | [repository](https://github.com/jamalsenouci/causalimpact), find which 66 | file you would like to propose changes and click in the little pencil 67 | icon at the top, to open [GitHub\'s code 68 | editor](https://docs.github.com/en/repositories/working-with-files/managing-files/editing-files). 69 | Once you finish editing the file, please write a message in the form at 70 | the bottom of the page describing which changes have you made and what 71 | are the motivations behind them and submit your proposal. 72 | 73 | ### Using github.dev 74 | 75 | [github.dev](https://github.com/github/dev) also provides a convenient way to spin up a vscode editor in your browser for small changes. 76 | 77 | ### Working locally 78 | 79 | When working on documentation changes in your local machine, you can 80 | preview them using your IDE's markdown preview 81 | 82 | Example: [vscode guide](https://code.visualstudio.com/docs/languages/markdown) 83 | 84 | ## Code Contributions 85 | 86 | ### Internals 87 | 88 | The package exports the CausalImpact class which encapsulates the full range of functionality exposed to the user. This class is defined in src/causalimpact/analysis.py which is responsible for orchestrating the causalimpact workflow. 89 | 90 | The causal impact workflow is fairly linear and can be broadly represented as 91 | 92 | 1. check user provided inputs (happens in analysis.py) 93 | 2. fit the model (happens in model.py) 94 | 3. make the predictions (happens in inferences.py) 95 | 4. format, visualise and summarise the output (happens back in analysis.py) 96 | 97 | The model fitting is handled by statsmodels.tsa.structural.UnobservedComponents. 98 | The plotting is handled using matplotlib 99 | 100 | ### Submit an issue 101 | 102 | Before you work on any non-trivial code contribution it\'s best to first 103 | create a report in the [issue 104 | tracker](https://github.com/jamalsenouci/causalimpact/issues) to start 105 | a discussion on the subject. This often provides additional 106 | considerations and avoids unnecessary work. 107 | 108 | ### Create an environment 109 | 110 | Before you start coding, we recommend creating an isolated 111 | to avoid any problems with your installed Python packages. 112 | We recommend using vscode's [devcontainers](https://code.visualstudio.com/docs/remote/containers) 113 | 114 | ### Clone the repository 115 | 116 | 1. Create an user account on GitHub if you do not already have one. 117 | 118 | 2. Fork the project 119 | [repository](https://github.com/jamalsenouci/causalimpact): click 120 | on the _Fork_ button near the top of the page. This creates a copy 121 | of the code under your account on GitHub. 122 | 123 | 3. Clone this copy to your local disk: 124 | 125 | git clone git@github.com:YourLogin/causalimpact.git 126 | cd causalimpact 127 | 128 | 4. You should run: 129 | 130 | pip install -e . 131 | 132 | to be able to import the package under development in the Python REPL. 133 | 134 | 5. Install `pre-commit`: 135 | 136 | pip install pre-commit 137 | pre-commit install 138 | 139 | `causalimpact` comes with a lot of hooks configured to automatically 140 | help the developer to check the code being written. 141 | 142 | ### Implement your changes 143 | 144 | 1. Create a branch to hold your changes: 145 | 146 | git checkout -b my-feature 147 | 148 | and start making changes. Never work on the main branch! 149 | 150 | 2. Start your work on this branch. Don\'t forget to add 151 | [docstrings](https://google.github.io/styleguide/pyguide.html#381-docstrings) 152 | to new functions, modules and classes, especially if they are part 153 | of public APIs. 154 | 155 | 3. Add yourself to the list of contributors in `AUTHORS.md`. 156 | 157 | 4. When you're done editing, do: 158 | 159 | git add 160 | git commit 161 | 162 | to record your changes in [git](https://git-scm.com). 163 | 164 | Please make sure to see the validation messages from `pre-commit`\_ 165 | and fix any eventual issues. This should automatically use 166 | [flake8](https://flake8.pycqa.org/en/stable/)/[black](https://pypi.org/project/black/) 167 | to check/fix the code style in a way that is compatible with the 168 | project. 169 | 170 | > **Important** > \ 171 | > Don\'t forget to add unit tests and documentation in case your 172 | > contribution adds an additional feature and is not just a bugfix. 173 | 174 | Writing a [descriptive commit 175 | message](https://chris.beams.io/posts/git-commit) is highly 176 | recommended. 177 | 178 | 5. Please check that your changes don\'t break any unit tests with: 179 | 180 | tox 181 | 182 | (after having installed `tox`\_ with `pip install tox` or `pipx`). 183 | 184 | You can also use `tox`\_ to run several other pre-configured tasks 185 | in the repository. Try `tox -av` to see a list of the available 186 | checks. 187 | 188 | ### Submit your contribution 189 | 190 | 1. If everything works fine, push your local branch to GitHub with: 191 | 192 | git push -u origin my-feature 193 | 194 | 2. Go to the web page of your fork and click \"Create pull request\" to 195 | send your changes for review. 196 | 197 | Find more detailed information in [creating a 198 | PR](https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/proposing-changes-to-your-work-with-pull-requests/creating-a-pull-request). 199 | You might also want to open the PR as a draft first and mark it as 200 | ready for review after the feedbacks from the continuous 201 | integration (CI) system or any required fixes. 202 | ::: 203 | 204 | ### Troubleshooting 205 | 206 | The following tips can be used when facing problems to build or test the 207 | package: 208 | 209 | 1. Make sure to fetch all the tags from the upstream 210 | [repository](https://github.com/jamalsenouci/causalimpact). The 211 | command `git describe --abbrev=0 --tags` should return the version 212 | you are expecting. If you are trying to run CI scripts in a fork 213 | repository, make sure to push all the tags. You can also try to 214 | remove all the egg files or the complete egg folder, i.e., `.eggs`, 215 | as well as the `*.egg-info` folders in the `src` folder or 216 | potentially in the root of your project. 217 | 218 | 2. Sometimes `tox`\_ misses out when new dependencies are added, 219 | especially to `setup.cfg`. If you find 220 | any problems with missing dependencies when running a command with 221 | `tox`\_, try to recreate the `tox` environment using the `-r` flag. 222 | For example, instead of: 223 | 224 | tox -e build 225 | 226 | Try running: 227 | 228 | tox -r -e build 229 | 230 | 3. Make sure to have a reliable `tox`\_ installation that uses the 231 | correct Python version (e.g., 3.7+). When in doubt you can run: 232 | 233 | tox --version 234 | # OR 235 | which tox 236 | 237 | If you have trouble and are seeing weird errors upon running 238 | `tox`\_, you can also try to create a dedicated [virtual 239 | environment](https://realpython.com/python-virtual-environments-a-primer/) 240 | with a `tox`\_ binary freshly installed. For example: 241 | 242 | virtualenv .venv 243 | source .venv/bin/activate 244 | .venv/bin/pip install tox 245 | .venv/bin/tox -e all 246 | 247 | 4. [Pytest can drop 248 | you](https://docs.pytest.org/en/stable/how-to/failures.html#using-python-library-pdb-with-pytest) 249 | in an interactive session in the case an error occurs. In order to 250 | do that you need to pass a `--pdb` option (for example by running 251 | `tox -- -k --pdb`). You can also setup 252 | breakpoints manually instead of using the `--pdb` option. 253 | 254 | ## Maintainer tasks 255 | 256 | ### Releases 257 | 258 | If you are part of the group of maintainers and have correct user 259 | permissions on [PyPI](https://pypi.org/), the following steps can be 260 | used to release a new version for `causalimpact`: 261 | 262 | 1. Make sure all unit tests are successful locally and on CI. 263 | 2. Run `cz bump --changelog` to generate a new tag and an updated changelog.md file 264 | 3. Push the new tag to the upstream 265 | [repository](https://github.com/jamalsenouci/causalimpact), e.g., 266 | `git push upstream v1.2.3` 267 | 4. The github action should detect the new tag and publish to pypi 268 | -------------------------------------------------------------------------------- /LICENSE.md: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ## CausalImpact 2 | 3 | [![Python package](https://github.com/jamalsenouci/causalimpact/actions/workflows/main.yml/badge.svg)](https://github.com/jamalsenouci/causalimpact/actions/workflows/main.yml) 4 | [![codecov](https://codecov.io/gh/jamalsenouci/causalimpact/branch/master/graph/badge.svg?token=EIPC36VQHS)](https://codecov.io/gh/jamalsenouci/causalimpact) 5 | ![monthly downloads](https://pepy.tech/badge/causalimpact/month) 6 | [![DeepSource](https://deepsource.io/gh/jamalsenouci/causalimpact.svg/?label=active+issues&show_trend=true&token=R5aIDSkIId_5THWTAPKccjcH)](https://deepsource.io/gh/jamalsenouci/causalimpact/?ref=repository-badge) 7 | 8 | #### A Python package for causal inference using Bayesian structural time-series models 9 | 10 | This is a port of the R package CausalImpact, see: https://github.com/google/CausalImpact. 11 | 12 | This package implements an approach to estimating the causal effect of a designed intervention on a time series. For example, how many additional daily clicks were generated by an advertising campaign? Answering a question like this can be difficult when a randomized experiment is not available. The package aims to address this difficulty using a structural Bayesian time-series model to estimate how the response metric might have evolved after the intervention if the intervention had not occurred. 13 | 14 | As with all approaches to causal inference on non-experimental data, valid conclusions require strong assumptions. The CausalImpact package, in particular, assumes that the outcome time series can be explained in terms of a set of control time series that were themselves not affected by the intervention. Furthermore, the relation between treated series and control series is assumed to be stable during the post-intervention period. Understanding and checking these assumptions for any given application is critical for obtaining valid conclusions. 15 | 16 | #### Try it out in the browser 17 | 18 | [![Binder](https://mybinder.org/badge_logo.svg)](https://mybinder.org/v2/gh/jamalsenouci/causalimpact/HEAD?labpath=GettingStarted.ipynb) 19 | 20 | #### Installation 21 | 22 | install the latest release via pip 23 | 24 | ```bash 25 | pip install causalimpact 26 | ``` 27 | 28 | #### Getting started 29 | 30 | [Documentation and examples](https://nbviewer.org/github/jamalsenouci/causalimpact/blob/master/GettingStarted.ipynb) 31 | 32 | #### Further resources 33 | 34 | - Manuscript: [Brodersen et al., Annals of Applied Statistics (2015)](http://research.google.com/pubs/pub41854.html) 35 | 36 | #### Bugs 37 | 38 | The issue tracker is at https://github.com/jamalsenouci/causalimpact/issues. Please report any bugs that you find. Or, even better, fork the repository on GitHub and create a pull request. 39 | -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | name: binder 2 | dependencies: 3 | - python 4 | - compilers 5 | - pandas 6 | - numpy 7 | - statsmodels 8 | - matplotlib 9 | - pymc 10 | - pytensor 11 | -------------------------------------------------------------------------------- /package-lock.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "causalimpact", 3 | "version": "1.0.0", 4 | "lockfileVersion": 2, 5 | "requires": true, 6 | "packages": { 7 | "": { 8 | "name": "causalimpact", 9 | "version": "1.0.0", 10 | "license": "ISC", 11 | "dependencies": { 12 | "onchange": "^7.1.0" 13 | }, 14 | "devDependencies": { 15 | "prettier": "^2.6.2" 16 | } 17 | }, 18 | "node_modules/@blakeembrey/deque": { 19 | "version": "1.0.5", 20 | "resolved": "https://registry.npmjs.org/@blakeembrey/deque/-/deque-1.0.5.tgz", 21 | "integrity": "sha512-6xnwtvp9DY1EINIKdTfvfeAtCYw4OqBZJhtiqkT3ivjnEfa25VQ3TsKvaFfKm8MyGIEfE95qLe+bNEt3nB0Ylg==" 22 | }, 23 | "node_modules/@blakeembrey/template": { 24 | "version": "1.1.0", 25 | "resolved": "https://registry.npmjs.org/@blakeembrey/template/-/template-1.1.0.tgz", 26 | "integrity": "sha512-iZf+UWfL+DogJVpd/xMQyP6X6McYd6ArdYoPMiv/zlOTzeXXfQbYxBNJJBF6tThvsjLMbA8tLjkCdm9RWMFCCw==" 27 | }, 28 | "node_modules/anymatch": { 29 | "version": "3.1.2", 30 | "resolved": "https://registry.npmjs.org/anymatch/-/anymatch-3.1.2.tgz", 31 | "integrity": "sha512-P43ePfOAIupkguHUycrc4qJ9kz8ZiuOUijaETwX7THt0Y/GNK7v0aa8rY816xWjZ7rJdA5XdMcpVFTKMq+RvWg==", 32 | "dependencies": { 33 | "normalize-path": "^3.0.0", 34 | "picomatch": "^2.0.4" 35 | }, 36 | "engines": { 37 | "node": ">= 8" 38 | } 39 | }, 40 | "node_modules/arg": { 41 | "version": "4.1.3", 42 | "resolved": "https://registry.npmjs.org/arg/-/arg-4.1.3.tgz", 43 | "integrity": "sha512-58S9QDqG0Xx27YwPSt9fJxivjYl432YCwfDMfZ+71RAqUrZef7LrKQZ3LHLOwCS4FLNBplP533Zx895SeOCHvA==" 44 | }, 45 | "node_modules/binary-extensions": { 46 | "version": "2.2.0", 47 | "resolved": "https://registry.npmjs.org/binary-extensions/-/binary-extensions-2.2.0.tgz", 48 | "integrity": "sha512-jDctJ/IVQbZoJykoeHbhXpOlNBqGNcwXJKJog42E5HDPUwQTSdjCHdihjj0DlnheQ7blbT6dHOafNAiS8ooQKA==", 49 | "engines": { 50 | "node": ">=8" 51 | } 52 | }, 53 | "node_modules/braces": { 54 | "version": "3.0.2", 55 | "resolved": "https://registry.npmjs.org/braces/-/braces-3.0.2.tgz", 56 | "integrity": "sha512-b8um+L1RzM3WDSzvhm6gIz1yfTbBt6YTlcEKAvsmqCZZFw46z626lVj9j1yEPW33H5H+lBQpZMP1k8l+78Ha0A==", 57 | "dependencies": { 58 | "fill-range": "^7.0.1" 59 | }, 60 | "engines": { 61 | "node": ">=8" 62 | } 63 | }, 64 | "node_modules/chokidar": { 65 | "version": "3.5.3", 66 | "resolved": "https://registry.npmjs.org/chokidar/-/chokidar-3.5.3.tgz", 67 | "integrity": "sha512-Dr3sfKRP6oTcjf2JmUmFJfeVMvXBdegxB0iVQ5eb2V10uFJUCAS8OByZdVAyVb8xXNz3GjjTgj9kLWsZTqE6kw==", 68 | "funding": [ 69 | { 70 | "type": "individual", 71 | "url": "https://paulmillr.com/funding/" 72 | } 73 | ], 74 | "dependencies": { 75 | "anymatch": "~3.1.2", 76 | "braces": "~3.0.2", 77 | "glob-parent": "~5.1.2", 78 | "is-binary-path": "~2.1.0", 79 | "is-glob": "~4.0.1", 80 | "normalize-path": "~3.0.0", 81 | "readdirp": "~3.6.0" 82 | }, 83 | "engines": { 84 | "node": ">= 8.10.0" 85 | }, 86 | "optionalDependencies": { 87 | "fsevents": "~2.3.2" 88 | } 89 | }, 90 | "node_modules/cross-spawn": { 91 | "version": "7.0.3", 92 | "resolved": "https://registry.npmjs.org/cross-spawn/-/cross-spawn-7.0.3.tgz", 93 | "integrity": "sha512-iRDPJKUPVEND7dHPO8rkbOnPpyDygcDFtWjpeWNCgy8WP2rXcxXL8TskReQl6OrB2G7+UJrags1q15Fudc7G6w==", 94 | "dependencies": { 95 | "path-key": "^3.1.0", 96 | "shebang-command": "^2.0.0", 97 | "which": "^2.0.1" 98 | }, 99 | "engines": { 100 | "node": ">= 8" 101 | } 102 | }, 103 | "node_modules/fill-range": { 104 | "version": "7.0.1", 105 | "resolved": "https://registry.npmjs.org/fill-range/-/fill-range-7.0.1.tgz", 106 | "integrity": "sha512-qOo9F+dMUmC2Lcb4BbVvnKJxTPjCm+RRpe4gDuGrzkL7mEVl/djYSu2OdQ2Pa302N4oqkSg9ir6jaLWJ2USVpQ==", 107 | "dependencies": { 108 | "to-regex-range": "^5.0.1" 109 | }, 110 | "engines": { 111 | "node": ">=8" 112 | } 113 | }, 114 | "node_modules/fsevents": { 115 | "version": "2.3.2", 116 | "resolved": "https://registry.npmjs.org/fsevents/-/fsevents-2.3.2.tgz", 117 | "integrity": "sha512-xiqMQR4xAeHTuB9uWm+fFRcIOgKBMiOBP+eXiyT7jsgVCq1bkVygt00oASowB7EdtpOHaaPgKt812P9ab+DDKA==", 118 | "hasInstallScript": true, 119 | "optional": true, 120 | "os": [ 121 | "darwin" 122 | ], 123 | "engines": { 124 | "node": "^8.16.0 || ^10.6.0 || >=11.0.0" 125 | } 126 | }, 127 | "node_modules/glob-parent": { 128 | "version": "5.1.2", 129 | "resolved": "https://registry.npmjs.org/glob-parent/-/glob-parent-5.1.2.tgz", 130 | "integrity": "sha512-AOIgSQCepiJYwP3ARnGx+5VnTu2HBYdzbGP45eLw1vr3zB3vZLeyed1sC9hnbcOc9/SrMyM5RPQrkGz4aS9Zow==", 131 | "dependencies": { 132 | "is-glob": "^4.0.1" 133 | }, 134 | "engines": { 135 | "node": ">= 6" 136 | } 137 | }, 138 | "node_modules/ignore": { 139 | "version": "5.2.0", 140 | "resolved": "https://registry.npmjs.org/ignore/-/ignore-5.2.0.tgz", 141 | "integrity": "sha512-CmxgYGiEPCLhfLnpPp1MoRmifwEIOgjcHXxOBjv7mY96c+eWScsOP9c112ZyLdWHi0FxHjI+4uVhKYp/gcdRmQ==", 142 | "engines": { 143 | "node": ">= 4" 144 | } 145 | }, 146 | "node_modules/is-binary-path": { 147 | "version": "2.1.0", 148 | "resolved": "https://registry.npmjs.org/is-binary-path/-/is-binary-path-2.1.0.tgz", 149 | "integrity": "sha512-ZMERYes6pDydyuGidse7OsHxtbI7WVeUEozgR/g7rd0xUimYNlvZRE/K2MgZTjWy725IfelLeVcEM97mmtRGXw==", 150 | "dependencies": { 151 | "binary-extensions": "^2.0.0" 152 | }, 153 | "engines": { 154 | "node": ">=8" 155 | } 156 | }, 157 | "node_modules/is-extglob": { 158 | "version": "2.1.1", 159 | "resolved": "https://registry.npmjs.org/is-extglob/-/is-extglob-2.1.1.tgz", 160 | "integrity": "sha1-qIwCU1eR8C7TfHahueqXc8gz+MI=", 161 | "engines": { 162 | "node": ">=0.10.0" 163 | } 164 | }, 165 | "node_modules/is-glob": { 166 | "version": "4.0.3", 167 | "resolved": "https://registry.npmjs.org/is-glob/-/is-glob-4.0.3.tgz", 168 | "integrity": "sha512-xelSayHH36ZgE7ZWhli7pW34hNbNl8Ojv5KVmkJD4hBdD3th8Tfk9vYasLM+mXWOZhFkgZfxhLSnrwRr4elSSg==", 169 | "dependencies": { 170 | "is-extglob": "^2.1.1" 171 | }, 172 | "engines": { 173 | "node": ">=0.10.0" 174 | } 175 | }, 176 | "node_modules/is-number": { 177 | "version": "7.0.0", 178 | "resolved": "https://registry.npmjs.org/is-number/-/is-number-7.0.0.tgz", 179 | "integrity": "sha512-41Cifkg6e8TylSpdtTpeLVMqvSBEVzTttHvERD741+pnZ8ANv0004MRL43QKPDlK9cGvNp6NZWZUBlbGXYxxng==", 180 | "engines": { 181 | "node": ">=0.12.0" 182 | } 183 | }, 184 | "node_modules/isexe": { 185 | "version": "2.0.0", 186 | "resolved": "https://registry.npmjs.org/isexe/-/isexe-2.0.0.tgz", 187 | "integrity": "sha1-6PvzdNxVb/iUehDcsFctYz8s+hA=" 188 | }, 189 | "node_modules/normalize-path": { 190 | "version": "3.0.0", 191 | "resolved": "https://registry.npmjs.org/normalize-path/-/normalize-path-3.0.0.tgz", 192 | "integrity": "sha512-6eZs5Ls3WtCisHWp9S2GUy8dqkpGi4BVSz3GaqiE6ezub0512ESztXUwUB6C6IKbQkY2Pnb/mD4WYojCRwcwLA==", 193 | "engines": { 194 | "node": ">=0.10.0" 195 | } 196 | }, 197 | "node_modules/onchange": { 198 | "version": "7.1.0", 199 | "resolved": "https://registry.npmjs.org/onchange/-/onchange-7.1.0.tgz", 200 | "integrity": "sha512-ZJcqsPiWUAUpvmnJri5TPBooqJOPmC0ttN65juhN15Q8xA+Nbg3BaxBHXQ45EistKKlKElb0edmbPWnKSBkvMg==", 201 | "dependencies": { 202 | "@blakeembrey/deque": "^1.0.5", 203 | "@blakeembrey/template": "^1.0.0", 204 | "arg": "^4.1.3", 205 | "chokidar": "^3.3.1", 206 | "cross-spawn": "^7.0.1", 207 | "ignore": "^5.1.4", 208 | "tree-kill": "^1.2.2" 209 | }, 210 | "bin": { 211 | "onchange": "dist/bin.js" 212 | } 213 | }, 214 | "node_modules/path-key": { 215 | "version": "3.1.1", 216 | "resolved": "https://registry.npmjs.org/path-key/-/path-key-3.1.1.tgz", 217 | "integrity": "sha512-ojmeN0qd+y0jszEtoY48r0Peq5dwMEkIlCOu6Q5f41lfkswXuKtYrhgoTpLnyIcHm24Uhqx+5Tqm2InSwLhE6Q==", 218 | "engines": { 219 | "node": ">=8" 220 | } 221 | }, 222 | "node_modules/picomatch": { 223 | "version": "2.3.1", 224 | "resolved": "https://registry.npmjs.org/picomatch/-/picomatch-2.3.1.tgz", 225 | "integrity": "sha512-JU3teHTNjmE2VCGFzuY8EXzCDVwEqB2a8fsIvwaStHhAWJEeVd1o1QD80CU6+ZdEXXSLbSsuLwJjkCBWqRQUVA==", 226 | "engines": { 227 | "node": ">=8.6" 228 | }, 229 | "funding": { 230 | "url": "https://github.com/sponsors/jonschlinkert" 231 | } 232 | }, 233 | "node_modules/prettier": { 234 | "version": "2.6.2", 235 | "resolved": "https://registry.npmjs.org/prettier/-/prettier-2.6.2.tgz", 236 | "integrity": "sha512-PkUpF+qoXTqhOeWL9fu7As8LXsIUZ1WYaJiY/a7McAQzxjk82OF0tibkFXVCDImZtWxbvojFjerkiLb0/q8mew==", 237 | "dev": true, 238 | "bin": { 239 | "prettier": "bin-prettier.js" 240 | }, 241 | "engines": { 242 | "node": ">=10.13.0" 243 | }, 244 | "funding": { 245 | "url": "https://github.com/prettier/prettier?sponsor=1" 246 | } 247 | }, 248 | "node_modules/readdirp": { 249 | "version": "3.6.0", 250 | "resolved": "https://registry.npmjs.org/readdirp/-/readdirp-3.6.0.tgz", 251 | "integrity": "sha512-hOS089on8RduqdbhvQ5Z37A0ESjsqz6qnRcffsMU3495FuTdqSm+7bhJ29JvIOsBDEEnan5DPu9t3To9VRlMzA==", 252 | "dependencies": { 253 | "picomatch": "^2.2.1" 254 | }, 255 | "engines": { 256 | "node": ">=8.10.0" 257 | } 258 | }, 259 | "node_modules/shebang-command": { 260 | "version": "2.0.0", 261 | "resolved": "https://registry.npmjs.org/shebang-command/-/shebang-command-2.0.0.tgz", 262 | "integrity": "sha512-kHxr2zZpYtdmrN1qDjrrX/Z1rR1kG8Dx+gkpK1G4eXmvXswmcE1hTWBWYUzlraYw1/yZp6YuDY77YtvbN0dmDA==", 263 | "dependencies": { 264 | "shebang-regex": "^3.0.0" 265 | }, 266 | "engines": { 267 | "node": ">=8" 268 | } 269 | }, 270 | "node_modules/shebang-regex": { 271 | "version": "3.0.0", 272 | "resolved": "https://registry.npmjs.org/shebang-regex/-/shebang-regex-3.0.0.tgz", 273 | "integrity": "sha512-7++dFhtcx3353uBaq8DDR4NuxBetBzC7ZQOhmTQInHEd6bSrXdiEyzCvG07Z44UYdLShWUyXt5M/yhz8ekcb1A==", 274 | "engines": { 275 | "node": ">=8" 276 | } 277 | }, 278 | "node_modules/to-regex-range": { 279 | "version": "5.0.1", 280 | "resolved": "https://registry.npmjs.org/to-regex-range/-/to-regex-range-5.0.1.tgz", 281 | "integrity": "sha512-65P7iz6X5yEr1cwcgvQxbbIw7Uk3gOy5dIdtZ4rDveLqhrdJP+Li/Hx6tyK0NEb+2GCyneCMJiGqrADCSNk8sQ==", 282 | "dependencies": { 283 | "is-number": "^7.0.0" 284 | }, 285 | "engines": { 286 | "node": ">=8.0" 287 | } 288 | }, 289 | "node_modules/tree-kill": { 290 | "version": "1.2.2", 291 | "resolved": "https://registry.npmjs.org/tree-kill/-/tree-kill-1.2.2.tgz", 292 | "integrity": "sha512-L0Orpi8qGpRG//Nd+H90vFB+3iHnue1zSSGmNOOCh1GLJ7rUKVwV2HvijphGQS2UmhUZewS9VgvxYIdgr+fG1A==", 293 | "bin": { 294 | "tree-kill": "cli.js" 295 | } 296 | }, 297 | "node_modules/which": { 298 | "version": "2.0.2", 299 | "resolved": "https://registry.npmjs.org/which/-/which-2.0.2.tgz", 300 | "integrity": "sha512-BLI3Tl1TW3Pvl70l3yq3Y64i+awpwXqsGBYWkkqMtnbXgrMD+yj7rhW0kuEDxzJaYXGjEW5ogapKNMEKNMjibA==", 301 | "dependencies": { 302 | "isexe": "^2.0.0" 303 | }, 304 | "bin": { 305 | "node-which": "bin/node-which" 306 | }, 307 | "engines": { 308 | "node": ">= 8" 309 | } 310 | } 311 | }, 312 | "dependencies": { 313 | "@blakeembrey/deque": { 314 | "version": "1.0.5", 315 | "resolved": "https://registry.npmjs.org/@blakeembrey/deque/-/deque-1.0.5.tgz", 316 | "integrity": "sha512-6xnwtvp9DY1EINIKdTfvfeAtCYw4OqBZJhtiqkT3ivjnEfa25VQ3TsKvaFfKm8MyGIEfE95qLe+bNEt3nB0Ylg==" 317 | }, 318 | "@blakeembrey/template": { 319 | "version": "1.1.0", 320 | "resolved": "https://registry.npmjs.org/@blakeembrey/template/-/template-1.1.0.tgz", 321 | "integrity": "sha512-iZf+UWfL+DogJVpd/xMQyP6X6McYd6ArdYoPMiv/zlOTzeXXfQbYxBNJJBF6tThvsjLMbA8tLjkCdm9RWMFCCw==" 322 | }, 323 | "anymatch": { 324 | "version": "3.1.2", 325 | "resolved": "https://registry.npmjs.org/anymatch/-/anymatch-3.1.2.tgz", 326 | "integrity": "sha512-P43ePfOAIupkguHUycrc4qJ9kz8ZiuOUijaETwX7THt0Y/GNK7v0aa8rY816xWjZ7rJdA5XdMcpVFTKMq+RvWg==", 327 | "requires": { 328 | "normalize-path": "^3.0.0", 329 | "picomatch": "^2.0.4" 330 | } 331 | }, 332 | "arg": { 333 | "version": "4.1.3", 334 | "resolved": "https://registry.npmjs.org/arg/-/arg-4.1.3.tgz", 335 | "integrity": "sha512-58S9QDqG0Xx27YwPSt9fJxivjYl432YCwfDMfZ+71RAqUrZef7LrKQZ3LHLOwCS4FLNBplP533Zx895SeOCHvA==" 336 | }, 337 | "binary-extensions": { 338 | "version": "2.2.0", 339 | "resolved": "https://registry.npmjs.org/binary-extensions/-/binary-extensions-2.2.0.tgz", 340 | "integrity": "sha512-jDctJ/IVQbZoJykoeHbhXpOlNBqGNcwXJKJog42E5HDPUwQTSdjCHdihjj0DlnheQ7blbT6dHOafNAiS8ooQKA==" 341 | }, 342 | "braces": { 343 | "version": "3.0.2", 344 | "resolved": "https://registry.npmjs.org/braces/-/braces-3.0.2.tgz", 345 | "integrity": "sha512-b8um+L1RzM3WDSzvhm6gIz1yfTbBt6YTlcEKAvsmqCZZFw46z626lVj9j1yEPW33H5H+lBQpZMP1k8l+78Ha0A==", 346 | "requires": { 347 | "fill-range": "^7.0.1" 348 | } 349 | }, 350 | "chokidar": { 351 | "version": "3.5.3", 352 | "resolved": "https://registry.npmjs.org/chokidar/-/chokidar-3.5.3.tgz", 353 | "integrity": "sha512-Dr3sfKRP6oTcjf2JmUmFJfeVMvXBdegxB0iVQ5eb2V10uFJUCAS8OByZdVAyVb8xXNz3GjjTgj9kLWsZTqE6kw==", 354 | "requires": { 355 | "anymatch": "~3.1.2", 356 | "braces": "~3.0.2", 357 | "fsevents": "~2.3.2", 358 | "glob-parent": "~5.1.2", 359 | "is-binary-path": "~2.1.0", 360 | "is-glob": "~4.0.1", 361 | "normalize-path": "~3.0.0", 362 | "readdirp": "~3.6.0" 363 | } 364 | }, 365 | "cross-spawn": { 366 | "version": "7.0.3", 367 | "resolved": "https://registry.npmjs.org/cross-spawn/-/cross-spawn-7.0.3.tgz", 368 | "integrity": "sha512-iRDPJKUPVEND7dHPO8rkbOnPpyDygcDFtWjpeWNCgy8WP2rXcxXL8TskReQl6OrB2G7+UJrags1q15Fudc7G6w==", 369 | "requires": { 370 | "path-key": "^3.1.0", 371 | "shebang-command": "^2.0.0", 372 | "which": "^2.0.1" 373 | } 374 | }, 375 | "fill-range": { 376 | "version": "7.0.1", 377 | "resolved": "https://registry.npmjs.org/fill-range/-/fill-range-7.0.1.tgz", 378 | "integrity": "sha512-qOo9F+dMUmC2Lcb4BbVvnKJxTPjCm+RRpe4gDuGrzkL7mEVl/djYSu2OdQ2Pa302N4oqkSg9ir6jaLWJ2USVpQ==", 379 | "requires": { 380 | "to-regex-range": "^5.0.1" 381 | } 382 | }, 383 | "fsevents": { 384 | "version": "2.3.2", 385 | "resolved": "https://registry.npmjs.org/fsevents/-/fsevents-2.3.2.tgz", 386 | "integrity": "sha512-xiqMQR4xAeHTuB9uWm+fFRcIOgKBMiOBP+eXiyT7jsgVCq1bkVygt00oASowB7EdtpOHaaPgKt812P9ab+DDKA==", 387 | "optional": true 388 | }, 389 | "glob-parent": { 390 | "version": "5.1.2", 391 | "resolved": "https://registry.npmjs.org/glob-parent/-/glob-parent-5.1.2.tgz", 392 | "integrity": "sha512-AOIgSQCepiJYwP3ARnGx+5VnTu2HBYdzbGP45eLw1vr3zB3vZLeyed1sC9hnbcOc9/SrMyM5RPQrkGz4aS9Zow==", 393 | "requires": { 394 | "is-glob": "^4.0.1" 395 | } 396 | }, 397 | "ignore": { 398 | "version": "5.2.0", 399 | "resolved": "https://registry.npmjs.org/ignore/-/ignore-5.2.0.tgz", 400 | "integrity": "sha512-CmxgYGiEPCLhfLnpPp1MoRmifwEIOgjcHXxOBjv7mY96c+eWScsOP9c112ZyLdWHi0FxHjI+4uVhKYp/gcdRmQ==" 401 | }, 402 | "is-binary-path": { 403 | "version": "2.1.0", 404 | "resolved": "https://registry.npmjs.org/is-binary-path/-/is-binary-path-2.1.0.tgz", 405 | "integrity": "sha512-ZMERYes6pDydyuGidse7OsHxtbI7WVeUEozgR/g7rd0xUimYNlvZRE/K2MgZTjWy725IfelLeVcEM97mmtRGXw==", 406 | "requires": { 407 | "binary-extensions": "^2.0.0" 408 | } 409 | }, 410 | "is-extglob": { 411 | "version": "2.1.1", 412 | "resolved": "https://registry.npmjs.org/is-extglob/-/is-extglob-2.1.1.tgz", 413 | "integrity": "sha1-qIwCU1eR8C7TfHahueqXc8gz+MI=" 414 | }, 415 | "is-glob": { 416 | "version": "4.0.3", 417 | "resolved": "https://registry.npmjs.org/is-glob/-/is-glob-4.0.3.tgz", 418 | "integrity": "sha512-xelSayHH36ZgE7ZWhli7pW34hNbNl8Ojv5KVmkJD4hBdD3th8Tfk9vYasLM+mXWOZhFkgZfxhLSnrwRr4elSSg==", 419 | "requires": { 420 | "is-extglob": "^2.1.1" 421 | } 422 | }, 423 | "is-number": { 424 | "version": "7.0.0", 425 | "resolved": "https://registry.npmjs.org/is-number/-/is-number-7.0.0.tgz", 426 | "integrity": "sha512-41Cifkg6e8TylSpdtTpeLVMqvSBEVzTttHvERD741+pnZ8ANv0004MRL43QKPDlK9cGvNp6NZWZUBlbGXYxxng==" 427 | }, 428 | "isexe": { 429 | "version": "2.0.0", 430 | "resolved": "https://registry.npmjs.org/isexe/-/isexe-2.0.0.tgz", 431 | "integrity": "sha1-6PvzdNxVb/iUehDcsFctYz8s+hA=" 432 | }, 433 | "normalize-path": { 434 | "version": "3.0.0", 435 | "resolved": "https://registry.npmjs.org/normalize-path/-/normalize-path-3.0.0.tgz", 436 | "integrity": "sha512-6eZs5Ls3WtCisHWp9S2GUy8dqkpGi4BVSz3GaqiE6ezub0512ESztXUwUB6C6IKbQkY2Pnb/mD4WYojCRwcwLA==" 437 | }, 438 | "onchange": { 439 | "version": "7.1.0", 440 | "resolved": "https://registry.npmjs.org/onchange/-/onchange-7.1.0.tgz", 441 | "integrity": "sha512-ZJcqsPiWUAUpvmnJri5TPBooqJOPmC0ttN65juhN15Q8xA+Nbg3BaxBHXQ45EistKKlKElb0edmbPWnKSBkvMg==", 442 | "requires": { 443 | "@blakeembrey/deque": "^1.0.5", 444 | "@blakeembrey/template": "^1.0.0", 445 | "arg": "^4.1.3", 446 | "chokidar": "^3.3.1", 447 | "cross-spawn": "^7.0.1", 448 | "ignore": "^5.1.4", 449 | "tree-kill": "^1.2.2" 450 | } 451 | }, 452 | "path-key": { 453 | "version": "3.1.1", 454 | "resolved": "https://registry.npmjs.org/path-key/-/path-key-3.1.1.tgz", 455 | "integrity": "sha512-ojmeN0qd+y0jszEtoY48r0Peq5dwMEkIlCOu6Q5f41lfkswXuKtYrhgoTpLnyIcHm24Uhqx+5Tqm2InSwLhE6Q==" 456 | }, 457 | "picomatch": { 458 | "version": "2.3.1", 459 | "resolved": "https://registry.npmjs.org/picomatch/-/picomatch-2.3.1.tgz", 460 | "integrity": "sha512-JU3teHTNjmE2VCGFzuY8EXzCDVwEqB2a8fsIvwaStHhAWJEeVd1o1QD80CU6+ZdEXXSLbSsuLwJjkCBWqRQUVA==" 461 | }, 462 | "prettier": { 463 | "version": "2.6.2", 464 | "resolved": "https://registry.npmjs.org/prettier/-/prettier-2.6.2.tgz", 465 | "integrity": "sha512-PkUpF+qoXTqhOeWL9fu7As8LXsIUZ1WYaJiY/a7McAQzxjk82OF0tibkFXVCDImZtWxbvojFjerkiLb0/q8mew==", 466 | "dev": true 467 | }, 468 | "readdirp": { 469 | "version": "3.6.0", 470 | "resolved": "https://registry.npmjs.org/readdirp/-/readdirp-3.6.0.tgz", 471 | "integrity": "sha512-hOS089on8RduqdbhvQ5Z37A0ESjsqz6qnRcffsMU3495FuTdqSm+7bhJ29JvIOsBDEEnan5DPu9t3To9VRlMzA==", 472 | "requires": { 473 | "picomatch": "^2.2.1" 474 | } 475 | }, 476 | "shebang-command": { 477 | "version": "2.0.0", 478 | "resolved": "https://registry.npmjs.org/shebang-command/-/shebang-command-2.0.0.tgz", 479 | "integrity": "sha512-kHxr2zZpYtdmrN1qDjrrX/Z1rR1kG8Dx+gkpK1G4eXmvXswmcE1hTWBWYUzlraYw1/yZp6YuDY77YtvbN0dmDA==", 480 | "requires": { 481 | "shebang-regex": "^3.0.0" 482 | } 483 | }, 484 | "shebang-regex": { 485 | "version": "3.0.0", 486 | "resolved": "https://registry.npmjs.org/shebang-regex/-/shebang-regex-3.0.0.tgz", 487 | "integrity": "sha512-7++dFhtcx3353uBaq8DDR4NuxBetBzC7ZQOhmTQInHEd6bSrXdiEyzCvG07Z44UYdLShWUyXt5M/yhz8ekcb1A==" 488 | }, 489 | "to-regex-range": { 490 | "version": "5.0.1", 491 | "resolved": "https://registry.npmjs.org/to-regex-range/-/to-regex-range-5.0.1.tgz", 492 | "integrity": "sha512-65P7iz6X5yEr1cwcgvQxbbIw7Uk3gOy5dIdtZ4rDveLqhrdJP+Li/Hx6tyK0NEb+2GCyneCMJiGqrADCSNk8sQ==", 493 | "requires": { 494 | "is-number": "^7.0.0" 495 | } 496 | }, 497 | "tree-kill": { 498 | "version": "1.2.2", 499 | "resolved": "https://registry.npmjs.org/tree-kill/-/tree-kill-1.2.2.tgz", 500 | "integrity": "sha512-L0Orpi8qGpRG//Nd+H90vFB+3iHnue1zSSGmNOOCh1GLJ7rUKVwV2HvijphGQS2UmhUZewS9VgvxYIdgr+fG1A==" 501 | }, 502 | "which": { 503 | "version": "2.0.2", 504 | "resolved": "https://registry.npmjs.org/which/-/which-2.0.2.tgz", 505 | "integrity": "sha512-BLI3Tl1TW3Pvl70l3yq3Y64i+awpwXqsGBYWkkqMtnbXgrMD+yj7rhW0kuEDxzJaYXGjEW5ogapKNMEKNMjibA==", 506 | "requires": { 507 | "isexe": "^2.0.0" 508 | } 509 | } 510 | } 511 | } 512 | -------------------------------------------------------------------------------- /package.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "causalimpact", 3 | "version": "1.0.0", 4 | "description": "", 5 | "main": "index.js", 6 | "scripts": { 7 | "test": "pytest", 8 | "format": "prettier *.md --write", 9 | "watch": "onchange 'src/**/*.py' 'tests/**/*.py' -- pytest" 10 | }, 11 | "keywords": [], 12 | "author": "", 13 | "license": "ISC", 14 | "devDependencies": { 15 | "prettier": "^2.6.2", 16 | "onchange": "^7.1.0" 17 | } 18 | } 19 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | # AVOID CHANGING REQUIRES: IT WILL BE UPDATED BY PYSCAFFOLD! 3 | requires = ["setuptools>=46.1.0", "setuptools_scm[toml]>=5", "wheel"] 4 | build-backend = "setuptools.build_meta" 5 | 6 | [tool.setuptools_scm] 7 | # For smarter version schemes and other configuration options, 8 | # check out https://github.com/pypa/setuptools_scm 9 | version_scheme = "no-guess-dev" 10 | 11 | [tool.commitizen] 12 | name = "cz_conventional_commits" 13 | version = "0.2.5" 14 | tag_format = "v$version" 15 | -------------------------------------------------------------------------------- /renovate.json: -------------------------------------------------------------------------------- 1 | { 2 | "extends": ["config:base"] 3 | } 4 | -------------------------------------------------------------------------------- /requirements.in: -------------------------------------------------------------------------------- 1 | pandas 2 | numpy 3 | statsmodels 4 | matplotlib 5 | pymc 6 | pytensor 7 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | # 2 | # This file is autogenerated by pip-compile with Python 3.10 3 | # by the following command: 4 | # 5 | # pip-compile 6 | # 7 | arviz==0.14.0 8 | # via pymc 9 | cachetools==5.2.0 10 | # via pymc 11 | cftime==1.6.2 12 | # via netcdf4 13 | cloudpickle==2.2.0 14 | # via pymc 15 | cons==0.4.5 16 | # via 17 | # etuples 18 | # minikanren 19 | # pytensor 20 | contourpy==1.0.6 21 | # via matplotlib 22 | cycler==0.11.0 23 | # via matplotlib 24 | etuples==0.3.8 25 | # via 26 | # minikanren 27 | # pytensor 28 | fastprogress==1.0.3 29 | # via pymc 30 | filelock==3.9.0 31 | # via pytensor 32 | fonttools==4.38.0 33 | # via matplotlib 34 | kiwisolver==1.4.4 35 | # via matplotlib 36 | logical-unification==0.4.5 37 | # via 38 | # cons 39 | # minikanren 40 | # pytensor 41 | matplotlib==3.6.2 42 | # via 43 | # -r requirements.in 44 | # arviz 45 | minikanren==1.0.3 46 | # via pytensor 47 | multipledispatch==0.6.0 48 | # via 49 | # etuples 50 | # logical-unification 51 | # minikanren 52 | netcdf4==1.6.2 53 | # via arviz 54 | numpy==1.24.1 55 | # via 56 | # -r requirements.in 57 | # arviz 58 | # cftime 59 | # contourpy 60 | # matplotlib 61 | # netcdf4 62 | # pandas 63 | # patsy 64 | # pymc 65 | # pytensor 66 | # scipy 67 | # statsmodels 68 | # xarray 69 | # xarray-einstats 70 | packaging==22.0 71 | # via 72 | # arviz 73 | # matplotlib 74 | # statsmodels 75 | # xarray 76 | pandas==1.5.2 77 | # via 78 | # -r requirements.in 79 | # arviz 80 | # pymc 81 | # statsmodels 82 | # xarray 83 | patsy==0.5.3 84 | # via statsmodels 85 | pillow==10.0.1 86 | # via matplotlib 87 | pymc==5.0.1 88 | # via -r requirements.in 89 | pyparsing==3.0.9 90 | # via matplotlib 91 | pytensor==2.8.11 92 | # via 93 | # -r requirements.in 94 | # pymc 95 | python-dateutil==2.8.2 96 | # via 97 | # matplotlib 98 | # pandas 99 | pytz==2022.7 100 | # via pandas 101 | scipy==1.10.0 102 | # via 103 | # arviz 104 | # pymc 105 | # pytensor 106 | # statsmodels 107 | # xarray-einstats 108 | six==1.16.0 109 | # via 110 | # multipledispatch 111 | # patsy 112 | # python-dateutil 113 | statsmodels==0.13.5 114 | # via -r requirements.in 115 | toolz==0.12.0 116 | # via 117 | # logical-unification 118 | # minikanren 119 | typing-extensions==4.4.0 120 | # via 121 | # arviz 122 | # pymc 123 | # pytensor 124 | xarray==2022.12.0 125 | # via 126 | # arviz 127 | # xarray-einstats 128 | xarray-einstats==0.4.0 129 | # via arviz 130 | 131 | # The following packages are considered to be unsafe in a requirements file: 132 | # setuptools 133 | -------------------------------------------------------------------------------- /setup.cfg: -------------------------------------------------------------------------------- 1 | [metadata] 2 | name = causalimpact 3 | description = Python Package for causal inference using Bayesian structural time-series models 4 | author = Jamal Senouci 5 | author_email = jamalsenouci@gmail.com 6 | license = MIT 7 | license_files = LICENSE.txt 8 | long_description = file: README.md 9 | long_description_content_type = text/markdown; charset=UTF-8 10 | url = https://github.com/jamalsenouci/causalimpact/ 11 | project_urls = 12 | Documentation = https://nbviewer.org/github/jamalsenouci/causalimpact/blob/master/GettingStarted.ipynb 13 | Source = https://github.com/jamalsenouci/causalimpact/ 14 | Changelog = https://github.com/jamalsenouci/causalimpact/CHANGELOG.md 15 | Download = https://pypi.python.org/pypi/causalimpact/ 16 | platforms = any 17 | classifiers = 18 | Development Status :: 4 - Beta 19 | Programming Language :: Python 20 | 21 | [options] 22 | zip_safe = False 23 | packages = find_namespace: 24 | include_package_data = True 25 | package_dir = 26 | =src 27 | python_requires = >=2.6 28 | install_requires = 29 | pandas 30 | numpy 31 | statsmodels 32 | matplotlib 33 | pymc 34 | pytensor 35 | importlib-metadata; python_version<"3.8" 36 | 37 | [options.packages.find] 38 | where = src 39 | exclude = 40 | tests 41 | 42 | [options.extras_require] 43 | testing = 44 | setuptools 45 | pytest 46 | pytest-cov 47 | 48 | [options.entry_points] 49 | 50 | [tool:pytest] 51 | addopts = 52 | --cov causalimpact --cov-report term-missing 53 | --verbose 54 | norecursedirs = 55 | dist 56 | build 57 | .tox 58 | testpaths = tests 59 | 60 | [devpi:upload] 61 | no_vcs = 1 62 | formats = bdist_wheel 63 | 64 | [flake8] 65 | max_line_length = 88 66 | extend_ignore = E203, W503 67 | exclude = 68 | .tox 69 | build 70 | dist 71 | .eggs 72 | docs/conf.py 73 | 74 | [pyscaffold] 75 | version = 4.2.1 76 | package = causalimpact 77 | extensions = 78 | no_skeleton 79 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | """ 2 | Setup file for causalimpact. 3 | Use setup.cfg to configure your project. 4 | This file was generated with PyScaffold 4.1.1. 5 | PyScaffold helps you to put up the scaffold of your new Python project. 6 | Learn more under: https://pyscaffold.org/ 7 | """ 8 | 9 | from setuptools import setup 10 | 11 | if __name__ == "__main__": 12 | try: 13 | setup(use_scm_version={"version_scheme": "no-guess-dev"}) 14 | except: # noqa 15 | print( 16 | "\n\nAn error occurred while building the project, " 17 | "please ensure you have the most updated version of setuptools, " 18 | "setuptools_scm and wheel with:\n" 19 | " pip install -U setuptools setuptools_scm wheel\n\n" 20 | ) 21 | raise 22 | -------------------------------------------------------------------------------- /src/causalimpact/__init__.py: -------------------------------------------------------------------------------- 1 | """Causal Impact. 2 | 3 | A Python package for causal inference using Bayesian structural time-series 4 | models. It's a port of the R package CausalImpact, 5 | see https://github.com/google/CausalImpact 6 | 7 | """ 8 | 9 | import sys 10 | 11 | if sys.version_info[:2] >= (3, 8): 12 | # TODO: Import directly (no need for conditional) when `python_requires = >= 3.8` 13 | from importlib.metadata import PackageNotFoundError, version # pragma: no cover 14 | else: 15 | from importlib_metadata import PackageNotFoundError, version # pragma: no cover 16 | 17 | try: 18 | # Change here if project is renamed and does not equal the package name 19 | dist_name = __name__ 20 | __version__ = version(dist_name) 21 | except PackageNotFoundError: # pragma: no cover 22 | __version__ = "unknown" 23 | finally: 24 | del version, PackageNotFoundError 25 | 26 | from causalimpact.analysis import CausalImpact # noqa 27 | 28 | __all__ = [ 29 | "CausalImpact", 30 | ] 31 | -------------------------------------------------------------------------------- /src/causalimpact/analysis.py: -------------------------------------------------------------------------------- 1 | import textwrap 2 | import numpy as np 3 | import pandas as pd 4 | from pandas.api.types import is_list_like, is_datetime64_dtype 5 | 6 | from causalimpact.misc import standardize_all_variables, df_print, get_matplotlib 7 | from causalimpact.model import construct_model, model_fit 8 | from causalimpact.inferences import compile_inferences 9 | import scipy.stats as st 10 | 11 | 12 | class CausalImpact: 13 | """CausalImpact() performs causal inference through counterfactual 14 | predictions using a Bayesian structural time-series model. 15 | 16 | Parameters: 17 | ---------- 18 | data : pandas dataframe 19 | the response variable must be in the first column, and any covariates 20 | in subsequent columns. 21 | pre_period : list 22 | A list specifying the first and the last time point of the 23 | pre-intervention period in the response column. This period can be 24 | thought of as a training period, used to determine the relationship 25 | between the response variable and the covariates. 26 | post_period : list 27 | A vector specifying the first and the last day of the post-intervention 28 | period we wish to study. This is the period after the intervention has 29 | begun whose effect we are interested in. The relationship between 30 | response variable and covariates, as determined during the pre-period, 31 | will be used to predict how the response variable should have evolved 32 | during the post-period had no intervention taken place. 33 | model_args : dict 34 | Optional arguments that can be used to adjust the default construction 35 | of the state-space model used for inference. 36 | For full control over the model, you can construct your own model using 37 | the statsmodels package and feed the model into CausalImpact(). 38 | ucm_model : statsmodels.tsa.statespace.structural.UnobservedComponents 39 | Instead of passing in data and having CausalImpact construct a 40 | model, it is possible to construct a model yourself using the 41 | statsmodel package. In this case, omit data, pre_period, and 42 | post_period. Instead only pass in ucm_model, y_post, alpha (optional). 43 | The model must have been fitted on data where the response variable was 44 | set to np.nan during the post-treatment period. The actual observed data 45 | during this period must then be passed to the function in y_post. 46 | post_period_response : list | pd.Series | np.Array 47 | Actual observed data during the post-intervention period. This is required 48 | if and only if a fitted ucm_model is passed instead of data. 49 | alpha : float 50 | Desired tail-area probability for posterior intervals. Defaults to 0.05, 51 | which will produce central 95% intervals. 52 | 53 | 54 | Returns 55 | ------- 56 | CausalImpact Object 57 | 58 | """ 59 | 60 | def __init__( 61 | self, 62 | data=None, 63 | pre_period=None, 64 | post_period=None, 65 | model_args=None, 66 | ucm_model=None, 67 | post_period_response=None, 68 | alpha=0.05, 69 | estimation="MLE", 70 | ): 71 | self.series = None 72 | self.model = {} 73 | if isinstance(data, pd.DataFrame): 74 | self.data = data.copy() 75 | else: 76 | self.data = data 77 | self.params = { 78 | "data": data, 79 | "pre_period": pre_period, 80 | "post_period": post_period, 81 | "model_args": model_args, 82 | "ucm_model": ucm_model, 83 | "post_period_response": post_period_response, 84 | "alpha": alpha, 85 | "estimation": estimation, 86 | } 87 | self.inferences = None 88 | self.results = None 89 | 90 | def run(self): 91 | kwargs = self._format_input( 92 | self.params["data"], 93 | self.params["pre_period"], 94 | self.params["post_period"], 95 | self.params["model_args"], 96 | self.params["ucm_model"], 97 | self.params["post_period_response"], 98 | self.params["alpha"], 99 | ) 100 | 101 | # Depending on input, dispatch to the appropriate Run* method() 102 | if self.data is not None: 103 | self._run_with_data( 104 | kwargs["data"], 105 | kwargs["pre_period"], 106 | kwargs["post_period"], 107 | kwargs["model_args"], 108 | kwargs["alpha"], 109 | self.params["estimation"], 110 | ) 111 | else: 112 | self._run_with_ucm( 113 | kwargs["ucm_model"], 114 | kwargs["post_period_response"], 115 | kwargs["alpha"], 116 | kwargs["model_args"], 117 | self.params["estimation"], 118 | ) 119 | 120 | @staticmethod 121 | def _format_input_data(data): 122 | """Check and format the data argument provided to CausalImpact(). 123 | 124 | Args: 125 | data: Pandas DataFrame 126 | 127 | Returns: 128 | correctly formatted Pandas DataFrame 129 | """ 130 | # If is a Pandas DataFrame and the first column is 'date', 131 | # try to convert 132 | 133 | if ( 134 | isinstance(data, pd.DataFrame) 135 | and isinstance(data.columns[0], str) 136 | and data.columns[0].lower() in ["date", "time"] 137 | ): 138 | data = data.set_index(data.columns[0]) 139 | 140 | # Try to convert to Pandas DataFrame 141 | try: 142 | data = pd.DataFrame(data) 143 | except ValueError: 144 | raise ValueError("could not convert input data to Pandas " + "DataFrame") 145 | 146 | # Must have at least 3 time points 147 | if len(data.index) < 3: 148 | raise ValueError("data must have at least 3 time points") 149 | 150 | # Must not have NA in covariates (if any) 151 | if len(data.columns) >= 2 and pd.isnull(data.iloc[:, 1:]).any(axis=None): 152 | raise ValueError("covariates must not contain null values") 153 | 154 | return data 155 | 156 | @staticmethod 157 | def _check_periods_are_valid(pre_period, post_period): 158 | if not isinstance(pre_period, list) or not isinstance(post_period, list): 159 | raise ValueError("pre_period and post_period must both be lists") 160 | if len(pre_period) != 2 or len(post_period) != 2: 161 | raise ValueError("pre_period and post_period must both be of " + "length 2") 162 | if pd.isnull(pre_period).any(axis=None) or pd.isnull(post_period).any( 163 | axis=None 164 | ): 165 | raise ValueError( 166 | "pre_period and post period must not contain " + "null values" 167 | ) 168 | 169 | @staticmethod 170 | def _align_periods_dtypes(pre_period, post_period, data): 171 | """align the dtypes of the pre_period and post_period to the data index. 172 | 173 | Args: 174 | pre_period: two-element list 175 | post_period: two-element list 176 | data: already-checked Pandas DataFrame, for reference only 177 | """ 178 | pre_dtype = np.array(pre_period).dtype 179 | post_dtype = np.array(post_period).dtype 180 | # if index is datetime then convert pre and post to datetimes 181 | if isinstance(data.index, pd.core.indexes.datetimes.DatetimeIndex): 182 | pre_period = [pd.to_datetime(date) for date in pre_period] 183 | post_period = [pd.to_datetime(date) for date in post_period] 184 | is_datetime64_dtype(pre_period) 185 | # if index is not datetime then error if datetime pre and post is passed 186 | elif is_datetime64_dtype(pd.Series(pre_period)) or is_datetime64_dtype( 187 | pd.Series(post_period) 188 | ): 189 | raise ValueError( 190 | "pre_period (" 191 | + pre_dtype.name 192 | + ") and post_period (" 193 | + post_dtype.name 194 | + ") should have the same class as the " 195 | + "time points in the data (" 196 | + data.index.dtype.name 197 | + ")" 198 | ) 199 | # if index is int 200 | elif pd.api.types.is_int64_dtype(data.index): 201 | pre_period = [int(elem) for elem in pre_period] 202 | post_period = [int(elem) for elem in post_period] 203 | # if index is int 204 | elif pd.api.types.is_float_dtype(data.index): 205 | pre_period = [float(elem) for elem in pre_period] 206 | post_period = [float(elem) for elem in post_period] 207 | # if index is string 208 | elif pd.api.types.is_string_dtype(data.index): 209 | if pd.api.types.is_numeric_dtype( 210 | np.array(pre_period) 211 | ) or pd.api.types.is_numeric_dtype(np.array(post_period)): 212 | raise ValueError( 213 | "pre_period (" 214 | + pre_dtype.name 215 | + ") and post_period (" 216 | + post_dtype.name 217 | + ") should have the same class as the " 218 | + "time points in the data (" 219 | + data.index.dtype.name 220 | + ")" 221 | ) 222 | else: 223 | pre_period = [str(idx) for idx in pre_period] 224 | post_period = [str(idx) for idx in post_period] 225 | else: 226 | raise ValueError( 227 | "pre_period (" 228 | + pre_dtype.name 229 | + ") and post_period (" 230 | + post_dtype.name 231 | + ") should have the same class as the " 232 | + "time points in the data (" 233 | + data.index.dtype.name 234 | + ")" 235 | ) 236 | return [pre_period, post_period] 237 | 238 | def _format_input_prepost(self, pre_period, post_period, data): 239 | """Check and format the pre_period and post_period input arguments. 240 | 241 | Args: 242 | pre_period: two-element list 243 | post_period: two-element list 244 | data: already-checked Pandas DataFrame, for reference only 245 | """ 246 | self._check_periods_are_valid(pre_period, post_period) 247 | 248 | pre_period, post_period = self._align_periods_dtypes( 249 | pre_period, post_period, data 250 | ) 251 | 252 | if pre_period[1] > post_period[0]: 253 | raise ValueError( 254 | "post period must start at least 1 observation" 255 | + " after the end of the pre_period" 256 | ) 257 | 258 | if isinstance(data.index, pd.RangeIndex): 259 | loc3 = post_period[0] 260 | loc4 = post_period[1] 261 | else: 262 | loc3 = data.index.get_loc(post_period[0]) 263 | loc4 = data.index.get_loc(post_period[1]) 264 | 265 | if loc4 < loc3: 266 | raise ValueError( 267 | "post_period[1] must not be earlier than " + "post_period[0]" 268 | ) 269 | 270 | if pre_period[0] < data.index.min(): 271 | pre_period[0] = data.index.min() 272 | if post_period[1] > data.index.max(): 273 | post_period[1] = data.index.max() 274 | return {"pre_period": pre_period, "post_period": post_period} 275 | 276 | @staticmethod 277 | def _check_valid_args_combo(args): 278 | data_model_args = [True, True, True, False, False] 279 | ucm_model_args = [False, False, False, True, True] 280 | 281 | if np.any(pd.isnull(args) != data_model_args) and np.any( 282 | pd.isnull(args) != ucm_model_args 283 | ): 284 | raise SyntaxError( 285 | "Must either provide ``data``, ``pre_period``" 286 | + " ,``post_period``, ``model_args``" 287 | " or ``ucm_model" + "and ``post_period_response``" 288 | ) 289 | 290 | @staticmethod 291 | def _check_valid_alpha(alpha): 292 | if alpha is None: 293 | raise ValueError("alpha must not be None") 294 | if not np.isreal(alpha): 295 | raise ValueError("alpha must be a real number") 296 | if np.isnan(alpha): 297 | raise ValueError("alpha must not be NA") 298 | if alpha <= 0 or alpha >= 1: 299 | raise ValueError("alpha must be between 0 and 1") 300 | 301 | def _format_input( 302 | self, 303 | data, 304 | pre_period, 305 | post_period, 306 | model_args, 307 | ucm_model, 308 | post_period_response, 309 | alpha, 310 | ): 311 | """Check and format all input arguments supplied to CausalImpact(). 312 | See the documentation of CausalImpact() for details 313 | 314 | Args: 315 | data: Pandas DataFrame or data frame 316 | pre_period: beginning and end of pre-period 317 | post_period: beginning and end of post-period 318 | model_args: dict of additional arguments for the model 319 | ucm_model: UnobservedComponents model (instead of data) 320 | post_period_response: observed response in the post-period 321 | alpha: tail-area for posterior intervals 322 | estimation: method of estimation for model fitting 323 | 324 | Returns: 325 | list of checked (and possibly reformatted) input arguments 326 | """ 327 | from statsmodels.tsa.statespace.structural import UnobservedComponents 328 | 329 | # Check that a consistent set of variables has been provided 330 | args = [data, pre_period, post_period, ucm_model, post_period_response] 331 | 332 | self._check_valid_args_combo(args) 333 | 334 | # Check and convert to Pandas DataFrame, with rows 335 | # representing time points 336 | if data is not None: 337 | data = self._format_input_data(data) 338 | 339 | # Check and 340 | if data is not None: 341 | checked = self._format_input_prepost(pre_period, post_period, data) 342 | pre_period = checked["pre_period"] 343 | post_period = checked["post_period"] 344 | self.params["pre_period"] = pre_period 345 | self.params["post_period"] = post_period 346 | 347 | # Parse , fill gaps using <_defaults> 348 | 349 | _defaults = { 350 | "ndraws": 1000, 351 | "nburn": 100, 352 | "niter": 1000, 353 | "standardize_data": True, 354 | "prior_level_sd": 0.01, 355 | "nseasons": 1, 356 | "season_duration": 1, 357 | "dynamic_regression": False, 358 | } 359 | 360 | if model_args is None: 361 | model_args = _defaults 362 | else: 363 | missing = [key for key in _defaults if key not in model_args] 364 | for arg in missing: 365 | model_args[arg] = _defaults[arg] 366 | 367 | # Check 368 | if not isinstance(model_args["standardize_data"], bool): 369 | raise ValueError("model_args.standardize_data must be a" + " boolean value") 370 | 371 | # Check 372 | if ucm_model is not None and not isinstance(ucm_model, UnobservedComponents): 373 | raise ValueError( 374 | "ucm_model must be an object of class " 375 | "statsmodels.tsa.statespace.structural.UnobservedComponents " 376 | "instead received " + str(type(ucm_model))[8:-2] 377 | ) 378 | 379 | # Check 380 | if ucm_model is not None: 381 | if not is_list_like(post_period_response): 382 | raise ValueError("post_period_response must be list-like") 383 | if np.array(post_period_response).dtype.num == 17: 384 | raise ValueError( 385 | "post_period_response should not be" + " datetime values" 386 | ) 387 | if not np.all(np.isreal(post_period_response)): 388 | raise ValueError( 389 | "post_period_response must contain all" + " real values" 390 | ) 391 | 392 | # Check 393 | self._check_valid_alpha(alpha) 394 | 395 | # Return updated arguments 396 | kwargs = { 397 | "data": data, 398 | "pre_period": pre_period, 399 | "post_period": post_period, 400 | "model_args": model_args, 401 | "ucm_model": ucm_model, 402 | "post_period_response": post_period_response, 403 | "alpha": alpha, 404 | } 405 | return kwargs 406 | 407 | def _run_with_data( 408 | self, data, pre_period, post_period, model_args, alpha, estimation 409 | ): 410 | # Zoom in on data in modeling range 411 | if data.shape[1] == 1: # no exogenous values provided 412 | raise ValueError("data contains no exogenous variables") 413 | data_modeling = data.copy() 414 | 415 | df_pre = data_modeling.loc[pre_period[0] : pre_period[1], :] 416 | df_post = data_modeling.loc[post_period[0] : post_period[1], :] 417 | 418 | # Standardize all variables 419 | orig_std_params = (0, 1) 420 | if model_args["standardize_data"]: 421 | sd_results = standardize_all_variables( 422 | data_modeling, pre_period, post_period 423 | ) 424 | df_pre = sd_results["data_pre"] 425 | df_post = sd_results["data_post"] 426 | orig_std_params = sd_results["orig_std_params"] 427 | 428 | # Construct model and perform inference 429 | model = construct_model(df_pre, model_args) 430 | self.model = model 431 | 432 | model_results = model_fit(model, estimation, model_args) 433 | 434 | inferences = compile_inferences( 435 | model_results, 436 | data, 437 | df_pre, 438 | df_post, 439 | None, 440 | alpha, 441 | orig_std_params, 442 | estimation, 443 | ) 444 | 445 | # "append" to 'CausalImpact' object 446 | self.inferences = inferences["series"] 447 | self.results = model_results 448 | 449 | def _run_with_ucm( 450 | self, ucm_model, post_period_response, alpha, model_args, estimation 451 | ): 452 | """Runs an impact analysis on top of a ucm model. 453 | 454 | Args: 455 | ucm_model: Model as returned by UnobservedComponents(), 456 | in which the data during the post-period was set to NA 457 | post_period_response: observed data during the post-intervention 458 | period 459 | alpha: tail-probabilities of posterior intervals""" 460 | 461 | df_pre = ucm_model.data.orig_endog[: -len(post_period_response)] 462 | df_pre = pd.DataFrame(df_pre) 463 | 464 | post_period_response = pd.DataFrame(post_period_response) 465 | 466 | data = pd.DataFrame( 467 | np.concatenate([df_pre.values, post_period_response.values]) 468 | ) 469 | 470 | orig_std_params = (0, 1) 471 | 472 | model_results = model_fit(ucm_model, estimation, model_args) 473 | 474 | # Compile posterior inferences 475 | inferences = compile_inferences( 476 | model_results, 477 | data, 478 | df_pre, 479 | None, 480 | post_period_response, 481 | alpha, 482 | orig_std_params, 483 | estimation, 484 | ) 485 | 486 | obs_inter = model_results.model_nobs - len(post_period_response) 487 | 488 | self.params["pre_period"] = [0, obs_inter - 1] 489 | self.params["post_period"] = [obs_inter, -1] 490 | self.data = pd.concat([df_pre, post_period_response]) 491 | self.inferences = inferences["series"] 492 | self.results = model_results 493 | 494 | @staticmethod 495 | def _print_report( 496 | mean_pred_fmt, 497 | mean_resp_fmt, 498 | mean_lower_fmt, 499 | mean_upper_fmt, 500 | abs_effect_fmt, 501 | abs_effect_upper_fmt, 502 | abs_effect_lower_fmt, 503 | rel_effect_fmt, 504 | rel_effect_upper_fmt, 505 | rel_effect_lower_fmt, 506 | cum_resp_fmt, 507 | cum_pred_fmt, 508 | cum_lower_fmt, 509 | cum_upper_fmt, 510 | confidence, 511 | cum_rel_effect_lower, 512 | cum_rel_effect_upper, 513 | cum_rel_effect, 514 | width, 515 | p_value, 516 | alpha, 517 | ): 518 | sig = not (cum_rel_effect_lower < 0 < cum_rel_effect_upper) 519 | pos = cum_rel_effect > 0 520 | # Summarize averages 521 | stmt = textwrap.dedent( 522 | """During the post-intervention period, the response 523 | variable had an average value of 524 | approx. {mean_resp}. 525 | """.format( 526 | mean_resp=mean_resp_fmt 527 | ) 528 | ) 529 | if sig: 530 | stmt += " By contrast, in " 531 | else: 532 | stmt += " In " 533 | 534 | stmt += textwrap.dedent( 535 | """ 536 | the absence of an intervention, we would have 537 | expected an average response of {mean_pred}. The 538 | {confidence} interval of this counterfactual 539 | prediction is [{mean_lower}, {mean_upper}]. 540 | Subtracting this prediction from the observed 541 | response yields an estimate of the causal effect 542 | the intervention had on the response variable. 543 | This effect is {abs_effect} with a 544 | {confidence} interval of [{abs_lower}, 545 | {abs_upper}]. For a discussion of the 546 | significance of this effect, 547 | see below. 548 | """.format( 549 | mean_pred=mean_pred_fmt, 550 | confidence=confidence, 551 | mean_lower=mean_lower_fmt, 552 | mean_upper=mean_upper_fmt, 553 | abs_effect=abs_effect_fmt, 554 | abs_upper=abs_effect_upper_fmt, 555 | abs_lower=abs_effect_lower_fmt, 556 | ) 557 | ) 558 | # Summarize sums 559 | stmt2 = textwrap.dedent( 560 | """ 561 | Summing up the individual data points during the 562 | post-intervention period (which can only sometimes be 563 | meaningfully interpreted), the response variable had an 564 | overall value of {cum_resp}. 565 | """.format( 566 | cum_resp=cum_resp_fmt 567 | ) 568 | ) 569 | if sig: 570 | stmt2 += " By contrast, had " 571 | else: 572 | stmt2 += " Had " 573 | 574 | stmt2 += textwrap.dedent( 575 | """ 576 | the intervention not taken place, we would have expected 577 | a sum of {cum_pred}. The {confidence} interval of this 578 | prediction is [{cum_pred_lower}, {cum_pred_upper}] 579 | """.format( 580 | cum_pred=cum_pred_fmt, 581 | confidence=confidence, 582 | cum_pred_lower=cum_lower_fmt, 583 | cum_pred_upper=cum_upper_fmt, 584 | ) 585 | ) 586 | 587 | # Summarize relative numbers (in which case row [1] = row [2]) 588 | stmt3 = textwrap.dedent( 589 | """ 590 | The above results are given in terms 591 | of absolute numbers. In relative terms, the 592 | response variable showed 593 | """ 594 | ) 595 | if pos: 596 | stmt3 += " an increase of " 597 | else: 598 | stmt3 += " a decrease of " 599 | 600 | stmt3 += textwrap.dedent( 601 | """ 602 | {rel_effect}. The {confidence} interval of this 603 | percentage is [{rel_effect_lower}, 604 | {rel_effect_upper}] 605 | """.format( 606 | confidence=confidence, 607 | rel_effect=rel_effect_fmt, 608 | rel_effect_lower=rel_effect_lower_fmt, 609 | rel_effect_upper=rel_effect_upper_fmt, 610 | ) 611 | ) 612 | 613 | # Comment on significance 614 | if sig and pos: 615 | stmt4 = textwrap.dedent( 616 | """ 617 | This means that the positive effect observed 618 | during the intervention period is statistically 619 | significant and unlikely to be due to random 620 | fluctuations. It should be noted, however, that 621 | the question of whether this increase also bears 622 | substantive significance can only be answered by 623 | comparing the absolute effect {abs_effect} to 624 | the original goal of the underlying 625 | intervention. 626 | """.format( 627 | abs_effect=abs_effect_fmt 628 | ) 629 | ) 630 | elif sig and not pos: 631 | stmt4 = textwrap.dedent( 632 | """ 633 | This means that the negative effect observed 634 | during the intervention period is statistically 635 | significant. If the experimenter had expected a 636 | positive effect, it is recommended to double-check 637 | whether anomalies in the control variables may have 638 | caused an overly optimistic expectation of what 639 | should have happened in the response variable in the 640 | absence of the intervention. 641 | """ 642 | ) 643 | elif not sig and pos: 644 | stmt4 = textwrap.dedent( 645 | """ 646 | This means that, although the intervention 647 | appears to have caused a positive effect, this 648 | effect is not statistically significant when 649 | considering the post-intervention period as a whole. 650 | Individual days or shorter stretches within the 651 | intervention period may of course still have had a 652 | significant effect, as indicated whenever the lower 653 | limit of the impact time series (lower plot) was 654 | above zero. 655 | """ 656 | ) 657 | elif not sig and not pos: 658 | stmt4 = textwrap.dedent( 659 | """ 660 | This means that, although it may look as though 661 | the intervention has exerted a negative effect on 662 | the response variable when considering the 663 | intervention period as a whole, this effect is not 664 | statistically significant, and so cannot be 665 | meaningfully interpreted. 666 | """ 667 | ) 668 | if not sig: 669 | stmt4 += textwrap.dedent( 670 | """ 671 | The apparent effect could be the result of random 672 | fluctuations that are unrelated to the intervention. 673 | This is often the case when the intervention period 674 | is very long and includes much of the time when the 675 | effect has already worn off. It can also be the case 676 | when the intervention period is too short to 677 | distinguish the signal from the noise. Finally, 678 | failing to find a significant effect can happen when 679 | there are not enough control variables or when these 680 | variables do not correlate well with the response 681 | variable during the learning period.""" 682 | ) 683 | if p_value < alpha: 684 | stmt5 = textwrap.dedent( 685 | """The probability of obtaining this effect by 686 | chance is very small (Bayesian one-sided tail-area 687 | probability {p}). This means the 688 | causal effect can be considered statistically 689 | significant.""".format( 690 | p=np.round(p_value, 3) 691 | ) 692 | ) 693 | else: 694 | stmt5 = textwrap.dedent( 695 | """The probability of obtaining this effect by 696 | chance is p = {p}. This means the effect may 697 | be spurious and would generally not be considered 698 | statistically significant.""".format(p=np.round(p_value, 3) 699 | ) 700 | ) 701 | 702 | print(textwrap.fill(stmt, width=width)) 703 | print("\n") 704 | print(textwrap.fill(stmt2, width=width)) 705 | print("\n") 706 | print(textwrap.fill(stmt3, width=width)) 707 | print("\n") 708 | print(textwrap.fill(stmt4, width=width)) 709 | print("\n") 710 | print(textwrap.fill(stmt5, width=width)) 711 | 712 | def summary(self, output="summary", width=120, path=None): 713 | """reports a summary of the results 714 | 715 | Parameters 716 | ---------- 717 | output: str 718 | can be summary or report. summary outputs a table. 719 | report outputs a natural language description of the 720 | findings 721 | width : int 722 | line width of the output. Only relevant if output == report 723 | path : str 724 | path to output summary to csv. Only relevant if output == summary 725 | 726 | """ 727 | alpha = self.params["alpha"] 728 | confidence = "{}%".format(int((1 - alpha) * 100)) 729 | post_period = self.params["post_period"] 730 | post_inf = self.inferences.loc[post_period[0] : post_period[1], :] 731 | post_point_resp = post_inf.loc[:, "response"] 732 | post_point_pred = post_inf.loc[:, "point_pred"] 733 | post_point_upper = post_inf.loc[:, "point_pred_upper"] 734 | post_point_lower = post_inf.loc[:, "point_pred_lower"] 735 | 736 | mean_resp = post_point_resp.mean() 737 | mean_resp_fmt = int(mean_resp) 738 | cum_resp = post_point_resp.sum() 739 | cum_resp_fmt = int(cum_resp) 740 | mean_pred = post_point_pred.mean() 741 | mean_pred_fmt = int(post_point_pred.mean()) 742 | cum_pred = post_point_pred.sum() 743 | cum_pred_fmt = int(cum_pred) 744 | mean_lower = post_point_lower.mean() 745 | mean_lower_fmt = int(mean_lower) 746 | mean_upper = post_point_upper.mean() 747 | mean_upper_fmt = int(mean_upper) 748 | mean_ci_fmt = [mean_lower_fmt, mean_upper_fmt] 749 | cum_lower = post_point_lower.sum() 750 | cum_lower_fmt = int(cum_lower) 751 | cum_upper = post_point_upper.sum() 752 | cum_upper_fmt = int(cum_upper) 753 | cum_ci_fmt = [cum_lower_fmt, cum_upper_fmt] 754 | 755 | abs_effect = (post_point_resp - post_point_pred).mean() 756 | abs_effect_fmt = int(abs_effect) 757 | cum_abs_effect = (post_point_resp - post_point_pred).sum() 758 | cum_abs_effect_fmt = int(cum_abs_effect) 759 | abs_effect_lower = (post_point_resp - post_point_lower).mean() 760 | abs_effect_lower_fmt = int(abs_effect_lower) 761 | abs_effect_upper = (post_point_resp - post_point_upper).mean() 762 | abs_effect_upper_fmt = int(abs_effect_upper) 763 | abs_effect_ci_fmt = [abs_effect_lower_fmt, abs_effect_upper_fmt] 764 | cum_abs_lower = (post_point_resp - post_point_lower).sum() 765 | cum_abs_lower_fmt = int(cum_abs_lower) 766 | cum_abs_upper = (post_point_resp - post_point_upper).sum() 767 | cum_abs_upper_fmt = int(cum_abs_upper) 768 | cum_abs_effect_ci_fmt = [cum_abs_lower_fmt, cum_abs_upper_fmt] 769 | 770 | rel_effect = abs_effect / mean_pred * 100 771 | rel_effect_fmt = "{:.1f}%".format(rel_effect) 772 | cum_rel_effect = cum_abs_effect / cum_pred * 100 773 | cum_rel_effect_fmt = "{:.1f}%".format(cum_rel_effect) 774 | rel_effect_lower = abs_effect_lower / mean_pred * 100 775 | rel_effect_lower_fmt = "{:.1f}%".format(rel_effect_lower) 776 | rel_effect_upper = abs_effect_upper / mean_pred * 100 777 | rel_effect_upper_fmt = "{:.1f}%".format(rel_effect_upper) 778 | rel_effect_ci_fmt = [rel_effect_lower_fmt, rel_effect_upper_fmt] 779 | cum_rel_effect_lower = cum_abs_lower / cum_pred * 100 780 | cum_rel_effect_lower_fmt = "{:.1f}%".format(cum_rel_effect_lower) 781 | cum_rel_effect_upper = cum_abs_upper / cum_pred * 100 782 | cum_rel_effect_upper_fmt = "{:.1f}%".format(cum_rel_effect_upper) 783 | cum_rel_effect_ci_fmt = [cum_rel_effect_lower_fmt, cum_rel_effect_upper_fmt] 784 | 785 | # assuming approximately normal distribution 786 | # calculate standard deviation from the 95% conf interval 787 | std_pred = ( 788 | mean_upper - mean_pred 789 | ) / 1.96 # from mean_upper = mean_pred + 1.96 * std 790 | # calculate z score 791 | z_score = (0 - mean_pred) / std_pred 792 | # pvalue from zscore 793 | p_value = st.norm.cdf(z_score) 794 | prob_causal = 1 - p_value 795 | p_value_perc = p_value * 100 796 | prob_causal_perc = prob_causal * 100 797 | 798 | if output == "summary": 799 | # Posterior inference {CausalImpact} 800 | summary = [ 801 | [mean_resp_fmt, cum_resp_fmt], 802 | [mean_pred_fmt, cum_pred_fmt], 803 | [mean_ci_fmt, cum_ci_fmt], 804 | [" ", " "], 805 | [abs_effect_fmt, cum_abs_effect_fmt], 806 | [abs_effect_ci_fmt, cum_abs_effect_ci_fmt], 807 | [" ", " "], 808 | [rel_effect_fmt, cum_rel_effect_fmt], 809 | [rel_effect_ci_fmt, cum_rel_effect_ci_fmt], 810 | [" ", " "], 811 | ["{:.1f}%".format(p_value_perc), " "], 812 | ["{:.1f}%".format(prob_causal_perc), " "], 813 | ] 814 | summary = pd.DataFrame( 815 | summary, 816 | columns=["Average", "Cumulative"], 817 | index=[ 818 | "Actual", 819 | "Predicted", 820 | "95% CI", 821 | " ", 822 | "Absolute Effect", 823 | "95% CI", 824 | " ", 825 | "Relative Effect", 826 | "95% CI", 827 | " ", 828 | "P-value", 829 | "Prob. of Causal Effect", 830 | ], 831 | ) 832 | df_print(summary, path) 833 | elif output == "report": 834 | self._print_report( 835 | mean_pred_fmt, 836 | mean_resp_fmt, 837 | mean_lower_fmt, 838 | mean_upper_fmt, 839 | abs_effect_fmt, 840 | abs_effect_upper_fmt, 841 | abs_effect_lower_fmt, 842 | rel_effect_fmt, 843 | rel_effect_upper_fmt, 844 | rel_effect_lower_fmt, 845 | cum_resp_fmt, 846 | cum_pred_fmt, 847 | cum_lower_fmt, 848 | cum_upper_fmt, 849 | confidence, 850 | cum_rel_effect_lower, 851 | cum_rel_effect_upper, 852 | cum_rel_effect, 853 | width, 854 | p_value, 855 | alpha, 856 | ) 857 | else: 858 | raise ValueError( 859 | "Output argument must be either 'summary' " + "or 'report'" 860 | ) 861 | 862 | def plot( 863 | self, 864 | panels=None, 865 | figsize=(15, 12), 866 | fname=None, 867 | ): 868 | if panels is None: 869 | panels = ["original", "pointwise", "cumulative"] 870 | plt = get_matplotlib() 871 | fig = plt.figure(figsize=figsize) 872 | 873 | data_inter = self.params["pre_period"][1] 874 | if isinstance(data_inter, pd.DatetimeIndex): 875 | data_inter = pd.Timestamp(data_inter) 876 | 877 | inferences = self.inferences.iloc[1:, :] 878 | 879 | # Observation and regression components 880 | if "original" in panels: 881 | ax1 = plt.subplot(3, 1, 1) 882 | plt.plot(inferences.point_pred, "r--", linewidth=2, label="model") 883 | plt.plot(inferences.response, "k", linewidth=2, label="endog") 884 | 885 | plt.axvline(data_inter, c="k", linestyle="--") 886 | 887 | plt.fill_between( 888 | inferences.index, 889 | inferences.point_pred_lower, 890 | inferences.point_pred_upper, 891 | facecolor="gray", 892 | interpolate=True, 893 | alpha=0.25, 894 | ) 895 | plt.setp(ax1.get_xticklabels(), visible=False) 896 | plt.legend(loc="upper left") 897 | plt.title("Observation vs prediction") 898 | 899 | if "pointwise" in panels: 900 | # Pointwise difference 901 | if "ax1" in locals(): 902 | ax2 = plt.subplot(312, sharex=ax1) 903 | else: 904 | ax2 = plt.subplot(312) 905 | lift = inferences.point_effect 906 | plt.plot(lift, "r--", linewidth=2) 907 | plt.plot(self.data.index, np.zeros(self.data.shape[0]), "g-", linewidth=2) 908 | plt.axvline(data_inter, c="k", linestyle="--") 909 | 910 | lift_lower = inferences.point_effect_lower 911 | lift_upper = inferences.point_effect_upper 912 | 913 | plt.fill_between( 914 | inferences.index, 915 | lift_lower, 916 | lift_upper, 917 | facecolor="gray", 918 | interpolate=True, 919 | alpha=0.25, 920 | ) 921 | plt.setp(ax2.get_xticklabels(), visible=False) 922 | plt.title("Difference") 923 | 924 | # Cumulative impact 925 | if "cumulative" in panels: 926 | if "ax1" in locals(): 927 | plt.subplot(313, sharex=ax1) 928 | elif "ax2" in locals(): 929 | plt.subplot(313, sharex=ax2) 930 | else: 931 | plt.subplot(313) 932 | plt.plot( 933 | inferences.index, 934 | inferences.cum_effect, 935 | "r--", 936 | linewidth=2, 937 | ) 938 | 939 | plt.plot(self.data.index, np.zeros(self.data.shape[0]), "g-", linewidth=2) 940 | plt.axvline(data_inter, c="k", linestyle="--") 941 | 942 | plt.fill_between( 943 | inferences.index, 944 | inferences.cum_effect_lower, 945 | inferences.cum_effect_upper, 946 | facecolor="gray", 947 | interpolate=True, 948 | alpha=0.25, 949 | ) 950 | plt.axis([inferences.index[0], inferences.index[-1], None, None]) 951 | 952 | plt.title("Cumulative Impact") 953 | plt.xlabel("$T$") 954 | if fname is None: 955 | plt.show() 956 | else: 957 | fig.savefig(fname, bbox_inches="tight") 958 | plt.close(fig) 959 | -------------------------------------------------------------------------------- /src/causalimpact/inferences.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pandas as pd 3 | from causalimpact.misc import unstandardize 4 | 5 | 6 | def compile_inferences( 7 | results, 8 | data, 9 | df_pre, 10 | df_post, 11 | post_period_response, 12 | alpha, 13 | orig_std_params, 14 | estimation, 15 | ): 16 | """Compiles inferences to make predictions for post intervention 17 | period. 18 | 19 | Args: 20 | results: trained UnobservedComponents model from statsmodels package. 21 | data: pd.DataFrame pre and post-intervention data containing y and X. 22 | df_pre: pd.DataFrame pre intervention data 23 | df_post: pd.DataFrame post intervention data 24 | post_period_response: pd.DataFrame used when the model trained is not 25 | default one but a customized instead. In this case, 26 | ``df_post`` is None. 27 | alpha: float significance level for confidence intervals. 28 | orig_std_params: tuple of floats where first value is the mean and 29 | second value is standard deviation used for standardizing data. 30 | estimation: str to choose fitting method. "MLE" as default 31 | 32 | Returns: 33 | dict containing all data related to the inference process. 34 | """ 35 | # Compute point predictions of counterfactual (in standardized space) 36 | if df_post is not None: 37 | # returns pre-period predictions (in-sample) 38 | predict = results.get_prediction() 39 | # returns post-period forecast 40 | forecast = results.get_forecast(df_post, alpha=alpha) 41 | else: 42 | pre_len = results.model_nobs - len(post_period_response) 43 | 44 | predict = results.get_prediction(end=pre_len - 1) 45 | forecast = results.get_prediction(start=pre_len) 46 | 47 | df_post = post_period_response 48 | df_post.index = pd.core.indexes.range.RangeIndex( 49 | start=pre_len, stop=pre_len + len(df_post), step=1 50 | ) 51 | 52 | # Compile summary statistics (in original space) 53 | pre_pred = unstandardize(predict.predicted_mean, orig_std_params) 54 | pre_pred.index = df_pre.index 55 | 56 | post_pred = unstandardize(forecast.predicted_mean, orig_std_params) 57 | post_pred.index = df_post.index 58 | 59 | point_pred = pd.concat([pre_pred, post_pred]) 60 | 61 | pre_ci = unstandardize(predict.conf_int(alpha=alpha), orig_std_params) 62 | pre_ci.index = df_pre.index 63 | 64 | post_ci = unstandardize(forecast.conf_int(alpha=alpha), orig_std_params) 65 | 66 | post_ci.index = df_post.index 67 | ci = pd.concat([pre_ci, post_ci]) 68 | point_pred_lower = ci.iloc[:, 0].to_frame() 69 | point_pred_upper = ci.iloc[:, 1].to_frame() 70 | 71 | response = data.iloc[:, 0] 72 | response_index = data.index 73 | 74 | response = pd.DataFrame(response) 75 | 76 | cum_response = np.cumsum(response) 77 | cum_pred = np.cumsum(point_pred) 78 | cum_pred_lower = np.cumsum(point_pred_lower) 79 | cum_pred_upper = np.cumsum(point_pred_upper) 80 | 81 | data = pd.concat( 82 | [ 83 | point_pred, 84 | point_pred_lower, 85 | point_pred_upper, 86 | cum_pred, 87 | cum_pred_lower, 88 | cum_pred_upper, 89 | ], 90 | axis=1, 91 | ) 92 | 93 | data = pd.concat([response, cum_response], axis=1).join(data, lsuffix="l") 94 | 95 | data.columns = [ 96 | "response", 97 | "cum_response", 98 | "point_pred", 99 | "point_pred_lower", 100 | "point_pred_upper", 101 | "cum_pred", 102 | "cum_pred_lower", 103 | "cum_pred_upper", 104 | ] 105 | 106 | point_effect = (data.response - data.point_pred).to_frame() 107 | point_effect_lower = (data.response - data.point_pred_lower).to_frame() 108 | point_effect_upper = (data.response - data.point_pred_upper).to_frame() 109 | 110 | cum_effect = point_effect.copy() 111 | cum_effect.loc[df_pre.index[0] : df_pre.index[-1]] = 0 112 | cum_effect = np.cumsum(cum_effect) 113 | 114 | cum_effect_lower = point_effect_lower.copy() 115 | cum_effect_lower.loc[df_pre.index[0] : df_pre.index[-1]] = 0 116 | cum_effect_lower = np.cumsum(cum_effect_lower) 117 | 118 | cum_effect_upper = point_effect_upper.copy() 119 | cum_effect_upper.loc[df_pre.index[0] : df_pre.index[-1]] = 0 120 | cum_effect_upper = np.cumsum(cum_effect_upper) 121 | 122 | data = pd.concat( 123 | [ 124 | data, 125 | point_effect, 126 | point_effect_lower, 127 | point_effect_upper, 128 | cum_effect, 129 | cum_effect_lower, 130 | cum_effect_upper, 131 | ], 132 | axis=1, 133 | ) 134 | 135 | # Create DataFrame of results 136 | data.columns = [ 137 | "response", 138 | "cum_response", 139 | "point_pred", 140 | "point_pred_lower", 141 | "point_pred_upper", 142 | "cum_pred", 143 | "cum_pred_lower", 144 | "cum_pred_upper", 145 | "point_effect", 146 | "point_effect_lower", 147 | "point_effect_upper", 148 | "cum_effect", 149 | "cum_effect_lower", 150 | "cum_effect_upper", 151 | ] 152 | 153 | data.index = response_index 154 | 155 | series = data 156 | 157 | inferences = {"series": series} 158 | return inferences 159 | -------------------------------------------------------------------------------- /src/causalimpact/misc.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | 3 | 4 | def standardize_all_variables(data, pre_period, post_period): 5 | """Standardize all columns of a given time series. 6 | Args: 7 | data: Pandas DataFrame with one or more columns 8 | 9 | Returns: 10 | dict: 11 | data: standardized data 12 | UnStandardize: function for undoing the transformation of the 13 | first column in the provided data 14 | """ 15 | if not isinstance(data, pd.DataFrame): 16 | raise ValueError("``data`` must be of type `pandas.DataFrame`") 17 | 18 | if not ( 19 | pd.api.types.is_list_like(pre_period) and pd.api.types.is_list_like(post_period) 20 | ): 21 | raise ValueError("``pre_period`` and ``post_period``must be listlike") 22 | 23 | data_mu = data.loc[pre_period[0] : pre_period[1], :].mean(skipna=True) 24 | data_sd = data.loc[pre_period[0] : pre_period[1], :].std(skipna=True, ddof=0) 25 | data = data - data_mu 26 | data_sd = data_sd.fillna(1) 27 | 28 | data[data != 0] = data[data != 0] / data_sd 29 | y_mu = data_mu[0] 30 | y_sd = data_sd[0] 31 | 32 | data_pre = data.loc[pre_period[0] : pre_period[1], :] 33 | data_post = data.loc[post_period[0] : post_period[1], :] 34 | 35 | return { 36 | "data_pre": data_pre, 37 | "data_post": data_post, 38 | "orig_std_params": (y_mu, y_sd), 39 | } 40 | 41 | 42 | def unstandardize(data, orig_std_params): 43 | """Function for reversing the standardization of the first column in the 44 | provided data. 45 | """ 46 | data = pd.DataFrame(data) 47 | y_mu = orig_std_params[0] 48 | y_sd = orig_std_params[1] 49 | data = data.mul(y_sd, axis=1) 50 | data = data.add(y_mu, axis=1) 51 | return data 52 | 53 | 54 | def df_print(data, path=None): 55 | if path: 56 | data.to_csv(path) 57 | print(data) 58 | 59 | 60 | def get_matplotlib(): # pragma: no cover 61 | """Wrapper function to facilitate unit testing the `plot` tool by removing 62 | the strong dependencies of matplotlib. 63 | 64 | Returns: 65 | module matplotlib.pyplot 66 | """ 67 | import matplotlib.pyplot as plt 68 | 69 | return plt 70 | -------------------------------------------------------------------------------- /src/causalimpact/model.py: -------------------------------------------------------------------------------- 1 | """Constructs and fits the statespace model. 2 | 3 | Contains the construct_model and model_fit functions that are called in analysis.py. 4 | """ 5 | 6 | import numpy as np 7 | import pandas as pd 8 | import pymc as pm 9 | import pytensor.tensor as at 10 | from pytensor.graph.op import Op 11 | 12 | 13 | def observations_ill_conditioned(y): 14 | """Checks whether the response variable (i.e., the series of observations 15 | for the dependent variable y) are ill-conditioned. For example, the series 16 | might contain too few non-NA values. In such cases, inference will be 17 | aborted. 18 | 19 | Args: 20 | y: observed series (Pandas Series) 21 | 22 | Returns: 23 | True if something is wrong with the observations; False otherwise. 24 | """ 25 | 26 | if y is None: 27 | raise ValueError("y cannot be None") 28 | if not (len(y) > 1): 29 | raise ValueError("y must have len > 1") 30 | 31 | # All NA? 32 | if np.all(pd.isnull(y)): 33 | raise ValueError("Aborting inference due to input series being all " "null.") 34 | elif len(y[pd.notnull(y)]) < 3: 35 | # Fewer than 3 non-NA values? 36 | raise ValueError( 37 | "Aborting inference due to fewer than 3 nonnull " "values in input." 38 | ) 39 | # Constant series? 40 | elif y.std(skipna=True) == 0: 41 | raise ValueError("Aborting inference due to input series being " "constant") 42 | return False 43 | 44 | 45 | def construct_model(data, model_args=None): 46 | """Specifies the model and performs inference. Inference means using a 47 | technique that combines Kalman Filters with Maximum Likelihood Estimators 48 | methods to fit the parameters that best explain the observed data. 49 | 50 | Args: 51 | data: time series of response variable and optional covariates 52 | model_args: optional list of additional model arguments 53 | 54 | Returns: 55 | An Unobserved Components Model, as returned by UnobservedComponents() 56 | """ 57 | if model_args is None: 58 | model_args = {} 59 | from statsmodels.tsa.statespace.structural import UnobservedComponents 60 | 61 | y = data.iloc[:, 0] 62 | 63 | observations_ill_conditioned(y) 64 | 65 | # LocalLevel specification of statespace 66 | ss = {"endog": y.values, "level": "llevel"} 67 | 68 | # No regression? 69 | if len(data.columns) > 1: 70 | # Static regression 71 | if not model_args.get("dynamic_regression"): 72 | ss["exog"] = data.iloc[:, 1:].values 73 | # Dynamic regression 74 | else: 75 | raise NotImplementedError() 76 | mod = UnobservedComponents(**ss) 77 | return mod 78 | 79 | 80 | class Loglike(Op): 81 | """Theano LogLike wrapper that allow PyMC3 to compute the likelihood 82 | and Jacobian in a way that it can make use of.""" 83 | 84 | itypes = [at.dvector] # expects a vector of parameter values when called 85 | otypes = [at.dscalar] # outputs a single scalar value (the log likelihood) 86 | 87 | def __init__(self, model): 88 | self.model = model 89 | self.score = Score(self.model) 90 | 91 | def perform(self, node, inputs, outputs): 92 | (theta,) = inputs # contains the vector of parameters 93 | llf = self.model.loglike(theta) 94 | outputs[0][0] = np.array(llf) # output the log-likelihood 95 | 96 | def grad(self, inputs, g): 97 | # the method that calculates the gradients - it actually returns the 98 | # vector-Jacobian product - g[0] is a vector of parameter values 99 | (theta,) = inputs # our parameters 100 | out = [g[0] * self.score(theta)] 101 | return out 102 | 103 | 104 | class Score(Op): 105 | """Theano Score wrapper that allow PyMC3 to compute the likelihood and 106 | Jacobian in a way that it can make use of.""" 107 | 108 | itypes = [at.dvector] 109 | otypes = [at.dvector] 110 | 111 | def __init__(self, model): 112 | self.model = model 113 | 114 | def perform(self, node, inputs, outputs): 115 | (theta,) = inputs 116 | outputs[0][0] = self.model.score(theta) 117 | 118 | 119 | class ModelResults: 120 | """ModelResults class containing everything needed for inference 121 | intended to allow extension to other models (e.g. tensorflow) 122 | 123 | Parameters 124 | ---------- 125 | ucm_model : statsmodels.tsa.statespace.structural.UnobservedComponents 126 | The constructed UCM model being fit 127 | results : 128 | estimation : string 129 | The estimation method. Options are "MLE" or "pymc" 130 | 131 | """ 132 | 133 | def __init__(self, ucm_model, results, estimation) -> None: 134 | self.results = results 135 | self.estimation = estimation 136 | self.model_nobs = ucm_model.nobs 137 | 138 | def get_prediction(self, start=None, end=None): 139 | """ 140 | In-sample prediction and out-of-sample forecasting 141 | 142 | 143 | Parameters 144 | ---------- 145 | start : int 146 | Zero-indexed observation number at which to start forecasting, 147 | i.e., the first forecast is start. Can also be a date string to 148 | parse or a datetime type. Default is the the zeroth observation. 149 | end : int 150 | Zero-indexed observation number at which to end forecasting, 151 | i.e., the last forecast is end. Can also be a date string to 152 | parse or a datetime type. However, if the dates index does not 153 | have a fixed frequency, end must be an integer index if you want 154 | out of sample prediction. Default is the last observation in the 155 | sample. 156 | 157 | Returns 158 | ------- 159 | ModelPredictions 160 | """ 161 | predictions = self.results.get_prediction(start=start, end=end) 162 | return predictions 163 | 164 | def get_forecast(self, df_post, alpha): 165 | forecast = self.results.get_forecast( 166 | steps=len(df_post), exog=df_post.iloc[:, 1:], alpha=alpha 167 | ) 168 | return forecast 169 | 170 | def summary(self): 171 | return self.results.summary() 172 | 173 | 174 | def model_fit(model, estimation, model_args): 175 | """Fits the model and returns a ModelResults object. 176 | 177 | Uses the chosen estimation option to fit the model and 178 | return a ModelResults object that is agnostic of 179 | estimation approach. 180 | 181 | Parameters: 182 | ----------- 183 | model : statsmodels.tsa.statespace.structural.UnobservedComponents 184 | estimation : str 185 | Either 'MLE' or 'pymc'. 186 | model_args : dict 187 | possible args for MLE are: 188 | niter: int 189 | possible args for pymc are: 190 | ndraws: int 191 | number of draws from the distribution 192 | nburn: int 193 | number of "burn-in points" (which will be discarded) 194 | 195 | """ 196 | if estimation == "MLE": 197 | trained_model = model.fit(maxiter=model_args["niter"]) 198 | model_results = ModelResults(model, trained_model, estimation) 199 | return model_results 200 | elif estimation == "pymc": 201 | loglike = Loglike(model) 202 | with pm.Model(): 203 | # Priors 204 | sigma2irregular = pm.InverseGamma("sigma2.irregular", 1, 1) 205 | sigma2level = pm.InverseGamma("sigma2.level", 1, 1) 206 | if model.exog is None: 207 | # convert variables to tensor vectors 208 | theta = at.as_tensor_variable([sigma2irregular, sigma2level]) 209 | else: 210 | # prior for regressors 211 | betax1 = pm.Laplace("beta.x1", mu=0, b=1.0 / 0.7) 212 | # convert variables to tensor vectors 213 | theta = at.as_tensor_variable([sigma2irregular, sigma2level, betax1]) 214 | # use a DensityDist (use a lambda function to "call" the Op) 215 | pm.Potential("likelihood", loglike(theta)) 216 | 217 | # Draw samples 218 | trace = pm.sample( 219 | model_args["ndraws"], 220 | tune=model_args["nburn"], 221 | return_inferencedata=True, 222 | cores=4, 223 | compute_convergence_checks=False, 224 | ) 225 | # Retrieve the posterior means 226 | params = pm.summary(trace)["mean"].values 227 | 228 | # Construct results using these posterior means as parameter values 229 | results = model.smooth(params) 230 | model_results = ModelResults(model, results, estimation) 231 | return model_results 232 | -------------------------------------------------------------------------------- /tests/README.md: -------------------------------------------------------------------------------- 1 | These files are for testing the methods and functions in CausalImpact 2 | 3 | the tests can be run using running tox 4 | 5 | ```python 6 | tox 7 | ``` 8 | -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jamalsenouci/causalimpact/abf1eea7072f5b3e29fb8aae7be275fc25271d29/tests/__init__.py -------------------------------------------------------------------------------- /tests/conftest.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pytest 3 | 4 | here = os.path.dirname(os.path.abspath(__file__)) 5 | 6 | 7 | @pytest.fixture 8 | def FIXTURES_FOLDER(): 9 | return os.path.join(here, "fixtures") 10 | -------------------------------------------------------------------------------- /tests/fixtures/analysis/summary_report_output.txt: -------------------------------------------------------------------------------- 1 | During the post-intervention period, the response variable had an average value of approx. 3.the absence of an intervention, we would have expected an average response of 3. The 95% interval of this counterfactual prediction is [3, 3]. Subtracting this prediction from the observed response yields an estimate of the causal effect the intervention had on the response variable. This effect is 0 with a 95% interval of [0, 0]. For a discussion of the significance of this effect, see below.Summing up the individual data points during the post-intervention period (which can only sometimes be meaningfully interpreted), the response variable had an overall value of 7.the intervention not taken place, we would have expected a sum of 7. The 95% interval of this prediction is [7, 7]The above results are given in terms of absolute numbers. In relative terms, the response variable showed-2.8%. The 95% interval of this percentage is [0.0%, -11.1%]This means that the negative effect observed during the intervention period is statistically significant. If the experimenter had expected a positive effect, it is recommended to double-check whether anomalies in the control variables may have caused an overly optimistic expectation of what should have happened in the response variable in the absence of the intervention.The probability of obtaining this effect by chance is very small (Bayesian one-sided tail-area probability 0.0). This 2 | means the causal effect can be considered statistically significant. 3 | -------------------------------------------------------------------------------- /tests/test_analysis.py: -------------------------------------------------------------------------------- 1 | """Unit Tests for analysis module.""" 2 | 3 | import os 4 | import re 5 | from tempfile import mkdtemp 6 | import unittest.mock as mock 7 | import pytest 8 | import numpy as np 9 | from numpy.testing import assert_array_equal 10 | import pandas as pd 11 | from pandas.testing import assert_frame_equal 12 | from statsmodels.tsa.statespace.structural import UnobservedComponents as UCM 13 | from causalimpact import CausalImpact 14 | from unittest import TestCase 15 | 16 | 17 | @pytest.fixture() 18 | def data(): 19 | return pd.DataFrame(np.random.randn(200, 3), columns=["y", "x1", "x2"]) 20 | 21 | 22 | @pytest.fixture() 23 | def summary_report_filename(FIXTURES_FOLDER): 24 | return os.path.join(FIXTURES_FOLDER, "analysis", "summary_report_output.txt") 25 | 26 | 27 | @pytest.fixture() 28 | def expected_columns(): 29 | return [ 30 | "response", 31 | "cum_response", 32 | "point_pred", 33 | "point_pred_lower", 34 | "point_pred_upper", 35 | "cum_pred", 36 | "cum_pred_lower", 37 | "cum_pred_upper", 38 | "point_effect", 39 | "point_effect_lower", 40 | "point_effect_upper", 41 | "cum_effect", 42 | "cum_effect_lower", 43 | "cum_effect_upper", 44 | ] 45 | 46 | 47 | @pytest.fixture() 48 | def inference_input(): 49 | return { 50 | "response": np.array([1.0, 2.0, 3.0, 4.0]), 51 | "point_pred": np.array([1.1, 2.2, 3.1, 4.1]), 52 | "point_pred_upper": np.array([1.5, 2.6, 3.4, 4.4]), 53 | "point_pred_lower": np.array([1.0, 2.0, 3.0, 4.0]), 54 | } 55 | 56 | 57 | @pytest.fixture() 58 | def pre_period(): 59 | return [0, 100] 60 | 61 | 62 | @pytest.fixture() 63 | def post_period(): 64 | return [101, 199] 65 | 66 | 67 | @pytest.fixture() 68 | def ucm_model(data, post_period): 69 | data_modeling = data.copy() 70 | data_modeling[post_period[0] : post_period[1] + 1] = np.nan 71 | return UCM(endog=data_modeling.iloc[:, 0].values, level="llevel") 72 | 73 | 74 | @pytest.fixture() 75 | def impact_ucm(ucm_model): 76 | post_period_response = np.random.randn(100) 77 | return CausalImpact(ucm_model=ucm_model, post_period_response=post_period_response) 78 | 79 | 80 | @pytest.fixture() 81 | def causal_impact(data, pre_period, post_period): 82 | model_args = {"niter": 123} 83 | return CausalImpact(data, pre_period, post_period, model_args) 84 | 85 | 86 | class TestFormatInput: 87 | @staticmethod 88 | def test_input_default_model(causal_impact): 89 | expected = { 90 | "data": causal_impact.params["data"], 91 | "pre_period": causal_impact.params["pre_period"], 92 | "post_period": causal_impact.params["post_period"], 93 | "model_args": causal_impact.params["model_args"], 94 | "ucm_model": None, 95 | "post_period_response": None, 96 | "alpha": causal_impact.params["alpha"], 97 | } 98 | 99 | result = causal_impact._format_input( 100 | causal_impact.params["data"], 101 | causal_impact.params["pre_period"], 102 | causal_impact.params["post_period"], 103 | causal_impact.params["model_args"], 104 | None, 105 | None, 106 | causal_impact.params["alpha"], 107 | ) 108 | 109 | result_data = result["data"] 110 | expected_data = expected["data"] 111 | assert_frame_equal(result_data, expected_data) 112 | 113 | result_model_args = result["model_args"] 114 | expected_model_args = expected["model_args"] 115 | assert result_model_args == expected_model_args 116 | 117 | result_other = { 118 | key: result[key] for key in result if key not in {"model_args", "data"} 119 | } 120 | 121 | expected_other = { 122 | key: expected[key] for key in expected if key not in {"model_args", "data"} 123 | } 124 | assert result_other == expected_other 125 | 126 | @staticmethod 127 | def test_input_raises_w_data_and_ucm_model(causal_impact, ucm_model): 128 | # Test inconsistent input (must not provide both data and ucm_model) 129 | with pytest.raises(SyntaxError) as excinfo: 130 | causal_impact._format_input( 131 | causal_impact.params["data"], 132 | causal_impact.params["pre_period"], 133 | causal_impact.params["post_period"], 134 | causal_impact.params["model_args"], 135 | ucm_model, 136 | [1, 2, 3], 137 | causal_impact.params["alpha"], 138 | ) 139 | assert str(excinfo.value) == ( 140 | "Must either provide ``data``, " 141 | "``pre_period`` ,``post_period``, ``model_args`` or " 142 | "``ucm_modeland ``post_period_response``" 143 | ) 144 | 145 | @staticmethod 146 | def test_input_w_ucm_input(ucm_model, impact_ucm): 147 | expected = { 148 | "data": None, 149 | "pre_period": None, 150 | "post_period": None, 151 | "ucm_model": impact_ucm.params["ucm_model"], 152 | "post_period_response": impact_ucm.params["post_period_response"], 153 | "alpha": impact_ucm.params["alpha"], 154 | } 155 | 156 | result = impact_ucm._format_input( 157 | None, 158 | None, 159 | None, 160 | None, 161 | impact_ucm.params["ucm_model"], 162 | impact_ucm.params["post_period_response"], 163 | impact_ucm.params["alpha"], 164 | ) 165 | result.pop("model_args") 166 | assert result == expected 167 | 168 | @staticmethod 169 | def test_format_output_is_df(causal_impact): 170 | # Test that is converted to pandas DataFrame 171 | expected_data = pd.DataFrame(np.arange(0, 8).reshape(4, 2), index=[0, 1, 2, 3]) 172 | 173 | funny_datas = [ 174 | pd.DataFrame([[0, 1], [2, 3], [4, 5], [6, 7]]), 175 | pd.DataFrame(data=[[0, 1], [2, 3], [4, 5], [6, 7]], index=[0, 1, 2, 3]), 176 | [[0, 1], [2, 3], [4, 5], [6, 7]], 177 | np.array([[0, 1], [2, 3], [4, 5], [6, 7]]), 178 | ] 179 | 180 | for funny_data in funny_datas: 181 | result = causal_impact._format_input( 182 | funny_data, [0, 2], [3, 3], {}, None, None, 0.05 183 | ) 184 | assert_array_equal(result["data"].values, expected_data.values) 185 | assert isinstance(result["data"], pd.DataFrame) 186 | 187 | @staticmethod 188 | def test_input_w_bad_data(causal_impact): 189 | text_data = "foo" 190 | with pytest.raises(ValueError) as excinfo: 191 | causal_impact._format_input(text_data, [0, 3], [3, 3], {}, None, None, 0.05) 192 | assert str(excinfo.value) == ( 193 | "could not convert input data to " "Pandas DataFrame" 194 | ) 195 | 196 | @staticmethod 197 | def test_input_w_bad_pre_period(data, causal_impact): 198 | bad_pre_periods = [ 199 | 1, 200 | [], 201 | [1, 2, 3], 202 | [np.nan, 2], 203 | [pd.to_datetime(date) for date in ["2011-01-01", "2011-12-31"]], 204 | ] 205 | 206 | errors_list = [ 207 | "pre_period and post_period must both be lists", 208 | "pre_period and post_period must both be of length 2", 209 | "pre_period and post_period must both be of length 2", 210 | "pre_period and post period must not contain null values", 211 | ( 212 | "pre_period (object) and post_period (int64) should have the same" 213 | " class as the time points in the data (int64)" 214 | ), 215 | ] 216 | 217 | for idx, bad_pre_period in enumerate(bad_pre_periods): 218 | with pytest.raises(ValueError) as excinfo: 219 | causal_impact._format_input( 220 | data, bad_pre_period, [1, 2], None, None, None, 0.05 221 | ) 222 | assert str(excinfo.value) == errors_list[idx] 223 | 224 | @staticmethod 225 | def test_input_w_bad_post_period(data, causal_impact): 226 | bad_post_periods = [ 227 | 1, 228 | [], 229 | [1, 2, 3], 230 | [np.nan, 2], 231 | [pd.to_datetime(date) for date in ["2011-01-01", "2011-12-31"]], 232 | ] 233 | 234 | errors_list = [ 235 | "pre_period and post_period must both be lists", 236 | "pre_period and post_period must both be of length 2", 237 | "pre_period and post_period must both be of length 2", 238 | "pre_period and post period must not contain null values", 239 | ( 240 | "pre_period (int64) and post_period (object) should have the same" 241 | " class as the time points in the data (int64)" 242 | ), 243 | ] 244 | 245 | for idx, bad_post_period in enumerate(bad_post_periods): 246 | with pytest.raises(ValueError) as excinfo: 247 | causal_impact._format_input( 248 | data, [1, 2], bad_post_period, None, None, None, 0.05 249 | ) 250 | assert str(excinfo.value) == errors_list[idx] 251 | 252 | @staticmethod 253 | def test_input_w_pre_and_post_periods_having_distinct_classes(causal_impact): 254 | # Test what happens when pre_period/post_period has a different class 255 | # than the timestamps in 256 | bad_data = pd.DataFrame( 257 | data=[1, 2, 3, 4], 258 | index=["2014-01-01", "2014-01-02", "2014-01-03", "2014-01-04"], 259 | ) 260 | 261 | bad_pre_period = [0.0, 3.0] # float 262 | bad_post_period = [3, 3] 263 | 264 | with pytest.raises(ValueError) as excinfo: 265 | causal_impact._format_input( 266 | bad_data, bad_pre_period, bad_post_period, None, None, None, 0.05 267 | ) 268 | assert str(excinfo.value) == ( 269 | "pre_period (float64) and post_period (" 270 | "int64) should have the same class as the time points in the data" 271 | " (object)" 272 | ) 273 | 274 | bad_pre_period = [0, 2] # integer 275 | bad_post_period = [3, 3] 276 | with pytest.raises(ValueError) as excinfo: 277 | causal_impact._format_input( 278 | bad_data, bad_pre_period, bad_post_period, None, None, None, 0.05 279 | ) 280 | assert str(excinfo.value) == ( 281 | "pre_period (int64) and post_period (" 282 | "int64) should have the same class as the time points in the data" 283 | " (object)" 284 | ) 285 | 286 | @staticmethod 287 | def test_bad_model_args(data, causal_impact): 288 | with pytest.raises(TypeError) as excinfo: 289 | causal_impact._format_input(data, [0, 3], [3, 10], 1000, None, None, 0.05) 290 | assert str(excinfo.value) == "argument of type 'int' is not iterable" 291 | with pytest.raises(TypeError) as excinfo: 292 | causal_impact._format_input( 293 | data, [0, 3], [3, 10], "ninter=1000", None, None, 0.05 294 | ) 295 | assert str(excinfo.value) == ("'str' object does not support item assignment") 296 | 297 | @staticmethod 298 | def test_bad_standardize(data, causal_impact): 299 | bad_standardize_data = [np.nan, 123, "foo", [True, False]] 300 | for bad_standardize in bad_standardize_data: 301 | bad_model_args = {"standardize_data": bad_standardize} 302 | with pytest.raises(ValueError) as excinfo: 303 | causal_impact._format_input( 304 | data, [0, 3], [3, 10], bad_model_args, None, None, 0.05 305 | ) 306 | assert str(excinfo.value) == ( 307 | "model_args.standardize_data must be a boolean value" 308 | ) 309 | 310 | @staticmethod 311 | def test_bad_post_period_response(causal_impact, ucm_model): 312 | with pytest.raises(ValueError) as excinfo: 313 | causal_impact._format_input( 314 | None, None, None, None, ucm_model, pd.to_datetime("2011-01-01"), 0.05 315 | ) 316 | assert str(excinfo.value) == ("post_period_response must be list-like") 317 | 318 | with pytest.raises(ValueError) as excinfo: 319 | causal_impact._format_input(None, None, None, None, ucm_model, True, 0.05) 320 | assert str(excinfo.value) == ("post_period_response must be list-like") 321 | 322 | with pytest.raises(ValueError) as excinfo: 323 | causal_impact._format_input( 324 | None, None, None, None, ucm_model, [pd.to_datetime("2018-01-01")], 0.05 325 | ) 326 | assert str(excinfo.value) == ( 327 | "post_period_response should not be datetime values" 328 | ) 329 | 330 | with pytest.raises(ValueError) as excinfo: 331 | causal_impact._format_input(None, None, None, None, ucm_model, [2j], 0.05) 332 | assert str(excinfo.value) == ( 333 | "post_period_response must contain all real values" 334 | ) 335 | 336 | @staticmethod 337 | def test_bad_alpha(data, causal_impact): 338 | bad_alphas = [None, np.nan, -1, 0, 1, [0.8, 0.9], "0.1"] 339 | for bad_alpha in bad_alphas: 340 | with pytest.raises(ValueError) as excinfo: 341 | causal_impact._format_input( 342 | data, [0, 3], [3, 10], {}, None, None, bad_alpha 343 | ) 344 | assert str(excinfo.value) == "alpha must be a real number" 345 | 346 | @staticmethod 347 | def test_bad_ucm(data, causal_impact): 348 | from statsmodels.tsa.statespace.sarimax import SARIMAX 349 | 350 | bad_ucm = SARIMAX(endog=[1, 2, 3, 4]) 351 | 352 | with pytest.raises(ValueError) as excinfo: 353 | causal_impact._format_input(None, None, None, None, bad_ucm, [2], 0.05) 354 | assert ( 355 | str(excinfo.value) 356 | == "ucm_model must be an object of class " 357 | + "statsmodels.tsa.statespace.structural.UnobservedComponents " 358 | + "instead received statsmodels.tsa.statespace.sarimax.SARIMAX" 359 | ) 360 | 361 | @staticmethod 362 | def test_input_w_date_column(): 363 | data = pd.DataFrame(np.random.randn(100, 2), columns=["x1", "x2"]) 364 | data["date"] = pd.date_range(start="2018-01-01", periods=100) 365 | data = data[["date", "x1", "x2"]] 366 | pre_period = ["2018-01-01", "2018-02-10"] 367 | post_period = ["2018-02-11", "2018-4-10"] 368 | causal_impact = CausalImpact(data, pre_period, post_period, {}) 369 | data = data.set_index("date") 370 | pre_period = [pd.to_datetime(e) for e in pre_period] 371 | post_period = [pd.to_datetime(e) for e in post_period] 372 | 373 | expected = { 374 | "data": data, 375 | "pre_period": pre_period, 376 | "post_period": post_period, 377 | "model_args": causal_impact.params["model_args"], 378 | "ucm_model": None, 379 | "post_period_response": None, 380 | "alpha": causal_impact.params["alpha"], 381 | } 382 | result = causal_impact._format_input( 383 | causal_impact.params["data"], 384 | causal_impact.params["pre_period"], 385 | causal_impact.params["post_period"], 386 | causal_impact.params["model_args"], 387 | None, 388 | None, 389 | causal_impact.params["alpha"], 390 | ) 391 | 392 | result_data = result["data"] 393 | expected_data = expected["data"] 394 | assert_frame_equal(result_data, expected_data) 395 | 396 | result_model_args = result["model_args"] 397 | expected_model_args = expected["model_args"] 398 | assert result_model_args == expected_model_args 399 | 400 | result_other = { 401 | key: result[key] for key in result if key not in {"model_args", "data"} 402 | } 403 | 404 | expected_other = { 405 | key: expected[key] for key in expected if key not in {"model_args", "data"} 406 | } 407 | assert result_other == expected_other 408 | 409 | @staticmethod 410 | def test_input_w_time_column(): 411 | data = pd.DataFrame(np.random.randn(100, 2), columns=["x1", "x2"]) 412 | data["time"] = pd.date_range(start="2018-01-01", periods=100) 413 | data = data[["time", "x1", "x2"]] 414 | pre_period = ["2018-01-01", "2018-02-10"] 415 | post_period = ["2018-02-11", "2018-4-10"] 416 | 417 | causal_impact = CausalImpact(data, pre_period, post_period, {}) 418 | 419 | data = data.set_index("time") 420 | pre_period = [pd.to_datetime(e) for e in pre_period] 421 | post_period = [pd.to_datetime(e) for e in post_period] 422 | 423 | expected = { 424 | "data": data, 425 | "pre_period": pre_period, 426 | "post_period": post_period, 427 | "model_args": causal_impact.params["model_args"], 428 | "ucm_model": None, 429 | "post_period_response": None, 430 | "alpha": causal_impact.params["alpha"], 431 | } 432 | result = causal_impact._format_input( 433 | causal_impact.params["data"], 434 | causal_impact.params["pre_period"], 435 | causal_impact.params["post_period"], 436 | causal_impact.params["model_args"], 437 | None, 438 | None, 439 | causal_impact.params["alpha"], 440 | ) 441 | 442 | result_data = result["data"] 443 | expected_data = expected["data"] 444 | assert_frame_equal(result_data, expected_data) 445 | 446 | result_model_args = result["model_args"] 447 | expected_model_args = expected["model_args"] 448 | assert result_model_args == expected_model_args 449 | 450 | result_other = { 451 | key: result[key] for key in result if key not in {"model_args", "data"} 452 | } 453 | 454 | expected_other = { 455 | key: expected[key] for key in expected if key not in {"model_args", "data"} 456 | } 457 | assert result_other == expected_other 458 | 459 | @staticmethod 460 | def test_input_w_just_2_points_raises_exception(): 461 | data = pd.DataFrame(np.random.randn(2, 2), columns=["x1", "x2"]) 462 | causal_impact = CausalImpact(data, [0, 0], [1, 1], {}) 463 | 464 | with pytest.raises(ValueError) as excinfo: 465 | causal_impact._format_input( 466 | causal_impact.params["data"], 467 | causal_impact.params["pre_period"], 468 | causal_impact.params["post_period"], 469 | causal_impact.params["model_args"], 470 | None, 471 | None, 472 | causal_impact.params["alpha"], 473 | ) 474 | assert str(excinfo.value) == "data must have at least 3 time points" 475 | 476 | @staticmethod 477 | def test_input_covariates_w_nan_value_raises(): 478 | data = np.array([[1, 1, 2], [1, 2, 3], [1, 3, 4], [1, np.nan, 5], [1, 6, 7]]) 479 | data = pd.DataFrame(data, columns=["y", "x1", "x2"]) 480 | causal_impact = CausalImpact(data, [0, 3], [3, 4], {}) 481 | 482 | with pytest.raises(ValueError) as excinfo: 483 | causal_impact._format_input( 484 | causal_impact.params["data"], 485 | causal_impact.params["pre_period"], 486 | causal_impact.params["post_period"], 487 | causal_impact.params["model_args"], 488 | None, 489 | None, 490 | causal_impact.params["alpha"], 491 | ) 492 | assert str(excinfo.value) == "covariates must not contain null values" 493 | 494 | @staticmethod 495 | def test_int_index_pre_period_contains_float(causal_impact, pre_period): 496 | expected = { 497 | "data": causal_impact.params["data"], 498 | "pre_period": causal_impact.params["pre_period"], 499 | "post_period": causal_impact.params["post_period"], 500 | "model_args": causal_impact.params["model_args"], 501 | "ucm_model": None, 502 | "post_period_response": None, 503 | "alpha": causal_impact.params["alpha"], 504 | } 505 | result = causal_impact._format_input( 506 | causal_impact.params["data"], 507 | [float(pre_period[0]), pre_period[1]], 508 | causal_impact.params["post_period"], 509 | causal_impact.params["model_args"], 510 | None, 511 | None, 512 | causal_impact.params["alpha"], 513 | ) 514 | 515 | result_data = result["data"] 516 | expected_data = expected["data"] 517 | assert_frame_equal(result_data, expected_data) 518 | 519 | result_model_args = result["model_args"] 520 | expected_model_args = expected["model_args"] 521 | assert result_model_args == expected_model_args 522 | 523 | result_other = { 524 | key: result[key] for key in result if key not in {"model_args", "data"} 525 | } 526 | 527 | expected_other = { 528 | key: expected[key] for key in expected if key not in {"model_args", "data"} 529 | } 530 | assert result_other == expected_other 531 | 532 | @staticmethod 533 | def test_float_index_pre_period_contains_int(): 534 | data = np.random.randn(200, 3) 535 | data = pd.DataFrame(data, columns=["y", "x1", "x2"]) 536 | data = data.set_index(np.array([float(i) for i in range(200)])) 537 | causal_impact = CausalImpact(data, [0, 3], [3, 4], {}) 538 | 539 | expected = { 540 | "data": causal_impact.params["data"], 541 | "pre_period": causal_impact.params["pre_period"], 542 | "post_period": causal_impact.params["post_period"], 543 | "model_args": causal_impact.params["model_args"], 544 | "ucm_model": None, 545 | "post_period_response": None, 546 | "alpha": causal_impact.params["alpha"], 547 | } 548 | result = causal_impact._format_input( 549 | causal_impact.params["data"], 550 | causal_impact.params["pre_period"], 551 | causal_impact.params["post_period"], 552 | causal_impact.params["model_args"], 553 | None, 554 | None, 555 | causal_impact.params["alpha"], 556 | ) 557 | 558 | result_data = result["data"] 559 | expected_data = expected["data"] 560 | assert_frame_equal(result_data, expected_data) 561 | 562 | result_model_args = result["model_args"] 563 | expected_model_args = expected["model_args"] 564 | assert result_model_args == expected_model_args 565 | 566 | result_other = { 567 | key: result[key] for key in result if key not in {"model_args", "data"} 568 | } 569 | expected_other = { 570 | key: expected[key] for key in expected if key not in {"model_args", "data"} 571 | } 572 | assert result_other == expected_other 573 | 574 | @staticmethod 575 | def test_pre_period_in_conflict_w_post_period(): 576 | data = pd.DataFrame(np.random.randn(20, 2), columns=["x1", "x2"]) 577 | causal_impact = CausalImpact(data, [0, 10], [9, 20], {}) 578 | 579 | with pytest.raises(ValueError) as excinfo: 580 | causal_impact._format_input( 581 | causal_impact.params["data"], 582 | causal_impact.params["pre_period"], 583 | causal_impact.params["post_period"], 584 | causal_impact.params["model_args"], 585 | None, 586 | None, 587 | causal_impact.params["alpha"], 588 | ) 589 | assert str(excinfo.value) == ( 590 | "post period must start at least 1 observation after the end of " 591 | "the pre_period" 592 | ) 593 | 594 | causal_impact = CausalImpact(data, [0, 10], [11, 9], {}) 595 | with pytest.raises(ValueError) as excinfo: 596 | causal_impact._format_input( 597 | causal_impact.params["data"], 598 | causal_impact.params["pre_period"], 599 | causal_impact.params["post_period"], 600 | causal_impact.params["model_args"], 601 | None, 602 | None, 603 | causal_impact.params["alpha"], 604 | ) 605 | assert str(excinfo.value) == ( 606 | "post_period[1] must not be earlier than post_period[0]" 607 | ) 608 | 609 | causal_impact = CausalImpact(data, [0, 10], [11, 9], {}) 610 | with pytest.raises(ValueError) as excinfo: 611 | causal_impact._format_input( 612 | causal_impact.params["data"], 613 | causal_impact.params["pre_period"], 614 | causal_impact.params["post_period"], 615 | causal_impact.params["model_args"], 616 | None, 617 | None, 618 | causal_impact.params["alpha"], 619 | ) 620 | assert str(excinfo.value) == ( 621 | "post_period[1] must not be earlier than post_period[0]" 622 | ) 623 | 624 | 625 | class TestRunWithData: 626 | @staticmethod 627 | def test_missing_input(): 628 | with pytest.raises(SyntaxError): 629 | impact = CausalImpact() 630 | impact.run() 631 | 632 | @staticmethod 633 | def test_unlabelled_pandas_series(expected_columns, pre_period, post_period): 634 | model_args = {"niter": 123, "standardize_data": False} 635 | alpha = 0.05 636 | data = pd.DataFrame(np.random.randn(200, 3)) 637 | causal_impact = CausalImpact( 638 | data.values, pre_period, post_period, model_args, None, None, alpha, "MLE" 639 | ) 640 | 641 | causal_impact.run() 642 | actual_columns = list(causal_impact.inferences.columns) 643 | assert actual_columns == expected_columns 644 | 645 | @staticmethod 646 | def test_other_formats(expected_columns, pre_period, post_period): 647 | # Test other data formats 648 | model_args = {"niter": 100, "standardize_data": True} 649 | 650 | # labelled dataframe 651 | data = pd.DataFrame(np.random.randn(200, 3), columns=["a", "b", "c"]) 652 | impact = CausalImpact(data, pre_period, post_period, model_args) 653 | impact.run() 654 | actual_columns = list(impact.inferences.columns) 655 | assert actual_columns == expected_columns 656 | 657 | # numpy array 658 | data = np.random.randn(200, 3) 659 | impact = CausalImpact(data, pre_period, post_period, model_args) 660 | impact.run() 661 | actual_columns = list(impact.inferences.columns) 662 | assert actual_columns == expected_columns 663 | 664 | # list of lists 665 | data = np.random.randn(200, 2).tolist() 666 | impact = CausalImpact(data, pre_period, post_period, model_args) 667 | impact.run() 668 | actual_columns = list(impact.inferences.columns) 669 | assert actual_columns == expected_columns 670 | 671 | @staticmethod 672 | def test_frame_w_no_exog(pre_period, post_period): 673 | data = np.random.randn(200) 674 | impact = CausalImpact(data, pre_period, post_period, {}) 675 | with pytest.raises(ValueError) as excinfo: 676 | impact.run() 677 | assert str(excinfo.value) == "data contains no exogenous variables" 678 | 679 | @staticmethod 680 | def test_missing_pre_period_data(data, pre_period, post_period): 681 | model_data = data.copy() 682 | model_data.iloc[3:5, 0] = np.nan 683 | impact = CausalImpact(model_data, pre_period, post_period) 684 | impact.run() 685 | assert len(impact.inferences) == len(model_data) 686 | 687 | @staticmethod 688 | def test_pre_period_starts_after_beginning_of_data(data): 689 | pre_period = [3, 100] 690 | impact = CausalImpact(data, pre_period, [101, 199]) 691 | impact.run() 692 | np.testing.assert_array_equal(impact.inferences.response.values, data.y.values) 693 | assert np.all(pd.isnull(impact.inferences.iloc[0 : pre_period[0], 2:])) 694 | 695 | @staticmethod 696 | def test_post_period_finishes_before_end_of_data(data, pre_period): 697 | post_period = [101, 197] 698 | impact = CausalImpact(data, pre_period, post_period) 699 | impact.run() 700 | np.testing.assert_array_equal(impact.inferences.response.values, data.y.values) 701 | assert np.all(pd.isnull(impact.inferences.iloc[-2:, 2:])) 702 | 703 | @staticmethod 704 | def test_gap_between_pre_and_post_periods(data, pre_period): 705 | post_period = [120, 199] 706 | impact = CausalImpact(data, pre_period, post_period) 707 | impact.run() 708 | assert np.all( 709 | pd.isnull(impact.inferences.loc[101:119, impact.inferences.columns[2:]]) 710 | ) 711 | 712 | @staticmethod 713 | def test_late_start_early_finish_and_gap_between_periods(data): 714 | pre_period = [3, 80] 715 | post_period = [120, 197] 716 | impact = CausalImpact(data, pre_period, post_period) 717 | impact.run() 718 | assert np.all( 719 | pd.isnull(impact.inferences.loc[:2, impact.inferences.columns[2:]]) 720 | ) 721 | assert np.all( 722 | pd.isnull(impact.inferences.loc[81:119, impact.inferences.columns[2:]]) 723 | ) 724 | assert np.all( 725 | pd.isnull(impact.inferences.loc[198:, impact.inferences.columns[2:]]) 726 | ) 727 | 728 | @staticmethod 729 | def test_pre_period_lower_than_data_index_min(data): 730 | pre_period = [-1, 100] 731 | post_period = [101, 199] 732 | impact = CausalImpact(data, pre_period, post_period) 733 | impact.run() 734 | assert impact.params["pre_period"] == [0, 100] 735 | 736 | @staticmethod 737 | def test_post_period_bigger_than_data_index_max(data): 738 | pre_period = [0, 100] 739 | post_period = [101, 300] 740 | impact = CausalImpact(data, pre_period, post_period) 741 | impact.run() 742 | assert impact.params["post_period"] == [101, 199] 743 | 744 | @staticmethod 745 | def test_missing_values_in_pre_period_y(pre_period, post_period): 746 | """Test that all columns in the result series except the 747 | point predictions have missing values at the time points the 748 | result time series has missing values.""" 749 | data = pd.DataFrame(np.random.randn(200, 3), columns=["y", "x1", "x2"]) 750 | data.iloc[95:100, 0] = np.nan 751 | 752 | impact = CausalImpact(data, pre_period, post_period) 753 | impact.run() 754 | 755 | predicted_cols = [ 756 | impact.inferences.columns.get_loc(col) 757 | for col in impact.inferences.columns 758 | if ("response" not in col and "point_effect" not in col) 759 | ] 760 | 761 | effect_cols = [ 762 | impact.inferences.columns.get_loc(col) 763 | for col in impact.inferences.columns 764 | if "point_effect" in col 765 | ] 766 | 767 | response_cols = [ 768 | impact.inferences.columns.get_loc(col) 769 | for col in impact.inferences.columns 770 | if "response" in col 771 | ] 772 | 773 | assert np.all(np.isnan(impact.inferences.iloc[95:100, response_cols])) 774 | assert np.all(np.isnan(impact.inferences.iloc[95:100, effect_cols])) 775 | TestCase().assertFalse( 776 | np.any(np.isnan(impact.inferences.iloc[95:100, predicted_cols])) 777 | ) 778 | TestCase().assertFalse(np.any(np.isnan(impact.inferences.iloc[:95, :]))) 779 | TestCase().assertFalse(np.any(np.isnan(impact.inferences.iloc[101:, :]))) 780 | 781 | 782 | class TestRunWithUCM: 783 | @staticmethod 784 | def test_regular_run(expected_columns, impact_ucm): 785 | impact_ucm.run() 786 | actual_columns = list(impact_ucm.inferences.columns) 787 | assert actual_columns == expected_columns 788 | 789 | 790 | class TestSummary: 791 | @staticmethod 792 | def test_summary(inference_input): 793 | inferences_df = pd.DataFrame(inference_input) 794 | causal = CausalImpact() 795 | 796 | params = {"alpha": 0.05, "post_period": [2, 4]} 797 | 798 | causal.params = params 799 | causal.inferences = inferences_df 800 | 801 | expected = [ 802 | [3, 7], 803 | [3, 7], 804 | [[3, 3], [7, 7]], 805 | [" ", " "], 806 | [0, 0], 807 | [[0, 0], [0, 0]], 808 | [" ", " "], 809 | ["-2.8%", "-2.8%"], 810 | [["0.0%", "-11.1%"], ["0.0%", "-11.1%"]], 811 | [" ", " "], 812 | ["0.0%", " "], 813 | ["100.0%", " "], 814 | ] 815 | 816 | expected = pd.DataFrame( 817 | expected, 818 | columns=["Average", "Cumulative"], 819 | index=[ 820 | "Actual", 821 | "Predicted", 822 | "95% CI", 823 | " ", 824 | "Absolute Effect", 825 | "95% CI", 826 | " ", 827 | "Relative Effect", 828 | "95% CI", 829 | " ", 830 | "P-value", 831 | "Prob. of Causal Effect", 832 | ], 833 | ) 834 | 835 | tmpdir = mkdtemp() 836 | tmp_expected = "tmp_expected" 837 | tmp_result = "tmp_test_summary" 838 | 839 | result_file = os.path.join(tmpdir, tmp_result) 840 | expected_file = os.path.join(tmpdir, tmp_expected) 841 | 842 | expected.to_csv(expected_file) 843 | expected_str = open(expected_file).read() 844 | 845 | causal.summary(path=result_file) 846 | 847 | result = open(result_file).read() 848 | assert result == expected_str 849 | 850 | @staticmethod 851 | def test_summary_w_report_output( 852 | monkeypatch, inference_input, summary_report_filename 853 | ): 854 | inferences_df = pd.DataFrame(inference_input) 855 | causal = CausalImpact() 856 | 857 | params = {"alpha": 0.05, "post_period": [2, 4]} 858 | 859 | causal.params = params 860 | causal.inferences = inferences_df 861 | 862 | dedent_mock = mock.Mock() 863 | 864 | expected = open(summary_report_filename).read() 865 | expected = re.sub(r"\s+", " ", expected) 866 | expected = expected.strip() 867 | 868 | tmpdir = mkdtemp() 869 | tmp_file = os.path.join(tmpdir, "summary_test") 870 | 871 | def dedent_side_effect(msg): 872 | with open(tmp_file, "a") as file_obj: 873 | msg = re.sub(r"\s+", " ", msg) 874 | msg = msg.strip() 875 | file_obj.write(msg) 876 | return msg 877 | 878 | dedent_mock.side_effect = dedent_side_effect 879 | monkeypatch.setattr("textwrap.dedent", dedent_mock) 880 | 881 | causal.summary(output="report") 882 | result_str = open(tmp_file, "r").read() 883 | assert result_str == expected 884 | 885 | @staticmethod 886 | def test_summary_wrong_argument_raises(inference_input): 887 | inferences_df = pd.DataFrame(inference_input) 888 | causal = CausalImpact() 889 | 890 | params = {"alpha": 0.05, "post_period": [2, 4]} 891 | 892 | causal.params = params 893 | causal.inferences = inferences_df 894 | 895 | with pytest.raises(ValueError): 896 | causal.summary(output="wrong_argument") 897 | 898 | 899 | class TestPlot: 900 | # @patch('causalimpact.data') 901 | @staticmethod 902 | def test_plot(monkeypatch): 903 | 904 | params = {"alpha": 0.05, "post_period": [2, 4], "pre_period": [0, 1]} 905 | inferences_mock = mock.MagicMock() 906 | 907 | class EnhancedDict: 908 | @property 909 | def index(self): 910 | return [0, 1] 911 | 912 | @property 913 | def response(self): 914 | return "y obs" 915 | 916 | @property 917 | def point_pred(self): 918 | return "points predicted" 919 | 920 | @property 921 | def point_pred_lower(self): 922 | return "lower predictions" 923 | 924 | @property 925 | def point_pred_upper(self): 926 | return "upper predictions" 927 | 928 | @property 929 | def point_effect(self): 930 | return "lift" 931 | 932 | @property 933 | def point_effect_lower(self): 934 | return "point effect lower" 935 | 936 | @property 937 | def point_effect_upper(self): 938 | return "point effect upper" 939 | 940 | @property 941 | def cum_effect(self): 942 | return "cum effect" 943 | 944 | @property 945 | def cum_effect_lower(self): 946 | return "cum effect lower" 947 | 948 | @property 949 | def cum_effect_upper(self): 950 | return "cum effect upper" 951 | 952 | def getitem(name): 953 | return EnhancedDict() 954 | 955 | inferences_mock.iloc.__getitem__.side_effect = getitem 956 | 957 | class Data: 958 | @property 959 | def index(self): 960 | return "index" 961 | 962 | @property 963 | def shape(self): 964 | return [(1, 2)] 965 | 966 | data_mock = Data() 967 | 968 | plot_mock = mock.Mock() 969 | np_zeros_mock = mock.Mock() 970 | np_zeros_mock.side_effect = lambda x: [0, 0] 971 | 972 | get_lib_mock = mock.Mock(return_value=plot_mock) 973 | monkeypatch.setattr("causalimpact.analysis.get_matplotlib", get_lib_mock) 974 | 975 | monkeypatch.setattr("numpy.zeros", np_zeros_mock) 976 | 977 | causal = CausalImpact() 978 | causal.params = params 979 | causal.inferences = inferences_mock 980 | causal.data = data_mock 981 | 982 | causal.plot(panels=["original", "pointwise", "cumulative"]) 983 | causal.plot(panels=["pointwise", "cumulative"]) 984 | 985 | causal.plot(panels=["original"]) 986 | plot_mock.plot.assert_any_call("y obs", "k", label="endog", linewidth=2) 987 | plot_mock.plot.assert_any_call( 988 | "points predicted", "r--", label="model", linewidth=2 989 | ) 990 | 991 | plot_mock.fill_between.assert_any_call( 992 | [0, 1], 993 | "lower predictions", 994 | "upper predictions", 995 | facecolor="gray", 996 | interpolate=True, 997 | alpha=0.25, 998 | ) 999 | 1000 | causal.plot(panels=["pointwise"]) 1001 | 1002 | plot_mock.plot.assert_any_call("lift", "r--", linewidth=2) 1003 | plot_mock.plot.assert_any_call("index", [0, 0], "g-", linewidth=2) 1004 | 1005 | causal.plot(panels=["cumulative"]) 1006 | 1007 | plot_mock.plot.assert_any_call([0, 1], "cum effect", "r--", linewidth=2) 1008 | plot_mock.plot.assert_any_call("index", [0, 0], "g-", linewidth=2) 1009 | -------------------------------------------------------------------------------- /tests/test_inferences.py: -------------------------------------------------------------------------------- 1 | """Unit Tests for inferences module""" 2 | 3 | import pytest 4 | import numpy as np 5 | import pandas as pd 6 | from pandas.testing import assert_series_equal 7 | from statsmodels.tsa.arima_process import ArmaProcess 8 | 9 | import causalimpact 10 | from unittest.mock import Mock 11 | 12 | from causalimpact.model import ModelResults 13 | 14 | compile_inferences = causalimpact.inferences.compile_inferences 15 | np.random.seed(1) 16 | 17 | 18 | @pytest.fixture 19 | def data(): 20 | ar = np.r_[1, 0.9] 21 | ma = np.array([1]) 22 | arma_process = ArmaProcess(ar, ma) 23 | 24 | X = 1 + arma_process.generate_sample(nsample=100) 25 | X = X.reshape(-1, 1) 26 | y = 1.2 * X + np.random.normal(size=(100, 1)) 27 | data = np.concatenate((y, X), axis=1) 28 | data = pd.DataFrame(data) 29 | return data 30 | 31 | 32 | @pytest.fixture 33 | def pre_period(): 34 | return [0, 49] 35 | 36 | 37 | @pytest.fixture 38 | def post_period(): 39 | return [50, 100] 40 | 41 | 42 | @pytest.fixture 43 | def trained_model(post_period): 44 | trained_model = Mock(spec=ModelResults) 45 | trained_model.model_nobs = 100 46 | # the mocked methods on these two classses 47 | # only work because pre-period and 48 | # post_period are both 50 length 49 | 50 | class PredictionResultsMock: 51 | def __init__(self, n): 52 | self.predicted_mean = np.ones(50) 53 | 54 | def conf_int(self, alpha=None): 55 | return np.ones(100).reshape(50, 2) 56 | 57 | class ForecastResultsMock: 58 | def __init__(self): 59 | self.predicted_mean = np.ones(50) 60 | 61 | def conf_int(self, alpha=None): 62 | return np.ones(100).reshape(50, 2) 63 | 64 | def get_prediction_mock(start=None, end=None): 65 | if start is not None or end is not None: 66 | return PredictionResultsMock(n=50) 67 | else: 68 | return PredictionResultsMock(n=100) 69 | 70 | def get_forecast_mock(df_post, alpha=None): 71 | return ForecastResultsMock() 72 | 73 | trained_model.get_prediction = get_prediction_mock 74 | trained_model.get_forecast = get_forecast_mock 75 | return trained_model 76 | 77 | 78 | def test_compile_inferences_w_data(data, pre_period, post_period, trained_model): 79 | 80 | df_pre = data.loc[pre_period[0] : pre_period[1], :] 81 | df_post = data.loc[post_period[0] : post_period[1], :] 82 | 83 | post_period_response = None 84 | alpha = 0.05 85 | orig_std_params = (0.0, 1.0) 86 | 87 | estimation = "MLE" 88 | 89 | inferences = compile_inferences( 90 | trained_model, 91 | data, 92 | df_pre, 93 | df_post, 94 | post_period_response, 95 | alpha, 96 | orig_std_params, 97 | estimation, 98 | ) 99 | 100 | expected_response = pd.Series(data.iloc[:, 0], name="response") 101 | assert_series_equal(expected_response, inferences["series"]["response"]) 102 | 103 | expected_cumsum = pd.Series(np.cumsum(expected_response), name="cum_response") 104 | 105 | assert_series_equal(expected_cumsum, inferences["series"]["cum_response"]) 106 | 107 | predictor = trained_model.get_prediction() 108 | forecaster = trained_model.get_forecast(df_post, alpha=alpha) 109 | 110 | pre_pred = predictor.predicted_mean 111 | post_pred = forecaster.predicted_mean 112 | 113 | point_pred = np.concatenate([pre_pred, post_pred]) 114 | 115 | expected_point_pred = pd.Series(point_pred, name="point_pred") 116 | assert_series_equal(expected_point_pred, inferences["series"]["point_pred"]) 117 | 118 | pre_ci = pd.DataFrame(predictor.conf_int(alpha=alpha)) 119 | pre_ci.index = df_pre.index 120 | post_ci = pd.DataFrame(forecaster.conf_int(alpha=alpha)) 121 | post_ci.index = df_post.index 122 | 123 | ci = pd.concat([pre_ci, post_ci]) 124 | 125 | expected_pred_upper = ci.iloc[:, 1] 126 | expected_pred_upper = expected_pred_upper.rename("point_pred_upper") 127 | expected_pred_lower = ci.iloc[:, 0] 128 | expected_pred_lower = expected_pred_lower.rename("point_pred_lower") 129 | 130 | assert_series_equal(expected_pred_upper, inferences["series"]["point_pred_upper"]) 131 | assert_series_equal(expected_pred_lower, inferences["series"]["point_pred_lower"]) 132 | 133 | expected_cum_pred = pd.Series(np.cumsum(point_pred), name="cum_pred") 134 | assert_series_equal(expected_cum_pred, inferences["series"]["cum_pred"]) 135 | 136 | expected_cum_pred_lower = pd.Series( 137 | np.cumsum(expected_pred_lower), name="cum_pred_lower" 138 | ) 139 | assert_series_equal(expected_cum_pred_lower, inferences["series"]["cum_pred_lower"]) 140 | 141 | expected_cum_pred_upper = pd.Series( 142 | np.cumsum(expected_pred_upper), name="cum_pred_upper" 143 | ) 144 | assert_series_equal(expected_cum_pred_upper, inferences["series"]["cum_pred_upper"]) 145 | 146 | expected_point_effect = pd.Series( 147 | expected_response - expected_point_pred, name="point_effect" 148 | ) 149 | assert_series_equal(expected_point_effect, inferences["series"]["point_effect"]) 150 | 151 | expected_point_effect_lower = pd.Series( 152 | expected_response - expected_pred_lower, name="point_effect_lower" 153 | ) 154 | assert_series_equal( 155 | expected_point_effect_lower, inferences["series"]["point_effect_lower"] 156 | ) 157 | 158 | expected_point_effect_upper = pd.Series( 159 | expected_response - expected_pred_upper, name="point_effect_upper" 160 | ) 161 | assert_series_equal( 162 | expected_point_effect_upper, inferences["series"]["point_effect_upper"] 163 | ) 164 | 165 | expected_cum_effect = pd.Series( 166 | np.concatenate( 167 | ( 168 | np.zeros(len(df_pre)), 169 | np.cumsum(expected_point_effect.iloc[len(df_pre) :]), 170 | ) 171 | ), 172 | name="cum_effect", 173 | ) 174 | assert_series_equal(expected_cum_effect, inferences["series"]["cum_effect"]) 175 | 176 | expected_cum_effect_lower = pd.Series( 177 | np.concatenate( 178 | ( 179 | np.zeros(len(df_pre)), 180 | np.cumsum(expected_point_effect_lower.iloc[len(df_pre) :]), 181 | ) 182 | ), 183 | name="cum_effect_lower", 184 | ) 185 | assert_series_equal( 186 | expected_cum_effect_lower, inferences["series"]["cum_effect_lower"] 187 | ) 188 | 189 | expected_cum_effect_upper = pd.Series( 190 | np.concatenate( 191 | ( 192 | np.zeros(len(df_pre)), 193 | np.cumsum(expected_point_effect_upper.iloc[len(df_pre) :]), 194 | ) 195 | ), 196 | name="cum_effect_upper", 197 | ) 198 | assert_series_equal( 199 | expected_cum_effect_upper, inferences["series"]["cum_effect_upper"] 200 | ) 201 | 202 | 203 | def test_compile_inferences_w_post_period_response( 204 | data, pre_period, post_period, trained_model 205 | ): 206 | 207 | df_pre = data.loc[pre_period[0] : pre_period[1], :] 208 | df_post = data.loc[post_period[0] : post_period[1], :] 209 | 210 | post_period_response = df_post.loc[post_period[0] : post_period[1]] 211 | 212 | X = df_post.iloc[:, 1:] 213 | y = X.copy() 214 | y[:] = np.nan 215 | 216 | df_post = pd.DataFrame(np.concatenate([y, X], axis=1)) 217 | data_index = data.index 218 | data = pd.concat([df_pre, df_post], axis=0) 219 | data.index = data_index 220 | 221 | alpha = 0.05 222 | orig_std_params = (0.0, 1.0) 223 | estimation = "MLE" 224 | 225 | inferences = compile_inferences( 226 | trained_model, 227 | data, 228 | df_pre, 229 | None, 230 | post_period_response, 231 | alpha, 232 | orig_std_params, 233 | estimation, 234 | ) 235 | 236 | expected_response = pd.Series(data.iloc[:, 0], name="response") 237 | assert_series_equal(expected_response, inferences["series"]["response"]) 238 | 239 | expected_cumsum = pd.Series(np.cumsum(expected_response), name="cum_response") 240 | 241 | assert_series_equal(expected_cumsum, inferences["series"]["cum_response"]) 242 | 243 | predictor = trained_model.get_prediction(end=len(df_pre) - 1) 244 | forecaster = trained_model.get_prediction(start=len(df_pre)) 245 | 246 | pre_pred = predictor.predicted_mean 247 | post_pred = forecaster.predicted_mean 248 | 249 | point_pred = np.concatenate([pre_pred, post_pred]) 250 | 251 | expected_point_pred = pd.Series(point_pred, name="point_pred") 252 | assert_series_equal(expected_point_pred, inferences["series"]["point_pred"]) 253 | 254 | pre_ci = pd.DataFrame(predictor.conf_int(alpha=alpha)) 255 | pre_ci.index = df_pre.index 256 | post_ci = pd.DataFrame(forecaster.conf_int(alpha=alpha)) 257 | post_ci.index = df_post.index 258 | 259 | ci = pd.concat([pre_ci, post_ci]) 260 | 261 | expected_pred_upper = ci.iloc[:, 1] 262 | expected_pred_upper = expected_pred_upper.rename("point_pred_upper") 263 | expected_pred_upper.index = data.index 264 | 265 | expected_pred_lower = ci.iloc[:, 0] 266 | expected_pred_lower = expected_pred_lower.rename("point_pred_lower") 267 | expected_pred_lower.index = data.index 268 | 269 | assert_series_equal(expected_pred_upper, inferences["series"]["point_pred_upper"]) 270 | assert_series_equal(expected_pred_lower, inferences["series"]["point_pred_lower"]) 271 | 272 | expected_cum_pred = pd.Series(np.cumsum(point_pred), name="cum_pred") 273 | assert_series_equal(expected_cum_pred, inferences["series"]["cum_pred"]) 274 | 275 | expected_cum_pred_lower = pd.Series( 276 | np.cumsum(expected_pred_lower), name="cum_pred_lower" 277 | ) 278 | assert_series_equal(expected_cum_pred_lower, inferences["series"]["cum_pred_lower"]) 279 | 280 | expected_cum_pred_upper = pd.Series( 281 | np.cumsum(expected_pred_upper), name="cum_pred_upper" 282 | ) 283 | assert_series_equal(expected_cum_pred_upper, inferences["series"]["cum_pred_upper"]) 284 | 285 | expected_point_effect = pd.Series( 286 | expected_response - expected_point_pred, name="point_effect" 287 | ) 288 | assert_series_equal(expected_point_effect, inferences["series"]["point_effect"]) 289 | 290 | expected_point_effect_lower = pd.Series( 291 | expected_response - expected_pred_lower, name="point_effect_lower" 292 | ) 293 | assert_series_equal( 294 | expected_point_effect_lower, inferences["series"]["point_effect_lower"] 295 | ) 296 | 297 | expected_point_effect_upper = pd.Series( 298 | expected_response - expected_pred_upper, name="point_effect_upper" 299 | ) 300 | assert_series_equal( 301 | expected_point_effect_upper, inferences["series"]["point_effect_upper"] 302 | ) 303 | 304 | expected_cum_effect = pd.Series( 305 | np.concatenate( 306 | ( 307 | np.zeros(len(df_pre)), 308 | np.cumsum(expected_point_effect.iloc[len(df_pre) :]), 309 | ) 310 | ), 311 | name="cum_effect", 312 | ) 313 | assert_series_equal(expected_cum_effect, inferences["series"]["cum_effect"]) 314 | 315 | expected_cum_effect_lower = pd.Series( 316 | np.concatenate( 317 | ( 318 | np.zeros(len(df_pre)), 319 | np.cumsum(expected_point_effect_lower.iloc[len(df_pre) :]), 320 | ) 321 | ), 322 | name="cum_effect_lower", 323 | ) 324 | assert_series_equal( 325 | expected_cum_effect_lower, inferences["series"]["cum_effect_lower"] 326 | ) 327 | 328 | expected_cum_effect_upper = pd.Series( 329 | np.concatenate( 330 | ( 331 | np.zeros(len(df_pre)), 332 | np.cumsum(expected_point_effect_upper.iloc[len(df_pre) :]), 333 | ) 334 | ), 335 | name="cum_effect_upper", 336 | ) 337 | assert_series_equal( 338 | expected_cum_effect_upper, inferences["series"]["cum_effect_upper"] 339 | ) 340 | -------------------------------------------------------------------------------- /tests/test_misc.py: -------------------------------------------------------------------------------- 1 | """Tests for misc module.""" 2 | 3 | import unittest.mock as mock 4 | import numpy as np 5 | import pandas as pd 6 | from pandas.testing import assert_frame_equal 7 | from numpy.testing import assert_almost_equal 8 | import pytest 9 | 10 | import causalimpact 11 | 12 | 13 | standardize = causalimpact.misc.standardize_all_variables 14 | unstandardize = causalimpact.misc.unstandardize 15 | df_print = causalimpact.misc.df_print 16 | 17 | 18 | def test_basic_standardize(): 19 | pre_period = [0, 2] 20 | post_period = [3, 4] 21 | 22 | data = {"c1": [1, 4, 8, 9, 10], "c2": [4, 8, 12, 16, 20]} 23 | data = pd.DataFrame(data) 24 | 25 | result = standardize(data, pre_period, post_period) 26 | assert_almost_equal(np.zeros((2)), np.mean(result["data_pre"].values, axis=0)) 27 | 28 | assert_almost_equal(np.ones((2)), np.std(result["data_pre"].values, axis=0)) 29 | assert len(result["data_pre"]) == pre_period[-1] + 1 30 | 31 | 32 | def test_standardize_returns_expected_types(): 33 | pre_period = [0, 4] 34 | post_period = [5, 5] 35 | 36 | data = [-1, 0.1, 1, 2, np.nan, 3] 37 | data = pd.DataFrame(data) 38 | 39 | result = standardize(data, pre_period, post_period) 40 | 41 | assert isinstance(result, dict) 42 | assert set(result.keys()) == {"data_pre", "data_post", "orig_std_params"} 43 | 44 | assert len(result["data_pre"]) == pre_period[-1] + 1 45 | assert_frame_equal( 46 | unstandardize(result["data_pre"], result["orig_std_params"]), 47 | pd.DataFrame(data[:5]), 48 | ) 49 | 50 | 51 | def test_standardize_w_distinct_inputs(): 52 | test_data = [[1], [1, 1, 1], [1, np.nan, 3], pd.DataFrame([10, 20, 30])] 53 | 54 | test_data = [pd.DataFrame(data, dtype="float") for data in test_data] 55 | 56 | for data in test_data: 57 | result = standardize( 58 | data, 59 | pre_period=[0, len(data) + 1], 60 | post_period=[len(data) + 1, len(data) + 1], 61 | ) 62 | 63 | pd.util.testing.assert_frame_equal( 64 | unstandardize(result["data_pre"], result["orig_std_params"]), data 65 | ) 66 | 67 | 68 | def test_standardize_raises_w_bad_input(): 69 | with pytest.raises(ValueError): 70 | standardize("text", 1, 2) 71 | 72 | with pytest.raises(ValueError): 73 | standardize(pd.DataFrame([1, 2]), 1, 2) 74 | 75 | 76 | def test_unstandardize(): 77 | data = np.array([-1.16247639, -0.11624764, 1.27872403]) 78 | orig_std_params = (4.3333333, 2.8674417556) 79 | original_data = unstandardize(data, orig_std_params) 80 | 81 | assert_almost_equal(original_data.values, np.array([[1.0, 4.0, 8.0]]).T) 82 | 83 | 84 | def test_df_print(): 85 | data_mock = mock.Mock() 86 | df_print(data_mock) 87 | data_mock.assert_not_called() 88 | 89 | df_print(data_mock, path="path") 90 | data_mock.to_csv.assert_called_once_with("path") 91 | -------------------------------------------------------------------------------- /tests/test_model.py: -------------------------------------------------------------------------------- 1 | """Unit Tests for model module""" 2 | 3 | import pytest 4 | import unittest.mock as mock 5 | import numpy as np 6 | from numpy.testing import assert_array_equal 7 | import pandas as pd 8 | 9 | import causalimpact 10 | from causalimpact.model import ModelResults 11 | from statsmodels.tsa.statespace.structural import UnobservedComponents 12 | 13 | observations_validate = causalimpact.model.observations_ill_conditioned 14 | construct_model = causalimpact.model.construct_model 15 | model_fit = causalimpact.model.model_fit 16 | 17 | 18 | def test_raises_when_y_is_None(): 19 | with pytest.raises(ValueError) as excinfo: 20 | observations_validate(None) 21 | assert str(excinfo.value) == "y cannot be None" 22 | 23 | 24 | def test_raises_when_y_has_len_1(): 25 | with pytest.raises(ValueError) as excinfo: 26 | observations_validate([1]) 27 | assert str(excinfo.value) == "y must have len > 1" 28 | 29 | 30 | def test_raises_when_y_is_all_nan(): 31 | with pytest.raises(ValueError) as excinfo: 32 | observations_validate([np.nan, np.nan]) 33 | assert str(excinfo.value) == ( 34 | "Aborting inference due to input series " "being all null." 35 | ) 36 | 37 | 38 | def test_raises_when_y_have_just_2_values(): 39 | with pytest.raises(ValueError) as excinfo: 40 | observations_validate(pd.DataFrame([1, 2])) 41 | assert str(excinfo.value) == ( 42 | "Aborting inference due to fewer than 3 " "nonnull values in input." 43 | ) 44 | 45 | 46 | def test_raises_when_y_is_constant(): 47 | with pytest.raises(ValueError) as excinfo: 48 | observations_validate(pd.Series([1, 1, 1, 1, 1])) 49 | assert str(excinfo.value) == ( 50 | "Aborting inference due to input series " "being constant" 51 | ) 52 | 53 | 54 | def test_model_constructor(): 55 | data = pd.DataFrame(np.random.randn(200, 2)) 56 | model = construct_model(data) 57 | assert_array_equal(model.data.endog, data.iloc[:, 0].values) 58 | assert model.irregular 59 | assert model.k_exog == data.shape[1] - 1 60 | assert model.level 61 | assert_array_equal( 62 | model.exog, data.iloc[:, 1].values.reshape(-1, data.shape[1] - 1) 63 | ) 64 | 65 | 66 | def test_model_constructor_w_just_endog(): 67 | data = pd.DataFrame(np.random.randn(200, 1)) 68 | model = construct_model(data) 69 | assert_array_equal(model.data.endog, data.iloc[:, 0].values) 70 | assert model.irregular 71 | assert model.k_exog == data.shape[1] - 1 72 | assert model.level 73 | assert not model.exog 74 | 75 | 76 | def test_model_fit_with_mle(): 77 | model_mock = mock.Mock() 78 | 79 | model_results = model_fit(model_mock, "MLE", {"niter": 50}) 80 | model_mock.fit.assert_called_once_with(maxiter=50) 81 | assert isinstance(model_results, ModelResults) 82 | 83 | 84 | def test_model_fit_with_pymc(): 85 | model_mock = UnobservedComponents([1, 2, 3, 4], level="llevel") 86 | model_results = model_fit(model_mock, "pymc", {"ndraws": 2, "nburn": 1}) 87 | assert isinstance(model_results, ModelResults) 88 | -------------------------------------------------------------------------------- /tox.ini: -------------------------------------------------------------------------------- 1 | # Tox configuration file 2 | # Read more under https://tox.wiki/ 3 | # THIS SCRIPT IS SUPPOSED TO BE AN EXAMPLE. MODIFY IT ACCORDING TO YOUR NEEDS! 4 | 5 | [tox] 6 | minversion = 3.24 7 | envlist = default 8 | isolated_build = True 9 | 10 | [flake8] 11 | max-line-length = 88 12 | extend-ignore = E203 13 | 14 | [testenv] 15 | description = Invoke pytest to run automated tests 16 | deps = 17 | numpy 18 | pandas 19 | statsmodels 20 | pymc 21 | pytest 22 | pytest-cov 23 | mock 24 | setenv = 25 | TOXINIDIR = {toxinidir} 26 | passenv = 27 | HOME 28 | extras = 29 | testing 30 | commands = 31 | pytest {posargs} 32 | 33 | 34 | # # To run `tox -e lint` you need to make sure you have a 35 | # # `.pre-commit-config.yaml` file. See https://pre-commit.com 36 | [testenv:lint] 37 | description = Perform static analysis and style checks 38 | skip_install = True 39 | deps = pre-commit 40 | passenv = 41 | HOMEPATH 42 | PROGRAMDATA 43 | commands = 44 | pre-commit run --all-files {posargs:--show-diff-on-failure} 45 | 46 | 47 | [testenv:{build,clean}] 48 | description = 49 | build: Build the package in isolation according to PEP517, see https://github.com/pypa/build 50 | clean: Remove old distribution files and temporary build artifacts (./build and ./dist) 51 | # https://setuptools.pypa.io/en/stable/build_meta.html#how-to-use-it 52 | skip_install = True 53 | changedir = {toxinidir} 54 | deps = 55 | build: build[virtualenv] 56 | commands = 57 | clean: python -c 'import shutil; [shutil.rmtree(p, True) for p in ("build", "dist", "docs/_build")]' 58 | clean: python -c 'import pathlib, shutil; [shutil.rmtree(p, True) for p in pathlib.Path("src").glob("*.egg-info")]' 59 | build: python -m build {posargs} 60 | 61 | 62 | [testenv:{docs,doctests,linkcheck}] 63 | description = 64 | docs: Invoke sphinx-build to build the docs 65 | doctests: Invoke sphinx-build to run doctests 66 | linkcheck: Check for broken links in the documentation 67 | setenv = 68 | DOCSDIR = {toxinidir}/docs 69 | BUILDDIR = {toxinidir}/docs/_build 70 | docs: BUILD = html 71 | doctests: BUILD = doctest 72 | linkcheck: BUILD = linkcheck 73 | deps = 74 | -r {toxinidir}/docs/requirements.txt 75 | # ^ requirements.txt shared with Read The Docs 76 | commands = 77 | sphinx-build --color -b {env:BUILD} -d "{env:BUILDDIR}/doctrees" "{env:DOCSDIR}" "{env:BUILDDIR}/{env:BUILD}" {posargs} 78 | 79 | 80 | [testenv:publish] 81 | description = 82 | Publish the package you have been developing to a package index server. 83 | By default, it uses testpypi. If you really want to publish your package 84 | to be publicly accessible in PyPI, use the `-- --repository pypi` option. 85 | skip_install = True 86 | changedir = {toxinidir} 87 | passenv = 88 | # See: https://twine.readthedocs.io/en/latest/ 89 | TWINE_USERNAME 90 | TWINE_PASSWORD 91 | TWINE_REPOSITORY 92 | deps = twine 93 | commands = 94 | python -m twine check dist/* 95 | python -m twine upload {posargs:--repository {env:TWINE_REPOSITORY:testpypi}} dist/* 96 | --------------------------------------------------------------------------------