├── .github ├── .stale.yml ├── ISSUE_TEMPLATE │ ├── bug_report.md │ └── feature_request.md ├── PULL_REQUEST_TEMPLATE.md └── workflows │ ├── black.yml │ ├── python-publish.yml │ ├── python-test.yaml │ ├── test-build-from-source.yml │ └── test-pypi-install.yml ├── .gitignore ├── .pre-commit-config.yaml ├── .readthedocs.yml ├── ANTITRUST.md ├── CHARTER.md ├── CODE_OF_CONDUCT.md ├── CONTRIBUTING.md ├── GOVERNANCE.md ├── LICENSE ├── MAINTAINERS.md ├── MANIFEST.in ├── Makefile ├── README.md ├── SECURITY.md ├── STEERING_COMMITTEE.md ├── TRADEMARKS.md ├── causalml ├── __init__.py ├── dataset │ ├── __init__.py │ ├── classification.py │ ├── regression.py │ └── synthetic.py ├── feature_selection │ ├── __init__.py │ └── filters.py ├── features.py ├── inference │ ├── __init__.py │ ├── iv │ │ ├── __init__.py │ │ ├── drivlearner.py │ │ └── iv_regression.py │ ├── meta │ │ ├── __init__.py │ │ ├── base.py │ │ ├── drlearner.py │ │ ├── explainer.py │ │ ├── rlearner.py │ │ ├── slearner.py │ │ ├── tlearner.py │ │ ├── tmle.py │ │ ├── utils.py │ │ └── xlearner.py │ ├── tf │ │ ├── __init__.py │ │ ├── dragonnet.py │ │ └── utils.py │ ├── torch │ │ ├── __init__.py │ │ └── cevae.py │ └── tree │ │ ├── __init__.py │ │ ├── _tree │ │ ├── __init__.py │ │ ├── _classes.py │ │ ├── _criterion.pxd │ │ ├── _criterion.pyx │ │ ├── _splitter.pxd │ │ ├── _splitter.pyx │ │ ├── _tree.pxd │ │ ├── _tree.pyx │ │ ├── _typedefs.pxd │ │ ├── _typedefs.pyx │ │ ├── _utils.pxd │ │ └── _utils.pyx │ │ ├── causal │ │ ├── __init__.py │ │ ├── _builder.pxd │ │ ├── _builder.pyx │ │ ├── _criterion.pxd │ │ ├── _criterion.pyx │ │ ├── _tree.py │ │ ├── causalforest.py │ │ └── causaltree.py │ │ ├── plot.py │ │ ├── uplift.pyx │ │ └── utils.py ├── match.py ├── metrics │ ├── __init__.py │ ├── classification.py │ ├── const.py │ ├── regression.py │ ├── sensitivity.py │ └── visualize.py ├── optimize │ ├── __init__.py │ ├── pns.py │ ├── policylearner.py │ ├── unit_selection.py │ ├── utils.py │ └── value_optimization.py └── propensity.py ├── docs ├── Makefile ├── _static │ └── img │ │ ├── auuc_table_vis.png │ │ ├── auuc_vis.png │ │ ├── counterfactual_value_optimization.png │ │ ├── logo │ │ ├── android-chrome-192x192.png │ │ ├── android-chrome-512x512.png │ │ ├── apple-touch-icon.png │ │ ├── causalml_logo.png │ │ ├── causalml_logo.svg │ │ ├── causalml_logo_square.png │ │ ├── causalml_logo_square_transparent.png │ │ ├── causalml_logo_transparent.png │ │ ├── favicon-16x16.png │ │ ├── favicon-32x32.png │ │ └── favicon.ico │ │ ├── meta_feature_imp_vis.png │ │ ├── meta_shap_dependence_vis.png │ │ ├── meta_shap_vis.png │ │ ├── sensitivity_selection_bias_r2.png │ │ ├── shap_vis.png │ │ ├── synthetic_dgp_bar_plot_multiple.png │ │ ├── synthetic_dgp_scatter_plot.png │ │ ├── synthetic_dgp_scatter_plot_multiple.png │ │ ├── uplift_tree_feature_imp_vis.png │ │ └── uplift_tree_vis.png ├── about.rst ├── causalml.rst ├── changelog.rst ├── conf.py ├── environment-py39-rtd.yml ├── examples.rst ├── examples │ ├── benchmark_simulation_studies.ipynb │ ├── binary_policy_learner_example.ipynb │ ├── calibration.ipynb │ ├── causal_trees_interpretation.ipynb │ ├── causal_trees_with_synthetic_data.ipynb │ ├── cevae_example.ipynb │ ├── counterfactual_unit_selection.ipynb │ ├── counterfactual_value_optimization.ipynb │ ├── data │ │ ├── card.csv │ │ ├── ihdp_npci_1.csv │ │ ├── ihdp_npci_2.csv │ │ ├── ihdp_npci_3.csv │ │ ├── ihdp_npci_4.csv │ │ ├── ihdp_npci_5.csv │ │ ├── ihdp_npci_6.csv │ │ ├── ihdp_npci_7.csv │ │ ├── ihdp_npci_8.csv │ │ └── ihdp_npci_9.csv │ ├── dr_learner_with_synthetic_data.ipynb │ ├── dragonnet_example.ipynb │ ├── feature_interpretations_example.ipynb │ ├── feature_selection.ipynb │ ├── iv_nlsym_synthetic_data.ipynb │ ├── logistic_regression_based_data_generation_for_uplift_classification.ipynb │ ├── meta_learners_with_synthetic_data.ipynb │ ├── meta_learners_with_synthetic_data_multiple_treatment.ipynb │ ├── necessary_and_sufficient.ipynb │ ├── qini_curves_for_costly_treatment_arms.ipynb │ ├── sensitivity_example_with_synthetic_data.ipynb │ ├── uplift_tree_visualization.ipynb │ ├── uplift_trees_with_synthetic_data.ipynb │ └── validation_with_tmle.ipynb ├── index.rst ├── installation.rst ├── interpretation.rst ├── methodology.rst ├── quickstart.rst ├── references.rst ├── refs.bib ├── requirements.txt └── validation.rst ├── pyproject.toml ├── setup.cfg ├── setup.py ├── tests ├── __init__.py ├── conftest.py ├── const.py ├── test_causal_trees.py ├── test_cevae.py ├── test_counterfactual_unit_selection.py ├── test_datasets.py ├── test_dragonnet.py ├── test_feature_selection.py ├── test_features.py ├── test_ivlearner.py ├── test_match.py ├── test_meta_learners.py ├── test_metrics.py ├── test_propensity.py ├── test_sensitivity.py ├── test_uplift_trees.py ├── test_utils.py ├── test_value_optimization.py └── test_visualize.py └── tox.ini /.github/.stale.yml: -------------------------------------------------------------------------------- 1 | # Number of days of inactivity before an issue becomes stale 2 | daysUntilStale: 60 3 | # Number of days of inactivity before a stale issue is closed 4 | daysUntilClose: 7 5 | # Issues with these labels will never be considered stale 6 | exemptLabels: 7 | - pinned 8 | - security 9 | # Label to use when marking an issue as stale 10 | staleLabel: wontfix 11 | # Comment to post when marking an issue as stale. Set to `false` to disable 12 | markComment: > 13 | This issue has been automatically marked as stale because it has not had 14 | recent activity. It will be closed if no further activity occurs. Thank you 15 | for your contributions. 16 | # Comment to post when closing a stale issue. Set to `false` to disable 17 | closeComment: false -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/bug_report.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Bug report 3 | about: Create a report to help us improve 4 | title: '' 5 | labels: bug 6 | assignees: '' 7 | 8 | --- 9 | 10 | **Describe the bug** 11 | A clear and concise description of what the bug is. 12 | 13 | **To Reproduce** 14 | Steps to reproduce the behavior: 15 | 16 | **Expected behavior** 17 | A clear and concise description of what you expected to happen. 18 | 19 | **Screenshots** 20 | If applicable, add screenshots to help explain your problem. 21 | 22 | **Environment (please complete the following information):** 23 | - OS: [e.g. macOS, Windows, Ubuntu] 24 | - Python Version: [e.g. 3.6, 3.7] 25 | - Versions of Major Dependencies (`pandas`, `scikit-learn`, `cython`): [e.g. `pandas==0.25`, `scikit-learn==0.22`, `cython==0.28`] 26 | 27 | **Additional context** 28 | Add any other context about the problem here. 29 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/feature_request.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Feature request 3 | about: Suggest an idea for this project 4 | title: '' 5 | labels: enhancement 6 | assignees: '' 7 | 8 | --- 9 | 10 | **Is your feature request related to a problem? Please describe.** 11 | A clear and concise description of what the problem is. Ex. I'm always frustrated when [...] 12 | 13 | **Describe the solution you'd like** 14 | A clear and concise description of what you want to happen. 15 | 16 | **Describe alternatives you've considered** 17 | A clear and concise description of any alternative solutions or features you've considered. 18 | 19 | **Additional context** 20 | Add any other context or screenshots about the feature request here. 21 | -------------------------------------------------------------------------------- /.github/PULL_REQUEST_TEMPLATE.md: -------------------------------------------------------------------------------- 1 | ## Proposed changes 2 | 3 | Describe the big picture of your changes here to communicate to the maintainers why we should accept this pull request. If it fixes a bug or resolves a feature request, be sure to link to that issue. 4 | 5 | ## Types of changes 6 | 7 | What types of changes does your code introduce to **CausalML**? 8 | _Put an `x` in the boxes that apply_ 9 | 10 | - [ ] Bugfix (non-breaking change which fixes an issue) 11 | - [ ] New feature (non-breaking change which adds functionality) 12 | - [ ] Breaking change (fix or feature that would cause existing functionality to not work as expected) 13 | - [ ] Documentation Update (if none of the other choices apply) 14 | 15 | ## Checklist 16 | 17 | _Put an `x` in the boxes that apply. You can also fill these out after creating the PR. If you're unsure about any of them, don't hesitate to ask. We're here to help! This is simply a reminder of what we are going to look for before merging your code._ 18 | 19 | - [ ] I have read the [CONTRIBUTING](https://github.com/uber/causalml/blob/master/CONTRIBUTING.md) doc 20 | - [ ] I have signed the [CLA](https://cla-assistant.io/uber/causalml) 21 | - [ ] Lint and unit tests pass locally with my changes 22 | - [ ] I have added tests that prove my fix is effective or that my feature works 23 | - [ ] I have added necessary documentation (if appropriate) 24 | - [ ] Any dependent changes have been merged and published in downstream modules 25 | 26 | ## Further comments 27 | 28 | If this is a relatively large or complex change, kick off the discussion by explaining why you chose the solution you did and what alternatives you considered, etc. This PR template is adopted from [appium](https://github.com/appium/appium). 29 | -------------------------------------------------------------------------------- /.github/workflows/black.yml: -------------------------------------------------------------------------------- 1 | name: Lint 2 | 3 | on: [push, pull_request] 4 | 5 | jobs: 6 | lint: 7 | runs-on: ubuntu-latest 8 | steps: 9 | - uses: actions/checkout@v2 10 | - uses: psf/black@stable -------------------------------------------------------------------------------- /.github/workflows/python-publish.yml: -------------------------------------------------------------------------------- 1 | # This workflow will upload a Python Package using Twine when a release is created 2 | # For more information see: https://help.github.com/en/actions/language-and-framework-guides/using-python-with-github-actions#publishing-to-package-registries 3 | # Updated with cibuildwheel: https://learn.scientific-python.org/development/guides/gha-wheels/ 4 | 5 | # This workflow uses actions that are not certified by GitHub. 6 | # They are provided by a third-party and are governed by 7 | # separate terms of service, privacy policy, and support 8 | # documentation. 9 | 10 | name: Upload Python Package 11 | 12 | on: 13 | workflow_dispatch: 14 | release: 15 | types: 16 | - published 17 | 18 | jobs: 19 | make_sdist: 20 | name: Make SDist 21 | runs-on: ubuntu-latest 22 | steps: 23 | - uses: actions/checkout@v4 24 | with: 25 | fetch-depth: 0 # Optional, use if you use setuptools_scm 26 | submodules: false # Optional, use if you have submodules 27 | 28 | - name: Build SDist 29 | run: pipx run build --sdist 30 | 31 | - uses: actions/upload-artifact@v4 32 | with: 33 | name: cibw-sdist 34 | path: dist/*.tar.gz 35 | 36 | build_wheels: 37 | name: Wheel on ${{ matrix.os }} 38 | runs-on: ${{ matrix.os }} 39 | strategy: 40 | fail-fast: false 41 | matrix: 42 | os: [ubuntu-latest, windows-latest, macos-13, macos-14] 43 | 44 | steps: 45 | - uses: actions/checkout@v4 46 | with: 47 | fetch-depth: 0 48 | submodules: true 49 | 50 | - uses: pypa/cibuildwheel@v2.22 51 | 52 | - name: Upload wheels 53 | uses: actions/upload-artifact@v4 54 | with: 55 | name: cibw-wheels-${{ matrix.os }} 56 | path: wheelhouse/*.whl 57 | 58 | upload_all: 59 | needs: [build_wheels, make_sdist] 60 | runs-on: ubuntu-latest 61 | if: github.event_name == 'release' && github.event.action == 'published' 62 | steps: 63 | - uses: actions/download-artifact@v4 64 | with: 65 | pattern: cibw-* 66 | path: dist 67 | merge-multiple: true 68 | 69 | - uses: pypa/gh-action-pypi-publish@release/v1 70 | with: 71 | password: ${{ secrets.PYPI_API_TOKEN }} 72 | -------------------------------------------------------------------------------- /.github/workflows/python-test.yaml: -------------------------------------------------------------------------------- 1 | name: build 2 | 3 | on: [push, pull_request] 4 | 5 | jobs: 6 | build: 7 | # [macos-latest, macos-latest, windows-latest] 8 | runs-on: ubuntu-latest 9 | strategy: 10 | # You can use PyPy versions in python-version. 11 | # For example, pypy3.10 12 | matrix: 13 | python-version: ["3.9", "3.10", "3.11", "3.12"] 14 | 15 | steps: 16 | - uses: actions/checkout@v4 17 | - name: Set up Python ${{ matrix.python-version }} 18 | uses: actions/setup-python@v4 19 | with: 20 | python-version: ${{ matrix.python-version }} 21 | allow-prereleases: true 22 | # You can test your matrix by printing the current Python version 23 | - name: Display Python version 24 | run: python --version 25 | - name: Install dependencies 26 | run: | 27 | sudo apt install graphviz 28 | pip install --upgrade pip 29 | pip install --upgrade setuptools 30 | python -m pip install -e ".[test]" 31 | - name: Test with pytest 32 | run: pytest -vs tests/ --cov causalml/ 33 | -------------------------------------------------------------------------------- /.github/workflows/test-build-from-source.yml: -------------------------------------------------------------------------------- 1 | name: Test build from source install 2 | 3 | on: 4 | push: 5 | paths: 6 | - 'envs/*.yml' 7 | - '.github/workflows/test-build-from-source.yml' 8 | 9 | jobs: 10 | build: 11 | name: ${{ matrix.os }}${{ matrix.tf-label }}-py${{ matrix.python-version }} 12 | runs-on: ${{ matrix.os }} 13 | 14 | defaults: 15 | run: 16 | shell: bash -l {0} 17 | 18 | strategy: 19 | fail-fast: false 20 | matrix: 21 | os: [ubuntu-latest] 22 | python-version: ["3.9", "3.10", "3.11", "3.12"] 23 | tf-label: ['', '-tf'] 24 | include: 25 | - python-version: "3.9" 26 | python-version-nd: 39 27 | - python-version: "3.10" 28 | python-version-nd: 310 29 | - python-version: "3.11" 30 | python-version-nd: 311 31 | - python-version: "3.12" 32 | python-version-nd: 312 33 | - tf-label: '-tf' 34 | tf-label-pip: ',tf' 35 | 36 | steps: 37 | - name: checkout repository 38 | uses: actions/checkout@v2 39 | 40 | - name: Set up Python ${{ matrix.python-version }} 41 | uses: actions/setup-python@v4 42 | with: 43 | python-version: ${{ matrix.python-version }} 44 | 45 | - name: Display Python version 46 | run: python -c "import sys; print(sys.version)" 47 | 48 | - name: create environment 49 | uses: conda-incubator/setup-miniconda@v2 50 | with: 51 | activate-environment: causalml${{ matrix.tf-label }}-py${{ matrix.python-version-nd }} 52 | python-version: ${{ matrix.python-version }} 53 | channels: defaults 54 | 55 | - name: install cxx-compiler 56 | run: | 57 | conda install -c conda-forge cxx-compiler 58 | conda install python-graphviz 59 | conda install -c conda-forge xorg-libxrender 60 | conda install -c conda-forge libxcrypt 61 | 62 | - name: echo conda config 63 | run: | 64 | conda info 65 | conda list 66 | conda config --show-sources 67 | conda config --show 68 | 69 | - name: Build 70 | run: | 71 | pip install -U pip 72 | pip install -U setuptools 73 | python -m pip install -e ".[test${{ matrix.tf-label-pip}}]" 74 | 75 | - name: Test with pytest 76 | run: pytest -vs tests/ --cov causalml/ 77 | 78 | - name: echo conda env 79 | run: | 80 | conda env export 81 | 82 | 83 | -------------------------------------------------------------------------------- /.github/workflows/test-pypi-install.yml: -------------------------------------------------------------------------------- 1 | name: Test PyPI install 2 | 3 | on: 4 | schedule: 5 | - cron: '0 0 1 * *' 6 | 7 | jobs: 8 | build: 9 | name: ${{ matrix.os }}-py${{ matrix.python-version }} 10 | runs-on: ${{ matrix.os }} 11 | strategy: 12 | fail-fast: false 13 | matrix: 14 | os: [ubuntu-latest] 15 | python-version: ["3.9", "3.10", "3.11", "3.12"] 16 | include: 17 | - python-version: "3.9" 18 | python-version-nd: 39 19 | - python-version: "3.10" 20 | python-version-nd: 310 21 | - python-version: "3.11" 22 | python-version-nd: 311 23 | - python-version: "3.12" 24 | python-version-nd: 312 25 | 26 | steps: 27 | - uses: actions/checkout@v2 28 | - name: Set up Python ${{ matrix.python-version }} 29 | uses: actions/setup-python@v2 30 | with: 31 | python-version: ${{ matrix.python-version }} 32 | - name: Display Python version 33 | run: python -c "import sys; print(sys.version)" 34 | 35 | - name: create environment 36 | uses: conda-incubator/setup-miniconda@v2 37 | with: 38 | activate-environment: causalml-py${{ matrix.python-version-nd }} 39 | python-version: ${{ matrix.python-version }} 40 | 41 | - name: Install using pip 42 | run: | 43 | pip install causalml 44 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | .eggs/ 2 | *.egg-info/ 3 | env_docs/ 4 | .idea 5 | .pytest_cache 6 | .vscode 7 | *.DS_Store 8 | __pycache__ 9 | build 10 | dist 11 | wheelhouse 12 | *.pyc 13 | _build/ 14 | .ipynb_checkpoints/ 15 | *.c 16 | *.cpp 17 | *.so 18 | .coverage* 19 | *.html 20 | *.prof 21 | .venv/ 22 | .python-version 23 | uv.lock -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | repos: 2 | - repo: https://github.com/pre-commit/pre-commit-hooks 3 | rev: v2.3.0 4 | hooks: 5 | - id: check-yaml 6 | - id: end-of-file-fixer 7 | - id: trailing-whitespace 8 | - repo: https://github.com/psf/black 9 | rev: 22.10.0 10 | hooks: 11 | - id: black 12 | -------------------------------------------------------------------------------- /.readthedocs.yml: -------------------------------------------------------------------------------- 1 | # Required 2 | version: 2 3 | 4 | # Set the OS, Python version and other tools you might need 5 | build: 6 | os: ubuntu-22.04 7 | tools: 8 | python: "miniconda3-4.7" 9 | 10 | conda: 11 | environment: docs/environment-py39-rtd.yml 12 | 13 | python: 14 | install: 15 | - method: pip 16 | path: . 17 | 18 | # Build documentation in the docs/ directory with Sphinx 19 | sphinx: 20 | configuration: docs/conf.py 21 | 22 | # Optionally build your docs in additional formats such as PDF and ePub 23 | formats: all 24 | 25 | # Optionally set the version of Python and requirements required to build your docs 26 | -------------------------------------------------------------------------------- /ANTITRUST.md: -------------------------------------------------------------------------------- 1 | # Antitrust Policy 2 | 3 | Participants acknowledge that they may compete with other participants in various lines of business and that it is therefore imperative that they and their respective representatives act in a manner that does not violate any applicable antitrust laws, competition laws, or associated regulations. This Policy does not restrict any participant from engaging in other similar projects. Each participant may design, develop, manufacture, acquire or market competitive deliverables, products, and services, and conduct its business, in whatever way it chooses. No participant is obligated to announce or market any products or services. Without limiting the generality of the foregoing, participants agree not to have any discussion relating to any product pricing, methods or channels of product distribution, contracts with third-parties, division or allocation of markets, geographic territories, or customers, or any other topic that relates in any way to limiting or lessening fair competition. 4 | 5 | --- 6 | Part of [MVG-0.1-beta](https://github.com/github/MVG/tree/v0.1-beta). 7 | Made with love by GitHub. Licensed under the [CC-BY 4.0 License](https://creativecommons.org/licenses/by-sa/4.0/). 8 | -------------------------------------------------------------------------------- /CHARTER.md: -------------------------------------------------------------------------------- 1 | # Charter for the CausalML Organization 2 | 3 | This is the organizational charter for the CausalML Organization (the "Organization"). By adding their name to the [Steering Committee.md file](./STEERING_COMMITTEE.md), Steering Committee members agree as follows. 4 | 5 | ## 1. Mission 6 | 7 | CausalML is committed to democratizing causal machine learning through accessible, innovative, and well-documented open-source tools that empower data scientists, researchers, and organizations. At our core, we embrace inclusivity and foster a vibrant community where members exchange ideas, share knowledge, and collaboratively shape a future where CausalML drives advancements across diverse domains. 8 | 9 | ## 2. Steering Committee 10 | 11 | **2.1 Purpose**. The Steering Committee will be responsible for all technical oversight, project approval and oversight, policy oversight, and trademark management for the Organization. 12 | 13 | **2.2 Composition**. The Steering Committee voting members are listed in the steering-committee.md file in the repository. 14 | Voting members may be added or removed by no less than 3/4 affirmative vote of the Steering Committee. 15 | The Steering Committee will appoint a Chair responsible for organizing Steering Committee activity. 16 | 17 | ## 3. Voting 18 | 19 | **3.1. Decision Making**. The Steering Committee will strive for all decisions to be made by consensus. While explicit agreement of the entire Steering Committee is preferred, it is not required for consensus. Rather, the Steering Committee will determine consensus based on their good faith consideration of a number of factors, including the dominant view of the Steering Committee and nature of support and objections. The Steering Committee will document evidence of consensus in accordance with these requirements. If consensus cannot be reached, the Steering Committee will make the decision by a vote. 20 | 21 | **3.2. Voting**. The Steering Committee Chair will call a vote with reasonable notice to the Steering Committee, setting out a discussion period and a separate voting period. Any discussion may be conducted in person or electronically by text, voice, or video. The discussion will be open to the public. In any vote, each voting representative will have one vote. Except as specifically noted elsewhere in this Charter, decisions by vote require a simple majority vote of all voting members. 22 | 23 | ## 4. Termination of Membership 24 | 25 | In addition to the method set out in section 2.2, the membership of a Steering Committee member will terminate if any of the following occur: 26 | 27 | **4.1 Resignation**. Written notice of resignation to the Steering Committee. 28 | 29 | **4.2 Unreachable Member**. If a member is unresponsive at its listed handle for more than three months the Steering Committee may vote to remove the member. 30 | 31 | ## 5. Trademarks 32 | 33 | Any names, trademarks, service marks, logos, mascots, or similar indicators of source or origin and the goodwill associated with them arising out of the Organization's activities or Organization projects' activities (the "Marks"), are controlled by the Organization. Steering Committee members may only use the Marks in accordance with the Organization's [trademark policy](./TRADEMARKS.md). If a Steering Committee member is terminated or removed from the Steering Committee, any rights the Steering Committee member may have in the Marks revert to the Organization. 34 | 35 | ## 6. Antitrust Policy 36 | 37 | The Steering Committee is bound by the Organization's [antitrust policy](./ANTITRUST.md). 38 | 39 | ## 7. No Confidentiality 40 | 41 | Information disclosed in connection with any of the Organization's activities, including but not limited to meetings, Contributions, and submissions, is not confidential, regardless of any markings or statements to the contrary. 42 | 43 | ## 8. Amendments 44 | 45 | Amendments to this charter, the [antitrust policy](./ANTITRUST.md), the [trademark policy](./TRADEMARKS.md), or the [code of conduct](./CODE_OF_CONDUCT.md) may only be made with at least a 3/4 affirmative vote of the Steering Committee. 46 | 47 | --- 48 | Adapted from [MVG-0.1-beta](https://github.com/github/MVG/tree/v0.1-beta). 49 | Made with love by GitHub. Licensed under the [CC-BY 4.0 License](https://creativecommons.org/licenses/by-sa/4.0/). 50 | -------------------------------------------------------------------------------- /CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | # Contributor Covenant Code of Conduct 2 | 3 | ## Our Pledge 4 | 5 | In the interest of fostering an open and welcoming environment, we as 6 | contributors and maintainers pledge to making participation in our project and 7 | our community a harassment-free experience for everyone, regardless of age, 8 | body size, disability, ethnicity, gender identity and expression, level of 9 | experience, nationality, personal appearance, race, religion, or sexual 10 | identity and orientation. 11 | 12 | ## Our Standards 13 | 14 | Examples of behavior that contributes to creating a positive environment 15 | include: 16 | 17 | * Using welcoming and inclusive language 18 | * Being respectful of differing viewpoints and experiences 19 | * Gracefully accepting constructive criticism 20 | * Focusing on what is best for the community 21 | * Showing empathy towards other community members 22 | 23 | Examples of unacceptable behavior by participants include: 24 | 25 | * The use of sexualized language or imagery and unwelcome sexual attention or 26 | advances 27 | * Trolling, insulting/derogatory comments, and personal or political attacks 28 | * Public or private harassment 29 | * Publishing others' private information, such as a physical or electronic 30 | address, without explicit permission 31 | * Other conduct which could reasonably be considered inappropriate in a 32 | professional setting 33 | 34 | ## Our Responsibilities 35 | 36 | Project maintainers are responsible for clarifying the standards of acceptable 37 | behavior and are expected to take appropriate and fair corrective action in 38 | response to any instances of unacceptable behavior. 39 | 40 | Project maintainers have the right and responsibility to remove, edit, or 41 | reject comments, commits, code, wiki edits, issues, and other contributions 42 | that are not aligned to this Code of Conduct, or to ban temporarily or 43 | permanently any contributor for other behaviors that they deem inappropriate, 44 | threatening, offensive, or harmful. 45 | 46 | ## Scope 47 | 48 | This Code of Conduct applies both within project spaces and in public spaces 49 | when an individual is representing the project or its community. Examples of 50 | representing a project or community include using an official project e-mail 51 | address, posting via an official social media account, or acting as an 52 | appointed representative at an online or offline event. Representation of a 53 | project may be further defined and clarified by project maintainers. 54 | 55 | ## Enforcement 56 | 57 | Instances of abusive, harassing, or otherwise unacceptable behavior may be 58 | reported by contacting the project team at oss-conduct@uber.com. The project 59 | team will review and investigate all complaints, and will respond in a way 60 | that it deems appropriate to the circumstances. The project team is obligated 61 | to maintain confidentiality with regard to the reporter of an incident. 62 | Further details of specific enforcement policies may be posted separately. 63 | 64 | Project maintainers who do not follow or enforce the Code of Conduct in good 65 | faith may face temporary or permanent repercussions as determined by other 66 | members of the project's leadership. 67 | 68 | ## Attribution 69 | 70 | This Code of Conduct is adapted from the [Contributor Covenant][homepage], 71 | version 1.4, available at 72 | [http://contributor-covenant.org/version/1/4][version]. 73 | 74 | [homepage]: http://contributor-covenant.org 75 | [version]: http://contributor-covenant.org/version/1/4/ 76 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contributing to CausalML 2 | 3 | The **CausalML** project welcome community contributors. 4 | To contribute to it, please follow guidelines here. 5 | 6 | The codebase is hosted on Github at https://github.com/uber/causalml. 7 | 8 | We use [`black`](https://black.readthedocs.io/en/stable/index.html) as a formatter to keep the coding style and format across all Python files consistent and compliant with [PEP8](https://www.python.org/dev/peps/pep-0008/). We recommend that you add `black` to your IDE as a formatter (see the [instruction](https://black.readthedocs.io/en/stable/integrations/editors.html)) or run `black` on the command line before submitting a PR as follows: 9 | ```bash 10 | # move to the top directory of the causalml repository 11 | $ cd causalml 12 | $ pip install -U black 13 | $ black . 14 | ``` 15 | 16 | Additionally, you can set up black and other tools we use to run before any commit is made via: 17 | ```bash 18 | make setup_local 19 | ``` 20 | 21 | As a start, please check out outstanding [issues](https://github.com/uber/causalml/issues). 22 | If you'd like to contribute to something else, open a new issue for discussion first. 23 | 24 | ## Development Workflow :computer: 25 | 26 | 1. Fork the `causalml` repo. This will create your own copy of the `causalml` repo. For more details about forks, please check [this guide](https://docs.github.com/en/github/collaborating-with-pull-requests/working-with-forks/about-forks) at GitHub. 27 | 2. Clone the forked repo locally 28 | 3. (optional) Complete local installation by running: 29 | ```bash 30 | make setup_local 31 | ``` 32 | 4. Create a branch for the change: 33 | ```bash 34 | $ git checkout -b branch_name 35 | ``` 36 | 5. Make a change 37 | 6. Test your change as described below in the Test section 38 | 7. Commit the change to your local branch 39 | ```bash 40 | $ git add file1_changed file2_changed 41 | $ git commit -m "Issue number: message to describe the change." 42 | ``` 43 | 8. Push your local branch to remote 44 | ```bash 45 | $ git push origin branch_name 46 | ``` 47 | 9. Go to GitHub and create PR from your branch in your forked repo to the original `causalml` repo. An instruction to create a PR from a fork is available [here](https://docs.github.com/en/github/collaborating-with-pull-requests/proposing-changes-to-your-work-with-pull-requests/creating-a-pull-request-from-a-fork) 48 | 49 | ## Documentation :books: 50 | 51 | [**CausalML** documentation](https://causalml.readthedocs.io/) is generated with [Sphinx](https://www.sphinx-doc.org/en/master/) and hosted on [Read the Docs](https://readthedocs.org/). 52 | 53 | ### Docstrings 54 | 55 | All public classes and functions should have docstrings to specify their inputs, outputs, behaviors and/or examples. For docstring conventions in Python, please refer to [PEP257](https://www.python.org/dev/peps/pep-0257/). 56 | 57 | **CausalML** supports the NumPy and Google style docstrings in addition to Python's original docstring with [`sphinx.ext.napoleon`](https://www.sphinx-doc.org/en/master/usage/extensions/napoleon.html). Google style docstrings are recommended for simplicity. You can find examples of Google style docstrings [here](https://sphinxcontrib-napoleon.readthedocs.io/en/latest/example_google.html) 58 | 59 | ### Generating Documentation Locally 60 | 61 | You can generate documentation in HTML locally as follows: 62 | ```bash 63 | $ cd docs/ 64 | $ pip install -r requirements.txt 65 | $ make html 66 | ``` 67 | 68 | Documentation will be available in `docs/_build/html/index.html`. 69 | 70 | ## Test :wrench: 71 | 72 | If you added a new inference method, add test code to the `tests/` folder. 73 | 74 | ### Prerequisites 75 | 76 | **CausalML** uses `pytest` for tests. Install `pytest` and `pytest-cov`, and the package dependencies: 77 | ```bash 78 | $ pip install .[test] 79 | ``` 80 | See details for test dependencies in `pyproject.toml` 81 | 82 | ### Building Cython 83 | 84 | In order to run tests, you need to build the Cython modules 85 | ```bash 86 | $ python setup.py build_ext --inplace 87 | ``` 88 | This is important because during testing causalml modules are imported from the source code. 89 | 90 | ### Testing 91 | 92 | Before submitting a PR, make sure the change to pass all tests and test coverage to be at least 70%. 93 | ```bash 94 | $ pytest -vs tests/ --cov causalml/ 95 | ``` 96 | 97 | To run tests that require tensorflow (i.e. DragonNet), make sure tensorflow is installed and include the `--runtf` option with the `pytest` command. For example: 98 | 99 | ```bash 100 | $ pytest --runtf -vs tests/test_dragonnet.py 101 | ``` 102 | 103 | You can also run tests via make: 104 | ```bash 105 | $ make test 106 | ``` 107 | 108 | 109 | 110 | ## Submission :tada: 111 | 112 | In your PR, please include: 113 | - Changes made 114 | - Links to related issues/PRs 115 | - Tests 116 | - Dependencies 117 | - References 118 | 119 | Please add the core Causal ML contributors as reviewers. 120 | 121 | ## Maintain in `conda-forge` :snake: 122 | 123 | We are supporting to install the package through `conda`, in order to maintain the packages in conda we need to keep the package's version in conda's recipe repository [here](https://github.com/conda-forge/causalml-feedstock) in sync with `CausalML`. You can follow the [instruction](https://conda-forge.org/#update_recipe) from conda or below steps: 124 | 125 | 1. After a new release of the package, fork the repo. 126 | 2. Create a new branch from the master branch. 127 | 3. Edit the recipe: 128 | - Update the version number [here](https://github.com/conda-forge/causalml-feedstock/blob/main/recipe/meta.yaml#L2) in `meta.yaml` 129 | - Generate the new sha256 hash and update it [here](https://github.com/conda-forge/causalml-feedstock/blob/main/recipe/meta.yaml#L11): the sha256 hash can get from PyPi; look for the SHA256 link next to the download link on PyPi package’s files page, e.g. https://pypi.org/project/causalml/#files 130 | - Reset the build number to 0 131 | - Update the dependencies if needed 132 | 4. Submit the PR and the recipe will automatically be built; 133 | 134 | Once the recipe is ready it will be merged. The recipe will then automatically be built and uploaded to the conda-forge channel. 135 | -------------------------------------------------------------------------------- /GOVERNANCE.md: -------------------------------------------------------------------------------- 1 | # Governance Policy 2 | 3 | This document provides the governance policy for the Project. Maintainers agree to this policy and to abide by all Project polices, including the [code of conduct](./CODE_OF_CONDUCT.md), [trademark policy](./TRADEMARKS.md), and [antitrust policy](./ANTITRUST.md) by adding their name to the [maintainers.md file](./MAINTAINERS.md). 4 | 5 | ## 1. Roles. 6 | 7 | This project may include the following roles. Additional roles may be adopted and documented by the Project. 8 | 9 | **1.1. Maintainers**. Maintainers are responsible for organizing activities around developing, maintaining, and updating the Project. Maintainers are also responsible for determining consensus. This Project may add or remove Maintainers with the approval of the current Maintainers. 10 | 11 | **1.2. Contributors**. Contributors are those that have made contributions to the Project. 12 | 13 | ## 2. Decisions. 14 | 15 | **2.1. Consensus-Based Decision Making**. Projects make decisions through consensus of the Maintainers. While explicit agreement of all Maintainers is preferred, it is not required for consensus. Rather, the Maintainers will determine consensus based on their good faith consideration of a number of factors, including the dominant view of the Contributors and nature of support and objections. The Maintainers will document evidence of consensus in accordance with these requirements. 16 | 17 | **2.2. Appeal Process**. Decisions may be appealed by opening an issue and that appeal will be considered by the Maintainers in good faith, who will respond in writing within a reasonable time. If the Maintainers deny the appeal, the appeal may be brought before the Organization Steering Committee, who will also respond in writing in a reasonable time. 18 | 19 | 20 | ## 3. Termination of Membership 21 | 22 | The membership of a Maintainer will terminate if any of the following occur: 23 | 24 | **3.1 Resignation**. Written notice of resignation to the Maintainers. 25 | 26 | **3.2 Unreachable Member**. If a member is unresponsive at its listed handle for more than three months the Maintainers may vote to remove the member. 27 | 28 | ## 4. How We Work. 29 | 30 | **4.1. Openness**. Participation is open to anyone who is directly and materially affected by the activity in question. There shall be no undue financial barriers to participation. 31 | 32 | **4.2. Balance**. The development process should balance the interests of Contributors and other stakeholders. Contributors from diverse interest categories shall be sought with the objective of achieving balance. 33 | 34 | **4.3. Coordination and Harmonization**. Good faith efforts shall be made to resolve potential conflicts or incompatibility between releases in this Project. 35 | 36 | **4.4. Consideration of Views and Objections**. Prompt consideration shall be given to the written views and objections of all Contributors. 37 | 38 | **4.5. Written procedures**. This governance document and other materials documenting this project's development process shall be available to any interested person. 39 | 40 | ## 5. No Confidentiality. 41 | 42 | Information disclosed in connection with any Project activity, including but not limited to meetings, contributions, and submissions, is not confidential, regardless of any markings or statements to the contrary. 43 | 44 | ## 6. Trademarks. 45 | 46 | Any names, trademarks, logos, or goodwill developed by and associated with the Project (the "Marks") are controlled by the Organization. Maintainers may only use these Marks in accordance with the Organization's trademark policy. If a Maintainer resigns or is removed, any rights the Maintainer may have in the Marks revert to the Organization. 47 | 48 | ## 7. Amendments. 49 | 50 | Amendments to this governance policy may be made by affirmative vote of 2/3 of all Maintainers, with approval by the Organization's Steering Committee. 51 | 52 | --- 53 | Adapted from [MVG-0.1-beta](https://github.com/github/MVG/tree/v0.1-beta). 54 | Made with love by GitHub. Licensed under the [CC-BY 4.0 License](https://creativecommons.org/licenses/by-sa/4.0/). 55 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Copyright 2019 Uber Technology, Inc. 2 | 3 | Licensed under the Apache License, Version 2.0 (the "License"); 4 | you may not use this file except in compliance with the License. 5 | You may obtain a copy of the License at 6 | 7 | http://www.apache.org/licenses/LICENSE-2.0 8 | 9 | Unless required by applicable law or agreed to in writing, software 10 | distributed under the License is distributed on an "AS IS" BASIS, 11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | See the License for the specific language governing permissions and 13 | limitations under the License. -------------------------------------------------------------------------------- /MAINTAINERS.md: -------------------------------------------------------------------------------- 1 | # Maintainers 2 | 3 | This document lists the Maintainers of the Project. Maintainers may be added once approved by the existing maintainers as described in the [Governance document](./GOVERNANCE.md). By adding your name to this list you are agreeing to abide by the Project governance documents and to abide by all of the Organization's polices, including the [code of conduct](./CODE-OF-CONDUCT.md), [trademark policy](./TRADEMARKS.md), and [antitrust policy](./ANTITRUST.md). If you are participating because of your affiliation with another organization (designated below), you represent that you have the authority to bind that organization to these policies. 4 | 5 | | **NAME** | **Handle** | 6 | | --- | --- | 7 | | Huigang Chen | @huigangchen | 8 | | Totte Harinen | @t-tte | 9 | | Jeong-Yoon Lee | @jeongyoonlee | 10 | | Paul Lo | @paullo0106 | 11 | | Jing Pan | @ppstacy | 12 | | Alexander Popkov | @alexander-pv | 13 | | Roland Stevenson | @ras44 | 14 | | Yifeng Wu | @vincewu51 | 15 | | Zhenyu Zhao | @zhenyuz0500 | 16 | 17 | ## Previous Maintainers 18 | 19 | | **NAME** | **Handle** | 20 | | --- | --- | 21 | | Mike Yung | @yungmsh | 22 | | Yuchen Luo | @yluogit | 23 | | Steve Yang | @steveyang90 | 24 | 25 | --- 26 | Adapted from [MVG-0.1-beta](https://github.com/github/MVG/tree/v0.1-beta). 27 | Made with love by GitHub. Licensed under the [CC-BY 4.0 License](https://creativecommons.org/licenses/by-sa/4.0/). 28 | -------------------------------------------------------------------------------- /MANIFEST.in: -------------------------------------------------------------------------------- 1 | # Include the README 2 | include *.txt *.md 3 | recursive-include docs *.txt 4 | recursive-include causalml *.pyx *.pxd *.c *.h 5 | 6 | # Include the license file 7 | include LICENSE 8 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | .PHONY: build_ext 2 | build_ext: clean 3 | python setup.py build_ext --force --inplace 4 | 5 | .PHONY: build 6 | build: build_ext 7 | python setup.py bdist_wheel 8 | 9 | .PHONY: install 10 | install: build_ext 11 | pip install . 12 | 13 | .PHONY: test 14 | test: build_ext 15 | pytest -vs --cov causalml/ 16 | python setup.py clean --all 17 | 18 | .PHONY: clean 19 | clean: 20 | python setup.py clean --all 21 | rm -rf ./build ./dist ./eggs ./causalml.egg-info 22 | find ./causalml -type f \( -name "*.so" -o -name "*.c" -o -name "*.html" \) -delete 23 | 24 | .PHONY: setup_local 25 | setup_local: 26 | pip install pre-commit 27 | pre-commit install 28 | -------------------------------------------------------------------------------- /SECURITY.md: -------------------------------------------------------------------------------- 1 | # Security Policy 2 | 3 | ## Supported Versions 4 | 5 | | Version | Supported | 6 | | ------- | ------------------ | 7 | | all | :white_check_mark: | 8 | 9 | ## Reporting a Vulnerability 10 | 11 | Please report any vulnerabilities to causalml@uber.com 12 | -------------------------------------------------------------------------------- /STEERING_COMMITTEE.md: -------------------------------------------------------------------------------- 1 | # Steering Committee 2 | 3 | This document lists the members of the Organization's Steering Committee. Voting members may be added once approved by the Steering Committee as described in the [charter](./CHARTER.md). By adding your name to this list you are agreeing to abide by all Organization polices, including the [charter](./CHARTER.md), the [code of conduct](./CODE_OF_CONDUCT.md), the [trademark policy](./TRADEMARKS.md), and the [antitrust policy](./ANTITRUST.md). If you are serving on the Steering Committee because of your affiliation with another organization (designated below), you represent that you have authority to bind that organization to these policies. 4 | 5 | | **NAME** | **Handle** | **Affiliated Organization** | 6 | | --- | --- | --- | 7 | | Huigang Chen | @huigangchen | Meta | 8 | | Totte Harinen | @t-tte | AirBnB | 9 | | Jeong-Yoon Lee | @jeongyoonlee | Uber | 10 | | Zhenyu Zhao | @zhenyuz0500 | Tencent | 11 | 12 | --- 13 | Adapted from [MVG-0.1-beta](https://github.com/github/MVG/tree/v0.1-beta). 14 | Made with love by GitHub. Licensed under the [CC-BY 4.0 License](https://creativecommons.org/licenses/by-sa/4.0/). 15 | -------------------------------------------------------------------------------- /TRADEMARKS.md: -------------------------------------------------------------------------------- 1 | ## Introduction 2 | 3 | This is the Organization's policy for the use of our trademarks. While our work is available under free and open source software licenses, those licenses do not include a license to use our trademarks. 4 | 5 | This policy describes how you may use our trademarks. Our goal is to strike a balance between: 1) our need to ensure that our trademarks remain reliable indicators of the quality software we release; and 2) our community members' desire to be full participants in our Organization. 6 | 7 | ## Our Trademarks 8 | 9 | This policy covers the name of the Organization and each of the Organization's projects, as well as any associated names, trademarks, service marks, logos, mascots, or similar indicators of source or origin (our "Marks"). 10 | 11 | ## In General 12 | 13 | Whenever you use our Marks, you must always do so in a way that does not mislead anyone about exactly who is the source of the software. For example, you cannot say you are distributing the "Mark" software when you're distributing a modified version of it because people will believe they are getting the same software that they can get directly from us when they aren't. You also cannot use our Marks on your website in a way that suggests that your website is an official Organization website or that we endorse your website. But, if true, you can say you like the "Mark" software, that you participate in the "Mark" community, that you are providing an unmodified version of the "Mark" software, or that you wrote a book describing how to use the "Mark" software. 14 | 15 | This fundamental requirement, that it is always clear to people what they are getting and from whom, is reflected throughout this policy. It should also serve as your guide if you are not sure about how you are using the Marks. 16 | 17 | In addition: 18 | * You may not use or register, in whole or in part, the Marks as part of your own trademark, service mark, domain name, company name, trade name, product name or service name. 19 | * Trademark law does not allow your use of names or trademarks that are too similar to ours. You therefore may not use an obvious variation of any of our Marks or any phonetic equivalent, foreign language equivalent, takeoff, or abbreviation for a similar or compatible product or service. 20 | * You agree that any goodwill generated by your use of the Marks and participation in our community inures solely to our collective benefit. 21 | 22 | ## Distribution of unmodified source code or unmodified executable code we have compiled 23 | 24 | When you redistribute an unmodified copy of our software, you are not changing the quality or nature of it. Therefore, you may retain the Marks we have placed on the software to identify your redistribution. This kind of use only applies if you are redistributing an official distribution from this Project that has not been changed in any way. 25 | 26 | ## Distribution of executable code that you have compiled, or modified code 27 | 28 | You may use any word marks, but not any Organization logos, to truthfully describe the origin of the software that you are providing, that is, that the code you are distributing is a modification of our software. You may say, for example, that "this software is derived from the source code for 'Mark' software." 29 | 30 | Of course, you can place your own trademarks or logos on versions of the software to which you have made substantive modifications, because by modifying the software, you have become the origin of that exact version. In that case, you should not use our Marks. 31 | 32 | However, you may use our Marks for the distribution of code (source or executable) on the condition that any executable is built from the official Project source code and that any modifications are limited to switching on or off features already included in the software, translations into other languages, and incorporating minor bug-fix patches. Use of our Marks on any further modification is not permitted. 33 | 34 | ## Statements about your software's relation to our software 35 | 36 | You may use the word Marks, but not the Organization's logos, to truthfully describe the relationship between your software and ours. Our Mark should be used after a verb or preposition that describes the relationship between your software and ours. So you may say, for example, "Bob's software for the 'Mark' platform" but may not say "Bob's 'Mark' software." Some other examples that may work for you are: 37 | 38 | * [Your software] uses "Mark" software 39 | * [Your software] is powered by "Mark" software 40 | * [Your software] runs on "Mark" software 41 | * [Your software] for use with "Mark" software 42 | * [Your software] for Mark software 43 | 44 | These guidelines are based on the [Model Trademark Guidelines](http://www.modeltrademarkguidelines.org), used under a [Creative Commons Attribution 3.0 Unported license](https://creativecommons.org/licenses/by/3.0/deed.en_US) 45 | -------------------------------------------------------------------------------- /causalml/__init__.py: -------------------------------------------------------------------------------- 1 | __all__ = [ 2 | "dataset", 3 | "features", 4 | "feature_selection", 5 | "inference", 6 | "match", 7 | "metrics", 8 | "optimize", 9 | "propensity", 10 | ] 11 | -------------------------------------------------------------------------------- /causalml/dataset/__init__.py: -------------------------------------------------------------------------------- 1 | from .regression import synthetic_data 2 | from .regression import simulate_nuisance_and_easy_treatment 3 | from .regression import simulate_randomized_trial 4 | from .regression import simulate_easy_propensity_difficult_baseline 5 | from .regression import simulate_unrelated_treatment_control 6 | from .regression import simulate_hidden_confounder 7 | from .classification import make_uplift_classification 8 | from .classification import make_uplift_classification_logistic 9 | 10 | from .synthetic import get_synthetic_preds, get_synthetic_preds_holdout 11 | from .synthetic import get_synthetic_summary, get_synthetic_summary_holdout 12 | from .synthetic import scatter_plot_summary, scatter_plot_summary_holdout 13 | from .synthetic import bar_plot_summary, bar_plot_summary_holdout 14 | from .synthetic import distr_plot_single_sim 15 | from .synthetic import scatter_plot_single_sim 16 | from .synthetic import get_synthetic_auuc 17 | -------------------------------------------------------------------------------- /causalml/feature_selection/__init__.py: -------------------------------------------------------------------------------- 1 | from .filters import FilterSelect 2 | -------------------------------------------------------------------------------- /causalml/inference/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/uber/causalml/5f5c4fb17cbadc79314e0fbc1dc0bc3f31f8a168/causalml/inference/__init__.py -------------------------------------------------------------------------------- /causalml/inference/iv/__init__.py: -------------------------------------------------------------------------------- 1 | from .iv_regression import IVRegressor 2 | from .drivlearner import BaseDRIVLearner, BaseDRIVRegressor, XGBDRIVRegressor 3 | -------------------------------------------------------------------------------- /causalml/inference/iv/iv_regression.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from causalml.inference.meta.utils import convert_pd_to_np 4 | import statsmodels.api as sm 5 | from statsmodels.sandbox.regression.gmm import IV2SLS 6 | 7 | 8 | class IVRegressor: 9 | """A wrapper class that uses IV2SLS from statsmodel 10 | 11 | A linear 2SLS model that estimates the average treatment effect with endogenous treatment variable. 12 | """ 13 | 14 | def __init__(self): 15 | """ 16 | Initializes the class. 17 | """ 18 | 19 | self.method = "2SLS" 20 | 21 | def fit(self, X, treatment, y, w): 22 | """Fits the 2SLS model. 23 | 24 | Args: 25 | X (np.matrix or np.array or pd.Dataframe): a feature matrix 26 | treatment (np.array or pd.Series): a treatment vector 27 | y (np.array or pd.Series): an outcome vector 28 | w (np.array or pd.Series): an instrument vector 29 | """ 30 | 31 | X, treatment, y, w = convert_pd_to_np(X, treatment, y, w) 32 | 33 | exog = sm.add_constant(np.c_[X, treatment]) 34 | endog = y 35 | instrument = sm.add_constant(np.c_[X, w]) 36 | 37 | self.iv_model = IV2SLS(endog=endog, exog=exog, instrument=instrument) 38 | self.iv_fit = self.iv_model.fit() 39 | 40 | def predict(self): 41 | """Returns the average treatment effect and its estimated standard error 42 | 43 | Returns: 44 | (float): average treatment effect 45 | (float): standard error of the estimation 46 | """ 47 | 48 | return self.iv_fit.params[-1], self.iv_fit.bse[-1] 49 | -------------------------------------------------------------------------------- /causalml/inference/meta/__init__.py: -------------------------------------------------------------------------------- 1 | from .slearner import LRSRegressor, BaseSLearner, BaseSRegressor, BaseSClassifier 2 | from .tlearner import ( 3 | XGBTRegressor, 4 | MLPTRegressor, 5 | BaseTLearner, 6 | BaseTRegressor, 7 | BaseTClassifier, 8 | ) 9 | from .xlearner import BaseXLearner, BaseXRegressor, BaseXClassifier 10 | from .rlearner import BaseRLearner, BaseRRegressor, BaseRClassifier, XGBRRegressor 11 | from .tmle import TMLELearner 12 | from .drlearner import BaseDRLearner, BaseDRRegressor, XGBDRRegressor 13 | -------------------------------------------------------------------------------- /causalml/inference/meta/utils.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | import numpy as np 3 | 4 | from packaging import version 5 | from xgboost import __version__ as xgboost_version 6 | 7 | 8 | def convert_pd_to_np(*args): 9 | output = [obj.to_numpy() if hasattr(obj, "to_numpy") else obj for obj in args] 10 | return output if len(output) > 1 else output[0] 11 | 12 | 13 | def check_treatment_vector(treatment, control_name=None): 14 | n_unique_treatments = np.unique(treatment).shape[0] 15 | assert n_unique_treatments > 1, "Treatment vector must have at least two levels." 16 | if control_name is not None: 17 | assert ( 18 | control_name in treatment 19 | ), "Control group level {} not found in treatment vector.".format(control_name) 20 | 21 | 22 | def check_p_conditions(p, t_groups): 23 | eps = np.finfo(float).eps 24 | assert isinstance( 25 | p, (np.ndarray, pd.Series, dict) 26 | ), "p must be an np.ndarray, pd.Series, or dict type" 27 | if isinstance(p, (np.ndarray, pd.Series)): 28 | assert ( 29 | t_groups.shape[0] == 1 30 | ), "If p is passed as an np.ndarray, there must be only 1 unique non-control group in the treatment vector." 31 | assert (0 + eps < p).all() and ( 32 | p < 1 - eps 33 | ).all(), "The values of p should lie within the (0, 1) interval." 34 | 35 | if isinstance(p, dict): 36 | for t_name in t_groups: 37 | assert (0 + eps < p[t_name]).all() and ( 38 | p[t_name] < 1 - eps 39 | ).all(), "The values of p should lie within the (0, 1) interval." 40 | 41 | 42 | def check_explain_conditions(method, models, X=None, treatment=None, y=None): 43 | valid_methods = ["gini", "permutation", "shapley"] 44 | assert method in valid_methods, "Current supported methods: {}".format( 45 | ", ".join(valid_methods) 46 | ) 47 | 48 | if method in ("gini", "shapley"): 49 | conds = [hasattr(mod, "feature_importances_") for mod in models] 50 | assert all( 51 | conds 52 | ), "Both models must have .feature_importances_ attribute if method = {}".format( 53 | method 54 | ) 55 | 56 | if method in ("permutation", "shapley"): 57 | assert all( 58 | arr is not None for arr in (X, treatment, y) 59 | ), "X, treatment, and y must be provided if method = {}".format(method) 60 | 61 | 62 | def clean_xgboost_objective(objective): 63 | """ 64 | Translate objective to be compatible with loaded xgboost version 65 | 66 | Args 67 | ---- 68 | 69 | objective : string 70 | The objective to translate. 71 | 72 | Returns 73 | ------- 74 | The translated objective, or original if no translation was required. 75 | """ 76 | compat_before_v83 = {"reg:squarederror": "reg:linear"} 77 | compat_v83_or_later = {"reg:linear": "reg:squarederror"} 78 | if version.parse(xgboost_version) < version.parse("0.83"): 79 | if objective in compat_before_v83: 80 | objective = compat_before_v83[objective] 81 | else: 82 | if objective in compat_v83_or_later: 83 | objective = compat_v83_or_later[objective] 84 | return objective 85 | 86 | 87 | def get_xgboost_objective_metric(objective): 88 | """ 89 | Get the xgboost version-compatible objective and evaluation metric from a potentially version-incompatible input. 90 | 91 | Args 92 | ---- 93 | 94 | objective : string 95 | An xgboost objective that may be incompatible with the installed version. 96 | 97 | Returns 98 | ------- 99 | A tuple with the translated objective and evaluation metric. 100 | """ 101 | 102 | def clean_dict_keys(orig): 103 | return {clean_xgboost_objective(k): v for (k, v) in orig.items()} 104 | 105 | metric_mapping = clean_dict_keys( 106 | {"rank:pairwise": "auc", "reg:squarederror": "rmse"} 107 | ) 108 | 109 | objective = clean_xgboost_objective(objective) 110 | 111 | assert ( 112 | objective in metric_mapping 113 | ), "Effect learner objective must be one of: " + ", ".join(metric_mapping) 114 | return objective, metric_mapping[objective] 115 | 116 | 117 | def get_weighted_variance(x, sample_weight): 118 | """ 119 | Calculate the variance of array x with sample_weight. 120 | 121 | Args 122 | ---- 123 | 124 | x : (np.array) 125 | A list of number 126 | 127 | sample_weight (np.array or list): an array of sample weights indicating the 128 | weight of each observation for `effect_learner`. If None, it assumes equal weight. 129 | 130 | Returns 131 | ------- 132 | The variance of x with sample weight 133 | """ 134 | average = np.average(x, weights=sample_weight) 135 | variance = np.average((x - average) ** 2, weights=sample_weight) 136 | return variance 137 | -------------------------------------------------------------------------------- /causalml/inference/tf/__init__.py: -------------------------------------------------------------------------------- 1 | from .dragonnet import DragonNet 2 | -------------------------------------------------------------------------------- /causalml/inference/tf/utils.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | from tensorflow.keras import backend as K 3 | from tensorflow.keras.layers import Layer 4 | from tensorflow.keras.metrics import binary_accuracy 5 | 6 | 7 | def binary_classification_loss(concat_true, concat_pred): 8 | """ 9 | Implements a classification (binary cross-entropy) loss function for DragonNet architecture. 10 | 11 | Args: 12 | - concat_true (tf.tensor): tensor of true samples, with shape (n_samples, 2) 13 | Each row in concat_true is comprised of (y, treatment) 14 | - concat_pred (tf.tensor): tensor of predictions, with shape (n_samples, 4) 15 | Each row in concat_pred is comprised of (y0, y1, propensity, epsilon) 16 | Returns: 17 | - (float): binary cross-entropy loss 18 | """ 19 | t_true = concat_true[:, 1] 20 | t_pred = concat_pred[:, 2] 21 | t_pred = (t_pred + 0.001) / 1.002 22 | losst = tf.reduce_sum(K.binary_crossentropy(t_true, t_pred)) 23 | 24 | return losst 25 | 26 | 27 | def regression_loss(concat_true, concat_pred): 28 | """ 29 | Implements a regression (squared error) loss function for DragonNet architecture. 30 | 31 | Args: 32 | - concat_true (tf.tensor): tensor of true samples, with shape (n_samples, 2) 33 | Each row in concat_true is comprised of (y, treatment) 34 | - concat_pred (tf.tensor): tensor of predictions, with shape (n_samples, 4) 35 | Each row in concat_pred is comprised of (y0, y1, propensity, epsilon) 36 | Returns: 37 | - (float): aggregated regression loss 38 | """ 39 | y_true = concat_true[:, 0] 40 | t_true = concat_true[:, 1] 41 | 42 | y0_pred = concat_pred[:, 0] 43 | y1_pred = concat_pred[:, 1] 44 | 45 | loss0 = tf.reduce_sum((1.0 - t_true) * tf.square(y_true - y0_pred)) 46 | loss1 = tf.reduce_sum(t_true * tf.square(y_true - y1_pred)) 47 | 48 | return loss0 + loss1 49 | 50 | 51 | def dragonnet_loss_binarycross(concat_true, concat_pred): 52 | """ 53 | Implements regression + classification loss in one wrapper function. 54 | 55 | Args: 56 | - concat_true (tf.tensor): tensor of true samples, with shape (n_samples, 2) 57 | Each row in concat_true is comprised of (y, treatment) 58 | - concat_pred (tf.tensor): tensor of predictions, with shape (n_samples, 4) 59 | Each row in concat_pred is comprised of (y0, y1, propensity, epsilon) 60 | Returns: 61 | - (float): aggregated regression + classification loss 62 | """ 63 | return regression_loss(concat_true, concat_pred) + binary_classification_loss( 64 | concat_true, concat_pred 65 | ) 66 | 67 | 68 | def treatment_accuracy(concat_true, concat_pred): 69 | """ 70 | Returns keras' binary_accuracy between treatment and prediction of propensity. 71 | 72 | Args: 73 | - concat_true (tf.tensor): tensor of true samples, with shape (n_samples, 2) 74 | Each row in concat_true is comprised of (y, treatment) 75 | - concat_pred (tf.tensor): tensor of predictions, with shape (n_samples, 4) 76 | Each row in concat_pred is comprised of (y0, y1, propensity, epsilon) 77 | Returns: 78 | - (float): binary accuracy 79 | """ 80 | t_true = concat_true[:, 1] 81 | t_pred = concat_pred[:, 2] 82 | return binary_accuracy(t_true, t_pred) 83 | 84 | 85 | def track_epsilon(concat_true, concat_pred): 86 | """ 87 | Tracks the mean absolute value of epsilon. 88 | 89 | Args: 90 | - concat_true (tf.tensor): tensor of true samples, with shape (n_samples, 2) 91 | Each row in concat_true is comprised of (y, treatment) 92 | - concat_pred (tf.tensor): tensor of predictions, with shape (n_samples, 4) 93 | Each row in concat_pred is comprised of (y0, y1, propensity, epsilon) 94 | Returns: 95 | - (float): mean absolute value of epsilon 96 | """ 97 | epsilons = concat_pred[:, 3] 98 | return tf.abs(tf.reduce_mean(epsilons)) 99 | 100 | 101 | def make_tarreg_loss(ratio=1.0, dragonnet_loss=dragonnet_loss_binarycross): 102 | """ 103 | Given a specified loss function, returns the same loss function with targeted regularization. 104 | 105 | Args: 106 | ratio (float): weight assigned to the targeted regularization loss component 107 | dragonnet_loss (function): a loss function 108 | Returns: 109 | (function): loss function with targeted regularization, weighted by specified ratio 110 | """ 111 | 112 | def tarreg_ATE_unbounded_domain_loss(concat_true, concat_pred): 113 | """ 114 | Returns the loss function (specified in outer function) with targeted regularization. 115 | """ 116 | vanilla_loss = dragonnet_loss(concat_true, concat_pred) 117 | 118 | y_true = concat_true[:, 0] 119 | t_true = concat_true[:, 1] 120 | 121 | y0_pred = concat_pred[:, 0] 122 | y1_pred = concat_pred[:, 1] 123 | t_pred = concat_pred[:, 2] 124 | 125 | epsilons = concat_pred[:, 3] 126 | t_pred = (t_pred + 0.01) / 1.02 127 | # t_pred = tf.clip_by_value(t_pred,0.01, 0.99,name='t_pred') 128 | 129 | y_pred = t_true * y1_pred + (1 - t_true) * y0_pred 130 | 131 | h = t_true / t_pred - (1 - t_true) / (1 - t_pred) 132 | 133 | y_pert = y_pred + epsilons * h 134 | targeted_regularization = tf.reduce_sum(tf.square(y_true - y_pert)) 135 | 136 | # final 137 | loss = vanilla_loss + ratio * targeted_regularization 138 | return loss 139 | 140 | return tarreg_ATE_unbounded_domain_loss 141 | 142 | 143 | class EpsilonLayer(Layer): 144 | """ 145 | Custom keras layer to allow epsilon to be learned during training process. 146 | """ 147 | 148 | def __init__(self, **kwargs): 149 | """ 150 | Inherits keras' Layer object. 151 | """ 152 | super(EpsilonLayer, self).__init__(**kwargs) 153 | 154 | def build(self, input_shape): 155 | """ 156 | Creates a trainable weight variable for this layer. 157 | """ 158 | self.epsilon = self.add_weight( 159 | name="epsilon", shape=[1, 1], initializer="RandomNormal", trainable=True 160 | ) 161 | super(EpsilonLayer, self).build(input_shape) 162 | 163 | def call(self, inputs, **kwargs): 164 | return self.epsilon * tf.ones_like(inputs)[:, 0:1] 165 | 166 | def get_config(self): 167 | config = super().get_config() 168 | return config 169 | 170 | @classmethod 171 | def from_config(cls, config): 172 | return cls(**config) 173 | -------------------------------------------------------------------------------- /causalml/inference/torch/__init__.py: -------------------------------------------------------------------------------- 1 | from .cevae import CEVAE 2 | -------------------------------------------------------------------------------- /causalml/inference/torch/cevae.py: -------------------------------------------------------------------------------- 1 | """ 2 | This module calls the CEVAE[1] function implemented by pyro team. CEVAE demonstrates a number of innovations including: 3 | 4 | - A generative model for causal effect inference with hidden confounders; 5 | - A model and guide with twin neural nets to allow imbalanced treatment; and 6 | - A custom training loss that includes both ELBO terms and extra terms needed to train the guide to be able to answer 7 | counterfactual queries. 8 | 9 | Generative model for a causal model with latent confounder z and binary treatment w: 10 | z ~ p(z) # latent confounder 11 | x ~ p(x|z) # partial noisy observation of z 12 | w ~ p(w|z) # treatment, whose application is biased by z 13 | y ~ p(y|t,z) # outcome 14 | Each of these distributions is defined by a neural network. The y distribution is defined by a disjoint pair of neural 15 | networks defining p(y|t=0,z) and p(y|t=1,z); this allows highly imbalanced treatment. 16 | 17 | **References** 18 | 19 | [1] C. Louizos, U. Shalit, J. Mooij, D. Sontag, R. Zemel, M. Welling (2017). 20 | | Causal Effect Inference with Deep Latent-Variable Models. 21 | | http://papers.nips.cc/paper/7223-causal-effect-inference-with-deep-latent-variable-models.pdf 22 | | https://github.com/AMLab-Amsterdam/CEVAE 23 | """ 24 | 25 | import logging 26 | import torch 27 | from pyro.contrib.cevae import CEVAE as CEVAEModel 28 | 29 | from causalml.inference.meta.utils import convert_pd_to_np 30 | 31 | pyro_logger = logging.getLogger("pyro") 32 | pyro_logger.setLevel(logging.DEBUG) 33 | if pyro_logger.handlers: 34 | pyro_logger.handlers[0].setLevel(logging.DEBUG) 35 | 36 | 37 | class CEVAE: 38 | def __init__( 39 | self, 40 | outcome_dist="studentt", 41 | latent_dim=20, 42 | hidden_dim=200, 43 | num_epochs=50, 44 | num_layers=3, 45 | batch_size=100, 46 | learning_rate=1e-3, 47 | learning_rate_decay=0.1, 48 | num_samples=1000, 49 | weight_decay=1e-4, 50 | ): 51 | """ 52 | Initializes CEVAE. 53 | 54 | Args: 55 | outcome_dist (str): Outcome distribution as one of: "bernoulli" , "exponential", "laplace", "normal", 56 | and "studentt" 57 | latent_dim (int) : Dimension of the latent variable 58 | hidden_dim (int) : Dimension of hidden layers of fully connected networks 59 | num_epochs (int): Number of training epochs 60 | num_layers (int): Number of hidden layers in fully connected networks 61 | batch_size (int): Batch size 62 | learning_rate (int): Learning rate 63 | learning_rate_decay (float/int): Learning rate decay over all epochs; the per-step decay rate will 64 | depend on batch size and number of epochs such that the initial 65 | learning rate will be learning_rate and the 66 | final learning rate will be learning_rate * learning_rate_decay 67 | num_samples (int) : Number of samples to calculate ITE 68 | weight_decay (float) : Weight decay 69 | """ 70 | self.outcome_dist = outcome_dist 71 | self.latent_dim = latent_dim 72 | self.hidden_dim = hidden_dim 73 | self.num_epochs = num_epochs 74 | self.num_layers = num_layers 75 | self.batch_size = batch_size 76 | self.learning_rate = learning_rate 77 | self.learning_rate_decay = learning_rate_decay 78 | self.num_samples = num_samples 79 | self.weight_decay = weight_decay 80 | 81 | def fit(self, X, treatment, y, p=None): 82 | """ 83 | Fits CEVAE. 84 | 85 | Args: 86 | X (np.matrix or np.array or pd.Dataframe): a feature matrix 87 | treatment (np.array or pd.Series): a treatment vector 88 | y (np.array or pd.Series): an outcome vector 89 | """ 90 | X, treatment, y = convert_pd_to_np(X, treatment, y) 91 | 92 | self.cevae = CEVAEModel( 93 | outcome_dist=self.outcome_dist, 94 | feature_dim=X.shape[-1], 95 | latent_dim=self.latent_dim, 96 | hidden_dim=self.hidden_dim, 97 | num_layers=self.num_layers, 98 | ) 99 | 100 | self.cevae.fit( 101 | x=torch.tensor(X, dtype=torch.float), 102 | t=torch.tensor(treatment, dtype=torch.float), 103 | y=torch.tensor(y, dtype=torch.float), 104 | num_epochs=self.num_epochs, 105 | batch_size=self.batch_size, 106 | learning_rate=self.learning_rate, 107 | learning_rate_decay=self.learning_rate_decay, 108 | weight_decay=self.weight_decay, 109 | ) 110 | 111 | def predict(self, X, treatment=None, y=None, p=None): 112 | """ 113 | Calls predict on fitted DragonNet. 114 | 115 | Args: 116 | X (np.matrix or np.array or pd.Dataframe): a feature matrix 117 | Returns: 118 | (np.ndarray): Predictions of treatment effects. 119 | """ 120 | return ( 121 | self.cevae.ite( 122 | torch.tensor(X, dtype=torch.float), 123 | num_samples=self.num_samples, 124 | batch_size=self.batch_size, 125 | ) 126 | .cpu() 127 | .numpy() 128 | ) 129 | 130 | def fit_predict(self, X, treatment, y, p=None): 131 | """ 132 | Fits the CEVAE model and then predicts. 133 | 134 | Args: 135 | X (np.matrix or np.array or pd.Dataframe): a feature matrix 136 | treatment (np.array or pd.Series): a treatment vector 137 | y (np.array or pd.Series): an outcome vector 138 | Returns: 139 | (np.ndarray): Predictions of treatment effects. 140 | """ 141 | self.fit(X, treatment, y) 142 | return self.predict(X) 143 | -------------------------------------------------------------------------------- /causalml/inference/tree/__init__.py: -------------------------------------------------------------------------------- 1 | from .causal.causaltree import CausalTreeRegressor 2 | from .causal.causalforest import CausalRandomForestRegressor 3 | from .plot import uplift_tree_string, uplift_tree_plot, plot_dist_tree_leaves_values 4 | from .uplift import DecisionTree, UpliftTreeClassifier, UpliftRandomForestClassifier 5 | from .utils import ( 6 | cat_group, 7 | cat_transform, 8 | cv_fold_index, 9 | cat_continuous, 10 | kpi_transform, 11 | get_tree_leaves_mask, 12 | ) 13 | -------------------------------------------------------------------------------- /causalml/inference/tree/_tree/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | This part of tree structures definition was initially borrowed from 3 | https://github.com/scikit-learn/scikit-learn/tree/1.5.2/sklearn/tree 4 | """ 5 | 6 | """Decision tree based models for classification and regression.""" 7 | 8 | from ._classes import ( 9 | BaseDecisionTree, 10 | ) 11 | 12 | __all__ = [ 13 | "BaseDecisionTree", 14 | ] 15 | -------------------------------------------------------------------------------- /causalml/inference/tree/_tree/_criterion.pxd: -------------------------------------------------------------------------------- 1 | # Authors: Gilles Louppe <g.louppe@gmail.com> 2 | # Peter Prettenhofer <peter.prettenhofer@gmail.com> 3 | # Brian Holt <bdholt1@gmail.com> 4 | # Joel Nothman <joel.nothman@gmail.com> 5 | # Arnaud Joly <arnaud.v.joly@gmail.com> 6 | # Jacob Schreiber <jmschreiber91@gmail.com> 7 | # 8 | # License: BSD 3 clause 9 | 10 | # cython: cdivision=True 11 | # cython: boundscheck=False 12 | # cython: wraparound=False 13 | # cython: language_level=3 14 | # cython: linetrace=True 15 | 16 | # See _criterion.pyx for implementation details. 17 | from ._typedefs cimport float64_t, int8_t, int32_t, intp_t 18 | 19 | 20 | cdef class Criterion: 21 | # The criterion computes the impurity of a node and the reduction of 22 | # impurity of a split on that node. It also computes the output statistics 23 | # such as the mean in regression and class probabilities in classification. 24 | 25 | # Internal structures 26 | cdef const float64_t[:, ::1] y # Values of y 27 | cdef const int32_t[:] treatment # Treatment assignment: 1 for treatment, 0 for control 28 | cdef const float64_t[:] sample_weight # Sample weights 29 | 30 | cdef const intp_t[:] sample_indices # Sample indices in X, y 31 | cdef intp_t start # samples[start:pos] are the samples in the left node 32 | cdef intp_t pos # samples[pos:end] are the samples in the right node 33 | cdef intp_t end 34 | cdef intp_t n_missing # Number of missing values for the feature being evaluated 35 | cdef bint missing_go_to_left # Whether missing values go to the left node 36 | 37 | cdef intp_t n_outputs # Number of outputs 38 | cdef intp_t n_samples # Number of samples 39 | cdef intp_t n_node_samples # Number of samples in the node (end-start) 40 | cdef float64_t weighted_n_samples # Weighted number of samples (in total) 41 | cdef float64_t weighted_n_node_samples # Weighted number of samples in the node 42 | cdef float64_t weighted_n_left # Weighted number of samples in the left node 43 | cdef float64_t weighted_n_right # Weighted number of samples in the right node 44 | cdef float64_t weighted_n_missing # Weighted number of samples that are missing 45 | 46 | # The criterion object is maintained such that left and right collected 47 | # statistics correspond to samples[start:pos] and samples[pos:end]. 48 | 49 | # Methods 50 | cdef int init( 51 | self, 52 | const float64_t[:, ::1] y, 53 | const int32_t[:] treatment, 54 | const float64_t[:] sample_weight, 55 | float64_t weighted_n_samples, 56 | const intp_t[:] sample_indices, 57 | intp_t start, 58 | intp_t end 59 | ) except -1 nogil 60 | cdef void init_sum_missing(self) 61 | cdef void init_missing(self, intp_t n_missing) noexcept nogil 62 | cdef int reset(self) except -1 nogil 63 | cdef int reverse_reset(self) except -1 nogil 64 | cdef int update(self, intp_t new_pos) except -1 nogil 65 | cdef float64_t node_impurity(self) noexcept nogil 66 | cdef void children_impurity( 67 | self, 68 | float64_t* impurity_left, 69 | float64_t* impurity_right 70 | ) noexcept nogil 71 | cdef void node_value( 72 | self, 73 | float64_t* dest 74 | ) noexcept nogil 75 | cdef void clip_node_value( 76 | self, 77 | float64_t* dest, 78 | float64_t lower_bound, 79 | float64_t upper_bound 80 | ) noexcept nogil 81 | cdef float64_t middle_value(self) noexcept nogil 82 | cdef float64_t impurity_improvement( 83 | self, 84 | float64_t impurity_parent, 85 | float64_t impurity_left, 86 | float64_t impurity_right 87 | ) noexcept nogil 88 | cdef float64_t proxy_impurity_improvement(self) noexcept nogil 89 | cdef bint check_monotonicity( 90 | self, 91 | int8_t monotonic_cst, 92 | float64_t lower_bound, 93 | float64_t upper_bound, 94 | ) noexcept nogil 95 | cdef inline bint _check_monotonicity( 96 | self, 97 | int8_t monotonic_cst, 98 | float64_t lower_bound, 99 | float64_t upper_bound, 100 | float64_t sum_left, 101 | float64_t sum_right, 102 | ) noexcept nogil 103 | 104 | cdef class ClassificationCriterion(Criterion): 105 | """Abstract criterion for classification.""" 106 | 107 | cdef intp_t[::1] n_classes 108 | cdef intp_t max_n_classes 109 | 110 | cdef float64_t[:, ::1] sum_total # The sum of the weighted count of each label. 111 | cdef float64_t[:, ::1] sum_left # Same as above, but for the left side of the split 112 | cdef float64_t[:, ::1] sum_right # Same as above, but for the right side of the split 113 | cdef float64_t[:, ::1] sum_missing # Same as above, but for missing values in X 114 | 115 | cdef class RegressionCriterion(Criterion): 116 | """Abstract regression criterion.""" 117 | 118 | cdef float64_t sq_sum_total 119 | 120 | cdef float64_t[::1] sum_total # The sum of w*y. 121 | cdef float64_t[::1] sum_left # Same as above, but for the left side of the split 122 | cdef float64_t[::1] sum_right # Same as above, but for the right side of the split 123 | cdef float64_t[::1] sum_missing # Same as above, but for missing values in X 124 | -------------------------------------------------------------------------------- /causalml/inference/tree/_tree/_splitter.pxd: -------------------------------------------------------------------------------- 1 | # Authors: Gilles Louppe <g.louppe@gmail.com> 2 | # Peter Prettenhofer <peter.prettenhofer@gmail.com> 3 | # Brian Holt <bdholt1@gmail.com> 4 | # Joel Nothman <joel.nothman@gmail.com> 5 | # Arnaud Joly <arnaud.v.joly@gmail.com> 6 | # Jacob Schreiber <jmschreiber91@gmail.com> 7 | # 8 | # License: BSD 3 clause 9 | 10 | # distutils: language = c++ 11 | # cython: cdivision=True 12 | # cython: boundscheck=False 13 | # cython: wraparound=False 14 | # cython: language_level=3 15 | # cython: linetrace=True 16 | 17 | # See _splitter.pyx for details. 18 | from ._criterion cimport Criterion 19 | from ._tree cimport ParentInfo 20 | 21 | from ._typedefs cimport float32_t, float64_t, intp_t, int8_t, int32_t, uint32_t 22 | 23 | 24 | cdef struct SplitRecord: 25 | # Data to track sample split 26 | intp_t feature # Which feature to split on. 27 | intp_t pos # Split samples array at the given position, 28 | # # i.e. count of samples below threshold for feature. 29 | # # pos is >= end if the node is a leaf. 30 | float64_t threshold # Threshold to split at. 31 | float64_t improvement # Impurity improvement given parent node. 32 | float64_t impurity_left # Impurity of the left split. 33 | float64_t impurity_right # Impurity of the right split. 34 | float64_t lower_bound # Lower bound on value of both children for monotonicity 35 | float64_t upper_bound # Upper bound on value of both children for monotonicity 36 | unsigned char missing_go_to_left # Controls if missing values go to the left node. 37 | intp_t n_missing # Number of missing values for the feature being split on 38 | 39 | cdef class Splitter: 40 | # The splitter searches in the input space for a feature and a threshold 41 | # to split the samples samples[start:end]. 42 | # 43 | # The impurity computations are delegated to a criterion object. 44 | 45 | # Internal structures 46 | cdef public Criterion criterion # Impurity criterion 47 | cdef public intp_t max_features # Number of features to test 48 | cdef public intp_t min_samples_leaf # Min samples in a leaf 49 | cdef public float64_t min_weight_leaf # Minimum weight in a leaf 50 | 51 | cdef object random_state # Random state 52 | cdef uint32_t rand_r_state # sklearn_rand_r random number state 53 | 54 | cdef intp_t[::1] samples # Sample indices in X, y 55 | cdef intp_t n_samples # X.shape[0] 56 | cdef float64_t weighted_n_samples # Weighted number of samples 57 | cdef intp_t[::1] features # Feature indices in X 58 | cdef intp_t[::1] constant_features # Constant features indices 59 | cdef intp_t n_features # X.shape[1] 60 | cdef float32_t[::1] feature_values # temp. array holding feature values 61 | 62 | cdef intp_t start # Start position for the current node 63 | cdef intp_t end # End position for the current node 64 | 65 | cdef const float64_t[:, ::1] y 66 | # Monotonicity constraints for each feature. 67 | # The encoding is as follows: 68 | # -1: monotonic decrease 69 | # 0: no constraint 70 | # +1: monotonic increase 71 | cdef const int8_t[:] monotonic_cst 72 | cdef bint with_monotonic_cst 73 | cdef const int32_t[:] treatment 74 | cdef const float64_t[:] sample_weight 75 | 76 | # The samples vector `samples` is maintained by the Splitter object such 77 | # that the samples contained in a node are contiguous. With this setting, 78 | # `node_split` reorganizes the node samples `samples[start:end]` in two 79 | # subsets `samples[start:pos]` and `samples[pos:end]`. 80 | 81 | # The 1-d `features` array of size n_features contains the features 82 | # indices and allows fast sampling without replacement of features. 83 | 84 | # The 1-d `constant_features` array of size n_features holds in 85 | # `constant_features[:n_constant_features]` the feature ids with 86 | # constant values for all the samples that reached a specific node. 87 | # The value `n_constant_features` is given by the parent node to its 88 | # child nodes. The content of the range `[n_constant_features:]` is left 89 | # undefined, but preallocated for performance reasons 90 | # This allows optimization with depth-based tree building. 91 | 92 | # Methods 93 | cdef int init( 94 | self, 95 | object X, 96 | const float64_t[:, ::1] y, 97 | const int32_t[:] treatment, 98 | const float64_t[:] sample_weight, 99 | const unsigned char[::1] missing_values_in_feature_mask, 100 | ) except -1 101 | 102 | cdef int node_reset( 103 | self, 104 | intp_t start, 105 | intp_t end, 106 | float64_t* weighted_n_node_samples 107 | ) except -1 nogil 108 | 109 | cdef int node_split( 110 | self, 111 | ParentInfo* parent, 112 | SplitRecord* split, 113 | ) except -1 nogil 114 | 115 | cdef void node_value(self, float64_t* dest) noexcept nogil 116 | 117 | cdef void clip_node_value(self, float64_t* dest, float64_t lower_bound, float64_t upper_bound) noexcept nogil 118 | 119 | cdef float64_t node_impurity(self) noexcept nogil 120 | -------------------------------------------------------------------------------- /causalml/inference/tree/_tree/_tree.pxd: -------------------------------------------------------------------------------- 1 | # Authors: Gilles Louppe <g.louppe@gmail.com> 2 | # Peter Prettenhofer <peter.prettenhofer@gmail.com> 3 | # Brian Holt <bdholt1@gmail.com> 4 | # Joel Nothman <joel.nothman@gmail.com> 5 | # Arnaud Joly <arnaud.v.joly@gmail.com> 6 | # Jacob Schreiber <jmschreiber91@gmail.com> 7 | # Nelson Liu <nelson@nelsonliu.me> 8 | # 9 | # License: BSD 3 clause 10 | 11 | # distutils: language = c++ 12 | # cython: cdivision=True 13 | # cython: boundscheck=False 14 | # cython: wraparound=False 15 | # cython: language_level=3 16 | # cython: linetrace=True 17 | 18 | # See _tree.pyx for details. 19 | 20 | import numpy as np 21 | cimport numpy as cnp 22 | 23 | from ._typedefs cimport float32_t, float64_t, intp_t, int32_t, uint32_t 24 | 25 | from ._splitter cimport Splitter 26 | from ._splitter cimport SplitRecord 27 | 28 | cdef struct Node: 29 | # Base storage structure for the nodes in a Tree object 30 | 31 | intp_t left_child # id of the left child of the node 32 | intp_t right_child # id of the right child of the node 33 | intp_t feature # Feature used for splitting the node 34 | float64_t threshold # Threshold value at the node 35 | float64_t impurity # Impurity of the node (i.e., the value of the criterion) 36 | intp_t n_node_samples # Number of samples at the node 37 | float64_t weighted_n_node_samples # Weighted number of samples at the node 38 | unsigned char missing_go_to_left # Whether features have missing values 39 | 40 | cdef void _init_parent_record(ParentInfo* record) noexcept nogil 41 | 42 | cdef struct ParentInfo: 43 | # Structure to store information about the parent of a node 44 | # This is passed to the splitter, to provide information about the previous split 45 | 46 | float64_t lower_bound # the lower bound of the parent's impurity 47 | float64_t upper_bound # the upper bound of the parent's impurity 48 | float64_t impurity # the impurity of the parent 49 | intp_t n_constant_features # the number of constant features found in parent 50 | 51 | cdef class Tree: 52 | # The Tree object is a binary tree structure constructed by the 53 | # TreeBuilder. The tree structure is used for predictions and 54 | # feature importances. 55 | 56 | # Input/Output layout 57 | cdef public intp_t n_features # Number of features in X 58 | cdef intp_t* n_classes # Number of classes in y[:, k] 59 | cdef public intp_t n_outputs # Number of outputs in y 60 | cdef public intp_t max_n_classes # max(n_classes) 61 | 62 | # Inner structures: values are stored separately from node structure, 63 | # since size is determined at runtime. 64 | cdef public intp_t max_depth # Max depth of the tree 65 | cdef public intp_t node_count # Counter for node IDs 66 | cdef public intp_t capacity # Capacity of tree, in terms of nodes 67 | cdef Node* nodes # Array of nodes 68 | cdef float64_t* value # (capacity, n_outputs, max_n_classes) array of values 69 | cdef intp_t value_stride # = n_outputs * max_n_classes 70 | 71 | # Methods 72 | cdef intp_t _add_node(self, intp_t parent, bint is_left, bint is_leaf, 73 | intp_t feature, float64_t threshold, float64_t impurity, 74 | intp_t n_node_samples, 75 | float64_t weighted_n_node_samples, 76 | unsigned char missing_go_to_left) except -1 nogil 77 | cdef int _resize(self, intp_t capacity) except -1 nogil 78 | cdef int _resize_c(self, intp_t capacity=*) except -1 nogil 79 | 80 | cdef cnp.ndarray _get_value_ndarray(self) 81 | cdef cnp.ndarray _get_node_ndarray(self) 82 | 83 | cpdef cnp.ndarray predict(self, object X) 84 | 85 | cpdef cnp.ndarray apply(self, object X) 86 | cdef cnp.ndarray _apply_dense(self, object X) 87 | cdef cnp.ndarray _apply_sparse_csr(self, object X) 88 | 89 | cpdef object decision_path(self, object X) 90 | cdef object _decision_path_dense(self, object X) 91 | cdef object _decision_path_sparse_csr(self, object X) 92 | 93 | cpdef compute_node_depths(self) 94 | cpdef compute_feature_importances(self, normalize=*) 95 | 96 | 97 | # ============================================================================= 98 | # Tree builder 99 | # ============================================================================= 100 | 101 | cdef class TreeBuilder: 102 | # The TreeBuilder recursively builds a Tree object from training samples, 103 | # using a Splitter object for splitting internal nodes and assigning 104 | # values to leaves. 105 | # 106 | # This class controls the various stopping criteria and the node splitting 107 | # evaluation order, e.g. depth-first or best-first. 108 | 109 | cdef Splitter splitter # Splitting algorithm 110 | 111 | cdef intp_t min_samples_split # Minimum number of samples in an internal node 112 | cdef intp_t min_samples_leaf # Minimum number of samples in a leaf 113 | cdef float64_t min_weight_leaf # Minimum weight in a leaf 114 | cdef intp_t max_depth # Maximal tree depth 115 | cdef float64_t min_impurity_decrease # Impurity threshold for early stopping 116 | 117 | cpdef build( 118 | self, 119 | Tree tree, 120 | object X, 121 | const float64_t[:, ::1] y, 122 | const int32_t[:] treatment, 123 | const float64_t[:] sample_weight=*, 124 | const unsigned char[::1] missing_values_in_feature_mask=*, 125 | ) 126 | 127 | cdef _check_input( 128 | self, 129 | object X, 130 | const float64_t[:, ::1] y, 131 | const int32_t[:] treatment, 132 | const float64_t[:] sample_weight, 133 | ) 134 | 135 | cdef struct FrontierRecord: 136 | # Record of information of a Node, the frontier for a split. Those records are 137 | # maintained in a heap to access the Node with the best improvement in impurity, 138 | # allowing growing trees greedily on this improvement. 139 | intp_t node_id 140 | intp_t start 141 | intp_t end 142 | intp_t pos 143 | intp_t depth 144 | bint is_leaf 145 | float64_t impurity 146 | float64_t impurity_left 147 | float64_t impurity_right 148 | float64_t improvement 149 | float64_t lower_bound 150 | float64_t upper_bound 151 | float64_t middle_value 152 | 153 | # A record on the stack for depth-first tree growing 154 | cdef struct StackRecord: 155 | intp_t start 156 | intp_t end 157 | intp_t depth 158 | intp_t parent 159 | bint is_left 160 | float64_t impurity 161 | intp_t n_constant_features 162 | float64_t lower_bound 163 | float64_t upper_bound 164 | -------------------------------------------------------------------------------- /causalml/inference/tree/_tree/_typedefs.pxd: -------------------------------------------------------------------------------- 1 | # Commonly used types 2 | # These are redefinitions of the ones defined by numpy in 3 | # https://github.com/numpy/numpy/blob/main/numpy/__init__.pxd. 4 | # It will eventually avoid having to always include the numpy headers even when we 5 | # would only use it for the types. 6 | # 7 | # When used to declare variables that will receive values from numpy arrays, it 8 | # should match the dtype of the array. For example, to declare a variable that will 9 | # receive values from a numpy array of dtype np.float64, the type float64_t must be 10 | # used. 11 | # 12 | # TODO: Stop defining custom types locally or globally like DTYPE_t and friends and 13 | # use these consistently throughout the codebase. 14 | # NOTE: Extend this list as needed when converting more cython extensions. 15 | ctypedef unsigned char uint8_t 16 | ctypedef unsigned int uint32_t 17 | ctypedef unsigned long long uint64_t 18 | # Note: In NumPy 2, indexing always happens with npy_intp which is an alias for 19 | # the Py_ssize_t type, see PEP 353. 20 | # 21 | # Note that on most platforms Py_ssize_t is equivalent to C99's intptr_t, 22 | # but they can differ on architecture with segmented memory (none 23 | # supported by scikit-learn at the time of writing). 24 | # 25 | # intp_t/np.intp should be used to index arrays in a platform dependent way. 26 | # Storing arrays with platform dependent dtypes as attribute on picklable 27 | # objects is not recommended as it requires special care when loading and 28 | # using such datastructures on a host with different bitness. Instead one 29 | # should rather use fixed width integer types such as int32 or uint32 when we know 30 | # that the number of elements to index is not larger to 2 or 4 billions. 31 | ctypedef Py_ssize_t intp_t 32 | ctypedef float float32_t 33 | ctypedef double float64_t 34 | # Sparse matrices indices and indices' pointers arrays must use int32_t over 35 | # intp_t because intp_t is platform dependent. 36 | # When large sparse matrices are supported, indexing must use int64_t. 37 | # See https://github.com/scikit-learn/scikit-learn/issues/23653 which tracks the 38 | # ongoing work to support large sparse matrices. 39 | ctypedef signed char int8_t 40 | ctypedef signed int int32_t 41 | ctypedef signed long long int64_t 42 | -------------------------------------------------------------------------------- /causalml/inference/tree/_tree/_typedefs.pyx: -------------------------------------------------------------------------------- 1 | # _typedefs is a declaration only module 2 | # 3 | # The functions implemented here are for testing purpose only. 4 | 5 | 6 | import numpy as np 7 | 8 | 9 | ctypedef fused testing_type_t: 10 | float32_t 11 | float64_t 12 | int8_t 13 | int32_t 14 | int64_t 15 | intp_t 16 | uint8_t 17 | uint32_t 18 | uint64_t 19 | 20 | 21 | def testing_make_array_from_typed_val(testing_type_t val): 22 | cdef testing_type_t[:] val_view = <testing_type_t[:1]>&val 23 | return np.asarray(val_view) 24 | -------------------------------------------------------------------------------- /causalml/inference/tree/_tree/_utils.pxd: -------------------------------------------------------------------------------- 1 | # Authors: Gilles Louppe <g.louppe@gmail.com> 2 | # Peter Prettenhofer <peter.prettenhofer@gmail.com> 3 | # Arnaud Joly <arnaud.v.joly@gmail.com> 4 | # Jacob Schreiber <jmschreiber91@gmail.com> 5 | # Nelson Liu <nelson@nelsonliu.me> 6 | # 7 | # License: BSD 3 clause 8 | 9 | # distutils: language = c++ 10 | # cython: cdivision=True 11 | # cython: boundscheck=False 12 | # cython: wraparound=False 13 | # cython: language_level=3 14 | # cython: linetrace=True 15 | 16 | # See _utils.pyx for details. 17 | 18 | cimport numpy as cnp 19 | from ._tree cimport Node 20 | from ._typedefs cimport float32_t, float64_t, intp_t, int32_t, uint32_t 21 | 22 | cdef enum: 23 | # Max value for our rand_r replacement (near the bottom). 24 | # We don't use RAND_MAX because it's different across platforms and 25 | # particularly tiny on Windows/MSVC. 26 | # It corresponds to the maximum representable value for 27 | # 32-bit signed integers (i.e. 2^31 - 1). 28 | RAND_R_MAX = 2147483647 29 | 30 | 31 | # safe_realloc(&p, n) resizes the allocation of p to n * sizeof(*p) bytes or 32 | # raises a MemoryError. It never calls free, since that's __dealloc__'s job. 33 | # cdef float32_t *p = NULL 34 | # safe_realloc(&p, n) 35 | # is equivalent to p = malloc(n * sizeof(*p)) with error checking. 36 | ctypedef fused realloc_ptr: 37 | # Add pointer types here as needed. 38 | (float32_t*) 39 | (intp_t*) 40 | (unsigned char*) 41 | (WeightedPQueueRecord*) 42 | (float64_t*) 43 | (float64_t**) 44 | (Node*) 45 | (Node**) 46 | 47 | cdef int safe_realloc(realloc_ptr* p, intp_t nelems) except -1 nogil 48 | 49 | 50 | cdef cnp.ndarray sizet_ptr_to_ndarray(intp_t* data, intp_t size) 51 | 52 | 53 | cdef intp_t rand_int(intp_t low, intp_t high, 54 | uint32_t* random_state) noexcept nogil 55 | 56 | 57 | cdef float64_t rand_uniform(float64_t low, float64_t high, 58 | uint32_t* random_state) noexcept nogil 59 | 60 | 61 | cdef float64_t log(float64_t x) noexcept nogil 62 | 63 | # ============================================================================= 64 | # WeightedPQueue data structure 65 | # ============================================================================= 66 | 67 | # A record stored in the WeightedPQueue 68 | cdef struct WeightedPQueueRecord: 69 | float64_t data 70 | float64_t weight 71 | 72 | cdef class WeightedPQueue: 73 | cdef intp_t capacity 74 | cdef intp_t array_ptr 75 | cdef WeightedPQueueRecord* array_ 76 | 77 | cdef bint is_empty(self) noexcept nogil 78 | cdef int reset(self) except -1 nogil 79 | cdef intp_t size(self) noexcept nogil 80 | cdef int push(self, float64_t data, float64_t weight) except -1 nogil 81 | cdef int remove(self, float64_t data, float64_t weight) noexcept nogil 82 | cdef int pop(self, float64_t* data, float64_t* weight) noexcept nogil 83 | cdef int peek(self, float64_t* data, float64_t* weight) noexcept nogil 84 | cdef float64_t get_weight_from_index(self, intp_t index) noexcept nogil 85 | cdef float64_t get_value_from_index(self, intp_t index) noexcept nogil 86 | 87 | 88 | # ============================================================================= 89 | # WeightedMedianCalculator data structure 90 | # ============================================================================= 91 | 92 | cdef class WeightedMedianCalculator: 93 | cdef intp_t initial_capacity 94 | cdef WeightedPQueue samples 95 | cdef float64_t total_weight 96 | cdef intp_t k 97 | cdef float64_t sum_w_0_k # represents sum(weights[0:k]) = w[0] + w[1] + ... + w[k-1] 98 | cdef intp_t size(self) noexcept nogil 99 | cdef int push(self, float64_t data, float64_t weight) except -1 nogil 100 | cdef int reset(self) except -1 nogil 101 | cdef int update_median_parameters_post_push( 102 | self, float64_t data, float64_t weight, 103 | float64_t original_median) noexcept nogil 104 | cdef int remove(self, float64_t data, float64_t weight) noexcept nogil 105 | cdef int pop(self, float64_t* data, float64_t* weight) noexcept nogil 106 | cdef int update_median_parameters_post_remove( 107 | self, float64_t data, float64_t weight, 108 | float64_t original_median) noexcept nogil 109 | cdef float64_t get_median(self) noexcept nogil 110 | -------------------------------------------------------------------------------- /causalml/inference/tree/causal/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/uber/causalml/5f5c4fb17cbadc79314e0fbc1dc0bc3f31f8a168/causalml/inference/tree/causal/__init__.py -------------------------------------------------------------------------------- /causalml/inference/tree/causal/_builder.pxd: -------------------------------------------------------------------------------- 1 | # cython: cdivision=True 2 | # cython: boundscheck=False 3 | # cython: wraparound=False 4 | # cython: language_level=3 5 | # cython: linetrace=True 6 | 7 | from .._tree._tree cimport Node, Tree, TreeBuilder 8 | from .._tree._splitter cimport Splitter, SplitRecord 9 | from .._tree._tree cimport intp_t, int32_t, float64_t 10 | from .._tree._tree cimport FrontierRecord, StackRecord 11 | from .._tree._tree cimport ParentInfo, _init_parent_record 12 | -------------------------------------------------------------------------------- /causalml/inference/tree/causal/_criterion.pxd: -------------------------------------------------------------------------------- 1 | # cython: cdivision=True 2 | # cython: boundscheck=False 3 | # cython: wraparound=False 4 | # cython: language_level=3 5 | # cython: linetrace=True 6 | 7 | 8 | from .._tree._criterion cimport RegressionCriterion 9 | from .._tree._criterion cimport int32_t, intp_t, float64_t 10 | 11 | 12 | cdef struct NodeInfo: 13 | double count # the number of obs 14 | double tr_count # the number of treatment obs 15 | double ct_count # the number of control obs 16 | double tr_y_sum # the sum of outcomes among treatment obs 17 | double ct_y_sum # the sum of outcomes among control obs 18 | double y_sq_sum # the squared sum of outcomes 19 | double tr_y_sq_sum # the squared sum of outcomes among treatment obs 20 | double ct_y_sq_sum # the squared sum of outcomes among control obs 21 | double split_metric # Additional split metric for t-test criterion 22 | 23 | cdef struct SplitState: 24 | NodeInfo node # current node state 25 | NodeInfo right # right split state 26 | NodeInfo left # left split state 27 | -------------------------------------------------------------------------------- /causalml/metrics/__init__.py: -------------------------------------------------------------------------------- 1 | from .classification import roc_auc_score, logloss, classification_metrics # noqa 2 | from .regression import ( 3 | ape, 4 | mape, 5 | mae, 6 | rmse, 7 | r2_score, 8 | gini, 9 | smape, 10 | regression_metrics, 11 | ) # noqa 12 | from .visualize import ( 13 | plot, 14 | plot_gain, 15 | plot_lift, 16 | plot_qini, 17 | plot_tmlegain, 18 | plot_tmleqini, 19 | ) # noqa 20 | from .visualize import ( 21 | get_cumgain, 22 | get_cumlift, 23 | get_qini, 24 | get_tmlegain, 25 | get_tmleqini, 26 | ) # noqa 27 | from .visualize import auuc_score, qini_score # noqa 28 | from .sensitivity import Sensitivity, SensitivityPlaceboTreatment # noqa 29 | from .sensitivity import ( 30 | SensitivityRandomCause, 31 | SensitivityRandomReplace, 32 | SensitivitySubsetData, 33 | SensitivitySelectionBias, 34 | ) # noqa 35 | -------------------------------------------------------------------------------- /causalml/metrics/classification.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from sklearn.metrics import log_loss, roc_auc_score 3 | 4 | from .const import EPS 5 | from .regression import regression_metrics 6 | 7 | 8 | logger = logging.getLogger("causalml") 9 | 10 | 11 | def logloss(y, p): 12 | """Bounded log loss error. 13 | Args: 14 | y (numpy.array): target 15 | p (numpy.array): prediction 16 | Returns: 17 | bounded log loss error 18 | """ 19 | 20 | p[p < EPS] = EPS 21 | p[p > 1 - EPS] = 1 - EPS 22 | return log_loss(y, p) 23 | 24 | 25 | def classification_metrics( 26 | y, p, w=None, metrics={"AUC": roc_auc_score, "Log Loss": logloss} 27 | ): 28 | """Log metrics for classifiers. 29 | 30 | Args: 31 | y (numpy.array): target 32 | p (numpy.array): prediction 33 | w (numpy.array, optional): a treatment vector (1 or True: treatment, 0 or False: control). If given, log 34 | metrics for the treatment and control group separately 35 | metrics (dict, optional): a dictionary of the metric names and functions 36 | """ 37 | regression_metrics(y=y, p=p, w=w, metrics=metrics) 38 | -------------------------------------------------------------------------------- /causalml/metrics/const.py: -------------------------------------------------------------------------------- 1 | EPS = 1e-15 2 | -------------------------------------------------------------------------------- /causalml/metrics/regression.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import numpy as np 3 | from sklearn.metrics import mean_squared_error as mse 4 | from sklearn.metrics import mean_absolute_error as mae # noqa 5 | from sklearn.metrics import r2_score # noqa 6 | 7 | from .const import EPS 8 | 9 | 10 | logger = logging.getLogger("causalml") 11 | 12 | 13 | def ape(y, p): 14 | """Absolute Percentage Error (APE). 15 | Args: 16 | y (float): target 17 | p (float): prediction 18 | 19 | Returns: 20 | e (float): APE 21 | """ 22 | 23 | assert np.abs(y) > EPS 24 | return np.abs(1 - p / y) 25 | 26 | 27 | def mape(y, p): 28 | """Mean Absolute Percentage Error (MAPE). 29 | Args: 30 | y (numpy.array): target 31 | p (numpy.array): prediction 32 | 33 | Returns: 34 | e (numpy.float64): MAPE 35 | """ 36 | 37 | filt = np.abs(y) > EPS 38 | return np.mean(np.abs(1 - p[filt] / y[filt])) 39 | 40 | 41 | def smape(y, p): 42 | """Symmetric Mean Absolute Percentage Error (sMAPE). 43 | Args: 44 | y (numpy.array): target 45 | p (numpy.array): prediction 46 | 47 | Returns: 48 | e (numpy.float64): sMAPE 49 | """ 50 | return 2.0 * np.mean(np.abs(y - p) / (np.abs(y) + np.abs(p))) 51 | 52 | 53 | def rmse(y, p): 54 | """Root Mean Squared Error (RMSE). 55 | Args: 56 | y (numpy.array): target 57 | p (numpy.array): prediction 58 | 59 | Returns: 60 | e (numpy.float64): RMSE 61 | """ 62 | 63 | # check and get number of samples 64 | assert y.shape == p.shape 65 | 66 | return np.sqrt(mse(y, p)) 67 | 68 | 69 | def gini(y, p): 70 | """Normalized Gini Coefficient. 71 | 72 | Args: 73 | y (numpy.array): target 74 | p (numpy.array): prediction 75 | 76 | Returns: 77 | e (numpy.float64): normalized Gini coefficient 78 | """ 79 | 80 | # check and get number of samples 81 | assert y.shape == p.shape 82 | 83 | n_samples = y.shape[0] 84 | 85 | # sort rows on prediction column 86 | # (from largest to smallest) 87 | arr = np.array([y, p]).transpose() 88 | true_order = arr[arr[:, 0].argsort()][::-1, 0] 89 | pred_order = arr[arr[:, 1].argsort()][::-1, 0] 90 | 91 | # get Lorenz curves 92 | l_true = np.cumsum(true_order) / np.sum(true_order) 93 | l_pred = np.cumsum(pred_order) / np.sum(pred_order) 94 | l_ones = np.linspace(1 / n_samples, 1, n_samples) 95 | 96 | # get Gini coefficients (area between curves) 97 | g_true = np.sum(l_ones - l_true) 98 | g_pred = np.sum(l_ones - l_pred) 99 | 100 | # normalize to true Gini coefficient 101 | return g_pred / g_true 102 | 103 | 104 | def regression_metrics( 105 | y, p, w=None, metrics={"RMSE": rmse, "sMAPE": smape, "Gini": gini} 106 | ): 107 | """Log metrics for regressors. 108 | 109 | Args: 110 | y (numpy.array): target 111 | p (numpy.array): prediction 112 | w (numpy.array, optional): a treatment vector (1 or True: treatment, 0 or False: control). If given, log 113 | metrics for the treatment and control group separately 114 | metrics (dict, optional): a dictionary of the metric names and functions 115 | """ 116 | assert metrics 117 | assert y.shape[0] == p.shape[0] 118 | 119 | for name, func in metrics.items(): 120 | if w is not None: 121 | assert y.shape[0] == w.shape[0] 122 | if w.dtype != bool: 123 | w = w == 1 124 | logger.info("{:>8s} (Control): {:10.4f}".format(name, func(y[~w], p[~w]))) 125 | logger.info("{:>8s} (Treatment): {:10.4f}".format(name, func(y[w], p[w]))) 126 | else: 127 | logger.info("{:>8s}: {:10.4f}".format(name, func(y, p))) 128 | -------------------------------------------------------------------------------- /causalml/optimize/__init__.py: -------------------------------------------------------------------------------- 1 | from .policylearner import PolicyLearner 2 | from .unit_selection import CounterfactualUnitSelector 3 | from .utils import get_treatment_costs, get_actual_value, get_uplift_best 4 | from .value_optimization import CounterfactualValueEstimator 5 | from .pns import get_pns_bounds 6 | -------------------------------------------------------------------------------- /causalml/optimize/pns.py: -------------------------------------------------------------------------------- 1 | def get_pns_bounds(data_exp, data_obs, T, Y, type="PNS"): 2 | """ 3 | Args 4 | ---- 5 | data_exp : DataFrame 6 | Data from an experiment. 7 | data_obs : DataFrame 8 | Data from an observational study 9 | T : str 10 | Name of the binary treatment indicator 11 | y : str 12 | Name of the binary outcome indicator 13 | type : str 14 | Type of probability of causation desired. Acceptable args are: 15 | - ``PNS``: Probability of necessary and sufficient causation 16 | - ``PS``: Probability of sufficient causation 17 | - ``PN``: Probability of necessary causation 18 | 19 | Notes 20 | ----- 21 | Based on Equation (24) in `Tian and Pearl (2000) <https://ftp.cs.ucla.edu/pub/stat_ser/r271-A.pdf>`_. 22 | 23 | To capture the counterfactual notation, we use ``1`` and ``0`` to indicate the actual and 24 | counterfactual values of a variable, respectively, and we use ``do`` to indicate the effect 25 | of an intervention. 26 | 27 | The experimental and observational data are either assumed to come to the same population, 28 | or from random samples of the population. If the data are from a sample, the bounds may 29 | be incorrectly calculated because the relevant quantities in the Tian-Pearl equations are 30 | defined e.g. as :math:`P(Y|do(T))`, not :math:`P(Y|do(T), S)` where :math:`S` corresponds to sample selection. 31 | `Bareinboim and Pearl (2016) <https://www.pnas.org/doi/10.1073/pnas.1510507113>`_ discuss conditions 32 | under which :math:`P(Y|do(T))` can be recovered from :math:`P(Y|do(T), S)`. 33 | """ 34 | 35 | # Probabilities calculated from observational data 36 | Y1 = data_obs[Y].mean() 37 | T1Y0 = ( 38 | data_obs.loc[(data_obs[T] == 1) & (data_obs[Y] == 0)].shape[0] 39 | / data_obs.shape[0] 40 | ) 41 | T1Y1 = ( 42 | data_obs.loc[(data_obs[T] == 1) & (data_obs[Y] == 1)].shape[0] 43 | / data_obs.shape[0] 44 | ) 45 | T0Y0 = ( 46 | data_obs.loc[(data_obs[T] == 0) & (data_obs[Y] == 0)].shape[0] 47 | / data_obs.shape[0] 48 | ) 49 | T0Y1 = ( 50 | data_obs.loc[(data_obs[T] == 0) & (data_obs[Y] == 1)].shape[0] 51 | / data_obs.shape[0] 52 | ) 53 | 54 | # Probabilities calculated from experimental data 55 | Y1doT1 = data_exp.loc[data_exp[T] == 1, Y].mean() 56 | Y1doT0 = data_exp.loc[data_exp[T] == 0, Y].mean() 57 | Y0doT0 = 1 - Y1doT0 58 | 59 | if type == "PNS": 60 | lb_args = [0, Y1doT1 - Y1doT0, Y1 - Y1doT0, Y1doT1 - Y1] 61 | 62 | ub_args = [Y1doT1, Y0doT0, T1Y1 + T0Y0, Y1doT1 - Y1doT0 + T1Y0 + T0Y1] 63 | 64 | if type == "PN": 65 | lb_args = [0, (Y1 - Y1doT0) / T1Y1] 66 | ub_args = [1, (Y0doT0 - T0Y0) / T1Y1] 67 | 68 | if type == "PS": 69 | lb_args = [0, (Y1doT1 - Y1) / T0Y0] 70 | ub_args = [1, (Y1doT1 - T1Y1) / T0Y0] 71 | 72 | lower_bound = max(lb_args) 73 | upper_bound = min(ub_args) 74 | 75 | return lower_bound, upper_bound 76 | -------------------------------------------------------------------------------- /causalml/optimize/policylearner.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | import numpy as np 4 | from causalml.propensity import compute_propensity_score 5 | from sklearn.ensemble import GradientBoostingRegressor, GradientBoostingClassifier 6 | from sklearn.model_selection import KFold 7 | from sklearn.tree import DecisionTreeClassifier 8 | 9 | 10 | logger = logging.getLogger("causalml") 11 | 12 | 13 | class PolicyLearner: 14 | """ 15 | A Learner that learns a treatment assignment policy with observational data using doubly robust estimator of causal 16 | effect for binary treatment. 17 | 18 | Details of the policy learner are available at `Athey and Wager (2018) <https://arxiv.org/abs/1702.02896>`_. 19 | 20 | """ 21 | 22 | def __init__( 23 | self, 24 | outcome_learner=GradientBoostingRegressor(), 25 | treatment_learner=GradientBoostingClassifier(), 26 | policy_learner=DecisionTreeClassifier(), 27 | clip_bounds=(1e-3, 1 - 1e-3), 28 | n_fold=5, 29 | random_state=None, 30 | calibration=False, 31 | ): 32 | """Initialize a treatment assignment policy learner. 33 | 34 | Args: 35 | outcome_learner (optional): a regression model to estimate outcomes 36 | policy_learner (optional): a classification model to estimate treatment assignment. It needs to take 37 | `sample_weight` as an input argument for `fit()` 38 | clip_bounds (tuple, optional): lower and upper bounds for clipping propensity scores to avoid division by 39 | zero in PolicyLearner.fit() 40 | n_fold (int, optional): the number of cross validation folds for outcome_learner 41 | random_state (int or RandomState, optional): a seed (int) or random number generator (RandomState) 42 | """ 43 | self.model_mu = outcome_learner 44 | self.model_w = treatment_learner 45 | self.model_pi = policy_learner 46 | self.clip_bounds = clip_bounds 47 | self.cv = KFold(n_splits=n_fold, shuffle=True, random_state=random_state) 48 | self.calibration = calibration 49 | 50 | self._y_pred, self._tau_pred, self._w_pred, self._dr_score = ( 51 | None, 52 | None, 53 | None, 54 | None, 55 | ) 56 | 57 | def __repr__(self): 58 | return ( 59 | "{}(model_mu={},\n" 60 | "\tmodel_w={},\n" 61 | "\tmodel_pi={})".format( 62 | self.__class__.__name__, 63 | self.model_mu.__repr__(), 64 | self.model_w.__repr__(), 65 | self.model_pi.__repr__(), 66 | ) 67 | ) 68 | 69 | def _outcome_estimate(self, X, w, y): 70 | self._y_pred = np.zeros(len(y)) 71 | self._tau_pred = np.zeros(len(y)) 72 | 73 | for train_index, test_index in self.cv.split(y): 74 | X_train, X_test = X[train_index], X[test_index] 75 | w_train, w_test = w[train_index], w[test_index] 76 | y_train, _ = y[train_index], y[test_index] 77 | 78 | self.model_mu.fit( 79 | np.concatenate([X_train, w_train.reshape(-1, 1)], axis=1), y_train 80 | ) 81 | self._y_pred[test_index] = self.model_mu.predict( 82 | np.concatenate([X_test, w_test.reshape(-1, 1)], axis=1) 83 | ) 84 | self._tau_pred[test_index] = self.model_mu.predict( 85 | np.concatenate([X_test, np.ones((len(w_test), 1))], axis=1) 86 | ) - self.model_mu.predict( 87 | np.concatenate([X_test, np.zeros((len(w_test), 1))], axis=1) 88 | ) 89 | 90 | def _treatment_estimate(self, X, w): 91 | self._w_pred = np.zeros(len(w)) 92 | 93 | for train_index, test_index in self.cv.split(w): 94 | X_train, X_test = X[train_index], X[test_index] 95 | w_train, w_test = w[train_index], w[test_index] 96 | 97 | self._w_pred[test_index], _ = compute_propensity_score( 98 | X=X_train, 99 | treatment=w_train, 100 | X_pred=X_test, 101 | treatment_pred=w_test, 102 | calibrate_p=self.calibration, 103 | ) 104 | 105 | self._w_pred = np.clip( 106 | self._w_pred, a_min=self.clip_bounds[0], a_max=self.clip_bounds[1] 107 | ) 108 | 109 | def fit(self, X, treatment, y, p=None, dhat=None): 110 | """Fit the treatment assignment policy learner. 111 | 112 | Args: 113 | X (np.matrix): a feature matrix 114 | treatment (np.array): a treatment vector (1 if treated, otherwise 0) 115 | y (np.array): an outcome vector 116 | p (optional, np.array): user provided propensity score vector between 0 and 1 117 | dhat (optinal, np.array): user provided predicted treatment effect vector 118 | 119 | Returns: 120 | self: returns an instance of self. 121 | """ 122 | 123 | logger.info( 124 | "generating out-of-fold CV outcome estimates with {}".format(self.model_mu) 125 | ) 126 | self._outcome_estimate(X, treatment, y) 127 | 128 | if dhat is not None: 129 | self._tau_pred = dhat 130 | 131 | if p is None: 132 | self._treatment_estimate(X, treatment) 133 | else: 134 | self._w_pred = np.clip(p, self.clip_bounds[0], self.clip_bounds[1]) 135 | 136 | # Doubly Robust Modification 137 | self._dr_score = self._tau_pred + (treatment - self._w_pred) / self._w_pred / ( 138 | 1 - self._w_pred 139 | ) * (y - self._y_pred) 140 | 141 | target = self._dr_score.copy() 142 | target = np.sign(target) 143 | 144 | logger.info("training the treatment assignment model, {}".format(self.model_pi)) 145 | self.model_pi.fit(X, target, sample_weight=abs(self._dr_score)) 146 | 147 | return self 148 | 149 | def predict(self, X): 150 | """Predict treatment assignment that optimizes the outcome. 151 | 152 | Args: 153 | X (np.matrix): a feature matrix 154 | 155 | Returns: 156 | (numpy.ndarray): predictions of treatment assignment. 157 | """ 158 | 159 | return self.model_pi.predict(X) 160 | 161 | def predict_proba(self, X): 162 | """Predict treatment assignment score that optimizes the outcome. 163 | 164 | Args: 165 | X (np.matrix): a feature matrix 166 | 167 | Returns: 168 | (numpy.ndarray): predictions of treatment assignment score. 169 | """ 170 | 171 | pi_hat = self.model_pi.predict_proba(X)[:, 1] 172 | 173 | return pi_hat 174 | -------------------------------------------------------------------------------- /causalml/optimize/utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | def get_treatment_costs(treatment, control_name, cc_dict, ic_dict): 5 | """ 6 | Set the conversion and impression costs based on a dict of parameters. 7 | 8 | Calculate the actual cost of targeting a user with the actual treatment 9 | group using the above parameters. 10 | 11 | Params 12 | ------ 13 | treatment : array, shape = (num_samples, ) 14 | Treatment array. 15 | 16 | control_name, str 17 | Control group name as string. 18 | 19 | cc_dict : dict 20 | Dict containing the conversion cost for each treatment. 21 | 22 | ic_dict 23 | Dict containing the impression cost for each treatment. 24 | 25 | Returns 26 | ------- 27 | conversion_cost : ndarray, shape = (num_samples, num_treatments) 28 | An array of conversion costs for each treatment. 29 | 30 | impression_cost : ndarray, shape = (num_samples, num_treatments) 31 | An array of impression costs for each treatment. 32 | 33 | conditions : list, len = len(set(treatment)) 34 | A list of experimental conditions. 35 | """ 36 | 37 | # Set the conversion costs of the treatments 38 | conversion_cost = np.zeros((len(treatment), len(cc_dict.keys()))) 39 | for idx, dict_key in enumerate(cc_dict.keys()): 40 | conversion_cost[:, idx] = cc_dict.get(dict_key) 41 | 42 | # Set the impression costs of the treatments 43 | impression_cost = np.zeros((len(treatment), len(ic_dict.keys()))) 44 | for idx, dict_key in enumerate(ic_dict.keys()): 45 | impression_cost[:, idx] = ic_dict.get(dict_key) 46 | 47 | # Get a sorted list of conditions 48 | conditions = list(set(treatment)) 49 | conditions.remove(control_name) 50 | conditions_sorted = sorted(conditions) 51 | conditions_sorted.insert(0, control_name) 52 | 53 | return conversion_cost, impression_cost, conditions_sorted 54 | 55 | 56 | def get_actual_value( 57 | treatment, 58 | observed_outcome, 59 | conversion_value, 60 | conditions, 61 | conversion_cost, 62 | impression_cost, 63 | ): 64 | """ 65 | Set the conversion and impression costs based on a dict of parameters. 66 | 67 | Calculate the actual value of targeting a user with the actual treatment group 68 | using the above parameters. 69 | 70 | Params 71 | ------ 72 | treatment : array, shape = (num_samples, ) 73 | Treatment array. 74 | 75 | observed_outcome : array, shape = (num_samples, ) 76 | Observed outcome array, aka y. 77 | 78 | conversion_value : array, shape = (num_samples, ) 79 | The value of converting a given user. 80 | 81 | conditions : list, len = len(set(treatment)) 82 | List of treatment conditions. 83 | 84 | conversion_cost : array, shape = (num_samples, num_treatment) 85 | Array of conversion costs for each unit in each treatment. 86 | 87 | impression_cost : array, shape = (num_samples, num_treatment) 88 | Array of impression costs for each unit in each treatment. 89 | 90 | Returns 91 | ------- 92 | actual_value : array, shape = (num_samples, ) 93 | Array of actual values of havng a user in their actual treatment group. 94 | 95 | conversion_value : array, shape = (num_samples, ) 96 | Array of payoffs from converting a user. 97 | """ 98 | 99 | cost_filter = [ 100 | actual_group == possible_group 101 | for actual_group in treatment 102 | for possible_group in conditions 103 | ] 104 | 105 | conversion_cost_flat = conversion_cost.flatten() 106 | actual_cc = conversion_cost_flat[cost_filter] 107 | impression_cost_flat = impression_cost.flatten() 108 | actual_ic = impression_cost_flat[cost_filter] 109 | 110 | # Calculate the actual value of having a user in their actual treatment 111 | actual_value = (conversion_value - actual_cc) * observed_outcome - actual_ic 112 | 113 | return actual_value 114 | 115 | 116 | def get_uplift_best(cate, conditions): 117 | """ 118 | Takes the CATE prediction from a learner, adds the control 119 | outcome array and finds the name of the argmax conditon. 120 | 121 | Params 122 | ------ 123 | cate : array, shape = (num_samples, ) 124 | The conditional average treatment effect prediction. 125 | 126 | conditions : list, len = len(set(treatment)) 127 | 128 | Returns 129 | ------- 130 | uplift_recomm_name : array, shape = (num_samples, ) 131 | The experimental group recommended by the learner. 132 | """ 133 | cate_with_control = np.c_[np.zeros(cate.shape[0]), cate] 134 | uplift_best_idx = np.argmax(cate_with_control, axis=1) 135 | uplift_best_name = [conditions[idx] for idx in uplift_best_idx] 136 | 137 | return uplift_best_name 138 | -------------------------------------------------------------------------------- /causalml/optimize/value_optimization.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | class CounterfactualValueEstimator: 5 | """ 6 | Args 7 | ---- 8 | treatment : array, shape = (num_samples, ) 9 | An array of treatment group indicator values. 10 | 11 | control_name : string 12 | The name of the control condition as a string. Must be contained in the treatment array. 13 | 14 | treatment_names : list, length = cate.shape[1] 15 | A list of treatment group names. NB: The order of the items in the 16 | list must correspond to the order in which the conditional average 17 | treatment effect estimates are in cate_array. 18 | 19 | y_proba : array, shape = (num_samples, ) 20 | The predicted probability of conversion using the Y ~ X model across 21 | the total sample. 22 | 23 | cate : array, shape = (num_samples, len(set(treatment))) 24 | Conditional average treatment effect estimations from any model. 25 | 26 | value : array, shape = (num_samples, ) 27 | Value of converting each unit. 28 | 29 | conversion_cost : shape = (num_samples, len(set(treatment))) 30 | The cost of a treatment that is triggered if a unit converts after having been in the treatment, such as a 31 | promotion code. 32 | 33 | impression_cost : shape = (num_samples, len(set(treatment))) 34 | The cost of a treatment that is the same for each unit whether or not they convert, such as a cost associated 35 | with a promotion channel. 36 | 37 | 38 | Notes 39 | ----- 40 | Because we get the conditional average treatment effects from 41 | cate-learners relative to the control condition, we subtract the 42 | cate for the unit in their actual treatment group from y_proba for that 43 | unit, in order to recover the control outcome. We then add the cates 44 | to the control outcome to obtain y_proba under each condition. These 45 | outcomes are counterfactual because just one of them is actually 46 | observed. 47 | """ 48 | 49 | def __init__( 50 | self, 51 | treatment, 52 | control_name, 53 | treatment_names, 54 | y_proba, 55 | cate, 56 | value, 57 | conversion_cost, 58 | impression_cost, 59 | *args, 60 | **kwargs, 61 | ): 62 | self.treatment = treatment 63 | self.control_name = control_name 64 | self.treatment_names = treatment_names 65 | self.y_proba = y_proba 66 | self.cate = cate 67 | self.value = value 68 | self.conversion_cost = conversion_cost 69 | self.impression_cost = impression_cost 70 | 71 | def predict_best(self): 72 | """ 73 | Predict the best treatment group based on the highest counterfactual 74 | value for a treatment. 75 | """ 76 | self._get_counterfactuals() 77 | self._get_counterfactual_values() 78 | return self.best_treatment 79 | 80 | def predict_counterfactuals(self): 81 | """ 82 | Predict the counterfactual values for each treatment group. 83 | """ 84 | self._get_counterfactuals() 85 | self._get_counterfactual_values() 86 | return self.expected_values 87 | 88 | def _get_counterfactuals(self): 89 | """ 90 | Get an array of counterfactual outcomes based on control outcome and 91 | the array of conditional average treatment effects. 92 | """ 93 | conditions = self.treatment_names.copy() 94 | conditions.insert(0, self.control_name) 95 | cates_with_control = np.c_[np.zeros(self.cate.shape[0]), self.cate] 96 | cates_flat = cates_with_control.flatten() 97 | 98 | cates_filt = [ 99 | actual_group == poss_group 100 | for actual_group in self.treatment 101 | for poss_group in conditions 102 | ] 103 | 104 | control_outcome = self.y_proba - cates_flat[cates_filt] 105 | self.counterfactuals = cates_with_control + control_outcome[:, None] 106 | 107 | def _get_counterfactual_values(self): 108 | """ 109 | Calculate the expected value of assigning a unit to each of the 110 | treatment conditions given the value of conversion and the conversion 111 | and impression costs associated with the treatment. 112 | """ 113 | 114 | self.expected_values = ( 115 | self.value[:, None] - self.conversion_cost 116 | ) * self.counterfactuals - self.impression_cost 117 | 118 | self.best_treatment = np.argmax(self.expected_values, axis=1) 119 | -------------------------------------------------------------------------------- /causalml/propensity.py: -------------------------------------------------------------------------------- 1 | from abc import ABCMeta, abstractmethod 2 | import logging 3 | import numpy as np 4 | from sklearn.metrics import roc_auc_score as auc 5 | from sklearn.linear_model import LogisticRegressionCV 6 | from sklearn.model_selection import StratifiedKFold, train_test_split 7 | from sklearn.isotonic import IsotonicRegression 8 | import xgboost as xgb 9 | 10 | 11 | logger = logging.getLogger("causalml") 12 | 13 | 14 | class PropensityModel(metaclass=ABCMeta): 15 | def __init__(self, clip_bounds=(1e-3, 1 - 1e-3), calibrate=True, **model_kwargs): 16 | """ 17 | Args: 18 | clip_bounds (tuple): lower and upper bounds for clipping propensity scores. Bounds should be implemented 19 | such that: 0 < lower < upper < 1, to avoid division by zero in BaseRLearner.fit_predict() step. 20 | calibrate (bool): whether calibrate the propensity score 21 | model_kwargs: Keyword arguments to be passed to the underlying classification model. 22 | """ 23 | self.clip_bounds = clip_bounds 24 | self.calibrate = calibrate 25 | self.model_kwargs = model_kwargs 26 | self.model = self._model 27 | self.calibrator = None 28 | 29 | @property 30 | @abstractmethod 31 | def _model(self): 32 | pass 33 | 34 | def __repr__(self): 35 | return self.model.__repr__() 36 | 37 | def fit(self, X, y): 38 | """ 39 | Fit a propensity model. 40 | 41 | Args: 42 | X (numpy.ndarray): a feature matrix 43 | y (numpy.ndarray): a binary target vector 44 | """ 45 | self.model.fit(X, y) 46 | if self.calibrate: 47 | # Fit a calibrator to the propensity scores with IsotonicRegression. 48 | # Ref: https://scikit-learn.org/stable/modules/isotonic.html 49 | self.calibrator = IsotonicRegression( 50 | out_of_bounds="clip", 51 | y_min=self.clip_bounds[0], 52 | y_max=self.clip_bounds[1], 53 | ) 54 | self.calibrator.fit(self.model.predict_proba(X)[:, 1], y) 55 | 56 | def predict(self, X): 57 | """ 58 | Predict propensity scores. 59 | 60 | Args: 61 | X (numpy.ndarray): a feature matrix 62 | 63 | Returns: 64 | (numpy.ndarray): Propensity scores between 0 and 1. 65 | """ 66 | p = self.model.predict_proba(X)[:, 1] 67 | if self.calibrate: 68 | p = self.calibrator.transform(p) 69 | 70 | return np.clip(p, *self.clip_bounds) 71 | 72 | def fit_predict(self, X, y): 73 | """ 74 | Fit a propensity model and predict propensity scores. 75 | 76 | Args: 77 | X (numpy.ndarray): a feature matrix 78 | y (numpy.ndarray): a binary target vector 79 | 80 | Returns: 81 | (numpy.ndarray): Propensity scores between 0 and 1. 82 | """ 83 | self.fit(X, y) 84 | propensity_scores = self.predict(X) 85 | return propensity_scores 86 | 87 | 88 | class LogisticRegressionPropensityModel(PropensityModel): 89 | """ 90 | Propensity regression model based on the LogisticRegression algorithm. 91 | """ 92 | 93 | @property 94 | def _model(self): 95 | kwargs = { 96 | "penalty": "elasticnet", 97 | "solver": "saga", 98 | "Cs": np.logspace(1e-3, 1 - 1e-3, 4), 99 | "l1_ratios": np.linspace(1e-3, 1 - 1e-3, 4), 100 | "cv": StratifiedKFold( 101 | n_splits=( 102 | self.model_kwargs.pop("n_fold") 103 | if "n_fold" in self.model_kwargs 104 | else 4 105 | ), 106 | shuffle=True, 107 | random_state=self.model_kwargs.get("random_state", 42), 108 | ), 109 | "random_state": 42, 110 | } 111 | kwargs.update(self.model_kwargs) 112 | 113 | return LogisticRegressionCV(**kwargs) 114 | 115 | 116 | class ElasticNetPropensityModel(LogisticRegressionPropensityModel): 117 | pass 118 | 119 | 120 | class GradientBoostedPropensityModel(PropensityModel): 121 | """ 122 | Gradient boosted propensity score model with optional early stopping. 123 | 124 | Notes 125 | ----- 126 | Please see the xgboost documentation for more information on gradient boosting tuning parameters: 127 | https://xgboost.readthedocs.io/en/latest/python/python_api.html 128 | """ 129 | 130 | def __init__( 131 | self, 132 | early_stop=False, 133 | clip_bounds=(1e-3, 1 - 1e-3), 134 | calibrate=True, 135 | **model_kwargs, 136 | ): 137 | self.early_stop = early_stop 138 | super().__init__(clip_bounds, calibrate, **model_kwargs) 139 | 140 | @property 141 | def _model(self): 142 | kwargs = { 143 | "max_depth": 8, 144 | "learning_rate": 0.1, 145 | "n_estimators": 100, 146 | "objective": "binary:logistic", 147 | "nthread": -1, 148 | "colsample_bytree": 0.8, 149 | "random_state": 42, 150 | } 151 | kwargs.update(self.model_kwargs) 152 | 153 | if self.early_stop: 154 | kwargs.update({"early_stopping_rounds": 10}) 155 | 156 | return xgb.XGBClassifier(**kwargs) 157 | 158 | def fit(self, X, y, stop_val_size=0.2): 159 | """ 160 | Fit a propensity model. 161 | 162 | Args: 163 | X (numpy.ndarray): a feature matrix 164 | y (numpy.ndarray): a binary target vector 165 | """ 166 | 167 | if self.early_stop: 168 | X_train, X_val, y_train, y_val = train_test_split( 169 | X, y, test_size=stop_val_size 170 | ) 171 | 172 | self.model.fit( 173 | X_train, 174 | y_train, 175 | eval_set=[(X_val, y_val)], 176 | ) 177 | if self.calibrate: 178 | self.calibrator = IsotonicRegression( 179 | out_of_bounds="clip", 180 | y_min=self.clip_bounds[0], 181 | y_max=self.clip_bounds[1], 182 | ) 183 | self.calibrator.fit(self.model.predict_proba(X)[:, 1], y) 184 | else: 185 | super().fit(X, y) 186 | 187 | 188 | def compute_propensity_score( 189 | X, 190 | treatment, 191 | p_model=None, 192 | X_pred=None, 193 | treatment_pred=None, 194 | calibrate_p=True, 195 | clip_bounds=(1e-3, 1 - 1e-3), 196 | ): 197 | """Generate propensity score if user didn't provide and optionally calibrate. 198 | 199 | Args: 200 | X (np.matrix): features for training 201 | treatment (np.array or pd.Series): a treatment vector for training 202 | p_model (model object, optional): a binary classifier with either a predict_proba or predict method 203 | X_pred (np.matrix, optional): features for prediction 204 | treatment_pred (np.array or pd.Series, optional): a treatment vector for prediciton 205 | calibrate_p (bool, optional): whether calibrate the propensity score 206 | clip_bounds (tuple, optional): lower and upper bounds for clipping propensity scores. Bounds should be implemented 207 | such that: 0 < lower < upper < 1, to avoid division by zero in BaseRLearner.fit_predict() step. 208 | 209 | Returns: 210 | (tuple) 211 | - p (numpy.ndarray): propensity score 212 | - p_model (PropensityModel): either the original p_model or a trained ElasticNetPropensityModel 213 | """ 214 | if treatment_pred is None: 215 | treatment_pred = treatment.copy() 216 | if p_model is None: 217 | p_model = ElasticNetPropensityModel( 218 | clip_bounds=clip_bounds, calibrate=calibrate_p 219 | ) 220 | 221 | p_model.fit(X, treatment) 222 | 223 | X_pred = X if X_pred is None else X_pred 224 | 225 | try: 226 | p = p_model.predict_proba(X_pred)[:, 1] 227 | except AttributeError: 228 | logger.info("predict_proba not available, using predict instead") 229 | p = p_model.predict(X_pred) 230 | 231 | return p, p_model 232 | -------------------------------------------------------------------------------- /docs/Makefile: -------------------------------------------------------------------------------- 1 | # Makefile for Sphinx documentation 2 | # 3 | 4 | # You can set these variables from the command line. 5 | SPHINXOPTS = 6 | SPHINXBUILD = sphinx-build 7 | PAPER = 8 | BUILDDIR = _build 9 | 10 | # User-friendly check for sphinx-build 11 | ifeq ($(shell which $(SPHINXBUILD) >/dev/null 2>&1; echo $?), 1) 12 | $(error The '$(SPHINXBUILD)' command was not found. Make sure you have Sphinx installed, then set the SPHINXBUILD environment variable to point to the full path of the '$(SPHINXBUILD)' executable. Alternatively you can add the directory with the executable to your PATH. If you don't have Sphinx installed, grab it from http://sphinx-doc.org/) 13 | endif 14 | 15 | # Internal variables. 16 | PAPEROPT_a4 = -D latex_paper_size=a4 17 | PAPEROPT_letter = -D latex_paper_size=letter 18 | ALLSPHINXOPTS = -d $(BUILDDIR)/doctrees $(PAPEROPT_$(PAPER)) $(SPHINXOPTS) . 19 | # the i18n builder cannot share the environment and doctrees with the others 20 | I18NSPHINXOPTS = $(PAPEROPT_$(PAPER)) $(SPHINXOPTS) . 21 | 22 | .PHONY: help clean html dirhtml singlehtml pickle json htmlhelp qthelp devhelp epub latex latexpdf text man changes linkcheck doctest gettext 23 | 24 | help: 25 | @echo "Please use \`make <target>' where <target> is one of" 26 | @echo " html to make standalone HTML files" 27 | @echo " dirhtml to make HTML files named index.html in directories" 28 | @echo " singlehtml to make a single large HTML file" 29 | @echo " pickle to make pickle files" 30 | @echo " json to make JSON files" 31 | @echo " htmlhelp to make HTML files and a HTML help project" 32 | @echo " qthelp to make HTML files and a qthelp project" 33 | @echo " devhelp to make HTML files and a Devhelp project" 34 | @echo " epub to make an epub" 35 | @echo " latex to make LaTeX files, you can set PAPER=a4 or PAPER=letter" 36 | @echo " latexpdf to make LaTeX files and run them through pdflatex" 37 | @echo " latexpdfja to make LaTeX files and run them through platex/dvipdfmx" 38 | @echo " text to make text files" 39 | @echo " man to make manual pages" 40 | @echo " texinfo to make Texinfo files" 41 | @echo " info to make Texinfo files and run them through makeinfo" 42 | @echo " gettext to make PO message catalogs" 43 | @echo " changes to make an overview of all changed/added/deprecated items" 44 | @echo " xml to make Docutils-native XML files" 45 | @echo " pseudoxml to make pseudoxml-XML files for display purposes" 46 | @echo " linkcheck to check all external links for integrity" 47 | @echo " doctest to run all doctests embedded in the documentation (if enabled)" 48 | 49 | clean: 50 | rm -rf $(BUILDDIR)/* 51 | 52 | html: 53 | $(SPHINXBUILD) -b html $(ALLSPHINXOPTS) $(BUILDDIR)/html 54 | @echo 55 | @echo "Build finished. The HTML pages are in $(BUILDDIR)/html." 56 | 57 | dirhtml: 58 | $(SPHINXBUILD) -b dirhtml $(ALLSPHINXOPTS) $(BUILDDIR)/dirhtml 59 | @echo 60 | @echo "Build finished. The HTML pages are in $(BUILDDIR)/dirhtml." 61 | 62 | singlehtml: 63 | $(SPHINXBUILD) -b singlehtml $(ALLSPHINXOPTS) $(BUILDDIR)/singlehtml 64 | @echo 65 | @echo "Build finished. The HTML page is in $(BUILDDIR)/singlehtml." 66 | 67 | pickle: 68 | $(SPHINXBUILD) -b pickle $(ALLSPHINXOPTS) $(BUILDDIR)/pickle 69 | @echo 70 | @echo "Build finished; now you can process the pickle files." 71 | 72 | json: 73 | $(SPHINXBUILD) -b json $(ALLSPHINXOPTS) $(BUILDDIR)/json 74 | @echo 75 | @echo "Build finished; now you can process the JSON files." 76 | 77 | htmlhelp: 78 | $(SPHINXBUILD) -b htmlhelp $(ALLSPHINXOPTS) $(BUILDDIR)/htmlhelp 79 | @echo 80 | @echo "Build finished; now you can run HTML Help Workshop with the" \ 81 | ".hhp project file in $(BUILDDIR)/htmlhelp." 82 | 83 | qthelp: 84 | $(SPHINXBUILD) -b qthelp $(ALLSPHINXOPTS) $(BUILDDIR)/qthelp 85 | @echo 86 | @echo "Build finished; now you can run "qcollectiongenerator" with the" \ 87 | ".qhcp project file in $(BUILDDIR)/qthelp, like this:" 88 | @echo "# qcollectiongenerator $(BUILDDIR)/qthelp/causalml.qhcp" 89 | @echo "To view the help file:" 90 | @echo "# assistant -collectionFile $(BUILDDIR)/qthelp/causalml.qhc" 91 | 92 | devhelp: 93 | $(SPHINXBUILD) -b devhelp $(ALLSPHINXOPTS) $(BUILDDIR)/devhelp 94 | @echo 95 | @echo "Build finished." 96 | @echo "To view the help file:" 97 | @echo "# mkdir -p $HOME/.local/share/devhelp/causalml" 98 | @echo "# ln -s $(BUILDDIR)/devhelp $HOME/.local/share/devhelp/causalml" 99 | @echo "# devhelp" 100 | 101 | epub: 102 | $(SPHINXBUILD) -b epub $(ALLSPHINXOPTS) $(BUILDDIR)/epub 103 | @echo 104 | @echo "Build finished. The epub file is in $(BUILDDIR)/epub." 105 | 106 | latex: 107 | $(SPHINXBUILD) -b latex $(ALLSPHINXOPTS) $(BUILDDIR)/latex 108 | @echo 109 | @echo "Build finished; the LaTeX files are in $(BUILDDIR)/latex." 110 | @echo "Run \`make' in that directory to run these through (pdf)latex" \ 111 | "(use \`make latexpdf' here to do that automatically)." 112 | 113 | latexpdf: 114 | $(SPHINXBUILD) -b latex $(ALLSPHINXOPTS) $(BUILDDIR)/latex 115 | @echo "Running LaTeX files through pdflatex..." 116 | $(MAKE) -C $(BUILDDIR)/latex all-pdf 117 | @echo "pdflatex finished; the PDF files are in $(BUILDDIR)/latex." 118 | 119 | latexpdfja: 120 | $(SPHINXBUILD) -b latex $(ALLSPHINXOPTS) $(BUILDDIR)/latex 121 | @echo "Running LaTeX files through platex and dvipdfmx..." 122 | $(MAKE) -C $(BUILDDIR)/latex all-pdf-ja 123 | @echo "pdflatex finished; the PDF files are in $(BUILDDIR)/latex." 124 | 125 | text: 126 | $(SPHINXBUILD) -b text $(ALLSPHINXOPTS) $(BUILDDIR)/text 127 | @echo 128 | @echo "Build finished. The text files are in $(BUILDDIR)/text." 129 | 130 | man: 131 | $(SPHINXBUILD) -b man $(ALLSPHINXOPTS) $(BUILDDIR)/man 132 | @echo 133 | @echo "Build finished. The manual pages are in $(BUILDDIR)/man." 134 | 135 | texinfo: 136 | $(SPHINXBUILD) -b texinfo $(ALLSPHINXOPTS) $(BUILDDIR)/texinfo 137 | @echo 138 | @echo "Build finished. The Texinfo files are in $(BUILDDIR)/texinfo." 139 | @echo "Run \`make' in that directory to run these through makeinfo" \ 140 | "(use \`make info' here to do that automatically)." 141 | 142 | info: 143 | $(SPHINXBUILD) -b texinfo $(ALLSPHINXOPTS) $(BUILDDIR)/texinfo 144 | @echo "Running Texinfo files through makeinfo..." 145 | make -C $(BUILDDIR)/texinfo info 146 | @echo "makeinfo finished; the Info files are in $(BUILDDIR)/texinfo." 147 | 148 | gettext: 149 | $(SPHINXBUILD) -b gettext $(I18NSPHINXOPTS) $(BUILDDIR)/locale 150 | @echo 151 | @echo "Build finished. The message catalogs are in $(BUILDDIR)/locale." 152 | 153 | changes: 154 | $(SPHINXBUILD) -b changes $(ALLSPHINXOPTS) $(BUILDDIR)/changes 155 | @echo 156 | @echo "The overview file is in $(BUILDDIR)/changes." 157 | 158 | linkcheck: 159 | $(SPHINXBUILD) -b linkcheck $(ALLSPHINXOPTS) $(BUILDDIR)/linkcheck 160 | @echo 161 | @echo "Link check complete; look for any errors in the above output " \ 162 | "or in $(BUILDDIR)/linkcheck/output.txt." 163 | 164 | doctest: 165 | $(SPHINXBUILD) -b doctest $(ALLSPHINXOPTS) $(BUILDDIR)/doctest 166 | @echo "Testing of doctests in the sources finished, look at the " \ 167 | "results in $(BUILDDIR)/doctest/output.txt." 168 | 169 | xml: 170 | $(SPHINXBUILD) -b xml $(ALLSPHINXOPTS) $(BUILDDIR)/xml 171 | @echo 172 | @echo "Build finished. The XML files are in $(BUILDDIR)/xml." 173 | 174 | pseudoxml: 175 | $(SPHINXBUILD) -b pseudoxml $(ALLSPHINXOPTS) $(BUILDDIR)/pseudoxml 176 | @echo 177 | @echo "Build finished. The pseudo-XML files are in $(BUILDDIR)/pseudoxml." 178 | -------------------------------------------------------------------------------- /docs/_static/img/auuc_table_vis.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/uber/causalml/5f5c4fb17cbadc79314e0fbc1dc0bc3f31f8a168/docs/_static/img/auuc_table_vis.png -------------------------------------------------------------------------------- /docs/_static/img/auuc_vis.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/uber/causalml/5f5c4fb17cbadc79314e0fbc1dc0bc3f31f8a168/docs/_static/img/auuc_vis.png -------------------------------------------------------------------------------- /docs/_static/img/counterfactual_value_optimization.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/uber/causalml/5f5c4fb17cbadc79314e0fbc1dc0bc3f31f8a168/docs/_static/img/counterfactual_value_optimization.png -------------------------------------------------------------------------------- /docs/_static/img/logo/android-chrome-192x192.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/uber/causalml/5f5c4fb17cbadc79314e0fbc1dc0bc3f31f8a168/docs/_static/img/logo/android-chrome-192x192.png -------------------------------------------------------------------------------- /docs/_static/img/logo/android-chrome-512x512.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/uber/causalml/5f5c4fb17cbadc79314e0fbc1dc0bc3f31f8a168/docs/_static/img/logo/android-chrome-512x512.png -------------------------------------------------------------------------------- /docs/_static/img/logo/apple-touch-icon.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/uber/causalml/5f5c4fb17cbadc79314e0fbc1dc0bc3f31f8a168/docs/_static/img/logo/apple-touch-icon.png -------------------------------------------------------------------------------- /docs/_static/img/logo/causalml_logo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/uber/causalml/5f5c4fb17cbadc79314e0fbc1dc0bc3f31f8a168/docs/_static/img/logo/causalml_logo.png -------------------------------------------------------------------------------- /docs/_static/img/logo/causalml_logo_square.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/uber/causalml/5f5c4fb17cbadc79314e0fbc1dc0bc3f31f8a168/docs/_static/img/logo/causalml_logo_square.png -------------------------------------------------------------------------------- /docs/_static/img/logo/causalml_logo_square_transparent.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/uber/causalml/5f5c4fb17cbadc79314e0fbc1dc0bc3f31f8a168/docs/_static/img/logo/causalml_logo_square_transparent.png -------------------------------------------------------------------------------- /docs/_static/img/logo/causalml_logo_transparent.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/uber/causalml/5f5c4fb17cbadc79314e0fbc1dc0bc3f31f8a168/docs/_static/img/logo/causalml_logo_transparent.png -------------------------------------------------------------------------------- /docs/_static/img/logo/favicon-16x16.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/uber/causalml/5f5c4fb17cbadc79314e0fbc1dc0bc3f31f8a168/docs/_static/img/logo/favicon-16x16.png -------------------------------------------------------------------------------- /docs/_static/img/logo/favicon-32x32.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/uber/causalml/5f5c4fb17cbadc79314e0fbc1dc0bc3f31f8a168/docs/_static/img/logo/favicon-32x32.png -------------------------------------------------------------------------------- /docs/_static/img/logo/favicon.ico: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/uber/causalml/5f5c4fb17cbadc79314e0fbc1dc0bc3f31f8a168/docs/_static/img/logo/favicon.ico -------------------------------------------------------------------------------- /docs/_static/img/meta_feature_imp_vis.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/uber/causalml/5f5c4fb17cbadc79314e0fbc1dc0bc3f31f8a168/docs/_static/img/meta_feature_imp_vis.png -------------------------------------------------------------------------------- /docs/_static/img/meta_shap_dependence_vis.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/uber/causalml/5f5c4fb17cbadc79314e0fbc1dc0bc3f31f8a168/docs/_static/img/meta_shap_dependence_vis.png -------------------------------------------------------------------------------- /docs/_static/img/meta_shap_vis.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/uber/causalml/5f5c4fb17cbadc79314e0fbc1dc0bc3f31f8a168/docs/_static/img/meta_shap_vis.png -------------------------------------------------------------------------------- /docs/_static/img/sensitivity_selection_bias_r2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/uber/causalml/5f5c4fb17cbadc79314e0fbc1dc0bc3f31f8a168/docs/_static/img/sensitivity_selection_bias_r2.png -------------------------------------------------------------------------------- /docs/_static/img/shap_vis.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/uber/causalml/5f5c4fb17cbadc79314e0fbc1dc0bc3f31f8a168/docs/_static/img/shap_vis.png -------------------------------------------------------------------------------- /docs/_static/img/synthetic_dgp_bar_plot_multiple.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/uber/causalml/5f5c4fb17cbadc79314e0fbc1dc0bc3f31f8a168/docs/_static/img/synthetic_dgp_bar_plot_multiple.png -------------------------------------------------------------------------------- /docs/_static/img/synthetic_dgp_scatter_plot.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/uber/causalml/5f5c4fb17cbadc79314e0fbc1dc0bc3f31f8a168/docs/_static/img/synthetic_dgp_scatter_plot.png -------------------------------------------------------------------------------- /docs/_static/img/synthetic_dgp_scatter_plot_multiple.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/uber/causalml/5f5c4fb17cbadc79314e0fbc1dc0bc3f31f8a168/docs/_static/img/synthetic_dgp_scatter_plot_multiple.png -------------------------------------------------------------------------------- /docs/_static/img/uplift_tree_feature_imp_vis.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/uber/causalml/5f5c4fb17cbadc79314e0fbc1dc0bc3f31f8a168/docs/_static/img/uplift_tree_feature_imp_vis.png -------------------------------------------------------------------------------- /docs/_static/img/uplift_tree_vis.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/uber/causalml/5f5c4fb17cbadc79314e0fbc1dc0bc3f31f8a168/docs/_static/img/uplift_tree_vis.png -------------------------------------------------------------------------------- /docs/about.rst: -------------------------------------------------------------------------------- 1 | About CausalML 2 | =========================== 3 | 4 | ``CausalML`` is a Python package that provides a suite of uplift modeling and causal inference methods using machine learning algorithms based on recent research. 5 | It provides a standard interface that allows user to estimate the **Conditional Average Treatment Effect** (CATE), also known as **Individual Treatment Effect** (ITE), from experimental or observational data. 6 | Essentially, it estimates the causal impact of intervention **W** on outcome **Y** for users with observed features **X**, without strong assumptions on the model form. 7 | 8 | GitHub Repo 9 | ----------- 10 | 11 | https://github.com/uber/causalml 12 | 13 | Mission 14 | ------- 15 | 16 | From the CausalML `Charter <https://github.com/uber/causalml/blob/master/CHARTER.md>`_: 17 | 18 | CausalML is committed to democratizing causal machine learning through accessible, innovative, and well-documented open-source tools that empower data scientists, researchers, and organizations. At our core, we embrace inclusivity and foster a vibrant community where members exchange ideas, share knowledge, and collaboratively shape a future where CausalML drives advancements across diverse domains. 19 | 20 | Contributing 21 | ------------ 22 | `Contributing.md <https://github.com/uber/causalml/blob/master/CONTRIBUTING.md>`_ 23 | 24 | Governance 25 | ---------- 26 | * `Charter <https://github.com/uber/causalml/blob/master/CHARTER.md>`_ 27 | * `Contributors <https://github.com/uber/causalml/graphs/contributors>`_ 28 | * `Maintainers <https://github.com/uber/causalml/blob/master/MAINTAINERS.md>`_ 29 | 30 | Intro to Causal Machine Learning 31 | ================================ 32 | 33 | What is Causal Machine Learning? 34 | -------------------------------- 35 | 36 | Causal machine learning is a branch of machine learning that focuses on understanding the cause and effect relationships in data. It goes beyond just predicting outcomes based on patterns in the data, and tries to understand how changing one variable can affect an outcome. 37 | Suppose we are trying to predict a student’s test score based on how many hours they study and how much sleep they get. Traditional machine learning models would find patterns in the data, like students who study more or sleep more tend to get higher scores. 38 | But what if you want to know what would happen if a student studied an extra hour each day? Or slept an extra hour each night? Modeling these potential outcomes or counterfactuals is where causal machine learning comes in. It tries to understand cause-and-effect relationships - how much changing one variable (like study hours or sleep hours) will affect the outcome (the test score). 39 | This is useful in many fields, including economics, healthcare, and policy making, where understanding the impact of interventions is crucial. 40 | While traditional machine learning is great for prediction, causal machine learning helps us understand the difference in outcomes due to interventions. 41 | 42 | 43 | 44 | Difference from Traditional Machine Learning 45 | -------------------------------------------- 46 | 47 | Traditional machine learning and causal machine learning are both powerful tools, but they serve different purposes and answer different types of questions. 48 | Traditional Machine Learning is primarily concerned with prediction. Given a set of input features, it learns a function from the data that can predict an outcome. It’s great at finding patterns and correlations in large datasets, but it doesn’t tell us about the cause-and-effect relationships between variables. It answers questions like “Given a patient’s symptoms, what disease are they likely to have?” 49 | On the other hand, Causal Machine Learning is concerned with understanding the cause-and-effect relationships between variables. It goes beyond prediction and tries to answer questions about intervention: “What will happen if we change this variable?” For example, in a medical context, it could help answer questions like “What will happen if a patient takes this medication?” 50 | In essence, while traditional machine learning can tell us “what is”, causal machine learning can help us understand “what if”. This makes causal machine learning particularly useful in fields where we need to make decisions based on data, such as policy making, economics, and healthcare. 51 | 52 | 53 | Measuring Causal Effects 54 | ------------------------ 55 | 56 | **Randomized Control Trials (RCT)** are the gold standard for causal effect measurements. Subjects are randomly exposed to a treatment and the Average Treatment Effect (ATE) is measured as the difference between the mean effects in the treatment and control groups. Random assignment removes the effect of any confounders on the treatment. 57 | 58 | If an RCT is available and the treatment effects are heterogeneous across covariates, measuring the conditional average treatment effect(CATE) can be of interest. The CATE is an estimate of the treatment effect conditioned on all available experiment covariates and confounders. We call these Heterogeneous Treatment Effects (HTEs). 59 | 60 | 61 | Example Use Cases 62 | ----------------- 63 | 64 | - **Campaign Targeting Optimization**: An important lever to increase ROI in an advertising campaign is to target the ad to the set of customers who will have a favorable response in a given KPI such as engagement or sales. CATE identifies these customers by estimating the effect of the KPI from ad exposure at the individual level from A/B experiment or historical observational data. 65 | 66 | - **Personalized Engagement**: A company might have multiple options to interact with its customers such as different product choices in up-sell or different messaging channels for communications. One can use CATE to estimate the heterogeneous treatment effect for each customer and treatment option combination for an optimal personalized engagement experience. 67 | 68 | -------------------------------------------------------------------------------- /docs/causalml.rst: -------------------------------------------------------------------------------- 1 | causalml package 2 | ================ 3 | 4 | Submodules 5 | ---------- 6 | 7 | causalml.inference.tree module 8 | ------------------------------ 9 | 10 | .. automodule:: causalml.inference.tree 11 | :members: 12 | :imported-members: 13 | :undoc-members: 14 | :show-inheritance: 15 | 16 | causalml.inference.meta module 17 | ------------------------------ 18 | 19 | .. automodule:: causalml.inference.meta 20 | :members: 21 | :imported-members: 22 | :undoc-members: 23 | :show-inheritance: 24 | 25 | causalml.inference.iv module 26 | ---------------------------- 27 | 28 | .. automodule:: causalml.inference.iv 29 | :members: 30 | :imported-members: 31 | :undoc-members: 32 | :show-inheritance: 33 | 34 | causalml.inference.nn module 35 | ---------------------------- 36 | 37 | .. automodule:: causalml.inference.nn 38 | :members: 39 | :imported-members: 40 | :undoc-members: 41 | :show-inheritance: 42 | 43 | causalml.inference.tf module 44 | ---------------------------- 45 | 46 | .. automodule:: causalml.inference.tf 47 | :members: 48 | :imported-members: 49 | :undoc-members: 50 | :show-inheritance: 51 | 52 | causalml.optimize module 53 | ------------------------ 54 | 55 | .. automodule:: causalml.optimize 56 | :members: 57 | :imported-members: 58 | :undoc-members: 59 | :show-inheritance: 60 | 61 | causalml.dataset module 62 | ----------------------- 63 | 64 | .. automodule:: causalml.dataset 65 | :members: 66 | :imported-members: 67 | :undoc-members: 68 | :show-inheritance: 69 | 70 | causalml.match module 71 | --------------------- 72 | 73 | .. automodule:: causalml.match 74 | :members: 75 | :undoc-members: 76 | :show-inheritance: 77 | 78 | causalml.propensity module 79 | -------------------------- 80 | 81 | .. automodule:: causalml.propensity 82 | :members: 83 | :undoc-members: 84 | :show-inheritance: 85 | 86 | causalml.metrics module 87 | ----------------------- 88 | 89 | .. automodule:: causalml.metrics 90 | :members: 91 | :imported-members: 92 | :undoc-members: 93 | :show-inheritance: 94 | 95 | causalml.feature_selection module 96 | --------------------------------- 97 | 98 | .. automodule:: causalml.feature_selection 99 | :members: 100 | :imported-members: 101 | :undoc-members: 102 | :show-inheritance: 103 | 104 | causalml.features module 105 | ------------------------ 106 | 107 | .. automodule:: causalml.features 108 | :members: 109 | :undoc-members: 110 | :show-inheritance: 111 | 112 | 113 | Module contents 114 | --------------- 115 | 116 | .. automodule:: causalml 117 | :members: 118 | :undoc-members: 119 | :show-inheritance: 120 | -------------------------------------------------------------------------------- /docs/environment-py39-rtd.yml: -------------------------------------------------------------------------------- 1 | name: causalml-rtd-py39 2 | channels: 3 | - conda-forge 4 | - defaults 5 | dependencies: 6 | - pip=24.0 7 | - python=3.9 8 | - pandoc 9 | - sphinx 10 | - sphinx_rtd_theme 11 | - sphinxcontrib-bibtex<2.0.0 12 | - nbsphinx 13 | - pip: 14 | - cython==0.29.34 15 | - dill==0.3.8 16 | - importlib-metadata==8.5.0 17 | - joblib==1.4.0 18 | - lightgbm==4.5.0 19 | - matplotlib==3.9.2 20 | - multiprocess==0.70.16 21 | - numba==0.60.0 22 | - numpy==1.26.4 23 | - pandas==2.2.2 24 | - pyro-api==0.1.2 25 | - pyro-ppl==1.9.1 26 | - scikit-learn==1.5.2 27 | - scipy==1.11.4 28 | - seaborn==0.13.2 29 | - shap==0.46.0 30 | - statsmodels==0.14.2 31 | - xgboost==2.1.3 32 | -------------------------------------------------------------------------------- /docs/examples.rst: -------------------------------------------------------------------------------- 1 | Examples 2 | ======== 3 | 4 | Working example notebooks are available in the `example folder <https://github.com/uber/causalml/tree/master/docs/examples>`_. 5 | 6 | Follow the below links for an approximate ordering of example tutorials from introductory to advanced features. 7 | 8 | .. toctree:: 9 | :maxdepth: 1 10 | 11 | examples/meta_learners_with_synthetic_data 12 | examples/uplift_trees_with_synthetic_data 13 | examples/meta_learners_with_synthetic_data_multiple_treatment 14 | examples/uplift_tree_visualization 15 | examples/feature_interpretations_example 16 | examples/validation_with_tmle 17 | examples/dragonnet_example 18 | examples/iv_nlsym_synthetic_data 19 | examples/sensitivity_example_with_synthetic_data 20 | examples/counterfactual_unit_selection 21 | examples/counterfactual_value_optimization 22 | examples/feature_selection 23 | examples/binary_policy_learner_example 24 | examples/cevae_example 25 | examples/dr_learner_with_synthetic_data 26 | examples/benchmark_simulation_studies 27 | examples/necessity_sufficiency_example 28 | examples/causal_trees_with_synthetic_data 29 | examples/causal_trees_interpretation 30 | examples/logistic_regression_based_data_generation_for_uplift_classification 31 | examples/qini_curves_for_costly_treatment_arms 32 | examples/calibration 33 | 34 | -------------------------------------------------------------------------------- /docs/index.rst: -------------------------------------------------------------------------------- 1 | Welcome to Causal ML's documentation 2 | ==================================== 3 | 4 | Contents: 5 | 6 | .. toctree:: 7 | :maxdepth: 2 8 | 9 | about 10 | installation 11 | quickstart 12 | examples 13 | methodology 14 | interpretation 15 | validation 16 | causalml 17 | references 18 | changelog 19 | 20 | 21 | Indices and tables 22 | ================== 23 | 24 | * :ref:`genindex` 25 | * :ref:`modindex` 26 | * :ref:`search` 27 | -------------------------------------------------------------------------------- /docs/installation.rst: -------------------------------------------------------------------------------- 1 | ============ 2 | Installation 3 | ============ 4 | 5 | Installation with ``conda`` or ``pip`` is recommended. Developers can follow the **Install from source** instructions below. If building from source, consider doing so within a conda environment and then exporting the environment for reproducibility. 6 | 7 | To use models under the ``inference.tf`` or ``inference.torch`` module (e.g. ``DragonNet`` or ``CEVAE``), additional dependency of ``tensorflow`` or ``torch`` is required. For detailed instructions, see below. 8 | 9 | Install using ``conda`` 10 | ----------------------- 11 | 12 | Install ``conda`` 13 | ^^^^^^^^^^^^^^^^^ 14 | 15 | .. code-block:: bash 16 | 17 | wget https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh 18 | bash Miniconda3-latest-Linux-x86_64.sh -b 19 | source miniconda3/bin/activate 20 | conda init 21 | source ~/.bashrc 22 | 23 | Install from ``conda-forge`` 24 | ^^^^^^^^^^^^^^^^^^^^^^^^^^^^ 25 | 26 | Directly install from the ``conda-forge`` channel using ``conda``. 27 | 28 | .. code-block:: bash 29 | 30 | conda install -c conda-forge causalml 31 | 32 | Install from ``PyPI`` 33 | --------------------- 34 | 35 | .. code-block:: bash 36 | 37 | pip install causalml 38 | 39 | Install ``causalml`` with ``tensorflow`` for ``DragonNet`` from ``PyPI`` 40 | ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ 41 | 42 | .. code-block:: bash 43 | 44 | pip install causalml[tf] 45 | 46 | Install ``causalml`` with ``torch`` for ``CEVAE`` from ``PyPI`` 47 | ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ 48 | 49 | .. code-block:: bash 50 | 51 | pip install causalml[torch] 52 | 53 | 54 | Install using `uv <https://github.com/astral-sh/uv/blob/main/README.md>`_ 55 | --------------------- 56 | 57 | .. code-block:: bash 58 | 59 | uv init 60 | uv add causalml 61 | 62 | Install ``causalml`` with ``tensorflow`` for ``DragonNet`` using `uv <https://github.com/astral-sh/uv/blob/main/README.md>`_ 63 | ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ 64 | 65 | .. code-block:: bash 66 | 67 | uv add "causalml[tf]" 68 | 69 | Install ``causalml`` with ``torch`` for ``CEVAE`` using `uv <https://github.com/astral-sh/uv/blob/main/README.md>`_ 70 | ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ 71 | 72 | .. code-block:: bash 73 | 74 | uv add "causalml[torch]" 75 | 76 | 77 | 78 | 79 | 80 | 81 | Install from source 82 | ------------------- 83 | 84 | [Optional] If you don't have Graphviz installed, you can install it using ``conda``, ``brew`` (on MacOS), or ``apt`` (on Linux). 85 | 86 | .. code-block:: bash 87 | 88 | conda install python-graphviz 89 | brew install graphviz # MacOS 90 | sudo apt-get install graphviz # Linux 91 | 92 | First, clone the repository and install the package: 93 | 94 | .. code-block:: bash 95 | 96 | git clone https://github.com/uber/causalml.git 97 | cd causalml 98 | pip install -e . 99 | 100 | with ``tensorflow`` for ``DragonNet``: 101 | 102 | .. code-block:: bash 103 | 104 | pip install -e ".[tf]" 105 | 106 | with ``torch`` for ``CEVAE``: 107 | 108 | .. code-block:: bash 109 | 110 | pip install -e ".[torch]" 111 | 112 | ======= 113 | 114 | Windows 115 | ------- 116 | 117 | See content in https://github.com/uber/causalml/issues/678 118 | 119 | 120 | Running Tests 121 | ------------- 122 | 123 | Make sure pytest is installed before attempting to run tests. 124 | 125 | .. code-block:: bash 126 | 127 | pip install -e ".[test]" 128 | 129 | Run all tests with: 130 | 131 | .. code-block:: bash 132 | 133 | pytest -vs tests/ --cov causalml/ 134 | 135 | Add ``--runtf`` and/or ``--runtorch`` to run optional tensorflow/torch tests which will be skipped by default. 136 | 137 | You can also run tests via make: 138 | 139 | .. code-block:: bash 140 | 141 | make test 142 | -------------------------------------------------------------------------------- /docs/interpretation.rst: -------------------------------------------------------------------------------- 1 | ======================= 2 | Interpretable Causal ML 3 | ======================= 4 | 5 | Causal ML provides methods to interpret the treatment effect models trained, where we provide more sample code in `feature_interpretations_example.ipynb notebook <https://github.com/uber/causalml/blob/master/docs/examples/feature_interpretations_example.ipynb>`_. 6 | 7 | Meta-Learner Feature Importances 8 | -------------------------------- 9 | 10 | .. code-block:: python 11 | 12 | from causalml.inference.meta import BaseSRegressor, BaseTRegressor, BaseXRegressor, BaseRRegressor 13 | 14 | slearner = BaseSRegressor(LGBMRegressor(), control_name='control') 15 | slearner.estimate_ate(X, w_multi, y) 16 | slearner_tau = slearner.fit_predict(X, w_multi, y) 17 | 18 | model_tau_feature = RandomForestRegressor() # specify model for model_tau_feature 19 | 20 | slearner.get_importance(X=X, tau=slearner_tau, model_tau_feature=model_tau_feature, 21 | normalize=True, method='auto', features=feature_names) 22 | 23 | # Using the feature_importances_ method in the base learner (LGBMRegressor() in this example) 24 | slearner.plot_importance(X=X, tau=slearner_tau, normalize=True, method='auto') 25 | 26 | # Using eli5's PermutationImportance 27 | slearner.plot_importance(X=X, tau=slearner_tau, normalize=True, method='permutation') 28 | 29 | # Using SHAP 30 | shap_slearner = slearner.get_shap_values(X=X, tau=slearner_tau) 31 | 32 | # Plot shap values without specifying shap_dict 33 | slearner.plot_shap_values(X=X, tau=slearner_tau) 34 | 35 | # Plot shap values WITH specifying shap_dict 36 | slearner.plot_shap_values(X=X, shap_dict=shap_slearner) 37 | 38 | # interaction_idx set to 'auto' (searches for feature with greatest approximate interaction) 39 | slearner.plot_shap_dependence(treatment_group='treatment_A', 40 | feature_idx=1, 41 | X=X, 42 | tau=slearner_tau, 43 | interaction_idx='auto') 44 | 45 | .. image:: ./_static/img/meta_feature_imp_vis.png 46 | :width: 629 47 | 48 | .. image:: ./_static/img/meta_shap_vis.png 49 | :width: 629 50 | 51 | .. image:: ./_static/img/meta_shap_dependence_vis.png 52 | :width: 629 53 | 54 | Uplift Tree Visualization 55 | ------------------------- 56 | 57 | .. code-block:: python 58 | 59 | from IPython.display import Image 60 | from causalml.inference.tree import UpliftTreeClassifier, UpliftRandomForestClassifier 61 | from causalml.inference.tree import uplift_tree_string, uplift_tree_plot 62 | from causalml.dataset import make_uplift_classification 63 | 64 | df, x_names = make_uplift_classification() 65 | uplift_model = UpliftTreeClassifier(max_depth=5, min_samples_leaf=200, min_samples_treatment=50, 66 | n_reg=100, evaluationFunction='KL', control_name='control') 67 | 68 | uplift_model.fit(df[x_names].values, 69 | treatment=df['treatment_group_key'].values, 70 | y=df['conversion'].values) 71 | 72 | graph = uplift_tree_plot(uplift_model.fitted_uplift_tree, x_names) 73 | Image(graph.create_png()) 74 | 75 | .. image:: ./_static/img/uplift_tree_vis.png 76 | :width: 629 77 | 78 | Please see below for how to read the plot, and `uplift_tree_visualization.ipynb example notebook <https://github.com/uber/causalml/blob/master/docs/examples/uplift_tree_visualization.ipynb>`_ is provided in the repo. 79 | 80 | - feature_name > threshold: For non-leaf node, the first line is an inequality indicating the splitting rule of this node to its children nodes. 81 | - impurity: the impurity is defined as the value of the split criterion function (such as KL, Chi, or ED) evaluated at this current node 82 | - total_sample: sample size in this node. 83 | - group_sample: sample sizes by treatment groups 84 | - uplift score: treatment effect in this node, if there are multiple treatment, it indicates the maximum (signed) of the treatment effects across all treatment vs control pairs. 85 | - uplift p_value: p value of the treatment effect in this node 86 | - validation uplift score: all the information above is static once the tree is trained (based on the trained trees), while the validation uplift score represents the treatment effect of the testing data when the method fill() is used. This score can be used as a comparison to the training uplift score, to evaluate if the tree has an overfitting issue. 87 | 88 | Uplift Tree Feature Importances 89 | ------------------------------- 90 | 91 | .. code-block:: python 92 | 93 | pd.Series(uplift_model.feature_importances_, index=x_names).sort_values().plot(kind='barh', figsize=(12,8)) 94 | 95 | .. image:: ./_static/img/uplift_tree_feature_imp_vis.png 96 | :width: 629 -------------------------------------------------------------------------------- /docs/references.rst: -------------------------------------------------------------------------------- 1 | References 2 | ========== 3 | 4 | Open Source Software Projects 5 | ----------------------------- 6 | 7 | Python Packages 8 | ~~~~~~~~~~~~~~~ 9 | 10 | - `DoWhy <https://github.com/Microsoft/dowhy>`_: a package for causal inference based on causal graphs. 11 | - `CausalLift <https://github.com/Minyus/causallift/>`_: a package for uplift modeling based on T-learner :cite:`kunzel2019metalearners`. 12 | - `PyLift <https://github.com/wayfair/pylift>`_: a package for uplift modeling based on the transformed outcome method in :cite:`athey2016recursive`. 13 | - `EconML <https://github.com/Microsoft/EconML>`_: a package for treatment effect estimation with orthogonal random forest :cite:`oprescu2018orthogonal`, DeepIV :cite:`hartford2017deep` and other ML methods. 14 | 15 | R Packages 16 | ~~~~~~~~~~ 17 | 18 | - `uplift <https://cran.r-project.org/web/packages/uplift/index.html>`_: a package for treatment effect estimation with ML. 19 | - `grf <https://github.com/grf-labs/grf>`_: a package for forest-based honest estimation from :cite:`athey2019generalized`. 20 | 21 | Papers 22 | ------ 23 | 24 | .. bibliography:: refs.bib 25 | :style: plain 26 | -------------------------------------------------------------------------------- /docs/requirements.txt: -------------------------------------------------------------------------------- 1 | Cython>=0.28.0 2 | numpy>=0.16.0 3 | scikit-learn 4 | matplotlib 5 | sphinx 6 | sphinx_rtd_theme 7 | sphinxcontrib-bibtex<2.0.0 8 | nbsphinx 9 | -------------------------------------------------------------------------------- /docs/validation.rst: -------------------------------------------------------------------------------- 1 | ========== 2 | Validation 3 | ========== 4 | 5 | Estimation of the treatment effect cannot be validated the same way as regular ML predictions because the true value is not available except for the experimental data. Here we focus on the internal validation methods under the assumption of unconfoundedness of potential outcomes and the treatment status conditioned on the feature set available to us. 6 | 7 | Validation with Multiple Estimates 8 | ---------------------------------- 9 | 10 | We can validate the methodology by comparing the estimates with other approaches, checking the consistency of estimates across different levels and cohorts. 11 | 12 | Model Robustness for Meta Algorithms 13 | ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ 14 | 15 | In meta-algorithms we can assess the quality of user-level treatment effect estimation by comparing estimates from different underlying ML algorithms. We will report MSE, coverage (overlapping 95% confidence interval), uplift curve. In addition, we can split the sample within a cohort and compare the result from out-of-sample scoring and within-sample scoring. 16 | 17 | User Level/Segment Level/Cohort Level Consistency 18 | ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ 19 | 20 | We can also evaluate user-level/segment level/cohort level estimation consistency by conducting T-test. 21 | 22 | Stability between Cohorts 23 | ~~~~~~~~~~~~~~~~~~~~~~~~~ 24 | 25 | Treatment effect may vary from cohort to cohort but should not be too volatile. For a given cohort, we will compare the scores generated by model fit to another score with the ones generated by its own model. 26 | 27 | Validation with Synthetic Data Sets 28 | ----------------------------------- 29 | 30 | We can test the methodology with simulations, where we generate data with known causal and non-causal links between the outcome, treatment and some of confounding variables. 31 | 32 | We implemented the following sets of synthetic data generation mechanisms based on :cite:`nie2017quasi`: 33 | 34 | Mechanism 1 35 | ~~~~~~~~~~~ 36 | 37 | | This generates a complex outcome regression model with easy treatment effect with input variables :math:`X_i \sim Unif(0, 1)^d`. 38 | | The treatment flag is a binomial variable, whose d.g.p. is: 39 | | 40 | | :math:`P(W_i = 1 | X_i) = trim_{0.1}(sin(\pi X_{i1} X_{i2})` 41 | | 42 | | With : 43 | | :math:`trim_\eta(x)=\max (\eta,\min (x,1-\eta))` 44 | | 45 | | The outcome variable is: 46 | | 47 | | :math:`y_i = sin(\pi X_{i1} X_{i2}) + 2(X_{i3} - 0.5)^2 + X_{i4} + 0.5 X_{i5} + (W_i - 0.5)(X_{i1} + X_{i2})/ 2 + \epsilon_i` 48 | | 49 | 50 | Mechanism 2 51 | ~~~~~~~~~~~ 52 | 53 | | This simulates a randomized trial. The input variables are generated by :math:`X_i \sim N(0, I_{d\times d})` 54 | | 55 | | The treatment flag is generated by a fair coin flip: 56 | | 57 | | :math:`P(W_i = 1|X_i) = 0.5` 58 | | 59 | | The outcome variable is 60 | | 61 | | :math:`y_i = max(X_{i1} + X_{i2}, X_{i3}, 0) + max(X_{i4} + X_{i5}, 0) + (W_i - 0.5)(X_{i1} + \log(1 + e^{X_{i2}}))` 62 | | 63 | 64 | Mechanism 3 65 | ~~~~~~~~~~~ 66 | 67 | | This one has an easy propensity score but a difficult control outcome. The input variables follow :math:`X_i \sim N(0, I_{d\times d})` 68 | | 69 | | The treatment flag is a binomial variable, whose d.g.p is: 70 | | 71 | | :math:`P(W_i = 1 | X_i) = \frac{1}{1+\exp{X_{i2} + X_{i3}}}` 72 | | 73 | | The outcome variable is: 74 | | 75 | | :math:`y_i = 2\log(1 + e^{X_{i1} + X_{i2} + X_{i3}}) + (W_i - 0.5)` 76 | | 77 | 78 | Mechanism 4 79 | ~~~~~~~~~~~ 80 | 81 | | This contains an unrelated treatment arm and control arm, with input data generated by :math:`X_i \sim N(0, I_{d\times d})`. 82 | | 83 | | The treatment flag is a binomial variable whose d.g.p. is: 84 | | 85 | | :math:`P(W_i = 1 | X_i) = \frac{1}{1+\exp{-X_{i1}} + \exp{-X_{i2}}}` 86 | | 87 | | The outcome variable is: 88 | | 89 | | :math:`y_i = \frac{1}{2}\big(max(X_{i1} + X_{i2} + X_{i3}, 0) + max(X_{i4} + X_{i5}, 0)\big) + (W_i - 0.5)(max(X_{i1} + X_{i2} + X_{i3}, 0) - max(X_{i4}, X_{i5}, 0))` 90 | | 91 | 92 | Validation with Uplift Curve (AUUC) 93 | ----------------------------------- 94 | 95 | We can validate the estimation by evaluating and comparing the uplift gains with AUUC (Area Under Uplift Curve), it calculates cumulative gains. Please find more details in `meta_learners_with_synthetic_data.ipynb example notebook <https://github.com/uber/causalml/blob/master/docs/examples/meta_learners_with_synthetic_data.ipynb>`_. 96 | 97 | .. code-block:: python 98 | 99 | from causalml.dataset import * 100 | from causalml.metrics import * 101 | # Single simulation 102 | train_preds, valid_preds = get_synthetic_preds_holdout(simulate_nuisance_and_easy_treatment, 103 | n=50000, 104 | valid_size=0.2) 105 | # Cumulative Gain AUUC values for a Single Simulation of Validation Data 106 | get_synthetic_auuc(valid_preds) 107 | 108 | 109 | .. image:: ./_static/img/auuc_table_vis.png 110 | :width: 629 111 | 112 | .. image:: ./_static/img/auuc_vis.png 113 | :width: 629 114 | 115 | For data with skewed treatment, it is sometimes advantageous to use :ref:`Targeted maximum likelihood estimation (TMLE) for ATE` to generate the AUUC curve for validation, as TMLE provides a more accurate estimation of ATE. Please find `validation_with_tmle.ipynb example notebook <https://github.com/uber/causalml/blob/master/docs/examples/validation_with_tmle.ipynb>`_ for details. 116 | 117 | Validation with Sensitivity Analysis 118 | ------------------------------------ 119 | Sensitivity analysis aim to check the robustness of the unconfoundeness assumption. If there is hidden bias (unobserved confounders), it determines how severe would have to be to change conclusion by examining the average treatment effect estimation. 120 | 121 | We implemented the following methods to conduct sensitivity analysis: 122 | 123 | Placebo Treatment 124 | ~~~~~~~~~~~~~~~~~ 125 | 126 | | Replace treatment with a random variable. 127 | 128 | Irrelevant Additional Confounder 129 | ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ 130 | 131 | | Add a random common cause variable. 132 | 133 | Subset validation 134 | ~~~~~~~~~~~~~~~~~ 135 | 136 | | Remove a random subset of the data. 137 | 138 | Random Replace 139 | ~~~~~~~~~~~~~~ 140 | 141 | | Random replace a covariate with an irrelevant variable. 142 | 143 | Selection Bias 144 | ~~~~~~~~~~~~~~ 145 | 146 | | `Blackwell(2013) <https://www.mattblackwell.org/files/papers/sens.pdf>` introduced an approach to sensitivity analysis for causal effects that directly models confounding or selection bias. 147 | | 148 | | One Sided Confounding Function: here as the name implies, this function can detect sensitivity to one-sided selection bias, but it would fail to detect other deviations from ignobility. That is, it can only determine the bias resulting from the treatment group being on average better off or the control group being on average better off. 149 | | 150 | | Alignment Confounding Function: this type of bias is likely to occur when units select into treatment and control based on their predicted treatment effects 151 | | 152 | | The sensitivity analysis is rigid in this way because the confounding function is not identified from the data, so that the causal model in the last section is only identified conditional on a specific choice of that function. The goal of the sensitivity analysis is not to choose the “correct” confounding function, since we have no way of evaluating this correctness. By its very nature, unmeasured confounding is unmeasured. Rather, the goal is to identify plausible deviations from ignobility and test sensitivity to those deviations. The main harm that results from the incorrect specification of the confounding function is that hidden biases remain hidden. 153 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [project] 2 | name = "causalml" 3 | version = "0.15.5" 4 | description = "Python Package for Uplift Modeling and Causal Inference with Machine Learning Algorithms" 5 | readme = { file = "README.md", content-type = "text/markdown" } 6 | 7 | authors = [ 8 | { "name" = "Huigang Chen" }, 9 | { "name" = "Totte Harinen" }, 10 | { "name" = "Jeong-Yoon Lee" }, 11 | { "name" = "Jing Pan" }, 12 | { "name" = "Mike Yung" }, 13 | { "name" = "Zhenyu Zhao" } 14 | ] 15 | maintainers = [ 16 | { name = "Jeong-Yoon Lee" } 17 | ] 18 | classifiers = [ 19 | "Programming Language :: Python", 20 | "License :: OSI Approved :: Apache Software License", 21 | "Operating System :: OS Independent", 22 | ] 23 | 24 | requires-python = ">=3.9" 25 | dependencies = [ 26 | "forestci==0.6", 27 | "pathos==0.2.9", 28 | "numpy>=1.18.5", 29 | "scipy>=1.4.1,<1.16.0", 30 | "matplotlib", 31 | "pandas>=0.24.1", 32 | "scikit-learn>=1.6.0", 33 | "statsmodels>=0.9.0", 34 | "seaborn", 35 | "xgboost", 36 | "pydotplus", 37 | "tqdm", 38 | "shap", 39 | "dill", 40 | "lightgbm", 41 | "packaging", 42 | "graphviz", 43 | ] 44 | 45 | [project.optional-dependencies] 46 | test = [ 47 | "pytest>=4.6", 48 | "pytest-cov>=4.0" 49 | ] 50 | tf = [ 51 | "tensorflow>=2.4.0" 52 | ] 53 | torch = [ 54 | "torch", 55 | "pyro-ppl" 56 | ] 57 | 58 | [build-system] 59 | requires = [ 60 | "setuptools>=18.0", 61 | "wheel", 62 | "Cython", 63 | "numpy>=1.18.5", 64 | "scikit-learn>=1.6.0", 65 | ] 66 | 67 | [project.urls] 68 | homepage = "https://github.com/uber/causalml" 69 | 70 | [tool.cibuildwheel] 71 | build = ["cp39-*", "cp310-*", "cp311-*", "cp312-*"] 72 | build-verbosity = 1 73 | # Skip 32-bit builds 74 | skip = ["*-win32", "*-manylinux_i686", "*-musllinux*"] 75 | -------------------------------------------------------------------------------- /setup.cfg: -------------------------------------------------------------------------------- 1 | [metadata] 2 | # This includes the license file(s) in the wheel. 3 | # https://wheel.readthedocs.io/en/stable/user_guide.html#including-license-files-in-the-generated-wheel-file 4 | license_files = LICENSE 5 | 6 | [bdist_wheel] 7 | # This flag says to generate wheels that support both Python 2 and Python 8 | # 3. If your code will not run unchanged on both Python 2 and 3, you will 9 | # need to generate separate wheels for each Python version that you 10 | # support. Removing this line (or setting universal to 0) will prevent 11 | # bdist_wheel from trying to make a universal wheel. For more see: 12 | # https://packaging.python.org/guides/distributing-packages-using-setuptools/#wheels 13 | universal=1 14 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | import multiprocessing as mp 2 | import os 3 | from setuptools import dist, setup, find_packages 4 | from setuptools.extension import Extension 5 | 6 | try: 7 | from Cython.Build import cythonize 8 | except ImportError: 9 | dist.Distribution().fetch_build_eggs(["cython"]) 10 | from Cython.Build import cythonize 11 | import Cython.Compiler.Options 12 | 13 | Cython.Compiler.Options.annotate = True 14 | 15 | try: 16 | import numpy as np 17 | except ImportError: 18 | dist.Distribution().fetch_build_eggs(["numpy"]) 19 | import numpy as np 20 | 21 | # fmt: off 22 | cython_modules = [ 23 | ("causalml.inference.tree._tree._tree", "causalml/inference/tree/_tree/_tree.pyx"), 24 | ("causalml.inference.tree._tree._criterion", "causalml/inference/tree/_tree/_criterion.pyx"), 25 | ("causalml.inference.tree._tree._splitter", "causalml/inference/tree/_tree/_splitter.pyx"), 26 | ("causalml.inference.tree._tree._utils", "causalml/inference/tree/_tree/_utils.pyx"), 27 | ("causalml.inference.tree.causal._criterion", "causalml/inference/tree/causal/_criterion.pyx"), 28 | ("causalml.inference.tree.causal._builder", "causalml/inference/tree/causal/_builder.pyx"), 29 | ("causalml.inference.tree.uplift", "causalml/inference/tree/uplift.pyx"), 30 | ] 31 | # fmt: on 32 | 33 | extensions = [ 34 | Extension( 35 | name, 36 | [source], 37 | libraries=[], 38 | include_dirs=[np.get_include()], 39 | extra_compile_args=["-O3"], 40 | ) 41 | for name, source in cython_modules 42 | ] 43 | 44 | packages = find_packages(exclude=["tests", "tests.*"]) 45 | 46 | nthreads = mp.cpu_count() 47 | if os.name == "nt": 48 | nthreads = 0 49 | else: 50 | mp.set_start_method("fork", force=True) 51 | 52 | setup( 53 | packages=packages, 54 | ext_modules=cythonize(extensions, annotate=True, nthreads=nthreads), 55 | include_dirs=[np.get_include()], 56 | setup_requires=[ 57 | "cython", 58 | "numpy", 59 | ], 60 | ) 61 | -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/uber/causalml/5f5c4fb17cbadc79314e0fbc1dc0bc3f31f8a168/tests/__init__.py -------------------------------------------------------------------------------- /tests/conftest.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pytest 3 | 4 | from causalml.dataset import synthetic_data 5 | from causalml.dataset import make_uplift_classification 6 | 7 | from .const import ( 8 | RANDOM_SEED, 9 | N_SAMPLE, 10 | TREATMENT_NAMES, 11 | CONVERSION, 12 | DELTA_UPLIFT_INCREASE_DICT, 13 | N_UPLIFT_INCREASE_DICT, 14 | ) 15 | 16 | 17 | @pytest.fixture(scope="module") 18 | def generate_regression_data(mode: int = 1, p: int = 8, sigma: float = 0.1): 19 | generated = False 20 | 21 | def _generate_data(mode: int = mode, p: int = p, sigma: float = sigma): 22 | if not generated: 23 | np.random.seed(RANDOM_SEED) 24 | data = synthetic_data(mode=mode, n=N_SAMPLE, p=p, sigma=sigma) 25 | 26 | return data 27 | 28 | yield _generate_data 29 | 30 | 31 | @pytest.fixture(scope="module") 32 | def generate_classification_data(): 33 | generated = False 34 | 35 | def _generate_data(): 36 | if not generated: 37 | np.random.seed(RANDOM_SEED) 38 | data = make_uplift_classification( 39 | n_samples=N_SAMPLE, 40 | treatment_name=TREATMENT_NAMES, 41 | y_name=CONVERSION, 42 | random_seed=RANDOM_SEED, 43 | ) 44 | 45 | return data 46 | 47 | yield _generate_data 48 | 49 | 50 | @pytest.fixture(scope="module") 51 | def generate_classification_data_two_treatments(): 52 | generated = False 53 | 54 | def _generate_data(): 55 | if not generated: 56 | np.random.seed(RANDOM_SEED) 57 | data = make_uplift_classification( 58 | n_samples=N_SAMPLE, 59 | treatment_name=TREATMENT_NAMES[0:2], 60 | y_name=CONVERSION, 61 | random_seed=RANDOM_SEED, 62 | delta_uplift_increase_dict=DELTA_UPLIFT_INCREASE_DICT, 63 | n_uplift_increase_dict=N_UPLIFT_INCREASE_DICT, 64 | ) 65 | 66 | return data 67 | 68 | yield _generate_data 69 | 70 | 71 | def pytest_addoption(parser): 72 | parser.addoption("--runtf", action="store_true", default=False, help="run tf tests") 73 | parser.addoption( 74 | "--runtorch", action="store_true", default=False, help="run torch tests" 75 | ) 76 | 77 | 78 | def pytest_configure(config): 79 | config.addinivalue_line("markers", "tf: mark test as tf to run") 80 | config.addinivalue_line("markers", "torch: mark test as torch to run") 81 | 82 | 83 | def pytest_collection_modifyitems(config, items): 84 | 85 | skip_tf = False if config.getoption("--runtf") else True 86 | skip_torch = False if config.getoption("--runtorch") else True 87 | 88 | for item in items: 89 | if "tf" in item.keywords and skip_tf: 90 | item.add_marker(pytest.mark.skip(reason="need --runtf option to run")) 91 | if "torch" in item.keywords and skip_torch: 92 | item.add_marker(pytest.mark.skip(reason="need --runtorch option to run")) 93 | -------------------------------------------------------------------------------- /tests/const.py: -------------------------------------------------------------------------------- 1 | RANDOM_SEED = 42 2 | N_SAMPLE = 1000 3 | ERROR_THRESHOLD = 0.5 4 | NUM_FEATURES = 6 5 | 6 | TREATMENT_COL = "treatment" 7 | SCORE_COL = "score" 8 | GROUP_COL = "group" 9 | OUTCOME_COL = "outcome" 10 | 11 | CONTROL_NAME = "control" 12 | TREATMENT_NAMES = [CONTROL_NAME, "treatment1", "treatment2", "treatment3"] 13 | CONVERSION = "conversion" 14 | DELTA_UPLIFT_INCREASE_DICT = { 15 | "treatment1": 0.25, 16 | } 17 | N_UPLIFT_INCREASE_DICT = {"treatment1": 2} 18 | -------------------------------------------------------------------------------- /tests/test_causal_trees.py: -------------------------------------------------------------------------------- 1 | import multiprocessing as mp 2 | from abc import abstractmethod 3 | 4 | import pandas as pd 5 | import pytest 6 | from sklearn.model_selection import train_test_split 7 | 8 | from causalml.inference.tree import CausalTreeRegressor, CausalRandomForestRegressor 9 | from causalml.metrics import ape 10 | from causalml.metrics import qini_score 11 | from .const import RANDOM_SEED, ERROR_THRESHOLD 12 | 13 | 14 | class CausalTreeBase: 15 | test_size: float = 0.2 16 | control_name: int = 0 17 | 18 | @abstractmethod 19 | def prepare_model(self, *args, **kwargs): 20 | return 21 | 22 | @abstractmethod 23 | def test_fit(self, *args, **kwargs): 24 | return 25 | 26 | @abstractmethod 27 | def test_predict(self, *args, **kwargs): 28 | return 29 | 30 | def prepare_data(self, generate_regression_data) -> tuple: 31 | y, X, treatment, tau, b, e = generate_regression_data(mode=2) 32 | df = pd.DataFrame(X) 33 | feature_names = [f"feature_{i}" for i in range(X.shape[1])] 34 | df.columns = feature_names 35 | df["outcome"] = y 36 | df["treatment"] = treatment 37 | df["treatment_effect"] = tau 38 | self.df_train, self.df_test = train_test_split( 39 | df, test_size=self.test_size, random_state=RANDOM_SEED 40 | ) 41 | X_train, X_test = ( 42 | self.df_train[feature_names].values, 43 | self.df_test[feature_names].values, 44 | ) 45 | y_train, y_test = ( 46 | self.df_train["outcome"].values, 47 | self.df_test["outcome"].values, 48 | ) 49 | treatment_train, treatment_test = ( 50 | self.df_train["treatment"].values, 51 | self.df_test["treatment"].values, 52 | ) 53 | return X_train, X_test, y_train, y_test, treatment_train, treatment_test 54 | 55 | 56 | class TestCausalTreeRegressor(CausalTreeBase): 57 | def prepare_model(self) -> CausalTreeRegressor: 58 | ctree = CausalTreeRegressor( 59 | control_name=self.control_name, groups_cnt=True, random_state=RANDOM_SEED 60 | ) 61 | return ctree 62 | 63 | def test_fit(self, generate_regression_data): 64 | ctree = self.prepare_model() 65 | ( 66 | X_train, 67 | X_test, 68 | y_train, 69 | y_test, 70 | treatment_train, 71 | treatment_test, 72 | ) = self.prepare_data(generate_regression_data) 73 | ctree.fit(X=X_train, treatment=treatment_train, y=y_train) 74 | df_result = pd.DataFrame( 75 | { 76 | "ctree_ite_pred": ctree.predict(X_test), 77 | "outcome": y_test, 78 | "is_treated": treatment_test, 79 | "treatment_effect": self.df_test["treatment_effect"], 80 | } 81 | ) 82 | df_qini = qini_score( 83 | df_result, 84 | outcome_col="outcome", 85 | treatment_col="is_treated", 86 | treatment_effect_col="treatment_effect", 87 | ) 88 | assert df_qini["ctree_ite_pred"] > 0.0 89 | 90 | @pytest.mark.parametrize("return_ci", (False, True)) 91 | @pytest.mark.parametrize("bootstrap_size", (500, 800)) 92 | @pytest.mark.parametrize("n_bootstraps", (1000,)) 93 | def test_fit_predict( 94 | self, generate_regression_data, return_ci, bootstrap_size, n_bootstraps 95 | ): 96 | y, X, treatment, tau, b, e = generate_regression_data(mode=1) 97 | ctree = self.prepare_model() 98 | output = ctree.fit_predict( 99 | X=X, 100 | treatment=treatment, 101 | y=y, 102 | return_ci=return_ci, 103 | n_bootstraps=n_bootstraps, 104 | bootstrap_size=bootstrap_size, 105 | n_jobs=mp.cpu_count() - 1, 106 | verbose=False, 107 | ) 108 | if return_ci: 109 | te, te_lower, te_upper = output 110 | assert len(output) == 3 111 | assert (te_lower <= te).all() and (te_upper >= te).all() 112 | else: 113 | te = output 114 | assert te.shape[0] == y.shape[0] 115 | 116 | def test_predict(self, generate_regression_data): 117 | y, X, treatment, tau, b, e = generate_regression_data(mode=2) 118 | ctree = self.prepare_model() 119 | ctree.fit(X=X, treatment=treatment, y=y) 120 | y_pred = ctree.predict(X[:1, :]) 121 | y_pred_with_outcomes = ctree.predict(X[:1, :], with_outcomes=True) 122 | assert y_pred.shape == (1,) 123 | assert y_pred_with_outcomes.shape == (1, 3) 124 | 125 | def test_ate(self, generate_regression_data): 126 | y, X, treatment, tau, b, e = generate_regression_data(mode=2) 127 | ctree = self.prepare_model() 128 | ate, ate_lower, ate_upper = ctree.estimate_ate(X=X, y=y, treatment=treatment) 129 | assert (ate >= ate_lower) and (ate <= ate_upper) 130 | assert ape(tau.mean(), ate) < ERROR_THRESHOLD 131 | 132 | 133 | class TestCausalRandomForestRegressor(CausalTreeBase): 134 | def prepare_model(self, n_estimators: int) -> CausalRandomForestRegressor: 135 | crforest = CausalRandomForestRegressor( 136 | criterion="causal_mse", 137 | control_name=self.control_name, 138 | n_estimators=n_estimators, 139 | n_jobs=mp.cpu_count() - 1, 140 | ) 141 | return crforest 142 | 143 | @pytest.mark.parametrize("n_estimators", (5, 10, 50)) 144 | def test_fit(self, generate_regression_data, n_estimators): 145 | crforest = self.prepare_model(n_estimators=n_estimators) 146 | ( 147 | X_train, 148 | X_test, 149 | y_train, 150 | y_test, 151 | treatment_train, 152 | treatment_test, 153 | ) = self.prepare_data(generate_regression_data) 154 | crforest.fit(X=X_train, treatment=treatment_train, y=y_train) 155 | 156 | df_result = pd.DataFrame( 157 | { 158 | "crforest_ite_pred": crforest.predict(X_test), 159 | "is_treated": treatment_test, 160 | "treatment_effect": self.df_test["treatment_effect"], 161 | } 162 | ) 163 | df_qini = qini_score( 164 | df_result, 165 | outcome_col="outcome", 166 | treatment_col="is_treated", 167 | treatment_effect_col="treatment_effect", 168 | ) 169 | assert df_qini["crforest_ite_pred"] > 0.0 170 | 171 | @pytest.mark.parametrize("n_estimators", (5,)) 172 | def test_predict(self, generate_regression_data, n_estimators): 173 | y, X, treatment, tau, b, e = generate_regression_data(mode=2) 174 | ctree = self.prepare_model(n_estimators=n_estimators) 175 | ctree.fit(X=X, y=y, treatment=treatment) 176 | y_pred = ctree.predict(X[:1, :]) 177 | y_pred_with_outcomes = ctree.predict(X[:1, :], with_outcomes=True) 178 | assert y_pred.shape == (1,) 179 | assert y_pred_with_outcomes.shape == (1, 3) 180 | 181 | @pytest.mark.parametrize("n_estimators", (5,)) 182 | def test_unbiased_sampling_error(self, generate_regression_data, n_estimators): 183 | crforest = self.prepare_model(n_estimators=n_estimators) 184 | ( 185 | X_train, 186 | X_test, 187 | y_train, 188 | y_test, 189 | treatment_train, 190 | treatment_test, 191 | ) = self.prepare_data(generate_regression_data) 192 | crforest.fit(X=X_train, treatment=treatment_train, y=y_train) 193 | crforest_test_var = crforest.calculate_error(X_train=X_train, X_test=X_test) 194 | assert (crforest_test_var > 0).all() 195 | assert crforest_test_var.shape[0] == y_test.shape[0] 196 | -------------------------------------------------------------------------------- /tests/test_cevae.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | import pytest 3 | 4 | try: 5 | import torch 6 | from causalml.inference.torch import CEVAE 7 | except ImportError: 8 | pass 9 | from causalml.dataset import simulate_hidden_confounder 10 | from causalml.metrics import get_cumgain 11 | 12 | 13 | @pytest.mark.torch 14 | def test_CEVAE(): 15 | y, X, treatment, tau, b, e = simulate_hidden_confounder( 16 | n=10000, p=5, sigma=1.0, adj=0.0 17 | ) 18 | 19 | outcome_dist = "normal" 20 | latent_dim = 20 21 | hidden_dim = 200 22 | num_epochs = 50 23 | batch_size = 100 24 | learning_rate = 1e-3 25 | learning_rate_decay = 0.1 26 | 27 | cevae = CEVAE( 28 | outcome_dist=outcome_dist, 29 | latent_dim=latent_dim, 30 | hidden_dim=hidden_dim, 31 | num_epochs=num_epochs, 32 | batch_size=batch_size, 33 | learning_rate=learning_rate, 34 | learning_rate_decay=learning_rate_decay, 35 | ) 36 | 37 | cevae.fit( 38 | X=torch.tensor(X, dtype=torch.float), 39 | treatment=torch.tensor(treatment, dtype=torch.float), 40 | y=torch.tensor(y, dtype=torch.float), 41 | ) 42 | 43 | # check the accuracy of the ite accuracy 44 | ite = cevae.predict(X).flatten() 45 | 46 | auuc_metrics = pd.DataFrame( 47 | {"ite": ite, "W": treatment, "y": y, "treatment_effect_col": tau} 48 | ) 49 | 50 | cumgain = get_cumgain( 51 | auuc_metrics, outcome_col="y", treatment_col="W", treatment_effect_col="tau" 52 | ) 53 | 54 | # Check if the cumulative gain when using the model's prediction is 55 | # higher than it would be under random targeting 56 | assert cumgain["ite"].sum() > cumgain["Random"].sum() 57 | -------------------------------------------------------------------------------- /tests/test_counterfactual_unit_selection.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pandas as pd 3 | 4 | from sklearn.model_selection import train_test_split 5 | from sklearn.linear_model import LogisticRegressionCV 6 | 7 | from causalml.dataset import make_uplift_classification 8 | from causalml.optimize.unit_selection import CounterfactualUnitSelector 9 | from causalml.optimize.utils import get_treatment_costs 10 | from causalml.optimize.utils import get_actual_value 11 | 12 | from tests.const import RANDOM_SEED 13 | 14 | 15 | def test_counterfactual_unit_selection(): 16 | df, X_names = make_uplift_classification( 17 | n_samples=2000, treatment_name=["control", "treatment"] 18 | ) 19 | df["treatment_numeric"] = df["treatment_group_key"].replace( 20 | {"control": 0, "treatment": 1} 21 | ) 22 | df_train, df_test = train_test_split(df, test_size=0.2, random_state=RANDOM_SEED) 23 | 24 | train_idx = df_train.index 25 | test_idx = df_test.index 26 | 27 | conversion_cost_dict = {"control": 0, "treatment": 2.5} 28 | impression_cost_dict = {"control": 0, "treatment": 0} 29 | 30 | cc_array, ic_array, conditions = get_treatment_costs( 31 | treatment=df["treatment_group_key"], 32 | control_name="control", 33 | cc_dict=conversion_cost_dict, 34 | ic_dict=impression_cost_dict, 35 | ) 36 | conversion_value_array = np.full(df.shape[0], 20) 37 | 38 | actual_value = get_actual_value( 39 | treatment=df["treatment_group_key"], 40 | observed_outcome=df["conversion"], 41 | conversion_value=conversion_value_array, 42 | conditions=conditions, 43 | conversion_cost=cc_array, 44 | impression_cost=ic_array, 45 | ) 46 | 47 | random_allocation_value = actual_value.loc[test_idx].mean() 48 | 49 | nevertaker_payoff = 0 50 | alwaystaker_payoff = -2.5 51 | complier_payoff = 17.5 52 | defier_payoff = -20 53 | 54 | cus = CounterfactualUnitSelector( 55 | learner=LogisticRegressionCV(), 56 | nevertaker_payoff=nevertaker_payoff, 57 | alwaystaker_payoff=alwaystaker_payoff, 58 | complier_payoff=complier_payoff, 59 | defier_payoff=defier_payoff, 60 | ) 61 | 62 | cus.fit( 63 | data=df_train.drop("treatment_group_key", axis=1), 64 | treatment="treatment_numeric", 65 | outcome="conversion", 66 | ) 67 | 68 | cus_pred = cus.predict( 69 | data=df_test.drop("treatment_group_key", axis=1), 70 | treatment="treatment_numeric", 71 | outcome="conversion", 72 | ) 73 | 74 | best_cus = np.where(cus_pred > 0, 1, 0) 75 | actual_is_cus = df_test["treatment_numeric"] == best_cus.ravel() 76 | cus_value = actual_value.loc[test_idx][actual_is_cus].mean() 77 | 78 | assert cus_value > random_allocation_value 79 | -------------------------------------------------------------------------------- /tests/test_datasets.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | from causalml.dataset import ( 4 | simulate_nuisance_and_easy_treatment, 5 | simulate_hidden_confounder, 6 | simulate_randomized_trial, 7 | ) 8 | from causalml.dataset import ( 9 | get_synthetic_preds, 10 | get_synthetic_summary, 11 | get_synthetic_auuc, 12 | ) 13 | from causalml.dataset import get_synthetic_preds_holdout, get_synthetic_summary_holdout 14 | from causalml.inference.meta import LRSRegressor, XGBTRegressor 15 | 16 | 17 | @pytest.mark.parametrize( 18 | "synthetic_data_func", 19 | [ 20 | simulate_nuisance_and_easy_treatment, 21 | simulate_hidden_confounder, 22 | simulate_randomized_trial, 23 | ], 24 | ) 25 | def test_get_synthetic_preds(synthetic_data_func): 26 | preds_dict = get_synthetic_preds( 27 | synthetic_data_func=synthetic_data_func, 28 | n=1000, 29 | estimators={ 30 | "S Learner (LR)": LRSRegressor(), 31 | "T Learner (XGB)": XGBTRegressor(), 32 | }, 33 | ) 34 | 35 | assert ( 36 | preds_dict["S Learner (LR)"].shape[0] == preds_dict["T Learner (XGB)"].shape[0] 37 | ) 38 | 39 | 40 | def test_get_synthetic_summary(): 41 | summary = get_synthetic_summary( 42 | synthetic_data_func=simulate_nuisance_and_easy_treatment, 43 | estimators={ 44 | "S Learner (LR)": LRSRegressor(), 45 | "T Learner (XGB)": XGBTRegressor(), 46 | }, 47 | ) 48 | 49 | print(summary) 50 | 51 | 52 | def test_get_synthetic_preds_holdout(): 53 | preds_train, preds_valid = get_synthetic_preds_holdout( 54 | synthetic_data_func=simulate_nuisance_and_easy_treatment, 55 | n=1000, 56 | estimators={ 57 | "S Learner (LR)": LRSRegressor(), 58 | "T Learner (XGB)": XGBTRegressor(), 59 | }, 60 | ) 61 | 62 | assert ( 63 | preds_train["S Learner (LR)"].shape[0] 64 | == preds_train["T Learner (XGB)"].shape[0] 65 | ) 66 | assert ( 67 | preds_valid["S Learner (LR)"].shape[0] 68 | == preds_valid["T Learner (XGB)"].shape[0] 69 | ) 70 | 71 | 72 | def test_get_synthetic_summary_holdout(): 73 | summary = get_synthetic_summary_holdout( 74 | synthetic_data_func=simulate_nuisance_and_easy_treatment 75 | ) 76 | 77 | print(summary) 78 | 79 | 80 | def test_get_synthetic_auuc(): 81 | preds_dict = get_synthetic_preds( 82 | synthetic_data_func=simulate_nuisance_and_easy_treatment, 83 | n=1000, 84 | estimators={ 85 | "S Learner (LR)": LRSRegressor(), 86 | "T Learner (XGB)": XGBTRegressor(), 87 | }, 88 | ) 89 | 90 | auuc_df = get_synthetic_auuc(preds_dict, plot=False) 91 | print(auuc_df) 92 | -------------------------------------------------------------------------------- /tests/test_dragonnet.py: -------------------------------------------------------------------------------- 1 | try: 2 | from causalml.inference.tf import DragonNet 3 | except ImportError: 4 | pass 5 | from causalml.dataset.regression import simulate_nuisance_and_easy_treatment 6 | import pytest 7 | 8 | 9 | @pytest.mark.tf 10 | def test_save_load_dragonnet(tmp_path): 11 | y, X, w, tau, b, e = simulate_nuisance_and_easy_treatment(n=1000) 12 | 13 | dragon = DragonNet(neurons_per_layer=200, targeted_reg=True, verbose=False) 14 | dragon_ite = dragon.fit_predict(X, w, y, return_components=False) 15 | dragon_ate = dragon_ite.mean() 16 | 17 | model_file = tmp_path / "smaug.h5" 18 | dragon.save(model_file) 19 | 20 | smaug = DragonNet() 21 | smaug.load(model_file) 22 | 23 | assert smaug.predict_tau(X).mean() == dragon_ate 24 | -------------------------------------------------------------------------------- /tests/test_feature_selection.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from causalml.feature_selection.filters import FilterSelect 3 | 4 | from .const import RANDOM_SEED, CONVERSION 5 | 6 | 7 | def test_filter_f(generate_classification_data): 8 | # generate uplift classification data 9 | np.random.seed(RANDOM_SEED) 10 | df, X_names = generate_classification_data() 11 | y_name = CONVERSION 12 | 13 | # test F filter 14 | method = "F" 15 | filter_f = FilterSelect() 16 | f_imp = filter_f.get_importance( 17 | df, X_names, y_name, method, treatment_group="treatment1" 18 | ) 19 | 20 | # each row represents the rank and importance score of each feature 21 | # and spot check if it's sorted properly 22 | assert f_imp.shape[0] == len(X_names) 23 | assert f_imp["rank"].values[0] == 1 24 | assert f_imp["score"].values[0] >= f_imp["score"].values[1] 25 | 26 | 27 | def test_filter_lr(generate_classification_data): 28 | # generate uplift classification data 29 | np.random.seed(RANDOM_SEED) 30 | df, X_names = generate_classification_data() 31 | y_name = CONVERSION 32 | 33 | # test LR filter 34 | method = "LR" 35 | filter_obj = FilterSelect() 36 | imp = filter_obj.get_importance( 37 | df, X_names, y_name, method, treatment_group="treatment1" 38 | ) 39 | 40 | # each row represents the rank and importance score of each feature 41 | # and spot check if it's sorted properly 42 | assert imp.shape[0] == len(X_names) 43 | assert imp["rank"].values[0] == 1 44 | assert imp["score"].values[0] >= imp["score"].values[1] 45 | 46 | 47 | def test_filter_kl(generate_classification_data): 48 | # generate uplift classification data 49 | np.random.seed(RANDOM_SEED) 50 | df, X_names = generate_classification_data() 51 | y_name = CONVERSION 52 | 53 | # test KL filter 54 | method = "KL" 55 | filter_obj = FilterSelect() 56 | imp = filter_obj.get_importance( 57 | df, X_names, y_name, method, treatment_group="treatment1" 58 | ) 59 | 60 | # each row represents the rank and importance score of each feature 61 | # and spot check if it's sorted properly 62 | assert imp.shape[0] == len(X_names) 63 | assert imp["rank"].values[0] == 1 64 | assert imp["score"].values[0] >= imp["score"].values[1] 65 | -------------------------------------------------------------------------------- /tests/test_features.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | import pytest 3 | from causalml.features import OneHotEncoder, LabelEncoder, load_data 4 | 5 | 6 | @pytest.fixture 7 | def generate_categorical_data(): 8 | generated = False 9 | 10 | def _generate_data(): 11 | if not generated: 12 | df = pd.DataFrame( 13 | { 14 | "cat1": ["a", "a", "b", "a", "c", "b", "d"], 15 | "cat2": ["aa", "aa", "aa", "bb", "bb", "bb", "cc"], 16 | "num1": [1, 2, 1, 2, 1, 1, 1], 17 | } 18 | ) 19 | 20 | return df 21 | 22 | yield _generate_data 23 | 24 | 25 | def test_load_data(generate_categorical_data): 26 | df = generate_categorical_data() 27 | 28 | features = load_data(df, df.columns) 29 | 30 | assert df.shape[0] == features.shape[0] 31 | 32 | 33 | def test_LabelEncoder(generate_categorical_data): 34 | df = generate_categorical_data() 35 | cat_cols = [col for col in df.columns if df[col].dtype == "object"] 36 | n_category = 0 37 | for col in cat_cols: 38 | n_category += df[col].nunique() 39 | 40 | lbe = LabelEncoder(min_obs=2) 41 | X_cat = lbe.fit_transform(df[cat_cols]) 42 | n_label = 0 43 | for col in cat_cols: 44 | n_label += X_cat[col].nunique() 45 | 46 | assert df.shape[0] == X_cat.shape[0] and n_label < n_category 47 | 48 | 49 | def test_OneHotEncoder(generate_categorical_data): 50 | df = generate_categorical_data() 51 | cat_cols = [col for col in df.columns if df[col].dtype == "object"] 52 | n_category = 0 53 | for col in cat_cols: 54 | n_category += df[col].nunique() 55 | 56 | ohe = OneHotEncoder(min_obs=2) 57 | X_cat = ohe.fit_transform(df[cat_cols]).todense() 58 | 59 | assert df.shape[0] == X_cat.shape[0] and X_cat.shape[1] < n_category 60 | -------------------------------------------------------------------------------- /tests/test_ivlearner.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | import numpy as np 3 | from sklearn.linear_model import LinearRegression 4 | from xgboost import XGBRegressor 5 | 6 | from causalml.inference.iv import BaseDRIVLearner 7 | from causalml.metrics import ape, auuc_score 8 | 9 | from .const import RANDOM_SEED, ERROR_THRESHOLD 10 | 11 | 12 | def test_drivlearner(): 13 | np.random.seed(RANDOM_SEED) 14 | n = 1000 15 | p = 8 16 | sigma = 1.0 17 | 18 | X = np.random.uniform(size=n * p).reshape((n, -1)) 19 | b = ( 20 | np.sin(np.pi * X[:, 0] * X[:, 1]) 21 | + 2 * (X[:, 2] - 0.5) ** 2 22 | + X[:, 3] 23 | + 0.5 * X[:, 4] 24 | ) 25 | assignment = (np.random.uniform(size=n) > 0.5).astype(int) 26 | eta = 0.1 27 | e_raw = np.maximum( 28 | np.repeat(eta, n), 29 | np.minimum(np.sin(np.pi * X[:, 0] * X[:, 1]), np.repeat(1 - eta, n)), 30 | ) 31 | e = e_raw.copy() 32 | e[assignment == 0] = 0 33 | tau = (X[:, 0] + X[:, 1]) / 2 34 | 35 | w = np.random.binomial(1, e, size=n) 36 | treatment = w 37 | y = b + (w - 0.5) * tau + sigma * np.random.normal(size=n) 38 | 39 | learner = BaseDRIVLearner( 40 | learner=XGBRegressor(), treatment_effect_learner=LinearRegression() 41 | ) 42 | 43 | # check the accuracy of the ATE estimation 44 | ate_p, lb, ub = learner.estimate_ate( 45 | X=X, 46 | assignment=assignment, 47 | treatment=treatment, 48 | y=y, 49 | p=(np.ones(n) * 1e-6, e_raw), 50 | ) 51 | assert (ate_p >= lb) and (ate_p <= ub) 52 | assert ape(tau.mean(), ate_p) < ERROR_THRESHOLD 53 | 54 | # check the accuracy of the CATE estimation with the bootstrap CI 55 | cate_p, _, _ = learner.fit_predict( 56 | X=X, 57 | assignment=assignment, 58 | treatment=treatment, 59 | y=y, 60 | p=(np.ones(n) * 1e-6, e_raw), 61 | return_ci=True, 62 | n_bootstraps=10, 63 | ) 64 | 65 | auuc_metrics = pd.DataFrame( 66 | { 67 | "cate_p": cate_p.flatten(), 68 | "W": treatment, 69 | "y": y, 70 | "treatment_effect_col": tau, 71 | } 72 | ) 73 | 74 | # Check if the normalized AUUC score of model's prediction is higher than random (0.5). 75 | auuc = auuc_score( 76 | auuc_metrics, 77 | outcome_col="y", 78 | treatment_col="W", 79 | treatment_effect_col="tau", 80 | normalize=True, 81 | ) 82 | assert auuc["cate_p"] > 0.5 83 | -------------------------------------------------------------------------------- /tests/test_match.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pandas as pd 3 | import pytest 4 | 5 | from causalml.match import NearestNeighborMatch, MatchOptimizer 6 | from causalml.propensity import ElasticNetPropensityModel 7 | from .const import RANDOM_SEED, TREATMENT_COL, SCORE_COL, GROUP_COL 8 | 9 | 10 | @pytest.fixture 11 | def generate_unmatched_data(generate_regression_data): 12 | generated = False 13 | 14 | def _generate_data(): 15 | if not generated: 16 | y, X, treatment, tau, b, e = generate_regression_data() 17 | 18 | features = ["x{}".format(i) for i in range(X.shape[1])] 19 | df = pd.DataFrame(X, columns=features) 20 | df[TREATMENT_COL] = treatment 21 | 22 | df_c = df.loc[treatment == 0] 23 | df_t = df.loc[treatment == 1] 24 | 25 | df = pd.concat([df_t, df_c, df_c], axis=0, ignore_index=True) 26 | 27 | pm = ElasticNetPropensityModel(random_state=RANDOM_SEED) 28 | ps = pm.fit_predict(df[features], df[TREATMENT_COL]) 29 | df[SCORE_COL] = ps 30 | df[GROUP_COL] = np.random.randint(0, 2, size=df.shape[0]) 31 | 32 | return df, features 33 | 34 | yield _generate_data 35 | 36 | 37 | def test_nearest_neighbor_match_ratio_2(generate_unmatched_data): 38 | df, features = generate_unmatched_data() 39 | 40 | psm = NearestNeighborMatch(replace=False, ratio=2, random_state=RANDOM_SEED) 41 | matched = psm.match(data=df, treatment_col=TREATMENT_COL, score_cols=[SCORE_COL]) 42 | assert sum(matched[TREATMENT_COL] == 0) == 2 * sum(matched[TREATMENT_COL] != 0) 43 | 44 | 45 | def test_nearest_neighbor_match_by_group(generate_unmatched_data): 46 | df, features = generate_unmatched_data() 47 | 48 | psm = NearestNeighborMatch(replace=False, ratio=1, random_state=RANDOM_SEED) 49 | 50 | matched = psm.match_by_group( 51 | data=df, 52 | treatment_col=TREATMENT_COL, 53 | score_cols=[SCORE_COL], 54 | groupby_col=GROUP_COL, 55 | ) 56 | 57 | assert sum(matched[TREATMENT_COL] == 0) == sum(matched[TREATMENT_COL] != 0) 58 | 59 | 60 | def test_nearest_neighbor_match_control_to_treatment(generate_unmatched_data): 61 | """ 62 | Tests whether control to treatment matching is working. Does so 63 | by using: 64 | 65 | replace=True 66 | treatment_to_control=False 67 | ratio=2 68 | 69 | 70 | And testing if we get 2x the number of control matches than treatment 71 | """ 72 | df, features = generate_unmatched_data() 73 | 74 | psm = NearestNeighborMatch( 75 | replace=True, ratio=2, treatment_to_control=False, random_state=RANDOM_SEED 76 | ) 77 | matched = psm.match(data=df, treatment_col=TREATMENT_COL, score_cols=[SCORE_COL]) 78 | assert 2 * sum(matched[TREATMENT_COL] == 0) == sum(matched[TREATMENT_COL] != 0) 79 | 80 | 81 | def test_match_optimizer(generate_unmatched_data): 82 | df, features = generate_unmatched_data() 83 | 84 | optimizer = MatchOptimizer( 85 | treatment_col=TREATMENT_COL, 86 | ps_col=SCORE_COL, 87 | matching_covariates=[SCORE_COL], 88 | min_users_per_group=100, 89 | smd_cols=[SCORE_COL], 90 | dev_cols_transformations={SCORE_COL: np.mean}, 91 | ) 92 | 93 | matched = optimizer.search_best_match(df) 94 | 95 | assert sum(matched[TREATMENT_COL] == 0) == sum(matched[TREATMENT_COL] != 0) 96 | -------------------------------------------------------------------------------- /tests/test_metrics.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | from numpy import isclose 3 | from causalml.metrics.visualize import qini_score 4 | 5 | 6 | def test_qini_score(): 7 | test_df = pd.DataFrame( 8 | {"y": [0, 0, 0, 0, 1, 0, 0, 1, 1, 1], "w": [0] * 5 + [1] * 5} 9 | ) 10 | 11 | good_uplift = [_ / 10 for _ in range(0, 5)] 12 | bad_uplift = [1] + [0] * 9 13 | test_df["learner_1"] = good_uplift * 2 14 | # learner_2 is a bad model because it gives zero for almost all rows of data 15 | test_df["learner_2"] = bad_uplift 16 | 17 | # get qini score for 2 models in the single calling of qini_score 18 | full_result = qini_score(test_df) 19 | 20 | # get qini score for learner_1 separately 21 | learner_1_result = qini_score(test_df[["y", "w", "learner_1"]]) 22 | 23 | # get qini score for learner_2 separately 24 | learner_2_result = qini_score(test_df[["y", "w", "learner_2"]]) 25 | 26 | # for each learner, its qini score should stay same no matter calling with another model or calling separately 27 | assert isclose(full_result["learner_1"], learner_1_result["learner_1"]) 28 | assert isclose(full_result["learner_2"], learner_2_result["learner_2"]) 29 | -------------------------------------------------------------------------------- /tests/test_propensity.py: -------------------------------------------------------------------------------- 1 | from causalml.propensity import ( 2 | ElasticNetPropensityModel, 3 | GradientBoostedPropensityModel, 4 | LogisticRegressionPropensityModel, 5 | ) 6 | from causalml.metrics import roc_auc_score 7 | 8 | 9 | from .const import RANDOM_SEED 10 | 11 | 12 | def test_logistic_regression_propensity_model(generate_regression_data): 13 | y, X, treatment, tau, b, e = generate_regression_data() 14 | 15 | pm = LogisticRegressionPropensityModel(random_state=RANDOM_SEED) 16 | ps = pm.fit_predict(X, treatment) 17 | 18 | assert roc_auc_score(treatment, ps) > 0.5 19 | 20 | 21 | def test_logistic_regression_propensity_model_model_kwargs(generate_regression_data): 22 | y, X, treatment, tau, b, e = generate_regression_data() 23 | 24 | pm = LogisticRegressionPropensityModel(random_state=123) 25 | 26 | assert pm.model.random_state == 123 27 | 28 | 29 | def test_elasticnet_propensity_model(generate_regression_data): 30 | y, X, treatment, tau, b, e = generate_regression_data() 31 | 32 | pm = ElasticNetPropensityModel(random_state=RANDOM_SEED) 33 | ps = pm.fit_predict(X, treatment) 34 | 35 | assert roc_auc_score(treatment, ps) > 0.5 36 | 37 | 38 | def test_gradientboosted_propensity_model(generate_regression_data): 39 | y, X, treatment, tau, b, e = generate_regression_data() 40 | 41 | pm = GradientBoostedPropensityModel(random_state=RANDOM_SEED) 42 | ps = pm.fit_predict(X, treatment) 43 | 44 | assert roc_auc_score(treatment, ps) > 0.5 45 | 46 | 47 | def test_gradientboosted_propensity_model_earlystopping(generate_regression_data): 48 | y, X, treatment, tau, b, e = generate_regression_data() 49 | 50 | pm = GradientBoostedPropensityModel(random_state=RANDOM_SEED, early_stop=True) 51 | ps = pm.fit_predict(X, treatment) 52 | 53 | assert roc_auc_score(treatment, ps) > 0.5 54 | -------------------------------------------------------------------------------- /tests/test_sensitivity.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | import pytest 3 | import numpy as np 4 | from sklearn.linear_model import LinearRegression 5 | 6 | from causalml.dataset import synthetic_data 7 | from causalml.inference.meta import ( 8 | BaseSLearner, 9 | BaseTLearner, 10 | XGBTRegressor, 11 | BaseXLearner, 12 | BaseRLearner, 13 | ) 14 | from causalml.metrics.sensitivity import Sensitivity 15 | from causalml.metrics.sensitivity import ( 16 | SensitivityPlaceboTreatment, 17 | SensitivityRandomCause, 18 | ) 19 | from causalml.metrics.sensitivity import ( 20 | SensitivityRandomReplace, 21 | SensitivitySelectionBias, 22 | ) 23 | from causalml.metrics.sensitivity import ( 24 | one_sided, 25 | alignment, 26 | one_sided_att, 27 | alignment_att, 28 | ) 29 | 30 | from .const import TREATMENT_COL, SCORE_COL, OUTCOME_COL, NUM_FEATURES 31 | 32 | 33 | @pytest.mark.parametrize( 34 | "learner", 35 | [ 36 | BaseSLearner(LinearRegression()), 37 | BaseTLearner(LinearRegression()), 38 | XGBTRegressor(), 39 | BaseXLearner(LinearRegression()), 40 | BaseRLearner(LinearRegression()), 41 | ], 42 | ) 43 | def test_Sensitivity(learner): 44 | y, X, treatment, tau, b, e = synthetic_data( 45 | mode=1, n=100000, p=NUM_FEATURES, sigma=1.0 46 | ) 47 | 48 | # generate the dataset format for sensitivity analysis 49 | INFERENCE_FEATURES = ["feature_" + str(i) for i in range(NUM_FEATURES)] 50 | df = pd.DataFrame(X, columns=INFERENCE_FEATURES) 51 | df[TREATMENT_COL] = treatment 52 | df[OUTCOME_COL] = y 53 | df[SCORE_COL] = e 54 | 55 | # calling the Base XLearner class and return the sensitivity analysis summary report 56 | sens = Sensitivity( 57 | df=df, 58 | inference_features=INFERENCE_FEATURES, 59 | p_col=SCORE_COL, 60 | treatment_col=TREATMENT_COL, 61 | outcome_col=OUTCOME_COL, 62 | learner=learner, 63 | ) 64 | 65 | # check the sensitivity summary report 66 | sens_summary = sens.sensitivity_analysis( 67 | methods=[ 68 | "Placebo Treatment", 69 | "Random Cause", 70 | "Subset Data", 71 | "Random Replace", 72 | "Selection Bias", 73 | ], 74 | sample_size=0.5, 75 | ) 76 | 77 | print(sens_summary) 78 | 79 | 80 | def test_SensitivityPlaceboTreatment(): 81 | y, X, treatment, tau, b, e = synthetic_data( 82 | mode=1, n=100000, p=NUM_FEATURES, sigma=1.0 83 | ) 84 | 85 | # generate the dataset format for sensitivity analysis 86 | INFERENCE_FEATURES = ["feature_" + str(i) for i in range(NUM_FEATURES)] 87 | df = pd.DataFrame(X, columns=INFERENCE_FEATURES) 88 | df[TREATMENT_COL] = treatment 89 | df[OUTCOME_COL] = y 90 | df[SCORE_COL] = e 91 | 92 | # calling the Base XLearner class and return the sensitivity analysis summary report 93 | learner = BaseXLearner(LinearRegression()) 94 | sens = SensitivityPlaceboTreatment( 95 | df=df, 96 | inference_features=INFERENCE_FEATURES, 97 | p_col=SCORE_COL, 98 | treatment_col=TREATMENT_COL, 99 | outcome_col=OUTCOME_COL, 100 | learner=learner, 101 | ) 102 | 103 | sens_summary = sens.summary(method="Random Cause") 104 | print(sens_summary) 105 | 106 | 107 | def test_SensitivityRandomCause(): 108 | y, X, treatment, tau, b, e = synthetic_data( 109 | mode=1, n=100000, p=NUM_FEATURES, sigma=1.0 110 | ) 111 | 112 | # generate the dataset format for sensitivity analysis 113 | INFERENCE_FEATURES = ["feature_" + str(i) for i in range(NUM_FEATURES)] 114 | df = pd.DataFrame(X, columns=INFERENCE_FEATURES) 115 | df[TREATMENT_COL] = treatment 116 | df[OUTCOME_COL] = y 117 | df[SCORE_COL] = e 118 | 119 | # calling the Base XLearner class and return the sensitivity analysis summary report 120 | learner = BaseXLearner(LinearRegression()) 121 | sens = SensitivityRandomCause( 122 | df=df, 123 | inference_features=INFERENCE_FEATURES, 124 | p_col=SCORE_COL, 125 | treatment_col=TREATMENT_COL, 126 | outcome_col=OUTCOME_COL, 127 | learner=learner, 128 | ) 129 | 130 | sens_summary = sens.summary(method="Random Cause") 131 | print(sens_summary) 132 | 133 | 134 | def test_SensitivityRandomReplace(): 135 | y, X, treatment, tau, b, e = synthetic_data( 136 | mode=1, n=100000, p=NUM_FEATURES, sigma=1.0 137 | ) 138 | 139 | # generate the dataset format for sensitivity analysis 140 | INFERENCE_FEATURES = ["feature_" + str(i) for i in range(NUM_FEATURES)] 141 | df = pd.DataFrame(X, columns=INFERENCE_FEATURES) 142 | df[TREATMENT_COL] = treatment 143 | df[OUTCOME_COL] = y 144 | df[SCORE_COL] = e 145 | 146 | # calling the Base XLearner class and return the sensitivity analysis summary report 147 | learner = BaseXLearner(LinearRegression()) 148 | sens = SensitivityRandomReplace( 149 | df=df, 150 | inference_features=INFERENCE_FEATURES, 151 | p_col=SCORE_COL, 152 | treatment_col=TREATMENT_COL, 153 | outcome_col=OUTCOME_COL, 154 | learner=learner, 155 | sample_size=0.9, 156 | replaced_feature="feature_0", 157 | ) 158 | 159 | sens_summary = sens.summary(method="Random Replace") 160 | print(sens_summary) 161 | 162 | 163 | def test_SensitivitySelectionBias(): 164 | y, X, treatment, tau, b, e = synthetic_data( 165 | mode=1, n=100000, p=NUM_FEATURES, sigma=1.0 166 | ) 167 | 168 | # generate the dataset format for sensitivity analysis 169 | INFERENCE_FEATURES = ["feature_" + str(i) for i in range(NUM_FEATURES)] 170 | df = pd.DataFrame(X, columns=INFERENCE_FEATURES) 171 | df[TREATMENT_COL] = treatment 172 | df[OUTCOME_COL] = y 173 | df[SCORE_COL] = e 174 | 175 | # calling the Base XLearner class and return the sensitivity analysis summary report 176 | learner = BaseXLearner(LinearRegression()) 177 | sens = SensitivitySelectionBias( 178 | df, 179 | INFERENCE_FEATURES, 180 | p_col=SCORE_COL, 181 | treatment_col=TREATMENT_COL, 182 | outcome_col=OUTCOME_COL, 183 | learner=learner, 184 | confound="alignment", 185 | alpha_range=None, 186 | ) 187 | 188 | lls_bias_alignment, partial_rsqs_bias_alignment = sens.causalsens() 189 | print(lls_bias_alignment, partial_rsqs_bias_alignment) 190 | 191 | # Plot the results by confounding vector and plot Confidence Intervals for ATE 192 | sens.plot(lls_bias_alignment, ci=True) 193 | 194 | 195 | def test_one_sided(): 196 | y, X, treatment, tau, b, e = synthetic_data( 197 | mode=1, n=100000, p=NUM_FEATURES, sigma=1.0 198 | ) 199 | alpha = np.quantile(y, 0.25) 200 | adj = one_sided(alpha, e, treatment) 201 | 202 | assert y.shape == adj.shape 203 | 204 | 205 | def test_alignment(): 206 | y, X, treatment, tau, b, e = synthetic_data( 207 | mode=1, n=100000, p=NUM_FEATURES, sigma=1.0 208 | ) 209 | alpha = np.quantile(y, 0.25) 210 | adj = alignment(alpha, e, treatment) 211 | 212 | assert y.shape == adj.shape 213 | 214 | 215 | def test_one_sided_att(): 216 | y, X, treatment, tau, b, e = synthetic_data( 217 | mode=1, n=100000, p=NUM_FEATURES, sigma=1.0 218 | ) 219 | alpha = np.quantile(y, 0.25) 220 | adj = one_sided_att(alpha, e, treatment) 221 | 222 | assert y.shape == adj.shape 223 | 224 | 225 | def test_alignment_att(): 226 | y, X, treatment, tau, b, e = synthetic_data( 227 | mode=1, n=100000, p=NUM_FEATURES, sigma=1.0 228 | ) 229 | alpha = np.quantile(y, 0.25) 230 | adj = alignment_att(alpha, e, treatment) 231 | 232 | assert y.shape == adj.shape 233 | -------------------------------------------------------------------------------- /tests/test_utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from causalml.inference.meta.utils import get_weighted_variance 3 | 4 | 5 | def test_weighted_variance(): 6 | x = np.array([1, 2, 3, 4, 5]) 7 | sample_weight_equal = np.ones(len(x)) 8 | 9 | var_x = get_weighted_variance(x, sample_weight_equal) 10 | # should get the same variance with equal sample_weight 11 | assert var_x == x.var() 12 | 13 | x1 = np.array([1, 2, 3, 4, 4, 5, 5]) 14 | sample_weight_equal = np.ones(len(x1)) 15 | sample_weight = [1, 1, 1, 2, 2] 16 | var_x2 = get_weighted_variance(x, sample_weight) 17 | var_x1 = get_weighted_variance(x1, sample_weight_equal) 18 | 19 | # should get the same variance by duplicate the observation based on the sample weight 20 | assert var_x1 == var_x2 21 | -------------------------------------------------------------------------------- /tests/test_value_optimization.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pandas as pd 3 | 4 | from sklearn.model_selection import train_test_split 5 | from sklearn.linear_model import LogisticRegression 6 | 7 | from causalml.dataset import make_uplift_classification 8 | from causalml.inference.meta import BaseTClassifier 9 | from causalml.optimize.value_optimization import CounterfactualValueEstimator 10 | from causalml.optimize.utils import get_treatment_costs 11 | from causalml.optimize.utils import get_actual_value 12 | 13 | 14 | from tests.const import RANDOM_SEED 15 | 16 | 17 | def test_counterfactual_value_optimization(): 18 | df, X_names = make_uplift_classification( 19 | n_samples=2000, treatment_name=["control", "treatment1", "treatment2"] 20 | ) 21 | df_train, df_test = train_test_split(df, test_size=0.2, random_state=RANDOM_SEED) 22 | 23 | train_idx = df_train.index 24 | test_idx = df_test.index 25 | 26 | conversion_cost_dict = {"control": 0, "treatment1": 2.5, "treatment2": 5} 27 | impression_cost_dict = {"control": 0, "treatment1": 0, "treatment2": 0.02} 28 | 29 | cc_array, ic_array, conditions = get_treatment_costs( 30 | treatment=df["treatment_group_key"], 31 | control_name="control", 32 | cc_dict=conversion_cost_dict, 33 | ic_dict=impression_cost_dict, 34 | ) 35 | conversion_value_array = np.full(df.shape[0], 20) 36 | 37 | actual_value = get_actual_value( 38 | treatment=df["treatment_group_key"], 39 | observed_outcome=df["conversion"], 40 | conversion_value=conversion_value_array, 41 | conditions=conditions, 42 | conversion_cost=cc_array, 43 | impression_cost=ic_array, 44 | ) 45 | 46 | random_allocation_value = actual_value.loc[test_idx].mean() 47 | 48 | tm = BaseTClassifier(learner=LogisticRegression(), control_name="control") 49 | tm.fit( 50 | df_train[X_names].values, 51 | df_train["treatment_group_key"], 52 | df_train["conversion"], 53 | ) 54 | tm_pred = tm.predict(df_test[X_names].values) 55 | 56 | proba_model = LogisticRegression() 57 | 58 | W_dummies = pd.get_dummies(df["treatment_group_key"]) 59 | XW = np.c_[df[X_names], W_dummies] 60 | proba_model.fit(XW[train_idx], df_train["conversion"]) 61 | y_proba = proba_model.predict_proba(XW[test_idx])[:, 1] 62 | 63 | cve = CounterfactualValueEstimator( 64 | treatment=df_test["treatment_group_key"], 65 | control_name="control", 66 | treatment_names=conditions[1:], 67 | y_proba=y_proba, 68 | cate=tm_pred, 69 | value=conversion_value_array[test_idx], 70 | conversion_cost=cc_array[test_idx], 71 | impression_cost=ic_array[test_idx], 72 | ) 73 | 74 | cve_best_idx = cve.predict_best() 75 | cve_best = [conditions[idx] for idx in cve_best_idx] 76 | actual_is_cve_best = df.loc[test_idx, "treatment_group_key"] == cve_best 77 | cve_value = actual_value.loc[test_idx][actual_is_cve_best].mean() 78 | 79 | assert cve_value > random_allocation_value 80 | -------------------------------------------------------------------------------- /tests/test_visualize.py: -------------------------------------------------------------------------------- 1 | from matplotlib import pyplot as plt 2 | import numpy as np 3 | import pandas as pd 4 | import pytest 5 | from sklearn.model_selection import KFold, train_test_split 6 | 7 | from causalml.metrics.visualize import get_cumlift, plot_tmlegain 8 | from causalml.inference.meta import LRSRegressor 9 | 10 | 11 | def test_visualize_get_cumlift_errors_on_nan(): 12 | df = pd.DataFrame( 13 | [[0, np.nan, 0.5], [1, np.nan, 0.1], [1, 1, 0.4], [0, 1, 0.3], [1, 1, 0.2]], 14 | columns=["w", "y", "pred"], 15 | ) 16 | 17 | with pytest.raises(Exception): 18 | get_cumlift(df) 19 | 20 | 21 | def test_plot_tmlegain(generate_regression_data, monkeypatch): 22 | monkeypatch.setattr(plt, "show", lambda: None) 23 | 24 | y, X, treatment, tau, b, e = generate_regression_data() 25 | 26 | ( 27 | X_train, 28 | X_test, 29 | y_train, 30 | y_test, 31 | e_train, 32 | e_test, 33 | treatment_train, 34 | treatment_test, 35 | tau_train, 36 | tau_test, 37 | b_train, 38 | b_test, 39 | ) = train_test_split(X, y, e, treatment, tau, b, test_size=0.5, random_state=42) 40 | 41 | learner = LRSRegressor() 42 | learner.fit(X_train, treatment_train, y_train) 43 | cate_test = learner.predict(X_test, treatment_test).flatten() 44 | 45 | df = pd.DataFrame( 46 | { 47 | "y": y_test, 48 | "w": treatment_test, 49 | "p": e_test, 50 | "S-Learner": cate_test, 51 | "Actual": tau_test, 52 | } 53 | ) 54 | 55 | inference_cols = [] 56 | for i in range(X_test.shape[1]): 57 | col = "col_" + str(i) 58 | df[col] = X_test[:, i] 59 | inference_cols.append(col) 60 | 61 | n_fold = 3 62 | kf = KFold(n_splits=n_fold) 63 | 64 | plot_tmlegain( 65 | df, 66 | inference_col=inference_cols, 67 | outcome_col="y", 68 | treatment_col="w", 69 | p_col="p", 70 | n_segment=5, 71 | cv=kf, 72 | ci=False, 73 | ) 74 | -------------------------------------------------------------------------------- /tox.ini: -------------------------------------------------------------------------------- 1 | [tox] 2 | envlist = py38 3 | 4 | [testenv] 5 | deps = pytest 6 | commands = 7 | pytest -sv 8 | 9 | [flake8] 10 | max-line-length = 120 11 | ignore = E121, 12 | E123, 13 | E126, 14 | E128, 15 | E129, 16 | E226, 17 | E24, 18 | E704, 19 | E731, 20 | E741, 21 | W503, 22 | W504 23 | 24 | builtins = __builtins__ 25 | --------------------------------------------------------------------------------