├── .github └── workflows │ └── build.yml ├── .gitignore ├── .readthedocs.yml ├── CODE_OF_CONDUCT.md ├── CONTRIBUTING.md ├── LICENSE ├── README.md ├── causallib ├── README.md ├── __init__.py ├── analysis │ └── __init__.py ├── contrib │ ├── README.md │ ├── __init__.py │ ├── adversarial_balancing │ │ ├── __init__.py │ │ ├── adversarial_balancing.py │ │ └── classifier_selection.py │ ├── bicause_tree │ │ ├── __init__.py │ │ ├── bicause_tree.py │ │ └── overlap_utils.py │ ├── faissknn.py │ ├── hemm │ │ ├── README.md │ │ ├── __init__.py │ │ ├── gen_synthetic_data.py │ │ ├── hemm.py │ │ ├── hemm_api.py │ │ ├── hemm_metrics.py │ │ ├── hemm_outcome_models.py │ │ ├── hemm_utilities.py │ │ └── load_ihdp_data.py │ ├── requirements.txt │ ├── shared_sparsity_selection │ │ ├── __init__.py │ │ └── shared_sparsity_selection.py │ ├── sklearn_scorer_wrapper │ │ ├── __init__.py │ │ └── sklearn_scorer_wrapper.py │ └── tests │ │ ├── __init__.py │ │ ├── test_adversarial_balancing.py │ │ ├── test_bicause_tree.py │ │ ├── test_hemm.py │ │ ├── test_shared_sparsity_selection.py │ │ └── test_sklearn_scorer_wrapper.py ├── datasets │ ├── README.md │ ├── __init__.py │ ├── data │ │ ├── acic_challenge_2016 │ │ │ ├── LICENSE │ │ │ ├── README.md │ │ │ ├── x.csv │ │ │ ├── zymu_1.csv │ │ │ ├── zymu_10.csv │ │ │ ├── zymu_2.csv │ │ │ ├── zymu_3.csv │ │ │ ├── zymu_4.csv │ │ │ ├── zymu_5.csv │ │ │ ├── zymu_6.csv │ │ │ ├── zymu_7.csv │ │ │ ├── zymu_8.csv │ │ │ └── zymu_9.csv │ │ └── nhefs │ │ │ ├── LICENSE │ │ │ ├── NHEFS.csv │ │ │ └── NHEFS_codebook.csv │ └── data_loader.py ├── estimation │ ├── README.md │ ├── __init__.py │ ├── base_estimator.py │ ├── base_weight.py │ ├── doubly_robust.py │ ├── ipw.py │ ├── marginal_outcome.py │ ├── matching.py │ ├── overlap_weights.py │ ├── rlearner.py │ ├── standardization.py │ ├── tmle.py │ └── xlearner.py ├── evaluation │ ├── README.md │ ├── __init__.py │ ├── evaluator.py │ ├── metrics.py │ ├── plots │ │ ├── __init__.py │ │ ├── curve_data_makers.py │ │ ├── data_extractors.py │ │ ├── mixins.py │ │ └── plots.py │ ├── predictions.py │ ├── predictor.py │ ├── results.py │ └── scoring.py ├── metrics │ ├── __init__.py │ ├── outcome_metrics.py │ ├── propensity_metrics.py │ ├── scorers.py │ └── weight_metrics.py ├── model_selection │ ├── __init__.py │ ├── search.py │ └── split.py ├── positivity │ ├── __init__.py │ ├── base_positivity.py │ ├── datasets │ │ ├── __init__.py │ │ ├── pizza_data_simulator.py │ │ └── positivity_data_simulator.py │ ├── matching.py │ ├── metrics │ │ ├── __init__.py │ │ └── metrics.py │ ├── multiple_treatment_positivity.py │ ├── trimming.py │ └── univariate_bbox.py ├── preprocessing │ ├── README.md │ ├── __init__.py │ ├── confounder_selection.py │ ├── filters.py │ └── transformers.py ├── simulation │ ├── CausalSimulator3.py │ └── __init__.py ├── survival │ ├── README.md │ ├── __init__.py │ ├── base_survival.py │ ├── marginal_survival.py │ ├── regression_curve_fitter.py │ ├── standardized_survival.py │ ├── survival_utils.py │ ├── univariate_curve_fitter.py │ ├── weighted_standardized_survival.py │ └── weighted_survival.py ├── tests │ ├── __init__.py │ ├── test_base_weight.py │ ├── test_causal_simulator3.py │ ├── test_confounder_selection.py │ ├── test_datasets.py │ ├── test_doublyrobust.py │ ├── test_evaluation.py │ ├── test_ipw.py │ ├── test_marginal_outcome.py │ ├── test_matching.py │ ├── test_metrics.py │ ├── test_overlap_weights.py │ ├── test_plots.py │ ├── test_positivity_data.py │ ├── test_positivity_metrics.py │ ├── test_positivity_models.py │ ├── test_positivity_models_multitreatment.py │ ├── test_rlearner.py │ ├── test_scorers.py │ ├── test_search.py │ ├── test_split.py │ ├── test_standardization.py │ ├── test_survival.py │ ├── test_tmle.py │ ├── test_transformers.py │ ├── test_utils.py │ └── test_xlearner.py └── utils │ ├── __init__.py │ ├── crossfit.py │ ├── exceptions.py │ ├── general_tools.py │ └── stat_utils.py ├── docs ├── Makefile ├── README.md ├── make.bat ├── requirements.txt └── source │ ├── causallib.analysis.rst │ ├── causallib.contrib.adversarial_balancing.adversarial_balancing.rst │ ├── causallib.contrib.adversarial_balancing.classifier_selection.rst │ ├── causallib.contrib.adversarial_balancing.rst │ ├── causallib.contrib.faissknn.rst │ ├── causallib.contrib.hemm.gen_synthetic_data.rst │ ├── causallib.contrib.hemm.hemm.rst │ ├── causallib.contrib.hemm.hemm_api.rst │ ├── causallib.contrib.hemm.hemm_metrics.rst │ ├── causallib.contrib.hemm.hemm_outcome_models.rst │ ├── causallib.contrib.hemm.hemm_utilities.rst │ ├── causallib.contrib.hemm.load_ihdp_data.rst │ ├── causallib.contrib.hemm.rst │ ├── causallib.contrib.rst │ ├── causallib.contrib.shared_sparsity_selection.rst │ ├── causallib.contrib.shared_sparsity_selection.shared_sparsity_selection.rst │ ├── causallib.contrib.tests.rst │ ├── causallib.contrib.tests.test_adversarial_balancing.rst │ ├── causallib.contrib.tests.test_hemm.rst │ ├── causallib.contrib.tests.test_shared_sparsity_selection.rst │ ├── causallib.datasets.data_loader.rst │ ├── causallib.datasets.rst │ ├── causallib.estimation.base_estimator.rst │ ├── causallib.estimation.base_weight.rst │ ├── causallib.estimation.doubly_robust.rst │ ├── causallib.estimation.ipw.rst │ ├── causallib.estimation.marginal_outcome.rst │ ├── causallib.estimation.matching.rst │ ├── causallib.estimation.overlap_weights.rst │ ├── causallib.estimation.rlearner.rst │ ├── causallib.estimation.rst │ ├── causallib.estimation.standardization.rst │ ├── causallib.estimation.tmle.rst │ ├── causallib.estimation.xlearner.rst │ ├── causallib.evaluation.evaluator.rst │ ├── causallib.evaluation.metrics.rst │ ├── causallib.evaluation.plots.curve_data_makers.rst │ ├── causallib.evaluation.plots.data_extractors.rst │ ├── causallib.evaluation.plots.mixins.rst │ ├── causallib.evaluation.plots.plots.rst │ ├── causallib.evaluation.plots.rst │ ├── causallib.evaluation.predictions.rst │ ├── causallib.evaluation.predictor.rst │ ├── causallib.evaluation.results.rst │ ├── causallib.evaluation.rst │ ├── causallib.evaluation.scoring.rst │ ├── causallib.preprocessing.confounder_selection.rst │ ├── causallib.preprocessing.filters.rst │ ├── causallib.preprocessing.rst │ ├── causallib.preprocessing.transformers.rst │ ├── causallib.rst │ ├── causallib.simulation.CausalSimulator3.rst │ ├── causallib.simulation.rst │ ├── causallib.survival.base_survival.rst │ ├── causallib.survival.marginal_survival.rst │ ├── causallib.survival.regression_curve_fitter.rst │ ├── causallib.survival.rst │ ├── causallib.survival.standardized_survival.rst │ ├── causallib.survival.survival_utils.rst │ ├── causallib.survival.univariate_curve_fitter.rst │ ├── causallib.survival.weighted_standardized_survival.rst │ ├── causallib.survival.weighted_survival.rst │ ├── causallib.tests.rst │ ├── causallib.tests.test_base_weight.rst │ ├── causallib.tests.test_causal_simulator3.rst │ ├── causallib.tests.test_confounder_selection.rst │ ├── causallib.tests.test_datasets.rst │ ├── causallib.tests.test_doublyrobust.rst │ ├── causallib.tests.test_evaluation.rst │ ├── causallib.tests.test_ipw.rst │ ├── causallib.tests.test_marginal_outcome.rst │ ├── causallib.tests.test_matching.rst │ ├── causallib.tests.test_overlap_weights.rst │ ├── causallib.tests.test_plots.rst │ ├── causallib.tests.test_rlearner.rst │ ├── causallib.tests.test_standardization.rst │ ├── causallib.tests.test_survival.rst │ ├── causallib.tests.test_tmle.rst │ ├── causallib.tests.test_transformers.rst │ ├── causallib.tests.test_utils.rst │ ├── causallib.tests.test_xlearner.rst │ ├── causallib.utils.crossfit.rst │ ├── causallib.utils.general_tools.rst │ ├── causallib.utils.rst │ ├── causallib.utils.stat_utils.rst │ ├── conf.py │ ├── index.rst │ └── modules.rst ├── examples ├── Bank-Marketing.ipynb ├── Dehejia_Wahba_replication.ipynb ├── MANAGE agricultural data.ipynb ├── TMLE.ipynb ├── causal_inference_vs_descriptive_statistics.ipynb ├── causal_simulator.ipynb ├── causal_survival_analysis.ipynb ├── doubly_robust.ipynb ├── evaluation_plots_overview.ipynb ├── fast_food_employment_card_krueger.ipynb ├── hemm_demo.ipynb ├── ipw.ipynb ├── lalonde.ipynb ├── lalonde_matching.ipynb ├── matching.ipynb ├── matching_with_custom_backends.ipynb ├── nhefs.ipynb ├── positivity.ipynb ├── rlearner.ipynb ├── standardization.ipynb └── xlearner.ipynb ├── requirements.txt └── setup.py /.github/workflows/build.yml: -------------------------------------------------------------------------------- 1 | name: Build 2 | 3 | on: [push, pull_request] 4 | 5 | jobs: 6 | build_and_test: 7 | runs-on: ubuntu-latest 8 | strategy: 9 | fail-fast: false # Don't cancel entire run if one python-version fails 10 | matrix: 11 | python-version: ["3.8", "3.9", "3.10", "3.11", "3.12"] 12 | name: Build and test on Python ${{ matrix.python-version }} 13 | steps: 14 | - uses: actions/checkout@v4 15 | 16 | - name: Set up Python 17 | uses: actions/setup-python@v5 18 | with: 19 | python-version: ${{ matrix.python-version }} 20 | # cache: 'pip' 21 | # cache-dependency-path: setup.py 22 | 23 | - name: Install and upgrade latest CI dependencies 24 | run: | 25 | pip install --upgrade pip 26 | pip install --upgrade pytest coverage # pytest-cov 27 | pip install --upgrade importlib-metadata # Solves a python 3.7 install bug 28 | 29 | - name: Install local causallib 30 | run: | 31 | pip install . 32 | pip install .[contrib] # Optional requirements for contrib module 33 | 34 | - name: Show environment's final dependencies 35 | run: pip freeze --all 36 | 37 | - name: Test with pytest 38 | run: | 39 | coverage run --source=. --omit=*__init__.py,setup.py -m pytest 40 | coverage xml 41 | # pytest tests.py --doctest-modules --junitxml=junit/test-results.xml --cov=com --cov-report=xml --cov-report=html 42 | 43 | - name: Publish to CodeClimate 44 | uses: paambaati/codeclimate-action@v8.0.0 45 | env: 46 | CC_TEST_REPORTER_ID: ${{ secrets.CODECLIMATE_REPORTER_ID }} 47 | # Forked PRs have no access to secrets, so uploading a coverage report to Code Climate fails. 48 | # To avoid that failing the entire workflow, continue on error: 49 | continue-on-error: true 50 | 51 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | #project-specific 2 | .DS_store 3 | .project 4 | .pydevproject 5 | .vscode 6 | 7 | # Byte-compiled / optimized / DLL files 8 | __pycache__/ 9 | *.py[cod] 10 | *$py.class 11 | 12 | # C extensions 13 | *.so 14 | 15 | # Distribution / packaging 16 | .Python 17 | env/ 18 | build/ 19 | develop-eggs/ 20 | dist/ 21 | downloads/ 22 | eggs/ 23 | .eggs/ 24 | lib/ 25 | lib64/ 26 | parts/ 27 | sdist/ 28 | var/ 29 | *.egg-info/ 30 | .installed.cfg 31 | *.egg 32 | 33 | # PyInstaller 34 | # Usually these files are written by a python script from a template 35 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 36 | *.manifest 37 | *.spec 38 | 39 | # Installer logs 40 | pip-log.txt 41 | pip-delete-this-directory.txt 42 | 43 | # Unit test / coverage reports 44 | htmlcov/ 45 | .tox/ 46 | .coverage 47 | .coverage.* 48 | .cache 49 | nosetests.xml 50 | coverage.xml 51 | *,cover 52 | .hypothesis/ 53 | .pytest_cache 54 | 55 | # Translations 56 | *.mo 57 | *.pot 58 | 59 | # Django stuff: 60 | *.log 61 | local_settings.py 62 | 63 | # Flask instance folder 64 | instance/ 65 | 66 | # Scrapy stuff: 67 | .scrapy 68 | 69 | # Sphinx documentation 70 | docs/_build/ 71 | 72 | # PyBuilder 73 | target/ 74 | 75 | # IPython Notebook 76 | .ipynb_checkpoints 77 | 78 | # pyenv 79 | .python-version 80 | 81 | # celery beat schedule file 82 | celerybeat-schedule 83 | 84 | # dotenv 85 | .env 86 | 87 | # virtualenv 88 | venv/ 89 | ENV/ 90 | 91 | # Spyder project settings 92 | .spyderproject 93 | 94 | ################### 95 | # Covers JetBrains IDEs: IntelliJ, RubyMine, PhpStorm, AppCode, PyCharm, CLion, Android Studio and Webstorm 96 | # Reference: https://intellij-support.jetbrains.com/hc/en-us/articles/206544839 97 | 98 | # User-specific stuff: 99 | .idea/ 100 | 101 | # Sensitive or high-churn files: 102 | .idea/dataSources.ids 103 | .idea/dataSources.xml 104 | .idea/dataSources.local.xml 105 | .idea/sqlDataSources.xml 106 | .idea/dynamic.xml 107 | .idea/uiDesigner.xml 108 | 109 | # Gradle: 110 | .idea/gradle.xml 111 | .idea/libraries 112 | 113 | # Mongo Explorer plugin: 114 | .idea/mongoSettings.xml 115 | 116 | ## File-based project format: 117 | *.iws 118 | 119 | ## Plugin-specific files: 120 | 121 | # IntelliJ 122 | /out/ 123 | 124 | # mpeltonen/sbt-idea plugin 125 | .idea_modules/ 126 | 127 | # JIRA plugin 128 | atlassian-ide-plugin.xml 129 | 130 | # Crashlytics plugin (for Android Studio and IntelliJ) 131 | com_crashlytics_export_strings.xml 132 | crashlytics.properties 133 | crashlytics-build.properties 134 | fabric.properties 135 | 136 | 137 | #### 138 | # linux 139 | *~ 140 | 141 | # temporary files which can be created if a process still has a handle open of a deleted file 142 | .fuse_hidden* 143 | 144 | # KDE directory preferences 145 | .directory 146 | 147 | # Linux trash folder which might appear on any partition or disk 148 | .Trash-* 149 | 150 | #### 151 | #emacs: 152 | # -*- mode: gitignore; -*- 153 | *~ 154 | \#*\# 155 | /.emacs.desktop 156 | /.emacs.desktop.lock 157 | *.elc 158 | auto-save-list 159 | tramp 160 | .\#* 161 | 162 | # Org-mode 163 | .org-id-locations 164 | *_archive 165 | 166 | # flymake-mode 167 | *_flymake.* 168 | 169 | # eshell files 170 | /eshell/history 171 | /eshell/lastdir 172 | 173 | # elpa packages 174 | /elpa/ 175 | 176 | # reftex files 177 | *.rel 178 | 179 | # AUCTeX auto folder 180 | /auto/ 181 | 182 | # cask packages 183 | .cask/ 184 | dist/ 185 | 186 | # Flycheck 187 | flycheck_*.el 188 | 189 | # server auth directory 190 | /server/ 191 | 192 | # projectiles files 193 | .projectile 194 | 195 | ############ 196 | # latex 197 | *.log 198 | *.aux 199 | *.out 200 | *tex.gz 201 | -------------------------------------------------------------------------------- /.readthedocs.yml: -------------------------------------------------------------------------------- 1 | # Documentation: https://docs.readthedocs.io/en/stable/config-file/v2.html 2 | 3 | # Configuration file version: 4 | version: 2 5 | 6 | build: 7 | os: ubuntu-20.04 8 | tools: 9 | python: "3.9" 10 | 11 | # Build documentation in the docs/ directory with Sphinx 12 | sphinx: 13 | configuration: docs/source/conf.py 14 | builder: html 15 | 16 | # Additional documentation formats (pdf, etc.) 17 | formats: all 18 | 19 | # Python and requirements required to build your docs 20 | python: 21 | install: 22 | - requirements: requirements.txt 23 | - requirements: docs/requirements.txt 24 | # - method: pip # Default 25 | # path: . 26 | # extra_requirements: 27 | # - docs 28 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contribution guidelines 2 | 3 | Causallib welcomes community contributions to this repository. 4 | This file provides the guidelines to contribute to this project. 5 | 6 | ## Contributions 7 | We welcome a wide range of contributions: 8 | - estimation models 9 | - preprocessors 10 | - plots 11 | - improvements to the overall design 12 | - causal analysis examples using causallib in Jupyter Notebooks 13 | - documentation 14 | - bug reports 15 | - bug fixes 16 | - and more 17 | 18 | ## Prerequisites 19 | Causallib follows the [Github contribution workflow](https://git-scm.com/book/sv/v2/GitHub-Contributing-to-a-Project): 20 | forking the repository, cloning it, branching out a feature branch, developing, 21 | opening a pull request back to the causallib upstream once you are done, 22 | and performing an iterative review process. 23 | If your changes require a lot of work, it is better to first make sure they are 24 | aligned with the plans for the package. 25 | Therefore, it is recommended that you first open an issue describing 26 | what changes you think should be made and why. 27 | After a discussion with the core maintainers, we will decide whether the suggestion 28 | is welcomed or not. 29 | If so, you are encouraged to link you pull request to its corresponding issue. 30 | 31 | ### Tests 32 | Contribution of new code is best when accompanied by corresponding testing code. 33 | Unittests should be located in the `causallib/tests/` directory and run with `pytest`. 34 | 35 | New bug fixes should, too, be ideally coupled with tests replicating the bug, 36 | ensuring it will not repeat in the future. 37 | 38 | ### Documentation 39 | New code should also be well documented. 40 | Causallib uses [Google docstring format](https://sphinxcontrib-napoleon.readthedocs.io/en/latest/example_google.html), 41 | and docstrings should include input and output typing 42 | (if not [specified in the code](https://docs.python.org/3/library/typing.html)). 43 | If there are relevant academic papers on which the contribution is based upon, 44 | please cite it and link to it in the docstring. 45 | 46 | ### Style 47 | The ultimate goal is to enhance the readability of the code. 48 | Causallib does not currently adhere to any strict style guideline. 49 | It follows the general guidance of PEP8 specifications, 50 | but encourages contributors to diverge from it if they see fit. 51 | 52 | Whenever in doubt - follow [the _Black_ code style guide](https://black.readthedocs.io/en/stable/the_black_code_style/current_style.html). 53 | 54 | ### `Contrib` module 55 | The `contrib` module is designated to more state-of-the-art methods that are not 56 | yet well-established but nonetheless may benefit the community. 57 | Ideally, models should still adhere to the causallib's API 58 | (namely, `IndividualOutcomeEstimator`, `PopulationOutcomeEstimator`, `WeightEstimator`). 59 | This module has its own requirements file and tests. 60 | 61 | ### Contributor License Agreement 62 | The Causallib developer team works for IBM. 63 | To accept contributions outside of IBM, 64 | we need a signed Contributor License Agreement (CLA) 65 | from you before code contributions can be reviewed and merged. 66 | By signing a contributor license agreement (CLA), 67 | you're basically just attesting to the fact that 68 | you are the author of the contribution and that you're freely 69 | contributing it under the terms of the Apache-2.0 license. 70 | 71 | When you contribute to the Causallib project with a new pull request, 72 | a bot will evaluate whether you have signed the CLA. If required, the 73 | bot will comment on the pull request, including a link to accept the 74 | agreement. 75 | You can review the [individual CLA document as a PDF](https://www.apache.org/licenses/icla.pdf). 76 | 77 | **Note**: 78 | > If your contribution is part of your employment or your contribution 79 | > is the property of your employer, then you will likely need to sign a 80 | > [corporate CLA](https://www.apache.org/licenses/cla-corporate.txt) too and 81 | > email it to us at . 82 | 83 | ## Contributors 84 | Ehud Karavani 85 | Yishai Shimoni 86 | Michael Danziger 87 | Lior Ness 88 | Itay Manes 89 | Yoav Kan-Tor 90 | Chirag Nagpal 91 | Tal Kozlovski 92 | Liran Szlak 93 | Onkar Bhardwaj 94 | Dennis Wei 95 | -------------------------------------------------------------------------------- /causallib/README.md: -------------------------------------------------------------------------------- 1 | # Package `causallib` 2 | A package for estimating causal effect and counterfactual outcomes from observational data. 3 | 4 | `casuallib` provide various causal inference methods with a distinct paradigm: 5 | * Every causal model has some machine learning model at its core. 6 | This allows to mix & match causal models with powerful machine learning tools, 7 | simply by plugging them into the causal model. 8 | * Inspired by the scikit-learn design, once trained, causal models can be 9 | applied onto out-of-bag samples. 10 | 11 | `causallib` also provide performance evaluation scheme of the causal model 12 | by evaluating the machine learning core model in a causal inference context. 13 | 14 | Accompanying datasets are also available, both real and simulated ones. 15 | The various modules and folders provide the specific usage for each part. 16 | 17 | ## Structure 18 | The package is comprised of several modules, 19 | each providing a different functionality 20 | that is related to the causal inference models. 21 | 22 | ### `estimation` 23 | This module includes the estimator classes, 24 | where multiple popular estimators are implemented. 25 | Specifically, This includes 26 | - Inverse probability weighting (IPW). 27 | - Standardization. 28 | - 3 versions of doubly-robust methods. 29 | 30 | Each of these methods receives one or more machine learning models that 31 | can be trained (fit), and then used to estimate (predict) the relevant outcome 32 | of interest. 33 | 34 | ### `evaluation` 35 | This module provides the classes to evaluate the performance of methods 36 | defined in the estimation module. 37 | Evaluations are tailored to the type of method that is used. 38 | For example, weight estimators such as IPW can be evaluated for how well 39 | they remove bias from the data, 40 | while outcome models can be evaluated for their precision. 41 | 42 | ### `preprocessing` 43 | This module provides several enhancements to the filters and transformers 44 | provided by scikit-learn. 45 | These can be used within a pipeline framework together with the models. 46 | 47 | ### `datasets` 48 | Several datasets are provided within the package in the `datasets` module: 49 | * NHEFS study data on the effect of smoking cessation on weight gain. 50 | Adapted from [Hernán and Robins' Causal Inference Book](https://www.hsph.harvard.edu/miguel-hernan/causal-inference-book/) 51 | * A handful of simulation sets from the [2016 Atlantic Causal Inference 52 | Conference (ACIC) data challenge](https://jenniferhill7.wixsite.com/acic-2016/competition). 53 | * Simulation module allows creating simulated data based on a causal graph 54 | depicting the connection between covariates, treatment assignment and outcomes. 55 | 56 | ### Additional folders 57 | Several additional folders exist under the package and hold several 58 | internal utilities. 59 | They should only be used as part of development. 60 | This folders include `analysis`, `simulation`, `utils`, and `tests`. 61 | 62 | ## Usage 63 | The examples folder contains several notebooks exemplifying the use of the 64 | package. 65 | -------------------------------------------------------------------------------- /causallib/__init__.py: -------------------------------------------------------------------------------- 1 | __version__ = "0.10.0" 2 | -------------------------------------------------------------------------------- /causallib/analysis/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BiomedSciAI/causallib/803a4d34eaf09980258b498631d6af15017528dc/causallib/analysis/__init__.py -------------------------------------------------------------------------------- /causallib/contrib/README.md: -------------------------------------------------------------------------------- 1 | # Module `causallib.contrib` 2 | This module currently includes additional causal methods contributed to the package 3 | by causal inference researchers other than `causallib`'s core developers. 4 | 5 | The causal models in this module can be slightly more novel then in the ones in `estimation` module. 6 | However, they should largely adhere to `causallib` API 7 | (e.g., `IndividualOutcomeEstimator` or `WeightEstimator`). 8 | Since code here is more experimental, 9 | models might also require additional (and less trivial) package dependencies, 10 | or have less test coverage. 11 | Well-integrated models could be transferred into the main `estimation` module in the future. 12 | 13 | ## Contributed Methods 14 | Currently contributed methods are: 15 | 16 | 1. Adversarial Balancing: implementing the algorithm described in 17 | [Adversarial Balancing for Causal Inference](https://arxiv.org/abs/1810.07406). 18 | ```python 19 | from causallib.contrib.adversarial_balancing import AdversarialBalancing 20 | ``` 21 | 1. Interpretable Subgroup Discovery in Treatment Effect Estimation: 22 | implementing the heterogeneous effect mixture model (HEMM) presented in 23 | [Interpretable Subgroup Discovery in Treatment Effect Estimation with Application to Opioid Prescribing Guidelines](https://arxiv.org/pdf/1905.03297.pdf) 24 | ```python 25 | from causallib.contrib.hemm import HEMM 26 | ``` 27 | 1. Matching Estimation/Transform using `faiss`. 28 | 29 | Implemented a nearest neighbors search with API that matches `sklearn.NearestNeighbors` 30 | but is powered by [faiss](https://github.com/facebookresearch/faiss) for GPU 31 | support and much faster search on CPU as well. 32 | 33 | ```python 34 | from causallib.contrib.faissknn import FaissNearestNeighbors 35 | ``` 36 | 37 | ## Dependencies 38 | Each model might have slightly different requirements. 39 | Refer to the documentation of each model for the additional packages it requires. 40 | 41 | Requirements for `contrib` models are concentrated in `contrib/requirements.txt` 42 | and can be automatically installed using the extra-requirements `contrib` flag: 43 | ```shell script 44 | pip install causallib[contrib] -f https://download.pytorch.org/whl/torch_stable.html 45 | ``` 46 | The `-f` find-links option is required to install PyTorch dependency. 47 | 48 | ## References 49 | 50 | Ozery-Flato, M., Thodoroff, P., Ninio, M., Rosen-Zvi, M., & El-Hay, T. (2018). [Adversarial balancing for causal inference.](https://arxiv.org/abs/1810.07406) arXiv preprint arXiv:1810.07406. 51 | 52 | Nagpal, C., Wei, D., Vinzamuri, B., Shekhar, M., Berger, S. E., Das, S., & Varshney, K. R. (2020, April). [Interpretable subgroup discovery in treatment effect estimation with application to opioid prescribing guidelines.](https://arxiv.org/pdf/1905.03297.pdf) In Proceedings of the ACM Conference on Health, Inference, and Learning (pp. 19-29). 53 | -------------------------------------------------------------------------------- /causallib/contrib/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BiomedSciAI/causallib/803a4d34eaf09980258b498631d6af15017528dc/causallib/contrib/__init__.py -------------------------------------------------------------------------------- /causallib/contrib/adversarial_balancing/__init__.py: -------------------------------------------------------------------------------- 1 | from .adversarial_balancing import AdversarialBalancing 2 | -------------------------------------------------------------------------------- /causallib/contrib/bicause_tree/__init__.py: -------------------------------------------------------------------------------- 1 | from .bicause_tree import BICauseTree, PropensityBICauseTree 2 | -------------------------------------------------------------------------------- /causallib/contrib/faissknn.py: -------------------------------------------------------------------------------- 1 | # (C) Copyright 2021 IBM Corp. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import numpy as np 16 | import faiss 17 | 18 | 19 | class FaissNearestNeighbors: 20 | 21 | def __init__(self, 22 | metric="mahalanobis", 23 | index_type="flatl2", n_cells=100, n_probes=10): 24 | """NearestNeighbors object utilizing the faiss library for speed 25 | 26 | Implements the same API as sklearn but runs 5-10x faster. Utilizes the 27 | `faiss` library https://github.com/facebookresearch/faiss . Tested with 28 | version 1.7.0. If `faiss-gpu` is installed from pypi, GPU acceleration 29 | will be used if available. 30 | 31 | Args: 32 | metric (str) : Distance metric for finding nearest neighbors 33 | (default: "mahalanobis") 34 | index_type (str) : Index type within faiss to use 35 | (supported: "flatl2" and "ivfflat") 36 | n_cells (int) : Number of voronoi cells (only used for "ivfflat", 37 | default: 100) 38 | n_probes (int) : Number of voronoi cells to search in 39 | (only used for "ivfflat", default: 10) 40 | Attributes (after running `fit`): 41 | index_ : the faiss index fit from the data. For details about 42 | faiss indices, see the faiss documentation at 43 | https://github.com/facebookresearch/faiss/wiki/Faiss-indexes . 44 | """ 45 | self.metric = metric 46 | self.n_cells = n_cells 47 | self.n_probes = n_probes 48 | self.index_type = index_type 49 | 50 | def fit(self, X): 51 | """Create faiss index and train with data. 52 | 53 | Args: 54 | X (np.array): Array of N samples of shape (NxM) 55 | 56 | Returns: 57 | self: Fitted object 58 | """ 59 | X = self._transform_covariates(X) 60 | if self.index_type == "flatl2": 61 | self.index_ = faiss.IndexFlatL2(X.shape[1]) 62 | self.index_.add(X) 63 | elif self.index_type == "ivfflat": 64 | quantizer = faiss.IndexFlatL2(X.shape[1]) 65 | n_cells = max(1, min(self.n_cells, X.shape[0]//200)) 66 | n_probes = min(self.n_probes, n_cells) 67 | self.index_ = faiss.IndexIVFFlat( 68 | quantizer, X.shape[1], n_cells) 69 | self.index_.train(X) 70 | self.index_.nprobe = n_probes 71 | self.index_.add(X) 72 | else: 73 | raise NotImplementedError( 74 | "Index type {} not implemented. Please select" 75 | "one of [\"flatl2\", \"ivfflat\"]".format(self.index_type)) 76 | return self 77 | 78 | def kneighbors(self, X, n_neighbors=1): 79 | """Find the k nearest neighbors of each sample in X 80 | 81 | Args: 82 | X (np.array): Array of shape (N,M) of samples to search 83 | for neighbors of. M must be the same as the fit data. 84 | n_neighbors (int, optional): Number of neighbors to find. 85 | Defaults to 1. 86 | 87 | Returns: 88 | (distances, indices): Two np.array objects of shape (N,n_neighbors) 89 | containing the distances and indices of the closest neighbors. 90 | """ 91 | X = self._transform_covariates(X) 92 | distances, indices = self.index_.search(X, n_neighbors) 93 | # faiss returns euclidean distance squared 94 | return np.sqrt(distances), indices 95 | 96 | def _transform_covariates(self, X): 97 | if self.metric == "mahalanobis": 98 | if not hasattr(self, "VI"): 99 | raise AttributeError("Set inverse covariance VI first.") 100 | X = np.dot(X, self.VI.T) 101 | return np.ascontiguousarray(X).astype("float32") 102 | 103 | def set_params(self, **parameters): 104 | for parameter, value in parameters.items(): 105 | if parameter == "metric_params": 106 | self.set_params(**value) 107 | else: 108 | self._setattr(parameter, value) 109 | return self 110 | 111 | def get_params(self, deep=True): 112 | # `deep` plays no role because there are no sublearners 113 | params_to_return = ["metric", "n_cells", "n_probes", "index_type"] 114 | return {i: self.__getattribute__(i) for i in params_to_return} 115 | 116 | def _setattr(self, parameter, value): 117 | # based on faiss docs https://github.com/facebookresearch/faiss/wiki/MetricType-and-distances 118 | if parameter == "VI": 119 | value = np.linalg.inv(value) 120 | chol = np.linalg.cholesky(value) 121 | cholvi = np.linalg.inv(chol) 122 | value = cholvi 123 | setattr(self, parameter, value) 124 | -------------------------------------------------------------------------------- /causallib/contrib/hemm/README.md: -------------------------------------------------------------------------------- 1 | # Module `causallib.contrib.hemm` 2 | 3 | Implementation of the heterogeneous effect mixture model (HEMM) presented in the [_Interpretable Subgroup Discovery in Treatment Effect Estimation with Application to Opioid Prescribing Guidelines_](https://arxiv.org/abs/1905.03297) paper. 4 | 5 | HEMM is used for discovering subgroups with enhanced and diminished treatment effects in a potential outcomes causal inference framework, using sparsity to enhance interpretability. The HEMM’s outcome model is extended to include neural networks to better adjust for confounding and develop a joint inference procedure for the overall graphical model and neural networks. The model has two parts: 6 | 7 | 1. The subgroup discovery component. 8 | 2. The outcome prediction from the subgroup assignment and the interaction with confounders through an MLP. 9 | 10 | The model can be initialized with any of the following outcome models: 11 | * **Balanced Net**: A torch.model class that is used as a component of the HEMM module to determine the outcome as a function of confounders. The balanced net consists of two different neural networks for the two potential outcomes (under treatment and under control). 12 | * **MLP model**: An MLP with an ELU activation. This allows for a single neural network to have two heads, one for each of the potential outcomes. 13 | * **Linear model**: Linear model with two separate linear functions of the input covariates. 14 | 15 | The balanced net outcome model relies on utility functions that are to be used with the balanced net outcome model based on [_Estimating individual treatment effect: generalization bounds and algorithms_](https://arxiv.org/abs/1606.03976), Shalit et al., ICML (2017). The utility functions mainly consist of IPM metrics to calculate the imbalance between the control and treated population. 16 | -------------------------------------------------------------------------------- /causallib/contrib/hemm/__init__.py: -------------------------------------------------------------------------------- 1 | from causallib.contrib.hemm.hemm_api import HEMM 2 | -------------------------------------------------------------------------------- /causallib/contrib/hemm/hemm_metrics.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | 3 | # (C) Copyright 2019 IBM Corp. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | # 17 | # Created on Sept 25, 2019 18 | # Big thanks to Akanksha Atrey for original 19 | # implementation of this module. 20 | 21 | import torch 22 | import numpy as np 23 | 24 | 25 | def pdist2sq(X,Y): 26 | """ 27 | Computes the squared Euclidean distance between all pairs x in X, y in Y. 28 | """ 29 | C = -2*torch.matmul(X,torch.transpose(Y,0,1)) 30 | nx = torch.sum(torch.pow(X,2),dim=1,keepdim=True) 31 | ny = torch.sum(torch.pow(Y,2),dim=1,keepdim=True) 32 | D = (C + torch.transpose(ny,0,1)) + nx 33 | 34 | return D 35 | 36 | 37 | def mmd2_lin(X, t, p): 38 | """ 39 | Computes linear maximum mean discrepancy (MMD) metric. 40 | """ 41 | it = np.where(t==1)[0] 42 | ic = np.where(t==0)[0] 43 | 44 | Xc = X[ic] 45 | Xt = X[it] 46 | 47 | mean_control = torch.mean(Xc) 48 | mean_treated = torch.mean(Xt) 49 | 50 | mmd = torch.sum(torch.pow(2.0*p*mean_treated - 2.0*(1.0-p)*mean_control, 2)) 51 | 52 | return mmd 53 | 54 | 55 | def mmd2_rbf(X, t, p, sig=0.1): 56 | """ 57 | Computes the l2-RBF maximum mean discrepancy (MMD) for X given t. 58 | http://www.jmlr.org/papers/volume13/gretton12a/gretton12a.pdf -- Eq3 59 | """ 60 | it = np.where(t==1)[0] 61 | ic = np.where(t==0)[0] 62 | 63 | Xc = X[ic] 64 | Xt = X[it] 65 | 66 | if list(Xc.shape)[0] == 0.0 or list(Xt.shape)[0] == 0.0: 67 | return torch.tensor(float('nan')) # pylint: disable=E1102 68 | 69 | Kcc = torch.exp(-pdist2sq(Xc,Xc)/np.square(sig)) 70 | Kct = torch.exp(-pdist2sq(Xc,Xt)/np.square(sig)) 71 | Ktt = torch.exp(-pdist2sq(Xt,Xt)/np.square(sig)) 72 | 73 | m = float(list(Xc.shape)[0]) 74 | n = float(list(Xt.shape)[0]) 75 | 76 | mmd = np.square(1.0-p)/(m*(m-1.0))*(torch.sum(Kcc)-m) 77 | mmd = mmd + np.square(p)/(n*(n-1.0))*(torch.sum(Ktt)-n) 78 | mmd = mmd - 2.0*p*(1.0-p)/(m*n)*torch.sum(Kct) 79 | mmd = 4.0*mmd 80 | 81 | return mmd 82 | 83 | 84 | def wass(X,t,p,lam=10.0,its=20,sq=False,backpropT=False): 85 | """ 86 | Computes the Wasserstein metric. 87 | 88 | Algorithm 3 from "Fast Computation of Wasserstein Barycenters", Cuturi and Doucet (2014) (https://arxiv.org/pdf/1310.4375.pdf). 89 | See supplement B.1 from Shalit et al. (2017) for more details (https://arxiv.org/abs/1606.03976). 90 | """ 91 | it = np.where(t==1)[0] 92 | ic = np.where(t==0)[0] 93 | Xc = X[ic] 94 | Xt = X[it] 95 | nc = float(list(Xc.shape)[0]) 96 | nt = float(list(Xt.shape)[0]) 97 | 98 | if list(Xc.shape)[0] == 0.0 or list(Xt.shape)[0] == 0.0: 99 | return torch.tensor(float('nan')), torch.tensor(float('nan')) # pylint: disable=E1102 100 | 101 | ''' Compute distance matrix''' 102 | if sq: 103 | M = pdist2sq(Xt,Xc) 104 | else: 105 | M = torch.sqrt(pdist2sq(Xt,Xc)) 106 | 107 | ''' Estimate lambda and delta ''' 108 | M_mean = torch.mean(M) 109 | M_drop = torch.nn.Dropout(1/(nc*nt))(M) 110 | delta = (torch.max(M)).detach() 111 | eff_lam = (lam/M_mean).detach() 112 | 113 | ''' Compute new distance matrix ''' 114 | Mt = M 115 | row = (delta*torch.ones((M[0:1,:]).shape)).type(torch.float64) 116 | col = torch.cat((delta*torch.ones((M[:,0:1]).shape),torch.zeros((1,1))), 0).type(torch.float64) 117 | Mt = torch.cat((M,row), 0) 118 | Mt = torch.cat((Mt,col), 1) 119 | 120 | ''' Compute marginal vectors ''' 121 | a = torch.cat((p*torch.ones((np.where(t>0)[0].reshape(-1,1)).shape)/nt, (1-p)*torch.ones((1,1))), 0).type(torch.float64) 122 | b = torch.cat(((1-p)*torch.ones((np.where(t<1)[0].reshape(-1,1)).shape)/nc, p*torch.ones((1,1))), 0).type(torch.float64) 123 | 124 | ''' Compute kernel matrix''' 125 | Mlam = eff_lam*Mt 126 | K = torch.exp(-Mlam) + 1e-6 # added constant to avoid nan 127 | U = K*Mt 128 | ainvK = K/a 129 | 130 | u = a 131 | for i in range(0,its): 132 | u = 1.0/(torch.matmul(ainvK,(b/torch.transpose(torch.matmul(torch.transpose(u,0,1),K),0,1)))) 133 | v = b/(torch.transpose(torch.matmul(torch.transpose(u,0,1),K),0,1)) 134 | 135 | T = u*(torch.transpose(v,0,1)*K) 136 | 137 | if not backpropT: 138 | T = T.detach() 139 | 140 | E = T*Mt 141 | D = 2*torch.sum(E) 142 | 143 | return D, Mlam 144 | -------------------------------------------------------------------------------- /causallib/contrib/hemm/hemm_outcome_models.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | 3 | # (C) Copyright 2019 IBM Corp. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | # 17 | # Created on Sept 25, 2019 18 | 19 | import torch.nn.functional as F 20 | import torch.nn as nn 21 | import torch 22 | 23 | 24 | class BalancedNet(nn.Module): 25 | """A torch.model used as a component of the HEMM module to determine the outcome as a function of confounders. 26 | The balanced net consists of two different neural networks for the outcome and counteractual. 27 | """ 28 | 29 | def __init__(self, D_in, H, D_out): 30 | """Instantiate two nn.Linear modules and assign them as member variables. 31 | 32 | Args: 33 | D_in: input dimension 34 | H: dimension of hidden layer 35 | D_out: output dimension 36 | """ 37 | 38 | super(BalancedNet, self).__init__() 39 | 40 | self.f1 = nn.Linear(D_in, H) 41 | self.f2 = nn.Linear(H, D_out) 42 | 43 | self.cf1 = nn.Linear(D_in, H) 44 | self.cf2 = nn.Linear(H, D_out) 45 | 46 | def forward(self, x): 47 | """Accept a Variable of input data and return a Variable of output data. 48 | 49 | We can use Modules defined in the constructor as well as arbitrary operators on Variables. 50 | """ 51 | h_relu = F.elu(self.f1(x)) 52 | f = self.f2(h_relu) 53 | 54 | h_relu = F.elu(self.cf1(x)) 55 | cf = self.cf2(h_relu) 56 | 57 | out = torch.cat((f, cf), dim=1) 58 | 59 | return out 60 | 61 | 62 | def genMLPModule(D_in, H, out=1): 63 | """Fit an MLP with an ELU activation. 64 | 65 | This allows for a single neural network to have two heads for the outcome and counterfactual. 66 | """ 67 | if type(H) is int: 68 | model = torch.nn.Sequential(torch.nn.Linear(D_in, H), torch.nn.ELU(), torch.nn.Linear(H, out)) 69 | return model.double() 70 | 71 | 72 | def genLinearModule(D_in, out=1): 73 | """Two separate linear functions of the input covariates.""" 74 | model = torch.nn.Sequential(torch.nn.Linear(D_in, out),) 75 | return model.double() 76 | -------------------------------------------------------------------------------- /causallib/contrib/hemm/hemm_utilities.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | 3 | # (C) Copyright 2019 IBM Corp. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | # 17 | # Created on Sept 25, 2019 18 | 19 | import numpy as np 20 | from collections import Counter 21 | 22 | 23 | def KFoldStratifiedMultiClass(labels, n, seed=0): 24 | """Returns a stratified n-fold in a dictionary, trying to distribute classes across folds as well as possible.""" 25 | np.random.seed(seed) 26 | freqs = np.sum(labels, axis=0) # class frequencies 27 | tups = zip(freqs, range(19)) 28 | tups = sorted(tups) 29 | freq_order = map(lambda i: i[1], tups) # classes sorted by frequency 30 | folds = {} 31 | 32 | # folds is a dictionary indexed by integer, which contains 33 | # the indexes of the samples in each fold 34 | for i in range(n): 35 | folds[i] = [] 36 | 37 | # visited is a set to keep the samples already sorted into folds 38 | visited = set() 39 | for j in freq_order: 40 | occs = np.where(labels[:, j] == 1)[0] # find the occurrences of the class 41 | for i, k in enumerate(occs): 42 | # for each occurrence check if it's been sorted 43 | if k not in visited: 44 | # and in case not append it to a given fold 45 | folds[i % n].append(k) 46 | # after sorting each class update visited set 47 | for i in folds.keys(): 48 | visited = visited.union(set(folds[i])) 49 | 50 | # sort remaining examples into folds, trying to keep all folds 51 | # equally populated. 52 | for i in range(labels.shape[0]): 53 | if i not in visited: 54 | f = np.argmin(map(lambda i: len(folds[i]), range(n))) 55 | folds[f].append(i) 56 | 57 | return folds 58 | 59 | 60 | def getMeanandStd(X): 61 | """Takes the features and computes the mean and std dev of each column.""" 62 | mu = np.mean(X, axis=0) 63 | std = np.std(X, axis=0) 64 | 65 | return np.array([mu]).astype('float64'), np.array([std]).astype('float64') 66 | 67 | 68 | def genSplits(T, Y, k=5): 69 | """Returns k folds indices of the input data.""" 70 | star = ((T * 10) + (Y ** 1)).tolist() 71 | star = np.vstack((T, Y)).T 72 | starsum = star.sum(axis=1) 73 | starsum = ((T * 10) + (Y ** 1)) 74 | 75 | Counter(starsum) 76 | 77 | splits = KFoldStratifiedMultiClass(star, k) 78 | # for split in splits: 79 | # print(split,len(splits[split]), T[splits[split]].sum(), Y[splits[split]].sum(), Counter(starsum[splits[split]])) 80 | 81 | overlap = set(range(len(T))) 82 | for split in splits: 83 | overlap &= set(splits[split]) 84 | # print("Total Overlap:", len(overlap)) 85 | 86 | return splits 87 | 88 | 89 | def returnIndices(splits): 90 | """Takes the splits and return the train and dev splits.""" 91 | test = [] 92 | train = [] 93 | dev = [] 94 | 95 | for split in splits.keys(): 96 | if split in [0]: 97 | dev += splits[split] 98 | else: 99 | train += splits[split] 100 | 101 | return train, dev 102 | 103 | 104 | def returnIndicesTest(splits): 105 | """Takes the splits and return the train, test and dev splits.""" 106 | test = [] 107 | train = [] 108 | dev = [] 109 | 110 | for split in splits.keys(): 111 | if split in [0, 1]: 112 | test += splits[split] 113 | if split in [2]: 114 | dev += splits[split] 115 | else: 116 | train += splits[split] 117 | 118 | return train, test, dev 119 | -------------------------------------------------------------------------------- /causallib/contrib/hemm/load_ihdp_data.py: -------------------------------------------------------------------------------- 1 | # (C) Copyright 2019 IBM Corp. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # 15 | # Created on Sept 25, 2019 16 | 17 | """ 18 | IHDP Data Downloading, Unzipping and Loading 19 | 20 | This module provides data suitable for testing the HEMM estimator and for 21 | writing example notebooks. 22 | """ 23 | 24 | import numpy as np 25 | import os 26 | import urllib.request 27 | import zipfile 28 | 29 | 30 | def __download_data(url_path, local_path, verbose=0): 31 | if not os.path.exists(local_path): 32 | if verbose: 33 | print(f"Downloading data from {url_path} to {local_path}") 34 | req = urllib.request.urlretrieve(url=url_path, filename=local_path) 35 | return req[0] 36 | return local_path 37 | 38 | 39 | def loadIHDPData(cache_dir=None, verbose=0, delete_extracted=True): 40 | """Downloads and loads IHDP-1000 dataset. 41 | Taken From Fredrik Johansson's website: http://www.fredjo.com/ 42 | 43 | Args: 44 | cache_dir (str): Directory to which files will be downloaded 45 | If None: files will be downloaded to ~/causallib-data/. 46 | verbose (int): Controls the verbosity: the higher, the more messages. 47 | delete_extracted (bool): Delete extracted files from disk once loaded 48 | 49 | Returns: 50 | dict[str, dict[str, np.ndarray]]: "TRAIN" and "TEST" sets as keys. 51 | Values are dictionaries with `'x', 't', 'yf', 'ycf', 'mu0', 'mu1'` keys standing for 52 | covariates, treatment, factual outcome, counterfactual outcome, and noiseless potential outcomes 53 | 54 | Notes: 55 | Requires internet connection in case local data files do not already exist. 56 | Will save a local copy of the download 57 | 58 | """ 59 | base_remote_url = "http://www.fredjo.com/files/" 60 | file_name = "ihdp_npci_1-1000.{phase}.npz.zip" 61 | 62 | # Set local download location: 63 | if cache_dir is None: 64 | cache_dir = os.path.join("~", 'causallib-data') 65 | cache_dir = os.path.expanduser(cache_dir) # Expand ~ component to full path 66 | cache_dir = os.path.join(cache_dir, "IHDP") 67 | # cache_dir = cache_dir.replace("/", os.sep) 68 | os.makedirs(cache_dir, exist_ok=True) 69 | 70 | data = {} 71 | for phase in ["train", "test"]: 72 | # Obtain local copy of the data: 73 | phase_file_name = file_name.format(phase=phase) 74 | file_path = __download_data( 75 | url_path=base_remote_url + phase_file_name, 76 | local_path=os.path.join(cache_dir, phase_file_name), 77 | verbose=verbose 78 | ) 79 | 80 | # Extract zipped data: 81 | npz_file_path = file_path.rsplit(".", maxsplit=1)[0] # Remove ".zip" extension 82 | if not os.path.exists(npz_file_path): 83 | with zipfile.ZipFile(file_path) as zf: 84 | if verbose: 85 | print(f"Extracting file into {npz_file_path}") 86 | zf.extractall(path=cache_dir) 87 | 88 | # Load data: 89 | phase_data = np.load(npz_file_path) 90 | phase_data = dict(phase_data) # Load into memory, avoid lazy-loading 91 | data[phase.upper()] = phase_data 92 | 93 | # # In-memory extraction, works only in python>=3.7 https://github.com/python/cpython/pull/4966 94 | # with zipfile.ZipFile(file_path) as zf: 95 | # internal_file_name = phase_file_name.rsplit(".", maxsplit=1)[0] # Remove ".zip" extension 96 | # with zf.open(internal_file_name, 'r') as npz_file: 97 | # data[phase.upper()] = dict(np.load(npz_file)) 98 | 99 | if delete_extracted: 100 | if verbose: 101 | print(f"Deleting extracted file {npz_file_path}") 102 | os.remove(npz_file_path) 103 | 104 | return data 105 | -------------------------------------------------------------------------------- /causallib/contrib/requirements.txt: -------------------------------------------------------------------------------- 1 | -f https://download.pytorch.org/whl/cpu/ # To support cpu torch installation 2 | torch>=1.2.0 3 | faiss-cpu~=1.7.0;python_version < '3.12' # Can also use gpu for some Python versions 4 | faiss-cpu~=1.8.0;python_version >= '3.12' -------------------------------------------------------------------------------- /causallib/contrib/shared_sparsity_selection/__init__.py: -------------------------------------------------------------------------------- 1 | from .shared_sparsity_selection import SharedSparsityConfounderSelection 2 | -------------------------------------------------------------------------------- /causallib/contrib/sklearn_scorer_wrapper/__init__.py: -------------------------------------------------------------------------------- 1 | from .sklearn_scorer_wrapper import SKLearnScorerWrapper 2 | -------------------------------------------------------------------------------- /causallib/contrib/sklearn_scorer_wrapper/sklearn_scorer_wrapper.py: -------------------------------------------------------------------------------- 1 | from causallib.metrics.scorers import PropensityScorerBase 2 | 3 | 4 | class SKLearnScorerWrapper(PropensityScorerBase): 5 | def __init__(self, score_func, sign=None, **kwargs): 6 | super().__init__( 7 | score_func=score_func, 8 | sign=1, # This keeps original scorer sign 9 | **kwargs 10 | ) 11 | 12 | def _score(self, estimator, X, a, y=None, sample_weight=None, **kwargs): 13 | learner = self._extract_sklearn_estimator(estimator) 14 | score = self._score_func(learner, X, a, sample_weight=sample_weight) 15 | return score 16 | 17 | @staticmethod 18 | def _extract_sklearn_estimator(estimator): 19 | if hasattr(estimator, "best_estimator_"): 20 | # Causallib's wrapper around GridSearchCV 21 | return estimator.best_estimator_.learner 22 | if hasattr(estimator, "learner"): 23 | return estimator.learner 24 | raise AttributeError( 25 | f"Could not extract an sklearn estimator from {estimator}," 26 | f"which has the following attributes:\n" 27 | f"{list(estimator.__dict__.keys())}" 28 | ) -------------------------------------------------------------------------------- /causallib/contrib/tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BiomedSciAI/causallib/803a4d34eaf09980258b498631d6af15017528dc/causallib/contrib/tests/__init__.py -------------------------------------------------------------------------------- /causallib/contrib/tests/test_hemm.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pandas as pd 3 | import torch 4 | import unittest 5 | 6 | from causallib.contrib.hemm import HEMM 7 | from causallib.contrib.hemm.hemm_utilities import genSplits, returnIndices, getMeanandStd 8 | from causallib.contrib.hemm.hemm_outcome_models import genMLPModule, genLinearModule, BalancedNet 9 | from causallib.contrib.hemm.gen_synthetic_data import gen_montecarlo 10 | 11 | 12 | class TestHemmEstimator(unittest.TestCase): 13 | def experiment(self, data, i, comp, response, outcome_model, lr, batch_size): 14 | np.random.seed(0) 15 | 16 | Xtr = data['TRAIN']['x'][:, :, i] 17 | Ttr = data['TRAIN']['t'][:, i] 18 | Ytr = data['TRAIN']['yf'][:, i] 19 | 20 | Ytr_ = np.ones_like(Ytr) 21 | splits = genSplits(Ttr, Ytr_) 22 | train, dev = returnIndices(splits) 23 | n = Xtr.shape[0] 24 | 25 | Xte = data['TEST']['x'][:, :, i] 26 | Tte = data['TEST']['t'][:, i] 27 | Yte = data['TEST']['yf'][:, i] 28 | 29 | mu, std = getMeanandStd(Xtr) 30 | 31 | Xdev = Xtr[dev] # Numpy array 32 | Ydev = torch.from_numpy(Ytr[dev].astype('float64')) 33 | Tdev = Ttr[dev] # Numpy array 34 | 35 | Xtr = pd.DataFrame(Xtr[train]) # Train covariates as a data frame 36 | Ytr = torch.from_numpy(Ytr[train].astype('float64')) 37 | Ttr = pd.Series(Ttr[train]) # Train treatment assignments as a series 38 | 39 | Xte = torch.from_numpy(Xte.astype('float64')) 40 | Yte = torch.from_numpy(Yte.astype('float64')) 41 | Tte = torch.from_numpy(Tte.astype('float64')) 42 | 43 | if outcome_model == 'MLP': 44 | outcome_model = genMLPModule(Xte.shape[1], Xte.shape[1] / 5, 2) 45 | elif outcome_model == 'linear': 46 | outcome_model = genLinearModule(Xte.shape[1], 2) 47 | elif outcome_model == 'CF': 48 | outcome_model = BalancedNet(Xte.shape[1], Xte.shape[1], 1) 49 | 50 | estimator = HEMM( 51 | Xte.shape[1], 52 | comp, 53 | mu=mu, 54 | std=std, 55 | bc=6, 56 | lamb=0., 57 | spread=0., 58 | outcome_model=outcome_model, 59 | epochs=500, 60 | batch_size=batch_size, 61 | learning_rate=lr, 62 | weight_decay=1e-4, 63 | metric='LL', 64 | response=response, 65 | imb_fun='wass' 66 | ) 67 | estimator.fit(Xtr, Ttr, Ytr, validation_data=(Xdev, Tdev, Ydev)) 68 | 69 | Xtr = data['TRAIN']['x'][:, :, i] 70 | Ttr = data['TRAIN']['t'][:, i] 71 | Ytr = data['TRAIN']['yf'][:, i] 72 | 73 | Xtr = torch.from_numpy(Xtr.astype('float64')) 74 | Ytr = torch.from_numpy(Ytr.astype('float64')) 75 | Ttr = torch.from_numpy(Ttr.astype('float64')) 76 | 77 | in_estimations = estimator.estimate_individual_outcome(Xtr, Ttr) 78 | out_estimations = estimator.estimate_individual_outcome(Xte, Tte) 79 | 80 | group_proba = estimator.get_groups_proba(Xte) 81 | self.assertEqual(group_proba.shape, (data['TEST']['x'][:, :, 1].shape[0], comp)) 82 | 83 | group_assignment = estimator.get_groups(Xte) 84 | pd.testing.assert_series_equal(group_assignment, group_proba.idxmax(axis="columns")) 85 | 86 | group_effect = estimator.get_groups_effect(Xte, Tte) 87 | self.assertEqual(group_effect.shape, (comp,)) 88 | 89 | group_sizes = estimator.group_sizes(Xte) 90 | self.assertEqual(len(group_sizes.keys()), comp) 91 | 92 | return in_estimations, out_estimations 93 | 94 | def test_hemm_estimator(self): 95 | data = gen_montecarlo(1000, 24, 2) 96 | in_estimations, out_estimations = self.experiment(data, 1, 3, 'cont', 'CF', 1e-3, 10) 97 | self.assertEqual(in_estimations.shape, (data['TRAIN']['x'][:, :, 1].shape[0], 2)) 98 | self.assertEqual(out_estimations.shape, (data['TEST']['x'][:, :, 1].shape[0], 2)) 99 | -------------------------------------------------------------------------------- /causallib/contrib/tests/test_sklearn_scorer_wrapper.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | import pandas as pd 4 | 5 | from sklearn.linear_model import LogisticRegression 6 | from sklearn.datasets import make_classification 7 | from sklearn.utils import Bunch 8 | from sklearn.metrics import get_scorer 9 | 10 | from causallib.estimation import IPW 11 | from causallib.model_selection import GridSearchCV 12 | 13 | from causallib.contrib.sklearn_scorer_wrapper import SKLearnScorerWrapper 14 | 15 | 16 | class TestSKLearnScorerWrapper(unittest.TestCase): 17 | @classmethod 18 | def setUpClass(cls): 19 | N = 500 20 | X, a = make_classification( 21 | n_samples=N, 22 | n_features=5, 23 | n_informative=5, 24 | n_redundant=0, 25 | random_state=42, 26 | ) 27 | X = pd.DataFrame(X) 28 | a = pd.Series(a) 29 | cls.data = Bunch(X=X, a=a, y=a) 30 | 31 | learner = LogisticRegression() 32 | ipw = IPW(learner) 33 | ipw.fit(X, a) 34 | # cls.learner = learner 35 | cls.estimator = ipw 36 | 37 | def test_agreement_with_sklearn(self): 38 | scorer_names = [ 39 | "accuracy", 40 | "average_precision", 41 | "neg_brier_score", 42 | "f1", 43 | "neg_log_loss", 44 | "precision", 45 | "recall", 46 | "roc_auc", 47 | ] 48 | for scorer_name in scorer_names: 49 | with self.subTest(f"Test scorer {scorer_name}"): 50 | scorer = get_scorer(scorer_name) 51 | score = scorer(self.estimator.learner, self.data.X, self.data.a) 52 | 53 | causallib_adapted_scorer = SKLearnScorerWrapper(scorer) 54 | causallib_score = causallib_adapted_scorer( 55 | self.estimator, self.data.X, self.data.a, self.data.y 56 | ) 57 | 58 | self.assertAlmostEqual(causallib_score, score) 59 | 60 | def test_hyperparameter_search_model(self): 61 | scorer = SKLearnScorerWrapper(get_scorer("roc_auc")) 62 | param_grid = dict( 63 | clip_min=[0.2, 0.3], 64 | learner__C=[0.1, 1], 65 | ) 66 | model = GridSearchCV( 67 | self.estimator, 68 | param_grid=param_grid, 69 | scoring=scorer, 70 | cv=3, 71 | ) 72 | model.fit(self.data.X, self.data.a, self.data.y) 73 | 74 | score = scorer(model, self.data.X, self.data.a, self.data.y) 75 | self.assertGreaterEqual(score, model.best_score_) 76 | -------------------------------------------------------------------------------- /causallib/datasets/README.md: -------------------------------------------------------------------------------- 1 | 18 | 19 | # Module `causallib.datasets` 20 | 21 | This module contains an example dataset, 22 | and a simulator to create a dataset. 23 | 24 | ## Datasets 25 | Currently one dataset is included. 26 | This is the National Health and Nutrition Examination Survey (NNHEFS) dataset. 27 | The dataset was adapted from the data available at 28 | . 29 | 30 | It can be loaded using 31 | ```Python 32 | from causallib.datasets.data_loader import load_nhefs 33 | data = load_nhefs() 34 | covariates = data.X, 35 | treatment_assignment = data.a, 36 | observed_outcome = data.y 37 | ``` 38 | 39 | This loads an object in which `data.X`, `data.a`, and `data.y` 40 | respectively hold the features for each individual, 41 | whether they stopped-smoking, 42 | and their observed difference in weight between 1971 and 1983. 43 | 44 | ## Simulator 45 | This module implements a simulator and some related functions 46 | (e.g. creating random graph topologies) 47 | 48 | CausalSimulator is based on an explicit graphical model connecting 49 | the feature data with several special nodes for treatment assignmenr, outcome, 50 | and censoring. 51 | CausalSimulator can generate the feature data randomly, 52 | or it can use a given dataset. 53 | The approach without input data is exhibited below, 54 | and the approach based on existing data is exemplified in the 55 | notebook [`CasualSimulator_example.ipynb`](CasualSimulator_example.ipynb) 56 | 57 | ### With no given data 58 | 59 | To initialize the simulator you need to state all the arguments 60 | regarding the graph's structure and variable related information 61 | 62 | ```Python 63 | import numpy as np 64 | from causallib.datasets import CausalSimulator 65 | topology = np.zeros((4, 4), dtype=np.bool) # topology[i,j] iff node j is a parent of node i 66 | topology[1, 0] = topology[2, 0] = topology[2, 1] = topology[3, 1] = topology[3, 2] = True 67 | var_types = ["hidden", "covariate", "treatment", "outcome"] 68 | link_types = ['linear', 'linear', 'linear', 'linear'] 69 | prob_categories = [[0.25, 0.25, 0.5], None, [0.5, 0.5], None] 70 | treatment_methods = "gaussian" 71 | snr = 0.9 72 | treatment_importance = 0.8 73 | effect_sizes = None 74 | outcome_types = "binary" 75 | 76 | sim = CausalSimulator(topology=topology, prob_categories=prob_categories, 77 | link_types=link_types, snr=snr, var_types=var_types, 78 | treatment_importances=treatment_importance, 79 | outcome_types=outcome_types, 80 | treatment_methods=treatment_methods, 81 | effect_sizes=effect_sizes) 82 | X, prop, (y0, y1) = sim.generate_data(num_samples=100) 83 | ``` 84 | 85 | ```plantuml 86 | digraph CausalGraph { 87 | hidden -> covariate 88 | hidden -> treatment 89 | covariate -> treatment 90 | covariate -> outcome 91 | treatment -> outcome 92 | } 93 | ``` 94 | 95 | * This creates a graph `topology` of 4 variables, as depicted in the graph above: 96 | 1 hidden var (i.e. latent 97 | covariate), 1 regular covariate, 1 treatment variable and 1 outcome. 98 | * `link_types` determines that all variables will have linear 99 | dependencies on their predecessors. 100 | * `var_types`, together with `prov_categories` define: 101 | * Variable 0 (hidden) is categorical with categories 102 | distributed by the multinomial distribution `[0.25, 0.25, 0.5]`. 103 | * Variable 1 (covariate) is continuous (since its 104 | corresponding prob_category is None). 105 | * Variable 2 (treatment) is categorical and treatment assignment is equal 106 | between the treatment groups. 107 | * `treatment_methods` means that treatment will be assigned by percentiles using a 108 | Gaussian distribution. 109 | * All variables have signal to noise ratio of 110 | *signal* / (*signal*+*noise*) = 0.9. 111 | * `treatment_importance = 0.8` indicates that the outcome will be affected 80% 112 | by treatment and 20% by all other predecessors. 113 | * Effect size won't be manipulated into a specific desired value (since 114 | it is None). 115 | * Outcome will be binary. 116 | 117 | The data that is generated contains: 118 | * `X` contains all the data generated (including latent variables, 119 | treatment assignments and outcome) 120 | * `prop` contains the propensities 121 | * `y0` and `y1` hold the counterfactual outcomes without and with treatment, 122 | respectively. 123 | 124 | ### Additional examples 125 | A more elaborate example that includes using existing data 126 | is available in the example notebook. 127 | 128 | ## License 129 | Datasets are provided under [Community Data License Agreement (CDLA)](https://cdla.io/). 130 | The ACIC16 dataset is provided under [CDLA-sharing](https://cdla.io/sharing-1-0/) license. 131 | The NHEFS dataset is provided under [CDLA-permissive](https://cdla.io/permissive-1-0/) license. 132 | Please see the full corresponding license within each directory. 133 | 134 | We thank the authors for sharing their data within this package. 135 | -------------------------------------------------------------------------------- /causallib/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | from .data_loader import load_nhefs, load_nhefs_survival, load_acic16 2 | from ..simulation.CausalSimulator3 import CausalSimulator3 as CausalSimulator, generate_random_topology 3 | -------------------------------------------------------------------------------- /causallib/datasets/data/acic_challenge_2016/README.md: -------------------------------------------------------------------------------- 1 | # ACIC 2016 challenge data 2 | 3 | This folder contains covariates, simulated treatment, and simulated response variables 4 | for the causal inference challenge in the 2016 Atlantic Causal Inference Conference. 5 | 6 | For each of 20 conditions, treatment and response data were simulated from real-world data 7 | corresponding to 4802 individuals and 58 covariates. 8 | 9 | 10 | #### Files: 11 | * x.csv - matrix of covariates; categorical variables are coded as A/B/C/..., 12 | binary variables as 0/1, and real numbers are left alone 13 | * zymu_##.csv - the twenty sets of treatment and response variables corresponding to various simulation settings; 14 | * treatment is column `z`, 15 | * two noisy potential outcomes: 16 | the observed response under control (`y0`) and under treatment (`y1`), 17 | * two expected potential outcomes (`mu0, mu1`) 18 | 19 | #### Cite: 20 | If used for academic purposes, please consider citing the competition organizers: 21 | ```bibtex 22 | @article{dorie2019automated, 23 | title={Automated versus do-it-yourself methods for causal inference: Lessons learned from a data analysis competition}, 24 | author={Dorie, Vincent and Hill, Jennifer and Shalit, Uri and Scott, Marc and Cervone, Dan}, 25 | journal={Statistical Science}, 26 | volume={34}, 27 | number={1}, 28 | pages={43--68}, 29 | year={2019}, 30 | publisher={Institute of Mathematical Statistics} 31 | } 32 | ``` 33 | -------------------------------------------------------------------------------- /causallib/datasets/data/nhefs/NHEFS_codebook.csv: -------------------------------------------------------------------------------- 1 | Variable name,Description 2 | active,"IN YOUR USUAL DAY, HOW ACTIVE ARE YOU? IN 1971, 0:very active, 1:moderately active, 2:inactive" 3 | age,AGE IN 1971 4 | alcoholfreq,"HOW OFTEN DO YOU DRINK? IN 1971 0: Almost every day, 1: 2-3 times/week, 2: 1-4 times/month, 3: < 12 times/year, 4: No alcohol last year, 5: Unknown" 5 | alcoholhowmuch,"WHEN YOU DRINK, HOW MUCH DO YOU DRINK? IN 1971" 6 | alcoholpy,"HAVE YOU HAD 1 DRINK PAST YEAR? IN 1971, 1:EVER, 0:NEVER; 2:MISSING" 7 | alcoholtype,"WHICH DO YOU MOST FREQUENTLY DRINK? IN 1971 1: BEER, 2: WINE, 3: LIQUOR, 4: OTHER/UNKNOWN" 8 | allergies,"USE ALLERGIES MEDICATION IN 1971, 1:EVER, 0:NEVER" 9 | asthma,"DX ASTHMA IN 1971, 1:EVER, 0:NEVER" 10 | bithcontrol,"BIRTH CONTROL PILLS PAST 6 MONTHS? IN 1971 1:YES, 0:NO, 2:MISSING" 11 | birthplace,CHECK STATE CODE - SECOND PAGE 12 | boweltrouble,"USE BOWEL TROUBLE MEDICATION IN 1971, 1:EVER, 0:NEVER, ; 2:MISSING" 13 | bronch,"DX CHRONIC BRONCHITIS/EMPHYSEMA IN 1971, 1:EVER, 0:NEVER" 14 | cholesterol,SERUM CHOLESTEROL (MG/100ML) IN 1971 15 | chroniccough,"DX CHRONIC COUGH IN 1971, 1:EVER, 0:NEVER" 16 | colitis,"DX COLITIS IN 1971, 1:EVER, 0:NEVER" 17 | dadth,DAY OF DEATH 18 | dbp,DIASTOLIC BLOOD PRESSURE IN 1982 19 | death,"DEATH BY 1992, 1:YES, 0:NO" 20 | diabetes,"DX DIABETES IN 1971, 1:EVER, 0:NEVER, 2:MISSING" 21 | education,"AMOUNT OF EDUCATION BY 1971: 1: 8TH GRADE OR LESS, 2: HS DROPOUT, 3: HS, 4:COLLEGE DROPOUT, 5: COLLEGE OR MORE" 22 | exercise,"IN RECREATION, HOW MUCH EXERCISE? IN 1971, 0:much exercise,1:moderate exercise,2:little or no exercise" 23 | hayfever,"DX HAY FEVER IN 1971, 1:EVER, 0:NEVER" 24 | hbp,"DX HIGH BLOOD PRESSURE IN 1971, 1:EVER, 0:NEVER, 2:MISSING" 25 | hbpmed,"USE HIGH BLOOD PRESSURE MEDICATION IN 1971, 1:EVER, 0:NEVER, ; 2:MISSING" 26 | headache,"USE HEADACHE MEDICATION IN 1971, 1:EVER, 0:NEVER" 27 | hepatitis,"DX HEPATITIS IN 1971, 1:EVER, 0:NEVER" 28 | hf,"DX HEART FAILURE IN 1971, 1:EVER, 0:NEVER" 29 | hightax82,"LIVING IN A HIGHLY TAXED STATE IN 1982, High taxed state of residence=1, 0 otherwise" 30 | ht,HEIGHT IN CENTIMETERS IN 1971 31 | income,"TOTAL FAMILY INCOME IN 1971 11:<$1000, 12: 1000-1999, 13: 2000-2999, 14: 3000-3999, 15: 4000-4999, 16: 5000-5999, 17: 6000-6999, 18: 7000-9999, 19: 10000-14999, 20: 15000-19999, 21: 20000-24999, 22: 25000+" 32 | infection,"USE INFECTION MEDICATION IN 1971, 1:EVER, 0:NEVER" 33 | lackpep,"USELACK OF PEP MEDICATION IN 1971, 1:EVER, 0:NEVER" 34 | marital,"MARITAL STATUS IN 1971 1: Under 17, 2: Married, 3: Widowed, 4: Never married, 5: Divorced, 6: Separated, 8: Unknown" 35 | modth,MONTH OF DEATH 36 | nerves,"USE NERVES MEDICATION IN 1971, 1:EVER, 0:NEVER" 37 | nervousbreak,"DX NERVOUS BREAKDOWN IN 1971, 1:EVER, 0:NEVER" 38 | otherpain,"USE OTHER PAINS MEDICATION IN 1971, 1:EVER, 0:NEVER" 39 | pepticulcer,"DX PEPTIC ULCER IN 1971, 1:EVER, 0:NEVER" 40 | pica,"DO YOU EAT DIRT OR CLAY, STARCH OR OTHER NON STANDARD FOOD? IN 1971 1:EVER, 0:NEVER; 2:MISSING" 41 | polio,"DX POLIO IN 1971, 1:EVER, 0:NEVER" 42 | pregnancies,TOTAL NUMBER OF PREGNANCIES? IN 1971 43 | price71,AVG TOBACCO PRICE IN STATE OF RESIDENCE 1971 (US$2008) 44 | price71_82,DIFFERENCE IN AVG TOBACCO PRICE IN STATE OF RESIDENCE 1971-1982 (US$2008) 45 | price82,AVG TOBACCO PRICE IN STATE OF RESIDENCE 1982 (US$2008) 46 | qsmk,"QUIT SMOKING BETWEEN 1ST QUESTIONNAIRE AND 1982, 1:YES, 0:NO" 47 | race,0: WHITE 1: BLACK OR OTHER IN 1971 48 | sbp,SYSTOLIC BLOOD PRESSURE IN 1982 49 | school,HIGHEST GRADE OF REGULAR SCHOOL EVER IN 1971 50 | seqn,UNIQUE PERSONAL IDENTIFIER 51 | sex,0: MALE 1: FEMALE 52 | smokeintensity,NUMBER OF CIGARETTES SMOKED PER DAY IN 1971 53 | smkintensity 82_71,INCREASE IN NUMBER OF CIGARETTES/DAY BETWEEN 1971 and 1982 54 | smokeyrs,YEARS OF SMOKING 55 | tax71,TOBACCO TAX IN STATE OF RESIDENCE 1971 (US$2008) 56 | tax71_82,DIFFERENCE IN TOBACCO TAX IN STATE OF RESIDENCE 1971-1982 (US$2008) 57 | tax82,TOBACCO TAX IN STATE OF RESIDENCE 1971 (US$2008) 58 | tb,"DX TUBERCULOSIS IN 1971, 1:EVER, 0:NEVER" 59 | tumor,"DX MALIGNANT TUMOR/GROWTH IN 1971, 1:EVER, 0:NEVER" 60 | weakheart,"USE WEAK HEART MEDICATION IN 1971, 1:EVER, 0:NEVER" 61 | wt71,WEIGHT IN KILOGRAMS IN 1971 62 | wt82,WEIGHT IN KILOGRAMS IN 1982 63 | wt82_71,WEIGHT CHANGE IN KILOGRAMS 64 | wtloss,"USE WEIGHT LOSS MEDICATION IN 1971, 1:EVER, 0:NEVER" 65 | yrdth,YEAR OF DEATH 66 | -------------------------------------------------------------------------------- /causallib/estimation/README.md: -------------------------------------------------------------------------------- 1 | # Module `causallib.estimation` 2 | This module allows estimating counterfactual outcomes and effect of treatment 3 | using a variety of common causal inference methods, as detailed below. 4 | Each of these methods can use an underlying machine learning model of choice. 5 | These models must have an interface similar to the one defined by 6 | scikit-learn. 7 | Namely, they must have `fit()` and `predict()` functions implemented, 8 | and `predict_proba()` implemented for models that predict categorical outcomes. 9 | 10 | Additional methods will be added incrementally. 11 | 12 | ## Available Methods 13 | The methods that are currently available are: 14 | 15 | 1. Inverse probability weighting (with minimal value cutoff): 16 | `causallib.estimation.IPW` 17 | 1. Standardization 18 | 1. As a single model depending on treatment: 19 | `causallib.estimation.Standardization` 20 | 1. Stratified by treatment value (similar to pooled regression): 21 | `causallib.estimation.StratifiedStandardization` 22 | 1. Doubly robust methods, as explained 23 | [here](https://www4.stat.ncsu.edu/~davidian/double.pdf) 24 | 1. Using the weighting as an additional feature: 25 | `causallib.estimation.DoublyRobustIpFeature` 26 | 1. Using the weighting for training the standardization model: 27 | `causallib.estimation.DoublyRobustJoffe` 28 | 1. Using the original formula for doubly robust estimation: 29 | `causallib.estimation.DoublyRobustVanilla` 30 | 31 | 32 | ### Example: Inverse Probability Weighting (IPW) 33 | An IPW model can be run, for example, using 34 | ```Python 35 | from sklearn.linear_model import LogisticRegression 36 | from causallib.estimation import IPW 37 | from causallib.datasets.data_loader import fetch_smoking_weight 38 | 39 | model = LogisticRegression() 40 | ipw = IPW(learner=model) 41 | data = fetch_smoking_weight() 42 | ipw.fit(data.X, data.a) 43 | ipw.estimate_population_outcome(data.X, data.a, data.y) 44 | ``` 45 | Note that `model` can be replaced by any machine learning model 46 | as explained above. 47 | -------------------------------------------------------------------------------- /causallib/estimation/__init__.py: -------------------------------------------------------------------------------- 1 | from .doubly_robust import AIPW, PropensityFeatureStandardization, WeightedStandardization 2 | from .ipw import IPW 3 | from .overlap_weights import OverlapWeights 4 | from .standardization import Standardization, StratifiedStandardization 5 | from .marginal_outcome import MarginalOutcomeEstimator 6 | from .matching import Matching, PropensityMatching 7 | from .rlearner import RLearner 8 | from .xlearner import XLearner 9 | from .tmle import TMLE 10 | 11 | -------------------------------------------------------------------------------- /causallib/estimation/marginal_outcome.py: -------------------------------------------------------------------------------- 1 | """ 2 | (C) Copyright 2019 IBM Corp. 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | http://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | 16 | Created on Apr 25, 2018 17 | 18 | """ 19 | 20 | import pandas as pd 21 | 22 | from .base_weight import WeightEstimator 23 | from .base_estimator import PopulationOutcomeEstimator 24 | 25 | 26 | class MarginalOutcomeEstimator(WeightEstimator, PopulationOutcomeEstimator): 27 | """ 28 | A marginal outcome predictor. 29 | Assumes the sample is marginally exchangeable, and therefore does not correct (adjust, control) for covariates. 30 | Predicts the outcome/effect as if the sample came from a randomized control trial: $\\Pr[Y|A]$. 31 | """ 32 | 33 | def compute_weight_matrix(self, X, a, use_stabilized=None, **kwargs): 34 | # Another way to view this is that Uncorrected is basically an IPW-like with all individuals equally weighted. 35 | treatment_values = a.unique() 36 | treatment_values = treatment_values.sort() 37 | weights = pd.DataFrame(data=1, index=a.index, columns=treatment_values) 38 | return weights 39 | 40 | def compute_weights(self, X, a, treatment_values=None, use_stabilized=None, **kwargs): 41 | # Another way to view this is that Uncorrected is basically an IPW-like with all individuals equally weighted. 42 | weights = pd.Series(data=1, index=a.index) 43 | return weights 44 | 45 | def fit(self, X=None, a=None, y=None): 46 | """ 47 | Dummy implementation to match the API. 48 | MarginalOutcomeEstimator acts as a WeightEstimator that weights each sample as 1 49 | 50 | Args: 51 | X (pd.DataFrame): Covariate matrix of size (num_subjects, num_features). 52 | a (pd.Series): Treatment assignment of size (num_subjects,). 53 | y (pd.Series): Observed outcome of size (num_subjects,). 54 | 55 | Returns: 56 | MarginalOutcomeEstimator: a fitted model. 57 | """ 58 | return self 59 | 60 | def estimate_population_outcome(self, X, a, y, w=None, treatment_values=None): 61 | """ 62 | Calculates potential population outcome for each treatment value. 63 | 64 | Args: 65 | X (pd.DataFrame): Covariate matrix of size (num_subjects, num_features). 66 | a (pd.Series): Treatment assignment of size (num_subjects,). 67 | y (pd.Series): Observed outcome of size (num_subjects,). 68 | w (pd.Series | None): Individual (sample) weights calculated. Used to achieved unbiased average outcome. 69 | If not provided, will be calculated on the data. 70 | treatment_values (Any): Desired treatment value/s to stratify upon before aggregating individual into 71 | population outcome. 72 | If not supplied, calculates for all available treatment values. 73 | 74 | Returns: 75 | pd.Series[Any, float]: Series which index are treatment values, and the values are numbers - the 76 | aggregated outcome for the strata of people whose assigned treatment is the key. 77 | """ 78 | if w is None: 79 | w = self.compute_weights(X, a) 80 | res = self._compute_stratified_weighted_aggregate(y, sample_weight=w, stratify_by=a, 81 | treatment_values=treatment_values) 82 | return res 83 | 84 | -------------------------------------------------------------------------------- /causallib/evaluation/README.md: -------------------------------------------------------------------------------- 1 | # Module `causallib.evaluation` 2 | 3 | This submodule allows evaluating the performance of the estimation models defined 4 | in `causallib.estmation`. 5 | 6 | The intended usage is to use `evaluate` from `causalib.evaluation` to generate `EvaluationResults` objects. 7 | If the cross-validation parameter `cv` is not supplied, a simple evaluation without cross-validation will 8 | be performed. And an object will be returned that can generate various plots, accessible by name (see the docs) 9 | or all at once via `plot_all()`. 10 | The object also includes the model's predictions, evaluated metrics, the fitted models as `models` 11 | and a copy of the original data as (`X`, `a`, and `y`). 12 | 13 | If the `cv` parameter is set to `"auto"`, `evaluate` generates a k-fold cross-validation with train and validation 14 | phases, refitting the model `k` times, with `k=5`. Other options are also supported 15 | for customizing cross-validation, see the docs. 16 | The `EvaluationResults` will also contain a list of train/test split indices used by cross-validation in `cv`. 17 | 18 | 19 | ## Example: Inverse probability weighting 20 | 21 | An IPW method with logistic regression can be evaluated 22 | in cross-validation using 23 | 24 | ```Python 25 | from sklearn.linear_model import LogisticRegression 26 | from causallib.estimation import IPW 27 | from causallib.datasets.data_loader import fetch_smoking_weight 28 | from causallib.evaluation import evaluate 29 | 30 | data = fetch_smoking_weight() 31 | 32 | model = LogisticRegression() 33 | ipw = IPW(learner=model) 34 | ipw.fit(data.X, data.a, data.y) 35 | res = evaluate(ipw, data.X, data.a, data.y, cv="auto") 36 | 37 | res.plot_all() 38 | ``` 39 | 40 | This will train the models and create evaluation plots 41 | showing the performance on both the training and validation data. 42 | 43 | ```python 44 | print(res.all_plot_names) 45 | # {'weight_distribution', 'pr_curve', 'covariate_balance_love', 'roc_curve', 'calibration', 'covariate_balance_slope'} 46 | res.plot_covariate_balance(kind="love", phase="valid") 47 | res.plot_weight_distribution() 48 | res.plot_roc_curve() 49 | res.plot_calibration_curve() 50 | ``` 51 | 52 | ## Submodule structure 53 | 54 | *This section is intended for future contributors and those seeking to customize the evaluation logic.* 55 | 56 | The `evaluate` function is defined in `evaluator.py`. To generate predictions 57 | it instantiates a `Predictor` object as defined in `predictor.py`. This handles 58 | refitting and generating the necessary predictions for the different models. 59 | The predictions objects are defined in `predictions.py`. 60 | Metrics are defined in `metrics.py`. These are simple functions and do not depend 61 | on the structure of the objects. 62 | The metrics are applied to the individual predictions via the scoring functions 63 | defined in `scoring.py`. 64 | The results of the predictors and scorers across multiple phases and folds are 65 | combined in the `EvaluationResults` object which is defined in `results.py`. 66 | 67 | ### evaluation.plots submodule structure 68 | 69 | In order to generate the correct plots from the `EvaluationResults` objects, we 70 | need `PlotDataExtractor` objects. The responsibility of these objects is to extract 71 | the correct data for a given plot from `EvaluationResults`, and they are defined 72 | in `plots/data_extractors.py`. 73 | Enabling plotting as member functions for `EvaluationResults` objects is accomplished 74 | using the plotter mixins, which are defined in `plots/mixins.py`. 75 | When an `EvaluationResults` object is produced by evaluate, the `EvaluationResults.make` 76 | factory ensures that it has the correct extractors and plotting mixins. 77 | 78 | Finally, `plots/curve_data_makers.py` contains a number of methods for aggregating and 79 | combining data to produce curves for ROC, PR and calibration plots. 80 | And `plots/plots.py` contains the individual plotting functions. 81 | 82 | ## How to add a new plot 83 | 84 | If there is a model evaluation plot that you would like to add to the codebase, 85 | you must first determine for what models it would be relevant. For example, 86 | a confusion matrix makes sense for a classification task but not for continuous 87 | outcome prediction, or sample weight calculation. 88 | 89 | Currently, the types of models are 90 | 91 | * Individual outcome predictions (continuous outcome) 92 | * Individual outcome predictions (binary outcome) 93 | * Sample weight predictions 94 | * Propensity predictions 95 | 96 | Propensity predictions combine binary individual outcome predictions (because 97 | "is treated" is a binary feature) with sample weight predictions. Something like 98 | a confusion matrix would make sense for binary outcome predictions and for propensity 99 | predictions, but not for the other categories. In that sense it would behave like 100 | the ROC curve, and PR curve which are already implemented. 101 | 102 | Assuming you want to add a new plot, you would add the basic plotting 103 | function to `plots/plots.py`. Then you would add a case to the relevant extractors' 104 | `get_data_for_plot` members to extract the data for the plot, based on its name, in `plots/data_extractors.py` . You would also add the name as an available plot in the relevant 105 | `frozenset` and in the `lookup_name` function, both in `plots/plots.py`. At this point, the plot should be drawn automatically when you run `plot_all` on the relevant `EvaluationResults` object. 106 | To expose the plot as a member `plot_my_new_plot`, you must add it to the correct mixin in 107 | `plots/mixins.py`. 108 | -------------------------------------------------------------------------------- /causallib/evaluation/__init__.py: -------------------------------------------------------------------------------- 1 | """Objects and methods to evaluate accuracy of causal models.""" 2 | from .evaluator import evaluate, evaluate_bootstrap 3 | 4 | __all__ = ["evaluate", "evaluate_bootstrap"] 5 | -------------------------------------------------------------------------------- /causallib/evaluation/metrics.py: -------------------------------------------------------------------------------- 1 | """Apply machine learning metrics to causal models for evaluation.""" 2 | import warnings 3 | 4 | import numpy as np 5 | import pandas as pd 6 | from sklearn import metrics 7 | 8 | 9 | NUMERICAL_CLASSIFICATION_METRICS = { 10 | "accuracy": metrics.accuracy_score, 11 | "precision": metrics.precision_score, 12 | "recall": metrics.recall_score, 13 | "f1": metrics.f1_score, 14 | "roc_auc": metrics.roc_auc_score, 15 | "avg_precision": metrics.average_precision_score, 16 | "hinge": metrics.hinge_loss, 17 | "matthews": metrics.matthews_corrcoef, 18 | "0_1": metrics.zero_one_loss, 19 | "brier": metrics.brier_score_loss, 20 | } 21 | NONNUMERICAL_CLASSIFICATION_METRICS = { 22 | "confusion_matrix": metrics.confusion_matrix, 23 | "roc_curve": metrics.roc_curve, 24 | "pr_curve": metrics.precision_recall_curve, 25 | } 26 | CLASSIFICATION_METRICS = { 27 | **NUMERICAL_CLASSIFICATION_METRICS, 28 | **NONNUMERICAL_CLASSIFICATION_METRICS, 29 | } 30 | 31 | REGRESSION_METRICS = { 32 | "expvar": metrics.explained_variance_score, 33 | "mae": metrics.mean_absolute_error, 34 | "mse": metrics.mean_squared_error, 35 | # "msle": metrics.mean_squared_log_error, #uncomment if predictions are all positive 36 | # Allow mdae receive sample_weight argument but ignore it. This unifies the interface: 37 | "mdae": lambda y_true, y_pred, **kwargs: metrics.median_absolute_error( 38 | y_true, y_pred 39 | ), 40 | "r2": metrics.r2_score, 41 | } 42 | 43 | 44 | def get_default_binary_metrics(only_numeric_metric=False): 45 | """Get default metrics for evaluating binary models. 46 | 47 | Args: 48 | only_numeric_metric (bool): If metrics_to_evaluate not provided and default is used, 49 | whether to use only numerical metrics. Ignored if metrics_to_evaluate is provided. 50 | Non-numerical metrics are for example roc_curve, that returns vectors and not scalars). 51 | Returns: 52 | dict [str, callable]: metrics dict with key: metric's name, value: callable that receives 53 | true labels, prediction and sample_weights (the latter is allowed to be ignored). 54 | """ 55 | if only_numeric_metric: 56 | return NUMERICAL_CLASSIFICATION_METRICS 57 | 58 | return CLASSIFICATION_METRICS 59 | 60 | 61 | def get_default_regression_metrics(): 62 | """Get default metrics for evaluating continuous prediction models. 63 | 64 | Returns: 65 | dict [str, callable]: metrics dict with key: metric's name, value: callable that receives 66 | true labels, prediction and sample_weights (the latter is allowed to be ignored). 67 | """ 68 | return REGRESSION_METRICS 69 | 70 | 71 | def evaluate_metrics( 72 | metrics_to_evaluate, 73 | y_true, 74 | y_pred=None, 75 | y_pred_proba=None, 76 | sample_weight=None, 77 | ): 78 | """Evaluates the metrics against the supplied predictions and labels. 79 | 80 | Note that some metrics operate on proba predictions (`y_pred_proba`) and others on 81 | direct predictions. The function will select the correct input based on the name of the metric, 82 | if it knows about the metric. 83 | Otherwise it defaults to using the direct prediction (`y_pred`). 84 | 85 | Args: 86 | metrics_to_evaluate (dict): key: metric's name, value: callable that receives 87 | true labels, prediction and sample_weights (the latter is allowed to be ignored). 88 | y_true (pd.Series): True labels 89 | y_pred_proba (pd.Series): continuous output of predictor, 90 | as in `predict_proba` or `decision_function`. 91 | y_pred (pd.Series): label (i.e., categories, decisions) predictions. 92 | sample_weight (pd.Series | None): weight of each sample. 93 | 94 | Returns: 95 | pd.Series: name of metric as index and the evaluated score as value. 96 | """ 97 | evaluated_metrics = {} 98 | for metric_name, metric_func in metrics_to_evaluate.items(): 99 | prediction = y_pred_proba if _metric_needs_proba(metric_name) else y_pred 100 | if prediction is None: 101 | continue 102 | 103 | try: 104 | metric_value = metric_func(y_true, prediction, sample_weight=sample_weight) 105 | except ValueError as v: # if y_true has single value 106 | warnings.warn(f"metric {metric_name} could not be evaluated") 107 | warnings.warn(str(v)) 108 | metric_value = np.nan 109 | evaluated_metrics[metric_name] = metric_value 110 | 111 | all_scalars = all(np.isscalar(v) for v in evaluated_metrics.values()) 112 | dtype = float if all_scalars else np.dtype(object) 113 | 114 | 115 | return pd.Series(evaluated_metrics, dtype=dtype) 116 | 117 | 118 | def _metric_needs_proba(metric_name): 119 | use_proba = metric_name in { 120 | "hinge", 121 | "brier", 122 | "roc_curve", 123 | "roc_auc", 124 | "pr_curve", 125 | "avg_precision", 126 | } 127 | 128 | return use_proba 129 | -------------------------------------------------------------------------------- /causallib/evaluation/plots/__init__.py: -------------------------------------------------------------------------------- 1 | """Causal model evaluation plotting functions.""" 2 | -------------------------------------------------------------------------------- /causallib/metrics/__init__.py: -------------------------------------------------------------------------------- 1 | from .propensity_metrics import weighted_roc_auc_error, expected_roc_auc_error 2 | from .propensity_metrics import weighted_roc_curve_error, expected_roc_curve_error 3 | from .propensity_metrics import ici_error 4 | from .weight_metrics import covariate_balancing_error 5 | from .weight_metrics import covariate_imbalance_count_error 6 | from .outcome_metrics import balanced_residuals_error 7 | 8 | from .scorers import get_scorer, get_scorer_names 9 | -------------------------------------------------------------------------------- /causallib/metrics/outcome_metrics.py: -------------------------------------------------------------------------------- 1 | """ 2 | Metrics to assess performance of direct-outcome (counterfactual prediction) models. 3 | Function named as ``*_error`` return a scalar value to minimize: 4 | the lower, the better. 5 | 6 | Metrics' interface doesn't strictly follow skleran's metrics interface. 7 | The outcome prediction is expected to be full potential outcome prediction 8 | (one column for each treatment value), and also expects the true treatment assignment. 9 | """ 10 | 11 | import numpy as np 12 | import pandas as pd 13 | 14 | from causallib.utils.stat_utils import calc_weighted_standardized_mean_differences 15 | from causallib.utils.stat_utils import robust_lookup 16 | 17 | 18 | def abs_standardized_mean_difference(a, b, **kwargs): 19 | asmd = calc_weighted_standardized_mean_differences( 20 | a, b, 21 | wx=np.ones_like(a), 22 | wy=np.ones_like(b), 23 | ) 24 | asmd = np.abs(asmd) 25 | return asmd 26 | 27 | 28 | def _get_observed_outcome_prediction(potential_outcomes, a): 29 | # TODO: duplicated throughout causallib. move to utils or standardization module 30 | is_predict_proba_classification_result = isinstance(potential_outcomes.columns, pd.MultiIndex) 31 | if is_predict_proba_classification_result: 32 | # Classification `outcome_model` with `predict_proba=True` returns 33 | # a MultiIndex treatment-values (`a`) over outcome-values (`y`) 34 | # Extract the prediction for the maximal outcome class 35 | # (probably class `1` in binary classification): 36 | outcome_values = potential_outcomes.columns.get_level_values(level=-1) 37 | potential_outcomes = potential_outcomes.xs( 38 | outcome_values.max(), axis="columns", level=-1, drop_level=True, 39 | ) 40 | potential_outcomes = robust_lookup(potential_outcomes, a) 41 | return potential_outcomes 42 | 43 | 44 | def balanced_residuals_error( 45 | y_true, y_pred, a_true, 46 | distance_metric=abs_standardized_mean_difference, 47 | distance_metric_kwargs=None, 48 | **kwargs, 49 | ): 50 | """Computes how different is the residuals distribution of the control group 51 | from that of the treatment group. 52 | Residuals are based on the observed (factual) outcome prediction. 53 | 54 | Can plug in any uni-variate two-sample test function. 55 | 56 | Args: 57 | y_true (pd.Series): The true observed outcomes. 58 | y_pred (pd.DataFrame): Potential outcome prediction, the output of `estimate_individual_outcome()`. 59 | A matrix of (n_samples, n_treatments), with column names as the treatment values. 60 | a_true (pd.Series): A vector of observed treatment assignment. 61 | distance_metric (callable): A two sample test function. 62 | First argument is the residual values of the treatment group, second is for the control group. 63 | Defaults to absolute standardized mean difference. 64 | distance_metric_kwargs (dict): Additional keyword arguments needed for the `distance_metric` function. 65 | 66 | Returns: 67 | score (float): 68 | """ 69 | if distance_metric_kwargs is None: 70 | distance_metric_kwargs = {} 71 | 72 | y_pred = _get_observed_outcome_prediction(y_pred, a_true) 73 | 74 | residuals = y_true - y_pred 75 | treatment_mask = a_true == a_true.max() 76 | control_mask = a_true == a_true.min() 77 | 78 | score = distance_metric( 79 | residuals[treatment_mask], 80 | residuals[control_mask], 81 | **distance_metric_kwargs, 82 | ) 83 | return score 84 | 85 | -------------------------------------------------------------------------------- /causallib/metrics/scorers.py: -------------------------------------------------------------------------------- 1 | """ 2 | This submodule implements a scorer interface for the 3 | various causal metrics implemented. 4 | These scorers can then be incorporated into model selection objects 5 | to select the models (or hyperparameters) that optimizes these scores. 6 | For example, as a `scoring` parameter to a `causallib.model_selection.GridSearchCV 7 | object. 8 | 9 | The signature of the call is ``(estimator, X, a, y)`` where ``estimator`` 10 | is the causal model to be evaluated, ``X`` are the covariates, 11 | `a` is the treatment assignment, and ``y`` is the ground truth target 12 | """ 13 | # from sklearn.metrics._scorer import _BaseScorer 14 | import abc 15 | 16 | from . import ( 17 | propensity_metrics, 18 | weight_metrics, 19 | outcome_metrics, 20 | ) 21 | 22 | 23 | class _BaseCausalScorer: 24 | def __init__(self, score_func, sign, **kwargs): 25 | self._score_func = score_func 26 | self._sign = sign 27 | self._kwargs = kwargs 28 | 29 | def __call__(self, estimator, X, a_true, y_true, sample_weight=None, **kwargs): 30 | return self._sign * self._score( 31 | estimator, 32 | X, 33 | a_true, 34 | y_true, 35 | sample_weight=sample_weight, 36 | **{**self._kwargs, **kwargs}, 37 | ) 38 | 39 | @abc.abstractmethod 40 | def _score(self, estimator, X, a, y, sample_weight=None, **kwargs): 41 | raise NotImplementedError 42 | 43 | 44 | class PropensityScorerBase(_BaseCausalScorer): 45 | def _score(self, estimator, X, a, y, sample_weight=None, **kwargs): 46 | propensities = estimator.compute_propensity(X, a) 47 | weights = estimator.compute_weights(X, a) 48 | score = self._score_func( 49 | a, propensities, sample_weight=weights, 50 | **kwargs 51 | ) 52 | return score 53 | 54 | 55 | weighted_roc_auc_error_scorer = PropensityScorerBase( 56 | propensity_metrics.weighted_roc_auc_error, -1, 57 | ) 58 | weighted_roc_curve_error_scorer = PropensityScorerBase( 59 | propensity_metrics.weighted_roc_curve_error, -1, 60 | ) 61 | expected_roc_auc_error_scorer = PropensityScorerBase( 62 | propensity_metrics.expected_roc_auc_error, -1, 63 | ) 64 | expected_roc_curve_error_scorer = PropensityScorerBase( 65 | propensity_metrics.expected_roc_curve_error, -1 66 | ) 67 | ici_error_scorer = PropensityScorerBase( 68 | propensity_metrics.ici_error, -1, 69 | ) 70 | 71 | _PROPENSITY_SCORERS = dict( 72 | weighted_roc_auc_error=weighted_roc_auc_error_scorer, 73 | weighted_roc_curve_error=weighted_roc_curve_error_scorer, 74 | expected_roc_auc_error=expected_roc_auc_error_scorer, 75 | expected_roc_curve_error=expected_roc_curve_error_scorer, 76 | ici_error=ici_error_scorer, 77 | 78 | ) 79 | 80 | 81 | class WeightScorerBase(_BaseCausalScorer): 82 | def _score(self, estimator, X, a, y, sample_weight=None, **kwargs): 83 | weights = estimator.compute_weights(X, a) 84 | score = self._score_func( 85 | X, a, sample_weight=weights, 86 | **kwargs 87 | ) 88 | return score 89 | 90 | 91 | covariate_balancing_error_scorer = WeightScorerBase( 92 | weight_metrics.covariate_balancing_error, -1, 93 | ) 94 | 95 | covariate_imbalance_count_error_scorer = WeightScorerBase( 96 | weight_metrics.covariate_imbalance_count_error, -1, 97 | ) 98 | 99 | _WEIGHT_SCORERS = dict( 100 | covariate_balancing_error=covariate_balancing_error_scorer, 101 | covariate_imbalance_count_error=covariate_imbalance_count_error_scorer, 102 | ) 103 | 104 | 105 | class OutcomeScorerBase(_BaseCausalScorer): 106 | def _score(self, estimator, X, a, y, sample_weight=None, **kwargs): 107 | potential_outcomes_pred = estimator.estimate_individual_outcome(X, a) 108 | score = self._score_func( 109 | y, potential_outcomes_pred, a, 110 | **kwargs 111 | # Is this a good generic API call to the outcome metrics? 112 | ) 113 | return score 114 | 115 | 116 | balanced_residuals_error_scorer = OutcomeScorerBase( 117 | outcome_metrics.balanced_residuals_error, -1, 118 | ) 119 | 120 | _OUTCOME_SCORERS = dict( 121 | balanced_residuals_error=balanced_residuals_error_scorer, 122 | ) 123 | 124 | _SCORERS = { 125 | **_PROPENSITY_SCORERS, 126 | **_WEIGHT_SCORERS, 127 | **_OUTCOME_SCORERS 128 | } 129 | 130 | 131 | def get_scorer(scoring): 132 | """Gets a scorer callable from string. 133 | see `causallib.metrics.get_scorer_names` to retrieve available score names. 134 | """ 135 | if callable(scoring): 136 | return scoring 137 | try: 138 | return _SCORERS.get(scoring) 139 | except KeyError: 140 | raise ValueError( 141 | f"Scoring name {scoring} is not a valid scoring name." 142 | f"use the `causallib.metrics.get_scorer_names` to get all valid names." 143 | ) 144 | 145 | 146 | def get_scorer_names(score_type="all"): 147 | """Get the name of all available scorers. 148 | These names can be passed to `causallib.metrics.get_scorer` to retrieve a scorer object. 149 | 150 | Args: 151 | score_type (str): any of {"all", "propensity", "weight", "outcome"}. 152 | Returns only scorers relevant to the `score_type` type of model. 153 | 154 | Returns: 155 | 156 | """ 157 | scores_types_map = { 158 | "all": _SCORERS, 159 | "propensity": _PROPENSITY_SCORERS, 160 | "weight": _WEIGHT_SCORERS, 161 | "outcome": _OUTCOME_SCORERS, 162 | } 163 | try: 164 | return sorted(scores_types_map[score_type]) 165 | except KeyError: 166 | raise ValueError( 167 | f"`score_type` {score_type} is not valid." 168 | f"Please use one of {scores_types_map.keys()}." 169 | ) 170 | 171 | -------------------------------------------------------------------------------- /causallib/metrics/weight_metrics.py: -------------------------------------------------------------------------------- 1 | """ 2 | Metrics to assess performance of weight models. 3 | Function named as ``*_error`` return a scalar value to minimize: 4 | the lower, the better. 5 | 6 | Metrics' interface doesn't strictly follow skleran's metrics interface. 7 | """ 8 | 9 | import pandas as pd 10 | import numpy as np 11 | 12 | from ..utils.stat_utils import ( 13 | calc_weighted_ks2samp, 14 | calc_weighted_standardized_mean_differences, 15 | ) 16 | 17 | DISTRIBUTION_DISTANCE_METRICS = { 18 | "smd": lambda x, y, wx, wy: calc_weighted_standardized_mean_differences( 19 | x, y, wx, wy 20 | ), 21 | "abs_smd": lambda x, y, wx, wy: abs( 22 | calc_weighted_standardized_mean_differences(x, y, wx, wy) 23 | ), 24 | "ks": lambda x, y, wx, wy: calc_weighted_ks2samp(x, y, wx, wy), 25 | } 26 | 27 | 28 | def calculate_covariate_balance(X, a, w, metric="abs_smd"): 29 | """Calculate covariate balance table ("table 1") 30 | 31 | Args: 32 | X (pd.DataFrame): Covariates. 33 | a (pd.Series): Group assignment of each sample. 34 | w (pd.Series): sample weights for balancing between groups in `a`. 35 | metric (str | callable): Either a key from DISTRIBUTION_DISTANCE_METRICS or a metric with 36 | the signature weighted_distance(x, y, wx, wy) calculating distance between the weighted 37 | sample x and weighted sample y (weights by wx and wy respectively). 38 | 39 | Returns: 40 | pd.DataFrame: index are covariate names (columns) from X, and columns are 41 | "weighted" / "unweighted" results of applying `metric` on each covariate 42 | to compare the two groups. 43 | """ 44 | treatment_values = np.sort(np.unique(a)) 45 | results = {} 46 | for treatment_value in treatment_values: 47 | distribution_distance_of_cur_treatment = pd.DataFrame( 48 | index=X.columns, columns=["weighted", "unweighted"], dtype=float 49 | ) 50 | for col_name, col_data in X.items(): 51 | weighted_distance = calculate_distribution_distance_for_single_feature( 52 | col_data, w, a, treatment_value, metric 53 | ) 54 | unweighted_distance = calculate_distribution_distance_for_single_feature( 55 | col_data, pd.Series(1, index=w.index), a, treatment_value, metric 56 | ) 57 | distribution_distance_of_cur_treatment.loc[ 58 | col_name, ["weighted", "unweighted"] 59 | ] = [weighted_distance, unweighted_distance] 60 | results[treatment_value] = distribution_distance_of_cur_treatment 61 | results = pd.concat( 62 | results, axis="columns", names=[a.name or "a", metric] 63 | ) # type: pd.DataFrame 64 | results.index.name = "covariate" 65 | if len(treatment_values) == 2: 66 | # If there are only two treatments, the results for both are identical. 67 | # Therefore, we can get rid of one of them. 68 | # We keep the results for the higher valued treatment group (assumed treated, typically 1): 69 | results = results.xs(treatment_values.max(), axis="columns", level=0) 70 | # TODO: is there a neat expansion for multi-treatment case? 71 | # maybe not current_treatment vs. the rest. 72 | return results 73 | 74 | 75 | def calculate_distribution_distance_for_single_feature( 76 | x, w, a, group_level, metric="abs_smd" 77 | ): 78 | """ 79 | 80 | Args: 81 | x (pd.Series): A single feature to check balancing. 82 | a (pd.Series): Group assignment of each sample. 83 | w (pd.Series): sample weights for balancing between groups in `a`. 84 | group_level: Value from `a` in order to divide the sample into one vs. rest. 85 | metric (str | callable): Either a key from DISTRIBUTION_DISTANCE_METRICS or a metric with 86 | the signature weighted_distance(x, y, wx, wy) calculating distance between the weighted 87 | sample x and weighted sample y (weights by wx and wy respectively). 88 | 89 | Returns: 90 | float: weighted distance between the samples assigned to `group_level` 91 | and the rest of the samples. 92 | """ 93 | if not callable(metric): 94 | metric = DISTRIBUTION_DISTANCE_METRICS[metric] 95 | cur_treated_mask = a == group_level 96 | x_treated = x.loc[cur_treated_mask] 97 | w_treated = w.loc[cur_treated_mask] 98 | x_untreated = x.loc[~cur_treated_mask] 99 | w_untreated = w.loc[~cur_treated_mask] 100 | distribution_distance = metric(x_treated, x_untreated, w_treated, w_untreated) 101 | return distribution_distance 102 | 103 | 104 | def covariate_balancing_error(X, a, sample_weight, agg=max, **kwargs): 105 | """Computes the weighted (i.e. balanced) absolute standardized mean difference 106 | of every covariate in X. 107 | 108 | Args: 109 | X (pd.DataFrame): Covariate matrix. 110 | a (pd.Series): Treatment assignment vector. 111 | sample_weight (pd.Series): Weights balancing between the treatment groups. 112 | agg (callable): A function to aggregate a vector of absolute differences 113 | between the curves' points. Default is max. 114 | 115 | Returns: 116 | score (float): 117 | """ 118 | asmds = calculate_covariate_balance(X, a, sample_weight, metric="abs_smd") 119 | weighted_asmds = asmds["weighted"] 120 | score = agg(weighted_asmds) 121 | return score 122 | 123 | 124 | def covariate_imbalance_count_error( 125 | X, a, sample_weight, threshold=0.1, fraction=True, **kwargs 126 | ) -> float: 127 | asmds = calculate_covariate_balance(X, a, sample_weight, metric="abs_smd") 128 | weighted_asmds = asmds["weighted"] 129 | is_violating = weighted_asmds > threshold 130 | score = sum(is_violating) 131 | if fraction: 132 | score /= is_violating.shape[0] 133 | return score 134 | -------------------------------------------------------------------------------- /causallib/model_selection/__init__.py: -------------------------------------------------------------------------------- 1 | from sklearn.model_selection import GridSearchCV as skGridSearchCV 2 | from sklearn.model_selection import RandomizedSearchCV as skRandomizedSearchCV 3 | 4 | from .search import causalize_searcher 5 | from .split import TreatmentOutcomeStratifiedKFold 6 | from .split import TreatmentStratifiedKFold 7 | 8 | GridSearchCV = causalize_searcher(skGridSearchCV) 9 | RandomizedSearchCV = causalize_searcher(skRandomizedSearchCV) 10 | -------------------------------------------------------------------------------- /causallib/model_selection/split.py: -------------------------------------------------------------------------------- 1 | from itertools import product 2 | 3 | import numpy as np 4 | import pandas as pd 5 | 6 | from sklearn.model_selection import StratifiedKFold 7 | from sklearn.utils.multiclass import type_of_target 8 | 9 | 10 | class TreatmentOutcomeStratifiedKFold(StratifiedKFold): 11 | """Creates stratified folds based on both the treatment assignment 12 | and the outcome. 13 | That is, every fold preserves both the treatment prevalence and 14 | outcome prevalence within each treatment. 15 | 16 | For non-class outcomes, stratification is done only based on treatment. 17 | 18 | """ 19 | __doc__ += StratifiedKFold.__doc__ 20 | 21 | @staticmethod 22 | def _combine_treatment_outcome_labels(a, y): 23 | """combines every `a` x `y` values as a unique label""" 24 | # Assuming n_a < 10, n_y < 10: labels = a*10+y. Implements a generic version. 25 | a_unique = np.unique(a) 26 | y_unique = np.unique(y) 27 | combinations = product(a_unique, y_unique) 28 | combinations_mapping = {c: i for i, c in enumerate(combinations)} 29 | combined_labels = [combinations_mapping[(ai, yi)] for ai, yi in zip(a, y)] 30 | combined_labels = pd.Series(combined_labels, index=a.index) 31 | return combined_labels 32 | 33 | def _get_labels_for_split(self, a, y): 34 | target_type = type_of_target(y) 35 | if target_type not in ("binary", "multiclass"): 36 | # `y` is incompatible with stratification 37 | raise ValueError( 38 | f"Outcome type should either be 'binary' or 'multiclass'." 39 | f"Received {target_type} instead." 40 | ) 41 | labels = self._combine_treatment_outcome_labels(a, y) 42 | return labels 43 | 44 | def split(self, joinedXa, y, groups=None): 45 | X = joinedXa.iloc[:, :-1] 46 | a = joinedXa.iloc[:, -1] 47 | splits = self._split(X, a, y, groups=groups) 48 | # labels = self._get_labels_for_split(a, y) 49 | # splits = super().split(X, labels, groups=groups) 50 | return splits 51 | 52 | def _split(self, X, a, y, groups=None): 53 | """A causallib-like `X, a, y` interface for split""" 54 | labels = self._get_labels_for_split(a, y) 55 | splits = super().split(X, labels, groups=groups) 56 | return splits 57 | 58 | 59 | class TreatmentStratifiedKFold(StratifiedKFold): 60 | """Creates stratified folds based on the treatment assignment. 61 | That is, every fold preserves the treatment prevalence. 62 | """ 63 | __doc__ += StratifiedKFold.__doc__ 64 | 65 | def split(self, joinedXa, y=None, groups=None): 66 | X = joinedXa.iloc[:, :-1] 67 | a = joinedXa.iloc[:, -1] 68 | splits = self._split(X, a, y, groups=groups) 69 | return splits 70 | 71 | def _split(self, X, a, y=None, groups=None): 72 | """A causallib-like `X, a, y` interface for split""" 73 | splits = super().split(X, a, groups=groups) 74 | return splits 75 | -------------------------------------------------------------------------------- /causallib/positivity/__init__.py: -------------------------------------------------------------------------------- 1 | from .base_positivity import BasePositivity 2 | from .trimming import Trimming 3 | from .matching import Matching 4 | from .univariate_bbox import UnivariateBoundingBox 5 | __version__ = "0.0.1" -------------------------------------------------------------------------------- /causallib/positivity/base_positivity.py: -------------------------------------------------------------------------------- 1 | """ 2 | (C) Copyright 2021 IBM Corp. 3 | Licensed under the Apache License, Version 2.0 (the "License"); 4 | you may not use this file except in compliance with the License. 5 | You may obtain a copy of the License at 6 | http://www.apache.org/licenses/LICENSE-2.0 7 | Unless required by applicable law or agreed to in writing, software 8 | distributed under the License is distributed on an "AS IS" BASIS, 9 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 10 | See the License for the specific language governing permissions and 11 | limitations under the License. 12 | Created on March 2, 2021 13 | """ 14 | from __future__ import annotations 15 | from abc import ABC, abstractmethod 16 | import pandas as pd 17 | from typing import Tuple 18 | from sklearn.base import BaseEstimator 19 | 20 | 21 | class BasePositivity(ABC, BaseEstimator): 22 | 23 | @abstractmethod 24 | def fit(self, 25 | X: pd.DataFrame, a: pd.Series) -> BasePositivity: 26 | """Fit positivity checker. 27 | 28 | Args: 29 | X (pd.DataFrame): Covariate matrix of size (num_subjects, num_features). 30 | a (pd.Series): Treatment assignment of size (num_subjects,). 31 | """ 32 | raise NotImplementedError 33 | 34 | @abstractmethod 35 | def predict(self, X: pd.DataFrame, a: pd.Series) -> pd.Series: 36 | """Predict whether a sample is in the overlap of treatments. 37 | 38 | Args: 39 | X (pd.DataFrame): Covariate matrix of size (num_subjects, num_features) 40 | a (pd.Series): Treatment assignment of size (num_subjects,). 41 | 42 | Returns: 43 | pd.Series: a Series of length `X.shape[0]` with the same index as 44 | `X` and only boolean values 45 | """ 46 | raise NotImplementedError 47 | 48 | def transform(self, 49 | X: pd.DataFrame, a: pd.Series, *args: pd.Series 50 | ) -> Tuple[pd.DataFrame, pd.Series, pd.Series]: 51 | """Transform the input data to remove positivity violations. 52 | 53 | Args: 54 | X (pd.DataFrame): Covariate matrix of size (num_subjects, num_features). 55 | a (pd.Series): Treatment assignment of size (num_subjects,). 56 | *args (pd.Series): Zero or more pd.Series objects corresponding to 57 | outcomes. Each argument must be indexed the same as the other 58 | arguments and have size (num_subjects,). 59 | 60 | Returns: 61 | Tuple[pd.DataFrame, pd.Series, pd.Series]: Subsets of `X`, `a` and 62 | the output series objects of `args` corresponding to the samples 63 | which do not violate the positivity assumption. 64 | """ 65 | indices_to_keep = self.predict(X, a) 66 | return_list = [X.loc[indices_to_keep], a.loc[indices_to_keep]] 67 | for output in args: 68 | return_list.append(output.loc[indices_to_keep]) 69 | return return_list 70 | 71 | def fit_predict(self, X: pd.DataFrame, a: pd.Series) -> pd.Series: 72 | """Fit positivity checker and predict overlap membership. 73 | 74 | This is a convenience function that calls `fit` and `predict`. 75 | 76 | Args: 77 | X (pd.DataFrame): Covariate matrix of size (num_subjects, num_features) 78 | a (pd.Series): Treatment assignment of size (num_subjects,). 79 | 80 | Returns: 81 | pd.Series: a Series of length `X.shape[0]` with the same index as 82 | `X` and only boolean values 83 | """ 84 | self.fit(X, a) 85 | return self.predict(X, a) 86 | 87 | def fit_transform(self, 88 | X: pd.DataFrame, a: pd.Series, *args: pd.Series 89 | ) -> Tuple[pd.DataFrame, pd.Series, pd.Series]: 90 | """Fit and transform data by removing positivity violations. 91 | 92 | Args: 93 | X (pd.DataFrame): Covariate matrix of size (num_subjects, num_features). 94 | a (pd.Series): Treatment assignment of size (num_subjects,). 95 | *args (pd.Series): Zero or more pd.Series objects corresponding to 96 | outcomes. Each argument must be indexed the same as the other 97 | arguments and have size (num_subjects,). 98 | 99 | Returns: 100 | Tuple[pd.DataFrame, pd.Series, pd.Series]: Subsets of `X`, `a` and 101 | the output series objects of `args` corresponding to the samples 102 | which do not violate the positivity assumption. 103 | """ 104 | self.fit(X, a) 105 | return self.transform(X, a, *args) 106 | 107 | def score(self, 108 | X: pd.DataFrame, a: pd.Series, 109 | **kwargs): 110 | """Score the positivity violation 111 | This is a generic function, but right now it receives 112 | only one kind of scorer - cross_covaraince_score 113 | Args: 114 | X (pd.DataFrame): Covariate matrix of size (num_subjects, num_features). 115 | a (pd.Series): Treatment assignment of size (num_subjects,). 116 | **kwargs : kwargs that are corresponding to the scoring metric. 117 | 118 | Returns: 119 | float: a non-negative score that quantifies the violation 120 | of positivity 121 | """ 122 | from .metrics.metrics import cross_covariance_score 123 | X_trans, a_trans = self.transform(X, a) 124 | return cross_covariance_score(X_trans, a_trans, **kwargs) 125 | 126 | -------------------------------------------------------------------------------- /causallib/positivity/datasets/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BiomedSciAI/causallib/803a4d34eaf09980258b498631d6af15017528dc/causallib/positivity/datasets/__init__.py -------------------------------------------------------------------------------- /causallib/positivity/datasets/pizza_data_simulator.py: -------------------------------------------------------------------------------- 1 | # todo: 2 | # adding more datasets as positivity benchmarks 3 | # and maybe order it nicely in a Class of Simulated_data 4 | 5 | import numpy as np 6 | import pandas as pd 7 | 8 | 9 | def _rotate_data(X, angle_rotation): 10 | """ rotate only the first 2d """ 11 | angle = np.deg2rad(angle_rotation) 12 | rot_2d = np.array([[np.cos(angle), -np.sin(angle)], 13 | [np.sin(angle), np.cos(angle)]]) 14 | rot_mat = np.eye(X.shape[1]) 15 | rot_mat[:2, :2] = rot_2d 16 | X_rotated = np.dot(X, rot_mat) 17 | return pd.DataFrame(X_rotated) 18 | 19 | 20 | def _probability_of_being_treated(X, angle_slice): 21 | """ based on the amount of slicing in the data """ 22 | angle = np.deg2rad(angle_slice) 23 | slope = np.sin(angle)/np.cos(angle) 24 | slicing_cond = (0 < X.iloc[:, 1]) & ((X.iloc[:, 1] / slope) < X.iloc[:, 0]) 25 | p = 0.5 + 0.5 * slicing_cond 26 | return p 27 | 28 | 29 | def pizza(n_dim=2, n_samples=1000, 30 | angle_rotation=45, angle_slice=90, seed=0): 31 | """ 32 | Rotate the data and create slice with strict non-overlapping, 33 | i.e. the propensity score given this sub space is equal to 1. 34 | 35 | If n_dim>2, uniform distribution over the covariate space, where 36 | the non-overlapping area is the same as n_dim=2, meaning that the 37 | other covariates are completely overlapping. 38 | Args: 39 | n_dim (int): number of dimensions, have to be equal or bigger than 2. 40 | n_samples (int): number of samples 41 | angle_rotation (float): the angle of rotation in degrees 42 | angle_slice (float): the angle of sliced out area, the bigger the angle 43 | the wider the non-overlapping area, 44 | ranging from [0, 180] 45 | seed (None | int): 46 | 47 | Returns: 48 | 49 | """ 50 | np.random.seed(seed) 51 | X = np.random.uniform(-np.ones(n_dim), 52 | np.ones(n_dim), 53 | size=(n_samples, n_dim)) / np.sqrt(n_dim) 54 | X_rotated = _rotate_data(X, angle_rotation) 55 | p = _probability_of_being_treated(X_rotated, angle_slice) 56 | a = pd.Series(np.random.binomial(1, p), name='a') 57 | return X_rotated, a 58 | -------------------------------------------------------------------------------- /causallib/positivity/datasets/positivity_data_simulator.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | import numpy as np 3 | 4 | 5 | def make_1d_overlap_data(treatment_bounds=(0, 75), 6 | control_bounds=(25, 100)): 7 | """Generate 1d overlap data with integer covariates 8 | 9 | Args: 10 | treatment_bounds (tuple, optional): Bounds for covariates in treatment 11 | group. Defaults to (0, 75). 12 | control_bounds (tuple, optional): Bounds for covariates in control 13 | group. Defaults to (25, 100). 14 | 15 | Returns: 16 | X (pd.DataFrame), a (pd.Series): covariate and treatment assignment 17 | """ 18 | 19 | X_treatment = np.arange(*treatment_bounds) 20 | a_treatment = np.ones_like(X_treatment) 21 | 22 | X_control = np.arange(*control_bounds) 23 | a_control = np.zeros_like(X_control) 24 | 25 | X = pd.DataFrame(data=np.hstack((X_treatment, X_control)), columns=["X1"]) 26 | a = pd.Series(data=np.hstack((a_treatment, a_control)), name="treatment") 27 | 28 | return X, a 29 | 30 | 31 | def make_1d_normal_distribution_overlap_data(treatment_params=(0, 1), 32 | control_params=(0, 1), 33 | probability_treated=0.5, 34 | n_samples=400, 35 | random_seed=1234): 36 | """ 37 | Args: 38 | treatment_params (tuple): loc and scale parameter of normal distribution 39 | control_params (tuple): loc and scale parameter of normal distribution 40 | probability_treated (float): 41 | n_samples (int): 42 | random_seed (int): 43 | Returns: 44 | X (pd.DataFrame), a (pd.Series): 45 | """ 46 | n_treated = int(round(n_samples * probability_treated)) 47 | n_control = n_samples - n_treated 48 | 49 | a_control = np.zeros(n_control) 50 | np.random.seed(random_seed) 51 | X_control = np.random.normal( 52 | loc=control_params[0], scale=control_params[1], size=n_control) 53 | 54 | a_treatment = np.ones(n_treated) 55 | np.random.seed(random_seed + 1) 56 | X_treatment = np.random.normal( 57 | loc=treatment_params[0], scale=treatment_params[1], size=n_treated) 58 | 59 | X = pd.DataFrame(data=np.hstack((X_treatment, X_control)), columns=["X"]) 60 | a = pd.Series(data=np.hstack((a_treatment, a_control)), name="treatment") 61 | return X, a 62 | 63 | 64 | def make_multivariate_normal_data( 65 | treatment_params=([0, 0], [[2, 1], [1, 2]]), 66 | control_params=([2, 2], [[1, 0], [0, 1]]), 67 | probability_treated=0.5, 68 | n_samples=400, 69 | random_seed=1234): 70 | """ 71 | Args: 72 | treatment_params (tuple): loc and scale parameter of multivariate normal distribution 73 | control_params (tuple): loc and scale parameter of multivariate normal distribution 74 | probability_treated (float): the probability to be in the treatment group 75 | n_samples (int): number of samples for the full data set 76 | random_seed (int): each group receive different seeding 77 | Returns: 78 | X (pd.DataFrame), a (pd.Series): 79 | """ 80 | n_treated = int(round(n_samples * probability_treated)) 81 | n_control = n_samples - n_treated 82 | 83 | a_control = np.zeros(n_control) 84 | np.random.seed(random_seed) 85 | X_control = np.random.multivariate_normal(control_params[0], 86 | control_params[1], 87 | n_control) 88 | a_treatment = np.ones(n_treated) 89 | np.random.seed(random_seed + 1) 90 | X_treatment = np.random.multivariate_normal(treatment_params[0], 91 | treatment_params[1], 92 | n_treated) 93 | 94 | X = pd.DataFrame(data=np.vstack((X_treatment, X_control))) 95 | a = pd.Series(data=np.hstack((a_treatment, a_control)), name="treatment") 96 | return X, a 97 | 98 | 99 | def make_random_y_like(a, random_seed=1234): 100 | np.random.seed(random_seed) 101 | y = pd.Series(data=np.random.random_sample(a.shape), name="outcome") 102 | return y 103 | -------------------------------------------------------------------------------- /causallib/positivity/matching.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | import pandas as pd 3 | from causallib.preprocessing.transformers import MatchingTransformer 4 | from causallib.positivity import BasePositivity 5 | from typing import Union, Optional 6 | import sklearn.neighbors 7 | import sklearn.base 8 | 9 | 10 | class Matching(BasePositivity): 11 | 12 | def __init__( 13 | self, 14 | propensity_transform: Optional[sklearn.base.TransformerMixin] = None, 15 | caliper: Optional[float] = None, 16 | with_replacement: bool = True, 17 | n_neighbors: int = 1, 18 | matching_mode: str = "both", 19 | metric: str = "mahalanobis", 20 | knn_backend: Union[str, 21 | sklearn.neighbors.NearestNeighbors] = "sklearn", 22 | ): 23 | """Fix positivity by matching. 24 | 25 | Args: 26 | propensity_transform (sklearn.TransformerMixin): an object for data 27 | preprocessing which implements `fit` and `transform` 28 | (default: None) 29 | caliper (float) : maximal distance for a match to be accepted. If 30 | not defined, all matches will be accepted. If defined, some 31 | samples may not be matched and their outcomes will not be 32 | estimated. (default: None) 33 | with_replacement (bool): whether samples can be used multiple times 34 | for matching. If set to False, the matching process will optimize 35 | the linear sum of distances between pairs of treatment and 36 | control samples and only `min(N_treatment, N_control)` samples 37 | will be estimated. Matching with no replacement does not make 38 | use of the `fit` data and is therefore not implemented for 39 | out-of-sample data (default: True) 40 | n_neighbors (int) : number of nearest neighbors to include in match. 41 | Must be 1 if `with_replacement` is `False.` If larger than 1, the 42 | estimate is calculated using the `regress_agg_function` or 43 | `classify_agg_function` across the `n_neighbors`. Note that when 44 | the `caliper` variable is set, some samples will have fewer than 45 | `n_neighbors` matches. (default: 1). 46 | matching_mode (str) : Direction of matching: `treatment_to_control`, 47 | `control_to_treatment` or `both` to indicate which set should 48 | be matched to which. All sets are cross-matched in `match` 49 | and when `with_replacement` is `False` all matching modes 50 | coincide. With replacement there is a difference. 51 | metric (str) : Distance metric string for calculating distance 52 | between samples. Note: if an external built `knn_backend` 53 | object with a different metric is supplied, `metric` needs to 54 | be changed to reflect that, because `Matching` will set its 55 | inverse covariance matrix if "mahalanobis" is set. (default: 56 | "mahalanobis", also supported: "euclidean") 57 | knn_backend (str or callable) : Backend to use for nearest neighbor 58 | search. Options are "sklearn" or a callable which returns an 59 | object implementing `fit`, `kneighbors` and `set_params` 60 | like the sklearn `NearestNeighbors` object. (default: "sklearn"). 61 | 62 | """ 63 | self.matching_transformer = MatchingTransformer( 64 | propensity_transform=propensity_transform, 65 | caliper=caliper, 66 | with_replacement=with_replacement, 67 | n_neighbors=n_neighbors, 68 | matching_mode=matching_mode, 69 | metric=metric, 70 | knn_backend=knn_backend, 71 | ) 72 | 73 | def fit(self, X: pd.DataFrame, a: pd.Series) -> Matching: 74 | """Fit matching positivity checker. 75 | 76 | Args: 77 | X (pd.DataFrame): samples 78 | a (pd.Series): treatment assignment 79 | """ 80 | self.matching_transformer.fit(X, a, pd.Series()) 81 | return self 82 | 83 | def predict(self, X: pd.DataFrame, a: pd.Series) -> pd.Series: 84 | """Predict whether or not a sample is in the overlap region. 85 | 86 | Find samples of treatment and control that successfully match and 87 | return a boolean indexer which is `True` if they matched and `False` if 88 | they did not. This function calls the `match` method of the underlying 89 | `Matching` object. 90 | 91 | Args: 92 | X (pd.DataFrame): samples 93 | a (pd.Series): treatment assignment 94 | 95 | Returns: 96 | pd.Series: a Series of length `X.shape[0]` with the same index as 97 | `X` and only boolean values 98 | """ 99 | self.matching_transformer.matching.match(X, a) 100 | matching_indices = self.matching_transformer.find_indices_of_matched_samples( 101 | X, a) 102 | return matching_indices 103 | -------------------------------------------------------------------------------- /causallib/positivity/metrics/__init__.py: -------------------------------------------------------------------------------- 1 | from .metrics import cross_covariance 2 | from .metrics import cross_covariance_score 3 | -------------------------------------------------------------------------------- /causallib/positivity/metrics/metrics.py: -------------------------------------------------------------------------------- 1 | """ 2 | Metrics to asses violations of positivity 3 | """ 4 | 5 | import numpy as np 6 | import pandas as pd 7 | 8 | 9 | def _check_mean_group(mean_group): 10 | mean_group = int(mean_group) 11 | if mean_group not in {0, 1}: 12 | raise ValueError('mean_group needs to be equal to 0 or 1') 13 | return mean_group 14 | 15 | 16 | def cross_covariance(X, a, mean_group=0): 17 | """ 18 | Computing the covariance, where the mean is taken from the counter group. 19 | suitable only for binary treatment, i.e a \in {0, 1} 20 | Args: 21 | X (pd.DataFrame): Covariate matrix of size (num_subjects, num_features). 22 | a (pd.Series): Treatment assignment of size (num_subjects,). 23 | mean_group (0 or 1): the treatment group 24 | 25 | Returns: 26 | np.array: cross-covariance square matrix 27 | """ 28 | mean_group = _check_mean_group(mean_group) 29 | avg = X.loc[a == mean_group, :].mean(axis=0) 30 | X_counter_group = X.loc[a == int(not mean_group)] 31 | X_counter_group = (X_counter_group - avg).astype(float) 32 | e_xx = np.dot(X_counter_group.T, X_counter_group) 33 | 34 | cross_cov = e_xx / (X_counter_group.shape[0] - 1) # unbiased cov estimate 35 | return cross_cov 36 | 37 | 38 | def cross_covariance_score(X, a, normalize=False, sum_scores=True, 39 | func=np.max, off_diagonal_only=False): 40 | """ 41 | Reduce the cross-covariance matrix into a single score 42 | Args: 43 | X (pd.DataFrame): Covariate matrix of size (num_subjects, num_features). 44 | a (pd.Series): Treatment assignment of size (num_subjects,). 45 | func (callable): function to apply on the cross-covariance matrix 46 | normalize (bool): If True, standardize features by removing the mean 47 | and scaling to unit variances. 48 | off_diagonal_only (bool): The diagonal of the cross-covariance matrix 49 | represents variances, and by definition 50 | max{var(x),var(y)} >= cov(x,y). 51 | If True, set on zero the diagonal of the 52 | cross-covariance matrix, focusing on maximum 53 | of cov elements. 54 | otherwise, focusing on maximum of 55 | var elements. 56 | 57 | sum_scores (bool): If True, sum the scores of the different groups, 58 | otherwise returns a list of scores. 59 | Returns: 60 | float|list: non-negative scores 61 | """ 62 | treatment_values = sorted(pd.unique(a)) 63 | if normalize: 64 | X = pd.DataFrame((X - X.mean(axis=0)) / X.std(axis=0), index=X.index) 65 | 66 | scores = list() 67 | for treatment in treatment_values: 68 | cross_cov = np.abs(cross_covariance(X, a, mean_group=treatment)) 69 | if off_diagonal_only: 70 | cross_cov -= np.diag(np.diag(cross_cov)) 71 | scores.append(func(cross_cov)) 72 | return np.sum(scores) if sum_scores else scores 73 | -------------------------------------------------------------------------------- /causallib/positivity/trimming.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | import numpy as np 3 | from causallib.positivity import BasePositivity 4 | from sklearn.linear_model import LogisticRegression 5 | 6 | OPTIMAL_THRESHOLD_ACCURACY = 5e-6 7 | 8 | 9 | def _check_is_valid_threshold_value(threshold_value): 10 | if not isinstance(threshold_value, (float, type(None))): 11 | raise ValueError("invalid threshold_value") 12 | return threshold_value 13 | 14 | 15 | def _check_is_valid_threshold_method(threshold_method): 16 | threshold_method = "crump" if threshold_method == "auto" else threshold_method 17 | if threshold_method not in cutoff_optimizers: 18 | raise ValueError("invalid threshold_method") 19 | return threshold_method 20 | 21 | 22 | def _check_propensities(prob): 23 | """ check if the treatment assignment is binary""" 24 | if prob.shape[1] > 2: 25 | raise ValueError('This threshold selection method is applicable only ' 26 | 'for binary treatment assignment') 27 | else: 28 | propensities = prob.iloc[:, 1] 29 | return propensities 30 | 31 | 32 | def crump_cutoff(prob, segments=10000): 33 | """ 34 | A systematic approach to find the optimal trimming cutoff, based on the 35 | marginal distribution of the propensity score, 36 | and according to a variance minimization criterion. 37 | 38 | "Crump, R. K., Hotz, V. J., Imbens, G. W., & Mitnik, O. A. (2009). 39 | Dealing with limited overlap in estimation of average treatment effects." 40 | Args: 41 | prob (pd.Series): probability of be assign to a group 42 | (n_samples, n_classes) 43 | segments (int): number of exclusive segments of the interval (0, 0.5]. 44 | more segments results with more precise cutoff 45 | 46 | Returns: 47 | float: the optimal cutoff, 48 | i.e. the smallest value that satisfies the criterion. 49 | """ 50 | propensities = _check_propensities(prob) 51 | alphas = np.linspace(1e-7, 0.5, segments) 52 | alphas_weights = alphas * (1 - alphas) 53 | overlap_weights = propensities * (1 - propensities) 54 | for i in range(segments): 55 | obs_meets_criterion = overlap_weights >= alphas_weights[i] 56 | criterion = 2 * (np.sum(obs_meets_criterion / overlap_weights) / 57 | np.maximum(np.sum(obs_meets_criterion), 1e-7)) 58 | if (1 / alphas_weights[i]) <= criterion: 59 | break 60 | return alphas[i] 61 | 62 | 63 | cutoff_optimizers = {'crump': crump_cutoff} 64 | 65 | 66 | def _lookup_method(threshold_method): 67 | if threshold_method in cutoff_optimizers: 68 | return cutoff_optimizers[threshold_method] 69 | else: 70 | raise Exception("Method %s does not exist" % threshold_method) 71 | 72 | 73 | class Trimming(BasePositivity): 74 | def __init__(self, 75 | learner=LogisticRegression(), 76 | threshold="auto"): 77 | """ 78 | 79 | Args: 80 | learner (sklearn object): Initialized sklearn model 81 | threshold (str | float) : The threshold method or value. 82 | - if auto: finding the optimized threshold in a principled way. 83 | - if float, hard-coded value between 0 to 0.5 is used 84 | in order to clip the propensity estimation. 85 | """ 86 | self.learner = learner 87 | if not hasattr(self.learner, "predict_proba"): 88 | raise AttributeError("Propensity Estimator must use a machine " 89 | "learning that can predict probabilities" 90 | "(i.e., have predict_proba method)") 91 | 92 | if isinstance(threshold, str): 93 | self.threshold = _check_is_valid_threshold_method(threshold) 94 | else: 95 | self.threshold_ = _check_is_valid_threshold_value(threshold) 96 | 97 | def _fit_threshold(self, X): 98 | """Fit threshold in a principled way""" 99 | prob = self.learner.predict_proba(X) 100 | prob = pd.DataFrame(prob, index=X.index, columns=self.learner.classes_) 101 | method = _lookup_method(self.threshold) 102 | threshold = method(prob) 103 | return threshold 104 | 105 | def fit(self, X, a): 106 | """Fit propensity model for positivity. 107 | 108 | Args: 109 | X (pd.DataFrame): covariate matrix of size 110 | (num_subjects, num_features) 111 | a (pd.Series): treatment assignment of size (num_subjects,) 112 | """ 113 | self.learner.fit(X, a) 114 | if hasattr(self, 'threshold'): 115 | self.threshold_ = self._fit_threshold(X) 116 | return self 117 | 118 | def predict(self, X, a, threshold=None): 119 | """Predict whether or not a sample is in the overlap region. 120 | Find samples that have probabilities to be assigned to one of the 121 | treatment groups, that is bigger than the cutoff threshold. 122 | 123 | return a boolean indexer which is `True` if their probabilities are 124 | higher than the cutoff threshold and `False` otherwise. 125 | Args: 126 | X (pd.DataFrame): covariate matrix of size 127 | (num_subjects, num_features) 128 | a (pd.Series): treatment assignment of size (num_subjects,) 129 | threshold (float|None): The cutoff threshold. 130 | - if float, an optional value between 0 to 0.5 to clip the 131 | propensity estimation. 132 | - if None, use the optimized cutoff in a principled way. 133 | 134 | Returns: 135 | pd.Series: a Series of length `X.shape[0]` with the same index as 136 | `X` and only boolean values 137 | """ 138 | prob = self.learner.predict_proba(X) 139 | prob = pd.DataFrame(prob, index=X.index, columns=self.learner.classes_) 140 | 141 | threshold_value = _check_is_valid_threshold_value(threshold) 142 | threshold_to_use = (self.threshold_ if threshold_value is None 143 | else threshold_value) 144 | 145 | untrimmed_indices = (prob >= threshold_to_use).all(axis=1) 146 | return untrimmed_indices 147 | -------------------------------------------------------------------------------- /causallib/preprocessing/README.md: -------------------------------------------------------------------------------- 1 | # Module `preprocessing` 2 | This module provides several useful filters and transformers to augment 3 | the ones provided by scikit-learn. 4 | 5 | Specifically, the various filters remove features for the following criteria: 6 | - Features that are almost constant (not by variance but by actual value). 7 | - Features that are highly correlated with other features. 8 | - Features that have a low variance (can deal with NaN values). 9 | - Features that are mostly NaN. 10 | - Features that are highly associated with the outcome (not just correlation) 11 | 12 | Various transformers are provided: 13 | - A standard scaler that deals with Nan values. 14 | - A min/max scaler. 15 | 16 | A transformer that accepts numpy arrays and turns them into pandas will be added soon. 17 | 18 | These filters and transformers can be used as part of a scikit-learn pipeline. 19 | 20 | ### Example: 21 | This example combines a scikit-learn filter with a causallib scaler. 22 | The pipeline scales the data, then removes covariates with low variance, 23 | and then applies IPW with logistic regression. 24 | 25 | ```Python 26 | from sklearn.linear_model import LogisticRegression 27 | from sklearn.feature_selection import VarianceThreshold 28 | from sklearn.pipeline import make_pipeline 29 | from causallib.estimation import IPW 30 | from causallib.datasets import load_nhefs 31 | from causallib.preprocessing.transformers import MinMaxScaler 32 | 33 | pipeline = make_pipeline(MinMaxScaler(), VarianceThreshold(0.1), LogisticRegression()) 34 | data = load_nhefs() 35 | ipw = IPW(pipeline) 36 | ipw.fit(data.X, data.a) 37 | ipw.estimate_population_outcome(data.X, data.a, data.y) 38 | ``` -------------------------------------------------------------------------------- /causallib/preprocessing/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BiomedSciAI/causallib/803a4d34eaf09980258b498631d6af15017528dc/causallib/preprocessing/__init__.py -------------------------------------------------------------------------------- /causallib/simulation/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BiomedSciAI/causallib/803a4d34eaf09980258b498631d6af15017528dc/causallib/simulation/__init__.py -------------------------------------------------------------------------------- /causallib/survival/README.md: -------------------------------------------------------------------------------- 1 | # Module `causallib.survival` 2 | This module allows estimating counterfactual outcomes in a setting of right-censored data 3 | (also known as survival analysis, or time-to-event modeling). 4 | In addition to the standard inputs of `X` - baseline covariates, `a` - treatment assignment and `y` - outcome indicator, 5 | a new variable `t` is introduced, measuring time from the beginning of observation period to an occurrence of event. 6 | An event may be right-censoring (where `y=0`) or an outcome of interest, or "death" (where `y=1`, 7 | which is also considered as censoring). 8 | Each of these methods uses an underlying machine learning model of choice, and can also integrate with the 9 | [`lifelines`](https://github.com/CamDavidsonPilon/lifelines) survival analysis Python package. 10 | 11 | Additional methods will be added incrementally. 12 | 13 | ## Available Methods 14 | The methods that are currently available are: 15 | 1. Weighting: `causallib.survival.WeightedSurvival` - uses `causallib`'s `WeightEstimator` (e.g., `IPW`) to generate weighted pseudo-population for survival analysis. 16 | 2. Standardization (parametric g-formula): `causallib.survival.StandardizedSurvival` - fits a parametric hazards model that includes baseline covariates. 17 | 3. Weighted Standardization: `causallib.survival.WeightedStandardizedSurvival` - combines the two above-mentioned methods. 18 | 19 | ### Example: Weighted survival analysis with Inverse Probability Weighting 20 | ```python 21 | from sklearn.linear_model import LogisticRegression 22 | from causallib.survival import WeightedSurvival 23 | from causallib.estimation import IPW 24 | from causallib.datasets import load_nhefs_survival 25 | 26 | ipw = IPW(learner=LogisticRegression()) 27 | weighted_survival_estimator = WeightedSurvival(weight_model=ipw) 28 | X, a, t, y = load_nhefs_survival() 29 | 30 | weighted_survival_estimator.fit(X, a) 31 | population_averaged_survival_curves = weighted_survival_estimator.estimate_population_outcome(X, a, t, y) 32 | ``` 33 | 34 | ### Example: Standardized survival (parametric g-formula) 35 | ```python 36 | from causallib.survival import StandardizedSurvival 37 | 38 | standardized_survival = StandardizedSurvival(survival_model=LogisticRegression()) 39 | standardized_survival.fit(X, a, t, y) 40 | population_averaged_survival_curves = standardized_survival.estimate_poplatuon_outcome(X, a, t) 41 | individual_survival_curves = standardized_survival.estimate_individual_outcome(X, a, t) 42 | ``` -------------------------------------------------------------------------------- /causallib/survival/__init__.py: -------------------------------------------------------------------------------- 1 | """Causal Survival Analysis Models""" 2 | 3 | from .univariate_curve_fitter import UnivariateCurveFitter 4 | from .regression_curve_fitter import RegressionCurveFitter 5 | from .marginal_survival import MarginalSurvival 6 | from .weighted_survival import WeightedSurvival 7 | from .standardized_survival import StandardizedSurvival 8 | from .weighted_standardized_survival import WeightedStandardizedSurvival 9 | 10 | __all__ = [ 11 | "UnivariateCurveFitter", 12 | "RegressionCurveFitter", 13 | "MarginalSurvival", 14 | "WeightedSurvival", 15 | "StandardizedSurvival", 16 | "WeightedStandardizedSurvival", 17 | ] 18 | -------------------------------------------------------------------------------- /causallib/survival/base_survival.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | from abc import ABC, abstractmethod 3 | 4 | 5 | class SurvivalBase(ABC): 6 | """ 7 | Interface class for causal survival analysis with fixed baseline covariates. 8 | """ 9 | @abstractmethod 10 | def fit(self, 11 | X: pd.DataFrame, 12 | a: pd.Series, 13 | t: pd.Series, 14 | y: pd.Series): 15 | """ 16 | Fits internal learner(s). 17 | 18 | Args: 19 | X (pd.DataFrame): Baseline covariate matrix of size (num_subjects, num_features). 20 | a (pd.Series): Treatment assignment of size (num_subjects,). 21 | t (pd.Series): Followup duration, size (num_subjects,). 22 | y (pd.Series): Observed outcome (1) or right censoring event (0), size (num_subjects,). 23 | 24 | Returns: 25 | self 26 | """ 27 | raise NotImplementedError 28 | 29 | @abstractmethod 30 | def estimate_population_outcome(self, 31 | **kwargs) -> pd.DataFrame: 32 | """ 33 | Returns population averaged survival curves. 34 | 35 | Returns: 36 | pd.DataFrame: with time-step index, treatment values as columns and survival as entries 37 | """ 38 | raise NotImplementedError 39 | 40 | 41 | class SurvivalTimeVaryingBase(SurvivalBase): 42 | """ 43 | Interface class for causal survival analysis estimators that support time-varying followup covariates. 44 | Followup covariates matrix (XF) needs to have a 'time' column, and indexed by subject IDs that correspond to 45 | the other inputs (X, a, y, t). All columns other than 'time' will be used for time-varying adjustments. 46 | 47 | Example XF format: 48 | +----+------+------+------+ 49 | | id | t | var1 | var2 | 50 | +----+------+------+------+ 51 | | 1 | 0 | 1.4 | 22 | 52 | | 1 | 4 | 1.2 | 22 | 53 | | 1 | 8 | 1.5 | NaN | 54 | | 2 | 0 | 1.6 | 10 | 55 | | 2 | 11 | 1.6 | 11 | 56 | +----+------+------+------+ 57 | """ 58 | 59 | @abstractmethod 60 | def fit(self, 61 | X: pd.DataFrame, 62 | a: pd.Series, 63 | t: pd.Series, 64 | y: pd.Series, 65 | XF: pd.DataFrame = None) -> None: 66 | """ 67 | Fits internal survival functions. 68 | 69 | Args: 70 | X (pd.DataFrame): Baseline covariate matrix of size (num_subjects, num_features). 71 | a (pd.Series): Treatment assignment of size (num_subjects,). 72 | t (pd.Series): Followup duration, size (num_subjects,). 73 | y (pd.Series): Observed outcome (1) or right censoring event (0), size (num_subjects,). 74 | XF (pd.DataFrame): Time-varying followup covariate matrix 75 | 76 | Returns: 77 | A fitted estimator with precalculated survival functions. 78 | """ 79 | -------------------------------------------------------------------------------- /causallib/survival/marginal_survival.py: -------------------------------------------------------------------------------- 1 | from .weighted_survival import WeightedSurvival 2 | from typing import Any 3 | 4 | 5 | class MarginalSurvival(WeightedSurvival): 6 | """ 7 | Marginal (un-adjusted) survival estimator. 8 | Essentially it is a degenerated WeightedSurvival instance without a weight model. 9 | """ 10 | def __init__(self, 11 | survival_model: Any = None): 12 | """ 13 | Marginal (un-adjusted) survival estimator. 14 | Args: 15 | survival_model: Three alternatives: 16 | 1. None - compute non-parametric KaplanMeier survival curve 17 | 2. Scikit-Learn estimator (needs to implement `predict_proba`) - compute parametric curve by fitting a time-varying hazards model 18 | 3. lifelines UnivariateFitter - use lifelines fitter to compute survival curves from events and durations 19 | """ 20 | 21 | super().__init__(weight_model=None, survival_model=survival_model) 22 | 23 | -------------------------------------------------------------------------------- /causallib/survival/univariate_curve_fitter.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | import numpy as np 3 | from typing import Optional 4 | from sklearn.base import BaseEstimator as SKLearnBaseEstimator 5 | from .survival_utils import safe_join 6 | from .regression_curve_fitter import RegressionCurveFitter 7 | 8 | 9 | class UnivariateCurveFitter: 10 | def __init__(self, learner: Optional[SKLearnBaseEstimator] = None): 11 | """ 12 | Default implementation of a univariate survival curve fitter. 13 | Construct a curve fitter, either non-parametric (Kaplan-Meier) or parametric. 14 | API follows 'lifelines' convention for univariate models, see here for example: 15 | https://lifelines.readthedocs.io/en/latest/fitters/univariate/KaplanMeierFitter.html#lifelines.fitters.kaplan_meier_fitter.KaplanMeierFitter.fit 16 | Args: 17 | learner: optional scikit-learn estimator (needs to implement `predict_proba`). If provided, will 18 | compute parametric curve by fitting a time-varying hazards model. if None, will compute 19 | non-parametric Kaplan-Meier estimator. 20 | """ 21 | self.learner = learner 22 | 23 | def fit(self, durations, event_observed=None, weights=None): 24 | """ 25 | Fits a univariate survival curve (Kaplan-Meier or parametric, if a learner was provided in constructor) 26 | 27 | Args: 28 | durations (Iterable): Duration subject was observed 29 | event_observed (Optional[Iterable]): Boolean or 0/1 iterable, where True means 'outcome event' and False 30 | means 'right censoring'. If unspecified, assumes that all events are 31 | 'outcome' (no censoring). 32 | weights (Optional[Iterable]): Optional subject weights 33 | 34 | Returns: 35 | Self 36 | """ 37 | # If 'event_observed' is unspecified, assumes that all events are 'outcome' (no censoring). 38 | if event_observed is None: 39 | event_observed = pd.Series(data=1, index=durations.index) 40 | 41 | if weights is None: 42 | weights = pd.Series(data=1, index=durations.index, name='weights') 43 | else: 44 | weights = pd.Series(data=weights, index=durations.index, name='weights') 45 | self.timeline_ = np.sort(np.unique(durations)) 46 | 47 | # If sklearn classifier is provided, fit parametric curve 48 | if self.learner is not None: 49 | self.curve_fitter_ = RegressionCurveFitter(learner=self.learner) 50 | fit_data, (duration_col_name, event_col_name, weights_col_name) = safe_join( 51 | df=None, list_of_series=[durations, event_observed, weights], return_series_names=True 52 | ) 53 | self.curve_fitter_.fit(df=fit_data, duration_col=duration_col_name, event_col=event_col_name, 54 | weights_col=weights_col_name) 55 | 56 | # Else, compute Kaplan Meier estimator non parametrically 57 | else: 58 | # Code inspired by lifelines KaplanMeierFitter 59 | df = pd.DataFrame({ 60 | 't': durations, 61 | 'removed': weights.to_numpy(), 62 | 'observed': weights.to_numpy() * (event_observed.to_numpy(dtype=bool)) 63 | }) 64 | 65 | death_table = df.groupby("t").sum() 66 | death_table['censored'] = (death_table['removed'] - death_table['observed']).astype(int) 67 | 68 | births = pd.DataFrame(np.zeros(durations.shape[0]), columns=["t"]) 69 | births['entrance'] = np.asarray(weights) 70 | births_table = births.groupby("t").sum() 71 | event_table = death_table.join(births_table, how="outer", sort=True).fillna( 72 | 0) # http://wesmckinney.com/blog/?p=414 73 | event_table['at_risk'] = event_table['entrance'].cumsum() - event_table['removed'].cumsum().shift(1).fillna( 74 | 0) 75 | self.event_table_ = event_table 76 | 77 | return self 78 | 79 | def predict(self, times=None, interpolate=False): 80 | """ 81 | Compute survival curve for time points given in 'times' param. 82 | Args: 83 | times: sequence of time points for prediction 84 | interpolate: if True, linearly interpolate non-observed times. Otherwise, repeat last observed time point. 85 | 86 | Returns: 87 | pd.Series: with times index and survival values 88 | 89 | """ 90 | if times is None: 91 | times = self.timeline_ 92 | else: 93 | times = sorted(times) 94 | 95 | if self.learner is not None: 96 | # Predict parametric survival curve 97 | survival = self.curve_fitter_.predict_survival_function(X=None, times=pd.Series(times)) 98 | else: 99 | # Compute hazard at each time step 100 | hazard = self.event_table_['observed'] / self.event_table_['at_risk'] 101 | timeline = hazard.index # if computed non-parametrically, timeline is all observed data points 102 | # Compute survival from hazards 103 | survival = pd.Series(data=np.cumprod(1 - hazard), index=timeline, name='survival') 104 | 105 | if interpolate: 106 | survival = pd.Series(data=np.interp(times, survival.index.values, survival.values), 107 | index=pd.Index(data=times, name='t'), name='survival') 108 | else: 109 | survival = survival.asof(times).squeeze() 110 | 111 | # Round near-zero values (may occur when using weights and all observed subjects "died" at some point) 112 | survival[np.abs(survival) < np.finfo(float).resolution] = 0 113 | 114 | return survival 115 | -------------------------------------------------------------------------------- /causallib/survival/weighted_standardized_survival.py: -------------------------------------------------------------------------------- 1 | from .survival_utils import canonize_dtypes_and_names 2 | from .standardized_survival import StandardizedSurvival 3 | from causallib.estimation.base_weight import WeightEstimator 4 | import pandas as pd 5 | from typing import Any, Optional 6 | 7 | 8 | class WeightedStandardizedSurvival(StandardizedSurvival): 9 | def __init__( 10 | self, 11 | weight_model: WeightEstimator, 12 | survival_model: Any, 13 | stratify: bool = True, 14 | outcome_covariates=None, 15 | weight_covariates=None, 16 | ): 17 | """ 18 | Combines WeightedSurvival and StandardizedSurvival: 19 | 1. Adjusts for treatment assignment by creating weighted pseudo-population (e.g., inverse propensity weighting). 20 | 2. Computes parametric curve by fitting a time-varying hazards model that includes baseline covariates. 21 | 22 | Args: 23 | weight_model: causallib compatible weight model (e.g., IPW) 24 | survival_model: Two alternatives: 25 | 1. Scikit-Learn estimator (needs to implement `predict_proba`) - compute parametric curve by fitting a 26 | time-varying hazards model that includes baseline covariates. Note that the model is fitted on a 27 | person-time table with all covariates, and might be computationally and memory expansive. 28 | 2. lifelines RegressionFitter - use lifelines fitter to compute survival curves from baseline covariates, 29 | events and durations 30 | stratify (bool): if True, fit a separate model per treatment group 31 | outcome_covariates (array): Covariates to use for outcome model. 32 | If None - all covariates passed will be used. 33 | Either list of column names or boolean mask. 34 | weight_covariates (array): Covariates to use for weight model. 35 | If None - all covariates passed will be used. 36 | Either list of column names or boolean mask. 37 | """ 38 | self.weight_model = weight_model 39 | super().__init__(survival_model=survival_model, stratify=stratify) 40 | self.outcome_covariates = outcome_covariates 41 | self.weight_covariates = weight_covariates 42 | 43 | def _prepare_data(self, X, *args, **kwargs): 44 | """ 45 | Extract the relevant parts for outcome model and weight model for the entire data matrix 46 | 47 | Args: 48 | X (pd.DataFrame): Covariate matrix of size (num_subjects, num_features). 49 | a (pd.Series): Treatment assignment of size (num_subjects,). 50 | 51 | Returns: 52 | (pd.DataFrame, pd.DataFrame): X_outcome, X_weight 53 | Data matrix for outcome model and data matrix weight model 54 | """ 55 | outcome_covariates = X.columns if self.outcome_covariates is None else self.outcome_covariates 56 | X_outcome = X[outcome_covariates] 57 | weight_covariates = X.columns if self.weight_covariates is None else self.weight_covariates 58 | X_weight = X[weight_covariates] 59 | return X_outcome, X_weight 60 | 61 | def fit(self, 62 | X: pd.DataFrame, 63 | a: pd.Series, 64 | t: pd.Series, 65 | y: pd.Series, 66 | w: Optional[pd.Series] = None, 67 | fit_kwargs: Optional[dict] = None): 68 | """ 69 | Fits parametric models and calculates internal survival functions. 70 | 71 | Args: 72 | X (pd.DataFrame): Baseline covariate matrix of size (num_subjects, num_features). 73 | a (pd.Series): Treatment assignment of size (num_subjects,). 74 | t (pd.Series): Followup duration, size (num_subjects,). 75 | y (pd.Series): Observed outcome (1) or right censoring event (0), size (num_subjects,). 76 | w (pd.Series): NOT USED (for compatibility only) optional subject weights. 77 | fit_kwargs (dict): Optional kwargs for fit call of survival model 78 | 79 | Returns: 80 | self 81 | """ 82 | a, t, y, _, X = canonize_dtypes_and_names(a=a, t=t, y=y, w=None, X=X) 83 | X_outcome, X_weight = self._prepare_data(X) 84 | 85 | self.weight_model.fit(X=X_weight, a=a, y=y) 86 | iptw_weights = self.weight_model.compute_weights(X_weight, a) 87 | 88 | # Call fit from StandardizedSurvival, with added ipt weights 89 | super().fit(X=X_outcome, a=a, t=t, y=y, w=iptw_weights, fit_kwargs=fit_kwargs) 90 | return self 91 | 92 | def estimate_individual_outcome( 93 | self, 94 | X: pd.DataFrame, 95 | a: pd.Series, 96 | t: pd.Series, 97 | y: Optional[Any] = None, 98 | timeline_start: Optional[int] = None, 99 | timeline_end: Optional[int] = None 100 | ) -> pd.DataFrame: 101 | X_outcome, _ = self._prepare_data(X) 102 | potential_outcomes = super().estimate_individual_outcome( 103 | X_outcome, 104 | a, t, y, 105 | timeline_start=timeline_start, 106 | timeline_end=timeline_end, 107 | ) 108 | return potential_outcomes 109 | -------------------------------------------------------------------------------- /causallib/survival/weighted_survival.py: -------------------------------------------------------------------------------- 1 | from causallib.estimation.base_weight import WeightEstimator 2 | from .univariate_curve_fitter import UnivariateCurveFitter 3 | from sklearn.base import BaseEstimator as SKLearnBaseEstimator 4 | from typing import Any 5 | import pandas as pd 6 | from copy import deepcopy 7 | from .survival_utils import canonize_dtypes_and_names 8 | from .base_survival import SurvivalBase 9 | from typing import Optional 10 | 11 | 12 | class WeightedSurvival(SurvivalBase): 13 | """ 14 | Weighted survival estimator 15 | """ 16 | 17 | def __init__(self, 18 | weight_model: WeightEstimator = None, 19 | survival_model: Any = None): 20 | """ 21 | Weighted survival estimator. 22 | Args: 23 | weight_model: causallib compatible weight model (e.g., IPW) 24 | survival_model: Three alternatives: 25 | 1. None - compute non-parametric KaplanMeier survival curve 26 | 2. Scikit-Learn estimator (needs to implement `predict_proba`) - compute parametric curve by fitting a 27 | time-varying hazards model 28 | 3. lifelines UnivariateFitter - use lifelines fitter to compute survival curves from events and durations 29 | """ 30 | self.weight_model = weight_model 31 | 32 | # Construct default curve fitter, non parametric estimation (Kaplan-Meier) 33 | if survival_model is None: 34 | self.survival_model = UnivariateCurveFitter() 35 | # Construct default curve fitter, parametric with a scikit-learn estimator 36 | elif isinstance(survival_model, SKLearnBaseEstimator): 37 | self.survival_model = UnivariateCurveFitter(survival_model) 38 | # Initialized lifelines univariate fitter (or any implementation with a compatible API) 39 | else: 40 | self.survival_model = survival_model 41 | 42 | def fit(self, 43 | X: pd.DataFrame, 44 | a: pd.Series, 45 | t: pd.Series = None, 46 | y: pd.Series = None, 47 | fit_kwargs: Optional[dict] = None): 48 | """ 49 | Fits internal weight module (e.g. IPW module, adversarial weighting, etc). 50 | 51 | Args: 52 | X (pd.DataFrame): Baseline covariate matrix of size (num_subjects, num_features). 53 | a (pd.Series): Treatment assignment of size (num_subjects,). 54 | t (pd.Series): NOT USED (for compatibility only) 55 | y (pd.Series): NOT USED (for compatibility only) 56 | fit_kwargs (dict): Optional kwargs for fit call of survival model (NOT USED, since fit 57 | call of survival model occurs in 'estimate_population_outcome' rather than here) 58 | 59 | Returns: 60 | self 61 | """ 62 | a, _, y, _, X = canonize_dtypes_and_names(a=a, t=None, y=y, w=None, X=X) 63 | if self.weight_model is not None: 64 | self.weight_model.fit(X=X, a=a, y=y) 65 | 66 | return self 67 | 68 | def estimate_population_outcome(self, 69 | X: pd.DataFrame, 70 | a: pd.Series, 71 | t: pd.Series, 72 | y: pd.Series, 73 | timeline_start: Optional[int] = None, 74 | timeline_end: Optional[int] = None 75 | ) -> pd.DataFrame: 76 | """ 77 | Returns population averaged survival curves. 78 | 79 | Args: 80 | X (pd.DataFrame): Baseline covariate matrix of size (num_subjects, num_features). 81 | a (pd.Series): Treatment assignment of size (num_subjects,). 82 | t (pd.Series|int): Followup durations, size (num_subjects,). 83 | y (pd.Series): Observed outcome (1) or right censoring event (0), size (num_subjects,). 84 | timeline_start (int): Common start time-step. If provided, will generate survival curves starting 85 | from 'timeline_start' for all patients. If None, will predict from first observed event. 86 | timeline_end (int): Common end time-step. If provided, will generate survival curves up to 'timeline_end' 87 | for all patients. If None, will predict up to last observed event. 88 | 89 | Returns: 90 | pd.DataFrame: with timestep index, treatment values as columns and survival as entries 91 | """ 92 | self.stratified_curve_fitters_ = {} 93 | a, t, y, _, X = canonize_dtypes_and_names(a=a, t=t, y=y, w=None, X=X) 94 | min_time = timeline_start if timeline_start is not None else int(t.min()) 95 | max_time = timeline_end if timeline_end is not None else int(t.max()) 96 | 97 | if self.weight_model is not None: 98 | # Generate inverse propensity for treatment weights (IPTW) 99 | iptw_weights = self.weight_model.compute_weights(X, a) 100 | iptw_weights.name = 'w' 101 | else: 102 | iptw_weights = None 103 | 104 | # Fit or compute survival curves 105 | treatment_values = a.unique() 106 | survival_curves = [] 107 | for treatment_value in treatment_values: 108 | stratum_indices = a == treatment_value 109 | stratum_curve_fitter = deepcopy(self.survival_model) 110 | 111 | # Fit curve model 112 | stratum_curve_fitter.fit(durations=t[stratum_indices], event_observed=y[stratum_indices], 113 | weights=iptw_weights[stratum_indices] if iptw_weights is not None else None) 114 | self.stratified_curve_fitters_[treatment_value] = stratum_curve_fitter 115 | 116 | # Predict curve model 117 | curve = stratum_curve_fitter.predict(times=range(min_time, max_time + 1)) 118 | curve.rename(treatment_value, inplace=True) 119 | survival_curves.append(curve) 120 | 121 | res = pd.concat(survival_curves, axis=1) 122 | 123 | # Setting index/column names 124 | res.index.name = t.name 125 | res.columns.name = a.name 126 | return res 127 | -------------------------------------------------------------------------------- /causallib/tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BiomedSciAI/causallib/803a4d34eaf09980258b498631d6af15017528dc/causallib/tests/__init__.py -------------------------------------------------------------------------------- /causallib/tests/test_base_weight.py: -------------------------------------------------------------------------------- 1 | """ 2 | (C) IBM Corp, 2019, All rights reserved 3 | Created on Aug 25, 2019 4 | 5 | @author: EHUD KARAVANI 6 | """ 7 | 8 | import unittest 9 | import pandas as pd 10 | from causallib.estimation.base_weight import WeightEstimator 11 | 12 | 13 | class TestBaseWeight(unittest.TestCase): 14 | @classmethod 15 | def setUpClass(cls): 16 | cls.y = pd.Series([0, 0, 0, 0, 0, 1, 1, 1, 1, 1]) 17 | cls.w = pd.Series([1, 1, 1, 1, 1, 1, 1, 1, 0.5, 0.5]) 18 | cls.a = pd.Series([0, 0, 0, 0, 1, 0, 1, 1, 1, 1]) 19 | 20 | def setUp(self): 21 | self.model = WeightEstimator(learner=None) 22 | 23 | def test_no_weighting_no_stratification(self): 24 | result = self.model._compute_stratified_weighted_aggregate(self.y, None, None) 25 | truth = pd.Series(5/10, index=[0]) 26 | pd.testing.assert_series_equal(truth, result) 27 | 28 | def test_weighting_no_stratification(self): 29 | result = self.model._compute_stratified_weighted_aggregate(self.y, self.w, None) 30 | truth = pd.Series(4/9, index=[0]) 31 | pd.testing.assert_series_equal(truth, result) 32 | 33 | def test_no_weighting_stratification(self): 34 | result = self.model._compute_stratified_weighted_aggregate(self.y, None, self.a) 35 | truth = pd.Series([1/5, 4/5], index=[0, 1]) 36 | pd.testing.assert_series_equal(truth, result) 37 | 38 | def test_weighting_stratification(self): 39 | result = self.model._compute_stratified_weighted_aggregate(self.y, self.w, self.a) 40 | truth = pd.Series([1/5, 3/4], index=[0, 1]) 41 | pd.testing.assert_series_equal(truth, result) 42 | 43 | def test_subset_treatment_values(self): 44 | with self.subTest("Subset of treatment values exist in treatment"): 45 | result = self.model._compute_stratified_weighted_aggregate(self.y, None, self.a, [0]) 46 | truth = pd.Series([1/5], index=[0]) 47 | pd.testing.assert_series_equal(truth, result) 48 | 49 | with self.subTest("Subset of treatment values does not exist in treatment"): 50 | with self.assertRaises(ZeroDivisionError): # Since the group is empty its weights' sum is zero 51 | self.model._compute_stratified_weighted_aggregate(self.y, None, self.a, [3]) 52 | 53 | 54 | -------------------------------------------------------------------------------- /causallib/tests/test_evaluation.py: -------------------------------------------------------------------------------- 1 | # (C) Copyright 2020 IBM Corp. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | # Created on Nov 12, 2020 16 | import unittest 17 | 18 | import matplotlib 19 | import matplotlib.axes 20 | import numpy as np 21 | import pandas as pd 22 | 23 | from sklearn.linear_model import LogisticRegression, LinearRegression 24 | 25 | from causallib.evaluation import evaluate 26 | from causallib.evaluation.evaluator import evaluate_bootstrap 27 | from causallib.evaluation.metrics import ( 28 | get_default_binary_metrics, 29 | get_default_regression_metrics, 30 | ) 31 | from causallib.evaluation.scoring import PropensityEvaluatorScores 32 | from causallib.estimation import AIPW, IPW, StratifiedStandardization, Matching 33 | from causallib.datasets import load_nhefs 34 | 35 | 36 | matplotlib.use("Agg") 37 | 38 | 39 | def binarize(cts_output: pd.Series) -> pd.Series: 40 | """Turn continuous outcome into binary by applying sigmoid. 41 | 42 | Args: 43 | cts_output (pd.Series): outcomes as continuous variables 44 | 45 | Returns: 46 | pd.Series: outcomes as binary variables 47 | """ 48 | 49 | y = 1 / (1 + np.exp(-cts_output)) 50 | y = np.random.binomial(1, y) 51 | y = pd.Series(y, index=cts_output.index) 52 | return y 53 | 54 | 55 | class TestEvaluations(unittest.TestCase): 56 | @classmethod 57 | def setUpClass(self): 58 | data = load_nhefs() 59 | self.X, self.a, self.y = data.X, data.a, data.y 60 | self.y_bin = binarize(data.y) 61 | ipw = IPW(LogisticRegression(solver="liblinear"), clip_min=0.05, clip_max=0.95) 62 | std = StratifiedStandardization(LinearRegression()) 63 | self.dr = AIPW(std, ipw) 64 | self.dr.fit(self.X, self.a, self.y) 65 | self.std_bin = StratifiedStandardization(LogisticRegression(solver="liblinear")) 66 | self.std_bin.fit(self.X, self.a, self.y_bin) 67 | 68 | def test_evaluate_bootstrap_with_refit_works(self): 69 | ipw = IPW(LogisticRegression(solver="liblinear"), clip_min=0.05, clip_max=0.95) 70 | evaluate_bootstrap(ipw, self.X, self.a, self.y, n_bootstrap=5, refit=True) 71 | 72 | def test_evaluate_cv_works_with_unfit_models(self): 73 | ipw = IPW(LogisticRegression(solver="liblinear"), clip_min=0.05, clip_max=0.95) 74 | evaluate(ipw, self.X, self.a, self.y, cv="auto") 75 | 76 | def test_metrics_to_evaluate_is_none_means_no_metrics_evaluated(self): 77 | for model in (self.dr.outcome_model, self.dr.weight_model): 78 | self.ensure_metrics_are_none(model) 79 | 80 | def ensure_metrics_are_none(self, model): 81 | results = evaluate(model, self.X, self.a, self.y, metrics_to_evaluate=None) 82 | self.assertIsNone(results.evaluated_metrics) 83 | 84 | def test_default_evaluation_metrics_weights(self): 85 | model = self.dr.weight_model 86 | results = evaluate(model, self.X, self.a, self.y) 87 | self.assertEqual( 88 | set(results.evaluated_metrics.prediction_scores.columns), 89 | set(get_default_binary_metrics().keys()), 90 | ) 91 | 92 | def test_default_evaluation_metrics_continuous_outcome(self): 93 | model = self.dr.outcome_model 94 | results = evaluate(model, self.X, self.a, self.y) 95 | self.assertEqual( 96 | set(results.evaluated_metrics.columns), 97 | set(get_default_regression_metrics().keys()), 98 | ) 99 | 100 | def test_default_evaluation_metrics_binary_outcome(self): 101 | model = self.std_bin 102 | results = evaluate(model, self.X, self.a, self.y_bin) 103 | self.assertEqual( 104 | set(results.evaluated_metrics.columns), 105 | set(get_default_binary_metrics().keys()), 106 | ) 107 | 108 | def test_outcome_weight_propensity_evaluated_metrics(self): 109 | matching = Matching(matching_mode="control_to_treatment").fit(self.X, self.a, self.y) 110 | ipw = IPW(LogisticRegression(max_iter=4000)).fit(self.X, self.a, self.y) 111 | std = StratifiedStandardization(LinearRegression()).fit(self.X, self.a, self.y) 112 | 113 | matching_res = evaluate(matching, self.X, self.a, self.y).evaluated_metrics 114 | ipw_res = evaluate(ipw, self.X, self.a, self.y).evaluated_metrics 115 | std_res = evaluate(std, self.X, self.a, self.y).evaluated_metrics 116 | 117 | covariate_balance_df_shape = (self.X.columns.size, 2) 118 | 119 | with self.subTest("Matching evaluated metrics"): 120 | self.assertIsInstance(matching_res, pd.DataFrame) 121 | self.assertTupleEqual(matching_res.shape, covariate_balance_df_shape) 122 | 123 | with self.subTest("IPW evaluated metrics"): 124 | self.assertIsInstance(ipw_res, PropensityEvaluatorScores) 125 | self.assertIsInstance(ipw_res.covariate_balance, pd.DataFrame) 126 | self.assertTupleEqual(ipw_res.covariate_balance.shape, covariate_balance_df_shape) 127 | self.assertIsInstance(ipw_res.prediction_scores, pd.DataFrame) 128 | propensity_scores_shape = (1, len(get_default_binary_metrics())) 129 | self.assertTupleEqual(ipw_res.prediction_scores.shape, propensity_scores_shape) 130 | 131 | with self.subTest("Standardization evaluated metrics"): 132 | self.assertIsInstance(std_res, pd.DataFrame) 133 | outcome_scores_shape = (3, len(get_default_regression_metrics())) # 3 = treated, control, overall 134 | self.assertTupleEqual(std_res.shape, outcome_scores_shape) 135 | -------------------------------------------------------------------------------- /causallib/tests/test_marginal_outcome.py: -------------------------------------------------------------------------------- 1 | """ 2 | (C) IBM Corp, 2019, All rights reserved 3 | Created on Aug 25, 2019 4 | 5 | @author: EHUD KARAVANI 6 | """ 7 | 8 | import unittest 9 | import pandas as pd 10 | import numpy as np 11 | from causallib.estimation import MarginalOutcomeEstimator 12 | 13 | 14 | class TestMarginalOutcomeEstimator(unittest.TestCase): 15 | @classmethod 16 | def setUpClass(cls): 17 | cls.X = pd.DataFrame([[1, 1, 0, 0, 1, 0, 0, 0, 1, 1]]) 18 | cls.y = pd.Series([0, 0, 0, 0, 0, 1, 1, 1, 1, 1]) 19 | cls.a = pd.Series([0, 0, 0, 0, 1, 0, 1, 1, 1, 1]) 20 | 21 | def setUp(self): 22 | self.model = MarginalOutcomeEstimator(learner=None) 23 | 24 | def test_fit_return(self): 25 | model = self.model.fit(self.X, self.a, self.y) 26 | self.assertTrue(isinstance(model, MarginalOutcomeEstimator)) 27 | 28 | def test_outcome_estimation(self): 29 | self.model.fit(self.X, self.a, self.y) 30 | outcomes = self.model.estimate_population_outcome(self.X, self.a, self.y) 31 | truth = pd.Series([1 / 5, 4 / 5], index=[0, 1]) 32 | pd.testing.assert_series_equal(truth, outcomes) 33 | 34 | with self.subTest("Change covariate and see no change in estimation"): 35 | X = pd.DataFrame(np.arange(20).reshape(4, 5)) # Different values and shape 36 | outcomes = self.model.estimate_population_outcome(X, self.a, self.y) 37 | truth = pd.Series([1/5, 4/5], index=[0, 1]) 38 | pd.testing.assert_series_equal(truth, outcomes) 39 | 40 | 41 | 42 | -------------------------------------------------------------------------------- /causallib/tests/test_overlap_weights.py: -------------------------------------------------------------------------------- 1 | """ 2 | (C) Copyright 2021 IBM Corp. 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | http://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | 16 | Created on Jun 09, 2021 17 | 18 | """ 19 | 20 | import unittest 21 | # from unittest import TestCase 22 | 23 | # import numpy as np 24 | import pandas as pd 25 | from sklearn.datasets import make_classification 26 | from sklearn.linear_model import LogisticRegression 27 | from causallib.estimation.overlap_weights import OverlapWeights 28 | 29 | 30 | class TestOverlapWeights(unittest.TestCase): 31 | @classmethod 32 | def setUpClass(cls): 33 | # Data: 34 | X, a = make_classification(n_features=1, n_informative=1, n_redundant=0, n_repeated=0, n_classes=2, 35 | n_clusters_per_class=1, flip_y=0.0, class_sep=10.0) 36 | cls.data_r_100 = {"X": pd.DataFrame(X), "a": pd.Series(a)} 37 | X, a = make_classification(n_features=1, n_informative=1, n_redundant=0, n_repeated=0, n_classes=2, 38 | n_clusters_per_class=1, flip_y=0.2, class_sep=10.0) 39 | cls.data_r_80 = {"X": pd.DataFrame(X), "a": pd.Series(a)} 40 | 41 | # Data that maps x=0->a=0 and x=1->a=1: 42 | X = pd.Series([0] * 50 + [1] * 50) 43 | cls.data_cat_r_100 = {"X": X.to_frame(), "a": X} 44 | 45 | # Data that maps x=0->a=0 and x=1->a=1, but 10% of x=0->a=1 and 10% of x=1->a=0: 46 | X = pd.Series([0] * 40 + [1] * 10 + [1] * 40 + [0] * 10).to_frame() 47 | a = pd.Series([0] * 50 + [1] * 50) 48 | cls.data_cat_r_80 = {"X": X, "a": a} 49 | 50 | # Avoids regularization of the model: 51 | cls.estimator = OverlapWeights(LogisticRegression(C=1e6, solver='lbfgs'), use_stabilized=False) 52 | 53 | def setUp(self): 54 | self.estimator.fit(self.data_r_100["X"], self.data_r_100["a"]) 55 | 56 | def test_classes_number_is_two(self): 57 | with self.subTest("OW check error arise if single class"): 58 | a = pd.Series(0, index=self.data_r_100["X"].index) 59 | with self.assertRaises(AssertionError): 60 | self.estimator.compute_weight_matrix(self.data_r_100["X"], a) 61 | with self.subTest("OW check error arise if more than two classes"): 62 | a = pd.Series([0] * 30 + [1] * 30 + [2] * 40, index=self.data_r_100["X"].index) 63 | with self.assertRaises(AssertionError): 64 | self.estimator.compute_weight_matrix(self.data_r_100["X"], a) 65 | 66 | def test_truncate_values_not_none(self): 67 | with self.assertWarns(RuntimeWarning): 68 | self.estimator.compute_weight_matrix( 69 | self.data_r_100["X"], self.data_r_100["a"], 70 | clip_min=0.2, clip_max=0.8, use_stabilized=None) 71 | 72 | def test_categorical_classes_df_col_names(self): 73 | a = pd.Series(["a"] * 50 + ["b"] * 50, index=self.data_r_100["X"].index) 74 | w = self.estimator.compute_weight_matrix(self.data_r_100["X"], a) 75 | cols_w = w.columns.values.tolist() 76 | self.assertTrue(cols_w, ["a", "b"]) 77 | 78 | def test_ow_weights_reversed_to_propensity(self): 79 | propensity = self.estimator.learner.predict_proba(self.data_r_100["X"]) 80 | propensity = pd.DataFrame(propensity) 81 | ow_weights = self.estimator.compute_weight_matrix(self.data_r_100["X"], self.data_r_100["a"], 82 | clip_min=None, clip_max=None) 83 | propensity.columns = propensity.columns.astype(ow_weights.columns.dtype) # Avoid column dtype assert 84 | pd.testing.assert_series_equal(propensity.loc[:, 0], ow_weights.loc[:, 1], check_names=False) 85 | pd.testing.assert_series_equal(propensity.loc[:, 1], ow_weights.loc[:, 0], check_names=False) 86 | pd.testing.assert_index_equal(propensity.columns, ow_weights.columns) 87 | -------------------------------------------------------------------------------- /causallib/tests/test_positivity_data.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | from causallib.positivity.datasets.positivity_data_simulator import make_1d_overlap_data 3 | from causallib.positivity.datasets.pizza_data_simulator import pizza 4 | 5 | 6 | class TestPositivityDataSim(unittest.TestCase): 7 | 8 | def test_get_1d_data(self): 9 | treatment_bounds = (0, 75) 10 | control_bounds = (25, 100) 11 | X, a = make_1d_overlap_data( 12 | treatment_bounds=treatment_bounds, control_bounds=control_bounds) 13 | self.assertEqual(X[a == 0].values.min(), control_bounds[0]) 14 | self.assertEqual(X[a == 0].values.max(), control_bounds[1]-1) 15 | self.assertEqual(X[a == 1].values.min(), treatment_bounds[0]) 16 | self.assertEqual(X[a == 1].values.max(), treatment_bounds[1]-1) 17 | 18 | 19 | class TestPositivityPizzaData(unittest.TestCase): 20 | def test_pizza_data(self): 21 | X, a = pizza(seed=0, n_samples=10000) 22 | self.assertAlmostEqual( 23 | X.loc[a == 0, 0].values.min(), X.loc[a == 1, 0].values.min(), 24 | delta=0.1, # depends on the density of the points 25 | ) 26 | self.assertAlmostEqual( 27 | X.loc[a == 0, 1].values.min(), X.loc[a == 1, 1].values.min(), 28 | delta=0.1, # depends on the density of the points 29 | ) 30 | self.assertAlmostEqual( 31 | X.loc[a == 0, 0].values.max(), X.loc[a == 1, 0].values.max(), 32 | delta=0.1, # depends on the density of the points 33 | ) 34 | self.assertAlmostEqual( 35 | X.loc[a == 0, 1].values.max(), X.loc[a == 1, 1].values.max(), 36 | delta=0.1, # depends on the density of the points 37 | ) 38 | 39 | def test_dimensions(self): 40 | X, a = pizza(n_dim=3, n_samples=10) 41 | self.assertEqual(3, X.shape[1]) 42 | self.assertEqual(10, X.shape[0]) 43 | self.assertEqual(10, a.shape[0]) 44 | -------------------------------------------------------------------------------- /causallib/tests/test_transformers.py: -------------------------------------------------------------------------------- 1 | """ 2 | (C) Copyright 2019 IBM Corp. 3 | 4 | Licensed under the Apache License, Version 2.0 (the "License"); 5 | you may not use this file except in compliance with the License. 6 | You may obtain a copy of the License at 7 | 8 | http://www.apache.org/licenses/LICENSE-2.0 9 | 10 | Unless required by applicable law or agreed to in writing, software 11 | distributed under the License is distributed on an "AS IS" BASIS, 12 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | See the License for the specific language governing permissions and 14 | limitations under the License. 15 | 16 | Created on Aug 29, 2018 17 | 18 | """ 19 | import abc 20 | import unittest 21 | 22 | import pandas as pd 23 | import numpy as np 24 | 25 | from causallib.preprocessing.transformers import StandardScaler, MinMaxScaler 26 | 27 | 28 | class TestTransformers: 29 | @abc.abstractmethod 30 | def test_fit(self): 31 | pass 32 | 33 | @abc.abstractmethod 34 | def test_fit_technical(self): 35 | pass 36 | 37 | @abc.abstractmethod 38 | def test_transform(self): 39 | pass 40 | 41 | @abc.abstractmethod 42 | def test_inverse_transform(self): 43 | pass 44 | 45 | 46 | class TestStandardScaler(TestTransformers, unittest.TestCase): 47 | @classmethod 48 | def setUpClass(cls): 49 | cls.data = pd.DataFrame({"binary": [4, 5, 5, np.nan], 50 | "continuous": [0, 2, 4, np.nan]}) 51 | cls.transformer = StandardScaler(with_mean=True, with_std=True, ignore_nans=True) 52 | cls.transformer.fit(cls.data) 53 | 54 | def test_fit_technical(self): 55 | with self.subTest("Has mean_ attribute"): 56 | self.assertTrue(hasattr(self.transformer, "mean_")) 57 | 58 | with self.subTest("Has scale_ attribute"): 59 | self.assertTrue(hasattr(self.transformer, "scale_")) 60 | 61 | with self.subTest("Applied on the right amount of columns"): 62 | # Should only be applied on "continuous" column 63 | self.assertEqual(1, len(self.transformer.mean_)) 64 | self.assertEqual(1, len(self.transformer.scale_)) 65 | 66 | def test_fit(self): 67 | with self.subTest("Test means are right"): 68 | self.assertEqual(2.0, self.transformer.mean_["continuous"]) 69 | 70 | with self.subTest("Test scale is correct"): 71 | self.assertEqual(2.0, self.transformer.scale_["continuous"]) 72 | 73 | def test_transform(self): 74 | transformed = self.transformer.transform(self.data) 75 | 76 | with self.subTest("Was not applied on binary column"): 77 | pd.testing.assert_series_equal(self.data["binary"], transformed["binary"]) 78 | 79 | with self.subTest("Result is right on the transformed column"): 80 | pd.testing.assert_series_equal(transformed["continuous"], pd.Series([-1.0, 0.0, 1.0, np.nan]), 81 | check_names=False) 82 | 83 | def test_inverse_transform(self): 84 | untransformed = self.transformer.inverse_transform(self.transformer.transform(self.data)) 85 | pd.testing.assert_frame_equal(self.data, untransformed) 86 | 87 | 88 | class TestMinMaxScaler(TestTransformers, unittest.TestCase): 89 | @classmethod 90 | def setUpClass(cls): 91 | cls.data = pd.DataFrame({"binary": [4, 5, 5, np.nan], 92 | "continuous": [0, 2, 4, np.nan]}) 93 | cls.transformer = MinMaxScaler(only_binary_features=True, ignore_nans=True) 94 | cls.transformer.fit(cls.data) 95 | 96 | def test_fit_technical(self): 97 | with self.subTest("Has min_ attribute"): 98 | self.assertTrue(hasattr(self.transformer, "min_")) 99 | 100 | with self.subTest("Has max_ attribute"): 101 | self.assertTrue(hasattr(self.transformer, "max_")) 102 | 103 | with self.subTest("Has scale_ attribute"): 104 | self.assertTrue(hasattr(self.transformer, "scale_")) 105 | 106 | with self.subTest("Applied on the right amount of columns"): 107 | # Should only be applied on "continuous" column 108 | self.assertEqual(1, len(self.transformer.min_)) 109 | self.assertEqual(1, len(self.transformer.max_)) 110 | self.assertEqual(1, len(self.transformer.scale_)) 111 | 112 | def test_fit(self): 113 | with self.subTest("Test min is right"): 114 | self.assertEqual(4.0, self.transformer.min_["binary"]) 115 | 116 | with self.subTest("Test max is right"): 117 | self.assertEqual(5.0, self.transformer.max_["binary"]) 118 | 119 | with self.subTest("Test scale is correct"): 120 | self.assertEqual(1.0, self.transformer.scale_["binary"]) 121 | 122 | def test_transform(self): 123 | transformed = self.transformer.transform(self.data) 124 | 125 | with self.subTest("Was not applied on binary column"): 126 | pd.testing.assert_series_equal(self.data["continuous"], transformed["continuous"]) 127 | 128 | with self.subTest("Result is right on the transformed column"): 129 | pd.testing.assert_series_equal(transformed["binary"], pd.Series([0.0, 1.0, 1.0, np.nan]), 130 | check_names=False) 131 | 132 | def test_inverse_transform(self): 133 | untransformed = self.transformer.inverse_transform(self.transformer.transform(self.data)) 134 | pd.testing.assert_frame_equal(self.data, untransformed) 135 | -------------------------------------------------------------------------------- /causallib/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BiomedSciAI/causallib/803a4d34eaf09980258b498631d6af15017528dc/causallib/utils/__init__.py -------------------------------------------------------------------------------- /causallib/utils/crossfit.py: -------------------------------------------------------------------------------- 1 | from sklearn.model_selection import KFold, StratifiedKFold 2 | from sklearn.utils.metaestimators import _safe_split 3 | import pandas as pd 4 | from sklearn.base import clone 5 | 6 | 7 | def cross_fitting(estimator, X, y, n_splits=5, predict_proba=False, 8 | return_estimator=True): 9 | """ 10 | 11 | Args: 12 | estimator(object): sklearn object 13 | X (pd.DataFrame): Covariate matrix of size (num_subjects, num_features). 14 | y (pd.Series): Observed outcome of size (num_subjects,). 15 | n_splits (int): number of folds 16 | predict_proba (bool): If True, the treatment model is a classifier 17 | and use 'predict_proba', 18 | If False, use 'predict'. 19 | return_estimator (bool): If true return fitted estimators of each fold 20 | 21 | Returns: 22 | array of held-out prediction, 23 | if return estimator: 24 | a tuple of estimators on held-out-data 25 | """ 26 | 27 | cv = StratifiedKFold(n_splits=n_splits) if predict_proba else KFold( 28 | n_splits=n_splits) 29 | ret = [_fit_and_predict(clone(estimator), X, y, train, test, 30 | predict_proba=predict_proba) 31 | for train, test in cv.split(X, y)] 32 | zipped_ret = list(zip(*ret)) 33 | if return_estimator: 34 | return pd.concat(zipped_ret[0]), zipped_ret[1] 35 | else: 36 | return pd.concat(zipped_ret[0]) 37 | 38 | 39 | def _fit_and_predict(estimator, X, y, train, test, predict_proba): 40 | """ 41 | fit the estimator with the train samples and make prediction with the test data 42 | Args: 43 | estimator(object): sklearn object 44 | X (pd.DataFrame): Covariate matrix of size (num_subjects, num_features). 45 | y (pd.Series): Observed outcome of size (num_subjects,). 46 | train: 47 | test: 48 | predict_proba (bool): If True, the treatment model is a classifier 49 | and use 'predict_proba', 50 | If False, use 'predict'. 51 | 52 | """ 53 | X_train, y_train = _safe_split(estimator, X, y, train) 54 | X_test, _ = _safe_split(estimator, X, y, test, train) 55 | estimator.fit(X_train, y_train) 56 | if predict_proba: 57 | pred = estimator.predict_proba(X_test)[:, 1] 58 | else: 59 | pred = estimator.predict(X_test) 60 | 61 | return pd.Series(pred, index=X_test.index), estimator 62 | -------------------------------------------------------------------------------- /causallib/utils/exceptions.py: -------------------------------------------------------------------------------- 1 | 2 | class ColumnNameChangeWarning(UserWarning): 3 | """Warning that causallib renamed input name 4 | to ensure all columns are of a single type 5 | so scikit-learn>=1.2.0 is happy. 6 | 7 | See array validation: 8 | https://github.com/scikit-learn/scikit-learn/blob/8133ecaacca77f06a8c4c560f5dbbfd654f1990f/sklearn/utils/validation.py#L2271-L2280""" 9 | pass 10 | -------------------------------------------------------------------------------- /docs/Makefile: -------------------------------------------------------------------------------- 1 | # Minimal makefile for Sphinx documentation 2 | # 3 | 4 | # You can set these variables from the command line, and also 5 | # from the environment for the first two. 6 | SPHINXOPTS ?= 7 | SPHINXBUILD ?= sphinx-build 8 | SOURCEDIR = source 9 | BUILDDIR = build 10 | 11 | # Put it first so that "make" without argument is like "make help". 12 | help: 13 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 14 | 15 | .PHONY: help Makefile 16 | 17 | # Catch-all target: route all unknown targets to Sphinx using the new 18 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). 19 | %: Makefile 20 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 21 | -------------------------------------------------------------------------------- /docs/README.md: -------------------------------------------------------------------------------- 1 | To generate the `source/causallib.*` files run: 2 | ```bash 3 | sphinx-apidoc -o source ../causallib --separate --force 4 | ``` 5 | (from within this directory) 6 | 7 | To generate html build run: 8 | ```bash 9 | make html 10 | ``` 11 | 12 | #### requirements 13 | * [sphinx v2.1.0](http://www.sphinx-doc.org/en/master/): to generate documentation 14 | * [m2r v0.2.1](https://github.com/miyakogi/m2r): to support inline inclusion of the modules' README markdown files 15 | * [nbsphinx v0.4.2](https://nbsphinx.readthedocs.io): to support inclusion of Jupyter Notebooks inside the html 16 | documentation 17 | 18 | `requirement.txt` is a requirement file necessary for [readthedocs.org](readthedocs.org) build. 19 | Pointed by `../.readthedocs.yml` configuration file. 20 | 21 | The `source/conf.py` file also includes some arbitrary code for the automatic 22 | inclusion of README files within the documentation. 23 | -------------------------------------------------------------------------------- /docs/make.bat: -------------------------------------------------------------------------------- 1 | @ECHO OFF 2 | 3 | pushd %~dp0 4 | 5 | REM Command file for Sphinx documentation 6 | 7 | if "%SPHINXBUILD%" == "" ( 8 | set SPHINXBUILD=sphinx-build 9 | ) 10 | set SOURCEDIR=source 11 | set BUILDDIR=build 12 | 13 | if "%1" == "" goto help 14 | 15 | %SPHINXBUILD% >NUL 2>NUL 16 | if errorlevel 9009 ( 17 | echo. 18 | echo.The 'sphinx-build' command was not found. Make sure you have Sphinx 19 | echo.installed, then set the SPHINXBUILD environment variable to point 20 | echo.to the full path of the 'sphinx-build' executable. Alternatively you 21 | echo.may add the Sphinx directory to PATH. 22 | echo. 23 | echo.If you don't have Sphinx installed, grab it from 24 | echo.http://sphinx-doc.org/ 25 | exit /b 1 26 | ) 27 | 28 | %SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% 29 | goto end 30 | 31 | :help 32 | %SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% 33 | 34 | :end 35 | popd 36 | -------------------------------------------------------------------------------- /docs/requirements.txt: -------------------------------------------------------------------------------- 1 | sphinx==4.4.0 2 | m2r2 3 | # nbsphinx 4 | sphinx-rtd-theme -------------------------------------------------------------------------------- /docs/source/causallib.analysis.rst: -------------------------------------------------------------------------------- 1 | causallib.analysis package 2 | ========================== 3 | 4 | Module contents 5 | --------------- 6 | 7 | .. automodule:: causallib.analysis 8 | :members: 9 | :undoc-members: 10 | :show-inheritance: 11 | -------------------------------------------------------------------------------- /docs/source/causallib.contrib.adversarial_balancing.adversarial_balancing.rst: -------------------------------------------------------------------------------- 1 | causallib.contrib.adversarial\_balancing.adversarial\_balancing module 2 | ====================================================================== 3 | 4 | .. automodule:: causallib.contrib.adversarial_balancing.adversarial_balancing 5 | :members: 6 | :undoc-members: 7 | :show-inheritance: 8 | -------------------------------------------------------------------------------- /docs/source/causallib.contrib.adversarial_balancing.classifier_selection.rst: -------------------------------------------------------------------------------- 1 | causallib.contrib.adversarial\_balancing.classifier\_selection module 2 | ===================================================================== 3 | 4 | .. automodule:: causallib.contrib.adversarial_balancing.classifier_selection 5 | :members: 6 | :undoc-members: 7 | :show-inheritance: 8 | -------------------------------------------------------------------------------- /docs/source/causallib.contrib.adversarial_balancing.rst: -------------------------------------------------------------------------------- 1 | causallib.contrib.adversarial\_balancing package 2 | ================================================ 3 | 4 | Submodules 5 | ---------- 6 | 7 | .. toctree:: 8 | :maxdepth: 4 9 | 10 | causallib.contrib.adversarial_balancing.adversarial_balancing 11 | causallib.contrib.adversarial_balancing.classifier_selection 12 | 13 | Module contents 14 | --------------- 15 | 16 | .. automodule:: causallib.contrib.adversarial_balancing 17 | :members: 18 | :undoc-members: 19 | :show-inheritance: 20 | -------------------------------------------------------------------------------- /docs/source/causallib.contrib.faissknn.rst: -------------------------------------------------------------------------------- 1 | causallib.contrib.faissknn module 2 | ================================= 3 | 4 | .. automodule:: causallib.contrib.faissknn 5 | :members: 6 | :undoc-members: 7 | :show-inheritance: 8 | -------------------------------------------------------------------------------- /docs/source/causallib.contrib.hemm.gen_synthetic_data.rst: -------------------------------------------------------------------------------- 1 | causallib.contrib.hemm.gen\_synthetic\_data module 2 | ================================================== 3 | 4 | .. automodule:: causallib.contrib.hemm.gen_synthetic_data 5 | :members: 6 | :undoc-members: 7 | :show-inheritance: 8 | -------------------------------------------------------------------------------- /docs/source/causallib.contrib.hemm.hemm.rst: -------------------------------------------------------------------------------- 1 | causallib.contrib.hemm.hemm module 2 | ================================== 3 | 4 | .. automodule:: causallib.contrib.hemm.hemm 5 | :members: 6 | :undoc-members: 7 | :show-inheritance: 8 | -------------------------------------------------------------------------------- /docs/source/causallib.contrib.hemm.hemm_api.rst: -------------------------------------------------------------------------------- 1 | causallib.contrib.hemm.hemm\_api module 2 | ======================================= 3 | 4 | .. automodule:: causallib.contrib.hemm.hemm_api 5 | :members: 6 | :undoc-members: 7 | :show-inheritance: 8 | -------------------------------------------------------------------------------- /docs/source/causallib.contrib.hemm.hemm_metrics.rst: -------------------------------------------------------------------------------- 1 | causallib.contrib.hemm.hemm\_metrics module 2 | =========================================== 3 | 4 | .. automodule:: causallib.contrib.hemm.hemm_metrics 5 | :members: 6 | :undoc-members: 7 | :show-inheritance: 8 | -------------------------------------------------------------------------------- /docs/source/causallib.contrib.hemm.hemm_outcome_models.rst: -------------------------------------------------------------------------------- 1 | causallib.contrib.hemm.hemm\_outcome\_models module 2 | =================================================== 3 | 4 | .. automodule:: causallib.contrib.hemm.hemm_outcome_models 5 | :members: 6 | :undoc-members: 7 | :show-inheritance: 8 | -------------------------------------------------------------------------------- /docs/source/causallib.contrib.hemm.hemm_utilities.rst: -------------------------------------------------------------------------------- 1 | causallib.contrib.hemm.hemm\_utilities module 2 | ============================================= 3 | 4 | .. automodule:: causallib.contrib.hemm.hemm_utilities 5 | :members: 6 | :undoc-members: 7 | :show-inheritance: 8 | -------------------------------------------------------------------------------- /docs/source/causallib.contrib.hemm.load_ihdp_data.rst: -------------------------------------------------------------------------------- 1 | causallib.contrib.hemm.load\_ihdp\_data module 2 | ============================================== 3 | 4 | .. automodule:: causallib.contrib.hemm.load_ihdp_data 5 | :members: 6 | :undoc-members: 7 | :show-inheritance: 8 | -------------------------------------------------------------------------------- /docs/source/causallib.contrib.hemm.rst: -------------------------------------------------------------------------------- 1 | 2 | .. mdinclude:: ../../causallib/contrib/hemm/README.md 3 | 4 | Submodules 5 | ---------- 6 | 7 | .. toctree:: 8 | :maxdepth: 4 9 | 10 | causallib.contrib.hemm.gen_synthetic_data 11 | causallib.contrib.hemm.hemm 12 | causallib.contrib.hemm.hemm_api 13 | causallib.contrib.hemm.hemm_metrics 14 | causallib.contrib.hemm.hemm_outcome_models 15 | causallib.contrib.hemm.hemm_utilities 16 | causallib.contrib.hemm.load_ihdp_data 17 | 18 | Module contents 19 | --------------- 20 | 21 | .. automodule:: causallib.contrib.hemm 22 | :members: 23 | :undoc-members: 24 | :show-inheritance: 25 | -------------------------------------------------------------------------------- /docs/source/causallib.contrib.rst: -------------------------------------------------------------------------------- 1 | 2 | .. mdinclude:: ../../causallib/contrib/README.md 3 | 4 | Subpackages 5 | ----------- 6 | 7 | .. toctree:: 8 | :maxdepth: 4 9 | 10 | causallib.contrib.adversarial_balancing 11 | causallib.contrib.hemm 12 | causallib.contrib.shared_sparsity_selection 13 | causallib.contrib.tests 14 | 15 | Submodules 16 | ---------- 17 | 18 | .. toctree:: 19 | :maxdepth: 4 20 | 21 | causallib.contrib.faissknn 22 | 23 | Module contents 24 | --------------- 25 | 26 | .. automodule:: causallib.contrib 27 | :members: 28 | :undoc-members: 29 | :show-inheritance: 30 | -------------------------------------------------------------------------------- /docs/source/causallib.contrib.shared_sparsity_selection.rst: -------------------------------------------------------------------------------- 1 | causallib.contrib.shared\_sparsity\_selection package 2 | ===================================================== 3 | 4 | Submodules 5 | ---------- 6 | 7 | .. toctree:: 8 | :maxdepth: 4 9 | 10 | causallib.contrib.shared_sparsity_selection.shared_sparsity_selection 11 | 12 | Module contents 13 | --------------- 14 | 15 | .. automodule:: causallib.contrib.shared_sparsity_selection 16 | :members: 17 | :undoc-members: 18 | :show-inheritance: 19 | -------------------------------------------------------------------------------- /docs/source/causallib.contrib.shared_sparsity_selection.shared_sparsity_selection.rst: -------------------------------------------------------------------------------- 1 | causallib.contrib.shared\_sparsity\_selection.shared\_sparsity\_selection module 2 | ================================================================================ 3 | 4 | .. automodule:: causallib.contrib.shared_sparsity_selection.shared_sparsity_selection 5 | :members: 6 | :undoc-members: 7 | :show-inheritance: 8 | -------------------------------------------------------------------------------- /docs/source/causallib.contrib.tests.rst: -------------------------------------------------------------------------------- 1 | causallib.contrib.tests package 2 | =============================== 3 | 4 | Submodules 5 | ---------- 6 | 7 | .. toctree:: 8 | :maxdepth: 4 9 | 10 | causallib.contrib.tests.test_adversarial_balancing 11 | causallib.contrib.tests.test_hemm 12 | causallib.contrib.tests.test_shared_sparsity_selection 13 | 14 | Module contents 15 | --------------- 16 | 17 | .. automodule:: causallib.contrib.tests 18 | :members: 19 | :undoc-members: 20 | :show-inheritance: 21 | -------------------------------------------------------------------------------- /docs/source/causallib.contrib.tests.test_adversarial_balancing.rst: -------------------------------------------------------------------------------- 1 | causallib.contrib.tests.test\_adversarial\_balancing module 2 | =========================================================== 3 | 4 | .. automodule:: causallib.contrib.tests.test_adversarial_balancing 5 | :members: 6 | :undoc-members: 7 | :show-inheritance: 8 | -------------------------------------------------------------------------------- /docs/source/causallib.contrib.tests.test_hemm.rst: -------------------------------------------------------------------------------- 1 | causallib.contrib.tests.test\_hemm module 2 | ========================================= 3 | 4 | .. automodule:: causallib.contrib.tests.test_hemm 5 | :members: 6 | :undoc-members: 7 | :show-inheritance: 8 | -------------------------------------------------------------------------------- /docs/source/causallib.contrib.tests.test_shared_sparsity_selection.rst: -------------------------------------------------------------------------------- 1 | causallib.contrib.tests.test\_shared\_sparsity\_selection module 2 | ================================================================ 3 | 4 | .. automodule:: causallib.contrib.tests.test_shared_sparsity_selection 5 | :members: 6 | :undoc-members: 7 | :show-inheritance: 8 | -------------------------------------------------------------------------------- /docs/source/causallib.datasets.data_loader.rst: -------------------------------------------------------------------------------- 1 | causallib.datasets.data\_loader module 2 | ====================================== 3 | 4 | .. automodule:: causallib.datasets.data_loader 5 | :members: 6 | :undoc-members: 7 | :show-inheritance: 8 | -------------------------------------------------------------------------------- /docs/source/causallib.datasets.rst: -------------------------------------------------------------------------------- 1 | 2 | .. mdinclude:: ../../causallib/datasets/README.md 3 | 4 | Submodules 5 | ---------- 6 | 7 | .. toctree:: 8 | :maxdepth: 4 9 | 10 | causallib.datasets.data_loader 11 | 12 | Module contents 13 | --------------- 14 | 15 | .. automodule:: causallib.datasets 16 | :members: 17 | :undoc-members: 18 | :show-inheritance: 19 | -------------------------------------------------------------------------------- /docs/source/causallib.estimation.base_estimator.rst: -------------------------------------------------------------------------------- 1 | causallib.estimation.base\_estimator module 2 | =========================================== 3 | 4 | .. automodule:: causallib.estimation.base_estimator 5 | :members: 6 | :undoc-members: 7 | :show-inheritance: 8 | -------------------------------------------------------------------------------- /docs/source/causallib.estimation.base_weight.rst: -------------------------------------------------------------------------------- 1 | causallib.estimation.base\_weight module 2 | ======================================== 3 | 4 | .. automodule:: causallib.estimation.base_weight 5 | :members: 6 | :undoc-members: 7 | :show-inheritance: 8 | -------------------------------------------------------------------------------- /docs/source/causallib.estimation.doubly_robust.rst: -------------------------------------------------------------------------------- 1 | causallib.estimation.doubly\_robust module 2 | ========================================== 3 | 4 | .. automodule:: causallib.estimation.doubly_robust 5 | :members: 6 | :undoc-members: 7 | :show-inheritance: 8 | -------------------------------------------------------------------------------- /docs/source/causallib.estimation.ipw.rst: -------------------------------------------------------------------------------- 1 | causallib.estimation.ipw module 2 | =============================== 3 | 4 | .. automodule:: causallib.estimation.ipw 5 | :members: 6 | :undoc-members: 7 | :show-inheritance: 8 | -------------------------------------------------------------------------------- /docs/source/causallib.estimation.marginal_outcome.rst: -------------------------------------------------------------------------------- 1 | causallib.estimation.marginal\_outcome module 2 | ============================================= 3 | 4 | .. automodule:: causallib.estimation.marginal_outcome 5 | :members: 6 | :undoc-members: 7 | :show-inheritance: 8 | -------------------------------------------------------------------------------- /docs/source/causallib.estimation.matching.rst: -------------------------------------------------------------------------------- 1 | causallib.estimation.matching module 2 | ==================================== 3 | 4 | .. automodule:: causallib.estimation.matching 5 | :members: 6 | :undoc-members: 7 | :show-inheritance: 8 | -------------------------------------------------------------------------------- /docs/source/causallib.estimation.overlap_weights.rst: -------------------------------------------------------------------------------- 1 | causallib.estimation.overlap\_weights module 2 | ============================================ 3 | 4 | .. automodule:: causallib.estimation.overlap_weights 5 | :members: 6 | :undoc-members: 7 | :show-inheritance: 8 | -------------------------------------------------------------------------------- /docs/source/causallib.estimation.rlearner.rst: -------------------------------------------------------------------------------- 1 | causallib.estimation.rlearner module 2 | ==================================== 3 | 4 | .. automodule:: causallib.estimation.rlearner 5 | :members: 6 | :undoc-members: 7 | :show-inheritance: 8 | -------------------------------------------------------------------------------- /docs/source/causallib.estimation.rst: -------------------------------------------------------------------------------- 1 | 2 | .. mdinclude:: ../../causallib/estimation/README.md 3 | 4 | Submodules 5 | ---------- 6 | 7 | .. toctree:: 8 | :maxdepth: 4 9 | 10 | causallib.estimation.base_estimator 11 | causallib.estimation.base_weight 12 | causallib.estimation.doubly_robust 13 | causallib.estimation.ipw 14 | causallib.estimation.marginal_outcome 15 | causallib.estimation.matching 16 | causallib.estimation.overlap_weights 17 | causallib.estimation.rlearner 18 | causallib.estimation.standardization 19 | causallib.estimation.tmle 20 | causallib.estimation.xlearner 21 | 22 | Module contents 23 | --------------- 24 | 25 | .. automodule:: causallib.estimation 26 | :members: 27 | :undoc-members: 28 | :show-inheritance: 29 | -------------------------------------------------------------------------------- /docs/source/causallib.estimation.standardization.rst: -------------------------------------------------------------------------------- 1 | causallib.estimation.standardization module 2 | =========================================== 3 | 4 | .. automodule:: causallib.estimation.standardization 5 | :members: 6 | :undoc-members: 7 | :show-inheritance: 8 | -------------------------------------------------------------------------------- /docs/source/causallib.estimation.tmle.rst: -------------------------------------------------------------------------------- 1 | causallib.estimation.tmle module 2 | ================================ 3 | 4 | .. automodule:: causallib.estimation.tmle 5 | :members: 6 | :undoc-members: 7 | :show-inheritance: 8 | -------------------------------------------------------------------------------- /docs/source/causallib.estimation.xlearner.rst: -------------------------------------------------------------------------------- 1 | causallib.estimation.xlearner module 2 | ==================================== 3 | 4 | .. automodule:: causallib.estimation.xlearner 5 | :members: 6 | :undoc-members: 7 | :show-inheritance: 8 | -------------------------------------------------------------------------------- /docs/source/causallib.evaluation.evaluator.rst: -------------------------------------------------------------------------------- 1 | causallib.evaluation.evaluator module 2 | ===================================== 3 | 4 | .. automodule:: causallib.evaluation.evaluator 5 | :members: 6 | :undoc-members: 7 | :show-inheritance: 8 | -------------------------------------------------------------------------------- /docs/source/causallib.evaluation.metrics.rst: -------------------------------------------------------------------------------- 1 | causallib.evaluation.metrics module 2 | =================================== 3 | 4 | .. automodule:: causallib.evaluation.metrics 5 | :members: 6 | :undoc-members: 7 | :show-inheritance: 8 | -------------------------------------------------------------------------------- /docs/source/causallib.evaluation.plots.curve_data_makers.rst: -------------------------------------------------------------------------------- 1 | causallib.evaluation.plots.curve\_data\_makers module 2 | ===================================================== 3 | 4 | .. automodule:: causallib.evaluation.plots.curve_data_makers 5 | :members: 6 | :undoc-members: 7 | :show-inheritance: 8 | -------------------------------------------------------------------------------- /docs/source/causallib.evaluation.plots.data_extractors.rst: -------------------------------------------------------------------------------- 1 | causallib.evaluation.plots.data\_extractors module 2 | ================================================== 3 | 4 | .. automodule:: causallib.evaluation.plots.data_extractors 5 | :members: 6 | :undoc-members: 7 | :show-inheritance: 8 | -------------------------------------------------------------------------------- /docs/source/causallib.evaluation.plots.mixins.rst: -------------------------------------------------------------------------------- 1 | causallib.evaluation.plots.mixins module 2 | ======================================== 3 | 4 | .. automodule:: causallib.evaluation.plots.mixins 5 | :members: 6 | :undoc-members: 7 | :show-inheritance: 8 | -------------------------------------------------------------------------------- /docs/source/causallib.evaluation.plots.plots.rst: -------------------------------------------------------------------------------- 1 | causallib.evaluation.plots.plots module 2 | ======================================= 3 | 4 | .. automodule:: causallib.evaluation.plots.plots 5 | :members: 6 | :undoc-members: 7 | :show-inheritance: 8 | -------------------------------------------------------------------------------- /docs/source/causallib.evaluation.plots.rst: -------------------------------------------------------------------------------- 1 | causallib.evaluation.plots package 2 | ================================== 3 | 4 | Submodules 5 | ---------- 6 | 7 | .. toctree:: 8 | :maxdepth: 4 9 | 10 | causallib.evaluation.plots.curve_data_makers 11 | causallib.evaluation.plots.data_extractors 12 | causallib.evaluation.plots.mixins 13 | causallib.evaluation.plots.plots 14 | 15 | Module contents 16 | --------------- 17 | 18 | .. automodule:: causallib.evaluation.plots 19 | :members: 20 | :undoc-members: 21 | :show-inheritance: 22 | -------------------------------------------------------------------------------- /docs/source/causallib.evaluation.predictions.rst: -------------------------------------------------------------------------------- 1 | causallib.evaluation.predictions module 2 | ======================================= 3 | 4 | .. automodule:: causallib.evaluation.predictions 5 | :members: 6 | :undoc-members: 7 | :show-inheritance: 8 | -------------------------------------------------------------------------------- /docs/source/causallib.evaluation.predictor.rst: -------------------------------------------------------------------------------- 1 | causallib.evaluation.predictor module 2 | ===================================== 3 | 4 | .. automodule:: causallib.evaluation.predictor 5 | :members: 6 | :undoc-members: 7 | :show-inheritance: 8 | -------------------------------------------------------------------------------- /docs/source/causallib.evaluation.results.rst: -------------------------------------------------------------------------------- 1 | causallib.evaluation.results module 2 | =================================== 3 | 4 | .. automodule:: causallib.evaluation.results 5 | :members: 6 | :undoc-members: 7 | :show-inheritance: 8 | -------------------------------------------------------------------------------- /docs/source/causallib.evaluation.rst: -------------------------------------------------------------------------------- 1 | 2 | .. mdinclude:: ../../causallib/evaluation/README.md 3 | 4 | Subpackages 5 | ----------- 6 | 7 | .. toctree:: 8 | :maxdepth: 4 9 | 10 | causallib.evaluation.plots 11 | 12 | Submodules 13 | ---------- 14 | 15 | .. toctree:: 16 | :maxdepth: 4 17 | 18 | causallib.evaluation.evaluator 19 | causallib.evaluation.metrics 20 | causallib.evaluation.predictions 21 | causallib.evaluation.predictor 22 | causallib.evaluation.results 23 | causallib.evaluation.scoring 24 | 25 | Module contents 26 | --------------- 27 | 28 | .. automodule:: causallib.evaluation 29 | :members: 30 | :undoc-members: 31 | :show-inheritance: 32 | -------------------------------------------------------------------------------- /docs/source/causallib.evaluation.scoring.rst: -------------------------------------------------------------------------------- 1 | causallib.evaluation.scoring module 2 | =================================== 3 | 4 | .. automodule:: causallib.evaluation.scoring 5 | :members: 6 | :undoc-members: 7 | :show-inheritance: 8 | -------------------------------------------------------------------------------- /docs/source/causallib.preprocessing.confounder_selection.rst: -------------------------------------------------------------------------------- 1 | causallib.preprocessing.confounder\_selection module 2 | ==================================================== 3 | 4 | .. automodule:: causallib.preprocessing.confounder_selection 5 | :members: 6 | :undoc-members: 7 | :show-inheritance: 8 | -------------------------------------------------------------------------------- /docs/source/causallib.preprocessing.filters.rst: -------------------------------------------------------------------------------- 1 | causallib.preprocessing.filters module 2 | ====================================== 3 | 4 | .. automodule:: causallib.preprocessing.filters 5 | :members: 6 | :undoc-members: 7 | :show-inheritance: 8 | -------------------------------------------------------------------------------- /docs/source/causallib.preprocessing.rst: -------------------------------------------------------------------------------- 1 | 2 | .. mdinclude:: ../../causallib/preprocessing/README.md 3 | 4 | Submodules 5 | ---------- 6 | 7 | .. toctree:: 8 | :maxdepth: 4 9 | 10 | causallib.preprocessing.confounder_selection 11 | causallib.preprocessing.filters 12 | causallib.preprocessing.transformers 13 | 14 | Module contents 15 | --------------- 16 | 17 | .. automodule:: causallib.preprocessing 18 | :members: 19 | :undoc-members: 20 | :show-inheritance: 21 | -------------------------------------------------------------------------------- /docs/source/causallib.preprocessing.transformers.rst: -------------------------------------------------------------------------------- 1 | causallib.preprocessing.transformers module 2 | =========================================== 3 | 4 | .. automodule:: causallib.preprocessing.transformers 5 | :members: 6 | :undoc-members: 7 | :show-inheritance: 8 | -------------------------------------------------------------------------------- /docs/source/causallib.rst: -------------------------------------------------------------------------------- 1 | 2 | .. mdinclude:: ../../causallib/README.md 3 | 4 | Subpackages 5 | ----------- 6 | 7 | .. toctree:: 8 | :maxdepth: 4 9 | 10 | causallib.analysis 11 | causallib.contrib 12 | causallib.datasets 13 | causallib.estimation 14 | causallib.evaluation 15 | causallib.preprocessing 16 | causallib.simulation 17 | causallib.survival 18 | causallib.tests 19 | causallib.utils 20 | 21 | Module contents 22 | --------------- 23 | 24 | .. automodule:: causallib 25 | :members: 26 | :undoc-members: 27 | :show-inheritance: 28 | -------------------------------------------------------------------------------- /docs/source/causallib.simulation.CausalSimulator3.rst: -------------------------------------------------------------------------------- 1 | causallib.simulation.CausalSimulator3 module 2 | ============================================ 3 | 4 | .. automodule:: causallib.simulation.CausalSimulator3 5 | :members: 6 | :undoc-members: 7 | :show-inheritance: 8 | -------------------------------------------------------------------------------- /docs/source/causallib.simulation.rst: -------------------------------------------------------------------------------- 1 | causallib.simulation package 2 | ============================ 3 | 4 | Submodules 5 | ---------- 6 | 7 | .. toctree:: 8 | :maxdepth: 4 9 | 10 | causallib.simulation.CausalSimulator3 11 | 12 | Module contents 13 | --------------- 14 | 15 | .. automodule:: causallib.simulation 16 | :members: 17 | :undoc-members: 18 | :show-inheritance: 19 | -------------------------------------------------------------------------------- /docs/source/causallib.survival.base_survival.rst: -------------------------------------------------------------------------------- 1 | causallib.survival.base\_survival module 2 | ======================================== 3 | 4 | .. automodule:: causallib.survival.base_survival 5 | :members: 6 | :undoc-members: 7 | :show-inheritance: 8 | -------------------------------------------------------------------------------- /docs/source/causallib.survival.marginal_survival.rst: -------------------------------------------------------------------------------- 1 | causallib.survival.marginal\_survival module 2 | ============================================ 3 | 4 | .. automodule:: causallib.survival.marginal_survival 5 | :members: 6 | :undoc-members: 7 | :show-inheritance: 8 | -------------------------------------------------------------------------------- /docs/source/causallib.survival.regression_curve_fitter.rst: -------------------------------------------------------------------------------- 1 | causallib.survival.regression\_curve\_fitter module 2 | =================================================== 3 | 4 | .. automodule:: causallib.survival.regression_curve_fitter 5 | :members: 6 | :undoc-members: 7 | :show-inheritance: 8 | -------------------------------------------------------------------------------- /docs/source/causallib.survival.rst: -------------------------------------------------------------------------------- 1 | 2 | .. mdinclude:: ../../causallib/survival/README.md 3 | 4 | Submodules 5 | ---------- 6 | 7 | .. toctree:: 8 | :maxdepth: 4 9 | 10 | causallib.survival.base_survival 11 | causallib.survival.marginal_survival 12 | causallib.survival.regression_curve_fitter 13 | causallib.survival.standardized_survival 14 | causallib.survival.survival_utils 15 | causallib.survival.univariate_curve_fitter 16 | causallib.survival.weighted_standardized_survival 17 | causallib.survival.weighted_survival 18 | 19 | Module contents 20 | --------------- 21 | 22 | .. automodule:: causallib.survival 23 | :members: 24 | :undoc-members: 25 | :show-inheritance: 26 | -------------------------------------------------------------------------------- /docs/source/causallib.survival.standardized_survival.rst: -------------------------------------------------------------------------------- 1 | causallib.survival.standardized\_survival module 2 | ================================================ 3 | 4 | .. automodule:: causallib.survival.standardized_survival 5 | :members: 6 | :undoc-members: 7 | :show-inheritance: 8 | -------------------------------------------------------------------------------- /docs/source/causallib.survival.survival_utils.rst: -------------------------------------------------------------------------------- 1 | causallib.survival.survival\_utils module 2 | ========================================= 3 | 4 | .. automodule:: causallib.survival.survival_utils 5 | :members: 6 | :undoc-members: 7 | :show-inheritance: 8 | -------------------------------------------------------------------------------- /docs/source/causallib.survival.univariate_curve_fitter.rst: -------------------------------------------------------------------------------- 1 | causallib.survival.univariate\_curve\_fitter module 2 | =================================================== 3 | 4 | .. automodule:: causallib.survival.univariate_curve_fitter 5 | :members: 6 | :undoc-members: 7 | :show-inheritance: 8 | -------------------------------------------------------------------------------- /docs/source/causallib.survival.weighted_standardized_survival.rst: -------------------------------------------------------------------------------- 1 | causallib.survival.weighted\_standardized\_survival module 2 | ========================================================== 3 | 4 | .. automodule:: causallib.survival.weighted_standardized_survival 5 | :members: 6 | :undoc-members: 7 | :show-inheritance: 8 | -------------------------------------------------------------------------------- /docs/source/causallib.survival.weighted_survival.rst: -------------------------------------------------------------------------------- 1 | causallib.survival.weighted\_survival module 2 | ============================================ 3 | 4 | .. automodule:: causallib.survival.weighted_survival 5 | :members: 6 | :undoc-members: 7 | :show-inheritance: 8 | -------------------------------------------------------------------------------- /docs/source/causallib.tests.rst: -------------------------------------------------------------------------------- 1 | causallib.tests package 2 | ======================= 3 | 4 | Submodules 5 | ---------- 6 | 7 | .. toctree:: 8 | :maxdepth: 4 9 | 10 | causallib.tests.test_base_weight 11 | causallib.tests.test_causal_simulator3 12 | causallib.tests.test_confounder_selection 13 | causallib.tests.test_datasets 14 | causallib.tests.test_doublyrobust 15 | causallib.tests.test_evaluation 16 | causallib.tests.test_ipw 17 | causallib.tests.test_marginal_outcome 18 | causallib.tests.test_matching 19 | causallib.tests.test_overlap_weights 20 | causallib.tests.test_plots 21 | causallib.tests.test_rlearner 22 | causallib.tests.test_standardization 23 | causallib.tests.test_survival 24 | causallib.tests.test_tmle 25 | causallib.tests.test_transformers 26 | causallib.tests.test_utils 27 | causallib.tests.test_xlearner 28 | 29 | Module contents 30 | --------------- 31 | 32 | .. automodule:: causallib.tests 33 | :members: 34 | :undoc-members: 35 | :show-inheritance: 36 | -------------------------------------------------------------------------------- /docs/source/causallib.tests.test_base_weight.rst: -------------------------------------------------------------------------------- 1 | causallib.tests.test\_base\_weight module 2 | ========================================= 3 | 4 | .. automodule:: causallib.tests.test_base_weight 5 | :members: 6 | :undoc-members: 7 | :show-inheritance: 8 | -------------------------------------------------------------------------------- /docs/source/causallib.tests.test_causal_simulator3.rst: -------------------------------------------------------------------------------- 1 | causallib.tests.test\_causal\_simulator3 module 2 | =============================================== 3 | 4 | .. automodule:: causallib.tests.test_causal_simulator3 5 | :members: 6 | :undoc-members: 7 | :show-inheritance: 8 | -------------------------------------------------------------------------------- /docs/source/causallib.tests.test_confounder_selection.rst: -------------------------------------------------------------------------------- 1 | causallib.tests.test\_confounder\_selection module 2 | ================================================== 3 | 4 | .. automodule:: causallib.tests.test_confounder_selection 5 | :members: 6 | :undoc-members: 7 | :show-inheritance: 8 | -------------------------------------------------------------------------------- /docs/source/causallib.tests.test_datasets.rst: -------------------------------------------------------------------------------- 1 | causallib.tests.test\_datasets module 2 | ===================================== 3 | 4 | .. automodule:: causallib.tests.test_datasets 5 | :members: 6 | :undoc-members: 7 | :show-inheritance: 8 | -------------------------------------------------------------------------------- /docs/source/causallib.tests.test_doublyrobust.rst: -------------------------------------------------------------------------------- 1 | causallib.tests.test\_doublyrobust module 2 | ========================================= 3 | 4 | .. automodule:: causallib.tests.test_doublyrobust 5 | :members: 6 | :undoc-members: 7 | :show-inheritance: 8 | -------------------------------------------------------------------------------- /docs/source/causallib.tests.test_evaluation.rst: -------------------------------------------------------------------------------- 1 | causallib.tests.test\_evaluation module 2 | ======================================= 3 | 4 | .. automodule:: causallib.tests.test_evaluation 5 | :members: 6 | :undoc-members: 7 | :show-inheritance: 8 | -------------------------------------------------------------------------------- /docs/source/causallib.tests.test_ipw.rst: -------------------------------------------------------------------------------- 1 | causallib.tests.test\_ipw module 2 | ================================ 3 | 4 | .. automodule:: causallib.tests.test_ipw 5 | :members: 6 | :undoc-members: 7 | :show-inheritance: 8 | -------------------------------------------------------------------------------- /docs/source/causallib.tests.test_marginal_outcome.rst: -------------------------------------------------------------------------------- 1 | causallib.tests.test\_marginal\_outcome module 2 | ============================================== 3 | 4 | .. automodule:: causallib.tests.test_marginal_outcome 5 | :members: 6 | :undoc-members: 7 | :show-inheritance: 8 | -------------------------------------------------------------------------------- /docs/source/causallib.tests.test_matching.rst: -------------------------------------------------------------------------------- 1 | causallib.tests.test\_matching module 2 | ===================================== 3 | 4 | .. automodule:: causallib.tests.test_matching 5 | :members: 6 | :undoc-members: 7 | :show-inheritance: 8 | -------------------------------------------------------------------------------- /docs/source/causallib.tests.test_overlap_weights.rst: -------------------------------------------------------------------------------- 1 | causallib.tests.test\_overlap\_weights module 2 | ============================================= 3 | 4 | .. automodule:: causallib.tests.test_overlap_weights 5 | :members: 6 | :undoc-members: 7 | :show-inheritance: 8 | -------------------------------------------------------------------------------- /docs/source/causallib.tests.test_plots.rst: -------------------------------------------------------------------------------- 1 | causallib.tests.test\_plots module 2 | ================================== 3 | 4 | .. automodule:: causallib.tests.test_plots 5 | :members: 6 | :undoc-members: 7 | :show-inheritance: 8 | -------------------------------------------------------------------------------- /docs/source/causallib.tests.test_rlearner.rst: -------------------------------------------------------------------------------- 1 | causallib.tests.test\_rlearner module 2 | ===================================== 3 | 4 | .. automodule:: causallib.tests.test_rlearner 5 | :members: 6 | :undoc-members: 7 | :show-inheritance: 8 | -------------------------------------------------------------------------------- /docs/source/causallib.tests.test_standardization.rst: -------------------------------------------------------------------------------- 1 | causallib.tests.test\_standardization module 2 | ============================================ 3 | 4 | .. automodule:: causallib.tests.test_standardization 5 | :members: 6 | :undoc-members: 7 | :show-inheritance: 8 | -------------------------------------------------------------------------------- /docs/source/causallib.tests.test_survival.rst: -------------------------------------------------------------------------------- 1 | causallib.tests.test\_survival module 2 | ===================================== 3 | 4 | .. automodule:: causallib.tests.test_survival 5 | :members: 6 | :undoc-members: 7 | :show-inheritance: 8 | -------------------------------------------------------------------------------- /docs/source/causallib.tests.test_tmle.rst: -------------------------------------------------------------------------------- 1 | causallib.tests.test\_tmle module 2 | ================================= 3 | 4 | .. automodule:: causallib.tests.test_tmle 5 | :members: 6 | :undoc-members: 7 | :show-inheritance: 8 | -------------------------------------------------------------------------------- /docs/source/causallib.tests.test_transformers.rst: -------------------------------------------------------------------------------- 1 | causallib.tests.test\_transformers module 2 | ========================================= 3 | 4 | .. automodule:: causallib.tests.test_transformers 5 | :members: 6 | :undoc-members: 7 | :show-inheritance: 8 | -------------------------------------------------------------------------------- /docs/source/causallib.tests.test_utils.rst: -------------------------------------------------------------------------------- 1 | causallib.tests.test\_utils module 2 | ================================== 3 | 4 | .. automodule:: causallib.tests.test_utils 5 | :members: 6 | :undoc-members: 7 | :show-inheritance: 8 | -------------------------------------------------------------------------------- /docs/source/causallib.tests.test_xlearner.rst: -------------------------------------------------------------------------------- 1 | causallib.tests.test\_xlearner module 2 | ===================================== 3 | 4 | .. automodule:: causallib.tests.test_xlearner 5 | :members: 6 | :undoc-members: 7 | :show-inheritance: 8 | -------------------------------------------------------------------------------- /docs/source/causallib.utils.crossfit.rst: -------------------------------------------------------------------------------- 1 | causallib.utils.crossfit module 2 | =============================== 3 | 4 | .. automodule:: causallib.utils.crossfit 5 | :members: 6 | :undoc-members: 7 | :show-inheritance: 8 | -------------------------------------------------------------------------------- /docs/source/causallib.utils.general_tools.rst: -------------------------------------------------------------------------------- 1 | causallib.utils.general\_tools module 2 | ===================================== 3 | 4 | .. automodule:: causallib.utils.general_tools 5 | :members: 6 | :undoc-members: 7 | :show-inheritance: 8 | -------------------------------------------------------------------------------- /docs/source/causallib.utils.rst: -------------------------------------------------------------------------------- 1 | causallib.utils package 2 | ======================= 3 | 4 | Submodules 5 | ---------- 6 | 7 | .. toctree:: 8 | :maxdepth: 4 9 | 10 | causallib.utils.crossfit 11 | causallib.utils.general_tools 12 | causallib.utils.stat_utils 13 | 14 | Module contents 15 | --------------- 16 | 17 | .. automodule:: causallib.utils 18 | :members: 19 | :undoc-members: 20 | :show-inheritance: 21 | -------------------------------------------------------------------------------- /docs/source/causallib.utils.stat_utils.rst: -------------------------------------------------------------------------------- 1 | causallib.utils.stat\_utils module 2 | ================================== 3 | 4 | .. automodule:: causallib.utils.stat_utils 5 | :members: 6 | :undoc-members: 7 | :show-inheritance: 8 | -------------------------------------------------------------------------------- /docs/source/index.rst: -------------------------------------------------------------------------------- 1 | .. causallib documentation master file, created by 2 | sphinx-quickstart on Thu Jun 6 11:47:18 2019. 3 | You can adapt this file completely to your liking, but it should at least 4 | contain the root `toctree` directive. 5 | 6 | Welcome to causallib's documentation! 7 | ===================================== 8 | .. mdinclude:: ../../README.md 9 | 10 | 11 | Examples 12 | ======== 13 | Comprehensive Jupyter Notebooks examples can be found in the `examples directory on GitHub `_. 14 | 15 | 16 | Documentation 17 | ============= 18 | .. toctree:: 19 | :titlesonly: 20 | :maxdepth: 5 21 | :glob: 22 | 23 | causallib 24 | 25 | 26 | 27 | Indices and tables 28 | ================== 29 | 30 | * :ref:`genindex` 31 | * :ref:`modindex` 32 | * :ref:`search` 33 | -------------------------------------------------------------------------------- /docs/source/modules.rst: -------------------------------------------------------------------------------- 1 | causallib 2 | ========= 3 | 4 | .. toctree:: 5 | :maxdepth: 4 6 | 7 | causallib 8 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | pandas>=0.25.2,<3 2 | scipy>=0.19,<2 3 | statsmodels>=0.9,<1 4 | networkx>=1.1,<4 5 | numpy>=1.13,<3 6 | scikit-learn>=0.20,<2 7 | matplotlib>=2.2,<4 8 | dataclasses>=0.8;python_version < '3.7' 9 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, find_packages 2 | import os 3 | 4 | 5 | GIT_URL = "https://github.com/BiomedSciAI/causallib" 6 | 7 | this_directory = os.path.abspath(os.path.dirname(__file__)) 8 | with open(os.path.join(this_directory, 'README.md'), encoding='utf-8') as f: 9 | long_description = f.read() 10 | 11 | 12 | def get_lines(filename): 13 | with open(filename, 'r') as f: 14 | return f.read().splitlines() 15 | 16 | 17 | def get_valid_packages_from_requirement_file(file_path): 18 | lines = get_lines(file_path) 19 | # Filter out non-package lines that are legal for `pip install -r` but fail for setuptools' `require`: 20 | pkg_list = [p for p in lines if p.lstrip()[0].isalnum()] 21 | return pkg_list 22 | 23 | 24 | def get_version(filename): 25 | with open(filename, 'r') as f: 26 | for line in f.readlines(): 27 | if line.startswith('__version__'): 28 | quotes_type = '"' if '"' in line else "'" 29 | version = line.split(quotes_type)[1] 30 | return version 31 | raise RuntimeError("Unable to find version string.") 32 | 33 | 34 | setup(name='causallib', 35 | version=get_version(os.path.join('causallib', '__init__.py')), 36 | # packages=find_packages(exclude=['scripts', 'data', 'tests']), 37 | packages=find_packages(), 38 | description='A Python package for flexible and modular causal inference modeling', 39 | long_description=long_description, 40 | long_description_content_type='text/markdown', 41 | url=GIT_URL, 42 | author='Causal Machine Learning for Healthcare and Life Sciences, IBM Research Israel', 43 | # author_email=None, 44 | license="Apache License 2.0", 45 | keywords="causal inference effect estimation causality", 46 | install_requires=get_valid_packages_from_requirement_file("requirements.txt"), 47 | extras_require={ 48 | 'contrib': get_valid_packages_from_requirement_file(os.path.join("causallib", "contrib", "requirements.txt")), 49 | 'docs': get_valid_packages_from_requirement_file(os.path.join("docs", "requirements.txt")) 50 | }, 51 | # include_package_data=True, 52 | package_data={ 53 | 'causallib': [os.path.join('datasets', 'data', '*/*.csv')] 54 | }, 55 | project_urls={ 56 | 'Documentation': 'https://causallib.readthedocs.io/en/latest/', 57 | 'Source Code': GIT_URL, 58 | 'Bug Tracker': GIT_URL + '/issues', 59 | }, 60 | classifiers=[ 61 | "Programming Language :: Python :: 3.6", 62 | "Programming Language :: Python :: 3.7", 63 | "Programming Language :: Python :: 3.8", 64 | "Programming Language :: Python :: 3.9", 65 | "License :: OSI Approved :: Apache Software License", 66 | "Development Status :: 4 - Beta", 67 | "Topic :: Scientific/Engineering", 68 | "Intended Audience :: Science/Research" 69 | ] 70 | ) 71 | --------------------------------------------------------------------------------