├── tests
├── __init__.py
├── unit
│ ├── __init__.py
│ ├── test_tmle_core.py
│ ├── test_gps_classifier.py
│ ├── test_core.py
│ ├── test_gps_core.py
│ ├── test_mediation.py
│ ├── test_tmle_regressor.py
│ └── test_gps_regressor.py
├── integration
│ ├── __init__.py
│ ├── test_tmle.py
│ ├── test_mediation.py
│ └── test_gps.py
├── test_helpers.py
└── conftest.py
├── imgs
├── curves.png
├── cdrc
│ └── CDRC.png
├── tmle_plot.png
├── welcome_plot.png
├── binary_OR_fig.png
├── full_example
│ ├── BLL_dist.png
│ ├── test_dist.png
│ ├── lead_paint_can.jpeg
│ ├── mediation_curve.png
│ └── test_causal_curves.png
└── mediation
│ ├── diabetes_DAG.png
│ └── mediation_effect.png
├── paper
├── welcome_plot.png
├── paper.bib
└── paper.md
├── docs
├── modules.rst
├── requirements.txt
├── citation.rst
├── Makefile
├── make.bat
├── install.rst
├── causal_curve.rst
├── conf.py
├── GPS_Classifier.rst
├── TMLE_Regressor.rst
├── index.rst
├── GPS_Regressor.rst
├── Mediation_example.rst
├── changelog.rst
├── intro.rst
├── contribute.rst
└── full_example.rst
├── codecov.yml
├── requirements.txt
├── causal_curve
├── __init__.py
├── core.py
├── gps_classifier.py
├── tmle_regressor.py
├── gps_regressor.py
├── tmle_core.py
└── mediation.py
├── .travis.yml
├── LICENSE
├── .gitignore
├── setup.py
├── README.md
└── .pylintrc
/tests/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/tests/unit/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/tests/integration/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/imgs/curves.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ronikobrosly/causal-curve/HEAD/imgs/curves.png
--------------------------------------------------------------------------------
/imgs/cdrc/CDRC.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ronikobrosly/causal-curve/HEAD/imgs/cdrc/CDRC.png
--------------------------------------------------------------------------------
/imgs/tmle_plot.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ronikobrosly/causal-curve/HEAD/imgs/tmle_plot.png
--------------------------------------------------------------------------------
/imgs/welcome_plot.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ronikobrosly/causal-curve/HEAD/imgs/welcome_plot.png
--------------------------------------------------------------------------------
/imgs/binary_OR_fig.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ronikobrosly/causal-curve/HEAD/imgs/binary_OR_fig.png
--------------------------------------------------------------------------------
/paper/welcome_plot.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ronikobrosly/causal-curve/HEAD/paper/welcome_plot.png
--------------------------------------------------------------------------------
/docs/modules.rst:
--------------------------------------------------------------------------------
1 | causal_curve
2 | ============
3 |
4 | .. toctree::
5 | :maxdepth: 4
6 |
7 | causal_curve
8 |
--------------------------------------------------------------------------------
/imgs/full_example/BLL_dist.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ronikobrosly/causal-curve/HEAD/imgs/full_example/BLL_dist.png
--------------------------------------------------------------------------------
/imgs/full_example/test_dist.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ronikobrosly/causal-curve/HEAD/imgs/full_example/test_dist.png
--------------------------------------------------------------------------------
/imgs/mediation/diabetes_DAG.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ronikobrosly/causal-curve/HEAD/imgs/mediation/diabetes_DAG.png
--------------------------------------------------------------------------------
/imgs/mediation/mediation_effect.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ronikobrosly/causal-curve/HEAD/imgs/mediation/mediation_effect.png
--------------------------------------------------------------------------------
/imgs/full_example/lead_paint_can.jpeg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ronikobrosly/causal-curve/HEAD/imgs/full_example/lead_paint_can.jpeg
--------------------------------------------------------------------------------
/imgs/full_example/mediation_curve.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ronikobrosly/causal-curve/HEAD/imgs/full_example/mediation_curve.png
--------------------------------------------------------------------------------
/imgs/full_example/test_causal_curves.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ronikobrosly/causal-curve/HEAD/imgs/full_example/test_causal_curves.png
--------------------------------------------------------------------------------
/codecov.yml:
--------------------------------------------------------------------------------
1 | coverage:
2 | status:
3 | project:
4 | default:
5 | threshold: 0%
6 | patch:
7 | default:
8 | threshold: 0%
9 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | black
2 | coverage
3 | future
4 | joblib
5 | numpy
6 | numpydoc
7 | pandas
8 | patsy
9 | progressbar2
10 | pygam
11 | pytest
12 | python-dateutil
13 | python-utils
14 | pytz
15 | scikit-learn
16 | scipy
17 | six
18 | sphinx_rtd_theme
19 | statsmodels
20 |
--------------------------------------------------------------------------------
/docs/requirements.txt:
--------------------------------------------------------------------------------
1 | black
2 | coverage
3 | future
4 | joblib
5 | numpy
6 | numpydoc
7 | pandas
8 | patsy
9 | progressbar2
10 | pygam
11 | pytest
12 | python-dateutil
13 | python-utils
14 | pytz
15 | scikit-learn
16 | scipy
17 | six
18 | sphinx_rtd_theme
19 | statsmodels
20 |
--------------------------------------------------------------------------------
/docs/citation.rst:
--------------------------------------------------------------------------------
1 | Citation
2 | ========
3 |
4 | Please consider citing us in your academic or industry project.
5 |
6 | Kobrosly, R. W., (2020). causal-curve: A Python Causal Inference Package to Estimate
7 | Causal Dose-Response Curves. Journal of Open Source Software, 5(52), 2523, https://doi.org/10.21105/joss.02523
8 |
--------------------------------------------------------------------------------
/tests/test_helpers.py:
--------------------------------------------------------------------------------
1 | """Misc helper functions for tests"""
2 |
3 | from pandas.testing import assert_frame_equal
4 |
5 |
6 | def assert_df_equal(observed_frame, expected_frame, check_less_precise):
7 | """Assert that two pandas dataframes are equal, ignoring ordering of columns."""
8 | assert_frame_equal(
9 | observed_frame.sort_index(axis=1),
10 | expected_frame.sort_index(axis=1),
11 | check_names=True,
12 | check_less_precise=check_less_precise,
13 | )
14 |
--------------------------------------------------------------------------------
/causal_curve/__init__.py:
--------------------------------------------------------------------------------
1 | """causal_curve module"""
2 |
3 | import warnings
4 |
5 | from statsmodels.genmod.generalized_linear_model import DomainWarning
6 |
7 | from causal_curve.gps_classifier import GPS_Classifier
8 | from causal_curve.gps_regressor import GPS_Regressor
9 |
10 | from causal_curve.tmle_regressor import TMLE_Regressor
11 | from causal_curve.mediation import Mediation
12 |
13 |
14 | # Suppress statsmodel warning for gamma family GLM
15 | warnings.filterwarnings("ignore", category=DomainWarning)
16 | warnings.filterwarnings("ignore", category=UserWarning)
17 |
--------------------------------------------------------------------------------
/.travis.yml:
--------------------------------------------------------------------------------
1 | language: python
2 |
3 | # sudo false implies containerized builds
4 | sudo: false
5 |
6 | python:
7 | - 3.6
8 |
9 | env:
10 | global:
11 | # test invocation
12 | - TESTFOLDER="tests"
13 |
14 | before_install:
15 | # Here we download miniconda and install the dependencies
16 | - pip install black coverage future joblib numpy numpydoc pandas patsy progressbar2 pygam pytest python-dateutil python-utils pytz scikit-learn scipy six sphinx_rtd_theme statsmodels
17 |
18 | install:
19 | - python setup.py install
20 |
21 | script:
22 | - coverage run -m pytest $TESTFOLDER
23 |
24 | after_success:
25 | - bash <(curl -s https://codecov.io/bash)
26 |
--------------------------------------------------------------------------------
/docs/Makefile:
--------------------------------------------------------------------------------
1 | # Minimal makefile for Sphinx documentation
2 | #
3 |
4 | # You can set these variables from the command line, and also
5 | # from the environment for the first two.
6 | SPHINXOPTS ?=
7 | SPHINXBUILD ?= sphinx-build
8 | SOURCEDIR = .
9 | BUILDDIR = _build
10 |
11 | # Put it first so that "make" without argument is like "make help".
12 | help:
13 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
14 |
15 | .PHONY: help Makefile
16 |
17 | # Catch-all target: route all unknown targets to Sphinx using the new
18 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS).
19 | %: Makefile
20 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
21 |
--------------------------------------------------------------------------------
/tests/integration/test_tmle.py:
--------------------------------------------------------------------------------
1 | """ Integration tests of the tmle.py module """
2 |
3 | import pandas as pd
4 |
5 | from causal_curve import TMLE_Regressor
6 |
7 |
8 | def test_full_tmle_flow(continuous_dataset_fixture):
9 | """
10 | Tests the full flow of the TMLE tool
11 | """
12 |
13 | tmle = TMLE_Regressor(
14 | random_seed=100,
15 | verbose=True,
16 | )
17 | tmle.fit(
18 | T=continuous_dataset_fixture["treatment"],
19 | X=continuous_dataset_fixture[["x1", "x2"]],
20 | y=continuous_dataset_fixture["outcome"],
21 | )
22 | tmle_results = tmle.calculate_CDRC(0.95)
23 |
24 | assert isinstance(tmle_results, pd.DataFrame)
25 | check = tmle_results.columns == [
26 | "Treatment",
27 | "Causal_Dose_Response",
28 | "Lower_CI",
29 | "Upper_CI",
30 | ]
31 | assert check.all()
32 |
--------------------------------------------------------------------------------
/docs/make.bat:
--------------------------------------------------------------------------------
1 | @ECHO OFF
2 |
3 | pushd %~dp0
4 |
5 | REM Command file for Sphinx documentation
6 |
7 | if "%SPHINXBUILD%" == "" (
8 | set SPHINXBUILD=sphinx-build
9 | )
10 | set SOURCEDIR=.
11 | set BUILDDIR=_build
12 |
13 | if "%1" == "" goto help
14 |
15 | %SPHINXBUILD% >NUL 2>NUL
16 | if errorlevel 9009 (
17 | echo.
18 | echo.The 'sphinx-build' command was not found. Make sure you have Sphinx
19 | echo.installed, then set the SPHINXBUILD environment variable to point
20 | echo.to the full path of the 'sphinx-build' executable. Alternatively you
21 | echo.may add the Sphinx directory to PATH.
22 | echo.
23 | echo.If you don't have Sphinx installed, grab it from
24 | echo.http://sphinx-doc.org/
25 | exit /b 1
26 | )
27 |
28 | %SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O%
29 | goto end
30 |
31 | :help
32 | %SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O%
33 |
34 | :end
35 | popd
36 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2020 Roni Kobrosly
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/tests/integration/test_mediation.py:
--------------------------------------------------------------------------------
1 | """ Integration tests of the mediation.py module """
2 |
3 | import pandas as pd
4 |
5 | from causal_curve import Mediation
6 |
7 |
8 | def test_full_mediation_flow(mediation_fixture):
9 | """
10 | Tests the full flow of the Mediation tool
11 | """
12 |
13 | med = Mediation(
14 | treatment_grid_num=10,
15 | lower_grid_constraint=0.01,
16 | upper_grid_constraint=0.99,
17 | bootstrap_draws=100,
18 | bootstrap_replicates=50,
19 | spline_order=3,
20 | n_splines=5,
21 | lambda_=0.5,
22 | max_iter=20,
23 | random_seed=None,
24 | verbose=True,
25 | )
26 | med.fit(
27 | T=mediation_fixture["treatment"],
28 | M=mediation_fixture["mediator"],
29 | y=mediation_fixture["outcome"],
30 | )
31 |
32 | med_results = med.calculate_mediation(0.95)
33 |
34 | assert isinstance(med_results, pd.DataFrame)
35 | check = med_results.columns == [
36 | "Treatment_Value",
37 | "Proportion_Direct_Effect",
38 | "Proportion_Indirect_Effect",
39 | ]
40 | assert check.all()
41 |
--------------------------------------------------------------------------------
/tests/unit/test_tmle_core.py:
--------------------------------------------------------------------------------
1 | """ Unit tests of the tmle.py module """
2 |
3 | import pytest
4 |
5 | from causal_curve.tmle_core import TMLE_Core
6 |
7 |
8 | def test_tmle_fit(continuous_dataset_fixture):
9 | """
10 | Tests the fit method GPS tool
11 | """
12 |
13 | tmle = TMLE_Core(
14 | random_seed=100,
15 | verbose=True,
16 | )
17 | tmle.fit(
18 | T=continuous_dataset_fixture["treatment"],
19 | X=continuous_dataset_fixture[["x1", "x2"]],
20 | y=continuous_dataset_fixture["outcome"],
21 | )
22 |
23 | assert tmle.num_rows == 500
24 | assert tmle.fully_expanded_t_and_x.shape == (50500, 3)
25 |
26 |
27 | def test_bad_param_calculate_CDRC_TMLE(TMLE_fitted_model_continuous_fixture):
28 | """
29 | Tests the TMLE `calculate_CDRC` when the `ci` param is bad
30 | """
31 |
32 | with pytest.raises(Exception) as bad:
33 | observed_result = TMLE_fitted_model_continuous_fixture.calculate_CDRC(
34 | np.array([50]), ci={"param": 0.95}
35 | )
36 |
37 | with pytest.raises(Exception) as bad:
38 | observed_result = TMLE_fitted_model_continuous_fixture.calculate_CDRC(
39 | np.array([50]), ci=1.05
40 | )
41 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | *.DS_Store
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 | *.so
6 | .Python
7 | build/
8 | develop-eggs/
9 | dist/
10 | downloads/
11 | eggs/
12 | .eggs/
13 | lib/
14 | lib64/
15 | parts/
16 | sdist/
17 | var/
18 | wheels/
19 | pip-wheel-metadata/
20 | share/python-wheels/
21 | *.egg-info/
22 | .installed.cfg
23 | *.egg
24 | MANIFEST
25 | *.manifest
26 | *.spec
27 | pip-log.txt
28 | pip-delete-this-directory.txt
29 | htmlcov/
30 | .tox/
31 | .nox/
32 | .coverage
33 | .coverage.*
34 | .cache
35 | nosetests.xml
36 | coverage.xml
37 | *.cover
38 | *.py,cover
39 | .hypothesis/
40 | .pytest_cache/
41 | *.mo
42 | *.pot
43 | *.log
44 | local_settings.py
45 | db.sqlite3
46 | db.sqlite3-journal
47 | instance/
48 | .webassets-cache
49 | .scrapy
50 | docs/_build/
51 | target/
52 | .ipynb_checkpoints
53 | profile_default/
54 | ipython_config.py
55 | .python-version
56 | __pypackages__/
57 | celerybeat-schedule
58 | celerybeat.pid
59 | *.sage.py
60 | .env
61 | .venv
62 | env/
63 | venv/
64 | ENV/
65 | env.bak/
66 | venv.bak/
67 | .spyderproject
68 | .spyproject
69 | .ropeproject
70 | /site
71 | .mypy_cache/
72 | .dmypy.json
73 | dmypy.json
74 | .pyre/
75 | .idea/
76 | .idea_modules/
77 | workdir
78 | reports
79 | tests/integration/api/snapshots
80 | tests/integration/api/webcache
81 |
--------------------------------------------------------------------------------
/tests/unit/test_gps_classifier.py:
--------------------------------------------------------------------------------
1 | """ Unit tests for the GPS_Core class """
2 |
3 | import numpy as np
4 | from pygam import LinearGAM
5 | import pytest
6 |
7 | from causal_curve.gps_core import GPS_Core
8 | from tests.conftest import full_continuous_example_dataset
9 |
10 |
11 | def test_predict_log_odds_method_good(GPS_fitted_model_binary_fixture):
12 | """
13 | Tests the GPS `estimate_log_odds` method using appropriate data (with a binary outcome)
14 | """
15 | observed_result = GPS_fitted_model_binary_fixture.estimate_log_odds(np.array([0.5]))
16 | assert isinstance(observed_result[0][0], float)
17 | assert len(observed_result[0]) == 1
18 |
19 | observed_result = GPS_fitted_model_binary_fixture.estimate_log_odds(
20 | np.array([0.5, 0.6, 0.7])
21 | )
22 | assert isinstance(observed_result[0][0], float)
23 | assert len(observed_result[0]) == 3
24 |
25 |
26 | def test_predict_log_odds_method_bad(GPS_fitted_model_continuous_fixture):
27 | """
28 | Tests the GPS `estimate_log_odds` method using appropriate data (with a continuous outcome)
29 | """
30 | with pytest.raises(Exception) as bad:
31 | observed_result = GPS_fitted_model_continuous_fixture.estimate_log_odds(
32 | np.array([50])
33 | )
34 |
--------------------------------------------------------------------------------
/docs/install.rst:
--------------------------------------------------------------------------------
1 | .. _install:
2 |
3 | =====================================
4 | Installation, testing and development
5 | =====================================
6 |
7 | Dependencies
8 | ------------
9 |
10 | causal-curve requires:
11 |
12 | - black
13 | - coverage
14 | - future
15 | - joblib
16 | - numpy
17 | - numpydoc
18 | - pandas
19 | - patsy
20 | - progressbar2
21 | - pygam
22 | - pytest
23 | - python-dateutil
24 | - python-utils
25 | - pytz
26 | - scikit-learn
27 | - scipy
28 | - six
29 | - sphinx_rtd_theme
30 | - statsmodels
31 |
32 |
33 |
34 | User installation
35 | -----------------
36 |
37 | If you already have a working installation of numpy, pandas, pygam, scipy, and statsmodels,
38 | you can easily install causal-curve using ``pip``::
39 |
40 | pip install causal-curve
41 |
42 |
43 | You can also get the latest version of causal-curve by cloning the repository::
44 |
45 | git clone https://github.com/ronikobrosly/causal-curve.git
46 | cd causal-curve
47 | pip install .
48 |
49 |
50 | Testing
51 | -------
52 |
53 | After installation, you can launch the test suite from outside the source
54 | directory using ``pytest``::
55 |
56 | pytest
57 |
58 |
59 | Development
60 | -----------
61 |
62 | Please reach out if you are interested in adding additional tools,
63 | or have ideas on how to improve the package!
64 |
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | import setuptools
2 |
3 | with open("README.md", "r") as fh:
4 | long_description = fh.read()
5 |
6 | setuptools.setup(
7 | name="causal-curve",
8 | version="1.0.6",
9 | author="Roni Kobrosly",
10 | author_email="roni.kobrosly@gmail.com",
11 | description="A python library with tools to perform causal inference using \
12 | observational data when the treatment of interest is continuous.",
13 | long_description=long_description,
14 | long_description_content_type="text/markdown",
15 | url="https://github.com/ronikobrosly/causal-curve",
16 | packages=setuptools.find_packages(include=['causal_curve']),
17 | classifiers=[
18 | "Programming Language :: Python :: 3",
19 | "License :: OSI Approved :: MIT License",
20 | "Operating System :: OS Independent",
21 | ],
22 | python_requires='>=3.6',
23 | install_requires=[
24 | 'black',
25 | 'coverage',
26 | 'future',
27 | 'joblib',
28 | 'numpy',
29 | 'numpydoc',
30 | 'pandas',
31 | 'patsy',
32 | 'progressbar2',
33 | 'pygam',
34 | 'pytest',
35 | 'python-dateutil',
36 | 'python-utils',
37 | 'pytz',
38 | 'scikit-learn',
39 | 'scipy',
40 | 'six',
41 | 'sphinx_rtd_theme',
42 | 'statsmodels'
43 | ]
44 | )
45 |
--------------------------------------------------------------------------------
/docs/causal_curve.rst:
--------------------------------------------------------------------------------
1 | causal\_curve package
2 | =====================
3 |
4 |
5 | causal\_curve.core module
6 | -------------------------
7 |
8 | .. automodule:: causal_curve.core
9 | :members:
10 | :undoc-members:
11 | :show-inheritance:
12 |
13 | causal\_curve.gps_core module
14 | -----------------------------
15 |
16 | .. automodule:: causal_curve.gps_core
17 | :members:
18 | :undoc-members:
19 | :show-inheritance:
20 |
21 | causal\_curve.gps_regressor module
22 | ----------------------------------
23 |
24 | .. automodule:: causal_curve.gps_regressor
25 | :members:
26 | :undoc-members:
27 | :show-inheritance:
28 |
29 |
30 | causal\_curve.gps_classifier module
31 | -----------------------------------
32 |
33 | .. automodule:: causal_curve.gps_classifier
34 | :members:
35 | :undoc-members:
36 | :show-inheritance:
37 |
38 |
39 | causal\_curve.tmle_core module
40 | ------------------------------
41 |
42 | .. automodule:: causal_curve.tmle_core
43 | :members:
44 | :undoc-members:
45 | :show-inheritance:
46 |
47 | causal\_curve.tmle_regressor module
48 | -----------------------------------
49 |
50 | .. automodule:: causal_curve.tmle_regressor
51 | :members:
52 | :undoc-members:
53 | :show-inheritance:
54 |
55 | causal\_curve.mediation module
56 | ------------------------------
57 |
58 | .. automodule:: causal_curve.mediation
59 | :members:
60 | :undoc-members:
61 | :show-inheritance:
62 |
63 |
64 |
65 |
66 |
67 | Module contents
68 | ---------------
69 |
70 | .. automodule:: causal_curve
71 | :members:
72 | :undoc-members:
73 | :show-inheritance:
74 |
--------------------------------------------------------------------------------
/tests/unit/test_core.py:
--------------------------------------------------------------------------------
1 | """ General unit tests of the causal-curve package """
2 |
3 | import numpy as np
4 |
5 | from causal_curve.core import Core
6 |
7 |
8 | def test_get_params():
9 | """
10 | Tests the `get_params` method of the Core base class
11 | """
12 |
13 | core = Core()
14 | core.a = 5
15 | core.b = 10
16 |
17 | observed_results = core.get_params()
18 |
19 | assert observed_results == {"a": 5, "b": 10}
20 |
21 |
22 | def test_if_verbose_print(capfd):
23 | """
24 | Tests the `if_verbose_print` method of the Core base class
25 | """
26 |
27 | core = Core()
28 | core.verbose = True
29 |
30 | core.if_verbose_print("This is a test")
31 | out, err = capfd.readouterr()
32 |
33 | assert out == "This is a test\n"
34 |
35 | core.verbose = False
36 |
37 | core.if_verbose_print("This is a test")
38 | out, err = capfd.readouterr()
39 |
40 | assert out == ""
41 |
42 |
43 | def test_rand_seed_wrapper():
44 | """
45 | Tests the `rand_seed_wrapper` method of the Core base class
46 | """
47 |
48 | core = Core()
49 | core.rand_seed_wrapper(123)
50 |
51 | assert np.random.get_state()[1][0] == 123
52 |
53 |
54 | def test_calculate_z_score():
55 | """
56 | Tests the `calculate_z_score` method of the Core base class
57 | """
58 |
59 | core = Core()
60 | assert round(core.calculate_z_score(0.95), 3) == 1.960
61 | assert round(core.calculate_z_score(0.90), 3) == 1.645
62 |
63 |
64 | def test_clip_negatives():
65 | """
66 | Tests the `clip_negatives` method of the Core base class
67 | """
68 |
69 | core = Core()
70 | assert core.clip_negatives(0.5) == 0.5
71 | assert core.clip_negatives(-1.5) == 0
72 |
--------------------------------------------------------------------------------
/tests/unit/test_gps_core.py:
--------------------------------------------------------------------------------
1 | """ Unit tests for the GPS_Core class """
2 |
3 | import numpy as np
4 | from pygam import LinearGAM
5 | import pytest
6 |
7 | from causal_curve.gps_core import GPS_Core
8 | from tests.conftest import full_continuous_example_dataset
9 |
10 |
11 | @pytest.mark.parametrize(
12 | ("df_fixture", "family"),
13 | [
14 | (full_continuous_example_dataset, "normal"),
15 | (full_continuous_example_dataset, "lognormal"),
16 | (full_continuous_example_dataset, "gamma"),
17 | (full_continuous_example_dataset, None),
18 | ],
19 | )
20 | def test_gps_fit(df_fixture, family):
21 | """
22 | Tests the fit method of the GPS_Core tool
23 | """
24 |
25 | gps = GPS_Core(
26 | gps_family=family,
27 | treatment_grid_num=10,
28 | lower_grid_constraint=0.0,
29 | upper_grid_constraint=1.0,
30 | spline_order=3,
31 | n_splines=10,
32 | max_iter=100,
33 | random_seed=100,
34 | verbose=True,
35 | )
36 | gps.fit(
37 | T=df_fixture()["treatment"],
38 | X=df_fixture()["x1"],
39 | y=df_fixture()["outcome"],
40 | )
41 |
42 | assert isinstance(gps.gam_results, LinearGAM)
43 | assert gps.gps.shape == (500,)
44 |
45 |
46 | def test_bad_param_calculate_CDRC(GPS_fitted_model_continuous_fixture):
47 | """
48 | Tests the GPS `calculate_CDRC` when the `ci` param is bad
49 | """
50 |
51 | with pytest.raises(Exception) as bad:
52 | observed_result = GPS_fitted_model_continuous_fixture.calculate_CDRC(
53 | np.array([50]), ci={"param": 0.95}
54 | )
55 |
56 | with pytest.raises(Exception) as bad:
57 | observed_result = GPS_fitted_model_continuous_fixture.calculate_CDRC(
58 | np.array([50]), ci=1.05
59 | )
60 |
--------------------------------------------------------------------------------
/tests/integration/test_gps.py:
--------------------------------------------------------------------------------
1 | """ Integration tests of the gps.py module """
2 |
3 | import pandas as pd
4 |
5 | from causal_curve import GPS_Regressor, GPS_Classifier
6 |
7 |
8 | def test_full_continuous_gps_flow(continuous_dataset_fixture):
9 | """
10 | Tests the full flow of the GPS tool when used with a continuous outcome
11 | """
12 |
13 | gps = GPS_Regressor(
14 | treatment_grid_num=10,
15 | lower_grid_constraint=0.0,
16 | upper_grid_constraint=1.0,
17 | spline_order=3,
18 | n_splines=10,
19 | max_iter=100,
20 | random_seed=100,
21 | verbose=True,
22 | )
23 | gps.fit(
24 | T=continuous_dataset_fixture["treatment"],
25 | X=continuous_dataset_fixture[["x1", "x2"]],
26 | y=continuous_dataset_fixture["outcome"],
27 | )
28 | gps_results = gps.calculate_CDRC(0.95)
29 |
30 | assert isinstance(gps_results, pd.DataFrame)
31 | check = gps_results.columns == [
32 | "Treatment",
33 | "Causal_Dose_Response",
34 | "Lower_CI",
35 | "Upper_CI",
36 | ]
37 | assert check.all()
38 |
39 |
40 | def test_binary_gps_flow(binary_dataset_fixture):
41 | """
42 | Tests the full flow of the GPS tool when used with a binary outcome
43 | """
44 |
45 | gps = GPS_Classifier(
46 | gps_family="normal",
47 | treatment_grid_num=10,
48 | lower_grid_constraint=0.0,
49 | upper_grid_constraint=1.0,
50 | spline_order=3,
51 | n_splines=10,
52 | max_iter=100,
53 | random_seed=100,
54 | verbose=True,
55 | )
56 | gps.fit(
57 | T=binary_dataset_fixture["treatment"],
58 | X=binary_dataset_fixture["x1"],
59 | y=binary_dataset_fixture["outcome"],
60 | )
61 | gps_results = gps.calculate_CDRC(0.95)
62 |
63 | assert isinstance(gps_results, pd.DataFrame)
64 | check = gps_results.columns == [
65 | "Treatment",
66 | "Causal_Odds_Ratio",
67 | "Lower_CI",
68 | "Upper_CI",
69 | ]
70 | assert check.all()
71 |
--------------------------------------------------------------------------------
/paper/paper.bib:
--------------------------------------------------------------------------------
1 |
2 |
3 | @book{Galagate:2016,
4 | Adsurl = {https://drum.lib.umd.edu/handle/1903/18170},
5 | Author = {{Galagate}, D.},
6 | Title = {Causal Inference with a Continuous Treatment and Outcome: Alternative Estimators for Parametric Dose-Response function with Applications.},
7 | Booktitle = {Causal Inference with a Continuous Treatment and Outcome: Alternative Estimators for Parametric Dose-Response function with Applications.},
8 | Publisher = {Digital Repository at the University of Maryland},
9 | Year = 2016
10 | }
11 |
12 | @article{Moodie:2010,
13 | author = {{Moodie}, E. and {Stephen}, D.},
14 | title = "{Estimation of dose–response functions for longitudinal data using the generalised propensity score.}",
15 | journal = {Statistical Methods in Medical Research},
16 | year = 2010,
17 | volume = 21,
18 | doi = {10.1177/0962280209340213},
19 | }
20 |
21 | @book{Hirano:2004,
22 | Author = {{Hirano}, K. and {Imbens}, G.},
23 | Booktitle = {Applied bayesian modeling and causal inference from incomplete-data perspectives, by Gelman A and Meng XL. ~Published by Wiley, Oxford, UK.},
24 | Publisher = {Wiley},
25 | Title = {{The propensity score with continuous treatments}},
26 | Year = 2004
27 | }
28 |
29 | @article{van_der_Laan:2010,
30 | author = {{van der Laan}, M. and {Gruber}, S.},
31 | title = "{Collaborative double robust penalized targeted maximum likelihood estimation.}",
32 | journal = {The International Journal of Biostatistics},
33 | year = 2010,
34 | volume = 6,
35 | doi = {10.2202/1557-4679.1181},
36 | }
37 |
38 | @article{Imai:2010,
39 | author = {{Imai}, K., {Keele}, L., and {Tingley}, D.},
40 | title = "{A General Approach to Causal Mediation Analysis.}",
41 | journal = {Psychological Methods},
42 | year = 2010,
43 | volume = 15,
44 | doi = {10.1037/a0020761}
45 | }
46 |
47 | @book{Hernán:2020,
48 | Author = {{Hernán}, M. and {Robins}, J.},
49 | Booktitle = {Causal Inference: What If.},
50 | Publisher = {Chapman & Hall},
51 | Title = {{Causal Inference: What If.}},
52 | Year = 2020
53 | }
54 |
--------------------------------------------------------------------------------
/causal_curve/core.py:
--------------------------------------------------------------------------------
1 | """
2 | Core classes (with basic methods) that will be invoked when other, model classes are defined
3 | """
4 |
5 | import numpy as np
6 | from scipy.stats import norm
7 |
8 |
9 | class Core:
10 | """Base class for causal_curve module"""
11 |
12 | def __init__(self):
13 | pass
14 |
15 | __version__ = "1.0.6"
16 |
17 | def get_params(self):
18 | """Returns a dict of all of the object's user-facing parameters
19 |
20 | Parameters
21 | ----------
22 | None
23 |
24 | Returns
25 | -------
26 | self: object
27 | """
28 | attrs = self.__dict__
29 | return dict(
30 | [(k, v) for k, v in list(attrs.items()) if (k[0] != "_") and (k[-1] != "_")]
31 | )
32 |
33 | def if_verbose_print(self, string):
34 | """Prints the input statement if verbose is set to True
35 |
36 | Parameters
37 | ----------
38 | string: str, some string to be printed
39 |
40 | Returns
41 | ----------
42 | None
43 |
44 | """
45 | if self.verbose:
46 | print(string)
47 |
48 | @staticmethod
49 | def rand_seed_wrapper(random_seed=None):
50 | """Sets the random seed using numpy
51 |
52 | Parameters
53 | ----------
54 | random_seed: int, random seed number
55 |
56 | Returns
57 | ----------
58 | None
59 | """
60 | if random_seed is None:
61 | pass
62 | else:
63 | np.random.seed(random_seed)
64 |
65 | @staticmethod
66 | def calculate_z_score(ci):
67 | """Calculates the critical z-score for a desired two-sided,
68 | confidence interval width.
69 |
70 | Parameters
71 | ----------
72 | ci: float, the confidence interval width (e.g. 0.95)
73 |
74 | Returns
75 | -------
76 | Float, critical z-score value
77 | """
78 | return norm.ppf((1 + ci) / 2)
79 |
80 | @staticmethod
81 | def clip_negatives(number):
82 | """Helper function to clip negative numbers to zero
83 |
84 | Parameters
85 | ----------
86 | number: int or float, any number that needs a floor at zero
87 |
88 | Returns
89 | -------
90 | Int or float of modified value
91 |
92 | """
93 | if number < 0:
94 | return 0
95 | return number
96 |
--------------------------------------------------------------------------------
/docs/conf.py:
--------------------------------------------------------------------------------
1 | # Configuration file for the Sphinx documentation builder.
2 | #
3 | # This file only contains a selection of the most common options. For a full
4 | # list see the documentation:
5 | # https://www.sphinx-doc.org/en/master/usage/configuration.html
6 |
7 | # -- Path setup --------------------------------------------------------------
8 |
9 | # If extensions (or modules to document with autodoc) are in another directory,
10 | # add these directories to sys.path here. If the directory is relative to the
11 | # documentation root, use os.path.abspath to make it absolute, like shown here.
12 |
13 | import os
14 | import sys
15 |
16 | sys.path.insert(0, os.path.abspath("../"))
17 |
18 |
19 | # -- Project information -----------------------------------------------------
20 |
21 | project = "causal_curve"
22 | copyright = "2020, Roni Kobrosly"
23 | author = "Roni Kobrosly"
24 |
25 | # The full version, including alpha/beta/rc tags
26 | release = "1.0.6"
27 |
28 | # -- General configuration ---------------------------------------------------
29 |
30 | # Add any Sphinx extension module names here, as strings. They can be
31 | # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom
32 | # ones.
33 | extensions = [
34 | "sphinx.ext.autodoc",
35 | "sphinx.ext.autosummary",
36 | "numpydoc",
37 | ]
38 |
39 | # Add any paths that contain templates here, relative to this directory.
40 | templates_path = ["_templates"]
41 |
42 | # List of patterns, relative to source directory, that match files and
43 | # directories to ignore when looking for source files.
44 | # This pattern also affects html_static_path and html_extra_path.
45 | exclude_patterns = ["_build", "Thumbs.db", ".DS_Store"]
46 |
47 | # The name of the Pygments (syntax highlighting) style to use.
48 | pygments_style = "sphinx"
49 |
50 | # this is needed for some reason...
51 | # see https://github.com/numpy/numpydoc/issues/69
52 | numpydoc_show_class_members = False
53 |
54 | # generate autosummary even if no references
55 | autosummary_generate = True
56 |
57 | master_doc = "index"
58 |
59 | # -- Options for HTML output -------------------------------------------------
60 |
61 | # The theme to use for HTML and HTML Help pages. See the documentation for
62 | # a list of builtin themes.
63 | #
64 | html_theme = "sphinx_rtd_theme"
65 |
66 | # Add any paths that contain custom static files (such as style sheets) here,
67 | # relative to this directory. They are copied after the builtin static files,
68 | # so a file named "default.css" will overwrite the builtin "default.css".
69 | html_static_path = []
70 |
--------------------------------------------------------------------------------
/docs/GPS_Classifier.rst:
--------------------------------------------------------------------------------
1 | .. _GPS_Classifier:
2 |
3 | ============================================================
4 | GPS_Classifier Tool (continuous treatments, binary outcomes)
5 | ============================================================
6 |
7 | As with the other GPS tool, we calculate generalized propensity scores (GPS) but
8 | with the classifier we can estimate the point-by-point causal contribution of
9 | a continuous treatment to a binary outcome. The GPS_Classifier does this by
10 | estimating the log odds of a positive outcome and odds ratio (odds of positive outcome / odds of negative outcome) along
11 | the entire range of treatment values:
12 |
13 | .. image:: ../imgs/binary_OR_fig.png
14 |
15 |
16 | Currently, the causal-curve package does not contain a TMLE implementation that is appropriate for a binary outcome,
17 | so the GPS_Classifier tool will have to suffice for this sort of outcome.
18 |
19 | This tool works much like the _GPS_Regressor tool; as long as the outcome series in your dataframe contains
20 | binary integer values (e.g. 0's and 1's) the ``fit()`` method will work as it's supposed to:
21 |
22 | >>> df.head(5) # a pandas dataframe with your data
23 | X_1 X_2 Treatment Outcome
24 | 0 0.596685 0.162688 0.000039 1
25 | 1 1.014187 0.916101 0.000197 0
26 | 2 0.932859 1.328576 0.000223 0
27 | 3 1.140052 0.555203 0.000339 0
28 | 4 1.613471 0.340886 0.000438 1
29 |
30 | With this dataframe, we can now calculate the GPS to estimate the causal relationship between
31 | treatment and outcome. Let's use the default settings of the GPS tool:
32 |
33 | >>> from causal_curve import GPS_Classifier
34 | >>> gps = GPS()
35 | >>> gps.fit(T = df['Treatment'], X = df[['X_1', 'X_2']], y = df['Outcome'])
36 | >>> gps_results = gps.calculate_CDRC(0.95)
37 |
38 | The ``gps_results`` object (a dataframe) now contains all of the data to produce the above plot.
39 |
40 | If you'd like to estimate the log odds at a specific point on the curve, use the
41 | ``predict_log_odds`` to do so.
42 |
43 | References
44 | ----------
45 |
46 | Galagate, D. Causal Inference with a Continuous Treatment and Outcome: Alternative
47 | Estimators for Parametric Dose-Response function with Applications. PhD thesis, 2016.
48 |
49 | Moodie E and Stephens DA. Estimation of dose–response functions for
50 | longitudinal data using the generalised propensity score. In: Statistical Methods in
51 | Medical Research 21(2), 2010, pp.149–166.
52 |
53 | Hirano K and Imbens GW. The propensity score with continuous treatments.
54 | In: Gelman A and Meng XL (eds) Applied bayesian modeling and causal inference
55 | from incomplete-data perspectives. Oxford, UK: Wiley, 2004, pp.73–84.
56 |
--------------------------------------------------------------------------------
/docs/TMLE_Regressor.rst:
--------------------------------------------------------------------------------
1 | .. _TMLE_Regressor:
2 |
3 | ================================================================
4 | TMLE_Regressor Tool (continuous treatments, continuous outcomes)
5 | ================================================================
6 |
7 |
8 | In this example, we use this package's Targeted Maximum Likelihood Estimation (TMLE)
9 | tool to estimate the marginal causal curve of some continuous treatment on a continuous outcome,
10 | accounting for some mild confounding effects.
11 |
12 | The TMLE algorithm is doubly robust, meaning that as long as one of the two models contained
13 | with the tool (the ``g`` or ``q`` models) performs well, then the overall tool will correctly
14 | estimate the causal curve.
15 |
16 | Compared with the package's GPS methods incorporates more powerful machine learning techniques internally (gradient boosting)
17 | and produces significantly smaller confidence intervals. However it is less computationally efficient
18 | and will take longer to run. In addition, **the treatment values provided should
19 | be roughly normally-distributed**, otherwise you may encounter internal math errors.
20 |
21 | Let's first generate some simple toy data:
22 |
23 |
24 | >>> import matplotlib.pyplot as plt
25 | import numpy as np
26 | import pandas as pd
27 | from causal_curve import TMLE_Regressor
28 | np.random.seed(200)
29 |
30 | >>> def generate_data(t, A, sigma, omega, noise=0, n_outliers=0, random_state=0):
31 | y = A * np.exp(-sigma * t) * np.sin(omega * t)
32 | rnd = np.random.RandomState(random_state)
33 | error = noise * rnd.randn(t.size)
34 | outliers = rnd.randint(0, t.size, n_outliers)
35 | error[outliers] *= 35
36 | return y + error
37 |
38 | >>> treatment = np.linspace(0, 10, 1000)
39 | outcome = generate_data(
40 | t = treatment,
41 | A = 2,
42 | sigma = 0.1,
43 | omega = (0.1 * 2 * np.pi),
44 | noise = 0.1,
45 | n_outliers = 5
46 | )
47 | x1 = np.random.uniform(0,10,1000)
48 | x2 = (np.random.uniform(0,10,1000) * 3)
49 |
50 | >>> df = pd.DataFrame(
51 | {
52 | 'x1': x1,
53 | 'x2': x2,
54 | 'treatment': treatment,
55 | 'outcome': outcome
56 | }
57 | )
58 |
59 |
60 | All we do now is employ the TMLE_Regressor class, with mostly default settings:
61 |
62 |
63 | >>> from causal_curve import TMLE_Regressor
64 | tmle = TMLE_Regressor(
65 | random_seed=111,
66 | bandwidth=10
67 | )
68 |
69 | >>> tmle.fit(T = df['treatment'], X = df[['x1', 'x2']], y = df['outcome'])
70 | gps_results = tmle.calculate_CDRC(0.95)
71 |
72 | The resulting dataframe contains all of the data you need to generate the following plot:
73 |
74 | .. image:: ../imgs/tmle_plot.png
75 |
76 | To generate user-specified points along the curve, use the ``point_estimate`` and ``point_estimate_interval`` methods:
77 |
78 | >>> tmle.point_estimate(np.array([5.5]))
79 | tmle.point_estimate_interval(np.array([5.5]))
80 |
81 |
82 | References
83 | ----------
84 |
85 | Kennedy EH, Ma Z, McHugh MD, Small DS. Nonparametric methods for doubly robust estimation
86 | of continuous treatment effects. Journal of the Royal Statistical Society, Series B. 79(4), 2017, pp.1229-1245.
87 |
88 | van der Laan MJ and Rubin D. Targeted maximum likelihood learning. In: U.C. Berkeley Division of
89 | Biostatistics Working Paper Series, 2006.
90 |
91 | van der Laan MJ and Gruber S. Collaborative double robust penalized targeted
92 | maximum likelihood estimation. In: The International Journal of Biostatistics 6(1), 2010.
93 |
--------------------------------------------------------------------------------
/docs/index.rst:
--------------------------------------------------------------------------------
1 | Welcome to causal-curve's documentation!
2 | ========================================
3 |
4 |
5 | .. toctree::
6 | :maxdepth: 2
7 | :hidden:
8 | :caption: Getting Started
9 |
10 | intro
11 | install
12 | contribute
13 |
14 |
15 | .. toctree::
16 | :maxdepth: 1
17 | :hidden:
18 | :caption: End-to-end demonstration
19 |
20 | full_example
21 |
22 |
23 | .. toctree::
24 | :maxdepth: 1
25 | :hidden:
26 | :caption: Package Tools
27 |
28 | GPS_Regressor
29 | GPS_Classifier
30 | TMLE_Regressor
31 | Mediation_example
32 |
33 | .. toctree::
34 | :maxdepth: 1
35 | :hidden:
36 | :caption: Module details
37 |
38 | modules
39 |
40 |
41 | .. toctree::
42 | :maxdepth: 1
43 | :hidden:
44 | :caption: Additional Information
45 |
46 | changelog
47 | citation
48 |
49 |
50 | .. toctree::
51 | :maxdepth: 2
52 | :caption: Contents:
53 |
54 |
55 | **causal-curve** is a Python package with tools to perform causal inference
56 | when the treatment of interest is continuous.
57 |
58 | .. image:: ../imgs/welcome_plot.png
59 |
60 |
61 | Summary
62 | -------
63 |
64 | (**Version 1.0.0 released in Jan 2021!**)
65 |
66 | There are many available methods to perform causal inference when your intervention of interest is binary,
67 | but few methods exist to handle continuous treatments. This is unfortunate because there are many
68 | scenarios (in industry and research) where these methods would be useful. This library attempts to
69 | address this gap, providing tools to estimate causal curves (AKA causal dose-response curves).
70 | Both continuous and binary outcomes can be modeled with this package.
71 |
72 |
73 | Quick example (of the ``GPS_Regressor`` tool)
74 | ---------------------------------------------
75 |
76 | **causal-curve** uses a sklearn-like API that should feel familiar to python machine learning users.
77 | This includes ``_Regressor`` and ``_Classifier`` models, and ``fit()`` methods.
78 |
79 | The following example estimates the causal dose-response curve (CDRC) by calculating
80 | generalized propensity scores.
81 |
82 | >>> from causal_curve import GPS_Regressor
83 | >>> import numpy as np
84 |
85 | >>> gps = GPS_Regressor(treatment_grid_num = 200, random_seed = 512)
86 |
87 | >>> df # a pandas dataframe with your data
88 | X_1 X_2 Treatment Outcome
89 | 0 0.596685 0.162688 0.000039 -0.270533
90 | 1 1.014187 0.916101 0.000197 -0.266979
91 | 2 0.932859 1.328576 0.000223 1.921979
92 | 3 1.140052 0.555203 0.000339 1.461526
93 | 4 1.613471 0.340886 0.000438 2.064511
94 |
95 | >>> gps.fit(T = df['Treatment'], X = df[['X_1', 'X_2']], y = df['Outcome'])
96 | >>> gps_results = gps.calculate_CDRC(ci = 0.95)
97 | >>> gps_point = gps.point_estimate(np.array([0.0003]))
98 | >>> gps_point_interval = gps.point_estimate_interval(np.array([0.0003]), ci = 0.95)
99 |
100 | 1. First we import the `GPS_Regressor` class.
101 |
102 | 2. Then we instantiate the class, providing any of the optional parameters.
103 |
104 | 3. Prepare and organized your treatment, covariate, and outcome data into a pandas dataframe.
105 |
106 | 4. Fit the load the training and test sets by calling the ``.fit()`` method.
107 |
108 | 5. Estimate the points of the causal curve (along with 95% confidence interval bounds) with the ``.calculate_CDRC()`` method.
109 |
110 | 6. Generate point estimates along the causal curve with the ``.point_estimate()``, ``.point_estimate_interval()``, and ``.estimate_log_odds()`` methods.
111 |
112 | 7. Explore or plot your results!
113 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # causal-curve
2 |
3 | [](https://travis-ci.org/ronikobrosly/causal-curve)
4 | [](https://codecov.io/gh/ronikobrosly/causal-curve)
5 | [](https://zenodo.org/badge/latestdoi/256017107)
6 |
7 | Python tools to perform causal inference when the treatment of interest is continuous.
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 | ## Table of Contents
17 |
18 | - [Overview](#overview)
19 | - [Installation](#installation)
20 | - [Documentation](#documentation)
21 | - [Contributing](#contributing)
22 | - [Citation](#citation)
23 | - [References](#references)
24 |
25 | ## Overview
26 |
27 | (**Version 1.0.0 released in January 2021!**)
28 |
29 | There are many implemented methods to perform causal inference when your intervention of interest is binary,
30 | but few methods exist to handle continuous treatments.
31 |
32 | This is unfortunate because there are many scenarios (in industry and research) where these methods would be useful.
33 | For example, when you would like to:
34 |
35 | * Estimate the causal response to increasing or decreasing the price of a product across a wide range.
36 | * Understand how the number of minutes per week of aerobic exercise causes positive health outcomes.
37 | * Estimate how decreasing order wait time will impact customer satisfaction, after controlling for confounding effects.
38 | * Estimate how changing neighborhood income inequality (Gini index) could be causally related to neighborhood crime rate.
39 |
40 | This library attempts to address this gap, providing tools to estimate causal curves (AKA causal dose-response curves).
41 | Both continuous and binary outcomes can be modeled against a continuous treatment.
42 |
43 | ## Installation
44 |
45 | Available via PyPI:
46 |
47 | `pip install causal-curve`
48 |
49 | You can also get the latest version of causal-curve by cloning the repository::
50 |
51 | ```
52 | git clone -b main https://github.com/ronikobrosly/causal-curve.git
53 | cd causal-curve
54 | pip install .
55 | ```
56 |
57 | ## Documentation
58 |
59 | [Documentation, tutorials, and examples are available at readthedocs.org](https://causal-curve.readthedocs.io/en/latest/)
60 |
61 |
62 | ## Contributing
63 |
64 | Your help is absolutely welcome! Please do reach out or create a feature branch!
65 |
66 | ## Citation
67 |
68 | Kobrosly, R. W., (2020). causal-curve: A Python Causal Inference Package to Estimate Causal Dose-Response Curves. Journal of Open Source Software, 5(52), 2523, [https://doi.org/10.21105/joss.02523](https://doi.org/10.21105/joss.02523)
69 |
70 | ## References
71 |
72 | Galagate, D. Causal Inference with a Continuous Treatment and Outcome: Alternative
73 | Estimators for Parametric Dose-Response function with Applications. PhD thesis, 2016.
74 |
75 | Hirano K and Imbens GW. The propensity score with continuous treatments.
76 | In: Gelman A and Meng XL (eds) Applied bayesian modeling and causal inference
77 | from incomplete-data perspectives. Oxford, UK: Wiley, 2004, pp.73–84.
78 |
79 | Imai K, Keele L, Tingley D. A General Approach to Causal Mediation Analysis. Psychological
80 | Methods. 15(4), 2010, pp.309–334.
81 |
82 | Kennedy EH, Ma Z, McHugh MD, Small DS. Nonparametric methods for doubly robust estimation
83 | of continuous treatment effects. Journal of the Royal Statistical Society, Series B. 79(4), 2017, pp.1229-1245.
84 |
85 | Moodie E and Stephens DA. Estimation of dose–response functions for
86 | longitudinal data using the generalised propensity score. In: Statistical Methods in
87 | Medical Research 21(2), 2010, pp.149–166.
88 |
89 | van der Laan MJ and Gruber S. Collaborative double robust penalized targeted
90 | maximum likelihood estimation. In: The International Journal of Biostatistics 6(1), 2010.
91 |
92 | van der Laan MJ and Rubin D. Targeted maximum likelihood learning. In: U.C. Berkeley Division of
93 | Biostatistics Working Paper Series, 2006.
94 |
--------------------------------------------------------------------------------
/tests/unit/test_mediation.py:
--------------------------------------------------------------------------------
1 | """ Unit tests of the Mediation.py module """
2 |
3 | import numpy as np
4 | import pytest
5 |
6 | from causal_curve import Mediation
7 |
8 |
9 | def test_mediation_fit(mediation_fixture):
10 | """
11 | Tests the fit method of Mediation tool
12 | """
13 |
14 | med = Mediation(
15 | treatment_grid_num=10,
16 | lower_grid_constraint=0.01,
17 | upper_grid_constraint=0.99,
18 | bootstrap_draws=100,
19 | bootstrap_replicates=50,
20 | spline_order=3,
21 | n_splines=5,
22 | lambda_=0.5,
23 | max_iter=20,
24 | random_seed=None,
25 | verbose=True,
26 | )
27 | med.fit(
28 | T=mediation_fixture["treatment"],
29 | M=mediation_fixture["mediator"],
30 | y=mediation_fixture["outcome"],
31 | )
32 |
33 | assert len(med.final_bootstrap_results) == 9
34 |
35 |
36 | @pytest.mark.parametrize(
37 | (
38 | "treatment_grid_num",
39 | "lower_grid_constraint",
40 | "upper_grid_constraint",
41 | "bootstrap_draws",
42 | "bootstrap_replicates",
43 | "spline_order",
44 | "n_splines",
45 | "lambda_",
46 | "max_iter",
47 | "random_seed",
48 | "verbose",
49 | ),
50 | [
51 | (10.5, 0.01, 0.99, 10, 10, 3, 5, 0.5, 100, None, True),
52 | (0, 0.01, 0.99, 10, 10, 3, 5, 0.5, 100, None, True),
53 | (1e6, 0.01, 0.99, 10, 10, 3, 5, 0.5, 100, None, True),
54 | (10, "hehe", 0.99, 10, 10, 3, 5, 0.5, 100, None, True),
55 | (10, -1.0, 0.99, 10, 10, 3, 5, 0.5, 100, None, True),
56 | (10, 1.5, 0.99, 10, 10, 3, 5, 0.5, 100, None, True),
57 | (10, 0.1, "hehe", 10, 10, 3, 5, 0.5, 100, None, True),
58 | (10, 0.1, -1.0, 10, 10, 3, 5, 0.5, 100, None, True),
59 | (10, 0.1, 1.5, 10, 10, 3, 5, 0.5, 100, None, True),
60 | (10, 0.1, 0.9, 10.5, 10, 3, 5, 0.5, 100, None, True),
61 | (10, 0.1, 0.9, -2, 10, 3, 5, 0.5, 100, None, True),
62 | (10, 0.1, 0.9, 10000000, 10, 3, 5, 0.5, 100, None, True),
63 | (10, 0.1, 0.9, 100, "10", 3, 5, 0.5, 100, None, True),
64 | (10, 0.1, 0.9, 100, -1, 3, 5, 0.5, 100, None, True),
65 | (10, 0.1, 0.9, 100, 10000000, 3, 5, 0.5, 100, None, True),
66 | (10, 0.1, 0.9, 100, 200, "3", 5, 0.5, 100, None, True),
67 | (10, 0.1, 0.9, 100, 200, 1, 5, 0.5, 100, None, True),
68 | (10, 0.1, 0.9, 100, 200, 1e6, 5, 0.5, 100, None, True),
69 | (10, 0.1, 0.9, 100, 200, 5, "10", 0.5, 100, None, True),
70 | (10, 0.1, 0.9, 100, 200, 5, 1, 0.5, 100, None, True),
71 | (10, 0.1, 0.9, 100, 200, 5, 10, "0.5", 100, None, True),
72 | (10, 0.1, 0.9, 100, 200, 5, 10, -0.5, 100, None, True),
73 | (10, 0.1, 0.9, 100, 200, 5, 10, 1e7, 100, None, True),
74 | (10, 0.1, 0.9, 100, 200, 5, 10, 1, "100", None, True),
75 | (10, 0.1, 0.9, 100, 200, 5, 10, 1, 1, None, True),
76 | (10, 0.1, 0.9, 100, 200, 5, 10, 1, 1e8, None, True),
77 | (10, 0.1, 0.9, 100, 200, 5, 10, 1, 100, "None", True),
78 | (10, 0.1, 0.9, 100, 200, 5, 10, 1, 100, -5, True),
79 | (10, 0.1, 0.9, 100, 200, 5, 10, 1, 100, 123, "True"),
80 | (10, 30.0, 10.0, 100, 200, 5, 10, 1, 100, None, True),
81 | ],
82 | )
83 | def test_bad_mediation_instantiation(
84 | treatment_grid_num,
85 | lower_grid_constraint,
86 | upper_grid_constraint,
87 | bootstrap_draws,
88 | bootstrap_replicates,
89 | spline_order,
90 | n_splines,
91 | lambda_,
92 | max_iter,
93 | random_seed,
94 | verbose,
95 | ):
96 | with pytest.raises(Exception) as bad:
97 | Mediation(
98 | treatment_grid_num=treatment_grid_num,
99 | lower_grid_constraint=lower_grid_constraint,
100 | upper_grid_constraint=upper_grid_constraint,
101 | bootstrap_draws=bootstrap_draws,
102 | bootstrap_replicates=bootstrap_replicates,
103 | spline_order=spline_order,
104 | n_splines=n_splines,
105 | lambda_=lambda_,
106 | max_iter=max_iter,
107 | random_seed=random_seed,
108 | verbose=verbose,
109 | )
110 |
--------------------------------------------------------------------------------
/tests/unit/test_tmle_regressor.py:
--------------------------------------------------------------------------------
1 | """ Unit tests of the tmle.py module """
2 |
3 | import numpy as np
4 | import pytest
5 |
6 | from causal_curve import TMLE_Regressor
7 |
8 |
9 | def test_point_estimate_method_good(TMLE_fitted_model_continuous_fixture):
10 | """
11 | Tests the GPS `point_estimate` method using appropriate data (with a continuous outcome)
12 | """
13 |
14 | observed_result = TMLE_fitted_model_continuous_fixture.point_estimate(
15 | np.array([50])
16 | )
17 | assert isinstance(observed_result[0][0], float)
18 | assert len(observed_result[0]) == 1
19 |
20 | observed_result = TMLE_fitted_model_continuous_fixture.point_estimate(
21 | np.array([40, 50, 60])
22 | )
23 | assert isinstance(observed_result[0][0], float)
24 | assert len(observed_result[0]) == 3
25 |
26 |
27 | def test_point_estimate_interval_method_good(TMLE_fitted_model_continuous_fixture):
28 | """
29 | Tests the GPS `point_estimate_interval` method using appropriate data (with a continuous outcome)
30 | """
31 | observed_result = TMLE_fitted_model_continuous_fixture.point_estimate_interval(
32 | np.array([50])
33 | )
34 | assert isinstance(observed_result[0][0], float)
35 | assert observed_result.shape == (1, 2)
36 |
37 | observed_result = TMLE_fitted_model_continuous_fixture.point_estimate_interval(
38 | np.array([40, 50, 60])
39 | )
40 | assert isinstance(observed_result[0][0], float)
41 | assert observed_result.shape == (3, 2)
42 |
43 |
44 | @pytest.mark.parametrize(
45 | (
46 | "treatment_grid_num",
47 | "lower_grid_constraint",
48 | "upper_grid_constraint",
49 | "n_estimators",
50 | "learning_rate",
51 | "max_depth",
52 | "bandwidth",
53 | "random_seed",
54 | "verbose",
55 | ),
56 | [
57 | # treatment_grid_num
58 | (100.0, 0.01, 0.99, 200, 0.01, 3, 0.5, None, False),
59 | ("100.0", 0.01, 0.99, 200, 0.01, 3, 0.5, None, False),
60 | (2, 0.01, 0.99, 200, 0.01, 3, 0.5, None, False),
61 | (500000, 0.01, 0.99, 200, 0.01, 3, 0.5, None, False),
62 | # lower_grid_constraint
63 | (100, {0.01: "a"}, 0.99, 200, 0.01, 3, 0.5, None, False),
64 | (100, -0.01, 0.99, 200, 0.01, 3, 0.5, None, False),
65 | (100, 6.05, 0.99, 200, 0.01, 3, 0.5, None, False),
66 | # upper_grid_constraint
67 | (100, 0.01, [1, 2, 3], 200, 0.01, 3, 0.5, None, False),
68 | (100, 0.01, -0.05, 200, 0.01, 3, 0.5, None, False),
69 | (100, 0.01, 5.99, 200, 0.01, 3, 0.5, None, False),
70 | (100, 0.9, 0.2, 200, 0.01, 3, 0.5, None, False),
71 | # n_estimators
72 | (100, 0.01, 0.99, "3.0", 0.01, 3, 0.5, None, False),
73 | (100, 0.01, 0.99, -5, 0.01, 3, 0.5, None, False),
74 | (100, 0.01, 0.99, 10000000, 0.01, 3, 0.5, None, False),
75 | # learning_rate
76 | (100, 0.01, 0.99, 200, ["a", "b"], 3, 0.5, None, False),
77 | (100, 0.01, 0.99, 200, 5000000, 3, 0.5, None, False),
78 | # max_depth
79 | (100, 0.01, 0.99, 200, 0.01, "a", 0.5, None, False),
80 | (100, 0.01, 0.99, 200, 0.01, -6, 0.5, None, False),
81 | # bandwidth
82 | (100, 0.01, 0.99, 200, 0.01, 3, "b", None, False),
83 | (100, 0.01, 0.99, 200, 0.01, 3, -10, None, False),
84 | # random seed
85 | (100, 0.01, 0.99, 200, 0.01, 3, 0.5, "b", False),
86 | (100, 0.01, 0.99, 200, 0.01, 3, 0.5, -10, False),
87 | # verbose
88 | (100, 0.01, 0.99, 200, 0.01, 3, 0.5, None, "Verbose"),
89 | ],
90 | )
91 | def test_bad_tmle_instantiation(
92 | treatment_grid_num,
93 | lower_grid_constraint,
94 | upper_grid_constraint,
95 | n_estimators,
96 | learning_rate,
97 | max_depth,
98 | bandwidth,
99 | random_seed,
100 | verbose,
101 | ):
102 | with pytest.raises(Exception) as bad:
103 | TMLE_Regressor(
104 | treatment_grid_num=treatment_grid_num,
105 | lower_grid_constraint=lower_grid_constraint,
106 | upper_grid_constraint=upper_grid_constraint,
107 | n_estimators=n_estimators,
108 | learning_rate=learning_rate,
109 | max_depth=max_depth,
110 | bandwidth=bandwidth,
111 | random_seed=random_seed,
112 | verbose=verbose,
113 | )
114 |
--------------------------------------------------------------------------------
/docs/GPS_Regressor.rst:
--------------------------------------------------------------------------------
1 | .. _GPS_Regressor:
2 |
3 | ================================================================
4 | GPS_Regressor Tool (continuous treatments, continuous outcomes)
5 | ================================================================
6 |
7 |
8 | In this example, we use this package's GPS_Regressor tool to estimate the marginal causal curve of some
9 | continuous treatment on a continuous outcome, accounting for some mild confounding effects.
10 | To put this differently, the result of this will be an estimate of the average
11 | of each individual's dose-response to the treatment. To do this we calculate
12 | generalized propensity scores (GPS) to correct the treatment prediction of the outcome.
13 |
14 | Compared with the package's TMLE method, the GPS methods are more computationally efficient,
15 | better suited for large datasets, but produces wider confidence intervals.
16 |
17 | In this example we use simulated data originally developed by Hirano and Imbens but adapted by others
18 | (see references). The advantage of this simulated data is it allows us
19 | to compare the estimate we produce against the true, analytically-derived causal curve.
20 |
21 | Let :math:`t_i` be the treatment for the i-th unit, let :math:`x_1` and :math:`x_2` be the
22 | confounding covariates, and let :math:`y_i` be the outcome measure. We assume that the covariates
23 | and treatment are exponentially-distributed, and the treatment variable is associated with the
24 | covariates in the following way:
25 |
26 | >>> import numpy as np
27 | >>> import pandas as pd
28 | >>> from scipy.stats import expon
29 |
30 | >>> np.random.seed(333)
31 | >>> n = 5000
32 | >>> x_1 = expon.rvs(size=n, scale = 1)
33 | >>> x_2 = expon.rvs(size=n, scale = 1)
34 | >>> treatment = expon.rvs(size=n, scale = (1/(x_1 + x_2)))
35 |
36 | The GPS is given by
37 |
38 | .. math::
39 |
40 | f(t, x_1, x_2) = (x_1 + x_2) * e^{-(x_1 + x_2) * t}
41 |
42 | If we generate the outcome variable by summing the treatment and GPS, the true causal
43 | curve is derived analytically to be:
44 |
45 | .. math::
46 |
47 | f(t) = t + \frac{2}{(1 + t)^3}
48 |
49 |
50 | The following code completes the data generation:
51 |
52 | >>> gps = ((x_1 + x_2) * np.exp(-(x_1 + x_2) * treatment))
53 | >>> outcome = treatment + gps + np.random.normal(size = n, scale = 1)
54 |
55 | >>> truth_func = lambda treatment: (treatment + (2/(1 + treatment)**3))
56 | >>> vfunc = np.vectorize(truth_func)
57 | >>> true_outcome = vfunc(treatment)
58 |
59 | >>> df = pd.DataFrame(
60 | >>> {
61 | >>> 'X_1': x_1,
62 | >>> 'X_2': x_2,
63 | >>> 'Treatment': treatment,
64 | >>> 'GPS': gps,
65 | >>> 'Outcome': outcome,
66 | >>> 'True_outcome': true_outcome
67 | >>> }
68 | >>> ).sort_values('Treatment', ascending = True)
69 |
70 | With this dataframe, we can now calculate the GPS to estimate the causal relationship between
71 | treatment and outcome. Let's use the default settings of the GPS_Regressor tool:
72 |
73 | >>> from causal_curve import GPS_Regressor
74 | >>> gps = GPS_Regressor()
75 | >>> gps.fit(T = df['Treatment'], X = df[['X_1', 'X_2']], y = df['Outcome'])
76 | >>> gps_results = gps.calculate_CDRC(0.95)
77 |
78 | You now have everything to produce the following plot with matplotlib. In this example with only mild confounding,
79 | the GPS-calculated estimate of the true causal curve produces has approximately
80 | half the error of a simple LOESS estimate using only the treatment and the outcome.
81 |
82 | .. image:: ../imgs/cdrc/CDRC.png
83 |
84 | The GPS_Regressor tool also allows you to estimate a specific set of points along the causal curve.
85 | Use the `predict` and `predict_interval` methods to produce a point estimate
86 | and prediction interval, respectively.
87 |
88 | References
89 | ----------
90 |
91 | Galagate, D. Causal Inference with a Continuous Treatment and Outcome: Alternative
92 | Estimators for Parametric Dose-Response function with Applications. PhD thesis, 2016.
93 |
94 | Moodie E and Stephens DA. Estimation of dose–response functions for
95 | longitudinal data using the generalised propensity score. In: Statistical Methods in
96 | Medical Research 21(2), 2010, pp.149–166.
97 |
98 | Hirano K and Imbens GW. The propensity score with continuous treatments.
99 | In: Gelman A and Meng XL (eds) Applied bayesian modeling and causal inference
100 | from incomplete-data perspectives. Oxford, UK: Wiley, 2004, pp.73–84.
101 |
--------------------------------------------------------------------------------
/docs/Mediation_example.rst:
--------------------------------------------------------------------------------
1 | .. _Mediation_example:
2 |
3 | ============================================================
4 | Mediation Tool (continuous treatment, mediator, and outcome)
5 | ============================================================
6 |
7 |
8 | In trying to explore the causal relationships between various elements, oftentimes you'll use
9 | your domain knowledge to sketch out your initial ideas about the causal connections.
10 | See the following causal DAG of the expected relationships between smoking, diabetes, obesity, age,
11 | and mortality (Havumaki et al.):
12 |
13 | .. image:: ../imgs/mediation/diabetes_DAG.png
14 |
15 | At some point though, it's helpful to validate these ideas with empirical tests.
16 | This tool provides a test that can estimate the amount of mediation that occurs between
17 | a treatment, a purported mediator, and an outcome. In keeping with the causal curve theme,
18 | this tool uses a test developed by Imai et al. when handling a continuous treatment and
19 | mediator.
20 |
21 | In this example we use the following simulated data, and assume that the `mediator`
22 | variable is decided to be a mediator by expert judgement.
23 |
24 | >>> import numpy as np
25 | >>> import pandas as pd
26 |
27 | >>> np.random.seed(132)
28 | >>> n_obs = 500
29 |
30 | >>> treatment = np.random.normal(loc=50.0, scale=10.0, size=n_obs)
31 | >>> mediator = np.random.normal(loc=70.0 + treatment, scale=8.0, size=n_obs)
32 | >>> outcome = np.random.normal(loc=(treatment + mediator - 50), scale=10.0, size=n_obs)
33 |
34 | >>> df = pd.DataFrame(
35 | >>> {
36 | >>> "treatment": treatment,
37 | >>> "mediator": mediator,
38 | >>> "outcome": outcome
39 | >>> }
40 | >>> )
41 |
42 |
43 | Now we can instantiate the Mediation class:
44 |
45 | >>> from causal_curve import Mediation
46 | >>> med = Mediation(
47 | >>> bootstrap_draws=100,
48 | >>> bootstrap_replicates=100,
49 | >>> spline_order=3,
50 | >>> n_splines=5,
51 | >>> verbose=True,
52 | >>> )
53 |
54 |
55 | We then fit the data to the `med` object:
56 |
57 | >>> med.fit(
58 | >>> T=df["treatment"],
59 | >>> M=df["mediator"],
60 | >>> y=df["outcome"],
61 | >>> )
62 |
63 | With the internal models of the mediation test fit with data, we can now run the
64 | `calculate_mediation` method to produce the final report:
65 |
66 | >>> med.calculate_mediation(ci = 0.95)
67 | >>>
68 | >>> ----------------------------------
69 | >>> Mean indirect effect proportion: 0.5238 (0.5141 - 0.5344)
70 | >>>
71 | >>> Treatment_Value Proportion_Direct_Effect Proportion_Indirect_Effect
72 | >>> 35.1874 0.4743 0.5257
73 | >>> 41.6870 0.4638 0.5362
74 | >>> 44.6997 0.4611 0.5389
75 | >>> 47.5672 0.4745 0.5255
76 | >>> 50.1900 0.4701 0.5299
77 | >>> 52.7526 0.4775 0.5225
78 | >>> 56.0204 0.4727 0.5273
79 | >>> 60.5174 0.4940 0.5060
80 | >>> 66.7243 0.4982 0.5018
81 |
82 | The final analysis tells us that overall, the mediator is estimated to account for
83 | around 52% (+/- 1%) of the effect of the treatment on the outcome. This indicates that
84 | moderate mediation is occurring here. The remaining 48% occurs through a direct effect of the
85 | treatment on the outcome.
86 |
87 | So long as we are confident that the mediator doesn't play another role in the causal graph
88 | (it isn't a confounder of the treatment and outcome association), this supports the idea that
89 | the mediator is in fact a mediator.
90 |
91 | The report also shows how this mediation effect various as a function of the continuous treatment.
92 | In this case, it looks the effect is relatively flat (as expected). With a little processing
93 | and some basic interpolation, we can plot this mediation effect:
94 |
95 | .. image:: ../imgs/mediation/mediation_effect.png
96 |
97 |
98 |
99 | References
100 | ----------
101 |
102 | Imai K., Keele L., Tingley D. A General Approach to Causal Mediation Analysis. Psychological
103 | Methods. 15(4), 2010, pp.309–334.
104 |
105 | Havumaki J., Eisenberg M.C. Mathematical modeling of directed acyclic graphs to explore
106 | competing causal mechanisms underlying epidemiological study data. medRxiv preprint.
107 | doi: https://doi.org/10.1101/19007922. Accessed June 23, 2020.
108 |
--------------------------------------------------------------------------------
/causal_curve/gps_classifier.py:
--------------------------------------------------------------------------------
1 | """
2 | Defines the Generalized Prospensity Score (GPS) classifier model class
3 | """
4 | from pprint import pprint
5 |
6 | import numpy as np
7 | from scipy.special import logit
8 |
9 | from causal_curve.gps_core import GPS_Core
10 |
11 |
12 | class GPS_Classifier(GPS_Core):
13 | """
14 | A GPS tool that handles binary outcomes. Inherits the GPS_core
15 | base class. See that base class code its docstring for more details.
16 |
17 |
18 | Methods
19 | ----------
20 |
21 | estimate_log_odds: (self, T)
22 | Calculates the predicted log odds of the highest integer class. Can
23 | only be used when the outcome is binary.
24 |
25 | """
26 |
27 | def __init__(
28 | self,
29 | gps_family=None,
30 | treatment_grid_num=100,
31 | lower_grid_constraint=0.01,
32 | upper_grid_constraint=0.99,
33 | spline_order=3,
34 | n_splines=30,
35 | lambda_=0.5,
36 | max_iter=100,
37 | random_seed=None,
38 | verbose=False,
39 | ):
40 |
41 | self.gps_family = gps_family
42 | self.treatment_grid_num = treatment_grid_num
43 | self.lower_grid_constraint = lower_grid_constraint
44 | self.upper_grid_constraint = upper_grid_constraint
45 | self.spline_order = spline_order
46 | self.n_splines = n_splines
47 | self.lambda_ = lambda_
48 | self.max_iter = max_iter
49 | self.random_seed = random_seed
50 | self.verbose = verbose
51 |
52 | # Validate the params
53 | self._validate_init_params()
54 | self.rand_seed_wrapper()
55 |
56 | self.if_verbose_print("Using the following params for GPS model:")
57 | if self.verbose:
58 | pprint(self.get_params(), indent=4)
59 |
60 | def _cdrc_predictions_binary(self, ci):
61 | """Returns the predictions of CDRC for each value of the treatment grid. Essentially,
62 | we're making predictions using the original treatment and gps_at_grid.
63 | To be used when the outcome of interest is binary.
64 | """
65 | # To keep track of cdrc predictions, we create an empty 2d array of shape
66 | # (n_samples, treatment_grid_num, 2). The last dimension is of length 2 because
67 | # we are going to keep track of the point estimate (log-odds) of the prediction, as well as
68 | # the standard error of the prediction interval (again, this is for the log odds)
69 | cdrc_preds = np.zeros((len(self.T), self.treatment_grid_num, 2), dtype=float)
70 |
71 | # Loop through each of the grid values, predict point estimate and get prediction interval
72 | for i in range(0, self.treatment_grid_num):
73 |
74 | temp_T = np.repeat(self.grid_values[i], repeats=len(self.T))
75 | temp_gps = self.gps_at_grid[:, i]
76 |
77 | temp_cdrc_preds = logit(
78 | self.gam_results.predict_proba(np.column_stack((temp_T, temp_gps)))
79 | )
80 |
81 | temp_cdrc_interval = logit(
82 | self.gam_results.confidence_intervals(
83 | np.column_stack((temp_T, temp_gps)), width=ci
84 | )
85 | )
86 |
87 | standard_error = (
88 | temp_cdrc_interval[:, 1] - temp_cdrc_preds
89 | ) / self.calculate_z_score(ci)
90 |
91 | cdrc_preds[:, i, 0] = temp_cdrc_preds
92 | cdrc_preds[:, i, 1] = standard_error
93 |
94 | return np.round(cdrc_preds, 3)
95 |
96 | def estimate_log_odds(self, T):
97 | """Calculates the estimated log odds of the highest integer class. Can
98 | only be used when the outcome is binary. Can be estimate for a single
99 | data point or can be run in batch for many observations. Extrapolation will produce
100 | untrustworthy results; the provided treatment should be within
101 | the range of the training data.
102 |
103 | Parameters
104 | ----------
105 | T: Numpy array, shape (n_samples,)
106 | A continuous treatment variable.
107 |
108 | Returns
109 | ----------
110 | array: Numpy array
111 | Contains a set of log odds
112 | """
113 | if self.outcome_type != "binary":
114 | raise TypeError("Your outcome must be binary to use this function!")
115 |
116 | return np.apply_along_axis(self._create_log_odds, 0, T.reshape(1, -1))
117 |
118 | def _create_log_odds(self, T):
119 | """Take a single treatment value and produces the log odds of the higher
120 | integer class, in the case of a binary outcome.
121 | """
122 | return logit(
123 | self.gam_results.predict_proba(
124 | np.array([T[0], self.gps_function(T).mean()]).reshape(1, -1)
125 | )
126 | )
127 |
--------------------------------------------------------------------------------
/docs/changelog.rst:
--------------------------------------------------------------------------------
1 | .. _changelog:
2 |
3 | ==========
4 | Change Log
5 | ==========
6 |
7 | Version 1.0.6
8 | -------------
9 | - Latest version of python black can now run. Linted tmle_core.py.
10 |
11 | Version 1.0.5
12 | -------------
13 | - Removed `master` branch, replaced with `main`
14 | - Removed all mention of `master` branch from documentation
15 |
16 | Version 1.0.4
17 | -------------
18 | - Fixed TMLE plot and code errors in documentation
19 |
20 | Version 1.0.3
21 | -------------
22 | - Fixed bug with `random_seed` functionality in all tools
23 |
24 | Version 1.0.2
25 | -------------
26 | - Updated end-to-end example notebook in `/examples` folder
27 | - Fixed various class docstrings if they still reference old v0.5.2 API
28 | - Fixed bug where custom class input parameters weren't being used
29 |
30 |
31 | Version 1.0.1
32 | -------------
33 | - Added to TMLE overview in the docs (including plot)
34 |
35 |
36 | Version 1.0.0: **Major Update**
37 | -------------------------------
38 | - Overhaul of the TMLE tool to make it dramatically more accurate and user-friendly.
39 | - Improved TMLE example documentation
40 | - Much like with `scikit-learn`, there are now separate model classes used for predicting binary or continuous outcomes
41 | - Updating documentation to reflect API changes
42 | - Added more tests
43 | - Linted with `pylint` (added `.pylintrc` file)
44 |
45 |
46 | Version 0.5.2
47 | -------------
48 | - Fixed bug that prevented `causal-curve` modules from being shown in Sphinx documentation
49 | - Augmented tests to capture more error states and improve code coverage
50 |
51 |
52 | Version 0.5.1
53 | -------------
54 | - Removed working test file
55 |
56 |
57 | Version 0.5.0
58 | -------------
59 | - Added new `predict`, `predict_interval`, and `predict_log_odds` methods to GPS tool
60 | - Slight updates to doc to reflect new features
61 |
62 |
63 | Version 0.4.1
64 | -------------
65 | - When using GPS tool with a treatment with negative values, only the normal GLM family can be picked
66 | - Added 'sphinx_rtd_theme' to dependency list in `.travis.yml` and `install.rst`
67 | - core.py base class now has __version__ attribute
68 |
69 |
70 | Version 0.4.0
71 | -------------
72 | - Added support for binary outcomes in GPS tool
73 | - Small changes to repo README
74 |
75 |
76 | Version 0.3.8
77 | -------------
78 | - Added citation (yay!)
79 |
80 |
81 | Version 0.3.7
82 | -------------
83 | - Bumped version for PyPi
84 |
85 |
86 | Version 0.3.6
87 | -------------
88 | - Fixed bug in Mediation.calculate_mediation that would clip treatments < 0 or > 1
89 | - Fixed incorrect horizontal axis labels in lead example
90 | - Fixed typos in documentation
91 | - Added links to resources so users could learn more about causal inference theory
92 |
93 |
94 | Version 0.3.5
95 | -------------
96 | - Re-organized documentation
97 | - Added `Introduction` section to explain purpose and need for the package
98 |
99 |
100 | Version 0.3.4
101 | -------------
102 | - Removed XGBoost as dependency.
103 | - Now using sklearn's gradient boosting implementation.
104 |
105 |
106 | Version 0.3.3
107 | -------------
108 | - Misc edits to paper and bibliography
109 |
110 |
111 | Version 0.3.2
112 | -------------
113 | - Fixed random seed issue with Mediation tool
114 | - Fixed Mediation bootstrap issue. Confidence interval bounded [0,1]
115 | - Fixed issue with all classes not accepting non-sequential indicies in pandas Dataframes/Series
116 | - Class init checks for all classes now print cleaner errors if bad input
117 |
118 |
119 | Version 0.3.1
120 | -------------
121 | - Small fixes to end-to-end example documentation
122 | - Enlarged image in paper
123 |
124 |
125 | Version 0.3.0
126 | -------------
127 | - Added full, end-to-end example of package usage to documentation
128 | - Cleaned up documentation
129 | - Added example folder with end-to-end notebook
130 | - Added manuscript to paper folder
131 |
132 |
133 | Version 0.2.4
134 | -------------
135 | - Strengthened unit tests
136 |
137 |
138 | Version 0.2.3
139 | -------------
140 | - codecov integration
141 |
142 |
143 | Version 0.2.2
144 | -------------
145 | - Travis CI integration
146 |
147 |
148 | Version 0.2.1
149 | -------------
150 | - Fixed Mediation tool error / removed `tqdm` from requirements
151 | - Misc documentation cleanup / revisions
152 |
153 |
154 | Version 0.2.0
155 | -------------
156 | - Added new Mediation class
157 | - Updated documentation to reflect this
158 | - Added unit and integration tests for Mediation methods
159 |
160 |
161 | Version 0.1.3
162 | -------------
163 | - Simplifying unit and integration tests.
164 |
165 |
166 | Version 0.1.2
167 | -------------
168 |
169 | - Added unit and integration tests
170 |
171 |
172 | Version 0.1.1
173 | -------------
174 |
175 | - setup.py fix
176 |
177 |
178 | Version 0.1.0
179 | -------------
180 |
181 | - Added new TMLE class
182 | - Updated documentation to reflect new TMLE method
183 | - Renamed CDRC method to more appropriate `GPS` method
184 | - Small docstring corrections to GPS method
185 |
186 |
187 | Version 0.0.10
188 | --------------
189 |
190 | - Bug fix in GPS estimation method
191 |
192 |
193 | Version 0.0.9
194 | -------------
195 |
196 | - Project created
197 |
--------------------------------------------------------------------------------
/docs/intro.rst:
--------------------------------------------------------------------------------
1 | .. _intro:
2 |
3 | ============================
4 | Introduction to causal-curve
5 | ============================
6 |
7 | In academia and industry, randomized controlled experiments (or simply experiments or "A/B tests") are considered the gold standard approach for assessing the true, causal impact
8 | of a treatment or intervention. For example:
9 |
10 | * We want to increase the number of times per day new customers log into our business's website. Will it help if we send daily emails out to our customers? We take a group of 2000 new business customers and half is randomly chosen to receive daily emails while the other half receives one email per week. We follow both groups forward in time for a month compare each group's average number of logins per day.
11 |
12 | However, for ethical or financial reasons, experiments may not always be feasible to carry out.
13 |
14 | * It's not ethical to randomly assign some people to receive a possible carcinogen in pill form while others receive a sugar pill, and then see which group is more likely to develop cancer.
15 | * It's not feasible to increase the household incomes of some New York neighborhoods, while leaving others unchanged to see if changing a neighborhood's income inequality would improve the local crime rate.
16 |
17 | "Causal inference" methods are a set of approaches that attempt to estimate causal effects
18 | from observational rather than experimental data, correcting for the biases that are inherent
19 | to analyzing observational data (e.g. confounding and selection bias) [@Hernán:2020].
20 |
21 | As long as you have varying observational data on some treatment, your outcome of interest,
22 | and potentially confounding variables across your units of analysis (in addition to meeting the assumptions described below),
23 | then you can essentially estimate the results of a proper experiment and make causal claims.
24 |
25 |
26 | Interpreting the causal curve
27 | ------------------------------
28 |
29 | Two of the methods contained within this package produce causal curves for continuous treatments
30 | (see the GPS and TMLE methods). Both continuous and binary treatments can be modeled
31 | (only the `GPS_Classifier` tool can handle binary outcomes).
32 |
33 | **Continuous outcome:**
34 |
35 | .. image:: ../imgs/welcome_plot.png
36 |
37 | Using the above causal curve as an example, we see that employing a treatment value between 50 - 60
38 | causally produces the highest outcome values. We also see that
39 | the treatment produces a smaller effect if lower or higher than that range. The confidence
40 | intervals become wider on the parts of the curve where we have fewer data points (near the minimum and
41 | maximum treatment values).
42 |
43 | This curve differs from a simple bivariate plot of the treatment and outcome or even a similar-looking plot
44 | generated through standard multivariable regression modeling in a few important ways:
45 |
46 | * This curve represents the estimated causal effect of a treatment on an outcome, not the association between treatment and outcome.
47 | * This curve represents a population-level effect, and should not be used to infer effects at the individual-level (or whatever the unit of analysis is).
48 | * To generate a similar-looking plot using multivariable regression, you would have to hold covariates constant, and any treatment effect that is inferred occurs within the levels of the covariates specified in the model. The causal curve averages out across all of these strata and gives us the population marginal effect.
49 |
50 | **Binary outcome:**
51 |
52 | .. image:: ../imgs/binary_OR_fig.png
53 |
54 | In the case of binary outcome, the `GPS_Classifier` tool can be used to estimate a curve of odds ratios. Every
55 | point on the curve is relative to the lowest treatment value. The highest effect (relative to the lowest treatment value)
56 | is around a treatment value of -1.2. At this point in the treatment, the odds of a positive class
57 | occurring is 5.6 times higher compared with the lowest treatment value. This curve is always on
58 | the relative scale. This is why the odds ratio for the lowest point is always 1.0, because it is
59 | relative to itself. Odds ratios are bounded [0, inf] and cannot take on a negative value. Note that
60 | the confidence intervals at any given point in the curve isn't symmetric.
61 |
62 |
63 | A caution about causal inference assumptions
64 | --------------------------------------------
65 |
66 | There is a well-documented set of assumptions one must make to infer causal effects from
67 | observational data. These are covered elsewhere in more detail, but briefly:
68 |
69 | - Causes always occur before effects: The treatment variable needs to have occurred before the outcome.
70 | - SUTVA: The treatment status of a given individual does not affect the potential outcomes of any other individuals.
71 | - Positivity: Any individual has a positive probability of receiving all values of the treatment variable.
72 | - Ignorability: All major confounding variables are included in the data you provide.
73 |
74 | Violations of these assumptions will lead to biased results and incorrect conclusions!
75 |
76 | In addition, any covariates that are included in `causal-curve` models are assumed to only
77 | be **confounding** variables.
78 |
79 | None of the methods provided in causal-curve rely on inference via instrumental variables, they only
80 | rely on the data from the observed treatment, confounders, and the outcome of interest (like the above GPS example).
81 |
82 |
83 | References
84 | ----------
85 |
86 | Hernán M. and Robins J. Causal Inference: What If. Chapman & Hall, 2020.
87 |
88 | Ahern J, Hubbard A, and Galea S. Estimating the Effects of Potential Public Health Interventions
89 | on Population Disease Burden: A Step-by-Step Illustration of Causal Inference Methods. American Journal of Epidemiology.
90 | 169(9), 2009. pp.1140–1147.
91 |
--------------------------------------------------------------------------------
/causal_curve/tmle_regressor.py:
--------------------------------------------------------------------------------
1 | """
2 | Defines the Targetted Maximum likelihood Estimation (TMLE) regressor model class
3 | """
4 | from pprint import pprint
5 |
6 | import numpy as np
7 |
8 | from causal_curve.tmle_core import TMLE_Core
9 |
10 |
11 | class TMLE_Regressor(TMLE_Core):
12 | """
13 | A TMLE tool that handles continuous outcomes. Inherits the TMLE_core
14 | base class. See that base class code its docstring for more details.
15 |
16 | Methods
17 | ----------
18 |
19 | point_estimate: (self, T)
20 | Calculates point estimate within the CDRC given treatment values.
21 | Can only be used when outcome is continuous.
22 | """
23 |
24 | def __init__(
25 | self,
26 | treatment_grid_num=100,
27 | lower_grid_constraint=0.01,
28 | upper_grid_constraint=0.99,
29 | n_estimators=200,
30 | learning_rate=0.01,
31 | max_depth=3,
32 | bandwidth=0.5,
33 | random_seed=None,
34 | verbose=False,
35 | ):
36 |
37 | self.treatment_grid_num = treatment_grid_num
38 | self.lower_grid_constraint = lower_grid_constraint
39 | self.upper_grid_constraint = upper_grid_constraint
40 | self.n_estimators = n_estimators
41 | self.learning_rate = learning_rate
42 | self.max_depth = max_depth
43 | self.bandwidth = bandwidth
44 | self.random_seed = random_seed
45 | self.verbose = verbose
46 |
47 | # Validate the params
48 | self._validate_init_params()
49 | self.rand_seed_wrapper()
50 |
51 | self.if_verbose_print("Using the following params for TMLE model:")
52 | if self.verbose:
53 | pprint(self.get_params(), indent=4)
54 |
55 | def _cdrc_predictions_continuous(self, ci):
56 | """Returns the predictions of CDRC for each value of the treatment grid. Essentially,
57 | we're making predictions using the original treatment against the pseudo-outcome.
58 | To be used when the outcome of interest is continuous.
59 | """
60 |
61 | # To keep track of cdrc predictions, we create an empty 2d array of shape
62 | # (treatment_grid_num, 4). The last dimension is of length 4 because
63 | # we are going to keep track of the treatment, point estimate of the prediction, as well as
64 | # the lower and upper bounds of the prediction interval
65 | cdrc_preds = np.zeros((self.treatment_grid_num, 4), dtype=float)
66 |
67 | # Loop through each of the grid values, get point estimate and get estimate interval
68 | for i in range(0, self.treatment_grid_num):
69 | temp_T = self.grid_values[i]
70 | temp_cdrc_preds = self.final_gam.predict(temp_T)
71 | temp_cdrc_interval = self.final_gam.confidence_intervals(temp_T, width=ci)
72 | temp_cdrc_lower_bound = temp_cdrc_interval[:, 0]
73 | temp_cdrc_upper_bound = temp_cdrc_interval[:, 1]
74 | cdrc_preds[i, 0] = temp_T
75 | cdrc_preds[i, 1] = temp_cdrc_preds
76 | cdrc_preds[i, 2] = temp_cdrc_lower_bound
77 | cdrc_preds[i, 3] = temp_cdrc_upper_bound
78 |
79 | return cdrc_preds
80 |
81 | def point_estimate(self, T):
82 | """Calculates point estimate within the CDRC given treatment values.
83 | Can only be used when outcome is continuous. Can be estimate for a single
84 | data point or can be run in batch for many observations. Extrapolation will produce
85 | untrustworthy results; the provided treatment should be within
86 | the range of the training data.
87 |
88 | Parameters
89 | ----------
90 | T: Numpy array, shape (n_samples,)
91 | A continuous treatment variable.
92 |
93 | Returns
94 | ----------
95 | array: Numpy array
96 | Contains a set of CDRC point estimates
97 |
98 | """
99 | return np.apply_along_axis(self._create_point_estimate, 0, T.reshape(1, -1))
100 |
101 | def _create_point_estimate(self, T):
102 | """Takes a single treatment value and produces a single point estimate
103 | in the case of a continuous outcome.
104 | """
105 | return self.final_gam.predict(np.array([T]).reshape(1, -1))
106 |
107 | def point_estimate_interval(self, T, ci=0.95):
108 | """Calculates the prediction confidence interval associated with a point estimate
109 | within the CDRC given treatment values. Can only be used
110 | when outcome is continuous. Can be estimate for a single data point
111 | or can be run in batch for many observations. Extrapolation will produce
112 | untrustworthy results; the provided treatment should be within
113 | the range of the training data.
114 |
115 | Parameters
116 | ----------
117 | T: Numpy array, shape (n_samples,)
118 | A continuous treatment variable.
119 | ci: float (default = 0.95)
120 | The desired confidence interval to produce. Default value is 0.95, corresponding
121 | to 95% confidence intervals. bounded (0, 1.0).
122 |
123 | Returns
124 | ----------
125 | array: Numpy array
126 | Contains a set of CDRC prediction intervals ([lower bound, higher bound])
127 |
128 | """
129 | return np.apply_along_axis(
130 | self._create_point_estimate_interval, 0, T.reshape(1, -1), width=ci
131 | ).T.reshape(-1, 2)
132 |
133 | def _create_point_estimate_interval(self, T, width):
134 | """Takes a single treatment value and produces confidence interval
135 | associated with a point estimate in the case of a continuous outcome.
136 | """
137 | return self.final_gam.prediction_intervals(
138 | np.array([T]).reshape(1, -1), width=width
139 | )
140 |
--------------------------------------------------------------------------------
/tests/unit/test_gps_regressor.py:
--------------------------------------------------------------------------------
1 | """ Unit tests for the GPS_Core class """
2 |
3 | import numpy as np
4 | from pygam import LinearGAM
5 | import pytest
6 |
7 | from causal_curve import GPS_Regressor
8 |
9 |
10 | def test_point_estimate_method_good(GPS_fitted_model_continuous_fixture):
11 | """
12 | Tests the GPS `point_estimate` method using appropriate data (with a continuous outcome)
13 | """
14 |
15 | observed_result = GPS_fitted_model_continuous_fixture.point_estimate(np.array([50]))
16 | assert isinstance(observed_result[0][0], float)
17 | assert len(observed_result[0]) == 1
18 |
19 | observed_result = GPS_fitted_model_continuous_fixture.point_estimate(
20 | np.array([40, 50, 60])
21 | )
22 | assert isinstance(observed_result[0][0], float)
23 | assert len(observed_result[0]) == 3
24 |
25 |
26 | def test_point_estimate_interval_method_good(GPS_fitted_model_continuous_fixture):
27 | """
28 | Tests the GPS `point_estimate_interval` method using appropriate data (with a continuous outcome)
29 | """
30 | observed_result = GPS_fitted_model_continuous_fixture.point_estimate_interval(
31 | np.array([50])
32 | )
33 | assert isinstance(observed_result[0][0], float)
34 | assert observed_result.shape == (1, 2)
35 |
36 | observed_result = GPS_fitted_model_continuous_fixture.point_estimate_interval(
37 | np.array([40, 50, 60])
38 | )
39 | assert isinstance(observed_result[0][0], float)
40 | assert observed_result.shape == (3, 2)
41 |
42 |
43 | def test_point_estimate_method_bad(GPS_fitted_model_continuous_fixture):
44 | """
45 | Tests the GPS `point_estimate` method using appropriate data (with a continuous outcome)
46 | """
47 |
48 | GPS_fitted_model_continuous_fixture.outcome_type = "binary"
49 |
50 | with pytest.raises(Exception) as bad:
51 | observed_result = GPS_fitted_model_continuous_fixture.point_estimate(
52 | np.array([50])
53 | )
54 |
55 |
56 | def test_point_estimate_interval_method_bad(GPS_fitted_model_continuous_fixture):
57 | """
58 | Tests the GPS `point_estimate_interval` method using appropriate data (with a continuous outcome)
59 | """
60 |
61 | GPS_fitted_model_continuous_fixture.outcome_type = "binary"
62 |
63 | with pytest.raises(Exception) as bad:
64 | observed_result = GPS_fitted_model_continuous_fixture.point_estimate_interval(
65 | np.array([50])
66 | )
67 |
68 |
69 | @pytest.mark.parametrize(
70 | (
71 | "gps_family",
72 | "treatment_grid_num",
73 | "lower_grid_constraint",
74 | "upper_grid_constraint",
75 | "spline_order",
76 | "n_splines",
77 | "lambda_",
78 | "max_iter",
79 | "random_seed",
80 | "verbose",
81 | ),
82 | [
83 | (546, 10, 0, 1.0, 3, 10, 0.5, 100, 100, True),
84 | ("linear", 10, 0, 1.0, 3, 10, 0.5, 100, 100, True),
85 | (None, "hehe", 0, 1.0, 3, 10, 0.5, 100, 100, True),
86 | (None, 2, 0, 1.0, 3, 10, 0.5, 100, 100, True),
87 | (None, 100000, 0, 1.0, 3, 10, 0.5, 100, 100, True),
88 | (None, 10, "hehe", 1.0, 3, 10, 0.5, 100, 100, True),
89 | (None, 10, -1.0, 1.0, 3, 10, 0.5, 100, 100, True),
90 | (None, 10, 1.5, 1.0, 3, 10, 0.5, 100, 100, True),
91 | (None, 10, 0, "hehe", 3, 10, 0.5, 100, 100, True),
92 | (None, 10, 0, 1.5, 3, 10, 0.5, 100, 100, True),
93 | (None, 100, -3.0, 0.99, 3, 30, 0.5, 100, None, True),
94 | (None, 100, 0.01, 1, 3, 30, 0.5, 100, None, True),
95 | (None, 100, 0.01, -4.5, 3, 30, 0.5, 100, None, True),
96 | (None, 100, 0.01, 5.5, 3, 30, 0.5, 100, None, True),
97 | (None, 100, 0.99, 0.01, 3, 30, 0.5, 100, None, True),
98 | (None, 100, 0.01, 0.99, 3.0, 30, 0.5, 100, None, True),
99 | (None, 100, 0.01, 0.99, -2, 30, 0.5, 100, None, True),
100 | (None, 100, 0.01, 0.99, 30, 30, 0.5, 100, None, True),
101 | (None, 100, 0.01, 0.99, 3, 30.0, 0.5, 100, None, True),
102 | (None, 100, 0.01, 0.99, 3, -2, 0.5, 100, None, True),
103 | (None, 100, 0.01, 0.99, 3, 500, 0.5, 100, None, True),
104 | (None, 100, 0.01, 0.99, 3, 30, 0.5, 100.0, None, True),
105 | (None, 100, 0.01, 0.99, 3, 30, 0.5, -100, None, True),
106 | (None, 100, 0.01, 0.99, 3, 30, 0.5, 10000000000, None, True),
107 | (None, 100, 0.01, 0.99, 3, 30, 0.5, 100, 234.5, True),
108 | (None, 100, 0.01, 0.99, 3, 30, 0.5, 100, -5, True),
109 | (None, 100, 0.01, 0.99, 3, 30, 0.5, 100, None, 4.0),
110 | (None, 10, 0, -1, 3, 10, 0.5, 100, 100, True),
111 | (None, 10, 0, 1, 3, 10, 0.5, 100, 100, True),
112 | (None, 10, 0, 1, "splines", 10, 0.5, 100, 100, True),
113 | (None, 10, 0, 1, 0, 10, 0.5, 100, 100, True),
114 | (None, 10, 0, 1, 200, 10, 0.5, 100, 100, True),
115 | (None, 10, 0, 1, 3, 0, 0.5, 100, 100, True),
116 | (None, 10, 0, 1, 3, 1e6, 0.5, 100, 100, True),
117 | (None, 10, 0, 1, 3, 10, 0.5, 100, 100, True),
118 | (None, 10, 0, 1, 3, 10, 0.5, "many", 100, True),
119 | (None, 10, 0, 1, 3, 10, 0.5, 5, 100, True),
120 | (None, 10, 0, 1, 3, 10, 0.5, 1e7, 100, True),
121 | (None, 10, 0, 1, 3, 10, 0.5, 100, "random", True),
122 | (None, 10, 0, 1, 3, 10, 0.5, 100, -1.5, True),
123 | (None, 10, 0, 1, 3, 10, 0.5, 100, 111, "True"),
124 | (None, 100, 0.01, 0.99, 3, 30, "lambda", 100, None, True),
125 | (None, 100, 0.01, 0.99, 3, 30, -1.0, 100, None, True),
126 | (None, 100, 0.01, 0.99, 3, 30, 2000.0, 100, None, True),
127 | ],
128 | )
129 | def test_bad_gps_instantiation(
130 | gps_family,
131 | treatment_grid_num,
132 | lower_grid_constraint,
133 | upper_grid_constraint,
134 | spline_order,
135 | n_splines,
136 | lambda_,
137 | max_iter,
138 | random_seed,
139 | verbose,
140 | ):
141 | """
142 | Tests for exceptions when the GPS class if call with bad inputs.
143 | """
144 | with pytest.raises(Exception) as bad:
145 | GPS_Regressor(
146 | gps_family=gps_family,
147 | treatment_grid_num=treatment_grid_num,
148 | lower_grid_constraint=lower_grid_constraint,
149 | upper_grid_constraint=upper_grid_constraint,
150 | spline_order=spline_order,
151 | n_splines=n_splines,
152 | lambda_=lambda_,
153 | max_iter=max_iter,
154 | random_seed=random_seed,
155 | verbose=verbose,
156 | )
157 |
--------------------------------------------------------------------------------
/causal_curve/gps_regressor.py:
--------------------------------------------------------------------------------
1 | """
2 | Defines the Generalized Prospensity Score (GPS) regressor model class
3 | """
4 | from pprint import pprint
5 |
6 | import numpy as np
7 |
8 | from causal_curve.gps_core import GPS_Core
9 |
10 |
11 | class GPS_Regressor(GPS_Core):
12 | """
13 | A GPS tool that handles continuous outcomes. Inherits the GPS_core
14 | base class. See that base class code its docstring for more details.
15 |
16 | Methods
17 | ----------
18 |
19 | point_estimate: (self, T)
20 | Calculates point estimate within the CDRC given treatment values.
21 | Can only be used when outcome is continuous.
22 |
23 | point_estimate_interval: (self, T, ci)
24 | Calculates the prediction confidence interval associated with a point estimate
25 | within the CDRC given treatment values. Can only be used when outcome is continuous.
26 |
27 | """
28 |
29 | def __init__(
30 | self,
31 | gps_family=None,
32 | treatment_grid_num=100,
33 | lower_grid_constraint=0.01,
34 | upper_grid_constraint=0.99,
35 | spline_order=3,
36 | n_splines=30,
37 | lambda_=0.5,
38 | max_iter=100,
39 | random_seed=None,
40 | verbose=False,
41 | ):
42 |
43 | self.gps_family = gps_family
44 | self.treatment_grid_num = treatment_grid_num
45 | self.lower_grid_constraint = lower_grid_constraint
46 | self.upper_grid_constraint = upper_grid_constraint
47 | self.spline_order = spline_order
48 | self.n_splines = n_splines
49 | self.lambda_ = lambda_
50 | self.max_iter = max_iter
51 | self.random_seed = random_seed
52 | self.verbose = verbose
53 |
54 | # Validate the params
55 | self._validate_init_params()
56 | self.rand_seed_wrapper()
57 |
58 | self.if_verbose_print("Using the following params for GPS model:")
59 | if self.verbose:
60 | pprint(self.get_params(), indent=4)
61 |
62 | def _cdrc_predictions_continuous(self, ci):
63 | """Returns the predictions of CDRC for each value of the treatment grid. Essentially,
64 | we're making predictions using the original treatment and gps_at_grid.
65 | To be used when the outcome of interest is continuous.
66 | """
67 |
68 | # To keep track of cdrc predictions, we create an empty 3d array of shape
69 | # (n_samples, treatment_grid_num, 3). The last dimension is of length 3 because
70 | # we are going to keep track of the point estimate of the prediction, as well as
71 | # the lower and upper bounds of the prediction interval
72 | cdrc_preds = np.zeros((len(self.T), self.treatment_grid_num, 3), dtype=float)
73 |
74 | # Loop through each of the grid values, predict point estimate and get prediction interval
75 | for i in range(0, self.treatment_grid_num):
76 | temp_T = np.repeat(self.grid_values[i], repeats=len(self.T))
77 | temp_gps = self.gps_at_grid[:, i]
78 | temp_cdrc_preds = self.gam_results.predict(
79 | np.column_stack((temp_T, temp_gps))
80 | )
81 | temp_cdrc_interval = self.gam_results.confidence_intervals(
82 | np.column_stack((temp_T, temp_gps)), width=ci
83 | )
84 | temp_cdrc_lower_bound = temp_cdrc_interval[:, 0]
85 | temp_cdrc_upper_bound = temp_cdrc_interval[:, 1]
86 | cdrc_preds[:, i, 0] = temp_cdrc_preds
87 | cdrc_preds[:, i, 1] = temp_cdrc_lower_bound
88 | cdrc_preds[:, i, 2] = temp_cdrc_upper_bound
89 |
90 | return np.round(cdrc_preds, 3)
91 |
92 | def point_estimate(self, T):
93 | """Calculates point estimate within the CDRC given treatment values.
94 | Can only be used when outcome is continuous. Can be estimate for a single
95 | data point or can be run in batch for many observations. Extrapolation will produce
96 | untrustworthy results; the provided treatment should be within
97 | the range of the training data.
98 |
99 | Parameters
100 | ----------
101 | T: Numpy array, shape (n_samples,)
102 | A continuous treatment variable.
103 |
104 | Returns
105 | ----------
106 | array: Numpy array
107 | Contains a set of CDRC point estimates
108 |
109 | """
110 | if self.outcome_type != "continuous":
111 | raise TypeError("Your outcome must be continuous to use this function!")
112 |
113 | return np.apply_along_axis(self._create_point_estimate, 0, T.reshape(1, -1))
114 |
115 | def _create_point_estimate(self, T):
116 | """Takes a single treatment value and produces a single point estimate
117 | in the case of a continuous outcome.
118 | """
119 | return self.gam_results.predict(
120 | np.array([T[0], self.gps_function(T).mean()]).reshape(1, -1)
121 | )
122 |
123 | def point_estimate_interval(self, T, ci=0.95):
124 | """Calculates the prediction confidence interval associated with a point estimate
125 | within the CDRC given treatment values. Can only be used
126 | when outcome is continuous. Can be estimate for a single data point
127 | or can be run in batch for many observations. Extrapolation will produce
128 | untrustworthy results; the provided treatment should be within
129 | the range of the training data.
130 |
131 | Parameters
132 | ----------
133 | T: Numpy array, shape (n_samples,)
134 | A continuous treatment variable.
135 | ci: float (default = 0.95)
136 | The desired confidence interval to produce. Default value is 0.95, corresponding
137 | to 95% confidence intervals. bounded (0, 1.0).
138 |
139 | Returns
140 | ----------
141 | array: Numpy array
142 | Contains a set of CDRC prediction intervals ([lower bound, higher bound])
143 |
144 | """
145 | if self.outcome_type != "continuous":
146 | raise TypeError("Your outcome must be continuous to use this function!")
147 |
148 | return np.apply_along_axis(
149 | self._create_point_estimate_interval, 0, T.reshape(1, -1), width=ci
150 | ).T.reshape(-1, 2)
151 |
152 | def _create_point_estimate_interval(self, T, width):
153 | """Takes a single treatment value and produces confidence interval
154 | associated with a point estimate in the case of a continuous outcome.
155 | """
156 | return self.gam_results.prediction_intervals(
157 | np.array([T[0], self.gps_function(T).mean()]).reshape(1, -1), width=width
158 | )
159 |
--------------------------------------------------------------------------------
/tests/conftest.py:
--------------------------------------------------------------------------------
1 | """Common fixtures for tests using pytest framework"""
2 |
3 | import numpy as np
4 | import pandas as pd
5 | import pytest
6 | from scipy.stats import norm
7 |
8 | from causal_curve import GPS_Regressor, GPS_Classifier, TMLE_Regressor
9 |
10 |
11 | @pytest.fixture(scope="module")
12 | def continuous_dataset_fixture():
13 | """Returns full_continuous_example_dataset (with a continuous outcome)"""
14 | return full_continuous_example_dataset()
15 |
16 |
17 | def full_continuous_example_dataset():
18 | """Example dataset with a treatment, two covariates, and continuous outcome variable"""
19 |
20 | np.random.seed(500)
21 |
22 | n_obs = 500
23 |
24 | treatment = np.random.normal(loc=50.0, scale=10.0, size=n_obs)
25 | x_1 = np.random.normal(loc=50.0, scale=10.0, size=n_obs)
26 | x_2 = np.random.normal(loc=0, scale=10.0, size=n_obs)
27 | outcome = treatment + x_1 + x_2 + np.random.normal(loc=50.0, scale=3.0, size=n_obs)
28 |
29 | fixture = pd.DataFrame(
30 | {"treatment": treatment, "x1": x_1, "x2": x_2, "outcome": outcome}
31 | )
32 | fixture.reset_index(drop=True, inplace=True)
33 |
34 | return fixture
35 |
36 |
37 | @pytest.fixture(scope="module")
38 | def binary_dataset_fixture():
39 | """Returns full_binary_example_dataset (with a binary outcome)"""
40 | return full_binary_example_dataset()
41 |
42 |
43 | def full_binary_example_dataset():
44 | """Example dataset with a treatment, two covariates, and binary outcome variable"""
45 |
46 | np.random.seed(500)
47 | treatment = np.linspace(
48 | start=0,
49 | stop=100,
50 | num=100,
51 | )
52 | x_1 = norm.rvs(size=100, loc=50, scale=5)
53 | outcome = [
54 | 0,
55 | 0,
56 | 0,
57 | 0,
58 | 0,
59 | 0,
60 | 0,
61 | 0,
62 | 0,
63 | 0,
64 | 0,
65 | 0,
66 | 0,
67 | 0,
68 | 0,
69 | 0,
70 | 0,
71 | 0,
72 | 0,
73 | 0,
74 | 0,
75 | 0,
76 | 0,
77 | 0,
78 | 0,
79 | 0,
80 | 0,
81 | 1,
82 | 0,
83 | 0,
84 | 0,
85 | 0,
86 | 0,
87 | 0,
88 | 0,
89 | 0,
90 | 0,
91 | 0,
92 | 1,
93 | 1,
94 | 0,
95 | 0,
96 | 0,
97 | 0,
98 | 0,
99 | 0,
100 | 0,
101 | 0,
102 | 0,
103 | 1,
104 | 1,
105 | 1,
106 | 1,
107 | 1,
108 | 1,
109 | 0,
110 | 1,
111 | 1,
112 | 1,
113 | 1,
114 | 1,
115 | 0,
116 | 1,
117 | 1,
118 | 1,
119 | 1,
120 | 1,
121 | 0,
122 | 1,
123 | 1,
124 | 1,
125 | 1,
126 | 1,
127 | 1,
128 | 1,
129 | 1,
130 | 1,
131 | 1,
132 | 1,
133 | 1,
134 | 1,
135 | 1,
136 | 1,
137 | 1,
138 | 1,
139 | 1,
140 | 1,
141 | 1,
142 | 1,
143 | 1,
144 | 1,
145 | 1,
146 | 1,
147 | 1,
148 | 1,
149 | 1,
150 | 1,
151 | 1,
152 | 1,
153 | 1,
154 | ]
155 |
156 | fixture = pd.DataFrame({"treatment": treatment, "x1": x_1, "outcome": outcome})
157 | fixture.reset_index(drop=True, inplace=True)
158 |
159 | return fixture
160 |
161 |
162 | @pytest.fixture(scope="module")
163 | def mediation_fixture():
164 | """Returns mediation_dataset"""
165 | return mediation_dataset()
166 |
167 |
168 | def mediation_dataset():
169 | """Example dataset to test / demonstrate mediation with a treatment,
170 | a mediator, and an outcome variable"""
171 |
172 | np.random.seed(500)
173 |
174 | n_obs = 500
175 |
176 | treatment = np.random.normal(loc=50.0, scale=10.0, size=n_obs)
177 | mediator = np.random.normal(loc=70.0 + treatment, scale=8.0, size=n_obs)
178 | outcome = np.random.normal(loc=(treatment + mediator - 50), scale=10.0, size=n_obs)
179 |
180 | fixture = pd.DataFrame(
181 | {"treatment": treatment, "mediator": mediator, "outcome": outcome}
182 | )
183 |
184 | fixture.reset_index(drop=True, inplace=True)
185 |
186 | return fixture
187 |
188 |
189 | @pytest.fixture(scope="module")
190 | def GPS_fitted_model_continuous_fixture():
191 | """Returns a GPS model that is already fit with data with a continuous outcome"""
192 | return GPS_fitted_model_continuous()
193 |
194 |
195 | def GPS_fitted_model_continuous():
196 | """Example GPS model that is fit with data including a continuous outcome"""
197 |
198 | df = full_continuous_example_dataset()
199 |
200 | fixture = GPS_Regressor(
201 | treatment_grid_num=10,
202 | lower_grid_constraint=0.0,
203 | upper_grid_constraint=1.0,
204 | spline_order=3,
205 | n_splines=10,
206 | max_iter=100,
207 | random_seed=100,
208 | verbose=True,
209 | )
210 | fixture.fit(
211 | T=df["treatment"],
212 | X=df["x1"],
213 | y=df["outcome"],
214 | )
215 |
216 | return fixture
217 |
218 |
219 | @pytest.fixture(scope="module")
220 | def GPS_fitted_model_binary_fixture():
221 | """Returns a GPS model that is already fit with data with a continuous outcome"""
222 | return GPS_fitted_model_binary()
223 |
224 |
225 | def GPS_fitted_model_binary():
226 | """Example GPS model that is fit with data including a continuous outcome"""
227 |
228 | df = full_binary_example_dataset()
229 |
230 | fixture = GPS_Classifier(
231 | gps_family="normal",
232 | treatment_grid_num=10,
233 | lower_grid_constraint=0.0,
234 | upper_grid_constraint=1.0,
235 | spline_order=3,
236 | n_splines=10,
237 | max_iter=100,
238 | random_seed=100,
239 | verbose=True,
240 | )
241 | fixture.fit(
242 | T=df["treatment"],
243 | X=df["x1"],
244 | y=df["outcome"],
245 | )
246 |
247 | return fixture
248 |
249 |
250 | @pytest.fixture(scope="module")
251 | def TMLE_fitted_model_continuous_fixture():
252 | """Returns a TMLE model that is already fit with data with a continuous outcome"""
253 | return TMLE_fitted_model_continuous()
254 |
255 |
256 | def TMLE_fitted_model_continuous():
257 | """Example GPS model that is fit with data including a continuous outcome"""
258 |
259 | df = full_continuous_example_dataset()
260 |
261 | fixture = TMLE_Regressor(
262 | random_seed=100,
263 | verbose=True,
264 | )
265 | fixture.fit(
266 | T=df["treatment"],
267 | X=df[["x1", "x2"]],
268 | y=df["outcome"],
269 | )
270 |
271 | return fixture
272 |
--------------------------------------------------------------------------------
/docs/contribute.rst:
--------------------------------------------------------------------------------
1 | .. _contribute:
2 |
3 | ==================
4 | Contributing guide
5 | ==================
6 |
7 | Thank you for considering contributing to causal-curve. Contributions from anyone
8 | are welcomed. There are many ways to contribute to the package, such as
9 | reporting bugs, adding new features and improving the documentation. The
10 | following sections give more details on how to contribute.
11 |
12 | **Important links**:
13 |
14 | - The project is hosted on GitHub: https://github.com/ronikobrosly/causal-curve
15 |
16 |
17 | Submitting a bug report or a feature request
18 | --------------------------------------------
19 |
20 | If you experience a bug using causal-curve or if you would like to see a new
21 | feature being added to the package, feel free to open an issue on GitHub:
22 | https://github.com/ronikobrosly/causal-curve/issues
23 |
24 | Bug report
25 | ^^^^^^^^^^
26 |
27 | A good bug report usually contains:
28 |
29 | - a description of the bug,
30 | - a self-contained example to reproduce the bug if applicable,
31 | - a description of the difference between the actual and expected results,
32 | - the versions of the dependencies of causal-curve.
33 |
34 | The last point can easily be done with the following commands::
35 |
36 | import numpy; print("NumPy", numpy.__version__)
37 |
38 | These guidelines make reproducing the bug easier, which make fixing it easier.
39 |
40 |
41 | Feature request
42 | ^^^^^^^^^^^^^^^
43 |
44 | A good feature request usually contains:
45 |
46 | - a description of the requested feature,
47 | - a description of the relevance of this feature to causal inference,
48 | - references if applicable, with links to the papers if they are in open access.
49 |
50 | This makes reviewing the relevance of the requested feature easier.
51 |
52 |
53 | Contributing code
54 | -----------------
55 |
56 | In order to contribute code, you need to create a pull request on
57 | https://github.com/ronikobrosly/causal-curve/pulls
58 |
59 | How to contribute
60 | ^^^^^^^^^^^^^^^^^
61 |
62 | To contribute to causal-curve, you need to fork the repository then submit a
63 | pull request:
64 |
65 | 1. Fork the repository.
66 |
67 | 2. Clone your fork of the causal-curve repository from your GitHub account to your
68 | local disk::
69 |
70 | git clone https://github.com/yourusername/causal-curve.git
71 | cd causal-curve
72 |
73 | where ``yourusername`` is your GitHub username.
74 |
75 | 3. Install the development dependencies::
76 |
77 | pip install pytest pylint black
78 |
79 | 4. Install causal-curve in editable mode::
80 |
81 | pip install -e .
82 |
83 | 5. Add the ``upstream`` remote. It creates a reference to the main repository
84 | that can be used to keep your repository synchronized with the latest changes
85 | on the main repository::
86 |
87 | git remote add upstream https://github.com/ronikobrosly/causal-curve.git
88 |
89 | 6. Fetch the ``upstream`` remote then create a new branch where you will make
90 | your changes and switch to it::
91 |
92 | git fetch upstream
93 | git checkout -b my-feature upstream/main
94 |
95 | where ``my-feature`` is the name of your new branch (it's good practice to have
96 | an explicit name). You can now start making changes.
97 |
98 | 7. Make the changes that you want on your new branch on your new local machine.
99 | When you are done, add the changed files using ``git add`` and then
100 | ``git commit``::
101 |
102 | git add modified_files
103 | git commit
104 |
105 | Then push your commits to your GitHub account using ``git push``::
106 |
107 | git push origin my-feature
108 |
109 | 8. Create a pull request from your work. The base fork is the fork you
110 | would like to merge changes into, that is ``ronikobrosly/causal-curve`` on the
111 | ``main`` branch. The head fork is the repository where you made your
112 | changes, that is ``yourusername/causal-curve`` on the ``my-feature`` branch.
113 | Add a title and a description of your pull request, then click on
114 | **Create Pull Request**.
115 |
116 |
117 | Pull request checklist
118 | ^^^^^^^^^^^^^^^^^^^^^^
119 |
120 | Before pushing to your GitHub account, there are a few rules that are
121 | usually worth complying with.
122 |
123 | - **Make sure that your code passes tests**. You can do this by running the
124 | whole test suite with the ``pytest`` command. If you are experienced with
125 | ``pytest``, you can run specific tests that are relevant for your changes.
126 | It is still worth it running the whole test suite when you are done making
127 | changes since it does not take very long.
128 | For more information, please refer to the
129 | `pytest documentation `_.
130 | If your code does not pass tests but you are looking for help, feel free
131 | to do so (but mention it in your pull request).
132 |
133 | - **Make sure to add tests if you add new code**. It is important to test
134 | new code to make sure that it behaves as expected. Ideally code coverage
135 | should increase with any new pull request. You can check code coverage
136 | using ``pytest-cov``::
137 |
138 | pip install pytest-cov
139 | pytest --cov causal-curve
140 |
141 | - **Make sure that the documentation renders properly**. To build the
142 | documentation, please refer to the :ref:`contribute_documentation` guidelines.
143 |
144 | - **Make sure that your PR does not add PEP8 violations**. You can run `black`
145 | and `pylint` to only test the modified code.
146 | Feel free to submit another pull request if you find other PEP8 violations.
147 |
148 | .. _contribute_documentation:
149 |
150 | Contributing to the documentation
151 | ---------------------------------
152 |
153 | Documentation is as important as code. If you see typos, find docstrings
154 | unclear or want to add examples illustrating functionalities provided in
155 | causal-curve, feel free to open an issue to report it or a pull request if you
156 | want to fix it.
157 |
158 |
159 | Building the documentation
160 | ^^^^^^^^^^^^^^^^^^^^^^^^^^
161 |
162 | Building the documentation requires installing some additional packages::
163 |
164 | pip install sphinx==3.0.2 sphinx-rtd-theme numpydoc
165 |
166 | To build the documentation, you must be in the ``doc`` folder::
167 |
168 | cd doc
169 |
170 | To generate the website with the example gallery, run the following command::
171 |
172 | make html
173 |
174 | The documentation will be generated in the ``_build/html``. You can double
175 | click on ``index.html`` to open the index page, which will look like
176 | the first page that you see on the online documentation. Then you can move to
177 | the pages that you modified and have a look at your changes.
178 |
179 | Finally, repeat this process until you are satisfied with your changes and
180 | open a pull request describing the changes you made.
181 |
--------------------------------------------------------------------------------
/paper/paper.md:
--------------------------------------------------------------------------------
1 | ---
2 | title: 'causal-curve: A Python Causal Inference Package to Estimate Causal Dose-Response Curves'
3 | tags:
4 | - Python
5 | - causal inference
6 | - causality
7 | - machine learning
8 |
9 | authors:
10 | - name: Roni W. Kobrosly
11 | orcid: 0000-0003-0363-9662
12 | affiliation: "1, 2" # (Multiple affiliations must be quoted)
13 | affiliations:
14 | - name: Department of Environmental Medicine and Public Health, Icahn School of Medicine at Mount Sinai, New York, NY, USA
15 | index: 1
16 | - name: Flowcast, 44 Tehama St, San Francisco, CA, USA
17 | index: 2
18 | date: 1 July 2020
19 | bibliography: paper.bib
20 |
21 | ---
22 |
23 | # Summary
24 |
25 | In academia and industry, randomized controlled experiments (colloquially "A/B tests")
26 | are considered the gold standard approach for assessing the impact of a treatment or intervention.
27 | However, for ethical or financial reasons, these experiments may not always be feasible to carry out.
28 | "Causal inference" methods are a set of approaches that attempt to estimate causal effects
29 | from observational rather than experimental data, correcting for the biases that are inherent
30 | to analyzing observational data (e.g. confounding and selection bias) [@Hernán:2020].
31 |
32 | Although significant research and implementation effort has gone towards methods in
33 | causal inference to estimate the effects of binary treatments (e.g. what was the effect of
34 | treatment "A" or "B"?), much less has gone towards estimating the effects of continuous treatments.
35 | This is unfortunate because there are a great number of inquiries in research
36 | and industry that could benefit from tools to estimate the effect of
37 | continuous treatments, such as estimating how:
38 |
39 | - the number of minutes per week of aerobic exercise causes positive health outcomes,
40 | after controlling for confounding effects.
41 | - increasing or decreasing the price of a product would impact demand (price elasticity).
42 | - changing neighborhood income inequality (as measured by the continuous Gini index)
43 | might or might not be causally related to the neighborhood crime rate.
44 | - blood lead levels are causally related to neurodevelopment delays in children.
45 |
46 | `causal-curve` is a Python package created to address this gap; it is designed to perform
47 | causal inference when the treatment of interest is continuous in nature.
48 | From the observational data that is provided by the user, it estimates the
49 | "causal dose-response curve" (or simply the "causal curve").
50 |
51 | In the current release of the package there are two unique model classes for
52 | constructing the causal dose-response curve: the Generalized Propensity Score (GPS) and the
53 | Targetted Maximum Likelihood Estimation (TMLE) tools. There is also tool
54 | to assess causal mediation effects in the presence of a continuous mediator and treatment.
55 |
56 | `causal-curve` attempts to make the user-experience as painless as possible:
57 |
58 | - This package's API was designed to resemble that of `scikit-learn`, as this is a commonly
59 | used Python predictive modeling framework familiar to most machine learning practitioners.
60 | - All of the major classes contained in `causal-curve` readily use Pandas DataFrames and Series as
61 | inputs, to make this package more easily integrate with the standard Python data analysis tools.
62 | - A full, end-to-end example of applying the package to a causal inference problem (the analysis of health data)
63 | is provided. In addition to this, there are shorter tutorials for each of the three major classes are available online in the documentation, along with full documentation of all of their parameters, methods, and attributes.
64 |
65 | This package includes a suite of unit and integration tests made using the pytest framework. The
66 | repo containing the latest project code is integrated with TravisCI for continuous integration. Code
67 | coverage is monitored via codecov and is presently above 90%.
68 |
69 |
70 | # Methods
71 |
72 | The `GPS` method was originally described by Hirano [@Hirano:2004],
73 | and expanded by Moodie [@Moodie:2010] and more recently by Galagate [@Galagate:2016]. GPS is
74 | an extension of the standard propensity tool method and is essentially the treatment assignment density calculated
75 | at a particular treatment (and covariate) value. Similar to the standard propensity score approach,
76 | the GPS random variable is used to balance covariates. At the core of this tool, generalized linear
77 | models are used to estimate the GPS, and generalized additive models are used to estimate the smoothed
78 | final causal curve. Compared with this package’s TMLE method, this GPS method is more
79 | computationally efficient, better suited for large datasets, but produces significantly wider confidence intervals.
80 |
81 |
82 | 
83 |
84 |
85 | The `TMLE` method is based on van der Laan's work on an approach to causal inference that would
86 | employ powerful machine learning approaches to estimate a causal effect [@van_der_Laan:2010].
87 | TMLE involves predicting the outcome from the treatment and covariates using a machine learning model,
88 | then predicting treatment assignment from the covariates. TMLE also employs a substitution “targeting”
89 | step to correct for covariate imbalance and to estimate an unbiased causal effect.
90 | Currently, there is no implementation of TMLE that is suitable for continuous treatments. The
91 | implemention in `causal-curve` constructs the final curve through a series of binary treatment comparisons
92 | across the user-specified range of treatment values and then by connecting them.
93 | Compared with the package’s GPS method, this TMLE method is double robust
94 | against model misspecification, incorporates more powerful machine learning techniques internally, produces significantly
95 | smaller confidence intervals, however it is less computationally efficient.
96 |
97 | `causal-curve` allows for continuous mediation assessment with the `Mediation` tool. As described
98 | by Imai this approach provides a general approach to mediation analysis that invokes the potential
99 | outcomes / counterfactual framework [@Imai:2010]. While this approach can handle a
100 | continuous mediator and outcome, as put forward by Imai it only allows for a binary treatment. As
101 | mentioned above with the `TMLE` approach, the tool creates a series of binary treatment comparisons
102 | and connects them to show the user how mediation varies as a function of the treatment. An interpretable,
103 | overall mediation proportion is provided as well.
104 |
105 |
106 | # Statement of Need
107 |
108 | While there are a few established Python packages related to causal inference, to the best of
109 | the author's knowledge, there is no Python package available that can provide support for
110 | continuous treatments as `causal-curve` does. Similarly, the author isn't aware of any Python
111 | implementation of a causal mediation analysis for continuous treatments and mediators. Finally,
112 | the tutorials available in the documentation introduce the concept of continuous treatments
113 | and are instructive as to how the results of their analysis should be interpretted.
114 |
115 |
116 | # Acknowledgements
117 |
118 | We acknowledge the valuable feedback from Miguel-Angel Luque, Erica Moodie, and Mark van der Laan
119 | during the creation of this project.
120 |
121 |
122 | # References
123 |
--------------------------------------------------------------------------------
/docs/full_example.rst:
--------------------------------------------------------------------------------
1 | .. _full_example:
2 |
3 | =============================================================
4 | Health data: generating causal curves and examining mediation
5 | =============================================================
6 |
7 | To provide an end-to-end example of the sorts of analyses `cause-curve` can be used for, we'll
8 | begin with a health topic. A notebook containing the pipeline to produce the following
9 | output `is available here `_.
10 | Note: Specific examples of the individual `causal-curve` tools with
11 | code are available elsewhere in this documentation.
12 |
13 |
14 | The causal effect of blood lead levels on cognitive performance in children
15 | ---------------------------------------------------------------------------
16 |
17 | Despite the banning of the use of lead-based paint and the use of lead in gasoline in the United
18 | States, lead exposure remains an enormous public health problem for children and adolescents. This
19 | is particularly true for poorer children living in older homes in inner-city environments.
20 | For children, there is no known safe level of exposure to lead, and even small levels of
21 | lead measured in their blood have been shown to affect IQ and academic achievement.
22 | One of the scariest parts of lead exposure is that its effects are permanent. Blood lead levels (BLLs)
23 | of 5 ug/dL or higher are considered elevated.
24 |
25 | There are much research around and many government programs for lead abatement. In terms of
26 | public policy, it would be helpful to understand how childhood cognitive outcomes would be affected by
27 | reducing BLLs in children. This is the causal question to answer, with blood lead
28 | levels being the continuous treatment, and the cognitive outcomes being the outcome of interest.
29 |
30 | .. image:: https://upload.wikimedia.org/wikipedia/commons/6/69/LeadPaint1.JPG
31 |
32 | (Photo attribution: Thester11 / CC BY (https://creativecommons.org/licenses/by/3.0))
33 |
34 | To explore that problem, we can analyze data collected from the National Health and Nutrition
35 | Examination Survey (NHANES) III. This was a large, national study of families throughout the United
36 | States, carried out between 1988 and 1994. Participants were involved in extensive interviews,
37 | medical examinations, and provided biological samples. As part of this project, BLLs
38 | were measured, and four scaled sub-tests of the Wechsler Intelligence Scale for Children-Revised
39 | and the Wide Range Achievement Test-Revised (WISC/WRAT) cognitive test were carried out. This data
40 | is de-identified and publicly available on the Centers for Disease Control and Prevention (CDC)
41 | government website.
42 |
43 | When processing the data and missing values were dropped, there were 1,764 children between
44 | 6 and 12 years of age with complete data. BLLs among these children were log-normally
45 | distributed, as one would expect:
46 |
47 | .. image:: ../imgs/full_example/BLL_dist.png
48 |
49 | The four scaled sub-tests of the WISC/WRAT included a math test, a reading test, a block design
50 | test (a test of spatial visualization ability and motor skill), and a digit spanning test
51 | (a test of memory). Their distributions are shown here:
52 |
53 | .. image:: ../imgs/full_example/test_dist.png
54 |
55 | Using a well-known study by Bruce Lanphear conducted in 2000 as a guide, we used the following
56 | features as potentially confounding "nuisance" variables:
57 |
58 | - Child age
59 | - Child sex (in 1988 - 1994 the CDC assumed binary sex)
60 | - Child race/ethnicity
61 | - The education level of the guardian
62 | - Whether someone smokes in the child's home
63 | - Whether the child spent time in a neonatal intensive care unit as a baby
64 | - Whether the child is experiencing food insecurity (is food sometimes not available due to lack of resources?).
65 |
66 | In our "experiment", these above confounders will be controlled for.
67 |
68 | By using either the GPS or TMLE tools included in `causal-curve` one can generate the causal
69 | dose-response curves for BLLs in relation to the four outcomes:
70 |
71 | .. image:: ../imgs/full_example/test_causal_curves.png
72 |
73 | Note that the lower limit of detection for the blood lead test in this version of NHANES was
74 | 0.7 ug/dL. So lead levels below that value are not possible.
75 |
76 | In the case of the math test, these results indicate that by reducing BLLs in this population
77 | to their lowest value would cause scaled math scores to increase by around 2 points, relative
78 | to the BLLs around 10 ug/dL. Similar results are found for the reading and block design test,
79 | although the digit spanning test causal curve appears possibly flat (although with the sparse
80 | observations at the higher end of the BLL range and the wide confidence intervals it is
81 | difficult to say).
82 |
83 | The above curves differ from standard regression curves in a few big ways:
84 |
85 | - Even though the data that we used to generate these curves are observational, if causal inference assumptions are met, these curves can be interpretted as causal.
86 | - These models were created using the potential outcomes / counterfactual framework, while standard models are not. Also, the approach we used here essentially simulates experimental conditions by balancing out treatment assignment across the various confounders, and controlling for their effects.
87 | - Even if complex interactions between the variables are modelled, these curves average over the various interaction effects and subgroups. In this sense, these are "marginal" curves.
88 | - These curves should not be used to make predictions at the individual level. These are population level estimates and should remain that way.
89 |
90 |
91 |
92 | Do blood lead levels mediate the relationship between poverty and cognitive performance?
93 | ----------------------------------------------------------------------------------------
94 |
95 | There is a well-known link between household income and child academic performance. Now that we
96 | have some evidence of a potentially causal relationship between BLLs and test performance in
97 | children, one might wonder if lead exposure might mediate the relationship between household income
98 | academic performance. In other words, in this population does low income cause one to be
99 | exposed more to lead, which in turn causes lower performance? Or is household income directly
100 | linked with academic performance or through other variables?
101 |
102 | NHANES III captured each household's Poverty Index Ratio (the ratio of total family income to
103 | the federal poverty level for the year of the interview). For this example, let's focus just
104 | on the math test as an outcome. Using `causal-curve`'s mediation tool,
105 | we found that the overall, mediating indirect effect of BLLs are 0.20 (0.17 - 0.23). This means
106 | that lead exposure accounts for 20% of the relationship between low income and low test
107 | performance in this population. The mediation tool also allows you to see how the indirect effect
108 | varies as a function of the treatment. As the plot shows, the mediating effect is relatively flat,
109 | although interesting there is a hint of an increase as income increases relative to the poverty line.
110 |
111 | .. image:: ../imgs/full_example/mediation_curve.png
112 |
113 |
114 | References
115 | ----------
116 |
117 | Centers for Disease Control and Prevention. NHANES III (1988-1994).
118 | https://wwwn.cdc.gov/nchs/nhanes/nhanes3/default.aspx. Accessed on July 2, 2020.
119 |
120 | Centers for Disease Control and Prevention. Blood Lead Levels in Children.
121 | https://www.cdc.gov/nceh/lead/prevention/blood-lead-levels.htm. Accessed on July 2, 2020.
122 |
123 | Environmental Protection Agency. Learn about Lead. https://www.epa.gov/lead/learn-about-lead.
124 | Accessed on July 2, 2020.
125 |
126 | Pirkle JL, Kaufmann RB, Brody DJ, Hickman T, Gunter EW, Paschal DC. Exposure of the
127 | U.S. population to lead, 1991-1994. Environmental Health Perspectives, 106(11), 1998, pp. 745–750.
128 |
129 | Lanphear BP, Dietrich K, Auinger P, Cox C. Cognitive Deficits Associated with
130 | Blood Lead Concentrations <10 pg/dL in US Children and Adolescents.
131 | In: Public Health Reports, 115, 2000, pp.521-529.
132 |
--------------------------------------------------------------------------------
/.pylintrc:
--------------------------------------------------------------------------------
1 | [MASTER]
2 |
3 | # A comma-separated list of package or module names from where C extensions may
4 | # be loaded. Extensions are loading into the active Python interpreter and may
5 | # run arbitrary code.
6 | extension-pkg-whitelist=
7 |
8 | # Specify a score threshold to be exceeded before program exits with error.
9 | fail-under=10.0
10 |
11 | # Add files or directories to the blacklist. They should be base names, not
12 | # paths.
13 | ignore=CVS
14 |
15 | # Add files or directories matching the regex patterns to the blacklist. The
16 | # regex matches against base names, not paths.
17 | ignore-patterns=
18 |
19 | # Python code to execute, usually for sys.path manipulation such as
20 | # pygtk.require().
21 | #init-hook=
22 |
23 | # Use multiple processes to speed up Pylint. Specifying 0 will auto-detect the
24 | # number of processors available to use.
25 | jobs=1
26 |
27 | # Control the amount of potential inferred values when inferring a single
28 | # object. This can help the performance when dealing with large functions or
29 | # complex, nested conditions.
30 | limit-inference-results=100
31 |
32 | # List of plugins (as comma separated values of python module names) to load,
33 | # usually to register additional checkers.
34 | load-plugins=
35 |
36 | # Pickle collected data for later comparisons.
37 | persistent=yes
38 |
39 | # When enabled, pylint would attempt to guess common misconfiguration and emit
40 | # user-friendly hints instead of false-positive error messages.
41 | suggestion-mode=yes
42 |
43 | # Allow loading of arbitrary C extensions. Extensions are imported into the
44 | # active Python interpreter and may run arbitrary code.
45 | unsafe-load-any-extension=no
46 |
47 |
48 | [MESSAGES CONTROL]
49 |
50 | # Only show warnings with the listed confidence levels. Leave empty to show
51 | # all. Valid levels: HIGH, INFERENCE, INFERENCE_FAILURE, UNDEFINED.
52 | confidence=
53 |
54 | # Disable the message, report, category or checker with the given id(s). You
55 | # can either give multiple identifiers separated by comma (,) or put this
56 | # option multiple times (only on the command line, not in the configuration
57 | # file where it should appear only once). You can also use "--disable=all" to
58 | # disable everything first and then reenable specific checks. For example, if
59 | # you want to run only the similarities checker, you can use "--disable=all
60 | # --enable=similarities". If you want to run only the classes checker, but have
61 | # no Warning level messages displayed, use "--disable=all --enable=classes
62 | # --disable=W".
63 | disable=too-many-locals,
64 | super-init-not-called,
65 | attribute-defined-outside-init,
66 | too-many-instance-attributes,
67 | invalid-name,
68 | too-few-public-methods,
69 | consider-using-dict-comprehension,
70 | too-many-arguments,
71 | no-name-in-module,
72 | duplicate-code,
73 | print-statement,
74 | parameter-unpacking,
75 | unpacking-in-except,
76 | old-raise-syntax,
77 | backtick,
78 | long-suffix,
79 | old-ne-operator,
80 | old-octal-literal,
81 | import-star-module-level,
82 | non-ascii-bytes-literal,
83 | raw-checker-failed,
84 | bad-inline-option,
85 | locally-disabled,
86 | file-ignored,
87 | suppressed-message,
88 | useless-suppression,
89 | deprecated-pragma,
90 | use-symbolic-message-instead,
91 | apply-builtin,
92 | basestring-builtin,
93 | buffer-builtin,
94 | cmp-builtin,
95 | coerce-builtin,
96 | execfile-builtin,
97 | file-builtin,
98 | long-builtin,
99 | raw_input-builtin,
100 | reduce-builtin,
101 | standarderror-builtin,
102 | unicode-builtin,
103 | xrange-builtin,
104 | coerce-method,
105 | delslice-method,
106 | getslice-method,
107 | setslice-method,
108 | no-absolute-import,
109 | old-division,
110 | dict-iter-method,
111 | dict-view-method,
112 | next-method-called,
113 | metaclass-assignment,
114 | indexing-exception,
115 | raising-string,
116 | reload-builtin,
117 | oct-method,
118 | hex-method,
119 | nonzero-method,
120 | cmp-method,
121 | input-builtin,
122 | round-builtin,
123 | intern-builtin,
124 | unichr-builtin,
125 | map-builtin-not-iterating,
126 | zip-builtin-not-iterating,
127 | range-builtin-not-iterating,
128 | filter-builtin-not-iterating,
129 | using-cmp-argument,
130 | eq-without-hash,
131 | div-method,
132 | idiv-method,
133 | rdiv-method,
134 | exception-message-attribute,
135 | invalid-str-codec,
136 | sys-max-int,
137 | bad-python3-import,
138 | deprecated-string-function,
139 | deprecated-str-translate-call,
140 | deprecated-itertools-function,
141 | deprecated-types-field,
142 | next-method-defined,
143 | dict-items-not-iterating,
144 | dict-keys-not-iterating,
145 | dict-values-not-iterating,
146 | deprecated-operator-function,
147 | deprecated-urllib-function,
148 | xreadlines-attribute,
149 | deprecated-sys-function,
150 | exception-escape,
151 | comprehension-escape
152 |
153 | # Enable the message, report, category or checker with the given id(s). You can
154 | # either give multiple identifier separated by comma (,) or put this option
155 | # multiple time (only on the command line, not in the configuration file where
156 | # it should appear only once). See also the "--disable" option for examples.
157 | enable=c-extension-no-member
158 |
159 |
160 | [REPORTS]
161 |
162 | # Python expression which should return a score less than or equal to 10. You
163 | # have access to the variables 'error', 'warning', 'refactor', and 'convention'
164 | # which contain the number of messages in each category, as well as 'statement'
165 | # which is the total number of statements analyzed. This score is used by the
166 | # global evaluation report (RP0004).
167 | evaluation=10.0 - ((float(5 * error + warning + refactor + convention) / statement) * 10)
168 |
169 | # Template used to display messages. This is a python new-style format string
170 | # used to format the message information. See doc for all details.
171 | #msg-template=
172 |
173 | # Set the output format. Available formats are text, parseable, colorized, json
174 | # and msvs (visual studio). You can also give a reporter class, e.g.
175 | # mypackage.mymodule.MyReporterClass.
176 | output-format=text
177 |
178 | # Tells whether to display a full report or only the messages.
179 | reports=no
180 |
181 | # Activate the evaluation score.
182 | score=yes
183 |
184 |
185 | [REFACTORING]
186 |
187 | # Maximum number of nested blocks for function / method body
188 | max-nested-blocks=5
189 |
190 | # Complete name of functions that never returns. When checking for
191 | # inconsistent-return-statements if a never returning function is called then
192 | # it will be considered as an explicit return statement and no message will be
193 | # printed.
194 | never-returning-functions=sys.exit
195 |
196 |
197 | [LOGGING]
198 |
199 | # The type of string formatting that logging methods do. `old` means using %
200 | # formatting, `new` is for `{}` formatting.
201 | logging-format-style=old
202 |
203 | # Logging modules to check that the string format arguments are in logging
204 | # function parameter format.
205 | logging-modules=logging
206 |
207 |
208 | [SPELLING]
209 |
210 | # Limits count of emitted suggestions for spelling mistakes.
211 | max-spelling-suggestions=4
212 |
213 | # Spelling dictionary name. Available dictionaries: none. To make it work,
214 | # install the python-enchant package.
215 | spelling-dict=
216 |
217 | # List of comma separated words that should not be checked.
218 | spelling-ignore-words=
219 |
220 | # A path to a file that contains the private dictionary; one word per line.
221 | spelling-private-dict-file=
222 |
223 | # Tells whether to store unknown words to the private dictionary (see the
224 | # --spelling-private-dict-file option) instead of raising a message.
225 | spelling-store-unknown-words=no
226 |
227 |
228 | [MISCELLANEOUS]
229 |
230 | # List of note tags to take in consideration, separated by a comma.
231 | notes=FIXME,
232 | XXX,
233 | TODO
234 |
235 | # Regular expression of note tags to take in consideration.
236 | #notes-rgx=
237 |
238 |
239 | [TYPECHECK]
240 |
241 | # List of decorators that produce context managers, such as
242 | # contextlib.contextmanager. Add to this list to register other decorators that
243 | # produce valid context managers.
244 | contextmanager-decorators=contextlib.contextmanager
245 |
246 | # List of members which are set dynamically and missed by pylint inference
247 | # system, and so shouldn't trigger E1101 when accessed. Python regular
248 | # expressions are accepted.
249 | generated-members=
250 |
251 | # Tells whether missing members accessed in mixin class should be ignored. A
252 | # mixin class is detected if its name ends with "mixin" (case insensitive).
253 | ignore-mixin-members=yes
254 |
255 | # Tells whether to warn about missing members when the owner of the attribute
256 | # is inferred to be None.
257 | ignore-none=yes
258 |
259 | # This flag controls whether pylint should warn about no-member and similar
260 | # checks whenever an opaque object is returned when inferring. The inference
261 | # can return multiple potential results while evaluating a Python object, but
262 | # some branches might not be evaluated, which results in partial inference. In
263 | # that case, it might be useful to still emit no-member and other checks for
264 | # the rest of the inferred objects.
265 | ignore-on-opaque-inference=yes
266 |
267 | # List of class names for which member attributes should not be checked (useful
268 | # for classes with dynamically set attributes). This supports the use of
269 | # qualified names.
270 | ignored-classes=optparse.Values,thread._local,_thread._local
271 |
272 | # List of module names for which member attributes should not be checked
273 | # (useful for modules/projects where namespaces are manipulated during runtime
274 | # and thus existing member attributes cannot be deduced by static analysis). It
275 | # supports qualified module names, as well as Unix pattern matching.
276 | ignored-modules=
277 |
278 | # Show a hint with possible names when a member name was not found. The aspect
279 | # of finding the hint is based on edit distance.
280 | missing-member-hint=yes
281 |
282 | # The minimum edit distance a name should have in order to be considered a
283 | # similar match for a missing member name.
284 | missing-member-hint-distance=1
285 |
286 | # The total number of similar names that should be taken in consideration when
287 | # showing a hint for a missing member.
288 | missing-member-max-choices=1
289 |
290 | # List of decorators that change the signature of a decorated function.
291 | signature-mutators=
292 |
293 |
294 | [VARIABLES]
295 |
296 | # List of additional names supposed to be defined in builtins. Remember that
297 | # you should avoid defining new builtins when possible.
298 | additional-builtins=
299 |
300 | # Tells whether unused global variables should be treated as a violation.
301 | allow-global-unused-variables=yes
302 |
303 | # List of strings which can identify a callback function by name. A callback
304 | # name must start or end with one of those strings.
305 | callbacks=cb_,
306 | _cb
307 |
308 | # A regular expression matching the name of dummy variables (i.e. expected to
309 | # not be used).
310 | dummy-variables-rgx=_+$|(_[a-zA-Z0-9_]*[a-zA-Z0-9]+?$)|dummy|^ignored_|^unused_
311 |
312 | # Argument names that match this expression will be ignored. Default to name
313 | # with leading underscore.
314 | ignored-argument-names=_.*|^ignored_|^unused_
315 |
316 | # Tells whether we should check for unused import in __init__ files.
317 | init-import=no
318 |
319 | # List of qualified module names which can have objects that can redefine
320 | # builtins.
321 | redefining-builtins-modules=six.moves,past.builtins,future.builtins,builtins,io
322 |
323 |
324 | [FORMAT]
325 |
326 | # Expected format of line ending, e.g. empty (any line ending), LF or CRLF.
327 | expected-line-ending-format=
328 |
329 | # Regexp for a line that is allowed to be longer than the limit.
330 | ignore-long-lines=^\s*(# )??$
331 |
332 | # Number of spaces of indent required inside a hanging or continued line.
333 | indent-after-paren=4
334 |
335 | # String used as indentation unit. This is usually " " (4 spaces) or "\t" (1
336 | # tab).
337 | indent-string=' '
338 |
339 | # Maximum number of characters on a single line.
340 | max-line-length=100
341 |
342 | # Maximum number of lines in a module.
343 | max-module-lines=1000
344 |
345 | # Allow the body of a class to be on the same line as the declaration if body
346 | # contains single statement.
347 | single-line-class-stmt=no
348 |
349 | # Allow the body of an if to be on the same line as the test if there is no
350 | # else.
351 | single-line-if-stmt=no
352 |
353 |
354 | [SIMILARITIES]
355 |
356 | # Ignore comments when computing similarities.
357 | ignore-comments=yes
358 |
359 | # Ignore docstrings when computing similarities.
360 | ignore-docstrings=yes
361 |
362 | # Ignore imports when computing similarities.
363 | ignore-imports=no
364 |
365 | # Minimum lines number of a similarity.
366 | min-similarity-lines=4
367 |
368 |
369 | [BASIC]
370 |
371 | # Naming style matching correct argument names.
372 | argument-naming-style=snake_case
373 |
374 | # Regular expression matching correct argument names. Overrides argument-
375 | # naming-style.
376 | #argument-rgx=
377 |
378 | # Naming style matching correct attribute names.
379 | attr-naming-style=snake_case
380 |
381 | # Regular expression matching correct attribute names. Overrides attr-naming-
382 | # style.
383 | #attr-rgx=
384 |
385 | # Bad variable names which should always be refused, separated by a comma.
386 | bad-names=foo,
387 | bar,
388 | baz,
389 | toto,
390 | tutu,
391 | tata
392 |
393 | # Bad variable names regexes, separated by a comma. If names match any regex,
394 | # they will always be refused
395 | bad-names-rgxs=
396 |
397 | # Naming style matching correct class attribute names.
398 | class-attribute-naming-style=any
399 |
400 | # Regular expression matching correct class attribute names. Overrides class-
401 | # attribute-naming-style.
402 | #class-attribute-rgx=
403 |
404 | # Naming style matching correct class names.
405 | class-naming-style=PascalCase
406 |
407 | # Regular expression matching correct class names. Overrides class-naming-
408 | # style.
409 | #class-rgx=
410 |
411 | # Naming style matching correct constant names.
412 | const-naming-style=UPPER_CASE
413 |
414 | # Regular expression matching correct constant names. Overrides const-naming-
415 | # style.
416 | #const-rgx=
417 |
418 | # Minimum line length for functions/classes that require docstrings, shorter
419 | # ones are exempt.
420 | docstring-min-length=-1
421 |
422 | # Naming style matching correct function names.
423 | function-naming-style=snake_case
424 |
425 | # Regular expression matching correct function names. Overrides function-
426 | # naming-style.
427 | #function-rgx=
428 |
429 | # Good variable names which should always be accepted, separated by a comma.
430 | good-names=i,
431 | j,
432 | k,
433 | ex,
434 | Run,
435 | _
436 |
437 | # Good variable names regexes, separated by a comma. If names match any regex,
438 | # they will always be accepted
439 | good-names-rgxs=
440 |
441 | # Include a hint for the correct naming format with invalid-name.
442 | include-naming-hint=no
443 |
444 | # Naming style matching correct inline iteration names.
445 | inlinevar-naming-style=any
446 |
447 | # Regular expression matching correct inline iteration names. Overrides
448 | # inlinevar-naming-style.
449 | #inlinevar-rgx=
450 |
451 | # Naming style matching correct method names.
452 | method-naming-style=snake_case
453 |
454 | # Regular expression matching correct method names. Overrides method-naming-
455 | # style.
456 | #method-rgx=
457 |
458 | # Naming style matching correct module names.
459 | module-naming-style=snake_case
460 |
461 | # Regular expression matching correct module names. Overrides module-naming-
462 | # style.
463 | #module-rgx=
464 |
465 | # Colon-delimited sets of names that determine each other's naming style when
466 | # the name regexes allow several styles.
467 | name-group=
468 |
469 | # Regular expression which should only match function or class names that do
470 | # not require a docstring.
471 | no-docstring-rgx=^_
472 |
473 | # List of decorators that produce properties, such as abc.abstractproperty. Add
474 | # to this list to register other decorators that produce valid properties.
475 | # These decorators are taken in consideration only for invalid-name.
476 | property-classes=abc.abstractproperty
477 |
478 | # Naming style matching correct variable names.
479 | variable-naming-style=snake_case
480 |
481 | # Regular expression matching correct variable names. Overrides variable-
482 | # naming-style.
483 | #variable-rgx=
484 |
485 |
486 | [STRING]
487 |
488 | # This flag controls whether inconsistent-quotes generates a warning when the
489 | # character used as a quote delimiter is used inconsistently within a module.
490 | check-quote-consistency=no
491 |
492 | # This flag controls whether the implicit-str-concat should generate a warning
493 | # on implicit string concatenation in sequences defined over several lines.
494 | check-str-concat-over-line-jumps=no
495 |
496 |
497 | [IMPORTS]
498 |
499 | # List of modules that can be imported at any level, not just the top level
500 | # one.
501 | allow-any-import-level=
502 |
503 | # Allow wildcard imports from modules that define __all__.
504 | allow-wildcard-with-all=no
505 |
506 | # Analyse import fallback blocks. This can be used to support both Python 2 and
507 | # 3 compatible code, which means that the block might have code that exists
508 | # only in one or another interpreter, leading to false positives when analysed.
509 | analyse-fallback-blocks=no
510 |
511 | # Deprecated modules which should not be used, separated by a comma.
512 | deprecated-modules=optparse,tkinter.tix
513 |
514 | # Create a graph of external dependencies in the given file (report RP0402 must
515 | # not be disabled).
516 | ext-import-graph=
517 |
518 | # Create a graph of every (i.e. internal and external) dependencies in the
519 | # given file (report RP0402 must not be disabled).
520 | import-graph=
521 |
522 | # Create a graph of internal dependencies in the given file (report RP0402 must
523 | # not be disabled).
524 | int-import-graph=
525 |
526 | # Force import order to recognize a module as part of the standard
527 | # compatibility libraries.
528 | known-standard-library=
529 |
530 | # Force import order to recognize a module as part of a third party library.
531 | known-third-party=enchant
532 |
533 | # Couples of modules and preferred modules, separated by a comma.
534 | preferred-modules=
535 |
536 |
537 | [CLASSES]
538 |
539 | # List of method names used to declare (i.e. assign) instance attributes.
540 | defining-attr-methods=__init__,
541 | __new__,
542 | setUp,
543 | __post_init__
544 |
545 | # List of member names, which should be excluded from the protected access
546 | # warning.
547 | exclude-protected=_asdict,
548 | _fields,
549 | _replace,
550 | _source,
551 | _make
552 |
553 | # List of valid names for the first argument in a class method.
554 | valid-classmethod-first-arg=cls
555 |
556 | # List of valid names for the first argument in a metaclass class method.
557 | valid-metaclass-classmethod-first-arg=cls
558 |
559 |
560 | [DESIGN]
561 |
562 | # Maximum number of arguments for function / method.
563 | max-args=5
564 |
565 | # Maximum number of attributes for a class (see R0902).
566 | max-attributes=7
567 |
568 | # Maximum number of boolean expressions in an if statement (see R0916).
569 | max-bool-expr=5
570 |
571 | # Maximum number of branch for function / method body.
572 | max-branches=12
573 |
574 | # Maximum number of locals for function / method body.
575 | max-locals=15
576 |
577 | # Maximum number of parents for a class (see R0901).
578 | max-parents=7
579 |
580 | # Maximum number of public methods for a class (see R0904).
581 | max-public-methods=20
582 |
583 | # Maximum number of return / yield for function / method body.
584 | max-returns=6
585 |
586 | # Maximum number of statements in function / method body.
587 | max-statements=50
588 |
589 | # Minimum number of public methods for a class (see R0903).
590 | min-public-methods=2
591 |
592 |
593 | [EXCEPTIONS]
594 |
595 | # Exceptions that will emit a warning when being caught. Defaults to
596 | # "BaseException, Exception".
597 | overgeneral-exceptions=BaseException,
598 | Exception
599 |
--------------------------------------------------------------------------------
/causal_curve/tmle_core.py:
--------------------------------------------------------------------------------
1 | """
2 | Defines the Targetted Maximum likelihood Estimation (TMLE) model class
3 | """
4 | import numpy as np
5 | import pandas as pd
6 | from pandas.api.types import is_float_dtype, is_numeric_dtype
7 | from pygam import LinearGAM, s
8 | from scipy.interpolate import interp1d
9 | from sklearn.neighbors import KernelDensity
10 | from sklearn.ensemble import GradientBoostingRegressor
11 | from statsmodels.nonparametric.kernel_regression import KernelReg
12 |
13 | from causal_curve.core import Core
14 |
15 |
16 | class TMLE_Core(Core):
17 | """
18 | Constructs a causal dose response curve via a modified version of Targetted
19 | Maximum Likelihood Estimation (TMLE) across a grid of the treatment values.
20 | Gradient boosting is used for prediction of the Q model and G models, simple
21 | kernel regression is used processing those model results, and a generalized
22 | additive model is used in the final step to contruct the final curve.
23 | Assumes continuous treatment and outcome variable.
24 |
25 | WARNING:
26 |
27 | -The treatment values should be roughly normally-distributed for this tool
28 | to work. Otherwise you may encounter internal math errors.
29 |
30 | -This algorithm assumes you've already performed the necessary transformations to
31 | categorical covariates (i.e. these variables are already one-hot encoded and
32 | one of the categories is excluded for each set of dummy variables).
33 |
34 | -Please take care to ensure that the "ignorability" assumption is met (i.e.
35 | all strong confounders are captured in your covariates and there is no
36 | informative censoring), otherwise your results will be biased, sometimes strongly so.
37 |
38 | Parameters
39 | ----------
40 |
41 | treatment_grid_num: int, optional (default = 100)
42 | Takes the treatment, and creates a quantile-based grid across its values.
43 | For instance, if the number 6 is selected, this means the algorithm will only take
44 | the 6 treatment variable values at approximately the 0, 20, 40, 60, 80, and 100th
45 | percentiles to estimate the causal dose response curve.
46 | Higher value here means the final curve will be more finely estimated,
47 | but also increases computation time. Default is usually a reasonable number.
48 |
49 | lower_grid_constraint: float, optional(default = 0.01)
50 | This adds an optional constraint of the lower side of the treatment grid.
51 | Sometimes data near the minimum values of the treatment are few in number
52 | and thus generate unstable estimates. By default, this clips the bottom 1 percentile
53 | or lower of treatment values. This can be as low as 0, indicating there is no
54 | lower limit to how much treatment data is considered.
55 |
56 | upper_grid_constraint: float, optional (default = 0.99)
57 | See above parameter. Just like above, but this is an upper constraint.
58 | By default, this clips the top 99th percentile or higher of treatment values.
59 | This can be as high as 1.0, indicating there is no upper limit to how much
60 | treatment data is considered.
61 |
62 | n_estimators: int, optional (default = 200)
63 | Optional argument to set the number of learners to use when sklearn
64 | creates TMLE's Q and G models.
65 |
66 | learning_rate: float, optional (default = 0.01)
67 | Optional argument to set the sklearn's learning rate for TMLE's Q and G models.
68 |
69 | max_depth: int, optional (default = 3)
70 | Optional argument to set sklearn's maximum depth when creating TMLE's Q and G models.
71 |
72 | bandwidth: float, optional (default = 0.5)
73 | Optional argument to set the bandwidth parameter of the internal
74 | kernel density estimation and kernel regression methods.
75 |
76 | random_seed: int, optional (default = None)
77 | Sets the random seed.
78 |
79 | verbose: bool, optional (default = False)
80 | Determines whether the user will get verbose status updates.
81 |
82 |
83 | Attributes
84 | ----------
85 |
86 | grid_values: array of shape (treatment_grid_num, )
87 | The gridded values of the treatment variable. Equally spaced.
88 |
89 | final_gam: `pygam.LinearGAM` class
90 | trained final model of `LinearGAM` class, from pyGAM library
91 |
92 | pseudo_out: array of shape (observations, )
93 | Adjusted, pseudo-outcome observations
94 |
95 |
96 | Methods
97 | ----------
98 | fit: (self, T, X, y)
99 | Fits the causal dose-response model
100 |
101 | calculate_CDRC: (self, ci, CDRC_grid_num)
102 | Calculates the CDRC (and confidence interval) from TMLE estimation
103 |
104 |
105 | Examples
106 | --------
107 |
108 | >>> # With continuous outcome
109 | >>> from causal_curve import TMLE_Regressor
110 | >>> tmle = TMLE_Regressor()
111 | >>> tmle.fit(T = df['Treatment'], X = df[['X_1', 'X_2']], y = df['Outcome'])
112 | >>> tmle_results = tmle.calculate_CDRC(0.95)
113 | >>> point_estimate = tmle.point_estimate(np.array([5.0]))
114 | >>> point_estimate_interval = tmle.point_estimate_interval(np.array([5.0]), 0.95)
115 |
116 |
117 | References
118 | ----------
119 |
120 | Kennedy EH, Ma Z, McHugh MD, Small DS. Nonparametric methods for doubly robust estimation
121 | of continuous treatment effects. Journal of the Royal Statistical Society,
122 | Series B. 79(4), 2017, pp.1229-1245.
123 |
124 | van der Laan MJ and Rubin D. Targeted maximum likelihood learning. In: The International
125 | Journal of Biostatistics, 2(1), 2006.
126 |
127 | van der Laan MJ and Gruber S. Collaborative double robust penalized targeted
128 | maximum likelihood estimation. In: The International Journal of Biostatistics 6(1), 2010.
129 |
130 | """
131 |
132 | def __init__(
133 | self,
134 | treatment_grid_num=100,
135 | lower_grid_constraint=0.01,
136 | upper_grid_constraint=0.99,
137 | n_estimators=200,
138 | learning_rate=0.01,
139 | max_depth=3,
140 | bandwidth=0.5,
141 | random_seed=None,
142 | verbose=False,
143 | ):
144 |
145 | self.treatment_grid_num = treatment_grid_num
146 | self.lower_grid_constraint = lower_grid_constraint
147 | self.upper_grid_constraint = upper_grid_constraint
148 | self.n_estimators = n_estimators
149 | self.learning_rate = learning_rate
150 | self.max_depth = max_depth
151 | self.bandwidth = bandwidth
152 | self.random_seed = random_seed
153 | self.verbose = verbose
154 |
155 | def _validate_init_params(self):
156 | """
157 | Checks that the params used when instantiating TMLE model are formatted correctly
158 | """
159 |
160 | # Checks for treatment_grid_num
161 | if not isinstance(self.treatment_grid_num, int):
162 | raise TypeError(
163 | f"treatment_grid_num parameter must be an integer, "
164 | f"but found type {type(self.treatment_grid_num)}"
165 | )
166 |
167 | if (isinstance(self.treatment_grid_num, int)) and self.treatment_grid_num < 10:
168 | raise ValueError(
169 | f"treatment_grid_num parameter should be >= 10 so your final curve "
170 | f"has enough resolution, but found value {self.treatment_grid_num}"
171 | )
172 |
173 | if (
174 | isinstance(self.treatment_grid_num, int)
175 | ) and self.treatment_grid_num >= 1000:
176 | raise ValueError("treatment_grid_num parameter is too high!")
177 |
178 | # Checks for lower_grid_constraint
179 | if not isinstance(self.lower_grid_constraint, float):
180 | raise TypeError(
181 | f"lower_grid_constraint parameter must be a float, "
182 | f"but found type {type(self.lower_grid_constraint)}"
183 | )
184 |
185 | if (
186 | isinstance(self.lower_grid_constraint, float)
187 | ) and self.lower_grid_constraint < 0:
188 | raise ValueError(
189 | f"lower_grid_constraint parameter cannot be < 0, "
190 | f"but found value {self.lower_grid_constraint}"
191 | )
192 |
193 | if (
194 | isinstance(self.lower_grid_constraint, float)
195 | ) and self.lower_grid_constraint >= 1.0:
196 | raise ValueError(
197 | f"lower_grid_constraint parameter cannot >= 1.0, "
198 | f"but found value {self.lower_grid_constraint}"
199 | )
200 |
201 | # Checks for upper_grid_constraint
202 | if not isinstance(self.upper_grid_constraint, float):
203 | raise TypeError(
204 | f"upper_grid_constraint parameter must be a float, "
205 | f"but found type {type(self.upper_grid_constraint)}"
206 | )
207 |
208 | if (
209 | isinstance(self.upper_grid_constraint, float)
210 | ) and self.upper_grid_constraint <= 0:
211 | raise ValueError(
212 | f"upper_grid_constraint parameter cannot be <= 0, "
213 | f"but found value {self.upper_grid_constraint}"
214 | )
215 |
216 | if (
217 | isinstance(self.upper_grid_constraint, float)
218 | ) and self.upper_grid_constraint > 1.0:
219 | raise ValueError(
220 | f"upper_grid_constraint parameter cannot > 1.0, "
221 | f"but found value {self.upper_grid_constraint}"
222 | )
223 |
224 | # Checks for lower_grid_constraint isn't higher than upper_grid_constraint
225 | if self.lower_grid_constraint >= self.upper_grid_constraint:
226 | raise ValueError(
227 | "lower_grid_constraint should be lower than upper_grid_constraint!"
228 | )
229 |
230 | # Checks for n_estimators
231 | if not isinstance(self.n_estimators, int):
232 | raise TypeError(
233 | f"n_estimators parameter must be an integer, "
234 | f"but found type {type(self.n_estimators)}"
235 | )
236 |
237 | if (self.n_estimators < 10) or (self.n_estimators > 100000):
238 | raise TypeError("n_estimators parameter must be between 10 and 100000")
239 |
240 | # Checks for learning_rate
241 | if not isinstance(self.learning_rate, (int, float)):
242 | raise TypeError(
243 | f"learning_rate parameter must be an integer or float, "
244 | f"but found type {type(self.learning_rate)}"
245 | )
246 |
247 | if (self.learning_rate <= 0) or (self.learning_rate >= 1000):
248 | raise TypeError(
249 | "learning_rate parameter must be greater than 0 and less than 1000"
250 | )
251 |
252 | # Checks for max_depth
253 | if not isinstance(self.max_depth, int):
254 | raise TypeError(
255 | f"max_depth parameter must be an integer, "
256 | f"but found type {type(self.max_depth)}"
257 | )
258 |
259 | if self.max_depth <= 0:
260 | raise TypeError("max_depth parameter must be greater than 0")
261 |
262 | # Checks for bandwidth
263 | if not isinstance(self.bandwidth, (int, float)):
264 | raise TypeError(
265 | f"bandwidth parameter must be an integer or float, "
266 | f"but found type {type(self.bandwidth)}"
267 | )
268 |
269 | if (self.bandwidth <= 0) or (self.bandwidth >= 1000):
270 | raise TypeError(
271 | "bandwidth parameter must be greater than 0 and less than 1000"
272 | )
273 |
274 | # Checks for random_seed
275 | if not isinstance(self.random_seed, (int, type(None))):
276 | raise TypeError(
277 | f"random_seed parameter must be an int, but found type {type(self.random_seed)}"
278 | )
279 |
280 | if (isinstance(self.random_seed, int)) and self.random_seed < 0:
281 | raise ValueError("random_seed parameter must be > 0")
282 |
283 | # Checks for verbose
284 | if not isinstance(self.verbose, bool):
285 | raise TypeError(
286 | f"verbose parameter must be a boolean type, but found type {type(self.verbose)}"
287 | )
288 |
289 | def _validate_fit_data(self):
290 | """Verifies that T, X, and y are formatted the right way"""
291 | # Checks for T column
292 | if not is_float_dtype(self.t_data):
293 | raise TypeError("Treatment data must be of type float")
294 |
295 | # Make sure all X columns are float or int
296 | if isinstance(self.x_data, pd.Series):
297 | if not is_numeric_dtype(self.x_data):
298 | raise TypeError(
299 | "All covariate (X) columns must be int or float type (i.e. must be numeric)"
300 | )
301 |
302 | elif isinstance(self.x_data, pd.DataFrame):
303 | for column in self.x_data:
304 | if not is_numeric_dtype(self.x_data[column]):
305 | raise TypeError(
306 | """All covariate (X) columns must be int or float type
307 | (i.e. must be numeric)"""
308 | )
309 |
310 | # Checks for Y column
311 | if not is_float_dtype(self.y_data):
312 | raise TypeError("Outcome data must be of type float")
313 |
314 | def _validate_calculate_CDRC_params(self, ci):
315 | """Validates the parameters given to `calculate_CDRC`"""
316 |
317 | if not isinstance(ci, float):
318 | raise TypeError(
319 | f"`ci` parameter must be an float, but found type {type(ci)}"
320 | )
321 |
322 | if isinstance(ci, float) and ((ci <= 0) or (ci >= 1.0)):
323 | raise ValueError("`ci` parameter should be between (0, 1)")
324 |
325 | def fit(self, T, X, y):
326 | """Fits the TMLE causal dose-response model. For now, this only
327 | accepts pandas columns. You *must* provide at least one covariate column.
328 |
329 | Parameters
330 | ----------
331 | T: array-like, shape (n_samples,)
332 | A continuous treatment variable
333 | X: array-like, shape (n_samples, m_features)
334 | Covariates, where n_samples is the number of samples
335 | and m_features is the number of features
336 | y: array-like, shape (n_samples,)
337 | Outcome variable
338 |
339 | Returns
340 | ----------
341 | self : object
342 |
343 | """
344 | self.rand_seed_wrapper(self.random_seed)
345 |
346 | self.t_data = T.reset_index(drop=True, inplace=False)
347 | self.x_data = X.reset_index(drop=True, inplace=False)
348 | self.y_data = y.reset_index(drop=True, inplace=False)
349 |
350 | # Validate this input data
351 | self._validate_fit_data()
352 |
353 | # Capture covariate and treatment column names
354 | self.treatment_col_name = self.t_data.name
355 |
356 | if len(self.x_data.shape) == 1:
357 | self.covariate_col_names = [self.x_data.name]
358 | else:
359 | self.covariate_col_names = self.x_data.columns.values.tolist()
360 |
361 | # Note the size of the data
362 | self.num_rows = len(self.t_data)
363 |
364 | # Produce expanded versions of the inputs
365 | self.if_verbose_print("Transforming data for the Q-model and G-model")
366 | (
367 | self.grid_values,
368 | self.fully_expanded_x,
369 | self.fully_expanded_t_and_x,
370 | ) = self._transform_inputs()
371 |
372 | # Fit G-model and get relevent predictions
373 | self.if_verbose_print(
374 | "Fitting G-model and making treatment assignment predictions..."
375 | )
376 | self.g_model_preds, self.g_model_2_preds = self._g_model()
377 |
378 | # Fit Q-model and get relevent predictions
379 | self.if_verbose_print("Fitting Q-model and making outcome predictions...")
380 | self.q_model_preds = self._q_model()
381 |
382 | # Calculating treatment assignment adjustment using G-model's predictions
383 | self.if_verbose_print(
384 | "Calculating treatment assignment adjustment using G-model's predictions..."
385 | )
386 | (
387 | self.n_interpd_values,
388 | self.var_n_interpd_values,
389 | ) = self._treatment_assignment_correction()
390 |
391 | # Adjusting outcome using Q-model's predictions
392 | self.if_verbose_print("Adjusting outcome using Q-model's predictions...")
393 | self.outcome_adjust, self.expand_outcome_adjust = self._outcome_adjustment()
394 |
395 | # Calculating corrected pseudo-outcome values
396 | self.if_verbose_print("Calculating corrected pseudo-outcome values...")
397 | self.pseudo_out = (self.y_data - self.outcome_adjust) / (
398 | self.n_interpd_values / self.var_n_interpd_values
399 | ) + self.expand_outcome_adjust
400 |
401 | # Training final GAM model using pseudo-outcome values
402 | self.if_verbose_print("Training final GAM model using pseudo-outcome values...")
403 | self.final_gam = self._fit_final_gam()
404 |
405 | def calculate_CDRC(self, ci=0.95):
406 | """Using the results of the fitted model, this generates a dataframe of CDRC point estimates
407 | at each of the values of the treatment grid. Connecting these estimates will produce
408 | the overall estimated CDRC. Confidence interval is returned as well.
409 |
410 | Parameters
411 | ----------
412 | ci: float (default = 0.95)
413 | The desired confidence interval to produce. Default value is 0.95, corresponding
414 | to 95% confidence intervals. bounded (0, 1.0).
415 |
416 | Returns
417 | ----------
418 | dataframe: Pandas dataframe
419 | Contains treatment grid values, the CDRC point estimate at that value,
420 | and the associated lower and upper confidence interval bounds at that point.
421 |
422 | self: object
423 |
424 | """
425 | self.rand_seed_wrapper(self.random_seed)
426 |
427 | self._validate_calculate_CDRC_params(ci)
428 |
429 | self.if_verbose_print(
430 | """
431 | Generating predictions for each value of treatment grid,
432 | and averaging to get the CDRC..."""
433 | )
434 |
435 | # Create CDRC predictions from the trained, final GAM
436 |
437 | self._cdrc_preds = self._cdrc_predictions_continuous(ci)
438 |
439 | return pd.DataFrame(
440 | self._cdrc_preds,
441 | columns=["Treatment", "Causal_Dose_Response", "Lower_CI", "Upper_CI"],
442 | ).round(3)
443 |
444 | def _transform_inputs(self):
445 | """Takes the treatment and covariates and transforms so they can
446 | be used by the algo"""
447 |
448 | # Create treatment grid
449 | grid_values = np.linspace(
450 | start=self.t_data.min(), stop=self.t_data.max(), num=self.treatment_grid_num
451 | )
452 |
453 | # Create expanded treatment array
454 | expanded_t = np.array([])
455 | for treat_value in grid_values:
456 | expanded_t = np.append(expanded_t, ([treat_value] * self.num_rows))
457 |
458 | # Create expanded treatment array with covariates
459 | expanded_t_and_x = pd.concat(
460 | [
461 | pd.DataFrame(expanded_t),
462 | pd.concat([self.x_data] * self.treatment_grid_num).reset_index(
463 | drop=True, inplace=False
464 | ),
465 | ],
466 | axis=1,
467 | ignore_index=True,
468 | )
469 |
470 | expanded_t_and_x.columns = [self.treatment_col_name] + self.covariate_col_names
471 |
472 | fully_expanded_t_and_x = pd.concat(
473 | [pd.concat([self.x_data, self.t_data], axis=1), expanded_t_and_x],
474 | axis=0,
475 | ignore_index=True,
476 | )
477 |
478 | fully_expanded_x = fully_expanded_t_and_x[self.covariate_col_names]
479 |
480 | return grid_values, fully_expanded_x, fully_expanded_t_and_x
481 |
482 | def _g_model(self):
483 | """Produces the G-model and gets treatment assignment predictions"""
484 |
485 | t = self.t_data.to_numpy()
486 | X = self.x_data.to_numpy()
487 |
488 | g_model = GradientBoostingRegressor(
489 | n_estimators=self.n_estimators,
490 | max_depth=self.max_depth,
491 | learning_rate=self.learning_rate,
492 | random_state=self.random_seed,
493 | ).fit(X=X, y=t)
494 | g_model_preds = g_model.predict(self.fully_expanded_x)
495 |
496 | g_model2 = GradientBoostingRegressor(
497 | n_estimators=self.n_estimators,
498 | max_depth=self.max_depth,
499 | learning_rate=self.learning_rate,
500 | random_state=self.random_seed,
501 | ).fit(X=X, y=((t - g_model_preds[0 : self.num_rows]) ** 2))
502 | g_model_2_preds = g_model2.predict(self.fully_expanded_x)
503 |
504 | return g_model_preds, g_model_2_preds
505 |
506 | def _q_model(self):
507 | """Produces the Q-model and gets outcome predictions using the provided treatment
508 | values, when the treatment is completely present and not present.
509 | """
510 |
511 | X = pd.concat([self.t_data, self.x_data], axis=1).to_numpy()
512 | y = self.y_data.to_numpy()
513 |
514 | q_model = GradientBoostingRegressor(
515 | n_estimators=self.n_estimators,
516 | max_depth=self.max_depth,
517 | learning_rate=self.learning_rate,
518 | random_state=self.random_seed,
519 | ).fit(X=X, y=y)
520 | q_model_preds = q_model.predict(self.fully_expanded_t_and_x)
521 |
522 | return q_model_preds
523 |
524 | def _treatment_assignment_correction(self):
525 | """Uses the G-model and its predictions to adjust treatment assignment"""
526 |
527 | t_standard = (
528 | self.fully_expanded_t_and_x[self.treatment_col_name] - self.g_model_preds
529 | ) / np.sqrt(self.g_model_2_preds)
530 |
531 | interpd_values = (
532 | interp1d(
533 | self.one_dim_estimate_density(t_standard.values)[0],
534 | self.one_dim_estimate_density(t_standard.values[0 : self.num_rows])[1],
535 | kind="linear",
536 | )(t_standard)
537 | / np.sqrt(self.g_model_2_preds)
538 | )
539 |
540 | n_interpd_values = interpd_values[0 : self.num_rows]
541 |
542 | temp_interpd = interpd_values[self.num_rows :]
543 |
544 | zeros_mat = np.zeros((self.num_rows, self.treatment_grid_num))
545 |
546 | for i in range(0, self.treatment_grid_num):
547 | lower = i * self.num_rows
548 | upper = i * self.num_rows + self.num_rows
549 | zeros_mat[:, i] = temp_interpd[lower:upper]
550 |
551 | var_n_interpd_values = self.pred_from_loess(
552 | train_x=self.grid_values,
553 | train_y=zeros_mat.mean(axis=0),
554 | x_to_pred=self.t_data,
555 | )
556 |
557 | return n_interpd_values, var_n_interpd_values
558 |
559 | def _outcome_adjustment(self):
560 | """Uses the Q-model and its predictions to adjust the outcome"""
561 |
562 | outcome_adjust = self.q_model_preds[0 : self.num_rows]
563 |
564 | temp_outcome_adjust = self.q_model_preds[self.num_rows :]
565 |
566 | zero_mat = np.zeros((self.num_rows, self.treatment_grid_num))
567 | for i in range(0, self.treatment_grid_num):
568 | lower = i * self.num_rows
569 | upper = i * self.num_rows + self.num_rows
570 | zero_mat[:, i] = temp_outcome_adjust[lower:upper]
571 |
572 | expand_outcome_adjust = self.pred_from_loess(
573 | train_x=self.grid_values,
574 | train_y=zero_mat.mean(axis=0),
575 | x_to_pred=self.t_data,
576 | )
577 |
578 | return outcome_adjust, expand_outcome_adjust
579 |
580 | def _fit_final_gam(self):
581 | """We now regress the original treatment values against the pseudo-outcome values"""
582 |
583 | return LinearGAM(
584 | s(0, n_splines=30, spline_order=3), max_iter=500, lam=self.bandwidth
585 | ).fit(self.t_data, y=self.pseudo_out)
586 |
587 | def one_dim_estimate_density(self, series):
588 | """
589 | Takes in a numpy array, returns grid values for KDE and predicted probabilities
590 | """
591 | series_grid = np.linspace(
592 | start=series.min(), stop=series.max(), num=self.num_rows
593 | )
594 |
595 | kde = KernelDensity(kernel="gaussian", bandwidth=self.bandwidth).fit(
596 | series.reshape(-1, 1)
597 | )
598 |
599 | return series_grid, np.exp(kde.score_samples(series_grid.reshape(-1, 1)))
600 |
601 | def pred_from_loess(self, train_x, train_y, x_to_pred):
602 | """
603 | Trains simple loess regression and returns predictions
604 | """
605 | kr_model = KernelReg(
606 | endog=train_y, exog=train_x, var_type="c", bw=[self.bandwidth]
607 | )
608 |
609 | return kr_model.fit(x_to_pred)[0]
610 |
--------------------------------------------------------------------------------
/causal_curve/mediation.py:
--------------------------------------------------------------------------------
1 | """
2 | Defines the Mediation test class
3 | """
4 |
5 | from pprint import pprint
6 |
7 | import numpy as np
8 | import pandas as pd
9 | from pandas.api.types import is_float_dtype
10 | from pygam import LinearGAM, s
11 |
12 | from causal_curve.core import Core
13 |
14 |
15 | class Mediation(Core):
16 | """
17 | Given three continuous variables (a treatment or independent variable of interest,
18 | a potential mediator, and an outcome variable of interest), Mediation provides a method
19 | to determine the average direct and indirect effect.
20 |
21 | Parameters
22 | ----------
23 |
24 | treatment_grid_num: int, optional (default = 10)
25 | Takes the treatment, and creates a quantile-based grid across its values.
26 | For instance, if the number 6 is selected, this means the algorithm will only take
27 | the 6 treatment variable values at approximately the 0, 20, 40, 60, 80, and 100th
28 | percentiles to estimate the causal dose response curve.
29 | Higher value here means the final curve will be more finely estimated,
30 | but also increases computation time. Default is usually a reasonable number.
31 |
32 | lower_grid_constraint: float, optional(default = 0.01)
33 | This adds an optional constraint of the lower side of the treatment grid.
34 | Sometimes data near the minimum values of the treatment are few in number
35 | and thus generate unstable estimates. By default, this clips the bottom 1 percentile
36 | or lower of treatment values. This can be as low as 0, indicating there is no
37 | lower limit to how much treatment data is considered.
38 |
39 | upper_grid_constraint: float, optional (default = 0.99)
40 | See above parameter. Just like above, but this is an upper constraint.
41 | By default, this clips the top 99th percentile or higher of treatment values.
42 | This can be as high as 1.0, indicating there is no upper limit to how much
43 | treatment data is considered.
44 |
45 | bootstrap_draws: int, optional (default = 500)
46 | Bootstrapping is used as part of the mediation test. The parameter determines
47 | the number of draws from the original data to create a single bootstrap replicate.
48 |
49 | bootstrap_replicates: int, optional (default = 100)
50 | Bootstrapping is used as part of the mediation test. The parameter determines
51 | the number of bootstrapping runs to perform / number of new datasets to create.
52 |
53 | spline_order: int, optional (default = 3)
54 | Order of the splines to use fitting the final GAM.
55 | Must be integer >= 1. Default value creates cubic splines.
56 |
57 | n_splines: int, optional (default = 5)
58 | Number of splines to use for the mediation and outcome GAMs.
59 | Must be integer >= 2. Must be non-negative.
60 |
61 | lambda_: int or float, optional (default = 0.5)
62 | Strength of smoothing penalty. Must be a positive float.
63 | Larger values enforce stronger smoothing.
64 |
65 | max_iter: int, optional (default = 100)
66 | Maximum number of iterations allowed for the maximum likelihood algo to converge.
67 |
68 | random_seed: int, optional (default = None)
69 | Sets the random seed.
70 |
71 | verbose: bool, optional (default = False)
72 | Determines whether the user will get verbose status updates.
73 |
74 |
75 | Attributes
76 | ----------
77 |
78 | grid_values: array of shape (treatment_grid_num, )
79 | The gridded values of the treatment variable. Equally spaced.
80 |
81 |
82 | Methods
83 | ----------
84 | fit: (self, T, M, y)
85 | Fits the trio of relevant variables using generalized additive models.
86 |
87 | calculate_effects: (self, ci)
88 | Calculates the average direct and indirect effects.
89 |
90 |
91 | Examples
92 | --------
93 | >>> from causal_curve import Mediation
94 | >>> med = Mediation(treatment_grid_num = 200, random_seed = 512)
95 | >>> med.fit(T = df['Treatment'], M = df['Mediator'], y = df['Outcome'])
96 | >>> med_results = med.calculate_effects(0.95)
97 |
98 |
99 | References
100 | ----------
101 |
102 | Imai K., Keele L., Tingley D. A General Approach to Causal Mediation Analysis. Psychological
103 | Methods. 15(4), 2010, pp.309–334.
104 |
105 | """
106 |
107 | def __init__(
108 | self,
109 | treatment_grid_num=10,
110 | lower_grid_constraint=0.01,
111 | upper_grid_constraint=0.99,
112 | bootstrap_draws=500,
113 | bootstrap_replicates=100,
114 | spline_order=3,
115 | n_splines=5,
116 | lambda_=0.5,
117 | max_iter=100,
118 | random_seed=None,
119 | verbose=False,
120 | ):
121 |
122 | self.treatment_grid_num = treatment_grid_num
123 | self.lower_grid_constraint = lower_grid_constraint
124 | self.upper_grid_constraint = upper_grid_constraint
125 | self.bootstrap_draws = bootstrap_draws
126 | self.bootstrap_replicates = bootstrap_replicates
127 | self.spline_order = spline_order
128 | self.n_splines = n_splines
129 | self.lambda_ = lambda_
130 | self.max_iter = max_iter
131 | self.random_seed = random_seed
132 | self.verbose = verbose
133 |
134 | # Validate the params
135 | self._validate_init_params()
136 | self.rand_seed_wrapper()
137 |
138 | if self.verbose:
139 | print("Using the following params for the mediation analysis:")
140 | pprint(self.get_params(), indent=4)
141 |
142 | def _validate_init_params(self):
143 | """
144 | Checks that the params used when instantiating mediation tool are formatted correctly
145 | """
146 |
147 | # Checks for treatment_grid_num
148 | if not isinstance(self.treatment_grid_num, int):
149 | raise TypeError(
150 | f"treatment_grid_num parameter must be an integer, "
151 | f"but found type {type(self.treatment_grid_num)}"
152 | )
153 |
154 | if (isinstance(self.treatment_grid_num, int)) and self.treatment_grid_num < 4:
155 | raise ValueError(
156 | f"treatment_grid_num parameter should be >= 4 so the internal models "
157 | f"have enough resolution, but found value {self.treatment_grid_num}"
158 | )
159 |
160 | if (isinstance(self.treatment_grid_num, int)) and self.treatment_grid_num > 100:
161 | raise ValueError("treatment_grid_num parameter is too high!")
162 |
163 | # Checks for lower_grid_constraint
164 | if not isinstance(self.lower_grid_constraint, float):
165 | raise TypeError(
166 | f"lower_grid_constraint parameter must be a float, "
167 | f"but found type {type(self.lower_grid_constraint)}"
168 | )
169 |
170 | if (
171 | isinstance(self.lower_grid_constraint, float)
172 | ) and self.lower_grid_constraint < 0:
173 | raise ValueError(
174 | f"lower_grid_constraint parameter cannot be < 0, "
175 | f"but found value {self.lower_grid_constraint}"
176 | )
177 |
178 | if (
179 | isinstance(self.lower_grid_constraint, float)
180 | ) and self.lower_grid_constraint >= 1.0:
181 | raise ValueError(
182 | f"lower_grid_constraint parameter cannot >= 1.0, "
183 | f"but found value {self.lower_grid_constraint}"
184 | )
185 |
186 | # Checks for upper_grid_constraint
187 | if not isinstance(self.upper_grid_constraint, float):
188 | raise TypeError(
189 | f"upper_grid_constraint parameter must be a float, "
190 | f"but found type {type(self.upper_grid_constraint)}"
191 | )
192 |
193 | if (
194 | isinstance(self.upper_grid_constraint, float)
195 | ) and self.upper_grid_constraint <= 0:
196 | raise ValueError(
197 | f"upper_grid_constraint parameter cannot be <= 0, "
198 | f"but found value {self.upper_grid_constraint}"
199 | )
200 |
201 | if (
202 | isinstance(self.upper_grid_constraint, float)
203 | ) and self.upper_grid_constraint > 1.0:
204 | raise ValueError(
205 | f"upper_grid_constraint parameter cannot > 1.0, "
206 | f"but found value {self.upper_grid_constraint}"
207 | )
208 |
209 | # Checks for bootstrap_draws
210 | if not isinstance(self.bootstrap_draws, int):
211 | raise TypeError(
212 | f"bootstrap_draws parameter must be a int, "
213 | f"but found type {type(self.bootstrap_draws)}"
214 | )
215 |
216 | if (isinstance(self.bootstrap_draws, int)) and self.bootstrap_draws < 100:
217 | raise ValueError(
218 | f"bootstrap_draws parameter cannot be < 100, "
219 | f"but found value {self.bootstrap_draws}"
220 | )
221 |
222 | if (isinstance(self.bootstrap_draws, int)) and self.bootstrap_draws > 500000:
223 | raise ValueError(
224 | f"bootstrap_draws parameter cannot > 500000, "
225 | f"but found value {self.bootstrap_draws}"
226 | )
227 |
228 | # Checks for bootstrap_replicates
229 | if not isinstance(self.bootstrap_replicates, int):
230 | raise TypeError(
231 | f"bootstrap_replicates parameter must be a int, "
232 | f"but found type {type(self.bootstrap_replicates)}"
233 | )
234 |
235 | if (
236 | isinstance(self.bootstrap_replicates, int)
237 | ) and self.bootstrap_replicates < 50:
238 | raise ValueError(
239 | f"bootstrap_replicates parameter cannot be < 50, "
240 | f"but found value {self.bootstrap_replicates}"
241 | )
242 |
243 | if (
244 | isinstance(self.bootstrap_replicates, int)
245 | ) and self.bootstrap_replicates > 100000:
246 | raise ValueError(
247 | f"bootstrap_replicates parameter cannot > 100000, "
248 | f"but found value {self.bootstrap_replicates}"
249 | )
250 |
251 | # Checks for lower_grid_constraint isn't higher than upper_grid_constraint
252 | if self.lower_grid_constraint >= self.upper_grid_constraint:
253 | raise ValueError(
254 | "lower_grid_constraint should be lower than upper_grid_constraint!"
255 | )
256 |
257 | # Checks for spline_order
258 | if not isinstance(self.spline_order, int):
259 | raise TypeError(
260 | f"spline_order parameter must be an integer, "
261 | f"but found type {type(self.spline_order)}"
262 | )
263 |
264 | if (isinstance(self.spline_order, int)) and self.spline_order < 3:
265 | raise ValueError(
266 | f"spline_order parameter should be >= 1, but found {self.spline_order}"
267 | )
268 |
269 | if (isinstance(self.spline_order, int)) and self.spline_order >= 30:
270 | raise ValueError("spline_order parameter is too high!")
271 |
272 | # Checks for n_splines
273 | if not isinstance(self.n_splines, int):
274 | raise TypeError(
275 | f"n_splines parameter must be an integer, but found type {type(self.n_splines)}"
276 | )
277 |
278 | if (isinstance(self.n_splines, int)) and self.n_splines < 2:
279 | raise ValueError(
280 | f"n_splines parameter should be >= 2, but found {self.n_splines}"
281 | )
282 |
283 | if (isinstance(self.n_splines, int)) and self.n_splines >= 100:
284 | raise ValueError("n_splines parameter is too high!")
285 |
286 | # Checks for lambda_
287 | if not isinstance(self.lambda_, (int, float)):
288 | raise TypeError(
289 | f"lambda_ parameter must be an int or float, but found type {type(self.lambda_)}"
290 | )
291 |
292 | if (isinstance(self.lambda_, (int, float))) and self.lambda_ <= 0:
293 | raise ValueError(
294 | f"lambda_ parameter should be >= 2, but found {self.lambda_}"
295 | )
296 |
297 | if (isinstance(self.lambda_, (int, float))) and self.lambda_ >= 1000:
298 | raise ValueError("lambda_ parameter is too high!")
299 |
300 | # Checks for max_iter
301 | if not isinstance(self.max_iter, int):
302 | raise TypeError(
303 | f"max_iter parameter must be an int, but found type {type(self.max_iter)}"
304 | )
305 |
306 | if (isinstance(self.max_iter, int)) and self.max_iter <= 10:
307 | raise ValueError(
308 | "max_iter parameter is too low! Results won't be reliable!"
309 | )
310 |
311 | if (isinstance(self.max_iter, int)) and self.max_iter >= 1e6:
312 | raise ValueError("max_iter parameter is unnecessarily high!")
313 |
314 | # Checks for random_seed
315 | if not isinstance(self.random_seed, (int, type(None))):
316 | raise TypeError(
317 | f"random_seed parameter must be an int, but found type {type(self.random_seed)}"
318 | )
319 |
320 | if (isinstance(self.random_seed, int)) and self.random_seed < 0:
321 | raise ValueError("random_seed parameter must be > 0")
322 |
323 | # Checks for verbose
324 | if not isinstance(self.verbose, bool):
325 | raise TypeError(
326 | f"verbose parameter must be a boolean type, but found type {type(self.verbose)}"
327 | )
328 |
329 | def _validate_fit_data(self):
330 | """Verifies that T, M, and y are formatted the right way"""
331 | # Checks for T column
332 | if not is_float_dtype(self.T):
333 | raise TypeError("Treatment data must be of type float")
334 |
335 | # Checks for M column
336 | if not is_float_dtype(self.M):
337 | raise TypeError("Mediation data must be of type float")
338 |
339 | # Checks for Y column
340 | if not is_float_dtype(self.y):
341 | raise TypeError("Outcome data must be of type float")
342 |
343 | def _grid_values(self):
344 | """Produces initial grid values for the treatment variable"""
345 | return np.quantile(
346 | self.T,
347 | q=np.linspace(
348 | start=self.lower_grid_constraint,
349 | stop=self.upper_grid_constraint,
350 | num=self.treatment_grid_num,
351 | ),
352 | )
353 |
354 | def _collect_mean_t_levels(self):
355 | """Collects the mean treatment value within each treatment bucket in the grid_values"""
356 |
357 | t_bin_means = []
358 |
359 | for index, _ in enumerate(self.grid_values):
360 | if index == (len(self.grid_values) - 1):
361 | continue
362 |
363 | t_bin_means.append(
364 | self.T[
365 | (
366 | (self.T >= self.grid_values[index])
367 | & (self.T <= self.grid_values[index + 1])
368 | )
369 | ].mean()
370 | )
371 |
372 | return t_bin_means
373 |
374 | def fit(self, T, M, y):
375 | """Fits models so that mediation analysis can be run.
376 | For now, this only accepts pandas columns.
377 |
378 | Parameters
379 | ----------
380 | T: array-like, shape (n_samples,)
381 | A continuous treatment variable
382 | M: array-like, shape (n_samples,)
383 | A continuous mediation variable
384 | y: array-like, shape (n_samples,)
385 | A continuous outcome variable
386 |
387 | Returns
388 | ----------
389 | self : object
390 |
391 | """
392 | self.rand_seed_wrapper(self.random_seed)
393 |
394 | self.T = T.reset_index(drop=True, inplace=False)
395 | self.M = M.reset_index(drop=True, inplace=False)
396 | self.y = y.reset_index(drop=True, inplace=False)
397 |
398 | # Validate this input data
399 | self._validate_fit_data()
400 |
401 | self.n = len(y)
402 |
403 | # Create grid_values
404 | self.grid_values = self._grid_values()
405 |
406 | # Loop through the comparisons in the grid_values
407 | if self.verbose:
408 | print("Beginning main loop through treatment bins...")
409 |
410 | # Collect loop results in this list
411 | self.final_bootstrap_results = []
412 |
413 | # Begin main loop
414 | for index, _ in enumerate(self.grid_values):
415 | if index == 0:
416 | continue
417 | if self.verbose:
418 | print(
419 | f"***** Starting iteration {index} of {len(self.grid_values) - 1} *****"
420 | )
421 |
422 | temp_low_treatment = self.grid_values[index - 1]
423 | temp_high_treatment = self.grid_values[index]
424 |
425 | bootstrap_results = self._bootstrap_analysis(
426 | temp_low_treatment, temp_high_treatment
427 | )
428 |
429 | self.final_bootstrap_results.append(bootstrap_results)
430 |
431 | def calculate_mediation(self, ci=0.95):
432 | """Conducts mediation analysis on the fit data
433 |
434 | Parameters
435 | ----------
436 | ci: float (default = 0.95)
437 | The desired bootstrap confidence interval to produce. Default value is 0.95,
438 | corresponding to 95% confidence intervals. bounded (0, 1.0).
439 |
440 | Returns
441 | ----------
442 | dataframe: Pandas dataframe
443 | Contains the estimate of the direct and indirect effects
444 | and the proportion of indirect effects across the treatment grid values.
445 | The bootstrap confidence interval that is returned might not be symmetric.
446 |
447 | self : object
448 |
449 | """
450 | self.rand_seed_wrapper(self.random_seed)
451 |
452 | # Collect effect results in these lists
453 | self.t_bin_means = self._collect_mean_t_levels()
454 | self.prop_direct_list = []
455 | self.prop_indirect_list = []
456 | general_indirect = []
457 |
458 | lower = (1 - ci) / 2
459 | upper = ci + lower
460 |
461 | # Calculate results for each treatment bin
462 | for index, _ in enumerate(self.grid_values):
463 |
464 | if index == (len(self.grid_values) - 1):
465 | continue
466 |
467 | temp_bootstrap_results = self.final_bootstrap_results[index]
468 |
469 | mean_results = {
470 | key: temp_bootstrap_results[key].mean()
471 | for key in temp_bootstrap_results
472 | }
473 |
474 | tau_coef = (
475 | mean_results["d1"]
476 | + mean_results["d0"]
477 | + mean_results["z1"]
478 | + mean_results["z0"]
479 | ) / 2
480 | n0 = mean_results["d0"] / tau_coef
481 | n1 = mean_results["d1"] / tau_coef
482 | n_avg = (n0 + n1) / 2
483 |
484 | tau_general = (
485 | temp_bootstrap_results["d1"]
486 | + temp_bootstrap_results["d0"]
487 | + temp_bootstrap_results["z1"]
488 | + temp_bootstrap_results["z0"]
489 | ) / 2
490 | nu_0_general = temp_bootstrap_results["d0"] / tau_general
491 | nu_1_general = temp_bootstrap_results["d1"] / tau_general
492 | nu_avg_general = (nu_0_general + nu_1_general) / 2
493 |
494 | self.prop_direct_list.append(1 - n_avg)
495 | self.prop_indirect_list.append(n_avg)
496 | general_indirect.append(nu_avg_general)
497 |
498 | general_indirect = pd.concat(general_indirect)
499 |
500 | # Bootstrap these general_indirect values
501 | bootstrap_overall_means = []
502 | for _ in range(0, 1000):
503 | bootstrap_overall_means.append(
504 | general_indirect.sample(frac=0.25, replace=True).mean()
505 | )
506 |
507 | bootstrap_overall_means = np.array(bootstrap_overall_means)
508 |
509 | final_results = pd.DataFrame(
510 | {
511 | "Treatment_Value": self.t_bin_means,
512 | "Proportion_Direct_Effect": self.prop_direct_list,
513 | "Proportion_Indirect_Effect": self.prop_indirect_list,
514 | }
515 | ).round(4)
516 |
517 | # Clip Proportion_Direct_Effect and Proportion_Indirect_Effect
518 | final_results["Proportion_Direct_Effect"].clip(lower=0, upper=1.0, inplace=True)
519 | final_results["Proportion_Indirect_Effect"].clip(
520 | lower=0, upper=1.0, inplace=True
521 | )
522 |
523 | # Calculate overall, mean, indirect effect
524 | total_prop_mean = round(np.array(self.prop_indirect_list).mean(), 4)
525 | total_prop_lower = self.clip_negatives(
526 | round(np.percentile(bootstrap_overall_means, q=lower * 100), 4)
527 | )
528 | total_prop_upper = self.clip_negatives(
529 | round(np.percentile(bootstrap_overall_means, q=upper * 100), 4)
530 | )
531 |
532 | print(
533 | f"""\n\nMean indirect effect proportion:
534 | {total_prop_mean} ({total_prop_lower} - {total_prop_upper})
535 | """
536 | )
537 | return final_results
538 |
539 | def _bootstrap_analysis(self, temp_low_treatment, temp_high_treatment):
540 | """The top-level function used in the fitting method"""
541 |
542 | bootstrap_collection = []
543 |
544 | for _ in range(0, self.bootstrap_replicates):
545 | # Create single bootstrap replicate
546 | temp_t, temp_m, temp_y = self._create_bootstrap_replicate()
547 | # Create the models from this
548 | temp_mediator_model, temp_outcome_model = self._fit_gams(
549 | temp_t, temp_m, temp_y
550 | )
551 | # Make mediator predictions
552 | predict_m1, predict_m0 = self._mediator_prediction(
553 | temp_mediator_model,
554 | temp_t,
555 | temp_m,
556 | temp_low_treatment,
557 | temp_high_treatment,
558 | )
559 | # Make outcome predictions
560 | outcome_preds = self._outcome_prediction(
561 | temp_low_treatment,
562 | temp_high_treatment,
563 | predict_m1,
564 | predict_m0,
565 | temp_outcome_model,
566 | )
567 | # Collect the replicate results here
568 | bootstrap_collection.append(outcome_preds)
569 |
570 | # Convert this into a dataframe
571 | bootstrap_results = pd.DataFrame(bootstrap_collection)
572 |
573 | return bootstrap_results
574 |
575 | def _create_bootstrap_replicate(self):
576 | """Creates a single bootstrap replicate from the data"""
577 | temp_t = self.T.sample(n=self.bootstrap_draws, replace=True)
578 | temp_m = self.M.iloc[temp_t.index]
579 | temp_y = self.y.iloc[temp_t.index]
580 |
581 | return temp_t, temp_m, temp_y
582 |
583 | def _fit_gams(self, temp_t, temp_m, temp_y):
584 | """Fits the mediator and outcome GAMs"""
585 | temp_mediator_model = LinearGAM(
586 | s(0, n_splines=self.n_splines, spline_order=self.spline_order),
587 | fit_intercept=True,
588 | max_iter=self.max_iter,
589 | lam=self.lambda_,
590 | )
591 | temp_mediator_model.fit(temp_t, temp_m)
592 |
593 | temp_outcome_model = LinearGAM(
594 | s(0, n_splines=self.n_splines, spline_order=self.spline_order)
595 | + s(1, n_splines=self.n_splines, spline_order=self.spline_order),
596 | fit_intercept=True,
597 | max_iter=self.max_iter,
598 | lam=self.lambda_,
599 | )
600 | temp_outcome_model.fit(pd.concat([temp_t, temp_m], axis=1), temp_y)
601 |
602 | return temp_mediator_model, temp_outcome_model
603 |
604 | def _mediator_prediction(
605 | self,
606 | temp_mediator_model,
607 | temp_t,
608 | temp_m,
609 | temp_low_treatment,
610 | temp_high_treatment,
611 | ):
612 | """Makes predictions based on the mediator models"""
613 |
614 | m1_mean = temp_mediator_model.predict(temp_high_treatment)[0]
615 | m0_mean = temp_mediator_model.predict(temp_low_treatment)[0]
616 |
617 | std_dev = (
618 | (temp_mediator_model.deviance_residuals(temp_t, temp_m) ** 2).sum()
619 | ) / (self.n - (len(temp_mediator_model.get_params()["terms"]._terms) + 1))
620 |
621 | est_error = np.random.normal(loc=0, scale=std_dev, size=self.n)
622 |
623 | predict_m1 = m1_mean + est_error
624 | predict_m0 = m0_mean + est_error
625 |
626 | return predict_m1, predict_m0
627 |
628 | def _outcome_prediction(
629 | self,
630 | temp_low_treatment,
631 | temp_high_treatment,
632 | predict_m1,
633 | predict_m0,
634 | temp_outcome_model,
635 | ):
636 | """Makes predictions based on the outcome models"""
637 |
638 | outcome_preds = {}
639 |
640 | inputs = [
641 | ["d1", temp_high_treatment, temp_high_treatment, predict_m1, predict_m0],
642 | ["d0", temp_low_treatment, temp_low_treatment, predict_m1, predict_m0],
643 | ["z1", temp_high_treatment, temp_low_treatment, predict_m1, predict_m1],
644 | ["z0", temp_high_treatment, temp_low_treatment, predict_m0, predict_m0],
645 | ]
646 |
647 | for element in inputs:
648 |
649 | # Set treatment values
650 | t_1 = element[1]
651 | t_0 = element[2]
652 |
653 | # Set mediator values
654 | m_1 = element[3]
655 | m_0 = element[4]
656 |
657 | pr_1 = temp_outcome_model.predict(
658 | np.column_stack((np.repeat(t_1, self.n), m_1))
659 | )
660 |
661 | pr_0 = temp_outcome_model.predict(
662 | np.column_stack((np.repeat(t_0, self.n), m_0))
663 | )
664 |
665 | outcome_preds[element[0]] = (pr_1 - pr_0).mean()
666 |
667 | return outcome_preds
668 |
--------------------------------------------------------------------------------