├── tests ├── __init__.py ├── unit │ ├── __init__.py │ ├── test_tmle_core.py │ ├── test_gps_classifier.py │ ├── test_core.py │ ├── test_gps_core.py │ ├── test_mediation.py │ ├── test_tmle_regressor.py │ └── test_gps_regressor.py ├── integration │ ├── __init__.py │ ├── test_tmle.py │ ├── test_mediation.py │ └── test_gps.py ├── test_helpers.py └── conftest.py ├── imgs ├── curves.png ├── cdrc │ └── CDRC.png ├── tmle_plot.png ├── welcome_plot.png ├── binary_OR_fig.png ├── full_example │ ├── BLL_dist.png │ ├── test_dist.png │ ├── lead_paint_can.jpeg │ ├── mediation_curve.png │ └── test_causal_curves.png └── mediation │ ├── diabetes_DAG.png │ └── mediation_effect.png ├── paper ├── welcome_plot.png ├── paper.bib └── paper.md ├── docs ├── modules.rst ├── requirements.txt ├── citation.rst ├── Makefile ├── make.bat ├── install.rst ├── causal_curve.rst ├── conf.py ├── GPS_Classifier.rst ├── TMLE_Regressor.rst ├── index.rst ├── GPS_Regressor.rst ├── Mediation_example.rst ├── changelog.rst ├── intro.rst ├── contribute.rst └── full_example.rst ├── codecov.yml ├── requirements.txt ├── causal_curve ├── __init__.py ├── core.py ├── gps_classifier.py ├── tmle_regressor.py ├── gps_regressor.py ├── tmle_core.py └── mediation.py ├── .travis.yml ├── LICENSE ├── .gitignore ├── setup.py ├── README.md └── .pylintrc /tests/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tests/unit/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tests/integration/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /imgs/curves.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ronikobrosly/causal-curve/HEAD/imgs/curves.png -------------------------------------------------------------------------------- /imgs/cdrc/CDRC.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ronikobrosly/causal-curve/HEAD/imgs/cdrc/CDRC.png -------------------------------------------------------------------------------- /imgs/tmle_plot.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ronikobrosly/causal-curve/HEAD/imgs/tmle_plot.png -------------------------------------------------------------------------------- /imgs/welcome_plot.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ronikobrosly/causal-curve/HEAD/imgs/welcome_plot.png -------------------------------------------------------------------------------- /imgs/binary_OR_fig.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ronikobrosly/causal-curve/HEAD/imgs/binary_OR_fig.png -------------------------------------------------------------------------------- /paper/welcome_plot.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ronikobrosly/causal-curve/HEAD/paper/welcome_plot.png -------------------------------------------------------------------------------- /docs/modules.rst: -------------------------------------------------------------------------------- 1 | causal_curve 2 | ============ 3 | 4 | .. toctree:: 5 | :maxdepth: 4 6 | 7 | causal_curve 8 | -------------------------------------------------------------------------------- /imgs/full_example/BLL_dist.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ronikobrosly/causal-curve/HEAD/imgs/full_example/BLL_dist.png -------------------------------------------------------------------------------- /imgs/full_example/test_dist.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ronikobrosly/causal-curve/HEAD/imgs/full_example/test_dist.png -------------------------------------------------------------------------------- /imgs/mediation/diabetes_DAG.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ronikobrosly/causal-curve/HEAD/imgs/mediation/diabetes_DAG.png -------------------------------------------------------------------------------- /imgs/mediation/mediation_effect.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ronikobrosly/causal-curve/HEAD/imgs/mediation/mediation_effect.png -------------------------------------------------------------------------------- /imgs/full_example/lead_paint_can.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ronikobrosly/causal-curve/HEAD/imgs/full_example/lead_paint_can.jpeg -------------------------------------------------------------------------------- /imgs/full_example/mediation_curve.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ronikobrosly/causal-curve/HEAD/imgs/full_example/mediation_curve.png -------------------------------------------------------------------------------- /imgs/full_example/test_causal_curves.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ronikobrosly/causal-curve/HEAD/imgs/full_example/test_causal_curves.png -------------------------------------------------------------------------------- /codecov.yml: -------------------------------------------------------------------------------- 1 | coverage: 2 | status: 3 | project: 4 | default: 5 | threshold: 0% 6 | patch: 7 | default: 8 | threshold: 0% 9 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | black 2 | coverage 3 | future 4 | joblib 5 | numpy 6 | numpydoc 7 | pandas 8 | patsy 9 | progressbar2 10 | pygam 11 | pytest 12 | python-dateutil 13 | python-utils 14 | pytz 15 | scikit-learn 16 | scipy 17 | six 18 | sphinx_rtd_theme 19 | statsmodels 20 | -------------------------------------------------------------------------------- /docs/requirements.txt: -------------------------------------------------------------------------------- 1 | black 2 | coverage 3 | future 4 | joblib 5 | numpy 6 | numpydoc 7 | pandas 8 | patsy 9 | progressbar2 10 | pygam 11 | pytest 12 | python-dateutil 13 | python-utils 14 | pytz 15 | scikit-learn 16 | scipy 17 | six 18 | sphinx_rtd_theme 19 | statsmodels 20 | -------------------------------------------------------------------------------- /docs/citation.rst: -------------------------------------------------------------------------------- 1 | Citation 2 | ======== 3 | 4 | Please consider citing us in your academic or industry project. 5 | 6 | Kobrosly, R. W., (2020). causal-curve: A Python Causal Inference Package to Estimate 7 | Causal Dose-Response Curves. Journal of Open Source Software, 5(52), 2523, https://doi.org/10.21105/joss.02523 8 | -------------------------------------------------------------------------------- /tests/test_helpers.py: -------------------------------------------------------------------------------- 1 | """Misc helper functions for tests""" 2 | 3 | from pandas.testing import assert_frame_equal 4 | 5 | 6 | def assert_df_equal(observed_frame, expected_frame, check_less_precise): 7 | """Assert that two pandas dataframes are equal, ignoring ordering of columns.""" 8 | assert_frame_equal( 9 | observed_frame.sort_index(axis=1), 10 | expected_frame.sort_index(axis=1), 11 | check_names=True, 12 | check_less_precise=check_less_precise, 13 | ) 14 | -------------------------------------------------------------------------------- /causal_curve/__init__.py: -------------------------------------------------------------------------------- 1 | """causal_curve module""" 2 | 3 | import warnings 4 | 5 | from statsmodels.genmod.generalized_linear_model import DomainWarning 6 | 7 | from causal_curve.gps_classifier import GPS_Classifier 8 | from causal_curve.gps_regressor import GPS_Regressor 9 | 10 | from causal_curve.tmle_regressor import TMLE_Regressor 11 | from causal_curve.mediation import Mediation 12 | 13 | 14 | # Suppress statsmodel warning for gamma family GLM 15 | warnings.filterwarnings("ignore", category=DomainWarning) 16 | warnings.filterwarnings("ignore", category=UserWarning) 17 | -------------------------------------------------------------------------------- /.travis.yml: -------------------------------------------------------------------------------- 1 | language: python 2 | 3 | # sudo false implies containerized builds 4 | sudo: false 5 | 6 | python: 7 | - 3.6 8 | 9 | env: 10 | global: 11 | # test invocation 12 | - TESTFOLDER="tests" 13 | 14 | before_install: 15 | # Here we download miniconda and install the dependencies 16 | - pip install black coverage future joblib numpy numpydoc pandas patsy progressbar2 pygam pytest python-dateutil python-utils pytz scikit-learn scipy six sphinx_rtd_theme statsmodels 17 | 18 | install: 19 | - python setup.py install 20 | 21 | script: 22 | - coverage run -m pytest $TESTFOLDER 23 | 24 | after_success: 25 | - bash <(curl -s https://codecov.io/bash) 26 | -------------------------------------------------------------------------------- /docs/Makefile: -------------------------------------------------------------------------------- 1 | # Minimal makefile for Sphinx documentation 2 | # 3 | 4 | # You can set these variables from the command line, and also 5 | # from the environment for the first two. 6 | SPHINXOPTS ?= 7 | SPHINXBUILD ?= sphinx-build 8 | SOURCEDIR = . 9 | BUILDDIR = _build 10 | 11 | # Put it first so that "make" without argument is like "make help". 12 | help: 13 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 14 | 15 | .PHONY: help Makefile 16 | 17 | # Catch-all target: route all unknown targets to Sphinx using the new 18 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). 19 | %: Makefile 20 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 21 | -------------------------------------------------------------------------------- /tests/integration/test_tmle.py: -------------------------------------------------------------------------------- 1 | """ Integration tests of the tmle.py module """ 2 | 3 | import pandas as pd 4 | 5 | from causal_curve import TMLE_Regressor 6 | 7 | 8 | def test_full_tmle_flow(continuous_dataset_fixture): 9 | """ 10 | Tests the full flow of the TMLE tool 11 | """ 12 | 13 | tmle = TMLE_Regressor( 14 | random_seed=100, 15 | verbose=True, 16 | ) 17 | tmle.fit( 18 | T=continuous_dataset_fixture["treatment"], 19 | X=continuous_dataset_fixture[["x1", "x2"]], 20 | y=continuous_dataset_fixture["outcome"], 21 | ) 22 | tmle_results = tmle.calculate_CDRC(0.95) 23 | 24 | assert isinstance(tmle_results, pd.DataFrame) 25 | check = tmle_results.columns == [ 26 | "Treatment", 27 | "Causal_Dose_Response", 28 | "Lower_CI", 29 | "Upper_CI", 30 | ] 31 | assert check.all() 32 | -------------------------------------------------------------------------------- /docs/make.bat: -------------------------------------------------------------------------------- 1 | @ECHO OFF 2 | 3 | pushd %~dp0 4 | 5 | REM Command file for Sphinx documentation 6 | 7 | if "%SPHINXBUILD%" == "" ( 8 | set SPHINXBUILD=sphinx-build 9 | ) 10 | set SOURCEDIR=. 11 | set BUILDDIR=_build 12 | 13 | if "%1" == "" goto help 14 | 15 | %SPHINXBUILD% >NUL 2>NUL 16 | if errorlevel 9009 ( 17 | echo. 18 | echo.The 'sphinx-build' command was not found. Make sure you have Sphinx 19 | echo.installed, then set the SPHINXBUILD environment variable to point 20 | echo.to the full path of the 'sphinx-build' executable. Alternatively you 21 | echo.may add the Sphinx directory to PATH. 22 | echo. 23 | echo.If you don't have Sphinx installed, grab it from 24 | echo.http://sphinx-doc.org/ 25 | exit /b 1 26 | ) 27 | 28 | %SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% 29 | goto end 30 | 31 | :help 32 | %SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% 33 | 34 | :end 35 | popd 36 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 Roni Kobrosly 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /tests/integration/test_mediation.py: -------------------------------------------------------------------------------- 1 | """ Integration tests of the mediation.py module """ 2 | 3 | import pandas as pd 4 | 5 | from causal_curve import Mediation 6 | 7 | 8 | def test_full_mediation_flow(mediation_fixture): 9 | """ 10 | Tests the full flow of the Mediation tool 11 | """ 12 | 13 | med = Mediation( 14 | treatment_grid_num=10, 15 | lower_grid_constraint=0.01, 16 | upper_grid_constraint=0.99, 17 | bootstrap_draws=100, 18 | bootstrap_replicates=50, 19 | spline_order=3, 20 | n_splines=5, 21 | lambda_=0.5, 22 | max_iter=20, 23 | random_seed=None, 24 | verbose=True, 25 | ) 26 | med.fit( 27 | T=mediation_fixture["treatment"], 28 | M=mediation_fixture["mediator"], 29 | y=mediation_fixture["outcome"], 30 | ) 31 | 32 | med_results = med.calculate_mediation(0.95) 33 | 34 | assert isinstance(med_results, pd.DataFrame) 35 | check = med_results.columns == [ 36 | "Treatment_Value", 37 | "Proportion_Direct_Effect", 38 | "Proportion_Indirect_Effect", 39 | ] 40 | assert check.all() 41 | -------------------------------------------------------------------------------- /tests/unit/test_tmle_core.py: -------------------------------------------------------------------------------- 1 | """ Unit tests of the tmle.py module """ 2 | 3 | import pytest 4 | 5 | from causal_curve.tmle_core import TMLE_Core 6 | 7 | 8 | def test_tmle_fit(continuous_dataset_fixture): 9 | """ 10 | Tests the fit method GPS tool 11 | """ 12 | 13 | tmle = TMLE_Core( 14 | random_seed=100, 15 | verbose=True, 16 | ) 17 | tmle.fit( 18 | T=continuous_dataset_fixture["treatment"], 19 | X=continuous_dataset_fixture[["x1", "x2"]], 20 | y=continuous_dataset_fixture["outcome"], 21 | ) 22 | 23 | assert tmle.num_rows == 500 24 | assert tmle.fully_expanded_t_and_x.shape == (50500, 3) 25 | 26 | 27 | def test_bad_param_calculate_CDRC_TMLE(TMLE_fitted_model_continuous_fixture): 28 | """ 29 | Tests the TMLE `calculate_CDRC` when the `ci` param is bad 30 | """ 31 | 32 | with pytest.raises(Exception) as bad: 33 | observed_result = TMLE_fitted_model_continuous_fixture.calculate_CDRC( 34 | np.array([50]), ci={"param": 0.95} 35 | ) 36 | 37 | with pytest.raises(Exception) as bad: 38 | observed_result = TMLE_fitted_model_continuous_fixture.calculate_CDRC( 39 | np.array([50]), ci=1.05 40 | ) 41 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | *.DS_Store 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | *.so 6 | .Python 7 | build/ 8 | develop-eggs/ 9 | dist/ 10 | downloads/ 11 | eggs/ 12 | .eggs/ 13 | lib/ 14 | lib64/ 15 | parts/ 16 | sdist/ 17 | var/ 18 | wheels/ 19 | pip-wheel-metadata/ 20 | share/python-wheels/ 21 | *.egg-info/ 22 | .installed.cfg 23 | *.egg 24 | MANIFEST 25 | *.manifest 26 | *.spec 27 | pip-log.txt 28 | pip-delete-this-directory.txt 29 | htmlcov/ 30 | .tox/ 31 | .nox/ 32 | .coverage 33 | .coverage.* 34 | .cache 35 | nosetests.xml 36 | coverage.xml 37 | *.cover 38 | *.py,cover 39 | .hypothesis/ 40 | .pytest_cache/ 41 | *.mo 42 | *.pot 43 | *.log 44 | local_settings.py 45 | db.sqlite3 46 | db.sqlite3-journal 47 | instance/ 48 | .webassets-cache 49 | .scrapy 50 | docs/_build/ 51 | target/ 52 | .ipynb_checkpoints 53 | profile_default/ 54 | ipython_config.py 55 | .python-version 56 | __pypackages__/ 57 | celerybeat-schedule 58 | celerybeat.pid 59 | *.sage.py 60 | .env 61 | .venv 62 | env/ 63 | venv/ 64 | ENV/ 65 | env.bak/ 66 | venv.bak/ 67 | .spyderproject 68 | .spyproject 69 | .ropeproject 70 | /site 71 | .mypy_cache/ 72 | .dmypy.json 73 | dmypy.json 74 | .pyre/ 75 | .idea/ 76 | .idea_modules/ 77 | workdir 78 | reports 79 | tests/integration/api/snapshots 80 | tests/integration/api/webcache 81 | -------------------------------------------------------------------------------- /tests/unit/test_gps_classifier.py: -------------------------------------------------------------------------------- 1 | """ Unit tests for the GPS_Core class """ 2 | 3 | import numpy as np 4 | from pygam import LinearGAM 5 | import pytest 6 | 7 | from causal_curve.gps_core import GPS_Core 8 | from tests.conftest import full_continuous_example_dataset 9 | 10 | 11 | def test_predict_log_odds_method_good(GPS_fitted_model_binary_fixture): 12 | """ 13 | Tests the GPS `estimate_log_odds` method using appropriate data (with a binary outcome) 14 | """ 15 | observed_result = GPS_fitted_model_binary_fixture.estimate_log_odds(np.array([0.5])) 16 | assert isinstance(observed_result[0][0], float) 17 | assert len(observed_result[0]) == 1 18 | 19 | observed_result = GPS_fitted_model_binary_fixture.estimate_log_odds( 20 | np.array([0.5, 0.6, 0.7]) 21 | ) 22 | assert isinstance(observed_result[0][0], float) 23 | assert len(observed_result[0]) == 3 24 | 25 | 26 | def test_predict_log_odds_method_bad(GPS_fitted_model_continuous_fixture): 27 | """ 28 | Tests the GPS `estimate_log_odds` method using appropriate data (with a continuous outcome) 29 | """ 30 | with pytest.raises(Exception) as bad: 31 | observed_result = GPS_fitted_model_continuous_fixture.estimate_log_odds( 32 | np.array([50]) 33 | ) 34 | -------------------------------------------------------------------------------- /docs/install.rst: -------------------------------------------------------------------------------- 1 | .. _install: 2 | 3 | ===================================== 4 | Installation, testing and development 5 | ===================================== 6 | 7 | Dependencies 8 | ------------ 9 | 10 | causal-curve requires: 11 | 12 | - black 13 | - coverage 14 | - future 15 | - joblib 16 | - numpy 17 | - numpydoc 18 | - pandas 19 | - patsy 20 | - progressbar2 21 | - pygam 22 | - pytest 23 | - python-dateutil 24 | - python-utils 25 | - pytz 26 | - scikit-learn 27 | - scipy 28 | - six 29 | - sphinx_rtd_theme 30 | - statsmodels 31 | 32 | 33 | 34 | User installation 35 | ----------------- 36 | 37 | If you already have a working installation of numpy, pandas, pygam, scipy, and statsmodels, 38 | you can easily install causal-curve using ``pip``:: 39 | 40 | pip install causal-curve 41 | 42 | 43 | You can also get the latest version of causal-curve by cloning the repository:: 44 | 45 | git clone https://github.com/ronikobrosly/causal-curve.git 46 | cd causal-curve 47 | pip install . 48 | 49 | 50 | Testing 51 | ------- 52 | 53 | After installation, you can launch the test suite from outside the source 54 | directory using ``pytest``:: 55 | 56 | pytest 57 | 58 | 59 | Development 60 | ----------- 61 | 62 | Please reach out if you are interested in adding additional tools, 63 | or have ideas on how to improve the package! 64 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | import setuptools 2 | 3 | with open("README.md", "r") as fh: 4 | long_description = fh.read() 5 | 6 | setuptools.setup( 7 | name="causal-curve", 8 | version="1.0.6", 9 | author="Roni Kobrosly", 10 | author_email="roni.kobrosly@gmail.com", 11 | description="A python library with tools to perform causal inference using \ 12 | observational data when the treatment of interest is continuous.", 13 | long_description=long_description, 14 | long_description_content_type="text/markdown", 15 | url="https://github.com/ronikobrosly/causal-curve", 16 | packages=setuptools.find_packages(include=['causal_curve']), 17 | classifiers=[ 18 | "Programming Language :: Python :: 3", 19 | "License :: OSI Approved :: MIT License", 20 | "Operating System :: OS Independent", 21 | ], 22 | python_requires='>=3.6', 23 | install_requires=[ 24 | 'black', 25 | 'coverage', 26 | 'future', 27 | 'joblib', 28 | 'numpy', 29 | 'numpydoc', 30 | 'pandas', 31 | 'patsy', 32 | 'progressbar2', 33 | 'pygam', 34 | 'pytest', 35 | 'python-dateutil', 36 | 'python-utils', 37 | 'pytz', 38 | 'scikit-learn', 39 | 'scipy', 40 | 'six', 41 | 'sphinx_rtd_theme', 42 | 'statsmodels' 43 | ] 44 | ) 45 | -------------------------------------------------------------------------------- /docs/causal_curve.rst: -------------------------------------------------------------------------------- 1 | causal\_curve package 2 | ===================== 3 | 4 | 5 | causal\_curve.core module 6 | ------------------------- 7 | 8 | .. automodule:: causal_curve.core 9 | :members: 10 | :undoc-members: 11 | :show-inheritance: 12 | 13 | causal\_curve.gps_core module 14 | ----------------------------- 15 | 16 | .. automodule:: causal_curve.gps_core 17 | :members: 18 | :undoc-members: 19 | :show-inheritance: 20 | 21 | causal\_curve.gps_regressor module 22 | ---------------------------------- 23 | 24 | .. automodule:: causal_curve.gps_regressor 25 | :members: 26 | :undoc-members: 27 | :show-inheritance: 28 | 29 | 30 | causal\_curve.gps_classifier module 31 | ----------------------------------- 32 | 33 | .. automodule:: causal_curve.gps_classifier 34 | :members: 35 | :undoc-members: 36 | :show-inheritance: 37 | 38 | 39 | causal\_curve.tmle_core module 40 | ------------------------------ 41 | 42 | .. automodule:: causal_curve.tmle_core 43 | :members: 44 | :undoc-members: 45 | :show-inheritance: 46 | 47 | causal\_curve.tmle_regressor module 48 | ----------------------------------- 49 | 50 | .. automodule:: causal_curve.tmle_regressor 51 | :members: 52 | :undoc-members: 53 | :show-inheritance: 54 | 55 | causal\_curve.mediation module 56 | ------------------------------ 57 | 58 | .. automodule:: causal_curve.mediation 59 | :members: 60 | :undoc-members: 61 | :show-inheritance: 62 | 63 | 64 | 65 | 66 | 67 | Module contents 68 | --------------- 69 | 70 | .. automodule:: causal_curve 71 | :members: 72 | :undoc-members: 73 | :show-inheritance: 74 | -------------------------------------------------------------------------------- /tests/unit/test_core.py: -------------------------------------------------------------------------------- 1 | """ General unit tests of the causal-curve package """ 2 | 3 | import numpy as np 4 | 5 | from causal_curve.core import Core 6 | 7 | 8 | def test_get_params(): 9 | """ 10 | Tests the `get_params` method of the Core base class 11 | """ 12 | 13 | core = Core() 14 | core.a = 5 15 | core.b = 10 16 | 17 | observed_results = core.get_params() 18 | 19 | assert observed_results == {"a": 5, "b": 10} 20 | 21 | 22 | def test_if_verbose_print(capfd): 23 | """ 24 | Tests the `if_verbose_print` method of the Core base class 25 | """ 26 | 27 | core = Core() 28 | core.verbose = True 29 | 30 | core.if_verbose_print("This is a test") 31 | out, err = capfd.readouterr() 32 | 33 | assert out == "This is a test\n" 34 | 35 | core.verbose = False 36 | 37 | core.if_verbose_print("This is a test") 38 | out, err = capfd.readouterr() 39 | 40 | assert out == "" 41 | 42 | 43 | def test_rand_seed_wrapper(): 44 | """ 45 | Tests the `rand_seed_wrapper` method of the Core base class 46 | """ 47 | 48 | core = Core() 49 | core.rand_seed_wrapper(123) 50 | 51 | assert np.random.get_state()[1][0] == 123 52 | 53 | 54 | def test_calculate_z_score(): 55 | """ 56 | Tests the `calculate_z_score` method of the Core base class 57 | """ 58 | 59 | core = Core() 60 | assert round(core.calculate_z_score(0.95), 3) == 1.960 61 | assert round(core.calculate_z_score(0.90), 3) == 1.645 62 | 63 | 64 | def test_clip_negatives(): 65 | """ 66 | Tests the `clip_negatives` method of the Core base class 67 | """ 68 | 69 | core = Core() 70 | assert core.clip_negatives(0.5) == 0.5 71 | assert core.clip_negatives(-1.5) == 0 72 | -------------------------------------------------------------------------------- /tests/unit/test_gps_core.py: -------------------------------------------------------------------------------- 1 | """ Unit tests for the GPS_Core class """ 2 | 3 | import numpy as np 4 | from pygam import LinearGAM 5 | import pytest 6 | 7 | from causal_curve.gps_core import GPS_Core 8 | from tests.conftest import full_continuous_example_dataset 9 | 10 | 11 | @pytest.mark.parametrize( 12 | ("df_fixture", "family"), 13 | [ 14 | (full_continuous_example_dataset, "normal"), 15 | (full_continuous_example_dataset, "lognormal"), 16 | (full_continuous_example_dataset, "gamma"), 17 | (full_continuous_example_dataset, None), 18 | ], 19 | ) 20 | def test_gps_fit(df_fixture, family): 21 | """ 22 | Tests the fit method of the GPS_Core tool 23 | """ 24 | 25 | gps = GPS_Core( 26 | gps_family=family, 27 | treatment_grid_num=10, 28 | lower_grid_constraint=0.0, 29 | upper_grid_constraint=1.0, 30 | spline_order=3, 31 | n_splines=10, 32 | max_iter=100, 33 | random_seed=100, 34 | verbose=True, 35 | ) 36 | gps.fit( 37 | T=df_fixture()["treatment"], 38 | X=df_fixture()["x1"], 39 | y=df_fixture()["outcome"], 40 | ) 41 | 42 | assert isinstance(gps.gam_results, LinearGAM) 43 | assert gps.gps.shape == (500,) 44 | 45 | 46 | def test_bad_param_calculate_CDRC(GPS_fitted_model_continuous_fixture): 47 | """ 48 | Tests the GPS `calculate_CDRC` when the `ci` param is bad 49 | """ 50 | 51 | with pytest.raises(Exception) as bad: 52 | observed_result = GPS_fitted_model_continuous_fixture.calculate_CDRC( 53 | np.array([50]), ci={"param": 0.95} 54 | ) 55 | 56 | with pytest.raises(Exception) as bad: 57 | observed_result = GPS_fitted_model_continuous_fixture.calculate_CDRC( 58 | np.array([50]), ci=1.05 59 | ) 60 | -------------------------------------------------------------------------------- /tests/integration/test_gps.py: -------------------------------------------------------------------------------- 1 | """ Integration tests of the gps.py module """ 2 | 3 | import pandas as pd 4 | 5 | from causal_curve import GPS_Regressor, GPS_Classifier 6 | 7 | 8 | def test_full_continuous_gps_flow(continuous_dataset_fixture): 9 | """ 10 | Tests the full flow of the GPS tool when used with a continuous outcome 11 | """ 12 | 13 | gps = GPS_Regressor( 14 | treatment_grid_num=10, 15 | lower_grid_constraint=0.0, 16 | upper_grid_constraint=1.0, 17 | spline_order=3, 18 | n_splines=10, 19 | max_iter=100, 20 | random_seed=100, 21 | verbose=True, 22 | ) 23 | gps.fit( 24 | T=continuous_dataset_fixture["treatment"], 25 | X=continuous_dataset_fixture[["x1", "x2"]], 26 | y=continuous_dataset_fixture["outcome"], 27 | ) 28 | gps_results = gps.calculate_CDRC(0.95) 29 | 30 | assert isinstance(gps_results, pd.DataFrame) 31 | check = gps_results.columns == [ 32 | "Treatment", 33 | "Causal_Dose_Response", 34 | "Lower_CI", 35 | "Upper_CI", 36 | ] 37 | assert check.all() 38 | 39 | 40 | def test_binary_gps_flow(binary_dataset_fixture): 41 | """ 42 | Tests the full flow of the GPS tool when used with a binary outcome 43 | """ 44 | 45 | gps = GPS_Classifier( 46 | gps_family="normal", 47 | treatment_grid_num=10, 48 | lower_grid_constraint=0.0, 49 | upper_grid_constraint=1.0, 50 | spline_order=3, 51 | n_splines=10, 52 | max_iter=100, 53 | random_seed=100, 54 | verbose=True, 55 | ) 56 | gps.fit( 57 | T=binary_dataset_fixture["treatment"], 58 | X=binary_dataset_fixture["x1"], 59 | y=binary_dataset_fixture["outcome"], 60 | ) 61 | gps_results = gps.calculate_CDRC(0.95) 62 | 63 | assert isinstance(gps_results, pd.DataFrame) 64 | check = gps_results.columns == [ 65 | "Treatment", 66 | "Causal_Odds_Ratio", 67 | "Lower_CI", 68 | "Upper_CI", 69 | ] 70 | assert check.all() 71 | -------------------------------------------------------------------------------- /paper/paper.bib: -------------------------------------------------------------------------------- 1 | 2 | 3 | @book{Galagate:2016, 4 | Adsurl = {https://drum.lib.umd.edu/handle/1903/18170}, 5 | Author = {{Galagate}, D.}, 6 | Title = {Causal Inference with a Continuous Treatment and Outcome: Alternative Estimators for Parametric Dose-Response function with Applications.}, 7 | Booktitle = {Causal Inference with a Continuous Treatment and Outcome: Alternative Estimators for Parametric Dose-Response function with Applications.}, 8 | Publisher = {Digital Repository at the University of Maryland}, 9 | Year = 2016 10 | } 11 | 12 | @article{Moodie:2010, 13 | author = {{Moodie}, E. and {Stephen}, D.}, 14 | title = "{Estimation of dose–response functions for longitudinal data using the generalised propensity score.}", 15 | journal = {Statistical Methods in Medical Research}, 16 | year = 2010, 17 | volume = 21, 18 | doi = {10.1177/0962280209340213}, 19 | } 20 | 21 | @book{Hirano:2004, 22 | Author = {{Hirano}, K. and {Imbens}, G.}, 23 | Booktitle = {Applied bayesian modeling and causal inference from incomplete-data perspectives, by Gelman A and Meng XL. ~Published by Wiley, Oxford, UK.}, 24 | Publisher = {Wiley}, 25 | Title = {{The propensity score with continuous treatments}}, 26 | Year = 2004 27 | } 28 | 29 | @article{van_der_Laan:2010, 30 | author = {{van der Laan}, M. and {Gruber}, S.}, 31 | title = "{Collaborative double robust penalized targeted maximum likelihood estimation.}", 32 | journal = {The International Journal of Biostatistics}, 33 | year = 2010, 34 | volume = 6, 35 | doi = {10.2202/1557-4679.1181}, 36 | } 37 | 38 | @article{Imai:2010, 39 | author = {{Imai}, K., {Keele}, L., and {Tingley}, D.}, 40 | title = "{A General Approach to Causal Mediation Analysis.}", 41 | journal = {Psychological Methods}, 42 | year = 2010, 43 | volume = 15, 44 | doi = {10.1037/a0020761} 45 | } 46 | 47 | @book{Hernán:2020, 48 | Author = {{Hernán}, M. and {Robins}, J.}, 49 | Booktitle = {Causal Inference: What If.}, 50 | Publisher = {Chapman & Hall}, 51 | Title = {{Causal Inference: What If.}}, 52 | Year = 2020 53 | } 54 | -------------------------------------------------------------------------------- /causal_curve/core.py: -------------------------------------------------------------------------------- 1 | """ 2 | Core classes (with basic methods) that will be invoked when other, model classes are defined 3 | """ 4 | 5 | import numpy as np 6 | from scipy.stats import norm 7 | 8 | 9 | class Core: 10 | """Base class for causal_curve module""" 11 | 12 | def __init__(self): 13 | pass 14 | 15 | __version__ = "1.0.6" 16 | 17 | def get_params(self): 18 | """Returns a dict of all of the object's user-facing parameters 19 | 20 | Parameters 21 | ---------- 22 | None 23 | 24 | Returns 25 | ------- 26 | self: object 27 | """ 28 | attrs = self.__dict__ 29 | return dict( 30 | [(k, v) for k, v in list(attrs.items()) if (k[0] != "_") and (k[-1] != "_")] 31 | ) 32 | 33 | def if_verbose_print(self, string): 34 | """Prints the input statement if verbose is set to True 35 | 36 | Parameters 37 | ---------- 38 | string: str, some string to be printed 39 | 40 | Returns 41 | ---------- 42 | None 43 | 44 | """ 45 | if self.verbose: 46 | print(string) 47 | 48 | @staticmethod 49 | def rand_seed_wrapper(random_seed=None): 50 | """Sets the random seed using numpy 51 | 52 | Parameters 53 | ---------- 54 | random_seed: int, random seed number 55 | 56 | Returns 57 | ---------- 58 | None 59 | """ 60 | if random_seed is None: 61 | pass 62 | else: 63 | np.random.seed(random_seed) 64 | 65 | @staticmethod 66 | def calculate_z_score(ci): 67 | """Calculates the critical z-score for a desired two-sided, 68 | confidence interval width. 69 | 70 | Parameters 71 | ---------- 72 | ci: float, the confidence interval width (e.g. 0.95) 73 | 74 | Returns 75 | ------- 76 | Float, critical z-score value 77 | """ 78 | return norm.ppf((1 + ci) / 2) 79 | 80 | @staticmethod 81 | def clip_negatives(number): 82 | """Helper function to clip negative numbers to zero 83 | 84 | Parameters 85 | ---------- 86 | number: int or float, any number that needs a floor at zero 87 | 88 | Returns 89 | ------- 90 | Int or float of modified value 91 | 92 | """ 93 | if number < 0: 94 | return 0 95 | return number 96 | -------------------------------------------------------------------------------- /docs/conf.py: -------------------------------------------------------------------------------- 1 | # Configuration file for the Sphinx documentation builder. 2 | # 3 | # This file only contains a selection of the most common options. For a full 4 | # list see the documentation: 5 | # https://www.sphinx-doc.org/en/master/usage/configuration.html 6 | 7 | # -- Path setup -------------------------------------------------------------- 8 | 9 | # If extensions (or modules to document with autodoc) are in another directory, 10 | # add these directories to sys.path here. If the directory is relative to the 11 | # documentation root, use os.path.abspath to make it absolute, like shown here. 12 | 13 | import os 14 | import sys 15 | 16 | sys.path.insert(0, os.path.abspath("../")) 17 | 18 | 19 | # -- Project information ----------------------------------------------------- 20 | 21 | project = "causal_curve" 22 | copyright = "2020, Roni Kobrosly" 23 | author = "Roni Kobrosly" 24 | 25 | # The full version, including alpha/beta/rc tags 26 | release = "1.0.6" 27 | 28 | # -- General configuration --------------------------------------------------- 29 | 30 | # Add any Sphinx extension module names here, as strings. They can be 31 | # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom 32 | # ones. 33 | extensions = [ 34 | "sphinx.ext.autodoc", 35 | "sphinx.ext.autosummary", 36 | "numpydoc", 37 | ] 38 | 39 | # Add any paths that contain templates here, relative to this directory. 40 | templates_path = ["_templates"] 41 | 42 | # List of patterns, relative to source directory, that match files and 43 | # directories to ignore when looking for source files. 44 | # This pattern also affects html_static_path and html_extra_path. 45 | exclude_patterns = ["_build", "Thumbs.db", ".DS_Store"] 46 | 47 | # The name of the Pygments (syntax highlighting) style to use. 48 | pygments_style = "sphinx" 49 | 50 | # this is needed for some reason... 51 | # see https://github.com/numpy/numpydoc/issues/69 52 | numpydoc_show_class_members = False 53 | 54 | # generate autosummary even if no references 55 | autosummary_generate = True 56 | 57 | master_doc = "index" 58 | 59 | # -- Options for HTML output ------------------------------------------------- 60 | 61 | # The theme to use for HTML and HTML Help pages. See the documentation for 62 | # a list of builtin themes. 63 | # 64 | html_theme = "sphinx_rtd_theme" 65 | 66 | # Add any paths that contain custom static files (such as style sheets) here, 67 | # relative to this directory. They are copied after the builtin static files, 68 | # so a file named "default.css" will overwrite the builtin "default.css". 69 | html_static_path = [] 70 | -------------------------------------------------------------------------------- /docs/GPS_Classifier.rst: -------------------------------------------------------------------------------- 1 | .. _GPS_Classifier: 2 | 3 | ============================================================ 4 | GPS_Classifier Tool (continuous treatments, binary outcomes) 5 | ============================================================ 6 | 7 | As with the other GPS tool, we calculate generalized propensity scores (GPS) but 8 | with the classifier we can estimate the point-by-point causal contribution of 9 | a continuous treatment to a binary outcome. The GPS_Classifier does this by 10 | estimating the log odds of a positive outcome and odds ratio (odds of positive outcome / odds of negative outcome) along 11 | the entire range of treatment values: 12 | 13 | .. image:: ../imgs/binary_OR_fig.png 14 | 15 | 16 | Currently, the causal-curve package does not contain a TMLE implementation that is appropriate for a binary outcome, 17 | so the GPS_Classifier tool will have to suffice for this sort of outcome. 18 | 19 | This tool works much like the _GPS_Regressor tool; as long as the outcome series in your dataframe contains 20 | binary integer values (e.g. 0's and 1's) the ``fit()`` method will work as it's supposed to: 21 | 22 | >>> df.head(5) # a pandas dataframe with your data 23 | X_1 X_2 Treatment Outcome 24 | 0 0.596685 0.162688 0.000039 1 25 | 1 1.014187 0.916101 0.000197 0 26 | 2 0.932859 1.328576 0.000223 0 27 | 3 1.140052 0.555203 0.000339 0 28 | 4 1.613471 0.340886 0.000438 1 29 | 30 | With this dataframe, we can now calculate the GPS to estimate the causal relationship between 31 | treatment and outcome. Let's use the default settings of the GPS tool: 32 | 33 | >>> from causal_curve import GPS_Classifier 34 | >>> gps = GPS() 35 | >>> gps.fit(T = df['Treatment'], X = df[['X_1', 'X_2']], y = df['Outcome']) 36 | >>> gps_results = gps.calculate_CDRC(0.95) 37 | 38 | The ``gps_results`` object (a dataframe) now contains all of the data to produce the above plot. 39 | 40 | If you'd like to estimate the log odds at a specific point on the curve, use the 41 | ``predict_log_odds`` to do so. 42 | 43 | References 44 | ---------- 45 | 46 | Galagate, D. Causal Inference with a Continuous Treatment and Outcome: Alternative 47 | Estimators for Parametric Dose-Response function with Applications. PhD thesis, 2016. 48 | 49 | Moodie E and Stephens DA. Estimation of dose–response functions for 50 | longitudinal data using the generalised propensity score. In: Statistical Methods in 51 | Medical Research 21(2), 2010, pp.149–166. 52 | 53 | Hirano K and Imbens GW. The propensity score with continuous treatments. 54 | In: Gelman A and Meng XL (eds) Applied bayesian modeling and causal inference 55 | from incomplete-data perspectives. Oxford, UK: Wiley, 2004, pp.73–84. 56 | -------------------------------------------------------------------------------- /docs/TMLE_Regressor.rst: -------------------------------------------------------------------------------- 1 | .. _TMLE_Regressor: 2 | 3 | ================================================================ 4 | TMLE_Regressor Tool (continuous treatments, continuous outcomes) 5 | ================================================================ 6 | 7 | 8 | In this example, we use this package's Targeted Maximum Likelihood Estimation (TMLE) 9 | tool to estimate the marginal causal curve of some continuous treatment on a continuous outcome, 10 | accounting for some mild confounding effects. 11 | 12 | The TMLE algorithm is doubly robust, meaning that as long as one of the two models contained 13 | with the tool (the ``g`` or ``q`` models) performs well, then the overall tool will correctly 14 | estimate the causal curve. 15 | 16 | Compared with the package's GPS methods incorporates more powerful machine learning techniques internally (gradient boosting) 17 | and produces significantly smaller confidence intervals. However it is less computationally efficient 18 | and will take longer to run. In addition, **the treatment values provided should 19 | be roughly normally-distributed**, otherwise you may encounter internal math errors. 20 | 21 | Let's first generate some simple toy data: 22 | 23 | 24 | >>> import matplotlib.pyplot as plt 25 | import numpy as np 26 | import pandas as pd 27 | from causal_curve import TMLE_Regressor 28 | np.random.seed(200) 29 | 30 | >>> def generate_data(t, A, sigma, omega, noise=0, n_outliers=0, random_state=0): 31 | y = A * np.exp(-sigma * t) * np.sin(omega * t) 32 | rnd = np.random.RandomState(random_state) 33 | error = noise * rnd.randn(t.size) 34 | outliers = rnd.randint(0, t.size, n_outliers) 35 | error[outliers] *= 35 36 | return y + error 37 | 38 | >>> treatment = np.linspace(0, 10, 1000) 39 | outcome = generate_data( 40 | t = treatment, 41 | A = 2, 42 | sigma = 0.1, 43 | omega = (0.1 * 2 * np.pi), 44 | noise = 0.1, 45 | n_outliers = 5 46 | ) 47 | x1 = np.random.uniform(0,10,1000) 48 | x2 = (np.random.uniform(0,10,1000) * 3) 49 | 50 | >>> df = pd.DataFrame( 51 | { 52 | 'x1': x1, 53 | 'x2': x2, 54 | 'treatment': treatment, 55 | 'outcome': outcome 56 | } 57 | ) 58 | 59 | 60 | All we do now is employ the TMLE_Regressor class, with mostly default settings: 61 | 62 | 63 | >>> from causal_curve import TMLE_Regressor 64 | tmle = TMLE_Regressor( 65 | random_seed=111, 66 | bandwidth=10 67 | ) 68 | 69 | >>> tmle.fit(T = df['treatment'], X = df[['x1', 'x2']], y = df['outcome']) 70 | gps_results = tmle.calculate_CDRC(0.95) 71 | 72 | The resulting dataframe contains all of the data you need to generate the following plot: 73 | 74 | .. image:: ../imgs/tmle_plot.png 75 | 76 | To generate user-specified points along the curve, use the ``point_estimate`` and ``point_estimate_interval`` methods: 77 | 78 | >>> tmle.point_estimate(np.array([5.5])) 79 | tmle.point_estimate_interval(np.array([5.5])) 80 | 81 | 82 | References 83 | ---------- 84 | 85 | Kennedy EH, Ma Z, McHugh MD, Small DS. Nonparametric methods for doubly robust estimation 86 | of continuous treatment effects. Journal of the Royal Statistical Society, Series B. 79(4), 2017, pp.1229-1245. 87 | 88 | van der Laan MJ and Rubin D. Targeted maximum likelihood learning. In: ​U.C. Berkeley Division of 89 | Biostatistics Working Paper Series, 2006. 90 | 91 | van der Laan MJ and Gruber S. Collaborative double robust penalized targeted 92 | maximum likelihood estimation. In: The International Journal of Biostatistics 6(1), 2010. 93 | -------------------------------------------------------------------------------- /docs/index.rst: -------------------------------------------------------------------------------- 1 | Welcome to causal-curve's documentation! 2 | ======================================== 3 | 4 | 5 | .. toctree:: 6 | :maxdepth: 2 7 | :hidden: 8 | :caption: Getting Started 9 | 10 | intro 11 | install 12 | contribute 13 | 14 | 15 | .. toctree:: 16 | :maxdepth: 1 17 | :hidden: 18 | :caption: End-to-end demonstration 19 | 20 | full_example 21 | 22 | 23 | .. toctree:: 24 | :maxdepth: 1 25 | :hidden: 26 | :caption: Package Tools 27 | 28 | GPS_Regressor 29 | GPS_Classifier 30 | TMLE_Regressor 31 | Mediation_example 32 | 33 | .. toctree:: 34 | :maxdepth: 1 35 | :hidden: 36 | :caption: Module details 37 | 38 | modules 39 | 40 | 41 | .. toctree:: 42 | :maxdepth: 1 43 | :hidden: 44 | :caption: Additional Information 45 | 46 | changelog 47 | citation 48 | 49 | 50 | .. toctree:: 51 | :maxdepth: 2 52 | :caption: Contents: 53 | 54 | 55 | **causal-curve** is a Python package with tools to perform causal inference 56 | when the treatment of interest is continuous. 57 | 58 | .. image:: ../imgs/welcome_plot.png 59 | 60 | 61 | Summary 62 | ------- 63 | 64 | (**Version 1.0.0 released in Jan 2021!**) 65 | 66 | There are many available methods to perform causal inference when your intervention of interest is binary, 67 | but few methods exist to handle continuous treatments. This is unfortunate because there are many 68 | scenarios (in industry and research) where these methods would be useful. This library attempts to 69 | address this gap, providing tools to estimate causal curves (AKA causal dose-response curves). 70 | Both continuous and binary outcomes can be modeled with this package. 71 | 72 | 73 | Quick example (of the ``GPS_Regressor`` tool) 74 | --------------------------------------------- 75 | 76 | **causal-curve** uses a sklearn-like API that should feel familiar to python machine learning users. 77 | This includes ``_Regressor`` and ``_Classifier`` models, and ``fit()`` methods. 78 | 79 | The following example estimates the causal dose-response curve (CDRC) by calculating 80 | generalized propensity scores. 81 | 82 | >>> from causal_curve import GPS_Regressor 83 | >>> import numpy as np 84 | 85 | >>> gps = GPS_Regressor(treatment_grid_num = 200, random_seed = 512) 86 | 87 | >>> df # a pandas dataframe with your data 88 | X_1 X_2 Treatment Outcome 89 | 0 0.596685 0.162688 0.000039 -0.270533 90 | 1 1.014187 0.916101 0.000197 -0.266979 91 | 2 0.932859 1.328576 0.000223 1.921979 92 | 3 1.140052 0.555203 0.000339 1.461526 93 | 4 1.613471 0.340886 0.000438 2.064511 94 | 95 | >>> gps.fit(T = df['Treatment'], X = df[['X_1', 'X_2']], y = df['Outcome']) 96 | >>> gps_results = gps.calculate_CDRC(ci = 0.95) 97 | >>> gps_point = gps.point_estimate(np.array([0.0003])) 98 | >>> gps_point_interval = gps.point_estimate_interval(np.array([0.0003]), ci = 0.95) 99 | 100 | 1. First we import the `GPS_Regressor` class. 101 | 102 | 2. Then we instantiate the class, providing any of the optional parameters. 103 | 104 | 3. Prepare and organized your treatment, covariate, and outcome data into a pandas dataframe. 105 | 106 | 4. Fit the load the training and test sets by calling the ``.fit()`` method. 107 | 108 | 5. Estimate the points of the causal curve (along with 95% confidence interval bounds) with the ``.calculate_CDRC()`` method. 109 | 110 | 6. Generate point estimates along the causal curve with the ``.point_estimate()``, ``.point_estimate_interval()``, and ``.estimate_log_odds()`` methods. 111 | 112 | 7. Explore or plot your results! 113 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # causal-curve 2 | 3 | [![build status](http://img.shields.io/travis/ronikobrosly/causal-curve/main.svg?style=flat)](https://travis-ci.org/ronikobrosly/causal-curve) 4 | [![codecov](https://codecov.io/gh/ronikobrosly/causal-curve/branch/main/graph/badge.svg)](https://codecov.io/gh/ronikobrosly/causal-curve) 5 | [![DOI](https://zenodo.org/badge/256017107.svg)](https://zenodo.org/badge/latestdoi/256017107) 6 | 7 | Python tools to perform causal inference when the treatment of interest is continuous. 8 | 9 | 10 |

11 | 12 |

13 | 14 | 15 | 16 | ## Table of Contents 17 | 18 | - [Overview](#overview) 19 | - [Installation](#installation) 20 | - [Documentation](#documentation) 21 | - [Contributing](#contributing) 22 | - [Citation](#citation) 23 | - [References](#references) 24 | 25 | ## Overview 26 | 27 | (**Version 1.0.0 released in January 2021!**) 28 | 29 | There are many implemented methods to perform causal inference when your intervention of interest is binary, 30 | but few methods exist to handle continuous treatments. 31 | 32 | This is unfortunate because there are many scenarios (in industry and research) where these methods would be useful. 33 | For example, when you would like to: 34 | 35 | * Estimate the causal response to increasing or decreasing the price of a product across a wide range. 36 | * Understand how the number of minutes per week of aerobic exercise causes positive health outcomes. 37 | * Estimate how decreasing order wait time will impact customer satisfaction, after controlling for confounding effects. 38 | * Estimate how changing neighborhood income inequality (Gini index) could be causally related to neighborhood crime rate. 39 | 40 | This library attempts to address this gap, providing tools to estimate causal curves (AKA causal dose-response curves). 41 | Both continuous and binary outcomes can be modeled against a continuous treatment. 42 | 43 | ## Installation 44 | 45 | Available via PyPI: 46 | 47 | `pip install causal-curve` 48 | 49 | You can also get the latest version of causal-curve by cloning the repository:: 50 | 51 | ``` 52 | git clone -b main https://github.com/ronikobrosly/causal-curve.git 53 | cd causal-curve 54 | pip install . 55 | ``` 56 | 57 | ## Documentation 58 | 59 | [Documentation, tutorials, and examples are available at readthedocs.org](https://causal-curve.readthedocs.io/en/latest/) 60 | 61 | 62 | ## Contributing 63 | 64 | Your help is absolutely welcome! Please do reach out or create a feature branch! 65 | 66 | ## Citation 67 | 68 | Kobrosly, R. W., (2020). causal-curve: A Python Causal Inference Package to Estimate Causal Dose-Response Curves. Journal of Open Source Software, 5(52), 2523, [https://doi.org/10.21105/joss.02523](https://doi.org/10.21105/joss.02523) 69 | 70 | ## References 71 | 72 | Galagate, D. Causal Inference with a Continuous Treatment and Outcome: Alternative 73 | Estimators for Parametric Dose-Response function with Applications. PhD thesis, 2016. 74 | 75 | Hirano K and Imbens GW. The propensity score with continuous treatments. 76 | In: Gelman A and Meng XL (eds) Applied bayesian modeling and causal inference 77 | from incomplete-data perspectives. Oxford, UK: Wiley, 2004, pp.73–84. 78 | 79 | Imai K, Keele L, Tingley D. A General Approach to Causal Mediation Analysis. Psychological 80 | Methods. 15(4), 2010, pp.309–334. 81 | 82 | Kennedy EH, Ma Z, McHugh MD, Small DS. Nonparametric methods for doubly robust estimation 83 | of continuous treatment effects. Journal of the Royal Statistical Society, Series B. 79(4), 2017, pp.1229-1245. 84 | 85 | Moodie E and Stephens DA. Estimation of dose–response functions for 86 | longitudinal data using the generalised propensity score. In: Statistical Methods in 87 | Medical Research 21(2), 2010, pp.149–166. 88 | 89 | van der Laan MJ and Gruber S. Collaborative double robust penalized targeted 90 | maximum likelihood estimation. In: The International Journal of Biostatistics 6(1), 2010. 91 | 92 | van der Laan MJ and Rubin D. Targeted maximum likelihood learning. In: ​U.C. Berkeley Division of 93 | Biostatistics Working Paper Series, 2006. 94 | -------------------------------------------------------------------------------- /tests/unit/test_mediation.py: -------------------------------------------------------------------------------- 1 | """ Unit tests of the Mediation.py module """ 2 | 3 | import numpy as np 4 | import pytest 5 | 6 | from causal_curve import Mediation 7 | 8 | 9 | def test_mediation_fit(mediation_fixture): 10 | """ 11 | Tests the fit method of Mediation tool 12 | """ 13 | 14 | med = Mediation( 15 | treatment_grid_num=10, 16 | lower_grid_constraint=0.01, 17 | upper_grid_constraint=0.99, 18 | bootstrap_draws=100, 19 | bootstrap_replicates=50, 20 | spline_order=3, 21 | n_splines=5, 22 | lambda_=0.5, 23 | max_iter=20, 24 | random_seed=None, 25 | verbose=True, 26 | ) 27 | med.fit( 28 | T=mediation_fixture["treatment"], 29 | M=mediation_fixture["mediator"], 30 | y=mediation_fixture["outcome"], 31 | ) 32 | 33 | assert len(med.final_bootstrap_results) == 9 34 | 35 | 36 | @pytest.mark.parametrize( 37 | ( 38 | "treatment_grid_num", 39 | "lower_grid_constraint", 40 | "upper_grid_constraint", 41 | "bootstrap_draws", 42 | "bootstrap_replicates", 43 | "spline_order", 44 | "n_splines", 45 | "lambda_", 46 | "max_iter", 47 | "random_seed", 48 | "verbose", 49 | ), 50 | [ 51 | (10.5, 0.01, 0.99, 10, 10, 3, 5, 0.5, 100, None, True), 52 | (0, 0.01, 0.99, 10, 10, 3, 5, 0.5, 100, None, True), 53 | (1e6, 0.01, 0.99, 10, 10, 3, 5, 0.5, 100, None, True), 54 | (10, "hehe", 0.99, 10, 10, 3, 5, 0.5, 100, None, True), 55 | (10, -1.0, 0.99, 10, 10, 3, 5, 0.5, 100, None, True), 56 | (10, 1.5, 0.99, 10, 10, 3, 5, 0.5, 100, None, True), 57 | (10, 0.1, "hehe", 10, 10, 3, 5, 0.5, 100, None, True), 58 | (10, 0.1, -1.0, 10, 10, 3, 5, 0.5, 100, None, True), 59 | (10, 0.1, 1.5, 10, 10, 3, 5, 0.5, 100, None, True), 60 | (10, 0.1, 0.9, 10.5, 10, 3, 5, 0.5, 100, None, True), 61 | (10, 0.1, 0.9, -2, 10, 3, 5, 0.5, 100, None, True), 62 | (10, 0.1, 0.9, 10000000, 10, 3, 5, 0.5, 100, None, True), 63 | (10, 0.1, 0.9, 100, "10", 3, 5, 0.5, 100, None, True), 64 | (10, 0.1, 0.9, 100, -1, 3, 5, 0.5, 100, None, True), 65 | (10, 0.1, 0.9, 100, 10000000, 3, 5, 0.5, 100, None, True), 66 | (10, 0.1, 0.9, 100, 200, "3", 5, 0.5, 100, None, True), 67 | (10, 0.1, 0.9, 100, 200, 1, 5, 0.5, 100, None, True), 68 | (10, 0.1, 0.9, 100, 200, 1e6, 5, 0.5, 100, None, True), 69 | (10, 0.1, 0.9, 100, 200, 5, "10", 0.5, 100, None, True), 70 | (10, 0.1, 0.9, 100, 200, 5, 1, 0.5, 100, None, True), 71 | (10, 0.1, 0.9, 100, 200, 5, 10, "0.5", 100, None, True), 72 | (10, 0.1, 0.9, 100, 200, 5, 10, -0.5, 100, None, True), 73 | (10, 0.1, 0.9, 100, 200, 5, 10, 1e7, 100, None, True), 74 | (10, 0.1, 0.9, 100, 200, 5, 10, 1, "100", None, True), 75 | (10, 0.1, 0.9, 100, 200, 5, 10, 1, 1, None, True), 76 | (10, 0.1, 0.9, 100, 200, 5, 10, 1, 1e8, None, True), 77 | (10, 0.1, 0.9, 100, 200, 5, 10, 1, 100, "None", True), 78 | (10, 0.1, 0.9, 100, 200, 5, 10, 1, 100, -5, True), 79 | (10, 0.1, 0.9, 100, 200, 5, 10, 1, 100, 123, "True"), 80 | (10, 30.0, 10.0, 100, 200, 5, 10, 1, 100, None, True), 81 | ], 82 | ) 83 | def test_bad_mediation_instantiation( 84 | treatment_grid_num, 85 | lower_grid_constraint, 86 | upper_grid_constraint, 87 | bootstrap_draws, 88 | bootstrap_replicates, 89 | spline_order, 90 | n_splines, 91 | lambda_, 92 | max_iter, 93 | random_seed, 94 | verbose, 95 | ): 96 | with pytest.raises(Exception) as bad: 97 | Mediation( 98 | treatment_grid_num=treatment_grid_num, 99 | lower_grid_constraint=lower_grid_constraint, 100 | upper_grid_constraint=upper_grid_constraint, 101 | bootstrap_draws=bootstrap_draws, 102 | bootstrap_replicates=bootstrap_replicates, 103 | spline_order=spline_order, 104 | n_splines=n_splines, 105 | lambda_=lambda_, 106 | max_iter=max_iter, 107 | random_seed=random_seed, 108 | verbose=verbose, 109 | ) 110 | -------------------------------------------------------------------------------- /tests/unit/test_tmle_regressor.py: -------------------------------------------------------------------------------- 1 | """ Unit tests of the tmle.py module """ 2 | 3 | import numpy as np 4 | import pytest 5 | 6 | from causal_curve import TMLE_Regressor 7 | 8 | 9 | def test_point_estimate_method_good(TMLE_fitted_model_continuous_fixture): 10 | """ 11 | Tests the GPS `point_estimate` method using appropriate data (with a continuous outcome) 12 | """ 13 | 14 | observed_result = TMLE_fitted_model_continuous_fixture.point_estimate( 15 | np.array([50]) 16 | ) 17 | assert isinstance(observed_result[0][0], float) 18 | assert len(observed_result[0]) == 1 19 | 20 | observed_result = TMLE_fitted_model_continuous_fixture.point_estimate( 21 | np.array([40, 50, 60]) 22 | ) 23 | assert isinstance(observed_result[0][0], float) 24 | assert len(observed_result[0]) == 3 25 | 26 | 27 | def test_point_estimate_interval_method_good(TMLE_fitted_model_continuous_fixture): 28 | """ 29 | Tests the GPS `point_estimate_interval` method using appropriate data (with a continuous outcome) 30 | """ 31 | observed_result = TMLE_fitted_model_continuous_fixture.point_estimate_interval( 32 | np.array([50]) 33 | ) 34 | assert isinstance(observed_result[0][0], float) 35 | assert observed_result.shape == (1, 2) 36 | 37 | observed_result = TMLE_fitted_model_continuous_fixture.point_estimate_interval( 38 | np.array([40, 50, 60]) 39 | ) 40 | assert isinstance(observed_result[0][0], float) 41 | assert observed_result.shape == (3, 2) 42 | 43 | 44 | @pytest.mark.parametrize( 45 | ( 46 | "treatment_grid_num", 47 | "lower_grid_constraint", 48 | "upper_grid_constraint", 49 | "n_estimators", 50 | "learning_rate", 51 | "max_depth", 52 | "bandwidth", 53 | "random_seed", 54 | "verbose", 55 | ), 56 | [ 57 | # treatment_grid_num 58 | (100.0, 0.01, 0.99, 200, 0.01, 3, 0.5, None, False), 59 | ("100.0", 0.01, 0.99, 200, 0.01, 3, 0.5, None, False), 60 | (2, 0.01, 0.99, 200, 0.01, 3, 0.5, None, False), 61 | (500000, 0.01, 0.99, 200, 0.01, 3, 0.5, None, False), 62 | # lower_grid_constraint 63 | (100, {0.01: "a"}, 0.99, 200, 0.01, 3, 0.5, None, False), 64 | (100, -0.01, 0.99, 200, 0.01, 3, 0.5, None, False), 65 | (100, 6.05, 0.99, 200, 0.01, 3, 0.5, None, False), 66 | # upper_grid_constraint 67 | (100, 0.01, [1, 2, 3], 200, 0.01, 3, 0.5, None, False), 68 | (100, 0.01, -0.05, 200, 0.01, 3, 0.5, None, False), 69 | (100, 0.01, 5.99, 200, 0.01, 3, 0.5, None, False), 70 | (100, 0.9, 0.2, 200, 0.01, 3, 0.5, None, False), 71 | # n_estimators 72 | (100, 0.01, 0.99, "3.0", 0.01, 3, 0.5, None, False), 73 | (100, 0.01, 0.99, -5, 0.01, 3, 0.5, None, False), 74 | (100, 0.01, 0.99, 10000000, 0.01, 3, 0.5, None, False), 75 | # learning_rate 76 | (100, 0.01, 0.99, 200, ["a", "b"], 3, 0.5, None, False), 77 | (100, 0.01, 0.99, 200, 5000000, 3, 0.5, None, False), 78 | # max_depth 79 | (100, 0.01, 0.99, 200, 0.01, "a", 0.5, None, False), 80 | (100, 0.01, 0.99, 200, 0.01, -6, 0.5, None, False), 81 | # bandwidth 82 | (100, 0.01, 0.99, 200, 0.01, 3, "b", None, False), 83 | (100, 0.01, 0.99, 200, 0.01, 3, -10, None, False), 84 | # random seed 85 | (100, 0.01, 0.99, 200, 0.01, 3, 0.5, "b", False), 86 | (100, 0.01, 0.99, 200, 0.01, 3, 0.5, -10, False), 87 | # verbose 88 | (100, 0.01, 0.99, 200, 0.01, 3, 0.5, None, "Verbose"), 89 | ], 90 | ) 91 | def test_bad_tmle_instantiation( 92 | treatment_grid_num, 93 | lower_grid_constraint, 94 | upper_grid_constraint, 95 | n_estimators, 96 | learning_rate, 97 | max_depth, 98 | bandwidth, 99 | random_seed, 100 | verbose, 101 | ): 102 | with pytest.raises(Exception) as bad: 103 | TMLE_Regressor( 104 | treatment_grid_num=treatment_grid_num, 105 | lower_grid_constraint=lower_grid_constraint, 106 | upper_grid_constraint=upper_grid_constraint, 107 | n_estimators=n_estimators, 108 | learning_rate=learning_rate, 109 | max_depth=max_depth, 110 | bandwidth=bandwidth, 111 | random_seed=random_seed, 112 | verbose=verbose, 113 | ) 114 | -------------------------------------------------------------------------------- /docs/GPS_Regressor.rst: -------------------------------------------------------------------------------- 1 | .. _GPS_Regressor: 2 | 3 | ================================================================ 4 | GPS_Regressor Tool (continuous treatments, continuous outcomes) 5 | ================================================================ 6 | 7 | 8 | In this example, we use this package's GPS_Regressor tool to estimate the marginal causal curve of some 9 | continuous treatment on a continuous outcome, accounting for some mild confounding effects. 10 | To put this differently, the result of this will be an estimate of the average 11 | of each individual's dose-response to the treatment. To do this we calculate 12 | generalized propensity scores (GPS) to correct the treatment prediction of the outcome. 13 | 14 | Compared with the package's TMLE method, the GPS methods are more computationally efficient, 15 | better suited for large datasets, but produces wider confidence intervals. 16 | 17 | In this example we use simulated data originally developed by Hirano and Imbens but adapted by others 18 | (see references). The advantage of this simulated data is it allows us 19 | to compare the estimate we produce against the true, analytically-derived causal curve. 20 | 21 | Let :math:`t_i` be the treatment for the i-th unit, let :math:`x_1` and :math:`x_2` be the 22 | confounding covariates, and let :math:`y_i` be the outcome measure. We assume that the covariates 23 | and treatment are exponentially-distributed, and the treatment variable is associated with the 24 | covariates in the following way: 25 | 26 | >>> import numpy as np 27 | >>> import pandas as pd 28 | >>> from scipy.stats import expon 29 | 30 | >>> np.random.seed(333) 31 | >>> n = 5000 32 | >>> x_1 = expon.rvs(size=n, scale = 1) 33 | >>> x_2 = expon.rvs(size=n, scale = 1) 34 | >>> treatment = expon.rvs(size=n, scale = (1/(x_1 + x_2))) 35 | 36 | The GPS is given by 37 | 38 | .. math:: 39 | 40 | f(t, x_1, x_2) = (x_1 + x_2) * e^{-(x_1 + x_2) * t} 41 | 42 | If we generate the outcome variable by summing the treatment and GPS, the true causal 43 | curve is derived analytically to be: 44 | 45 | .. math:: 46 | 47 | f(t) = t + \frac{2}{(1 + t)^3} 48 | 49 | 50 | The following code completes the data generation: 51 | 52 | >>> gps = ((x_1 + x_2) * np.exp(-(x_1 + x_2) * treatment)) 53 | >>> outcome = treatment + gps + np.random.normal(size = n, scale = 1) 54 | 55 | >>> truth_func = lambda treatment: (treatment + (2/(1 + treatment)**3)) 56 | >>> vfunc = np.vectorize(truth_func) 57 | >>> true_outcome = vfunc(treatment) 58 | 59 | >>> df = pd.DataFrame( 60 | >>> { 61 | >>> 'X_1': x_1, 62 | >>> 'X_2': x_2, 63 | >>> 'Treatment': treatment, 64 | >>> 'GPS': gps, 65 | >>> 'Outcome': outcome, 66 | >>> 'True_outcome': true_outcome 67 | >>> } 68 | >>> ).sort_values('Treatment', ascending = True) 69 | 70 | With this dataframe, we can now calculate the GPS to estimate the causal relationship between 71 | treatment and outcome. Let's use the default settings of the GPS_Regressor tool: 72 | 73 | >>> from causal_curve import GPS_Regressor 74 | >>> gps = GPS_Regressor() 75 | >>> gps.fit(T = df['Treatment'], X = df[['X_1', 'X_2']], y = df['Outcome']) 76 | >>> gps_results = gps.calculate_CDRC(0.95) 77 | 78 | You now have everything to produce the following plot with matplotlib. In this example with only mild confounding, 79 | the GPS-calculated estimate of the true causal curve produces has approximately 80 | half the error of a simple LOESS estimate using only the treatment and the outcome. 81 | 82 | .. image:: ../imgs/cdrc/CDRC.png 83 | 84 | The GPS_Regressor tool also allows you to estimate a specific set of points along the causal curve. 85 | Use the `predict` and `predict_interval` methods to produce a point estimate 86 | and prediction interval, respectively. 87 | 88 | References 89 | ---------- 90 | 91 | Galagate, D. Causal Inference with a Continuous Treatment and Outcome: Alternative 92 | Estimators for Parametric Dose-Response function with Applications. PhD thesis, 2016. 93 | 94 | Moodie E and Stephens DA. Estimation of dose–response functions for 95 | longitudinal data using the generalised propensity score. In: Statistical Methods in 96 | Medical Research 21(2), 2010, pp.149–166. 97 | 98 | Hirano K and Imbens GW. The propensity score with continuous treatments. 99 | In: Gelman A and Meng XL (eds) Applied bayesian modeling and causal inference 100 | from incomplete-data perspectives. Oxford, UK: Wiley, 2004, pp.73–84. 101 | -------------------------------------------------------------------------------- /docs/Mediation_example.rst: -------------------------------------------------------------------------------- 1 | .. _Mediation_example: 2 | 3 | ============================================================ 4 | Mediation Tool (continuous treatment, mediator, and outcome) 5 | ============================================================ 6 | 7 | 8 | In trying to explore the causal relationships between various elements, oftentimes you'll use 9 | your domain knowledge to sketch out your initial ideas about the causal connections. 10 | See the following causal DAG of the expected relationships between smoking, diabetes, obesity, age, 11 | and mortality (Havumaki et al.): 12 | 13 | .. image:: ../imgs/mediation/diabetes_DAG.png 14 | 15 | At some point though, it's helpful to validate these ideas with empirical tests. 16 | This tool provides a test that can estimate the amount of mediation that occurs between 17 | a treatment, a purported mediator, and an outcome. In keeping with the causal curve theme, 18 | this tool uses a test developed by Imai et al. when handling a continuous treatment and 19 | mediator. 20 | 21 | In this example we use the following simulated data, and assume that the `mediator` 22 | variable is decided to be a mediator by expert judgement. 23 | 24 | >>> import numpy as np 25 | >>> import pandas as pd 26 | 27 | >>> np.random.seed(132) 28 | >>> n_obs = 500 29 | 30 | >>> treatment = np.random.normal(loc=50.0, scale=10.0, size=n_obs) 31 | >>> mediator = np.random.normal(loc=70.0 + treatment, scale=8.0, size=n_obs) 32 | >>> outcome = np.random.normal(loc=(treatment + mediator - 50), scale=10.0, size=n_obs) 33 | 34 | >>> df = pd.DataFrame( 35 | >>> { 36 | >>> "treatment": treatment, 37 | >>> "mediator": mediator, 38 | >>> "outcome": outcome 39 | >>> } 40 | >>> ) 41 | 42 | 43 | Now we can instantiate the Mediation class: 44 | 45 | >>> from causal_curve import Mediation 46 | >>> med = Mediation( 47 | >>> bootstrap_draws=100, 48 | >>> bootstrap_replicates=100, 49 | >>> spline_order=3, 50 | >>> n_splines=5, 51 | >>> verbose=True, 52 | >>> ) 53 | 54 | 55 | We then fit the data to the `med` object: 56 | 57 | >>> med.fit( 58 | >>> T=df["treatment"], 59 | >>> M=df["mediator"], 60 | >>> y=df["outcome"], 61 | >>> ) 62 | 63 | With the internal models of the mediation test fit with data, we can now run the 64 | `calculate_mediation` method to produce the final report: 65 | 66 | >>> med.calculate_mediation(ci = 0.95) 67 | >>> 68 | >>> ---------------------------------- 69 | >>> Mean indirect effect proportion: 0.5238 (0.5141 - 0.5344) 70 | >>> 71 | >>> Treatment_Value Proportion_Direct_Effect Proportion_Indirect_Effect 72 | >>> 35.1874 0.4743 0.5257 73 | >>> 41.6870 0.4638 0.5362 74 | >>> 44.6997 0.4611 0.5389 75 | >>> 47.5672 0.4745 0.5255 76 | >>> 50.1900 0.4701 0.5299 77 | >>> 52.7526 0.4775 0.5225 78 | >>> 56.0204 0.4727 0.5273 79 | >>> 60.5174 0.4940 0.5060 80 | >>> 66.7243 0.4982 0.5018 81 | 82 | The final analysis tells us that overall, the mediator is estimated to account for 83 | around 52% (+/- 1%) of the effect of the treatment on the outcome. This indicates that 84 | moderate mediation is occurring here. The remaining 48% occurs through a direct effect of the 85 | treatment on the outcome. 86 | 87 | So long as we are confident that the mediator doesn't play another role in the causal graph 88 | (it isn't a confounder of the treatment and outcome association), this supports the idea that 89 | the mediator is in fact a mediator. 90 | 91 | The report also shows how this mediation effect various as a function of the continuous treatment. 92 | In this case, it looks the effect is relatively flat (as expected). With a little processing 93 | and some basic interpolation, we can plot this mediation effect: 94 | 95 | .. image:: ../imgs/mediation/mediation_effect.png 96 | 97 | 98 | 99 | References 100 | ---------- 101 | 102 | Imai K., Keele L., Tingley D. A General Approach to Causal Mediation Analysis. Psychological 103 | Methods. 15(4), 2010, pp.309–334. 104 | 105 | Havumaki J., Eisenberg M.C. Mathematical modeling of directed acyclic graphs to explore 106 | competing causal mechanisms underlying epidemiological study data. medRxiv preprint. 107 | doi: https://doi.org/10.1101/19007922. Accessed June 23, 2020. 108 | -------------------------------------------------------------------------------- /causal_curve/gps_classifier.py: -------------------------------------------------------------------------------- 1 | """ 2 | Defines the Generalized Prospensity Score (GPS) classifier model class 3 | """ 4 | from pprint import pprint 5 | 6 | import numpy as np 7 | from scipy.special import logit 8 | 9 | from causal_curve.gps_core import GPS_Core 10 | 11 | 12 | class GPS_Classifier(GPS_Core): 13 | """ 14 | A GPS tool that handles binary outcomes. Inherits the GPS_core 15 | base class. See that base class code its docstring for more details. 16 | 17 | 18 | Methods 19 | ---------- 20 | 21 | estimate_log_odds: (self, T) 22 | Calculates the predicted log odds of the highest integer class. Can 23 | only be used when the outcome is binary. 24 | 25 | """ 26 | 27 | def __init__( 28 | self, 29 | gps_family=None, 30 | treatment_grid_num=100, 31 | lower_grid_constraint=0.01, 32 | upper_grid_constraint=0.99, 33 | spline_order=3, 34 | n_splines=30, 35 | lambda_=0.5, 36 | max_iter=100, 37 | random_seed=None, 38 | verbose=False, 39 | ): 40 | 41 | self.gps_family = gps_family 42 | self.treatment_grid_num = treatment_grid_num 43 | self.lower_grid_constraint = lower_grid_constraint 44 | self.upper_grid_constraint = upper_grid_constraint 45 | self.spline_order = spline_order 46 | self.n_splines = n_splines 47 | self.lambda_ = lambda_ 48 | self.max_iter = max_iter 49 | self.random_seed = random_seed 50 | self.verbose = verbose 51 | 52 | # Validate the params 53 | self._validate_init_params() 54 | self.rand_seed_wrapper() 55 | 56 | self.if_verbose_print("Using the following params for GPS model:") 57 | if self.verbose: 58 | pprint(self.get_params(), indent=4) 59 | 60 | def _cdrc_predictions_binary(self, ci): 61 | """Returns the predictions of CDRC for each value of the treatment grid. Essentially, 62 | we're making predictions using the original treatment and gps_at_grid. 63 | To be used when the outcome of interest is binary. 64 | """ 65 | # To keep track of cdrc predictions, we create an empty 2d array of shape 66 | # (n_samples, treatment_grid_num, 2). The last dimension is of length 2 because 67 | # we are going to keep track of the point estimate (log-odds) of the prediction, as well as 68 | # the standard error of the prediction interval (again, this is for the log odds) 69 | cdrc_preds = np.zeros((len(self.T), self.treatment_grid_num, 2), dtype=float) 70 | 71 | # Loop through each of the grid values, predict point estimate and get prediction interval 72 | for i in range(0, self.treatment_grid_num): 73 | 74 | temp_T = np.repeat(self.grid_values[i], repeats=len(self.T)) 75 | temp_gps = self.gps_at_grid[:, i] 76 | 77 | temp_cdrc_preds = logit( 78 | self.gam_results.predict_proba(np.column_stack((temp_T, temp_gps))) 79 | ) 80 | 81 | temp_cdrc_interval = logit( 82 | self.gam_results.confidence_intervals( 83 | np.column_stack((temp_T, temp_gps)), width=ci 84 | ) 85 | ) 86 | 87 | standard_error = ( 88 | temp_cdrc_interval[:, 1] - temp_cdrc_preds 89 | ) / self.calculate_z_score(ci) 90 | 91 | cdrc_preds[:, i, 0] = temp_cdrc_preds 92 | cdrc_preds[:, i, 1] = standard_error 93 | 94 | return np.round(cdrc_preds, 3) 95 | 96 | def estimate_log_odds(self, T): 97 | """Calculates the estimated log odds of the highest integer class. Can 98 | only be used when the outcome is binary. Can be estimate for a single 99 | data point or can be run in batch for many observations. Extrapolation will produce 100 | untrustworthy results; the provided treatment should be within 101 | the range of the training data. 102 | 103 | Parameters 104 | ---------- 105 | T: Numpy array, shape (n_samples,) 106 | A continuous treatment variable. 107 | 108 | Returns 109 | ---------- 110 | array: Numpy array 111 | Contains a set of log odds 112 | """ 113 | if self.outcome_type != "binary": 114 | raise TypeError("Your outcome must be binary to use this function!") 115 | 116 | return np.apply_along_axis(self._create_log_odds, 0, T.reshape(1, -1)) 117 | 118 | def _create_log_odds(self, T): 119 | """Take a single treatment value and produces the log odds of the higher 120 | integer class, in the case of a binary outcome. 121 | """ 122 | return logit( 123 | self.gam_results.predict_proba( 124 | np.array([T[0], self.gps_function(T).mean()]).reshape(1, -1) 125 | ) 126 | ) 127 | -------------------------------------------------------------------------------- /docs/changelog.rst: -------------------------------------------------------------------------------- 1 | .. _changelog: 2 | 3 | ========== 4 | Change Log 5 | ========== 6 | 7 | Version 1.0.6 8 | ------------- 9 | - Latest version of python black can now run. Linted tmle_core.py. 10 | 11 | Version 1.0.5 12 | ------------- 13 | - Removed `master` branch, replaced with `main` 14 | - Removed all mention of `master` branch from documentation 15 | 16 | Version 1.0.4 17 | ------------- 18 | - Fixed TMLE plot and code errors in documentation 19 | 20 | Version 1.0.3 21 | ------------- 22 | - Fixed bug with `random_seed` functionality in all tools 23 | 24 | Version 1.0.2 25 | ------------- 26 | - Updated end-to-end example notebook in `/examples` folder 27 | - Fixed various class docstrings if they still reference old v0.5.2 API 28 | - Fixed bug where custom class input parameters weren't being used 29 | 30 | 31 | Version 1.0.1 32 | ------------- 33 | - Added to TMLE overview in the docs (including plot) 34 | 35 | 36 | Version 1.0.0: **Major Update** 37 | ------------------------------- 38 | - Overhaul of the TMLE tool to make it dramatically more accurate and user-friendly. 39 | - Improved TMLE example documentation 40 | - Much like with `scikit-learn`, there are now separate model classes used for predicting binary or continuous outcomes 41 | - Updating documentation to reflect API changes 42 | - Added more tests 43 | - Linted with `pylint` (added `.pylintrc` file) 44 | 45 | 46 | Version 0.5.2 47 | ------------- 48 | - Fixed bug that prevented `causal-curve` modules from being shown in Sphinx documentation 49 | - Augmented tests to capture more error states and improve code coverage 50 | 51 | 52 | Version 0.5.1 53 | ------------- 54 | - Removed working test file 55 | 56 | 57 | Version 0.5.0 58 | ------------- 59 | - Added new `predict`, `predict_interval`, and `predict_log_odds` methods to GPS tool 60 | - Slight updates to doc to reflect new features 61 | 62 | 63 | Version 0.4.1 64 | ------------- 65 | - When using GPS tool with a treatment with negative values, only the normal GLM family can be picked 66 | - Added 'sphinx_rtd_theme' to dependency list in `.travis.yml` and `install.rst` 67 | - core.py base class now has __version__ attribute 68 | 69 | 70 | Version 0.4.0 71 | ------------- 72 | - Added support for binary outcomes in GPS tool 73 | - Small changes to repo README 74 | 75 | 76 | Version 0.3.8 77 | ------------- 78 | - Added citation (yay!) 79 | 80 | 81 | Version 0.3.7 82 | ------------- 83 | - Bumped version for PyPi 84 | 85 | 86 | Version 0.3.6 87 | ------------- 88 | - Fixed bug in Mediation.calculate_mediation that would clip treatments < 0 or > 1 89 | - Fixed incorrect horizontal axis labels in lead example 90 | - Fixed typos in documentation 91 | - Added links to resources so users could learn more about causal inference theory 92 | 93 | 94 | Version 0.3.5 95 | ------------- 96 | - Re-organized documentation 97 | - Added `Introduction` section to explain purpose and need for the package 98 | 99 | 100 | Version 0.3.4 101 | ------------- 102 | - Removed XGBoost as dependency. 103 | - Now using sklearn's gradient boosting implementation. 104 | 105 | 106 | Version 0.3.3 107 | ------------- 108 | - Misc edits to paper and bibliography 109 | 110 | 111 | Version 0.3.2 112 | ------------- 113 | - Fixed random seed issue with Mediation tool 114 | - Fixed Mediation bootstrap issue. Confidence interval bounded [0,1] 115 | - Fixed issue with all classes not accepting non-sequential indicies in pandas Dataframes/Series 116 | - Class init checks for all classes now print cleaner errors if bad input 117 | 118 | 119 | Version 0.3.1 120 | ------------- 121 | - Small fixes to end-to-end example documentation 122 | - Enlarged image in paper 123 | 124 | 125 | Version 0.3.0 126 | ------------- 127 | - Added full, end-to-end example of package usage to documentation 128 | - Cleaned up documentation 129 | - Added example folder with end-to-end notebook 130 | - Added manuscript to paper folder 131 | 132 | 133 | Version 0.2.4 134 | ------------- 135 | - Strengthened unit tests 136 | 137 | 138 | Version 0.2.3 139 | ------------- 140 | - codecov integration 141 | 142 | 143 | Version 0.2.2 144 | ------------- 145 | - Travis CI integration 146 | 147 | 148 | Version 0.2.1 149 | ------------- 150 | - Fixed Mediation tool error / removed `tqdm` from requirements 151 | - Misc documentation cleanup / revisions 152 | 153 | 154 | Version 0.2.0 155 | ------------- 156 | - Added new Mediation class 157 | - Updated documentation to reflect this 158 | - Added unit and integration tests for Mediation methods 159 | 160 | 161 | Version 0.1.3 162 | ------------- 163 | - Simplifying unit and integration tests. 164 | 165 | 166 | Version 0.1.2 167 | ------------- 168 | 169 | - Added unit and integration tests 170 | 171 | 172 | Version 0.1.1 173 | ------------- 174 | 175 | - setup.py fix 176 | 177 | 178 | Version 0.1.0 179 | ------------- 180 | 181 | - Added new TMLE class 182 | - Updated documentation to reflect new TMLE method 183 | - Renamed CDRC method to more appropriate `GPS` method 184 | - Small docstring corrections to GPS method 185 | 186 | 187 | Version 0.0.10 188 | -------------- 189 | 190 | - Bug fix in GPS estimation method 191 | 192 | 193 | Version 0.0.9 194 | ------------- 195 | 196 | - Project created 197 | -------------------------------------------------------------------------------- /docs/intro.rst: -------------------------------------------------------------------------------- 1 | .. _intro: 2 | 3 | ============================ 4 | Introduction to causal-curve 5 | ============================ 6 | 7 | In academia and industry, randomized controlled experiments (or simply experiments or "A/B tests") are considered the gold standard approach for assessing the true, causal impact 8 | of a treatment or intervention. For example: 9 | 10 | * We want to increase the number of times per day new customers log into our business's website. Will it help if we send daily emails out to our customers? We take a group of 2000 new business customers and half is randomly chosen to receive daily emails while the other half receives one email per week. We follow both groups forward in time for a month compare each group's average number of logins per day. 11 | 12 | However, for ethical or financial reasons, experiments may not always be feasible to carry out. 13 | 14 | * It's not ethical to randomly assign some people to receive a possible carcinogen in pill form while others receive a sugar pill, and then see which group is more likely to develop cancer. 15 | * It's not feasible to increase the household incomes of some New York neighborhoods, while leaving others unchanged to see if changing a neighborhood's income inequality would improve the local crime rate. 16 | 17 | "Causal inference" methods are a set of approaches that attempt to estimate causal effects 18 | from observational rather than experimental data, correcting for the biases that are inherent 19 | to analyzing observational data (e.g. confounding and selection bias) [@Hernán:2020]. 20 | 21 | As long as you have varying observational data on some treatment, your outcome of interest, 22 | and potentially confounding variables across your units of analysis (in addition to meeting the assumptions described below), 23 | then you can essentially estimate the results of a proper experiment and make causal claims. 24 | 25 | 26 | Interpreting the causal curve 27 | ------------------------------ 28 | 29 | Two of the methods contained within this package produce causal curves for continuous treatments 30 | (see the GPS and TMLE methods). Both continuous and binary treatments can be modeled 31 | (only the `GPS_Classifier` tool can handle binary outcomes). 32 | 33 | **Continuous outcome:** 34 | 35 | .. image:: ../imgs/welcome_plot.png 36 | 37 | Using the above causal curve as an example, we see that employing a treatment value between 50 - 60 38 | causally produces the highest outcome values. We also see that 39 | the treatment produces a smaller effect if lower or higher than that range. The confidence 40 | intervals become wider on the parts of the curve where we have fewer data points (near the minimum and 41 | maximum treatment values). 42 | 43 | This curve differs from a simple bivariate plot of the treatment and outcome or even a similar-looking plot 44 | generated through standard multivariable regression modeling in a few important ways: 45 | 46 | * This curve represents the estimated causal effect of a treatment on an outcome, not the association between treatment and outcome. 47 | * This curve represents a population-level effect, and should not be used to infer effects at the individual-level (or whatever the unit of analysis is). 48 | * To generate a similar-looking plot using multivariable regression, you would have to hold covariates constant, and any treatment effect that is inferred occurs within the levels of the covariates specified in the model. The causal curve averages out across all of these strata and gives us the population marginal effect. 49 | 50 | **Binary outcome:** 51 | 52 | .. image:: ../imgs/binary_OR_fig.png 53 | 54 | In the case of binary outcome, the `GPS_Classifier` tool can be used to estimate a curve of odds ratios. Every 55 | point on the curve is relative to the lowest treatment value. The highest effect (relative to the lowest treatment value) 56 | is around a treatment value of -1.2. At this point in the treatment, the odds of a positive class 57 | occurring is 5.6 times higher compared with the lowest treatment value. This curve is always on 58 | the relative scale. This is why the odds ratio for the lowest point is always 1.0, because it is 59 | relative to itself. Odds ratios are bounded [0, inf] and cannot take on a negative value. Note that 60 | the confidence intervals at any given point in the curve isn't symmetric. 61 | 62 | 63 | A caution about causal inference assumptions 64 | -------------------------------------------- 65 | 66 | There is a well-documented set of assumptions one must make to infer causal effects from 67 | observational data. These are covered elsewhere in more detail, but briefly: 68 | 69 | - Causes always occur before effects: The treatment variable needs to have occurred before the outcome. 70 | - SUTVA: The treatment status of a given individual does not affect the potential outcomes of any other individuals. 71 | - Positivity: Any individual has a positive probability of receiving all values of the treatment variable. 72 | - Ignorability: All major confounding variables are included in the data you provide. 73 | 74 | Violations of these assumptions will lead to biased results and incorrect conclusions! 75 | 76 | In addition, any covariates that are included in `causal-curve` models are assumed to only 77 | be **confounding** variables. 78 | 79 | None of the methods provided in causal-curve rely on inference via instrumental variables, they only 80 | rely on the data from the observed treatment, confounders, and the outcome of interest (like the above GPS example). 81 | 82 | 83 | References 84 | ---------- 85 | 86 | Hernán M. and Robins J. Causal Inference: What If. Chapman & Hall, 2020. 87 | 88 | Ahern J, Hubbard A, and Galea S. Estimating the Effects of Potential Public Health Interventions 89 | on Population Disease Burden: A Step-by-Step Illustration of Causal Inference Methods. American Journal of Epidemiology. 90 | 169(9), 2009. pp.1140–1147. 91 | -------------------------------------------------------------------------------- /causal_curve/tmle_regressor.py: -------------------------------------------------------------------------------- 1 | """ 2 | Defines the Targetted Maximum likelihood Estimation (TMLE) regressor model class 3 | """ 4 | from pprint import pprint 5 | 6 | import numpy as np 7 | 8 | from causal_curve.tmle_core import TMLE_Core 9 | 10 | 11 | class TMLE_Regressor(TMLE_Core): 12 | """ 13 | A TMLE tool that handles continuous outcomes. Inherits the TMLE_core 14 | base class. See that base class code its docstring for more details. 15 | 16 | Methods 17 | ---------- 18 | 19 | point_estimate: (self, T) 20 | Calculates point estimate within the CDRC given treatment values. 21 | Can only be used when outcome is continuous. 22 | """ 23 | 24 | def __init__( 25 | self, 26 | treatment_grid_num=100, 27 | lower_grid_constraint=0.01, 28 | upper_grid_constraint=0.99, 29 | n_estimators=200, 30 | learning_rate=0.01, 31 | max_depth=3, 32 | bandwidth=0.5, 33 | random_seed=None, 34 | verbose=False, 35 | ): 36 | 37 | self.treatment_grid_num = treatment_grid_num 38 | self.lower_grid_constraint = lower_grid_constraint 39 | self.upper_grid_constraint = upper_grid_constraint 40 | self.n_estimators = n_estimators 41 | self.learning_rate = learning_rate 42 | self.max_depth = max_depth 43 | self.bandwidth = bandwidth 44 | self.random_seed = random_seed 45 | self.verbose = verbose 46 | 47 | # Validate the params 48 | self._validate_init_params() 49 | self.rand_seed_wrapper() 50 | 51 | self.if_verbose_print("Using the following params for TMLE model:") 52 | if self.verbose: 53 | pprint(self.get_params(), indent=4) 54 | 55 | def _cdrc_predictions_continuous(self, ci): 56 | """Returns the predictions of CDRC for each value of the treatment grid. Essentially, 57 | we're making predictions using the original treatment against the pseudo-outcome. 58 | To be used when the outcome of interest is continuous. 59 | """ 60 | 61 | # To keep track of cdrc predictions, we create an empty 2d array of shape 62 | # (treatment_grid_num, 4). The last dimension is of length 4 because 63 | # we are going to keep track of the treatment, point estimate of the prediction, as well as 64 | # the lower and upper bounds of the prediction interval 65 | cdrc_preds = np.zeros((self.treatment_grid_num, 4), dtype=float) 66 | 67 | # Loop through each of the grid values, get point estimate and get estimate interval 68 | for i in range(0, self.treatment_grid_num): 69 | temp_T = self.grid_values[i] 70 | temp_cdrc_preds = self.final_gam.predict(temp_T) 71 | temp_cdrc_interval = self.final_gam.confidence_intervals(temp_T, width=ci) 72 | temp_cdrc_lower_bound = temp_cdrc_interval[:, 0] 73 | temp_cdrc_upper_bound = temp_cdrc_interval[:, 1] 74 | cdrc_preds[i, 0] = temp_T 75 | cdrc_preds[i, 1] = temp_cdrc_preds 76 | cdrc_preds[i, 2] = temp_cdrc_lower_bound 77 | cdrc_preds[i, 3] = temp_cdrc_upper_bound 78 | 79 | return cdrc_preds 80 | 81 | def point_estimate(self, T): 82 | """Calculates point estimate within the CDRC given treatment values. 83 | Can only be used when outcome is continuous. Can be estimate for a single 84 | data point or can be run in batch for many observations. Extrapolation will produce 85 | untrustworthy results; the provided treatment should be within 86 | the range of the training data. 87 | 88 | Parameters 89 | ---------- 90 | T: Numpy array, shape (n_samples,) 91 | A continuous treatment variable. 92 | 93 | Returns 94 | ---------- 95 | array: Numpy array 96 | Contains a set of CDRC point estimates 97 | 98 | """ 99 | return np.apply_along_axis(self._create_point_estimate, 0, T.reshape(1, -1)) 100 | 101 | def _create_point_estimate(self, T): 102 | """Takes a single treatment value and produces a single point estimate 103 | in the case of a continuous outcome. 104 | """ 105 | return self.final_gam.predict(np.array([T]).reshape(1, -1)) 106 | 107 | def point_estimate_interval(self, T, ci=0.95): 108 | """Calculates the prediction confidence interval associated with a point estimate 109 | within the CDRC given treatment values. Can only be used 110 | when outcome is continuous. Can be estimate for a single data point 111 | or can be run in batch for many observations. Extrapolation will produce 112 | untrustworthy results; the provided treatment should be within 113 | the range of the training data. 114 | 115 | Parameters 116 | ---------- 117 | T: Numpy array, shape (n_samples,) 118 | A continuous treatment variable. 119 | ci: float (default = 0.95) 120 | The desired confidence interval to produce. Default value is 0.95, corresponding 121 | to 95% confidence intervals. bounded (0, 1.0). 122 | 123 | Returns 124 | ---------- 125 | array: Numpy array 126 | Contains a set of CDRC prediction intervals ([lower bound, higher bound]) 127 | 128 | """ 129 | return np.apply_along_axis( 130 | self._create_point_estimate_interval, 0, T.reshape(1, -1), width=ci 131 | ).T.reshape(-1, 2) 132 | 133 | def _create_point_estimate_interval(self, T, width): 134 | """Takes a single treatment value and produces confidence interval 135 | associated with a point estimate in the case of a continuous outcome. 136 | """ 137 | return self.final_gam.prediction_intervals( 138 | np.array([T]).reshape(1, -1), width=width 139 | ) 140 | -------------------------------------------------------------------------------- /tests/unit/test_gps_regressor.py: -------------------------------------------------------------------------------- 1 | """ Unit tests for the GPS_Core class """ 2 | 3 | import numpy as np 4 | from pygam import LinearGAM 5 | import pytest 6 | 7 | from causal_curve import GPS_Regressor 8 | 9 | 10 | def test_point_estimate_method_good(GPS_fitted_model_continuous_fixture): 11 | """ 12 | Tests the GPS `point_estimate` method using appropriate data (with a continuous outcome) 13 | """ 14 | 15 | observed_result = GPS_fitted_model_continuous_fixture.point_estimate(np.array([50])) 16 | assert isinstance(observed_result[0][0], float) 17 | assert len(observed_result[0]) == 1 18 | 19 | observed_result = GPS_fitted_model_continuous_fixture.point_estimate( 20 | np.array([40, 50, 60]) 21 | ) 22 | assert isinstance(observed_result[0][0], float) 23 | assert len(observed_result[0]) == 3 24 | 25 | 26 | def test_point_estimate_interval_method_good(GPS_fitted_model_continuous_fixture): 27 | """ 28 | Tests the GPS `point_estimate_interval` method using appropriate data (with a continuous outcome) 29 | """ 30 | observed_result = GPS_fitted_model_continuous_fixture.point_estimate_interval( 31 | np.array([50]) 32 | ) 33 | assert isinstance(observed_result[0][0], float) 34 | assert observed_result.shape == (1, 2) 35 | 36 | observed_result = GPS_fitted_model_continuous_fixture.point_estimate_interval( 37 | np.array([40, 50, 60]) 38 | ) 39 | assert isinstance(observed_result[0][0], float) 40 | assert observed_result.shape == (3, 2) 41 | 42 | 43 | def test_point_estimate_method_bad(GPS_fitted_model_continuous_fixture): 44 | """ 45 | Tests the GPS `point_estimate` method using appropriate data (with a continuous outcome) 46 | """ 47 | 48 | GPS_fitted_model_continuous_fixture.outcome_type = "binary" 49 | 50 | with pytest.raises(Exception) as bad: 51 | observed_result = GPS_fitted_model_continuous_fixture.point_estimate( 52 | np.array([50]) 53 | ) 54 | 55 | 56 | def test_point_estimate_interval_method_bad(GPS_fitted_model_continuous_fixture): 57 | """ 58 | Tests the GPS `point_estimate_interval` method using appropriate data (with a continuous outcome) 59 | """ 60 | 61 | GPS_fitted_model_continuous_fixture.outcome_type = "binary" 62 | 63 | with pytest.raises(Exception) as bad: 64 | observed_result = GPS_fitted_model_continuous_fixture.point_estimate_interval( 65 | np.array([50]) 66 | ) 67 | 68 | 69 | @pytest.mark.parametrize( 70 | ( 71 | "gps_family", 72 | "treatment_grid_num", 73 | "lower_grid_constraint", 74 | "upper_grid_constraint", 75 | "spline_order", 76 | "n_splines", 77 | "lambda_", 78 | "max_iter", 79 | "random_seed", 80 | "verbose", 81 | ), 82 | [ 83 | (546, 10, 0, 1.0, 3, 10, 0.5, 100, 100, True), 84 | ("linear", 10, 0, 1.0, 3, 10, 0.5, 100, 100, True), 85 | (None, "hehe", 0, 1.0, 3, 10, 0.5, 100, 100, True), 86 | (None, 2, 0, 1.0, 3, 10, 0.5, 100, 100, True), 87 | (None, 100000, 0, 1.0, 3, 10, 0.5, 100, 100, True), 88 | (None, 10, "hehe", 1.0, 3, 10, 0.5, 100, 100, True), 89 | (None, 10, -1.0, 1.0, 3, 10, 0.5, 100, 100, True), 90 | (None, 10, 1.5, 1.0, 3, 10, 0.5, 100, 100, True), 91 | (None, 10, 0, "hehe", 3, 10, 0.5, 100, 100, True), 92 | (None, 10, 0, 1.5, 3, 10, 0.5, 100, 100, True), 93 | (None, 100, -3.0, 0.99, 3, 30, 0.5, 100, None, True), 94 | (None, 100, 0.01, 1, 3, 30, 0.5, 100, None, True), 95 | (None, 100, 0.01, -4.5, 3, 30, 0.5, 100, None, True), 96 | (None, 100, 0.01, 5.5, 3, 30, 0.5, 100, None, True), 97 | (None, 100, 0.99, 0.01, 3, 30, 0.5, 100, None, True), 98 | (None, 100, 0.01, 0.99, 3.0, 30, 0.5, 100, None, True), 99 | (None, 100, 0.01, 0.99, -2, 30, 0.5, 100, None, True), 100 | (None, 100, 0.01, 0.99, 30, 30, 0.5, 100, None, True), 101 | (None, 100, 0.01, 0.99, 3, 30.0, 0.5, 100, None, True), 102 | (None, 100, 0.01, 0.99, 3, -2, 0.5, 100, None, True), 103 | (None, 100, 0.01, 0.99, 3, 500, 0.5, 100, None, True), 104 | (None, 100, 0.01, 0.99, 3, 30, 0.5, 100.0, None, True), 105 | (None, 100, 0.01, 0.99, 3, 30, 0.5, -100, None, True), 106 | (None, 100, 0.01, 0.99, 3, 30, 0.5, 10000000000, None, True), 107 | (None, 100, 0.01, 0.99, 3, 30, 0.5, 100, 234.5, True), 108 | (None, 100, 0.01, 0.99, 3, 30, 0.5, 100, -5, True), 109 | (None, 100, 0.01, 0.99, 3, 30, 0.5, 100, None, 4.0), 110 | (None, 10, 0, -1, 3, 10, 0.5, 100, 100, True), 111 | (None, 10, 0, 1, 3, 10, 0.5, 100, 100, True), 112 | (None, 10, 0, 1, "splines", 10, 0.5, 100, 100, True), 113 | (None, 10, 0, 1, 0, 10, 0.5, 100, 100, True), 114 | (None, 10, 0, 1, 200, 10, 0.5, 100, 100, True), 115 | (None, 10, 0, 1, 3, 0, 0.5, 100, 100, True), 116 | (None, 10, 0, 1, 3, 1e6, 0.5, 100, 100, True), 117 | (None, 10, 0, 1, 3, 10, 0.5, 100, 100, True), 118 | (None, 10, 0, 1, 3, 10, 0.5, "many", 100, True), 119 | (None, 10, 0, 1, 3, 10, 0.5, 5, 100, True), 120 | (None, 10, 0, 1, 3, 10, 0.5, 1e7, 100, True), 121 | (None, 10, 0, 1, 3, 10, 0.5, 100, "random", True), 122 | (None, 10, 0, 1, 3, 10, 0.5, 100, -1.5, True), 123 | (None, 10, 0, 1, 3, 10, 0.5, 100, 111, "True"), 124 | (None, 100, 0.01, 0.99, 3, 30, "lambda", 100, None, True), 125 | (None, 100, 0.01, 0.99, 3, 30, -1.0, 100, None, True), 126 | (None, 100, 0.01, 0.99, 3, 30, 2000.0, 100, None, True), 127 | ], 128 | ) 129 | def test_bad_gps_instantiation( 130 | gps_family, 131 | treatment_grid_num, 132 | lower_grid_constraint, 133 | upper_grid_constraint, 134 | spline_order, 135 | n_splines, 136 | lambda_, 137 | max_iter, 138 | random_seed, 139 | verbose, 140 | ): 141 | """ 142 | Tests for exceptions when the GPS class if call with bad inputs. 143 | """ 144 | with pytest.raises(Exception) as bad: 145 | GPS_Regressor( 146 | gps_family=gps_family, 147 | treatment_grid_num=treatment_grid_num, 148 | lower_grid_constraint=lower_grid_constraint, 149 | upper_grid_constraint=upper_grid_constraint, 150 | spline_order=spline_order, 151 | n_splines=n_splines, 152 | lambda_=lambda_, 153 | max_iter=max_iter, 154 | random_seed=random_seed, 155 | verbose=verbose, 156 | ) 157 | -------------------------------------------------------------------------------- /causal_curve/gps_regressor.py: -------------------------------------------------------------------------------- 1 | """ 2 | Defines the Generalized Prospensity Score (GPS) regressor model class 3 | """ 4 | from pprint import pprint 5 | 6 | import numpy as np 7 | 8 | from causal_curve.gps_core import GPS_Core 9 | 10 | 11 | class GPS_Regressor(GPS_Core): 12 | """ 13 | A GPS tool that handles continuous outcomes. Inherits the GPS_core 14 | base class. See that base class code its docstring for more details. 15 | 16 | Methods 17 | ---------- 18 | 19 | point_estimate: (self, T) 20 | Calculates point estimate within the CDRC given treatment values. 21 | Can only be used when outcome is continuous. 22 | 23 | point_estimate_interval: (self, T, ci) 24 | Calculates the prediction confidence interval associated with a point estimate 25 | within the CDRC given treatment values. Can only be used when outcome is continuous. 26 | 27 | """ 28 | 29 | def __init__( 30 | self, 31 | gps_family=None, 32 | treatment_grid_num=100, 33 | lower_grid_constraint=0.01, 34 | upper_grid_constraint=0.99, 35 | spline_order=3, 36 | n_splines=30, 37 | lambda_=0.5, 38 | max_iter=100, 39 | random_seed=None, 40 | verbose=False, 41 | ): 42 | 43 | self.gps_family = gps_family 44 | self.treatment_grid_num = treatment_grid_num 45 | self.lower_grid_constraint = lower_grid_constraint 46 | self.upper_grid_constraint = upper_grid_constraint 47 | self.spline_order = spline_order 48 | self.n_splines = n_splines 49 | self.lambda_ = lambda_ 50 | self.max_iter = max_iter 51 | self.random_seed = random_seed 52 | self.verbose = verbose 53 | 54 | # Validate the params 55 | self._validate_init_params() 56 | self.rand_seed_wrapper() 57 | 58 | self.if_verbose_print("Using the following params for GPS model:") 59 | if self.verbose: 60 | pprint(self.get_params(), indent=4) 61 | 62 | def _cdrc_predictions_continuous(self, ci): 63 | """Returns the predictions of CDRC for each value of the treatment grid. Essentially, 64 | we're making predictions using the original treatment and gps_at_grid. 65 | To be used when the outcome of interest is continuous. 66 | """ 67 | 68 | # To keep track of cdrc predictions, we create an empty 3d array of shape 69 | # (n_samples, treatment_grid_num, 3). The last dimension is of length 3 because 70 | # we are going to keep track of the point estimate of the prediction, as well as 71 | # the lower and upper bounds of the prediction interval 72 | cdrc_preds = np.zeros((len(self.T), self.treatment_grid_num, 3), dtype=float) 73 | 74 | # Loop through each of the grid values, predict point estimate and get prediction interval 75 | for i in range(0, self.treatment_grid_num): 76 | temp_T = np.repeat(self.grid_values[i], repeats=len(self.T)) 77 | temp_gps = self.gps_at_grid[:, i] 78 | temp_cdrc_preds = self.gam_results.predict( 79 | np.column_stack((temp_T, temp_gps)) 80 | ) 81 | temp_cdrc_interval = self.gam_results.confidence_intervals( 82 | np.column_stack((temp_T, temp_gps)), width=ci 83 | ) 84 | temp_cdrc_lower_bound = temp_cdrc_interval[:, 0] 85 | temp_cdrc_upper_bound = temp_cdrc_interval[:, 1] 86 | cdrc_preds[:, i, 0] = temp_cdrc_preds 87 | cdrc_preds[:, i, 1] = temp_cdrc_lower_bound 88 | cdrc_preds[:, i, 2] = temp_cdrc_upper_bound 89 | 90 | return np.round(cdrc_preds, 3) 91 | 92 | def point_estimate(self, T): 93 | """Calculates point estimate within the CDRC given treatment values. 94 | Can only be used when outcome is continuous. Can be estimate for a single 95 | data point or can be run in batch for many observations. Extrapolation will produce 96 | untrustworthy results; the provided treatment should be within 97 | the range of the training data. 98 | 99 | Parameters 100 | ---------- 101 | T: Numpy array, shape (n_samples,) 102 | A continuous treatment variable. 103 | 104 | Returns 105 | ---------- 106 | array: Numpy array 107 | Contains a set of CDRC point estimates 108 | 109 | """ 110 | if self.outcome_type != "continuous": 111 | raise TypeError("Your outcome must be continuous to use this function!") 112 | 113 | return np.apply_along_axis(self._create_point_estimate, 0, T.reshape(1, -1)) 114 | 115 | def _create_point_estimate(self, T): 116 | """Takes a single treatment value and produces a single point estimate 117 | in the case of a continuous outcome. 118 | """ 119 | return self.gam_results.predict( 120 | np.array([T[0], self.gps_function(T).mean()]).reshape(1, -1) 121 | ) 122 | 123 | def point_estimate_interval(self, T, ci=0.95): 124 | """Calculates the prediction confidence interval associated with a point estimate 125 | within the CDRC given treatment values. Can only be used 126 | when outcome is continuous. Can be estimate for a single data point 127 | or can be run in batch for many observations. Extrapolation will produce 128 | untrustworthy results; the provided treatment should be within 129 | the range of the training data. 130 | 131 | Parameters 132 | ---------- 133 | T: Numpy array, shape (n_samples,) 134 | A continuous treatment variable. 135 | ci: float (default = 0.95) 136 | The desired confidence interval to produce. Default value is 0.95, corresponding 137 | to 95% confidence intervals. bounded (0, 1.0). 138 | 139 | Returns 140 | ---------- 141 | array: Numpy array 142 | Contains a set of CDRC prediction intervals ([lower bound, higher bound]) 143 | 144 | """ 145 | if self.outcome_type != "continuous": 146 | raise TypeError("Your outcome must be continuous to use this function!") 147 | 148 | return np.apply_along_axis( 149 | self._create_point_estimate_interval, 0, T.reshape(1, -1), width=ci 150 | ).T.reshape(-1, 2) 151 | 152 | def _create_point_estimate_interval(self, T, width): 153 | """Takes a single treatment value and produces confidence interval 154 | associated with a point estimate in the case of a continuous outcome. 155 | """ 156 | return self.gam_results.prediction_intervals( 157 | np.array([T[0], self.gps_function(T).mean()]).reshape(1, -1), width=width 158 | ) 159 | -------------------------------------------------------------------------------- /tests/conftest.py: -------------------------------------------------------------------------------- 1 | """Common fixtures for tests using pytest framework""" 2 | 3 | import numpy as np 4 | import pandas as pd 5 | import pytest 6 | from scipy.stats import norm 7 | 8 | from causal_curve import GPS_Regressor, GPS_Classifier, TMLE_Regressor 9 | 10 | 11 | @pytest.fixture(scope="module") 12 | def continuous_dataset_fixture(): 13 | """Returns full_continuous_example_dataset (with a continuous outcome)""" 14 | return full_continuous_example_dataset() 15 | 16 | 17 | def full_continuous_example_dataset(): 18 | """Example dataset with a treatment, two covariates, and continuous outcome variable""" 19 | 20 | np.random.seed(500) 21 | 22 | n_obs = 500 23 | 24 | treatment = np.random.normal(loc=50.0, scale=10.0, size=n_obs) 25 | x_1 = np.random.normal(loc=50.0, scale=10.0, size=n_obs) 26 | x_2 = np.random.normal(loc=0, scale=10.0, size=n_obs) 27 | outcome = treatment + x_1 + x_2 + np.random.normal(loc=50.0, scale=3.0, size=n_obs) 28 | 29 | fixture = pd.DataFrame( 30 | {"treatment": treatment, "x1": x_1, "x2": x_2, "outcome": outcome} 31 | ) 32 | fixture.reset_index(drop=True, inplace=True) 33 | 34 | return fixture 35 | 36 | 37 | @pytest.fixture(scope="module") 38 | def binary_dataset_fixture(): 39 | """Returns full_binary_example_dataset (with a binary outcome)""" 40 | return full_binary_example_dataset() 41 | 42 | 43 | def full_binary_example_dataset(): 44 | """Example dataset with a treatment, two covariates, and binary outcome variable""" 45 | 46 | np.random.seed(500) 47 | treatment = np.linspace( 48 | start=0, 49 | stop=100, 50 | num=100, 51 | ) 52 | x_1 = norm.rvs(size=100, loc=50, scale=5) 53 | outcome = [ 54 | 0, 55 | 0, 56 | 0, 57 | 0, 58 | 0, 59 | 0, 60 | 0, 61 | 0, 62 | 0, 63 | 0, 64 | 0, 65 | 0, 66 | 0, 67 | 0, 68 | 0, 69 | 0, 70 | 0, 71 | 0, 72 | 0, 73 | 0, 74 | 0, 75 | 0, 76 | 0, 77 | 0, 78 | 0, 79 | 0, 80 | 0, 81 | 1, 82 | 0, 83 | 0, 84 | 0, 85 | 0, 86 | 0, 87 | 0, 88 | 0, 89 | 0, 90 | 0, 91 | 0, 92 | 1, 93 | 1, 94 | 0, 95 | 0, 96 | 0, 97 | 0, 98 | 0, 99 | 0, 100 | 0, 101 | 0, 102 | 0, 103 | 1, 104 | 1, 105 | 1, 106 | 1, 107 | 1, 108 | 1, 109 | 0, 110 | 1, 111 | 1, 112 | 1, 113 | 1, 114 | 1, 115 | 0, 116 | 1, 117 | 1, 118 | 1, 119 | 1, 120 | 1, 121 | 0, 122 | 1, 123 | 1, 124 | 1, 125 | 1, 126 | 1, 127 | 1, 128 | 1, 129 | 1, 130 | 1, 131 | 1, 132 | 1, 133 | 1, 134 | 1, 135 | 1, 136 | 1, 137 | 1, 138 | 1, 139 | 1, 140 | 1, 141 | 1, 142 | 1, 143 | 1, 144 | 1, 145 | 1, 146 | 1, 147 | 1, 148 | 1, 149 | 1, 150 | 1, 151 | 1, 152 | 1, 153 | 1, 154 | ] 155 | 156 | fixture = pd.DataFrame({"treatment": treatment, "x1": x_1, "outcome": outcome}) 157 | fixture.reset_index(drop=True, inplace=True) 158 | 159 | return fixture 160 | 161 | 162 | @pytest.fixture(scope="module") 163 | def mediation_fixture(): 164 | """Returns mediation_dataset""" 165 | return mediation_dataset() 166 | 167 | 168 | def mediation_dataset(): 169 | """Example dataset to test / demonstrate mediation with a treatment, 170 | a mediator, and an outcome variable""" 171 | 172 | np.random.seed(500) 173 | 174 | n_obs = 500 175 | 176 | treatment = np.random.normal(loc=50.0, scale=10.0, size=n_obs) 177 | mediator = np.random.normal(loc=70.0 + treatment, scale=8.0, size=n_obs) 178 | outcome = np.random.normal(loc=(treatment + mediator - 50), scale=10.0, size=n_obs) 179 | 180 | fixture = pd.DataFrame( 181 | {"treatment": treatment, "mediator": mediator, "outcome": outcome} 182 | ) 183 | 184 | fixture.reset_index(drop=True, inplace=True) 185 | 186 | return fixture 187 | 188 | 189 | @pytest.fixture(scope="module") 190 | def GPS_fitted_model_continuous_fixture(): 191 | """Returns a GPS model that is already fit with data with a continuous outcome""" 192 | return GPS_fitted_model_continuous() 193 | 194 | 195 | def GPS_fitted_model_continuous(): 196 | """Example GPS model that is fit with data including a continuous outcome""" 197 | 198 | df = full_continuous_example_dataset() 199 | 200 | fixture = GPS_Regressor( 201 | treatment_grid_num=10, 202 | lower_grid_constraint=0.0, 203 | upper_grid_constraint=1.0, 204 | spline_order=3, 205 | n_splines=10, 206 | max_iter=100, 207 | random_seed=100, 208 | verbose=True, 209 | ) 210 | fixture.fit( 211 | T=df["treatment"], 212 | X=df["x1"], 213 | y=df["outcome"], 214 | ) 215 | 216 | return fixture 217 | 218 | 219 | @pytest.fixture(scope="module") 220 | def GPS_fitted_model_binary_fixture(): 221 | """Returns a GPS model that is already fit with data with a continuous outcome""" 222 | return GPS_fitted_model_binary() 223 | 224 | 225 | def GPS_fitted_model_binary(): 226 | """Example GPS model that is fit with data including a continuous outcome""" 227 | 228 | df = full_binary_example_dataset() 229 | 230 | fixture = GPS_Classifier( 231 | gps_family="normal", 232 | treatment_grid_num=10, 233 | lower_grid_constraint=0.0, 234 | upper_grid_constraint=1.0, 235 | spline_order=3, 236 | n_splines=10, 237 | max_iter=100, 238 | random_seed=100, 239 | verbose=True, 240 | ) 241 | fixture.fit( 242 | T=df["treatment"], 243 | X=df["x1"], 244 | y=df["outcome"], 245 | ) 246 | 247 | return fixture 248 | 249 | 250 | @pytest.fixture(scope="module") 251 | def TMLE_fitted_model_continuous_fixture(): 252 | """Returns a TMLE model that is already fit with data with a continuous outcome""" 253 | return TMLE_fitted_model_continuous() 254 | 255 | 256 | def TMLE_fitted_model_continuous(): 257 | """Example GPS model that is fit with data including a continuous outcome""" 258 | 259 | df = full_continuous_example_dataset() 260 | 261 | fixture = TMLE_Regressor( 262 | random_seed=100, 263 | verbose=True, 264 | ) 265 | fixture.fit( 266 | T=df["treatment"], 267 | X=df[["x1", "x2"]], 268 | y=df["outcome"], 269 | ) 270 | 271 | return fixture 272 | -------------------------------------------------------------------------------- /docs/contribute.rst: -------------------------------------------------------------------------------- 1 | .. _contribute: 2 | 3 | ================== 4 | Contributing guide 5 | ================== 6 | 7 | Thank you for considering contributing to causal-curve. Contributions from anyone 8 | are welcomed. There are many ways to contribute to the package, such as 9 | reporting bugs, adding new features and improving the documentation. The 10 | following sections give more details on how to contribute. 11 | 12 | **Important links**: 13 | 14 | - The project is hosted on GitHub: https://github.com/ronikobrosly/causal-curve 15 | 16 | 17 | Submitting a bug report or a feature request 18 | -------------------------------------------- 19 | 20 | If you experience a bug using causal-curve or if you would like to see a new 21 | feature being added to the package, feel free to open an issue on GitHub: 22 | https://github.com/ronikobrosly/causal-curve/issues 23 | 24 | Bug report 25 | ^^^^^^^^^^ 26 | 27 | A good bug report usually contains: 28 | 29 | - a description of the bug, 30 | - a self-contained example to reproduce the bug if applicable, 31 | - a description of the difference between the actual and expected results, 32 | - the versions of the dependencies of causal-curve. 33 | 34 | The last point can easily be done with the following commands:: 35 | 36 | import numpy; print("NumPy", numpy.__version__) 37 | 38 | These guidelines make reproducing the bug easier, which make fixing it easier. 39 | 40 | 41 | Feature request 42 | ^^^^^^^^^^^^^^^ 43 | 44 | A good feature request usually contains: 45 | 46 | - a description of the requested feature, 47 | - a description of the relevance of this feature to causal inference, 48 | - references if applicable, with links to the papers if they are in open access. 49 | 50 | This makes reviewing the relevance of the requested feature easier. 51 | 52 | 53 | Contributing code 54 | ----------------- 55 | 56 | In order to contribute code, you need to create a pull request on 57 | https://github.com/ronikobrosly/causal-curve/pulls 58 | 59 | How to contribute 60 | ^^^^^^^^^^^^^^^^^ 61 | 62 | To contribute to causal-curve, you need to fork the repository then submit a 63 | pull request: 64 | 65 | 1. Fork the repository. 66 | 67 | 2. Clone your fork of the causal-curve repository from your GitHub account to your 68 | local disk:: 69 | 70 | git clone https://github.com/yourusername/causal-curve.git 71 | cd causal-curve 72 | 73 | where ``yourusername`` is your GitHub username. 74 | 75 | 3. Install the development dependencies:: 76 | 77 | pip install pytest pylint black 78 | 79 | 4. Install causal-curve in editable mode:: 80 | 81 | pip install -e . 82 | 83 | 5. Add the ``upstream`` remote. It creates a reference to the main repository 84 | that can be used to keep your repository synchronized with the latest changes 85 | on the main repository:: 86 | 87 | git remote add upstream https://github.com/ronikobrosly/causal-curve.git 88 | 89 | 6. Fetch the ``upstream`` remote then create a new branch where you will make 90 | your changes and switch to it:: 91 | 92 | git fetch upstream 93 | git checkout -b my-feature upstream/main 94 | 95 | where ``my-feature`` is the name of your new branch (it's good practice to have 96 | an explicit name). You can now start making changes. 97 | 98 | 7. Make the changes that you want on your new branch on your new local machine. 99 | When you are done, add the changed files using ``git add`` and then 100 | ``git commit``:: 101 | 102 | git add modified_files 103 | git commit 104 | 105 | Then push your commits to your GitHub account using ``git push``:: 106 | 107 | git push origin my-feature 108 | 109 | 8. Create a pull request from your work. The base fork is the fork you 110 | would like to merge changes into, that is ``ronikobrosly/causal-curve`` on the 111 | ``main`` branch. The head fork is the repository where you made your 112 | changes, that is ``yourusername/causal-curve`` on the ``my-feature`` branch. 113 | Add a title and a description of your pull request, then click on 114 | **Create Pull Request**. 115 | 116 | 117 | Pull request checklist 118 | ^^^^^^^^^^^^^^^^^^^^^^ 119 | 120 | Before pushing to your GitHub account, there are a few rules that are 121 | usually worth complying with. 122 | 123 | - **Make sure that your code passes tests**. You can do this by running the 124 | whole test suite with the ``pytest`` command. If you are experienced with 125 | ``pytest``, you can run specific tests that are relevant for your changes. 126 | It is still worth it running the whole test suite when you are done making 127 | changes since it does not take very long. 128 | For more information, please refer to the 129 | `pytest documentation `_. 130 | If your code does not pass tests but you are looking for help, feel free 131 | to do so (but mention it in your pull request). 132 | 133 | - **Make sure to add tests if you add new code**. It is important to test 134 | new code to make sure that it behaves as expected. Ideally code coverage 135 | should increase with any new pull request. You can check code coverage 136 | using ``pytest-cov``:: 137 | 138 | pip install pytest-cov 139 | pytest --cov causal-curve 140 | 141 | - **Make sure that the documentation renders properly**. To build the 142 | documentation, please refer to the :ref:`contribute_documentation` guidelines. 143 | 144 | - **Make sure that your PR does not add PEP8 violations**. You can run `black` 145 | and `pylint` to only test the modified code. 146 | Feel free to submit another pull request if you find other PEP8 violations. 147 | 148 | .. _contribute_documentation: 149 | 150 | Contributing to the documentation 151 | --------------------------------- 152 | 153 | Documentation is as important as code. If you see typos, find docstrings 154 | unclear or want to add examples illustrating functionalities provided in 155 | causal-curve, feel free to open an issue to report it or a pull request if you 156 | want to fix it. 157 | 158 | 159 | Building the documentation 160 | ^^^^^^^^^^^^^^^^^^^^^^^^^^ 161 | 162 | Building the documentation requires installing some additional packages:: 163 | 164 | pip install sphinx==3.0.2 sphinx-rtd-theme numpydoc 165 | 166 | To build the documentation, you must be in the ``doc`` folder:: 167 | 168 | cd doc 169 | 170 | To generate the website with the example gallery, run the following command:: 171 | 172 | make html 173 | 174 | The documentation will be generated in the ``_build/html``. You can double 175 | click on ``index.html`` to open the index page, which will look like 176 | the first page that you see on the online documentation. Then you can move to 177 | the pages that you modified and have a look at your changes. 178 | 179 | Finally, repeat this process until you are satisfied with your changes and 180 | open a pull request describing the changes you made. 181 | -------------------------------------------------------------------------------- /paper/paper.md: -------------------------------------------------------------------------------- 1 | --- 2 | title: 'causal-curve: A Python Causal Inference Package to Estimate Causal Dose-Response Curves' 3 | tags: 4 | - Python 5 | - causal inference 6 | - causality 7 | - machine learning 8 | 9 | authors: 10 | - name: Roni W. Kobrosly 11 | orcid: 0000-0003-0363-9662 12 | affiliation: "1, 2" # (Multiple affiliations must be quoted) 13 | affiliations: 14 | - name: Department of Environmental Medicine and Public Health, Icahn School of Medicine at Mount Sinai, New York, NY, USA 15 | index: 1 16 | - name: Flowcast, 44 Tehama St, San Francisco, CA, USA 17 | index: 2 18 | date: 1 July 2020 19 | bibliography: paper.bib 20 | 21 | --- 22 | 23 | # Summary 24 | 25 | In academia and industry, randomized controlled experiments (colloquially "A/B tests") 26 | are considered the gold standard approach for assessing the impact of a treatment or intervention. 27 | However, for ethical or financial reasons, these experiments may not always be feasible to carry out. 28 | "Causal inference" methods are a set of approaches that attempt to estimate causal effects 29 | from observational rather than experimental data, correcting for the biases that are inherent 30 | to analyzing observational data (e.g. confounding and selection bias) [@Hernán:2020]. 31 | 32 | Although significant research and implementation effort has gone towards methods in 33 | causal inference to estimate the effects of binary treatments (e.g. what was the effect of 34 | treatment "A" or "B"?), much less has gone towards estimating the effects of continuous treatments. 35 | This is unfortunate because there are a great number of inquiries in research 36 | and industry that could benefit from tools to estimate the effect of 37 | continuous treatments, such as estimating how: 38 | 39 | - the number of minutes per week of aerobic exercise causes positive health outcomes, 40 | after controlling for confounding effects. 41 | - increasing or decreasing the price of a product would impact demand (price elasticity). 42 | - changing neighborhood income inequality (as measured by the continuous Gini index) 43 | might or might not be causally related to the neighborhood crime rate. 44 | - blood lead levels are causally related to neurodevelopment delays in children. 45 | 46 | `causal-curve` is a Python package created to address this gap; it is designed to perform 47 | causal inference when the treatment of interest is continuous in nature. 48 | From the observational data that is provided by the user, it estimates the 49 | "causal dose-response curve" (or simply the "causal curve"). 50 | 51 | In the current release of the package there are two unique model classes for 52 | constructing the causal dose-response curve: the Generalized Propensity Score (GPS) and the 53 | Targetted Maximum Likelihood Estimation (TMLE) tools. There is also tool 54 | to assess causal mediation effects in the presence of a continuous mediator and treatment. 55 | 56 | `causal-curve` attempts to make the user-experience as painless as possible: 57 | 58 | - This package's API was designed to resemble that of `scikit-learn`, as this is a commonly 59 | used Python predictive modeling framework familiar to most machine learning practitioners. 60 | - All of the major classes contained in `causal-curve` readily use Pandas DataFrames and Series as 61 | inputs, to make this package more easily integrate with the standard Python data analysis tools. 62 | - A full, end-to-end example of applying the package to a causal inference problem (the analysis of health data) 63 | is provided. In addition to this, there are shorter tutorials for each of the three major classes are available online in the documentation, along with full documentation of all of their parameters, methods, and attributes. 64 | 65 | This package includes a suite of unit and integration tests made using the pytest framework. The 66 | repo containing the latest project code is integrated with TravisCI for continuous integration. Code 67 | coverage is monitored via codecov and is presently above 90%. 68 | 69 | 70 | # Methods 71 | 72 | The `GPS` method was originally described by Hirano [@Hirano:2004], 73 | and expanded by Moodie [@Moodie:2010] and more recently by Galagate [@Galagate:2016]. GPS is 74 | an extension of the standard propensity tool method and is essentially the treatment assignment density calculated 75 | at a particular treatment (and covariate) value. Similar to the standard propensity score approach, 76 | the GPS random variable is used to balance covariates. At the core of this tool, generalized linear 77 | models are used to estimate the GPS, and generalized additive models are used to estimate the smoothed 78 | final causal curve. Compared with this package’s TMLE method, this GPS method is more 79 | computationally efficient, better suited for large datasets, but produces significantly wider confidence intervals. 80 | 81 | 82 | ![Example of a causal curve generated by the GPS tool.\label{fig:example}](welcome_plot.png) 83 | 84 | 85 | The `TMLE` method is based on van der Laan's work on an approach to causal inference that would 86 | employ powerful machine learning approaches to estimate a causal effect [@van_der_Laan:2010]. 87 | TMLE involves predicting the outcome from the treatment and covariates using a machine learning model, 88 | then predicting treatment assignment from the covariates. TMLE also employs a substitution “targeting” 89 | step to correct for covariate imbalance and to estimate an unbiased causal effect. 90 | Currently, there is no implementation of TMLE that is suitable for continuous treatments. The 91 | implemention in `causal-curve` constructs the final curve through a series of binary treatment comparisons 92 | across the user-specified range of treatment values and then by connecting them. 93 | Compared with the package’s GPS method, this TMLE method is double robust 94 | against model misspecification, incorporates more powerful machine learning techniques internally, produces significantly 95 | smaller confidence intervals, however it is less computationally efficient. 96 | 97 | `causal-curve` allows for continuous mediation assessment with the `Mediation` tool. As described 98 | by Imai this approach provides a general approach to mediation analysis that invokes the potential 99 | outcomes / counterfactual framework [@Imai:2010]. While this approach can handle a 100 | continuous mediator and outcome, as put forward by Imai it only allows for a binary treatment. As 101 | mentioned above with the `TMLE` approach, the tool creates a series of binary treatment comparisons 102 | and connects them to show the user how mediation varies as a function of the treatment. An interpretable, 103 | overall mediation proportion is provided as well. 104 | 105 | 106 | # Statement of Need 107 | 108 | While there are a few established Python packages related to causal inference, to the best of 109 | the author's knowledge, there is no Python package available that can provide support for 110 | continuous treatments as `causal-curve` does. Similarly, the author isn't aware of any Python 111 | implementation of a causal mediation analysis for continuous treatments and mediators. Finally, 112 | the tutorials available in the documentation introduce the concept of continuous treatments 113 | and are instructive as to how the results of their analysis should be interpretted. 114 | 115 | 116 | # Acknowledgements 117 | 118 | We acknowledge the valuable feedback from Miguel-Angel Luque, Erica Moodie, and Mark van der Laan 119 | during the creation of this project. 120 | 121 | 122 | # References 123 | -------------------------------------------------------------------------------- /docs/full_example.rst: -------------------------------------------------------------------------------- 1 | .. _full_example: 2 | 3 | ============================================================= 4 | Health data: generating causal curves and examining mediation 5 | ============================================================= 6 | 7 | To provide an end-to-end example of the sorts of analyses `cause-curve` can be used for, we'll 8 | begin with a health topic. A notebook containing the pipeline to produce the following 9 | output `is available here `_. 10 | Note: Specific examples of the individual `causal-curve` tools with 11 | code are available elsewhere in this documentation. 12 | 13 | 14 | The causal effect of blood lead levels on cognitive performance in children 15 | --------------------------------------------------------------------------- 16 | 17 | Despite the banning of the use of lead-based paint and the use of lead in gasoline in the United 18 | States, lead exposure remains an enormous public health problem for children and adolescents. This 19 | is particularly true for poorer children living in older homes in inner-city environments. 20 | For children, there is no known safe level of exposure to lead, and even small levels of 21 | lead measured in their blood have been shown to affect IQ and academic achievement. 22 | One of the scariest parts of lead exposure is that its effects are permanent. Blood lead levels (BLLs) 23 | of 5 ug/dL or higher are considered elevated. 24 | 25 | There are much research around and many government programs for lead abatement. In terms of 26 | public policy, it would be helpful to understand how childhood cognitive outcomes would be affected by 27 | reducing BLLs in children. This is the causal question to answer, with blood lead 28 | levels being the continuous treatment, and the cognitive outcomes being the outcome of interest. 29 | 30 | .. image:: https://upload.wikimedia.org/wikipedia/commons/6/69/LeadPaint1.JPG 31 | 32 | (Photo attribution: Thester11 / CC BY (https://creativecommons.org/licenses/by/3.0)) 33 | 34 | To explore that problem, we can analyze data collected from the National Health and Nutrition 35 | Examination Survey (NHANES) III. This was a large, national study of families throughout the United 36 | States, carried out between 1988 and 1994. Participants were involved in extensive interviews, 37 | medical examinations, and provided biological samples. As part of this project, BLLs 38 | were measured, and four scaled sub-tests of the Wechsler Intelligence Scale for Children-Revised 39 | and the Wide Range Achievement Test-Revised (WISC/WRAT) cognitive test were carried out. This data 40 | is de-identified and publicly available on the Centers for Disease Control and Prevention (CDC) 41 | government website. 42 | 43 | When processing the data and missing values were dropped, there were 1,764 children between 44 | 6 and 12 years of age with complete data. BLLs among these children were log-normally 45 | distributed, as one would expect: 46 | 47 | .. image:: ../imgs/full_example/BLL_dist.png 48 | 49 | The four scaled sub-tests of the WISC/WRAT included a math test, a reading test, a block design 50 | test (a test of spatial visualization ability and motor skill), and a digit spanning test 51 | (a test of memory). Their distributions are shown here: 52 | 53 | .. image:: ../imgs/full_example/test_dist.png 54 | 55 | Using a well-known study by Bruce Lanphear conducted in 2000 as a guide, we used the following 56 | features as potentially confounding "nuisance" variables: 57 | 58 | - Child age 59 | - Child sex (in 1988 - 1994 the CDC assumed binary sex) 60 | - Child race/ethnicity 61 | - The education level of the guardian 62 | - Whether someone smokes in the child's home 63 | - Whether the child spent time in a neonatal intensive care unit as a baby 64 | - Whether the child is experiencing food insecurity (is food sometimes not available due to lack of resources?). 65 | 66 | In our "experiment", these above confounders will be controlled for. 67 | 68 | By using either the GPS or TMLE tools included in `causal-curve` one can generate the causal 69 | dose-response curves for BLLs in relation to the four outcomes: 70 | 71 | .. image:: ../imgs/full_example/test_causal_curves.png 72 | 73 | Note that the lower limit of detection for the blood lead test in this version of NHANES was 74 | 0.7 ug/dL. So lead levels below that value are not possible. 75 | 76 | In the case of the math test, these results indicate that by reducing BLLs in this population 77 | to their lowest value would cause scaled math scores to increase by around 2 points, relative 78 | to the BLLs around 10 ug/dL. Similar results are found for the reading and block design test, 79 | although the digit spanning test causal curve appears possibly flat (although with the sparse 80 | observations at the higher end of the BLL range and the wide confidence intervals it is 81 | difficult to say). 82 | 83 | The above curves differ from standard regression curves in a few big ways: 84 | 85 | - Even though the data that we used to generate these curves are observational, if causal inference assumptions are met, these curves can be interpretted as causal. 86 | - These models were created using the potential outcomes / counterfactual framework, while standard models are not. Also, the approach we used here essentially simulates experimental conditions by balancing out treatment assignment across the various confounders, and controlling for their effects. 87 | - Even if complex interactions between the variables are modelled, these curves average over the various interaction effects and subgroups. In this sense, these are "marginal" curves. 88 | - These curves should not be used to make predictions at the individual level. These are population level estimates and should remain that way. 89 | 90 | 91 | 92 | Do blood lead levels mediate the relationship between poverty and cognitive performance? 93 | ---------------------------------------------------------------------------------------- 94 | 95 | There is a well-known link between household income and child academic performance. Now that we 96 | have some evidence of a potentially causal relationship between BLLs and test performance in 97 | children, one might wonder if lead exposure might mediate the relationship between household income 98 | academic performance. In other words, in this population does low income cause one to be 99 | exposed more to lead, which in turn causes lower performance? Or is household income directly 100 | linked with academic performance or through other variables? 101 | 102 | NHANES III captured each household's Poverty Index Ratio (the ratio of total family income to 103 | the federal poverty level for the year of the interview). For this example, let's focus just 104 | on the math test as an outcome. Using `causal-curve`'s mediation tool, 105 | we found that the overall, mediating indirect effect of BLLs are 0.20 (0.17 - 0.23). This means 106 | that lead exposure accounts for 20% of the relationship between low income and low test 107 | performance in this population. The mediation tool also allows you to see how the indirect effect 108 | varies as a function of the treatment. As the plot shows, the mediating effect is relatively flat, 109 | although interesting there is a hint of an increase as income increases relative to the poverty line. 110 | 111 | .. image:: ../imgs/full_example/mediation_curve.png 112 | 113 | 114 | References 115 | ---------- 116 | 117 | Centers for Disease Control and Prevention. NHANES III (1988-1994). 118 | https://wwwn.cdc.gov/nchs/nhanes/nhanes3/default.aspx. Accessed on July 2, 2020. 119 | 120 | Centers for Disease Control and Prevention. Blood Lead Levels in Children. 121 | https://www.cdc.gov/nceh/lead/prevention/blood-lead-levels.htm. Accessed on July 2, 2020. 122 | 123 | Environmental Protection Agency. Learn about Lead. https://www.epa.gov/lead/learn-about-lead. 124 | Accessed on July 2, 2020. 125 | 126 | Pirkle JL, Kaufmann RB, Brody DJ, Hickman T, Gunter EW, Paschal DC. Exposure of the 127 | U.S. population to lead, 1991-1994. Environmental Health Perspectives, 106(11), 1998, pp. 745–750. 128 | 129 | Lanphear BP, Dietrich K, Auinger P, Cox C. Cognitive Deficits Associated with 130 | Blood Lead Concentrations <10 pg/dL in US Children and Adolescents. 131 | In: Public Health Reports, 115, 2000, pp.521-529. 132 | -------------------------------------------------------------------------------- /.pylintrc: -------------------------------------------------------------------------------- 1 | [MASTER] 2 | 3 | # A comma-separated list of package or module names from where C extensions may 4 | # be loaded. Extensions are loading into the active Python interpreter and may 5 | # run arbitrary code. 6 | extension-pkg-whitelist= 7 | 8 | # Specify a score threshold to be exceeded before program exits with error. 9 | fail-under=10.0 10 | 11 | # Add files or directories to the blacklist. They should be base names, not 12 | # paths. 13 | ignore=CVS 14 | 15 | # Add files or directories matching the regex patterns to the blacklist. The 16 | # regex matches against base names, not paths. 17 | ignore-patterns= 18 | 19 | # Python code to execute, usually for sys.path manipulation such as 20 | # pygtk.require(). 21 | #init-hook= 22 | 23 | # Use multiple processes to speed up Pylint. Specifying 0 will auto-detect the 24 | # number of processors available to use. 25 | jobs=1 26 | 27 | # Control the amount of potential inferred values when inferring a single 28 | # object. This can help the performance when dealing with large functions or 29 | # complex, nested conditions. 30 | limit-inference-results=100 31 | 32 | # List of plugins (as comma separated values of python module names) to load, 33 | # usually to register additional checkers. 34 | load-plugins= 35 | 36 | # Pickle collected data for later comparisons. 37 | persistent=yes 38 | 39 | # When enabled, pylint would attempt to guess common misconfiguration and emit 40 | # user-friendly hints instead of false-positive error messages. 41 | suggestion-mode=yes 42 | 43 | # Allow loading of arbitrary C extensions. Extensions are imported into the 44 | # active Python interpreter and may run arbitrary code. 45 | unsafe-load-any-extension=no 46 | 47 | 48 | [MESSAGES CONTROL] 49 | 50 | # Only show warnings with the listed confidence levels. Leave empty to show 51 | # all. Valid levels: HIGH, INFERENCE, INFERENCE_FAILURE, UNDEFINED. 52 | confidence= 53 | 54 | # Disable the message, report, category or checker with the given id(s). You 55 | # can either give multiple identifiers separated by comma (,) or put this 56 | # option multiple times (only on the command line, not in the configuration 57 | # file where it should appear only once). You can also use "--disable=all" to 58 | # disable everything first and then reenable specific checks. For example, if 59 | # you want to run only the similarities checker, you can use "--disable=all 60 | # --enable=similarities". If you want to run only the classes checker, but have 61 | # no Warning level messages displayed, use "--disable=all --enable=classes 62 | # --disable=W". 63 | disable=too-many-locals, 64 | super-init-not-called, 65 | attribute-defined-outside-init, 66 | too-many-instance-attributes, 67 | invalid-name, 68 | too-few-public-methods, 69 | consider-using-dict-comprehension, 70 | too-many-arguments, 71 | no-name-in-module, 72 | duplicate-code, 73 | print-statement, 74 | parameter-unpacking, 75 | unpacking-in-except, 76 | old-raise-syntax, 77 | backtick, 78 | long-suffix, 79 | old-ne-operator, 80 | old-octal-literal, 81 | import-star-module-level, 82 | non-ascii-bytes-literal, 83 | raw-checker-failed, 84 | bad-inline-option, 85 | locally-disabled, 86 | file-ignored, 87 | suppressed-message, 88 | useless-suppression, 89 | deprecated-pragma, 90 | use-symbolic-message-instead, 91 | apply-builtin, 92 | basestring-builtin, 93 | buffer-builtin, 94 | cmp-builtin, 95 | coerce-builtin, 96 | execfile-builtin, 97 | file-builtin, 98 | long-builtin, 99 | raw_input-builtin, 100 | reduce-builtin, 101 | standarderror-builtin, 102 | unicode-builtin, 103 | xrange-builtin, 104 | coerce-method, 105 | delslice-method, 106 | getslice-method, 107 | setslice-method, 108 | no-absolute-import, 109 | old-division, 110 | dict-iter-method, 111 | dict-view-method, 112 | next-method-called, 113 | metaclass-assignment, 114 | indexing-exception, 115 | raising-string, 116 | reload-builtin, 117 | oct-method, 118 | hex-method, 119 | nonzero-method, 120 | cmp-method, 121 | input-builtin, 122 | round-builtin, 123 | intern-builtin, 124 | unichr-builtin, 125 | map-builtin-not-iterating, 126 | zip-builtin-not-iterating, 127 | range-builtin-not-iterating, 128 | filter-builtin-not-iterating, 129 | using-cmp-argument, 130 | eq-without-hash, 131 | div-method, 132 | idiv-method, 133 | rdiv-method, 134 | exception-message-attribute, 135 | invalid-str-codec, 136 | sys-max-int, 137 | bad-python3-import, 138 | deprecated-string-function, 139 | deprecated-str-translate-call, 140 | deprecated-itertools-function, 141 | deprecated-types-field, 142 | next-method-defined, 143 | dict-items-not-iterating, 144 | dict-keys-not-iterating, 145 | dict-values-not-iterating, 146 | deprecated-operator-function, 147 | deprecated-urllib-function, 148 | xreadlines-attribute, 149 | deprecated-sys-function, 150 | exception-escape, 151 | comprehension-escape 152 | 153 | # Enable the message, report, category or checker with the given id(s). You can 154 | # either give multiple identifier separated by comma (,) or put this option 155 | # multiple time (only on the command line, not in the configuration file where 156 | # it should appear only once). See also the "--disable" option for examples. 157 | enable=c-extension-no-member 158 | 159 | 160 | [REPORTS] 161 | 162 | # Python expression which should return a score less than or equal to 10. You 163 | # have access to the variables 'error', 'warning', 'refactor', and 'convention' 164 | # which contain the number of messages in each category, as well as 'statement' 165 | # which is the total number of statements analyzed. This score is used by the 166 | # global evaluation report (RP0004). 167 | evaluation=10.0 - ((float(5 * error + warning + refactor + convention) / statement) * 10) 168 | 169 | # Template used to display messages. This is a python new-style format string 170 | # used to format the message information. See doc for all details. 171 | #msg-template= 172 | 173 | # Set the output format. Available formats are text, parseable, colorized, json 174 | # and msvs (visual studio). You can also give a reporter class, e.g. 175 | # mypackage.mymodule.MyReporterClass. 176 | output-format=text 177 | 178 | # Tells whether to display a full report or only the messages. 179 | reports=no 180 | 181 | # Activate the evaluation score. 182 | score=yes 183 | 184 | 185 | [REFACTORING] 186 | 187 | # Maximum number of nested blocks for function / method body 188 | max-nested-blocks=5 189 | 190 | # Complete name of functions that never returns. When checking for 191 | # inconsistent-return-statements if a never returning function is called then 192 | # it will be considered as an explicit return statement and no message will be 193 | # printed. 194 | never-returning-functions=sys.exit 195 | 196 | 197 | [LOGGING] 198 | 199 | # The type of string formatting that logging methods do. `old` means using % 200 | # formatting, `new` is for `{}` formatting. 201 | logging-format-style=old 202 | 203 | # Logging modules to check that the string format arguments are in logging 204 | # function parameter format. 205 | logging-modules=logging 206 | 207 | 208 | [SPELLING] 209 | 210 | # Limits count of emitted suggestions for spelling mistakes. 211 | max-spelling-suggestions=4 212 | 213 | # Spelling dictionary name. Available dictionaries: none. To make it work, 214 | # install the python-enchant package. 215 | spelling-dict= 216 | 217 | # List of comma separated words that should not be checked. 218 | spelling-ignore-words= 219 | 220 | # A path to a file that contains the private dictionary; one word per line. 221 | spelling-private-dict-file= 222 | 223 | # Tells whether to store unknown words to the private dictionary (see the 224 | # --spelling-private-dict-file option) instead of raising a message. 225 | spelling-store-unknown-words=no 226 | 227 | 228 | [MISCELLANEOUS] 229 | 230 | # List of note tags to take in consideration, separated by a comma. 231 | notes=FIXME, 232 | XXX, 233 | TODO 234 | 235 | # Regular expression of note tags to take in consideration. 236 | #notes-rgx= 237 | 238 | 239 | [TYPECHECK] 240 | 241 | # List of decorators that produce context managers, such as 242 | # contextlib.contextmanager. Add to this list to register other decorators that 243 | # produce valid context managers. 244 | contextmanager-decorators=contextlib.contextmanager 245 | 246 | # List of members which are set dynamically and missed by pylint inference 247 | # system, and so shouldn't trigger E1101 when accessed. Python regular 248 | # expressions are accepted. 249 | generated-members= 250 | 251 | # Tells whether missing members accessed in mixin class should be ignored. A 252 | # mixin class is detected if its name ends with "mixin" (case insensitive). 253 | ignore-mixin-members=yes 254 | 255 | # Tells whether to warn about missing members when the owner of the attribute 256 | # is inferred to be None. 257 | ignore-none=yes 258 | 259 | # This flag controls whether pylint should warn about no-member and similar 260 | # checks whenever an opaque object is returned when inferring. The inference 261 | # can return multiple potential results while evaluating a Python object, but 262 | # some branches might not be evaluated, which results in partial inference. In 263 | # that case, it might be useful to still emit no-member and other checks for 264 | # the rest of the inferred objects. 265 | ignore-on-opaque-inference=yes 266 | 267 | # List of class names for which member attributes should not be checked (useful 268 | # for classes with dynamically set attributes). This supports the use of 269 | # qualified names. 270 | ignored-classes=optparse.Values,thread._local,_thread._local 271 | 272 | # List of module names for which member attributes should not be checked 273 | # (useful for modules/projects where namespaces are manipulated during runtime 274 | # and thus existing member attributes cannot be deduced by static analysis). It 275 | # supports qualified module names, as well as Unix pattern matching. 276 | ignored-modules= 277 | 278 | # Show a hint with possible names when a member name was not found. The aspect 279 | # of finding the hint is based on edit distance. 280 | missing-member-hint=yes 281 | 282 | # The minimum edit distance a name should have in order to be considered a 283 | # similar match for a missing member name. 284 | missing-member-hint-distance=1 285 | 286 | # The total number of similar names that should be taken in consideration when 287 | # showing a hint for a missing member. 288 | missing-member-max-choices=1 289 | 290 | # List of decorators that change the signature of a decorated function. 291 | signature-mutators= 292 | 293 | 294 | [VARIABLES] 295 | 296 | # List of additional names supposed to be defined in builtins. Remember that 297 | # you should avoid defining new builtins when possible. 298 | additional-builtins= 299 | 300 | # Tells whether unused global variables should be treated as a violation. 301 | allow-global-unused-variables=yes 302 | 303 | # List of strings which can identify a callback function by name. A callback 304 | # name must start or end with one of those strings. 305 | callbacks=cb_, 306 | _cb 307 | 308 | # A regular expression matching the name of dummy variables (i.e. expected to 309 | # not be used). 310 | dummy-variables-rgx=_+$|(_[a-zA-Z0-9_]*[a-zA-Z0-9]+?$)|dummy|^ignored_|^unused_ 311 | 312 | # Argument names that match this expression will be ignored. Default to name 313 | # with leading underscore. 314 | ignored-argument-names=_.*|^ignored_|^unused_ 315 | 316 | # Tells whether we should check for unused import in __init__ files. 317 | init-import=no 318 | 319 | # List of qualified module names which can have objects that can redefine 320 | # builtins. 321 | redefining-builtins-modules=six.moves,past.builtins,future.builtins,builtins,io 322 | 323 | 324 | [FORMAT] 325 | 326 | # Expected format of line ending, e.g. empty (any line ending), LF or CRLF. 327 | expected-line-ending-format= 328 | 329 | # Regexp for a line that is allowed to be longer than the limit. 330 | ignore-long-lines=^\s*(# )??$ 331 | 332 | # Number of spaces of indent required inside a hanging or continued line. 333 | indent-after-paren=4 334 | 335 | # String used as indentation unit. This is usually " " (4 spaces) or "\t" (1 336 | # tab). 337 | indent-string=' ' 338 | 339 | # Maximum number of characters on a single line. 340 | max-line-length=100 341 | 342 | # Maximum number of lines in a module. 343 | max-module-lines=1000 344 | 345 | # Allow the body of a class to be on the same line as the declaration if body 346 | # contains single statement. 347 | single-line-class-stmt=no 348 | 349 | # Allow the body of an if to be on the same line as the test if there is no 350 | # else. 351 | single-line-if-stmt=no 352 | 353 | 354 | [SIMILARITIES] 355 | 356 | # Ignore comments when computing similarities. 357 | ignore-comments=yes 358 | 359 | # Ignore docstrings when computing similarities. 360 | ignore-docstrings=yes 361 | 362 | # Ignore imports when computing similarities. 363 | ignore-imports=no 364 | 365 | # Minimum lines number of a similarity. 366 | min-similarity-lines=4 367 | 368 | 369 | [BASIC] 370 | 371 | # Naming style matching correct argument names. 372 | argument-naming-style=snake_case 373 | 374 | # Regular expression matching correct argument names. Overrides argument- 375 | # naming-style. 376 | #argument-rgx= 377 | 378 | # Naming style matching correct attribute names. 379 | attr-naming-style=snake_case 380 | 381 | # Regular expression matching correct attribute names. Overrides attr-naming- 382 | # style. 383 | #attr-rgx= 384 | 385 | # Bad variable names which should always be refused, separated by a comma. 386 | bad-names=foo, 387 | bar, 388 | baz, 389 | toto, 390 | tutu, 391 | tata 392 | 393 | # Bad variable names regexes, separated by a comma. If names match any regex, 394 | # they will always be refused 395 | bad-names-rgxs= 396 | 397 | # Naming style matching correct class attribute names. 398 | class-attribute-naming-style=any 399 | 400 | # Regular expression matching correct class attribute names. Overrides class- 401 | # attribute-naming-style. 402 | #class-attribute-rgx= 403 | 404 | # Naming style matching correct class names. 405 | class-naming-style=PascalCase 406 | 407 | # Regular expression matching correct class names. Overrides class-naming- 408 | # style. 409 | #class-rgx= 410 | 411 | # Naming style matching correct constant names. 412 | const-naming-style=UPPER_CASE 413 | 414 | # Regular expression matching correct constant names. Overrides const-naming- 415 | # style. 416 | #const-rgx= 417 | 418 | # Minimum line length for functions/classes that require docstrings, shorter 419 | # ones are exempt. 420 | docstring-min-length=-1 421 | 422 | # Naming style matching correct function names. 423 | function-naming-style=snake_case 424 | 425 | # Regular expression matching correct function names. Overrides function- 426 | # naming-style. 427 | #function-rgx= 428 | 429 | # Good variable names which should always be accepted, separated by a comma. 430 | good-names=i, 431 | j, 432 | k, 433 | ex, 434 | Run, 435 | _ 436 | 437 | # Good variable names regexes, separated by a comma. If names match any regex, 438 | # they will always be accepted 439 | good-names-rgxs= 440 | 441 | # Include a hint for the correct naming format with invalid-name. 442 | include-naming-hint=no 443 | 444 | # Naming style matching correct inline iteration names. 445 | inlinevar-naming-style=any 446 | 447 | # Regular expression matching correct inline iteration names. Overrides 448 | # inlinevar-naming-style. 449 | #inlinevar-rgx= 450 | 451 | # Naming style matching correct method names. 452 | method-naming-style=snake_case 453 | 454 | # Regular expression matching correct method names. Overrides method-naming- 455 | # style. 456 | #method-rgx= 457 | 458 | # Naming style matching correct module names. 459 | module-naming-style=snake_case 460 | 461 | # Regular expression matching correct module names. Overrides module-naming- 462 | # style. 463 | #module-rgx= 464 | 465 | # Colon-delimited sets of names that determine each other's naming style when 466 | # the name regexes allow several styles. 467 | name-group= 468 | 469 | # Regular expression which should only match function or class names that do 470 | # not require a docstring. 471 | no-docstring-rgx=^_ 472 | 473 | # List of decorators that produce properties, such as abc.abstractproperty. Add 474 | # to this list to register other decorators that produce valid properties. 475 | # These decorators are taken in consideration only for invalid-name. 476 | property-classes=abc.abstractproperty 477 | 478 | # Naming style matching correct variable names. 479 | variable-naming-style=snake_case 480 | 481 | # Regular expression matching correct variable names. Overrides variable- 482 | # naming-style. 483 | #variable-rgx= 484 | 485 | 486 | [STRING] 487 | 488 | # This flag controls whether inconsistent-quotes generates a warning when the 489 | # character used as a quote delimiter is used inconsistently within a module. 490 | check-quote-consistency=no 491 | 492 | # This flag controls whether the implicit-str-concat should generate a warning 493 | # on implicit string concatenation in sequences defined over several lines. 494 | check-str-concat-over-line-jumps=no 495 | 496 | 497 | [IMPORTS] 498 | 499 | # List of modules that can be imported at any level, not just the top level 500 | # one. 501 | allow-any-import-level= 502 | 503 | # Allow wildcard imports from modules that define __all__. 504 | allow-wildcard-with-all=no 505 | 506 | # Analyse import fallback blocks. This can be used to support both Python 2 and 507 | # 3 compatible code, which means that the block might have code that exists 508 | # only in one or another interpreter, leading to false positives when analysed. 509 | analyse-fallback-blocks=no 510 | 511 | # Deprecated modules which should not be used, separated by a comma. 512 | deprecated-modules=optparse,tkinter.tix 513 | 514 | # Create a graph of external dependencies in the given file (report RP0402 must 515 | # not be disabled). 516 | ext-import-graph= 517 | 518 | # Create a graph of every (i.e. internal and external) dependencies in the 519 | # given file (report RP0402 must not be disabled). 520 | import-graph= 521 | 522 | # Create a graph of internal dependencies in the given file (report RP0402 must 523 | # not be disabled). 524 | int-import-graph= 525 | 526 | # Force import order to recognize a module as part of the standard 527 | # compatibility libraries. 528 | known-standard-library= 529 | 530 | # Force import order to recognize a module as part of a third party library. 531 | known-third-party=enchant 532 | 533 | # Couples of modules and preferred modules, separated by a comma. 534 | preferred-modules= 535 | 536 | 537 | [CLASSES] 538 | 539 | # List of method names used to declare (i.e. assign) instance attributes. 540 | defining-attr-methods=__init__, 541 | __new__, 542 | setUp, 543 | __post_init__ 544 | 545 | # List of member names, which should be excluded from the protected access 546 | # warning. 547 | exclude-protected=_asdict, 548 | _fields, 549 | _replace, 550 | _source, 551 | _make 552 | 553 | # List of valid names for the first argument in a class method. 554 | valid-classmethod-first-arg=cls 555 | 556 | # List of valid names for the first argument in a metaclass class method. 557 | valid-metaclass-classmethod-first-arg=cls 558 | 559 | 560 | [DESIGN] 561 | 562 | # Maximum number of arguments for function / method. 563 | max-args=5 564 | 565 | # Maximum number of attributes for a class (see R0902). 566 | max-attributes=7 567 | 568 | # Maximum number of boolean expressions in an if statement (see R0916). 569 | max-bool-expr=5 570 | 571 | # Maximum number of branch for function / method body. 572 | max-branches=12 573 | 574 | # Maximum number of locals for function / method body. 575 | max-locals=15 576 | 577 | # Maximum number of parents for a class (see R0901). 578 | max-parents=7 579 | 580 | # Maximum number of public methods for a class (see R0904). 581 | max-public-methods=20 582 | 583 | # Maximum number of return / yield for function / method body. 584 | max-returns=6 585 | 586 | # Maximum number of statements in function / method body. 587 | max-statements=50 588 | 589 | # Minimum number of public methods for a class (see R0903). 590 | min-public-methods=2 591 | 592 | 593 | [EXCEPTIONS] 594 | 595 | # Exceptions that will emit a warning when being caught. Defaults to 596 | # "BaseException, Exception". 597 | overgeneral-exceptions=BaseException, 598 | Exception 599 | -------------------------------------------------------------------------------- /causal_curve/tmle_core.py: -------------------------------------------------------------------------------- 1 | """ 2 | Defines the Targetted Maximum likelihood Estimation (TMLE) model class 3 | """ 4 | import numpy as np 5 | import pandas as pd 6 | from pandas.api.types import is_float_dtype, is_numeric_dtype 7 | from pygam import LinearGAM, s 8 | from scipy.interpolate import interp1d 9 | from sklearn.neighbors import KernelDensity 10 | from sklearn.ensemble import GradientBoostingRegressor 11 | from statsmodels.nonparametric.kernel_regression import KernelReg 12 | 13 | from causal_curve.core import Core 14 | 15 | 16 | class TMLE_Core(Core): 17 | """ 18 | Constructs a causal dose response curve via a modified version of Targetted 19 | Maximum Likelihood Estimation (TMLE) across a grid of the treatment values. 20 | Gradient boosting is used for prediction of the Q model and G models, simple 21 | kernel regression is used processing those model results, and a generalized 22 | additive model is used in the final step to contruct the final curve. 23 | Assumes continuous treatment and outcome variable. 24 | 25 | WARNING: 26 | 27 | -The treatment values should be roughly normally-distributed for this tool 28 | to work. Otherwise you may encounter internal math errors. 29 | 30 | -This algorithm assumes you've already performed the necessary transformations to 31 | categorical covariates (i.e. these variables are already one-hot encoded and 32 | one of the categories is excluded for each set of dummy variables). 33 | 34 | -Please take care to ensure that the "ignorability" assumption is met (i.e. 35 | all strong confounders are captured in your covariates and there is no 36 | informative censoring), otherwise your results will be biased, sometimes strongly so. 37 | 38 | Parameters 39 | ---------- 40 | 41 | treatment_grid_num: int, optional (default = 100) 42 | Takes the treatment, and creates a quantile-based grid across its values. 43 | For instance, if the number 6 is selected, this means the algorithm will only take 44 | the 6 treatment variable values at approximately the 0, 20, 40, 60, 80, and 100th 45 | percentiles to estimate the causal dose response curve. 46 | Higher value here means the final curve will be more finely estimated, 47 | but also increases computation time. Default is usually a reasonable number. 48 | 49 | lower_grid_constraint: float, optional(default = 0.01) 50 | This adds an optional constraint of the lower side of the treatment grid. 51 | Sometimes data near the minimum values of the treatment are few in number 52 | and thus generate unstable estimates. By default, this clips the bottom 1 percentile 53 | or lower of treatment values. This can be as low as 0, indicating there is no 54 | lower limit to how much treatment data is considered. 55 | 56 | upper_grid_constraint: float, optional (default = 0.99) 57 | See above parameter. Just like above, but this is an upper constraint. 58 | By default, this clips the top 99th percentile or higher of treatment values. 59 | This can be as high as 1.0, indicating there is no upper limit to how much 60 | treatment data is considered. 61 | 62 | n_estimators: int, optional (default = 200) 63 | Optional argument to set the number of learners to use when sklearn 64 | creates TMLE's Q and G models. 65 | 66 | learning_rate: float, optional (default = 0.01) 67 | Optional argument to set the sklearn's learning rate for TMLE's Q and G models. 68 | 69 | max_depth: int, optional (default = 3) 70 | Optional argument to set sklearn's maximum depth when creating TMLE's Q and G models. 71 | 72 | bandwidth: float, optional (default = 0.5) 73 | Optional argument to set the bandwidth parameter of the internal 74 | kernel density estimation and kernel regression methods. 75 | 76 | random_seed: int, optional (default = None) 77 | Sets the random seed. 78 | 79 | verbose: bool, optional (default = False) 80 | Determines whether the user will get verbose status updates. 81 | 82 | 83 | Attributes 84 | ---------- 85 | 86 | grid_values: array of shape (treatment_grid_num, ) 87 | The gridded values of the treatment variable. Equally spaced. 88 | 89 | final_gam: `pygam.LinearGAM` class 90 | trained final model of `LinearGAM` class, from pyGAM library 91 | 92 | pseudo_out: array of shape (observations, ) 93 | Adjusted, pseudo-outcome observations 94 | 95 | 96 | Methods 97 | ---------- 98 | fit: (self, T, X, y) 99 | Fits the causal dose-response model 100 | 101 | calculate_CDRC: (self, ci, CDRC_grid_num) 102 | Calculates the CDRC (and confidence interval) from TMLE estimation 103 | 104 | 105 | Examples 106 | -------- 107 | 108 | >>> # With continuous outcome 109 | >>> from causal_curve import TMLE_Regressor 110 | >>> tmle = TMLE_Regressor() 111 | >>> tmle.fit(T = df['Treatment'], X = df[['X_1', 'X_2']], y = df['Outcome']) 112 | >>> tmle_results = tmle.calculate_CDRC(0.95) 113 | >>> point_estimate = tmle.point_estimate(np.array([5.0])) 114 | >>> point_estimate_interval = tmle.point_estimate_interval(np.array([5.0]), 0.95) 115 | 116 | 117 | References 118 | ---------- 119 | 120 | Kennedy EH, Ma Z, McHugh MD, Small DS. Nonparametric methods for doubly robust estimation 121 | of continuous treatment effects. Journal of the Royal Statistical Society, 122 | Series B. 79(4), 2017, pp.1229-1245. 123 | 124 | van der Laan MJ and Rubin D. Targeted maximum likelihood learning. In: The International 125 | Journal of Biostatistics, 2(1), 2006. 126 | 127 | van der Laan MJ and Gruber S. Collaborative double robust penalized targeted 128 | maximum likelihood estimation. In: The International Journal of Biostatistics 6(1), 2010. 129 | 130 | """ 131 | 132 | def __init__( 133 | self, 134 | treatment_grid_num=100, 135 | lower_grid_constraint=0.01, 136 | upper_grid_constraint=0.99, 137 | n_estimators=200, 138 | learning_rate=0.01, 139 | max_depth=3, 140 | bandwidth=0.5, 141 | random_seed=None, 142 | verbose=False, 143 | ): 144 | 145 | self.treatment_grid_num = treatment_grid_num 146 | self.lower_grid_constraint = lower_grid_constraint 147 | self.upper_grid_constraint = upper_grid_constraint 148 | self.n_estimators = n_estimators 149 | self.learning_rate = learning_rate 150 | self.max_depth = max_depth 151 | self.bandwidth = bandwidth 152 | self.random_seed = random_seed 153 | self.verbose = verbose 154 | 155 | def _validate_init_params(self): 156 | """ 157 | Checks that the params used when instantiating TMLE model are formatted correctly 158 | """ 159 | 160 | # Checks for treatment_grid_num 161 | if not isinstance(self.treatment_grid_num, int): 162 | raise TypeError( 163 | f"treatment_grid_num parameter must be an integer, " 164 | f"but found type {type(self.treatment_grid_num)}" 165 | ) 166 | 167 | if (isinstance(self.treatment_grid_num, int)) and self.treatment_grid_num < 10: 168 | raise ValueError( 169 | f"treatment_grid_num parameter should be >= 10 so your final curve " 170 | f"has enough resolution, but found value {self.treatment_grid_num}" 171 | ) 172 | 173 | if ( 174 | isinstance(self.treatment_grid_num, int) 175 | ) and self.treatment_grid_num >= 1000: 176 | raise ValueError("treatment_grid_num parameter is too high!") 177 | 178 | # Checks for lower_grid_constraint 179 | if not isinstance(self.lower_grid_constraint, float): 180 | raise TypeError( 181 | f"lower_grid_constraint parameter must be a float, " 182 | f"but found type {type(self.lower_grid_constraint)}" 183 | ) 184 | 185 | if ( 186 | isinstance(self.lower_grid_constraint, float) 187 | ) and self.lower_grid_constraint < 0: 188 | raise ValueError( 189 | f"lower_grid_constraint parameter cannot be < 0, " 190 | f"but found value {self.lower_grid_constraint}" 191 | ) 192 | 193 | if ( 194 | isinstance(self.lower_grid_constraint, float) 195 | ) and self.lower_grid_constraint >= 1.0: 196 | raise ValueError( 197 | f"lower_grid_constraint parameter cannot >= 1.0, " 198 | f"but found value {self.lower_grid_constraint}" 199 | ) 200 | 201 | # Checks for upper_grid_constraint 202 | if not isinstance(self.upper_grid_constraint, float): 203 | raise TypeError( 204 | f"upper_grid_constraint parameter must be a float, " 205 | f"but found type {type(self.upper_grid_constraint)}" 206 | ) 207 | 208 | if ( 209 | isinstance(self.upper_grid_constraint, float) 210 | ) and self.upper_grid_constraint <= 0: 211 | raise ValueError( 212 | f"upper_grid_constraint parameter cannot be <= 0, " 213 | f"but found value {self.upper_grid_constraint}" 214 | ) 215 | 216 | if ( 217 | isinstance(self.upper_grid_constraint, float) 218 | ) and self.upper_grid_constraint > 1.0: 219 | raise ValueError( 220 | f"upper_grid_constraint parameter cannot > 1.0, " 221 | f"but found value {self.upper_grid_constraint}" 222 | ) 223 | 224 | # Checks for lower_grid_constraint isn't higher than upper_grid_constraint 225 | if self.lower_grid_constraint >= self.upper_grid_constraint: 226 | raise ValueError( 227 | "lower_grid_constraint should be lower than upper_grid_constraint!" 228 | ) 229 | 230 | # Checks for n_estimators 231 | if not isinstance(self.n_estimators, int): 232 | raise TypeError( 233 | f"n_estimators parameter must be an integer, " 234 | f"but found type {type(self.n_estimators)}" 235 | ) 236 | 237 | if (self.n_estimators < 10) or (self.n_estimators > 100000): 238 | raise TypeError("n_estimators parameter must be between 10 and 100000") 239 | 240 | # Checks for learning_rate 241 | if not isinstance(self.learning_rate, (int, float)): 242 | raise TypeError( 243 | f"learning_rate parameter must be an integer or float, " 244 | f"but found type {type(self.learning_rate)}" 245 | ) 246 | 247 | if (self.learning_rate <= 0) or (self.learning_rate >= 1000): 248 | raise TypeError( 249 | "learning_rate parameter must be greater than 0 and less than 1000" 250 | ) 251 | 252 | # Checks for max_depth 253 | if not isinstance(self.max_depth, int): 254 | raise TypeError( 255 | f"max_depth parameter must be an integer, " 256 | f"but found type {type(self.max_depth)}" 257 | ) 258 | 259 | if self.max_depth <= 0: 260 | raise TypeError("max_depth parameter must be greater than 0") 261 | 262 | # Checks for bandwidth 263 | if not isinstance(self.bandwidth, (int, float)): 264 | raise TypeError( 265 | f"bandwidth parameter must be an integer or float, " 266 | f"but found type {type(self.bandwidth)}" 267 | ) 268 | 269 | if (self.bandwidth <= 0) or (self.bandwidth >= 1000): 270 | raise TypeError( 271 | "bandwidth parameter must be greater than 0 and less than 1000" 272 | ) 273 | 274 | # Checks for random_seed 275 | if not isinstance(self.random_seed, (int, type(None))): 276 | raise TypeError( 277 | f"random_seed parameter must be an int, but found type {type(self.random_seed)}" 278 | ) 279 | 280 | if (isinstance(self.random_seed, int)) and self.random_seed < 0: 281 | raise ValueError("random_seed parameter must be > 0") 282 | 283 | # Checks for verbose 284 | if not isinstance(self.verbose, bool): 285 | raise TypeError( 286 | f"verbose parameter must be a boolean type, but found type {type(self.verbose)}" 287 | ) 288 | 289 | def _validate_fit_data(self): 290 | """Verifies that T, X, and y are formatted the right way""" 291 | # Checks for T column 292 | if not is_float_dtype(self.t_data): 293 | raise TypeError("Treatment data must be of type float") 294 | 295 | # Make sure all X columns are float or int 296 | if isinstance(self.x_data, pd.Series): 297 | if not is_numeric_dtype(self.x_data): 298 | raise TypeError( 299 | "All covariate (X) columns must be int or float type (i.e. must be numeric)" 300 | ) 301 | 302 | elif isinstance(self.x_data, pd.DataFrame): 303 | for column in self.x_data: 304 | if not is_numeric_dtype(self.x_data[column]): 305 | raise TypeError( 306 | """All covariate (X) columns must be int or float type 307 | (i.e. must be numeric)""" 308 | ) 309 | 310 | # Checks for Y column 311 | if not is_float_dtype(self.y_data): 312 | raise TypeError("Outcome data must be of type float") 313 | 314 | def _validate_calculate_CDRC_params(self, ci): 315 | """Validates the parameters given to `calculate_CDRC`""" 316 | 317 | if not isinstance(ci, float): 318 | raise TypeError( 319 | f"`ci` parameter must be an float, but found type {type(ci)}" 320 | ) 321 | 322 | if isinstance(ci, float) and ((ci <= 0) or (ci >= 1.0)): 323 | raise ValueError("`ci` parameter should be between (0, 1)") 324 | 325 | def fit(self, T, X, y): 326 | """Fits the TMLE causal dose-response model. For now, this only 327 | accepts pandas columns. You *must* provide at least one covariate column. 328 | 329 | Parameters 330 | ---------- 331 | T: array-like, shape (n_samples,) 332 | A continuous treatment variable 333 | X: array-like, shape (n_samples, m_features) 334 | Covariates, where n_samples is the number of samples 335 | and m_features is the number of features 336 | y: array-like, shape (n_samples,) 337 | Outcome variable 338 | 339 | Returns 340 | ---------- 341 | self : object 342 | 343 | """ 344 | self.rand_seed_wrapper(self.random_seed) 345 | 346 | self.t_data = T.reset_index(drop=True, inplace=False) 347 | self.x_data = X.reset_index(drop=True, inplace=False) 348 | self.y_data = y.reset_index(drop=True, inplace=False) 349 | 350 | # Validate this input data 351 | self._validate_fit_data() 352 | 353 | # Capture covariate and treatment column names 354 | self.treatment_col_name = self.t_data.name 355 | 356 | if len(self.x_data.shape) == 1: 357 | self.covariate_col_names = [self.x_data.name] 358 | else: 359 | self.covariate_col_names = self.x_data.columns.values.tolist() 360 | 361 | # Note the size of the data 362 | self.num_rows = len(self.t_data) 363 | 364 | # Produce expanded versions of the inputs 365 | self.if_verbose_print("Transforming data for the Q-model and G-model") 366 | ( 367 | self.grid_values, 368 | self.fully_expanded_x, 369 | self.fully_expanded_t_and_x, 370 | ) = self._transform_inputs() 371 | 372 | # Fit G-model and get relevent predictions 373 | self.if_verbose_print( 374 | "Fitting G-model and making treatment assignment predictions..." 375 | ) 376 | self.g_model_preds, self.g_model_2_preds = self._g_model() 377 | 378 | # Fit Q-model and get relevent predictions 379 | self.if_verbose_print("Fitting Q-model and making outcome predictions...") 380 | self.q_model_preds = self._q_model() 381 | 382 | # Calculating treatment assignment adjustment using G-model's predictions 383 | self.if_verbose_print( 384 | "Calculating treatment assignment adjustment using G-model's predictions..." 385 | ) 386 | ( 387 | self.n_interpd_values, 388 | self.var_n_interpd_values, 389 | ) = self._treatment_assignment_correction() 390 | 391 | # Adjusting outcome using Q-model's predictions 392 | self.if_verbose_print("Adjusting outcome using Q-model's predictions...") 393 | self.outcome_adjust, self.expand_outcome_adjust = self._outcome_adjustment() 394 | 395 | # Calculating corrected pseudo-outcome values 396 | self.if_verbose_print("Calculating corrected pseudo-outcome values...") 397 | self.pseudo_out = (self.y_data - self.outcome_adjust) / ( 398 | self.n_interpd_values / self.var_n_interpd_values 399 | ) + self.expand_outcome_adjust 400 | 401 | # Training final GAM model using pseudo-outcome values 402 | self.if_verbose_print("Training final GAM model using pseudo-outcome values...") 403 | self.final_gam = self._fit_final_gam() 404 | 405 | def calculate_CDRC(self, ci=0.95): 406 | """Using the results of the fitted model, this generates a dataframe of CDRC point estimates 407 | at each of the values of the treatment grid. Connecting these estimates will produce 408 | the overall estimated CDRC. Confidence interval is returned as well. 409 | 410 | Parameters 411 | ---------- 412 | ci: float (default = 0.95) 413 | The desired confidence interval to produce. Default value is 0.95, corresponding 414 | to 95% confidence intervals. bounded (0, 1.0). 415 | 416 | Returns 417 | ---------- 418 | dataframe: Pandas dataframe 419 | Contains treatment grid values, the CDRC point estimate at that value, 420 | and the associated lower and upper confidence interval bounds at that point. 421 | 422 | self: object 423 | 424 | """ 425 | self.rand_seed_wrapper(self.random_seed) 426 | 427 | self._validate_calculate_CDRC_params(ci) 428 | 429 | self.if_verbose_print( 430 | """ 431 | Generating predictions for each value of treatment grid, 432 | and averaging to get the CDRC...""" 433 | ) 434 | 435 | # Create CDRC predictions from the trained, final GAM 436 | 437 | self._cdrc_preds = self._cdrc_predictions_continuous(ci) 438 | 439 | return pd.DataFrame( 440 | self._cdrc_preds, 441 | columns=["Treatment", "Causal_Dose_Response", "Lower_CI", "Upper_CI"], 442 | ).round(3) 443 | 444 | def _transform_inputs(self): 445 | """Takes the treatment and covariates and transforms so they can 446 | be used by the algo""" 447 | 448 | # Create treatment grid 449 | grid_values = np.linspace( 450 | start=self.t_data.min(), stop=self.t_data.max(), num=self.treatment_grid_num 451 | ) 452 | 453 | # Create expanded treatment array 454 | expanded_t = np.array([]) 455 | for treat_value in grid_values: 456 | expanded_t = np.append(expanded_t, ([treat_value] * self.num_rows)) 457 | 458 | # Create expanded treatment array with covariates 459 | expanded_t_and_x = pd.concat( 460 | [ 461 | pd.DataFrame(expanded_t), 462 | pd.concat([self.x_data] * self.treatment_grid_num).reset_index( 463 | drop=True, inplace=False 464 | ), 465 | ], 466 | axis=1, 467 | ignore_index=True, 468 | ) 469 | 470 | expanded_t_and_x.columns = [self.treatment_col_name] + self.covariate_col_names 471 | 472 | fully_expanded_t_and_x = pd.concat( 473 | [pd.concat([self.x_data, self.t_data], axis=1), expanded_t_and_x], 474 | axis=0, 475 | ignore_index=True, 476 | ) 477 | 478 | fully_expanded_x = fully_expanded_t_and_x[self.covariate_col_names] 479 | 480 | return grid_values, fully_expanded_x, fully_expanded_t_and_x 481 | 482 | def _g_model(self): 483 | """Produces the G-model and gets treatment assignment predictions""" 484 | 485 | t = self.t_data.to_numpy() 486 | X = self.x_data.to_numpy() 487 | 488 | g_model = GradientBoostingRegressor( 489 | n_estimators=self.n_estimators, 490 | max_depth=self.max_depth, 491 | learning_rate=self.learning_rate, 492 | random_state=self.random_seed, 493 | ).fit(X=X, y=t) 494 | g_model_preds = g_model.predict(self.fully_expanded_x) 495 | 496 | g_model2 = GradientBoostingRegressor( 497 | n_estimators=self.n_estimators, 498 | max_depth=self.max_depth, 499 | learning_rate=self.learning_rate, 500 | random_state=self.random_seed, 501 | ).fit(X=X, y=((t - g_model_preds[0 : self.num_rows]) ** 2)) 502 | g_model_2_preds = g_model2.predict(self.fully_expanded_x) 503 | 504 | return g_model_preds, g_model_2_preds 505 | 506 | def _q_model(self): 507 | """Produces the Q-model and gets outcome predictions using the provided treatment 508 | values, when the treatment is completely present and not present. 509 | """ 510 | 511 | X = pd.concat([self.t_data, self.x_data], axis=1).to_numpy() 512 | y = self.y_data.to_numpy() 513 | 514 | q_model = GradientBoostingRegressor( 515 | n_estimators=self.n_estimators, 516 | max_depth=self.max_depth, 517 | learning_rate=self.learning_rate, 518 | random_state=self.random_seed, 519 | ).fit(X=X, y=y) 520 | q_model_preds = q_model.predict(self.fully_expanded_t_and_x) 521 | 522 | return q_model_preds 523 | 524 | def _treatment_assignment_correction(self): 525 | """Uses the G-model and its predictions to adjust treatment assignment""" 526 | 527 | t_standard = ( 528 | self.fully_expanded_t_and_x[self.treatment_col_name] - self.g_model_preds 529 | ) / np.sqrt(self.g_model_2_preds) 530 | 531 | interpd_values = ( 532 | interp1d( 533 | self.one_dim_estimate_density(t_standard.values)[0], 534 | self.one_dim_estimate_density(t_standard.values[0 : self.num_rows])[1], 535 | kind="linear", 536 | )(t_standard) 537 | / np.sqrt(self.g_model_2_preds) 538 | ) 539 | 540 | n_interpd_values = interpd_values[0 : self.num_rows] 541 | 542 | temp_interpd = interpd_values[self.num_rows :] 543 | 544 | zeros_mat = np.zeros((self.num_rows, self.treatment_grid_num)) 545 | 546 | for i in range(0, self.treatment_grid_num): 547 | lower = i * self.num_rows 548 | upper = i * self.num_rows + self.num_rows 549 | zeros_mat[:, i] = temp_interpd[lower:upper] 550 | 551 | var_n_interpd_values = self.pred_from_loess( 552 | train_x=self.grid_values, 553 | train_y=zeros_mat.mean(axis=0), 554 | x_to_pred=self.t_data, 555 | ) 556 | 557 | return n_interpd_values, var_n_interpd_values 558 | 559 | def _outcome_adjustment(self): 560 | """Uses the Q-model and its predictions to adjust the outcome""" 561 | 562 | outcome_adjust = self.q_model_preds[0 : self.num_rows] 563 | 564 | temp_outcome_adjust = self.q_model_preds[self.num_rows :] 565 | 566 | zero_mat = np.zeros((self.num_rows, self.treatment_grid_num)) 567 | for i in range(0, self.treatment_grid_num): 568 | lower = i * self.num_rows 569 | upper = i * self.num_rows + self.num_rows 570 | zero_mat[:, i] = temp_outcome_adjust[lower:upper] 571 | 572 | expand_outcome_adjust = self.pred_from_loess( 573 | train_x=self.grid_values, 574 | train_y=zero_mat.mean(axis=0), 575 | x_to_pred=self.t_data, 576 | ) 577 | 578 | return outcome_adjust, expand_outcome_adjust 579 | 580 | def _fit_final_gam(self): 581 | """We now regress the original treatment values against the pseudo-outcome values""" 582 | 583 | return LinearGAM( 584 | s(0, n_splines=30, spline_order=3), max_iter=500, lam=self.bandwidth 585 | ).fit(self.t_data, y=self.pseudo_out) 586 | 587 | def one_dim_estimate_density(self, series): 588 | """ 589 | Takes in a numpy array, returns grid values for KDE and predicted probabilities 590 | """ 591 | series_grid = np.linspace( 592 | start=series.min(), stop=series.max(), num=self.num_rows 593 | ) 594 | 595 | kde = KernelDensity(kernel="gaussian", bandwidth=self.bandwidth).fit( 596 | series.reshape(-1, 1) 597 | ) 598 | 599 | return series_grid, np.exp(kde.score_samples(series_grid.reshape(-1, 1))) 600 | 601 | def pred_from_loess(self, train_x, train_y, x_to_pred): 602 | """ 603 | Trains simple loess regression and returns predictions 604 | """ 605 | kr_model = KernelReg( 606 | endog=train_y, exog=train_x, var_type="c", bw=[self.bandwidth] 607 | ) 608 | 609 | return kr_model.fit(x_to_pred)[0] 610 | -------------------------------------------------------------------------------- /causal_curve/mediation.py: -------------------------------------------------------------------------------- 1 | """ 2 | Defines the Mediation test class 3 | """ 4 | 5 | from pprint import pprint 6 | 7 | import numpy as np 8 | import pandas as pd 9 | from pandas.api.types import is_float_dtype 10 | from pygam import LinearGAM, s 11 | 12 | from causal_curve.core import Core 13 | 14 | 15 | class Mediation(Core): 16 | """ 17 | Given three continuous variables (a treatment or independent variable of interest, 18 | a potential mediator, and an outcome variable of interest), Mediation provides a method 19 | to determine the average direct and indirect effect. 20 | 21 | Parameters 22 | ---------- 23 | 24 | treatment_grid_num: int, optional (default = 10) 25 | Takes the treatment, and creates a quantile-based grid across its values. 26 | For instance, if the number 6 is selected, this means the algorithm will only take 27 | the 6 treatment variable values at approximately the 0, 20, 40, 60, 80, and 100th 28 | percentiles to estimate the causal dose response curve. 29 | Higher value here means the final curve will be more finely estimated, 30 | but also increases computation time. Default is usually a reasonable number. 31 | 32 | lower_grid_constraint: float, optional(default = 0.01) 33 | This adds an optional constraint of the lower side of the treatment grid. 34 | Sometimes data near the minimum values of the treatment are few in number 35 | and thus generate unstable estimates. By default, this clips the bottom 1 percentile 36 | or lower of treatment values. This can be as low as 0, indicating there is no 37 | lower limit to how much treatment data is considered. 38 | 39 | upper_grid_constraint: float, optional (default = 0.99) 40 | See above parameter. Just like above, but this is an upper constraint. 41 | By default, this clips the top 99th percentile or higher of treatment values. 42 | This can be as high as 1.0, indicating there is no upper limit to how much 43 | treatment data is considered. 44 | 45 | bootstrap_draws: int, optional (default = 500) 46 | Bootstrapping is used as part of the mediation test. The parameter determines 47 | the number of draws from the original data to create a single bootstrap replicate. 48 | 49 | bootstrap_replicates: int, optional (default = 100) 50 | Bootstrapping is used as part of the mediation test. The parameter determines 51 | the number of bootstrapping runs to perform / number of new datasets to create. 52 | 53 | spline_order: int, optional (default = 3) 54 | Order of the splines to use fitting the final GAM. 55 | Must be integer >= 1. Default value creates cubic splines. 56 | 57 | n_splines: int, optional (default = 5) 58 | Number of splines to use for the mediation and outcome GAMs. 59 | Must be integer >= 2. Must be non-negative. 60 | 61 | lambda_: int or float, optional (default = 0.5) 62 | Strength of smoothing penalty. Must be a positive float. 63 | Larger values enforce stronger smoothing. 64 | 65 | max_iter: int, optional (default = 100) 66 | Maximum number of iterations allowed for the maximum likelihood algo to converge. 67 | 68 | random_seed: int, optional (default = None) 69 | Sets the random seed. 70 | 71 | verbose: bool, optional (default = False) 72 | Determines whether the user will get verbose status updates. 73 | 74 | 75 | Attributes 76 | ---------- 77 | 78 | grid_values: array of shape (treatment_grid_num, ) 79 | The gridded values of the treatment variable. Equally spaced. 80 | 81 | 82 | Methods 83 | ---------- 84 | fit: (self, T, M, y) 85 | Fits the trio of relevant variables using generalized additive models. 86 | 87 | calculate_effects: (self, ci) 88 | Calculates the average direct and indirect effects. 89 | 90 | 91 | Examples 92 | -------- 93 | >>> from causal_curve import Mediation 94 | >>> med = Mediation(treatment_grid_num = 200, random_seed = 512) 95 | >>> med.fit(T = df['Treatment'], M = df['Mediator'], y = df['Outcome']) 96 | >>> med_results = med.calculate_effects(0.95) 97 | 98 | 99 | References 100 | ---------- 101 | 102 | Imai K., Keele L., Tingley D. A General Approach to Causal Mediation Analysis. Psychological 103 | Methods. 15(4), 2010, pp.309–334. 104 | 105 | """ 106 | 107 | def __init__( 108 | self, 109 | treatment_grid_num=10, 110 | lower_grid_constraint=0.01, 111 | upper_grid_constraint=0.99, 112 | bootstrap_draws=500, 113 | bootstrap_replicates=100, 114 | spline_order=3, 115 | n_splines=5, 116 | lambda_=0.5, 117 | max_iter=100, 118 | random_seed=None, 119 | verbose=False, 120 | ): 121 | 122 | self.treatment_grid_num = treatment_grid_num 123 | self.lower_grid_constraint = lower_grid_constraint 124 | self.upper_grid_constraint = upper_grid_constraint 125 | self.bootstrap_draws = bootstrap_draws 126 | self.bootstrap_replicates = bootstrap_replicates 127 | self.spline_order = spline_order 128 | self.n_splines = n_splines 129 | self.lambda_ = lambda_ 130 | self.max_iter = max_iter 131 | self.random_seed = random_seed 132 | self.verbose = verbose 133 | 134 | # Validate the params 135 | self._validate_init_params() 136 | self.rand_seed_wrapper() 137 | 138 | if self.verbose: 139 | print("Using the following params for the mediation analysis:") 140 | pprint(self.get_params(), indent=4) 141 | 142 | def _validate_init_params(self): 143 | """ 144 | Checks that the params used when instantiating mediation tool are formatted correctly 145 | """ 146 | 147 | # Checks for treatment_grid_num 148 | if not isinstance(self.treatment_grid_num, int): 149 | raise TypeError( 150 | f"treatment_grid_num parameter must be an integer, " 151 | f"but found type {type(self.treatment_grid_num)}" 152 | ) 153 | 154 | if (isinstance(self.treatment_grid_num, int)) and self.treatment_grid_num < 4: 155 | raise ValueError( 156 | f"treatment_grid_num parameter should be >= 4 so the internal models " 157 | f"have enough resolution, but found value {self.treatment_grid_num}" 158 | ) 159 | 160 | if (isinstance(self.treatment_grid_num, int)) and self.treatment_grid_num > 100: 161 | raise ValueError("treatment_grid_num parameter is too high!") 162 | 163 | # Checks for lower_grid_constraint 164 | if not isinstance(self.lower_grid_constraint, float): 165 | raise TypeError( 166 | f"lower_grid_constraint parameter must be a float, " 167 | f"but found type {type(self.lower_grid_constraint)}" 168 | ) 169 | 170 | if ( 171 | isinstance(self.lower_grid_constraint, float) 172 | ) and self.lower_grid_constraint < 0: 173 | raise ValueError( 174 | f"lower_grid_constraint parameter cannot be < 0, " 175 | f"but found value {self.lower_grid_constraint}" 176 | ) 177 | 178 | if ( 179 | isinstance(self.lower_grid_constraint, float) 180 | ) and self.lower_grid_constraint >= 1.0: 181 | raise ValueError( 182 | f"lower_grid_constraint parameter cannot >= 1.0, " 183 | f"but found value {self.lower_grid_constraint}" 184 | ) 185 | 186 | # Checks for upper_grid_constraint 187 | if not isinstance(self.upper_grid_constraint, float): 188 | raise TypeError( 189 | f"upper_grid_constraint parameter must be a float, " 190 | f"but found type {type(self.upper_grid_constraint)}" 191 | ) 192 | 193 | if ( 194 | isinstance(self.upper_grid_constraint, float) 195 | ) and self.upper_grid_constraint <= 0: 196 | raise ValueError( 197 | f"upper_grid_constraint parameter cannot be <= 0, " 198 | f"but found value {self.upper_grid_constraint}" 199 | ) 200 | 201 | if ( 202 | isinstance(self.upper_grid_constraint, float) 203 | ) and self.upper_grid_constraint > 1.0: 204 | raise ValueError( 205 | f"upper_grid_constraint parameter cannot > 1.0, " 206 | f"but found value {self.upper_grid_constraint}" 207 | ) 208 | 209 | # Checks for bootstrap_draws 210 | if not isinstance(self.bootstrap_draws, int): 211 | raise TypeError( 212 | f"bootstrap_draws parameter must be a int, " 213 | f"but found type {type(self.bootstrap_draws)}" 214 | ) 215 | 216 | if (isinstance(self.bootstrap_draws, int)) and self.bootstrap_draws < 100: 217 | raise ValueError( 218 | f"bootstrap_draws parameter cannot be < 100, " 219 | f"but found value {self.bootstrap_draws}" 220 | ) 221 | 222 | if (isinstance(self.bootstrap_draws, int)) and self.bootstrap_draws > 500000: 223 | raise ValueError( 224 | f"bootstrap_draws parameter cannot > 500000, " 225 | f"but found value {self.bootstrap_draws}" 226 | ) 227 | 228 | # Checks for bootstrap_replicates 229 | if not isinstance(self.bootstrap_replicates, int): 230 | raise TypeError( 231 | f"bootstrap_replicates parameter must be a int, " 232 | f"but found type {type(self.bootstrap_replicates)}" 233 | ) 234 | 235 | if ( 236 | isinstance(self.bootstrap_replicates, int) 237 | ) and self.bootstrap_replicates < 50: 238 | raise ValueError( 239 | f"bootstrap_replicates parameter cannot be < 50, " 240 | f"but found value {self.bootstrap_replicates}" 241 | ) 242 | 243 | if ( 244 | isinstance(self.bootstrap_replicates, int) 245 | ) and self.bootstrap_replicates > 100000: 246 | raise ValueError( 247 | f"bootstrap_replicates parameter cannot > 100000, " 248 | f"but found value {self.bootstrap_replicates}" 249 | ) 250 | 251 | # Checks for lower_grid_constraint isn't higher than upper_grid_constraint 252 | if self.lower_grid_constraint >= self.upper_grid_constraint: 253 | raise ValueError( 254 | "lower_grid_constraint should be lower than upper_grid_constraint!" 255 | ) 256 | 257 | # Checks for spline_order 258 | if not isinstance(self.spline_order, int): 259 | raise TypeError( 260 | f"spline_order parameter must be an integer, " 261 | f"but found type {type(self.spline_order)}" 262 | ) 263 | 264 | if (isinstance(self.spline_order, int)) and self.spline_order < 3: 265 | raise ValueError( 266 | f"spline_order parameter should be >= 1, but found {self.spline_order}" 267 | ) 268 | 269 | if (isinstance(self.spline_order, int)) and self.spline_order >= 30: 270 | raise ValueError("spline_order parameter is too high!") 271 | 272 | # Checks for n_splines 273 | if not isinstance(self.n_splines, int): 274 | raise TypeError( 275 | f"n_splines parameter must be an integer, but found type {type(self.n_splines)}" 276 | ) 277 | 278 | if (isinstance(self.n_splines, int)) and self.n_splines < 2: 279 | raise ValueError( 280 | f"n_splines parameter should be >= 2, but found {self.n_splines}" 281 | ) 282 | 283 | if (isinstance(self.n_splines, int)) and self.n_splines >= 100: 284 | raise ValueError("n_splines parameter is too high!") 285 | 286 | # Checks for lambda_ 287 | if not isinstance(self.lambda_, (int, float)): 288 | raise TypeError( 289 | f"lambda_ parameter must be an int or float, but found type {type(self.lambda_)}" 290 | ) 291 | 292 | if (isinstance(self.lambda_, (int, float))) and self.lambda_ <= 0: 293 | raise ValueError( 294 | f"lambda_ parameter should be >= 2, but found {self.lambda_}" 295 | ) 296 | 297 | if (isinstance(self.lambda_, (int, float))) and self.lambda_ >= 1000: 298 | raise ValueError("lambda_ parameter is too high!") 299 | 300 | # Checks for max_iter 301 | if not isinstance(self.max_iter, int): 302 | raise TypeError( 303 | f"max_iter parameter must be an int, but found type {type(self.max_iter)}" 304 | ) 305 | 306 | if (isinstance(self.max_iter, int)) and self.max_iter <= 10: 307 | raise ValueError( 308 | "max_iter parameter is too low! Results won't be reliable!" 309 | ) 310 | 311 | if (isinstance(self.max_iter, int)) and self.max_iter >= 1e6: 312 | raise ValueError("max_iter parameter is unnecessarily high!") 313 | 314 | # Checks for random_seed 315 | if not isinstance(self.random_seed, (int, type(None))): 316 | raise TypeError( 317 | f"random_seed parameter must be an int, but found type {type(self.random_seed)}" 318 | ) 319 | 320 | if (isinstance(self.random_seed, int)) and self.random_seed < 0: 321 | raise ValueError("random_seed parameter must be > 0") 322 | 323 | # Checks for verbose 324 | if not isinstance(self.verbose, bool): 325 | raise TypeError( 326 | f"verbose parameter must be a boolean type, but found type {type(self.verbose)}" 327 | ) 328 | 329 | def _validate_fit_data(self): 330 | """Verifies that T, M, and y are formatted the right way""" 331 | # Checks for T column 332 | if not is_float_dtype(self.T): 333 | raise TypeError("Treatment data must be of type float") 334 | 335 | # Checks for M column 336 | if not is_float_dtype(self.M): 337 | raise TypeError("Mediation data must be of type float") 338 | 339 | # Checks for Y column 340 | if not is_float_dtype(self.y): 341 | raise TypeError("Outcome data must be of type float") 342 | 343 | def _grid_values(self): 344 | """Produces initial grid values for the treatment variable""" 345 | return np.quantile( 346 | self.T, 347 | q=np.linspace( 348 | start=self.lower_grid_constraint, 349 | stop=self.upper_grid_constraint, 350 | num=self.treatment_grid_num, 351 | ), 352 | ) 353 | 354 | def _collect_mean_t_levels(self): 355 | """Collects the mean treatment value within each treatment bucket in the grid_values""" 356 | 357 | t_bin_means = [] 358 | 359 | for index, _ in enumerate(self.grid_values): 360 | if index == (len(self.grid_values) - 1): 361 | continue 362 | 363 | t_bin_means.append( 364 | self.T[ 365 | ( 366 | (self.T >= self.grid_values[index]) 367 | & (self.T <= self.grid_values[index + 1]) 368 | ) 369 | ].mean() 370 | ) 371 | 372 | return t_bin_means 373 | 374 | def fit(self, T, M, y): 375 | """Fits models so that mediation analysis can be run. 376 | For now, this only accepts pandas columns. 377 | 378 | Parameters 379 | ---------- 380 | T: array-like, shape (n_samples,) 381 | A continuous treatment variable 382 | M: array-like, shape (n_samples,) 383 | A continuous mediation variable 384 | y: array-like, shape (n_samples,) 385 | A continuous outcome variable 386 | 387 | Returns 388 | ---------- 389 | self : object 390 | 391 | """ 392 | self.rand_seed_wrapper(self.random_seed) 393 | 394 | self.T = T.reset_index(drop=True, inplace=False) 395 | self.M = M.reset_index(drop=True, inplace=False) 396 | self.y = y.reset_index(drop=True, inplace=False) 397 | 398 | # Validate this input data 399 | self._validate_fit_data() 400 | 401 | self.n = len(y) 402 | 403 | # Create grid_values 404 | self.grid_values = self._grid_values() 405 | 406 | # Loop through the comparisons in the grid_values 407 | if self.verbose: 408 | print("Beginning main loop through treatment bins...") 409 | 410 | # Collect loop results in this list 411 | self.final_bootstrap_results = [] 412 | 413 | # Begin main loop 414 | for index, _ in enumerate(self.grid_values): 415 | if index == 0: 416 | continue 417 | if self.verbose: 418 | print( 419 | f"***** Starting iteration {index} of {len(self.grid_values) - 1} *****" 420 | ) 421 | 422 | temp_low_treatment = self.grid_values[index - 1] 423 | temp_high_treatment = self.grid_values[index] 424 | 425 | bootstrap_results = self._bootstrap_analysis( 426 | temp_low_treatment, temp_high_treatment 427 | ) 428 | 429 | self.final_bootstrap_results.append(bootstrap_results) 430 | 431 | def calculate_mediation(self, ci=0.95): 432 | """Conducts mediation analysis on the fit data 433 | 434 | Parameters 435 | ---------- 436 | ci: float (default = 0.95) 437 | The desired bootstrap confidence interval to produce. Default value is 0.95, 438 | corresponding to 95% confidence intervals. bounded (0, 1.0). 439 | 440 | Returns 441 | ---------- 442 | dataframe: Pandas dataframe 443 | Contains the estimate of the direct and indirect effects 444 | and the proportion of indirect effects across the treatment grid values. 445 | The bootstrap confidence interval that is returned might not be symmetric. 446 | 447 | self : object 448 | 449 | """ 450 | self.rand_seed_wrapper(self.random_seed) 451 | 452 | # Collect effect results in these lists 453 | self.t_bin_means = self._collect_mean_t_levels() 454 | self.prop_direct_list = [] 455 | self.prop_indirect_list = [] 456 | general_indirect = [] 457 | 458 | lower = (1 - ci) / 2 459 | upper = ci + lower 460 | 461 | # Calculate results for each treatment bin 462 | for index, _ in enumerate(self.grid_values): 463 | 464 | if index == (len(self.grid_values) - 1): 465 | continue 466 | 467 | temp_bootstrap_results = self.final_bootstrap_results[index] 468 | 469 | mean_results = { 470 | key: temp_bootstrap_results[key].mean() 471 | for key in temp_bootstrap_results 472 | } 473 | 474 | tau_coef = ( 475 | mean_results["d1"] 476 | + mean_results["d0"] 477 | + mean_results["z1"] 478 | + mean_results["z0"] 479 | ) / 2 480 | n0 = mean_results["d0"] / tau_coef 481 | n1 = mean_results["d1"] / tau_coef 482 | n_avg = (n0 + n1) / 2 483 | 484 | tau_general = ( 485 | temp_bootstrap_results["d1"] 486 | + temp_bootstrap_results["d0"] 487 | + temp_bootstrap_results["z1"] 488 | + temp_bootstrap_results["z0"] 489 | ) / 2 490 | nu_0_general = temp_bootstrap_results["d0"] / tau_general 491 | nu_1_general = temp_bootstrap_results["d1"] / tau_general 492 | nu_avg_general = (nu_0_general + nu_1_general) / 2 493 | 494 | self.prop_direct_list.append(1 - n_avg) 495 | self.prop_indirect_list.append(n_avg) 496 | general_indirect.append(nu_avg_general) 497 | 498 | general_indirect = pd.concat(general_indirect) 499 | 500 | # Bootstrap these general_indirect values 501 | bootstrap_overall_means = [] 502 | for _ in range(0, 1000): 503 | bootstrap_overall_means.append( 504 | general_indirect.sample(frac=0.25, replace=True).mean() 505 | ) 506 | 507 | bootstrap_overall_means = np.array(bootstrap_overall_means) 508 | 509 | final_results = pd.DataFrame( 510 | { 511 | "Treatment_Value": self.t_bin_means, 512 | "Proportion_Direct_Effect": self.prop_direct_list, 513 | "Proportion_Indirect_Effect": self.prop_indirect_list, 514 | } 515 | ).round(4) 516 | 517 | # Clip Proportion_Direct_Effect and Proportion_Indirect_Effect 518 | final_results["Proportion_Direct_Effect"].clip(lower=0, upper=1.0, inplace=True) 519 | final_results["Proportion_Indirect_Effect"].clip( 520 | lower=0, upper=1.0, inplace=True 521 | ) 522 | 523 | # Calculate overall, mean, indirect effect 524 | total_prop_mean = round(np.array(self.prop_indirect_list).mean(), 4) 525 | total_prop_lower = self.clip_negatives( 526 | round(np.percentile(bootstrap_overall_means, q=lower * 100), 4) 527 | ) 528 | total_prop_upper = self.clip_negatives( 529 | round(np.percentile(bootstrap_overall_means, q=upper * 100), 4) 530 | ) 531 | 532 | print( 533 | f"""\n\nMean indirect effect proportion: 534 | {total_prop_mean} ({total_prop_lower} - {total_prop_upper}) 535 | """ 536 | ) 537 | return final_results 538 | 539 | def _bootstrap_analysis(self, temp_low_treatment, temp_high_treatment): 540 | """The top-level function used in the fitting method""" 541 | 542 | bootstrap_collection = [] 543 | 544 | for _ in range(0, self.bootstrap_replicates): 545 | # Create single bootstrap replicate 546 | temp_t, temp_m, temp_y = self._create_bootstrap_replicate() 547 | # Create the models from this 548 | temp_mediator_model, temp_outcome_model = self._fit_gams( 549 | temp_t, temp_m, temp_y 550 | ) 551 | # Make mediator predictions 552 | predict_m1, predict_m0 = self._mediator_prediction( 553 | temp_mediator_model, 554 | temp_t, 555 | temp_m, 556 | temp_low_treatment, 557 | temp_high_treatment, 558 | ) 559 | # Make outcome predictions 560 | outcome_preds = self._outcome_prediction( 561 | temp_low_treatment, 562 | temp_high_treatment, 563 | predict_m1, 564 | predict_m0, 565 | temp_outcome_model, 566 | ) 567 | # Collect the replicate results here 568 | bootstrap_collection.append(outcome_preds) 569 | 570 | # Convert this into a dataframe 571 | bootstrap_results = pd.DataFrame(bootstrap_collection) 572 | 573 | return bootstrap_results 574 | 575 | def _create_bootstrap_replicate(self): 576 | """Creates a single bootstrap replicate from the data""" 577 | temp_t = self.T.sample(n=self.bootstrap_draws, replace=True) 578 | temp_m = self.M.iloc[temp_t.index] 579 | temp_y = self.y.iloc[temp_t.index] 580 | 581 | return temp_t, temp_m, temp_y 582 | 583 | def _fit_gams(self, temp_t, temp_m, temp_y): 584 | """Fits the mediator and outcome GAMs""" 585 | temp_mediator_model = LinearGAM( 586 | s(0, n_splines=self.n_splines, spline_order=self.spline_order), 587 | fit_intercept=True, 588 | max_iter=self.max_iter, 589 | lam=self.lambda_, 590 | ) 591 | temp_mediator_model.fit(temp_t, temp_m) 592 | 593 | temp_outcome_model = LinearGAM( 594 | s(0, n_splines=self.n_splines, spline_order=self.spline_order) 595 | + s(1, n_splines=self.n_splines, spline_order=self.spline_order), 596 | fit_intercept=True, 597 | max_iter=self.max_iter, 598 | lam=self.lambda_, 599 | ) 600 | temp_outcome_model.fit(pd.concat([temp_t, temp_m], axis=1), temp_y) 601 | 602 | return temp_mediator_model, temp_outcome_model 603 | 604 | def _mediator_prediction( 605 | self, 606 | temp_mediator_model, 607 | temp_t, 608 | temp_m, 609 | temp_low_treatment, 610 | temp_high_treatment, 611 | ): 612 | """Makes predictions based on the mediator models""" 613 | 614 | m1_mean = temp_mediator_model.predict(temp_high_treatment)[0] 615 | m0_mean = temp_mediator_model.predict(temp_low_treatment)[0] 616 | 617 | std_dev = ( 618 | (temp_mediator_model.deviance_residuals(temp_t, temp_m) ** 2).sum() 619 | ) / (self.n - (len(temp_mediator_model.get_params()["terms"]._terms) + 1)) 620 | 621 | est_error = np.random.normal(loc=0, scale=std_dev, size=self.n) 622 | 623 | predict_m1 = m1_mean + est_error 624 | predict_m0 = m0_mean + est_error 625 | 626 | return predict_m1, predict_m0 627 | 628 | def _outcome_prediction( 629 | self, 630 | temp_low_treatment, 631 | temp_high_treatment, 632 | predict_m1, 633 | predict_m0, 634 | temp_outcome_model, 635 | ): 636 | """Makes predictions based on the outcome models""" 637 | 638 | outcome_preds = {} 639 | 640 | inputs = [ 641 | ["d1", temp_high_treatment, temp_high_treatment, predict_m1, predict_m0], 642 | ["d0", temp_low_treatment, temp_low_treatment, predict_m1, predict_m0], 643 | ["z1", temp_high_treatment, temp_low_treatment, predict_m1, predict_m1], 644 | ["z0", temp_high_treatment, temp_low_treatment, predict_m0, predict_m0], 645 | ] 646 | 647 | for element in inputs: 648 | 649 | # Set treatment values 650 | t_1 = element[1] 651 | t_0 = element[2] 652 | 653 | # Set mediator values 654 | m_1 = element[3] 655 | m_0 = element[4] 656 | 657 | pr_1 = temp_outcome_model.predict( 658 | np.column_stack((np.repeat(t_1, self.n), m_1)) 659 | ) 660 | 661 | pr_0 = temp_outcome_model.predict( 662 | np.column_stack((np.repeat(t_0, self.n), m_0)) 663 | ) 664 | 665 | outcome_preds[element[0]] = (pr_1 - pr_0).mean() 666 | 667 | return outcome_preds 668 | --------------------------------------------------------------------------------